commit 68cebecab91f74e9e84b81c8b3df5c4d927d41a8
parent 169059461d0a617fe1e28955f95f9850c7332c01
Author: n0tr1v <n0tr1v@protonmail.com>
Date: Mon, 12 Dec 2022 21:22:59 -0500
strongly typed forum IDs/UUIDs
Diffstat:
4 files changed, 60 insertions(+), 51 deletions(-)
diff --git a/pkg/database/tableUserForumThreadSubscriptions.go b/pkg/database/tableUserForumThreadSubscriptions.go
@@ -8,7 +8,7 @@ import (
type UserForumThreadSubscription struct {
UserID UserID
- ThreadID int64
+ ThreadID ForumThreadID
CreatedAt time.Time
User User
}
@@ -19,21 +19,21 @@ func (s *UserForumThreadSubscription) DoSave() {
}
}
-func SubscribeToForumThread(userID UserID, threadID int64) (err error) {
+func SubscribeToForumThread(userID UserID, threadID ForumThreadID) (err error) {
return DB.Create(&UserForumThreadSubscription{UserID: userID, ThreadID: threadID}).Error
}
-func UnsubscribeFromForumThread(userID UserID, threadID int64) (err error) {
+func UnsubscribeFromForumThread(userID UserID, threadID ForumThreadID) (err error) {
return DB.Delete(&UserForumThreadSubscription{}, "user_id = ? AND thread_id = ?", userID, threadID).Error
}
-func IsUserSubscribedToForumThread(userID UserID, threadID int64) bool {
+func IsUserSubscribedToForumThread(userID UserID, threadID ForumThreadID) bool {
var count int64
DB.Model(UserForumThreadSubscription{}).Where("user_id = ? AND thread_id = ?", userID, threadID).Count(&count)
return count == 1
}
-func GetUsersSubscribedToForumThread(threadID int64) (out []UserForumThreadSubscription, err error) {
+func GetUsersSubscribedToForumThread(threadID ForumThreadID) (out []UserForumThreadSubscription, err error) {
err = DB.Preload("User").Find(&out, "thread_id = ?", threadID).Error
return
}
diff --git a/pkg/database/table_forum_threads.go b/pkg/database/table_forum_threads.go
@@ -15,19 +15,24 @@ import (
bf "github.com/russross/blackfriday/v2"
)
+type ForumCategoryID int64
+
type ForumCategory struct {
- ID int64
+ ID ForumCategoryID
Idx int64
Name string
Slug string
}
+type ForumThreadID int64
+type ForumThreadUUID string
+
type ForumThread struct {
- ID int64
- UUID string
+ ID ForumThreadID
+ UUID ForumThreadUUID
Name string
UserID UserID
- CategoryID int64
+ CategoryID ForumCategoryID
CreatedAt time.Time
User User
Category ForumCategory
@@ -49,19 +54,23 @@ func GetForumCategoryBySlug(slug string) (out ForumCategory, err error) {
return
}
+type ForumMessageID int64
+
+type ForumMessageUUID string
+
type ForumMessage struct {
- ID int64
- UUID string
+ ID ForumMessageID
+ UUID ForumMessageUUID
Message string
UserID UserID
- ThreadID int64
+ ThreadID ForumThreadID
CreatedAt time.Time
User User
}
type ForumReadRecord struct {
UserID UserID
- ThreadID int64
+ ThreadID ForumThreadID
ReadAt time.Time
}
@@ -145,21 +154,21 @@ func (m *ForumMessage) Escape() string {
return res
}
-func GetForumMessage(messageID int64) (out ForumMessage, err error) {
+func GetForumMessage(messageID ForumMessageID) (out ForumMessage, err error) {
err = DB.First(&out, "id = ?", messageID).Error
return
}
-func GetForumMessageByUUID(messageUUID string) (out ForumMessage, err error) {
+func GetForumMessageByUUID(messageUUID ForumMessageUUID) (out ForumMessage, err error) {
err = DB.First(&out, "uuid = ?", messageUUID).Error
return
}
-func DeleteForumMessageByID(messageID int64) error {
+func DeleteForumMessageByID(messageID ForumMessageID) error {
return DB.Where("id = ?", messageID).Delete(&ForumMessage{}).Error
}
-func DeleteForumThreadByID(threadID int64) error {
+func DeleteForumThreadByID(threadID ForumThreadID) error {
return DB.Where("id = ?", threadID).Delete(&ForumThread{}).Error
}
@@ -168,17 +177,17 @@ func (m *ForumMessage) CanEdit() bool {
return true
}
-func GetForumThread(threadID int64) (out ForumThread, err error) {
+func GetForumThread(threadID ForumThreadID) (out ForumThread, err error) {
err = DB.First(&out, "id = ? AND is_club = 1", threadID).Error
return
}
-func GetForumThreadByID(threadID int64) (out ForumThread, err error) {
+func GetForumThreadByID(threadID ForumThreadID) (out ForumThread, err error) {
err = DB.First(&out, "id = ? AND is_club = 0", threadID).Error
return
}
-func GetForumThreadByUUID(threadUUID string) (out ForumThread, err error) {
+func GetForumThreadByUUID(threadUUID ForumThreadUUID) (out ForumThread, err error) {
err = DB.First(&out, "uuid = ? AND is_club = 0", threadUUID).Error
return
}
@@ -219,7 +228,7 @@ ORDER BY t.id DESC`, userID).Scan(&out).Error
return
}
-func GetPublicForumCategoryThreads(userID UserID, categoryID int64) (out []ForumThreadAug, err error) {
+func GetPublicForumCategoryThreads(userID UserID, categoryID ForumCategoryID) (out []ForumThreadAug, err error) {
err = DB.Raw(`SELECT t.*,
u.username as author,
u.chat_color as author_chat_color,
@@ -245,7 +254,7 @@ ORDER BY m.created_at DESC, t.id DESC`, userID, categoryID).Scan(&out).Error
return
}
-func GetPublicForumThreadsSearch(userID int64) (out []ForumThreadAug, err error) {
+func GetPublicForumThreadsSearch(userID UserID) (out []ForumThreadAug, err error) {
err = DB.Raw(`SELECT t.*,
u.username as author,
u.chat_color as author_chat_color,
@@ -271,7 +280,7 @@ ORDER BY m.created_at DESC, t.id DESC`, userID).Scan(&out).Error
return
}
-func GetThreadMessages(threadID int64) (out []ForumMessage, err error) {
+func GetThreadMessages(threadID ForumThreadID) (out []ForumMessage, err error) {
err = DB.Preload("User").Find(&out, "thread_id = ?", threadID).Error
return
}
diff --git a/pkg/web/handlers/api/v1/handlers.go b/pkg/web/handlers/api/v1/handlers.go
@@ -236,7 +236,7 @@ func UnsubscribeHandler(c echo.Context) error {
func ThreadSubscribeHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
- threadUUID := c.Param("threadUUID")
+ threadUUID := database.ForumThreadUUID(c.Param("threadUUID"))
thread, err := database.GetForumThreadByUUID(threadUUID)
if err != nil {
return c.Redirect(http.StatusFound, c.Request().Referer())
@@ -247,7 +247,7 @@ func ThreadSubscribeHandler(c echo.Context) error {
func ThreadUnsubscribeHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
- threadUUID := c.Param("threadUUID")
+ threadUUID := database.ForumThreadUUID(c.Param("threadUUID"))
thread, err := database.GetForumThreadByUUID(threadUUID)
if err != nil {
return c.Redirect(http.StatusFound, c.Request().Referer())
@@ -319,7 +319,7 @@ func ChatDeleteMessageHandler(c echo.Context) error {
func ClubDeleteMessageHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
- messageID := utils.DoParseInt64(c.Param("messageID"))
+ messageID := database.ForumMessageID(utils.DoParseInt64(c.Param("messageID")))
msg, err := database.GetForumMessage(messageID)
if err != nil {
return c.Redirect(http.StatusFound, c.Request().Referer())
diff --git a/pkg/web/handlers/handlers.go b/pkg/web/handlers/handlers.go
@@ -1641,7 +1641,7 @@ func ForumCategoryHandler(c echo.Context) error {
func ThreadHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
- threadUUID := c.Param("threadUUID")
+ threadUUID := database.ForumThreadUUID(c.Param("threadUUID"))
thread, err := database.GetForumThreadByUUID(threadUUID)
if err != nil {
return c.Redirect(http.StatusFound, "/")
@@ -1768,7 +1768,7 @@ func ThreadReplyHandler(c echo.Context) error {
return c.Render(http.StatusOK, "flash", FlashResponse{Message: "Account must be at least 3 days old", Redirect: c.Request().Referer(), Type: "alert-danger"})
}
- threadUUID := c.Param("threadUUID")
+ threadUUID := database.ForumThreadUUID(c.Param("threadUUID"))
thread, err := database.GetForumThreadByUUID(threadUUID)
if err != nil {
return c.Redirect(http.StatusFound, "/")
@@ -1782,7 +1782,7 @@ func ThreadReplyHandler(c echo.Context) error {
data.ErrorMessage = "Message must have at least 3 characters"
return c.Render(http.StatusOK, "thread-reply", data)
}
- message := database.ForumMessage{UUID: uuid.New().String(), Message: data.Message, UserID: authUser.ID, ThreadID: thread.ID}
+ message := database.ForumMessage{UUID: database.ForumMessageUUID(uuid.New().String()), Message: data.Message, UserID: authUser.ID, ThreadID: thread.ID}
if err := database.DB.Create(&message).Error; err != nil {
logrus.Error(err)
}
@@ -1794,7 +1794,7 @@ func ThreadReplyHandler(c echo.Context) error {
database.CreateNotification(msg, sub.UserID)
}
}
- return c.Redirect(http.StatusFound, "/t/"+thread.UUID)
+ return c.Redirect(http.StatusFound, "/t/"+string(thread.UUID))
}
return c.Render(http.StatusOK, "thread-reply", data)
@@ -1802,7 +1802,7 @@ func ThreadReplyHandler(c echo.Context) error {
func ClubThreadReplyHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
- threadID := utils.DoParseInt64(c.Param("threadID"))
+ threadID := database.ForumThreadID(utils.DoParseInt64(c.Param("threadID")))
thread, err := database.GetForumThread(threadID)
if err != nil {
return c.Redirect(http.StatusFound, "/")
@@ -1817,9 +1817,9 @@ func ClubThreadReplyHandler(c echo.Context) error {
data.ErrorMessage = "Message must have at least 3 characters"
return c.Render(http.StatusOK, "club.new-thread", data)
}
- message := database.ForumMessage{UUID: uuid.New().String(), Message: data.Message, UserID: authUser.ID, ThreadID: thread.ID}
+ message := database.ForumMessage{UUID: database.ForumMessageUUID(uuid.New().String()), Message: data.Message, UserID: authUser.ID, ThreadID: thread.ID}
database.DB.Create(&message)
- return c.Redirect(http.StatusFound, "/club/threads/"+utils.FormatInt64(thread.ID))
+ return c.Redirect(http.StatusFound, "/club/threads/"+utils.FormatInt64(int64(thread.ID)))
}
return c.Render(http.StatusOK, "club.thread-reply", data)
@@ -1833,7 +1833,7 @@ func ThreadDeleteMessageHandler(c echo.Context) error {
if !authUser.CanUseForumFn() {
return c.Render(http.StatusOK, "flash", FlashResponse{Message: "Account must be at least 3 days old", Redirect: c.Request().Referer(), Type: "alert-danger"})
}
- messageUUID := c.Param("messageUUID")
+ messageUUID := database.ForumMessageUUID(c.Param("messageUUID"))
msg, err := database.GetForumMessageByUUID(messageUUID)
if err != nil {
return c.Redirect(http.StatusFound, c.Request().Referer())
@@ -1858,7 +1858,7 @@ func ThreadDeleteMessageHandler(c echo.Context) error {
if err := database.DeleteForumMessageByID(msg.ID); err != nil {
logrus.Error(err)
}
- return c.Redirect(http.StatusFound, "/t/"+data.Thread.UUID)
+ return c.Redirect(http.StatusFound, "/t/"+string(data.Thread.UUID))
}
return c.Render(http.StatusOK, "thread-message-delete", data)
@@ -1960,7 +1960,7 @@ func ThreadEditHandler(c echo.Context) error {
if !authUser.IsAdmin {
return c.Redirect(http.StatusFound, c.Request().Referer())
}
- threadUUID := c.Param("threadUUID")
+ threadUUID := database.ForumThreadUUID(c.Param("threadUUID"))
thread, err := database.GetForumThreadByUUID(threadUUID)
if err != nil {
return c.Redirect(http.StatusFound, c.Request().Referer())
@@ -1969,7 +1969,7 @@ func ThreadEditHandler(c echo.Context) error {
data.Thread = thread
if c.Request().Method == http.MethodPost {
- thread.CategoryID = utils.DoParseInt64(c.Request().PostFormValue("category_id"))
+ thread.CategoryID = database.ForumCategoryID(utils.DoParseInt64(c.Request().PostFormValue("category_id")))
thread.DoSave()
return c.Redirect(http.StatusFound, "/forum")
}
@@ -1985,7 +1985,7 @@ func ThreadDeleteHandler(c echo.Context) error {
if !authUser.CanUseForumFn() {
return c.Render(http.StatusOK, "flash", FlashResponse{Message: "Account must be at least 3 days old", Redirect: c.Request().Referer(), Type: "alert-danger"})
}
- threadUUID := c.Param("threadUUID")
+ threadUUID := database.ForumThreadUUID(c.Param("threadUUID"))
thread, err := database.GetForumThreadByUUID(threadUUID)
if err != nil {
return c.Redirect(http.StatusFound, c.Request().Referer())
@@ -2016,8 +2016,8 @@ func ThreadEditMessageHandler(c echo.Context) error {
if !authUser.CanUseForumFn() {
return c.Render(http.StatusOK, "flash", FlashResponse{Message: "Account must be at least 3 days old", Redirect: c.Request().Referer(), Type: "alert-danger"})
}
- threadUUID := c.Param("threadUUID")
- messageUUID := c.Param("messageUUID")
+ threadUUID := database.ForumThreadUUID(c.Param("threadUUID"))
+ messageUUID := database.ForumMessageUUID(c.Param("messageUUID"))
thread, err := database.GetForumThreadByUUID(threadUUID)
if err != nil {
return c.Redirect(http.StatusFound, "/")
@@ -2042,7 +2042,7 @@ func ThreadEditMessageHandler(c echo.Context) error {
}
msg.Message = data.Message
msg.DoSave()
- return c.Redirect(http.StatusFound, "/t/"+thread.UUID)
+ return c.Redirect(http.StatusFound, "/t/"+string(thread.UUID))
}
return c.Render(http.StatusOK, "thread-reply", data)
@@ -2050,8 +2050,8 @@ func ThreadEditMessageHandler(c echo.Context) error {
func ClubThreadEditMessageHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
- threadID := utils.DoParseInt64(c.Param("threadID"))
- messageID := utils.DoParseInt64(c.Param("messageID"))
+ threadID := database.ForumThreadID(utils.DoParseInt64(c.Param("threadID")))
+ messageID := database.ForumMessageID(utils.DoParseInt64(c.Param("messageID")))
thread, err := database.GetForumThread(threadID)
if err != nil {
return c.Redirect(http.StatusFound, "/")
@@ -2077,7 +2077,7 @@ func ClubThreadEditMessageHandler(c echo.Context) error {
}
msg.Message = data.Message
msg.DoSave()
- return c.Redirect(http.StatusFound, "/club/threads/"+utils.FormatInt64(thread.ID))
+ return c.Redirect(http.StatusFound, "/club/threads/"+utils.FormatInt64(int64(thread.ID)))
}
return c.Render(http.StatusOK, "club.thread-reply", data)
@@ -2104,12 +2104,12 @@ func NewThreadHandler(c echo.Context) error {
data.ErrorMessage = "Thread message must have at least 3-20000 characters"
return c.Render(http.StatusOK, "new-thread", data)
}
- thread := database.ForumThread{UUID: uuid.New().String(), Name: data.ThreadName, UserID: authUser.ID, CategoryID: 1}
+ thread := database.ForumThread{UUID: database.ForumThreadUUID(uuid.New().String()), Name: data.ThreadName, UserID: authUser.ID, CategoryID: 1}
database.DB.Create(&thread)
- message := database.ForumMessage{UUID: uuid.New().String(), Message: data.Message, UserID: authUser.ID, ThreadID: thread.ID}
+ message := database.ForumMessage{UUID: database.ForumMessageUUID(uuid.New().String()), Message: data.Message, UserID: authUser.ID, ThreadID: thread.ID}
database.DB.Create(&message)
_ = database.SubscribeToForumThread(authUser.ID, thread.ID)
- return c.Redirect(http.StatusFound, "/t/"+thread.UUID)
+ return c.Redirect(http.StatusFound, "/t/"+string(thread.UUID))
}
return c.Render(http.StatusOK, "new-thread", data)
@@ -2131,11 +2131,11 @@ func ClubNewThreadHandler(c echo.Context) error {
data.ErrorMessage = "Thread name must have at least 3 characters"
return c.Render(http.StatusOK, "club.new-thread", data)
}
- thread := database.ForumThread{UUID: uuid.New().String(), Name: data.ThreadName, UserID: authUser.ID}
+ thread := database.ForumThread{UUID: database.ForumThreadUUID(uuid.New().String()), Name: data.ThreadName, UserID: authUser.ID}
database.DB.Create(&thread)
- message := database.ForumMessage{UUID: uuid.New().String(), Message: data.Message, UserID: authUser.ID, ThreadID: thread.ID}
+ message := database.ForumMessage{UUID: database.ForumMessageUUID(uuid.New().String()), Message: data.Message, UserID: authUser.ID, ThreadID: thread.ID}
database.DB.Create(&message)
- return c.Redirect(http.StatusFound, "/club/threads/"+utils.FormatInt64(thread.ID))
+ return c.Redirect(http.StatusFound, "/club/threads/"+utils.FormatInt64(int64(thread.ID)))
}
return c.Render(http.StatusOK, "club.new-thread", data)
@@ -2150,7 +2150,7 @@ func ClubMembersHandler(c echo.Context) error {
func ClubThreadHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
- threadID := utils.DoParseInt64(c.Param("threadID"))
+ threadID := database.ForumThreadID(utils.DoParseInt64(c.Param("threadID")))
thread, err := database.GetForumThread(threadID)
if err != nil {
return c.Redirect(http.StatusFound, "/")