dkforest

A forum and chat platform (onion)
git clone https://git.dasho.dev/n0tr1v/dkforest.git
Log | Files | Refs | LICENSE

commit 68cebecab91f74e9e84b81c8b3df5c4d927d41a8
parent 169059461d0a617fe1e28955f95f9850c7332c01
Author: n0tr1v <n0tr1v@protonmail.com>
Date:   Mon, 12 Dec 2022 21:22:59 -0500

strongly typed forum IDs/UUIDs

Diffstat:
Mpkg/database/tableUserForumThreadSubscriptions.go | 10+++++-----
Mpkg/database/table_forum_threads.go | 45+++++++++++++++++++++++++++------------------
Mpkg/web/handlers/api/v1/handlers.go | 6+++---
Mpkg/web/handlers/handlers.go | 50+++++++++++++++++++++++++-------------------------
4 files changed, 60 insertions(+), 51 deletions(-)

diff --git a/pkg/database/tableUserForumThreadSubscriptions.go b/pkg/database/tableUserForumThreadSubscriptions.go @@ -8,7 +8,7 @@ import ( type UserForumThreadSubscription struct { UserID UserID - ThreadID int64 + ThreadID ForumThreadID CreatedAt time.Time User User } @@ -19,21 +19,21 @@ func (s *UserForumThreadSubscription) DoSave() { } } -func SubscribeToForumThread(userID UserID, threadID int64) (err error) { +func SubscribeToForumThread(userID UserID, threadID ForumThreadID) (err error) { return DB.Create(&UserForumThreadSubscription{UserID: userID, ThreadID: threadID}).Error } -func UnsubscribeFromForumThread(userID UserID, threadID int64) (err error) { +func UnsubscribeFromForumThread(userID UserID, threadID ForumThreadID) (err error) { return DB.Delete(&UserForumThreadSubscription{}, "user_id = ? AND thread_id = ?", userID, threadID).Error } -func IsUserSubscribedToForumThread(userID UserID, threadID int64) bool { +func IsUserSubscribedToForumThread(userID UserID, threadID ForumThreadID) bool { var count int64 DB.Model(UserForumThreadSubscription{}).Where("user_id = ? AND thread_id = ?", userID, threadID).Count(&count) return count == 1 } -func GetUsersSubscribedToForumThread(threadID int64) (out []UserForumThreadSubscription, err error) { +func GetUsersSubscribedToForumThread(threadID ForumThreadID) (out []UserForumThreadSubscription, err error) { err = DB.Preload("User").Find(&out, "thread_id = ?", threadID).Error return } diff --git a/pkg/database/table_forum_threads.go b/pkg/database/table_forum_threads.go @@ -15,19 +15,24 @@ import ( bf "github.com/russross/blackfriday/v2" ) +type ForumCategoryID int64 + type ForumCategory struct { - ID int64 + ID ForumCategoryID Idx int64 Name string Slug string } +type ForumThreadID int64 +type ForumThreadUUID string + type ForumThread struct { - ID int64 - UUID string + ID ForumThreadID + UUID ForumThreadUUID Name string UserID UserID - CategoryID int64 + CategoryID ForumCategoryID CreatedAt time.Time User User Category ForumCategory @@ -49,19 +54,23 @@ func GetForumCategoryBySlug(slug string) (out ForumCategory, err error) { return } +type ForumMessageID int64 + +type ForumMessageUUID string + type ForumMessage struct { - ID int64 - UUID string + ID ForumMessageID + UUID ForumMessageUUID Message string UserID UserID - ThreadID int64 + ThreadID ForumThreadID CreatedAt time.Time User User } type ForumReadRecord struct { UserID UserID - ThreadID int64 + ThreadID ForumThreadID ReadAt time.Time } @@ -145,21 +154,21 @@ func (m *ForumMessage) Escape() string { return res } -func GetForumMessage(messageID int64) (out ForumMessage, err error) { +func GetForumMessage(messageID ForumMessageID) (out ForumMessage, err error) { err = DB.First(&out, "id = ?", messageID).Error return } -func GetForumMessageByUUID(messageUUID string) (out ForumMessage, err error) { +func GetForumMessageByUUID(messageUUID ForumMessageUUID) (out ForumMessage, err error) { err = DB.First(&out, "uuid = ?", messageUUID).Error return } -func DeleteForumMessageByID(messageID int64) error { +func DeleteForumMessageByID(messageID ForumMessageID) error { return DB.Where("id = ?", messageID).Delete(&ForumMessage{}).Error } -func DeleteForumThreadByID(threadID int64) error { +func DeleteForumThreadByID(threadID ForumThreadID) error { return DB.Where("id = ?", threadID).Delete(&ForumThread{}).Error } @@ -168,17 +177,17 @@ func (m *ForumMessage) CanEdit() bool { return true } -func GetForumThread(threadID int64) (out ForumThread, err error) { +func GetForumThread(threadID ForumThreadID) (out ForumThread, err error) { err = DB.First(&out, "id = ? AND is_club = 1", threadID).Error return } -func GetForumThreadByID(threadID int64) (out ForumThread, err error) { +func GetForumThreadByID(threadID ForumThreadID) (out ForumThread, err error) { err = DB.First(&out, "id = ? AND is_club = 0", threadID).Error return } -func GetForumThreadByUUID(threadUUID string) (out ForumThread, err error) { +func GetForumThreadByUUID(threadUUID ForumThreadUUID) (out ForumThread, err error) { err = DB.First(&out, "uuid = ? AND is_club = 0", threadUUID).Error return } @@ -219,7 +228,7 @@ ORDER BY t.id DESC`, userID).Scan(&out).Error return } -func GetPublicForumCategoryThreads(userID UserID, categoryID int64) (out []ForumThreadAug, err error) { +func GetPublicForumCategoryThreads(userID UserID, categoryID ForumCategoryID) (out []ForumThreadAug, err error) { err = DB.Raw(`SELECT t.*, u.username as author, u.chat_color as author_chat_color, @@ -245,7 +254,7 @@ ORDER BY m.created_at DESC, t.id DESC`, userID, categoryID).Scan(&out).Error return } -func GetPublicForumThreadsSearch(userID int64) (out []ForumThreadAug, err error) { +func GetPublicForumThreadsSearch(userID UserID) (out []ForumThreadAug, err error) { err = DB.Raw(`SELECT t.*, u.username as author, u.chat_color as author_chat_color, @@ -271,7 +280,7 @@ ORDER BY m.created_at DESC, t.id DESC`, userID).Scan(&out).Error return } -func GetThreadMessages(threadID int64) (out []ForumMessage, err error) { +func GetThreadMessages(threadID ForumThreadID) (out []ForumMessage, err error) { err = DB.Preload("User").Find(&out, "thread_id = ?", threadID).Error return } diff --git a/pkg/web/handlers/api/v1/handlers.go b/pkg/web/handlers/api/v1/handlers.go @@ -236,7 +236,7 @@ func UnsubscribeHandler(c echo.Context) error { func ThreadSubscribeHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) - threadUUID := c.Param("threadUUID") + threadUUID := database.ForumThreadUUID(c.Param("threadUUID")) thread, err := database.GetForumThreadByUUID(threadUUID) if err != nil { return c.Redirect(http.StatusFound, c.Request().Referer()) @@ -247,7 +247,7 @@ func ThreadSubscribeHandler(c echo.Context) error { func ThreadUnsubscribeHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) - threadUUID := c.Param("threadUUID") + threadUUID := database.ForumThreadUUID(c.Param("threadUUID")) thread, err := database.GetForumThreadByUUID(threadUUID) if err != nil { return c.Redirect(http.StatusFound, c.Request().Referer()) @@ -319,7 +319,7 @@ func ChatDeleteMessageHandler(c echo.Context) error { func ClubDeleteMessageHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) - messageID := utils.DoParseInt64(c.Param("messageID")) + messageID := database.ForumMessageID(utils.DoParseInt64(c.Param("messageID"))) msg, err := database.GetForumMessage(messageID) if err != nil { return c.Redirect(http.StatusFound, c.Request().Referer()) diff --git a/pkg/web/handlers/handlers.go b/pkg/web/handlers/handlers.go @@ -1641,7 +1641,7 @@ func ForumCategoryHandler(c echo.Context) error { func ThreadHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) - threadUUID := c.Param("threadUUID") + threadUUID := database.ForumThreadUUID(c.Param("threadUUID")) thread, err := database.GetForumThreadByUUID(threadUUID) if err != nil { return c.Redirect(http.StatusFound, "/") @@ -1768,7 +1768,7 @@ func ThreadReplyHandler(c echo.Context) error { return c.Render(http.StatusOK, "flash", FlashResponse{Message: "Account must be at least 3 days old", Redirect: c.Request().Referer(), Type: "alert-danger"}) } - threadUUID := c.Param("threadUUID") + threadUUID := database.ForumThreadUUID(c.Param("threadUUID")) thread, err := database.GetForumThreadByUUID(threadUUID) if err != nil { return c.Redirect(http.StatusFound, "/") @@ -1782,7 +1782,7 @@ func ThreadReplyHandler(c echo.Context) error { data.ErrorMessage = "Message must have at least 3 characters" return c.Render(http.StatusOK, "thread-reply", data) } - message := database.ForumMessage{UUID: uuid.New().String(), Message: data.Message, UserID: authUser.ID, ThreadID: thread.ID} + message := database.ForumMessage{UUID: database.ForumMessageUUID(uuid.New().String()), Message: data.Message, UserID: authUser.ID, ThreadID: thread.ID} if err := database.DB.Create(&message).Error; err != nil { logrus.Error(err) } @@ -1794,7 +1794,7 @@ func ThreadReplyHandler(c echo.Context) error { database.CreateNotification(msg, sub.UserID) } } - return c.Redirect(http.StatusFound, "/t/"+thread.UUID) + return c.Redirect(http.StatusFound, "/t/"+string(thread.UUID)) } return c.Render(http.StatusOK, "thread-reply", data) @@ -1802,7 +1802,7 @@ func ThreadReplyHandler(c echo.Context) error { func ClubThreadReplyHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) - threadID := utils.DoParseInt64(c.Param("threadID")) + threadID := database.ForumThreadID(utils.DoParseInt64(c.Param("threadID"))) thread, err := database.GetForumThread(threadID) if err != nil { return c.Redirect(http.StatusFound, "/") @@ -1817,9 +1817,9 @@ func ClubThreadReplyHandler(c echo.Context) error { data.ErrorMessage = "Message must have at least 3 characters" return c.Render(http.StatusOK, "club.new-thread", data) } - message := database.ForumMessage{UUID: uuid.New().String(), Message: data.Message, UserID: authUser.ID, ThreadID: thread.ID} + message := database.ForumMessage{UUID: database.ForumMessageUUID(uuid.New().String()), Message: data.Message, UserID: authUser.ID, ThreadID: thread.ID} database.DB.Create(&message) - return c.Redirect(http.StatusFound, "/club/threads/"+utils.FormatInt64(thread.ID)) + return c.Redirect(http.StatusFound, "/club/threads/"+utils.FormatInt64(int64(thread.ID))) } return c.Render(http.StatusOK, "club.thread-reply", data) @@ -1833,7 +1833,7 @@ func ThreadDeleteMessageHandler(c echo.Context) error { if !authUser.CanUseForumFn() { return c.Render(http.StatusOK, "flash", FlashResponse{Message: "Account must be at least 3 days old", Redirect: c.Request().Referer(), Type: "alert-danger"}) } - messageUUID := c.Param("messageUUID") + messageUUID := database.ForumMessageUUID(c.Param("messageUUID")) msg, err := database.GetForumMessageByUUID(messageUUID) if err != nil { return c.Redirect(http.StatusFound, c.Request().Referer()) @@ -1858,7 +1858,7 @@ func ThreadDeleteMessageHandler(c echo.Context) error { if err := database.DeleteForumMessageByID(msg.ID); err != nil { logrus.Error(err) } - return c.Redirect(http.StatusFound, "/t/"+data.Thread.UUID) + return c.Redirect(http.StatusFound, "/t/"+string(data.Thread.UUID)) } return c.Render(http.StatusOK, "thread-message-delete", data) @@ -1960,7 +1960,7 @@ func ThreadEditHandler(c echo.Context) error { if !authUser.IsAdmin { return c.Redirect(http.StatusFound, c.Request().Referer()) } - threadUUID := c.Param("threadUUID") + threadUUID := database.ForumThreadUUID(c.Param("threadUUID")) thread, err := database.GetForumThreadByUUID(threadUUID) if err != nil { return c.Redirect(http.StatusFound, c.Request().Referer()) @@ -1969,7 +1969,7 @@ func ThreadEditHandler(c echo.Context) error { data.Thread = thread if c.Request().Method == http.MethodPost { - thread.CategoryID = utils.DoParseInt64(c.Request().PostFormValue("category_id")) + thread.CategoryID = database.ForumCategoryID(utils.DoParseInt64(c.Request().PostFormValue("category_id"))) thread.DoSave() return c.Redirect(http.StatusFound, "/forum") } @@ -1985,7 +1985,7 @@ func ThreadDeleteHandler(c echo.Context) error { if !authUser.CanUseForumFn() { return c.Render(http.StatusOK, "flash", FlashResponse{Message: "Account must be at least 3 days old", Redirect: c.Request().Referer(), Type: "alert-danger"}) } - threadUUID := c.Param("threadUUID") + threadUUID := database.ForumThreadUUID(c.Param("threadUUID")) thread, err := database.GetForumThreadByUUID(threadUUID) if err != nil { return c.Redirect(http.StatusFound, c.Request().Referer()) @@ -2016,8 +2016,8 @@ func ThreadEditMessageHandler(c echo.Context) error { if !authUser.CanUseForumFn() { return c.Render(http.StatusOK, "flash", FlashResponse{Message: "Account must be at least 3 days old", Redirect: c.Request().Referer(), Type: "alert-danger"}) } - threadUUID := c.Param("threadUUID") - messageUUID := c.Param("messageUUID") + threadUUID := database.ForumThreadUUID(c.Param("threadUUID")) + messageUUID := database.ForumMessageUUID(c.Param("messageUUID")) thread, err := database.GetForumThreadByUUID(threadUUID) if err != nil { return c.Redirect(http.StatusFound, "/") @@ -2042,7 +2042,7 @@ func ThreadEditMessageHandler(c echo.Context) error { } msg.Message = data.Message msg.DoSave() - return c.Redirect(http.StatusFound, "/t/"+thread.UUID) + return c.Redirect(http.StatusFound, "/t/"+string(thread.UUID)) } return c.Render(http.StatusOK, "thread-reply", data) @@ -2050,8 +2050,8 @@ func ThreadEditMessageHandler(c echo.Context) error { func ClubThreadEditMessageHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) - threadID := utils.DoParseInt64(c.Param("threadID")) - messageID := utils.DoParseInt64(c.Param("messageID")) + threadID := database.ForumThreadID(utils.DoParseInt64(c.Param("threadID"))) + messageID := database.ForumMessageID(utils.DoParseInt64(c.Param("messageID"))) thread, err := database.GetForumThread(threadID) if err != nil { return c.Redirect(http.StatusFound, "/") @@ -2077,7 +2077,7 @@ func ClubThreadEditMessageHandler(c echo.Context) error { } msg.Message = data.Message msg.DoSave() - return c.Redirect(http.StatusFound, "/club/threads/"+utils.FormatInt64(thread.ID)) + return c.Redirect(http.StatusFound, "/club/threads/"+utils.FormatInt64(int64(thread.ID))) } return c.Render(http.StatusOK, "club.thread-reply", data) @@ -2104,12 +2104,12 @@ func NewThreadHandler(c echo.Context) error { data.ErrorMessage = "Thread message must have at least 3-20000 characters" return c.Render(http.StatusOK, "new-thread", data) } - thread := database.ForumThread{UUID: uuid.New().String(), Name: data.ThreadName, UserID: authUser.ID, CategoryID: 1} + thread := database.ForumThread{UUID: database.ForumThreadUUID(uuid.New().String()), Name: data.ThreadName, UserID: authUser.ID, CategoryID: 1} database.DB.Create(&thread) - message := database.ForumMessage{UUID: uuid.New().String(), Message: data.Message, UserID: authUser.ID, ThreadID: thread.ID} + message := database.ForumMessage{UUID: database.ForumMessageUUID(uuid.New().String()), Message: data.Message, UserID: authUser.ID, ThreadID: thread.ID} database.DB.Create(&message) _ = database.SubscribeToForumThread(authUser.ID, thread.ID) - return c.Redirect(http.StatusFound, "/t/"+thread.UUID) + return c.Redirect(http.StatusFound, "/t/"+string(thread.UUID)) } return c.Render(http.StatusOK, "new-thread", data) @@ -2131,11 +2131,11 @@ func ClubNewThreadHandler(c echo.Context) error { data.ErrorMessage = "Thread name must have at least 3 characters" return c.Render(http.StatusOK, "club.new-thread", data) } - thread := database.ForumThread{UUID: uuid.New().String(), Name: data.ThreadName, UserID: authUser.ID} + thread := database.ForumThread{UUID: database.ForumThreadUUID(uuid.New().String()), Name: data.ThreadName, UserID: authUser.ID} database.DB.Create(&thread) - message := database.ForumMessage{UUID: uuid.New().String(), Message: data.Message, UserID: authUser.ID, ThreadID: thread.ID} + message := database.ForumMessage{UUID: database.ForumMessageUUID(uuid.New().String()), Message: data.Message, UserID: authUser.ID, ThreadID: thread.ID} database.DB.Create(&message) - return c.Redirect(http.StatusFound, "/club/threads/"+utils.FormatInt64(thread.ID)) + return c.Redirect(http.StatusFound, "/club/threads/"+utils.FormatInt64(int64(thread.ID))) } return c.Render(http.StatusOK, "club.new-thread", data) @@ -2150,7 +2150,7 @@ func ClubMembersHandler(c echo.Context) error { func ClubThreadHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) - threadID := utils.DoParseInt64(c.Param("threadID")) + threadID := database.ForumThreadID(utils.DoParseInt64(c.Param("threadID"))) thread, err := database.GetForumThread(threadID) if err != nil { return c.Redirect(http.StatusFound, "/")