dkforest

A forum and chat platform (onion)
git clone https://git.dasho.dev/n0tr1v/dkforest.git
Log | Files | Refs | LICENSE

commit 79e79afef727561361b6944139c09d693fea8944
parent c07c542c3830bc897a3289eaf143e77e93e21a29
Author: n0tr1v <n0tr1v@protonmail.com>
Date:   Fri,  3 Mar 2023 03:10:02 -0800

get rid of database global variable and all side effects

Diffstat:
Mpkg/actions/actions.go | 72+++++++++++++++++++++++++++++-------------------------------------------
Mpkg/database/database.go | 236++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-----
Mpkg/database/tableAuditLog.go | 8++++----
Mpkg/database/tableBadges.go | 8++++----
Mpkg/database/tableCaptchaRequests.go | 4++--
Mpkg/database/tableChatInbox.go | 34+++++++++++++++++-----------------
Mpkg/database/tableChatMessages.go | 70+++++++++++++++++++++++++++++++++++-----------------------------------
Mpkg/database/tableChatReactions.go | 8++++----
Mpkg/database/tableChatReadMarkers.go | 10+++++-----
Mpkg/database/tableChatRoomGroups.go | 48++++++++++++++++++++++++------------------------
Mpkg/database/tableChatRoomWhitelistedUsers.go | 20++++++++++----------
Mpkg/database/tableChatRooms.go | 53+++++++++++++++++++++++++++--------------------------
Mpkg/database/tableDownloads.go | 12++++++------
Mpkg/database/tableFiledrops.go | 24++++++++++++------------
Mpkg/database/tableGists.go | 8++++----
Mpkg/database/tableIgnoredMessages.go | 8++++----
Mpkg/database/tableIgnoredUsers.go | 16++++++++--------
Mpkg/database/tableInvitations.go | 24++++++++++++------------
Mpkg/database/tableKarmaHistory.go | 4++--
Mpkg/database/tableLinks.go | 104++++++++++++++++++++++++++++++++++++++++----------------------------------------
Mpkg/database/tableNotifications.go | 44++++++++++++++++++++++----------------------
Mpkg/database/tableOnionBlacklist.go | 4++--
Mpkg/database/tablePmBlacklistedUsers.go | 28++++++++++++++--------------
Mpkg/database/tablePmWhitelistedUsers.go | 24++++++++++++------------
Mpkg/database/tableProhibitedPasswords.go | 4++--
Mpkg/database/tableSecurityLogs.go | 12++++++------
Mpkg/database/tableSessions.go | 34+++++++++++++++++-----------------
Mpkg/database/tableSettings.go | 14+++++++-------
Mpkg/database/tableSnippets.go | 12++++++------
Mpkg/database/tableUploads.go | 40++++++++++++++++++++--------------------
Mpkg/database/tableUserForumThreadSubscriptions.go | 20++++++++++----------
Mpkg/database/tableUserPrivateNotes.go | 10+++++-----
Mpkg/database/tableUserPublicNotes.go | 10+++++-----
Mpkg/database/tableUserRoomSubscriptions.go | 16++++++++--------
Mpkg/database/tableUsers.go | 134++++++++++++++++++++++++++++++++++++++++----------------------------------------
Mpkg/database/tableXmrInvoices.go | 14+++++++-------
Mpkg/database/table_forum_threads.go | 78+++++++++++++++++++++++++++++++++++++++---------------------------------------
Mpkg/database/utils/utils.go | 40++++++++++++++++++++--------------------
Mpkg/global/global.go | 8++++----
Mpkg/template/templates.go | 6+++++-
Mpkg/web/handlers/admin.go | 93+++++++++++++++++++++++++++++++++++++++++++++++--------------------------------
Mpkg/web/handlers/api/v1/bangInterceptor.go | 4++--
Mpkg/web/handlers/api/v1/battleship.go | 15++++++++-------
Mpkg/web/handlers/api/v1/chess.go | 25+++++++++++++------------
Mpkg/web/handlers/api/v1/handlers.go | 120++++++++++++++++++++++++++++++++++++++++++++++---------------------------------
Mpkg/web/handlers/api/v1/msgInterceptor.go | 50+++++++++++++++++++++++++-------------------------
Mpkg/web/handlers/api/v1/slashInterceptor.go | 206++++++++++++++++++++++++++++++++++++++++----------------------------------------
Mpkg/web/handlers/api/v1/snippetInterceptor.go | 6+++---
Mpkg/web/handlers/api/v1/spamInterceptor.go | 18+++++++++---------
Mpkg/web/handlers/api/v1/topBarHandler.go | 82+++++++++++++++++++++++++++++++++++++++++--------------------------------------
Mpkg/web/handlers/api/v1/uploadInterceptor.go | 8++++----
Mpkg/web/handlers/api/v1/werewolf.go | 54++++++++++++++++++++++++++++--------------------------
Mpkg/web/handlers/chat.go | 57+++++++++++++++++++++++++++++----------------------------
Mpkg/web/handlers/club.go | 30++++++++++++++++++------------
Mpkg/web/handlers/handlers.go | 566+++++++++++++++++++++++++++++++++++++++++++++----------------------------------
Mpkg/web/middlewares/middlewares.go | 17++++++++++++++---
Mpkg/web/public/views/pages/club/thread.gohtml | 2+-
Mpkg/web/public/views/pages/thread.gohtml | 2+-
Mpkg/web/web.go | 25++++++++++++++-----------
59 files changed, 1529 insertions(+), 1174 deletions(-)

diff --git a/pkg/actions/actions.go b/pkg/actions/actions.go @@ -47,22 +47,21 @@ func Start(c *cli.Context) error { ensureProjectHome() - initDB() - defer database.DB.Close() + db := database.DB() - runMigrations() + runMigrations(db) - config.IsFirstUse.Store(isFirstUse()) + config.IsFirstUse.Store(isFirstUse(db)) captcha.SetCustomStore(captcha.NewMemoryStore(captcha.CollectNum, 120*time.Second)) if false { - err := database.DB.Debug().Exec(` + err := db.DB().Debug().Exec(` `).Error logrus.Error(err) } - settings := database.GetSettings() + settings := db.GetSettings() config.ProtectHome.Store(settings.ProtectHome) config.HomeUsersList.Store(settings.HomeUsersList) config.ForceLoginCaptcha.Store(settings.ForceLoginCaptcha) @@ -76,22 +75,22 @@ func Start(c *cli.Context) error { config.Xmr() - utils.SGo(func() { cleanupDatabase() }) + utils.SGo(func() { cleanupDatabase(db) }) utils.SGo(func() { managers.ActiveUsers.CleanupUsersCache() }) - utils.SGo(func() { xmrWatch() }) + utils.SGo(func() { xmrWatch(db) }) utils.SGo(func() { openBrowser(noBrowser, int64(port)) }) - v1.ChessInstance = v1.NewChess() - v1.BattleshipInstance = v1.NewBattleship() - v1.WWInstance = v1.NewWerewolf() + v1.ChessInstance = v1.NewChess(db) + v1.BattleshipInstance = v1.NewBattleship(db) + v1.WWInstance = v1.NewWerewolf(db) - web.Start(host, port) + web.Start(db, host, port) return nil } -func isFirstUse() bool { +func isFirstUse(db *database.DkfDB) bool { var count int64 - database.DB.Model(database.User{}).Count(&count) + db.DB().Model(database.User{}).Count(&count) return count <= 0 } @@ -105,7 +104,7 @@ func openBrowser(noBrowser bool, port int64) { } } -func runMigrations() { +func runMigrations(db *database.DkfDB) { logrus.Info("running migrations") migrations := &migrate.AssetMigrationSource{ Asset: config.MigrationsFs.ReadFile, @@ -119,25 +118,15 @@ func runMigrations() { }, Dir: "migrations", } - database.DB.Exec("PRAGMA foreign_keys=OFF") - n, err := migrate.Exec(database.DB.DB(), "sqlite3", migrations, migrate.Up) + db.DB().Exec("PRAGMA foreign_keys=OFF") + n, err := migrate.Exec(db.DB().DB(), "sqlite3", migrations, migrate.Up) if err != nil { panic(err) } - database.DB.Exec("PRAGMA foreign_keys=ON") + db.DB().Exec("PRAGMA foreign_keys=ON") logrus.Infof("applied %d migrations", n) } -func initDB() { - dbPath := filepath.Join(config.Global.ProjectPath(), config.DbFileName) - db, err := database.OpenSqlite3DB(dbPath) - if err != nil { - logrus.Fatal("Failed to open sqlite3 db : " + err.Error()) - return - } - database.DB = db -} - // Ensure the project folder is created properly func ensureProjectHome() { config.Global.SetProjectPath(utils.MustGetDefaultProjectPath()) @@ -215,7 +204,7 @@ func (f LogFormatter) Format(e *logrus.Entry) ([]byte, error) { return buffer.Bytes(), nil } -func xmrWatch() { +func xmrWatch(db *database.DkfDB) { var once utils.Once for { select { @@ -227,7 +216,7 @@ func xmrWatch() { continue } for _, transfer := range transfers.In { - invoice, err := database.GetXmrInvoiceByAddress(transfer.Address) + invoice, err := db.GetXmrInvoiceByAddress(transfer.Address) if err != nil { logrus.Error(err, transfer.TxID) continue @@ -236,7 +225,7 @@ func xmrWatch() { invoice.Confirmations = int64(transfer.Confirmations) amount := int64(transfer.Amount) invoice.AmountReceived = &amount - invoice.DoSave() + invoice.DoSave(db) if origConfirmations >= 10 { continue } else if transfer.Confirmations < 10 { @@ -249,7 +238,7 @@ func xmrWatch() { } } -func cleanupDatabase() { +func cleanupDatabase(db *database.DkfDB) { var once utils.Once for { select { @@ -257,13 +246,13 @@ func cleanupDatabase() { case <-time.After(1 * time.Hour): } start := time.Now() - database.DeleteOldSessions() - database.DeleteOldUploads() - database.DeleteOldChatMessages() - database.DeleteOldPrivateChatRooms() - database.DeleteOldCaptchaRequests() - database.DeleteOldAuditLogs() - database.DeleteOldSecurityLogs() + db.DeleteOldSessions() + db.DeleteOldUploads() + db.DeleteOldChatMessages() + db.DeleteOldPrivateChatRooms() + db.DeleteOldCaptchaRequests() + db.DeleteOldAuditLogs() + db.DeleteOldSecurityLogs() logrus.Debugf("done cleaning database, took %s", time.Since(start)) } } @@ -276,9 +265,6 @@ func BuildProhibitedPasswords(c *cli.Context) error { ensureProjectHome() - initDB() - defer database.DB.Close() - readFile, _ := os.Open("rockyou.txt") fileScanner := bufio.NewScanner(readFile) fileScanner.Split(bufio.ScanLines) @@ -287,7 +273,7 @@ func BuildProhibitedPasswords(c *cli.Context) error { rows = append(rows, database.ProhibitedPassword{Password: fileScanner.Text()}) } readFile.Close() - if err := gormbulk.BulkInsert(database.DB, rows, 10000); err != nil { + if err := gormbulk.BulkInsert(database.DB().DB(), rows, 10000); err != nil { logrus.Error(err) } return nil diff --git a/pkg/database/database.go b/pkg/database/database.go @@ -5,26 +5,236 @@ import ( "fmt" "github.com/jinzhu/gorm" "github.com/mattn/go-sqlite3" + "github.com/sirupsen/logrus" "net/url" "path/filepath" "strings" + "sync" "time" ) -// DB ... -var DB *gorm.DB +type DkfDB struct { + db *gorm.DB +} -// OpenSqlite3DB ... -func OpenSqlite3DB(path string) (*gorm.DB, error) { - db, err := gorm.Open("sqlite3", path) - if err != nil { - return nil, err - } - db.DB().SetMaxIdleConns(1) // 10 - db.DB().SetMaxOpenConns(1) // 25 - db.LogMode(false) - db.Exec("PRAGMA foreign_keys=ON") - return db, nil +func (d *DkfDB) DB() *gorm.DB { + return d.db +} + +// Compile time checks to ensure type satisfies IDkfDB interface +var _ IDkfDB = (*DkfDB)(nil) + +type IDkfDB interface { + AddBlacklistedUser(userID, blacklistedUserID UserID) + AddLinkCategory(linkID, categoryID int64) (err error) + AddLinkTag(linkID, tagID int64) (err error) + AddUserToRoomGroup(roomID RoomID, groupID GroupID, userID UserID) (out ChatRoomUserGroup, err error) + AddWhitelistedUser(userID, whitelistedUserID UserID) + CanRenameTo(oldUsername, newUsername string) error + CanUseUsername(username string, isFirstUser bool) error + ClearRoomGroup(roomID RoomID, groupID GroupID) (err error) + CreateChatReaction(userID UserID, messageID, reaction int64) error + CreateChatRoomGroup(roomID RoomID, name, color string) (out ChatRoomGroup, err error) + CreateDownload(userID UserID, filename string) (out Download, err error) + CreateEncryptedUploadWithSize(fileName string, content []byte, userID UserID, size int64) (*Upload, error) + CreateFiledrop() (out Filedrop, err error) + CreateInboxMessage(msg string, roomID RoomID, fromUserID, toUserID UserID, isPm, moderators bool, msgID *int64) + CreateInvitation(userID UserID) (out Invitation, err error) + CreateKarmaHistory(karma int64, description string, userID UserID, fromUserID *int64) (out KarmaHistory, err error) + CreateKickMsg(kickedUser, kickedByUser User) + CreateLink(url, title, description, shorthand string) (out Link, err error) + CreateLinkMirror(linkID int64, link string) (out LinksMirror, err error) + CreateLinkPgp(linkID int64, title, description, publicKey string) (out LinksPgp, err error) + CreateLinksCategory(category string) (out LinksCategory, err error) + CreateLinksTag(tag string) (out LinksTag, err error) + CreateMsg(raw, txt, roomKey string, roomID RoomID, userID UserID, toUserID *UserID) (out ChatMessage, err error) + CreateNotification(msg string, userID UserID) + CreateOrEditMessage(editMsg *ChatMessage, message, raw, roomKey string, roomID RoomID, fromUserID UserID, toUserID *UserID, upload *Upload, groupID *GroupID, hellbanMsg, modMsg, systemMsg bool) (int64, error) + CreateRoom(name string, passwordHash string, ownerID UserID, isListed bool) (out ChatRoom, err error) + CreateSecurityLog(userID UserID, typ int64) + CreateSession(userID UserID, userAgent string) (Session, error) + CreateSessionNotification(msg string, sessionToken string) + CreateSnippet(userID UserID, name, text string) (out Snippet, err error) + CreateSysMsg(raw, txt, roomKey string, roomID RoomID, userID UserID) error + CreateUnkickMsg(kickedUser, kickedByUser User) + CreateUpload(fileName string, content []byte, userID UserID) (*Upload, error) + CreateUser(username, password, repassword string, registrationDuration int64, signupInfoEnc string) (User, UserErrors) + CreateUserBadge(userID UserID, badgeID int64) error + CreateXmrInvoice(userID UserID, productID int64) (out XmrInvoice, err error) + DeWhitelistUser(roomID RoomID, userID UserID) (err error) + DeleteAllChatInbox(userID UserID) error + DeleteAllNotifications(userID UserID) error + DeleteAllSessionNotifications(sessionToken string) error + DeleteChatInboxMessageByChatMessageID(chatMessageID int64) error + DeleteChatInboxMessageByID(messageID int64) error + DeleteChatMessageByUUID(messageUUID string) error + DeleteChatRoomByID(id RoomID) + DeleteChatRoomGroup(roomID RoomID, name string) (err error) + DeleteChatRoomGroups(roomID RoomID) (err error) + DeleteChatRoomMessages(roomID RoomID) error + DeleteDownloadByID(downloadID int64) (err error) + DeleteForumMessageByID(messageID ForumMessageID) error + DeleteForumThreadByID(threadID ForumThreadID) error + DeleteLinkByID(id int64) error + DeleteLinkCategories(linkID int64) error + DeleteLinkMirrorByID(id int64) error + DeleteLinkPgpByID(id int64) error + DeleteLinkTags(linkID int64) error + DeleteNotificationByID(notificationID int64) error + DeleteOldAuditLogs() + DeleteOldCaptchaRequests() + DeleteOldChatMessages() + DeleteOldPrivateChatRooms() + DeleteOldSecurityLogs() + DeleteOldSessions() + DeleteOldUploads() + DeleteReaction(userID UserID, messageID, reaction int64) error + DeleteSessionByToken(token string) error + DeleteSessionNotificationByID(sessionNotificationID int64) error + DeleteSnippet(userID UserID, name string) + DeleteUserByID(userID UserID) (err error) + DeleteUserChatInboxMessages(userID UserID) error + DeleteUserChatMessages(userID UserID) error + DeleteUserOtherSessions(userID UserID, currentToken string) error + DeleteUserSessionByToken(userID UserID, token string) error + DeleteUserSessions(userID UserID) error + DoCreateSession(userID UserID, userAgent string) Session + GetActiveUserSessions(userID UserID) (out []Session) + GetCategories() (out []CategoriesResult, err error) + GetChatMessages(roomID RoomID, username string, userID UserID, displayPms PmDisplayMode, mentionsOnly, displayHellbanned, displayIgnored, displayModerators, displayIgnoredMessages bool) (out ChatMessages, err error) + GetChatRoomByID(roomID RoomID) (out ChatRoom, err error) + GetChatRoomByName(roomName string) (out ChatRoom, err error) + GetChessSubscribers() (out []User, err error) + GetClubForumThreads(userID UserID) (out []ForumThreadAug, err error) + GetClubMembers() (out []User, err error) + GetFiledropByFileName(fileName string) (out Filedrop, err error) + GetFiledropByUUID(uuid string) (out Filedrop, err error) + GetFiledrops() (out []Filedrop, err error) + GetForumCategories() (out []ForumCategory, err error) + GetForumCategoryBySlug(slug string) (out ForumCategory, err error) + GetForumMessage(messageID ForumMessageID) (out ForumMessage, err error) + GetForumMessageByUUID(messageUUID ForumMessageUUID) (out ForumMessage, err error) + GetForumThread(threadID ForumThreadID) (out ForumThread, err error) + GetForumThreadByID(threadID ForumThreadID) (out ForumThread, err error) + GetForumThreadByUUID(threadUUID ForumThreadUUID) (out ForumThread, err error) + GetForumThreads() (out []ForumThread, err error) + GetGistByUUID(uuid string) (out Gist, err error) + GetIgnoredByUsers(userID UserID) (out []IgnoredUser, err error) + GetIgnoredUsers(userID UserID) (out []IgnoredUser, err error) + GetLinkByID(linkID int64) (out Link, err error) + GetLinkByShorthand(shorthand string) (out Link, err error) + GetLinkByUUID(linkUUID string) (out Link, err error) + GetLinkCategories(linkID int64) (out []LinksCategory, err error) + GetLinkMirrorByID(id int64) (out LinksMirror, err error) + GetLinkMirrors(linkID int64) (out []LinksMirror, err error) + GetLinkPgpByID(id int64) (out LinksPgp, err error) + GetLinkPgps(linkID int64) (out []LinksPgp, err error) + GetLinkTags(linkID int64) (out []LinksTag, err error) + GetLinks() (out []Link, err error) + GetListedChatRooms(userID UserID) (out []ChatRoomAug, err error) + GetModeratorsUsers() (out []User, err error) + GetOfficialChatRooms() (out []ChatRoom, err error) + GetOfficialChatRooms1(userID UserID) (out []ChatRoomAug, err error) + GetOnionBlacklist(hash string) (out OnionBlacklist, err error) + GetPmBlacklistedByUsers(userID UserID) (out []PmBlacklistedUsers, err error) + GetPmBlacklistedUsers(userID UserID) (out []PmBlacklistedUsers, err error) + GetPmWhitelistedUsers(userID UserID) (out []PmWhitelistedUsers, err error) + GetPublicForumCategoryThreads(userID UserID, categoryID ForumCategoryID) (out []ForumThreadAug, err error) + GetPublicForumThreadsSearch(userID UserID) (out []ForumThreadAug, err error) + GetRecentLinks() (out []Link, err error) + GetRecentUsersCount() int64 + GetRoomChatMessageByDate(roomID RoomID, userID UserID, dt time.Time) (out ChatMessage, err error) + GetRoomChatMessageByUUID(roomID RoomID, msgUUID string) (out ChatMessage, err error) + GetRoomChatMessages(roomID RoomID) (out ChatMessages, err error) + GetRoomChatMessagesByDate(roomID RoomID, dt time.Time) (out []ChatMessage, err error) + GetRoomGroupByName(roomID RoomID, groupName string) (out ChatRoomGroup, err error) + GetRoomGroupUsers(roomID RoomID, groupID GroupID) (out []ChatRoomUserGroup, err error) + GetRoomGroups(roomID RoomID) (out []ChatRoomGroup, err error) + GetSecurityLogs(userID UserID) (out []SecurityLog, err error) + GetSettings() (out Settings) + GetThreadMessages(threadID ForumThreadID) (out []ForumMessage, err error) + GetUnusedInvitationByToken(token string) (out Invitation, err error) + GetUploadByFileName(filename string) (out Upload, err error) + GetUploadByID(uploadID UploadID) (out Upload, err error) + GetUploads() (out []Upload, err error) + GetUserByApiKey(user *User, apiKey string) error + GetUserByID(userID UserID) (out User, err error) + GetUserBySessionKey(user *User, sessionKey string) error + GetUserByUsername(username string) (out User, err error) + GetUserChatInboxMessages(userID UserID) (msgs []ChatInboxMessage, err error) + GetUserChatInboxMessagesSent(userID UserID) (msgs []ChatInboxMessage, err error) + GetUserInboxMessagesCount(userID UserID) (count int64) + GetUserInvitations(userID UserID) (out []Invitation, err error) + GetUserLastChatMessageInRoom(userID UserID, roomID RoomID) (out ChatMessage, err error) + GetUserNotifications(userID UserID) (msgs []Notification, err error) + GetUserNotificationsCount(userID UserID) (count int64) + GetUserPrivateNotes(userID UserID) (out UserPrivateNote, err error) + GetUserPublicNotes(userID UserID) (out UserPublicNote, err error) + GetUserReadMarker(userID UserID, roomID RoomID) (out ChatReadMarker, err error) + GetUserRoomGroups(userID UserID, roomID RoomID) (out []ChatRoomUserGroup, err error) + GetUserRoomSubscriptions(userID UserID) (out []ChatRoomAug, err error) + GetUserSessionNotifications(sessionToken string) (msgs []SessionNotification, err error) + GetUserSessionNotificationsCount(sessionToken string) (count int64) + GetUserSnippets(userID UserID) (out []Snippet, err error) + GetUserTotalUploadSize(userID UserID) int64 + GetUserUnusedInvitations(userID UserID) (out []Invitation, err error) + GetUserUploads(userID UserID) (out []Upload, err error) + GetUsersBadges() (out []UserBadge, err error) + GetUsersByID(ids []UserID) (out []User, err error) + GetUsersByUsername(usernames []string) (out []User, err error) + GetUsersSubscribedToForumThread(threadID ForumThreadID) (out []UserForumThreadSubscription, err error) + GetVerifiedUserBySessionID(token string) (out User, err error) + GetVerifiedUserByUsername(username string) (out User, err error) + GetWhitelistedUsers(roomID RoomID) (out []ChatRoomWhitelistedUser, err error) + GetXmrInvoiceByAddress(address string) (out XmrInvoice, err error) + IgnoreMessage(userID UserID, messageID int64) + IgnoreUser(userID, ignoredUserID UserID) + IsPasswordProhibited(password string) bool + IsUserInGroupByID(userID UserID, groupID GroupID) bool + IsUserPmBlacklisted(fromUserID, toUserID UserID) bool + IsUserPmWhitelisted(fromUserID, toUserID UserID) bool + IsUserSubscribedToForumThread(userID UserID, threadID ForumThreadID) bool + IsUserSubscribedToRoom(userID UserID, roomID RoomID) bool + IsUserWhitelistedInRoom(userID UserID, roomID RoomID) bool + IsUsernameAlreadyTaken(username string) bool + NewAudit(authUser User, log string) + RmBlacklistedUser(userID, blacklistedUserID UserID) + RmUserFromRoomGroup(roomID RoomID, groupID GroupID, userID UserID) (err error) + RmWhitelistedUser(userID, whitelistedUserID UserID) + SetUserPrivateNotes(userID UserID, notes string) error + SetUserPublicNotes(userID UserID, notes string) error + SubscribeToForumThread(userID UserID, threadID ForumThreadID) (err error) + SubscribeToRoom(userID UserID, roomID RoomID) (err error) + ToggleBlacklistedUser(userID, blacklistedUserID UserID) bool + ToggleWhitelistedUser(userID, whitelistedUserID UserID) bool + UnIgnoreMessage(userID UserID, messageID int64) + UnIgnoreUser(userID, ignoredUserID UserID) + UnsubscribeFromForumThread(userID UserID, threadID ForumThreadID) (err error) + UnsubscribeFromRoom(userID UserID, roomID RoomID) (err error) + UpdateChatReadMarker(userID UserID, roomID RoomID) + UpdateChatReadRecord(userID UserID, roomID RoomID) + UpdateForumReadRecord(userID UserID, threadID ForumThreadID) + UserNbDownloaded(userID UserID, filename string) (out int64) + WhitelistUser(roomID RoomID, userID UserID) (out ChatRoomWhitelistedUser, err error) +} + +var once sync.Once +var inst *DkfDB + +func DB() *DkfDB { + once.Do(func() { + dbPath := filepath.Join(config.Global.ProjectPath(), config.DbFileName) + db, err := gorm.Open("sqlite3", dbPath) + if err != nil { + logrus.Fatal("Failed to open sqlite3 db : " + err.Error()) + } + db.DB().SetMaxIdleConns(1) // 10 + db.DB().SetMaxOpenConns(1) // 25 + db.LogMode(false) + db.Exec("PRAGMA foreign_keys=ON") + inst = &DkfDB{db: db} + }) + return inst } // DB2 is the SQL database. diff --git a/pkg/database/tableAuditLog.go b/pkg/database/tableAuditLog.go @@ -14,14 +14,14 @@ type AuditLog struct { User User } -func NewAudit(authUser User, log string) { - if err := DB.Create(&AuditLog{UserID: authUser.ID, Log: log}).Error; err != nil { +func (d *DkfDB) NewAudit(authUser User, log string) { + if err := d.db.Create(&AuditLog{UserID: authUser.ID, Log: log}).Error; err != nil { logrus.Error(err) } } -func DeleteOldAuditLogs() { - if err := DB.Delete(AuditLog{}, "created_at < date('now', '-90 Day')").Error; err != nil { +func (d *DkfDB) DeleteOldAuditLogs() { + if err := d.db.Delete(AuditLog{}, "created_at < date('now', '-90 Day')").Error; err != nil { logrus.Error(err) } } diff --git a/pkg/database/tableBadges.go b/pkg/database/tableBadges.go @@ -16,12 +16,12 @@ type UserBadge struct { Badge Badge } -func CreateUserBadge(userID UserID, badgeID int64) error { +func (d *DkfDB) CreateUserBadge(userID UserID, badgeID int64) error { ub := UserBadge{UserID: userID, BadgeID: badgeID} - return DB.Create(&ub).Error + return d.db.Create(&ub).Error } -func GetUsersBadges() (out []UserBadge, err error) { - err = DB.Preload("User").Preload("Badge").Order("created_at").Find(&out).Error +func (d *DkfDB) GetUsersBadges() (out []UserBadge, err error) { + err = d.db.Preload("User").Preload("Badge").Order("created_at").Find(&out).Error return } diff --git a/pkg/database/tableCaptchaRequests.go b/pkg/database/tableCaptchaRequests.go @@ -19,8 +19,8 @@ type CaptchaRequest struct { // return base64.StdEncoding.EncodeToString(r.CaptchaImg) //} -func DeleteOldCaptchaRequests() { - if err := DB.Delete(CaptchaRequest{}, "created_at < date('now', '-90 Day')").Error; err != nil { +func (d *DkfDB) DeleteOldCaptchaRequests() { + if err := d.db.Delete(CaptchaRequest{}, "created_at < date('now', '-90 Day')").Error; err != nil { logrus.Error(err) } } diff --git a/pkg/database/tableChatInbox.go b/pkg/database/tableChatInbox.go @@ -22,8 +22,8 @@ type ChatInboxMessage struct { Room ChatRoom } -func GetUserChatInboxMessages(userID UserID) (msgs []ChatInboxMessage, err error) { - err = DB.Order("id DESC"). +func (d *DkfDB) GetUserChatInboxMessages(userID UserID) (msgs []ChatInboxMessage, err error) { + err = d.db.Order("id DESC"). Limit(50). Preload("User"). Preload("ToUser"). @@ -33,14 +33,14 @@ func GetUserChatInboxMessages(userID UserID) (msgs []ChatInboxMessage, err error for _, msg := range msgs { ids = append(ids, msg.ID) } - if err := DB.Model(&ChatInboxMessage{}).Where("id IN (?)", ids).UpdateColumn("is_read", true).Error; err != nil { + if err := d.db.Model(&ChatInboxMessage{}).Where("id IN (?)", ids).UpdateColumn("is_read", true).Error; err != nil { logrus.Error(err) } return } -func GetUserChatInboxMessagesSent(userID UserID) (msgs []ChatInboxMessage, err error) { - err = DB.Order("id DESC"). +func (d *DkfDB) GetUserChatInboxMessagesSent(userID UserID) (msgs []ChatInboxMessage, err error) { + err = d.db.Order("id DESC"). Limit(50). Preload("User"). Preload("ToUser"). @@ -49,30 +49,30 @@ func GetUserChatInboxMessagesSent(userID UserID) (msgs []ChatInboxMessage, err e return } -func DeleteChatInboxMessageByID(messageID int64) error { - return DB.Where("id = ?", messageID).Delete(&ChatInboxMessage{}).Error +func (d *DkfDB) DeleteChatInboxMessageByID(messageID int64) error { + return d.db.Where("id = ?", messageID).Delete(&ChatInboxMessage{}).Error } -func DeleteChatInboxMessageByChatMessageID(chatMessageID int64) error { - return DB.Where("chat_message_id = ?", chatMessageID).Delete(&ChatInboxMessage{}).Error +func (d *DkfDB) DeleteChatInboxMessageByChatMessageID(chatMessageID int64) error { + return d.db.Where("chat_message_id = ?", chatMessageID).Delete(&ChatInboxMessage{}).Error } -func DeleteAllChatInbox(userID UserID) error { - return DB.Where("to_user_id = ?", userID).Delete(&ChatInboxMessage{}).Error +func (d *DkfDB) DeleteAllChatInbox(userID UserID) error { + return d.db.Where("to_user_id = ?", userID).Delete(&ChatInboxMessage{}).Error } -func DeleteUserChatInboxMessages(userID UserID) error { - return DB.Where("user_id = ?", userID).Delete(&ChatInboxMessage{}).Error +func (d *DkfDB) DeleteUserChatInboxMessages(userID UserID) error { + return d.db.Where("user_id = ?", userID).Delete(&ChatInboxMessage{}).Error } -func CreateInboxMessage(msg string, roomID RoomID, fromUserID, toUserID UserID, isPm, moderators bool, msgID *int64) { +func (d *DkfDB) CreateInboxMessage(msg string, roomID RoomID, fromUserID, toUserID UserID, isPm, moderators bool, msgID *int64) { inbox := ChatInboxMessage{Message: msg, RoomID: roomID, UserID: fromUserID, ToUserID: toUserID, IsPm: isPm, Moderators: moderators, ChatMessageID: msgID} - if err := DB.Create(&inbox).Error; err != nil { + if err := d.db.Create(&inbox).Error; err != nil { logrus.Error(err) } } -func GetUserInboxMessagesCount(userID UserID) (count int64) { - DB.Table("chat_inbox_messages").Where("to_user_id = ? AND is_read = ?", userID, false).Count(&count) +func (d *DkfDB) GetUserInboxMessagesCount(userID UserID) (count int64) { + d.db.Table("chat_inbox_messages").Where("to_user_id = ? AND is_read = ?", userID, false).Count(&count) return } diff --git a/pkg/database/tableChatMessages.go b/pkg/database/tableChatMessages.go @@ -212,14 +212,14 @@ func (m *ChatMessage) TrimMe() string { return "<p>" + strings.TrimPrefix(m.Message, "<p>/me ") } -func (m *ChatMessage) DoSave() { - if err := DB.Save(m).Error; err != nil { +func (m *ChatMessage) DoSave(db *DkfDB) { + if err := db.db.Save(m).Error; err != nil { logrus.Error(err) } } -func GetUserLastChatMessageInRoom(userID UserID, roomID RoomID) (out ChatMessage, err error) { - err = DB. +func (d *DkfDB) GetUserLastChatMessageInRoom(userID UserID, roomID RoomID) (out ChatMessage, err error) { + err = d.db. Where("user_id = ? AND room_id = ?", userID, roomID). Order("id DESC"). Preload("User"). @@ -230,8 +230,8 @@ func GetUserLastChatMessageInRoom(userID UserID, roomID RoomID) (out ChatMessage return } -func GetRoomChatMessages(roomID RoomID) (out ChatMessages, err error) { - err = DB. +func (d *DkfDB) GetRoomChatMessages(roomID RoomID) (out ChatMessages, err error) { + err = d.db. Where("room_id = ?", roomID). Preload("User"). Preload("ToUser"). @@ -241,8 +241,8 @@ func GetRoomChatMessages(roomID RoomID) (out ChatMessages, err error) { return } -func GetRoomChatMessageByUUID(roomID RoomID, msgUUID string) (out ChatMessage, err error) { - err = DB. +func (d *DkfDB) GetRoomChatMessageByUUID(roomID RoomID, msgUUID string) (out ChatMessage, err error) { + err = d.db. Where("room_id = ? AND uuid = ?", roomID, msgUUID). Preload("User"). Preload("ToUser"). @@ -252,8 +252,8 @@ func GetRoomChatMessageByUUID(roomID RoomID, msgUUID string) (out ChatMessage, e return } -func GetRoomChatMessageByDate(roomID RoomID, userID UserID, dt time.Time) (out ChatMessage, err error) { - err = DB. +func (d *DkfDB) GetRoomChatMessageByDate(roomID RoomID, userID UserID, dt time.Time) (out ChatMessage, err error) { + err = d.db. Select("*, strftime('%Y-%m-%d %H:%M:%S', created_at) as created_at1"). Where("room_id = ? AND user_id = ? AND created_at1 = ?", roomID, userID, dt.Format("2006-01-02 15:04:05")). Preload("User"). @@ -263,8 +263,8 @@ func GetRoomChatMessageByDate(roomID RoomID, userID UserID, dt time.Time) (out C return } -func GetRoomChatMessagesByDate(roomID RoomID, dt time.Time) (out []ChatMessage, err error) { - err = DB. +func (d *DkfDB) GetRoomChatMessagesByDate(roomID RoomID, dt time.Time) (out []ChatMessage, err error) { + err = d.db. Select("*, strftime('%m-%d %H:%M:%S', created_at) as created_at1"). Where("room_id = ? AND created_at1 = ?", roomID, dt.Format("01-02 15:04:05")). Preload("User"). @@ -283,11 +283,11 @@ const ( PmNone ) -func GetChatMessages(roomID RoomID, username string, userID UserID, displayPms PmDisplayMode, mentionsOnly, displayHellbanned, displayIgnored, displayModerators, displayIgnoredMessages bool) (out ChatMessages, err error) { +func (d *DkfDB) GetChatMessages(roomID RoomID, username string, userID UserID, displayPms PmDisplayMode, mentionsOnly, displayHellbanned, displayIgnored, displayModerators, displayIgnoredMessages bool) (out ChatMessages, err error) { cmp := func(t, t2 ChatMessage) bool { return t.ID > t2.ID } - q := DB. + q := d.db. Preload("User"). Preload("ToUser"). Preload("Room"). @@ -343,7 +343,7 @@ func GetChatMessages(roomID RoomID, username string, userID UserID, displayPms P //----------- - qg := DB. + qg := d.db. Preload("User"). Preload("ToUser"). Preload("Room"). @@ -393,22 +393,22 @@ func sortedMerge[T any](a, b []T, less func(T, T) bool) []T { return out } -func DeleteChatRoomMessages(roomID RoomID) error { - return DB.Delete(&ChatMessage{}, "room_id = ?", roomID).Error +func (d *DkfDB) DeleteChatRoomMessages(roomID RoomID) error { + return d.db.Delete(&ChatMessage{}, "room_id = ?", roomID).Error } -func DeleteChatMessageByUUID(messageUUID string) error { - return DB.Where("uuid = ?", messageUUID).Delete(&ChatMessage{}).Error +func (d *DkfDB) DeleteChatMessageByUUID(messageUUID string) error { + return d.db.Where("uuid = ?", messageUUID).Delete(&ChatMessage{}).Error } -func DeleteUserChatMessages(userID UserID) error { - return DB.Where("user_id = ?", userID).Delete(&ChatMessage{}).Error +func (d *DkfDB) DeleteUserChatMessages(userID UserID) error { + return d.db.Where("user_id = ?", userID).Delete(&ChatMessage{}).Error } -func DeleteOldChatMessages() { - rooms, _ := GetOfficialChatRooms() +func (d *DkfDB) DeleteOldChatMessages() { + rooms, _ := d.GetOfficialChatRooms() for _, room := range rooms { - DB.Exec(` + d.db.Exec(` DELETE FROM chat_messages -- Don't delete the last 500 "non PM" and "not hellbanned" messages WHERE id NOT IN ( @@ -447,7 +447,7 @@ func makeMsg(raw, txt string, roomID RoomID, userID UserID) ChatMessage { return msg } -func CreateMsg(raw, txt, roomKey string, roomID RoomID, userID UserID, toUserID *UserID) (out ChatMessage, err error) { +func (d *DkfDB) CreateMsg(raw, txt, roomKey string, roomID RoomID, userID UserID, toUserID *UserID) (out ChatMessage, err error) { if roomKey != "" { var err error txt, raw, err = encryptMessages(txt, raw, roomKey) @@ -460,11 +460,11 @@ func CreateMsg(raw, txt, roomKey string, roomID RoomID, userID UserID, toUserID if toUserID != nil { out.ToUserID = toUserID } - err = DB.Create(&out).Error + err = d.db.Create(&out).Error return } -func CreateSysMsg(raw, txt, roomKey string, roomID RoomID, userID UserID) error { +func (d *DkfDB) CreateSysMsg(raw, txt, roomKey string, roomID RoomID, userID UserID) error { if roomKey != "" { var err error txt, raw, err = encryptMessages(txt, raw, roomKey) @@ -474,30 +474,30 @@ func CreateSysMsg(raw, txt, roomKey string, roomID RoomID, userID UserID) error } msg := makeMsg(raw, txt, roomID, userID) msg.System = true - return DB.Create(&msg).Error + return d.db.Create(&msg).Error } -func CreateKickMsg(kickedUser, kickedByUser User) { +func (d *DkfDB) CreateKickMsg(kickedUser, kickedByUser User) { // Display kick message styledUsername := fmt.Sprintf(`<span %s>%s</span>`, kickedUser.GenerateChatStyle(), kickedUser.Username) rawTxt := fmt.Sprintf("%s has been kicked. (%s)", kickedUser.Username, kickedByUser.Username) txt := fmt.Sprintf("%s has been kicked. (%s)", styledUsername, kickedByUser.Username) - if err := CreateSysMsg(rawTxt, txt, "", config.GeneralRoomID, kickedByUser.ID); err != nil { + if err := d.CreateSysMsg(rawTxt, txt, "", config.GeneralRoomID, kickedByUser.ID); err != nil { logrus.Error(err) } } -func CreateUnkickMsg(kickedUser, kickedByUser User) { +func (d *DkfDB) CreateUnkickMsg(kickedUser, kickedByUser User) { // Display unkick message styledUsername := fmt.Sprintf(`<span %s>%s</span>`, kickedUser.GenerateChatStyle(), kickedUser.Username) rawTxt := fmt.Sprintf("%s has been unkicked. (%s)", kickedUser.Username, kickedByUser.Username) txt := fmt.Sprintf("%s has been unkicked. (%s)", styledUsername, kickedByUser.Username) - if err := CreateSysMsg(rawTxt, txt, "", config.GeneralRoomID, kickedByUser.ID); err != nil { + if err := d.CreateSysMsg(rawTxt, txt, "", config.GeneralRoomID, kickedByUser.ID); err != nil { logrus.Error(err) } } -func CreateOrEditMessage( +func (d *DkfDB) CreateOrEditMessage( editMsg *ChatMessage, message, raw, roomKey string, roomID RoomID, @@ -519,7 +519,7 @@ func CreateOrEditMessage( editMsg.Message = message editMsg.RawMessage = raw // Delete inboxes, we'll create new ones bellow - _ = DeleteChatInboxMessageByChatMessageID(editMsg.ID) + _ = d.DeleteChatInboxMessageByChatMessageID(editMsg.ID) } else { msg := makeMsg(raw, message, roomID, fromUserID) editMsg = &msg @@ -532,7 +532,7 @@ func CreateOrEditMessage( editMsg.UploadID = &upload.ID } } - editMsg.DoSave() + editMsg.DoSave(d) return editMsg.ID, nil } diff --git a/pkg/database/tableChatReactions.go b/pkg/database/tableChatReactions.go @@ -12,15 +12,15 @@ type ChatReaction struct { CreatedAt time.Time } -func CreateChatReaction(userID UserID, messageID, reaction int64) error { +func (d *DkfDB) CreateChatReaction(userID UserID, messageID, reaction int64) error { out := ChatReaction{ UserID: userID, MessageID: messageID, Reaction: reaction, } - return DB.Create(&out).Error + return d.db.Create(&out).Error } -func DeleteReaction(userID UserID, messageID, reaction int64) error { - return DB.Delete(ChatReaction{}, "user_id = ? AND message_id = ? AND reaction = ?", userID, messageID, reaction).Error +func (d *DkfDB) DeleteReaction(userID UserID, messageID, reaction int64) error { + return d.db.Delete(ChatReaction{}, "user_id = ? AND message_id = ? AND reaction = ?", userID, messageID, reaction).Error } diff --git a/pkg/database/tableChatReadMarkers.go b/pkg/database/tableChatReadMarkers.go @@ -13,15 +13,15 @@ type ChatReadMarker struct { ReadAt time.Time } -func GetUserReadMarker(userID UserID, roomID RoomID) (out ChatReadMarker, err error) { - err = DB.First(&out, "user_id = ? AND room_id = ?", userID, roomID).Error +func (d *DkfDB) GetUserReadMarker(userID UserID, roomID RoomID) (out ChatReadMarker, err error) { + err = d.db.First(&out, "user_id = ? AND room_id = ?", userID, roomID).Error return } -func UpdateChatReadMarker(userID UserID, roomID RoomID) { +func (d *DkfDB) UpdateChatReadMarker(userID UserID, roomID RoomID) { now := time.Now() - res := DB.Table("chat_read_markers").Where("user_id = ? AND room_id = ?", userID, roomID).Update("read_at", now) + res := d.db.Table("chat_read_markers").Where("user_id = ? AND room_id = ?", userID, roomID).Update("read_at", now) if res.RowsAffected == 0 { - DB.Create(ChatReadMarker{UserID: userID, RoomID: roomID, ReadAt: now}) + d.db.Create(ChatReadMarker{UserID: userID, RoomID: roomID, ReadAt: now}) } } diff --git a/pkg/database/tableChatRoomGroups.go b/pkg/database/tableChatRoomGroups.go @@ -16,8 +16,8 @@ type ChatRoomGroup struct { CreatedAt time.Time } -func (g *ChatRoomGroup) DoSave() { - if err := DB.Save(g).Error; err != nil { +func (g *ChatRoomGroup) DoSave(db *DkfDB) { + if err := db.db.Save(g).Error; err != nil { logrus.Error(err) } } @@ -29,60 +29,60 @@ type ChatRoomUserGroup struct { User User } -func GetUserRoomGroups(userID UserID, roomID RoomID) (out []ChatRoomUserGroup, err error) { - err = DB.Find(&out, "user_id = ? AND room_id = ?", userID, roomID).Error +func (d *DkfDB) GetUserRoomGroups(userID UserID, roomID RoomID) (out []ChatRoomUserGroup, err error) { + err = d.db.Find(&out, "user_id = ? AND room_id = ?", userID, roomID).Error return } -func GetRoomGroupByName(roomID RoomID, groupName string) (out ChatRoomGroup, err error) { - err = DB.First(&out, "room_id = ? AND name = ?", roomID, groupName).Error +func (d *DkfDB) GetRoomGroupByName(roomID RoomID, groupName string) (out ChatRoomGroup, err error) { + err = d.db.First(&out, "room_id = ? AND name = ?", roomID, groupName).Error return } -func IsUserInGroupByID(userID UserID, groupID GroupID) bool { +func (d *DkfDB) IsUserInGroupByID(userID UserID, groupID GroupID) bool { var count int64 - DB.Model(ChatRoomUserGroup{}).Where("group_id = ? AND user_id = ?", groupID, userID).Count(&count) + d.db.Model(ChatRoomUserGroup{}).Where("group_id = ? AND user_id = ?", groupID, userID).Count(&count) return count == 1 } -func DeleteChatRoomGroup(roomID RoomID, name string) (err error) { - err = DB.Delete(&ChatRoomGroup{}, "room_id = ? AND name = ?", roomID, name).Error +func (d *DkfDB) DeleteChatRoomGroup(roomID RoomID, name string) (err error) { + err = d.db.Delete(&ChatRoomGroup{}, "room_id = ? AND name = ?", roomID, name).Error return } -func DeleteChatRoomGroups(roomID RoomID) (err error) { - err = DB.Delete(&ChatRoomGroup{}, "room_id = ?", roomID).Error +func (d *DkfDB) DeleteChatRoomGroups(roomID RoomID) (err error) { + err = d.db.Delete(&ChatRoomGroup{}, "room_id = ?", roomID).Error return } -func CreateChatRoomGroup(roomID RoomID, name, color string) (out ChatRoomGroup, err error) { +func (d *DkfDB) CreateChatRoomGroup(roomID RoomID, name, color string) (out ChatRoomGroup, err error) { out = ChatRoomGroup{Name: name, Color: color, RoomID: roomID} - err = DB.Create(&out).Error + err = d.db.Create(&out).Error return } -func AddUserToRoomGroup(roomID RoomID, groupID GroupID, userID UserID) (out ChatRoomUserGroup, err error) { +func (d *DkfDB) AddUserToRoomGroup(roomID RoomID, groupID GroupID, userID UserID) (out ChatRoomUserGroup, err error) { out = ChatRoomUserGroup{GroupID: groupID, RoomID: roomID, UserID: userID} - err = DB.Create(&out).Error + err = d.db.Create(&out).Error return } -func RmUserFromRoomGroup(roomID RoomID, groupID GroupID, userID UserID) (err error) { - err = DB.Delete(&ChatRoomUserGroup{}, "user_id = ? AND group_id = ? AND room_id = ?", userID, groupID, roomID).Error +func (d *DkfDB) RmUserFromRoomGroup(roomID RoomID, groupID GroupID, userID UserID) (err error) { + err = d.db.Delete(&ChatRoomUserGroup{}, "user_id = ? AND group_id = ? AND room_id = ?", userID, groupID, roomID).Error return } -func ClearRoomGroup(roomID RoomID, groupID GroupID) (err error) { - err = DB.Delete(&ChatRoomUserGroup{}, "group_id = ? AND room_id = ?", groupID, roomID).Error +func (d *DkfDB) ClearRoomGroup(roomID RoomID, groupID GroupID) (err error) { + err = d.db.Delete(&ChatRoomUserGroup{}, "group_id = ? AND room_id = ?", groupID, roomID).Error return } -func GetRoomGroups(roomID RoomID) (out []ChatRoomGroup, err error) { - err = DB.Find(&out, "room_id = ?", roomID).Error +func (d *DkfDB) GetRoomGroups(roomID RoomID) (out []ChatRoomGroup, err error) { + err = d.db.Find(&out, "room_id = ?", roomID).Error return } -func GetRoomGroupUsers(roomID RoomID, groupID GroupID) (out []ChatRoomUserGroup, err error) { - err = DB.Where("room_id = ? AND group_id = ?", roomID, groupID).Preload("User").Find(&out).Error +func (d *DkfDB) GetRoomGroupUsers(roomID RoomID, groupID GroupID) (out []ChatRoomUserGroup, err error) { + err = d.db.Where("room_id = ? AND group_id = ?", roomID, groupID).Preload("User").Find(&out).Error return } diff --git a/pkg/database/tableChatRoomWhitelistedUsers.go b/pkg/database/tableChatRoomWhitelistedUsers.go @@ -13,30 +13,30 @@ type ChatRoomWhitelistedUser struct { User User } -func (r *ChatRoomWhitelistedUser) DoSave() { - if err := DB.Save(r).Error; err != nil { +func (r *ChatRoomWhitelistedUser) DoSave(db *DkfDB) { + if err := db.db.Save(r).Error; err != nil { logrus.Error(err) } } -func IsUserWhitelistedInRoom(userID UserID, roomID RoomID) bool { +func (d *DkfDB) IsUserWhitelistedInRoom(userID UserID, roomID RoomID) bool { var count int64 - DB.Table("chat_room_whitelisted_users").Where("user_id = ? and room_id = ?", userID, roomID).Count(&count) + d.db.Table("chat_room_whitelisted_users").Where("user_id = ? and room_id = ?", userID, roomID).Count(&count) return count == 1 } -func GetWhitelistedUsers(roomID RoomID) (out []ChatRoomWhitelistedUser, err error) { - err = DB.Preload("User").Find(&out, "room_id = ?", roomID).Error +func (d *DkfDB) GetWhitelistedUsers(roomID RoomID) (out []ChatRoomWhitelistedUser, err error) { + err = d.db.Preload("User").Find(&out, "room_id = ?", roomID).Error return } -func WhitelistUser(roomID RoomID, userID UserID) (out ChatRoomWhitelistedUser, err error) { +func (d *DkfDB) WhitelistUser(roomID RoomID, userID UserID) (out ChatRoomWhitelistedUser, err error) { out = ChatRoomWhitelistedUser{UserID: userID, RoomID: roomID} - err = DB.Create(&out).Error + err = d.db.Create(&out).Error return } -func DeWhitelistUser(roomID RoomID, userID UserID) (err error) { - err = DB.Delete(ChatRoomWhitelistedUser{}, "user_id = ? and room_id = ?", userID, roomID).Error +func (d *DkfDB) DeWhitelistUser(roomID RoomID, userID UserID) (err error) { + err = d.db.Delete(ChatRoomWhitelistedUser{}, "user_id = ? and room_id = ?", userID, roomID).Error return } diff --git a/pkg/database/tableChatRooms.go b/pkg/database/tableChatRooms.go @@ -31,7 +31,7 @@ const ( UserWhitelistRoomMode = 1 ) -func CreateRoom(name string, passwordHash string, ownerID UserID, isListed bool) (out ChatRoom, err error) { +func (d *DkfDB) CreateRoom(name string, passwordHash string, ownerID UserID, isListed bool) (out ChatRoom, err error) { out = ChatRoom{ Name: name, Password: passwordHash, @@ -39,7 +39,7 @@ func CreateRoom(name string, passwordHash string, ownerID UserID, isListed bool) IsListed: isListed, IsEphemeral: true, } - err = DB.Create(&out).Error + err = d.db.Create(&out).Error return } @@ -72,8 +72,8 @@ func (r *ChatRoom) IsProtected() bool { return r.Password != "" } -func (r *ChatRoom) DoSave() { - if err := DB.Save(r).Error; err != nil { +func (r *ChatRoom) DoSave(db *DkfDB) { + if err := db.db.Save(r).Error; err != nil { logrus.Error(err) } } @@ -88,6 +88,7 @@ func (r *ChatRoom) IsOfficialRoom() bool { func (r *ChatRoom) HasAccess(c echo.Context) bool { authUser := c.Get("authUser").(*User) + db := c.Get("database").(*DkfDB) if authUser == nil { return false } @@ -99,7 +100,7 @@ func (r *ChatRoom) HasAccess(c echo.Context) bool { } if r.Mode == UserWhitelistRoomMode { if r.OwnerUserID != nil && *r.OwnerUserID != authUser.ID { - if !IsUserWhitelistedInRoom(authUser.ID, r.ID) { + if !db.IsUserWhitelistedInRoom(authUser.ID, r.ID) { return false } } @@ -118,18 +119,18 @@ func (r *ChatRoom) HasAccess(c echo.Context) bool { return true } -func GetChatRoomByID(roomID RoomID) (out ChatRoom, err error) { - err = DB.Where("id = ?", roomID).First(&out).Error +func (d *DkfDB) GetChatRoomByID(roomID RoomID) (out ChatRoom, err error) { + err = d.db.Where("id = ?", roomID).First(&out).Error return } -func GetChatRoomByName(roomName string) (out ChatRoom, err error) { - err = DB.Where("name = ?", roomName).First(&out).Error +func (d *DkfDB) GetChatRoomByName(roomName string) (out ChatRoom, err error) { + err = d.db.Where("name = ?", roomName).First(&out).Error return } -func DeleteChatRoomByID(id RoomID) { - if err := DB.Delete(ChatRoom{}, "id = ?", id).Error; err != nil { +func (d *DkfDB) DeleteChatRoomByID(id RoomID) { + if err := d.db.Delete(ChatRoom{}, "id = ?", id).Error; err != nil { logrus.Error(err) } } @@ -141,8 +142,8 @@ type ChatRoomAug struct { } // GetOfficialChatRooms1 returns official chat rooms with additional information such as "IsUnread" -func GetOfficialChatRooms1(userID UserID) (out []ChatRoomAug, err error) { - err = DB.Raw(`SELECT r.*, +func (d *DkfDB) GetOfficialChatRooms1(userID UserID) (out []ChatRoomAug, err error) { + err = d.db.Raw(`SELECT r.*, COALESCE((rr.read_at < m.created_at), 1) as is_unread FROM chat_rooms r -- Find last message for room @@ -154,8 +155,8 @@ ORDER BY r.id ASC`, userID, userID).Scan(&out).Error return } -func GetUserRoomSubscriptions(userID UserID) (out []ChatRoomAug, err error) { - err = DB.Raw(`SELECT r.*, +func (d *DkfDB) GetUserRoomSubscriptions(userID UserID) (out []ChatRoomAug, err error) { + err = d.db.Raw(`SELECT r.*, COALESCE((rr.read_at < m.created_at), 1) as is_unread FROM user_room_subscriptions s INNER JOIN chat_rooms r ON r.id = s.room_id @@ -168,8 +169,8 @@ ORDER BY r.id ASC`, userID, userID, userID).Scan(&out).Error return } -func GetListedChatRooms(userID UserID) (out []ChatRoomAug, err error) { - err = DB.Raw(`SELECT r.*, +func (d *DkfDB) GetListedChatRooms(userID UserID) (out []ChatRoomAug, err error) { + err = d.db.Raw(`SELECT r.*, u.*, COALESCE((rr.read_at < m.created_at), 1) as is_unread FROM chat_rooms r @@ -184,13 +185,13 @@ ORDER BY r.id ASC`, userID, userID).Scan(&out).Error return } -func GetOfficialChatRooms() (out []ChatRoom, err error) { - err = DB.Where("id IN (1, 2, 3, 4, 14)").Preload("ReadRecord").Find(&out).Error +func (d *DkfDB) GetOfficialChatRooms() (out []ChatRoom, err error) { + err = d.db.Where("id IN (1, 2, 3, 4, 14)").Preload("ReadRecord").Find(&out).Error return } -func DeleteOldPrivateChatRooms() { - DB.Exec(`DELETE FROM chat_rooms +func (d *DkfDB) DeleteOldPrivateChatRooms() { + d.db.Exec(`DELETE FROM chat_rooms WHERE owner_user_id IS NOT NULL AND is_ephemeral = 1 AND ((SELECT chat_messages.created_at FROM chat_messages WHERE chat_messages.room_id = chat_rooms.id ORDER BY chat_messages.ID DESC) < date('now', '-1 Day') @@ -206,16 +207,16 @@ type ChatReadRecord struct { ReadAt time.Time } -func (r *ChatReadRecord) DoSave() { - if err := DB.Save(r).Error; err != nil { +func (r *ChatReadRecord) DoSave(db *DkfDB) { + if err := db.db.Save(r).Error; err != nil { logrus.Error(err) } } -func UpdateChatReadRecord(userID UserID, roomID RoomID) { +func (d *DkfDB) UpdateChatReadRecord(userID UserID, roomID RoomID) { now := time.Now() - res := DB.Table("chat_read_records").Where("user_id = ? AND room_id = ?", userID, roomID).Update("read_at", now) + res := d.db.Table("chat_read_records").Where("user_id = ? AND room_id = ?", userID, roomID).Update("read_at", now) if res.RowsAffected == 0 { - DB.Create(ChatReadRecord{UserID: userID, RoomID: roomID, ReadAt: now}) + d.db.Create(ChatReadRecord{UserID: userID, RoomID: roomID, ReadAt: now}) } } diff --git a/pkg/database/tableDownloads.go b/pkg/database/tableDownloads.go @@ -14,19 +14,19 @@ type Download struct { } // CreateDownload ... -func CreateDownload(userID UserID, filename string) (out Download, err error) { +func (d *DkfDB) CreateDownload(userID UserID, filename string) (out Download, err error) { out = Download{UserID: userID, Filename: filename} - err = DB.Create(&out).Error + err = d.db.Create(&out).Error return } // UserNbDownloaded returns how many times a user downloaded a file -func UserNbDownloaded(userID UserID, filename string) (out int64) { - DB.Table("downloads").Where("user_id = ? AND filename = ?", userID, filename).Count(&out) +func (d *DkfDB) UserNbDownloaded(userID UserID, filename string) (out int64) { + d.db.Table("downloads").Where("user_id = ? AND filename = ?", userID, filename).Count(&out) return } -func DeleteDownloadByID(downloadID int64) (err error) { - err = DB.Unscoped().Delete(Download{}, "id = ?", downloadID).Error +func (d *DkfDB) DeleteDownloadByID(downloadID int64) (err error) { + err = d.db.Unscoped().Delete(Download{}, "id = ?", downloadID).Error return } diff --git a/pkg/database/tableFiledrops.go b/pkg/database/tableFiledrops.go @@ -23,25 +23,25 @@ type Filedrop struct { UpdatedAt *time.Time } -func GetFiledropByUUID(uuid string) (out Filedrop, err error) { - err = DB.First(&out, "uuid = ?", uuid).Error +func (d *DkfDB) GetFiledropByUUID(uuid string) (out Filedrop, err error) { + err = d.db.First(&out, "uuid = ?", uuid).Error return } -func GetFiledropByFileName(fileName string) (out Filedrop, err error) { - err = DB.First(&out, "file_name = ?", fileName).Error +func (d *DkfDB) GetFiledropByFileName(fileName string) (out Filedrop, err error) { + err = d.db.First(&out, "file_name = ?", fileName).Error return } -func GetFiledrops() (out []Filedrop, err error) { - err = DB.Find(&out).Error +func (d *DkfDB) GetFiledrops() (out []Filedrop, err error) { + err = d.db.Find(&out).Error return } -func CreateFiledrop() (out Filedrop, err error) { +func (d *DkfDB) CreateFiledrop() (out Filedrop, err error) { out.UUID = uuid.New().String() out.FileName = utils.MD5([]byte(utils.GenerateToken32())) - err = DB.Save(&out).Error + err = d.db.Save(&out).Error return } @@ -65,20 +65,20 @@ func (d *Filedrop) GetContent() (*os.File, *ucrypto.StreamDecrypter, error) { return f, decrypter, nil } -func (d *Filedrop) Delete() error { +func (d *Filedrop) Delete(db *DkfDB) error { if d.FileName != "" { if err := os.Remove(filepath.Join(config.Global.ProjectFiledropPath(), d.FileName)); err != nil { logrus.Error(err) } } - if err := DB.Delete(&d).Error; err != nil { + if err := db.db.Delete(&d).Error; err != nil { return err } return nil } -func (d *Filedrop) DoSave() { - if err := DB.Save(d).Error; err != nil { +func (d *Filedrop) DoSave(db *DkfDB) { + if err := db.db.Save(d).Error; err != nil { logrus.Error(err) } } diff --git a/pkg/database/tableGists.go b/pkg/database/tableGists.go @@ -21,8 +21,8 @@ type Gist struct { CreatedAt time.Time } -func GetGistByUUID(uuid string) (out Gist, err error) { - err = DB.First(&out, "uuid = ?", uuid).Error +func (d *DkfDB) GetGistByUUID(uuid string) (out Gist, err error) { + err = d.db.First(&out, "uuid = ?", uuid).Error return } @@ -50,8 +50,8 @@ func (g *Gist) HasAccess(c echo.Context) bool { } // DoSave user in the database, ignore error -func (g *Gist) DoSave() { - if err := DB.Save(g).Error; err != nil { +func (g *Gist) DoSave(db *DkfDB) { + if err := db.db.Save(g).Error; err != nil { logrus.Error(err) } } diff --git a/pkg/database/tableIgnoredMessages.go b/pkg/database/tableIgnoredMessages.go @@ -9,15 +9,15 @@ type IgnoredMessage struct { MessageID int64 } -func IgnoreMessage(userID UserID, messageID int64) { +func (d *DkfDB) IgnoreMessage(userID UserID, messageID int64) { ignore := IgnoredMessage{UserID: userID, MessageID: messageID} - if err := DB.Create(&ignore).Error; err != nil { + if err := d.db.Create(&ignore).Error; err != nil { logrus.Error(err) } } -func UnIgnoreMessage(userID UserID, messageID int64) { - if err := DB.Delete(&IgnoredMessage{}, "user_id = ? AND message_id = ?", userID, messageID).Error; err != nil { +func (d *DkfDB) UnIgnoreMessage(userID UserID, messageID int64) { + if err := d.db.Delete(&IgnoredMessage{}, "user_id = ? AND message_id = ?", userID, messageID).Error; err != nil { logrus.Error(err) } } diff --git a/pkg/database/tableIgnoredUsers.go b/pkg/database/tableIgnoredUsers.go @@ -14,26 +14,26 @@ type IgnoredUser struct { IgnoredUser User } -func GetIgnoredUsers(userID UserID) (out []IgnoredUser, err error) { - err = DB.Where("user_id = ?", userID).Preload("IgnoredUser").Find(&out).Error +func (d *DkfDB) GetIgnoredUsers(userID UserID) (out []IgnoredUser, err error) { + err = d.db.Where("user_id = ?", userID).Preload("IgnoredUser").Find(&out).Error return } // GetIgnoredByUsers get a list of people who ignore userID -func GetIgnoredByUsers(userID UserID) (out []IgnoredUser, err error) { - err = DB.Where("ignored_user_id = ?", userID).Find(&out).Error +func (d *DkfDB) GetIgnoredByUsers(userID UserID) (out []IgnoredUser, err error) { + err = d.db.Where("ignored_user_id = ?", userID).Find(&out).Error return } -func IgnoreUser(userID, ignoredUserID UserID) { +func (d *DkfDB) IgnoreUser(userID, ignoredUserID UserID) { ignore := IgnoredUser{UserID: userID, IgnoredUserID: ignoredUserID} - if err := DB.Create(&ignore).Error; err != nil { + if err := d.db.Create(&ignore).Error; err != nil { logrus.Error(err) } } -func UnIgnoreUser(userID, ignoredUserID UserID) { - if err := DB.Delete(IgnoredUser{}, "user_id = ? AND ignored_user_id = ?", userID, ignoredUserID).Error; err != nil { +func (d *DkfDB) UnIgnoreUser(userID, ignoredUserID UserID) { + if err := d.db.Delete(IgnoredUser{}, "user_id = ? AND ignored_user_id = ?", userID, ignoredUserID).Error; err != nil { logrus.Error(err) } } diff --git a/pkg/database/tableInvitations.go b/pkg/database/tableInvitations.go @@ -16,38 +16,38 @@ type Invitation struct { } // Save user in the database -func (i *Invitation) Save() error { - return DB.Save(i).Error +func (i *Invitation) Save(db *DkfDB) error { + return db.db.Save(i).Error } // DoSave user in the database, ignore error -func (i *Invitation) DoSave() { - if err := DB.Save(i).Error; err != nil { +func (i *Invitation) DoSave(db *DkfDB) { + if err := db.db.Save(i).Error; err != nil { logrus.Error(err) } } -func CreateInvitation(userID UserID) (out Invitation, err error) { +func (d *DkfDB) CreateInvitation(userID UserID) (out Invitation, err error) { out = Invitation{ Token: utils.GenerateToken32(), OwnerUserID: userID, InviteeUserID: 1, } - err = DB.Create(&out).Error + err = d.db.Create(&out).Error return } -func GetUnusedInvitationByToken(token string) (out Invitation, err error) { - err = DB.First(&out, "token = ? AND invitee_user_id == 1", token).Error +func (d *DkfDB) GetUnusedInvitationByToken(token string) (out Invitation, err error) { + err = d.db.First(&out, "token = ? AND invitee_user_id == 1", token).Error return } -func GetUserInvitations(userID UserID) (out []Invitation, err error) { - err = DB.Find(&out, "owner_user_id = ?", userID).Error +func (d *DkfDB) GetUserInvitations(userID UserID) (out []Invitation, err error) { + err = d.db.Find(&out, "owner_user_id = ?", userID).Error return } -func GetUserUnusedInvitations(userID UserID) (out []Invitation, err error) { - err = DB.Find(&out, "owner_user_id = ? AND invitee_user_id == 1", userID).Error +func (d *DkfDB) GetUserUnusedInvitations(userID UserID) (out []Invitation, err error) { + err = d.db.Find(&out, "owner_user_id = ? AND invitee_user_id == 1", userID).Error return } diff --git a/pkg/database/tableKarmaHistory.go b/pkg/database/tableKarmaHistory.go @@ -11,14 +11,14 @@ type KarmaHistory struct { CreatedAt time.Time } -func CreateKarmaHistory(karma int64, description string, userID UserID, fromUserID *int64) (out KarmaHistory, err error) { +func (d *DkfDB) CreateKarmaHistory(karma int64, description string, userID UserID, fromUserID *int64) (out KarmaHistory, err error) { out = KarmaHistory{ Karma: karma, Description: description, UserID: userID, FromUserID: fromUserID, } - err = DB.Create(&out).Error + err = d.db.Create(&out).Error return } diff --git a/pkg/database/tableLinks.go b/pkg/database/tableLinks.go @@ -49,51 +49,51 @@ func (l Link) DescriptionSafe() string { return html.EscapeString(l.Description) } -func (l *Link) Save() error { - return DB.Save(l).Error +func (l *Link) Save(db *DkfDB) error { + return db.db.Save(l).Error } -func (l *Link) DoSave() { - if err := DB.Save(l).Error; err != nil { +func (l *Link) DoSave(db *DkfDB) { + if err := l.Save(db); err != nil { logrus.Error(err) } } -func CreateLink(url, title, description, shorthand string) (out Link, err error) { +func (d *DkfDB) CreateLink(url, title, description, shorthand string) (out Link, err error) { out = Link{UUID: uuid.New().String(), URL: url, Title: title, Description: description} if shorthand != "" { out.Shorthand = &shorthand } - err = DB.FirstOrCreate(&out, "url = ?", url).Error + err = d.db.FirstOrCreate(&out, "url = ?", url).Error return } -func DeleteLinkByID(id int64) error { - return DB.Where("id = ?", id).Delete(&Link{}).Error +func (d *DkfDB) DeleteLinkByID(id int64) error { + return d.db.Where("id = ?", id).Delete(&Link{}).Error } -func GetLinks() (out []Link, err error) { - err = DB.Find(&out).Error +func (d *DkfDB) GetLinks() (out []Link, err error) { + err = d.db.Find(&out).Error return } -func GetRecentLinks() (out []Link, err error) { - err = DB.Order("id DESC").Limit(100).Find(&out).Error +func (d *DkfDB) GetRecentLinks() (out []Link, err error) { + err = d.db.Order("id DESC").Limit(100).Find(&out).Error return } -func GetLinkByShorthand(shorthand string) (out Link, err error) { - err = DB.Preload("OwnerUser").First(&out, "shorthand = ?", shorthand).Error +func (d *DkfDB) GetLinkByShorthand(shorthand string) (out Link, err error) { + err = d.db.Preload("OwnerUser").First(&out, "shorthand = ?", shorthand).Error return } -func GetLinkByUUID(linkUUID string) (out Link, err error) { - err = DB.Preload("OwnerUser").First(&out, "uuid = ?", linkUUID).Error +func (d *DkfDB) GetLinkByUUID(linkUUID string) (out Link, err error) { + err = d.db.Preload("OwnerUser").First(&out, "uuid = ?", linkUUID).Error return } -func GetLinkByID(linkID int64) (out Link, err error) { - err = DB.First(&out, "id = ?", linkID).Error +func (d *DkfDB) GetLinkByID(linkID int64) (out Link, err error) { + err = d.db.First(&out, "id = ?", linkID).Error return } @@ -102,9 +102,9 @@ type LinksCategory struct { Name string } -func CreateLinksCategory(category string) (out LinksCategory, err error) { +func (d *DkfDB) CreateLinksCategory(category string) (out LinksCategory, err error) { out = LinksCategory{Name: category} - err = DB.FirstOrCreate(&out, "name = ?", category).Error + err = d.db.FirstOrCreate(&out, "name = ?", category).Error return } @@ -113,9 +113,9 @@ type LinksTag struct { Name string } -func CreateLinksTag(tag string) (out LinksTag, err error) { +func (d *DkfDB) CreateLinksTag(tag string) (out LinksTag, err error) { out = LinksTag{Name: tag} - err = DB.FirstOrCreate(&out, "name = ?", tag).Error + err = d.db.FirstOrCreate(&out, "name = ?", tag).Error return } @@ -124,8 +124,8 @@ type LinksTagsLink struct { TagID int64 } -func AddLinkTag(linkID, tagID int64) (err error) { - return DB.Create(&LinksTagsLink{LinkID: linkID, TagID: tagID}).Error +func (d *DkfDB) AddLinkTag(linkID, tagID int64) (err error) { + return d.db.Create(&LinksTagsLink{LinkID: linkID, TagID: tagID}).Error } type LinksCategoriesLink struct { @@ -133,8 +133,8 @@ type LinksCategoriesLink struct { LinkID int64 } -func AddLinkCategory(linkID, categoryID int64) (err error) { - return DB.Create(&LinksCategoriesLink{CategoryID: categoryID, LinkID: linkID}).Error +func (d *DkfDB) AddLinkCategory(linkID, categoryID int64) (err error) { + return d.db.Create(&LinksCategoriesLink{CategoryID: categoryID, LinkID: linkID}).Error } type CategoriesResult struct { @@ -142,8 +142,8 @@ type CategoriesResult struct { Count int64 } -func GetCategories() (out []CategoriesResult, err error) { - err = DB.Raw(`SELECT +func (d *DkfDB) GetCategories() (out []CategoriesResult, err error) { + err = d.db.Raw(`SELECT c.name, count(cl.link_id) as count FROM links_categories_links cl INNER JOIN links_categories c ON c.id = cl.category_id @@ -153,8 +153,8 @@ ORDER BY c.name`).Scan(&out).Error return } -func GetLinkCategories(linkID int64) (out []LinksCategory, err error) { - err = DB.Raw(`SELECT +func (d *DkfDB) GetLinkCategories(linkID int64) (out []LinksCategory, err error) { + err = d.db.Raw(`SELECT c.id, c.name FROM links_categories_links cl INNER JOIN links_categories c ON c.id = cl.category_id @@ -163,8 +163,8 @@ ORDER BY c.name`, linkID).Scan(&out).Error return } -func GetLinkTags(linkID int64) (out []LinksTag, err error) { - err = DB.Raw(`SELECT +func (d *DkfDB) GetLinkTags(linkID int64) (out []LinksTag, err error) { + err = d.db.Raw(`SELECT t.id, t.name FROM links_tags_links tl INNER JOIN links_tags t ON t.id = tl.tag_id @@ -173,12 +173,12 @@ ORDER BY t.name`, linkID).Scan(&out).Error return } -func DeleteLinkCategories(linkID int64) error { - return DB.Delete(&LinksCategoriesLink{}, "link_id = ?", linkID).Error +func (d *DkfDB) DeleteLinkCategories(linkID int64) error { + return d.db.Delete(&LinksCategoriesLink{}, "link_id = ?", linkID).Error } -func DeleteLinkTags(linkID int64) error { - return DB.Delete(&LinksTagsLink{}, "link_id = ?", linkID).Error +func (d *DkfDB) DeleteLinkTags(linkID int64) error { + return d.db.Delete(&LinksTagsLink{}, "link_id = ?", linkID).Error } type LinksMirror struct { @@ -212,50 +212,50 @@ func (l LinksPgp) GetKeyFingerprint() string { return out } -func CreateLinkPgp(linkID int64, title, description, publicKey string) (out LinksPgp, err error) { +func (d *DkfDB) CreateLinkPgp(linkID int64, title, description, publicKey string) (out LinksPgp, err error) { out = LinksPgp{ LinkID: linkID, Title: title, Description: description, PgpPublicKey: publicKey, } - err = DB.Create(&out).Error + err = d.db.Create(&out).Error return } -func CreateLinkMirror(linkID int64, link string) (out LinksMirror, err error) { +func (d *DkfDB) CreateLinkMirror(linkID int64, link string) (out LinksMirror, err error) { out = LinksMirror{ LinkID: linkID, MirrorURL: link, } - err = DB.Create(&out).Error + err = d.db.Create(&out).Error return } -func GetLinkPgps(linkID int64) (out []LinksPgp, err error) { - err = DB.Find(&out, "link_id = ?", linkID).Error +func (d *DkfDB) GetLinkPgps(linkID int64) (out []LinksPgp, err error) { + err = d.db.Find(&out, "link_id = ?", linkID).Error return } -func GetLinkMirrors(linkID int64) (out []LinksMirror, err error) { - err = DB.Find(&out, "link_id = ?", linkID).Error +func (d *DkfDB) GetLinkMirrors(linkID int64) (out []LinksMirror, err error) { + err = d.db.Find(&out, "link_id = ?", linkID).Error return } -func GetLinkPgpByID(id int64) (out LinksPgp, err error) { - err = DB.First(&out, "id = ?", id).Error +func (d *DkfDB) GetLinkPgpByID(id int64) (out LinksPgp, err error) { + err = d.db.First(&out, "id = ?", id).Error return } -func GetLinkMirrorByID(id int64) (out LinksMirror, err error) { - err = DB.First(&out, "id = ?", id).Error +func (d *DkfDB) GetLinkMirrorByID(id int64) (out LinksMirror, err error) { + err = d.db.First(&out, "id = ?", id).Error return } -func DeleteLinkPgpByID(id int64) error { - return DB.Where("id = ?", id).Delete(&LinksPgp{}).Error +func (d *DkfDB) DeleteLinkPgpByID(id int64) error { + return d.db.Where("id = ?", id).Delete(&LinksPgp{}).Error } -func DeleteLinkMirrorByID(id int64) error { - return DB.Where("id = ?", id).Delete(&LinksMirror{}).Error +func (d *DkfDB) DeleteLinkMirrorByID(id int64) error { + return d.db.Where("id = ?", id).Delete(&LinksMirror{}).Error } diff --git a/pkg/database/tableNotifications.go b/pkg/database/tableNotifications.go @@ -26,8 +26,8 @@ type SessionNotification struct { User User } -func GetUserNotifications(userID UserID) (msgs []Notification, err error) { - err = DB.Order("id DESC"). +func (d *DkfDB) GetUserNotifications(userID UserID) (msgs []Notification, err error) { + err = d.db.Order("id DESC"). Limit(50). Preload("User"). Find(&msgs, "user_id = ?", userID).Error @@ -36,15 +36,15 @@ func GetUserNotifications(userID UserID) (msgs []Notification, err error) { ids = append(ids, msg.ID) } now := time.Now() - if err := DB.Model(&Notification{}).Where("id IN (?)", ids). + if err := d.db.Model(&Notification{}).Where("id IN (?)", ids). UpdateColumn("is_read", true, "read_at", &now).Error; err != nil { logrus.Error(err) } return } -func GetUserSessionNotifications(sessionToken string) (msgs []SessionNotification, err error) { - err = DB.Order("id DESC"). +func (d *DkfDB) GetUserSessionNotifications(sessionToken string) (msgs []SessionNotification, err error) { + err = d.db.Order("id DESC"). Limit(50). Joins("INNER JOIN sessions s ON s.token = session_token"). Joins("INNER JOIN users u ON u.id = s.user_id"). @@ -54,49 +54,49 @@ func GetUserSessionNotifications(sessionToken string) (msgs []SessionNotificatio ids = append(ids, msg.ID) } now := time.Now() - if err := DB.Table("session_notifications").Where("id IN (?)", ids). + if err := d.db.Table("session_notifications").Where("id IN (?)", ids). UpdateColumn("is_read", true, "read_at", &now).Error; err != nil { logrus.Error(err) } return } -func DeleteNotificationByID(notificationID int64) error { - return DB.Where("id = ?", notificationID).Delete(&Notification{}).Error +func (d *DkfDB) DeleteNotificationByID(notificationID int64) error { + return d.db.Where("id = ?", notificationID).Delete(&Notification{}).Error } -func DeleteSessionNotificationByID(sessionNotificationID int64) error { - return DB.Where("id = ?", sessionNotificationID).Delete(&SessionNotification{}).Error +func (d *DkfDB) DeleteSessionNotificationByID(sessionNotificationID int64) error { + return d.db.Where("id = ?", sessionNotificationID).Delete(&SessionNotification{}).Error } -func DeleteAllNotifications(userID UserID) error { - return DB.Where("user_id = ?", userID).Delete(&Notification{}).Error +func (d *DkfDB) DeleteAllNotifications(userID UserID) error { + return d.db.Where("user_id = ?", userID).Delete(&Notification{}).Error } -func CreateNotification(msg string, userID UserID) { +func (d *DkfDB) CreateNotification(msg string, userID UserID) { inbox := Notification{Message: msg, UserID: userID, IsRead: false} - if err := DB.Create(&inbox).Error; err != nil { + if err := d.db.Create(&inbox).Error; err != nil { logrus.Error(err) } } -func GetUserNotificationsCount(userID UserID) (count int64) { - DB.Table("notifications").Where("user_id = ? AND is_read = ?", userID, false).Count(&count) +func (d *DkfDB) GetUserNotificationsCount(userID UserID) (count int64) { + d.db.Table("notifications").Where("user_id = ? AND is_read = ?", userID, false).Count(&count) return } -func GetUserSessionNotificationsCount(sessionToken string) (count int64) { - DB.Table("session_notifications").Where("session_token = ? AND is_read = ?", sessionToken, false).Count(&count) +func (d *DkfDB) GetUserSessionNotificationsCount(sessionToken string) (count int64) { + d.db.Table("session_notifications").Where("session_token = ? AND is_read = ?", sessionToken, false).Count(&count) return } -func CreateSessionNotification(msg string, sessionToken string) { +func (d *DkfDB) CreateSessionNotification(msg string, sessionToken string) { inbox := SessionNotification{Message: msg, SessionToken: sessionToken, IsRead: false} - if err := DB.Create(&inbox).Error; err != nil { + if err := d.db.Create(&inbox).Error; err != nil { logrus.Error(err) } } -func DeleteAllSessionNotifications(sessionToken string) error { - return DB.Where("session_token = ?", sessionToken).Delete(&SessionNotification{}).Error +func (d *DkfDB) DeleteAllSessionNotifications(sessionToken string) error { + return d.db.Where("session_token = ?", sessionToken).Delete(&SessionNotification{}).Error } diff --git a/pkg/database/tableOnionBlacklist.go b/pkg/database/tableOnionBlacklist.go @@ -7,7 +7,7 @@ type OnionBlacklist struct { CreatedAt time.Time } -func GetOnionBlacklist(hash string) (out OnionBlacklist, err error) { - err = DB.First(&out, "md5 = ?", hash).Error +func (d *DkfDB) GetOnionBlacklist(hash string) (out OnionBlacklist, err error) { + err = d.db.First(&out, "md5 = ?", hash).Error return } diff --git a/pkg/database/tablePmBlacklistedUsers.go b/pkg/database/tablePmBlacklistedUsers.go @@ -11,43 +11,43 @@ type PmBlacklistedUsers struct { } // IsUserPmBlacklisted returns either or not toUserID blacklisted fromUserID -func IsUserPmBlacklisted(fromUserID, toUserID UserID) bool { +func (d *DkfDB) IsUserPmBlacklisted(fromUserID, toUserID UserID) bool { var count int64 - DB.Model(&PmBlacklistedUsers{}).Where("blacklisted_user_id = ? AND user_id = ?", fromUserID, toUserID).Count(&count) + d.db.Model(&PmBlacklistedUsers{}).Where("blacklisted_user_id = ? AND user_id = ?", fromUserID, toUserID).Count(&count) return count == 1 } // GetPmBlacklistedUsers returns a list of userID blacklisted users -func GetPmBlacklistedUsers(userID UserID) (out []PmBlacklistedUsers, err error) { - err = DB.Where("user_id = ?", userID).Preload("BlacklistedUser").Find(&out).Error +func (d *DkfDB) GetPmBlacklistedUsers(userID UserID) (out []PmBlacklistedUsers, err error) { + err = d.db.Where("user_id = ?", userID).Preload("BlacklistedUser").Find(&out).Error return } // GetPmBlacklistedByUsers returns a list of users that are blacklisting userID -func GetPmBlacklistedByUsers(userID UserID) (out []PmBlacklistedUsers, err error) { - err = DB.Where("blacklisted_user_id = ?", userID).Find(&out).Error +func (d *DkfDB) GetPmBlacklistedByUsers(userID UserID) (out []PmBlacklistedUsers, err error) { + err = d.db.Where("blacklisted_user_id = ?", userID).Find(&out).Error return } // ToggleBlacklistedUser returns true if the user was added to the blacklist -func ToggleBlacklistedUser(userID, blacklistedUserID UserID) bool { - if IsUserPmBlacklisted(blacklistedUserID, userID) { - RmBlacklistedUser(userID, blacklistedUserID) +func (d *DkfDB) ToggleBlacklistedUser(userID, blacklistedUserID UserID) bool { + if d.IsUserPmBlacklisted(blacklistedUserID, userID) { + d.RmBlacklistedUser(userID, blacklistedUserID) return false } - AddBlacklistedUser(userID, blacklistedUserID) + d.AddBlacklistedUser(userID, blacklistedUserID) return true } -func AddBlacklistedUser(userID, blacklistedUserID UserID) { +func (d *DkfDB) AddBlacklistedUser(userID, blacklistedUserID UserID) { ignore := PmBlacklistedUsers{UserID: userID, BlacklistedUserID: blacklistedUserID} - if err := DB.Create(&ignore).Error; err != nil { + if err := d.db.Create(&ignore).Error; err != nil { logrus.Error(err) } } -func RmBlacklistedUser(userID, blacklistedUserID UserID) { - if err := DB.Delete(PmBlacklistedUsers{}, "user_id = ? AND blacklisted_user_id = ?", userID, blacklistedUserID).Error; err != nil { +func (d *DkfDB) RmBlacklistedUser(userID, blacklistedUserID UserID) { + if err := d.db.Delete(PmBlacklistedUsers{}, "user_id = ? AND blacklisted_user_id = ?", userID, blacklistedUserID).Error; err != nil { logrus.Error(err) } } diff --git a/pkg/database/tablePmWhitelistedUsers.go b/pkg/database/tablePmWhitelistedUsers.go @@ -10,36 +10,36 @@ type PmWhitelistedUsers struct { WhitelistedUser User } -func IsUserPmWhitelisted(fromUserID, toUserID UserID) bool { +func (d *DkfDB) IsUserPmWhitelisted(fromUserID, toUserID UserID) bool { var count int64 - DB.Model(&PmWhitelistedUsers{}).Where("whitelisted_user_id = ? AND user_id = ?", fromUserID, toUserID).Count(&count) + d.db.Model(&PmWhitelistedUsers{}).Where("whitelisted_user_id = ? AND user_id = ?", fromUserID, toUserID).Count(&count) return count == 1 } -func GetPmWhitelistedUsers(userID UserID) (out []PmWhitelistedUsers, err error) { - err = DB.Where("user_id = ?", userID).Preload("WhitelistedUser").Find(&out).Error +func (d *DkfDB) GetPmWhitelistedUsers(userID UserID) (out []PmWhitelistedUsers, err error) { + err = d.db.Where("user_id = ?", userID).Preload("WhitelistedUser").Find(&out).Error return } // ToggleWhitelistedUser returns true if the user was added to the whitelist -func ToggleWhitelistedUser(userID, whitelistedUserID UserID) bool { - if IsUserPmWhitelisted(whitelistedUserID, userID) { - RmWhitelistedUser(userID, whitelistedUserID) +func (d *DkfDB) ToggleWhitelistedUser(userID, whitelistedUserID UserID) bool { + if d.IsUserPmWhitelisted(whitelistedUserID, userID) { + d.RmWhitelistedUser(userID, whitelistedUserID) return false } - AddWhitelistedUser(userID, whitelistedUserID) + d.AddWhitelistedUser(userID, whitelistedUserID) return true } -func AddWhitelistedUser(userID, whitelistedUserID UserID) { +func (d *DkfDB) AddWhitelistedUser(userID, whitelistedUserID UserID) { ignore := PmWhitelistedUsers{UserID: userID, WhitelistedUserID: whitelistedUserID} - if err := DB.Create(&ignore).Error; err != nil { + if err := d.db.Create(&ignore).Error; err != nil { logrus.Error(err) } } -func RmWhitelistedUser(userID, whitelistedUserID UserID) { - if err := DB.Delete(PmWhitelistedUsers{}, "user_id = ? AND whitelisted_user_id = ?", userID, whitelistedUserID).Error; err != nil { +func (d *DkfDB) RmWhitelistedUser(userID, whitelistedUserID UserID) { + if err := d.db.Delete(PmWhitelistedUsers{}, "user_id = ? AND whitelisted_user_id = ?", userID, whitelistedUserID).Error; err != nil { logrus.Error(err) } } diff --git a/pkg/database/tableProhibitedPasswords.go b/pkg/database/tableProhibitedPasswords.go @@ -4,8 +4,8 @@ type ProhibitedPassword struct { Password string } -func IsPasswordProhibited(password string) bool { +func (d *DkfDB) IsPasswordProhibited(password string) bool { var count int - DB.Table("prohibited_passwords").Where("password = ?", password).Count(&count) + d.db.Table("prohibited_passwords").Where("password = ?", password).Count(&count) return count > 0 } diff --git a/pkg/database/tableSecurityLogs.go b/pkg/database/tableSecurityLogs.go @@ -56,24 +56,24 @@ func getMessageForType(typ int64) string { return "" } -func CreateSecurityLog(userID UserID, typ int64) { +func (d *DkfDB) CreateSecurityLog(userID UserID, typ int64) { log := SecurityLog{ Message: getMessageForType(typ), UserID: userID, Typ: typ, } - if err := DB.Create(&log).Error; err != nil { + if err := d.db.Create(&log).Error; err != nil { logrus.Error(err) } } -func GetSecurityLogs(userID UserID) (out []SecurityLog, err error) { - err = DB.Order("id DESC").Find(&out, "user_id = ?", userID).Error +func (d *DkfDB) GetSecurityLogs(userID UserID) (out []SecurityLog, err error) { + err = d.db.Order("id DESC").Find(&out, "user_id = ?", userID).Error return } -func DeleteOldSecurityLogs() { - if err := DB.Delete(SecurityLog{}, "created_at < date('now', '-7 Day')").Error; err != nil { +func (d *DkfDB) DeleteOldSecurityLogs() { + if err := d.db.Delete(SecurityLog{}, "created_at < date('now', '-7 Day')").Error; err != nil { logrus.Error(err) } } diff --git a/pkg/database/tableSessions.go b/pkg/database/tableSessions.go @@ -20,15 +20,15 @@ type Session struct { } // GetActiveUserSessions gets all user sessions -func GetActiveUserSessions(userID UserID) (out []Session) { - DB.Order("created_at DESC").Find(&out, "user_id = ? AND expires_at > DATETIME('now') AND deleted_at IS NULL", userID) +func (d *DkfDB) GetActiveUserSessions(userID UserID) (out []Session) { + d.db.Order("created_at DESC").Find(&out, "user_id = ? AND expires_at > DATETIME('now') AND deleted_at IS NULL", userID) return } // CreateSession creates a session for a user -func CreateSession(userID UserID, userAgent string) (Session, error) { +func (d *DkfDB) CreateSession(userID UserID, userAgent string) (Session, error) { // Delete all sessions except the last 4 - if err := DB.Exec(`DELETE FROM sessions WHERE user_id = ? AND token NOT IN (SELECT s2.token FROM sessions s2 WHERE s2.user_id = ? ORDER BY s2.created_at DESC LIMIT 4)`, userID, userID).Error; err != nil { + if err := d.db.Exec(`DELETE FROM sessions WHERE user_id = ? AND token NOT IN (SELECT s2.token FROM sessions s2 WHERE s2.user_id = ? ORDER BY s2.created_at DESC LIMIT 4)`, userID, userID).Error; err != nil { logrus.Error(err) } session := Session{ @@ -38,13 +38,13 @@ func CreateSession(userID UserID, userAgent string) (Session, error) { UserAgent: userAgent, ExpiresAt: time.Now().Add(time.Duration(utils.OneMonthSecs) * time.Second), } - err := DB.Create(&session).Error + err := d.db.Create(&session).Error return session, err } // DoCreateSession same as CreateSession but log the error instead of returning it -func DoCreateSession(userID UserID, userAgent string) Session { - session, err := CreateSession(userID, userAgent) +func (d *DkfDB) DoCreateSession(userID UserID, userAgent string) Session { + session, err := d.CreateSession(userID, userAgent) if err != nil { logrus.Error("Failed to create session : ", err) } @@ -52,25 +52,25 @@ func DoCreateSession(userID UserID, userAgent string) Session { } // DeleteUserSessions all sessions of the user. -func DeleteUserSessions(userID UserID) error { - return DB.Unscoped().Where("user_id = ?", userID).Delete(&Session{}).Error +func (d *DkfDB) DeleteUserSessions(userID UserID) error { + return d.db.Unscoped().Where("user_id = ?", userID).Delete(&Session{}).Error } // DeleteSessionByToken a session by its token -func DeleteSessionByToken(token string) error { - return DB.Unscoped().Where("token = ?", token).Delete(&Session{}).Error +func (d *DkfDB) DeleteSessionByToken(token string) error { + return d.db.Unscoped().Where("token = ?", token).Delete(&Session{}).Error } -func DeleteUserSessionByToken(userID UserID, token string) error { - return DB.Unscoped().Where("user_id = ? AND token = ?", userID, token).Delete(&Session{}).Error +func (d *DkfDB) DeleteUserSessionByToken(userID UserID, token string) error { + return d.db.Unscoped().Where("user_id = ? AND token = ?", userID, token).Delete(&Session{}).Error } -func DeleteUserOtherSessions(userID UserID, currentToken string) error { - return DB.Unscoped().Where("user_id = ? AND token != ?", userID, currentToken).Delete(&Session{}).Error +func (d *DkfDB) DeleteUserOtherSessions(userID UserID, currentToken string) error { + return d.db.Unscoped().Where("user_id = ? AND token != ?", userID, currentToken).Delete(&Session{}).Error } -func DeleteOldSessions() { - if err := DB.Unscoped().Delete(Session{}, "expires_at < date('now', '-32 Day') OR (expires_at < date('now', '-32 Day') AND deleted_at IS NOT NULL)").Error; err != nil { +func (d *DkfDB) DeleteOldSessions() { + if err := d.db.Unscoped().Delete(Session{}, "expires_at < date('now', '-32 Day') OR (expires_at < date('now', '-32 Day') AND deleted_at IS NOT NULL)").Error; err != nil { logrus.Error(err) } } diff --git a/pkg/database/tableSettings.go b/pkg/database/tableSettings.go @@ -18,27 +18,27 @@ type Settings struct { } // GetSettings get the saved settings from the DB -func GetSettings() (out Settings) { - if err := DB.Model(Settings{}).First(&out).Error; err != nil { +func (d *DkfDB) GetSettings() (out Settings) { + if err := d.db.Model(Settings{}).First(&out).Error; err != nil { out.SignupEnabled = true out.SilentSelfKick = true out.ForumEnabled = true out.MaybeAuthEnabled = true out.DownloadsEnabled = true out.CaptchaDifficulty = 2 - DB.Create(&out) + d.db.Create(&out) } return } // Save the settings to DB -func (s *Settings) Save() error { - return DB.Save(s).Error +func (s *Settings) Save(db *DkfDB) error { + return db.db.Save(s).Error } // DoSave settings in the database, ignore error -func (s *Settings) DoSave() { - if err := s.Save(); err != nil { +func (s *Settings) DoSave(db *DkfDB) { + if err := s.Save(db); err != nil { logrus.Error(err) } } diff --git a/pkg/database/tableSnippets.go b/pkg/database/tableSnippets.go @@ -8,23 +8,23 @@ type Snippet struct { Text string } -func GetUserSnippets(userID UserID) (out []Snippet, err error) { - err = DB.Find(&out, "user_id = ?", userID).Error +func (d *DkfDB) GetUserSnippets(userID UserID) (out []Snippet, err error) { + err = d.db.Find(&out, "user_id = ?", userID).Error return } -func CreateSnippet(userID UserID, name, text string) (out Snippet, err error) { +func (d *DkfDB) CreateSnippet(userID UserID, name, text string) (out Snippet, err error) { out = Snippet{ Name: name, UserID: userID, Text: text, } - err = DB.Create(&out).Error + err = d.db.Create(&out).Error return } -func DeleteSnippet(userID UserID, name string) { - if err := DB.Delete(Snippet{}, "user_id = ? AND name = ?", userID, name).Error; err != nil { +func (d *DkfDB) DeleteSnippet(userID UserID, name string) { + if err := d.db.Delete(Snippet{}, "user_id = ? AND name = ?", userID, name).Error; err != nil { logrus.Error(err) } } diff --git a/pkg/database/tableUploads.go b/pkg/database/tableUploads.go @@ -56,30 +56,30 @@ func (u *Upload) Exists() bool { return utils.FileExists(filePath1) } -func (u *Upload) Delete() error { +func (u *Upload) Delete(db *DkfDB) error { if err := os.Remove(filepath.Join(config.Global.ProjectUploadsPath(), u.FileName)); err != nil { return err } - if err := DB.Delete(&u).Error; err != nil { + if err := db.db.Delete(&u).Error; err != nil { return err } return nil } // CreateUpload create file on disk in "uploads" folder, and save upload in database as well. -func CreateUpload(fileName string, content []byte, userID UserID) (*Upload, error) { - return createUploadWithSize(fileName, content, userID, int64(len(content))) +func (d *DkfDB) CreateUpload(fileName string, content []byte, userID UserID) (*Upload, error) { + return d.createUploadWithSize(fileName, content, userID, int64(len(content))) } -func CreateEncryptedUploadWithSize(fileName string, content []byte, userID UserID, size int64) (*Upload, error) { +func (d *DkfDB) CreateEncryptedUploadWithSize(fileName string, content []byte, userID UserID, size int64) (*Upload, error) { encryptedContent, err := utils.EncryptAESMaster(content) if err != nil { return nil, err } - return createUploadWithSize(fileName, encryptedContent, userID, size) + return d.createUploadWithSize(fileName, encryptedContent, userID, size) } -func createUploadWithSize(fileName string, content []byte, userID UserID, size int64) (*Upload, error) { +func (d *DkfDB) createUploadWithSize(fileName string, content []byte, userID UserID, size int64) (*Upload, error) { newFileName := utils.MD5([]byte(utils.GenerateToken32())) if err := ioutil.WriteFile(filepath.Join(config.Global.ProjectUploadsPath(), newFileName), content, 0644); err != nil { return nil, err @@ -90,42 +90,42 @@ func createUploadWithSize(fileName string, content []byte, userID UserID, size i OrigFileName: fileName, FileSize: size, } - if err := DB.Create(&upload).Error; err != nil { + if err := d.db.Create(&upload).Error; err != nil { logrus.Error(err) } return &upload, nil } -func GetUploadByFileName(filename string) (out Upload, err error) { - err = DB.First(&out, "file_name = ?", filename).Error +func (d *DkfDB) GetUploadByFileName(filename string) (out Upload, err error) { + err = d.db.First(&out, "file_name = ?", filename).Error return } -func GetUploadByID(uploadID UploadID) (out Upload, err error) { - err = DB.First(&out, "id = ?", uploadID).Error +func (d *DkfDB) GetUploadByID(uploadID UploadID) (out Upload, err error) { + err = d.db.First(&out, "id = ?", uploadID).Error return } -func GetUploads() (out []Upload, err error) { - err = DB.Preload("User").Order("id DESC").Find(&out).Error +func (d *DkfDB) GetUploads() (out []Upload, err error) { + err = d.db.Preload("User").Order("id DESC").Find(&out).Error return } -func GetUserUploads(userID UserID) (out []Upload, err error) { - err = DB.Order("id DESC").Find(&out, "user_id = ?", userID).Error +func (d *DkfDB) GetUserUploads(userID UserID) (out []Upload, err error) { + err = d.db.Order("id DESC").Find(&out, "user_id = ?", userID).Error return } -func GetUserTotalUploadSize(userID UserID) int64 { +func (d *DkfDB) GetUserTotalUploadSize(userID UserID) int64 { var out struct{ TotalSize int64 } - if err := DB.Raw(`SELECT SUM(file_size) as total_size FROM uploads WHERE user_id = ?`, userID).Scan(&out).Error; err != nil { + if err := d.db.Raw(`SELECT SUM(file_size) as total_size FROM uploads WHERE user_id = ?`, userID).Scan(&out).Error; err != nil { logrus.Error(err) } return out.TotalSize } -func DeleteOldUploads() { - if err := DB.Exec(`DELETE FROM uploads WHERE created_at < date('now', '-1 Day')`).Error; err != nil { +func (d *DkfDB) DeleteOldUploads() { + if err := d.db.Exec(`DELETE FROM uploads WHERE created_at < date('now', '-1 Day')`).Error; err != nil { logrus.Error(err.Error()) } fileInfo, err := ioutil.ReadDir(config.Global.ProjectUploadsPath()) diff --git a/pkg/database/tableUserForumThreadSubscriptions.go b/pkg/database/tableUserForumThreadSubscriptions.go @@ -13,27 +13,27 @@ type UserForumThreadSubscription struct { User User } -func (s *UserForumThreadSubscription) DoSave() { - if err := DB.Save(s).Error; err != nil { +func (s *UserForumThreadSubscription) DoSave(db *DkfDB) { + if err := db.db.Save(s).Error; err != nil { logrus.Error(err) } } -func SubscribeToForumThread(userID UserID, threadID ForumThreadID) (err error) { - return DB.Create(&UserForumThreadSubscription{UserID: userID, ThreadID: threadID}).Error +func (d *DkfDB) SubscribeToForumThread(userID UserID, threadID ForumThreadID) (err error) { + return d.db.Create(&UserForumThreadSubscription{UserID: userID, ThreadID: threadID}).Error } -func UnsubscribeFromForumThread(userID UserID, threadID ForumThreadID) (err error) { - return DB.Delete(&UserForumThreadSubscription{}, "user_id = ? AND thread_id = ?", userID, threadID).Error +func (d *DkfDB) UnsubscribeFromForumThread(userID UserID, threadID ForumThreadID) (err error) { + return d.db.Delete(&UserForumThreadSubscription{}, "user_id = ? AND thread_id = ?", userID, threadID).Error } -func IsUserSubscribedToForumThread(userID UserID, threadID ForumThreadID) bool { +func (d *DkfDB) IsUserSubscribedToForumThread(userID UserID, threadID ForumThreadID) bool { var count int64 - DB.Model(UserForumThreadSubscription{}).Where("user_id = ? AND thread_id = ?", userID, threadID).Count(&count) + d.db.Model(UserForumThreadSubscription{}).Where("user_id = ? AND thread_id = ?", userID, threadID).Count(&count) return count == 1 } -func GetUsersSubscribedToForumThread(threadID ForumThreadID) (out []UserForumThreadSubscription, err error) { - err = DB.Preload("User").Find(&out, "thread_id = ?", threadID).Error +func (d *DkfDB) GetUsersSubscribedToForumThread(threadID ForumThreadID) (out []UserForumThreadSubscription, err error) { + err = d.db.Preload("User").Find(&out, "thread_id = ?", threadID).Error return } diff --git a/pkg/database/tableUserPrivateNotes.go b/pkg/database/tableUserPrivateNotes.go @@ -14,19 +14,19 @@ type UserPrivateNote struct { UpdatedAt time.Time } -func GetUserPrivateNotes(userID UserID) (out UserPrivateNote, err error) { - err = DB.First(&out, "user_id = ?", userID).Error +func (d *DkfDB) GetUserPrivateNotes(userID UserID) (out UserPrivateNote, err error) { + err = d.db.First(&out, "user_id = ?", userID).Error return } -func SetUserPrivateNotes(userID UserID, notes string) error { +func (d *DkfDB) SetUserPrivateNotes(userID UserID, notes string) error { if !govalidator.RuneLength(notes, "0", "10000") { return errors.New("notes must have 10000 characters maximum") } n := UserPrivateNote{UserID: userID} - if err := DB.FirstOrCreate(&n, "user_id = ?", userID).Error; err != nil { + if err := d.db.FirstOrCreate(&n, "user_id = ?", userID).Error; err != nil { return err } n.Notes = EncryptedString(notes) - return DB.Save(&n).Error + return d.db.Save(&n).Error } diff --git a/pkg/database/tableUserPublicNotes.go b/pkg/database/tableUserPublicNotes.go @@ -14,19 +14,19 @@ type UserPublicNote struct { UpdatedAt time.Time } -func GetUserPublicNotes(userID UserID) (out UserPublicNote, err error) { - err = DB.First(&out, "user_id = ?", userID).Error +func (d *DkfDB) GetUserPublicNotes(userID UserID) (out UserPublicNote, err error) { + err = d.db.First(&out, "user_id = ?", userID).Error return } -func SetUserPublicNotes(userID UserID, notes string) error { +func (d *DkfDB) SetUserPublicNotes(userID UserID, notes string) error { if !govalidator.RuneLength(notes, "0", "10000") { return errors.New("notes must have 10000 characters maximum") } n := UserPublicNote{UserID: userID} - if err := DB.FirstOrCreate(&n, "user_id = ?", userID).Error; err != nil { + if err := d.db.FirstOrCreate(&n, "user_id = ?", userID).Error; err != nil { return err } n.Notes = notes - return DB.Save(&n).Error + return d.db.Save(&n).Error } diff --git a/pkg/database/tableUserRoomSubscriptions.go b/pkg/database/tableUserRoomSubscriptions.go @@ -13,22 +13,22 @@ type UserRoomSubscription struct { Room ChatRoom } -func (s *UserRoomSubscription) DoSave() { - if err := DB.Save(s).Error; err != nil { +func (s *UserRoomSubscription) DoSave(db *DkfDB) { + if err := db.db.Save(s).Error; err != nil { logrus.Error(err) } } -func SubscribeToRoom(userID UserID, roomID RoomID) (err error) { - return DB.Create(&UserRoomSubscription{UserID: userID, RoomID: roomID}).Error +func (d *DkfDB) SubscribeToRoom(userID UserID, roomID RoomID) (err error) { + return d.db.Create(&UserRoomSubscription{UserID: userID, RoomID: roomID}).Error } -func UnsubscribeFromRoom(userID UserID, roomID RoomID) (err error) { - return DB.Delete(&UserRoomSubscription{}, "user_id = ? AND room_id = ?", userID, roomID).Error +func (d *DkfDB) UnsubscribeFromRoom(userID UserID, roomID RoomID) (err error) { + return d.db.Delete(&UserRoomSubscription{}, "user_id = ? AND room_id = ?", userID, roomID).Error } -func IsUserSubscribedToRoom(userID UserID, roomID RoomID) bool { +func (d *DkfDB) IsUserSubscribedToRoom(userID UserID, roomID RoomID) bool { var count int64 - DB.Model(UserRoomSubscription{}).Where("user_id = ? AND room_id = ?", userID, roomID).Count(&count) + d.db.Model(UserRoomSubscription{}).Where("user_id = ? AND room_id = ?", userID, roomID).Count(&count) return count == 1 } diff --git a/pkg/database/tableUsers.go b/pkg/database/tableUsers.go @@ -125,8 +125,8 @@ func UserPtrID(user *User) *UserID { return nil } -func GetChessSubscribers() (out []User, err error) { - err = DB.Find(&out, "notify_chess_games == 1").Error +func (d *DkfDB) GetChessSubscribers() (out []User, err error) { + err = d.db.Find(&out, "notify_chess_games == 1").Error return } @@ -250,84 +250,84 @@ func (u *User) GetHellbanOpacityF64() float64 { } // Save user in the database -func (u *User) Save() error { - return DB.Save(u).Error +func (u *User) Save(db *DkfDB) error { + return db.db.Save(u).Error } // DoSave user in the database, ignore error -func (u *User) DoSave() { - if err := u.Save(); err != nil { +func (u *User) DoSave(db *DkfDB) { + if err := u.Save(db); err != nil { logrus.Error(err) } } -func (u *User) HellBan() { +func (u *User) HellBan(db *DkfDB) { u.IsHellbanned = true - u.DoSave() - if err := DB.Model(&ChatMessage{}).Where("user_id = ?", u.ID).Update("is_hellbanned", true).Error; err != nil { + u.DoSave(db) + if err := db.db.Model(&ChatMessage{}).Where("user_id = ?", u.ID).Update("is_hellbanned", true).Error; err != nil { logrus.Error(err) } } -func (u *User) UnHellBan() { +func (u *User) UnHellBan(db *DkfDB) { u.IsHellbanned = false - u.DoSave() - if err := DB.Model(&ChatMessage{}).Where("user_id = ?", u.ID).Update("is_hellbanned", false).Error; err != nil { + u.DoSave(db) + if err := db.db.Model(&ChatMessage{}).Where("user_id = ?", u.ID).Update("is_hellbanned", false).Error; err != nil { logrus.Error(err) } } // GetUserBySessionKey ... -func GetUserBySessionKey(user *User, sessionKey string) error { - return DB.Joins("INNER JOIN sessions s ON s.token = ? AND s.expires_at > DATETIME('now') and s.deleted_at IS NULL AND s.user_id = users.id"). +func (d *DkfDB) GetUserBySessionKey(user *User, sessionKey string) error { + return d.db.Joins("INNER JOIN sessions s ON s.token = ? AND s.expires_at > DATETIME('now') and s.deleted_at IS NULL AND s.user_id = users.id"). Where("users.verified = 1", sessionKey). First(user).Error } // GetUserByApiKey ... -func GetUserByApiKey(user *User, apiKey string) error { - return DB.First(user, "api_key = ?", apiKey).Error +func (d *DkfDB) GetUserByApiKey(user *User, apiKey string) error { + return d.db.First(user, "api_key = ?", apiKey).Error } // GetUserByID ... -func GetUserByID(userID UserID) (out User, err error) { - err = DB.First(&out, "id = ?", userID).Error +func (d *DkfDB) GetUserByID(userID UserID) (out User, err error) { + err = d.db.First(&out, "id = ?", userID).Error return } // GetUserByUsername ... -func GetUserByUsername(username string) (out User, err error) { - err = DB.First(&out, "username = ? COLLATE NOCASE", username).Error +func (d *DkfDB) GetUserByUsername(username string) (out User, err error) { + err = d.db.First(&out, "username = ? COLLATE NOCASE", username).Error return } -func GetVerifiedUserByUsername(username string) (out User, err error) { - err = DB.First(&out, "username = ? COLLATE NOCASE AND verified = 1", username).Error +func (d *DkfDB) GetVerifiedUserByUsername(username string) (out User, err error) { + err = d.db.First(&out, "username = ? COLLATE NOCASE AND verified = 1", username).Error return } -func GetUsersByID(ids []UserID) (out []User, err error) { - err = DB.Find(&out, "id IN (?)", ids).Error +func (d *DkfDB) GetUsersByID(ids []UserID) (out []User, err error) { + err = d.db.Find(&out, "id IN (?)", ids).Error return } -func GetUsersByUsername(usernames []string) (out []User, err error) { - err = DB.Find(&out, "username IN (?)", usernames).Error +func (d *DkfDB) GetUsersByUsername(usernames []string) (out []User, err error) { + err = d.db.Find(&out, "username IN (?)", usernames).Error return } -func DeleteUserByID(userID UserID) (err error) { - err = DB.Unscoped().Delete(User{}, "id = ?", userID).Error +func (d *DkfDB) DeleteUserByID(userID UserID) (err error) { + err = d.db.Unscoped().Delete(User{}, "id = ?", userID).Error return } -func GetModeratorsUsers() (out []User, err error) { - err = DB.Order("username ASC").Find(&out, "role = ? OR is_admin = 1", "moderator").Error +func (d *DkfDB) GetModeratorsUsers() (out []User, err error) { + err = d.db.Order("username ASC").Find(&out, "role = ? OR is_admin = 1", "moderator").Error return } -func GetClubMembers() (out []User, err error) { - err = DB.Find(&out, "is_club_member = ?", true).Error +func (d *DkfDB) GetClubMembers() (out []User, err error) { + err = d.db.Find(&out, "is_club_member = ?", true).Error return } @@ -336,40 +336,40 @@ func GetClubMembers() (out []User, err error) { // Assume I realize I left myself logged into a shared computer. // I change my password to protect myself. // The session on the public computer needs to be invalidated. -func (u *User) ChangePassword(hashedPassword string) error { +func (u *User) ChangePassword(db *DkfDB, hashedPassword string) error { u.Password = hashedPassword - if err := DB.Save(u).Error; err != nil { + if err := db.db.Save(u).Error; err != nil { return err } // Delete active user sessions - if err := DeleteUserSessions(u.ID); err != nil { + if err := db.DeleteUserSessions(u.ID); err != nil { return err } return nil } -func (u *User) ChangeDuressPassword(hashedDuressPassword string) error { +func (u *User) ChangeDuressPassword(db *DkfDB, hashedDuressPassword string) error { u.DuressPassword = hashedDuressPassword - if err := DB.Save(u).Error; err != nil { + if err := db.db.Save(u).Error; err != nil { return err } // Delete active user sessions - if err := DeleteUserSessions(u.ID); err != nil { + if err := db.DeleteUserSessions(u.ID); err != nil { return err } return nil } -func (u *User) CheckPassword(password string) bool { +func (u *User) CheckPassword(db *DkfDB, password string) bool { if err := bcrypt.CompareHashAndPassword([]byte(u.Password), []byte(password)); err != nil { if err := bcrypt.CompareHashAndPassword([]byte(u.DuressPassword), []byte(password)); err != nil { return false } u.IsUnderDuress = true - u.DoSave() + u.DoSave(db) } else { u.IsUnderDuress = false - u.DoSave() + u.DoSave(db) } return true } @@ -420,22 +420,22 @@ func isUsernameReserved(username string) bool { } // GetVerifiedUserBySessionID ... -func GetVerifiedUserBySessionID(token string) (out User, err error) { - err = DB.First(&out, "token = ? and verified = 1", token).Error +func (d *DkfDB) GetVerifiedUserBySessionID(token string) (out User, err error) { + err = d.db.First(&out, "token = ? and verified = 1", token).Error return } // GetRecentUsersCount ... -func GetRecentUsersCount() int64 { +func (d *DkfDB) GetRecentUsersCount() int64 { var count int64 - DB.Table("users").Where("created_at > datetime('now', '-1 Minute')").Count(&count) + d.db.Table("users").Where("created_at > datetime('now', '-1 Minute')").Count(&count) return count } // IsUsernameAlreadyTaken ... -func IsUsernameAlreadyTaken(username string) bool { +func (d *DkfDB) IsUsernameAlreadyTaken(username string) bool { var count int64 - DB.Table("users").Where("username = ? COLLATE NOCASE", username).Count(&count) + d.db.Table("users").Where("username = ? COLLATE NOCASE", username).Count(&count) return count > 0 || isUsernameReserved(username) } @@ -446,7 +446,7 @@ type PasswordValidator struct { } // NewPasswordValidator ... -func NewPasswordValidator(password string) *PasswordValidator { +func NewPasswordValidator(db *DkfDB, password string) *PasswordValidator { p := new(PasswordValidator) p.password = password if len(password) < 8 { @@ -455,7 +455,7 @@ func NewPasswordValidator(password string) *PasswordValidator { if len(password) > 128 { p.error = errors.New("password must be at most 128 characters") } - if IsPasswordProhibited(password) { + if db.IsPasswordProhibited(password) { p.error = errors.New("this password is too weak") } return p @@ -482,21 +482,21 @@ func (p *PasswordValidator) Hash() (string, error) { return string(h), p.error } -func CanUseUsername(username string, isFirstUser bool) error { +func (d *DkfDB) CanUseUsername(username string, isFirstUser bool) error { if _, err := ValidateUsername(username, isFirstUser); err != nil { return err - } else if IsUsernameAlreadyTaken(username) { + } else if d.IsUsernameAlreadyTaken(username) { return errors.New("username already taken") } return nil } -func CanRenameTo(oldUsername, newUsername string) error { +func (d *DkfDB) CanRenameTo(oldUsername, newUsername string) error { if _, err := ValidateUsername(newUsername, false); err != nil { return err } if strings.ToLower(oldUsername) != strings.ToLower(newUsername) { - if IsUsernameAlreadyTaken(newUsername) { + if d.IsUsernameAlreadyTaken(newUsername) { return errors.New("username already taken") } } @@ -504,34 +504,34 @@ func CanRenameTo(oldUsername, newUsername string) error { } // CreateUser ... -func CreateUser(username, password, repassword string, registrationDuration int64, signupInfoEnc string) (User, UserErrors) { - return createUser(username, password, repassword, "", false, true, false, false, false, registrationDuration, signupInfoEnc) +func (d *DkfDB) CreateUser(username, password, repassword string, registrationDuration int64, signupInfoEnc string) (User, UserErrors) { + return d.createUser(username, password, repassword, "", false, true, false, false, false, registrationDuration, signupInfoEnc) } -func CreateGuestUser(username, password string) (User, UserErrors) { - return createUser(username, password, password, "", false, true, true, false, false, 0, "signupInfoEnc") +func (d *DkfDB) CreateGuestUser(username, password string) (User, UserErrors) { + return d.createUser(username, password, password, "", false, true, true, false, false, 0, "signupInfoEnc") } -func CreateFirstUser(username, password, repassword string) (User, UserErrors) { - return createUser(username, password, repassword, "", true, true, false, true, false, 12000, "") +func (d *DkfDB) CreateFirstUser(username, password, repassword string) (User, UserErrors) { + return d.createUser(username, password, repassword, "", true, true, false, true, false, 12000, "") } -func CreateZeroUser() (User, UserErrors) { +func (d *DkfDB) CreateZeroUser() (User, UserErrors) { password := utils.GenerateToken10() - return createUser("0", password, password, config.NullUserPublicKey, false, true, false, false, true, 12000, "") + return d.createUser("0", password, password, config.NullUserPublicKey, false, true, false, false, true, 12000, "") } // skipUsernameValidation: entirely skip username validation (for "0" user) // isFirstUser: less strict username validation; can use "admin"/"n0tr1v" usernames -func createUser(username, password, repassword, gpgPublicKey string, isAdmin, verified, temp, isFirstUser, skipUsernameValidation bool, registrationDuration int64, signupInfoEnc string) (User, UserErrors) { +func (d *DkfDB) createUser(username, password, repassword, gpgPublicKey string, isAdmin, verified, temp, isFirstUser, skipUsernameValidation bool, registrationDuration int64, signupInfoEnc string) (User, UserErrors) { username = strings.TrimSpace(username) var errs UserErrors if !skipUsernameValidation { - if err := CanUseUsername(username, isFirstUser); err != nil { + if err := d.CanUseUsername(username, isFirstUser); err != nil { errs.Username = err.Error() } } - hashedPassword, err := NewPasswordValidator(password).CompareWith(repassword).Hash() + hashedPassword, err := NewPasswordValidator(d, password).CompareWith(repassword).Hash() if err != nil { errs.Password = err.Error() } @@ -569,7 +569,7 @@ func createUser(username, password, repassword, gpgPublicKey string, isAdmin, ve token := utils.GenerateToken32() newUser.Token = &token } - if err := DB.Create(&newUser).Error; err != nil { + if err := d.db.Create(&newUser).Error; err != nil { logrus.Error(err) } @@ -582,8 +582,8 @@ func (u *User) SetAvatar(b []byte) { u.Avatar = b } -func (u *User) IncrKarma(karma int64, description string) { - if _, err := CreateKarmaHistory(karma, description, u.ID, nil); err != nil { +func (u *User) IncrKarma(db *DkfDB, karma int64, description string) { + if _, err := db.CreateKarmaHistory(karma, description, u.ID, nil); err != nil { logrus.Error(err) return } diff --git a/pkg/database/tableXmrInvoices.go b/pkg/database/tableXmrInvoices.go @@ -22,8 +22,8 @@ type XmrInvoice struct { CreatedAt time.Time } -func CreateXmrInvoice(userID UserID, productID int64) (out XmrInvoice, err error) { - err = DB.Where("user_id = ? AND product_id = ? AND amount_received IS NULL", userID, productID).First(&out).Error +func (d *DkfDB) CreateXmrInvoice(userID UserID, productID int64) (out XmrInvoice, err error) { + err = d.db.Where("user_id = ? AND product_id = ? AND amount_received IS NULL", userID, productID).First(&out).Error if err == nil { return } @@ -37,14 +37,14 @@ func CreateXmrInvoice(userID UserID, productID int64) (out XmrInvoice, err error Address: resp.Address, AmountRequested: 10, } - if err = DB.Create(&out).Error; err != nil { + if err = d.db.Create(&out).Error; err != nil { return } return } -func GetXmrInvoiceByAddress(address string) (out XmrInvoice, err error) { - err = DB.Where("address = ?", address).First(&out).Error +func (d *DkfDB) GetXmrInvoiceByAddress(address string) (out XmrInvoice, err error) { + err = d.db.Where("address = ?", address).First(&out).Error return } @@ -68,8 +68,8 @@ func (i XmrInvoice) GetImage() (image.Image, error) { return b, nil } -func (i *XmrInvoice) DoSave() { - if err := DB.Save(i).Error; err != nil { +func (i *XmrInvoice) DoSave(db *DkfDB) { + if err := db.db.Save(i).Error; err != nil { logrus.Error(err) } } diff --git a/pkg/database/table_forum_threads.go b/pkg/database/table_forum_threads.go @@ -41,19 +41,19 @@ func MakeForumThread(threadName string, userID UserID, categoryID ForumCategoryI return ForumThread{UUID: ForumThreadUUID(uuid.New().String()), Name: threadName, UserID: userID, CategoryID: categoryID} } -func (u *ForumThread) DoSave() { - if err := DB.Save(u).Error; err != nil { +func (u *ForumThread) DoSave(db *DkfDB) { + if err := db.db.Save(u).Error; err != nil { logrus.Error(err) } } -func GetForumCategories() (out []ForumCategory, err error) { - err = DB.Find(&out).Order("idx ASC, name ASC").Error +func (d *DkfDB) GetForumCategories() (out []ForumCategory, err error) { + err = d.db.Find(&out).Order("idx ASC, name ASC").Error return } -func GetForumCategoryBySlug(slug string) (out ForumCategory, err error) { - err = DB.First(&out, "slug = ?", slug).Error +func (d *DkfDB) GetForumCategoryBySlug(slug string) (out ForumCategory, err error) { + err = d.db.First(&out, "slug = ?", slug).Error return } @@ -82,29 +82,29 @@ type ForumReadRecord struct { ReadAt time.Time } -func UpdateForumReadRecord(userID UserID, threadID ForumThreadID) { +func (d *DkfDB) UpdateForumReadRecord(userID UserID, threadID ForumThreadID) { now := time.Now() - res := DB.Table("forum_read_records").Where("user_id = ? AND thread_id = ?", userID, threadID).Update("read_at", now) + res := d.db.Table("forum_read_records").Where("user_id = ? AND thread_id = ?", userID, threadID).Update("read_at", now) if res.RowsAffected == 0 { - DB.Create(ForumReadRecord{UserID: userID, ThreadID: threadID, ReadAt: now}) + d.db.Create(ForumReadRecord{UserID: userID, ThreadID: threadID, ReadAt: now}) } } // DoSave user in the database, ignore error -func (u *ForumReadRecord) DoSave() { - if err := DB.Save(u).Error; err != nil { +func (u *ForumReadRecord) DoSave(db *DkfDB) { + if err := db.db.Save(u).Error; err != nil { logrus.Error(err) } } // DoSave user in the database, ignore error -func (u *ForumMessage) DoSave() { - if err := DB.Save(u).Error; err != nil { +func (u *ForumMessage) DoSave(db *DkfDB) { + if err := db.db.Save(u).Error; err != nil { logrus.Error(err) } } -func (m *ForumMessage) Escape() string { +func (m *ForumMessage) Escape(db *DkfDB) string { msg := m.Message if m.IsSigned { if b, _ := clearsign.Decode([]byte(msg)); b != nil { @@ -120,7 +120,7 @@ func (m *ForumMessage) Escape() string { var tagRgx = regexp.MustCompile(`@(\w{3,20})`) if tagRgx.MatchString(res) { res = tagRgx.ReplaceAllStringFunc(res, func(s string) string { - if user, err := GetUserByUsername(strings.TrimPrefix(s, "@")); err == nil { + if user, err := db.GetUserByUsername(strings.TrimPrefix(s, "@")); err == nil { return `<span style="color: ` + user.ChatColor + `;">` + s + `</span>` } return s @@ -129,22 +129,22 @@ func (m *ForumMessage) Escape() string { return res } -func GetForumMessage(messageID ForumMessageID) (out ForumMessage, err error) { - err = DB.First(&out, "id = ?", messageID).Error +func (d *DkfDB) GetForumMessage(messageID ForumMessageID) (out ForumMessage, err error) { + err = d.db.First(&out, "id = ?", messageID).Error return } -func GetForumMessageByUUID(messageUUID ForumMessageUUID) (out ForumMessage, err error) { - err = DB.First(&out, "uuid = ?", messageUUID).Error +func (d *DkfDB) GetForumMessageByUUID(messageUUID ForumMessageUUID) (out ForumMessage, err error) { + err = d.db.First(&out, "uuid = ?", messageUUID).Error return } -func DeleteForumMessageByID(messageID ForumMessageID) error { - return DB.Where("id = ?", messageID).Delete(&ForumMessage{}).Error +func (d *DkfDB) DeleteForumMessageByID(messageID ForumMessageID) error { + return d.db.Where("id = ?", messageID).Delete(&ForumMessage{}).Error } -func DeleteForumThreadByID(threadID ForumThreadID) error { - return DB.Where("id = ?", threadID).Delete(&ForumThread{}).Error +func (d *DkfDB) DeleteForumThreadByID(threadID ForumThreadID) error { + return d.db.Where("id = ?", threadID).Delete(&ForumThread{}).Error } func (m *ForumMessage) CanEdit() bool { @@ -159,23 +159,23 @@ func (m *ForumMessage) ValidateSignature(pkey string) bool { return utils.PgpCheckClearSignMessage(pkey, m.Message) } -func GetForumThread(threadID ForumThreadID) (out ForumThread, err error) { - err = DB.First(&out, "id = ? AND is_club = 1", threadID).Error +func (d *DkfDB) GetForumThread(threadID ForumThreadID) (out ForumThread, err error) { + err = d.db.First(&out, "id = ? AND is_club = 1", threadID).Error return } -func GetForumThreadByID(threadID ForumThreadID) (out ForumThread, err error) { - err = DB.First(&out, "id = ? AND is_club = 0", threadID).Error +func (d *DkfDB) GetForumThreadByID(threadID ForumThreadID) (out ForumThread, err error) { + err = d.db.First(&out, "id = ? AND is_club = 0", threadID).Error return } -func GetForumThreadByUUID(threadUUID ForumThreadUUID) (out ForumThread, err error) { - err = DB.First(&out, "uuid = ? AND is_club = 0", threadUUID).Error +func (d *DkfDB) GetForumThreadByUUID(threadUUID ForumThreadUUID) (out ForumThread, err error) { + err = d.db.First(&out, "uuid = ? AND is_club = 0", threadUUID).Error return } -func GetForumThreads() (out []ForumThread, err error) { - err = DB.Order("id DESC").Find(&out).Error +func (d *DkfDB) GetForumThreads() (out []ForumThread, err error) { + err = d.db.Order("id DESC").Find(&out).Error return } @@ -191,8 +191,8 @@ type ForumThreadAug struct { RepliesCount int64 } -func GetClubForumThreads(userID UserID) (out []ForumThreadAug, err error) { - err = DB.Raw(`SELECT t.*, +func (d *DkfDB) GetClubForumThreads(userID UserID) (out []ForumThreadAug, err error) { + err = d.db.Raw(`SELECT t.*, u.username as author, u.chat_color as author_chat_color, lu.username as last_msg_author, @@ -210,8 +210,8 @@ ORDER BY t.id DESC`, userID).Scan(&out).Error return } -func GetPublicForumCategoryThreads(userID UserID, categoryID ForumCategoryID) (out []ForumThreadAug, err error) { - err = DB.Raw(`SELECT t.*, +func (d *DkfDB) GetPublicForumCategoryThreads(userID UserID, categoryID ForumCategoryID) (out []ForumThreadAug, err error) { + err = d.db.Raw(`SELECT t.*, u.username as author, u.chat_color as author_chat_color, lu.username as last_msg_author, @@ -236,8 +236,8 @@ ORDER BY m.created_at DESC, t.id DESC`, userID, categoryID).Scan(&out).Error return } -func GetPublicForumThreadsSearch(userID UserID) (out []ForumThreadAug, err error) { - err = DB.Raw(`SELECT t.*, +func (d *DkfDB) GetPublicForumThreadsSearch(userID UserID) (out []ForumThreadAug, err error) { + err = d.db.Raw(`SELECT t.*, u.username as author, u.chat_color as author_chat_color, lu.username as last_msg_author, @@ -262,7 +262,7 @@ ORDER BY m.created_at DESC, t.id DESC`, userID).Scan(&out).Error return } -func GetThreadMessages(threadID ForumThreadID) (out []ForumMessage, err error) { - err = DB.Preload("User").Find(&out, "thread_id = ?", threadID).Error +func (d *DkfDB) GetThreadMessages(threadID ForumThreadID) (out []ForumMessage, err error) { + err = d.db.Preload("User").Find(&out, "thread_id = ?", threadID).Error return } diff --git a/pkg/database/utils/utils.go b/pkg/database/utils/utils.go @@ -13,15 +13,15 @@ import ( "net/http" ) -func GetZeroUser() database.User { - zeroUser, err := database.GetUserByUsername(config.NullUsername) +func GetZeroUser(db *database.DkfDB) database.User { + zeroUser, err := db.GetUserByUsername(config.NullUsername) if err != nil { logrus.Fatal(err) } return zeroUser } -func SendNewChessGameMessages(key, roomKey string, roomID database.RoomID, zeroUser, player1, player2 database.User) { +func SendNewChessGameMessages(db *database.DkfDB, key, roomKey string, roomID database.RoomID, zeroUser, player1, player2 database.User) { // Send game link to players getPlayerMsg := func(opponent database.User) (raw string, msg string) { raw = `Chess game against ` + opponent.Username @@ -29,19 +29,19 @@ func SendNewChessGameMessages(key, roomKey string, roomID database.RoomID, zeroU return } raw, msg := getPlayerMsg(player2) - _, _ = database.CreateMsg(raw, msg, roomKey, roomID, zeroUser.ID, &player1.ID) + _, _ = db.CreateMsg(raw, msg, roomKey, roomID, zeroUser.ID, &player1.ID) raw, msg = getPlayerMsg(player1) - _, _ = database.CreateMsg(raw, msg, roomKey, roomID, zeroUser.ID, &player2.ID) + _, _ = db.CreateMsg(raw, msg, roomKey, roomID, zeroUser.ID, &player2.ID) // Send notifications to chess games subscribers raw = `Chess game: ` + player1.Username + ` VS ` + player2.Username msg = `<a href="/chess/` + key + `" rel="noopener noreferrer" target="_blank">Chess game: ` + player1.Username + ` VS ` + player2.Username + `</a>` - users, _ := database.GetChessSubscribers() + users, _ := db.GetChessSubscribers() for _, user := range users { if user.ID == player1.ID || user.ID == player2.ID { continue } - _, _ = database.CreateMsg(raw, msg, roomKey, roomID, zeroUser.ID, &user.ID) + _, _ = db.CreateMsg(raw, msg, roomKey, roomID, zeroUser.ID, &user.ID) } } @@ -89,22 +89,22 @@ func DoParseRoomID(v string) (out database.RoomID) { return DoParse[database.RoomID](v) } -func Kick(kicked, kickedBy database.User, purge, silent bool) error { +func Kick(db *database.DkfDB, kicked, kickedBy database.User, purge, silent bool) error { if kicked.IsHellbanned { silent = true } - return kick(kicked, kickedBy, silent, purge) + return kick(db, kicked, kickedBy, silent, purge) } -func SilentKick(kicked, kickedBy database.User) error { - return kick(kicked, kickedBy, true, true) +func SilentKick(db *database.DkfDB, kicked, kickedBy database.User) error { + return kick(db, kicked, kickedBy, true, true) } -func SelfKick(kicked database.User, silent bool) error { - return kick(kicked, kicked, silent, true) +func SelfKick(db *database.DkfDB, kicked database.User, silent bool) error { + return kick(db, kicked, kicked, silent, true) } -func kick(kicked, kickedBy database.User, silent, purge bool) error { +func kick(db *database.DkfDB, kicked, kickedBy database.User, silent, purge bool) error { if !kicked.Verified { return errors.New("user already kicked") } @@ -117,16 +117,16 @@ func kick(kicked, kickedBy database.User, silent, purge bool) error { return errors.New("cannot kick another moderator") } - database.NewAudit(kickedBy, fmt.Sprintf("kick %s #%d", kicked.Username, kicked.ID)) + db.NewAudit(kickedBy, fmt.Sprintf("kick %s #%d", kicked.Username, kicked.ID)) kicked.Verified = false - kicked.DoSave() + kicked.DoSave(db) // Remove user from the user cache managers.ActiveUsers.RemoveUser(kicked.ID) if purge { // Purge user messages - if err := database.DeleteUserChatMessages(kicked.ID); err != nil { + if err := db.DeleteUserChatMessages(kicked.ID); err != nil { logrus.Error(err) } } @@ -134,15 +134,15 @@ func kick(kicked, kickedBy database.User, silent, purge bool) error { // If user is HB, do not display system message if !silent { // Display kick message - database.CreateKickMsg(kicked, kickedBy) + db.CreateKickMsg(kicked, kickedBy) } return nil } -func GetRoomAndKey(c echo.Context, roomName string) (database.ChatRoom, string, error) { +func GetRoomAndKey(db *database.DkfDB, c echo.Context, roomName string) (database.ChatRoom, string, error) { roomKey := "" - room, err := database.GetChatRoomByName(roomName) + room, err := db.GetChatRoomByName(roomName) if err != nil { return room, roomKey, c.NoContent(http.StatusNotFound) } diff --git a/pkg/global/global.go b/pkg/global/global.go @@ -20,16 +20,16 @@ func DeleteUserNotificationCount(userID database.UserID, sessionToken string) { notifCountCache.Delete(cacheKey(userID, sessionToken)) } -func GetUserNotificationCount(userID database.UserID, sessionToken string) int64 { +func GetUserNotificationCount(db *database.DkfDB, userID database.UserID, sessionToken string) int64 { count, found := notifCountCache.Get(cacheKey(userID, sessionToken)) if found { return count } - count = database.GetUserInboxMessagesCount(userID) - count += database.GetUserNotificationsCount(userID) + count = db.GetUserInboxMessagesCount(userID) + count += db.GetUserNotificationsCount(userID) // sessionToken can be empty when using the API if sessionToken != "" { - count += database.GetUserSessionNotificationsCount(sessionToken) + count += db.GetUserSessionNotificationsCount(sessionToken) } notifCountCache.SetD(cacheKey(userID, sessionToken), count) return count diff --git a/pkg/template/templates.go b/pkg/template/templates.go @@ -50,6 +50,7 @@ type templateDataStruct struct { ShaHTML template.HTML LogoASCII template.HTML NullUsername string + DB *database.DkfDB CSRF string Master bool Development bool @@ -64,6 +65,8 @@ type templateDataStruct struct { func (t *Templates) Render(w io.Writer, name string, data any, c echo.Context) error { tmpl := t.Templates[name] + db := c.Get("database").(*database.DkfDB) + d := templateDataStruct{} d.TmplName = name d.Data = data @@ -77,6 +80,7 @@ func (t *Templates) Render(w io.Writer, name string, data any, c echo.Context) e d.ShaHTML = template.HTML(fmt.Sprintf("<!-- SHA: %s -->", config.Global.Sha())) d.NullUsername = config.NullUsername d.CSRF, _ = c.Get("csrf").(string) + d.DB = db d.AcceptLanguage = c.Get("accept-language").(string) d.Lang = c.Get("lang").(string) d.BaseKeywords = strings.Join(getBaseKeywords(), ", ") @@ -94,7 +98,7 @@ func (t *Templates) Render(w io.Writer, name string, data any, c echo.Context) e if authCookie, err := c.Cookie(hutils.AuthCookieName); err == nil { sessionToken = authCookie.Value } - d.InboxCount = global.GetUserNotificationCount(d.AuthUser.ID, sessionToken) + d.InboxCount = global.GetUserNotificationCount(db, d.AuthUser.ID, sessionToken) } return tmpl.ExecuteTemplate(w, "base", d) diff --git a/pkg/web/handlers/admin.go b/pkg/web/handlers/admin.go @@ -19,6 +19,7 @@ import ( func AdminNewGistHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) var data adminCreateGistData data.ActiveTab = "gists" if c.Request().Method == http.MethodPost { @@ -33,9 +34,8 @@ func AdminNewGistHandler(c echo.Context) error { if data.Password != "" { passwordHash = database.GetGistPasswordHash(data.Password) } - gist := database.Gist{Name: data.Name, Password: passwordHash, UserID: authUser.ID, Content: data.Content} - gist.UUID = uuid.New().String() - if err := database.DB.Create(&gist).Error; err != nil { + gist := database.Gist{UUID: uuid.New().String(), Name: data.Name, Password: passwordHash, UserID: authUser.ID, Content: data.Content} + if err := db.DB().Create(&gist).Error; err != nil { data.Error = err.Error() return c.Render(http.StatusOK, "admin.gist-create", data) } @@ -46,8 +46,9 @@ func AdminNewGistHandler(c echo.Context) error { func AdminEditGistHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) gistUUID := c.Param("gistUUID") - gist, err := database.GetGistByUUID(gistUUID) + gist, err := db.GetGistByUUID(gistUUID) if err != nil { return c.Redirect(http.StatusFound, "/") } @@ -78,19 +79,20 @@ func AdminEditGistHandler(c echo.Context) error { } gist.Name = data.Name gist.Content = data.Content - gist.DoSave() + gist.DoSave(db) return c.Redirect(http.StatusFound, "/gists/"+gist.UUID) } return c.Render(http.StatusOK, "admin.gist-create", data) } func AdminGistsHandler(c echo.Context) error { + db := c.Get("database").(*database.DkfDB) var data adminGistsData data.ActiveTab = "gists" userQuery := c.QueryParam("u") - if err := database.DB.Table("gists"). + if err := db.DB().Table("gists"). Scopes(func(query *gorm.DB) *gorm.DB { if userQuery != "" { query = query.Where("user_id = ?", userQuery) @@ -108,15 +110,16 @@ func AdminGistsHandler(c echo.Context) error { } func AdminUploadsHandler(c echo.Context) error { + db := c.Get("database").(*database.DkfDB) if c.Request().Method == http.MethodPost { formName := c.FormValue("formName") if formName == "deleteUpload" { fileName := c.Request().PostFormValue("file_name") - file, err := database.GetUploadByFileName(fileName) + file, err := db.GetUploadByFileName(fileName) if err != nil { return c.Redirect(http.StatusFound, "/") } - if err := file.Delete(); err != nil { + if err := file.Delete(db); err != nil { logrus.Error(err) } return c.Redirect(http.StatusFound, c.Request().Referer()) @@ -125,7 +128,7 @@ func AdminUploadsHandler(c echo.Context) error { var data adminUploadsData data.ActiveTab = "uploads" - data.Uploads, _ = database.GetUploads() + data.Uploads, _ = db.GetUploads() for _, f := range data.Uploads { data.TotalSize += f.FileSize } @@ -133,21 +136,22 @@ func AdminUploadsHandler(c echo.Context) error { } func AdminFiledropsHandler(c echo.Context) error { + db := c.Get("database").(*database.DkfDB) if c.Request().Method == http.MethodPost { formName := c.FormValue("formName") if formName == "createFiledrop" { - if _, err := database.CreateFiledrop(); err != nil { + if _, err := db.CreateFiledrop(); err != nil { logrus.Error(err) } return c.Redirect(http.StatusFound, c.Request().Referer()) } else if formName == "deleteFiledrop" { fileName := c.Request().PostFormValue("file_name") - file, err := database.GetFiledropByFileName(fileName) + file, err := db.GetFiledropByFileName(fileName) if err != nil { return c.Redirect(http.StatusFound, "/") } - if err := file.Delete(); err != nil { + if err := file.Delete(db); err != nil { logrus.Error(err) } return c.Redirect(http.StatusFound, c.Request().Referer()) @@ -156,7 +160,7 @@ func AdminFiledropsHandler(c echo.Context) error { var data adminFiledropsData data.ActiveTab = "filedrops" - data.Filedrops, _ = database.GetFiledrops() + data.Filedrops, _ = db.GetFiledrops() for _, f := range data.Filedrops { data.TotalSize += f.FileSize } @@ -164,12 +168,13 @@ func AdminFiledropsHandler(c echo.Context) error { } func AdminDownloadsHandler(c echo.Context) error { + db := c.Get("database").(*database.DkfDB) var data adminDownloadsData data.ActiveTab = "downloads" userQuery := c.QueryParam("u") - database.DB.Model(&database.Download{}). + db.DB().Model(&database.Download{}). Scopes(func(query *gorm.DB) *gorm.DB { if userQuery != "" { query = query.Where("user_id = ?", userQuery) @@ -185,13 +190,14 @@ func AdminDownloadsHandler(c echo.Context) error { } func AdminDeleteDownloadHandler(c echo.Context) error { + db := c.Get("database").(*database.DkfDB) downloadID, err := utils.ParseInt64(c.Param("downloadID")) if err != nil { return c.Render(http.StatusOK, "flash", FlashResponse{"download id not found", c.Request().Referer(), "alert-danger"}) } - if err := database.DeleteDownloadByID(downloadID); err != nil { + if err := db.DeleteDownloadByID(downloadID); err != nil { logrus.Error(err) } return c.Redirect(http.StatusFound, c.Request().Referer()) @@ -199,9 +205,10 @@ func AdminDeleteDownloadHandler(c echo.Context) error { // AdminSettingsHandler ... func AdminSettingsHandler(c echo.Context) error { + db := c.Get("database").(*database.DkfDB) var data adminSettingsData data.ActiveTab = "settings" - settings := database.GetSettings() + settings := db.GetSettings() data.ProtectHome = settings.ProtectHome data.HomeUsersList = settings.HomeUsersList data.ForceLoginCaptcha = settings.ForceLoginCaptcha @@ -221,7 +228,7 @@ func AdminSettingsHandler(c echo.Context) error { return c.Redirect(http.StatusFound, c.Request().Referer()) } else if formName == "saveSettings" { - settings := database.GetSettings() + settings := db.GetSettings() settings.ProtectHome = utils.DoParseBool(c.Request().PostFormValue("protectHome")) settings.HomeUsersList = utils.DoParseBool(c.Request().PostFormValue("homeUsersList")) settings.ForceLoginCaptcha = utils.DoParseBool(c.Request().PostFormValue("forceLoginCaptcha")) @@ -231,7 +238,7 @@ func AdminSettingsHandler(c echo.Context) error { settings.ForumEnabled = utils.DoParseBool(c.Request().PostFormValue("forumEnabled")) settings.MaybeAuthEnabled = utils.DoParseBool(c.Request().PostFormValue("maybeAuthEnabled")) settings.CaptchaDifficulty = utils.DoParseInt64(c.Request().PostFormValue("captchaDifficulty")) - settings.DoSave() + settings.DoSave(db) config.ProtectHome.Store(settings.ProtectHome) config.HomeUsersList.Store(settings.HomeUsersList) config.ForceLoginCaptcha.Store(settings.ForceLoginCaptcha) @@ -250,12 +257,13 @@ func AdminSettingsHandler(c echo.Context) error { } func AdminHandler(c echo.Context) error { + db := c.Get("database").(*database.DkfDB) var data adminData data.ActiveTab = "users" data.Query = strings.TrimSpace(c.QueryParam("q")) likeQuery := "%" + data.Query + "%" - if err := database.DB. + if err := db.DB(). Table("users"). Scopes(func(query *gorm.DB) *gorm.DB { if data.Query != "" { @@ -273,12 +281,13 @@ func AdminHandler(c echo.Context) error { } func SessionsHandler(c echo.Context) error { + db := c.Get("database").(*database.DkfDB) var data adminSessionsData data.ActiveTab = "sessions" data.Query = c.QueryParam("q") likeQuery := "%" + data.Query + "%" - if err := database.DB. + if err := db.DB(). Table("sessions"). Where("deleted_at IS NULL"). Scopes(func(query *gorm.DB) *gorm.DB { @@ -297,12 +306,13 @@ func SessionsHandler(c echo.Context) error { } func IgnoredHandler(c echo.Context) error { + db := c.Get("database").(*database.DkfDB) var data adminIgnoredData data.ActiveTab = "ignored" data.Query = c.QueryParam("q") likeQuery := "%" + data.Query + "%" - if err := database.DB. + if err := db.DB(). Table("ignored_users"). Scopes(func(query *gorm.DB) *gorm.DB { if data.Query != "" { @@ -367,10 +377,11 @@ func BackupHandler(c echo.Context) error { } func AdminAuditsHandler(c echo.Context) error { + db := c.Get("database").(*database.DkfDB) var data adminAuditsData data.ActiveTab = "audits" - if err := database.DB. + if err := db.DB(). Table("audit_logs"). Scopes(func(query *gorm.DB) *gorm.DB { data.CurrentPage, data.MaxPage, data.AuditLogsCount, query = NewPaginator().Paginate(c, query) @@ -385,12 +396,13 @@ func AdminAuditsHandler(c echo.Context) error { } func AdminRoomsHandler(c echo.Context) error { + db := c.Get("database").(*database.DkfDB) var data adminRoomsData data.ActiveTab = "rooms" data.Query = c.QueryParam("q") likeQuery := "%" + data.Query + "%" - if err := database.DB. + if err := db.DB(). Table("chat_rooms"). Scopes(func(query *gorm.DB) *gorm.DB { if data.Query != "" { @@ -408,10 +420,11 @@ func AdminRoomsHandler(c echo.Context) error { } func AdminCaptchaHandler(c echo.Context) error { + db := c.Get("database").(*database.DkfDB) var data adminCaptchaData data.ActiveTab = "captcha" - if err := database.DB.Table("captcha_requests"). + if err := db.DB().Table("captcha_requests"). Scopes(func(query *gorm.DB) *gorm.DB { data.CurrentPage, data.MaxPage, data.CaptchasCount, query = NewPaginator().Paginate(c, query) return query @@ -426,6 +439,7 @@ func AdminCaptchaHandler(c echo.Context) error { // AdminDeleteUserHandler ... func AdminDeleteUserHandler(c echo.Context) error { + db := c.Get("database").(*database.DkfDB) userID, err := dutils.ParseUserID(c.Param("userID")) if err != nil { return c.Render(http.StatusOK, "flash", @@ -436,41 +450,45 @@ func AdminDeleteUserHandler(c echo.Context) error { FlashResponse{"Root admin cannot be deleted", c.Request().Referer(), "alert-danger"}) } - if err := database.DeleteUserByID(userID); err != nil { + if err := db.DeleteUserByID(userID); err != nil { logrus.Error(err) } return c.Redirect(http.StatusFound, c.Request().Referer()) } func IgnoredDeleteHandler(c echo.Context) error { + db := c.Get("database").(*database.DkfDB) userID := dutils.DoParseUserID(c.Request().PostFormValue("user_id")) ignoredUserID := dutils.DoParseUserID(c.Request().PostFormValue("ignored_user_id")) - database.UnIgnoreUser(userID, ignoredUserID) + db.UnIgnoreUser(userID, ignoredUserID) return c.Redirect(http.StatusFound, c.Request().Referer()) } // AdminDeleteRoomHandler ... func AdminDeleteRoomHandler(c echo.Context) error { + db := c.Get("database").(*database.DkfDB) id := dutils.DoParseRoomID(c.Param("roomID")) - database.DeleteChatRoomByID(id) + db.DeleteChatRoomByID(id) return c.Redirect(http.StatusFound, c.Request().Referer()) } func AdminUserSecurityLogsHandler(c echo.Context) error { //authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) userID, err := dutils.ParseUserID(c.Param("userID")) if err != nil { return c.Redirect(http.StatusFound, "/admin/users") } var data settingsSecurityData data.ActiveTab = "security" - data.Logs, _ = database.GetSecurityLogs(userID) + data.Logs, _ = db.GetSecurityLogs(userID) return c.Render(http.StatusOK, "admin.user-security-logs", data) } // AdminEditUserHandler ... func AdminEditUserHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) userID, err := dutils.ParseUserID(c.Param("userID")) if err != nil { return c.Redirect(http.StatusFound, "/admin/users") @@ -479,7 +497,7 @@ func AdminEditUserHandler(c echo.Context) error { if userID == config.RootAdminID && authUser.ID != config.RootAdminID { return c.Redirect(http.StatusFound, "/admin/users") } - user, err := database.GetUserByID(userID) + user, err := db.GetUserByID(userID) if err != nil { return c.Redirect(http.StatusFound, "/admin/users") } @@ -516,7 +534,7 @@ func AdminEditUserHandler(c echo.Context) error { formName := c.Request().PostFormValue("formName") if formName == "reset_tutorial" { user.ChatTutorial = 0 - user.DoSave() + user.DoSave(db) return c.Redirect(http.StatusFound, c.Request().Referer()) } @@ -539,7 +557,7 @@ func AdminEditUserHandler(c echo.Context) error { data.ChatColor = c.FormValue("chat_color") data.ChatFont = utils.DoParseInt64(c.FormValue("chat_font")) if data.Username != user.Username { - if err := database.CanRenameTo(user.Username, data.Username); err != nil { + if err := db.CanRenameTo(user.Username, data.Username); err != nil { data.Errors.Username = err.Error() } } @@ -549,7 +567,7 @@ func AdminEditUserHandler(c echo.Context) error { data.RePassword = c.Request().PostFormValue("repassword") data.ApiKey = c.Request().PostFormValue("api_key") if data.Password != "" || data.RePassword != "" { - hashedPassword, err = database.NewPasswordValidator(data.Password).CompareWith(data.RePassword).Hash() + hashedPassword, err = database.NewPasswordValidator(db, data.Password).CompareWith(data.RePassword).Hash() if err != nil { data.Errors.Password = err.Error() } @@ -560,7 +578,7 @@ func AdminEditUserHandler(c echo.Context) error { if hashedPassword != "" { user.LoginAttempts = 0 - if err := user.ChangePassword(hashedPassword); err != nil { + if err := user.ChangePassword(db, hashedPassword); err != nil { data.Errors.Password = err.Error() return c.Render(http.StatusOK, "admin.user-edit", data) } @@ -569,7 +587,7 @@ func AdminEditUserHandler(c echo.Context) error { user.Username = data.Username user.IsAdmin = data.IsAdmin if data.IsHellbanned { - user.HellBan() + user.HellBan(db) managers.ActiveUsers.UpdateUserHBInRooms(managers.NewUserInfo(user, nil)) } user.ApiKey = data.ApiKey @@ -589,18 +607,19 @@ func AdminEditUserHandler(c echo.Context) error { user.Role = data.Role user.ChatColor = data.ChatColor user.ChatFont = data.ChatFont - user.DoSave() + user.DoSave(db) return c.Redirect(http.StatusFound, c.Request().Referer()) } // AdminEditRoomHandler ... func AdminEditRoomHandler(c echo.Context) error { + db := c.Get("database").(*database.DkfDB) roomID, err := dutils.ParseRoomID(c.Param("roomID")) if err != nil { return c.Redirect(http.StatusFound, "/admin/rooms") } - room, err := database.GetChatRoomByID(roomID) + room, err := db.GetChatRoomByID(roomID) if err != nil { return c.Redirect(http.StatusFound, "/admin/rooms") } @@ -618,7 +637,7 @@ func AdminEditRoomHandler(c echo.Context) error { room.IsEphemeral = data.IsEphemeral room.IsListed = data.IsListed - room.DoSave() + room.DoSave(db) return c.Redirect(http.StatusFound, "/admin/rooms") } diff --git a/pkg/web/handlers/api/v1/bangInterceptor.go b/pkg/web/handlers/api/v1/bangInterceptor.go @@ -20,13 +20,13 @@ Chats: Black Hat Chat: ` + config.BhcOnion + ` Forums: CryptBB: ` + config.CryptbbOnion - msg, _ := ProcessRawMessage(message, "", cmd.authUser.ID, cmd.room.ID, nil) + msg, _ := ProcessRawMessage(cmd.db, message, "", cmd.authUser.ID, cmd.room.ID, nil) cmd.zeroMsg(msg) cmd.err = ErrRedirect } func handleRtutoBangCmd(cmd *Command) { cmd.authUser.ChatTutorial = 0 - cmd.authUser.DoSave() + cmd.authUser.DoSave(cmd.db) cmd.err = ErrRedirect } diff --git a/pkg/web/handlers/api/v1/battleship.go b/pkg/web/handlers/api/v1/battleship.go @@ -213,13 +213,14 @@ func (g *BSGame) Shot(pos string) (shipStr string, shipDead, gameEnded bool, err type Battleship struct { sync.Mutex + db *database.DkfDB zeroID database.UserID games map[string]*BSGame } -func NewBattleship() *Battleship { - zeroUser, _ := database.GetUserByUsername(config.NullUsername) - b := &Battleship{zeroID: zeroUser.ID} +func NewBattleship(db *database.DkfDB) *Battleship { + zeroUser, _ := db.GetUserByUsername(config.NullUsername) + b := &Battleship{db: db, zeroID: zeroUser.ID} b.games = make(map[string]*BSGame) // Thread that cleanup inactive games @@ -553,7 +554,7 @@ func (b *Battleship) playMove(roomName string, roomID database.RoomID, roomKey s b.Lock() defer b.Unlock() - user, err := database.GetUserByUsername(enemyUsername) + user, err := b.db.GetUserByUsername(enemyUsername) if err != nil { return errors.New("invalid username") } @@ -586,17 +587,17 @@ func (b *Battleship) playMove(roomName string, roomID database.RoomID, roomKey s } // Delete old messages sent by "0" to the players - if err := database.DB. + if err := b.db.DB(). Where("room_id = ? AND user_id = ? AND (to_user_id = ? OR to_user_id = ?)", roomID, b.zeroID, g.player1.id, g.player2.id). Delete(&database.ChatMessage{}).Error; err != nil { logrus.Error(err) } card1 := g.drawCardFor(0, roomName, isNewGame, shipDead, gameEnded, shipStr, pos) - _, _ = database.CreateMsg(card1, card1, roomKey, roomID, b.zeroID, &g.player1.id) + _, _ = b.db.CreateMsg(card1, card1, roomKey, roomID, b.zeroID, &g.player1.id) card2 := g.drawCardFor(1, roomName, isNewGame, shipDead, gameEnded, shipStr, pos) - _, _ = database.CreateMsg(card2, card2, roomKey, roomID, b.zeroID, &g.player2.id) + _, _ = b.db.CreateMsg(card2, card2, roomKey, roomID, b.zeroID, &g.player2.id) if gameEnded { delete(b.games, gameKey) diff --git a/pkg/web/handlers/api/v1/chess.go b/pkg/web/handlers/api/v1/chess.go @@ -64,13 +64,14 @@ func newChessGame(gameKey string, player1, player2 database.User) *ChessGame { type Chess struct { sync.Mutex + db *database.DkfDB zeroID database.UserID games map[string]*ChessGame } -func NewChess() *Chess { - zeroUser, _ := database.GetUserByUsername(config.NullUsername) - c := &Chess{zeroID: zeroUser.ID} +func NewChess(db *database.DkfDB) *Chess { + zeroUser, _ := db.GetUserByUsername(config.NullUsername) + c := &Chess{db: db, zeroID: zeroUser.ID} c.games = make(map[string]*ChessGame) // Thread that cleanup inactive games @@ -401,8 +402,8 @@ func (b *Chess) NewGame1(roomKey string, roomID database.RoomID, player1, player key := uuid.New().String() g := b.NewGame(key, player1, player2) - zeroUser := dutils.GetZeroUser() - dutils.SendNewChessGameMessages(key, roomKey, roomID, zeroUser, player1, player2) + zeroUser := dutils.GetZeroUser(b.db) + dutils.SendNewChessGameMessages(b.db, key, roomKey, roomID, zeroUser, player1, player2) return g, nil } @@ -471,11 +472,11 @@ func (b *Chess) SendMove(gameKey string, userID database.UserID, g *ChessGame, c // Notify (pm) the opponent that you made a move if opponent.NotifyChessMove { msg := fmt.Sprintf("@%s played %s", you.Username, moveStr) - msg, _ = colorifyTaggedUsers(msg, database.GetUsersByUsername) - chatMsg, _ := database.CreateMsg(msg, msg, "", config.GeneralRoomID, b.zeroID, &opponent.ID) + msg, _ = colorifyTaggedUsers(msg, b.db.GetUsersByUsername) + chatMsg, _ := b.db.CreateMsg(msg, msg, "", config.GeneralRoomID, b.zeroID, &opponent.ID) go func() { time.Sleep(30 * time.Second) - _ = database.DeleteChatMessageByUUID(chatMsg.UUID) + _ = b.db.DeleteChatMessageByUUID(chatMsg.UUID) }() } @@ -500,7 +501,7 @@ func (b *Chess) playMove(enemyUsername, pos string, authUser database.User, c ec b.Lock() defer b.Unlock() - user, err := database.GetUserByUsername(enemyUsername) + user, err := b.db.GetUserByUsername(enemyUsername) if err != nil { return errors.New("invalid username") } @@ -525,17 +526,17 @@ func (b *Chess) playMove(enemyUsername, pos string, authUser database.User, c ec } // Delete old messages sent by "0" to the players - if err := database.DB. + if err := b.db.DB(). Where("room_id = ? AND user_id = ? AND (to_user_id = ? OR to_user_id = ?)", roomID, b.zeroID, g.Player1.ID, g.Player2.ID). Delete(&database.ChatMessage{}).Error; err != nil { logrus.Error(err) } card1 := g.DrawPlayerCard(roomName, true, false, true) - _, _ = database.CreateMsg(card1, card1, roomKey, roomID, b.zeroID, &g.Player1.ID) + _, _ = b.db.CreateMsg(card1, card1, roomKey, roomID, b.zeroID, &g.Player1.ID) card1 = g.DrawPlayerCard(roomName, true, true, true) - _, _ = database.CreateMsg(card1, card1, roomKey, roomID, b.zeroID, &g.Player2.ID) + _, _ = b.db.CreateMsg(card1, card1, roomKey, roomID, b.zeroID, &g.Player2.ID) return nil } diff --git a/pkg/web/handlers/api/v1/handlers.go b/pkg/web/handlers/api/v1/handlers.go @@ -83,12 +83,13 @@ var unhideRgx = regexp.MustCompile(`^/unhide (\d{2}:\d{2}:\d{2})$`) func ChatMessagesHandler(c echo.Context) error { authCookie, _ := c.Cookie(hutils.AuthCookieName) authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) roomName := c.Param("roomName") pmOnlyQuery := dutils.DoParsePmDisplayMode(c.QueryParam("pmonly")) mentionsOnlyQuery := utils.DoParseBool(c.QueryParam("mentionsOnly")) - room, err := database.GetChatRoomByName(roomName) + room, err := db.GetChatRoomByName(roomName) if err != nil { return c.NoContent(http.StatusNotFound) } @@ -102,7 +103,7 @@ func ChatMessagesHandler(c echo.Context) error { // Only fill the ignored set if the user does not display the ignored users ("Toggle ignored" chat setting) // and if the user has "Hide ignored users from users lists" enabled (user setting) if !authUser.DisplayIgnored && authUser.HideIgnoredUsersFromList { - ignoredUsers, _ := database.GetIgnoredUsers(authUser.ID) + ignoredUsers, _ := db.GetIgnoredUsers(authUser.ID) for _, ignoredUser := range ignoredUsers { ignoredSet.Insert(ignoredUser.IgnoredUser.Username) } @@ -111,7 +112,7 @@ func ChatMessagesHandler(c echo.Context) error { membersInRoom, membersInChat := managers.ActiveUsers.GetRoomUsers(room, ignoredSet) displayHellbanned := authUser.DisplayHellbanned || authUser.IsHellbanned displayIgnoredMessages := false - msgs, _ := database.GetChatMessages(room.ID, authUser.Username, authUser.ID, pmOnlyQuery, mentionsOnlyQuery, displayHellbanned, authUser.DisplayIgnored, authUser.DisplayModerators, displayIgnoredMessages) + msgs, _ := db.GetChatMessages(room.ID, authUser.Username, authUser.ID, pmOnlyQuery, mentionsOnlyQuery, displayHellbanned, authUser.DisplayIgnored, authUser.DisplayModerators, displayIgnoredMessages) if room.IsProtected() { key, err := hutils.GetRoomKeyCookie(c, int64(room.ID)) if err != nil { @@ -123,7 +124,7 @@ func ChatMessagesHandler(c echo.Context) error { } // Update read record - database.UpdateChatReadRecord(authUser.ID, room.ID) + db.UpdateChatReadRecord(authUser.ID, room.ID) var data chatMessagesData @@ -171,11 +172,11 @@ func ChatMessagesHandler(c echo.Context) error { if authCookie != nil { sessionToken = authCookie.Value } - data.InboxCount = global.GetUserNotificationCount(authUser.ID, sessionToken) + data.InboxCount = global.GetUserNotificationCount(db, authUser.ID, sessionToken) - data.ReadMarker, _ = database.GetUserReadMarker(authUser.ID, room.ID) - data.OfficialRooms, _ = database.GetOfficialChatRooms1(authUser.ID) - data.SubscribedRooms, _ = database.GetUserRoomSubscriptions(authUser.ID) + data.ReadMarker, _ = db.GetUserReadMarker(authUser.ID, room.ID) + data.OfficialRooms, _ = db.GetOfficialChatRooms1(authUser.ID) + data.SubscribedRooms, _ = db.GetUserRoomSubscriptions(authUser.ID) bools := []bool{authUser.DisplayDeleteButton} if authUser.IsModerator() { @@ -204,10 +205,11 @@ func ChatMessagesHandler(c echo.Context) error { func RoomNotifierHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) roomID := dutils.DoParseRoomID(c.Param("roomID")) lastKnownDate := c.Request().PostFormValue("last_known_date") - room, err := database.GetChatRoomByID(roomID) + room, err := db.GetChatRoomByID(roomID) if err != nil { return c.NoContent(http.StatusNotFound) } @@ -219,7 +221,7 @@ func RoomNotifierHandler(c echo.Context) error { displayHellbanned := authUser.DisplayHellbanned || authUser.IsHellbanned displayIgnoredMessages := false - msgs, _ := database.GetChatMessages(roomID, authUser.Username, authUser.ID, database.PmNoFilter, false, displayHellbanned, authUser.DisplayIgnored, authUser.DisplayModerators, displayIgnoredMessages) + msgs, _ := db.GetChatMessages(roomID, authUser.Username, authUser.ID, database.PmNoFilter, false, displayHellbanned, authUser.DisplayIgnored, authUser.DisplayModerators, displayIgnoredMessages) if room.IsProtected() { key, err := hutils.GetRoomKeyCookie(c, int64(room.ID)) if err != nil { @@ -233,7 +235,7 @@ func RoomNotifierHandler(c echo.Context) error { var data testData data.NewMessageSound, data.PmSound, data.TaggedSound, data.LastMessageCreatedAt = shouldPlaySound(authUser, lastKnownDate, msgs) - data.InboxCount = database.GetUserInboxMessagesCount(authUser.ID) + data.InboxCount = db.GetUserInboxMessagesCount(authUser.ID) return c.JSON(http.StatusOK, data) } @@ -269,15 +271,16 @@ func shouldPlaySound(authUser *database.User, lastKnownDate string, msgs []datab func UserHellbanHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) userID := dutils.DoParseUserID(c.Param("userID")) - user, err := database.GetUserByID(userID) + user, err := db.GetUserByID(userID) if err != nil { return c.Redirect(http.StatusFound, c.Request().Referer()) } if !user.IsHellbanned { if authUser.IsAdmin || !user.IsModerator() { - database.NewAudit(*authUser, fmt.Sprintf("hellban %s #%d", user.Username, user.ID)) - user.HellBan() + db.NewAudit(*authUser, fmt.Sprintf("hellban %s #%d", user.Username, user.ID)) + user.HellBan(db) managers.ActiveUsers.UpdateUserHBInRooms(managers.NewUserInfo(user, nil)) } } @@ -286,14 +289,15 @@ func UserHellbanHandler(c echo.Context) error { func UserUnHellbanHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) userID := dutils.DoParseUserID(c.Param("userID")) - user, err := database.GetUserByID(userID) + user, err := db.GetUserByID(userID) if err != nil { return c.Redirect(http.StatusFound, c.Request().Referer()) } if user.IsHellbanned { - database.NewAudit(*authUser, fmt.Sprintf("unhellban %s #%d", user.Username, user.ID)) - user.UnHellBan() + db.NewAudit(*authUser, fmt.Sprintf("unhellban %s #%d", user.Username, user.ID)) + user.UnHellBan(db) managers.ActiveUsers.UpdateUserHBInRooms(managers.NewUserInfo(user, nil)) } return c.Redirect(http.StatusFound, c.Request().Referer()) @@ -301,67 +305,73 @@ func UserUnHellbanHandler(c echo.Context) error { func KickHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) userID := dutils.DoParseUserID(c.Param("userID")) - user, err := database.GetUserByID(userID) + user, err := db.GetUserByID(userID) if err != nil { return c.Redirect(http.StatusFound, c.Request().Referer()) } if user.IsModerator() { return c.Redirect(http.StatusFound, c.Request().Referer()) } - _ = dutils.SilentKick(user, *authUser) + _ = dutils.SilentKick(db, user, *authUser) return c.Redirect(http.StatusFound, c.Request().Referer()) } func SubscribeHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) roomName := c.Param("roomName") - room, err := database.GetChatRoomByName(roomName) + room, err := db.GetChatRoomByName(roomName) if err != nil { return c.Redirect(http.StatusFound, c.Request().Referer()) } - _ = database.SubscribeToRoom(authUser.ID, room.ID) + _ = db.SubscribeToRoom(authUser.ID, room.ID) return c.Redirect(http.StatusFound, c.Request().Referer()) } func UnsubscribeHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) roomName := c.Param("roomName") - room, err := database.GetChatRoomByName(roomName) + room, err := db.GetChatRoomByName(roomName) if err != nil { return c.Redirect(http.StatusFound, c.Request().Referer()) } - _ = database.UnsubscribeFromRoom(authUser.ID, room.ID) + _ = db.UnsubscribeFromRoom(authUser.ID, room.ID) return c.Redirect(http.StatusFound, c.Request().Referer()) } func ThreadSubscribeHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) threadUUID := database.ForumThreadUUID(c.Param("threadUUID")) - thread, err := database.GetForumThreadByUUID(threadUUID) + thread, err := db.GetForumThreadByUUID(threadUUID) if err != nil { return c.Redirect(http.StatusFound, c.Request().Referer()) } - _ = database.SubscribeToForumThread(authUser.ID, thread.ID) + _ = db.SubscribeToForumThread(authUser.ID, thread.ID) return c.Redirect(http.StatusFound, c.Request().Referer()) } func ThreadUnsubscribeHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) threadUUID := database.ForumThreadUUID(c.Param("threadUUID")) - thread, err := database.GetForumThreadByUUID(threadUUID) + thread, err := db.GetForumThreadByUUID(threadUUID) if err != nil { return c.Redirect(http.StatusFound, c.Request().Referer()) } - _ = database.UnsubscribeFromForumThread(authUser.ID, thread.ID) + _ = db.UnsubscribeFromForumThread(authUser.ID, thread.ID) return c.Redirect(http.StatusFound, c.Request().Referer()) } func ChatMessageReactionHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) messageUUID := c.Request().PostFormValue("message_uuid") var msg database.ChatMessage - if err := database.DB.Where("uuid = ?", messageUUID).Preload("User").Preload("Room").First(&msg).Error; err != nil { + if err := db.DB().Where("uuid = ?", messageUUID).Preload("User").Preload("Room").First(&msg).Error; err != nil { return err } reaction := utils.DoParseInt64(c.Request().PostFormValue("reaction_id")) @@ -369,8 +379,8 @@ func ChatMessageReactionHandler(c echo.Context) error { return errors.New("invalid reaction") } - if err := database.CreateChatReaction(authUser.ID, msg.ID, reaction); err != nil { - _ = database.DeleteReaction(authUser.ID, msg.ID, reaction) + if err := db.CreateChatReaction(authUser.ID, msg.ID, reaction); err != nil { + _ = db.DeleteReaction(authUser.ID, msg.ID, reaction) } return c.Redirect(http.StatusFound, c.Request().Referer()) @@ -378,10 +388,11 @@ func ChatMessageReactionHandler(c echo.Context) error { func ChatDeleteMessageHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) messageUUID := c.Param("messageUUID") var msg database.ChatMessage - if err := database.DB.Where("uuid = ?", messageUUID). + if err := db.DB().Where("uuid = ?", messageUUID). Preload("User"). Preload("Room"). First(&msg).Error; err != nil { @@ -404,7 +415,7 @@ func ChatDeleteMessageHandler(c echo.Context) error { msg.User.Username, msg.User.ID, utils.TruncStr(msg.RawMessage, 75, "…")) - database.NewAudit(*authUser, auditMsg) + db.NewAudit(*authUser, auditMsg) } } } else if msg.Room.OwnerUserID != nil && authUser.ID == *msg.Room.OwnerUserID { // Room owner can delete messages in its room @@ -414,12 +425,12 @@ func ChatDeleteMessageHandler(c echo.Context) error { if msg.RoomID == config.GeneralRoomID && msg.ToUserID == nil { authUser.GeneralMessagesCount-- - authUser.DoSave() + authUser.DoSave(db) } // If we delete message manually, also delete linked inbox if any - _ = database.DeleteChatInboxMessageByChatMessageID(msg.ID) - if err := database.DeleteChatMessageByUUID(messageUUID); err != nil { + _ = db.DeleteChatInboxMessageByChatMessageID(msg.ID) + if err := db.DeleteChatMessageByUUID(messageUUID); err != nil { logrus.Error(err) } @@ -428,8 +439,9 @@ func ChatDeleteMessageHandler(c echo.Context) error { func ClubDeleteMessageHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) messageID := database.ForumMessageID(utils.DoParseInt64(c.Param("messageID"))) - msg, err := database.GetForumMessage(messageID) + msg, err := db.GetForumMessage(messageID) if err != nil { return c.Redirect(http.StatusFound, c.Request().Referer()) } @@ -442,7 +454,7 @@ func ClubDeleteMessageHandler(c echo.Context) error { return c.Redirect(http.StatusFound, c.Request().Referer()) } - if err := database.DeleteForumMessageByID(messageID); err != nil { + if err := db.DeleteForumMessageByID(messageID); err != nil { logrus.Error(err) } return c.Redirect(http.StatusFound, c.Request().Referer()) @@ -450,27 +462,29 @@ func ClubDeleteMessageHandler(c echo.Context) error { func DeleteNotificationHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) notificationID := utils.DoParseInt64(c.Param("notificationID")) var msg database.Notification - if err := database.DB.Where("ID = ? AND user_id = ?", notificationID, authUser.ID).First(&msg).Error; err != nil { + if err := db.DB().Where("ID = ? AND user_id = ?", notificationID, authUser.ID).First(&msg).Error; err != nil { logrus.Error(err) return c.Redirect(http.StatusFound, c.Request().Referer()) } - if err := database.DeleteNotificationByID(notificationID); err != nil { + if err := db.DeleteNotificationByID(notificationID); err != nil { logrus.Error(err) } return c.Redirect(http.StatusFound, "/settings/inbox") } func DeleteSessionNotificationHandler(c echo.Context) error { + db := c.Get("database").(*database.DkfDB) authCookie, _ := c.Cookie(hutils.AuthCookieName) sessionNotificationID := utils.DoParseInt64(c.Param("sessionNotificationID")) var msg database.SessionNotification - if err := database.DB.Where("ID = ? AND session_token = ?", sessionNotificationID, authCookie.Value).First(&msg).Error; err != nil { + if err := db.DB().Where("ID = ? AND session_token = ?", sessionNotificationID, authCookie.Value).First(&msg).Error; err != nil { logrus.Error(err) return c.Redirect(http.StatusFound, c.Request().Referer()) } - if err := database.DeleteSessionNotificationByID(sessionNotificationID); err != nil { + if err := db.DeleteSessionNotificationByID(sessionNotificationID); err != nil { logrus.Error(err) } return c.Redirect(http.StatusFound, "/settings/inbox") @@ -478,13 +492,14 @@ func DeleteSessionNotificationHandler(c echo.Context) error { func ChatInboxDeleteMessageHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) messageID := utils.DoParseInt64(c.Param("messageID")) var msg database.ChatInboxMessage - if err := database.DB.Where("ID = ? AND to_user_id = ?", messageID, authUser.ID).First(&msg).Error; err != nil { + if err := db.DB().Where("ID = ? AND to_user_id = ?", messageID, authUser.ID).First(&msg).Error; err != nil { logrus.Error(err) return c.Redirect(http.StatusFound, "/settings/inbox") } - if err := database.DeleteChatInboxMessageByID(messageID); err != nil { + if err := db.DeleteChatInboxMessageByID(messageID); err != nil { logrus.Error(err) } return c.Redirect(http.StatusFound, "/settings/inbox") @@ -493,13 +508,14 @@ func ChatInboxDeleteMessageHandler(c echo.Context) error { func ChatInboxDeleteAllMessageHandler(c echo.Context) error { authCookie, _ := c.Cookie(hutils.AuthCookieName) authUser := c.Get("authUser").(*database.User) - if err := database.DeleteAllChatInbox(authUser.ID); err != nil { + db := c.Get("database").(*database.DkfDB) + if err := db.DeleteAllChatInbox(authUser.ID); err != nil { logrus.Error(err) } - if err := database.DeleteAllNotifications(authUser.ID); err != nil { + if err := db.DeleteAllNotifications(authUser.ID); err != nil { logrus.Error(err) } - if err := database.DeleteAllSessionNotifications(authCookie.Value); err != nil { + if err := db.DeleteAllSessionNotifications(authCookie.Value); err != nil { logrus.Error(err) } return c.Redirect(http.StatusFound, "/settings/inbox") @@ -513,6 +529,7 @@ func GetCaptchaHandler(c echo.Context) error { func CaptchaSolverHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) captchaB64 := c.Request().PostFormValue("captcha") answer, err := captcha.SolveBase64(captchaB64) if err != nil { @@ -524,7 +541,7 @@ func CaptchaSolverHandler(c echo.Context) error { CaptchaImg: captchaB64, Answer: answer, } - if err := database.DB.Create(&captchaReq).Error; err != nil { + if err := db.DB().Create(&captchaReq).Error; err != nil { logrus.Error(err.Error()) } return c.JSON(http.StatusOK, map[string]any{"answer": answer}) @@ -532,11 +549,12 @@ func CaptchaSolverHandler(c echo.Context) error { func ChessHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) roomName := c.Request().PostFormValue("room") enemyUsername := c.Request().PostFormValue("enemyUsername") pos := c.Request().PostFormValue("move") redirectURL := "/api/v1/chat/messages/" + roomName - room, roomKey, err := dutils.GetRoomAndKey(c, roomName) + room, roomKey, err := dutils.GetRoomAndKey(db, c, roomName) if err != nil { return c.Redirect(http.StatusFound, redirectURL+"?error="+err.Error()+"&errorTs="+utils.FormatInt64(time.Now().Unix())) } @@ -547,10 +565,11 @@ func ChessHandler(c echo.Context) error { } func WerewolfHandler(c echo.Context) error { + db := c.Get("database").(*database.DkfDB) roomName := "werewolf" origMessage := c.Request().PostFormValue("message") redirectURL := "/api/v1/chat/messages/" + roomName - room, roomKey, err := dutils.GetRoomAndKey(c, roomName) + room, roomKey, err := dutils.GetRoomAndKey(db, c, roomName) if err != nil { return c.Redirect(http.StatusFound, redirectURL+"?error="+err.Error()+"&errorTs="+utils.FormatInt64(time.Now().Unix())) } @@ -564,11 +583,12 @@ func WerewolfHandler(c echo.Context) error { func BattleshipHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) roomName := c.Request().PostFormValue("room") enemyUsername := c.Request().PostFormValue("enemyUsername") pos := c.Request().PostFormValue("move") redirectURL := "/api/v1/chat/messages/" + roomName - room, roomKey, err := dutils.GetRoomAndKey(c, roomName) + room, roomKey, err := dutils.GetRoomAndKey(db, c, roomName) if err != nil { return c.Redirect(http.StatusFound, redirectURL+"?error="+err.Error()+"&errorTs="+utils.FormatInt64(time.Now().Unix())) } diff --git a/pkg/web/handlers/api/v1/msgInterceptor.go b/pkg/web/handlers/api/v1/msgInterceptor.go @@ -23,32 +23,32 @@ func (i MsgInterceptor) InterceptMsg(cmd *Command) { return } - html, taggedUsersIDsMap := ProcessRawMessage(cmd.message, cmd.roomKey, cmd.authUser.ID, cmd.room.ID, cmd.upload) + html, taggedUsersIDsMap := ProcessRawMessage(cmd.db, cmd.message, cmd.roomKey, cmd.authUser.ID, cmd.room.ID, cmd.upload) toUserID := database.UserPtrID(cmd.toUser) - msgID, _ := database.CreateOrEditMessage(cmd.editMsg, html, cmd.origMessage, cmd.roomKey, cmd.room.ID, cmd.fromUserID, toUserID, cmd.upload, cmd.groupID, cmd.hellbanMsg, cmd.modMsg, cmd.systemMsg) + msgID, _ := cmd.db.CreateOrEditMessage(cmd.editMsg, html, cmd.origMessage, cmd.roomKey, cmd.room.ID, cmd.fromUserID, toUserID, cmd.upload, cmd.groupID, cmd.hellbanMsg, cmd.modMsg, cmd.systemMsg) if !cmd.skipInboxes { - sendInboxes(cmd.room, cmd.authUser, cmd.toUser, msgID, cmd.groupID, html, cmd.modMsg, taggedUsersIDsMap) + sendInboxes(cmd.db, cmd.room, cmd.authUser, cmd.toUser, msgID, cmd.groupID, html, cmd.modMsg, taggedUsersIDsMap) } // Count public messages in #general room if cmd.room.ID == config.GeneralRoomID && cmd.toUser == nil { cmd.authUser.GeneralMessagesCount++ - generalRoomKarma(cmd.authUser) - cmd.authUser.DoSave() + generalRoomKarma(cmd.db, cmd.authUser) + cmd.authUser.DoSave(cmd.db) } // Update chat read marker - database.UpdateChatReadMarker(cmd.authUser.ID, cmd.room.ID) + cmd.db.UpdateChatReadMarker(cmd.authUser.ID, cmd.room.ID) // Update user activity isPM := cmd.toUser != nil updateUserActivity(isPM, cmd.modMsg, cmd.room, cmd.authUser) } -func generalRoomKarma(authUser *database.User) { +func generalRoomKarma(db *database.DkfDB, authUser *database.User) { // Hellban users ain't getting karma if authUser.IsHellbanned { return @@ -56,30 +56,30 @@ func generalRoomKarma(authUser *database.User) { messagesCount := authUser.GeneralMessagesCount if messagesCount%100 == 0 { description := fmt.Sprintf("sent %d messages", messagesCount) - authUser.IncrKarma(1, description) + authUser.IncrKarma(db, 1, description) } else if messagesCount == 20 { - authUser.IncrKarma(1, "first 20 messages sent") + authUser.IncrKarma(db, 1, "first 20 messages sent") } } // ProcessRawMessage return the new html, and a map of tagged users used for notifications // This function takes an "unsafe" user input "in", and return html which will be safe to render. -func ProcessRawMessage(in, roomKey string, authUserID database.UserID, roomID database.RoomID, upload *database.Upload) (string, map[database.UserID]database.User) { - html, quoted := convertQuote(in, roomKey, roomID) // Get raw quote text which is not safe to render - html = html2.EscapeString(html) // Makes user input safe to render +func ProcessRawMessage(db *database.DkfDB, in, roomKey string, authUserID database.UserID, roomID database.RoomID, upload *database.Upload) (string, map[database.UserID]database.User) { + html, quoted := convertQuote(db, in, roomKey, roomID) // Get raw quote text which is not safe to render + html = html2.EscapeString(html) // Makes user input safe to render // All html generated from this point on shall be safe to render. - html = convertPGPClearsignToFile(html, authUserID) - html = convertPGPMessageToFile(html, authUserID) - html = convertPGPPublicKeyToFile(html, authUserID) - html = convertAgeMessageToFile(html, authUserID) + html = convertPGPClearsignToFile(db, html, authUserID) + html = convertPGPMessageToFile(db, html, authUserID) + html = convertPGPPublicKeyToFile(db, html, authUserID) + html = convertAgeMessageToFile(db, html, authUserID) html = convertLinksWithoutScheme(html) html = convertMarkdown(html) html = convertBangShortcuts(html) - html = convertArchiveLinks(html, roomID, authUserID) - html = convertLinks(html, database.GetUserByUsername) + html = convertArchiveLinks(db, html, roomID, authUserID) + html = convertLinks(html, db.GetUserByUsername) html = linkDefaultRooms(html) - html, taggedUsersIDsMap := colorifyTaggedUsers(html, database.GetUsersByUsername) - html = linkRoomTags(html) + html, taggedUsersIDsMap := colorifyTaggedUsers(html, db.GetUsersByUsername) + html = linkRoomTags(db, html) html = emojiReplacer.Replace(html) html = styleQuote(html, quoted) html = appendUploadLink(html, upload) @@ -89,7 +89,7 @@ func ProcessRawMessage(in, roomKey string, authUserID database.UserID, roomID da return html, taggedUsersIDsMap } -func sendInboxes(room database.ChatRoom, authUser, toUser *database.User, msgID int64, groupID *database.GroupID, html string, modMsg bool, +func sendInboxes(db *database.DkfDB, room database.ChatRoom, authUser, toUser *database.User, msgID int64, groupID *database.GroupID, html string, modMsg bool, taggedUsersIDsMap map[database.UserID]database.User) { // Only have chat inbox for unencrypted messages if room.IsProtected() { @@ -104,13 +104,13 @@ func sendInboxes(room database.ChatRoom, authUser, toUser *database.User, msgID return } - blacklistedBy, _ := database.GetPmBlacklistedByUsers(authUser.ID) + blacklistedBy, _ := db.GetPmBlacklistedByUsers(authUser.ID) blacklistedByMap := make(map[database.UserID]struct{}) for _, b := range blacklistedBy { blacklistedByMap[b.UserID] = struct{}{} } - ignoredBy, _ := database.GetIgnoredByUsers(authUser.ID) + ignoredBy, _ := db.GetIgnoredByUsers(authUser.ID) ignoredByMap := make(map[database.UserID]struct{}) for _, b := range ignoredBy { ignoredByMap[b.UserID] = struct{}{} @@ -126,7 +126,7 @@ func sendInboxes(room database.ChatRoom, authUser, toUser *database.User, msgID if _, ok := ignoredByMap[user.ID]; ok { return } - database.CreateInboxMessage(html, room.ID, authUser.ID, user.ID, isPM, modCh, &msgID) + db.CreateInboxMessage(html, room.ID, authUser.ID, user.ID, isPM, modCh, &msgID) } } @@ -147,7 +147,7 @@ func sendInboxes(room database.ChatRoom, authUser, toUser *database.User, msgID } } else if groupID != nil { // Only tags other people in the group for _, user := range taggedUsersIDsMap { - if database.IsUserInGroupByID(user.ID, *groupID) { + if db.IsUserInGroupByID(user.ID, *groupID) { sendInbox(user, false, false) } } diff --git a/pkg/web/handlers/api/v1/slashInterceptor.go b/pkg/web/handlers/api/v1/slashInterceptor.go @@ -155,7 +155,7 @@ func handleModeratorGroupCmd(c *Command) (handled bool) { func handleListModeratorsCmd(c *Command) (handled bool) { if c.message == "/moderators" || c.message == "/mods" { - mods, err := database.GetModeratorsUsers() + mods, err := c.db.GetModeratorsUsers() if err != nil { c.err = err return true @@ -231,11 +231,11 @@ func handleKickKeepSilentCmd(c *Command) (handled bool) { } func kickCmd(c *Command, username string, purge, silent bool) error { - user, err := database.GetUserByUsername(username) + user, err := c.db.GetUserByUsername(username) if err != nil { return ErrUsernameNotFound } - return dutils.Kick(user, *c.authUser, purge, silent) + return dutils.Kick(c.db, user, *c.authUser, purge, silent) } var ErrUsernameNotFound = errors.New("username not found") @@ -244,7 +244,7 @@ var ErrUnauthorized = errors.New("unauthorized") func handleUnkickCmd(c *Command) (handled bool) { if m := unkickRgx.FindStringSubmatch(c.message); len(m) == 2 { username := m[1] - user, err := database.GetUserByUsername(username) + user, err := c.db.GetUserByUsername(username) if err != nil { c.err = ErrUsernameNotFound return true @@ -253,12 +253,12 @@ func handleUnkickCmd(c *Command) (handled bool) { c.err = errors.New("user already not kicked") return true } - database.NewAudit(*c.authUser, fmt.Sprintf("unkick %s #%d", user.Username, user.ID)) + c.db.NewAudit(*c.authUser, fmt.Sprintf("unkick %s #%d", user.Username, user.ID)) user.Verified = true - user.DoSave() + user.DoSave(c.db) // Display unkick message - database.CreateUnkickMsg(user, *c.authUser) + c.db.CreateUnkickMsg(user, *c.authUser) c.err = ErrRedirect return true @@ -269,15 +269,15 @@ func handleUnkickCmd(c *Command) (handled bool) { func handleForceCaptchaCmd(c *Command) (handled bool) { if m := forceCaptchaRgx.FindStringSubmatch(c.message); len(m) == 2 { username := m[1] - user, err := database.GetUserByUsername(username) + user, err := c.db.GetUserByUsername(username) if err != nil { c.err = ErrUsernameNotFound return true } if c.authUser.IsAdmin || !user.IsModerator() || c.authUser.Username == username { - database.NewAudit(*c.authUser, fmt.Sprintf("force captcha %s #%d", user.Username, user.ID)) + c.db.NewAudit(*c.authUser, fmt.Sprintf("force captcha %s #%d", user.Username, user.ID)) user.CaptchaRequired = true - user.DoSave() + user.DoSave(c.db) } c.err = ErrRedirect return true @@ -288,7 +288,7 @@ func handleForceCaptchaCmd(c *Command) (handled bool) { func handleLogoutCmd(c *Command) (handled bool) { if m := logoutRgx.FindStringSubmatch(c.message); len(m) == 2 { username := m[1] - user, err := database.GetUserByUsername(username) + user, err := c.db.GetUserByUsername(username) if err != nil { c.err = ErrUsernameNotFound return true @@ -298,9 +298,9 @@ func handleLogoutCmd(c *Command) (handled bool) { return true } if c.authUser.IsAdmin || !user.IsModerator() { - database.NewAudit(*c.authUser, fmt.Sprintf("logout %s #%d", user.Username, user.ID)) + c.db.NewAudit(*c.authUser, fmt.Sprintf("logout %s #%d", user.Username, user.ID)) - _ = database.DeleteUserSessions(user.ID) + _ = c.db.DeleteUserSessions(user.ID) // Remove user from the user cache managers.ActiveUsers.RemoveUser(user.ID) @@ -315,7 +315,7 @@ func handleLogoutCmd(c *Command) (handled bool) { func handleResetTutorialCmd(c *Command) (handled bool) { if m := rtutoRgx.FindStringSubmatch(c.message); len(m) == 2 { username := m[1] - user, err := database.GetUserByUsername(username) + user, err := c.db.GetUserByUsername(username) if err != nil { c.err = ErrUsernameNotFound return true @@ -325,9 +325,9 @@ func handleResetTutorialCmd(c *Command) (handled bool) { return true } if c.authUser.IsAdmin || !user.IsModerator() { - database.NewAudit(*c.authUser, fmt.Sprintf("rtuto %s #%d", user.Username, user.ID)) + c.db.NewAudit(*c.authUser, fmt.Sprintf("rtuto %s #%d", user.Username, user.ID)) user.ChatTutorial = 0 - user.DoSave() + user.DoSave(c.db) } c.err = ErrRedirect return true @@ -338,7 +338,7 @@ func handleResetTutorialCmd(c *Command) (handled bool) { func handleHellbanCmd(c *Command) (handled bool) { if m := hellbanRgx.FindStringSubmatch(c.message); len(m) == 2 { username := m[1] - user, err := database.GetUserByUsername(username) + user, err := c.db.GetUserByUsername(username) if err != nil { c.err = ErrUsernameNotFound return true @@ -347,8 +347,8 @@ func handleHellbanCmd(c *Command) (handled bool) { c.err = ErrUnauthorized return true } - database.NewAudit(*c.authUser, fmt.Sprintf("hellban %s #%d", user.Username, user.ID)) - user.HellBan() + c.db.NewAudit(*c.authUser, fmt.Sprintf("hellban %s #%d", user.Username, user.ID)) + user.HellBan(c.db) managers.ActiveUsers.UpdateUserHBInRooms(managers.NewUserInfo(user, nil)) c.err = ErrRedirect @@ -360,7 +360,7 @@ func handleHellbanCmd(c *Command) (handled bool) { func handleUnhellbanCmd(c *Command) (handled bool) { if m := unhellbanRgx.FindStringSubmatch(c.message); len(m) == 2 { username := m[1] - user, err := database.GetUserByUsername(username) + user, err := c.db.GetUserByUsername(username) if err != nil { c.err = ErrUsernameNotFound return true @@ -369,8 +369,8 @@ func handleUnhellbanCmd(c *Command) (handled bool) { c.err = ErrUnauthorized return true } - database.NewAudit(*c.authUser, fmt.Sprintf("unhellban %s #%d", user.Username, user.ID)) - user.UnHellBan() + c.db.NewAudit(*c.authUser, fmt.Sprintf("unhellban %s #%d", user.Username, user.ID)) + user.UnHellBan(c.db) managers.ActiveUsers.UpdateUserHBInRooms(managers.NewUserInfo(user, nil)) c.err = ErrRedirect @@ -399,9 +399,9 @@ func handleHbmtCmd(c *Command) (handled bool) { if m := hbmtRgx.FindStringSubmatch(c.message); len(m) == 2 { date := m[1] if dt, err := utils.ParsePrevDatetimeAt(date, clockwork.NewRealClock()); err == nil { - if msg, err := database.GetRoomChatMessageByDate(c.room.ID, c.authUser.ID, dt.UTC()); err == nil { + if msg, err := c.db.GetRoomChatMessageByDate(c.room.ID, c.authUser.ID, dt.UTC()); err == nil { msg.IsHellbanned = !msg.IsHellbanned - msg.DoSave() + msg.DoSave(c.db) } else { c.err = errors.New("no message found at this timestamp") return true @@ -418,7 +418,7 @@ func handleDiceCmd(c *Command) (handled bool) { dice := utils.RandInt(1, 6) raw := fmt.Sprintf(`rolling dice for @%s ... "%d"`, c.authUser.Username, dice) msg := fmt.Sprintf(`rolling dice for @%s ... "<span style="color: white;">%d</span>"`, c.authUser.Username, dice) - msg, _ = colorifyTaggedUsers(msg, database.GetUsersByUsername) + msg, _ = colorifyTaggedUsers(msg, c.db.GetUsersByUsername) go func() { time.Sleep(time.Second) c.zeroPublicMsg(raw, msg) @@ -456,7 +456,7 @@ func handleRandCmd(c *Command) (handled bool) { dice = utils.RandInt(min, max) raw := fmt.Sprintf(`rolling dice for @%s ... "%d"`, c.authUser.Username, dice) msg := fmt.Sprintf(`rolling dice for @%s ... "<span style="color: white;">%d</span>"`, c.authUser.Username, dice) - msg, _ = colorifyTaggedUsers(msg, database.GetUsersByUsername) + msg, _ = colorifyTaggedUsers(msg, c.db.GetUsersByUsername) go func() { time.Sleep(time.Second) c.zeroPublicMsg(raw, msg) @@ -473,7 +473,7 @@ func handleChoiceCmd(c *Command) (handled bool) { answer := utils.RandChoice(words) raw := fmt.Sprintf(`@%s choice %s ... "%s"`, c.authUser.Username, words, answer) msg := fmt.Sprintf(`@%s choice %s ... "<span style="color: white;">%s</span>"`, c.authUser.Username, words, answer) - msg, _ = colorifyTaggedUsers(msg, database.GetUsersByUsername) + msg, _ = colorifyTaggedUsers(msg, c.db.GetUsersByUsername) go func() { time.Sleep(time.Second) c.zeroPublicMsg(raw, msg) @@ -533,7 +533,7 @@ func handleHasherCmd(c *Command, prefix string, fn func([]byte) string) (handled func handleRmGroupCmd(c *Command) (handled bool) { if m := rmGroupRgx.FindStringSubmatch(c.message); len(m) == 2 { groupName := m[1] - if err := database.DeleteChatRoomGroup(c.room.ID, groupName); err != nil { + if err := c.db.DeleteChatRoomGroup(c.room.ID, groupName); err != nil { c.err = err return true } @@ -546,13 +546,13 @@ func handleRmGroupCmd(c *Command) (handled bool) { func handleLockGroupCmd(c *Command) (handled bool) { if m := lockGroupRgx.FindStringSubmatch(c.message); len(m) == 2 { groupName := m[1] - group, err := database.GetRoomGroupByName(c.room.ID, groupName) + group, err := c.db.GetRoomGroupByName(c.room.ID, groupName) if err != nil { c.err = err return true } group.Locked = true - group.DoSave() + group.DoSave(c.db) c.err = ErrRedirect return true } @@ -562,13 +562,13 @@ func handleLockGroupCmd(c *Command) (handled bool) { func handleUnlockGroupCmd(c *Command) (handled bool) { if m := unlockGroupRgx.FindStringSubmatch(c.message); len(m) == 2 { groupName := m[1] - group, err := database.GetRoomGroupByName(c.room.ID, groupName) + group, err := c.db.GetRoomGroupByName(c.room.ID, groupName) if err != nil { c.err = err return true } group.Locked = false - group.DoSave() + group.DoSave(c.db) c.err = ErrRedirect return true } @@ -578,12 +578,12 @@ func handleUnlockGroupCmd(c *Command) (handled bool) { func handleGroupUsersCmd(c *Command) (handled bool) { if m := groupUsersRgx.FindStringSubmatch(c.message); len(m) == 2 { groupName := m[1] - group, err := database.GetRoomGroupByName(c.room.ID, groupName) + group, err := c.db.GetRoomGroupByName(c.room.ID, groupName) if err != nil { c.err = err return true } - users, err := database.GetRoomGroupUsers(c.room.ID, group.ID) + users, err := c.db.GetRoomGroupUsers(c.room.ID, group.ID) sort.Slice(users, func(i, j int) bool { return users[i].User.Username < users[j].User.Username }) @@ -604,7 +604,7 @@ func handleGroupUsersCmd(c *Command) (handled bool) { func handleListGroupsCmd(c *Command) (handled bool) { if c.message == "/groups" { - groups, err := database.GetRoomGroups(c.room.ID) + groups, err := c.db.GetRoomGroups(c.room.ID) if err != nil { c.err = err return true @@ -628,17 +628,17 @@ func handleGroupAddUserCmd(c *Command) (handled bool) { if m := groupAddUserRgx.FindStringSubmatch(c.message); len(m) == 3 { groupName := m[1] username := m[2] - user, err := database.GetUserByUsername(username) + user, err := c.db.GetUserByUsername(username) if err != nil { c.err = err return true } - group, err := database.GetRoomGroupByName(c.room.ID, groupName) + group, err := c.db.GetRoomGroupByName(c.room.ID, groupName) if err != nil { c.err = err return true } - _, err = database.AddUserToRoomGroup(c.room.ID, group.ID, user.ID) + _, err = c.db.AddUserToRoomGroup(c.room.ID, group.ID, user.ID) if err != nil { c.err = err return true @@ -655,17 +655,17 @@ func handleGroupRmUserCmd(c *Command) (handled bool) { if m := groupRmUserRgx.FindStringSubmatch(c.message); len(m) == 3 { groupName := m[1] username := m[2] - user, err := database.GetUserByUsername(username) + user, err := c.db.GetUserByUsername(username) if err != nil { c.err = err return true } - group, err := database.GetRoomGroupByName(c.room.ID, groupName) + group, err := c.db.GetRoomGroupByName(c.room.ID, groupName) if err != nil { c.err = err return true } - err = database.RmUserFromRoomGroup(c.room.ID, group.ID, user.ID) + err = c.db.RmUserFromRoomGroup(c.room.ID, group.ID, user.ID) if err != nil { c.err = err return true @@ -681,7 +681,7 @@ func handleGroupRmUserCmd(c *Command) (handled bool) { func handleSetModeWhitelistCmd(c *Command) (handled bool) { if c.message == "/mode user-whitelist" { c.room.Mode = database.UserWhitelistRoomMode - c.room.DoSave() + c.room.DoSave(c.db) c.message = `room mode set to "user whitelist"` c.receivePM() return true @@ -692,7 +692,7 @@ func handleSetModeWhitelistCmd(c *Command) (handled bool) { func handleSetModeStandardCmd(c *Command) (handled bool) { if c.message == "/mode standard" { c.room.Mode = database.NormalRoomMode - c.room.DoSave() + c.room.DoSave(c.db) c.message = `room mode set to "standard"` c.receivePM() return true @@ -703,12 +703,12 @@ func handleSetModeStandardCmd(c *Command) (handled bool) { func handleGetRoomWhitelistCmd(c *Command) (handled bool) { if m := whitelistUserRgx.FindStringSubmatch(c.message); len(m) == 2 { username := m[1] - user, err := database.GetUserByUsername(username) + user, err := c.db.GetUserByUsername(username) if err != nil { c.message = fmt.Sprintf(`username "%s" not found`, username) } else { - if _, err := database.WhitelistUser(c.room.ID, user.ID); err != nil { - if err := database.DeWhitelistUser(c.room.ID, user.ID); err != nil { + if _, err := c.db.WhitelistUser(c.room.ID, user.ID); err != nil { + if err := c.db.DeWhitelistUser(c.room.ID, user.ID); err != nil { c.message = fmt.Sprintf("failed to toggle @%s in whitelist", user.Username) } else { c.message = fmt.Sprintf("@%s removed from whitelist", user.Username) @@ -726,7 +726,7 @@ func handleGetRoomWhitelistCmd(c *Command) (handled bool) { func handleAddGroupCmd(c *Command) (handled bool) { if m := addGroupRgx.FindStringSubmatch(c.message); len(m) == 2 { name := m[1] - _, err := database.CreateChatRoomGroup(c.room.ID, name, "#fff") + _, err := c.db.CreateChatRoomGroup(c.room.ID, name, "#fff") if err != nil { c.err = err return true @@ -739,7 +739,7 @@ func handleAddGroupCmd(c *Command) (handled bool) { func handleWhitelistCmd(c *Command) (handled bool) { if c.message == "/whitelist" || c.message == "/wl" { - whitelistedUsers, _ := database.GetWhitelistedUsers(c.room.ID) + whitelistedUsers, _ := c.db.GetWhitelistedUsers(c.room.ID) if len(whitelistedUsers) > 0 { usernames := make([]string, 0) for _, whitelistedUser := range whitelistedUsers { @@ -786,7 +786,7 @@ func handleEditCmd(c *Command) (handled bool) { newMsg := m[2] if dt, err := utils.ParsePrevDatetimeAt(date, clockwork.NewRealClock()); err == nil { if time.Since(dt) <= config.EditMessageTimeLimit { - if msg, err := database.GetRoomChatMessageByDate(c.room.ID, c.authUser.ID, dt.UTC()); err == nil { + if msg, err := c.db.GetRoomChatMessageByDate(c.room.ID, c.authUser.ID, dt.UTC()); err == nil { c.editMsg = &msg c.origMessage = newMsg c.message = newMsg @@ -794,7 +794,7 @@ func handleEditCmd(c *Command) (handled bool) { // If we're editing a message which contains a link to an uploaded file, // we need to re-add the link to the html. if msg.UploadID != nil { - if newUpload, err := database.GetUploadByID(*msg.UploadID); err == nil { + if newUpload, err := c.db.GetUploadByID(*msg.UploadID); err == nil { c.upload = &newUpload } } @@ -820,7 +820,7 @@ func handleEditCmd(c *Command) (handled bool) { func handleEditLastCmd(c *Command) (handled bool) { if c.message == "/e" { - msg, err := database.GetUserLastChatMessageInRoom(c.authUser.ID, c.room.ID) + msg, err := c.db.GetUserLastChatMessageInRoom(c.authUser.ID, c.room.ID) if err != nil { return true } @@ -834,16 +834,16 @@ func handleEditLastCmd(c *Command) (handled bool) { var ErrPMDenied = errors.New("you cannot pm/inbox this user") var Err20Msgs = errors.New("you need 20 public messages to unlock PMs/Inbox; or be whitelisted") -func canUserInboxOther(user, other database.User) error { +func canUserInboxOther(db *database.DkfDB, user, other database.User) error { doesNotMatter := false - _, err := canUserPmOther(user, other, doesNotMatter) + _, err := canUserPmOther(db, user, other, doesNotMatter) return err } -func canUserPmOther(user, other database.User, roomIsPrivate bool) (skipInbox bool, err error) { +func canUserPmOther(db *database.DkfDB, user, other database.User, roomIsPrivate bool) (skipInbox bool, err error) { errPMDenied := ErrPMDenied - if database.IsUserPmWhitelisted(user.ID, other.ID) { + if db.IsUserPmWhitelisted(user.ID, other.ID) { return false, nil } @@ -863,7 +863,7 @@ func canUserPmOther(user, other database.User, roomIsPrivate bool) (skipInbox bo } // User on blacklist cannot PM/Inbox - if database.IsUserPmBlacklisted(user.ID, other.ID) { + if db.IsUserPmBlacklisted(user.ID, other.ID) { return false, errPMDenied } // Other doesn't want PM from new users @@ -888,13 +888,13 @@ func handlePMCmd(c *Command) (handled bool) { return handlePm0(c, newMsg) } - user, err := database.GetUserByUsername(username) + user, err := c.db.GetUserByUsername(username) if err != nil { c.err = errors.New("invalid username") return true } - c.skipInboxes, c.err = canUserPmOther(*c.authUser, user, c.room.IsOwned()) + c.skipInboxes, c.err = canUserPmOther(c.db, *c.authUser, user, c.room.IsOwned()) if c.err != nil { return true } @@ -988,7 +988,7 @@ func handlePm0(c *Command, msg string) (handled bool) { // If we sent a clearsign file to @0, the bot will reply with information about the signature if c.upload.FileSize < config.MaxFileSizeBeforeDownload { - if file, err := database.GetUploadByFileName(c.upload.FileName); err == nil { + if file, err := c.db.GetUploadByFileName(c.upload.FileName); err == nil { if _, by, err := file.GetContent(); err == nil { if b, _ := clearsign.Decode(by); b != nil { if p, err := packet.Read(b.ArmoredSignature.Body); err == nil { @@ -1027,7 +1027,7 @@ func handlePm0(c *Command, msg string) (handled bool) { func handleSubscribeCmd(c *Command) (handled bool) { if c.message == "/subscribe" { - _ = database.SubscribeToRoom(c.authUser.ID, c.room.ID) + _ = c.db.SubscribeToRoom(c.authUser.ID, c.room.ID) c.err = ErrRedirect return true } @@ -1036,17 +1036,17 @@ func handleSubscribeCmd(c *Command) (handled bool) { func handleUnsubscribeCmd(c *Command) (handled bool) { if m := unsubscribeRgx.FindStringSubmatch(c.message); len(m) == 2 { - room, err := database.GetChatRoomByName(m[1]) + room, err := c.db.GetChatRoomByName(m[1]) if err != nil { c.err = err return true } - _ = database.UnsubscribeFromRoom(c.authUser.ID, room.ID) + _ = c.db.UnsubscribeFromRoom(c.authUser.ID, room.ID) c.err = ErrRedirect return true } else if c.message == "/unsubscribe" { - _ = database.UnsubscribeFromRoom(c.authUser.ID, c.room.ID) + _ = c.db.UnsubscribeFromRoom(c.authUser.ID, c.room.ID) c.err = ErrRedirect return true } @@ -1057,7 +1057,7 @@ func handleGroupChatCmd(c *Command) (handled bool) { if m := groupRgx.FindStringSubmatch(c.message); len(m) == 3 { groupName := m[1] c.message = m[2] - group, err := database.GetRoomGroupByName(c.room.ID, groupName) + group, err := c.db.GetRoomGroupByName(c.room.ID, groupName) if err != nil { c.err = err return true @@ -1078,7 +1078,7 @@ func handleGroupChatCmd(c *Command) (handled bool) { func handleListIgnoredCmd(c *Command) (handled bool) { if c.message == "/i" || c.message == "/ignore" { - ignoredUsers, _ := database.GetIgnoredUsers(c.authUser.ID) + ignoredUsers, _ := c.db.GetIgnoredUsers(c.authUser.ID) sort.Slice(ignoredUsers, func(i, j int) bool { return ignoredUsers[i].IgnoredUser.Username < ignoredUsers[j].IgnoredUser.Username }) @@ -1099,7 +1099,7 @@ func handleListIgnoredCmd(c *Command) (handled bool) { func handleListPmWhitelistCmd(c *Command) (handled bool) { if c.message == "/pmwhitelist" { - pmWhitelistUsers, _ := database.GetPmWhitelistedUsers(c.authUser.ID) + pmWhitelistUsers, _ := c.db.GetPmWhitelistedUsers(c.authUser.ID) sort.Slice(pmWhitelistUsers, func(i, j int) bool { return pmWhitelistUsers[i].WhitelistedUser.Username < pmWhitelistUsers[j].WhitelistedUser.Username }) @@ -1121,7 +1121,7 @@ func handleListPmWhitelistCmd(c *Command) (handled bool) { func handleSetPmModeWhitelistCmd(c *Command) (handled bool) { if c.message == "/setpmmode whitelist" { c.authUser.PmMode = database.PmModeWhitelist - c.authUser.DoSave() + c.authUser.DoSave(c.db) c.message = `pm mode set to "whitelist"` c.receivePM() return true @@ -1132,7 +1132,7 @@ func handleSetPmModeWhitelistCmd(c *Command) (handled bool) { func handleSetPmModeStandardCmd(c *Command) (handled bool) { if c.message == "/setpmmode standard" { c.authUser.PmMode = database.PmModeStandard - c.authUser.DoSave() + c.authUser.DoSave(c.db) c.message = `pm mode set to "standard"` c.receivePM() return true @@ -1143,12 +1143,12 @@ func handleSetPmModeStandardCmd(c *Command) (handled bool) { func handleTogglePmBlacklistedUser(c *Command) (handled bool) { if m := pmToggleBlacklistUserRgx.FindStringSubmatch(c.message); len(m) == 2 { username := m[1] - user, err := database.GetUserByUsername(username) + user, err := c.db.GetUserByUsername(username) if err != nil { c.err = ErrRedirect return true } - if database.ToggleBlacklistedUser(c.authUser.ID, user.ID) { + if c.db.ToggleBlacklistedUser(c.authUser.ID, user.ID) { c.err = NewErrSuccess("added to blacklist") } else { c.err = NewErrSuccess("removed from blacklist") @@ -1161,12 +1161,12 @@ func handleTogglePmBlacklistedUser(c *Command) (handled bool) { func handleTogglePmWhitelistedUser(c *Command) (handled bool) { if m := pmToggleWhitelistUserRgx.FindStringSubmatch(c.message); len(m) == 2 { username := m[1] - user, err := database.GetUserByUsername(username) + user, err := c.db.GetUserByUsername(username) if err != nil { c.err = ErrRedirect return true } - if database.ToggleWhitelistedUser(c.authUser.ID, user.ID) { + if c.db.ToggleWhitelistedUser(c.authUser.ID, user.ID) { c.err = NewErrSuccess("added to whitelist") } else { c.err = NewErrSuccess("removed from whitelist") @@ -1180,7 +1180,7 @@ func handleChessCmd(c *Command) (handled bool) { if m := chessRgx.FindStringSubmatch(c.message); len(m) == 2 { username := m[1] player1 := *c.authUser - player2, err := database.GetUserByUsername(username) + player2, err := c.db.GetUserByUsername(username) if err != nil { c.err = errors.New("invalid username") return true @@ -1204,13 +1204,13 @@ func handleInboxCmd(c *Command) (handled bool) { if encryptRaw == " -e" { tryEncrypt = true } - toUser, err := database.GetUserByUsername(username) + toUser, err := c.db.GetUserByUsername(username) if err != nil { c.err = errors.New("invalid username") return true } - if err := canUserInboxOther(*c.authUser, toUser); err != nil { + if err := canUserInboxOther(c.db, *c.authUser, toUser); err != nil { c.err = err return true } @@ -1229,8 +1229,8 @@ func handleInboxCmd(c *Command) (handled bool) { html = strings.Join(strings.Split(html, "\n"), " ") } - html, _ = ProcessRawMessage(html, c.roomKey, c.authUser.ID, c.room.ID, nil) - database.CreateInboxMessage(html, c.room.ID, c.authUser.ID, toUser.ID, true, false, nil) + html, _ = ProcessRawMessage(c.db, html, c.roomKey, c.authUser.ID, c.room.ID, nil) + c.db.CreateInboxMessage(html, c.room.ID, c.authUser.ID, toUser.ID, true, false, nil) c.dataMessage = "/inbox " + username + " " c.err = NewErrSuccess("inbox sent") @@ -1246,7 +1246,7 @@ func handleInboxCmd(c *Command) (handled bool) { func handleProfileCmd(c *Command) (handled bool) { if m := profileRgx.FindStringSubmatch(c.message); len(m) == 2 { username := m[1] - user, err := database.GetUserByUsername(username) + user, err := c.db.GetUserByUsername(username) if err != nil { c.err = ErrUsernameNotFound return true @@ -1275,7 +1275,7 @@ type tutorialSteps struct { func handleTutorialCmd(c *Command) (handled bool) { if c.message == "/tuto" && false { name := "tuto_" + utils.GenerateToken10() - room, _ := database.CreateRoom(name, "", c.authUser.ID, false) + room, _ := c.db.CreateRoom(name, "", c.authUser.ID, false) c.err = ErrRedirect c.zeroProcMsg("Tutorial here -> #" + room.Name) c.zeroPublicProcMsgRoom("Welcome to the tutorial", "", room.ID) @@ -1288,12 +1288,12 @@ func handleDeleteMsgCmd(c *Command) (handled bool) { delMsgFn := func(msg database.ChatMessage) { if msg.RoomID == config.GeneralRoomID && msg.ToUserID == nil { msg.User.GeneralMessagesCount-- - msg.User.DoSave() + msg.User.DoSave(c.db) } - _ = database.DeleteChatMessageByUUID(msg.UUID) + _ = c.db.DeleteChatMessageByUUID(msg.UUID) } if c.message == "/d" { - if msg, err := database.GetUserLastChatMessageInRoom(c.authUser.ID, c.room.ID); err != nil { + if msg, err := c.db.GetUserLastChatMessageInRoom(c.authUser.ID, c.room.ID); err != nil { c.err = errors.New("unable to find last message") return true } else if msg.TooOldToDelete() { @@ -1315,7 +1315,7 @@ func handleDeleteMsgCmd(c *Command) (handled bool) { c.err = err return true } - msgs, err := database.GetRoomChatMessagesByDate(c.room.ID, dt.UTC()) + msgs, err := c.db.GetRoomChatMessagesByDate(c.room.ID, dt.UTC()) if err != nil { c.err = err return true @@ -1340,7 +1340,7 @@ func handleDeleteMsgCmd(c *Command) (handled bool) { return true } // Moderator - _ = database.DeleteChatMessageByUUID(msg.UUID) + _ = c.db.DeleteChatMessageByUUID(msg.UUID) c.err = ErrRedirect return true @@ -1383,7 +1383,7 @@ func handleDeleteMsgCmd(c *Command) (handled bool) { c.err = errors.New("failed to find msg") return true } - _ = database.DeleteChatMessageByUUID(msg.UUID) + _ = c.db.DeleteChatMessageByUUID(msg.UUID) c.err = ErrRedirect return true } @@ -1406,13 +1406,13 @@ func handleHideMsgCmd(c *Command) (handled bool) { c.err = err return true } - msgs, err := database.GetRoomChatMessagesByDate(c.room.ID, dt.UTC()) + msgs, err := c.db.GetRoomChatMessagesByDate(c.room.ID, dt.UTC()) if err != nil { c.err = err return true } if len(msgs) == 1 { - database.IgnoreMessage(c.authUser.ID, msgs[0].ID) + c.db.IgnoreMessage(c.authUser.ID, msgs[0].ID) c.err = ErrRedirect } else { c.err = errors.New("more than 1 message") @@ -1431,13 +1431,13 @@ func handleUnHideMsgCmd(c *Command) (handled bool) { c.err = err return true } - msgs, err := database.GetRoomChatMessagesByDate(c.room.ID, dt.UTC()) + msgs, err := c.db.GetRoomChatMessagesByDate(c.room.ID, dt.UTC()) if err != nil { c.err = err return true } if len(msgs) == 1 { - database.UnIgnoreMessage(c.authUser.ID, msgs[0].ID) + c.db.UnIgnoreMessage(c.authUser.ID, msgs[0].ID) c.err = ErrRedirect } else { c.err = errors.New("more than 1 message") @@ -1450,12 +1450,12 @@ func handleUnHideMsgCmd(c *Command) (handled bool) { func handleIgnoreCmd(c *Command) (handled bool) { if m := ignoreRgx.FindStringSubmatch(c.message); len(m) == 2 { username := m[1] - user, err := database.GetUserByUsername(username) + user, err := c.db.GetUserByUsername(username) if err != nil { c.err = ErrRedirect return true } - database.IgnoreUser(c.authUser.ID, user.ID) + c.db.IgnoreUser(c.authUser.ID, user.ID) c.err = ErrRedirect return true } else if strings.HasPrefix(c.message, "/ignore ") || strings.HasPrefix(c.message, "/i ") { @@ -1468,12 +1468,12 @@ func handleIgnoreCmd(c *Command) (handled bool) { func handleUnIgnoreCmd(c *Command) (handled bool) { if m := unIgnoreRgx.FindStringSubmatch(c.message); len(m) == 2 { username := m[1] - user, err := database.GetUserByUsername(username) + user, err := c.db.GetUserByUsername(username) if err != nil { c.err = ErrRedirect return true } - database.UnIgnoreUser(c.authUser.ID, user.ID) + c.db.UnIgnoreUser(c.authUser.ID, user.ID) c.err = ErrRedirect return true } else if strings.HasPrefix(c.message, "/unignore ") || strings.HasPrefix(c.message, "/ui ") { @@ -1486,7 +1486,7 @@ func handleUnIgnoreCmd(c *Command) (handled bool) { func handleToggleAutocomplete(c *Command) (handled bool) { if c.message == "/toggle-autocomplete" { c.authUser.AutocompleteCommandsEnabled = !c.authUser.AutocompleteCommandsEnabled - c.authUser.DoSave() + c.authUser.DoSave(c.db) c.err = ErrRedirect return true } @@ -1527,13 +1527,13 @@ func handleSetChatRoomExternalLink(c *Command) (handled bool) { if !govalidator.IsURL(externalURL) { externalURL = "" } - room, err := database.GetChatRoomByID(c.room.ID) + room, err := c.db.GetChatRoomByID(c.room.ID) if err != nil { c.err = err return true } room.ExternalLink = externalURL - room.DoSave() + room.DoSave(c.db) c.err = ErrRedirect return true } @@ -1543,13 +1543,13 @@ func handleSetChatRoomExternalLink(c *Command) (handled bool) { func handlePurge(c *Command) (handled bool) { if m := purgeRgx.FindStringSubmatch(c.message); len(m) == 2 { username := m[1] - user, err := database.GetUserByUsername(username) + user, err := c.db.GetUserByUsername(username) if err != nil { c.err = err return true } - database.NewAudit(*c.authUser, fmt.Sprintf("purge %s #%d", user.Username, user.ID)) - _ = database.DeleteUserChatMessages(user.ID) + c.db.NewAudit(*c.authUser, fmt.Sprintf("purge %s #%d", user.Username, user.ID)) + _ = c.db.DeleteUserChatMessages(user.ID) c.err = ErrRedirect return true } @@ -1560,21 +1560,21 @@ func handleRename(c *Command) (handled bool) { if m := renameRgx.FindStringSubmatch(c.message); len(m) == 3 { oldUsername := m[1] newUsername := m[2] - user, err := database.GetUserByUsername(oldUsername) + user, err := c.db.GetUserByUsername(oldUsername) if err != nil { c.err = err return true } - database.NewAudit(*c.authUser, fmt.Sprintf("rename %s -> %s #%d", user.Username, newUsername, user.ID)) + c.db.NewAudit(*c.authUser, fmt.Sprintf("rename %s -> %s #%d", user.Username, newUsername, user.ID)) - if err := database.CanRenameTo(oldUsername, newUsername); err != nil { + if err := c.db.CanRenameTo(oldUsername, newUsername); err != nil { c.err = err return true } managers.ActiveUsers.RemoveUser(user.ID) user.Username = newUsername - user.DoSave() + user.DoSave(c.db) c.err = ErrRedirect return true diff --git a/pkg/web/handlers/api/v1/snippetInterceptor.go b/pkg/web/handlers/api/v1/snippetInterceptor.go @@ -11,16 +11,16 @@ type SnippetInterceptor struct{} func (i SnippetInterceptor) InterceptMsg(cmd *Command) { // Snippets actually mutate the original message, // to simulate that the user actually typed the text - cmd.origMessage = snippets(cmd.authUser.ID, cmd.origMessage) + cmd.origMessage = snippets(cmd.db, cmd.authUser.ID, cmd.origMessage) cmd.origMessage = autocompleteTags(cmd.origMessage) cmd.message = cmd.origMessage } -func snippets(authUserID database.UserID, html string) string { +func snippets(db *database.DkfDB, authUserID database.UserID, html string) string { if snippetRgx.MatchString(html) { - userSnippets, _ := database.GetUserSnippets(authUserID) + userSnippets, _ := db.GetUserSnippets(authUserID) if len(userSnippets) > 0 { // Build hashmap for fast lookup m := make(map[string]string) diff --git a/pkg/web/handlers/api/v1/spamInterceptor.go b/pkg/web/handlers/api/v1/spamInterceptor.go @@ -14,13 +14,13 @@ import ( type SpamInterceptor struct{} func (i SpamInterceptor) InterceptMsg(c *Command) { - if err := checkSpam(c.origMessage, c.authUser); err != nil { + if err := checkSpam(c.db, c.origMessage, c.authUser); err != nil { c.err = err return } // Check CP links - if checkCPLinks(c.message) { + if checkCPLinks(c.db, c.message) { c.err = errors.New("forbidden url") return } @@ -32,7 +32,7 @@ func (i SpamInterceptor) InterceptMsg(c *Command) { var ErrSpamFilterTriggered = errors.New("spam filter triggered") -func checkSpam(origMessage string, authUser *database.User) error { +func checkSpam(db *database.DkfDB, origMessage string, authUser *database.User) error { lowerCaseMessage := strings.ToLower(origMessage) silentSelfKick := config.SilentSelfKick.Load() @@ -42,13 +42,13 @@ func checkSpam(origMessage string, authUser *database.User) error { strings.Contains(lowerCaseMessage, "i wanna see gore") || strings.Contains(lowerCaseMessage, "how can i make money") || strings.Contains(lowerCaseMessage, "any links for scary stuff") { - _ = dutils.SelfKick(*authUser, silentSelfKick) + _ = dutils.SelfKick(db, *authUser, silentSelfKick) return ErrSpamFilterTriggered } } if authUser.GeneralMessagesCount < 20 || time.Since(authUser.CreatedAt) < 5*time.Hour { if strings.Contains(lowerCaseMessage, "cp link") { - _ = dutils.SelfKick(*authUser, silentSelfKick) + _ = dutils.SelfKick(db, *authUser, silentSelfKick) return ErrSpamFilterTriggered } } @@ -57,7 +57,7 @@ func checkSpam(origMessage string, authUser *database.User) error { if authUser.IsModerator() { return ErrSpamFilterTriggered } - _ = dutils.SelfKick(*authUser, silentSelfKick) + _ = dutils.SelfKick(db, *authUser, silentSelfKick) return ErrSpamFilterTriggered } @@ -66,7 +66,7 @@ func checkSpam(origMessage string, authUser *database.User) error { count, total := utils.CountUppercase(origMessage) pct := float64(count) / float64(total) if total > 5 && pct > 0.8 { - _ = dutils.SelfKick(*authUser, silentSelfKick) + _ = dutils.SelfKick(db, *authUser, silentSelfKick) return ErrSpamFilterTriggered } } @@ -103,14 +103,14 @@ func checkSpam(origMessage string, authUser *database.User) error { if authUser.GeneralMessagesCount < 10 { if wordsMap["porn"] > 0 && (wordsMap["link"] > 0 || wordsMap["links"] > 0) { - _ = dutils.SelfKick(*authUser, silentSelfKick) + _ = dutils.SelfKick(db, *authUser, silentSelfKick) return ErrSpamFilterTriggered } } if authUser.GeneralMessagesCount < 20 || time.Since(authUser.CreatedAt) < 5*time.Hour { if wordsMap["cp"] > 0 && (wordsMap["link"] > 0 || wordsMap["links"] > 0) { - _ = dutils.SelfKick(*authUser, silentSelfKick) + _ = dutils.SelfKick(db, *authUser, silentSelfKick) return ErrSpamFilterTriggered } } diff --git a/pkg/web/handlers/api/v1/topBarHandler.go b/pkg/web/handlers/api/v1/topBarHandler.go @@ -99,7 +99,7 @@ const ( redirectMultilineQP = "ml" ) -func getDataMessagePrefix(c echo.Context, roomKey string, room database.ChatRoom, authUser *database.User) (out string, err error) { +func getDataMessagePrefix(db *database.DkfDB, c echo.Context, roomKey string, room database.ChatRoom, authUser *database.User) (out string, err error) { pm := c.QueryParam(redirectPmQP) edit := c.QueryParam(redirectEditQP) group := c.QueryParam(redirectGroupQP) @@ -125,12 +125,12 @@ func getDataMessagePrefix(c echo.Context, roomKey string, room database.ChatRoom } else if mtag != "" { out = "/m @" + mtag + " " } else if edit != "" { - out, err = handleGetEdit(edit, roomKey, room, authUser) + out, err = handleGetEdit(db, edit, roomKey, room, authUser) if err != nil { return } } else if quote != "" { - out, err = handleGetQuote(quote, roomKey, room, authUser) + out, err = handleGetQuote(db, quote, roomKey, room, authUser) if err != nil { return } @@ -210,6 +210,7 @@ func buildCommandsList(authUser *database.User, room database.ChatRoom) (command func ChatTopBarHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) var data chatTopBarData data.RoomName = c.Param("roomName") @@ -229,7 +230,7 @@ func ChatTopBarHandler(c echo.Context) error { } } - room, roomKey, err := dutils.GetRoomAndKey(c, data.RoomName) + room, roomKey, err := dutils.GetRoomAndKey(db, c, data.RoomName) if err != nil { return err } @@ -239,7 +240,7 @@ func ChatTopBarHandler(c echo.Context) error { return c.Render(http.StatusOK, "chat-top-bar", data) } - data.Message, err = getDataMessagePrefix(c, roomKey, room, authUser) + data.Message, err = getDataMessagePrefix(db, c, roomKey, room, authUser) if err != nil { return c.Redirect(http.StatusFound, "/api/v1/chat/top-bar/"+room.Name) } @@ -331,8 +332,8 @@ func replTextPrefixSuffix(msg, prefix, suffix, repl string) (out string) { return } -func handleGetQuote(msgUUID, roomKey string, room database.ChatRoom, authUser *database.User) (dataMessage string, err error) { - quoted, err := database.GetRoomChatMessageByUUID(room.ID, msgUUID) +func handleGetQuote(db *database.DkfDB, msgUUID, roomKey string, room database.ChatRoom, authUser *database.User) (dataMessage string, err error) { + quoted, err := db.GetRoomChatMessageByUUID(room.ID, msgUUID) if err != nil { return } @@ -354,14 +355,14 @@ func handleGetQuote(msgUUID, roomKey string, room database.ChatRoom, authUser *d } // Append the actual quoted text - dataMessage = prefix + getQuoteTxt(roomKey, quoted) + " " + dataMessage = prefix + getQuoteTxt(db, roomKey, quoted) + " " return } -func handleGetEdit(hourMinSec, roomKey string, room database.ChatRoom, authUser *database.User) (dataMessage string, err error) { +func handleGetEdit(db *database.DkfDB, hourMinSec, roomKey string, room database.ChatRoom, authUser *database.User) (dataMessage string, err error) { if dt, err := utils.ParsePrevDatetimeAt(hourMinSec, clockwork.NewRealClock()); err == nil { if time.Since(dt) <= config.EditMessageTimeLimit { - if msg, err := database.GetRoomChatMessageByDate(room.ID, authUser.ID, dt.UTC()); err == nil { + if msg, err := db.GetRoomChatMessageByDate(room.ID, authUser.ID, dt.UTC()); err == nil { decrypted, err := msg.GetRawMessage(roomKey) if err != nil { return "", err @@ -384,6 +385,7 @@ type Command struct { room database.ChatRoom // Room the user is in roomKey string // Room password (if any) authUser *database.User // Authenticated user + db *database.DkfDB // Database instance fromUserID database.UserID // Sender of message toUser *database.User // If not nil, will be a PM upload *database.Upload // If the message contains an uploaded file @@ -399,9 +401,11 @@ type Command struct { func NewCommand(c echo.Context, origMessage string, room database.ChatRoom, roomKey string) *Command { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) return &Command{ c: c, authUser: authUser, + db: db, fromUserID: authUser.ID, hellbanMsg: authUser.IsHellbanned, redirectQP: url.Values{}, @@ -425,7 +429,7 @@ func (c *Command) receivePM() { // Lazy loading and cache of the zero user func (c *Command) getZeroUser() database.User { if c.zeroUser == nil { - zeroUser := dutils.GetZeroUser() + zeroUser := dutils.GetZeroUser(c.db) c.zeroUser = &zeroUser } return *c.zeroUser @@ -444,14 +448,14 @@ func (c *Command) zeroProcMsg(rawMsg string) { func (c *Command) zeroProcMsgRoom(rawMsg, roomKey string, roomID database.RoomID) { zeroUser := c.getZeroUser() - procMsg, _ := ProcessRawMessage(rawMsg, roomKey, c.authUser.ID, roomID, nil) - rawMsgRoom(zeroUser, c.authUser, rawMsg, procMsg, roomKey, roomID) + procMsg, _ := ProcessRawMessage(c.db, rawMsg, roomKey, c.authUser.ID, roomID, nil) + rawMsgRoom(c.db, zeroUser, c.authUser, rawMsg, procMsg, roomKey, roomID) } func (c *Command) zeroPublicProcMsgRoom(rawMsg, roomKey string, roomID database.RoomID) { zeroUser := c.getZeroUser() - procMsg, _ := ProcessRawMessage(rawMsg, roomKey, c.authUser.ID, roomID, nil) - rawMsgRoom(zeroUser, nil, rawMsg, procMsg, roomKey, roomID) + procMsg, _ := ProcessRawMessage(c.db, rawMsg, roomKey, c.authUser.ID, roomID, nil) + rawMsgRoom(c.db, zeroUser, nil, rawMsg, procMsg, roomKey, roomID) } func (c *Command) zeroPublicMsg(raw, msg string) { @@ -460,15 +464,15 @@ func (c *Command) zeroPublicMsg(raw, msg string) { } func (c *Command) rawMsg(user1 database.User, user2 *database.User, raw, msg string) { - rawMsgRoom(user1, user2, raw, msg, c.roomKey, c.room.ID) + rawMsgRoom(c.db, user1, user2, raw, msg, c.roomKey, c.room.ID) } -func rawMsgRoom(user1 database.User, user2 *database.User, raw, msg, roomKey string, roomID database.RoomID) { +func rawMsgRoom(db *database.DkfDB, user1 database.User, user2 *database.User, raw, msg, roomKey string, roomID database.RoomID) { var toUserID *database.UserID if user2 != nil { toUserID = &user2.ID } - _, _ = database.CreateMsg(raw, msg, roomKey, roomID, user1.ID, toUserID) + _, _ = db.CreateMsg(raw, msg, roomKey, roomID, user1.ID, toUserID) } type ErrSuccess struct { @@ -493,12 +497,12 @@ func appendUploadLink(html string, upload *database.Upload) string { return html } -func checkCPLinks(html string) bool { +func checkCPLinks(db *database.DkfDB, html string) bool { m1 := onionV3Rgx.FindAllStringSubmatch(html, -1) m2 := onionV2Rgx.FindAllStringSubmatch(html, -1) for _, m := range append(m1, m2...) { hash := utils.MD5([]byte(m[0])) - if _, err := database.GetOnionBlacklist(hash); err == nil { + if _, err := db.GetOnionBlacklist(hash); err == nil { return true } } @@ -525,7 +529,7 @@ func sanitizeUserInput(html string) string { // Convert timestamps such as 01:23:45 to an archive link if a message with that timestamp exists. // eg: "Some text 14:31:46 some more text" -func convertArchiveLinks(html string, roomID database.RoomID, authUserID database.UserID) string { +func convertArchiveLinks(db *database.DkfDB, html string, roomID database.RoomID, authUserID database.UserID) string { start, rest := "", html // Do not replace timestamps that are inside a quote text @@ -548,7 +552,7 @@ func convertArchiveLinks(html string, roomID database.RoomID, authUserID databas if err != nil { return s } - if msgs, err := database.GetRoomChatMessagesByDate(roomID, dt.UTC()); err == nil && len(msgs) > 0 { + if msgs, err := db.GetRoomChatMessagesByDate(roomID, dt.UTC()); err == nil && len(msgs) > 0 { msg := msgs[0] if len(msgs) > 1 { for _, msgTmp := range msgs { @@ -613,10 +617,10 @@ func colorifyTaggedUsers(html string, getUsersByUsername getUsersByUsernameFn) ( return html, taggedUsersIDsMap } -func linkRoomTags(html string) string { +func linkRoomTags(db *database.DkfDB, html string) string { if roomTagRgx.MatchString(html) { html = roomTagRgx.ReplaceAllStringFunc(html, func(s string) string { - if room, err := database.GetChatRoomByName(strings.TrimPrefix(s, "#")); err == nil { + if room, err := db.GetChatRoomByName(strings.TrimPrefix(s, "#")); err == nil { return `<a href="/chat/` + room.Name + `" target="_top">` + s + `</a>` } return s @@ -626,9 +630,9 @@ func linkRoomTags(html string) string { } // Given a roomID and hourMinSec (01:23:45) and a username, retrieve the message from database that fits the predicates. -func getQuotedChatMessage(hourMinSec, username string, roomID database.RoomID) (quoted *database.ChatMessage) { +func getQuotedChatMessage(db *database.DkfDB, hourMinSec, username string, roomID database.RoomID) (quoted *database.ChatMessage) { if dt, err := utils.ParsePrevDatetimeAt(hourMinSec, clockwork.NewRealClock()); err == nil { - if msgs, err := database.GetRoomChatMessagesByDate(roomID, dt.UTC()); err == nil && len(msgs) > 0 { + if msgs, err := db.GetRoomChatMessagesByDate(roomID, dt.UTC()); err == nil && len(msgs) > 0 { msg := msgs[0] if len(msgs) > 1 { for _, msgTmp := range msgs { @@ -645,7 +649,7 @@ func getQuotedChatMessage(hourMinSec, username string, roomID database.RoomID) ( } // Given a chat message, return the text to be used as a quote. -func getQuoteTxt(roomKey string, quoted database.ChatMessage) (out string) { +func getQuoteTxt(db *database.DkfDB, roomKey string, quoted database.ChatMessage) (out string) { var err error decrypted, err := quoted.GetRawMessage(roomKey) if err != nil { @@ -685,7 +689,7 @@ func getQuoteTxt(roomKey string, quoted database.ChatMessage) (out string) { remaining += fmt.Sprintf(`%s `, quoted.User.Username) } if quoted.UploadID != nil { - if upload, err := database.GetUploadByID(*quoted.UploadID); err == nil { + if upload, err := db.GetUploadByID(*quoted.UploadID); err == nil { if decrypted != "" { decrypted += " " } @@ -714,7 +718,7 @@ func getQuoteTxt(roomKey string, quoted database.ChatMessage) (out string) { // eg: we received altered quote, and return original quote -> // “[01:23:45] username - Some maliciously altered quote” Some text // “[01:23:45] username - The original text” Some text -func convertQuote(origHtml string, roomKey string, roomID database.RoomID) (html string, quoted *database.ChatMessage) { +func convertQuote(db *database.DkfDB, origHtml string, roomKey string, roomID database.RoomID) (html string, quoted *database.ChatMessage) { const quotePrefix = `“[` const quoteSuffix = `”` html = origHtml @@ -725,8 +729,8 @@ func convertQuote(origHtml string, roomKey string, roomID database.RoomID) (html if len(origHtml) > prefixLen+9 { hourMinSec := origHtml[prefixLen : prefixLen+8] username := origHtml[prefixLen+10 : strings.Index(origHtml[prefixLen+10:], " ")+prefixLen+10] - if quoted = getQuotedChatMessage(hourMinSec, username, roomID); quoted != nil { - html = getQuoteTxt(roomKey, *quoted) + if quoted = getQuotedChatMessage(db, hourMinSec, username, roomID); quoted != nil { + html = getQuoteTxt(db, roomKey, *quoted) html += origHtml[idx+suffixLen:] } } @@ -1036,7 +1040,7 @@ func extractPGPMessage(html string) (out string) { } // Auto convert pasted pgp message into uploaded file -func convertPGPMessageToFile(html string, authUserID database.UserID) string { +func convertPGPMessageToFile(db *database.DkfDB, html string, authUserID database.UserID) string { startIdx := strings.Index(html, pgpPrefix) endIdx := strings.Index(html, pgpSuffix) if startIdx != -1 && endIdx != -1 { @@ -1047,7 +1051,7 @@ func convertPGPMessageToFile(html string, authUserID database.UserID) string { tmp = strings.Join(strings.Split(tmp, " "), "\n") tmp = pgpPrefix + tmp tmp += pgpSuffix - upload, _ := database.CreateUpload("pgp.txt", []byte(tmp), authUserID) + upload, _ := db.CreateUpload("pgp.txt", []byte(tmp), authUserID) msgBefore := html[0:startIdx] msgAfter := html[endIdx+len(pgpSuffix):] html = msgBefore + ` [` + upload.GetHTMLLink() + `] ` + msgAfter @@ -1057,14 +1061,14 @@ func convertPGPMessageToFile(html string, authUserID database.UserID) string { } // Auto convert pasted pgp public key into uploaded file -func convertPGPPublicKeyToFile(html string, authUserID database.UserID) string { +func convertPGPPublicKeyToFile(db *database.DkfDB, html string, authUserID database.UserID) string { startIdx := strings.Index(html, pgpPKeyPrefix) endIdx := strings.Index(html, pgpPKeySuffix) if startIdx != -1 && endIdx != -1 { pkeySubSlice := html[startIdx : endIdx+len(pgpPKeySuffix)] unescapedPkey := html2.UnescapeString(pkeySubSlice) tmp := convertInlinePGPPublicKey(unescapedPkey) - upload, _ := database.CreateUpload("pgp_pkey.txt", []byte(tmp), authUserID) + upload, _ := db.CreateUpload("pgp_pkey.txt", []byte(tmp), authUserID) msgBefore := html[0:startIdx] msgAfter := html[endIdx+len(pgpPKeySuffix):] html = msgBefore + ` [` + upload.GetHTMLLink() + `] ` + msgAfter @@ -1073,12 +1077,12 @@ func convertPGPPublicKeyToFile(html string, authUserID database.UserID) string { return html } -func convertPGPClearsignToFile(html string, authUserID database.UserID) string { +func convertPGPClearsignToFile(db *database.DkfDB, html string, authUserID database.UserID) string { if b, _ := clearsign.Decode([]byte(html)); b != nil { startIdx := strings.Index(html, pgpSignedPrefix) endIdx := strings.Index(html, pgpSignedSuffix) tmp := html[startIdx : endIdx+len(pgpSignedSuffix)] - upload, _ := database.CreateUpload("pgp_clearsign.txt", []byte(tmp), authUserID) + upload, _ := db.CreateUpload("pgp_clearsign.txt", []byte(tmp), authUserID) msgBefore := html[0:startIdx] msgAfter := html[endIdx+len(pgpSignedSuffix):] html = msgBefore + ` [` + upload.GetHTMLLink() + `] ` + msgAfter @@ -1126,7 +1130,7 @@ func convertInlinePGPPublicKey(inlinePKey string) string { } // Auto convert pasted age message into uploaded file -func convertAgeMessageToFile(html string, authUserID database.UserID) string { +func convertAgeMessageToFile(db *database.DkfDB, html string, authUserID database.UserID) string { startIdx := strings.Index(html, agePrefix) endIdx := strings.Index(html, ageSuffix) if startIdx != -1 && endIdx != -1 { @@ -1137,7 +1141,7 @@ func convertAgeMessageToFile(html string, authUserID database.UserID) string { tmp = strings.Join(strings.Split(tmp, " "), "\n") tmp = agePrefix + tmp tmp += ageSuffix - upload, _ := database.CreateUpload("age.txt", []byte(tmp), authUserID) + upload, _ := db.CreateUpload("age.txt", []byte(tmp), authUserID) msgBefore := html[0:startIdx] msgAfter := html[endIdx+len(ageSuffix):] html = msgBefore + ` [` + upload.GetHTMLLink() + `] ` + msgAfter diff --git a/pkg/web/handlers/api/v1/uploadInterceptor.go b/pkg/web/handlers/api/v1/uploadInterceptor.go @@ -20,7 +20,7 @@ func (i UploadInterceptor) InterceptMsg(cmd *Command) { if file, handler, uploadErr := cmd.c.Request().FormFile("file"); uploadErr == nil { // Save file on disk & database & append file link to html var err error - cmd.upload, err = handleUploadedFile(file, handler, cmd.authUser) + cmd.upload, err = handleUploadedFile(cmd.db, file, handler, cmd.authUser) if err != nil { cmd.err = err return @@ -28,12 +28,12 @@ func (i UploadInterceptor) InterceptMsg(cmd *Command) { } } -func handleUploadedFile(file multipart.File, handler *multipart.FileHeader, authUser *database.User) (*database.Upload, error) { +func handleUploadedFile(db *database.DkfDB, file multipart.File, handler *multipart.FileHeader, authUser *database.User) (*database.Upload, error) { defer file.Close() if !authUser.CanUpload() { return nil, hutils.AccountTooYoungErr } - userSizeUploaded := database.GetUserTotalUploadSize(authUser.ID) + userSizeUploaded := db.GetUserTotalUploadSize(authUser.ID) if handler.Size+userSizeUploaded > config.MaxUserTotalUploadSize { return nil, fmt.Errorf("user upload limit reached (%s)", humanize.Bytes(config.MaxUserTotalUploadSize)) } @@ -66,7 +66,7 @@ func handleUploadedFile(file multipart.File, handler *multipart.FileHeader, auth } // Uploaded files are encrypted on disk - upload, err := database.CreateEncryptedUploadWithSize(origFileName, fileBytes, authUser.ID, handler.Size) + upload, err := db.CreateEncryptedUploadWithSize(origFileName, fileBytes, authUser.ID, handler.Size) if err != nil { logrus.Error(err) return nil, err diff --git a/pkg/web/handlers/api/v1/werewolf.go b/pkg/web/handlers/api/v1/werewolf.go @@ -37,6 +37,7 @@ const ( var ErrInvalidPlayerName = errors.New("unknown player name, please send a valid name") type Werewolf struct { + db *database.DkfDB ctx context.Context cancel context.CancelFunc readyCh chan bool @@ -105,7 +106,7 @@ func (b *Werewolf) InterceptPreGameMsg(cmd *Command) { b.cancel() time.Sleep(time.Second) utils.SGo(func() { - b.StartGame() + b.StartGame(cmd.db) }) cmd.err = ErrRedirect return @@ -212,7 +213,7 @@ func (b *Werewolf) InterceptMsg(cmd *Command) { cmd.err = ErrRedirect return } else if cmd.authUser.IsModerator() && cmd.message == "/clear" { - _ = database.DeleteChatRoomMessages(b.roomID) + _ = cmd.db.DeleteChatRoomMessages(b.roomID) b.Narrate(tuto, nil, nil) cmd.err = ErrRedirect return @@ -328,12 +329,12 @@ func (b *Werewolf) isValidPlayerName(name string) bool { // Narrate register a chat message on behalf of the narrator user func (b *Werewolf) Narrate(msg string, toUserID *database.UserID, groupID *database.GroupID) { - html, _ := ProcessRawMessage(msg, "", b.narratorID, b.roomID, nil) + html, _ := ProcessRawMessage(b.db, msg, "", b.narratorID, b.roomID, nil) b.NarrateRaw(html, toUserID, groupID) } func (b *Werewolf) NarrateRaw(msg string, toUserID *database.UserID, groupID *database.GroupID) { - _, _ = database.CreateOrEditMessage(nil, msg, msg, "", b.roomID, b.narratorID, toUserID, nil, groupID, false, false, false) + _, _ = b.db.CreateOrEditMessage(nil, msg, msg, "", b.roomID, b.narratorID, toUserID, nil, groupID, false, false, false) } // Display roles assigned at beginning of the Game @@ -345,7 +346,7 @@ func (b *Werewolf) displayRoles() { b.Narrate(msg, nil, nil) } -func (b *Werewolf) StartGame() { +func (b *Werewolf) StartGame(db *database.DkfDB) { defer func() { b.displayRoles() b.reset() @@ -360,7 +361,7 @@ func (b *Werewolf) StartGame() { for idx, player := range playersArr { if idx == 0 { b.werewolfSet.Insert(player.UserID) - _, _ = database.AddUserToRoomGroup(b.roomID, b.werewolfGroupID, player.UserID) + _, _ = db.AddUserToRoomGroup(b.roomID, b.werewolfGroupID, player.UserID) player.Role = WerewolfRole werewolfMsg := "During the day you seem to be a regular Townsperson.\n" + "However, you’ve been kissed by the Night and transform into a Werewolf when the sun sets.\n" + @@ -420,7 +421,7 @@ func (b *Werewolf) StartGame() { "There you find @"+playerNameToKill+"’s mangled remains by the Great Oak.\n"+ "Curiously, there are deep claw marks in the bark of the surrounding trees.\n"+ "It looks like @"+playerNameToKill+" put up a fight.", nil, nil) - b.kill(playerNameToKill) + b.kill(db, playerNameToKill) } b.Narrate("Players still alive: "+b.alivePlayersStr(), nil, nil) @@ -450,7 +451,7 @@ func (b *Werewolf) StartGame() { b.Narrate("Townspeople do not want to execute anyone", nil, nil) } else { b.Narrate("Townspeople execute @"+killName, nil, nil) - b.kill(killName) + b.kill(db, killName) } b.Narrate("Players still alive: "+b.alivePlayersStr(), nil, nil) @@ -478,7 +479,7 @@ func (b *Werewolf) alivePlayersStr() (out string) { } // Kill a player -func (b *Werewolf) kill(playerName string) { +func (b *Werewolf) kill(db *database.DkfDB, playerName string) { player, found := b.playersAlive[playerName] if !found { return @@ -487,7 +488,7 @@ func (b *Werewolf) kill(playerName string) { switch player.Role { case WerewolfRole: b.werewolfSet.Remove(player.UserID) - _ = database.RmUserFromRoomGroup(b.roomID, b.werewolfGroupID, player.UserID) + _ = db.RmUserFromRoomGroup(b.roomID, b.werewolfGroupID, player.UserID) case TownspeopleRole: b.townspersonSet.Remove(player.UserID) case HealerRole: @@ -497,7 +498,7 @@ func (b *Werewolf) kill(playerName string) { b.townspersonSet.Remove(player.UserID) b.seerID = nil } - _, _ = database.AddUserToRoomGroup(b.roomID, b.deadGroupID, player.UserID) + _, _ = db.AddUserToRoomGroup(b.roomID, b.deadGroupID, player.UserID) } // Return the name of the player name that receive the most vote @@ -637,15 +638,15 @@ func (b *Werewolf) LockGroups() { } func (b *Werewolf) LockGroup(groupName string) { - group, _ := database.GetRoomGroupByName(b.roomID, groupName) + group, _ := b.db.GetRoomGroupByName(b.roomID, groupName) group.Locked = true - group.DoSave() + group.DoSave(b.db) } func (b *Werewolf) UnlockGroup(groupName string) { - group, _ := database.GetRoomGroupByName(b.roomID, groupName) + group, _ := b.db.GetRoomGroupByName(b.roomID, groupName) group.Locked = false - group.DoSave() + group.DoSave(b.db) } type Player struct { @@ -668,27 +669,28 @@ func (b *Werewolf) reset() { b.healerCh = make(chan string) b.votesCh = make(chan string) b.readyCh = make(chan bool) - _ = database.ClearRoomGroup(b.roomID, b.werewolfGroupID) - _ = database.ClearRoomGroup(b.roomID, b.spectatorGroupID) - _ = database.ClearRoomGroup(b.roomID, b.deadGroupID) + _ = b.db.ClearRoomGroup(b.roomID, b.werewolfGroupID) + _ = b.db.ClearRoomGroup(b.roomID, b.spectatorGroupID) + _ = b.db.ClearRoomGroup(b.roomID, b.deadGroupID) } -func NewWerewolf() *Werewolf { +func NewWerewolf(db *database.DkfDB) *Werewolf { // Prepare room - room, err := database.GetChatRoomByName("werewolf") + room, err := db.GetChatRoomByName("werewolf") if err != nil { logrus.Error("#werewolf room not found") return nil } - zeroUser, _ := database.GetUserByUsername(config.NullUsername) - _ = database.DeleteChatRoomGroups(room.ID) - werewolfGroup, _ := database.CreateChatRoomGroup(room.ID, "werewolf", "#ffffff") + zeroUser, _ := db.GetUserByUsername(config.NullUsername) + _ = db.DeleteChatRoomGroups(room.ID) + werewolfGroup, _ := db.CreateChatRoomGroup(room.ID, "werewolf", "#ffffff") werewolfGroup.Locked = true - werewolfGroup.DoSave() - spectatorGroup, _ := database.CreateChatRoomGroup(room.ID, "spectator", "#ffffff") - deadGroup, _ := database.CreateChatRoomGroup(room.ID, "dead", "#ffffff") + werewolfGroup.DoSave(db) + spectatorGroup, _ := db.CreateChatRoomGroup(room.ID, "spectator", "#ffffff") + deadGroup, _ := db.CreateChatRoomGroup(room.ID, "dead", "#ffffff") b := new(Werewolf) + b.db = db b.werewolfGroupID = werewolfGroup.ID b.spectatorGroupID = spectatorGroup.ID b.deadGroupID = deadGroup.ID diff --git a/pkg/web/handlers/chat.go b/pkg/web/handlers/chat.go @@ -14,6 +14,7 @@ import ( func chatHandler(c echo.Context, redRoom bool) error { const chatPasswordTmplName = "standalone.chat-password" authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) var data chatData data.RedRoom = redRoom preventRefresh := utils.DoParseBool(c.QueryParam("r")) @@ -44,7 +45,7 @@ func chatHandler(c echo.Context, redRoom bool) error { data.CaptchaID, data.CaptchaImg = captcha.New() } - room, err := database.GetChatRoomByName(getRoomName(c)) + room, err := db.GetChatRoomByName(getRoomName(c)) if err != nil { return c.Redirect(http.StatusFound, "/") } @@ -59,13 +60,13 @@ func chatHandler(c echo.Context, redRoom bool) error { data.TutoFrames = generateCssFrames(data.TutoSecs, nil, true) if c.Request().Method == http.MethodGet { authUser.ChatTutorialTime = time.Now() - authUser.DoSave() + authUser.DoSave(db) } } } if c.Request().Method == http.MethodPost { - return handlePost(c, data, authUser) + return handlePost(db, c, data, authUser) } // If you don't have access to the room (room is protected and user is nil or no cookie with the password) @@ -77,7 +78,7 @@ func chatHandler(c echo.Context, redRoom bool) error { return c.Render(http.StatusOK, chatPasswordTmplName, data) } - data.IsSubscribed = database.IsUserSubscribedToRoom(authUser.ID, room.ID) + data.IsSubscribed = db.IsUserSubscribedToRoom(authUser.ID, room.ID) data.IsOfficialRoom = room.IsOfficialRoom() return c.Render(http.StatusOK, "chat", data) } @@ -90,25 +91,25 @@ func getRoomName(c echo.Context) string { return roomName } -func handlePost(c echo.Context, data chatData, authUser *database.User) error { +func handlePost(db *database.DkfDB, c echo.Context, data chatData, authUser *database.User) error { formName := c.Request().PostFormValue("formName") switch formName { case "logout": return handleLogoutPost(c, data.Room) case "toggle-hb": - return handleToggleHBPost(c, authUser) + return handleToggleHBPost(db, c, authUser) case "toggle-m": - return handleToggleMPost(c, authUser) + return handleToggleMPost(db, c, authUser) case "toggle-ignored": - return handleToggleIgnoredPost(c, authUser) + return handleToggleIgnoredPost(db, c, authUser) case "afk": - return handleAfkPost(c, authUser) + return handleAfkPost(db, c, authUser) case "update-read-marker": - return handleUpdateReadMarkerPost(c, data.Room, authUser) + return handleUpdateReadMarkerPost(db, c, data.Room, authUser) case "tutorialP1", "tutorialP2", "tutorialP3": - return handleTutorialPost(c, data, authUser) + return handleTutorialPost(db, c, data, authUser) case "chat-password": - return handleChatPasswordPost(c, data, authUser) + return handleChatPasswordPost(db, c, data, authUser) } return c.Redirect(http.StatusFound, c.Request().Referer()) } @@ -119,49 +120,49 @@ func handleLogoutPost(c echo.Context, room database.ChatRoom) error { return c.Redirect(http.StatusFound, "/chat") } -func handleToggleHBPost(c echo.Context, authUser *database.User) error { +func handleToggleHBPost(db *database.DkfDB, c echo.Context, authUser *database.User) error { if authUser.CanSeeHB() { authUser.DisplayHellbanned = !authUser.DisplayHellbanned - authUser.DoSave() + authUser.DoSave(db) } return c.Redirect(http.StatusFound, c.Request().Referer()) } -func handleToggleMPost(c echo.Context, authUser *database.User) error { +func handleToggleMPost(db *database.DkfDB, c echo.Context, authUser *database.User) error { if authUser.IsModerator() { authUser.DisplayModerators = !authUser.DisplayModerators - authUser.DoSave() + authUser.DoSave(db) } return c.Redirect(http.StatusFound, c.Request().Referer()) } -func handleToggleIgnoredPost(c echo.Context, authUser *database.User) error { +func handleToggleIgnoredPost(db *database.DkfDB, c echo.Context, authUser *database.User) error { authUser.DisplayIgnored = !authUser.DisplayIgnored - authUser.DoSave() + authUser.DoSave(db) return c.Redirect(http.StatusFound, c.Request().Referer()) } -func handleAfkPost(c echo.Context, authUser *database.User) error { +func handleAfkPost(db *database.DkfDB, c echo.Context, authUser *database.User) error { authUser.AFK = !authUser.AFK - authUser.DoSave() + authUser.DoSave(db) return c.Redirect(http.StatusFound, c.Request().Referer()) } -func handleUpdateReadMarkerPost(c echo.Context, room database.ChatRoom, authUser *database.User) error { - database.UpdateChatReadMarker(authUser.ID, room.ID) +func handleUpdateReadMarkerPost(db *database.DkfDB, c echo.Context, room database.ChatRoom, authUser *database.User) error { + db.UpdateChatReadMarker(authUser.ID, room.ID) return c.Redirect(http.StatusFound, c.Request().Referer()) } -func handleTutorialPost(c echo.Context, data chatData, authUser *database.User) error { +func handleTutorialPost(db *database.DkfDB, c echo.Context, data chatData, authUser *database.User) error { if authUser.ChatTutorial < 3 && time.Since(authUser.ChatTutorialTime) >= time.Duration(data.TutoSecs)*time.Second { authUser.ChatTutorial++ - authUser.DoSave() + authUser.DoSave(db) } return c.Redirect(http.StatusFound, c.Request().Referer()) } // Handle POST requests for chat-password, when someone tries to authenticate in a protected room providing a password. -func handleChatPasswordPost(c echo.Context, data chatData, authUser *database.User) error { +func handleChatPasswordPost(db *database.DkfDB, c echo.Context, data chatData, authUser *database.User) error { const chatPasswordTmplName = "standalone.chat-password" data.RoomPassword = c.Request().PostFormValue("password") @@ -175,7 +176,7 @@ func handleChatPasswordPost(c echo.Context, data chatData, authUser *database.Us return c.Render(http.StatusOK, chatPasswordTmplName, data) } - if err := database.CanUseUsername(data.GuestUsername, false); err != nil { + if err := db.CanUseUsername(data.GuestUsername, false); err != nil { data.ErrGuestUsername = err.Error() return c.Render(http.StatusOK, chatPasswordTmplName, data) } @@ -193,13 +194,13 @@ func handleChatPasswordPost(c echo.Context, data chatData, authUser *database.Us // TODO: maybe add "_guest" suffix to guest accounts? if authUser == nil { password := utils.GenerateToken32() - newUser, errs := database.CreateGuestUser(data.GuestUsername, password) + newUser, errs := db.CreateGuestUser(data.GuestUsername, password) if errs.HasError() { data.ErrGuestUsername = errs.Username return c.Render(http.StatusOK, chatPasswordTmplName, data) } - session := database.DoCreateSession(newUser.ID, c.Request().UserAgent()) + session := db.DoCreateSession(newUser.ID, c.Request().UserAgent()) c.SetCookie(createSessionCookie(session.Token)) } diff --git a/pkg/web/handlers/club.go b/pkg/web/handlers/club.go @@ -10,14 +10,16 @@ import ( func ClubHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) var data clubData data.ActiveTab = "home" - data.ForumThreads, _ = database.GetClubForumThreads(authUser.ID) + data.ForumThreads, _ = db.GetClubForumThreads(authUser.ID) return c.Render(http.StatusOK, "club.home", data) } func ClubNewThreadHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) var data clubNewThreadData data.ActiveTab = "home" @@ -33,9 +35,9 @@ func ClubNewThreadHandler(c echo.Context) error { return c.Render(http.StatusOK, "club.new-thread", data) } thread := database.MakeForumThread(data.ThreadName, authUser.ID, 0) - database.DB.Create(&thread) + db.DB().Create(&thread) message := database.MakeForumMessage(data.Message, authUser.ID, thread.ID) - database.DB.Create(&message) + db.DB().Create(&message) return c.Redirect(http.StatusFound, "/club/threads/"+utils.FormatInt64(int64(thread.ID))) } @@ -44,8 +46,9 @@ func ClubNewThreadHandler(c echo.Context) error { func ClubThreadReplyHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) threadID := database.ForumThreadID(utils.DoParseInt64(c.Param("threadID"))) - thread, err := database.GetForumThread(threadID) + thread, err := db.GetForumThread(threadID) if err != nil { return c.Redirect(http.StatusFound, "/") } @@ -60,7 +63,7 @@ func ClubThreadReplyHandler(c echo.Context) error { return c.Render(http.StatusOK, "club.new-thread", data) } message := database.MakeForumMessage(data.Message, authUser.ID, thread.ID) - database.DB.Create(&message) + db.DB().Create(&message) return c.Redirect(http.StatusFound, "/club/threads/"+utils.FormatInt64(int64(thread.ID))) } @@ -69,13 +72,14 @@ func ClubThreadReplyHandler(c echo.Context) error { func ClubThreadEditMessageHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) threadID := database.ForumThreadID(utils.DoParseInt64(c.Param("threadID"))) messageID := database.ForumMessageID(utils.DoParseInt64(c.Param("messageID"))) - thread, err := database.GetForumThread(threadID) + thread, err := db.GetForumThread(threadID) if err != nil { return c.Redirect(http.StatusFound, "/") } - msg, err := database.GetForumMessage(messageID) + msg, err := db.GetForumMessage(messageID) if err != nil { return c.Redirect(http.StatusFound, "/") } @@ -95,7 +99,7 @@ func ClubThreadEditMessageHandler(c echo.Context) error { return c.Render(http.StatusOK, "club.new-thread", data) } msg.Message = data.Message - msg.DoSave() + msg.DoSave(db) return c.Redirect(http.StatusFound, "/club/threads/"+utils.FormatInt64(int64(thread.ID))) } @@ -103,26 +107,28 @@ func ClubThreadEditMessageHandler(c echo.Context) error { } func ClubMembersHandler(c echo.Context) error { + db := c.Get("database").(*database.DkfDB) var data clubMembersData data.ActiveTab = "members" - data.Members, _ = database.GetClubMembers() + data.Members, _ = db.GetClubMembers() return c.Render(http.StatusOK, "club.members", data) } func ClubThreadHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) threadID := database.ForumThreadID(utils.DoParseInt64(c.Param("threadID"))) - thread, err := database.GetForumThread(threadID) + thread, err := db.GetForumThread(threadID) if err != nil { return c.Redirect(http.StatusFound, "/") } var data clubThreadData data.ActiveTab = "home" data.Thread = thread - data.Messages, _ = database.GetThreadMessages(threadID) + data.Messages, _ = db.GetThreadMessages(threadID) // Update read record - database.UpdateForumReadRecord(authUser.ID, threadID) + db.UpdateForumReadRecord(authUser.ID, threadID) return c.Render(http.StatusOK, "club.thread", data) } diff --git a/pkg/web/handlers/handlers.go b/pkg/web/handlers/handlers.go @@ -64,6 +64,7 @@ import ( func firstUseHandler(c echo.Context) error { user := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) var data firstUseData if user != nil { return c.Redirect(http.StatusFound, "/") @@ -80,17 +81,17 @@ func firstUseHandler(c echo.Context) error { data.Username = c.Request().PostFormValue("username") data.Password = c.Request().PostFormValue("password") data.RePassword = c.Request().PostFormValue("repassword") - newUser, errs := database.CreateFirstUser(data.Username, data.Password, data.RePassword) + newUser, errs := db.CreateFirstUser(data.Username, data.Password, data.RePassword) data.Errors = errs if errs.HasError() { return c.Render(http.StatusOK, "standalone.first-use", data) } - _, errs = database.CreateZeroUser() + _, errs = db.CreateZeroUser() config.IsFirstUse.SetFalse() - session := database.DoCreateSession(newUser.ID, c.Request().UserAgent()) + session := db.DoCreateSession(newUser.ID, c.Request().UserAgent()) c.SetCookie(createSessionCookie(session.Token)) return c.Redirect(http.StatusFound, "/") @@ -276,6 +277,7 @@ func LoginAttackHandler(c echo.Context) error { func loginHandler(c echo.Context) error { formName := c.Request().PostFormValue("formName") + db := c.Get("database").(*database.DkfDB) if formName == "" { var data loginData data.Autofocus = 0 @@ -287,7 +289,7 @@ func loginHandler(c echo.Context) error { actualLogin := func(username, password string, captchaSolved bool) error { username = strings.TrimSpace(username) - user, err := database.GetVerifiedUserByUsername(username) + user, err := db.GetVerifiedUserByUsername(username) if err != nil { time.Sleep(utils.RandMs(50, 200)) data.Error = "Invalid username/password" @@ -295,7 +297,7 @@ func loginHandler(c echo.Context) error { } user.LoginAttempts++ - user.DoSave() + user.DoSave(db) if user.LoginAttempts > 4 && !captchaSolved { data.CaptchaRequired = true @@ -315,7 +317,7 @@ func loginHandler(c echo.Context) error { } } - if !user.CheckPassword(password) { + if !user.CheckPassword(db, password) { data.Password = "" data.Autofocus = 1 data.Error = "Invalid username/password" @@ -390,16 +392,17 @@ func loginHandler(c echo.Context) error { } func completeLogin(c echo.Context, user database.User) error { + db := c.Get("database").(*database.DkfDB) user.LoginAttempts = 0 - user.DoSave() + user.DoSave(db) - for _, session := range database.GetActiveUserSessions(user.ID) { + for _, session := range db.GetActiveUserSessions(user.ID) { msg := fmt.Sprintf(`New login`) - database.CreateSessionNotification(msg, session.Token) + db.CreateSessionNotification(msg, session.Token) } - session := database.DoCreateSession(user.ID, c.Request().UserAgent()) - database.CreateSecurityLog(user.ID, database.LoginSecurityLog) + session := db.DoCreateSession(user.ID, c.Request().UserAgent()) + db.CreateSecurityLog(user.ID, database.LoginSecurityLog) c.SetCookie(createSessionCookie(session.Token)) redirectURL := "/" @@ -424,12 +427,13 @@ func LoginCompletedHandler(c echo.Context) error { // SessionsGpgTwoFactorHandler ... func SessionsGpgTwoFactorHandler(c echo.Context, step1 bool, token string) error { + db := c.Get("database").(*database.DkfDB) item, found := partialAuthCache.Get(token) if !found || item.Step != PgpStep { return c.Redirect(http.StatusFound, "/") } - user, err := database.GetUserByID(item.UserID) + user, err := db.GetUserByID(item.UserID) if err != nil { logrus.Errorf("failed to get user %d", item.UserID) return c.Redirect(http.StatusFound, "/") @@ -472,12 +476,13 @@ func SessionsGpgTwoFactorHandler(c echo.Context, step1 bool, token string) error // SessionsGpgSignTwoFactorHandler ... func SessionsGpgSignTwoFactorHandler(c echo.Context, step1 bool, token string) error { + db := c.Get("database").(*database.DkfDB) item, found := partialAuthCache.Get(token) if !found || item.Step != PgpSignStep { return c.Redirect(http.StatusFound, "/") } - user, err := database.GetUserByID(item.UserID) + user, err := db.GetUserByID(item.UserID) if err != nil { logrus.Errorf("failed to get user %d", item.UserID) return c.Redirect(http.StatusFound, "/") @@ -516,6 +521,7 @@ func SessionsGpgSignTwoFactorHandler(c echo.Context, step1 bool, token string) e // SessionsTwoFactorHandler ... func SessionsTwoFactorHandler(c echo.Context, step1 bool, token string) error { + db := c.Get("database").(*database.DkfDB) item, found := partialAuthCache.Get(token) if !found || item.Step != TwoFactorStep { return c.Redirect(http.StatusFound, "/") @@ -525,7 +531,7 @@ func SessionsTwoFactorHandler(c echo.Context, step1 bool, token string) error { data.Token = token if !step1 { code := c.Request().PostFormValue("code") - user, err := database.GetUserByID(item.UserID) + user, err := db.GetUserByID(item.UserID) if err != nil { logrus.Errorf("failed to get user %d", item.UserID) return c.Redirect(http.StatusFound, "/") @@ -545,6 +551,7 @@ func SessionsTwoFactorHandler(c echo.Context, step1 bool, token string) error { // SessionsTwoFactorRecoveryHandler ... func SessionsTwoFactorRecoveryHandler(c echo.Context, token string) error { + db := c.Get("database").(*database.DkfDB) item, found := partialAuthCache.Get(token) if !found { return c.Redirect(http.StatusFound, "/") @@ -554,7 +561,7 @@ func SessionsTwoFactorRecoveryHandler(c echo.Context, token string) error { data.Token = token recoveryCode := c.Request().PostFormValue("code") if recoveryCode != "" { - user, err := database.GetUserByID(item.UserID) + user, err := db.GetUserByID(item.UserID) if err != nil { logrus.Errorf("failed to get user %d", item.UserID) return c.Redirect(http.StatusFound, "/") @@ -574,21 +581,22 @@ func SessionsTwoFactorRecoveryHandler(c echo.Context, token string) error { // LogoutHandler for logout route func LogoutHandler(ctx echo.Context) error { authUser := ctx.Get("authUser").(*database.User) + db := ctx.Get("database").(*database.DkfDB) c, _ := ctx.Cookie(hutils.AuthCookieName) - if err := database.DeleteSessionByToken(c.Value); err != nil { - logrus.Error("Failed to remove session from DB : ", err) + if err := db.DeleteSessionByToken(c.Value); err != nil { + logrus.Error("Failed to remove session from db : ", err) } if authUser.TerminateAllSessionsOnLogout { // Delete active user sessions - if err := database.DeleteUserSessions(authUser.ID); err != nil { + if err := db.DeleteUserSessions(authUser.ID); err != nil { logrus.Error("failed to delete user sessions : ", err) } } - database.CreateSecurityLog(authUser.ID, database.LogoutSecurityLog) + db.CreateSecurityLog(authUser.ID, database.LogoutSecurityLog) ctx.SetCookie(hutils.DeleteCookie(hutils.AuthCookieName)) managers.ActiveUsers.RemoveUser(authUser.ID) if authUser.Temp { - if err := database.DB.Where("id = ?", authUser.ID).Unscoped().Delete(&database.User{}).Error; err != nil { + if err := db.DB().Where("id = ?", authUser.ID).Unscoped().Delete(&database.User{}).Error; err != nil { logrus.Error(err) } } @@ -621,12 +629,13 @@ func SignupAttackHandler(c echo.Context) error { // SignupInvitationHandler ... func SignupInvitationHandler(c echo.Context) error { + db := c.Get("database").(*database.DkfDB) invitationToken := c.Param("invitationToken") invitationTokenQuery := c.QueryParam("invitationToken") if invitationTokenQuery != "" { invitationToken = invitationTokenQuery } - if _, err := database.GetUnusedInvitationByToken(invitationToken); err != nil { + if _, err := db.GetUnusedInvitationByToken(invitationToken); err != nil { return c.Redirect(http.StatusFound, "/") } return waitPageWrapper(c, signupHandler, hutils.WaitCookieName) @@ -699,6 +708,7 @@ type SignupInfo struct { func SignalCss1(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) data := c.Param("data") data = strings.TrimRight(data, ".png") data = strings.TrimRight(data, ".ttf") @@ -716,7 +726,7 @@ func SignalCss1(c echo.Context) error { info.UpdatedAt = time.Now().Format(time.RFC3339) signupInfoEnc, _ := json.Marshal(info) authUser.SignupMetadata = string(signupInfoEnc) - authUser.DoSave() + authUser.DoSave(db) return c.NoContent(http.StatusOK) } @@ -821,6 +831,7 @@ func waitPageWrapper(c echo.Context, clb echo.HandlerFunc, cookieName string) er // Not all requests to the signup endpoint will get the captcha at the same time, // so you cannot just refresh the page until you get a captcha that is easier to crack. func signupHandler(c echo.Context) error { + db := c.Get("database").(*database.DkfDB) start := c.Get("start").(int64) signupToken := c.Get("signupToken").(string) var data signupData @@ -877,7 +888,7 @@ func signupHandler(c echo.Context) error { signupInfoEnc, _ := json.Marshal(signupInfo) registrationDuration := time.Now().UnixMilli() - start - newUser, errs := database.CreateUser(data.Username, data.Password, data.RePassword, registrationDuration, string(signupInfoEnc)) + newUser, errs := db.CreateUser(data.Username, data.Password, data.RePassword, registrationDuration, string(signupInfoEnc)) if errs.HasError() { data.Errors = errs return c.Render(http.StatusOK, "standalone.signup", data) @@ -886,29 +897,29 @@ func signupHandler(c echo.Context) error { // Fuck with hellbanned users. New account also hellbanned if hasHBCookie { newUser.IsHellbanned = true - newUser.DoSave() + newUser.DoSave(db) } invitationToken := c.Param("invitationToken") if invitationToken != "" { - if invitation, err := database.GetUnusedInvitationByToken(invitationToken); err == nil { + if invitation, err := db.GetUnusedInvitationByToken(invitationToken); err == nil { invitation.InviteeUserID = newUser.ID - invitation.DoSave() + invitation.DoSave(db) } } // If more than 10 users were created in the past minute, auto disable signup for the website - if database.GetRecentUsersCount() > 10 { - settings := database.GetSettings() + if db.GetRecentUsersCount() > 10 { + settings := db.GetSettings() settings.SignupEnabled = false - settings.DoSave() + settings.DoSave(db) config.SignupEnabled.SetFalse() - if userNull, err := database.GetUserByUsername(config.NullUsername); err == nil { - database.NewAudit(userNull, fmt.Sprintf("auto turn off signup")) + if userNull, err := db.GetUserByUsername(config.NullUsername); err == nil { + db.NewAudit(userNull, fmt.Sprintf("auto turn off signup")) // Display message in chat txt := fmt.Sprintf("auto turn off registrations") - if err := database.CreateSysMsg(txt, txt, "", config.GeneralRoomID, userNull.ID); err != nil { + if err := db.CreateSysMsg(txt, txt, "", config.GeneralRoomID, userNull.ID); err != nil { logrus.Error(err) } } @@ -977,6 +988,7 @@ func ForgotPasswordHandler(c echo.Context) error { } func forgotPasswordHandler(c echo.Context) error { + db := c.Get("database").(*database.DkfDB) var data forgotPasswordData const ( usernameCaptchaStep = iota + 1 @@ -1012,7 +1024,7 @@ func forgotPasswordHandler(c echo.Context) error { data.ErrCaptcha = err.Error() return c.Render(http.StatusOK, forgotPasswordTmplName, data) } - user, err := database.GetUserByUsername(data.Username) + user, err := db.GetUserByUsername(data.Username) if err != nil { data.UsernameError = "no such user" return c.Render(http.StatusOK, forgotPasswordTmplName, data) @@ -1098,7 +1110,7 @@ func forgotPasswordHandler(c echo.Context) error { return c.Redirect(http.StatusFound, "/") } userID := item.UserID - user, err := database.GetUserByID(userID) + user, err := db.GetUserByID(userID) if err != nil { return c.Redirect(http.StatusFound, "/") } @@ -1108,16 +1120,16 @@ func forgotPasswordHandler(c echo.Context) error { data.NewPassword = newPassword data.RePassword = rePassword - hashedPassword, err := database.NewPasswordValidator(newPassword).CompareWith(rePassword).Hash() + hashedPassword, err := database.NewPasswordValidator(db, newPassword).CompareWith(rePassword).Hash() if err != nil { data.ErrorNewPassword = err.Error() return c.Render(http.StatusOK, forgotPasswordTmplName, data) } - if err := user.ChangePassword(hashedPassword); err != nil { + if err := user.ChangePassword(db, hashedPassword); err != nil { logrus.Error(err) } - database.CreateSecurityLog(user.ID, database.PasswordRecoverySecurityLog) + db.CreateSecurityLog(user.ID, database.PasswordRecoverySecurityLog) partialRecoveryCache.Delete(token) c.SetCookie(hutils.DeleteCookie(hutils.WaitCookieName)) @@ -1134,12 +1146,13 @@ func NewsHandler(c echo.Context) error { } func ForumSearchHandler(c echo.Context) error { + db := c.Get("database").(*database.DkfDB) var data forumSearchData data.Search = c.QueryParam("search") data.AuthorFilter = c.QueryParam("author") if data.AuthorFilter != "" { - if err := database.DB.Raw(`select + if err := db.DB().Raw(`select t.*, u.username as author, u.chat_color as author_chat_color, @@ -1164,7 +1177,7 @@ where u.username = ? and t.is_club = 0 order by id desc limit 100`, data.AuthorF return c.Render(http.StatusOK, "forum-search", data) } - if err := database.DB.Raw(`select m.uuid, snippet(fts5_forum_messages,-1, '[', ']', '...', 10) as snippet, t.uuid as thread_uuid, t.name as thread_name, + if err := db.DB().Raw(`select m.uuid, snippet(fts5_forum_messages,-1, '[', ']', '...', 10) as snippet, t.uuid as thread_uuid, t.name as thread_name, u.username as author, u.chat_color as author_chat_color, u.chat_font as author_chat_font, @@ -1179,7 +1192,7 @@ where fts5_forum_messages match ? and t.is_club = 0 order by rank limit 100`, da logrus.Error(err) } - if err := database.DB.Raw(`select + if err := db.DB().Raw(`select t.*, u.username as author, u.chat_color as author_chat_color, @@ -1207,23 +1220,24 @@ where fts5_forum_threads match ? and t.is_club = 0 order by rank limit 100`, dat func LinksHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) var data linksData - data.Categories, _ = database.GetCategories() + data.Categories, _ = db.GetCategories() data.Search = c.QueryParam("search") filterCategory := c.QueryParam("category") if filterCategory != "" { if filterCategory == "uncategorized" { - database.DB.Raw(`SELECT l.* + db.DB().Raw(`SELECT l.* FROM links l LEFT JOIN links_categories_links cl ON cl.link_id = l.id WHERE cl.link_id IS NULL AND l.deleted_at IS NULL ORDER BY l.title COLLATE NOCASE ASC`).Scan(&data.Links) data.LinksCount = int64(len(data.Links)) } else { - database.DB.Raw(`SELECT l.* + db.DB().Raw(`SELECT l.* FROM links_categories_links cl INNER JOIN links l ON l.id = cl.link_id WHERE cl.category_id = (SELECT id FROM links_categories WHERE name = ?) AND l.deleted_at IS NULL @@ -1235,7 +1249,7 @@ ORDER BY l.title COLLATE NOCASE ASC`, filterCategory).Scan(&data.Links) if searchedURL, err := url.Parse(data.Search); err == nil { h := searchedURL.Scheme + "://" + searchedURL.Hostname() var l database.Link - query := database.DB + query := db.DB() if authUser.IsModerator() { query = query.Unscoped() } @@ -1245,7 +1259,7 @@ ORDER BY l.title COLLATE NOCASE ASC`, filterCategory).Scan(&data.Links) data.LinksCount = int64(len(data.Links)) } } else { - if err := database.DB.Raw(`select l.id, l.uuid, l.url, l.title, l.description + if err := db.DB().Raw(`select l.id, l.uuid, l.url, l.title, l.description from fts5_links l where fts5_links match ? ORDER BY rank, l.title COLLATE NOCASE ASC @@ -1255,7 +1269,7 @@ LIMIT 100`, data.Search).Scan(&data.Links).Error; err != nil { data.LinksCount = int64(len(data.Links)) } } else { - if err := database.DB.Table("links"). + if err := db.DB().Table("links"). Scopes(func(query *gorm.DB) *gorm.DB { data.CurrentPage, data.MaxPage, data.LinksCount, query = NewPaginator().Paginate(c, query) return query @@ -1276,7 +1290,7 @@ LIMIT 100`, data.Search).Scan(&data.Links).Error; err != nil { } // Get all mirrors for all links that we have var mirrors []database.LinksMirror - database.DB.Raw(`select * from links_mirrors where link_id in (?)`, linksIDs).Scan(&mirrors) + db.DB().Raw(`select * from links_mirrors where link_id in (?)`, linksIDs).Scan(&mirrors) // Put mirrors in links for _, m := range mirrors { if l, ok := linksCache[m.LinkID]; ok { @@ -1289,6 +1303,7 @@ LIMIT 100`, data.Search).Scan(&data.Links).Error; err != nil { func LinksDownloadHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) fileName := "dkf_links.csv" // Captcha for bigger files @@ -1307,11 +1322,11 @@ func LinksDownloadHandler(c echo.Context) error { } // Keep track of user downloads - if _, err := database.CreateDownload(authUser.ID, fileName); err != nil { + if _, err := db.CreateDownload(authUser.ID, fileName); err != nil { logrus.Error(err) } - links, _ := database.GetLinks() + links, _ := db.GetLinks() by := make([]byte, 0) buf := bytes.NewBuffer(by) w := csv.NewWriter(buf) @@ -1326,9 +1341,10 @@ func LinksDownloadHandler(c echo.Context) error { func LinkPgpDownloadHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) pgpID := utils.DoParseInt64(c.Param("linkPgpID")) - linkPgp, err := database.GetLinkPgpByID(pgpID) + linkPgp, err := db.GetLinkPgpByID(pgpID) if err != nil { return c.NoContent(http.StatusNotFound) } @@ -1336,7 +1352,7 @@ func LinkPgpDownloadHandler(c echo.Context) error { fileName := linkPgp.Title + ".asc" // Keep track of user downloads - if _, err := database.CreateDownload(authUser.ID, fileName); err != nil { + if _, err := db.CreateDownload(authUser.ID, fileName); err != nil { logrus.Error(err) } @@ -1349,20 +1365,21 @@ func LinksClaimInstructionsHandler(c echo.Context) error { } func LinkHandler(c echo.Context) error { + db := c.Get("database").(*database.DkfDB) shorthand := c.Param("shorthand") linkUUID := c.Param("linkUUID") var data linkData var err error if shorthand != "" { - data.Link, err = database.GetLinkByShorthand(shorthand) + data.Link, err = db.GetLinkByShorthand(shorthand) } else { - data.Link, err = database.GetLinkByUUID(linkUUID) + data.Link, err = db.GetLinkByUUID(linkUUID) } if err != nil { return c.Redirect(http.StatusFound, "/links") } - data.PgpKeys, _ = database.GetLinkPgps(data.Link.ID) - data.Mirrors, _ = database.GetLinkMirrors(data.Link.ID) + data.PgpKeys, _ = db.GetLinkPgps(data.Link.ID) + data.Mirrors, _ = db.GetLinkMirrors(data.Link.ID) return c.Render(http.StatusOK, "link", data) } @@ -1372,6 +1389,7 @@ type CsvLink struct { } func LinksUploadHandler(c echo.Context) error { + db := c.Get("database").(*database.DkfDB) var data linksUploadData if c.Request().Method == http.MethodPost { data.CsvStr = c.Request().PostFormValue("csv") @@ -1404,7 +1422,7 @@ func LinksUploadHandler(c echo.Context) error { return c.Render(http.StatusOK, "links-upload", data) } for _, csvLink := range csvLinks { - _, err := database.CreateLink(csvLink.URL, csvLink.Title, "", "") + _, err := db.CreateLink(csvLink.URL, csvLink.Title, "", "") if err != nil { logrus.Error(err) } @@ -1415,18 +1433,20 @@ func LinksUploadHandler(c echo.Context) error { } func LinksReindexHandler(c echo.Context) error { - if err := database.DB.Exec(`INSERT INTO fts5_links(fts5_links) VALUES('rebuild')`).Error; err != nil { + db := c.Get("database").(*database.DkfDB) + if err := db.DB().Exec(`INSERT INTO fts5_links(fts5_links) VALUES('rebuild')`).Error; err != nil { logrus.Error(err) } - database.DB.Exec(`delete from fts5_links where rowid in (select id from links where deleted_at is not null)`) + db.DB().Exec(`delete from fts5_links where rowid in (select id from links where deleted_at is not null)`) return c.Redirect(http.StatusFound, c.Request().Referer()) } func ForumReindexHandler(c echo.Context) error { - if err := database.DB.Exec(`INSERT INTO fts5_forum_threads(fts5_forum_threads) VALUES('rebuild')`).Error; err != nil { + db := c.Get("database").(*database.DkfDB) + if err := db.DB().Exec(`INSERT INTO fts5_forum_threads(fts5_forum_threads) VALUES('rebuild')`).Error; err != nil { logrus.Error(err) } - if err := database.DB.Exec(`INSERT INTO fts5_forum_messages(fts5_forum_messages) VALUES('rebuild')`).Error; err != nil { + if err := db.DB().Exec(`INSERT INTO fts5_forum_messages(fts5_forum_messages) VALUES('rebuild')`).Error; err != nil { logrus.Error(err) } return c.Redirect(http.StatusFound, c.Request().Referer()) @@ -1434,6 +1454,7 @@ func ForumReindexHandler(c echo.Context) error { func NewLinkHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) if !authUser.IsModerator() { return c.Redirect(http.StatusFound, "/") } @@ -1489,26 +1510,26 @@ func NewLinkHandler(c echo.Context) error { var categories []database.LinksCategory var tags []database.LinksTag for _, categoryStr := range categoriesStr { - category, _ := database.CreateLinksCategory(categoryStr) + category, _ := db.CreateLinksCategory(categoryStr) categories = append(categories, category) } for _, tagStr := range tagsStr { - tag, _ := database.CreateLinksTag(tagStr) + tag, _ := db.CreateLinksTag(tagStr) tags = append(tags, tag) } - link, err := database.CreateLink(data.Link, data.Title, data.Description, data.Shorthand) + link, err := db.CreateLink(data.Link, data.Title, data.Description, data.Shorthand) if err != nil { logrus.Error(err) data.ErrorLink = "failed to create link" return c.Render(http.StatusOK, "new-link", data) } for _, category := range categories { - _ = database.AddLinkCategory(link.ID, category.ID) + _ = db.AddLinkCategory(link.ID, category.ID) } for _, tag := range tags { - _ = database.AddLinkTag(link.ID, tag.ID) + _ = db.AddLinkTag(link.ID, tag.ID) } - database.NewAudit(*authUser, fmt.Sprintf("create link %s", link.URL)) + db.NewAudit(*authUser, fmt.Sprintf("create link %s", link.URL)) return c.Redirect(http.StatusFound, "/links") } return c.Render(http.StatusOK, "new-link", data) @@ -1516,35 +1537,37 @@ func NewLinkHandler(c echo.Context) error { func RestoreLinkHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) if !authUser.IsModerator() { return c.Redirect(http.StatusFound, "/") } linkUUID := c.Param("linkUUID") var link database.Link - if err := database.DB.Unscoped().First(&link, "uuid = ?", linkUUID).Error; err != nil { + if err := db.DB().Unscoped().First(&link, "uuid = ?", linkUUID).Error; err != nil { return c.Redirect(http.StatusFound, c.Request().Referer()) } - database.NewAudit(*authUser, fmt.Sprintf("restore link %s", link.URL)) - database.DB.Unscoped().Model(&database.Link{}).Where("id = ?", link.ID).Update("deleted_at", nil) + db.NewAudit(*authUser, fmt.Sprintf("restore link %s", link.URL)) + db.DB().Unscoped().Model(&database.Link{}).Where("id = ?", link.ID).Update("deleted_at", nil) return c.Redirect(http.StatusFound, c.Request().Referer()) } func EditLinkHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) if !authUser.IsModerator() { return c.Redirect(http.StatusFound, "/") } linkUUID := c.Param("linkUUID") - link, err := database.GetLinkByUUID(linkUUID) + link, err := db.GetLinkByUUID(linkUUID) if err != nil { return c.Redirect(http.StatusFound, "/") } - out, _ := database.GetLinkCategories(link.ID) + out, _ := db.GetLinkCategories(link.ID) categories := make([]string, 0) for _, el := range out { categories = append(categories, el.Name) } - out1, err := database.GetLinkTags(link.ID) + out1, err := db.GetLinkTags(link.ID) tags := make([]string, 0) for _, el := range out1 { tags = append(tags, el.Name) @@ -1559,15 +1582,15 @@ func EditLinkHandler(c echo.Context) error { } data.Categories = strings.Join(categories, ",") data.Tags = strings.Join(tags, ",") - data.Mirrors, _ = database.GetLinkMirrors(link.ID) - data.LinkPgps, _ = database.GetLinkPgps(link.ID) + data.Mirrors, _ = db.GetLinkMirrors(link.ID) + data.LinkPgps, _ = db.GetLinkPgps(link.ID) //data.Categories = link if c.Request().Method == http.MethodPost { formName := c.Request().PostFormValue("formName") if formName == "createLink" { - _ = database.DeleteLinkCategories(link.ID) - _ = database.DeleteLinkTags(link.ID) + _ = db.DeleteLinkCategories(link.ID) + _ = db.DeleteLinkTags(link.ID) // If link is signed, we can no longer edit the link URL if link.SignedCertificate == "" { @@ -1622,11 +1645,11 @@ func EditLinkHandler(c echo.Context) error { var categories []database.LinksCategory var tags []database.LinksTag for _, categoryStr := range categoriesStr { - category, _ := database.CreateLinksCategory(categoryStr) + category, _ := db.CreateLinksCategory(categoryStr) categories = append(categories, category) } for _, tagStr := range tagsStr { - tag, _ := database.CreateLinksTag(tagStr) + tag, _ := db.CreateLinksTag(tagStr) tags = append(tags, tag) } link.URL = data.Link @@ -1635,7 +1658,7 @@ func EditLinkHandler(c echo.Context) error { if data.Shorthand != "" { link.Shorthand = &data.Shorthand } - if err := database.DB.Save(&link).Error; err != nil { + if err := db.DB().Save(&link).Error; err != nil { if strings.Contains(err.Error(), "UNIQUE constraint failed: links.shorthand") { data.ErrorShorthand = "shorthand already used" } else { @@ -1644,12 +1667,12 @@ func EditLinkHandler(c echo.Context) error { return c.Render(http.StatusOK, "new-link", data) } for _, category := range categories { - _ = database.AddLinkCategory(link.ID, category.ID) + _ = db.AddLinkCategory(link.ID, category.ID) } for _, tag := range tags { - _ = database.AddLinkTag(link.ID, tag.ID) + _ = db.AddLinkTag(link.ID, tag.ID) } - database.NewAudit(*authUser, fmt.Sprintf("updated link %s", link.URL)) + db.NewAudit(*authUser, fmt.Sprintf("updated link %s", link.URL)) return c.Redirect(http.StatusFound, "/links") } else if formName == "createPgp" { @@ -1660,10 +1683,10 @@ func EditLinkHandler(c echo.Context) error { } data.PGPDescription = c.Request().PostFormValue("pgp_description") data.PGPPublicKey = c.Request().PostFormValue("pgp_public_key") - if _, err = database.CreateLinkPgp(link.ID, data.PGPTitle, data.PGPDescription, data.PGPPublicKey); err != nil { + if _, err = db.CreateLinkPgp(link.ID, data.PGPTitle, data.PGPDescription, data.PGPPublicKey); err != nil { logrus.Error(err) } - database.NewAudit(*authUser, fmt.Sprintf("create gpg for link %s", link.URL)) + db.NewAudit(*authUser, fmt.Sprintf("create gpg for link %s", link.URL)) return c.Redirect(http.StatusFound, c.Request().Referer()) } else if formName == "createMirror" { @@ -1672,10 +1695,10 @@ func EditLinkHandler(c echo.Context) error { data.ErrorMirrorLink = "invalid link" return c.Render(http.StatusOK, "new-link", data) } - if _, err = database.CreateLinkMirror(link.ID, data.MirrorLink); err != nil { + if _, err = db.CreateLinkMirror(link.ID, data.MirrorLink); err != nil { logrus.Error(err) } - database.NewAudit(*authUser, fmt.Sprintf("create mirror for link %s", link.URL)) + db.NewAudit(*authUser, fmt.Sprintf("create mirror for link %s", link.URL)) return c.Redirect(http.StatusFound, c.Request().Referer()) } return c.Redirect(http.StatusFound, "/links") @@ -1686,8 +1709,9 @@ func EditLinkHandler(c echo.Context) error { func ClaimLinkHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) linkUUID := c.Param("linkUUID") - link, err := database.GetLinkByUUID(linkUUID) + link, err := db.GetLinkByUUID(linkUUID) if err != nil { return c.Redirect(http.StatusFound, "/") } @@ -1720,16 +1744,17 @@ func ClaimLinkHandler(c echo.Context) error { link.SignedCertificate = signedCert link.OwnerUserID = &authUser.ID - link.DoSave() + link.DoSave(db) return c.Redirect(http.StatusFound, "/links/"+link.UUID) } func ClaimDownloadCertificateLinkHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) linkUUID := c.Param("linkUUID") - link, err := database.GetLinkByUUID(linkUUID) + link, err := db.GetLinkByUUID(linkUUID) if err != nil { return c.Redirect(http.StatusFound, "/") } @@ -1737,7 +1762,7 @@ func ClaimDownloadCertificateLinkHandler(c echo.Context) error { fileName := "certificate.txt" // Keep track of user downloads - if _, err := database.CreateDownload(authUser.ID, fileName); err != nil { + if _, err := db.CreateDownload(authUser.ID, fileName); err != nil { logrus.Error(err) } @@ -1747,7 +1772,8 @@ func ClaimDownloadCertificateLinkHandler(c echo.Context) error { func ClaimCertificateLinkHandler(c echo.Context) error { linkUUID := c.Param("linkUUID") - link, err := database.GetLinkByUUID(linkUUID) + db := c.Get("database").(*database.DkfDB) + link, err := db.GetLinkByUUID(linkUUID) if err != nil { return c.Redirect(http.StatusFound, "/") } @@ -1756,35 +1782,38 @@ func ClaimCertificateLinkHandler(c echo.Context) error { func ForumHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) var data forumData - data.ForumCategories, _ = database.GetForumCategories() - data.ForumThreads, _ = database.GetPublicForumCategoryThreads(authUser.ID, 1) + data.ForumCategories, _ = db.GetForumCategories() + data.ForumThreads, _ = db.GetPublicForumCategoryThreads(authUser.ID, 1) return c.Render(http.StatusOK, "forum", data) } func ForumCategoryHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) categorySlug := c.Param("categorySlug") var data forumCategoryData - category, err := database.GetForumCategoryBySlug(categorySlug) + category, err := db.GetForumCategoryBySlug(categorySlug) if err != nil { return c.Redirect(http.StatusFound, "/forum") } - data.ForumThreads, _ = database.GetPublicForumCategoryThreads(authUser.ID, category.ID) + data.ForumThreads, _ = db.GetPublicForumCategoryThreads(authUser.ID, category.ID) return c.Render(http.StatusOK, "forum", data) } func ThreadHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) threadUUID := database.ForumThreadUUID(c.Param("threadUUID")) - thread, err := database.GetForumThreadByUUID(threadUUID) + thread, err := db.GetForumThreadByUUID(threadUUID) if err != nil { return c.Redirect(http.StatusFound, "/") } var data threadData data.Thread = thread - if err := database.DB. + if err := db.DB(). Table("forum_messages"). Where("thread_id = ?", thread.ID). Scopes(func(query *gorm.DB) *gorm.DB { @@ -1799,9 +1828,9 @@ func ThreadHandler(c echo.Context) error { } if authUser != nil { - data.IsSubscribed = database.IsUserSubscribedToForumThread(authUser.ID, thread.ID) + data.IsSubscribed = db.IsUserSubscribedToForumThread(authUser.ID, thread.ID) // Update read record - database.UpdateForumReadRecord(authUser.ID, thread.ID) + db.UpdateForumReadRecord(authUser.ID, thread.ID) } return c.Render(http.StatusOK, "thread", data) @@ -1809,8 +1838,9 @@ func ThreadHandler(c echo.Context) error { func GistHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) gistUUID := c.Param("gistUUID") - gist, err := database.GetGistByUUID(gistUUID) + gist, err := db.GetGistByUUID(gistUUID) if err != nil { return c.Redirect(http.StatusFound, "/") } @@ -1829,7 +1859,7 @@ func GistHandler(c echo.Context) error { if gist.Password != "" { hutils.DeleteGistCookie(c, gist.UUID) } - if err := database.DB.Delete(&gist).Error; err != nil { + if err := db.DB().Delete(&gist).Error; err != nil { logrus.Error(err) } return c.Redirect(http.StatusFound, "/") @@ -1885,12 +1915,13 @@ func ThreadReplyHandler(c echo.Context) error { return c.Render(http.StatusOK, "flash", FlashResponse{Message: "Forum is temporarily disabled", Redirect: "/", Type: "alert-danger"}) } authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) if !authUser.CanUseForumFn() { return c.Render(http.StatusOK, "flash", FlashResponse{Message: hutils.AccountTooYoungErr.Error(), Redirect: c.Request().Referer(), Type: "alert-danger"}) } threadUUID := database.ForumThreadUUID(c.Param("threadUUID")) - thread, err := database.GetForumThreadByUUID(threadUUID) + thread, err := db.GetForumThreadByUUID(threadUUID) if err != nil { return c.Redirect(http.StatusFound, "/") } @@ -1904,22 +1935,22 @@ func ThreadReplyHandler(c echo.Context) error { return c.Render(http.StatusOK, "thread-reply", data) } if isForumSpam(data.Message) { - database.NewAudit(*authUser, fmt.Sprintf("spam forum thread reply %s (#%d)", authUser.Username, authUser.ID)) + db.NewAudit(*authUser, fmt.Sprintf("spam forum thread reply %s (#%d)", authUser.Username, authUser.ID)) authUser.CanUseForum = false - authUser.DoSave() + authUser.DoSave(db) return c.Redirect(http.StatusFound, "/") } message := database.MakeForumMessage(data.Message, authUser.ID, thread.ID) message.IsSigned = message.ValidateSignature(authUser.GPGPublicKey) - if err := database.DB.Create(&message).Error; err != nil { + if err := db.DB().Create(&message).Error; err != nil { logrus.Error(err) } // Send notifications - subs, _ := database.GetUsersSubscribedToForumThread(thread.ID) + subs, _ := db.GetUsersSubscribedToForumThread(thread.ID) for _, sub := range subs { if sub.UserID != authUser.ID { msg := fmt.Sprintf(`New reply in thread &quot;<a href="/t/%s#%s">%s</a>&quot;`, thread.UUID, message.UUID, thread.Name) - database.CreateNotification(msg, sub.UserID) + db.CreateNotification(msg, sub.UserID) } } return c.Redirect(http.StatusFound, "/t/"+string(thread.UUID)) @@ -1933,11 +1964,12 @@ func ThreadDeleteMessageHandler(c echo.Context) error { return c.Render(http.StatusOK, "flash", FlashResponse{Message: "Forum is temporarily disabled", Redirect: "/", Type: "alert-danger"}) } authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) if !authUser.CanUseForumFn() { return c.Render(http.StatusOK, "flash", FlashResponse{Message: hutils.AccountTooYoungErr.Error(), Redirect: c.Request().Referer(), Type: "alert-danger"}) } messageUUID := database.ForumMessageUUID(c.Param("messageUUID")) - msg, err := database.GetForumMessageByUUID(messageUUID) + msg, err := db.GetForumMessageByUUID(messageUUID) if err != nil { return c.Redirect(http.StatusFound, c.Request().Referer()) } @@ -1951,14 +1983,14 @@ func ThreadDeleteMessageHandler(c echo.Context) error { } var data deleteForumMessageData - data.Thread, err = database.GetForumThreadByID(msg.ThreadID) + data.Thread, err = db.GetForumThreadByID(msg.ThreadID) if err != nil { return c.Redirect(http.StatusFound, c.Request().Referer()) } data.Message = msg if c.Request().Method == http.MethodPost { - if err := database.DeleteForumMessageByID(msg.ID); err != nil { + if err := db.DeleteForumMessageByID(msg.ID); err != nil { logrus.Error(err) } return c.Redirect(http.StatusFound, "/t/"+string(data.Thread.UUID)) @@ -1969,8 +2001,9 @@ func ThreadDeleteMessageHandler(c echo.Context) error { func LinkDeleteHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) linkUUID := c.Param("linkUUID") - link, err := database.GetLinkByUUID(linkUUID) + link, err := db.GetLinkByUUID(linkUUID) if err != nil { return c.Redirect(http.StatusFound, c.Request().Referer()) } @@ -1983,8 +2016,8 @@ func LinkDeleteHandler(c echo.Context) error { data.Link = link if c.Request().Method == http.MethodPost { - database.NewAudit(*authUser, fmt.Sprintf("deleted link %s", link.URL)) - if err := database.DeleteLinkByID(link.ID); err != nil { + db.NewAudit(*authUser, fmt.Sprintf("deleted link %s", link.URL)) + if err := db.DeleteLinkByID(link.ID); err != nil { logrus.Error(err) } return c.Redirect(http.StatusFound, "/links") @@ -1995,12 +2028,13 @@ func LinkDeleteHandler(c echo.Context) error { func LinkPgpDeleteHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) linkPgpID := utils.DoParseInt64(c.Param("linkPgpID")) - linkPgp, err := database.GetLinkPgpByID(linkPgpID) + linkPgp, err := db.GetLinkPgpByID(linkPgpID) if err != nil { return c.Redirect(http.StatusFound, c.Request().Referer()) } - link, err := database.GetLinkByID(linkPgp.LinkID) + link, err := db.GetLinkByID(linkPgp.LinkID) if err != nil { return c.Redirect(http.StatusFound, c.Request().Referer()) } @@ -2014,7 +2048,7 @@ func LinkPgpDeleteHandler(c echo.Context) error { data.LinkPgp = linkPgp if c.Request().Method == http.MethodPost { - if err := database.DeleteLinkPgpByID(linkPgp.ID); err != nil { + if err := db.DeleteLinkPgpByID(linkPgp.ID); err != nil { logrus.Error(err) } return c.Redirect(http.StatusFound, "/links/"+link.UUID+"/edit") @@ -2025,12 +2059,13 @@ func LinkPgpDeleteHandler(c echo.Context) error { func LinkMirrorDeleteHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) linkMirrorID := utils.DoParseInt64(c.Param("linkMirrorID")) - linkMirror, err := database.GetLinkMirrorByID(linkMirrorID) + linkMirror, err := db.GetLinkMirrorByID(linkMirrorID) if err != nil { return c.Redirect(http.StatusFound, c.Request().Referer()) } - link, err := database.GetLinkByID(linkMirror.LinkID) + link, err := db.GetLinkByID(linkMirror.LinkID) if err != nil { return c.Redirect(http.StatusFound, c.Request().Referer()) } @@ -2044,7 +2079,7 @@ func LinkMirrorDeleteHandler(c echo.Context) error { data.LinkMirror = linkMirror if c.Request().Method == http.MethodPost { - if err := database.DeleteLinkMirrorByID(linkMirror.ID); err != nil { + if err := db.DeleteLinkMirrorByID(linkMirror.ID); err != nil { logrus.Error(err) } return c.Redirect(http.StatusFound, "/links/"+link.UUID+"/edit") @@ -2054,6 +2089,7 @@ func LinkMirrorDeleteHandler(c echo.Context) error { } func ThreadEditHandler(c echo.Context) error { + db := c.Get("database").(*database.DkfDB) if config.ForumEnabled.IsFalse() { return c.Render(http.StatusOK, "flash", FlashResponse{Message: "Forum is temporarily disabled", Redirect: "/", Type: "alert-danger"}) } @@ -2065,7 +2101,7 @@ func ThreadEditHandler(c echo.Context) error { return c.Redirect(http.StatusFound, c.Request().Referer()) } threadUUID := database.ForumThreadUUID(c.Param("threadUUID")) - thread, err := database.GetForumThreadByUUID(threadUUID) + thread, err := db.GetForumThreadByUUID(threadUUID) if err != nil { return c.Redirect(http.StatusFound, c.Request().Referer()) } @@ -2074,7 +2110,7 @@ func ThreadEditHandler(c echo.Context) error { if c.Request().Method == http.MethodPost { thread.CategoryID = database.ForumCategoryID(utils.DoParseInt64(c.Request().PostFormValue("category_id"))) - thread.DoSave() + thread.DoSave(db) return c.Redirect(http.StatusFound, "/forum") } @@ -2086,11 +2122,12 @@ func ThreadDeleteHandler(c echo.Context) error { return c.Render(http.StatusOK, "flash", FlashResponse{Message: "Forum is temporarily disabled", Redirect: "/", Type: "alert-danger"}) } authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) if !authUser.CanUseForumFn() { return c.Render(http.StatusOK, "flash", FlashResponse{Message: hutils.AccountTooYoungErr.Error(), Redirect: c.Request().Referer(), Type: "alert-danger"}) } threadUUID := database.ForumThreadUUID(c.Param("threadUUID")) - thread, err := database.GetForumThreadByUUID(threadUUID) + thread, err := db.GetForumThreadByUUID(threadUUID) if err != nil { return c.Redirect(http.StatusFound, c.Request().Referer()) } @@ -2103,7 +2140,7 @@ func ThreadDeleteHandler(c echo.Context) error { data.Thread = thread if c.Request().Method == http.MethodPost { - if err := database.DeleteForumThreadByID(thread.ID); err != nil { + if err := db.DeleteForumThreadByID(thread.ID); err != nil { logrus.Error(err) } return c.Redirect(http.StatusFound, "/forum") @@ -2117,16 +2154,17 @@ func ThreadEditMessageHandler(c echo.Context) error { return c.Render(http.StatusOK, "flash", FlashResponse{Message: "Forum is temporarily disabled", Redirect: "/", Type: "alert-danger"}) } authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) if !authUser.CanUseForumFn() { return c.Render(http.StatusOK, "flash", FlashResponse{Message: hutils.AccountTooYoungErr.Error(), Redirect: c.Request().Referer(), Type: "alert-danger"}) } threadUUID := database.ForumThreadUUID(c.Param("threadUUID")) messageUUID := database.ForumMessageUUID(c.Param("messageUUID")) - thread, err := database.GetForumThreadByUUID(threadUUID) + thread, err := db.GetForumThreadByUUID(threadUUID) if err != nil { return c.Redirect(http.StatusFound, "/") } - msg, err := database.GetForumMessageByUUID(messageUUID) + msg, err := db.GetForumMessageByUUID(messageUUID) if err != nil { return c.Redirect(http.StatusFound, "/") } @@ -2145,14 +2183,14 @@ func ThreadEditMessageHandler(c echo.Context) error { return c.Render(http.StatusOK, "thread-reply", data) } if isForumSpam(data.Message) { - database.NewAudit(*authUser, fmt.Sprintf("spam forum edit msg %s (#%d)", authUser.Username, authUser.ID)) + db.NewAudit(*authUser, fmt.Sprintf("spam forum edit msg %s (#%d)", authUser.Username, authUser.ID)) authUser.CanUseForum = false - authUser.DoSave() + authUser.DoSave(db) return c.Redirect(http.StatusFound, "/") } msg.Message = data.Message msg.IsSigned = msg.ValidateSignature(authUser.GPGPublicKey) - msg.DoSave() + msg.DoSave(db) return c.Redirect(http.StatusFound, "/t/"+string(thread.UUID)) } @@ -2163,8 +2201,9 @@ func ThreadRawMessageHandler(c echo.Context) error { if config.ForumEnabled.IsFalse() { return c.Render(http.StatusOK, "flash", FlashResponse{Message: "Forum is temporarily disabled", Redirect: "/", Type: "alert-danger"}) } + db := c.Get("database").(*database.DkfDB) messageUUID := database.ForumMessageUUID(c.Param("messageUUID")) - msg, err := database.GetForumMessageByUUID(messageUUID) + msg, err := db.GetForumMessageByUUID(messageUUID) if err != nil { return c.Redirect(http.StatusFound, "/") } @@ -2183,6 +2222,7 @@ func NewThreadHandler(c echo.Context) error { return c.Render(http.StatusOK, "flash", FlashResponse{Message: "Forum is temporarily disabled", Redirect: "/", Type: "alert-danger"}) } authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) if !authUser.CanUseForumFn() { return c.Render(http.StatusOK, "flash", FlashResponse{Message: hutils.AccountTooYoungErr.Error(), Redirect: c.Request().Referer(), Type: "alert-danger"}) } @@ -2200,17 +2240,17 @@ func NewThreadHandler(c echo.Context) error { return c.Render(http.StatusOK, "new-thread", data) } if isForumSpam(data.Message) { - database.NewAudit(*authUser, fmt.Sprintf("spam forum new thread %s (#%d)", authUser.Username, authUser.ID)) + db.NewAudit(*authUser, fmt.Sprintf("spam forum new thread %s (#%d)", authUser.Username, authUser.ID)) authUser.CanUseForum = false - authUser.DoSave() + authUser.DoSave(db) return c.Redirect(http.StatusFound, "/") } thread := database.MakeForumThread(data.ThreadName, authUser.ID, 1) - database.DB.Create(&thread) + db.DB().Create(&thread) message := database.MakeForumMessage(data.Message, authUser.ID, thread.ID) message.IsSigned = message.ValidateSignature(authUser.GPGPublicKey) - database.DB.Create(&message) - _ = database.SubscribeToForumThread(authUser.ID, thread.ID) + db.DB().Create(&message) + _ = db.SubscribeToForumThread(authUser.ID, thread.ID) return c.Redirect(http.StatusFound, "/t/"+string(thread.UUID)) } @@ -2218,9 +2258,10 @@ func NewThreadHandler(c echo.Context) error { } func VipHandler(c echo.Context) error { + db := c.Get("database").(*database.DkfDB) var data vipData data.ActiveTab = "home" - data.UsersBadges, _ = database.GetUsersBadges() + data.UsersBadges, _ = db.GetUsersBadges() return c.Render(http.StatusOK, "vip.home", data) } @@ -2250,8 +2291,9 @@ func VipProjectsMalwareDropperHandler(c echo.Context) error { func RoomsHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) var data roomsData - data.Rooms, _ = database.GetListedChatRooms(authUser.ID) + data.Rooms, _ = db.GetListedChatRooms(authUser.ID) return c.Render(http.StatusOK, "rooms", data) } @@ -2286,9 +2328,10 @@ func ChatHelpHandler(c echo.Context) error { func RoomChatSettingsHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) var data roomChatSettingsData roomName := c.Param("roomName") - room, err := database.GetChatRoomByName(roomName) + room, err := db.GetChatRoomByName(roomName) if err != nil { return c.Redirect(http.StatusFound, "/") } @@ -2306,6 +2349,7 @@ func RoomChatSettingsHandler(c echo.Context) error { func ChatCreateRoomHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) var data chatCreateRoomData data.CaptchaID, data.CaptchaImg = captcha.New() data.IsEphemeral = true @@ -2328,7 +2372,7 @@ func ChatCreateRoomHandler(c echo.Context) error { if data.Password != "" { passwordHash = database.GetRoomPasswordHash(data.Password) } - if _, err := database.CreateRoom(data.RoomName, passwordHash, authUser.ID, data.IsListed); err != nil { + if _, err := db.CreateRoom(data.RoomName, passwordHash, authUser.ID, data.IsListed); err != nil { data.Error = err.Error() return c.Render(http.StatusOK, "chat-create-room", data) } @@ -2348,8 +2392,9 @@ func ShopHandler(c echo.Context) error { return base64.StdEncoding.EncodeToString(buf.Bytes()) } authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) var data shopData - invoice, err := database.CreateXmrInvoice(authUser.ID, 1) + invoice, err := db.CreateXmrInvoice(authUser.ID, 1) if err != nil { logrus.Error(err) } @@ -2411,12 +2456,13 @@ func SettingsChatHandler(c echo.Context) error { func SettingsChatPMHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) var data settingsChatPMData data.ActiveTab = "chat" data.PmMode = authUser.PmMode data.BlockNewUsersPm = authUser.BlockNewUsersPm - data.WhitelistedUsers, _ = database.GetPmWhitelistedUsers(authUser.ID) - data.BlacklistedUsers, _ = database.GetPmBlacklistedUsers(authUser.ID) + data.WhitelistedUsers, _ = db.GetPmWhitelistedUsers(authUser.ID) + data.BlacklistedUsers, _ = db.GetPmBlacklistedUsers(authUser.ID) if c.Request().Method == http.MethodGet { return c.Render(http.StatusOK, "settings.chat-pm", data) @@ -2427,48 +2473,49 @@ func SettingsChatPMHandler(c echo.Context) error { if formName == "addWhitelist" { data.AddWhitelist = strings.TrimSpace(c.Request().PostFormValue("username")) - user, err := database.GetUserByUsername(data.AddWhitelist) + user, err := db.GetUserByUsername(data.AddWhitelist) if err != nil { data.Error = "username not found" return c.Render(http.StatusOK, "settings.chat-pm", data) } - database.AddWhitelistedUser(authUser.ID, user.ID) + db.AddWhitelistedUser(authUser.ID, user.ID) return c.Redirect(http.StatusFound, c.Request().Referer()) } else if formName == "rmWhitelist" { userID := dutils.DoParseUserID(c.Request().PostFormValue("userID")) - database.RmWhitelistedUser(authUser.ID, userID) + db.RmWhitelistedUser(authUser.ID, userID) return c.Redirect(http.StatusFound, c.Request().Referer()) } else if formName == "addBlacklist" { data.AddBlacklist = strings.TrimSpace(c.Request().PostFormValue("username")) - user, err := database.GetUserByUsername(data.AddBlacklist) + user, err := db.GetUserByUsername(data.AddBlacklist) if err != nil { data.Error = "username not found" return c.Render(http.StatusOK, "settings.chat-pm", data) } - database.AddBlacklistedUser(authUser.ID, user.ID) + db.AddBlacklistedUser(authUser.ID, user.ID) return c.Redirect(http.StatusFound, c.Request().Referer()) } else if formName == "rmBlacklist" { userID := dutils.DoParseUserID(c.Request().PostFormValue("userID")) - database.RmBlacklistedUser(authUser.ID, userID) + db.RmBlacklistedUser(authUser.ID, userID) return c.Redirect(http.StatusFound, c.Request().Referer()) } data.PmMode = utils.Clamp(utils.DoParseInt64(c.Request().PostFormValue("pm_mode")), 0, 1) authUser.BlockNewUsersPm = utils.DoParseBool(c.Request().PostFormValue("block_new_users_pm")) authUser.PmMode = data.PmMode - authUser.DoSave() + authUser.DoSave(db) return c.Redirect(http.StatusFound, c.Request().Referer()) } func SettingsChatIgnoreHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) var data settingsChatIgnoreData data.ActiveTab = "chat" data.PmMode = authUser.PmMode - data.IgnoredUsers, _ = database.GetIgnoredUsers(authUser.ID) + data.IgnoredUsers, _ = db.GetIgnoredUsers(authUser.ID) if c.Request().Method == http.MethodGet { return c.Render(http.StatusOK, "settings.chat-ignore", data) @@ -2479,31 +2526,32 @@ func SettingsChatIgnoreHandler(c echo.Context) error { if formName == "addIgnored" { data.AddIgnored = strings.TrimSpace(c.Request().PostFormValue("username")) - user, err := database.GetUserByUsername(data.AddIgnored) + user, err := db.GetUserByUsername(data.AddIgnored) if err != nil { data.Error = "username not found" return c.Render(http.StatusOK, "settings.chat-ignore", data) } - database.IgnoreUser(authUser.ID, user.ID) + db.IgnoreUser(authUser.ID, user.ID) return c.Redirect(http.StatusFound, c.Request().Referer()) } else if formName == "rmIgnored" { userID := dutils.DoParseUserID(c.Request().PostFormValue("userID")) - database.UnIgnoreUser(authUser.ID, userID) + db.UnIgnoreUser(authUser.ID, userID) return c.Redirect(http.StatusFound, c.Request().Referer()) } data.PmMode = utils.Clamp(utils.DoParseInt64(c.Request().PostFormValue("pm_mode")), 0, 1) authUser.PmMode = data.PmMode - authUser.DoSave() + authUser.DoSave(db) return c.Redirect(http.StatusFound, c.Request().Referer()) } func SettingsChatSnippetsHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) var data settingsChatSnippetsData data.ActiveTab = "snippets" - data.Snippets, _ = database.GetUserSnippets(authUser.ID) + data.Snippets, _ = db.GetUserSnippets(authUser.ID) if c.Request().Method == http.MethodGet { return c.Render(http.StatusOK, "settings.chat-snippets", data) @@ -2531,7 +2579,7 @@ func SettingsChatSnippetsHandler(c echo.Context) error { data.Error = "text must be 1-1000 characters" return c.Render(http.StatusOK, "settings.chat-snippets", data) } - if _, err := database.CreateSnippet(authUser.ID, data.Name, data.Text); err != nil { + if _, err := db.CreateSnippet(authUser.ID, data.Name, data.Text); err != nil { data.Error = err.Error() return c.Render(http.StatusOK, "settings.chat-snippets", data) } @@ -2539,7 +2587,7 @@ func SettingsChatSnippetsHandler(c echo.Context) error { } else if formName == "rmSnippet" { snippetName := c.Request().PostFormValue("snippetName") - database.DeleteSnippet(authUser.ID, snippetName) + db.DeleteSnippet(authUser.ID, snippetName) return c.Redirect(http.StatusFound, c.Request().Referer()) } @@ -2548,9 +2596,10 @@ func SettingsChatSnippetsHandler(c echo.Context) error { func SettingsUploadsHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) var data settingsUploadsData data.ActiveTab = "uploads" - data.Files, _ = database.GetUserUploads(authUser.ID) + data.Files, _ = db.GetUserUploads(authUser.ID) for _, f := range data.Files { data.TotalSize += f.FileSize } @@ -2563,14 +2612,14 @@ func SettingsUploadsHandler(c echo.Context) error { formName := c.FormValue("formName") if formName == "deleteUpload" { fileName := c.Request().PostFormValue("file_name") - file, err := database.GetUploadByFileName(fileName) + file, err := db.GetUploadByFileName(fileName) if err != nil { return c.Redirect(http.StatusFound, "/") } if authUser.ID != file.UserID { return c.Redirect(http.StatusFound, "/") } - if err := file.Delete(); err != nil { + if err := file.Delete(db); err != nil { logrus.Error(err) } return c.Redirect(http.StatusFound, c.Request().Referer()) @@ -2580,16 +2629,17 @@ func SettingsUploadsHandler(c echo.Context) error { func SettingsPublicNotesHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) if !authUser.CanUseForumFn() { return c.Render(http.StatusOK, "flash", FlashResponse{Message: hutils.AccountTooYoungErr.Error(), Redirect: c.Request().Referer(), Type: "alert-danger"}) } var data settingsPublicNotesData data.ActiveTab = "notes" - data.Notes, _ = database.GetUserPublicNotes(authUser.ID) + data.Notes, _ = db.GetUserPublicNotes(authUser.ID) if c.Request().Method == http.MethodPost { notes := c.Request().PostFormValue("public_notes") - if err := database.SetUserPublicNotes(authUser.ID, notes); err != nil { + if err := db.SetUserPublicNotes(authUser.ID, notes); err != nil { data.Error = err.Error() return c.Render(http.StatusOK, "settings.public-notes", data) } @@ -2601,18 +2651,19 @@ func SettingsPublicNotesHandler(c echo.Context) error { func SettingsPrivateNotesHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) if !authUser.CanUseForumFn() { return c.Render(http.StatusOK, "flash", FlashResponse{Message: hutils.AccountTooYoungErr.Error(), Redirect: c.Request().Referer(), Type: "alert-danger"}) } var data settingsPrivateNotesData data.ActiveTab = "notes" if !authUser.IsUnderDuress { - data.Notes, _ = database.GetUserPrivateNotes(authUser.ID) + data.Notes, _ = db.GetUserPrivateNotes(authUser.ID) } if c.Request().Method == http.MethodPost { notes := c.Request().PostFormValue("private_notes") - if err := database.SetUserPrivateNotes(authUser.ID, notes); err != nil { + if err := db.SetUserPrivateNotes(authUser.ID, notes); err != nil { data.Error = err.Error() return c.Render(http.StatusOK, "settings.private-notes", data) } @@ -2625,14 +2676,15 @@ func SettingsPrivateNotesHandler(c echo.Context) error { func SettingsInboxHandler(c echo.Context) error { authCookie, _ := c.Cookie(hutils.AuthCookieName) authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) var data settingsInboxData data.ActiveTab = "inbox" // Do not fetch inboxes & notifications if logged in under duress if !authUser.IsUnderDuress { global.DeleteUserNotificationCount(authUser.ID, authCookie.Value) - data.ChatMessages, _ = database.GetUserChatInboxMessages(authUser.ID) - data.Notifications, _ = database.GetUserNotifications(authUser.ID) - data.SessionNotifications, _ = database.GetUserSessionNotifications(authCookie.Value) + data.ChatMessages, _ = db.GetUserChatInboxMessages(authUser.ID) + data.Notifications, _ = db.GetUserNotifications(authUser.ID) + data.SessionNotifications, _ = db.GetUserSessionNotifications(authCookie.Value) } for _, m := range data.ChatMessages { data.Notifs = append(data.Notifs, InboxTmp{IsNotif: false, ChatInboxMessage: m}) @@ -2669,28 +2721,31 @@ func SettingsInboxHandler(c echo.Context) error { func SettingsInboxSentHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) var data settingsInboxSentData data.ActiveTab = "inbox" // Do not fetch inboxes & notifications if logged in under duress if !authUser.IsUnderDuress { - data.ChatInboxSent, _ = database.GetUserChatInboxMessagesSent(authUser.ID) + data.ChatInboxSent, _ = db.GetUserChatInboxMessagesSent(authUser.ID) } return c.Render(http.StatusOK, "settings.inbox-sent", data) } func SettingsSecurityHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) var data settingsSecurityData data.ActiveTab = "security" - data.Logs, _ = database.GetSecurityLogs(authUser.ID) + data.Logs, _ = db.GetSecurityLogs(authUser.ID) return c.Render(http.StatusOK, "settings.security", data) } func SettingsSessionsHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) var data settingsSessionsData data.ActiveTab = "sessions" - sessions := database.GetActiveUserSessions(authUser.ID) + sessions := db.GetActiveUserSessions(authUser.ID) authCookie, _ := c.Cookie(hutils.AuthCookieName) for _, session := range sessions { s := WrapperSession{Session: session} @@ -2703,10 +2758,10 @@ func SettingsSessionsHandler(c echo.Context) error { if c.Request().Method == http.MethodPost { formName := c.Request().PostFormValue("formName") if formName == "revoke_all_other_sessions" { - _ = database.DeleteUserOtherSessions(authUser.ID, authCookie.Value) + _ = db.DeleteUserOtherSessions(authUser.ID, authCookie.Value) } else { sessionToken := c.Request().PostFormValue("sessionToken") - _ = database.DeleteUserSessionByToken(authUser.ID, sessionToken) + _ = db.DeleteUserSessionByToken(authUser.ID, sessionToken) } return c.Redirect(http.StatusFound, c.Request().Referer()) } @@ -2766,6 +2821,7 @@ func SettingsPasswordHandler(c echo.Context) error { func SettingsSecretPhraseHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) var data settingsSecretPhraseData data.ActiveTab = "secretPhrase" data.SecretPhrase = string(authUser.SecretPhrase) @@ -2790,42 +2846,44 @@ func SettingsSecretPhraseHandler(c echo.Context) error { return c.Render(http.StatusOK, "settings.secret-phrase", data) } - if !authUser.CheckPassword(currentPassword) { + if !authUser.CheckPassword(db, currentPassword) { data.ErrorCurrentPassword = "Invalid password" return c.Render(http.StatusOK, "settings.secret-phrase", data) } authUser.SecretPhrase = database.EncryptedString(secretPhrase) - authUser.DoSave() + authUser.DoSave(db) - database.CreateSecurityLog(authUser.ID, database.ChangeSecretPhraseSecurityLog) + db.CreateSecurityLog(authUser.ID, database.ChangeSecretPhraseSecurityLog) return c.Render(http.StatusFound, "flash", FlashResponse{Message: "Secret phrase changed successfully", Redirect: c.Request().Referer()}) } // SettingsInvitationsHandler ... func SettingsInvitationsHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) var data settingsInvitationsData data.ActiveTab = "invitations" data.DkfOnion = config.DkfOnion if c.Request().Method == http.MethodPost { - if _, err := database.CreateInvitation(authUser.ID); err != nil { + if _, err := db.CreateInvitation(authUser.ID); err != nil { logrus.Error(err) } return c.Redirect(http.StatusFound, c.Request().Referer()) } - data.Invitations, _ = database.GetUserUnusedInvitations(authUser.ID) + data.Invitations, _ = db.GetUserUnusedInvitations(authUser.ID) return c.Render(http.StatusOK, "settings.invitations", data) } // SettingsWebsiteHandler ... func SettingsWebsiteHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) var data settingsWebsiteData data.ActiveTab = "website" - settings := database.GetSettings() + settings := db.GetSettings() data.SignupEnabled = settings.SignupEnabled data.ForumEnabled = settings.ForumEnabled data.SilentSelfKick = settings.SilentSelfKick @@ -2833,11 +2891,11 @@ func SettingsWebsiteHandler(c echo.Context) error { settings.SignupEnabled = utils.DoParseBool(c.Request().PostFormValue("signupEnabled")) settings.ForumEnabled = utils.DoParseBool(c.Request().PostFormValue("forumEnabled")) settings.SilentSelfKick = utils.DoParseBool(c.Request().PostFormValue("silentSelfKick")) - settings.DoSave() + settings.DoSave(db) config.SignupEnabled.Store(settings.SignupEnabled) config.ForumEnabled.Store(settings.ForumEnabled) config.SilentSelfKick.Store(settings.SilentSelfKick) - database.NewAudit(*authUser, fmt.Sprintf("website settings, signup: %t, forum: %t, sk: %t", + db.NewAudit(*authUser, fmt.Sprintf("website settings, signup: %t, forum: %t, sk: %t", settings.SignupEnabled, settings.ForumEnabled, settings.SilentSelfKick)) return c.Redirect(http.StatusFound, c.Request().Referer()) } @@ -2847,6 +2905,7 @@ func SettingsWebsiteHandler(c echo.Context) error { func editProfileForm(c echo.Context, data settingsAccountData) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) email := c.Request().PostFormValue("email") website := c.Request().PostFormValue("website") @@ -2871,13 +2930,14 @@ func editProfileForm(c echo.Context, data settingsAccountData) error { authUser.Email = data.Email authUser.LastSeenPublic = data.LastSeenPublic authUser.TerminateAllSessionsOnLogout = data.TerminateAllSessionsOnLogout - authUser.DoSave() + authUser.DoSave(db) return c.Render(http.StatusOK, "flash", FlashResponse{Message: "Profile changed successfully", Redirect: c.Request().Referer()}) } func changeAvatarForm(c echo.Context, data settingsAccountData) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) if !authUser.CanUpload() { data.ErrorAvatar = hutils.AccountTooYoungErr.Error() return c.Render(http.StatusOK, "settings.account", data) @@ -2935,12 +2995,13 @@ func changeAvatarForm(c echo.Context, data settingsAccountData) error { } authUser.SetAvatar(fileBytes) - authUser.DoSave() + authUser.DoSave(db) return c.Render(http.StatusOK, "flash", FlashResponse{Message: "Avatar changed successfully", Redirect: c.Request().Referer()}) } func changeUsernameForm(c echo.Context, data settingsAccountData) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) if !authUser.CanChangeUsername { data.ErrorUsername = "Not allowed to change your username" return c.Render(http.StatusOK, "settings.account", data) @@ -2954,25 +3015,26 @@ func changeUsernameForm(c echo.Context, data settingsAccountData) error { return c.Render(http.StatusOK, "settings.account", data) } - if err := database.CanRenameTo(authUser.Username, username); err != nil { + if err := db.CanRenameTo(authUser.Username, username); err != nil { data.ErrorUsername = err.Error() return c.Render(http.StatusOK, "settings.account", data) } managers.ActiveUsers.RemoveUser(authUser.ID) authUser.Username = username - if err := database.DB.Save(authUser).Error; err != nil { + if err := db.DB().Save(authUser).Error; err != nil { logrus.Error(err) data.ErrorUsername = err.Error() return c.Render(http.StatusOK, "settings.account", data) } - database.CreateSecurityLog(authUser.ID, database.UsernameChangedSecurityLog) + db.CreateSecurityLog(authUser.ID, database.UsernameChangedSecurityLog) return c.Render(http.StatusOK, "flash", FlashResponse{Message: "Username changed successfully", Redirect: c.Request().Referer()}) } func changeSettingsForm(c echo.Context, data settingsChatData) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) data.RefreshRate = utils.Clamp(utils.DoParseInt64(c.Request().PostFormValue("refresh_rate")), 5, 60) data.ChatColor = c.Request().PostFormValue("chat_color") @@ -3051,7 +3113,7 @@ func changeSettingsForm(c echo.Context, data settingsChatData) error { authUser.DisplayHellbanButton = data.DisplayHellbanButton } - if err := database.DB.Save(authUser).Error; err != nil { + if err := db.DB().Save(authUser).Error; err != nil { logrus.Error(err) data.Error = err.Error() return c.Render(http.StatusOK, "settings.chat", data) @@ -3062,6 +3124,7 @@ func changeSettingsForm(c echo.Context, data settingsChatData) error { func changePasswordForm(c echo.Context, data settingsPasswordData) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) oldPassword := c.Request().PostFormValue("oldPassword") newPassword := c.Request().PostFormValue("newPassword") rePassword := c.Request().PostFormValue("rePassword") @@ -3075,22 +3138,22 @@ func changePasswordForm(c echo.Context, data settingsPasswordData) error { } if len(newPassword) > 0 || len(rePassword) > 0 { - hashedPassword, err := database.NewPasswordValidator(newPassword).CompareWith(rePassword).Hash() + hashedPassword, err := database.NewPasswordValidator(db, newPassword).CompareWith(rePassword).Hash() if err != nil { data.ErrorNewPassword = err.Error() return c.Render(http.StatusOK, "settings.password", data) } - if !authUser.CheckPassword(oldPassword) { + if !authUser.CheckPassword(db, oldPassword) { data.ErrorOldPassword = "Invalid password" return c.Render(http.StatusOK, "settings.password", data) } - if err := authUser.ChangePassword(hashedPassword); err != nil { + if err := authUser.ChangePassword(db, hashedPassword); err != nil { logrus.Error(err) } c.SetCookie(hutils.DeleteCookie(hutils.AuthCookieName)) - database.CreateSecurityLog(authUser.ID, database.ChangePasswordSecurityLog) + db.CreateSecurityLog(authUser.ID, database.ChangePasswordSecurityLog) return c.Render(http.StatusFound, "flash", FlashResponse{Message: "Password changed successfully", Redirect: "/login"}) } @@ -3099,6 +3162,7 @@ func changePasswordForm(c echo.Context, data settingsPasswordData) error { func changeDuressPasswordForm(c echo.Context, data settingsPasswordData) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) oldDuressPassword := c.Request().PostFormValue("oldDuressPassword") newDuressPassword := c.Request().PostFormValue("newDuressPassword") reDuressPassword := c.Request().PostFormValue("reDuressPassword") @@ -3112,22 +3176,22 @@ func changeDuressPasswordForm(c echo.Context, data settingsPasswordData) error { } if len(newDuressPassword) > 0 || len(reDuressPassword) > 0 { - hashedPassword, err := database.NewPasswordValidator(newDuressPassword).CompareWith(reDuressPassword).Hash() + hashedPassword, err := database.NewPasswordValidator(db, newDuressPassword).CompareWith(reDuressPassword).Hash() if err != nil { data.ErrorNewDuressPassword = err.Error() return c.Render(http.StatusOK, "settings.password", data) } - if !authUser.CheckPassword(oldDuressPassword) { + if !authUser.CheckPassword(db, oldDuressPassword) { data.ErrorOldDuressPassword = "Invalid password" return c.Render(http.StatusOK, "settings.password", data) } - if err := authUser.ChangeDuressPassword(hashedPassword); err != nil { + if err := authUser.ChangeDuressPassword(db, hashedPassword); err != nil { logrus.Error(err) } c.SetCookie(hutils.DeleteCookie(hutils.AuthCookieName)) - database.CreateSecurityLog(authUser.ID, database.ChangeDuressPasswordSecurityLog) + db.CreateSecurityLog(authUser.ID, database.ChangeDuressPasswordSecurityLog) return c.Render(http.StatusFound, "flash", FlashResponse{Message: "Password changed successfully", Redirect: "/login"}) } @@ -3136,9 +3200,10 @@ func changeDuressPasswordForm(c echo.Context, data settingsPasswordData) error { func ChatDeleteHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) var data chatDeleteData roomName := c.Param("roomName") - room, err := database.GetChatRoomByName(roomName) + room, err := db.GetChatRoomByName(roomName) if err != nil { return c.Redirect(http.StatusFound, "/") } @@ -3151,7 +3216,7 @@ func ChatDeleteHandler(c echo.Context) error { if room.IsProtected() { hutils.DeleteRoomCookie(c, int64(room.ID)) } - database.DeleteChatRoomByID(room.ID) + db.DeleteChatRoomByID(room.ID) return c.Redirect(http.StatusFound, "/chat") } @@ -3160,10 +3225,11 @@ func ChatDeleteHandler(c echo.Context) error { func ChatArchiveHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) var data chatArchiveData data.DateFormat = authUser.GetDateFormat() roomName := c.Param("roomName") - room, err := database.GetChatRoomByName(roomName) + room, err := db.GetChatRoomByName(roomName) if err != nil { return c.Redirect(http.StatusFound, "/") } @@ -3176,7 +3242,7 @@ func ChatArchiveHandler(c echo.Context) error { data.Room = room if data.UUID != "" { - msg, err := database.GetRoomChatMessageByUUID(room.ID, data.UUID) + msg, err := db.GetRoomChatMessageByUUID(room.ID, data.UUID) if err != nil { return c.Redirect(http.StatusFound, "/") } @@ -3207,7 +3273,7 @@ func ChatArchiveHandler(c echo.Context) error { ) ORDER BY id DESC` args = append(args, msg.ID, nbMsg) args = append(args, args...) - database.DB.Raw(raw, args...).Scan(&data.Messages) + db.DB().Raw(raw, args...).Scan(&data.Messages) // Manually do Preload("Room") for _, m := range data.Messages { @@ -3222,7 +3288,7 @@ func ChatArchiveHandler(c echo.Context) error { usersIDs.Insert(*m.ToUserID) } } - users, _ := database.GetUsersByID(usersIDs.ToArray()) + users, _ := db.GetUsersByID(usersIDs.ToArray()) usersMap := make(map[database.UserID]database.User) for _, u := range users { usersMap[u.ID] = u @@ -3240,7 +3306,7 @@ func ChatArchiveHandler(c echo.Context) error { //--- </ Manually do a Preload("User") Preload("ToUser") > --- } else { - if err := database.DB.Table("chat_messages"). + if err := db.DB().Table("chat_messages"). Where("room_id = ? AND group_id IS NULL AND (to_user_id is null OR to_user_id = ? OR user_id = ?)", room.ID, authUser.ID, authUser.ID). Scopes(func(query *gorm.DB) *gorm.DB { if !authUser.DisplayIgnored { @@ -3349,6 +3415,7 @@ func generatePgpToBeSignedTokenMessage(userID database.UserID, pkey string) stri func AddPGPHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) var data addPGPData data.PGPPublicKey = authUser.GPGPublicKey if c.Request().Method == http.MethodPost { @@ -3399,7 +3466,7 @@ func AddPGPHandler(c echo.Context) error { pgpTokenCache.Delete(authUser.ID) authUser.GPGPublicKey = token.PKey - authUser.DoSave() + authUser.DoSave(db) return c.Redirect(http.StatusFound, "/settings/pgp") } } @@ -3408,6 +3475,7 @@ func AddPGPHandler(c echo.Context) error { func AddAgeHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) var data addAgeData data.AgePublicKey = authUser.AgePublicKey if c.Request().Method == http.MethodPost { @@ -3436,7 +3504,7 @@ func AddAgeHandler(c echo.Context) error { } ageTokenCache.Delete(authUser.ID) authUser.AgePublicKey = token.PKey - authUser.DoSave() + authUser.DoSave(db) return c.Redirect(http.StatusFound, "/settings/age") } } @@ -3453,6 +3521,7 @@ type twoFactorObj struct { func GpgTwoFactorAuthenticationToggleHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) var data gpgTwoFactorAuthenticationVerifyData data.IsEnabled = authUser.GpgTwoFactorEnabled @@ -3463,7 +3532,7 @@ func GpgTwoFactorAuthenticationToggleHandler(c echo.Context) error { } password := c.Request().PostFormValue("password") - if !authUser.CheckPassword(password) { + if !authUser.CheckPassword(db, password) { data.ErrorPassword = "Invalid password" return c.Render(http.StatusOK, "two-factor-authentication-gpg", data) } @@ -3471,8 +3540,8 @@ func GpgTwoFactorAuthenticationToggleHandler(c echo.Context) error { // Disable if authUser.GpgTwoFactorEnabled { authUser.GpgTwoFactorEnabled = false - authUser.DoSave() - database.CreateSecurityLog(authUser.ID, database.Gpg2faDisabledSecurityLog) + authUser.DoSave(db) + db.CreateSecurityLog(authUser.ID, database.Gpg2faDisabledSecurityLog) return c.Render(http.StatusOK, "flash", FlashResponse{"GPG Two-factor authentication disabled", "/settings/account", "alert-success"}) } @@ -3481,14 +3550,14 @@ func GpgTwoFactorAuthenticationToggleHandler(c echo.Context) error { return c.Render(http.StatusOK, "flash", FlashResponse{"You need to setup your PGP key first", "/settings/pgp", "alert-danger"}) } // Delete active user sessions - if err := database.DeleteUserSessions(authUser.ID); err != nil { + if err := db.DeleteUserSessions(authUser.ID); err != nil { logrus.Error(err) } c.SetCookie(hutils.DeleteCookie(hutils.AuthCookieName)) authUser.GpgTwoFactorEnabled = true authUser.GpgTwoFactorMode = utils.DoParseBool(c.Request().PostFormValue("gpg_two_factor_mode")) - authUser.DoSave() - database.CreateSecurityLog(authUser.ID, database.Gpg2faEnabledSecurityLog) + authUser.DoSave(db) + db.CreateSecurityLog(authUser.ID, database.Gpg2faEnabledSecurityLog) return c.Render(http.StatusOK, "flash", FlashResponse{"GPG Two-factor authentication enabled", "/settings/account", "alert-success"}) } @@ -3500,6 +3569,7 @@ func TwoFactorAuthenticationVerifyHandler(c echo.Context) error { return base64.StdEncoding.EncodeToString(buf.Bytes()) } authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) if authUser.TwoFactorSecret != "" { return c.Redirect(http.StatusFound, "/settings/account") } @@ -3510,7 +3580,7 @@ func TwoFactorAuthenticationVerifyHandler(c echo.Context) error { return c.Redirect(http.StatusFound, "/two-factor-authentication/verify") } password := c.Request().PostFormValue("password") - if !authUser.CheckPassword(password) { + if !authUser.CheckPassword(db, password) { img, _ := twoFactor.key.Image(150, 150) data.QRCode = getImgStr(img) data.Secret = twoFactor.key.Secret() @@ -3534,14 +3604,14 @@ func TwoFactorAuthenticationVerifyHandler(c echo.Context) error { return c.Render(http.StatusOK, "two-factor-authentication-verify", data) } // Delete active user sessions - if err := database.DeleteUserSessions(authUser.ID); err != nil { + if err := db.DeleteUserSessions(authUser.ID); err != nil { logrus.Error(err) } c.SetCookie(hutils.DeleteCookie(hutils.AuthCookieName)) authUser.TwoFactorSecret = database.EncryptedString(twoFactor.key.Secret()) authUser.TwoFactorRecovery = string(h) - authUser.DoSave() - database.CreateSecurityLog(authUser.ID, database.TotpEnabledSecurityLog) + authUser.DoSave(db) + db.CreateSecurityLog(authUser.ID, database.TotpEnabledSecurityLog) return c.Render(http.StatusOK, "flash", FlashResponse{"Two-factor authentication enabled", "/", "alert-success"}) } key, _ := totp.Generate(totp.GenerateOpts{Issuer: "DarkForest", AccountName: authUser.Username}) @@ -3561,15 +3631,16 @@ func TwoFactorAuthenticationDisableHandler(c echo.Context) error { return c.Render(http.StatusOK, "disable-totp", data) } authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) password := c.Request().PostFormValue("password") - if !authUser.CheckPassword(password) { + if !authUser.CheckPassword(db, password) { data.ErrorPassword = "Invalid password" return c.Render(http.StatusOK, "disable-totp", data) } authUser.TwoFactorSecret = "" authUser.TwoFactorRecovery = "" - authUser.DoSave() - database.CreateSecurityLog(authUser.ID, database.TotpDisabledSecurityLog) + authUser.DoSave(db) + db.CreateSecurityLog(authUser.ID, database.TotpDisabledSecurityLog) return c.Render(http.StatusOK, "flash", FlashResponse{"Two-factor authentication disabled", "/settings/account", "alert-success"}) } @@ -3694,6 +3765,7 @@ var flagValidationCache = cache.NewWithKey[database.UserID, bool](time.Minute, t func VipDownloadsHandler(c echo.Context) error { const flagHash = "fefc9d5db52b51aeefd4b098f0178a8bcb7f0816dcadaf1714604f01ef63a621" authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) var data vipDownloadsHandlerData data.ActiveTab = "home" data.Files = getDownloadsFiles() @@ -3709,7 +3781,7 @@ func VipDownloadsHandler(c echo.Context) error { } if utils.Sha256([]byte(flag)) == flagHash { data.FlagMessage = "You found the flag!" - _ = database.CreateUserBadge(authUser.ID, 1) + _ = db.CreateUserBadge(authUser.ID, 1) } else { data.FlagMessage = "Invalid flag" } @@ -3725,6 +3797,7 @@ func downloadFile(c echo.Context, folder, redirect string) error { } authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) if authUser == nil { return c.Redirect(http.StatusFound, "/login?redirect="+redirect) } @@ -3737,7 +3810,7 @@ func downloadFile(c echo.Context, folder, redirect string) error { } // Keep track of user downloads - if _, err := database.CreateDownload(authUser.ID, filename); err != nil { + if _, err := db.CreateDownload(authUser.ID, filename); err != nil { logrus.Error(err) } @@ -3758,6 +3831,7 @@ func VipDownloadFileHandler(c echo.Context) error { func CaptchaRequiredHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) var data captchaRequiredData data.CaptchaDescription = "Captcha required" @@ -3778,7 +3852,7 @@ func CaptchaRequiredHandler(c echo.Context) error { } config.CaptchaRequiredSuccess.Inc() authUser.CaptchaRequired = false - authUser.DoSave() + authUser.DoSave(db) return c.Redirect(http.StatusFound, "/chat") } @@ -3819,20 +3893,22 @@ func CaptchaHandler(c echo.Context) error { func PublicUserProfileHandler(c echo.Context) error { username := c.Param("username") - user, err := database.GetUserByUsername(username) + db := c.Get("database").(*database.DkfDB) + user, err := db.GetUserByUsername(username) if err != nil { return c.Redirect(http.StatusFound, "/") } var data publicProfileData data.User = user data.UserStyle = user.GenerateChatStyle() - data.PublicNotes, _ = database.GetUserPublicNotes(user.ID) + data.PublicNotes, _ = db.GetUserPublicNotes(user.ID) return c.Render(http.StatusOK, "public-profile", data) } func PublicUserProfilePGPHandler(c echo.Context) error { username := c.Param("username") - user, err := database.GetUserByUsername(username) + db := c.Get("database").(*database.DkfDB) + user, err := db.GetUserByUsername(username) if err != nil { return c.Redirect(http.StatusFound, "/") } @@ -3877,8 +3953,9 @@ func isAttachmentMimeType(mimeType string) bool { func UploadsDownloadHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) filename := c.Param("filename") - file, err := database.GetUploadByFileName(filename) + file, err := db.GetUploadByFileName(filename) if err != nil { return c.Render(http.StatusOK, "standalone.upload404", nil) } @@ -3914,7 +3991,7 @@ func UploadsDownloadHandler(c echo.Context) error { // MimeType that always trigger a file "download" if isAttachmentMimeType(mimeType) { // Keep track of user downloads - if _, err := database.CreateDownload(authUser.ID, filename); err != nil { + if _, err := db.CreateDownload(authUser.ID, filename); err != nil { logrus.Error(err) } c.Response().Header().Set(echo.HeaderContentDisposition, fmt.Sprintf("%s; filename=%q", "attachment", file.OrigFileName)) @@ -3929,7 +4006,7 @@ func UploadsDownloadHandler(c echo.Context) error { return nil } - userNbDownloaded := database.UserNbDownloaded(authUser.ID, filename) + userNbDownloaded := db.UserNbDownloaded(authUser.ID, filename) // Display captcha to new users, or old users if they already downloaded the file. if !authUser.AccountOldEnough() || userNbDownloaded >= 1 { @@ -3955,7 +4032,7 @@ func UploadsDownloadHandler(c echo.Context) error { } // Keep track of user downloads - if _, err := database.CreateDownload(authUser.ID, filename); err != nil { + if _, err := db.CreateDownload(authUser.ID, filename); err != nil { logrus.Error(err) } @@ -3972,7 +4049,8 @@ func UploadsDownloadHandler(c echo.Context) error { func FiledropDownloadHandler(c echo.Context) error { filename := c.Param("filename") - filedrop, err := database.GetFiledropByFileName(filename) + db := c.Get("database").(*database.DkfDB) + filedrop, err := db.GetFiledropByFileName(filename) if err != nil { return c.Redirect(http.StatusFound, "/") } @@ -4023,6 +4101,7 @@ type ByteRoadPayload struct { func ByteRoadChallengeHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) const byteRoadChallengeTmplName = "vip.byte-road-challenge" var data byteRoadChallengeData data.ActiveTab = "home" @@ -4103,7 +4182,7 @@ func ByteRoadChallengeHandler(c echo.Context) error { _ = byteRoadUsersCountCache.Update(authUser.ID, payload) if payload.Count >= 100 { data.FlagFound = true - _ = database.CreateUserBadge(authUser.ID, 2) + _ = db.CreateUserBadge(authUser.ID, 2) } return c.Render(http.StatusOK, byteRoadChallengeTmplName, data) } @@ -4166,12 +4245,13 @@ func BHCHandler(c echo.Context) error { func FileDropHandler(c echo.Context) error { const filedropTmplName = "standalone.filedrop" uuidParam := c.Param("uuid") + db := c.Get("database").(*database.DkfDB) //if c.Request().ContentLength > config.MaxUserFileUploadSize { // data.Error = fmt.Sprintf("The maximum file size is %s", humanize.Bytes(config.MaxUserFileUploadSize)) // return c.Render(http.StatusOK, "chat-top-bar", data) //} - filedrop, err := database.GetFiledropByUUID(uuidParam) + filedrop, err := db.GetFiledropByUUID(uuidParam) if err != nil { return c.Redirect(http.StatusFound, "/") } @@ -4227,17 +4307,18 @@ func FileDropHandler(c echo.Context) error { filedrop.IV = encrypter.Meta().IV filedrop.OrigFileName = origFileName filedrop.FileSize = written - filedrop.DoSave() + filedrop.DoSave(db) data.Success = "File uploaded successfully" return c.Render(http.StatusOK, filedropTmplName, data) } func FileDropDkfUploadHandler(c echo.Context) error { + db := c.Get("database").(*database.DkfDB) // Init if c.Request().PostFormValue("init") != "" { filedropUUID := c.Param("uuid") - _, err := database.GetFiledropByUUID(filedropUUID) + _, err := db.GetFiledropByUUID(filedropUUID) if err != nil { return c.Redirect(http.StatusFound, "/") } @@ -4277,7 +4358,7 @@ func FileDropDkfUploadHandler(c echo.Context) error { if c.Request().PostFormValue("completed") != "" { filedropUUID := c.Param("uuid") - filedrop, err := database.GetFiledropByUUID(filedropUUID) + filedrop, err := db.GetFiledropByUUID(filedropUUID) if err != nil { return c.Redirect(http.StatusFound, "/") } @@ -4356,7 +4437,7 @@ func FileDropDkfUploadHandler(c echo.Context) error { filedrop.IV = iv filedrop.OrigFileName = origFileName filedrop.FileSize = written - filedrop.DoSave() + filedrop.DoSave(db) return c.NoContent(http.StatusOK) } @@ -4374,7 +4455,7 @@ func FileDropDkfUploadHandler(c echo.Context) error { } } - _, err := database.GetFiledropByUUID(filedropUUID) + _, err := db.GetFiledropByUUID(filedropUUID) if err != nil { return c.Redirect(http.StatusFound, "/") } @@ -4397,7 +4478,8 @@ func FileDropDkfUploadHandler(c echo.Context) error { func FileDropDkfDownloadHandler(c echo.Context) error { filedropUUID := c.Param("uuid") - filedrop, err := database.GetFiledropByUUID(filedropUUID) + db := c.Get("database").(*database.DkfDB) + filedrop, err := db.GetFiledropByUUID(filedropUUID) if err != nil { return c.NoContent(http.StatusNotFound) } @@ -4456,6 +4538,7 @@ func FileDropDkfDownloadHandler(c echo.Context) error { func FileDropDownloadHandler(c echo.Context) error { authUser, ok := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) if !ok { return c.Redirect(http.StatusFound, "/") } @@ -4466,7 +4549,7 @@ func FileDropDownloadHandler(c echo.Context) error { return c.Redirect(http.StatusFound, "/") } - userNbDownloaded := database.UserNbDownloaded(authUser.ID, fileName) + userNbDownloaded := db.UserNbDownloaded(authUser.ID, fileName) // Display captcha to new users, or old users if they already downloaded the file. if !authUser.AccountOldEnough() || userNbDownloaded >= 1 { @@ -4492,7 +4575,7 @@ func FileDropDownloadHandler(c echo.Context) error { } // Keep track of user downloads - if _, err := database.CreateDownload(authUser.ID, fileName); err != nil { + if _, err := db.CreateDownload(authUser.ID, fileName); err != nil { logrus.Error(err) } @@ -4513,6 +4596,7 @@ func FileDropDownloadHandler(c echo.Context) error { func Stego1ChallengeHandler(c echo.Context) error { const flagHash = "05b456689a9f8de69416d21cbb97157588b8491d07551167a95b93a1c7d61e7b" authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) var data stego1RoadChallengeData data.ActiveTab = "home" @@ -4528,7 +4612,7 @@ func Stego1ChallengeHandler(c echo.Context) error { } if utils.Sha256([]byte(flag)) == flagHash { data.FlagMessage = "You found the flag!" - _ = database.CreateUserBadge(authUser.ID, 3) + _ = db.CreateUserBadge(authUser.ID, 3) } else { data.FlagMessage = "Invalid flag" } @@ -4540,12 +4624,13 @@ func Stego1ChallengeHandler(c echo.Context) error { func ChessHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) var data chessData data.Games = v1.ChessInstance.GetGames() if c.Request().Method == http.MethodPost { data.Username = c.Request().PostFormValue("username") - player2, err := database.GetUserByUsername(data.Username) + player2, err := db.GetUserByUsername(data.Username) if err != nil { data.Error = "invalid username" return c.Render(http.StatusOK, "chess", data) @@ -4612,13 +4697,14 @@ html, body { func ChessGameHandler(c echo.Context) error { authUser := c.Get("authUser").(*database.User) + //db := c.Get("database").(*database.DkfDB) key := c.Param("key") g := v1.ChessInstance.GetGame(key) if g == nil { // Chess debug - //user1, _ := database.GetUserByID(1) - //user2, _ := database.GetUserByID(24132) + //user1, _ := db.GetUserByID(1) + //user2, _ := db.GetUserByID(24132) //v1.ChessInstance.NewGame(key, user1, user2) //g = v1.ChessInstance.GetGame(key) return c.Redirect(http.StatusFound, "/") diff --git a/pkg/web/middlewares/middlewares.go b/pkg/web/middlewares/middlewares.go @@ -213,6 +213,15 @@ func I18nMiddleware(bundle *i18n.Bundle, defaultLang string) echo.MiddlewareFunc } } +func SetDatabaseMiddleware(db *database.DkfDB) echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(ctx echo.Context) error { + ctx.Set("database", db) + return next(ctx) + } + } +} + // SetUserMiddleware Get user and put it into echo context. // - Get auth-token from cookie // - If exists, get user from database @@ -220,18 +229,19 @@ func I18nMiddleware(bundle *i18n.Bundle, defaultLang string) echo.MiddlewareFunc // - Otherwise, empty user will be put in context func SetUserMiddleware(next echo.HandlerFunc) echo.HandlerFunc { return func(ctx echo.Context) error { + db := ctx.Get("database").(*database.DkfDB) var nilUser *database.User var user database.User if apiKey := ctx.Request().Header.Get("DKF_API_KEY"); apiKey != "" { // Login using DKF_API_KEY - if err := database.GetUserByApiKey(&user, apiKey); err == nil { + if err := db.GetUserByApiKey(&user, apiKey); err == nil { ctx.Set("authUser", &user) return next(ctx) } } else if authCookie, err := ctx.Cookie(hutils.AuthCookieName); err == nil { // Login using auth cookie - if err := database.GetUserBySessionKey(&user, authCookie.Value); err == nil { + if err := db.GetUserBySessionKey(&user, authCookie.Value); err == nil { ctx.Set("authUser", &user) return next(ctx) } @@ -248,6 +258,7 @@ func SetUserMiddleware(next echo.HandlerFunc) echo.HandlerFunc { func IsAuthMiddleware(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { user := c.Get("authUser").(*database.User) + db := c.Get("database").(*database.DkfDB) if user == nil { if strings.HasPrefix(c.Path(), "/api/") { return c.String(http.StatusUnauthorized, "unauthorized") @@ -258,7 +269,7 @@ func IsAuthMiddleware(next echo.HandlerFunc) echo.HandlerFunc { c.Response().Header().Set("Cache-Control", "no-cache, no-store, must-revalidate") user.LastSeenAt = time.Now() - user.DoSave() + user.DoSave(db) // Prevent clickjacking by setting the header on every logged in page if !strings.Contains(c.Path(), "/api/v1/chat/messages") && diff --git a/pkg/web/public/views/pages/club/thread.gohtml b/pkg/web/public/views/pages/club/thread.gohtml @@ -42,7 +42,7 @@ </form> {{ end }} </div> - {{ .Escape | safe }} + {{ .Escape $.DB | safe }} </td> </tr> </table> diff --git a/pkg/web/public/views/pages/thread.gohtml b/pkg/web/public/views/pages/thread.gohtml @@ -88,7 +88,7 @@ </div> </div> <div style="padding: 5px 5px 10px 10px;"> - {{ .Escape | safe }} + {{ .Escape $.DB | safe }} </div> </div> {{ end }} diff --git a/pkg/web/web.go b/pkg/web/web.go @@ -2,6 +2,7 @@ package web import ( "context" + "dkforest/pkg/database" "fmt" "github.com/ulule/limiter" "github.com/ulule/limiter/drivers/store/memory" @@ -27,7 +28,7 @@ import ( yaml "gopkg.in/yaml.v1" ) -func getMainServer(i18nBundle *i18n.Bundle, renderer *tmp.Templates) echo.HandlerFunc { +func getMainServer(db *database.DkfDB, i18nBundle *i18n.Bundle, renderer *tmp.Templates) echo.HandlerFunc { e := echo.New() e.HideBanner = true e.HidePort = true @@ -39,6 +40,7 @@ func getMainServer(i18nBundle *i18n.Bundle, renderer *tmp.Templates) echo.Handle e.Use(staticbin.Static(bindata.Asset, staticbin.Options{Dir: "/public", SkipLogging: true})) e.Renderer = renderer + e.Use(middlewares.SetDatabaseMiddleware(db)) e.Use(middlewares.FirstUseMiddleware) e.Use(middlewares.DdosMiddleware) e.Use(middlewares.MaintenanceMiddleware) @@ -290,7 +292,7 @@ func getMainServer(i18nBundle *i18n.Bundle, renderer *tmp.Templates) echo.Handle } } -func getBaseServer() *echo.Echo { +func getBaseServer(db *database.DkfDB) *echo.Echo { e := echo.New() renderer := tmp.GetRenderer(e) i18nBundle := getI18nBundle() @@ -299,31 +301,32 @@ func getBaseServer() *echo.Echo { e.Debug = true e.Renderer = renderer e.Use(middlewares.SetUselessHeadersMiddleware) + e.Use(middlewares.SetDatabaseMiddleware(db)) e.Use(middlewares.SetUserMiddleware) e.Use(middlewares.I18nMiddleware(i18nBundle, "en")) e.GET("/file-drop/:uuid", handlers.FileDropHandler) e.POST("/file-drop/:uuid", handlers.FileDropHandler) e.POST("/file-drop/:uuid/dkfupload", handlers.FileDropDkfUploadHandler) - e.POST("/api/v1/file-drop/:uuid/dkfdownload", handlers.FileDropDkfDownloadHandler, middlewares.SetUserMiddleware, middlewares.IsAuthMiddleware) + e.POST("/api/v1/file-drop/:uuid/dkfdownload", handlers.FileDropDkfDownloadHandler, middlewares.IsAuthMiddleware) e.GET("/downloads/:fileName", handlers.FileDropDownloadHandler) e.POST("/downloads/:fileName", handlers.FileDropDownloadHandler) - e.Any("*", getMainServer(i18nBundle, renderer)) + e.Any("*", getMainServer(db, i18nBundle, renderer)) return e } -func startI2pServer(host string, port int) { +func startI2pServer(db *database.DkfDB, host string, port int) { if config.Development.IsTrue() { return } address := host + ":" + strconv.Itoa(port) - e := getBaseServer() + e := getBaseServer(db) logrus.Info("start i2p server on " + address) startServer(e, address) } -func startTorServer(host string, port int) { +func startTorServer(db *database.DkfDB, host string, port int) { address := host + ":" + strconv.Itoa(port) - e := getBaseServer() + e := getBaseServer(db) configTorProdServer(e) logrus.Info("start tor server on " + address) startServer(e, address) @@ -338,11 +341,11 @@ func startServer(e *echo.Echo, address string) { } // Start ... -func Start(host string, port int) { +func Start(db *database.DkfDB, host string, port int) { // Start server for I2P - go startI2pServer(host, port+1) + go startI2pServer(db, host, port+1) // Server for Tor/dev - startTorServer(host, port) + startTorServer(db, host, port) } func extractGlobalCircuitIdentifier(m string) int64 {