commit 79e79afef727561361b6944139c09d693fea8944
parent c07c542c3830bc897a3289eaf143e77e93e21a29
Author: n0tr1v <n0tr1v@protonmail.com>
Date: Fri, 3 Mar 2023 03:10:02 -0800
get rid of database global variable and all side effects
Diffstat:
59 files changed, 1529 insertions(+), 1174 deletions(-)
diff --git a/pkg/actions/actions.go b/pkg/actions/actions.go
@@ -47,22 +47,21 @@ func Start(c *cli.Context) error {
ensureProjectHome()
- initDB()
- defer database.DB.Close()
+ db := database.DB()
- runMigrations()
+ runMigrations(db)
- config.IsFirstUse.Store(isFirstUse())
+ config.IsFirstUse.Store(isFirstUse(db))
captcha.SetCustomStore(captcha.NewMemoryStore(captcha.CollectNum, 120*time.Second))
if false {
- err := database.DB.Debug().Exec(`
+ err := db.DB().Debug().Exec(`
`).Error
logrus.Error(err)
}
- settings := database.GetSettings()
+ settings := db.GetSettings()
config.ProtectHome.Store(settings.ProtectHome)
config.HomeUsersList.Store(settings.HomeUsersList)
config.ForceLoginCaptcha.Store(settings.ForceLoginCaptcha)
@@ -76,22 +75,22 @@ func Start(c *cli.Context) error {
config.Xmr()
- utils.SGo(func() { cleanupDatabase() })
+ utils.SGo(func() { cleanupDatabase(db) })
utils.SGo(func() { managers.ActiveUsers.CleanupUsersCache() })
- utils.SGo(func() { xmrWatch() })
+ utils.SGo(func() { xmrWatch(db) })
utils.SGo(func() { openBrowser(noBrowser, int64(port)) })
- v1.ChessInstance = v1.NewChess()
- v1.BattleshipInstance = v1.NewBattleship()
- v1.WWInstance = v1.NewWerewolf()
+ v1.ChessInstance = v1.NewChess(db)
+ v1.BattleshipInstance = v1.NewBattleship(db)
+ v1.WWInstance = v1.NewWerewolf(db)
- web.Start(host, port)
+ web.Start(db, host, port)
return nil
}
-func isFirstUse() bool {
+func isFirstUse(db *database.DkfDB) bool {
var count int64
- database.DB.Model(database.User{}).Count(&count)
+ db.DB().Model(database.User{}).Count(&count)
return count <= 0
}
@@ -105,7 +104,7 @@ func openBrowser(noBrowser bool, port int64) {
}
}
-func runMigrations() {
+func runMigrations(db *database.DkfDB) {
logrus.Info("running migrations")
migrations := &migrate.AssetMigrationSource{
Asset: config.MigrationsFs.ReadFile,
@@ -119,25 +118,15 @@ func runMigrations() {
},
Dir: "migrations",
}
- database.DB.Exec("PRAGMA foreign_keys=OFF")
- n, err := migrate.Exec(database.DB.DB(), "sqlite3", migrations, migrate.Up)
+ db.DB().Exec("PRAGMA foreign_keys=OFF")
+ n, err := migrate.Exec(db.DB().DB(), "sqlite3", migrations, migrate.Up)
if err != nil {
panic(err)
}
- database.DB.Exec("PRAGMA foreign_keys=ON")
+ db.DB().Exec("PRAGMA foreign_keys=ON")
logrus.Infof("applied %d migrations", n)
}
-func initDB() {
- dbPath := filepath.Join(config.Global.ProjectPath(), config.DbFileName)
- db, err := database.OpenSqlite3DB(dbPath)
- if err != nil {
- logrus.Fatal("Failed to open sqlite3 db : " + err.Error())
- return
- }
- database.DB = db
-}
-
// Ensure the project folder is created properly
func ensureProjectHome() {
config.Global.SetProjectPath(utils.MustGetDefaultProjectPath())
@@ -215,7 +204,7 @@ func (f LogFormatter) Format(e *logrus.Entry) ([]byte, error) {
return buffer.Bytes(), nil
}
-func xmrWatch() {
+func xmrWatch(db *database.DkfDB) {
var once utils.Once
for {
select {
@@ -227,7 +216,7 @@ func xmrWatch() {
continue
}
for _, transfer := range transfers.In {
- invoice, err := database.GetXmrInvoiceByAddress(transfer.Address)
+ invoice, err := db.GetXmrInvoiceByAddress(transfer.Address)
if err != nil {
logrus.Error(err, transfer.TxID)
continue
@@ -236,7 +225,7 @@ func xmrWatch() {
invoice.Confirmations = int64(transfer.Confirmations)
amount := int64(transfer.Amount)
invoice.AmountReceived = &amount
- invoice.DoSave()
+ invoice.DoSave(db)
if origConfirmations >= 10 {
continue
} else if transfer.Confirmations < 10 {
@@ -249,7 +238,7 @@ func xmrWatch() {
}
}
-func cleanupDatabase() {
+func cleanupDatabase(db *database.DkfDB) {
var once utils.Once
for {
select {
@@ -257,13 +246,13 @@ func cleanupDatabase() {
case <-time.After(1 * time.Hour):
}
start := time.Now()
- database.DeleteOldSessions()
- database.DeleteOldUploads()
- database.DeleteOldChatMessages()
- database.DeleteOldPrivateChatRooms()
- database.DeleteOldCaptchaRequests()
- database.DeleteOldAuditLogs()
- database.DeleteOldSecurityLogs()
+ db.DeleteOldSessions()
+ db.DeleteOldUploads()
+ db.DeleteOldChatMessages()
+ db.DeleteOldPrivateChatRooms()
+ db.DeleteOldCaptchaRequests()
+ db.DeleteOldAuditLogs()
+ db.DeleteOldSecurityLogs()
logrus.Debugf("done cleaning database, took %s", time.Since(start))
}
}
@@ -276,9 +265,6 @@ func BuildProhibitedPasswords(c *cli.Context) error {
ensureProjectHome()
- initDB()
- defer database.DB.Close()
-
readFile, _ := os.Open("rockyou.txt")
fileScanner := bufio.NewScanner(readFile)
fileScanner.Split(bufio.ScanLines)
@@ -287,7 +273,7 @@ func BuildProhibitedPasswords(c *cli.Context) error {
rows = append(rows, database.ProhibitedPassword{Password: fileScanner.Text()})
}
readFile.Close()
- if err := gormbulk.BulkInsert(database.DB, rows, 10000); err != nil {
+ if err := gormbulk.BulkInsert(database.DB().DB(), rows, 10000); err != nil {
logrus.Error(err)
}
return nil
diff --git a/pkg/database/database.go b/pkg/database/database.go
@@ -5,26 +5,236 @@ import (
"fmt"
"github.com/jinzhu/gorm"
"github.com/mattn/go-sqlite3"
+ "github.com/sirupsen/logrus"
"net/url"
"path/filepath"
"strings"
+ "sync"
"time"
)
-// DB ...
-var DB *gorm.DB
+type DkfDB struct {
+ db *gorm.DB
+}
-// OpenSqlite3DB ...
-func OpenSqlite3DB(path string) (*gorm.DB, error) {
- db, err := gorm.Open("sqlite3", path)
- if err != nil {
- return nil, err
- }
- db.DB().SetMaxIdleConns(1) // 10
- db.DB().SetMaxOpenConns(1) // 25
- db.LogMode(false)
- db.Exec("PRAGMA foreign_keys=ON")
- return db, nil
+func (d *DkfDB) DB() *gorm.DB {
+ return d.db
+}
+
+// Compile time checks to ensure type satisfies IDkfDB interface
+var _ IDkfDB = (*DkfDB)(nil)
+
+type IDkfDB interface {
+ AddBlacklistedUser(userID, blacklistedUserID UserID)
+ AddLinkCategory(linkID, categoryID int64) (err error)
+ AddLinkTag(linkID, tagID int64) (err error)
+ AddUserToRoomGroup(roomID RoomID, groupID GroupID, userID UserID) (out ChatRoomUserGroup, err error)
+ AddWhitelistedUser(userID, whitelistedUserID UserID)
+ CanRenameTo(oldUsername, newUsername string) error
+ CanUseUsername(username string, isFirstUser bool) error
+ ClearRoomGroup(roomID RoomID, groupID GroupID) (err error)
+ CreateChatReaction(userID UserID, messageID, reaction int64) error
+ CreateChatRoomGroup(roomID RoomID, name, color string) (out ChatRoomGroup, err error)
+ CreateDownload(userID UserID, filename string) (out Download, err error)
+ CreateEncryptedUploadWithSize(fileName string, content []byte, userID UserID, size int64) (*Upload, error)
+ CreateFiledrop() (out Filedrop, err error)
+ CreateInboxMessage(msg string, roomID RoomID, fromUserID, toUserID UserID, isPm, moderators bool, msgID *int64)
+ CreateInvitation(userID UserID) (out Invitation, err error)
+ CreateKarmaHistory(karma int64, description string, userID UserID, fromUserID *int64) (out KarmaHistory, err error)
+ CreateKickMsg(kickedUser, kickedByUser User)
+ CreateLink(url, title, description, shorthand string) (out Link, err error)
+ CreateLinkMirror(linkID int64, link string) (out LinksMirror, err error)
+ CreateLinkPgp(linkID int64, title, description, publicKey string) (out LinksPgp, err error)
+ CreateLinksCategory(category string) (out LinksCategory, err error)
+ CreateLinksTag(tag string) (out LinksTag, err error)
+ CreateMsg(raw, txt, roomKey string, roomID RoomID, userID UserID, toUserID *UserID) (out ChatMessage, err error)
+ CreateNotification(msg string, userID UserID)
+ CreateOrEditMessage(editMsg *ChatMessage, message, raw, roomKey string, roomID RoomID, fromUserID UserID, toUserID *UserID, upload *Upload, groupID *GroupID, hellbanMsg, modMsg, systemMsg bool) (int64, error)
+ CreateRoom(name string, passwordHash string, ownerID UserID, isListed bool) (out ChatRoom, err error)
+ CreateSecurityLog(userID UserID, typ int64)
+ CreateSession(userID UserID, userAgent string) (Session, error)
+ CreateSessionNotification(msg string, sessionToken string)
+ CreateSnippet(userID UserID, name, text string) (out Snippet, err error)
+ CreateSysMsg(raw, txt, roomKey string, roomID RoomID, userID UserID) error
+ CreateUnkickMsg(kickedUser, kickedByUser User)
+ CreateUpload(fileName string, content []byte, userID UserID) (*Upload, error)
+ CreateUser(username, password, repassword string, registrationDuration int64, signupInfoEnc string) (User, UserErrors)
+ CreateUserBadge(userID UserID, badgeID int64) error
+ CreateXmrInvoice(userID UserID, productID int64) (out XmrInvoice, err error)
+ DeWhitelistUser(roomID RoomID, userID UserID) (err error)
+ DeleteAllChatInbox(userID UserID) error
+ DeleteAllNotifications(userID UserID) error
+ DeleteAllSessionNotifications(sessionToken string) error
+ DeleteChatInboxMessageByChatMessageID(chatMessageID int64) error
+ DeleteChatInboxMessageByID(messageID int64) error
+ DeleteChatMessageByUUID(messageUUID string) error
+ DeleteChatRoomByID(id RoomID)
+ DeleteChatRoomGroup(roomID RoomID, name string) (err error)
+ DeleteChatRoomGroups(roomID RoomID) (err error)
+ DeleteChatRoomMessages(roomID RoomID) error
+ DeleteDownloadByID(downloadID int64) (err error)
+ DeleteForumMessageByID(messageID ForumMessageID) error
+ DeleteForumThreadByID(threadID ForumThreadID) error
+ DeleteLinkByID(id int64) error
+ DeleteLinkCategories(linkID int64) error
+ DeleteLinkMirrorByID(id int64) error
+ DeleteLinkPgpByID(id int64) error
+ DeleteLinkTags(linkID int64) error
+ DeleteNotificationByID(notificationID int64) error
+ DeleteOldAuditLogs()
+ DeleteOldCaptchaRequests()
+ DeleteOldChatMessages()
+ DeleteOldPrivateChatRooms()
+ DeleteOldSecurityLogs()
+ DeleteOldSessions()
+ DeleteOldUploads()
+ DeleteReaction(userID UserID, messageID, reaction int64) error
+ DeleteSessionByToken(token string) error
+ DeleteSessionNotificationByID(sessionNotificationID int64) error
+ DeleteSnippet(userID UserID, name string)
+ DeleteUserByID(userID UserID) (err error)
+ DeleteUserChatInboxMessages(userID UserID) error
+ DeleteUserChatMessages(userID UserID) error
+ DeleteUserOtherSessions(userID UserID, currentToken string) error
+ DeleteUserSessionByToken(userID UserID, token string) error
+ DeleteUserSessions(userID UserID) error
+ DoCreateSession(userID UserID, userAgent string) Session
+ GetActiveUserSessions(userID UserID) (out []Session)
+ GetCategories() (out []CategoriesResult, err error)
+ GetChatMessages(roomID RoomID, username string, userID UserID, displayPms PmDisplayMode, mentionsOnly, displayHellbanned, displayIgnored, displayModerators, displayIgnoredMessages bool) (out ChatMessages, err error)
+ GetChatRoomByID(roomID RoomID) (out ChatRoom, err error)
+ GetChatRoomByName(roomName string) (out ChatRoom, err error)
+ GetChessSubscribers() (out []User, err error)
+ GetClubForumThreads(userID UserID) (out []ForumThreadAug, err error)
+ GetClubMembers() (out []User, err error)
+ GetFiledropByFileName(fileName string) (out Filedrop, err error)
+ GetFiledropByUUID(uuid string) (out Filedrop, err error)
+ GetFiledrops() (out []Filedrop, err error)
+ GetForumCategories() (out []ForumCategory, err error)
+ GetForumCategoryBySlug(slug string) (out ForumCategory, err error)
+ GetForumMessage(messageID ForumMessageID) (out ForumMessage, err error)
+ GetForumMessageByUUID(messageUUID ForumMessageUUID) (out ForumMessage, err error)
+ GetForumThread(threadID ForumThreadID) (out ForumThread, err error)
+ GetForumThreadByID(threadID ForumThreadID) (out ForumThread, err error)
+ GetForumThreadByUUID(threadUUID ForumThreadUUID) (out ForumThread, err error)
+ GetForumThreads() (out []ForumThread, err error)
+ GetGistByUUID(uuid string) (out Gist, err error)
+ GetIgnoredByUsers(userID UserID) (out []IgnoredUser, err error)
+ GetIgnoredUsers(userID UserID) (out []IgnoredUser, err error)
+ GetLinkByID(linkID int64) (out Link, err error)
+ GetLinkByShorthand(shorthand string) (out Link, err error)
+ GetLinkByUUID(linkUUID string) (out Link, err error)
+ GetLinkCategories(linkID int64) (out []LinksCategory, err error)
+ GetLinkMirrorByID(id int64) (out LinksMirror, err error)
+ GetLinkMirrors(linkID int64) (out []LinksMirror, err error)
+ GetLinkPgpByID(id int64) (out LinksPgp, err error)
+ GetLinkPgps(linkID int64) (out []LinksPgp, err error)
+ GetLinkTags(linkID int64) (out []LinksTag, err error)
+ GetLinks() (out []Link, err error)
+ GetListedChatRooms(userID UserID) (out []ChatRoomAug, err error)
+ GetModeratorsUsers() (out []User, err error)
+ GetOfficialChatRooms() (out []ChatRoom, err error)
+ GetOfficialChatRooms1(userID UserID) (out []ChatRoomAug, err error)
+ GetOnionBlacklist(hash string) (out OnionBlacklist, err error)
+ GetPmBlacklistedByUsers(userID UserID) (out []PmBlacklistedUsers, err error)
+ GetPmBlacklistedUsers(userID UserID) (out []PmBlacklistedUsers, err error)
+ GetPmWhitelistedUsers(userID UserID) (out []PmWhitelistedUsers, err error)
+ GetPublicForumCategoryThreads(userID UserID, categoryID ForumCategoryID) (out []ForumThreadAug, err error)
+ GetPublicForumThreadsSearch(userID UserID) (out []ForumThreadAug, err error)
+ GetRecentLinks() (out []Link, err error)
+ GetRecentUsersCount() int64
+ GetRoomChatMessageByDate(roomID RoomID, userID UserID, dt time.Time) (out ChatMessage, err error)
+ GetRoomChatMessageByUUID(roomID RoomID, msgUUID string) (out ChatMessage, err error)
+ GetRoomChatMessages(roomID RoomID) (out ChatMessages, err error)
+ GetRoomChatMessagesByDate(roomID RoomID, dt time.Time) (out []ChatMessage, err error)
+ GetRoomGroupByName(roomID RoomID, groupName string) (out ChatRoomGroup, err error)
+ GetRoomGroupUsers(roomID RoomID, groupID GroupID) (out []ChatRoomUserGroup, err error)
+ GetRoomGroups(roomID RoomID) (out []ChatRoomGroup, err error)
+ GetSecurityLogs(userID UserID) (out []SecurityLog, err error)
+ GetSettings() (out Settings)
+ GetThreadMessages(threadID ForumThreadID) (out []ForumMessage, err error)
+ GetUnusedInvitationByToken(token string) (out Invitation, err error)
+ GetUploadByFileName(filename string) (out Upload, err error)
+ GetUploadByID(uploadID UploadID) (out Upload, err error)
+ GetUploads() (out []Upload, err error)
+ GetUserByApiKey(user *User, apiKey string) error
+ GetUserByID(userID UserID) (out User, err error)
+ GetUserBySessionKey(user *User, sessionKey string) error
+ GetUserByUsername(username string) (out User, err error)
+ GetUserChatInboxMessages(userID UserID) (msgs []ChatInboxMessage, err error)
+ GetUserChatInboxMessagesSent(userID UserID) (msgs []ChatInboxMessage, err error)
+ GetUserInboxMessagesCount(userID UserID) (count int64)
+ GetUserInvitations(userID UserID) (out []Invitation, err error)
+ GetUserLastChatMessageInRoom(userID UserID, roomID RoomID) (out ChatMessage, err error)
+ GetUserNotifications(userID UserID) (msgs []Notification, err error)
+ GetUserNotificationsCount(userID UserID) (count int64)
+ GetUserPrivateNotes(userID UserID) (out UserPrivateNote, err error)
+ GetUserPublicNotes(userID UserID) (out UserPublicNote, err error)
+ GetUserReadMarker(userID UserID, roomID RoomID) (out ChatReadMarker, err error)
+ GetUserRoomGroups(userID UserID, roomID RoomID) (out []ChatRoomUserGroup, err error)
+ GetUserRoomSubscriptions(userID UserID) (out []ChatRoomAug, err error)
+ GetUserSessionNotifications(sessionToken string) (msgs []SessionNotification, err error)
+ GetUserSessionNotificationsCount(sessionToken string) (count int64)
+ GetUserSnippets(userID UserID) (out []Snippet, err error)
+ GetUserTotalUploadSize(userID UserID) int64
+ GetUserUnusedInvitations(userID UserID) (out []Invitation, err error)
+ GetUserUploads(userID UserID) (out []Upload, err error)
+ GetUsersBadges() (out []UserBadge, err error)
+ GetUsersByID(ids []UserID) (out []User, err error)
+ GetUsersByUsername(usernames []string) (out []User, err error)
+ GetUsersSubscribedToForumThread(threadID ForumThreadID) (out []UserForumThreadSubscription, err error)
+ GetVerifiedUserBySessionID(token string) (out User, err error)
+ GetVerifiedUserByUsername(username string) (out User, err error)
+ GetWhitelistedUsers(roomID RoomID) (out []ChatRoomWhitelistedUser, err error)
+ GetXmrInvoiceByAddress(address string) (out XmrInvoice, err error)
+ IgnoreMessage(userID UserID, messageID int64)
+ IgnoreUser(userID, ignoredUserID UserID)
+ IsPasswordProhibited(password string) bool
+ IsUserInGroupByID(userID UserID, groupID GroupID) bool
+ IsUserPmBlacklisted(fromUserID, toUserID UserID) bool
+ IsUserPmWhitelisted(fromUserID, toUserID UserID) bool
+ IsUserSubscribedToForumThread(userID UserID, threadID ForumThreadID) bool
+ IsUserSubscribedToRoom(userID UserID, roomID RoomID) bool
+ IsUserWhitelistedInRoom(userID UserID, roomID RoomID) bool
+ IsUsernameAlreadyTaken(username string) bool
+ NewAudit(authUser User, log string)
+ RmBlacklistedUser(userID, blacklistedUserID UserID)
+ RmUserFromRoomGroup(roomID RoomID, groupID GroupID, userID UserID) (err error)
+ RmWhitelistedUser(userID, whitelistedUserID UserID)
+ SetUserPrivateNotes(userID UserID, notes string) error
+ SetUserPublicNotes(userID UserID, notes string) error
+ SubscribeToForumThread(userID UserID, threadID ForumThreadID) (err error)
+ SubscribeToRoom(userID UserID, roomID RoomID) (err error)
+ ToggleBlacklistedUser(userID, blacklistedUserID UserID) bool
+ ToggleWhitelistedUser(userID, whitelistedUserID UserID) bool
+ UnIgnoreMessage(userID UserID, messageID int64)
+ UnIgnoreUser(userID, ignoredUserID UserID)
+ UnsubscribeFromForumThread(userID UserID, threadID ForumThreadID) (err error)
+ UnsubscribeFromRoom(userID UserID, roomID RoomID) (err error)
+ UpdateChatReadMarker(userID UserID, roomID RoomID)
+ UpdateChatReadRecord(userID UserID, roomID RoomID)
+ UpdateForumReadRecord(userID UserID, threadID ForumThreadID)
+ UserNbDownloaded(userID UserID, filename string) (out int64)
+ WhitelistUser(roomID RoomID, userID UserID) (out ChatRoomWhitelistedUser, err error)
+}
+
+var once sync.Once
+var inst *DkfDB
+
+func DB() *DkfDB {
+ once.Do(func() {
+ dbPath := filepath.Join(config.Global.ProjectPath(), config.DbFileName)
+ db, err := gorm.Open("sqlite3", dbPath)
+ if err != nil {
+ logrus.Fatal("Failed to open sqlite3 db : " + err.Error())
+ }
+ db.DB().SetMaxIdleConns(1) // 10
+ db.DB().SetMaxOpenConns(1) // 25
+ db.LogMode(false)
+ db.Exec("PRAGMA foreign_keys=ON")
+ inst = &DkfDB{db: db}
+ })
+ return inst
}
// DB2 is the SQL database.
diff --git a/pkg/database/tableAuditLog.go b/pkg/database/tableAuditLog.go
@@ -14,14 +14,14 @@ type AuditLog struct {
User User
}
-func NewAudit(authUser User, log string) {
- if err := DB.Create(&AuditLog{UserID: authUser.ID, Log: log}).Error; err != nil {
+func (d *DkfDB) NewAudit(authUser User, log string) {
+ if err := d.db.Create(&AuditLog{UserID: authUser.ID, Log: log}).Error; err != nil {
logrus.Error(err)
}
}
-func DeleteOldAuditLogs() {
- if err := DB.Delete(AuditLog{}, "created_at < date('now', '-90 Day')").Error; err != nil {
+func (d *DkfDB) DeleteOldAuditLogs() {
+ if err := d.db.Delete(AuditLog{}, "created_at < date('now', '-90 Day')").Error; err != nil {
logrus.Error(err)
}
}
diff --git a/pkg/database/tableBadges.go b/pkg/database/tableBadges.go
@@ -16,12 +16,12 @@ type UserBadge struct {
Badge Badge
}
-func CreateUserBadge(userID UserID, badgeID int64) error {
+func (d *DkfDB) CreateUserBadge(userID UserID, badgeID int64) error {
ub := UserBadge{UserID: userID, BadgeID: badgeID}
- return DB.Create(&ub).Error
+ return d.db.Create(&ub).Error
}
-func GetUsersBadges() (out []UserBadge, err error) {
- err = DB.Preload("User").Preload("Badge").Order("created_at").Find(&out).Error
+func (d *DkfDB) GetUsersBadges() (out []UserBadge, err error) {
+ err = d.db.Preload("User").Preload("Badge").Order("created_at").Find(&out).Error
return
}
diff --git a/pkg/database/tableCaptchaRequests.go b/pkg/database/tableCaptchaRequests.go
@@ -19,8 +19,8 @@ type CaptchaRequest struct {
// return base64.StdEncoding.EncodeToString(r.CaptchaImg)
//}
-func DeleteOldCaptchaRequests() {
- if err := DB.Delete(CaptchaRequest{}, "created_at < date('now', '-90 Day')").Error; err != nil {
+func (d *DkfDB) DeleteOldCaptchaRequests() {
+ if err := d.db.Delete(CaptchaRequest{}, "created_at < date('now', '-90 Day')").Error; err != nil {
logrus.Error(err)
}
}
diff --git a/pkg/database/tableChatInbox.go b/pkg/database/tableChatInbox.go
@@ -22,8 +22,8 @@ type ChatInboxMessage struct {
Room ChatRoom
}
-func GetUserChatInboxMessages(userID UserID) (msgs []ChatInboxMessage, err error) {
- err = DB.Order("id DESC").
+func (d *DkfDB) GetUserChatInboxMessages(userID UserID) (msgs []ChatInboxMessage, err error) {
+ err = d.db.Order("id DESC").
Limit(50).
Preload("User").
Preload("ToUser").
@@ -33,14 +33,14 @@ func GetUserChatInboxMessages(userID UserID) (msgs []ChatInboxMessage, err error
for _, msg := range msgs {
ids = append(ids, msg.ID)
}
- if err := DB.Model(&ChatInboxMessage{}).Where("id IN (?)", ids).UpdateColumn("is_read", true).Error; err != nil {
+ if err := d.db.Model(&ChatInboxMessage{}).Where("id IN (?)", ids).UpdateColumn("is_read", true).Error; err != nil {
logrus.Error(err)
}
return
}
-func GetUserChatInboxMessagesSent(userID UserID) (msgs []ChatInboxMessage, err error) {
- err = DB.Order("id DESC").
+func (d *DkfDB) GetUserChatInboxMessagesSent(userID UserID) (msgs []ChatInboxMessage, err error) {
+ err = d.db.Order("id DESC").
Limit(50).
Preload("User").
Preload("ToUser").
@@ -49,30 +49,30 @@ func GetUserChatInboxMessagesSent(userID UserID) (msgs []ChatInboxMessage, err e
return
}
-func DeleteChatInboxMessageByID(messageID int64) error {
- return DB.Where("id = ?", messageID).Delete(&ChatInboxMessage{}).Error
+func (d *DkfDB) DeleteChatInboxMessageByID(messageID int64) error {
+ return d.db.Where("id = ?", messageID).Delete(&ChatInboxMessage{}).Error
}
-func DeleteChatInboxMessageByChatMessageID(chatMessageID int64) error {
- return DB.Where("chat_message_id = ?", chatMessageID).Delete(&ChatInboxMessage{}).Error
+func (d *DkfDB) DeleteChatInboxMessageByChatMessageID(chatMessageID int64) error {
+ return d.db.Where("chat_message_id = ?", chatMessageID).Delete(&ChatInboxMessage{}).Error
}
-func DeleteAllChatInbox(userID UserID) error {
- return DB.Where("to_user_id = ?", userID).Delete(&ChatInboxMessage{}).Error
+func (d *DkfDB) DeleteAllChatInbox(userID UserID) error {
+ return d.db.Where("to_user_id = ?", userID).Delete(&ChatInboxMessage{}).Error
}
-func DeleteUserChatInboxMessages(userID UserID) error {
- return DB.Where("user_id = ?", userID).Delete(&ChatInboxMessage{}).Error
+func (d *DkfDB) DeleteUserChatInboxMessages(userID UserID) error {
+ return d.db.Where("user_id = ?", userID).Delete(&ChatInboxMessage{}).Error
}
-func CreateInboxMessage(msg string, roomID RoomID, fromUserID, toUserID UserID, isPm, moderators bool, msgID *int64) {
+func (d *DkfDB) CreateInboxMessage(msg string, roomID RoomID, fromUserID, toUserID UserID, isPm, moderators bool, msgID *int64) {
inbox := ChatInboxMessage{Message: msg, RoomID: roomID, UserID: fromUserID, ToUserID: toUserID, IsPm: isPm, Moderators: moderators, ChatMessageID: msgID}
- if err := DB.Create(&inbox).Error; err != nil {
+ if err := d.db.Create(&inbox).Error; err != nil {
logrus.Error(err)
}
}
-func GetUserInboxMessagesCount(userID UserID) (count int64) {
- DB.Table("chat_inbox_messages").Where("to_user_id = ? AND is_read = ?", userID, false).Count(&count)
+func (d *DkfDB) GetUserInboxMessagesCount(userID UserID) (count int64) {
+ d.db.Table("chat_inbox_messages").Where("to_user_id = ? AND is_read = ?", userID, false).Count(&count)
return
}
diff --git a/pkg/database/tableChatMessages.go b/pkg/database/tableChatMessages.go
@@ -212,14 +212,14 @@ func (m *ChatMessage) TrimMe() string {
return "<p>" + strings.TrimPrefix(m.Message, "<p>/me ")
}
-func (m *ChatMessage) DoSave() {
- if err := DB.Save(m).Error; err != nil {
+func (m *ChatMessage) DoSave(db *DkfDB) {
+ if err := db.db.Save(m).Error; err != nil {
logrus.Error(err)
}
}
-func GetUserLastChatMessageInRoom(userID UserID, roomID RoomID) (out ChatMessage, err error) {
- err = DB.
+func (d *DkfDB) GetUserLastChatMessageInRoom(userID UserID, roomID RoomID) (out ChatMessage, err error) {
+ err = d.db.
Where("user_id = ? AND room_id = ?", userID, roomID).
Order("id DESC").
Preload("User").
@@ -230,8 +230,8 @@ func GetUserLastChatMessageInRoom(userID UserID, roomID RoomID) (out ChatMessage
return
}
-func GetRoomChatMessages(roomID RoomID) (out ChatMessages, err error) {
- err = DB.
+func (d *DkfDB) GetRoomChatMessages(roomID RoomID) (out ChatMessages, err error) {
+ err = d.db.
Where("room_id = ?", roomID).
Preload("User").
Preload("ToUser").
@@ -241,8 +241,8 @@ func GetRoomChatMessages(roomID RoomID) (out ChatMessages, err error) {
return
}
-func GetRoomChatMessageByUUID(roomID RoomID, msgUUID string) (out ChatMessage, err error) {
- err = DB.
+func (d *DkfDB) GetRoomChatMessageByUUID(roomID RoomID, msgUUID string) (out ChatMessage, err error) {
+ err = d.db.
Where("room_id = ? AND uuid = ?", roomID, msgUUID).
Preload("User").
Preload("ToUser").
@@ -252,8 +252,8 @@ func GetRoomChatMessageByUUID(roomID RoomID, msgUUID string) (out ChatMessage, e
return
}
-func GetRoomChatMessageByDate(roomID RoomID, userID UserID, dt time.Time) (out ChatMessage, err error) {
- err = DB.
+func (d *DkfDB) GetRoomChatMessageByDate(roomID RoomID, userID UserID, dt time.Time) (out ChatMessage, err error) {
+ err = d.db.
Select("*, strftime('%Y-%m-%d %H:%M:%S', created_at) as created_at1").
Where("room_id = ? AND user_id = ? AND created_at1 = ?", roomID, userID, dt.Format("2006-01-02 15:04:05")).
Preload("User").
@@ -263,8 +263,8 @@ func GetRoomChatMessageByDate(roomID RoomID, userID UserID, dt time.Time) (out C
return
}
-func GetRoomChatMessagesByDate(roomID RoomID, dt time.Time) (out []ChatMessage, err error) {
- err = DB.
+func (d *DkfDB) GetRoomChatMessagesByDate(roomID RoomID, dt time.Time) (out []ChatMessage, err error) {
+ err = d.db.
Select("*, strftime('%m-%d %H:%M:%S', created_at) as created_at1").
Where("room_id = ? AND created_at1 = ?", roomID, dt.Format("01-02 15:04:05")).
Preload("User").
@@ -283,11 +283,11 @@ const (
PmNone
)
-func GetChatMessages(roomID RoomID, username string, userID UserID, displayPms PmDisplayMode, mentionsOnly, displayHellbanned, displayIgnored, displayModerators, displayIgnoredMessages bool) (out ChatMessages, err error) {
+func (d *DkfDB) GetChatMessages(roomID RoomID, username string, userID UserID, displayPms PmDisplayMode, mentionsOnly, displayHellbanned, displayIgnored, displayModerators, displayIgnoredMessages bool) (out ChatMessages, err error) {
cmp := func(t, t2 ChatMessage) bool { return t.ID > t2.ID }
- q := DB.
+ q := d.db.
Preload("User").
Preload("ToUser").
Preload("Room").
@@ -343,7 +343,7 @@ func GetChatMessages(roomID RoomID, username string, userID UserID, displayPms P
//-----------
- qg := DB.
+ qg := d.db.
Preload("User").
Preload("ToUser").
Preload("Room").
@@ -393,22 +393,22 @@ func sortedMerge[T any](a, b []T, less func(T, T) bool) []T {
return out
}
-func DeleteChatRoomMessages(roomID RoomID) error {
- return DB.Delete(&ChatMessage{}, "room_id = ?", roomID).Error
+func (d *DkfDB) DeleteChatRoomMessages(roomID RoomID) error {
+ return d.db.Delete(&ChatMessage{}, "room_id = ?", roomID).Error
}
-func DeleteChatMessageByUUID(messageUUID string) error {
- return DB.Where("uuid = ?", messageUUID).Delete(&ChatMessage{}).Error
+func (d *DkfDB) DeleteChatMessageByUUID(messageUUID string) error {
+ return d.db.Where("uuid = ?", messageUUID).Delete(&ChatMessage{}).Error
}
-func DeleteUserChatMessages(userID UserID) error {
- return DB.Where("user_id = ?", userID).Delete(&ChatMessage{}).Error
+func (d *DkfDB) DeleteUserChatMessages(userID UserID) error {
+ return d.db.Where("user_id = ?", userID).Delete(&ChatMessage{}).Error
}
-func DeleteOldChatMessages() {
- rooms, _ := GetOfficialChatRooms()
+func (d *DkfDB) DeleteOldChatMessages() {
+ rooms, _ := d.GetOfficialChatRooms()
for _, room := range rooms {
- DB.Exec(`
+ d.db.Exec(`
DELETE FROM chat_messages
-- Don't delete the last 500 "non PM" and "not hellbanned" messages
WHERE id NOT IN (
@@ -447,7 +447,7 @@ func makeMsg(raw, txt string, roomID RoomID, userID UserID) ChatMessage {
return msg
}
-func CreateMsg(raw, txt, roomKey string, roomID RoomID, userID UserID, toUserID *UserID) (out ChatMessage, err error) {
+func (d *DkfDB) CreateMsg(raw, txt, roomKey string, roomID RoomID, userID UserID, toUserID *UserID) (out ChatMessage, err error) {
if roomKey != "" {
var err error
txt, raw, err = encryptMessages(txt, raw, roomKey)
@@ -460,11 +460,11 @@ func CreateMsg(raw, txt, roomKey string, roomID RoomID, userID UserID, toUserID
if toUserID != nil {
out.ToUserID = toUserID
}
- err = DB.Create(&out).Error
+ err = d.db.Create(&out).Error
return
}
-func CreateSysMsg(raw, txt, roomKey string, roomID RoomID, userID UserID) error {
+func (d *DkfDB) CreateSysMsg(raw, txt, roomKey string, roomID RoomID, userID UserID) error {
if roomKey != "" {
var err error
txt, raw, err = encryptMessages(txt, raw, roomKey)
@@ -474,30 +474,30 @@ func CreateSysMsg(raw, txt, roomKey string, roomID RoomID, userID UserID) error
}
msg := makeMsg(raw, txt, roomID, userID)
msg.System = true
- return DB.Create(&msg).Error
+ return d.db.Create(&msg).Error
}
-func CreateKickMsg(kickedUser, kickedByUser User) {
+func (d *DkfDB) CreateKickMsg(kickedUser, kickedByUser User) {
// Display kick message
styledUsername := fmt.Sprintf(`<span %s>%s</span>`, kickedUser.GenerateChatStyle(), kickedUser.Username)
rawTxt := fmt.Sprintf("%s has been kicked. (%s)", kickedUser.Username, kickedByUser.Username)
txt := fmt.Sprintf("%s has been kicked. (%s)", styledUsername, kickedByUser.Username)
- if err := CreateSysMsg(rawTxt, txt, "", config.GeneralRoomID, kickedByUser.ID); err != nil {
+ if err := d.CreateSysMsg(rawTxt, txt, "", config.GeneralRoomID, kickedByUser.ID); err != nil {
logrus.Error(err)
}
}
-func CreateUnkickMsg(kickedUser, kickedByUser User) {
+func (d *DkfDB) CreateUnkickMsg(kickedUser, kickedByUser User) {
// Display unkick message
styledUsername := fmt.Sprintf(`<span %s>%s</span>`, kickedUser.GenerateChatStyle(), kickedUser.Username)
rawTxt := fmt.Sprintf("%s has been unkicked. (%s)", kickedUser.Username, kickedByUser.Username)
txt := fmt.Sprintf("%s has been unkicked. (%s)", styledUsername, kickedByUser.Username)
- if err := CreateSysMsg(rawTxt, txt, "", config.GeneralRoomID, kickedByUser.ID); err != nil {
+ if err := d.CreateSysMsg(rawTxt, txt, "", config.GeneralRoomID, kickedByUser.ID); err != nil {
logrus.Error(err)
}
}
-func CreateOrEditMessage(
+func (d *DkfDB) CreateOrEditMessage(
editMsg *ChatMessage,
message, raw, roomKey string,
roomID RoomID,
@@ -519,7 +519,7 @@ func CreateOrEditMessage(
editMsg.Message = message
editMsg.RawMessage = raw
// Delete inboxes, we'll create new ones bellow
- _ = DeleteChatInboxMessageByChatMessageID(editMsg.ID)
+ _ = d.DeleteChatInboxMessageByChatMessageID(editMsg.ID)
} else {
msg := makeMsg(raw, message, roomID, fromUserID)
editMsg = &msg
@@ -532,7 +532,7 @@ func CreateOrEditMessage(
editMsg.UploadID = &upload.ID
}
}
- editMsg.DoSave()
+ editMsg.DoSave(d)
return editMsg.ID, nil
}
diff --git a/pkg/database/tableChatReactions.go b/pkg/database/tableChatReactions.go
@@ -12,15 +12,15 @@ type ChatReaction struct {
CreatedAt time.Time
}
-func CreateChatReaction(userID UserID, messageID, reaction int64) error {
+func (d *DkfDB) CreateChatReaction(userID UserID, messageID, reaction int64) error {
out := ChatReaction{
UserID: userID,
MessageID: messageID,
Reaction: reaction,
}
- return DB.Create(&out).Error
+ return d.db.Create(&out).Error
}
-func DeleteReaction(userID UserID, messageID, reaction int64) error {
- return DB.Delete(ChatReaction{}, "user_id = ? AND message_id = ? AND reaction = ?", userID, messageID, reaction).Error
+func (d *DkfDB) DeleteReaction(userID UserID, messageID, reaction int64) error {
+ return d.db.Delete(ChatReaction{}, "user_id = ? AND message_id = ? AND reaction = ?", userID, messageID, reaction).Error
}
diff --git a/pkg/database/tableChatReadMarkers.go b/pkg/database/tableChatReadMarkers.go
@@ -13,15 +13,15 @@ type ChatReadMarker struct {
ReadAt time.Time
}
-func GetUserReadMarker(userID UserID, roomID RoomID) (out ChatReadMarker, err error) {
- err = DB.First(&out, "user_id = ? AND room_id = ?", userID, roomID).Error
+func (d *DkfDB) GetUserReadMarker(userID UserID, roomID RoomID) (out ChatReadMarker, err error) {
+ err = d.db.First(&out, "user_id = ? AND room_id = ?", userID, roomID).Error
return
}
-func UpdateChatReadMarker(userID UserID, roomID RoomID) {
+func (d *DkfDB) UpdateChatReadMarker(userID UserID, roomID RoomID) {
now := time.Now()
- res := DB.Table("chat_read_markers").Where("user_id = ? AND room_id = ?", userID, roomID).Update("read_at", now)
+ res := d.db.Table("chat_read_markers").Where("user_id = ? AND room_id = ?", userID, roomID).Update("read_at", now)
if res.RowsAffected == 0 {
- DB.Create(ChatReadMarker{UserID: userID, RoomID: roomID, ReadAt: now})
+ d.db.Create(ChatReadMarker{UserID: userID, RoomID: roomID, ReadAt: now})
}
}
diff --git a/pkg/database/tableChatRoomGroups.go b/pkg/database/tableChatRoomGroups.go
@@ -16,8 +16,8 @@ type ChatRoomGroup struct {
CreatedAt time.Time
}
-func (g *ChatRoomGroup) DoSave() {
- if err := DB.Save(g).Error; err != nil {
+func (g *ChatRoomGroup) DoSave(db *DkfDB) {
+ if err := db.db.Save(g).Error; err != nil {
logrus.Error(err)
}
}
@@ -29,60 +29,60 @@ type ChatRoomUserGroup struct {
User User
}
-func GetUserRoomGroups(userID UserID, roomID RoomID) (out []ChatRoomUserGroup, err error) {
- err = DB.Find(&out, "user_id = ? AND room_id = ?", userID, roomID).Error
+func (d *DkfDB) GetUserRoomGroups(userID UserID, roomID RoomID) (out []ChatRoomUserGroup, err error) {
+ err = d.db.Find(&out, "user_id = ? AND room_id = ?", userID, roomID).Error
return
}
-func GetRoomGroupByName(roomID RoomID, groupName string) (out ChatRoomGroup, err error) {
- err = DB.First(&out, "room_id = ? AND name = ?", roomID, groupName).Error
+func (d *DkfDB) GetRoomGroupByName(roomID RoomID, groupName string) (out ChatRoomGroup, err error) {
+ err = d.db.First(&out, "room_id = ? AND name = ?", roomID, groupName).Error
return
}
-func IsUserInGroupByID(userID UserID, groupID GroupID) bool {
+func (d *DkfDB) IsUserInGroupByID(userID UserID, groupID GroupID) bool {
var count int64
- DB.Model(ChatRoomUserGroup{}).Where("group_id = ? AND user_id = ?", groupID, userID).Count(&count)
+ d.db.Model(ChatRoomUserGroup{}).Where("group_id = ? AND user_id = ?", groupID, userID).Count(&count)
return count == 1
}
-func DeleteChatRoomGroup(roomID RoomID, name string) (err error) {
- err = DB.Delete(&ChatRoomGroup{}, "room_id = ? AND name = ?", roomID, name).Error
+func (d *DkfDB) DeleteChatRoomGroup(roomID RoomID, name string) (err error) {
+ err = d.db.Delete(&ChatRoomGroup{}, "room_id = ? AND name = ?", roomID, name).Error
return
}
-func DeleteChatRoomGroups(roomID RoomID) (err error) {
- err = DB.Delete(&ChatRoomGroup{}, "room_id = ?", roomID).Error
+func (d *DkfDB) DeleteChatRoomGroups(roomID RoomID) (err error) {
+ err = d.db.Delete(&ChatRoomGroup{}, "room_id = ?", roomID).Error
return
}
-func CreateChatRoomGroup(roomID RoomID, name, color string) (out ChatRoomGroup, err error) {
+func (d *DkfDB) CreateChatRoomGroup(roomID RoomID, name, color string) (out ChatRoomGroup, err error) {
out = ChatRoomGroup{Name: name, Color: color, RoomID: roomID}
- err = DB.Create(&out).Error
+ err = d.db.Create(&out).Error
return
}
-func AddUserToRoomGroup(roomID RoomID, groupID GroupID, userID UserID) (out ChatRoomUserGroup, err error) {
+func (d *DkfDB) AddUserToRoomGroup(roomID RoomID, groupID GroupID, userID UserID) (out ChatRoomUserGroup, err error) {
out = ChatRoomUserGroup{GroupID: groupID, RoomID: roomID, UserID: userID}
- err = DB.Create(&out).Error
+ err = d.db.Create(&out).Error
return
}
-func RmUserFromRoomGroup(roomID RoomID, groupID GroupID, userID UserID) (err error) {
- err = DB.Delete(&ChatRoomUserGroup{}, "user_id = ? AND group_id = ? AND room_id = ?", userID, groupID, roomID).Error
+func (d *DkfDB) RmUserFromRoomGroup(roomID RoomID, groupID GroupID, userID UserID) (err error) {
+ err = d.db.Delete(&ChatRoomUserGroup{}, "user_id = ? AND group_id = ? AND room_id = ?", userID, groupID, roomID).Error
return
}
-func ClearRoomGroup(roomID RoomID, groupID GroupID) (err error) {
- err = DB.Delete(&ChatRoomUserGroup{}, "group_id = ? AND room_id = ?", groupID, roomID).Error
+func (d *DkfDB) ClearRoomGroup(roomID RoomID, groupID GroupID) (err error) {
+ err = d.db.Delete(&ChatRoomUserGroup{}, "group_id = ? AND room_id = ?", groupID, roomID).Error
return
}
-func GetRoomGroups(roomID RoomID) (out []ChatRoomGroup, err error) {
- err = DB.Find(&out, "room_id = ?", roomID).Error
+func (d *DkfDB) GetRoomGroups(roomID RoomID) (out []ChatRoomGroup, err error) {
+ err = d.db.Find(&out, "room_id = ?", roomID).Error
return
}
-func GetRoomGroupUsers(roomID RoomID, groupID GroupID) (out []ChatRoomUserGroup, err error) {
- err = DB.Where("room_id = ? AND group_id = ?", roomID, groupID).Preload("User").Find(&out).Error
+func (d *DkfDB) GetRoomGroupUsers(roomID RoomID, groupID GroupID) (out []ChatRoomUserGroup, err error) {
+ err = d.db.Where("room_id = ? AND group_id = ?", roomID, groupID).Preload("User").Find(&out).Error
return
}
diff --git a/pkg/database/tableChatRoomWhitelistedUsers.go b/pkg/database/tableChatRoomWhitelistedUsers.go
@@ -13,30 +13,30 @@ type ChatRoomWhitelistedUser struct {
User User
}
-func (r *ChatRoomWhitelistedUser) DoSave() {
- if err := DB.Save(r).Error; err != nil {
+func (r *ChatRoomWhitelistedUser) DoSave(db *DkfDB) {
+ if err := db.db.Save(r).Error; err != nil {
logrus.Error(err)
}
}
-func IsUserWhitelistedInRoom(userID UserID, roomID RoomID) bool {
+func (d *DkfDB) IsUserWhitelistedInRoom(userID UserID, roomID RoomID) bool {
var count int64
- DB.Table("chat_room_whitelisted_users").Where("user_id = ? and room_id = ?", userID, roomID).Count(&count)
+ d.db.Table("chat_room_whitelisted_users").Where("user_id = ? and room_id = ?", userID, roomID).Count(&count)
return count == 1
}
-func GetWhitelistedUsers(roomID RoomID) (out []ChatRoomWhitelistedUser, err error) {
- err = DB.Preload("User").Find(&out, "room_id = ?", roomID).Error
+func (d *DkfDB) GetWhitelistedUsers(roomID RoomID) (out []ChatRoomWhitelistedUser, err error) {
+ err = d.db.Preload("User").Find(&out, "room_id = ?", roomID).Error
return
}
-func WhitelistUser(roomID RoomID, userID UserID) (out ChatRoomWhitelistedUser, err error) {
+func (d *DkfDB) WhitelistUser(roomID RoomID, userID UserID) (out ChatRoomWhitelistedUser, err error) {
out = ChatRoomWhitelistedUser{UserID: userID, RoomID: roomID}
- err = DB.Create(&out).Error
+ err = d.db.Create(&out).Error
return
}
-func DeWhitelistUser(roomID RoomID, userID UserID) (err error) {
- err = DB.Delete(ChatRoomWhitelistedUser{}, "user_id = ? and room_id = ?", userID, roomID).Error
+func (d *DkfDB) DeWhitelistUser(roomID RoomID, userID UserID) (err error) {
+ err = d.db.Delete(ChatRoomWhitelistedUser{}, "user_id = ? and room_id = ?", userID, roomID).Error
return
}
diff --git a/pkg/database/tableChatRooms.go b/pkg/database/tableChatRooms.go
@@ -31,7 +31,7 @@ const (
UserWhitelistRoomMode = 1
)
-func CreateRoom(name string, passwordHash string, ownerID UserID, isListed bool) (out ChatRoom, err error) {
+func (d *DkfDB) CreateRoom(name string, passwordHash string, ownerID UserID, isListed bool) (out ChatRoom, err error) {
out = ChatRoom{
Name: name,
Password: passwordHash,
@@ -39,7 +39,7 @@ func CreateRoom(name string, passwordHash string, ownerID UserID, isListed bool)
IsListed: isListed,
IsEphemeral: true,
}
- err = DB.Create(&out).Error
+ err = d.db.Create(&out).Error
return
}
@@ -72,8 +72,8 @@ func (r *ChatRoom) IsProtected() bool {
return r.Password != ""
}
-func (r *ChatRoom) DoSave() {
- if err := DB.Save(r).Error; err != nil {
+func (r *ChatRoom) DoSave(db *DkfDB) {
+ if err := db.db.Save(r).Error; err != nil {
logrus.Error(err)
}
}
@@ -88,6 +88,7 @@ func (r *ChatRoom) IsOfficialRoom() bool {
func (r *ChatRoom) HasAccess(c echo.Context) bool {
authUser := c.Get("authUser").(*User)
+ db := c.Get("database").(*DkfDB)
if authUser == nil {
return false
}
@@ -99,7 +100,7 @@ func (r *ChatRoom) HasAccess(c echo.Context) bool {
}
if r.Mode == UserWhitelistRoomMode {
if r.OwnerUserID != nil && *r.OwnerUserID != authUser.ID {
- if !IsUserWhitelistedInRoom(authUser.ID, r.ID) {
+ if !db.IsUserWhitelistedInRoom(authUser.ID, r.ID) {
return false
}
}
@@ -118,18 +119,18 @@ func (r *ChatRoom) HasAccess(c echo.Context) bool {
return true
}
-func GetChatRoomByID(roomID RoomID) (out ChatRoom, err error) {
- err = DB.Where("id = ?", roomID).First(&out).Error
+func (d *DkfDB) GetChatRoomByID(roomID RoomID) (out ChatRoom, err error) {
+ err = d.db.Where("id = ?", roomID).First(&out).Error
return
}
-func GetChatRoomByName(roomName string) (out ChatRoom, err error) {
- err = DB.Where("name = ?", roomName).First(&out).Error
+func (d *DkfDB) GetChatRoomByName(roomName string) (out ChatRoom, err error) {
+ err = d.db.Where("name = ?", roomName).First(&out).Error
return
}
-func DeleteChatRoomByID(id RoomID) {
- if err := DB.Delete(ChatRoom{}, "id = ?", id).Error; err != nil {
+func (d *DkfDB) DeleteChatRoomByID(id RoomID) {
+ if err := d.db.Delete(ChatRoom{}, "id = ?", id).Error; err != nil {
logrus.Error(err)
}
}
@@ -141,8 +142,8 @@ type ChatRoomAug struct {
}
// GetOfficialChatRooms1 returns official chat rooms with additional information such as "IsUnread"
-func GetOfficialChatRooms1(userID UserID) (out []ChatRoomAug, err error) {
- err = DB.Raw(`SELECT r.*,
+func (d *DkfDB) GetOfficialChatRooms1(userID UserID) (out []ChatRoomAug, err error) {
+ err = d.db.Raw(`SELECT r.*,
COALESCE((rr.read_at < m.created_at), 1) as is_unread
FROM chat_rooms r
-- Find last message for room
@@ -154,8 +155,8 @@ ORDER BY r.id ASC`, userID, userID).Scan(&out).Error
return
}
-func GetUserRoomSubscriptions(userID UserID) (out []ChatRoomAug, err error) {
- err = DB.Raw(`SELECT r.*,
+func (d *DkfDB) GetUserRoomSubscriptions(userID UserID) (out []ChatRoomAug, err error) {
+ err = d.db.Raw(`SELECT r.*,
COALESCE((rr.read_at < m.created_at), 1) as is_unread
FROM user_room_subscriptions s
INNER JOIN chat_rooms r ON r.id = s.room_id
@@ -168,8 +169,8 @@ ORDER BY r.id ASC`, userID, userID, userID).Scan(&out).Error
return
}
-func GetListedChatRooms(userID UserID) (out []ChatRoomAug, err error) {
- err = DB.Raw(`SELECT r.*,
+func (d *DkfDB) GetListedChatRooms(userID UserID) (out []ChatRoomAug, err error) {
+ err = d.db.Raw(`SELECT r.*,
u.*,
COALESCE((rr.read_at < m.created_at), 1) as is_unread
FROM chat_rooms r
@@ -184,13 +185,13 @@ ORDER BY r.id ASC`, userID, userID).Scan(&out).Error
return
}
-func GetOfficialChatRooms() (out []ChatRoom, err error) {
- err = DB.Where("id IN (1, 2, 3, 4, 14)").Preload("ReadRecord").Find(&out).Error
+func (d *DkfDB) GetOfficialChatRooms() (out []ChatRoom, err error) {
+ err = d.db.Where("id IN (1, 2, 3, 4, 14)").Preload("ReadRecord").Find(&out).Error
return
}
-func DeleteOldPrivateChatRooms() {
- DB.Exec(`DELETE FROM chat_rooms
+func (d *DkfDB) DeleteOldPrivateChatRooms() {
+ d.db.Exec(`DELETE FROM chat_rooms
WHERE owner_user_id IS NOT NULL
AND is_ephemeral = 1
AND ((SELECT chat_messages.created_at FROM chat_messages WHERE chat_messages.room_id = chat_rooms.id ORDER BY chat_messages.ID DESC) < date('now', '-1 Day')
@@ -206,16 +207,16 @@ type ChatReadRecord struct {
ReadAt time.Time
}
-func (r *ChatReadRecord) DoSave() {
- if err := DB.Save(r).Error; err != nil {
+func (r *ChatReadRecord) DoSave(db *DkfDB) {
+ if err := db.db.Save(r).Error; err != nil {
logrus.Error(err)
}
}
-func UpdateChatReadRecord(userID UserID, roomID RoomID) {
+func (d *DkfDB) UpdateChatReadRecord(userID UserID, roomID RoomID) {
now := time.Now()
- res := DB.Table("chat_read_records").Where("user_id = ? AND room_id = ?", userID, roomID).Update("read_at", now)
+ res := d.db.Table("chat_read_records").Where("user_id = ? AND room_id = ?", userID, roomID).Update("read_at", now)
if res.RowsAffected == 0 {
- DB.Create(ChatReadRecord{UserID: userID, RoomID: roomID, ReadAt: now})
+ d.db.Create(ChatReadRecord{UserID: userID, RoomID: roomID, ReadAt: now})
}
}
diff --git a/pkg/database/tableDownloads.go b/pkg/database/tableDownloads.go
@@ -14,19 +14,19 @@ type Download struct {
}
// CreateDownload ...
-func CreateDownload(userID UserID, filename string) (out Download, err error) {
+func (d *DkfDB) CreateDownload(userID UserID, filename string) (out Download, err error) {
out = Download{UserID: userID, Filename: filename}
- err = DB.Create(&out).Error
+ err = d.db.Create(&out).Error
return
}
// UserNbDownloaded returns how many times a user downloaded a file
-func UserNbDownloaded(userID UserID, filename string) (out int64) {
- DB.Table("downloads").Where("user_id = ? AND filename = ?", userID, filename).Count(&out)
+func (d *DkfDB) UserNbDownloaded(userID UserID, filename string) (out int64) {
+ d.db.Table("downloads").Where("user_id = ? AND filename = ?", userID, filename).Count(&out)
return
}
-func DeleteDownloadByID(downloadID int64) (err error) {
- err = DB.Unscoped().Delete(Download{}, "id = ?", downloadID).Error
+func (d *DkfDB) DeleteDownloadByID(downloadID int64) (err error) {
+ err = d.db.Unscoped().Delete(Download{}, "id = ?", downloadID).Error
return
}
diff --git a/pkg/database/tableFiledrops.go b/pkg/database/tableFiledrops.go
@@ -23,25 +23,25 @@ type Filedrop struct {
UpdatedAt *time.Time
}
-func GetFiledropByUUID(uuid string) (out Filedrop, err error) {
- err = DB.First(&out, "uuid = ?", uuid).Error
+func (d *DkfDB) GetFiledropByUUID(uuid string) (out Filedrop, err error) {
+ err = d.db.First(&out, "uuid = ?", uuid).Error
return
}
-func GetFiledropByFileName(fileName string) (out Filedrop, err error) {
- err = DB.First(&out, "file_name = ?", fileName).Error
+func (d *DkfDB) GetFiledropByFileName(fileName string) (out Filedrop, err error) {
+ err = d.db.First(&out, "file_name = ?", fileName).Error
return
}
-func GetFiledrops() (out []Filedrop, err error) {
- err = DB.Find(&out).Error
+func (d *DkfDB) GetFiledrops() (out []Filedrop, err error) {
+ err = d.db.Find(&out).Error
return
}
-func CreateFiledrop() (out Filedrop, err error) {
+func (d *DkfDB) CreateFiledrop() (out Filedrop, err error) {
out.UUID = uuid.New().String()
out.FileName = utils.MD5([]byte(utils.GenerateToken32()))
- err = DB.Save(&out).Error
+ err = d.db.Save(&out).Error
return
}
@@ -65,20 +65,20 @@ func (d *Filedrop) GetContent() (*os.File, *ucrypto.StreamDecrypter, error) {
return f, decrypter, nil
}
-func (d *Filedrop) Delete() error {
+func (d *Filedrop) Delete(db *DkfDB) error {
if d.FileName != "" {
if err := os.Remove(filepath.Join(config.Global.ProjectFiledropPath(), d.FileName)); err != nil {
logrus.Error(err)
}
}
- if err := DB.Delete(&d).Error; err != nil {
+ if err := db.db.Delete(&d).Error; err != nil {
return err
}
return nil
}
-func (d *Filedrop) DoSave() {
- if err := DB.Save(d).Error; err != nil {
+func (d *Filedrop) DoSave(db *DkfDB) {
+ if err := db.db.Save(d).Error; err != nil {
logrus.Error(err)
}
}
diff --git a/pkg/database/tableGists.go b/pkg/database/tableGists.go
@@ -21,8 +21,8 @@ type Gist struct {
CreatedAt time.Time
}
-func GetGistByUUID(uuid string) (out Gist, err error) {
- err = DB.First(&out, "uuid = ?", uuid).Error
+func (d *DkfDB) GetGistByUUID(uuid string) (out Gist, err error) {
+ err = d.db.First(&out, "uuid = ?", uuid).Error
return
}
@@ -50,8 +50,8 @@ func (g *Gist) HasAccess(c echo.Context) bool {
}
// DoSave user in the database, ignore error
-func (g *Gist) DoSave() {
- if err := DB.Save(g).Error; err != nil {
+func (g *Gist) DoSave(db *DkfDB) {
+ if err := db.db.Save(g).Error; err != nil {
logrus.Error(err)
}
}
diff --git a/pkg/database/tableIgnoredMessages.go b/pkg/database/tableIgnoredMessages.go
@@ -9,15 +9,15 @@ type IgnoredMessage struct {
MessageID int64
}
-func IgnoreMessage(userID UserID, messageID int64) {
+func (d *DkfDB) IgnoreMessage(userID UserID, messageID int64) {
ignore := IgnoredMessage{UserID: userID, MessageID: messageID}
- if err := DB.Create(&ignore).Error; err != nil {
+ if err := d.db.Create(&ignore).Error; err != nil {
logrus.Error(err)
}
}
-func UnIgnoreMessage(userID UserID, messageID int64) {
- if err := DB.Delete(&IgnoredMessage{}, "user_id = ? AND message_id = ?", userID, messageID).Error; err != nil {
+func (d *DkfDB) UnIgnoreMessage(userID UserID, messageID int64) {
+ if err := d.db.Delete(&IgnoredMessage{}, "user_id = ? AND message_id = ?", userID, messageID).Error; err != nil {
logrus.Error(err)
}
}
diff --git a/pkg/database/tableIgnoredUsers.go b/pkg/database/tableIgnoredUsers.go
@@ -14,26 +14,26 @@ type IgnoredUser struct {
IgnoredUser User
}
-func GetIgnoredUsers(userID UserID) (out []IgnoredUser, err error) {
- err = DB.Where("user_id = ?", userID).Preload("IgnoredUser").Find(&out).Error
+func (d *DkfDB) GetIgnoredUsers(userID UserID) (out []IgnoredUser, err error) {
+ err = d.db.Where("user_id = ?", userID).Preload("IgnoredUser").Find(&out).Error
return
}
// GetIgnoredByUsers get a list of people who ignore userID
-func GetIgnoredByUsers(userID UserID) (out []IgnoredUser, err error) {
- err = DB.Where("ignored_user_id = ?", userID).Find(&out).Error
+func (d *DkfDB) GetIgnoredByUsers(userID UserID) (out []IgnoredUser, err error) {
+ err = d.db.Where("ignored_user_id = ?", userID).Find(&out).Error
return
}
-func IgnoreUser(userID, ignoredUserID UserID) {
+func (d *DkfDB) IgnoreUser(userID, ignoredUserID UserID) {
ignore := IgnoredUser{UserID: userID, IgnoredUserID: ignoredUserID}
- if err := DB.Create(&ignore).Error; err != nil {
+ if err := d.db.Create(&ignore).Error; err != nil {
logrus.Error(err)
}
}
-func UnIgnoreUser(userID, ignoredUserID UserID) {
- if err := DB.Delete(IgnoredUser{}, "user_id = ? AND ignored_user_id = ?", userID, ignoredUserID).Error; err != nil {
+func (d *DkfDB) UnIgnoreUser(userID, ignoredUserID UserID) {
+ if err := d.db.Delete(IgnoredUser{}, "user_id = ? AND ignored_user_id = ?", userID, ignoredUserID).Error; err != nil {
logrus.Error(err)
}
}
diff --git a/pkg/database/tableInvitations.go b/pkg/database/tableInvitations.go
@@ -16,38 +16,38 @@ type Invitation struct {
}
// Save user in the database
-func (i *Invitation) Save() error {
- return DB.Save(i).Error
+func (i *Invitation) Save(db *DkfDB) error {
+ return db.db.Save(i).Error
}
// DoSave user in the database, ignore error
-func (i *Invitation) DoSave() {
- if err := DB.Save(i).Error; err != nil {
+func (i *Invitation) DoSave(db *DkfDB) {
+ if err := db.db.Save(i).Error; err != nil {
logrus.Error(err)
}
}
-func CreateInvitation(userID UserID) (out Invitation, err error) {
+func (d *DkfDB) CreateInvitation(userID UserID) (out Invitation, err error) {
out = Invitation{
Token: utils.GenerateToken32(),
OwnerUserID: userID,
InviteeUserID: 1,
}
- err = DB.Create(&out).Error
+ err = d.db.Create(&out).Error
return
}
-func GetUnusedInvitationByToken(token string) (out Invitation, err error) {
- err = DB.First(&out, "token = ? AND invitee_user_id == 1", token).Error
+func (d *DkfDB) GetUnusedInvitationByToken(token string) (out Invitation, err error) {
+ err = d.db.First(&out, "token = ? AND invitee_user_id == 1", token).Error
return
}
-func GetUserInvitations(userID UserID) (out []Invitation, err error) {
- err = DB.Find(&out, "owner_user_id = ?", userID).Error
+func (d *DkfDB) GetUserInvitations(userID UserID) (out []Invitation, err error) {
+ err = d.db.Find(&out, "owner_user_id = ?", userID).Error
return
}
-func GetUserUnusedInvitations(userID UserID) (out []Invitation, err error) {
- err = DB.Find(&out, "owner_user_id = ? AND invitee_user_id == 1", userID).Error
+func (d *DkfDB) GetUserUnusedInvitations(userID UserID) (out []Invitation, err error) {
+ err = d.db.Find(&out, "owner_user_id = ? AND invitee_user_id == 1", userID).Error
return
}
diff --git a/pkg/database/tableKarmaHistory.go b/pkg/database/tableKarmaHistory.go
@@ -11,14 +11,14 @@ type KarmaHistory struct {
CreatedAt time.Time
}
-func CreateKarmaHistory(karma int64, description string, userID UserID, fromUserID *int64) (out KarmaHistory, err error) {
+func (d *DkfDB) CreateKarmaHistory(karma int64, description string, userID UserID, fromUserID *int64) (out KarmaHistory, err error) {
out = KarmaHistory{
Karma: karma,
Description: description,
UserID: userID,
FromUserID: fromUserID,
}
- err = DB.Create(&out).Error
+ err = d.db.Create(&out).Error
return
}
diff --git a/pkg/database/tableLinks.go b/pkg/database/tableLinks.go
@@ -49,51 +49,51 @@ func (l Link) DescriptionSafe() string {
return html.EscapeString(l.Description)
}
-func (l *Link) Save() error {
- return DB.Save(l).Error
+func (l *Link) Save(db *DkfDB) error {
+ return db.db.Save(l).Error
}
-func (l *Link) DoSave() {
- if err := DB.Save(l).Error; err != nil {
+func (l *Link) DoSave(db *DkfDB) {
+ if err := l.Save(db); err != nil {
logrus.Error(err)
}
}
-func CreateLink(url, title, description, shorthand string) (out Link, err error) {
+func (d *DkfDB) CreateLink(url, title, description, shorthand string) (out Link, err error) {
out = Link{UUID: uuid.New().String(), URL: url, Title: title, Description: description}
if shorthand != "" {
out.Shorthand = &shorthand
}
- err = DB.FirstOrCreate(&out, "url = ?", url).Error
+ err = d.db.FirstOrCreate(&out, "url = ?", url).Error
return
}
-func DeleteLinkByID(id int64) error {
- return DB.Where("id = ?", id).Delete(&Link{}).Error
+func (d *DkfDB) DeleteLinkByID(id int64) error {
+ return d.db.Where("id = ?", id).Delete(&Link{}).Error
}
-func GetLinks() (out []Link, err error) {
- err = DB.Find(&out).Error
+func (d *DkfDB) GetLinks() (out []Link, err error) {
+ err = d.db.Find(&out).Error
return
}
-func GetRecentLinks() (out []Link, err error) {
- err = DB.Order("id DESC").Limit(100).Find(&out).Error
+func (d *DkfDB) GetRecentLinks() (out []Link, err error) {
+ err = d.db.Order("id DESC").Limit(100).Find(&out).Error
return
}
-func GetLinkByShorthand(shorthand string) (out Link, err error) {
- err = DB.Preload("OwnerUser").First(&out, "shorthand = ?", shorthand).Error
+func (d *DkfDB) GetLinkByShorthand(shorthand string) (out Link, err error) {
+ err = d.db.Preload("OwnerUser").First(&out, "shorthand = ?", shorthand).Error
return
}
-func GetLinkByUUID(linkUUID string) (out Link, err error) {
- err = DB.Preload("OwnerUser").First(&out, "uuid = ?", linkUUID).Error
+func (d *DkfDB) GetLinkByUUID(linkUUID string) (out Link, err error) {
+ err = d.db.Preload("OwnerUser").First(&out, "uuid = ?", linkUUID).Error
return
}
-func GetLinkByID(linkID int64) (out Link, err error) {
- err = DB.First(&out, "id = ?", linkID).Error
+func (d *DkfDB) GetLinkByID(linkID int64) (out Link, err error) {
+ err = d.db.First(&out, "id = ?", linkID).Error
return
}
@@ -102,9 +102,9 @@ type LinksCategory struct {
Name string
}
-func CreateLinksCategory(category string) (out LinksCategory, err error) {
+func (d *DkfDB) CreateLinksCategory(category string) (out LinksCategory, err error) {
out = LinksCategory{Name: category}
- err = DB.FirstOrCreate(&out, "name = ?", category).Error
+ err = d.db.FirstOrCreate(&out, "name = ?", category).Error
return
}
@@ -113,9 +113,9 @@ type LinksTag struct {
Name string
}
-func CreateLinksTag(tag string) (out LinksTag, err error) {
+func (d *DkfDB) CreateLinksTag(tag string) (out LinksTag, err error) {
out = LinksTag{Name: tag}
- err = DB.FirstOrCreate(&out, "name = ?", tag).Error
+ err = d.db.FirstOrCreate(&out, "name = ?", tag).Error
return
}
@@ -124,8 +124,8 @@ type LinksTagsLink struct {
TagID int64
}
-func AddLinkTag(linkID, tagID int64) (err error) {
- return DB.Create(&LinksTagsLink{LinkID: linkID, TagID: tagID}).Error
+func (d *DkfDB) AddLinkTag(linkID, tagID int64) (err error) {
+ return d.db.Create(&LinksTagsLink{LinkID: linkID, TagID: tagID}).Error
}
type LinksCategoriesLink struct {
@@ -133,8 +133,8 @@ type LinksCategoriesLink struct {
LinkID int64
}
-func AddLinkCategory(linkID, categoryID int64) (err error) {
- return DB.Create(&LinksCategoriesLink{CategoryID: categoryID, LinkID: linkID}).Error
+func (d *DkfDB) AddLinkCategory(linkID, categoryID int64) (err error) {
+ return d.db.Create(&LinksCategoriesLink{CategoryID: categoryID, LinkID: linkID}).Error
}
type CategoriesResult struct {
@@ -142,8 +142,8 @@ type CategoriesResult struct {
Count int64
}
-func GetCategories() (out []CategoriesResult, err error) {
- err = DB.Raw(`SELECT
+func (d *DkfDB) GetCategories() (out []CategoriesResult, err error) {
+ err = d.db.Raw(`SELECT
c.name, count(cl.link_id) as count
FROM links_categories_links cl
INNER JOIN links_categories c ON c.id = cl.category_id
@@ -153,8 +153,8 @@ ORDER BY c.name`).Scan(&out).Error
return
}
-func GetLinkCategories(linkID int64) (out []LinksCategory, err error) {
- err = DB.Raw(`SELECT
+func (d *DkfDB) GetLinkCategories(linkID int64) (out []LinksCategory, err error) {
+ err = d.db.Raw(`SELECT
c.id, c.name
FROM links_categories_links cl
INNER JOIN links_categories c ON c.id = cl.category_id
@@ -163,8 +163,8 @@ ORDER BY c.name`, linkID).Scan(&out).Error
return
}
-func GetLinkTags(linkID int64) (out []LinksTag, err error) {
- err = DB.Raw(`SELECT
+func (d *DkfDB) GetLinkTags(linkID int64) (out []LinksTag, err error) {
+ err = d.db.Raw(`SELECT
t.id, t.name
FROM links_tags_links tl
INNER JOIN links_tags t ON t.id = tl.tag_id
@@ -173,12 +173,12 @@ ORDER BY t.name`, linkID).Scan(&out).Error
return
}
-func DeleteLinkCategories(linkID int64) error {
- return DB.Delete(&LinksCategoriesLink{}, "link_id = ?", linkID).Error
+func (d *DkfDB) DeleteLinkCategories(linkID int64) error {
+ return d.db.Delete(&LinksCategoriesLink{}, "link_id = ?", linkID).Error
}
-func DeleteLinkTags(linkID int64) error {
- return DB.Delete(&LinksTagsLink{}, "link_id = ?", linkID).Error
+func (d *DkfDB) DeleteLinkTags(linkID int64) error {
+ return d.db.Delete(&LinksTagsLink{}, "link_id = ?", linkID).Error
}
type LinksMirror struct {
@@ -212,50 +212,50 @@ func (l LinksPgp) GetKeyFingerprint() string {
return out
}
-func CreateLinkPgp(linkID int64, title, description, publicKey string) (out LinksPgp, err error) {
+func (d *DkfDB) CreateLinkPgp(linkID int64, title, description, publicKey string) (out LinksPgp, err error) {
out = LinksPgp{
LinkID: linkID,
Title: title,
Description: description,
PgpPublicKey: publicKey,
}
- err = DB.Create(&out).Error
+ err = d.db.Create(&out).Error
return
}
-func CreateLinkMirror(linkID int64, link string) (out LinksMirror, err error) {
+func (d *DkfDB) CreateLinkMirror(linkID int64, link string) (out LinksMirror, err error) {
out = LinksMirror{
LinkID: linkID,
MirrorURL: link,
}
- err = DB.Create(&out).Error
+ err = d.db.Create(&out).Error
return
}
-func GetLinkPgps(linkID int64) (out []LinksPgp, err error) {
- err = DB.Find(&out, "link_id = ?", linkID).Error
+func (d *DkfDB) GetLinkPgps(linkID int64) (out []LinksPgp, err error) {
+ err = d.db.Find(&out, "link_id = ?", linkID).Error
return
}
-func GetLinkMirrors(linkID int64) (out []LinksMirror, err error) {
- err = DB.Find(&out, "link_id = ?", linkID).Error
+func (d *DkfDB) GetLinkMirrors(linkID int64) (out []LinksMirror, err error) {
+ err = d.db.Find(&out, "link_id = ?", linkID).Error
return
}
-func GetLinkPgpByID(id int64) (out LinksPgp, err error) {
- err = DB.First(&out, "id = ?", id).Error
+func (d *DkfDB) GetLinkPgpByID(id int64) (out LinksPgp, err error) {
+ err = d.db.First(&out, "id = ?", id).Error
return
}
-func GetLinkMirrorByID(id int64) (out LinksMirror, err error) {
- err = DB.First(&out, "id = ?", id).Error
+func (d *DkfDB) GetLinkMirrorByID(id int64) (out LinksMirror, err error) {
+ err = d.db.First(&out, "id = ?", id).Error
return
}
-func DeleteLinkPgpByID(id int64) error {
- return DB.Where("id = ?", id).Delete(&LinksPgp{}).Error
+func (d *DkfDB) DeleteLinkPgpByID(id int64) error {
+ return d.db.Where("id = ?", id).Delete(&LinksPgp{}).Error
}
-func DeleteLinkMirrorByID(id int64) error {
- return DB.Where("id = ?", id).Delete(&LinksMirror{}).Error
+func (d *DkfDB) DeleteLinkMirrorByID(id int64) error {
+ return d.db.Where("id = ?", id).Delete(&LinksMirror{}).Error
}
diff --git a/pkg/database/tableNotifications.go b/pkg/database/tableNotifications.go
@@ -26,8 +26,8 @@ type SessionNotification struct {
User User
}
-func GetUserNotifications(userID UserID) (msgs []Notification, err error) {
- err = DB.Order("id DESC").
+func (d *DkfDB) GetUserNotifications(userID UserID) (msgs []Notification, err error) {
+ err = d.db.Order("id DESC").
Limit(50).
Preload("User").
Find(&msgs, "user_id = ?", userID).Error
@@ -36,15 +36,15 @@ func GetUserNotifications(userID UserID) (msgs []Notification, err error) {
ids = append(ids, msg.ID)
}
now := time.Now()
- if err := DB.Model(&Notification{}).Where("id IN (?)", ids).
+ if err := d.db.Model(&Notification{}).Where("id IN (?)", ids).
UpdateColumn("is_read", true, "read_at", &now).Error; err != nil {
logrus.Error(err)
}
return
}
-func GetUserSessionNotifications(sessionToken string) (msgs []SessionNotification, err error) {
- err = DB.Order("id DESC").
+func (d *DkfDB) GetUserSessionNotifications(sessionToken string) (msgs []SessionNotification, err error) {
+ err = d.db.Order("id DESC").
Limit(50).
Joins("INNER JOIN sessions s ON s.token = session_token").
Joins("INNER JOIN users u ON u.id = s.user_id").
@@ -54,49 +54,49 @@ func GetUserSessionNotifications(sessionToken string) (msgs []SessionNotificatio
ids = append(ids, msg.ID)
}
now := time.Now()
- if err := DB.Table("session_notifications").Where("id IN (?)", ids).
+ if err := d.db.Table("session_notifications").Where("id IN (?)", ids).
UpdateColumn("is_read", true, "read_at", &now).Error; err != nil {
logrus.Error(err)
}
return
}
-func DeleteNotificationByID(notificationID int64) error {
- return DB.Where("id = ?", notificationID).Delete(&Notification{}).Error
+func (d *DkfDB) DeleteNotificationByID(notificationID int64) error {
+ return d.db.Where("id = ?", notificationID).Delete(&Notification{}).Error
}
-func DeleteSessionNotificationByID(sessionNotificationID int64) error {
- return DB.Where("id = ?", sessionNotificationID).Delete(&SessionNotification{}).Error
+func (d *DkfDB) DeleteSessionNotificationByID(sessionNotificationID int64) error {
+ return d.db.Where("id = ?", sessionNotificationID).Delete(&SessionNotification{}).Error
}
-func DeleteAllNotifications(userID UserID) error {
- return DB.Where("user_id = ?", userID).Delete(&Notification{}).Error
+func (d *DkfDB) DeleteAllNotifications(userID UserID) error {
+ return d.db.Where("user_id = ?", userID).Delete(&Notification{}).Error
}
-func CreateNotification(msg string, userID UserID) {
+func (d *DkfDB) CreateNotification(msg string, userID UserID) {
inbox := Notification{Message: msg, UserID: userID, IsRead: false}
- if err := DB.Create(&inbox).Error; err != nil {
+ if err := d.db.Create(&inbox).Error; err != nil {
logrus.Error(err)
}
}
-func GetUserNotificationsCount(userID UserID) (count int64) {
- DB.Table("notifications").Where("user_id = ? AND is_read = ?", userID, false).Count(&count)
+func (d *DkfDB) GetUserNotificationsCount(userID UserID) (count int64) {
+ d.db.Table("notifications").Where("user_id = ? AND is_read = ?", userID, false).Count(&count)
return
}
-func GetUserSessionNotificationsCount(sessionToken string) (count int64) {
- DB.Table("session_notifications").Where("session_token = ? AND is_read = ?", sessionToken, false).Count(&count)
+func (d *DkfDB) GetUserSessionNotificationsCount(sessionToken string) (count int64) {
+ d.db.Table("session_notifications").Where("session_token = ? AND is_read = ?", sessionToken, false).Count(&count)
return
}
-func CreateSessionNotification(msg string, sessionToken string) {
+func (d *DkfDB) CreateSessionNotification(msg string, sessionToken string) {
inbox := SessionNotification{Message: msg, SessionToken: sessionToken, IsRead: false}
- if err := DB.Create(&inbox).Error; err != nil {
+ if err := d.db.Create(&inbox).Error; err != nil {
logrus.Error(err)
}
}
-func DeleteAllSessionNotifications(sessionToken string) error {
- return DB.Where("session_token = ?", sessionToken).Delete(&SessionNotification{}).Error
+func (d *DkfDB) DeleteAllSessionNotifications(sessionToken string) error {
+ return d.db.Where("session_token = ?", sessionToken).Delete(&SessionNotification{}).Error
}
diff --git a/pkg/database/tableOnionBlacklist.go b/pkg/database/tableOnionBlacklist.go
@@ -7,7 +7,7 @@ type OnionBlacklist struct {
CreatedAt time.Time
}
-func GetOnionBlacklist(hash string) (out OnionBlacklist, err error) {
- err = DB.First(&out, "md5 = ?", hash).Error
+func (d *DkfDB) GetOnionBlacklist(hash string) (out OnionBlacklist, err error) {
+ err = d.db.First(&out, "md5 = ?", hash).Error
return
}
diff --git a/pkg/database/tablePmBlacklistedUsers.go b/pkg/database/tablePmBlacklistedUsers.go
@@ -11,43 +11,43 @@ type PmBlacklistedUsers struct {
}
// IsUserPmBlacklisted returns either or not toUserID blacklisted fromUserID
-func IsUserPmBlacklisted(fromUserID, toUserID UserID) bool {
+func (d *DkfDB) IsUserPmBlacklisted(fromUserID, toUserID UserID) bool {
var count int64
- DB.Model(&PmBlacklistedUsers{}).Where("blacklisted_user_id = ? AND user_id = ?", fromUserID, toUserID).Count(&count)
+ d.db.Model(&PmBlacklistedUsers{}).Where("blacklisted_user_id = ? AND user_id = ?", fromUserID, toUserID).Count(&count)
return count == 1
}
// GetPmBlacklistedUsers returns a list of userID blacklisted users
-func GetPmBlacklistedUsers(userID UserID) (out []PmBlacklistedUsers, err error) {
- err = DB.Where("user_id = ?", userID).Preload("BlacklistedUser").Find(&out).Error
+func (d *DkfDB) GetPmBlacklistedUsers(userID UserID) (out []PmBlacklistedUsers, err error) {
+ err = d.db.Where("user_id = ?", userID).Preload("BlacklistedUser").Find(&out).Error
return
}
// GetPmBlacklistedByUsers returns a list of users that are blacklisting userID
-func GetPmBlacklistedByUsers(userID UserID) (out []PmBlacklistedUsers, err error) {
- err = DB.Where("blacklisted_user_id = ?", userID).Find(&out).Error
+func (d *DkfDB) GetPmBlacklistedByUsers(userID UserID) (out []PmBlacklistedUsers, err error) {
+ err = d.db.Where("blacklisted_user_id = ?", userID).Find(&out).Error
return
}
// ToggleBlacklistedUser returns true if the user was added to the blacklist
-func ToggleBlacklistedUser(userID, blacklistedUserID UserID) bool {
- if IsUserPmBlacklisted(blacklistedUserID, userID) {
- RmBlacklistedUser(userID, blacklistedUserID)
+func (d *DkfDB) ToggleBlacklistedUser(userID, blacklistedUserID UserID) bool {
+ if d.IsUserPmBlacklisted(blacklistedUserID, userID) {
+ d.RmBlacklistedUser(userID, blacklistedUserID)
return false
}
- AddBlacklistedUser(userID, blacklistedUserID)
+ d.AddBlacklistedUser(userID, blacklistedUserID)
return true
}
-func AddBlacklistedUser(userID, blacklistedUserID UserID) {
+func (d *DkfDB) AddBlacklistedUser(userID, blacklistedUserID UserID) {
ignore := PmBlacklistedUsers{UserID: userID, BlacklistedUserID: blacklistedUserID}
- if err := DB.Create(&ignore).Error; err != nil {
+ if err := d.db.Create(&ignore).Error; err != nil {
logrus.Error(err)
}
}
-func RmBlacklistedUser(userID, blacklistedUserID UserID) {
- if err := DB.Delete(PmBlacklistedUsers{}, "user_id = ? AND blacklisted_user_id = ?", userID, blacklistedUserID).Error; err != nil {
+func (d *DkfDB) RmBlacklistedUser(userID, blacklistedUserID UserID) {
+ if err := d.db.Delete(PmBlacklistedUsers{}, "user_id = ? AND blacklisted_user_id = ?", userID, blacklistedUserID).Error; err != nil {
logrus.Error(err)
}
}
diff --git a/pkg/database/tablePmWhitelistedUsers.go b/pkg/database/tablePmWhitelistedUsers.go
@@ -10,36 +10,36 @@ type PmWhitelistedUsers struct {
WhitelistedUser User
}
-func IsUserPmWhitelisted(fromUserID, toUserID UserID) bool {
+func (d *DkfDB) IsUserPmWhitelisted(fromUserID, toUserID UserID) bool {
var count int64
- DB.Model(&PmWhitelistedUsers{}).Where("whitelisted_user_id = ? AND user_id = ?", fromUserID, toUserID).Count(&count)
+ d.db.Model(&PmWhitelistedUsers{}).Where("whitelisted_user_id = ? AND user_id = ?", fromUserID, toUserID).Count(&count)
return count == 1
}
-func GetPmWhitelistedUsers(userID UserID) (out []PmWhitelistedUsers, err error) {
- err = DB.Where("user_id = ?", userID).Preload("WhitelistedUser").Find(&out).Error
+func (d *DkfDB) GetPmWhitelistedUsers(userID UserID) (out []PmWhitelistedUsers, err error) {
+ err = d.db.Where("user_id = ?", userID).Preload("WhitelistedUser").Find(&out).Error
return
}
// ToggleWhitelistedUser returns true if the user was added to the whitelist
-func ToggleWhitelistedUser(userID, whitelistedUserID UserID) bool {
- if IsUserPmWhitelisted(whitelistedUserID, userID) {
- RmWhitelistedUser(userID, whitelistedUserID)
+func (d *DkfDB) ToggleWhitelistedUser(userID, whitelistedUserID UserID) bool {
+ if d.IsUserPmWhitelisted(whitelistedUserID, userID) {
+ d.RmWhitelistedUser(userID, whitelistedUserID)
return false
}
- AddWhitelistedUser(userID, whitelistedUserID)
+ d.AddWhitelistedUser(userID, whitelistedUserID)
return true
}
-func AddWhitelistedUser(userID, whitelistedUserID UserID) {
+func (d *DkfDB) AddWhitelistedUser(userID, whitelistedUserID UserID) {
ignore := PmWhitelistedUsers{UserID: userID, WhitelistedUserID: whitelistedUserID}
- if err := DB.Create(&ignore).Error; err != nil {
+ if err := d.db.Create(&ignore).Error; err != nil {
logrus.Error(err)
}
}
-func RmWhitelistedUser(userID, whitelistedUserID UserID) {
- if err := DB.Delete(PmWhitelistedUsers{}, "user_id = ? AND whitelisted_user_id = ?", userID, whitelistedUserID).Error; err != nil {
+func (d *DkfDB) RmWhitelistedUser(userID, whitelistedUserID UserID) {
+ if err := d.db.Delete(PmWhitelistedUsers{}, "user_id = ? AND whitelisted_user_id = ?", userID, whitelistedUserID).Error; err != nil {
logrus.Error(err)
}
}
diff --git a/pkg/database/tableProhibitedPasswords.go b/pkg/database/tableProhibitedPasswords.go
@@ -4,8 +4,8 @@ type ProhibitedPassword struct {
Password string
}
-func IsPasswordProhibited(password string) bool {
+func (d *DkfDB) IsPasswordProhibited(password string) bool {
var count int
- DB.Table("prohibited_passwords").Where("password = ?", password).Count(&count)
+ d.db.Table("prohibited_passwords").Where("password = ?", password).Count(&count)
return count > 0
}
diff --git a/pkg/database/tableSecurityLogs.go b/pkg/database/tableSecurityLogs.go
@@ -56,24 +56,24 @@ func getMessageForType(typ int64) string {
return ""
}
-func CreateSecurityLog(userID UserID, typ int64) {
+func (d *DkfDB) CreateSecurityLog(userID UserID, typ int64) {
log := SecurityLog{
Message: getMessageForType(typ),
UserID: userID,
Typ: typ,
}
- if err := DB.Create(&log).Error; err != nil {
+ if err := d.db.Create(&log).Error; err != nil {
logrus.Error(err)
}
}
-func GetSecurityLogs(userID UserID) (out []SecurityLog, err error) {
- err = DB.Order("id DESC").Find(&out, "user_id = ?", userID).Error
+func (d *DkfDB) GetSecurityLogs(userID UserID) (out []SecurityLog, err error) {
+ err = d.db.Order("id DESC").Find(&out, "user_id = ?", userID).Error
return
}
-func DeleteOldSecurityLogs() {
- if err := DB.Delete(SecurityLog{}, "created_at < date('now', '-7 Day')").Error; err != nil {
+func (d *DkfDB) DeleteOldSecurityLogs() {
+ if err := d.db.Delete(SecurityLog{}, "created_at < date('now', '-7 Day')").Error; err != nil {
logrus.Error(err)
}
}
diff --git a/pkg/database/tableSessions.go b/pkg/database/tableSessions.go
@@ -20,15 +20,15 @@ type Session struct {
}
// GetActiveUserSessions gets all user sessions
-func GetActiveUserSessions(userID UserID) (out []Session) {
- DB.Order("created_at DESC").Find(&out, "user_id = ? AND expires_at > DATETIME('now') AND deleted_at IS NULL", userID)
+func (d *DkfDB) GetActiveUserSessions(userID UserID) (out []Session) {
+ d.db.Order("created_at DESC").Find(&out, "user_id = ? AND expires_at > DATETIME('now') AND deleted_at IS NULL", userID)
return
}
// CreateSession creates a session for a user
-func CreateSession(userID UserID, userAgent string) (Session, error) {
+func (d *DkfDB) CreateSession(userID UserID, userAgent string) (Session, error) {
// Delete all sessions except the last 4
- if err := DB.Exec(`DELETE FROM sessions WHERE user_id = ? AND token NOT IN (SELECT s2.token FROM sessions s2 WHERE s2.user_id = ? ORDER BY s2.created_at DESC LIMIT 4)`, userID, userID).Error; err != nil {
+ if err := d.db.Exec(`DELETE FROM sessions WHERE user_id = ? AND token NOT IN (SELECT s2.token FROM sessions s2 WHERE s2.user_id = ? ORDER BY s2.created_at DESC LIMIT 4)`, userID, userID).Error; err != nil {
logrus.Error(err)
}
session := Session{
@@ -38,13 +38,13 @@ func CreateSession(userID UserID, userAgent string) (Session, error) {
UserAgent: userAgent,
ExpiresAt: time.Now().Add(time.Duration(utils.OneMonthSecs) * time.Second),
}
- err := DB.Create(&session).Error
+ err := d.db.Create(&session).Error
return session, err
}
// DoCreateSession same as CreateSession but log the error instead of returning it
-func DoCreateSession(userID UserID, userAgent string) Session {
- session, err := CreateSession(userID, userAgent)
+func (d *DkfDB) DoCreateSession(userID UserID, userAgent string) Session {
+ session, err := d.CreateSession(userID, userAgent)
if err != nil {
logrus.Error("Failed to create session : ", err)
}
@@ -52,25 +52,25 @@ func DoCreateSession(userID UserID, userAgent string) Session {
}
// DeleteUserSessions all sessions of the user.
-func DeleteUserSessions(userID UserID) error {
- return DB.Unscoped().Where("user_id = ?", userID).Delete(&Session{}).Error
+func (d *DkfDB) DeleteUserSessions(userID UserID) error {
+ return d.db.Unscoped().Where("user_id = ?", userID).Delete(&Session{}).Error
}
// DeleteSessionByToken a session by its token
-func DeleteSessionByToken(token string) error {
- return DB.Unscoped().Where("token = ?", token).Delete(&Session{}).Error
+func (d *DkfDB) DeleteSessionByToken(token string) error {
+ return d.db.Unscoped().Where("token = ?", token).Delete(&Session{}).Error
}
-func DeleteUserSessionByToken(userID UserID, token string) error {
- return DB.Unscoped().Where("user_id = ? AND token = ?", userID, token).Delete(&Session{}).Error
+func (d *DkfDB) DeleteUserSessionByToken(userID UserID, token string) error {
+ return d.db.Unscoped().Where("user_id = ? AND token = ?", userID, token).Delete(&Session{}).Error
}
-func DeleteUserOtherSessions(userID UserID, currentToken string) error {
- return DB.Unscoped().Where("user_id = ? AND token != ?", userID, currentToken).Delete(&Session{}).Error
+func (d *DkfDB) DeleteUserOtherSessions(userID UserID, currentToken string) error {
+ return d.db.Unscoped().Where("user_id = ? AND token != ?", userID, currentToken).Delete(&Session{}).Error
}
-func DeleteOldSessions() {
- if err := DB.Unscoped().Delete(Session{}, "expires_at < date('now', '-32 Day') OR (expires_at < date('now', '-32 Day') AND deleted_at IS NOT NULL)").Error; err != nil {
+func (d *DkfDB) DeleteOldSessions() {
+ if err := d.db.Unscoped().Delete(Session{}, "expires_at < date('now', '-32 Day') OR (expires_at < date('now', '-32 Day') AND deleted_at IS NOT NULL)").Error; err != nil {
logrus.Error(err)
}
}
diff --git a/pkg/database/tableSettings.go b/pkg/database/tableSettings.go
@@ -18,27 +18,27 @@ type Settings struct {
}
// GetSettings get the saved settings from the DB
-func GetSettings() (out Settings) {
- if err := DB.Model(Settings{}).First(&out).Error; err != nil {
+func (d *DkfDB) GetSettings() (out Settings) {
+ if err := d.db.Model(Settings{}).First(&out).Error; err != nil {
out.SignupEnabled = true
out.SilentSelfKick = true
out.ForumEnabled = true
out.MaybeAuthEnabled = true
out.DownloadsEnabled = true
out.CaptchaDifficulty = 2
- DB.Create(&out)
+ d.db.Create(&out)
}
return
}
// Save the settings to DB
-func (s *Settings) Save() error {
- return DB.Save(s).Error
+func (s *Settings) Save(db *DkfDB) error {
+ return db.db.Save(s).Error
}
// DoSave settings in the database, ignore error
-func (s *Settings) DoSave() {
- if err := s.Save(); err != nil {
+func (s *Settings) DoSave(db *DkfDB) {
+ if err := s.Save(db); err != nil {
logrus.Error(err)
}
}
diff --git a/pkg/database/tableSnippets.go b/pkg/database/tableSnippets.go
@@ -8,23 +8,23 @@ type Snippet struct {
Text string
}
-func GetUserSnippets(userID UserID) (out []Snippet, err error) {
- err = DB.Find(&out, "user_id = ?", userID).Error
+func (d *DkfDB) GetUserSnippets(userID UserID) (out []Snippet, err error) {
+ err = d.db.Find(&out, "user_id = ?", userID).Error
return
}
-func CreateSnippet(userID UserID, name, text string) (out Snippet, err error) {
+func (d *DkfDB) CreateSnippet(userID UserID, name, text string) (out Snippet, err error) {
out = Snippet{
Name: name,
UserID: userID,
Text: text,
}
- err = DB.Create(&out).Error
+ err = d.db.Create(&out).Error
return
}
-func DeleteSnippet(userID UserID, name string) {
- if err := DB.Delete(Snippet{}, "user_id = ? AND name = ?", userID, name).Error; err != nil {
+func (d *DkfDB) DeleteSnippet(userID UserID, name string) {
+ if err := d.db.Delete(Snippet{}, "user_id = ? AND name = ?", userID, name).Error; err != nil {
logrus.Error(err)
}
}
diff --git a/pkg/database/tableUploads.go b/pkg/database/tableUploads.go
@@ -56,30 +56,30 @@ func (u *Upload) Exists() bool {
return utils.FileExists(filePath1)
}
-func (u *Upload) Delete() error {
+func (u *Upload) Delete(db *DkfDB) error {
if err := os.Remove(filepath.Join(config.Global.ProjectUploadsPath(), u.FileName)); err != nil {
return err
}
- if err := DB.Delete(&u).Error; err != nil {
+ if err := db.db.Delete(&u).Error; err != nil {
return err
}
return nil
}
// CreateUpload create file on disk in "uploads" folder, and save upload in database as well.
-func CreateUpload(fileName string, content []byte, userID UserID) (*Upload, error) {
- return createUploadWithSize(fileName, content, userID, int64(len(content)))
+func (d *DkfDB) CreateUpload(fileName string, content []byte, userID UserID) (*Upload, error) {
+ return d.createUploadWithSize(fileName, content, userID, int64(len(content)))
}
-func CreateEncryptedUploadWithSize(fileName string, content []byte, userID UserID, size int64) (*Upload, error) {
+func (d *DkfDB) CreateEncryptedUploadWithSize(fileName string, content []byte, userID UserID, size int64) (*Upload, error) {
encryptedContent, err := utils.EncryptAESMaster(content)
if err != nil {
return nil, err
}
- return createUploadWithSize(fileName, encryptedContent, userID, size)
+ return d.createUploadWithSize(fileName, encryptedContent, userID, size)
}
-func createUploadWithSize(fileName string, content []byte, userID UserID, size int64) (*Upload, error) {
+func (d *DkfDB) createUploadWithSize(fileName string, content []byte, userID UserID, size int64) (*Upload, error) {
newFileName := utils.MD5([]byte(utils.GenerateToken32()))
if err := ioutil.WriteFile(filepath.Join(config.Global.ProjectUploadsPath(), newFileName), content, 0644); err != nil {
return nil, err
@@ -90,42 +90,42 @@ func createUploadWithSize(fileName string, content []byte, userID UserID, size i
OrigFileName: fileName,
FileSize: size,
}
- if err := DB.Create(&upload).Error; err != nil {
+ if err := d.db.Create(&upload).Error; err != nil {
logrus.Error(err)
}
return &upload, nil
}
-func GetUploadByFileName(filename string) (out Upload, err error) {
- err = DB.First(&out, "file_name = ?", filename).Error
+func (d *DkfDB) GetUploadByFileName(filename string) (out Upload, err error) {
+ err = d.db.First(&out, "file_name = ?", filename).Error
return
}
-func GetUploadByID(uploadID UploadID) (out Upload, err error) {
- err = DB.First(&out, "id = ?", uploadID).Error
+func (d *DkfDB) GetUploadByID(uploadID UploadID) (out Upload, err error) {
+ err = d.db.First(&out, "id = ?", uploadID).Error
return
}
-func GetUploads() (out []Upload, err error) {
- err = DB.Preload("User").Order("id DESC").Find(&out).Error
+func (d *DkfDB) GetUploads() (out []Upload, err error) {
+ err = d.db.Preload("User").Order("id DESC").Find(&out).Error
return
}
-func GetUserUploads(userID UserID) (out []Upload, err error) {
- err = DB.Order("id DESC").Find(&out, "user_id = ?", userID).Error
+func (d *DkfDB) GetUserUploads(userID UserID) (out []Upload, err error) {
+ err = d.db.Order("id DESC").Find(&out, "user_id = ?", userID).Error
return
}
-func GetUserTotalUploadSize(userID UserID) int64 {
+func (d *DkfDB) GetUserTotalUploadSize(userID UserID) int64 {
var out struct{ TotalSize int64 }
- if err := DB.Raw(`SELECT SUM(file_size) as total_size FROM uploads WHERE user_id = ?`, userID).Scan(&out).Error; err != nil {
+ if err := d.db.Raw(`SELECT SUM(file_size) as total_size FROM uploads WHERE user_id = ?`, userID).Scan(&out).Error; err != nil {
logrus.Error(err)
}
return out.TotalSize
}
-func DeleteOldUploads() {
- if err := DB.Exec(`DELETE FROM uploads WHERE created_at < date('now', '-1 Day')`).Error; err != nil {
+func (d *DkfDB) DeleteOldUploads() {
+ if err := d.db.Exec(`DELETE FROM uploads WHERE created_at < date('now', '-1 Day')`).Error; err != nil {
logrus.Error(err.Error())
}
fileInfo, err := ioutil.ReadDir(config.Global.ProjectUploadsPath())
diff --git a/pkg/database/tableUserForumThreadSubscriptions.go b/pkg/database/tableUserForumThreadSubscriptions.go
@@ -13,27 +13,27 @@ type UserForumThreadSubscription struct {
User User
}
-func (s *UserForumThreadSubscription) DoSave() {
- if err := DB.Save(s).Error; err != nil {
+func (s *UserForumThreadSubscription) DoSave(db *DkfDB) {
+ if err := db.db.Save(s).Error; err != nil {
logrus.Error(err)
}
}
-func SubscribeToForumThread(userID UserID, threadID ForumThreadID) (err error) {
- return DB.Create(&UserForumThreadSubscription{UserID: userID, ThreadID: threadID}).Error
+func (d *DkfDB) SubscribeToForumThread(userID UserID, threadID ForumThreadID) (err error) {
+ return d.db.Create(&UserForumThreadSubscription{UserID: userID, ThreadID: threadID}).Error
}
-func UnsubscribeFromForumThread(userID UserID, threadID ForumThreadID) (err error) {
- return DB.Delete(&UserForumThreadSubscription{}, "user_id = ? AND thread_id = ?", userID, threadID).Error
+func (d *DkfDB) UnsubscribeFromForumThread(userID UserID, threadID ForumThreadID) (err error) {
+ return d.db.Delete(&UserForumThreadSubscription{}, "user_id = ? AND thread_id = ?", userID, threadID).Error
}
-func IsUserSubscribedToForumThread(userID UserID, threadID ForumThreadID) bool {
+func (d *DkfDB) IsUserSubscribedToForumThread(userID UserID, threadID ForumThreadID) bool {
var count int64
- DB.Model(UserForumThreadSubscription{}).Where("user_id = ? AND thread_id = ?", userID, threadID).Count(&count)
+ d.db.Model(UserForumThreadSubscription{}).Where("user_id = ? AND thread_id = ?", userID, threadID).Count(&count)
return count == 1
}
-func GetUsersSubscribedToForumThread(threadID ForumThreadID) (out []UserForumThreadSubscription, err error) {
- err = DB.Preload("User").Find(&out, "thread_id = ?", threadID).Error
+func (d *DkfDB) GetUsersSubscribedToForumThread(threadID ForumThreadID) (out []UserForumThreadSubscription, err error) {
+ err = d.db.Preload("User").Find(&out, "thread_id = ?", threadID).Error
return
}
diff --git a/pkg/database/tableUserPrivateNotes.go b/pkg/database/tableUserPrivateNotes.go
@@ -14,19 +14,19 @@ type UserPrivateNote struct {
UpdatedAt time.Time
}
-func GetUserPrivateNotes(userID UserID) (out UserPrivateNote, err error) {
- err = DB.First(&out, "user_id = ?", userID).Error
+func (d *DkfDB) GetUserPrivateNotes(userID UserID) (out UserPrivateNote, err error) {
+ err = d.db.First(&out, "user_id = ?", userID).Error
return
}
-func SetUserPrivateNotes(userID UserID, notes string) error {
+func (d *DkfDB) SetUserPrivateNotes(userID UserID, notes string) error {
if !govalidator.RuneLength(notes, "0", "10000") {
return errors.New("notes must have 10000 characters maximum")
}
n := UserPrivateNote{UserID: userID}
- if err := DB.FirstOrCreate(&n, "user_id = ?", userID).Error; err != nil {
+ if err := d.db.FirstOrCreate(&n, "user_id = ?", userID).Error; err != nil {
return err
}
n.Notes = EncryptedString(notes)
- return DB.Save(&n).Error
+ return d.db.Save(&n).Error
}
diff --git a/pkg/database/tableUserPublicNotes.go b/pkg/database/tableUserPublicNotes.go
@@ -14,19 +14,19 @@ type UserPublicNote struct {
UpdatedAt time.Time
}
-func GetUserPublicNotes(userID UserID) (out UserPublicNote, err error) {
- err = DB.First(&out, "user_id = ?", userID).Error
+func (d *DkfDB) GetUserPublicNotes(userID UserID) (out UserPublicNote, err error) {
+ err = d.db.First(&out, "user_id = ?", userID).Error
return
}
-func SetUserPublicNotes(userID UserID, notes string) error {
+func (d *DkfDB) SetUserPublicNotes(userID UserID, notes string) error {
if !govalidator.RuneLength(notes, "0", "10000") {
return errors.New("notes must have 10000 characters maximum")
}
n := UserPublicNote{UserID: userID}
- if err := DB.FirstOrCreate(&n, "user_id = ?", userID).Error; err != nil {
+ if err := d.db.FirstOrCreate(&n, "user_id = ?", userID).Error; err != nil {
return err
}
n.Notes = notes
- return DB.Save(&n).Error
+ return d.db.Save(&n).Error
}
diff --git a/pkg/database/tableUserRoomSubscriptions.go b/pkg/database/tableUserRoomSubscriptions.go
@@ -13,22 +13,22 @@ type UserRoomSubscription struct {
Room ChatRoom
}
-func (s *UserRoomSubscription) DoSave() {
- if err := DB.Save(s).Error; err != nil {
+func (s *UserRoomSubscription) DoSave(db *DkfDB) {
+ if err := db.db.Save(s).Error; err != nil {
logrus.Error(err)
}
}
-func SubscribeToRoom(userID UserID, roomID RoomID) (err error) {
- return DB.Create(&UserRoomSubscription{UserID: userID, RoomID: roomID}).Error
+func (d *DkfDB) SubscribeToRoom(userID UserID, roomID RoomID) (err error) {
+ return d.db.Create(&UserRoomSubscription{UserID: userID, RoomID: roomID}).Error
}
-func UnsubscribeFromRoom(userID UserID, roomID RoomID) (err error) {
- return DB.Delete(&UserRoomSubscription{}, "user_id = ? AND room_id = ?", userID, roomID).Error
+func (d *DkfDB) UnsubscribeFromRoom(userID UserID, roomID RoomID) (err error) {
+ return d.db.Delete(&UserRoomSubscription{}, "user_id = ? AND room_id = ?", userID, roomID).Error
}
-func IsUserSubscribedToRoom(userID UserID, roomID RoomID) bool {
+func (d *DkfDB) IsUserSubscribedToRoom(userID UserID, roomID RoomID) bool {
var count int64
- DB.Model(UserRoomSubscription{}).Where("user_id = ? AND room_id = ?", userID, roomID).Count(&count)
+ d.db.Model(UserRoomSubscription{}).Where("user_id = ? AND room_id = ?", userID, roomID).Count(&count)
return count == 1
}
diff --git a/pkg/database/tableUsers.go b/pkg/database/tableUsers.go
@@ -125,8 +125,8 @@ func UserPtrID(user *User) *UserID {
return nil
}
-func GetChessSubscribers() (out []User, err error) {
- err = DB.Find(&out, "notify_chess_games == 1").Error
+func (d *DkfDB) GetChessSubscribers() (out []User, err error) {
+ err = d.db.Find(&out, "notify_chess_games == 1").Error
return
}
@@ -250,84 +250,84 @@ func (u *User) GetHellbanOpacityF64() float64 {
}
// Save user in the database
-func (u *User) Save() error {
- return DB.Save(u).Error
+func (u *User) Save(db *DkfDB) error {
+ return db.db.Save(u).Error
}
// DoSave user in the database, ignore error
-func (u *User) DoSave() {
- if err := u.Save(); err != nil {
+func (u *User) DoSave(db *DkfDB) {
+ if err := u.Save(db); err != nil {
logrus.Error(err)
}
}
-func (u *User) HellBan() {
+func (u *User) HellBan(db *DkfDB) {
u.IsHellbanned = true
- u.DoSave()
- if err := DB.Model(&ChatMessage{}).Where("user_id = ?", u.ID).Update("is_hellbanned", true).Error; err != nil {
+ u.DoSave(db)
+ if err := db.db.Model(&ChatMessage{}).Where("user_id = ?", u.ID).Update("is_hellbanned", true).Error; err != nil {
logrus.Error(err)
}
}
-func (u *User) UnHellBan() {
+func (u *User) UnHellBan(db *DkfDB) {
u.IsHellbanned = false
- u.DoSave()
- if err := DB.Model(&ChatMessage{}).Where("user_id = ?", u.ID).Update("is_hellbanned", false).Error; err != nil {
+ u.DoSave(db)
+ if err := db.db.Model(&ChatMessage{}).Where("user_id = ?", u.ID).Update("is_hellbanned", false).Error; err != nil {
logrus.Error(err)
}
}
// GetUserBySessionKey ...
-func GetUserBySessionKey(user *User, sessionKey string) error {
- return DB.Joins("INNER JOIN sessions s ON s.token = ? AND s.expires_at > DATETIME('now') and s.deleted_at IS NULL AND s.user_id = users.id").
+func (d *DkfDB) GetUserBySessionKey(user *User, sessionKey string) error {
+ return d.db.Joins("INNER JOIN sessions s ON s.token = ? AND s.expires_at > DATETIME('now') and s.deleted_at IS NULL AND s.user_id = users.id").
Where("users.verified = 1", sessionKey).
First(user).Error
}
// GetUserByApiKey ...
-func GetUserByApiKey(user *User, apiKey string) error {
- return DB.First(user, "api_key = ?", apiKey).Error
+func (d *DkfDB) GetUserByApiKey(user *User, apiKey string) error {
+ return d.db.First(user, "api_key = ?", apiKey).Error
}
// GetUserByID ...
-func GetUserByID(userID UserID) (out User, err error) {
- err = DB.First(&out, "id = ?", userID).Error
+func (d *DkfDB) GetUserByID(userID UserID) (out User, err error) {
+ err = d.db.First(&out, "id = ?", userID).Error
return
}
// GetUserByUsername ...
-func GetUserByUsername(username string) (out User, err error) {
- err = DB.First(&out, "username = ? COLLATE NOCASE", username).Error
+func (d *DkfDB) GetUserByUsername(username string) (out User, err error) {
+ err = d.db.First(&out, "username = ? COLLATE NOCASE", username).Error
return
}
-func GetVerifiedUserByUsername(username string) (out User, err error) {
- err = DB.First(&out, "username = ? COLLATE NOCASE AND verified = 1", username).Error
+func (d *DkfDB) GetVerifiedUserByUsername(username string) (out User, err error) {
+ err = d.db.First(&out, "username = ? COLLATE NOCASE AND verified = 1", username).Error
return
}
-func GetUsersByID(ids []UserID) (out []User, err error) {
- err = DB.Find(&out, "id IN (?)", ids).Error
+func (d *DkfDB) GetUsersByID(ids []UserID) (out []User, err error) {
+ err = d.db.Find(&out, "id IN (?)", ids).Error
return
}
-func GetUsersByUsername(usernames []string) (out []User, err error) {
- err = DB.Find(&out, "username IN (?)", usernames).Error
+func (d *DkfDB) GetUsersByUsername(usernames []string) (out []User, err error) {
+ err = d.db.Find(&out, "username IN (?)", usernames).Error
return
}
-func DeleteUserByID(userID UserID) (err error) {
- err = DB.Unscoped().Delete(User{}, "id = ?", userID).Error
+func (d *DkfDB) DeleteUserByID(userID UserID) (err error) {
+ err = d.db.Unscoped().Delete(User{}, "id = ?", userID).Error
return
}
-func GetModeratorsUsers() (out []User, err error) {
- err = DB.Order("username ASC").Find(&out, "role = ? OR is_admin = 1", "moderator").Error
+func (d *DkfDB) GetModeratorsUsers() (out []User, err error) {
+ err = d.db.Order("username ASC").Find(&out, "role = ? OR is_admin = 1", "moderator").Error
return
}
-func GetClubMembers() (out []User, err error) {
- err = DB.Find(&out, "is_club_member = ?", true).Error
+func (d *DkfDB) GetClubMembers() (out []User, err error) {
+ err = d.db.Find(&out, "is_club_member = ?", true).Error
return
}
@@ -336,40 +336,40 @@ func GetClubMembers() (out []User, err error) {
// Assume I realize I left myself logged into a shared computer.
// I change my password to protect myself.
// The session on the public computer needs to be invalidated.
-func (u *User) ChangePassword(hashedPassword string) error {
+func (u *User) ChangePassword(db *DkfDB, hashedPassword string) error {
u.Password = hashedPassword
- if err := DB.Save(u).Error; err != nil {
+ if err := db.db.Save(u).Error; err != nil {
return err
}
// Delete active user sessions
- if err := DeleteUserSessions(u.ID); err != nil {
+ if err := db.DeleteUserSessions(u.ID); err != nil {
return err
}
return nil
}
-func (u *User) ChangeDuressPassword(hashedDuressPassword string) error {
+func (u *User) ChangeDuressPassword(db *DkfDB, hashedDuressPassword string) error {
u.DuressPassword = hashedDuressPassword
- if err := DB.Save(u).Error; err != nil {
+ if err := db.db.Save(u).Error; err != nil {
return err
}
// Delete active user sessions
- if err := DeleteUserSessions(u.ID); err != nil {
+ if err := db.DeleteUserSessions(u.ID); err != nil {
return err
}
return nil
}
-func (u *User) CheckPassword(password string) bool {
+func (u *User) CheckPassword(db *DkfDB, password string) bool {
if err := bcrypt.CompareHashAndPassword([]byte(u.Password), []byte(password)); err != nil {
if err := bcrypt.CompareHashAndPassword([]byte(u.DuressPassword), []byte(password)); err != nil {
return false
}
u.IsUnderDuress = true
- u.DoSave()
+ u.DoSave(db)
} else {
u.IsUnderDuress = false
- u.DoSave()
+ u.DoSave(db)
}
return true
}
@@ -420,22 +420,22 @@ func isUsernameReserved(username string) bool {
}
// GetVerifiedUserBySessionID ...
-func GetVerifiedUserBySessionID(token string) (out User, err error) {
- err = DB.First(&out, "token = ? and verified = 1", token).Error
+func (d *DkfDB) GetVerifiedUserBySessionID(token string) (out User, err error) {
+ err = d.db.First(&out, "token = ? and verified = 1", token).Error
return
}
// GetRecentUsersCount ...
-func GetRecentUsersCount() int64 {
+func (d *DkfDB) GetRecentUsersCount() int64 {
var count int64
- DB.Table("users").Where("created_at > datetime('now', '-1 Minute')").Count(&count)
+ d.db.Table("users").Where("created_at > datetime('now', '-1 Minute')").Count(&count)
return count
}
// IsUsernameAlreadyTaken ...
-func IsUsernameAlreadyTaken(username string) bool {
+func (d *DkfDB) IsUsernameAlreadyTaken(username string) bool {
var count int64
- DB.Table("users").Where("username = ? COLLATE NOCASE", username).Count(&count)
+ d.db.Table("users").Where("username = ? COLLATE NOCASE", username).Count(&count)
return count > 0 || isUsernameReserved(username)
}
@@ -446,7 +446,7 @@ type PasswordValidator struct {
}
// NewPasswordValidator ...
-func NewPasswordValidator(password string) *PasswordValidator {
+func NewPasswordValidator(db *DkfDB, password string) *PasswordValidator {
p := new(PasswordValidator)
p.password = password
if len(password) < 8 {
@@ -455,7 +455,7 @@ func NewPasswordValidator(password string) *PasswordValidator {
if len(password) > 128 {
p.error = errors.New("password must be at most 128 characters")
}
- if IsPasswordProhibited(password) {
+ if db.IsPasswordProhibited(password) {
p.error = errors.New("this password is too weak")
}
return p
@@ -482,21 +482,21 @@ func (p *PasswordValidator) Hash() (string, error) {
return string(h), p.error
}
-func CanUseUsername(username string, isFirstUser bool) error {
+func (d *DkfDB) CanUseUsername(username string, isFirstUser bool) error {
if _, err := ValidateUsername(username, isFirstUser); err != nil {
return err
- } else if IsUsernameAlreadyTaken(username) {
+ } else if d.IsUsernameAlreadyTaken(username) {
return errors.New("username already taken")
}
return nil
}
-func CanRenameTo(oldUsername, newUsername string) error {
+func (d *DkfDB) CanRenameTo(oldUsername, newUsername string) error {
if _, err := ValidateUsername(newUsername, false); err != nil {
return err
}
if strings.ToLower(oldUsername) != strings.ToLower(newUsername) {
- if IsUsernameAlreadyTaken(newUsername) {
+ if d.IsUsernameAlreadyTaken(newUsername) {
return errors.New("username already taken")
}
}
@@ -504,34 +504,34 @@ func CanRenameTo(oldUsername, newUsername string) error {
}
// CreateUser ...
-func CreateUser(username, password, repassword string, registrationDuration int64, signupInfoEnc string) (User, UserErrors) {
- return createUser(username, password, repassword, "", false, true, false, false, false, registrationDuration, signupInfoEnc)
+func (d *DkfDB) CreateUser(username, password, repassword string, registrationDuration int64, signupInfoEnc string) (User, UserErrors) {
+ return d.createUser(username, password, repassword, "", false, true, false, false, false, registrationDuration, signupInfoEnc)
}
-func CreateGuestUser(username, password string) (User, UserErrors) {
- return createUser(username, password, password, "", false, true, true, false, false, 0, "signupInfoEnc")
+func (d *DkfDB) CreateGuestUser(username, password string) (User, UserErrors) {
+ return d.createUser(username, password, password, "", false, true, true, false, false, 0, "signupInfoEnc")
}
-func CreateFirstUser(username, password, repassword string) (User, UserErrors) {
- return createUser(username, password, repassword, "", true, true, false, true, false, 12000, "")
+func (d *DkfDB) CreateFirstUser(username, password, repassword string) (User, UserErrors) {
+ return d.createUser(username, password, repassword, "", true, true, false, true, false, 12000, "")
}
-func CreateZeroUser() (User, UserErrors) {
+func (d *DkfDB) CreateZeroUser() (User, UserErrors) {
password := utils.GenerateToken10()
- return createUser("0", password, password, config.NullUserPublicKey, false, true, false, false, true, 12000, "")
+ return d.createUser("0", password, password, config.NullUserPublicKey, false, true, false, false, true, 12000, "")
}
// skipUsernameValidation: entirely skip username validation (for "0" user)
// isFirstUser: less strict username validation; can use "admin"/"n0tr1v" usernames
-func createUser(username, password, repassword, gpgPublicKey string, isAdmin, verified, temp, isFirstUser, skipUsernameValidation bool, registrationDuration int64, signupInfoEnc string) (User, UserErrors) {
+func (d *DkfDB) createUser(username, password, repassword, gpgPublicKey string, isAdmin, verified, temp, isFirstUser, skipUsernameValidation bool, registrationDuration int64, signupInfoEnc string) (User, UserErrors) {
username = strings.TrimSpace(username)
var errs UserErrors
if !skipUsernameValidation {
- if err := CanUseUsername(username, isFirstUser); err != nil {
+ if err := d.CanUseUsername(username, isFirstUser); err != nil {
errs.Username = err.Error()
}
}
- hashedPassword, err := NewPasswordValidator(password).CompareWith(repassword).Hash()
+ hashedPassword, err := NewPasswordValidator(d, password).CompareWith(repassword).Hash()
if err != nil {
errs.Password = err.Error()
}
@@ -569,7 +569,7 @@ func createUser(username, password, repassword, gpgPublicKey string, isAdmin, ve
token := utils.GenerateToken32()
newUser.Token = &token
}
- if err := DB.Create(&newUser).Error; err != nil {
+ if err := d.db.Create(&newUser).Error; err != nil {
logrus.Error(err)
}
@@ -582,8 +582,8 @@ func (u *User) SetAvatar(b []byte) {
u.Avatar = b
}
-func (u *User) IncrKarma(karma int64, description string) {
- if _, err := CreateKarmaHistory(karma, description, u.ID, nil); err != nil {
+func (u *User) IncrKarma(db *DkfDB, karma int64, description string) {
+ if _, err := db.CreateKarmaHistory(karma, description, u.ID, nil); err != nil {
logrus.Error(err)
return
}
diff --git a/pkg/database/tableXmrInvoices.go b/pkg/database/tableXmrInvoices.go
@@ -22,8 +22,8 @@ type XmrInvoice struct {
CreatedAt time.Time
}
-func CreateXmrInvoice(userID UserID, productID int64) (out XmrInvoice, err error) {
- err = DB.Where("user_id = ? AND product_id = ? AND amount_received IS NULL", userID, productID).First(&out).Error
+func (d *DkfDB) CreateXmrInvoice(userID UserID, productID int64) (out XmrInvoice, err error) {
+ err = d.db.Where("user_id = ? AND product_id = ? AND amount_received IS NULL", userID, productID).First(&out).Error
if err == nil {
return
}
@@ -37,14 +37,14 @@ func CreateXmrInvoice(userID UserID, productID int64) (out XmrInvoice, err error
Address: resp.Address,
AmountRequested: 10,
}
- if err = DB.Create(&out).Error; err != nil {
+ if err = d.db.Create(&out).Error; err != nil {
return
}
return
}
-func GetXmrInvoiceByAddress(address string) (out XmrInvoice, err error) {
- err = DB.Where("address = ?", address).First(&out).Error
+func (d *DkfDB) GetXmrInvoiceByAddress(address string) (out XmrInvoice, err error) {
+ err = d.db.Where("address = ?", address).First(&out).Error
return
}
@@ -68,8 +68,8 @@ func (i XmrInvoice) GetImage() (image.Image, error) {
return b, nil
}
-func (i *XmrInvoice) DoSave() {
- if err := DB.Save(i).Error; err != nil {
+func (i *XmrInvoice) DoSave(db *DkfDB) {
+ if err := db.db.Save(i).Error; err != nil {
logrus.Error(err)
}
}
diff --git a/pkg/database/table_forum_threads.go b/pkg/database/table_forum_threads.go
@@ -41,19 +41,19 @@ func MakeForumThread(threadName string, userID UserID, categoryID ForumCategoryI
return ForumThread{UUID: ForumThreadUUID(uuid.New().String()), Name: threadName, UserID: userID, CategoryID: categoryID}
}
-func (u *ForumThread) DoSave() {
- if err := DB.Save(u).Error; err != nil {
+func (u *ForumThread) DoSave(db *DkfDB) {
+ if err := db.db.Save(u).Error; err != nil {
logrus.Error(err)
}
}
-func GetForumCategories() (out []ForumCategory, err error) {
- err = DB.Find(&out).Order("idx ASC, name ASC").Error
+func (d *DkfDB) GetForumCategories() (out []ForumCategory, err error) {
+ err = d.db.Find(&out).Order("idx ASC, name ASC").Error
return
}
-func GetForumCategoryBySlug(slug string) (out ForumCategory, err error) {
- err = DB.First(&out, "slug = ?", slug).Error
+func (d *DkfDB) GetForumCategoryBySlug(slug string) (out ForumCategory, err error) {
+ err = d.db.First(&out, "slug = ?", slug).Error
return
}
@@ -82,29 +82,29 @@ type ForumReadRecord struct {
ReadAt time.Time
}
-func UpdateForumReadRecord(userID UserID, threadID ForumThreadID) {
+func (d *DkfDB) UpdateForumReadRecord(userID UserID, threadID ForumThreadID) {
now := time.Now()
- res := DB.Table("forum_read_records").Where("user_id = ? AND thread_id = ?", userID, threadID).Update("read_at", now)
+ res := d.db.Table("forum_read_records").Where("user_id = ? AND thread_id = ?", userID, threadID).Update("read_at", now)
if res.RowsAffected == 0 {
- DB.Create(ForumReadRecord{UserID: userID, ThreadID: threadID, ReadAt: now})
+ d.db.Create(ForumReadRecord{UserID: userID, ThreadID: threadID, ReadAt: now})
}
}
// DoSave user in the database, ignore error
-func (u *ForumReadRecord) DoSave() {
- if err := DB.Save(u).Error; err != nil {
+func (u *ForumReadRecord) DoSave(db *DkfDB) {
+ if err := db.db.Save(u).Error; err != nil {
logrus.Error(err)
}
}
// DoSave user in the database, ignore error
-func (u *ForumMessage) DoSave() {
- if err := DB.Save(u).Error; err != nil {
+func (u *ForumMessage) DoSave(db *DkfDB) {
+ if err := db.db.Save(u).Error; err != nil {
logrus.Error(err)
}
}
-func (m *ForumMessage) Escape() string {
+func (m *ForumMessage) Escape(db *DkfDB) string {
msg := m.Message
if m.IsSigned {
if b, _ := clearsign.Decode([]byte(msg)); b != nil {
@@ -120,7 +120,7 @@ func (m *ForumMessage) Escape() string {
var tagRgx = regexp.MustCompile(`@(\w{3,20})`)
if tagRgx.MatchString(res) {
res = tagRgx.ReplaceAllStringFunc(res, func(s string) string {
- if user, err := GetUserByUsername(strings.TrimPrefix(s, "@")); err == nil {
+ if user, err := db.GetUserByUsername(strings.TrimPrefix(s, "@")); err == nil {
return `<span style="color: ` + user.ChatColor + `;">` + s + `</span>`
}
return s
@@ -129,22 +129,22 @@ func (m *ForumMessage) Escape() string {
return res
}
-func GetForumMessage(messageID ForumMessageID) (out ForumMessage, err error) {
- err = DB.First(&out, "id = ?", messageID).Error
+func (d *DkfDB) GetForumMessage(messageID ForumMessageID) (out ForumMessage, err error) {
+ err = d.db.First(&out, "id = ?", messageID).Error
return
}
-func GetForumMessageByUUID(messageUUID ForumMessageUUID) (out ForumMessage, err error) {
- err = DB.First(&out, "uuid = ?", messageUUID).Error
+func (d *DkfDB) GetForumMessageByUUID(messageUUID ForumMessageUUID) (out ForumMessage, err error) {
+ err = d.db.First(&out, "uuid = ?", messageUUID).Error
return
}
-func DeleteForumMessageByID(messageID ForumMessageID) error {
- return DB.Where("id = ?", messageID).Delete(&ForumMessage{}).Error
+func (d *DkfDB) DeleteForumMessageByID(messageID ForumMessageID) error {
+ return d.db.Where("id = ?", messageID).Delete(&ForumMessage{}).Error
}
-func DeleteForumThreadByID(threadID ForumThreadID) error {
- return DB.Where("id = ?", threadID).Delete(&ForumThread{}).Error
+func (d *DkfDB) DeleteForumThreadByID(threadID ForumThreadID) error {
+ return d.db.Where("id = ?", threadID).Delete(&ForumThread{}).Error
}
func (m *ForumMessage) CanEdit() bool {
@@ -159,23 +159,23 @@ func (m *ForumMessage) ValidateSignature(pkey string) bool {
return utils.PgpCheckClearSignMessage(pkey, m.Message)
}
-func GetForumThread(threadID ForumThreadID) (out ForumThread, err error) {
- err = DB.First(&out, "id = ? AND is_club = 1", threadID).Error
+func (d *DkfDB) GetForumThread(threadID ForumThreadID) (out ForumThread, err error) {
+ err = d.db.First(&out, "id = ? AND is_club = 1", threadID).Error
return
}
-func GetForumThreadByID(threadID ForumThreadID) (out ForumThread, err error) {
- err = DB.First(&out, "id = ? AND is_club = 0", threadID).Error
+func (d *DkfDB) GetForumThreadByID(threadID ForumThreadID) (out ForumThread, err error) {
+ err = d.db.First(&out, "id = ? AND is_club = 0", threadID).Error
return
}
-func GetForumThreadByUUID(threadUUID ForumThreadUUID) (out ForumThread, err error) {
- err = DB.First(&out, "uuid = ? AND is_club = 0", threadUUID).Error
+func (d *DkfDB) GetForumThreadByUUID(threadUUID ForumThreadUUID) (out ForumThread, err error) {
+ err = d.db.First(&out, "uuid = ? AND is_club = 0", threadUUID).Error
return
}
-func GetForumThreads() (out []ForumThread, err error) {
- err = DB.Order("id DESC").Find(&out).Error
+func (d *DkfDB) GetForumThreads() (out []ForumThread, err error) {
+ err = d.db.Order("id DESC").Find(&out).Error
return
}
@@ -191,8 +191,8 @@ type ForumThreadAug struct {
RepliesCount int64
}
-func GetClubForumThreads(userID UserID) (out []ForumThreadAug, err error) {
- err = DB.Raw(`SELECT t.*,
+func (d *DkfDB) GetClubForumThreads(userID UserID) (out []ForumThreadAug, err error) {
+ err = d.db.Raw(`SELECT t.*,
u.username as author,
u.chat_color as author_chat_color,
lu.username as last_msg_author,
@@ -210,8 +210,8 @@ ORDER BY t.id DESC`, userID).Scan(&out).Error
return
}
-func GetPublicForumCategoryThreads(userID UserID, categoryID ForumCategoryID) (out []ForumThreadAug, err error) {
- err = DB.Raw(`SELECT t.*,
+func (d *DkfDB) GetPublicForumCategoryThreads(userID UserID, categoryID ForumCategoryID) (out []ForumThreadAug, err error) {
+ err = d.db.Raw(`SELECT t.*,
u.username as author,
u.chat_color as author_chat_color,
lu.username as last_msg_author,
@@ -236,8 +236,8 @@ ORDER BY m.created_at DESC, t.id DESC`, userID, categoryID).Scan(&out).Error
return
}
-func GetPublicForumThreadsSearch(userID UserID) (out []ForumThreadAug, err error) {
- err = DB.Raw(`SELECT t.*,
+func (d *DkfDB) GetPublicForumThreadsSearch(userID UserID) (out []ForumThreadAug, err error) {
+ err = d.db.Raw(`SELECT t.*,
u.username as author,
u.chat_color as author_chat_color,
lu.username as last_msg_author,
@@ -262,7 +262,7 @@ ORDER BY m.created_at DESC, t.id DESC`, userID).Scan(&out).Error
return
}
-func GetThreadMessages(threadID ForumThreadID) (out []ForumMessage, err error) {
- err = DB.Preload("User").Find(&out, "thread_id = ?", threadID).Error
+func (d *DkfDB) GetThreadMessages(threadID ForumThreadID) (out []ForumMessage, err error) {
+ err = d.db.Preload("User").Find(&out, "thread_id = ?", threadID).Error
return
}
diff --git a/pkg/database/utils/utils.go b/pkg/database/utils/utils.go
@@ -13,15 +13,15 @@ import (
"net/http"
)
-func GetZeroUser() database.User {
- zeroUser, err := database.GetUserByUsername(config.NullUsername)
+func GetZeroUser(db *database.DkfDB) database.User {
+ zeroUser, err := db.GetUserByUsername(config.NullUsername)
if err != nil {
logrus.Fatal(err)
}
return zeroUser
}
-func SendNewChessGameMessages(key, roomKey string, roomID database.RoomID, zeroUser, player1, player2 database.User) {
+func SendNewChessGameMessages(db *database.DkfDB, key, roomKey string, roomID database.RoomID, zeroUser, player1, player2 database.User) {
// Send game link to players
getPlayerMsg := func(opponent database.User) (raw string, msg string) {
raw = `Chess game against ` + opponent.Username
@@ -29,19 +29,19 @@ func SendNewChessGameMessages(key, roomKey string, roomID database.RoomID, zeroU
return
}
raw, msg := getPlayerMsg(player2)
- _, _ = database.CreateMsg(raw, msg, roomKey, roomID, zeroUser.ID, &player1.ID)
+ _, _ = db.CreateMsg(raw, msg, roomKey, roomID, zeroUser.ID, &player1.ID)
raw, msg = getPlayerMsg(player1)
- _, _ = database.CreateMsg(raw, msg, roomKey, roomID, zeroUser.ID, &player2.ID)
+ _, _ = db.CreateMsg(raw, msg, roomKey, roomID, zeroUser.ID, &player2.ID)
// Send notifications to chess games subscribers
raw = `Chess game: ` + player1.Username + ` VS ` + player2.Username
msg = `<a href="/chess/` + key + `" rel="noopener noreferrer" target="_blank">Chess game: ` + player1.Username + ` VS ` + player2.Username + `</a>`
- users, _ := database.GetChessSubscribers()
+ users, _ := db.GetChessSubscribers()
for _, user := range users {
if user.ID == player1.ID || user.ID == player2.ID {
continue
}
- _, _ = database.CreateMsg(raw, msg, roomKey, roomID, zeroUser.ID, &user.ID)
+ _, _ = db.CreateMsg(raw, msg, roomKey, roomID, zeroUser.ID, &user.ID)
}
}
@@ -89,22 +89,22 @@ func DoParseRoomID(v string) (out database.RoomID) {
return DoParse[database.RoomID](v)
}
-func Kick(kicked, kickedBy database.User, purge, silent bool) error {
+func Kick(db *database.DkfDB, kicked, kickedBy database.User, purge, silent bool) error {
if kicked.IsHellbanned {
silent = true
}
- return kick(kicked, kickedBy, silent, purge)
+ return kick(db, kicked, kickedBy, silent, purge)
}
-func SilentKick(kicked, kickedBy database.User) error {
- return kick(kicked, kickedBy, true, true)
+func SilentKick(db *database.DkfDB, kicked, kickedBy database.User) error {
+ return kick(db, kicked, kickedBy, true, true)
}
-func SelfKick(kicked database.User, silent bool) error {
- return kick(kicked, kicked, silent, true)
+func SelfKick(db *database.DkfDB, kicked database.User, silent bool) error {
+ return kick(db, kicked, kicked, silent, true)
}
-func kick(kicked, kickedBy database.User, silent, purge bool) error {
+func kick(db *database.DkfDB, kicked, kickedBy database.User, silent, purge bool) error {
if !kicked.Verified {
return errors.New("user already kicked")
}
@@ -117,16 +117,16 @@ func kick(kicked, kickedBy database.User, silent, purge bool) error {
return errors.New("cannot kick another moderator")
}
- database.NewAudit(kickedBy, fmt.Sprintf("kick %s #%d", kicked.Username, kicked.ID))
+ db.NewAudit(kickedBy, fmt.Sprintf("kick %s #%d", kicked.Username, kicked.ID))
kicked.Verified = false
- kicked.DoSave()
+ kicked.DoSave(db)
// Remove user from the user cache
managers.ActiveUsers.RemoveUser(kicked.ID)
if purge {
// Purge user messages
- if err := database.DeleteUserChatMessages(kicked.ID); err != nil {
+ if err := db.DeleteUserChatMessages(kicked.ID); err != nil {
logrus.Error(err)
}
}
@@ -134,15 +134,15 @@ func kick(kicked, kickedBy database.User, silent, purge bool) error {
// If user is HB, do not display system message
if !silent {
// Display kick message
- database.CreateKickMsg(kicked, kickedBy)
+ db.CreateKickMsg(kicked, kickedBy)
}
return nil
}
-func GetRoomAndKey(c echo.Context, roomName string) (database.ChatRoom, string, error) {
+func GetRoomAndKey(db *database.DkfDB, c echo.Context, roomName string) (database.ChatRoom, string, error) {
roomKey := ""
- room, err := database.GetChatRoomByName(roomName)
+ room, err := db.GetChatRoomByName(roomName)
if err != nil {
return room, roomKey, c.NoContent(http.StatusNotFound)
}
diff --git a/pkg/global/global.go b/pkg/global/global.go
@@ -20,16 +20,16 @@ func DeleteUserNotificationCount(userID database.UserID, sessionToken string) {
notifCountCache.Delete(cacheKey(userID, sessionToken))
}
-func GetUserNotificationCount(userID database.UserID, sessionToken string) int64 {
+func GetUserNotificationCount(db *database.DkfDB, userID database.UserID, sessionToken string) int64 {
count, found := notifCountCache.Get(cacheKey(userID, sessionToken))
if found {
return count
}
- count = database.GetUserInboxMessagesCount(userID)
- count += database.GetUserNotificationsCount(userID)
+ count = db.GetUserInboxMessagesCount(userID)
+ count += db.GetUserNotificationsCount(userID)
// sessionToken can be empty when using the API
if sessionToken != "" {
- count += database.GetUserSessionNotificationsCount(sessionToken)
+ count += db.GetUserSessionNotificationsCount(sessionToken)
}
notifCountCache.SetD(cacheKey(userID, sessionToken), count)
return count
diff --git a/pkg/template/templates.go b/pkg/template/templates.go
@@ -50,6 +50,7 @@ type templateDataStruct struct {
ShaHTML template.HTML
LogoASCII template.HTML
NullUsername string
+ DB *database.DkfDB
CSRF string
Master bool
Development bool
@@ -64,6 +65,8 @@ type templateDataStruct struct {
func (t *Templates) Render(w io.Writer, name string, data any, c echo.Context) error {
tmpl := t.Templates[name]
+ db := c.Get("database").(*database.DkfDB)
+
d := templateDataStruct{}
d.TmplName = name
d.Data = data
@@ -77,6 +80,7 @@ func (t *Templates) Render(w io.Writer, name string, data any, c echo.Context) e
d.ShaHTML = template.HTML(fmt.Sprintf("<!-- SHA: %s -->", config.Global.Sha()))
d.NullUsername = config.NullUsername
d.CSRF, _ = c.Get("csrf").(string)
+ d.DB = db
d.AcceptLanguage = c.Get("accept-language").(string)
d.Lang = c.Get("lang").(string)
d.BaseKeywords = strings.Join(getBaseKeywords(), ", ")
@@ -94,7 +98,7 @@ func (t *Templates) Render(w io.Writer, name string, data any, c echo.Context) e
if authCookie, err := c.Cookie(hutils.AuthCookieName); err == nil {
sessionToken = authCookie.Value
}
- d.InboxCount = global.GetUserNotificationCount(d.AuthUser.ID, sessionToken)
+ d.InboxCount = global.GetUserNotificationCount(db, d.AuthUser.ID, sessionToken)
}
return tmpl.ExecuteTemplate(w, "base", d)
diff --git a/pkg/web/handlers/admin.go b/pkg/web/handlers/admin.go
@@ -19,6 +19,7 @@ import (
func AdminNewGistHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
var data adminCreateGistData
data.ActiveTab = "gists"
if c.Request().Method == http.MethodPost {
@@ -33,9 +34,8 @@ func AdminNewGistHandler(c echo.Context) error {
if data.Password != "" {
passwordHash = database.GetGistPasswordHash(data.Password)
}
- gist := database.Gist{Name: data.Name, Password: passwordHash, UserID: authUser.ID, Content: data.Content}
- gist.UUID = uuid.New().String()
- if err := database.DB.Create(&gist).Error; err != nil {
+ gist := database.Gist{UUID: uuid.New().String(), Name: data.Name, Password: passwordHash, UserID: authUser.ID, Content: data.Content}
+ if err := db.DB().Create(&gist).Error; err != nil {
data.Error = err.Error()
return c.Render(http.StatusOK, "admin.gist-create", data)
}
@@ -46,8 +46,9 @@ func AdminNewGistHandler(c echo.Context) error {
func AdminEditGistHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
gistUUID := c.Param("gistUUID")
- gist, err := database.GetGistByUUID(gistUUID)
+ gist, err := db.GetGistByUUID(gistUUID)
if err != nil {
return c.Redirect(http.StatusFound, "/")
}
@@ -78,19 +79,20 @@ func AdminEditGistHandler(c echo.Context) error {
}
gist.Name = data.Name
gist.Content = data.Content
- gist.DoSave()
+ gist.DoSave(db)
return c.Redirect(http.StatusFound, "/gists/"+gist.UUID)
}
return c.Render(http.StatusOK, "admin.gist-create", data)
}
func AdminGistsHandler(c echo.Context) error {
+ db := c.Get("database").(*database.DkfDB)
var data adminGistsData
data.ActiveTab = "gists"
userQuery := c.QueryParam("u")
- if err := database.DB.Table("gists").
+ if err := db.DB().Table("gists").
Scopes(func(query *gorm.DB) *gorm.DB {
if userQuery != "" {
query = query.Where("user_id = ?", userQuery)
@@ -108,15 +110,16 @@ func AdminGistsHandler(c echo.Context) error {
}
func AdminUploadsHandler(c echo.Context) error {
+ db := c.Get("database").(*database.DkfDB)
if c.Request().Method == http.MethodPost {
formName := c.FormValue("formName")
if formName == "deleteUpload" {
fileName := c.Request().PostFormValue("file_name")
- file, err := database.GetUploadByFileName(fileName)
+ file, err := db.GetUploadByFileName(fileName)
if err != nil {
return c.Redirect(http.StatusFound, "/")
}
- if err := file.Delete(); err != nil {
+ if err := file.Delete(db); err != nil {
logrus.Error(err)
}
return c.Redirect(http.StatusFound, c.Request().Referer())
@@ -125,7 +128,7 @@ func AdminUploadsHandler(c echo.Context) error {
var data adminUploadsData
data.ActiveTab = "uploads"
- data.Uploads, _ = database.GetUploads()
+ data.Uploads, _ = db.GetUploads()
for _, f := range data.Uploads {
data.TotalSize += f.FileSize
}
@@ -133,21 +136,22 @@ func AdminUploadsHandler(c echo.Context) error {
}
func AdminFiledropsHandler(c echo.Context) error {
+ db := c.Get("database").(*database.DkfDB)
if c.Request().Method == http.MethodPost {
formName := c.FormValue("formName")
if formName == "createFiledrop" {
- if _, err := database.CreateFiledrop(); err != nil {
+ if _, err := db.CreateFiledrop(); err != nil {
logrus.Error(err)
}
return c.Redirect(http.StatusFound, c.Request().Referer())
} else if formName == "deleteFiledrop" {
fileName := c.Request().PostFormValue("file_name")
- file, err := database.GetFiledropByFileName(fileName)
+ file, err := db.GetFiledropByFileName(fileName)
if err != nil {
return c.Redirect(http.StatusFound, "/")
}
- if err := file.Delete(); err != nil {
+ if err := file.Delete(db); err != nil {
logrus.Error(err)
}
return c.Redirect(http.StatusFound, c.Request().Referer())
@@ -156,7 +160,7 @@ func AdminFiledropsHandler(c echo.Context) error {
var data adminFiledropsData
data.ActiveTab = "filedrops"
- data.Filedrops, _ = database.GetFiledrops()
+ data.Filedrops, _ = db.GetFiledrops()
for _, f := range data.Filedrops {
data.TotalSize += f.FileSize
}
@@ -164,12 +168,13 @@ func AdminFiledropsHandler(c echo.Context) error {
}
func AdminDownloadsHandler(c echo.Context) error {
+ db := c.Get("database").(*database.DkfDB)
var data adminDownloadsData
data.ActiveTab = "downloads"
userQuery := c.QueryParam("u")
- database.DB.Model(&database.Download{}).
+ db.DB().Model(&database.Download{}).
Scopes(func(query *gorm.DB) *gorm.DB {
if userQuery != "" {
query = query.Where("user_id = ?", userQuery)
@@ -185,13 +190,14 @@ func AdminDownloadsHandler(c echo.Context) error {
}
func AdminDeleteDownloadHandler(c echo.Context) error {
+ db := c.Get("database").(*database.DkfDB)
downloadID, err := utils.ParseInt64(c.Param("downloadID"))
if err != nil {
return c.Render(http.StatusOK, "flash",
FlashResponse{"download id not found", c.Request().Referer(), "alert-danger"})
}
- if err := database.DeleteDownloadByID(downloadID); err != nil {
+ if err := db.DeleteDownloadByID(downloadID); err != nil {
logrus.Error(err)
}
return c.Redirect(http.StatusFound, c.Request().Referer())
@@ -199,9 +205,10 @@ func AdminDeleteDownloadHandler(c echo.Context) error {
// AdminSettingsHandler ...
func AdminSettingsHandler(c echo.Context) error {
+ db := c.Get("database").(*database.DkfDB)
var data adminSettingsData
data.ActiveTab = "settings"
- settings := database.GetSettings()
+ settings := db.GetSettings()
data.ProtectHome = settings.ProtectHome
data.HomeUsersList = settings.HomeUsersList
data.ForceLoginCaptcha = settings.ForceLoginCaptcha
@@ -221,7 +228,7 @@ func AdminSettingsHandler(c echo.Context) error {
return c.Redirect(http.StatusFound, c.Request().Referer())
} else if formName == "saveSettings" {
- settings := database.GetSettings()
+ settings := db.GetSettings()
settings.ProtectHome = utils.DoParseBool(c.Request().PostFormValue("protectHome"))
settings.HomeUsersList = utils.DoParseBool(c.Request().PostFormValue("homeUsersList"))
settings.ForceLoginCaptcha = utils.DoParseBool(c.Request().PostFormValue("forceLoginCaptcha"))
@@ -231,7 +238,7 @@ func AdminSettingsHandler(c echo.Context) error {
settings.ForumEnabled = utils.DoParseBool(c.Request().PostFormValue("forumEnabled"))
settings.MaybeAuthEnabled = utils.DoParseBool(c.Request().PostFormValue("maybeAuthEnabled"))
settings.CaptchaDifficulty = utils.DoParseInt64(c.Request().PostFormValue("captchaDifficulty"))
- settings.DoSave()
+ settings.DoSave(db)
config.ProtectHome.Store(settings.ProtectHome)
config.HomeUsersList.Store(settings.HomeUsersList)
config.ForceLoginCaptcha.Store(settings.ForceLoginCaptcha)
@@ -250,12 +257,13 @@ func AdminSettingsHandler(c echo.Context) error {
}
func AdminHandler(c echo.Context) error {
+ db := c.Get("database").(*database.DkfDB)
var data adminData
data.ActiveTab = "users"
data.Query = strings.TrimSpace(c.QueryParam("q"))
likeQuery := "%" + data.Query + "%"
- if err := database.DB.
+ if err := db.DB().
Table("users").
Scopes(func(query *gorm.DB) *gorm.DB {
if data.Query != "" {
@@ -273,12 +281,13 @@ func AdminHandler(c echo.Context) error {
}
func SessionsHandler(c echo.Context) error {
+ db := c.Get("database").(*database.DkfDB)
var data adminSessionsData
data.ActiveTab = "sessions"
data.Query = c.QueryParam("q")
likeQuery := "%" + data.Query + "%"
- if err := database.DB.
+ if err := db.DB().
Table("sessions").
Where("deleted_at IS NULL").
Scopes(func(query *gorm.DB) *gorm.DB {
@@ -297,12 +306,13 @@ func SessionsHandler(c echo.Context) error {
}
func IgnoredHandler(c echo.Context) error {
+ db := c.Get("database").(*database.DkfDB)
var data adminIgnoredData
data.ActiveTab = "ignored"
data.Query = c.QueryParam("q")
likeQuery := "%" + data.Query + "%"
- if err := database.DB.
+ if err := db.DB().
Table("ignored_users").
Scopes(func(query *gorm.DB) *gorm.DB {
if data.Query != "" {
@@ -367,10 +377,11 @@ func BackupHandler(c echo.Context) error {
}
func AdminAuditsHandler(c echo.Context) error {
+ db := c.Get("database").(*database.DkfDB)
var data adminAuditsData
data.ActiveTab = "audits"
- if err := database.DB.
+ if err := db.DB().
Table("audit_logs").
Scopes(func(query *gorm.DB) *gorm.DB {
data.CurrentPage, data.MaxPage, data.AuditLogsCount, query = NewPaginator().Paginate(c, query)
@@ -385,12 +396,13 @@ func AdminAuditsHandler(c echo.Context) error {
}
func AdminRoomsHandler(c echo.Context) error {
+ db := c.Get("database").(*database.DkfDB)
var data adminRoomsData
data.ActiveTab = "rooms"
data.Query = c.QueryParam("q")
likeQuery := "%" + data.Query + "%"
- if err := database.DB.
+ if err := db.DB().
Table("chat_rooms").
Scopes(func(query *gorm.DB) *gorm.DB {
if data.Query != "" {
@@ -408,10 +420,11 @@ func AdminRoomsHandler(c echo.Context) error {
}
func AdminCaptchaHandler(c echo.Context) error {
+ db := c.Get("database").(*database.DkfDB)
var data adminCaptchaData
data.ActiveTab = "captcha"
- if err := database.DB.Table("captcha_requests").
+ if err := db.DB().Table("captcha_requests").
Scopes(func(query *gorm.DB) *gorm.DB {
data.CurrentPage, data.MaxPage, data.CaptchasCount, query = NewPaginator().Paginate(c, query)
return query
@@ -426,6 +439,7 @@ func AdminCaptchaHandler(c echo.Context) error {
// AdminDeleteUserHandler ...
func AdminDeleteUserHandler(c echo.Context) error {
+ db := c.Get("database").(*database.DkfDB)
userID, err := dutils.ParseUserID(c.Param("userID"))
if err != nil {
return c.Render(http.StatusOK, "flash",
@@ -436,41 +450,45 @@ func AdminDeleteUserHandler(c echo.Context) error {
FlashResponse{"Root admin cannot be deleted", c.Request().Referer(), "alert-danger"})
}
- if err := database.DeleteUserByID(userID); err != nil {
+ if err := db.DeleteUserByID(userID); err != nil {
logrus.Error(err)
}
return c.Redirect(http.StatusFound, c.Request().Referer())
}
func IgnoredDeleteHandler(c echo.Context) error {
+ db := c.Get("database").(*database.DkfDB)
userID := dutils.DoParseUserID(c.Request().PostFormValue("user_id"))
ignoredUserID := dutils.DoParseUserID(c.Request().PostFormValue("ignored_user_id"))
- database.UnIgnoreUser(userID, ignoredUserID)
+ db.UnIgnoreUser(userID, ignoredUserID)
return c.Redirect(http.StatusFound, c.Request().Referer())
}
// AdminDeleteRoomHandler ...
func AdminDeleteRoomHandler(c echo.Context) error {
+ db := c.Get("database").(*database.DkfDB)
id := dutils.DoParseRoomID(c.Param("roomID"))
- database.DeleteChatRoomByID(id)
+ db.DeleteChatRoomByID(id)
return c.Redirect(http.StatusFound, c.Request().Referer())
}
func AdminUserSecurityLogsHandler(c echo.Context) error {
//authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
userID, err := dutils.ParseUserID(c.Param("userID"))
if err != nil {
return c.Redirect(http.StatusFound, "/admin/users")
}
var data settingsSecurityData
data.ActiveTab = "security"
- data.Logs, _ = database.GetSecurityLogs(userID)
+ data.Logs, _ = db.GetSecurityLogs(userID)
return c.Render(http.StatusOK, "admin.user-security-logs", data)
}
// AdminEditUserHandler ...
func AdminEditUserHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
userID, err := dutils.ParseUserID(c.Param("userID"))
if err != nil {
return c.Redirect(http.StatusFound, "/admin/users")
@@ -479,7 +497,7 @@ func AdminEditUserHandler(c echo.Context) error {
if userID == config.RootAdminID && authUser.ID != config.RootAdminID {
return c.Redirect(http.StatusFound, "/admin/users")
}
- user, err := database.GetUserByID(userID)
+ user, err := db.GetUserByID(userID)
if err != nil {
return c.Redirect(http.StatusFound, "/admin/users")
}
@@ -516,7 +534,7 @@ func AdminEditUserHandler(c echo.Context) error {
formName := c.Request().PostFormValue("formName")
if formName == "reset_tutorial" {
user.ChatTutorial = 0
- user.DoSave()
+ user.DoSave(db)
return c.Redirect(http.StatusFound, c.Request().Referer())
}
@@ -539,7 +557,7 @@ func AdminEditUserHandler(c echo.Context) error {
data.ChatColor = c.FormValue("chat_color")
data.ChatFont = utils.DoParseInt64(c.FormValue("chat_font"))
if data.Username != user.Username {
- if err := database.CanRenameTo(user.Username, data.Username); err != nil {
+ if err := db.CanRenameTo(user.Username, data.Username); err != nil {
data.Errors.Username = err.Error()
}
}
@@ -549,7 +567,7 @@ func AdminEditUserHandler(c echo.Context) error {
data.RePassword = c.Request().PostFormValue("repassword")
data.ApiKey = c.Request().PostFormValue("api_key")
if data.Password != "" || data.RePassword != "" {
- hashedPassword, err = database.NewPasswordValidator(data.Password).CompareWith(data.RePassword).Hash()
+ hashedPassword, err = database.NewPasswordValidator(db, data.Password).CompareWith(data.RePassword).Hash()
if err != nil {
data.Errors.Password = err.Error()
}
@@ -560,7 +578,7 @@ func AdminEditUserHandler(c echo.Context) error {
if hashedPassword != "" {
user.LoginAttempts = 0
- if err := user.ChangePassword(hashedPassword); err != nil {
+ if err := user.ChangePassword(db, hashedPassword); err != nil {
data.Errors.Password = err.Error()
return c.Render(http.StatusOK, "admin.user-edit", data)
}
@@ -569,7 +587,7 @@ func AdminEditUserHandler(c echo.Context) error {
user.Username = data.Username
user.IsAdmin = data.IsAdmin
if data.IsHellbanned {
- user.HellBan()
+ user.HellBan(db)
managers.ActiveUsers.UpdateUserHBInRooms(managers.NewUserInfo(user, nil))
}
user.ApiKey = data.ApiKey
@@ -589,18 +607,19 @@ func AdminEditUserHandler(c echo.Context) error {
user.Role = data.Role
user.ChatColor = data.ChatColor
user.ChatFont = data.ChatFont
- user.DoSave()
+ user.DoSave(db)
return c.Redirect(http.StatusFound, c.Request().Referer())
}
// AdminEditRoomHandler ...
func AdminEditRoomHandler(c echo.Context) error {
+ db := c.Get("database").(*database.DkfDB)
roomID, err := dutils.ParseRoomID(c.Param("roomID"))
if err != nil {
return c.Redirect(http.StatusFound, "/admin/rooms")
}
- room, err := database.GetChatRoomByID(roomID)
+ room, err := db.GetChatRoomByID(roomID)
if err != nil {
return c.Redirect(http.StatusFound, "/admin/rooms")
}
@@ -618,7 +637,7 @@ func AdminEditRoomHandler(c echo.Context) error {
room.IsEphemeral = data.IsEphemeral
room.IsListed = data.IsListed
- room.DoSave()
+ room.DoSave(db)
return c.Redirect(http.StatusFound, "/admin/rooms")
}
diff --git a/pkg/web/handlers/api/v1/bangInterceptor.go b/pkg/web/handlers/api/v1/bangInterceptor.go
@@ -20,13 +20,13 @@ Chats:
Black Hat Chat: ` + config.BhcOnion + `
Forums:
CryptBB: ` + config.CryptbbOnion
- msg, _ := ProcessRawMessage(message, "", cmd.authUser.ID, cmd.room.ID, nil)
+ msg, _ := ProcessRawMessage(cmd.db, message, "", cmd.authUser.ID, cmd.room.ID, nil)
cmd.zeroMsg(msg)
cmd.err = ErrRedirect
}
func handleRtutoBangCmd(cmd *Command) {
cmd.authUser.ChatTutorial = 0
- cmd.authUser.DoSave()
+ cmd.authUser.DoSave(cmd.db)
cmd.err = ErrRedirect
}
diff --git a/pkg/web/handlers/api/v1/battleship.go b/pkg/web/handlers/api/v1/battleship.go
@@ -213,13 +213,14 @@ func (g *BSGame) Shot(pos string) (shipStr string, shipDead, gameEnded bool, err
type Battleship struct {
sync.Mutex
+ db *database.DkfDB
zeroID database.UserID
games map[string]*BSGame
}
-func NewBattleship() *Battleship {
- zeroUser, _ := database.GetUserByUsername(config.NullUsername)
- b := &Battleship{zeroID: zeroUser.ID}
+func NewBattleship(db *database.DkfDB) *Battleship {
+ zeroUser, _ := db.GetUserByUsername(config.NullUsername)
+ b := &Battleship{db: db, zeroID: zeroUser.ID}
b.games = make(map[string]*BSGame)
// Thread that cleanup inactive games
@@ -553,7 +554,7 @@ func (b *Battleship) playMove(roomName string, roomID database.RoomID, roomKey s
b.Lock()
defer b.Unlock()
- user, err := database.GetUserByUsername(enemyUsername)
+ user, err := b.db.GetUserByUsername(enemyUsername)
if err != nil {
return errors.New("invalid username")
}
@@ -586,17 +587,17 @@ func (b *Battleship) playMove(roomName string, roomID database.RoomID, roomKey s
}
// Delete old messages sent by "0" to the players
- if err := database.DB.
+ if err := b.db.DB().
Where("room_id = ? AND user_id = ? AND (to_user_id = ? OR to_user_id = ?)", roomID, b.zeroID, g.player1.id, g.player2.id).
Delete(&database.ChatMessage{}).Error; err != nil {
logrus.Error(err)
}
card1 := g.drawCardFor(0, roomName, isNewGame, shipDead, gameEnded, shipStr, pos)
- _, _ = database.CreateMsg(card1, card1, roomKey, roomID, b.zeroID, &g.player1.id)
+ _, _ = b.db.CreateMsg(card1, card1, roomKey, roomID, b.zeroID, &g.player1.id)
card2 := g.drawCardFor(1, roomName, isNewGame, shipDead, gameEnded, shipStr, pos)
- _, _ = database.CreateMsg(card2, card2, roomKey, roomID, b.zeroID, &g.player2.id)
+ _, _ = b.db.CreateMsg(card2, card2, roomKey, roomID, b.zeroID, &g.player2.id)
if gameEnded {
delete(b.games, gameKey)
diff --git a/pkg/web/handlers/api/v1/chess.go b/pkg/web/handlers/api/v1/chess.go
@@ -64,13 +64,14 @@ func newChessGame(gameKey string, player1, player2 database.User) *ChessGame {
type Chess struct {
sync.Mutex
+ db *database.DkfDB
zeroID database.UserID
games map[string]*ChessGame
}
-func NewChess() *Chess {
- zeroUser, _ := database.GetUserByUsername(config.NullUsername)
- c := &Chess{zeroID: zeroUser.ID}
+func NewChess(db *database.DkfDB) *Chess {
+ zeroUser, _ := db.GetUserByUsername(config.NullUsername)
+ c := &Chess{db: db, zeroID: zeroUser.ID}
c.games = make(map[string]*ChessGame)
// Thread that cleanup inactive games
@@ -401,8 +402,8 @@ func (b *Chess) NewGame1(roomKey string, roomID database.RoomID, player1, player
key := uuid.New().String()
g := b.NewGame(key, player1, player2)
- zeroUser := dutils.GetZeroUser()
- dutils.SendNewChessGameMessages(key, roomKey, roomID, zeroUser, player1, player2)
+ zeroUser := dutils.GetZeroUser(b.db)
+ dutils.SendNewChessGameMessages(b.db, key, roomKey, roomID, zeroUser, player1, player2)
return g, nil
}
@@ -471,11 +472,11 @@ func (b *Chess) SendMove(gameKey string, userID database.UserID, g *ChessGame, c
// Notify (pm) the opponent that you made a move
if opponent.NotifyChessMove {
msg := fmt.Sprintf("@%s played %s", you.Username, moveStr)
- msg, _ = colorifyTaggedUsers(msg, database.GetUsersByUsername)
- chatMsg, _ := database.CreateMsg(msg, msg, "", config.GeneralRoomID, b.zeroID, &opponent.ID)
+ msg, _ = colorifyTaggedUsers(msg, b.db.GetUsersByUsername)
+ chatMsg, _ := b.db.CreateMsg(msg, msg, "", config.GeneralRoomID, b.zeroID, &opponent.ID)
go func() {
time.Sleep(30 * time.Second)
- _ = database.DeleteChatMessageByUUID(chatMsg.UUID)
+ _ = b.db.DeleteChatMessageByUUID(chatMsg.UUID)
}()
}
@@ -500,7 +501,7 @@ func (b *Chess) playMove(enemyUsername, pos string, authUser database.User, c ec
b.Lock()
defer b.Unlock()
- user, err := database.GetUserByUsername(enemyUsername)
+ user, err := b.db.GetUserByUsername(enemyUsername)
if err != nil {
return errors.New("invalid username")
}
@@ -525,17 +526,17 @@ func (b *Chess) playMove(enemyUsername, pos string, authUser database.User, c ec
}
// Delete old messages sent by "0" to the players
- if err := database.DB.
+ if err := b.db.DB().
Where("room_id = ? AND user_id = ? AND (to_user_id = ? OR to_user_id = ?)", roomID, b.zeroID, g.Player1.ID, g.Player2.ID).
Delete(&database.ChatMessage{}).Error; err != nil {
logrus.Error(err)
}
card1 := g.DrawPlayerCard(roomName, true, false, true)
- _, _ = database.CreateMsg(card1, card1, roomKey, roomID, b.zeroID, &g.Player1.ID)
+ _, _ = b.db.CreateMsg(card1, card1, roomKey, roomID, b.zeroID, &g.Player1.ID)
card1 = g.DrawPlayerCard(roomName, true, true, true)
- _, _ = database.CreateMsg(card1, card1, roomKey, roomID, b.zeroID, &g.Player2.ID)
+ _, _ = b.db.CreateMsg(card1, card1, roomKey, roomID, b.zeroID, &g.Player2.ID)
return nil
}
diff --git a/pkg/web/handlers/api/v1/handlers.go b/pkg/web/handlers/api/v1/handlers.go
@@ -83,12 +83,13 @@ var unhideRgx = regexp.MustCompile(`^/unhide (\d{2}:\d{2}:\d{2})$`)
func ChatMessagesHandler(c echo.Context) error {
authCookie, _ := c.Cookie(hutils.AuthCookieName)
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
roomName := c.Param("roomName")
pmOnlyQuery := dutils.DoParsePmDisplayMode(c.QueryParam("pmonly"))
mentionsOnlyQuery := utils.DoParseBool(c.QueryParam("mentionsOnly"))
- room, err := database.GetChatRoomByName(roomName)
+ room, err := db.GetChatRoomByName(roomName)
if err != nil {
return c.NoContent(http.StatusNotFound)
}
@@ -102,7 +103,7 @@ func ChatMessagesHandler(c echo.Context) error {
// Only fill the ignored set if the user does not display the ignored users ("Toggle ignored" chat setting)
// and if the user has "Hide ignored users from users lists" enabled (user setting)
if !authUser.DisplayIgnored && authUser.HideIgnoredUsersFromList {
- ignoredUsers, _ := database.GetIgnoredUsers(authUser.ID)
+ ignoredUsers, _ := db.GetIgnoredUsers(authUser.ID)
for _, ignoredUser := range ignoredUsers {
ignoredSet.Insert(ignoredUser.IgnoredUser.Username)
}
@@ -111,7 +112,7 @@ func ChatMessagesHandler(c echo.Context) error {
membersInRoom, membersInChat := managers.ActiveUsers.GetRoomUsers(room, ignoredSet)
displayHellbanned := authUser.DisplayHellbanned || authUser.IsHellbanned
displayIgnoredMessages := false
- msgs, _ := database.GetChatMessages(room.ID, authUser.Username, authUser.ID, pmOnlyQuery, mentionsOnlyQuery, displayHellbanned, authUser.DisplayIgnored, authUser.DisplayModerators, displayIgnoredMessages)
+ msgs, _ := db.GetChatMessages(room.ID, authUser.Username, authUser.ID, pmOnlyQuery, mentionsOnlyQuery, displayHellbanned, authUser.DisplayIgnored, authUser.DisplayModerators, displayIgnoredMessages)
if room.IsProtected() {
key, err := hutils.GetRoomKeyCookie(c, int64(room.ID))
if err != nil {
@@ -123,7 +124,7 @@ func ChatMessagesHandler(c echo.Context) error {
}
// Update read record
- database.UpdateChatReadRecord(authUser.ID, room.ID)
+ db.UpdateChatReadRecord(authUser.ID, room.ID)
var data chatMessagesData
@@ -171,11 +172,11 @@ func ChatMessagesHandler(c echo.Context) error {
if authCookie != nil {
sessionToken = authCookie.Value
}
- data.InboxCount = global.GetUserNotificationCount(authUser.ID, sessionToken)
+ data.InboxCount = global.GetUserNotificationCount(db, authUser.ID, sessionToken)
- data.ReadMarker, _ = database.GetUserReadMarker(authUser.ID, room.ID)
- data.OfficialRooms, _ = database.GetOfficialChatRooms1(authUser.ID)
- data.SubscribedRooms, _ = database.GetUserRoomSubscriptions(authUser.ID)
+ data.ReadMarker, _ = db.GetUserReadMarker(authUser.ID, room.ID)
+ data.OfficialRooms, _ = db.GetOfficialChatRooms1(authUser.ID)
+ data.SubscribedRooms, _ = db.GetUserRoomSubscriptions(authUser.ID)
bools := []bool{authUser.DisplayDeleteButton}
if authUser.IsModerator() {
@@ -204,10 +205,11 @@ func ChatMessagesHandler(c echo.Context) error {
func RoomNotifierHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
roomID := dutils.DoParseRoomID(c.Param("roomID"))
lastKnownDate := c.Request().PostFormValue("last_known_date")
- room, err := database.GetChatRoomByID(roomID)
+ room, err := db.GetChatRoomByID(roomID)
if err != nil {
return c.NoContent(http.StatusNotFound)
}
@@ -219,7 +221,7 @@ func RoomNotifierHandler(c echo.Context) error {
displayHellbanned := authUser.DisplayHellbanned || authUser.IsHellbanned
displayIgnoredMessages := false
- msgs, _ := database.GetChatMessages(roomID, authUser.Username, authUser.ID, database.PmNoFilter, false, displayHellbanned, authUser.DisplayIgnored, authUser.DisplayModerators, displayIgnoredMessages)
+ msgs, _ := db.GetChatMessages(roomID, authUser.Username, authUser.ID, database.PmNoFilter, false, displayHellbanned, authUser.DisplayIgnored, authUser.DisplayModerators, displayIgnoredMessages)
if room.IsProtected() {
key, err := hutils.GetRoomKeyCookie(c, int64(room.ID))
if err != nil {
@@ -233,7 +235,7 @@ func RoomNotifierHandler(c echo.Context) error {
var data testData
data.NewMessageSound, data.PmSound, data.TaggedSound, data.LastMessageCreatedAt = shouldPlaySound(authUser, lastKnownDate, msgs)
- data.InboxCount = database.GetUserInboxMessagesCount(authUser.ID)
+ data.InboxCount = db.GetUserInboxMessagesCount(authUser.ID)
return c.JSON(http.StatusOK, data)
}
@@ -269,15 +271,16 @@ func shouldPlaySound(authUser *database.User, lastKnownDate string, msgs []datab
func UserHellbanHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
userID := dutils.DoParseUserID(c.Param("userID"))
- user, err := database.GetUserByID(userID)
+ user, err := db.GetUserByID(userID)
if err != nil {
return c.Redirect(http.StatusFound, c.Request().Referer())
}
if !user.IsHellbanned {
if authUser.IsAdmin || !user.IsModerator() {
- database.NewAudit(*authUser, fmt.Sprintf("hellban %s #%d", user.Username, user.ID))
- user.HellBan()
+ db.NewAudit(*authUser, fmt.Sprintf("hellban %s #%d", user.Username, user.ID))
+ user.HellBan(db)
managers.ActiveUsers.UpdateUserHBInRooms(managers.NewUserInfo(user, nil))
}
}
@@ -286,14 +289,15 @@ func UserHellbanHandler(c echo.Context) error {
func UserUnHellbanHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
userID := dutils.DoParseUserID(c.Param("userID"))
- user, err := database.GetUserByID(userID)
+ user, err := db.GetUserByID(userID)
if err != nil {
return c.Redirect(http.StatusFound, c.Request().Referer())
}
if user.IsHellbanned {
- database.NewAudit(*authUser, fmt.Sprintf("unhellban %s #%d", user.Username, user.ID))
- user.UnHellBan()
+ db.NewAudit(*authUser, fmt.Sprintf("unhellban %s #%d", user.Username, user.ID))
+ user.UnHellBan(db)
managers.ActiveUsers.UpdateUserHBInRooms(managers.NewUserInfo(user, nil))
}
return c.Redirect(http.StatusFound, c.Request().Referer())
@@ -301,67 +305,73 @@ func UserUnHellbanHandler(c echo.Context) error {
func KickHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
userID := dutils.DoParseUserID(c.Param("userID"))
- user, err := database.GetUserByID(userID)
+ user, err := db.GetUserByID(userID)
if err != nil {
return c.Redirect(http.StatusFound, c.Request().Referer())
}
if user.IsModerator() {
return c.Redirect(http.StatusFound, c.Request().Referer())
}
- _ = dutils.SilentKick(user, *authUser)
+ _ = dutils.SilentKick(db, user, *authUser)
return c.Redirect(http.StatusFound, c.Request().Referer())
}
func SubscribeHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
roomName := c.Param("roomName")
- room, err := database.GetChatRoomByName(roomName)
+ room, err := db.GetChatRoomByName(roomName)
if err != nil {
return c.Redirect(http.StatusFound, c.Request().Referer())
}
- _ = database.SubscribeToRoom(authUser.ID, room.ID)
+ _ = db.SubscribeToRoom(authUser.ID, room.ID)
return c.Redirect(http.StatusFound, c.Request().Referer())
}
func UnsubscribeHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
roomName := c.Param("roomName")
- room, err := database.GetChatRoomByName(roomName)
+ room, err := db.GetChatRoomByName(roomName)
if err != nil {
return c.Redirect(http.StatusFound, c.Request().Referer())
}
- _ = database.UnsubscribeFromRoom(authUser.ID, room.ID)
+ _ = db.UnsubscribeFromRoom(authUser.ID, room.ID)
return c.Redirect(http.StatusFound, c.Request().Referer())
}
func ThreadSubscribeHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
threadUUID := database.ForumThreadUUID(c.Param("threadUUID"))
- thread, err := database.GetForumThreadByUUID(threadUUID)
+ thread, err := db.GetForumThreadByUUID(threadUUID)
if err != nil {
return c.Redirect(http.StatusFound, c.Request().Referer())
}
- _ = database.SubscribeToForumThread(authUser.ID, thread.ID)
+ _ = db.SubscribeToForumThread(authUser.ID, thread.ID)
return c.Redirect(http.StatusFound, c.Request().Referer())
}
func ThreadUnsubscribeHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
threadUUID := database.ForumThreadUUID(c.Param("threadUUID"))
- thread, err := database.GetForumThreadByUUID(threadUUID)
+ thread, err := db.GetForumThreadByUUID(threadUUID)
if err != nil {
return c.Redirect(http.StatusFound, c.Request().Referer())
}
- _ = database.UnsubscribeFromForumThread(authUser.ID, thread.ID)
+ _ = db.UnsubscribeFromForumThread(authUser.ID, thread.ID)
return c.Redirect(http.StatusFound, c.Request().Referer())
}
func ChatMessageReactionHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
messageUUID := c.Request().PostFormValue("message_uuid")
var msg database.ChatMessage
- if err := database.DB.Where("uuid = ?", messageUUID).Preload("User").Preload("Room").First(&msg).Error; err != nil {
+ if err := db.DB().Where("uuid = ?", messageUUID).Preload("User").Preload("Room").First(&msg).Error; err != nil {
return err
}
reaction := utils.DoParseInt64(c.Request().PostFormValue("reaction_id"))
@@ -369,8 +379,8 @@ func ChatMessageReactionHandler(c echo.Context) error {
return errors.New("invalid reaction")
}
- if err := database.CreateChatReaction(authUser.ID, msg.ID, reaction); err != nil {
- _ = database.DeleteReaction(authUser.ID, msg.ID, reaction)
+ if err := db.CreateChatReaction(authUser.ID, msg.ID, reaction); err != nil {
+ _ = db.DeleteReaction(authUser.ID, msg.ID, reaction)
}
return c.Redirect(http.StatusFound, c.Request().Referer())
@@ -378,10 +388,11 @@ func ChatMessageReactionHandler(c echo.Context) error {
func ChatDeleteMessageHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
messageUUID := c.Param("messageUUID")
var msg database.ChatMessage
- if err := database.DB.Where("uuid = ?", messageUUID).
+ if err := db.DB().Where("uuid = ?", messageUUID).
Preload("User").
Preload("Room").
First(&msg).Error; err != nil {
@@ -404,7 +415,7 @@ func ChatDeleteMessageHandler(c echo.Context) error {
msg.User.Username,
msg.User.ID,
utils.TruncStr(msg.RawMessage, 75, "…"))
- database.NewAudit(*authUser, auditMsg)
+ db.NewAudit(*authUser, auditMsg)
}
}
} else if msg.Room.OwnerUserID != nil && authUser.ID == *msg.Room.OwnerUserID { // Room owner can delete messages in its room
@@ -414,12 +425,12 @@ func ChatDeleteMessageHandler(c echo.Context) error {
if msg.RoomID == config.GeneralRoomID && msg.ToUserID == nil {
authUser.GeneralMessagesCount--
- authUser.DoSave()
+ authUser.DoSave(db)
}
// If we delete message manually, also delete linked inbox if any
- _ = database.DeleteChatInboxMessageByChatMessageID(msg.ID)
- if err := database.DeleteChatMessageByUUID(messageUUID); err != nil {
+ _ = db.DeleteChatInboxMessageByChatMessageID(msg.ID)
+ if err := db.DeleteChatMessageByUUID(messageUUID); err != nil {
logrus.Error(err)
}
@@ -428,8 +439,9 @@ func ChatDeleteMessageHandler(c echo.Context) error {
func ClubDeleteMessageHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
messageID := database.ForumMessageID(utils.DoParseInt64(c.Param("messageID")))
- msg, err := database.GetForumMessage(messageID)
+ msg, err := db.GetForumMessage(messageID)
if err != nil {
return c.Redirect(http.StatusFound, c.Request().Referer())
}
@@ -442,7 +454,7 @@ func ClubDeleteMessageHandler(c echo.Context) error {
return c.Redirect(http.StatusFound, c.Request().Referer())
}
- if err := database.DeleteForumMessageByID(messageID); err != nil {
+ if err := db.DeleteForumMessageByID(messageID); err != nil {
logrus.Error(err)
}
return c.Redirect(http.StatusFound, c.Request().Referer())
@@ -450,27 +462,29 @@ func ClubDeleteMessageHandler(c echo.Context) error {
func DeleteNotificationHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
notificationID := utils.DoParseInt64(c.Param("notificationID"))
var msg database.Notification
- if err := database.DB.Where("ID = ? AND user_id = ?", notificationID, authUser.ID).First(&msg).Error; err != nil {
+ if err := db.DB().Where("ID = ? AND user_id = ?", notificationID, authUser.ID).First(&msg).Error; err != nil {
logrus.Error(err)
return c.Redirect(http.StatusFound, c.Request().Referer())
}
- if err := database.DeleteNotificationByID(notificationID); err != nil {
+ if err := db.DeleteNotificationByID(notificationID); err != nil {
logrus.Error(err)
}
return c.Redirect(http.StatusFound, "/settings/inbox")
}
func DeleteSessionNotificationHandler(c echo.Context) error {
+ db := c.Get("database").(*database.DkfDB)
authCookie, _ := c.Cookie(hutils.AuthCookieName)
sessionNotificationID := utils.DoParseInt64(c.Param("sessionNotificationID"))
var msg database.SessionNotification
- if err := database.DB.Where("ID = ? AND session_token = ?", sessionNotificationID, authCookie.Value).First(&msg).Error; err != nil {
+ if err := db.DB().Where("ID = ? AND session_token = ?", sessionNotificationID, authCookie.Value).First(&msg).Error; err != nil {
logrus.Error(err)
return c.Redirect(http.StatusFound, c.Request().Referer())
}
- if err := database.DeleteSessionNotificationByID(sessionNotificationID); err != nil {
+ if err := db.DeleteSessionNotificationByID(sessionNotificationID); err != nil {
logrus.Error(err)
}
return c.Redirect(http.StatusFound, "/settings/inbox")
@@ -478,13 +492,14 @@ func DeleteSessionNotificationHandler(c echo.Context) error {
func ChatInboxDeleteMessageHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
messageID := utils.DoParseInt64(c.Param("messageID"))
var msg database.ChatInboxMessage
- if err := database.DB.Where("ID = ? AND to_user_id = ?", messageID, authUser.ID).First(&msg).Error; err != nil {
+ if err := db.DB().Where("ID = ? AND to_user_id = ?", messageID, authUser.ID).First(&msg).Error; err != nil {
logrus.Error(err)
return c.Redirect(http.StatusFound, "/settings/inbox")
}
- if err := database.DeleteChatInboxMessageByID(messageID); err != nil {
+ if err := db.DeleteChatInboxMessageByID(messageID); err != nil {
logrus.Error(err)
}
return c.Redirect(http.StatusFound, "/settings/inbox")
@@ -493,13 +508,14 @@ func ChatInboxDeleteMessageHandler(c echo.Context) error {
func ChatInboxDeleteAllMessageHandler(c echo.Context) error {
authCookie, _ := c.Cookie(hutils.AuthCookieName)
authUser := c.Get("authUser").(*database.User)
- if err := database.DeleteAllChatInbox(authUser.ID); err != nil {
+ db := c.Get("database").(*database.DkfDB)
+ if err := db.DeleteAllChatInbox(authUser.ID); err != nil {
logrus.Error(err)
}
- if err := database.DeleteAllNotifications(authUser.ID); err != nil {
+ if err := db.DeleteAllNotifications(authUser.ID); err != nil {
logrus.Error(err)
}
- if err := database.DeleteAllSessionNotifications(authCookie.Value); err != nil {
+ if err := db.DeleteAllSessionNotifications(authCookie.Value); err != nil {
logrus.Error(err)
}
return c.Redirect(http.StatusFound, "/settings/inbox")
@@ -513,6 +529,7 @@ func GetCaptchaHandler(c echo.Context) error {
func CaptchaSolverHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
captchaB64 := c.Request().PostFormValue("captcha")
answer, err := captcha.SolveBase64(captchaB64)
if err != nil {
@@ -524,7 +541,7 @@ func CaptchaSolverHandler(c echo.Context) error {
CaptchaImg: captchaB64,
Answer: answer,
}
- if err := database.DB.Create(&captchaReq).Error; err != nil {
+ if err := db.DB().Create(&captchaReq).Error; err != nil {
logrus.Error(err.Error())
}
return c.JSON(http.StatusOK, map[string]any{"answer": answer})
@@ -532,11 +549,12 @@ func CaptchaSolverHandler(c echo.Context) error {
func ChessHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
roomName := c.Request().PostFormValue("room")
enemyUsername := c.Request().PostFormValue("enemyUsername")
pos := c.Request().PostFormValue("move")
redirectURL := "/api/v1/chat/messages/" + roomName
- room, roomKey, err := dutils.GetRoomAndKey(c, roomName)
+ room, roomKey, err := dutils.GetRoomAndKey(db, c, roomName)
if err != nil {
return c.Redirect(http.StatusFound, redirectURL+"?error="+err.Error()+"&errorTs="+utils.FormatInt64(time.Now().Unix()))
}
@@ -547,10 +565,11 @@ func ChessHandler(c echo.Context) error {
}
func WerewolfHandler(c echo.Context) error {
+ db := c.Get("database").(*database.DkfDB)
roomName := "werewolf"
origMessage := c.Request().PostFormValue("message")
redirectURL := "/api/v1/chat/messages/" + roomName
- room, roomKey, err := dutils.GetRoomAndKey(c, roomName)
+ room, roomKey, err := dutils.GetRoomAndKey(db, c, roomName)
if err != nil {
return c.Redirect(http.StatusFound, redirectURL+"?error="+err.Error()+"&errorTs="+utils.FormatInt64(time.Now().Unix()))
}
@@ -564,11 +583,12 @@ func WerewolfHandler(c echo.Context) error {
func BattleshipHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
roomName := c.Request().PostFormValue("room")
enemyUsername := c.Request().PostFormValue("enemyUsername")
pos := c.Request().PostFormValue("move")
redirectURL := "/api/v1/chat/messages/" + roomName
- room, roomKey, err := dutils.GetRoomAndKey(c, roomName)
+ room, roomKey, err := dutils.GetRoomAndKey(db, c, roomName)
if err != nil {
return c.Redirect(http.StatusFound, redirectURL+"?error="+err.Error()+"&errorTs="+utils.FormatInt64(time.Now().Unix()))
}
diff --git a/pkg/web/handlers/api/v1/msgInterceptor.go b/pkg/web/handlers/api/v1/msgInterceptor.go
@@ -23,32 +23,32 @@ func (i MsgInterceptor) InterceptMsg(cmd *Command) {
return
}
- html, taggedUsersIDsMap := ProcessRawMessage(cmd.message, cmd.roomKey, cmd.authUser.ID, cmd.room.ID, cmd.upload)
+ html, taggedUsersIDsMap := ProcessRawMessage(cmd.db, cmd.message, cmd.roomKey, cmd.authUser.ID, cmd.room.ID, cmd.upload)
toUserID := database.UserPtrID(cmd.toUser)
- msgID, _ := database.CreateOrEditMessage(cmd.editMsg, html, cmd.origMessage, cmd.roomKey, cmd.room.ID, cmd.fromUserID, toUserID, cmd.upload, cmd.groupID, cmd.hellbanMsg, cmd.modMsg, cmd.systemMsg)
+ msgID, _ := cmd.db.CreateOrEditMessage(cmd.editMsg, html, cmd.origMessage, cmd.roomKey, cmd.room.ID, cmd.fromUserID, toUserID, cmd.upload, cmd.groupID, cmd.hellbanMsg, cmd.modMsg, cmd.systemMsg)
if !cmd.skipInboxes {
- sendInboxes(cmd.room, cmd.authUser, cmd.toUser, msgID, cmd.groupID, html, cmd.modMsg, taggedUsersIDsMap)
+ sendInboxes(cmd.db, cmd.room, cmd.authUser, cmd.toUser, msgID, cmd.groupID, html, cmd.modMsg, taggedUsersIDsMap)
}
// Count public messages in #general room
if cmd.room.ID == config.GeneralRoomID && cmd.toUser == nil {
cmd.authUser.GeneralMessagesCount++
- generalRoomKarma(cmd.authUser)
- cmd.authUser.DoSave()
+ generalRoomKarma(cmd.db, cmd.authUser)
+ cmd.authUser.DoSave(cmd.db)
}
// Update chat read marker
- database.UpdateChatReadMarker(cmd.authUser.ID, cmd.room.ID)
+ cmd.db.UpdateChatReadMarker(cmd.authUser.ID, cmd.room.ID)
// Update user activity
isPM := cmd.toUser != nil
updateUserActivity(isPM, cmd.modMsg, cmd.room, cmd.authUser)
}
-func generalRoomKarma(authUser *database.User) {
+func generalRoomKarma(db *database.DkfDB, authUser *database.User) {
// Hellban users ain't getting karma
if authUser.IsHellbanned {
return
@@ -56,30 +56,30 @@ func generalRoomKarma(authUser *database.User) {
messagesCount := authUser.GeneralMessagesCount
if messagesCount%100 == 0 {
description := fmt.Sprintf("sent %d messages", messagesCount)
- authUser.IncrKarma(1, description)
+ authUser.IncrKarma(db, 1, description)
} else if messagesCount == 20 {
- authUser.IncrKarma(1, "first 20 messages sent")
+ authUser.IncrKarma(db, 1, "first 20 messages sent")
}
}
// ProcessRawMessage return the new html, and a map of tagged users used for notifications
// This function takes an "unsafe" user input "in", and return html which will be safe to render.
-func ProcessRawMessage(in, roomKey string, authUserID database.UserID, roomID database.RoomID, upload *database.Upload) (string, map[database.UserID]database.User) {
- html, quoted := convertQuote(in, roomKey, roomID) // Get raw quote text which is not safe to render
- html = html2.EscapeString(html) // Makes user input safe to render
+func ProcessRawMessage(db *database.DkfDB, in, roomKey string, authUserID database.UserID, roomID database.RoomID, upload *database.Upload) (string, map[database.UserID]database.User) {
+ html, quoted := convertQuote(db, in, roomKey, roomID) // Get raw quote text which is not safe to render
+ html = html2.EscapeString(html) // Makes user input safe to render
// All html generated from this point on shall be safe to render.
- html = convertPGPClearsignToFile(html, authUserID)
- html = convertPGPMessageToFile(html, authUserID)
- html = convertPGPPublicKeyToFile(html, authUserID)
- html = convertAgeMessageToFile(html, authUserID)
+ html = convertPGPClearsignToFile(db, html, authUserID)
+ html = convertPGPMessageToFile(db, html, authUserID)
+ html = convertPGPPublicKeyToFile(db, html, authUserID)
+ html = convertAgeMessageToFile(db, html, authUserID)
html = convertLinksWithoutScheme(html)
html = convertMarkdown(html)
html = convertBangShortcuts(html)
- html = convertArchiveLinks(html, roomID, authUserID)
- html = convertLinks(html, database.GetUserByUsername)
+ html = convertArchiveLinks(db, html, roomID, authUserID)
+ html = convertLinks(html, db.GetUserByUsername)
html = linkDefaultRooms(html)
- html, taggedUsersIDsMap := colorifyTaggedUsers(html, database.GetUsersByUsername)
- html = linkRoomTags(html)
+ html, taggedUsersIDsMap := colorifyTaggedUsers(html, db.GetUsersByUsername)
+ html = linkRoomTags(db, html)
html = emojiReplacer.Replace(html)
html = styleQuote(html, quoted)
html = appendUploadLink(html, upload)
@@ -89,7 +89,7 @@ func ProcessRawMessage(in, roomKey string, authUserID database.UserID, roomID da
return html, taggedUsersIDsMap
}
-func sendInboxes(room database.ChatRoom, authUser, toUser *database.User, msgID int64, groupID *database.GroupID, html string, modMsg bool,
+func sendInboxes(db *database.DkfDB, room database.ChatRoom, authUser, toUser *database.User, msgID int64, groupID *database.GroupID, html string, modMsg bool,
taggedUsersIDsMap map[database.UserID]database.User) {
// Only have chat inbox for unencrypted messages
if room.IsProtected() {
@@ -104,13 +104,13 @@ func sendInboxes(room database.ChatRoom, authUser, toUser *database.User, msgID
return
}
- blacklistedBy, _ := database.GetPmBlacklistedByUsers(authUser.ID)
+ blacklistedBy, _ := db.GetPmBlacklistedByUsers(authUser.ID)
blacklistedByMap := make(map[database.UserID]struct{})
for _, b := range blacklistedBy {
blacklistedByMap[b.UserID] = struct{}{}
}
- ignoredBy, _ := database.GetIgnoredByUsers(authUser.ID)
+ ignoredBy, _ := db.GetIgnoredByUsers(authUser.ID)
ignoredByMap := make(map[database.UserID]struct{})
for _, b := range ignoredBy {
ignoredByMap[b.UserID] = struct{}{}
@@ -126,7 +126,7 @@ func sendInboxes(room database.ChatRoom, authUser, toUser *database.User, msgID
if _, ok := ignoredByMap[user.ID]; ok {
return
}
- database.CreateInboxMessage(html, room.ID, authUser.ID, user.ID, isPM, modCh, &msgID)
+ db.CreateInboxMessage(html, room.ID, authUser.ID, user.ID, isPM, modCh, &msgID)
}
}
@@ -147,7 +147,7 @@ func sendInboxes(room database.ChatRoom, authUser, toUser *database.User, msgID
}
} else if groupID != nil { // Only tags other people in the group
for _, user := range taggedUsersIDsMap {
- if database.IsUserInGroupByID(user.ID, *groupID) {
+ if db.IsUserInGroupByID(user.ID, *groupID) {
sendInbox(user, false, false)
}
}
diff --git a/pkg/web/handlers/api/v1/slashInterceptor.go b/pkg/web/handlers/api/v1/slashInterceptor.go
@@ -155,7 +155,7 @@ func handleModeratorGroupCmd(c *Command) (handled bool) {
func handleListModeratorsCmd(c *Command) (handled bool) {
if c.message == "/moderators" || c.message == "/mods" {
- mods, err := database.GetModeratorsUsers()
+ mods, err := c.db.GetModeratorsUsers()
if err != nil {
c.err = err
return true
@@ -231,11 +231,11 @@ func handleKickKeepSilentCmd(c *Command) (handled bool) {
}
func kickCmd(c *Command, username string, purge, silent bool) error {
- user, err := database.GetUserByUsername(username)
+ user, err := c.db.GetUserByUsername(username)
if err != nil {
return ErrUsernameNotFound
}
- return dutils.Kick(user, *c.authUser, purge, silent)
+ return dutils.Kick(c.db, user, *c.authUser, purge, silent)
}
var ErrUsernameNotFound = errors.New("username not found")
@@ -244,7 +244,7 @@ var ErrUnauthorized = errors.New("unauthorized")
func handleUnkickCmd(c *Command) (handled bool) {
if m := unkickRgx.FindStringSubmatch(c.message); len(m) == 2 {
username := m[1]
- user, err := database.GetUserByUsername(username)
+ user, err := c.db.GetUserByUsername(username)
if err != nil {
c.err = ErrUsernameNotFound
return true
@@ -253,12 +253,12 @@ func handleUnkickCmd(c *Command) (handled bool) {
c.err = errors.New("user already not kicked")
return true
}
- database.NewAudit(*c.authUser, fmt.Sprintf("unkick %s #%d", user.Username, user.ID))
+ c.db.NewAudit(*c.authUser, fmt.Sprintf("unkick %s #%d", user.Username, user.ID))
user.Verified = true
- user.DoSave()
+ user.DoSave(c.db)
// Display unkick message
- database.CreateUnkickMsg(user, *c.authUser)
+ c.db.CreateUnkickMsg(user, *c.authUser)
c.err = ErrRedirect
return true
@@ -269,15 +269,15 @@ func handleUnkickCmd(c *Command) (handled bool) {
func handleForceCaptchaCmd(c *Command) (handled bool) {
if m := forceCaptchaRgx.FindStringSubmatch(c.message); len(m) == 2 {
username := m[1]
- user, err := database.GetUserByUsername(username)
+ user, err := c.db.GetUserByUsername(username)
if err != nil {
c.err = ErrUsernameNotFound
return true
}
if c.authUser.IsAdmin || !user.IsModerator() || c.authUser.Username == username {
- database.NewAudit(*c.authUser, fmt.Sprintf("force captcha %s #%d", user.Username, user.ID))
+ c.db.NewAudit(*c.authUser, fmt.Sprintf("force captcha %s #%d", user.Username, user.ID))
user.CaptchaRequired = true
- user.DoSave()
+ user.DoSave(c.db)
}
c.err = ErrRedirect
return true
@@ -288,7 +288,7 @@ func handleForceCaptchaCmd(c *Command) (handled bool) {
func handleLogoutCmd(c *Command) (handled bool) {
if m := logoutRgx.FindStringSubmatch(c.message); len(m) == 2 {
username := m[1]
- user, err := database.GetUserByUsername(username)
+ user, err := c.db.GetUserByUsername(username)
if err != nil {
c.err = ErrUsernameNotFound
return true
@@ -298,9 +298,9 @@ func handleLogoutCmd(c *Command) (handled bool) {
return true
}
if c.authUser.IsAdmin || !user.IsModerator() {
- database.NewAudit(*c.authUser, fmt.Sprintf("logout %s #%d", user.Username, user.ID))
+ c.db.NewAudit(*c.authUser, fmt.Sprintf("logout %s #%d", user.Username, user.ID))
- _ = database.DeleteUserSessions(user.ID)
+ _ = c.db.DeleteUserSessions(user.ID)
// Remove user from the user cache
managers.ActiveUsers.RemoveUser(user.ID)
@@ -315,7 +315,7 @@ func handleLogoutCmd(c *Command) (handled bool) {
func handleResetTutorialCmd(c *Command) (handled bool) {
if m := rtutoRgx.FindStringSubmatch(c.message); len(m) == 2 {
username := m[1]
- user, err := database.GetUserByUsername(username)
+ user, err := c.db.GetUserByUsername(username)
if err != nil {
c.err = ErrUsernameNotFound
return true
@@ -325,9 +325,9 @@ func handleResetTutorialCmd(c *Command) (handled bool) {
return true
}
if c.authUser.IsAdmin || !user.IsModerator() {
- database.NewAudit(*c.authUser, fmt.Sprintf("rtuto %s #%d", user.Username, user.ID))
+ c.db.NewAudit(*c.authUser, fmt.Sprintf("rtuto %s #%d", user.Username, user.ID))
user.ChatTutorial = 0
- user.DoSave()
+ user.DoSave(c.db)
}
c.err = ErrRedirect
return true
@@ -338,7 +338,7 @@ func handleResetTutorialCmd(c *Command) (handled bool) {
func handleHellbanCmd(c *Command) (handled bool) {
if m := hellbanRgx.FindStringSubmatch(c.message); len(m) == 2 {
username := m[1]
- user, err := database.GetUserByUsername(username)
+ user, err := c.db.GetUserByUsername(username)
if err != nil {
c.err = ErrUsernameNotFound
return true
@@ -347,8 +347,8 @@ func handleHellbanCmd(c *Command) (handled bool) {
c.err = ErrUnauthorized
return true
}
- database.NewAudit(*c.authUser, fmt.Sprintf("hellban %s #%d", user.Username, user.ID))
- user.HellBan()
+ c.db.NewAudit(*c.authUser, fmt.Sprintf("hellban %s #%d", user.Username, user.ID))
+ user.HellBan(c.db)
managers.ActiveUsers.UpdateUserHBInRooms(managers.NewUserInfo(user, nil))
c.err = ErrRedirect
@@ -360,7 +360,7 @@ func handleHellbanCmd(c *Command) (handled bool) {
func handleUnhellbanCmd(c *Command) (handled bool) {
if m := unhellbanRgx.FindStringSubmatch(c.message); len(m) == 2 {
username := m[1]
- user, err := database.GetUserByUsername(username)
+ user, err := c.db.GetUserByUsername(username)
if err != nil {
c.err = ErrUsernameNotFound
return true
@@ -369,8 +369,8 @@ func handleUnhellbanCmd(c *Command) (handled bool) {
c.err = ErrUnauthorized
return true
}
- database.NewAudit(*c.authUser, fmt.Sprintf("unhellban %s #%d", user.Username, user.ID))
- user.UnHellBan()
+ c.db.NewAudit(*c.authUser, fmt.Sprintf("unhellban %s #%d", user.Username, user.ID))
+ user.UnHellBan(c.db)
managers.ActiveUsers.UpdateUserHBInRooms(managers.NewUserInfo(user, nil))
c.err = ErrRedirect
@@ -399,9 +399,9 @@ func handleHbmtCmd(c *Command) (handled bool) {
if m := hbmtRgx.FindStringSubmatch(c.message); len(m) == 2 {
date := m[1]
if dt, err := utils.ParsePrevDatetimeAt(date, clockwork.NewRealClock()); err == nil {
- if msg, err := database.GetRoomChatMessageByDate(c.room.ID, c.authUser.ID, dt.UTC()); err == nil {
+ if msg, err := c.db.GetRoomChatMessageByDate(c.room.ID, c.authUser.ID, dt.UTC()); err == nil {
msg.IsHellbanned = !msg.IsHellbanned
- msg.DoSave()
+ msg.DoSave(c.db)
} else {
c.err = errors.New("no message found at this timestamp")
return true
@@ -418,7 +418,7 @@ func handleDiceCmd(c *Command) (handled bool) {
dice := utils.RandInt(1, 6)
raw := fmt.Sprintf(`rolling dice for @%s ... "%d"`, c.authUser.Username, dice)
msg := fmt.Sprintf(`rolling dice for @%s ... "<span style="color: white;">%d</span>"`, c.authUser.Username, dice)
- msg, _ = colorifyTaggedUsers(msg, database.GetUsersByUsername)
+ msg, _ = colorifyTaggedUsers(msg, c.db.GetUsersByUsername)
go func() {
time.Sleep(time.Second)
c.zeroPublicMsg(raw, msg)
@@ -456,7 +456,7 @@ func handleRandCmd(c *Command) (handled bool) {
dice = utils.RandInt(min, max)
raw := fmt.Sprintf(`rolling dice for @%s ... "%d"`, c.authUser.Username, dice)
msg := fmt.Sprintf(`rolling dice for @%s ... "<span style="color: white;">%d</span>"`, c.authUser.Username, dice)
- msg, _ = colorifyTaggedUsers(msg, database.GetUsersByUsername)
+ msg, _ = colorifyTaggedUsers(msg, c.db.GetUsersByUsername)
go func() {
time.Sleep(time.Second)
c.zeroPublicMsg(raw, msg)
@@ -473,7 +473,7 @@ func handleChoiceCmd(c *Command) (handled bool) {
answer := utils.RandChoice(words)
raw := fmt.Sprintf(`@%s choice %s ... "%s"`, c.authUser.Username, words, answer)
msg := fmt.Sprintf(`@%s choice %s ... "<span style="color: white;">%s</span>"`, c.authUser.Username, words, answer)
- msg, _ = colorifyTaggedUsers(msg, database.GetUsersByUsername)
+ msg, _ = colorifyTaggedUsers(msg, c.db.GetUsersByUsername)
go func() {
time.Sleep(time.Second)
c.zeroPublicMsg(raw, msg)
@@ -533,7 +533,7 @@ func handleHasherCmd(c *Command, prefix string, fn func([]byte) string) (handled
func handleRmGroupCmd(c *Command) (handled bool) {
if m := rmGroupRgx.FindStringSubmatch(c.message); len(m) == 2 {
groupName := m[1]
- if err := database.DeleteChatRoomGroup(c.room.ID, groupName); err != nil {
+ if err := c.db.DeleteChatRoomGroup(c.room.ID, groupName); err != nil {
c.err = err
return true
}
@@ -546,13 +546,13 @@ func handleRmGroupCmd(c *Command) (handled bool) {
func handleLockGroupCmd(c *Command) (handled bool) {
if m := lockGroupRgx.FindStringSubmatch(c.message); len(m) == 2 {
groupName := m[1]
- group, err := database.GetRoomGroupByName(c.room.ID, groupName)
+ group, err := c.db.GetRoomGroupByName(c.room.ID, groupName)
if err != nil {
c.err = err
return true
}
group.Locked = true
- group.DoSave()
+ group.DoSave(c.db)
c.err = ErrRedirect
return true
}
@@ -562,13 +562,13 @@ func handleLockGroupCmd(c *Command) (handled bool) {
func handleUnlockGroupCmd(c *Command) (handled bool) {
if m := unlockGroupRgx.FindStringSubmatch(c.message); len(m) == 2 {
groupName := m[1]
- group, err := database.GetRoomGroupByName(c.room.ID, groupName)
+ group, err := c.db.GetRoomGroupByName(c.room.ID, groupName)
if err != nil {
c.err = err
return true
}
group.Locked = false
- group.DoSave()
+ group.DoSave(c.db)
c.err = ErrRedirect
return true
}
@@ -578,12 +578,12 @@ func handleUnlockGroupCmd(c *Command) (handled bool) {
func handleGroupUsersCmd(c *Command) (handled bool) {
if m := groupUsersRgx.FindStringSubmatch(c.message); len(m) == 2 {
groupName := m[1]
- group, err := database.GetRoomGroupByName(c.room.ID, groupName)
+ group, err := c.db.GetRoomGroupByName(c.room.ID, groupName)
if err != nil {
c.err = err
return true
}
- users, err := database.GetRoomGroupUsers(c.room.ID, group.ID)
+ users, err := c.db.GetRoomGroupUsers(c.room.ID, group.ID)
sort.Slice(users, func(i, j int) bool {
return users[i].User.Username < users[j].User.Username
})
@@ -604,7 +604,7 @@ func handleGroupUsersCmd(c *Command) (handled bool) {
func handleListGroupsCmd(c *Command) (handled bool) {
if c.message == "/groups" {
- groups, err := database.GetRoomGroups(c.room.ID)
+ groups, err := c.db.GetRoomGroups(c.room.ID)
if err != nil {
c.err = err
return true
@@ -628,17 +628,17 @@ func handleGroupAddUserCmd(c *Command) (handled bool) {
if m := groupAddUserRgx.FindStringSubmatch(c.message); len(m) == 3 {
groupName := m[1]
username := m[2]
- user, err := database.GetUserByUsername(username)
+ user, err := c.db.GetUserByUsername(username)
if err != nil {
c.err = err
return true
}
- group, err := database.GetRoomGroupByName(c.room.ID, groupName)
+ group, err := c.db.GetRoomGroupByName(c.room.ID, groupName)
if err != nil {
c.err = err
return true
}
- _, err = database.AddUserToRoomGroup(c.room.ID, group.ID, user.ID)
+ _, err = c.db.AddUserToRoomGroup(c.room.ID, group.ID, user.ID)
if err != nil {
c.err = err
return true
@@ -655,17 +655,17 @@ func handleGroupRmUserCmd(c *Command) (handled bool) {
if m := groupRmUserRgx.FindStringSubmatch(c.message); len(m) == 3 {
groupName := m[1]
username := m[2]
- user, err := database.GetUserByUsername(username)
+ user, err := c.db.GetUserByUsername(username)
if err != nil {
c.err = err
return true
}
- group, err := database.GetRoomGroupByName(c.room.ID, groupName)
+ group, err := c.db.GetRoomGroupByName(c.room.ID, groupName)
if err != nil {
c.err = err
return true
}
- err = database.RmUserFromRoomGroup(c.room.ID, group.ID, user.ID)
+ err = c.db.RmUserFromRoomGroup(c.room.ID, group.ID, user.ID)
if err != nil {
c.err = err
return true
@@ -681,7 +681,7 @@ func handleGroupRmUserCmd(c *Command) (handled bool) {
func handleSetModeWhitelistCmd(c *Command) (handled bool) {
if c.message == "/mode user-whitelist" {
c.room.Mode = database.UserWhitelistRoomMode
- c.room.DoSave()
+ c.room.DoSave(c.db)
c.message = `room mode set to "user whitelist"`
c.receivePM()
return true
@@ -692,7 +692,7 @@ func handleSetModeWhitelistCmd(c *Command) (handled bool) {
func handleSetModeStandardCmd(c *Command) (handled bool) {
if c.message == "/mode standard" {
c.room.Mode = database.NormalRoomMode
- c.room.DoSave()
+ c.room.DoSave(c.db)
c.message = `room mode set to "standard"`
c.receivePM()
return true
@@ -703,12 +703,12 @@ func handleSetModeStandardCmd(c *Command) (handled bool) {
func handleGetRoomWhitelistCmd(c *Command) (handled bool) {
if m := whitelistUserRgx.FindStringSubmatch(c.message); len(m) == 2 {
username := m[1]
- user, err := database.GetUserByUsername(username)
+ user, err := c.db.GetUserByUsername(username)
if err != nil {
c.message = fmt.Sprintf(`username "%s" not found`, username)
} else {
- if _, err := database.WhitelistUser(c.room.ID, user.ID); err != nil {
- if err := database.DeWhitelistUser(c.room.ID, user.ID); err != nil {
+ if _, err := c.db.WhitelistUser(c.room.ID, user.ID); err != nil {
+ if err := c.db.DeWhitelistUser(c.room.ID, user.ID); err != nil {
c.message = fmt.Sprintf("failed to toggle @%s in whitelist", user.Username)
} else {
c.message = fmt.Sprintf("@%s removed from whitelist", user.Username)
@@ -726,7 +726,7 @@ func handleGetRoomWhitelistCmd(c *Command) (handled bool) {
func handleAddGroupCmd(c *Command) (handled bool) {
if m := addGroupRgx.FindStringSubmatch(c.message); len(m) == 2 {
name := m[1]
- _, err := database.CreateChatRoomGroup(c.room.ID, name, "#fff")
+ _, err := c.db.CreateChatRoomGroup(c.room.ID, name, "#fff")
if err != nil {
c.err = err
return true
@@ -739,7 +739,7 @@ func handleAddGroupCmd(c *Command) (handled bool) {
func handleWhitelistCmd(c *Command) (handled bool) {
if c.message == "/whitelist" || c.message == "/wl" {
- whitelistedUsers, _ := database.GetWhitelistedUsers(c.room.ID)
+ whitelistedUsers, _ := c.db.GetWhitelistedUsers(c.room.ID)
if len(whitelistedUsers) > 0 {
usernames := make([]string, 0)
for _, whitelistedUser := range whitelistedUsers {
@@ -786,7 +786,7 @@ func handleEditCmd(c *Command) (handled bool) {
newMsg := m[2]
if dt, err := utils.ParsePrevDatetimeAt(date, clockwork.NewRealClock()); err == nil {
if time.Since(dt) <= config.EditMessageTimeLimit {
- if msg, err := database.GetRoomChatMessageByDate(c.room.ID, c.authUser.ID, dt.UTC()); err == nil {
+ if msg, err := c.db.GetRoomChatMessageByDate(c.room.ID, c.authUser.ID, dt.UTC()); err == nil {
c.editMsg = &msg
c.origMessage = newMsg
c.message = newMsg
@@ -794,7 +794,7 @@ func handleEditCmd(c *Command) (handled bool) {
// If we're editing a message which contains a link to an uploaded file,
// we need to re-add the link to the html.
if msg.UploadID != nil {
- if newUpload, err := database.GetUploadByID(*msg.UploadID); err == nil {
+ if newUpload, err := c.db.GetUploadByID(*msg.UploadID); err == nil {
c.upload = &newUpload
}
}
@@ -820,7 +820,7 @@ func handleEditCmd(c *Command) (handled bool) {
func handleEditLastCmd(c *Command) (handled bool) {
if c.message == "/e" {
- msg, err := database.GetUserLastChatMessageInRoom(c.authUser.ID, c.room.ID)
+ msg, err := c.db.GetUserLastChatMessageInRoom(c.authUser.ID, c.room.ID)
if err != nil {
return true
}
@@ -834,16 +834,16 @@ func handleEditLastCmd(c *Command) (handled bool) {
var ErrPMDenied = errors.New("you cannot pm/inbox this user")
var Err20Msgs = errors.New("you need 20 public messages to unlock PMs/Inbox; or be whitelisted")
-func canUserInboxOther(user, other database.User) error {
+func canUserInboxOther(db *database.DkfDB, user, other database.User) error {
doesNotMatter := false
- _, err := canUserPmOther(user, other, doesNotMatter)
+ _, err := canUserPmOther(db, user, other, doesNotMatter)
return err
}
-func canUserPmOther(user, other database.User, roomIsPrivate bool) (skipInbox bool, err error) {
+func canUserPmOther(db *database.DkfDB, user, other database.User, roomIsPrivate bool) (skipInbox bool, err error) {
errPMDenied := ErrPMDenied
- if database.IsUserPmWhitelisted(user.ID, other.ID) {
+ if db.IsUserPmWhitelisted(user.ID, other.ID) {
return false, nil
}
@@ -863,7 +863,7 @@ func canUserPmOther(user, other database.User, roomIsPrivate bool) (skipInbox bo
}
// User on blacklist cannot PM/Inbox
- if database.IsUserPmBlacklisted(user.ID, other.ID) {
+ if db.IsUserPmBlacklisted(user.ID, other.ID) {
return false, errPMDenied
}
// Other doesn't want PM from new users
@@ -888,13 +888,13 @@ func handlePMCmd(c *Command) (handled bool) {
return handlePm0(c, newMsg)
}
- user, err := database.GetUserByUsername(username)
+ user, err := c.db.GetUserByUsername(username)
if err != nil {
c.err = errors.New("invalid username")
return true
}
- c.skipInboxes, c.err = canUserPmOther(*c.authUser, user, c.room.IsOwned())
+ c.skipInboxes, c.err = canUserPmOther(c.db, *c.authUser, user, c.room.IsOwned())
if c.err != nil {
return true
}
@@ -988,7 +988,7 @@ func handlePm0(c *Command, msg string) (handled bool) {
// If we sent a clearsign file to @0, the bot will reply with information about the signature
if c.upload.FileSize < config.MaxFileSizeBeforeDownload {
- if file, err := database.GetUploadByFileName(c.upload.FileName); err == nil {
+ if file, err := c.db.GetUploadByFileName(c.upload.FileName); err == nil {
if _, by, err := file.GetContent(); err == nil {
if b, _ := clearsign.Decode(by); b != nil {
if p, err := packet.Read(b.ArmoredSignature.Body); err == nil {
@@ -1027,7 +1027,7 @@ func handlePm0(c *Command, msg string) (handled bool) {
func handleSubscribeCmd(c *Command) (handled bool) {
if c.message == "/subscribe" {
- _ = database.SubscribeToRoom(c.authUser.ID, c.room.ID)
+ _ = c.db.SubscribeToRoom(c.authUser.ID, c.room.ID)
c.err = ErrRedirect
return true
}
@@ -1036,17 +1036,17 @@ func handleSubscribeCmd(c *Command) (handled bool) {
func handleUnsubscribeCmd(c *Command) (handled bool) {
if m := unsubscribeRgx.FindStringSubmatch(c.message); len(m) == 2 {
- room, err := database.GetChatRoomByName(m[1])
+ room, err := c.db.GetChatRoomByName(m[1])
if err != nil {
c.err = err
return true
}
- _ = database.UnsubscribeFromRoom(c.authUser.ID, room.ID)
+ _ = c.db.UnsubscribeFromRoom(c.authUser.ID, room.ID)
c.err = ErrRedirect
return true
} else if c.message == "/unsubscribe" {
- _ = database.UnsubscribeFromRoom(c.authUser.ID, c.room.ID)
+ _ = c.db.UnsubscribeFromRoom(c.authUser.ID, c.room.ID)
c.err = ErrRedirect
return true
}
@@ -1057,7 +1057,7 @@ func handleGroupChatCmd(c *Command) (handled bool) {
if m := groupRgx.FindStringSubmatch(c.message); len(m) == 3 {
groupName := m[1]
c.message = m[2]
- group, err := database.GetRoomGroupByName(c.room.ID, groupName)
+ group, err := c.db.GetRoomGroupByName(c.room.ID, groupName)
if err != nil {
c.err = err
return true
@@ -1078,7 +1078,7 @@ func handleGroupChatCmd(c *Command) (handled bool) {
func handleListIgnoredCmd(c *Command) (handled bool) {
if c.message == "/i" || c.message == "/ignore" {
- ignoredUsers, _ := database.GetIgnoredUsers(c.authUser.ID)
+ ignoredUsers, _ := c.db.GetIgnoredUsers(c.authUser.ID)
sort.Slice(ignoredUsers, func(i, j int) bool {
return ignoredUsers[i].IgnoredUser.Username < ignoredUsers[j].IgnoredUser.Username
})
@@ -1099,7 +1099,7 @@ func handleListIgnoredCmd(c *Command) (handled bool) {
func handleListPmWhitelistCmd(c *Command) (handled bool) {
if c.message == "/pmwhitelist" {
- pmWhitelistUsers, _ := database.GetPmWhitelistedUsers(c.authUser.ID)
+ pmWhitelistUsers, _ := c.db.GetPmWhitelistedUsers(c.authUser.ID)
sort.Slice(pmWhitelistUsers, func(i, j int) bool {
return pmWhitelistUsers[i].WhitelistedUser.Username < pmWhitelistUsers[j].WhitelistedUser.Username
})
@@ -1121,7 +1121,7 @@ func handleListPmWhitelistCmd(c *Command) (handled bool) {
func handleSetPmModeWhitelistCmd(c *Command) (handled bool) {
if c.message == "/setpmmode whitelist" {
c.authUser.PmMode = database.PmModeWhitelist
- c.authUser.DoSave()
+ c.authUser.DoSave(c.db)
c.message = `pm mode set to "whitelist"`
c.receivePM()
return true
@@ -1132,7 +1132,7 @@ func handleSetPmModeWhitelistCmd(c *Command) (handled bool) {
func handleSetPmModeStandardCmd(c *Command) (handled bool) {
if c.message == "/setpmmode standard" {
c.authUser.PmMode = database.PmModeStandard
- c.authUser.DoSave()
+ c.authUser.DoSave(c.db)
c.message = `pm mode set to "standard"`
c.receivePM()
return true
@@ -1143,12 +1143,12 @@ func handleSetPmModeStandardCmd(c *Command) (handled bool) {
func handleTogglePmBlacklistedUser(c *Command) (handled bool) {
if m := pmToggleBlacklistUserRgx.FindStringSubmatch(c.message); len(m) == 2 {
username := m[1]
- user, err := database.GetUserByUsername(username)
+ user, err := c.db.GetUserByUsername(username)
if err != nil {
c.err = ErrRedirect
return true
}
- if database.ToggleBlacklistedUser(c.authUser.ID, user.ID) {
+ if c.db.ToggleBlacklistedUser(c.authUser.ID, user.ID) {
c.err = NewErrSuccess("added to blacklist")
} else {
c.err = NewErrSuccess("removed from blacklist")
@@ -1161,12 +1161,12 @@ func handleTogglePmBlacklistedUser(c *Command) (handled bool) {
func handleTogglePmWhitelistedUser(c *Command) (handled bool) {
if m := pmToggleWhitelistUserRgx.FindStringSubmatch(c.message); len(m) == 2 {
username := m[1]
- user, err := database.GetUserByUsername(username)
+ user, err := c.db.GetUserByUsername(username)
if err != nil {
c.err = ErrRedirect
return true
}
- if database.ToggleWhitelistedUser(c.authUser.ID, user.ID) {
+ if c.db.ToggleWhitelistedUser(c.authUser.ID, user.ID) {
c.err = NewErrSuccess("added to whitelist")
} else {
c.err = NewErrSuccess("removed from whitelist")
@@ -1180,7 +1180,7 @@ func handleChessCmd(c *Command) (handled bool) {
if m := chessRgx.FindStringSubmatch(c.message); len(m) == 2 {
username := m[1]
player1 := *c.authUser
- player2, err := database.GetUserByUsername(username)
+ player2, err := c.db.GetUserByUsername(username)
if err != nil {
c.err = errors.New("invalid username")
return true
@@ -1204,13 +1204,13 @@ func handleInboxCmd(c *Command) (handled bool) {
if encryptRaw == " -e" {
tryEncrypt = true
}
- toUser, err := database.GetUserByUsername(username)
+ toUser, err := c.db.GetUserByUsername(username)
if err != nil {
c.err = errors.New("invalid username")
return true
}
- if err := canUserInboxOther(*c.authUser, toUser); err != nil {
+ if err := canUserInboxOther(c.db, *c.authUser, toUser); err != nil {
c.err = err
return true
}
@@ -1229,8 +1229,8 @@ func handleInboxCmd(c *Command) (handled bool) {
html = strings.Join(strings.Split(html, "\n"), " ")
}
- html, _ = ProcessRawMessage(html, c.roomKey, c.authUser.ID, c.room.ID, nil)
- database.CreateInboxMessage(html, c.room.ID, c.authUser.ID, toUser.ID, true, false, nil)
+ html, _ = ProcessRawMessage(c.db, html, c.roomKey, c.authUser.ID, c.room.ID, nil)
+ c.db.CreateInboxMessage(html, c.room.ID, c.authUser.ID, toUser.ID, true, false, nil)
c.dataMessage = "/inbox " + username + " "
c.err = NewErrSuccess("inbox sent")
@@ -1246,7 +1246,7 @@ func handleInboxCmd(c *Command) (handled bool) {
func handleProfileCmd(c *Command) (handled bool) {
if m := profileRgx.FindStringSubmatch(c.message); len(m) == 2 {
username := m[1]
- user, err := database.GetUserByUsername(username)
+ user, err := c.db.GetUserByUsername(username)
if err != nil {
c.err = ErrUsernameNotFound
return true
@@ -1275,7 +1275,7 @@ type tutorialSteps struct {
func handleTutorialCmd(c *Command) (handled bool) {
if c.message == "/tuto" && false {
name := "tuto_" + utils.GenerateToken10()
- room, _ := database.CreateRoom(name, "", c.authUser.ID, false)
+ room, _ := c.db.CreateRoom(name, "", c.authUser.ID, false)
c.err = ErrRedirect
c.zeroProcMsg("Tutorial here -> #" + room.Name)
c.zeroPublicProcMsgRoom("Welcome to the tutorial", "", room.ID)
@@ -1288,12 +1288,12 @@ func handleDeleteMsgCmd(c *Command) (handled bool) {
delMsgFn := func(msg database.ChatMessage) {
if msg.RoomID == config.GeneralRoomID && msg.ToUserID == nil {
msg.User.GeneralMessagesCount--
- msg.User.DoSave()
+ msg.User.DoSave(c.db)
}
- _ = database.DeleteChatMessageByUUID(msg.UUID)
+ _ = c.db.DeleteChatMessageByUUID(msg.UUID)
}
if c.message == "/d" {
- if msg, err := database.GetUserLastChatMessageInRoom(c.authUser.ID, c.room.ID); err != nil {
+ if msg, err := c.db.GetUserLastChatMessageInRoom(c.authUser.ID, c.room.ID); err != nil {
c.err = errors.New("unable to find last message")
return true
} else if msg.TooOldToDelete() {
@@ -1315,7 +1315,7 @@ func handleDeleteMsgCmd(c *Command) (handled bool) {
c.err = err
return true
}
- msgs, err := database.GetRoomChatMessagesByDate(c.room.ID, dt.UTC())
+ msgs, err := c.db.GetRoomChatMessagesByDate(c.room.ID, dt.UTC())
if err != nil {
c.err = err
return true
@@ -1340,7 +1340,7 @@ func handleDeleteMsgCmd(c *Command) (handled bool) {
return true
}
// Moderator
- _ = database.DeleteChatMessageByUUID(msg.UUID)
+ _ = c.db.DeleteChatMessageByUUID(msg.UUID)
c.err = ErrRedirect
return true
@@ -1383,7 +1383,7 @@ func handleDeleteMsgCmd(c *Command) (handled bool) {
c.err = errors.New("failed to find msg")
return true
}
- _ = database.DeleteChatMessageByUUID(msg.UUID)
+ _ = c.db.DeleteChatMessageByUUID(msg.UUID)
c.err = ErrRedirect
return true
}
@@ -1406,13 +1406,13 @@ func handleHideMsgCmd(c *Command) (handled bool) {
c.err = err
return true
}
- msgs, err := database.GetRoomChatMessagesByDate(c.room.ID, dt.UTC())
+ msgs, err := c.db.GetRoomChatMessagesByDate(c.room.ID, dt.UTC())
if err != nil {
c.err = err
return true
}
if len(msgs) == 1 {
- database.IgnoreMessage(c.authUser.ID, msgs[0].ID)
+ c.db.IgnoreMessage(c.authUser.ID, msgs[0].ID)
c.err = ErrRedirect
} else {
c.err = errors.New("more than 1 message")
@@ -1431,13 +1431,13 @@ func handleUnHideMsgCmd(c *Command) (handled bool) {
c.err = err
return true
}
- msgs, err := database.GetRoomChatMessagesByDate(c.room.ID, dt.UTC())
+ msgs, err := c.db.GetRoomChatMessagesByDate(c.room.ID, dt.UTC())
if err != nil {
c.err = err
return true
}
if len(msgs) == 1 {
- database.UnIgnoreMessage(c.authUser.ID, msgs[0].ID)
+ c.db.UnIgnoreMessage(c.authUser.ID, msgs[0].ID)
c.err = ErrRedirect
} else {
c.err = errors.New("more than 1 message")
@@ -1450,12 +1450,12 @@ func handleUnHideMsgCmd(c *Command) (handled bool) {
func handleIgnoreCmd(c *Command) (handled bool) {
if m := ignoreRgx.FindStringSubmatch(c.message); len(m) == 2 {
username := m[1]
- user, err := database.GetUserByUsername(username)
+ user, err := c.db.GetUserByUsername(username)
if err != nil {
c.err = ErrRedirect
return true
}
- database.IgnoreUser(c.authUser.ID, user.ID)
+ c.db.IgnoreUser(c.authUser.ID, user.ID)
c.err = ErrRedirect
return true
} else if strings.HasPrefix(c.message, "/ignore ") || strings.HasPrefix(c.message, "/i ") {
@@ -1468,12 +1468,12 @@ func handleIgnoreCmd(c *Command) (handled bool) {
func handleUnIgnoreCmd(c *Command) (handled bool) {
if m := unIgnoreRgx.FindStringSubmatch(c.message); len(m) == 2 {
username := m[1]
- user, err := database.GetUserByUsername(username)
+ user, err := c.db.GetUserByUsername(username)
if err != nil {
c.err = ErrRedirect
return true
}
- database.UnIgnoreUser(c.authUser.ID, user.ID)
+ c.db.UnIgnoreUser(c.authUser.ID, user.ID)
c.err = ErrRedirect
return true
} else if strings.HasPrefix(c.message, "/unignore ") || strings.HasPrefix(c.message, "/ui ") {
@@ -1486,7 +1486,7 @@ func handleUnIgnoreCmd(c *Command) (handled bool) {
func handleToggleAutocomplete(c *Command) (handled bool) {
if c.message == "/toggle-autocomplete" {
c.authUser.AutocompleteCommandsEnabled = !c.authUser.AutocompleteCommandsEnabled
- c.authUser.DoSave()
+ c.authUser.DoSave(c.db)
c.err = ErrRedirect
return true
}
@@ -1527,13 +1527,13 @@ func handleSetChatRoomExternalLink(c *Command) (handled bool) {
if !govalidator.IsURL(externalURL) {
externalURL = ""
}
- room, err := database.GetChatRoomByID(c.room.ID)
+ room, err := c.db.GetChatRoomByID(c.room.ID)
if err != nil {
c.err = err
return true
}
room.ExternalLink = externalURL
- room.DoSave()
+ room.DoSave(c.db)
c.err = ErrRedirect
return true
}
@@ -1543,13 +1543,13 @@ func handleSetChatRoomExternalLink(c *Command) (handled bool) {
func handlePurge(c *Command) (handled bool) {
if m := purgeRgx.FindStringSubmatch(c.message); len(m) == 2 {
username := m[1]
- user, err := database.GetUserByUsername(username)
+ user, err := c.db.GetUserByUsername(username)
if err != nil {
c.err = err
return true
}
- database.NewAudit(*c.authUser, fmt.Sprintf("purge %s #%d", user.Username, user.ID))
- _ = database.DeleteUserChatMessages(user.ID)
+ c.db.NewAudit(*c.authUser, fmt.Sprintf("purge %s #%d", user.Username, user.ID))
+ _ = c.db.DeleteUserChatMessages(user.ID)
c.err = ErrRedirect
return true
}
@@ -1560,21 +1560,21 @@ func handleRename(c *Command) (handled bool) {
if m := renameRgx.FindStringSubmatch(c.message); len(m) == 3 {
oldUsername := m[1]
newUsername := m[2]
- user, err := database.GetUserByUsername(oldUsername)
+ user, err := c.db.GetUserByUsername(oldUsername)
if err != nil {
c.err = err
return true
}
- database.NewAudit(*c.authUser, fmt.Sprintf("rename %s -> %s #%d", user.Username, newUsername, user.ID))
+ c.db.NewAudit(*c.authUser, fmt.Sprintf("rename %s -> %s #%d", user.Username, newUsername, user.ID))
- if err := database.CanRenameTo(oldUsername, newUsername); err != nil {
+ if err := c.db.CanRenameTo(oldUsername, newUsername); err != nil {
c.err = err
return true
}
managers.ActiveUsers.RemoveUser(user.ID)
user.Username = newUsername
- user.DoSave()
+ user.DoSave(c.db)
c.err = ErrRedirect
return true
diff --git a/pkg/web/handlers/api/v1/snippetInterceptor.go b/pkg/web/handlers/api/v1/snippetInterceptor.go
@@ -11,16 +11,16 @@ type SnippetInterceptor struct{}
func (i SnippetInterceptor) InterceptMsg(cmd *Command) {
// Snippets actually mutate the original message,
// to simulate that the user actually typed the text
- cmd.origMessage = snippets(cmd.authUser.ID, cmd.origMessage)
+ cmd.origMessage = snippets(cmd.db, cmd.authUser.ID, cmd.origMessage)
cmd.origMessage = autocompleteTags(cmd.origMessage)
cmd.message = cmd.origMessage
}
-func snippets(authUserID database.UserID, html string) string {
+func snippets(db *database.DkfDB, authUserID database.UserID, html string) string {
if snippetRgx.MatchString(html) {
- userSnippets, _ := database.GetUserSnippets(authUserID)
+ userSnippets, _ := db.GetUserSnippets(authUserID)
if len(userSnippets) > 0 {
// Build hashmap for fast lookup
m := make(map[string]string)
diff --git a/pkg/web/handlers/api/v1/spamInterceptor.go b/pkg/web/handlers/api/v1/spamInterceptor.go
@@ -14,13 +14,13 @@ import (
type SpamInterceptor struct{}
func (i SpamInterceptor) InterceptMsg(c *Command) {
- if err := checkSpam(c.origMessage, c.authUser); err != nil {
+ if err := checkSpam(c.db, c.origMessage, c.authUser); err != nil {
c.err = err
return
}
// Check CP links
- if checkCPLinks(c.message) {
+ if checkCPLinks(c.db, c.message) {
c.err = errors.New("forbidden url")
return
}
@@ -32,7 +32,7 @@ func (i SpamInterceptor) InterceptMsg(c *Command) {
var ErrSpamFilterTriggered = errors.New("spam filter triggered")
-func checkSpam(origMessage string, authUser *database.User) error {
+func checkSpam(db *database.DkfDB, origMessage string, authUser *database.User) error {
lowerCaseMessage := strings.ToLower(origMessage)
silentSelfKick := config.SilentSelfKick.Load()
@@ -42,13 +42,13 @@ func checkSpam(origMessage string, authUser *database.User) error {
strings.Contains(lowerCaseMessage, "i wanna see gore") ||
strings.Contains(lowerCaseMessage, "how can i make money") ||
strings.Contains(lowerCaseMessage, "any links for scary stuff") {
- _ = dutils.SelfKick(*authUser, silentSelfKick)
+ _ = dutils.SelfKick(db, *authUser, silentSelfKick)
return ErrSpamFilterTriggered
}
}
if authUser.GeneralMessagesCount < 20 || time.Since(authUser.CreatedAt) < 5*time.Hour {
if strings.Contains(lowerCaseMessage, "cp link") {
- _ = dutils.SelfKick(*authUser, silentSelfKick)
+ _ = dutils.SelfKick(db, *authUser, silentSelfKick)
return ErrSpamFilterTriggered
}
}
@@ -57,7 +57,7 @@ func checkSpam(origMessage string, authUser *database.User) error {
if authUser.IsModerator() {
return ErrSpamFilterTriggered
}
- _ = dutils.SelfKick(*authUser, silentSelfKick)
+ _ = dutils.SelfKick(db, *authUser, silentSelfKick)
return ErrSpamFilterTriggered
}
@@ -66,7 +66,7 @@ func checkSpam(origMessage string, authUser *database.User) error {
count, total := utils.CountUppercase(origMessage)
pct := float64(count) / float64(total)
if total > 5 && pct > 0.8 {
- _ = dutils.SelfKick(*authUser, silentSelfKick)
+ _ = dutils.SelfKick(db, *authUser, silentSelfKick)
return ErrSpamFilterTriggered
}
}
@@ -103,14 +103,14 @@ func checkSpam(origMessage string, authUser *database.User) error {
if authUser.GeneralMessagesCount < 10 {
if wordsMap["porn"] > 0 && (wordsMap["link"] > 0 || wordsMap["links"] > 0) {
- _ = dutils.SelfKick(*authUser, silentSelfKick)
+ _ = dutils.SelfKick(db, *authUser, silentSelfKick)
return ErrSpamFilterTriggered
}
}
if authUser.GeneralMessagesCount < 20 || time.Since(authUser.CreatedAt) < 5*time.Hour {
if wordsMap["cp"] > 0 && (wordsMap["link"] > 0 || wordsMap["links"] > 0) {
- _ = dutils.SelfKick(*authUser, silentSelfKick)
+ _ = dutils.SelfKick(db, *authUser, silentSelfKick)
return ErrSpamFilterTriggered
}
}
diff --git a/pkg/web/handlers/api/v1/topBarHandler.go b/pkg/web/handlers/api/v1/topBarHandler.go
@@ -99,7 +99,7 @@ const (
redirectMultilineQP = "ml"
)
-func getDataMessagePrefix(c echo.Context, roomKey string, room database.ChatRoom, authUser *database.User) (out string, err error) {
+func getDataMessagePrefix(db *database.DkfDB, c echo.Context, roomKey string, room database.ChatRoom, authUser *database.User) (out string, err error) {
pm := c.QueryParam(redirectPmQP)
edit := c.QueryParam(redirectEditQP)
group := c.QueryParam(redirectGroupQP)
@@ -125,12 +125,12 @@ func getDataMessagePrefix(c echo.Context, roomKey string, room database.ChatRoom
} else if mtag != "" {
out = "/m @" + mtag + " "
} else if edit != "" {
- out, err = handleGetEdit(edit, roomKey, room, authUser)
+ out, err = handleGetEdit(db, edit, roomKey, room, authUser)
if err != nil {
return
}
} else if quote != "" {
- out, err = handleGetQuote(quote, roomKey, room, authUser)
+ out, err = handleGetQuote(db, quote, roomKey, room, authUser)
if err != nil {
return
}
@@ -210,6 +210,7 @@ func buildCommandsList(authUser *database.User, room database.ChatRoom) (command
func ChatTopBarHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
var data chatTopBarData
data.RoomName = c.Param("roomName")
@@ -229,7 +230,7 @@ func ChatTopBarHandler(c echo.Context) error {
}
}
- room, roomKey, err := dutils.GetRoomAndKey(c, data.RoomName)
+ room, roomKey, err := dutils.GetRoomAndKey(db, c, data.RoomName)
if err != nil {
return err
}
@@ -239,7 +240,7 @@ func ChatTopBarHandler(c echo.Context) error {
return c.Render(http.StatusOK, "chat-top-bar", data)
}
- data.Message, err = getDataMessagePrefix(c, roomKey, room, authUser)
+ data.Message, err = getDataMessagePrefix(db, c, roomKey, room, authUser)
if err != nil {
return c.Redirect(http.StatusFound, "/api/v1/chat/top-bar/"+room.Name)
}
@@ -331,8 +332,8 @@ func replTextPrefixSuffix(msg, prefix, suffix, repl string) (out string) {
return
}
-func handleGetQuote(msgUUID, roomKey string, room database.ChatRoom, authUser *database.User) (dataMessage string, err error) {
- quoted, err := database.GetRoomChatMessageByUUID(room.ID, msgUUID)
+func handleGetQuote(db *database.DkfDB, msgUUID, roomKey string, room database.ChatRoom, authUser *database.User) (dataMessage string, err error) {
+ quoted, err := db.GetRoomChatMessageByUUID(room.ID, msgUUID)
if err != nil {
return
}
@@ -354,14 +355,14 @@ func handleGetQuote(msgUUID, roomKey string, room database.ChatRoom, authUser *d
}
// Append the actual quoted text
- dataMessage = prefix + getQuoteTxt(roomKey, quoted) + " "
+ dataMessage = prefix + getQuoteTxt(db, roomKey, quoted) + " "
return
}
-func handleGetEdit(hourMinSec, roomKey string, room database.ChatRoom, authUser *database.User) (dataMessage string, err error) {
+func handleGetEdit(db *database.DkfDB, hourMinSec, roomKey string, room database.ChatRoom, authUser *database.User) (dataMessage string, err error) {
if dt, err := utils.ParsePrevDatetimeAt(hourMinSec, clockwork.NewRealClock()); err == nil {
if time.Since(dt) <= config.EditMessageTimeLimit {
- if msg, err := database.GetRoomChatMessageByDate(room.ID, authUser.ID, dt.UTC()); err == nil {
+ if msg, err := db.GetRoomChatMessageByDate(room.ID, authUser.ID, dt.UTC()); err == nil {
decrypted, err := msg.GetRawMessage(roomKey)
if err != nil {
return "", err
@@ -384,6 +385,7 @@ type Command struct {
room database.ChatRoom // Room the user is in
roomKey string // Room password (if any)
authUser *database.User // Authenticated user
+ db *database.DkfDB // Database instance
fromUserID database.UserID // Sender of message
toUser *database.User // If not nil, will be a PM
upload *database.Upload // If the message contains an uploaded file
@@ -399,9 +401,11 @@ type Command struct {
func NewCommand(c echo.Context, origMessage string, room database.ChatRoom, roomKey string) *Command {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
return &Command{
c: c,
authUser: authUser,
+ db: db,
fromUserID: authUser.ID,
hellbanMsg: authUser.IsHellbanned,
redirectQP: url.Values{},
@@ -425,7 +429,7 @@ func (c *Command) receivePM() {
// Lazy loading and cache of the zero user
func (c *Command) getZeroUser() database.User {
if c.zeroUser == nil {
- zeroUser := dutils.GetZeroUser()
+ zeroUser := dutils.GetZeroUser(c.db)
c.zeroUser = &zeroUser
}
return *c.zeroUser
@@ -444,14 +448,14 @@ func (c *Command) zeroProcMsg(rawMsg string) {
func (c *Command) zeroProcMsgRoom(rawMsg, roomKey string, roomID database.RoomID) {
zeroUser := c.getZeroUser()
- procMsg, _ := ProcessRawMessage(rawMsg, roomKey, c.authUser.ID, roomID, nil)
- rawMsgRoom(zeroUser, c.authUser, rawMsg, procMsg, roomKey, roomID)
+ procMsg, _ := ProcessRawMessage(c.db, rawMsg, roomKey, c.authUser.ID, roomID, nil)
+ rawMsgRoom(c.db, zeroUser, c.authUser, rawMsg, procMsg, roomKey, roomID)
}
func (c *Command) zeroPublicProcMsgRoom(rawMsg, roomKey string, roomID database.RoomID) {
zeroUser := c.getZeroUser()
- procMsg, _ := ProcessRawMessage(rawMsg, roomKey, c.authUser.ID, roomID, nil)
- rawMsgRoom(zeroUser, nil, rawMsg, procMsg, roomKey, roomID)
+ procMsg, _ := ProcessRawMessage(c.db, rawMsg, roomKey, c.authUser.ID, roomID, nil)
+ rawMsgRoom(c.db, zeroUser, nil, rawMsg, procMsg, roomKey, roomID)
}
func (c *Command) zeroPublicMsg(raw, msg string) {
@@ -460,15 +464,15 @@ func (c *Command) zeroPublicMsg(raw, msg string) {
}
func (c *Command) rawMsg(user1 database.User, user2 *database.User, raw, msg string) {
- rawMsgRoom(user1, user2, raw, msg, c.roomKey, c.room.ID)
+ rawMsgRoom(c.db, user1, user2, raw, msg, c.roomKey, c.room.ID)
}
-func rawMsgRoom(user1 database.User, user2 *database.User, raw, msg, roomKey string, roomID database.RoomID) {
+func rawMsgRoom(db *database.DkfDB, user1 database.User, user2 *database.User, raw, msg, roomKey string, roomID database.RoomID) {
var toUserID *database.UserID
if user2 != nil {
toUserID = &user2.ID
}
- _, _ = database.CreateMsg(raw, msg, roomKey, roomID, user1.ID, toUserID)
+ _, _ = db.CreateMsg(raw, msg, roomKey, roomID, user1.ID, toUserID)
}
type ErrSuccess struct {
@@ -493,12 +497,12 @@ func appendUploadLink(html string, upload *database.Upload) string {
return html
}
-func checkCPLinks(html string) bool {
+func checkCPLinks(db *database.DkfDB, html string) bool {
m1 := onionV3Rgx.FindAllStringSubmatch(html, -1)
m2 := onionV2Rgx.FindAllStringSubmatch(html, -1)
for _, m := range append(m1, m2...) {
hash := utils.MD5([]byte(m[0]))
- if _, err := database.GetOnionBlacklist(hash); err == nil {
+ if _, err := db.GetOnionBlacklist(hash); err == nil {
return true
}
}
@@ -525,7 +529,7 @@ func sanitizeUserInput(html string) string {
// Convert timestamps such as 01:23:45 to an archive link if a message with that timestamp exists.
// eg: "Some text 14:31:46 some more text"
-func convertArchiveLinks(html string, roomID database.RoomID, authUserID database.UserID) string {
+func convertArchiveLinks(db *database.DkfDB, html string, roomID database.RoomID, authUserID database.UserID) string {
start, rest := "", html
// Do not replace timestamps that are inside a quote text
@@ -548,7 +552,7 @@ func convertArchiveLinks(html string, roomID database.RoomID, authUserID databas
if err != nil {
return s
}
- if msgs, err := database.GetRoomChatMessagesByDate(roomID, dt.UTC()); err == nil && len(msgs) > 0 {
+ if msgs, err := db.GetRoomChatMessagesByDate(roomID, dt.UTC()); err == nil && len(msgs) > 0 {
msg := msgs[0]
if len(msgs) > 1 {
for _, msgTmp := range msgs {
@@ -613,10 +617,10 @@ func colorifyTaggedUsers(html string, getUsersByUsername getUsersByUsernameFn) (
return html, taggedUsersIDsMap
}
-func linkRoomTags(html string) string {
+func linkRoomTags(db *database.DkfDB, html string) string {
if roomTagRgx.MatchString(html) {
html = roomTagRgx.ReplaceAllStringFunc(html, func(s string) string {
- if room, err := database.GetChatRoomByName(strings.TrimPrefix(s, "#")); err == nil {
+ if room, err := db.GetChatRoomByName(strings.TrimPrefix(s, "#")); err == nil {
return `<a href="/chat/` + room.Name + `" target="_top">` + s + `</a>`
}
return s
@@ -626,9 +630,9 @@ func linkRoomTags(html string) string {
}
// Given a roomID and hourMinSec (01:23:45) and a username, retrieve the message from database that fits the predicates.
-func getQuotedChatMessage(hourMinSec, username string, roomID database.RoomID) (quoted *database.ChatMessage) {
+func getQuotedChatMessage(db *database.DkfDB, hourMinSec, username string, roomID database.RoomID) (quoted *database.ChatMessage) {
if dt, err := utils.ParsePrevDatetimeAt(hourMinSec, clockwork.NewRealClock()); err == nil {
- if msgs, err := database.GetRoomChatMessagesByDate(roomID, dt.UTC()); err == nil && len(msgs) > 0 {
+ if msgs, err := db.GetRoomChatMessagesByDate(roomID, dt.UTC()); err == nil && len(msgs) > 0 {
msg := msgs[0]
if len(msgs) > 1 {
for _, msgTmp := range msgs {
@@ -645,7 +649,7 @@ func getQuotedChatMessage(hourMinSec, username string, roomID database.RoomID) (
}
// Given a chat message, return the text to be used as a quote.
-func getQuoteTxt(roomKey string, quoted database.ChatMessage) (out string) {
+func getQuoteTxt(db *database.DkfDB, roomKey string, quoted database.ChatMessage) (out string) {
var err error
decrypted, err := quoted.GetRawMessage(roomKey)
if err != nil {
@@ -685,7 +689,7 @@ func getQuoteTxt(roomKey string, quoted database.ChatMessage) (out string) {
remaining += fmt.Sprintf(`%s `, quoted.User.Username)
}
if quoted.UploadID != nil {
- if upload, err := database.GetUploadByID(*quoted.UploadID); err == nil {
+ if upload, err := db.GetUploadByID(*quoted.UploadID); err == nil {
if decrypted != "" {
decrypted += " "
}
@@ -714,7 +718,7 @@ func getQuoteTxt(roomKey string, quoted database.ChatMessage) (out string) {
// eg: we received altered quote, and return original quote ->
// “[01:23:45] username - Some maliciously altered quote” Some text
// “[01:23:45] username - The original text” Some text
-func convertQuote(origHtml string, roomKey string, roomID database.RoomID) (html string, quoted *database.ChatMessage) {
+func convertQuote(db *database.DkfDB, origHtml string, roomKey string, roomID database.RoomID) (html string, quoted *database.ChatMessage) {
const quotePrefix = `“[`
const quoteSuffix = `”`
html = origHtml
@@ -725,8 +729,8 @@ func convertQuote(origHtml string, roomKey string, roomID database.RoomID) (html
if len(origHtml) > prefixLen+9 {
hourMinSec := origHtml[prefixLen : prefixLen+8]
username := origHtml[prefixLen+10 : strings.Index(origHtml[prefixLen+10:], " ")+prefixLen+10]
- if quoted = getQuotedChatMessage(hourMinSec, username, roomID); quoted != nil {
- html = getQuoteTxt(roomKey, *quoted)
+ if quoted = getQuotedChatMessage(db, hourMinSec, username, roomID); quoted != nil {
+ html = getQuoteTxt(db, roomKey, *quoted)
html += origHtml[idx+suffixLen:]
}
}
@@ -1036,7 +1040,7 @@ func extractPGPMessage(html string) (out string) {
}
// Auto convert pasted pgp message into uploaded file
-func convertPGPMessageToFile(html string, authUserID database.UserID) string {
+func convertPGPMessageToFile(db *database.DkfDB, html string, authUserID database.UserID) string {
startIdx := strings.Index(html, pgpPrefix)
endIdx := strings.Index(html, pgpSuffix)
if startIdx != -1 && endIdx != -1 {
@@ -1047,7 +1051,7 @@ func convertPGPMessageToFile(html string, authUserID database.UserID) string {
tmp = strings.Join(strings.Split(tmp, " "), "\n")
tmp = pgpPrefix + tmp
tmp += pgpSuffix
- upload, _ := database.CreateUpload("pgp.txt", []byte(tmp), authUserID)
+ upload, _ := db.CreateUpload("pgp.txt", []byte(tmp), authUserID)
msgBefore := html[0:startIdx]
msgAfter := html[endIdx+len(pgpSuffix):]
html = msgBefore + ` [` + upload.GetHTMLLink() + `] ` + msgAfter
@@ -1057,14 +1061,14 @@ func convertPGPMessageToFile(html string, authUserID database.UserID) string {
}
// Auto convert pasted pgp public key into uploaded file
-func convertPGPPublicKeyToFile(html string, authUserID database.UserID) string {
+func convertPGPPublicKeyToFile(db *database.DkfDB, html string, authUserID database.UserID) string {
startIdx := strings.Index(html, pgpPKeyPrefix)
endIdx := strings.Index(html, pgpPKeySuffix)
if startIdx != -1 && endIdx != -1 {
pkeySubSlice := html[startIdx : endIdx+len(pgpPKeySuffix)]
unescapedPkey := html2.UnescapeString(pkeySubSlice)
tmp := convertInlinePGPPublicKey(unescapedPkey)
- upload, _ := database.CreateUpload("pgp_pkey.txt", []byte(tmp), authUserID)
+ upload, _ := db.CreateUpload("pgp_pkey.txt", []byte(tmp), authUserID)
msgBefore := html[0:startIdx]
msgAfter := html[endIdx+len(pgpPKeySuffix):]
html = msgBefore + ` [` + upload.GetHTMLLink() + `] ` + msgAfter
@@ -1073,12 +1077,12 @@ func convertPGPPublicKeyToFile(html string, authUserID database.UserID) string {
return html
}
-func convertPGPClearsignToFile(html string, authUserID database.UserID) string {
+func convertPGPClearsignToFile(db *database.DkfDB, html string, authUserID database.UserID) string {
if b, _ := clearsign.Decode([]byte(html)); b != nil {
startIdx := strings.Index(html, pgpSignedPrefix)
endIdx := strings.Index(html, pgpSignedSuffix)
tmp := html[startIdx : endIdx+len(pgpSignedSuffix)]
- upload, _ := database.CreateUpload("pgp_clearsign.txt", []byte(tmp), authUserID)
+ upload, _ := db.CreateUpload("pgp_clearsign.txt", []byte(tmp), authUserID)
msgBefore := html[0:startIdx]
msgAfter := html[endIdx+len(pgpSignedSuffix):]
html = msgBefore + ` [` + upload.GetHTMLLink() + `] ` + msgAfter
@@ -1126,7 +1130,7 @@ func convertInlinePGPPublicKey(inlinePKey string) string {
}
// Auto convert pasted age message into uploaded file
-func convertAgeMessageToFile(html string, authUserID database.UserID) string {
+func convertAgeMessageToFile(db *database.DkfDB, html string, authUserID database.UserID) string {
startIdx := strings.Index(html, agePrefix)
endIdx := strings.Index(html, ageSuffix)
if startIdx != -1 && endIdx != -1 {
@@ -1137,7 +1141,7 @@ func convertAgeMessageToFile(html string, authUserID database.UserID) string {
tmp = strings.Join(strings.Split(tmp, " "), "\n")
tmp = agePrefix + tmp
tmp += ageSuffix
- upload, _ := database.CreateUpload("age.txt", []byte(tmp), authUserID)
+ upload, _ := db.CreateUpload("age.txt", []byte(tmp), authUserID)
msgBefore := html[0:startIdx]
msgAfter := html[endIdx+len(ageSuffix):]
html = msgBefore + ` [` + upload.GetHTMLLink() + `] ` + msgAfter
diff --git a/pkg/web/handlers/api/v1/uploadInterceptor.go b/pkg/web/handlers/api/v1/uploadInterceptor.go
@@ -20,7 +20,7 @@ func (i UploadInterceptor) InterceptMsg(cmd *Command) {
if file, handler, uploadErr := cmd.c.Request().FormFile("file"); uploadErr == nil {
// Save file on disk & database & append file link to html
var err error
- cmd.upload, err = handleUploadedFile(file, handler, cmd.authUser)
+ cmd.upload, err = handleUploadedFile(cmd.db, file, handler, cmd.authUser)
if err != nil {
cmd.err = err
return
@@ -28,12 +28,12 @@ func (i UploadInterceptor) InterceptMsg(cmd *Command) {
}
}
-func handleUploadedFile(file multipart.File, handler *multipart.FileHeader, authUser *database.User) (*database.Upload, error) {
+func handleUploadedFile(db *database.DkfDB, file multipart.File, handler *multipart.FileHeader, authUser *database.User) (*database.Upload, error) {
defer file.Close()
if !authUser.CanUpload() {
return nil, hutils.AccountTooYoungErr
}
- userSizeUploaded := database.GetUserTotalUploadSize(authUser.ID)
+ userSizeUploaded := db.GetUserTotalUploadSize(authUser.ID)
if handler.Size+userSizeUploaded > config.MaxUserTotalUploadSize {
return nil, fmt.Errorf("user upload limit reached (%s)", humanize.Bytes(config.MaxUserTotalUploadSize))
}
@@ -66,7 +66,7 @@ func handleUploadedFile(file multipart.File, handler *multipart.FileHeader, auth
}
// Uploaded files are encrypted on disk
- upload, err := database.CreateEncryptedUploadWithSize(origFileName, fileBytes, authUser.ID, handler.Size)
+ upload, err := db.CreateEncryptedUploadWithSize(origFileName, fileBytes, authUser.ID, handler.Size)
if err != nil {
logrus.Error(err)
return nil, err
diff --git a/pkg/web/handlers/api/v1/werewolf.go b/pkg/web/handlers/api/v1/werewolf.go
@@ -37,6 +37,7 @@ const (
var ErrInvalidPlayerName = errors.New("unknown player name, please send a valid name")
type Werewolf struct {
+ db *database.DkfDB
ctx context.Context
cancel context.CancelFunc
readyCh chan bool
@@ -105,7 +106,7 @@ func (b *Werewolf) InterceptPreGameMsg(cmd *Command) {
b.cancel()
time.Sleep(time.Second)
utils.SGo(func() {
- b.StartGame()
+ b.StartGame(cmd.db)
})
cmd.err = ErrRedirect
return
@@ -212,7 +213,7 @@ func (b *Werewolf) InterceptMsg(cmd *Command) {
cmd.err = ErrRedirect
return
} else if cmd.authUser.IsModerator() && cmd.message == "/clear" {
- _ = database.DeleteChatRoomMessages(b.roomID)
+ _ = cmd.db.DeleteChatRoomMessages(b.roomID)
b.Narrate(tuto, nil, nil)
cmd.err = ErrRedirect
return
@@ -328,12 +329,12 @@ func (b *Werewolf) isValidPlayerName(name string) bool {
// Narrate register a chat message on behalf of the narrator user
func (b *Werewolf) Narrate(msg string, toUserID *database.UserID, groupID *database.GroupID) {
- html, _ := ProcessRawMessage(msg, "", b.narratorID, b.roomID, nil)
+ html, _ := ProcessRawMessage(b.db, msg, "", b.narratorID, b.roomID, nil)
b.NarrateRaw(html, toUserID, groupID)
}
func (b *Werewolf) NarrateRaw(msg string, toUserID *database.UserID, groupID *database.GroupID) {
- _, _ = database.CreateOrEditMessage(nil, msg, msg, "", b.roomID, b.narratorID, toUserID, nil, groupID, false, false, false)
+ _, _ = b.db.CreateOrEditMessage(nil, msg, msg, "", b.roomID, b.narratorID, toUserID, nil, groupID, false, false, false)
}
// Display roles assigned at beginning of the Game
@@ -345,7 +346,7 @@ func (b *Werewolf) displayRoles() {
b.Narrate(msg, nil, nil)
}
-func (b *Werewolf) StartGame() {
+func (b *Werewolf) StartGame(db *database.DkfDB) {
defer func() {
b.displayRoles()
b.reset()
@@ -360,7 +361,7 @@ func (b *Werewolf) StartGame() {
for idx, player := range playersArr {
if idx == 0 {
b.werewolfSet.Insert(player.UserID)
- _, _ = database.AddUserToRoomGroup(b.roomID, b.werewolfGroupID, player.UserID)
+ _, _ = db.AddUserToRoomGroup(b.roomID, b.werewolfGroupID, player.UserID)
player.Role = WerewolfRole
werewolfMsg := "During the day you seem to be a regular Townsperson.\n" +
"However, you’ve been kissed by the Night and transform into a Werewolf when the sun sets.\n" +
@@ -420,7 +421,7 @@ func (b *Werewolf) StartGame() {
"There you find @"+playerNameToKill+"’s mangled remains by the Great Oak.\n"+
"Curiously, there are deep claw marks in the bark of the surrounding trees.\n"+
"It looks like @"+playerNameToKill+" put up a fight.", nil, nil)
- b.kill(playerNameToKill)
+ b.kill(db, playerNameToKill)
}
b.Narrate("Players still alive: "+b.alivePlayersStr(), nil, nil)
@@ -450,7 +451,7 @@ func (b *Werewolf) StartGame() {
b.Narrate("Townspeople do not want to execute anyone", nil, nil)
} else {
b.Narrate("Townspeople execute @"+killName, nil, nil)
- b.kill(killName)
+ b.kill(db, killName)
}
b.Narrate("Players still alive: "+b.alivePlayersStr(), nil, nil)
@@ -478,7 +479,7 @@ func (b *Werewolf) alivePlayersStr() (out string) {
}
// Kill a player
-func (b *Werewolf) kill(playerName string) {
+func (b *Werewolf) kill(db *database.DkfDB, playerName string) {
player, found := b.playersAlive[playerName]
if !found {
return
@@ -487,7 +488,7 @@ func (b *Werewolf) kill(playerName string) {
switch player.Role {
case WerewolfRole:
b.werewolfSet.Remove(player.UserID)
- _ = database.RmUserFromRoomGroup(b.roomID, b.werewolfGroupID, player.UserID)
+ _ = db.RmUserFromRoomGroup(b.roomID, b.werewolfGroupID, player.UserID)
case TownspeopleRole:
b.townspersonSet.Remove(player.UserID)
case HealerRole:
@@ -497,7 +498,7 @@ func (b *Werewolf) kill(playerName string) {
b.townspersonSet.Remove(player.UserID)
b.seerID = nil
}
- _, _ = database.AddUserToRoomGroup(b.roomID, b.deadGroupID, player.UserID)
+ _, _ = db.AddUserToRoomGroup(b.roomID, b.deadGroupID, player.UserID)
}
// Return the name of the player name that receive the most vote
@@ -637,15 +638,15 @@ func (b *Werewolf) LockGroups() {
}
func (b *Werewolf) LockGroup(groupName string) {
- group, _ := database.GetRoomGroupByName(b.roomID, groupName)
+ group, _ := b.db.GetRoomGroupByName(b.roomID, groupName)
group.Locked = true
- group.DoSave()
+ group.DoSave(b.db)
}
func (b *Werewolf) UnlockGroup(groupName string) {
- group, _ := database.GetRoomGroupByName(b.roomID, groupName)
+ group, _ := b.db.GetRoomGroupByName(b.roomID, groupName)
group.Locked = false
- group.DoSave()
+ group.DoSave(b.db)
}
type Player struct {
@@ -668,27 +669,28 @@ func (b *Werewolf) reset() {
b.healerCh = make(chan string)
b.votesCh = make(chan string)
b.readyCh = make(chan bool)
- _ = database.ClearRoomGroup(b.roomID, b.werewolfGroupID)
- _ = database.ClearRoomGroup(b.roomID, b.spectatorGroupID)
- _ = database.ClearRoomGroup(b.roomID, b.deadGroupID)
+ _ = b.db.ClearRoomGroup(b.roomID, b.werewolfGroupID)
+ _ = b.db.ClearRoomGroup(b.roomID, b.spectatorGroupID)
+ _ = b.db.ClearRoomGroup(b.roomID, b.deadGroupID)
}
-func NewWerewolf() *Werewolf {
+func NewWerewolf(db *database.DkfDB) *Werewolf {
// Prepare room
- room, err := database.GetChatRoomByName("werewolf")
+ room, err := db.GetChatRoomByName("werewolf")
if err != nil {
logrus.Error("#werewolf room not found")
return nil
}
- zeroUser, _ := database.GetUserByUsername(config.NullUsername)
- _ = database.DeleteChatRoomGroups(room.ID)
- werewolfGroup, _ := database.CreateChatRoomGroup(room.ID, "werewolf", "#ffffff")
+ zeroUser, _ := db.GetUserByUsername(config.NullUsername)
+ _ = db.DeleteChatRoomGroups(room.ID)
+ werewolfGroup, _ := db.CreateChatRoomGroup(room.ID, "werewolf", "#ffffff")
werewolfGroup.Locked = true
- werewolfGroup.DoSave()
- spectatorGroup, _ := database.CreateChatRoomGroup(room.ID, "spectator", "#ffffff")
- deadGroup, _ := database.CreateChatRoomGroup(room.ID, "dead", "#ffffff")
+ werewolfGroup.DoSave(db)
+ spectatorGroup, _ := db.CreateChatRoomGroup(room.ID, "spectator", "#ffffff")
+ deadGroup, _ := db.CreateChatRoomGroup(room.ID, "dead", "#ffffff")
b := new(Werewolf)
+ b.db = db
b.werewolfGroupID = werewolfGroup.ID
b.spectatorGroupID = spectatorGroup.ID
b.deadGroupID = deadGroup.ID
diff --git a/pkg/web/handlers/chat.go b/pkg/web/handlers/chat.go
@@ -14,6 +14,7 @@ import (
func chatHandler(c echo.Context, redRoom bool) error {
const chatPasswordTmplName = "standalone.chat-password"
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
var data chatData
data.RedRoom = redRoom
preventRefresh := utils.DoParseBool(c.QueryParam("r"))
@@ -44,7 +45,7 @@ func chatHandler(c echo.Context, redRoom bool) error {
data.CaptchaID, data.CaptchaImg = captcha.New()
}
- room, err := database.GetChatRoomByName(getRoomName(c))
+ room, err := db.GetChatRoomByName(getRoomName(c))
if err != nil {
return c.Redirect(http.StatusFound, "/")
}
@@ -59,13 +60,13 @@ func chatHandler(c echo.Context, redRoom bool) error {
data.TutoFrames = generateCssFrames(data.TutoSecs, nil, true)
if c.Request().Method == http.MethodGet {
authUser.ChatTutorialTime = time.Now()
- authUser.DoSave()
+ authUser.DoSave(db)
}
}
}
if c.Request().Method == http.MethodPost {
- return handlePost(c, data, authUser)
+ return handlePost(db, c, data, authUser)
}
// If you don't have access to the room (room is protected and user is nil or no cookie with the password)
@@ -77,7 +78,7 @@ func chatHandler(c echo.Context, redRoom bool) error {
return c.Render(http.StatusOK, chatPasswordTmplName, data)
}
- data.IsSubscribed = database.IsUserSubscribedToRoom(authUser.ID, room.ID)
+ data.IsSubscribed = db.IsUserSubscribedToRoom(authUser.ID, room.ID)
data.IsOfficialRoom = room.IsOfficialRoom()
return c.Render(http.StatusOK, "chat", data)
}
@@ -90,25 +91,25 @@ func getRoomName(c echo.Context) string {
return roomName
}
-func handlePost(c echo.Context, data chatData, authUser *database.User) error {
+func handlePost(db *database.DkfDB, c echo.Context, data chatData, authUser *database.User) error {
formName := c.Request().PostFormValue("formName")
switch formName {
case "logout":
return handleLogoutPost(c, data.Room)
case "toggle-hb":
- return handleToggleHBPost(c, authUser)
+ return handleToggleHBPost(db, c, authUser)
case "toggle-m":
- return handleToggleMPost(c, authUser)
+ return handleToggleMPost(db, c, authUser)
case "toggle-ignored":
- return handleToggleIgnoredPost(c, authUser)
+ return handleToggleIgnoredPost(db, c, authUser)
case "afk":
- return handleAfkPost(c, authUser)
+ return handleAfkPost(db, c, authUser)
case "update-read-marker":
- return handleUpdateReadMarkerPost(c, data.Room, authUser)
+ return handleUpdateReadMarkerPost(db, c, data.Room, authUser)
case "tutorialP1", "tutorialP2", "tutorialP3":
- return handleTutorialPost(c, data, authUser)
+ return handleTutorialPost(db, c, data, authUser)
case "chat-password":
- return handleChatPasswordPost(c, data, authUser)
+ return handleChatPasswordPost(db, c, data, authUser)
}
return c.Redirect(http.StatusFound, c.Request().Referer())
}
@@ -119,49 +120,49 @@ func handleLogoutPost(c echo.Context, room database.ChatRoom) error {
return c.Redirect(http.StatusFound, "/chat")
}
-func handleToggleHBPost(c echo.Context, authUser *database.User) error {
+func handleToggleHBPost(db *database.DkfDB, c echo.Context, authUser *database.User) error {
if authUser.CanSeeHB() {
authUser.DisplayHellbanned = !authUser.DisplayHellbanned
- authUser.DoSave()
+ authUser.DoSave(db)
}
return c.Redirect(http.StatusFound, c.Request().Referer())
}
-func handleToggleMPost(c echo.Context, authUser *database.User) error {
+func handleToggleMPost(db *database.DkfDB, c echo.Context, authUser *database.User) error {
if authUser.IsModerator() {
authUser.DisplayModerators = !authUser.DisplayModerators
- authUser.DoSave()
+ authUser.DoSave(db)
}
return c.Redirect(http.StatusFound, c.Request().Referer())
}
-func handleToggleIgnoredPost(c echo.Context, authUser *database.User) error {
+func handleToggleIgnoredPost(db *database.DkfDB, c echo.Context, authUser *database.User) error {
authUser.DisplayIgnored = !authUser.DisplayIgnored
- authUser.DoSave()
+ authUser.DoSave(db)
return c.Redirect(http.StatusFound, c.Request().Referer())
}
-func handleAfkPost(c echo.Context, authUser *database.User) error {
+func handleAfkPost(db *database.DkfDB, c echo.Context, authUser *database.User) error {
authUser.AFK = !authUser.AFK
- authUser.DoSave()
+ authUser.DoSave(db)
return c.Redirect(http.StatusFound, c.Request().Referer())
}
-func handleUpdateReadMarkerPost(c echo.Context, room database.ChatRoom, authUser *database.User) error {
- database.UpdateChatReadMarker(authUser.ID, room.ID)
+func handleUpdateReadMarkerPost(db *database.DkfDB, c echo.Context, room database.ChatRoom, authUser *database.User) error {
+ db.UpdateChatReadMarker(authUser.ID, room.ID)
return c.Redirect(http.StatusFound, c.Request().Referer())
}
-func handleTutorialPost(c echo.Context, data chatData, authUser *database.User) error {
+func handleTutorialPost(db *database.DkfDB, c echo.Context, data chatData, authUser *database.User) error {
if authUser.ChatTutorial < 3 && time.Since(authUser.ChatTutorialTime) >= time.Duration(data.TutoSecs)*time.Second {
authUser.ChatTutorial++
- authUser.DoSave()
+ authUser.DoSave(db)
}
return c.Redirect(http.StatusFound, c.Request().Referer())
}
// Handle POST requests for chat-password, when someone tries to authenticate in a protected room providing a password.
-func handleChatPasswordPost(c echo.Context, data chatData, authUser *database.User) error {
+func handleChatPasswordPost(db *database.DkfDB, c echo.Context, data chatData, authUser *database.User) error {
const chatPasswordTmplName = "standalone.chat-password"
data.RoomPassword = c.Request().PostFormValue("password")
@@ -175,7 +176,7 @@ func handleChatPasswordPost(c echo.Context, data chatData, authUser *database.Us
return c.Render(http.StatusOK, chatPasswordTmplName, data)
}
- if err := database.CanUseUsername(data.GuestUsername, false); err != nil {
+ if err := db.CanUseUsername(data.GuestUsername, false); err != nil {
data.ErrGuestUsername = err.Error()
return c.Render(http.StatusOK, chatPasswordTmplName, data)
}
@@ -193,13 +194,13 @@ func handleChatPasswordPost(c echo.Context, data chatData, authUser *database.Us
// TODO: maybe add "_guest" suffix to guest accounts?
if authUser == nil {
password := utils.GenerateToken32()
- newUser, errs := database.CreateGuestUser(data.GuestUsername, password)
+ newUser, errs := db.CreateGuestUser(data.GuestUsername, password)
if errs.HasError() {
data.ErrGuestUsername = errs.Username
return c.Render(http.StatusOK, chatPasswordTmplName, data)
}
- session := database.DoCreateSession(newUser.ID, c.Request().UserAgent())
+ session := db.DoCreateSession(newUser.ID, c.Request().UserAgent())
c.SetCookie(createSessionCookie(session.Token))
}
diff --git a/pkg/web/handlers/club.go b/pkg/web/handlers/club.go
@@ -10,14 +10,16 @@ import (
func ClubHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
var data clubData
data.ActiveTab = "home"
- data.ForumThreads, _ = database.GetClubForumThreads(authUser.ID)
+ data.ForumThreads, _ = db.GetClubForumThreads(authUser.ID)
return c.Render(http.StatusOK, "club.home", data)
}
func ClubNewThreadHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
var data clubNewThreadData
data.ActiveTab = "home"
@@ -33,9 +35,9 @@ func ClubNewThreadHandler(c echo.Context) error {
return c.Render(http.StatusOK, "club.new-thread", data)
}
thread := database.MakeForumThread(data.ThreadName, authUser.ID, 0)
- database.DB.Create(&thread)
+ db.DB().Create(&thread)
message := database.MakeForumMessage(data.Message, authUser.ID, thread.ID)
- database.DB.Create(&message)
+ db.DB().Create(&message)
return c.Redirect(http.StatusFound, "/club/threads/"+utils.FormatInt64(int64(thread.ID)))
}
@@ -44,8 +46,9 @@ func ClubNewThreadHandler(c echo.Context) error {
func ClubThreadReplyHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
threadID := database.ForumThreadID(utils.DoParseInt64(c.Param("threadID")))
- thread, err := database.GetForumThread(threadID)
+ thread, err := db.GetForumThread(threadID)
if err != nil {
return c.Redirect(http.StatusFound, "/")
}
@@ -60,7 +63,7 @@ func ClubThreadReplyHandler(c echo.Context) error {
return c.Render(http.StatusOK, "club.new-thread", data)
}
message := database.MakeForumMessage(data.Message, authUser.ID, thread.ID)
- database.DB.Create(&message)
+ db.DB().Create(&message)
return c.Redirect(http.StatusFound, "/club/threads/"+utils.FormatInt64(int64(thread.ID)))
}
@@ -69,13 +72,14 @@ func ClubThreadReplyHandler(c echo.Context) error {
func ClubThreadEditMessageHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
threadID := database.ForumThreadID(utils.DoParseInt64(c.Param("threadID")))
messageID := database.ForumMessageID(utils.DoParseInt64(c.Param("messageID")))
- thread, err := database.GetForumThread(threadID)
+ thread, err := db.GetForumThread(threadID)
if err != nil {
return c.Redirect(http.StatusFound, "/")
}
- msg, err := database.GetForumMessage(messageID)
+ msg, err := db.GetForumMessage(messageID)
if err != nil {
return c.Redirect(http.StatusFound, "/")
}
@@ -95,7 +99,7 @@ func ClubThreadEditMessageHandler(c echo.Context) error {
return c.Render(http.StatusOK, "club.new-thread", data)
}
msg.Message = data.Message
- msg.DoSave()
+ msg.DoSave(db)
return c.Redirect(http.StatusFound, "/club/threads/"+utils.FormatInt64(int64(thread.ID)))
}
@@ -103,26 +107,28 @@ func ClubThreadEditMessageHandler(c echo.Context) error {
}
func ClubMembersHandler(c echo.Context) error {
+ db := c.Get("database").(*database.DkfDB)
var data clubMembersData
data.ActiveTab = "members"
- data.Members, _ = database.GetClubMembers()
+ data.Members, _ = db.GetClubMembers()
return c.Render(http.StatusOK, "club.members", data)
}
func ClubThreadHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
threadID := database.ForumThreadID(utils.DoParseInt64(c.Param("threadID")))
- thread, err := database.GetForumThread(threadID)
+ thread, err := db.GetForumThread(threadID)
if err != nil {
return c.Redirect(http.StatusFound, "/")
}
var data clubThreadData
data.ActiveTab = "home"
data.Thread = thread
- data.Messages, _ = database.GetThreadMessages(threadID)
+ data.Messages, _ = db.GetThreadMessages(threadID)
// Update read record
- database.UpdateForumReadRecord(authUser.ID, threadID)
+ db.UpdateForumReadRecord(authUser.ID, threadID)
return c.Render(http.StatusOK, "club.thread", data)
}
diff --git a/pkg/web/handlers/handlers.go b/pkg/web/handlers/handlers.go
@@ -64,6 +64,7 @@ import (
func firstUseHandler(c echo.Context) error {
user := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
var data firstUseData
if user != nil {
return c.Redirect(http.StatusFound, "/")
@@ -80,17 +81,17 @@ func firstUseHandler(c echo.Context) error {
data.Username = c.Request().PostFormValue("username")
data.Password = c.Request().PostFormValue("password")
data.RePassword = c.Request().PostFormValue("repassword")
- newUser, errs := database.CreateFirstUser(data.Username, data.Password, data.RePassword)
+ newUser, errs := db.CreateFirstUser(data.Username, data.Password, data.RePassword)
data.Errors = errs
if errs.HasError() {
return c.Render(http.StatusOK, "standalone.first-use", data)
}
- _, errs = database.CreateZeroUser()
+ _, errs = db.CreateZeroUser()
config.IsFirstUse.SetFalse()
- session := database.DoCreateSession(newUser.ID, c.Request().UserAgent())
+ session := db.DoCreateSession(newUser.ID, c.Request().UserAgent())
c.SetCookie(createSessionCookie(session.Token))
return c.Redirect(http.StatusFound, "/")
@@ -276,6 +277,7 @@ func LoginAttackHandler(c echo.Context) error {
func loginHandler(c echo.Context) error {
formName := c.Request().PostFormValue("formName")
+ db := c.Get("database").(*database.DkfDB)
if formName == "" {
var data loginData
data.Autofocus = 0
@@ -287,7 +289,7 @@ func loginHandler(c echo.Context) error {
actualLogin := func(username, password string, captchaSolved bool) error {
username = strings.TrimSpace(username)
- user, err := database.GetVerifiedUserByUsername(username)
+ user, err := db.GetVerifiedUserByUsername(username)
if err != nil {
time.Sleep(utils.RandMs(50, 200))
data.Error = "Invalid username/password"
@@ -295,7 +297,7 @@ func loginHandler(c echo.Context) error {
}
user.LoginAttempts++
- user.DoSave()
+ user.DoSave(db)
if user.LoginAttempts > 4 && !captchaSolved {
data.CaptchaRequired = true
@@ -315,7 +317,7 @@ func loginHandler(c echo.Context) error {
}
}
- if !user.CheckPassword(password) {
+ if !user.CheckPassword(db, password) {
data.Password = ""
data.Autofocus = 1
data.Error = "Invalid username/password"
@@ -390,16 +392,17 @@ func loginHandler(c echo.Context) error {
}
func completeLogin(c echo.Context, user database.User) error {
+ db := c.Get("database").(*database.DkfDB)
user.LoginAttempts = 0
- user.DoSave()
+ user.DoSave(db)
- for _, session := range database.GetActiveUserSessions(user.ID) {
+ for _, session := range db.GetActiveUserSessions(user.ID) {
msg := fmt.Sprintf(`New login`)
- database.CreateSessionNotification(msg, session.Token)
+ db.CreateSessionNotification(msg, session.Token)
}
- session := database.DoCreateSession(user.ID, c.Request().UserAgent())
- database.CreateSecurityLog(user.ID, database.LoginSecurityLog)
+ session := db.DoCreateSession(user.ID, c.Request().UserAgent())
+ db.CreateSecurityLog(user.ID, database.LoginSecurityLog)
c.SetCookie(createSessionCookie(session.Token))
redirectURL := "/"
@@ -424,12 +427,13 @@ func LoginCompletedHandler(c echo.Context) error {
// SessionsGpgTwoFactorHandler ...
func SessionsGpgTwoFactorHandler(c echo.Context, step1 bool, token string) error {
+ db := c.Get("database").(*database.DkfDB)
item, found := partialAuthCache.Get(token)
if !found || item.Step != PgpStep {
return c.Redirect(http.StatusFound, "/")
}
- user, err := database.GetUserByID(item.UserID)
+ user, err := db.GetUserByID(item.UserID)
if err != nil {
logrus.Errorf("failed to get user %d", item.UserID)
return c.Redirect(http.StatusFound, "/")
@@ -472,12 +476,13 @@ func SessionsGpgTwoFactorHandler(c echo.Context, step1 bool, token string) error
// SessionsGpgSignTwoFactorHandler ...
func SessionsGpgSignTwoFactorHandler(c echo.Context, step1 bool, token string) error {
+ db := c.Get("database").(*database.DkfDB)
item, found := partialAuthCache.Get(token)
if !found || item.Step != PgpSignStep {
return c.Redirect(http.StatusFound, "/")
}
- user, err := database.GetUserByID(item.UserID)
+ user, err := db.GetUserByID(item.UserID)
if err != nil {
logrus.Errorf("failed to get user %d", item.UserID)
return c.Redirect(http.StatusFound, "/")
@@ -516,6 +521,7 @@ func SessionsGpgSignTwoFactorHandler(c echo.Context, step1 bool, token string) e
// SessionsTwoFactorHandler ...
func SessionsTwoFactorHandler(c echo.Context, step1 bool, token string) error {
+ db := c.Get("database").(*database.DkfDB)
item, found := partialAuthCache.Get(token)
if !found || item.Step != TwoFactorStep {
return c.Redirect(http.StatusFound, "/")
@@ -525,7 +531,7 @@ func SessionsTwoFactorHandler(c echo.Context, step1 bool, token string) error {
data.Token = token
if !step1 {
code := c.Request().PostFormValue("code")
- user, err := database.GetUserByID(item.UserID)
+ user, err := db.GetUserByID(item.UserID)
if err != nil {
logrus.Errorf("failed to get user %d", item.UserID)
return c.Redirect(http.StatusFound, "/")
@@ -545,6 +551,7 @@ func SessionsTwoFactorHandler(c echo.Context, step1 bool, token string) error {
// SessionsTwoFactorRecoveryHandler ...
func SessionsTwoFactorRecoveryHandler(c echo.Context, token string) error {
+ db := c.Get("database").(*database.DkfDB)
item, found := partialAuthCache.Get(token)
if !found {
return c.Redirect(http.StatusFound, "/")
@@ -554,7 +561,7 @@ func SessionsTwoFactorRecoveryHandler(c echo.Context, token string) error {
data.Token = token
recoveryCode := c.Request().PostFormValue("code")
if recoveryCode != "" {
- user, err := database.GetUserByID(item.UserID)
+ user, err := db.GetUserByID(item.UserID)
if err != nil {
logrus.Errorf("failed to get user %d", item.UserID)
return c.Redirect(http.StatusFound, "/")
@@ -574,21 +581,22 @@ func SessionsTwoFactorRecoveryHandler(c echo.Context, token string) error {
// LogoutHandler for logout route
func LogoutHandler(ctx echo.Context) error {
authUser := ctx.Get("authUser").(*database.User)
+ db := ctx.Get("database").(*database.DkfDB)
c, _ := ctx.Cookie(hutils.AuthCookieName)
- if err := database.DeleteSessionByToken(c.Value); err != nil {
- logrus.Error("Failed to remove session from DB : ", err)
+ if err := db.DeleteSessionByToken(c.Value); err != nil {
+ logrus.Error("Failed to remove session from db : ", err)
}
if authUser.TerminateAllSessionsOnLogout {
// Delete active user sessions
- if err := database.DeleteUserSessions(authUser.ID); err != nil {
+ if err := db.DeleteUserSessions(authUser.ID); err != nil {
logrus.Error("failed to delete user sessions : ", err)
}
}
- database.CreateSecurityLog(authUser.ID, database.LogoutSecurityLog)
+ db.CreateSecurityLog(authUser.ID, database.LogoutSecurityLog)
ctx.SetCookie(hutils.DeleteCookie(hutils.AuthCookieName))
managers.ActiveUsers.RemoveUser(authUser.ID)
if authUser.Temp {
- if err := database.DB.Where("id = ?", authUser.ID).Unscoped().Delete(&database.User{}).Error; err != nil {
+ if err := db.DB().Where("id = ?", authUser.ID).Unscoped().Delete(&database.User{}).Error; err != nil {
logrus.Error(err)
}
}
@@ -621,12 +629,13 @@ func SignupAttackHandler(c echo.Context) error {
// SignupInvitationHandler ...
func SignupInvitationHandler(c echo.Context) error {
+ db := c.Get("database").(*database.DkfDB)
invitationToken := c.Param("invitationToken")
invitationTokenQuery := c.QueryParam("invitationToken")
if invitationTokenQuery != "" {
invitationToken = invitationTokenQuery
}
- if _, err := database.GetUnusedInvitationByToken(invitationToken); err != nil {
+ if _, err := db.GetUnusedInvitationByToken(invitationToken); err != nil {
return c.Redirect(http.StatusFound, "/")
}
return waitPageWrapper(c, signupHandler, hutils.WaitCookieName)
@@ -699,6 +708,7 @@ type SignupInfo struct {
func SignalCss1(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
data := c.Param("data")
data = strings.TrimRight(data, ".png")
data = strings.TrimRight(data, ".ttf")
@@ -716,7 +726,7 @@ func SignalCss1(c echo.Context) error {
info.UpdatedAt = time.Now().Format(time.RFC3339)
signupInfoEnc, _ := json.Marshal(info)
authUser.SignupMetadata = string(signupInfoEnc)
- authUser.DoSave()
+ authUser.DoSave(db)
return c.NoContent(http.StatusOK)
}
@@ -821,6 +831,7 @@ func waitPageWrapper(c echo.Context, clb echo.HandlerFunc, cookieName string) er
// Not all requests to the signup endpoint will get the captcha at the same time,
// so you cannot just refresh the page until you get a captcha that is easier to crack.
func signupHandler(c echo.Context) error {
+ db := c.Get("database").(*database.DkfDB)
start := c.Get("start").(int64)
signupToken := c.Get("signupToken").(string)
var data signupData
@@ -877,7 +888,7 @@ func signupHandler(c echo.Context) error {
signupInfoEnc, _ := json.Marshal(signupInfo)
registrationDuration := time.Now().UnixMilli() - start
- newUser, errs := database.CreateUser(data.Username, data.Password, data.RePassword, registrationDuration, string(signupInfoEnc))
+ newUser, errs := db.CreateUser(data.Username, data.Password, data.RePassword, registrationDuration, string(signupInfoEnc))
if errs.HasError() {
data.Errors = errs
return c.Render(http.StatusOK, "standalone.signup", data)
@@ -886,29 +897,29 @@ func signupHandler(c echo.Context) error {
// Fuck with hellbanned users. New account also hellbanned
if hasHBCookie {
newUser.IsHellbanned = true
- newUser.DoSave()
+ newUser.DoSave(db)
}
invitationToken := c.Param("invitationToken")
if invitationToken != "" {
- if invitation, err := database.GetUnusedInvitationByToken(invitationToken); err == nil {
+ if invitation, err := db.GetUnusedInvitationByToken(invitationToken); err == nil {
invitation.InviteeUserID = newUser.ID
- invitation.DoSave()
+ invitation.DoSave(db)
}
}
// If more than 10 users were created in the past minute, auto disable signup for the website
- if database.GetRecentUsersCount() > 10 {
- settings := database.GetSettings()
+ if db.GetRecentUsersCount() > 10 {
+ settings := db.GetSettings()
settings.SignupEnabled = false
- settings.DoSave()
+ settings.DoSave(db)
config.SignupEnabled.SetFalse()
- if userNull, err := database.GetUserByUsername(config.NullUsername); err == nil {
- database.NewAudit(userNull, fmt.Sprintf("auto turn off signup"))
+ if userNull, err := db.GetUserByUsername(config.NullUsername); err == nil {
+ db.NewAudit(userNull, fmt.Sprintf("auto turn off signup"))
// Display message in chat
txt := fmt.Sprintf("auto turn off registrations")
- if err := database.CreateSysMsg(txt, txt, "", config.GeneralRoomID, userNull.ID); err != nil {
+ if err := db.CreateSysMsg(txt, txt, "", config.GeneralRoomID, userNull.ID); err != nil {
logrus.Error(err)
}
}
@@ -977,6 +988,7 @@ func ForgotPasswordHandler(c echo.Context) error {
}
func forgotPasswordHandler(c echo.Context) error {
+ db := c.Get("database").(*database.DkfDB)
var data forgotPasswordData
const (
usernameCaptchaStep = iota + 1
@@ -1012,7 +1024,7 @@ func forgotPasswordHandler(c echo.Context) error {
data.ErrCaptcha = err.Error()
return c.Render(http.StatusOK, forgotPasswordTmplName, data)
}
- user, err := database.GetUserByUsername(data.Username)
+ user, err := db.GetUserByUsername(data.Username)
if err != nil {
data.UsernameError = "no such user"
return c.Render(http.StatusOK, forgotPasswordTmplName, data)
@@ -1098,7 +1110,7 @@ func forgotPasswordHandler(c echo.Context) error {
return c.Redirect(http.StatusFound, "/")
}
userID := item.UserID
- user, err := database.GetUserByID(userID)
+ user, err := db.GetUserByID(userID)
if err != nil {
return c.Redirect(http.StatusFound, "/")
}
@@ -1108,16 +1120,16 @@ func forgotPasswordHandler(c echo.Context) error {
data.NewPassword = newPassword
data.RePassword = rePassword
- hashedPassword, err := database.NewPasswordValidator(newPassword).CompareWith(rePassword).Hash()
+ hashedPassword, err := database.NewPasswordValidator(db, newPassword).CompareWith(rePassword).Hash()
if err != nil {
data.ErrorNewPassword = err.Error()
return c.Render(http.StatusOK, forgotPasswordTmplName, data)
}
- if err := user.ChangePassword(hashedPassword); err != nil {
+ if err := user.ChangePassword(db, hashedPassword); err != nil {
logrus.Error(err)
}
- database.CreateSecurityLog(user.ID, database.PasswordRecoverySecurityLog)
+ db.CreateSecurityLog(user.ID, database.PasswordRecoverySecurityLog)
partialRecoveryCache.Delete(token)
c.SetCookie(hutils.DeleteCookie(hutils.WaitCookieName))
@@ -1134,12 +1146,13 @@ func NewsHandler(c echo.Context) error {
}
func ForumSearchHandler(c echo.Context) error {
+ db := c.Get("database").(*database.DkfDB)
var data forumSearchData
data.Search = c.QueryParam("search")
data.AuthorFilter = c.QueryParam("author")
if data.AuthorFilter != "" {
- if err := database.DB.Raw(`select
+ if err := db.DB().Raw(`select
t.*,
u.username as author,
u.chat_color as author_chat_color,
@@ -1164,7 +1177,7 @@ where u.username = ? and t.is_club = 0 order by id desc limit 100`, data.AuthorF
return c.Render(http.StatusOK, "forum-search", data)
}
- if err := database.DB.Raw(`select m.uuid, snippet(fts5_forum_messages,-1, '[', ']', '...', 10) as snippet, t.uuid as thread_uuid, t.name as thread_name,
+ if err := db.DB().Raw(`select m.uuid, snippet(fts5_forum_messages,-1, '[', ']', '...', 10) as snippet, t.uuid as thread_uuid, t.name as thread_name,
u.username as author,
u.chat_color as author_chat_color,
u.chat_font as author_chat_font,
@@ -1179,7 +1192,7 @@ where fts5_forum_messages match ? and t.is_club = 0 order by rank limit 100`, da
logrus.Error(err)
}
- if err := database.DB.Raw(`select
+ if err := db.DB().Raw(`select
t.*,
u.username as author,
u.chat_color as author_chat_color,
@@ -1207,23 +1220,24 @@ where fts5_forum_threads match ? and t.is_club = 0 order by rank limit 100`, dat
func LinksHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
var data linksData
- data.Categories, _ = database.GetCategories()
+ data.Categories, _ = db.GetCategories()
data.Search = c.QueryParam("search")
filterCategory := c.QueryParam("category")
if filterCategory != "" {
if filterCategory == "uncategorized" {
- database.DB.Raw(`SELECT l.*
+ db.DB().Raw(`SELECT l.*
FROM links l
LEFT JOIN links_categories_links cl ON cl.link_id = l.id
WHERE cl.link_id IS NULL AND l.deleted_at IS NULL
ORDER BY l.title COLLATE NOCASE ASC`).Scan(&data.Links)
data.LinksCount = int64(len(data.Links))
} else {
- database.DB.Raw(`SELECT l.*
+ db.DB().Raw(`SELECT l.*
FROM links_categories_links cl
INNER JOIN links l ON l.id = cl.link_id
WHERE cl.category_id = (SELECT id FROM links_categories WHERE name = ?) AND l.deleted_at IS NULL
@@ -1235,7 +1249,7 @@ ORDER BY l.title COLLATE NOCASE ASC`, filterCategory).Scan(&data.Links)
if searchedURL, err := url.Parse(data.Search); err == nil {
h := searchedURL.Scheme + "://" + searchedURL.Hostname()
var l database.Link
- query := database.DB
+ query := db.DB()
if authUser.IsModerator() {
query = query.Unscoped()
}
@@ -1245,7 +1259,7 @@ ORDER BY l.title COLLATE NOCASE ASC`, filterCategory).Scan(&data.Links)
data.LinksCount = int64(len(data.Links))
}
} else {
- if err := database.DB.Raw(`select l.id, l.uuid, l.url, l.title, l.description
+ if err := db.DB().Raw(`select l.id, l.uuid, l.url, l.title, l.description
from fts5_links l
where fts5_links match ?
ORDER BY rank, l.title COLLATE NOCASE ASC
@@ -1255,7 +1269,7 @@ LIMIT 100`, data.Search).Scan(&data.Links).Error; err != nil {
data.LinksCount = int64(len(data.Links))
}
} else {
- if err := database.DB.Table("links").
+ if err := db.DB().Table("links").
Scopes(func(query *gorm.DB) *gorm.DB {
data.CurrentPage, data.MaxPage, data.LinksCount, query = NewPaginator().Paginate(c, query)
return query
@@ -1276,7 +1290,7 @@ LIMIT 100`, data.Search).Scan(&data.Links).Error; err != nil {
}
// Get all mirrors for all links that we have
var mirrors []database.LinksMirror
- database.DB.Raw(`select * from links_mirrors where link_id in (?)`, linksIDs).Scan(&mirrors)
+ db.DB().Raw(`select * from links_mirrors where link_id in (?)`, linksIDs).Scan(&mirrors)
// Put mirrors in links
for _, m := range mirrors {
if l, ok := linksCache[m.LinkID]; ok {
@@ -1289,6 +1303,7 @@ LIMIT 100`, data.Search).Scan(&data.Links).Error; err != nil {
func LinksDownloadHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
fileName := "dkf_links.csv"
// Captcha for bigger files
@@ -1307,11 +1322,11 @@ func LinksDownloadHandler(c echo.Context) error {
}
// Keep track of user downloads
- if _, err := database.CreateDownload(authUser.ID, fileName); err != nil {
+ if _, err := db.CreateDownload(authUser.ID, fileName); err != nil {
logrus.Error(err)
}
- links, _ := database.GetLinks()
+ links, _ := db.GetLinks()
by := make([]byte, 0)
buf := bytes.NewBuffer(by)
w := csv.NewWriter(buf)
@@ -1326,9 +1341,10 @@ func LinksDownloadHandler(c echo.Context) error {
func LinkPgpDownloadHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
pgpID := utils.DoParseInt64(c.Param("linkPgpID"))
- linkPgp, err := database.GetLinkPgpByID(pgpID)
+ linkPgp, err := db.GetLinkPgpByID(pgpID)
if err != nil {
return c.NoContent(http.StatusNotFound)
}
@@ -1336,7 +1352,7 @@ func LinkPgpDownloadHandler(c echo.Context) error {
fileName := linkPgp.Title + ".asc"
// Keep track of user downloads
- if _, err := database.CreateDownload(authUser.ID, fileName); err != nil {
+ if _, err := db.CreateDownload(authUser.ID, fileName); err != nil {
logrus.Error(err)
}
@@ -1349,20 +1365,21 @@ func LinksClaimInstructionsHandler(c echo.Context) error {
}
func LinkHandler(c echo.Context) error {
+ db := c.Get("database").(*database.DkfDB)
shorthand := c.Param("shorthand")
linkUUID := c.Param("linkUUID")
var data linkData
var err error
if shorthand != "" {
- data.Link, err = database.GetLinkByShorthand(shorthand)
+ data.Link, err = db.GetLinkByShorthand(shorthand)
} else {
- data.Link, err = database.GetLinkByUUID(linkUUID)
+ data.Link, err = db.GetLinkByUUID(linkUUID)
}
if err != nil {
return c.Redirect(http.StatusFound, "/links")
}
- data.PgpKeys, _ = database.GetLinkPgps(data.Link.ID)
- data.Mirrors, _ = database.GetLinkMirrors(data.Link.ID)
+ data.PgpKeys, _ = db.GetLinkPgps(data.Link.ID)
+ data.Mirrors, _ = db.GetLinkMirrors(data.Link.ID)
return c.Render(http.StatusOK, "link", data)
}
@@ -1372,6 +1389,7 @@ type CsvLink struct {
}
func LinksUploadHandler(c echo.Context) error {
+ db := c.Get("database").(*database.DkfDB)
var data linksUploadData
if c.Request().Method == http.MethodPost {
data.CsvStr = c.Request().PostFormValue("csv")
@@ -1404,7 +1422,7 @@ func LinksUploadHandler(c echo.Context) error {
return c.Render(http.StatusOK, "links-upload", data)
}
for _, csvLink := range csvLinks {
- _, err := database.CreateLink(csvLink.URL, csvLink.Title, "", "")
+ _, err := db.CreateLink(csvLink.URL, csvLink.Title, "", "")
if err != nil {
logrus.Error(err)
}
@@ -1415,18 +1433,20 @@ func LinksUploadHandler(c echo.Context) error {
}
func LinksReindexHandler(c echo.Context) error {
- if err := database.DB.Exec(`INSERT INTO fts5_links(fts5_links) VALUES('rebuild')`).Error; err != nil {
+ db := c.Get("database").(*database.DkfDB)
+ if err := db.DB().Exec(`INSERT INTO fts5_links(fts5_links) VALUES('rebuild')`).Error; err != nil {
logrus.Error(err)
}
- database.DB.Exec(`delete from fts5_links where rowid in (select id from links where deleted_at is not null)`)
+ db.DB().Exec(`delete from fts5_links where rowid in (select id from links where deleted_at is not null)`)
return c.Redirect(http.StatusFound, c.Request().Referer())
}
func ForumReindexHandler(c echo.Context) error {
- if err := database.DB.Exec(`INSERT INTO fts5_forum_threads(fts5_forum_threads) VALUES('rebuild')`).Error; err != nil {
+ db := c.Get("database").(*database.DkfDB)
+ if err := db.DB().Exec(`INSERT INTO fts5_forum_threads(fts5_forum_threads) VALUES('rebuild')`).Error; err != nil {
logrus.Error(err)
}
- if err := database.DB.Exec(`INSERT INTO fts5_forum_messages(fts5_forum_messages) VALUES('rebuild')`).Error; err != nil {
+ if err := db.DB().Exec(`INSERT INTO fts5_forum_messages(fts5_forum_messages) VALUES('rebuild')`).Error; err != nil {
logrus.Error(err)
}
return c.Redirect(http.StatusFound, c.Request().Referer())
@@ -1434,6 +1454,7 @@ func ForumReindexHandler(c echo.Context) error {
func NewLinkHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
if !authUser.IsModerator() {
return c.Redirect(http.StatusFound, "/")
}
@@ -1489,26 +1510,26 @@ func NewLinkHandler(c echo.Context) error {
var categories []database.LinksCategory
var tags []database.LinksTag
for _, categoryStr := range categoriesStr {
- category, _ := database.CreateLinksCategory(categoryStr)
+ category, _ := db.CreateLinksCategory(categoryStr)
categories = append(categories, category)
}
for _, tagStr := range tagsStr {
- tag, _ := database.CreateLinksTag(tagStr)
+ tag, _ := db.CreateLinksTag(tagStr)
tags = append(tags, tag)
}
- link, err := database.CreateLink(data.Link, data.Title, data.Description, data.Shorthand)
+ link, err := db.CreateLink(data.Link, data.Title, data.Description, data.Shorthand)
if err != nil {
logrus.Error(err)
data.ErrorLink = "failed to create link"
return c.Render(http.StatusOK, "new-link", data)
}
for _, category := range categories {
- _ = database.AddLinkCategory(link.ID, category.ID)
+ _ = db.AddLinkCategory(link.ID, category.ID)
}
for _, tag := range tags {
- _ = database.AddLinkTag(link.ID, tag.ID)
+ _ = db.AddLinkTag(link.ID, tag.ID)
}
- database.NewAudit(*authUser, fmt.Sprintf("create link %s", link.URL))
+ db.NewAudit(*authUser, fmt.Sprintf("create link %s", link.URL))
return c.Redirect(http.StatusFound, "/links")
}
return c.Render(http.StatusOK, "new-link", data)
@@ -1516,35 +1537,37 @@ func NewLinkHandler(c echo.Context) error {
func RestoreLinkHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
if !authUser.IsModerator() {
return c.Redirect(http.StatusFound, "/")
}
linkUUID := c.Param("linkUUID")
var link database.Link
- if err := database.DB.Unscoped().First(&link, "uuid = ?", linkUUID).Error; err != nil {
+ if err := db.DB().Unscoped().First(&link, "uuid = ?", linkUUID).Error; err != nil {
return c.Redirect(http.StatusFound, c.Request().Referer())
}
- database.NewAudit(*authUser, fmt.Sprintf("restore link %s", link.URL))
- database.DB.Unscoped().Model(&database.Link{}).Where("id = ?", link.ID).Update("deleted_at", nil)
+ db.NewAudit(*authUser, fmt.Sprintf("restore link %s", link.URL))
+ db.DB().Unscoped().Model(&database.Link{}).Where("id = ?", link.ID).Update("deleted_at", nil)
return c.Redirect(http.StatusFound, c.Request().Referer())
}
func EditLinkHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
if !authUser.IsModerator() {
return c.Redirect(http.StatusFound, "/")
}
linkUUID := c.Param("linkUUID")
- link, err := database.GetLinkByUUID(linkUUID)
+ link, err := db.GetLinkByUUID(linkUUID)
if err != nil {
return c.Redirect(http.StatusFound, "/")
}
- out, _ := database.GetLinkCategories(link.ID)
+ out, _ := db.GetLinkCategories(link.ID)
categories := make([]string, 0)
for _, el := range out {
categories = append(categories, el.Name)
}
- out1, err := database.GetLinkTags(link.ID)
+ out1, err := db.GetLinkTags(link.ID)
tags := make([]string, 0)
for _, el := range out1 {
tags = append(tags, el.Name)
@@ -1559,15 +1582,15 @@ func EditLinkHandler(c echo.Context) error {
}
data.Categories = strings.Join(categories, ",")
data.Tags = strings.Join(tags, ",")
- data.Mirrors, _ = database.GetLinkMirrors(link.ID)
- data.LinkPgps, _ = database.GetLinkPgps(link.ID)
+ data.Mirrors, _ = db.GetLinkMirrors(link.ID)
+ data.LinkPgps, _ = db.GetLinkPgps(link.ID)
//data.Categories = link
if c.Request().Method == http.MethodPost {
formName := c.Request().PostFormValue("formName")
if formName == "createLink" {
- _ = database.DeleteLinkCategories(link.ID)
- _ = database.DeleteLinkTags(link.ID)
+ _ = db.DeleteLinkCategories(link.ID)
+ _ = db.DeleteLinkTags(link.ID)
// If link is signed, we can no longer edit the link URL
if link.SignedCertificate == "" {
@@ -1622,11 +1645,11 @@ func EditLinkHandler(c echo.Context) error {
var categories []database.LinksCategory
var tags []database.LinksTag
for _, categoryStr := range categoriesStr {
- category, _ := database.CreateLinksCategory(categoryStr)
+ category, _ := db.CreateLinksCategory(categoryStr)
categories = append(categories, category)
}
for _, tagStr := range tagsStr {
- tag, _ := database.CreateLinksTag(tagStr)
+ tag, _ := db.CreateLinksTag(tagStr)
tags = append(tags, tag)
}
link.URL = data.Link
@@ -1635,7 +1658,7 @@ func EditLinkHandler(c echo.Context) error {
if data.Shorthand != "" {
link.Shorthand = &data.Shorthand
}
- if err := database.DB.Save(&link).Error; err != nil {
+ if err := db.DB().Save(&link).Error; err != nil {
if strings.Contains(err.Error(), "UNIQUE constraint failed: links.shorthand") {
data.ErrorShorthand = "shorthand already used"
} else {
@@ -1644,12 +1667,12 @@ func EditLinkHandler(c echo.Context) error {
return c.Render(http.StatusOK, "new-link", data)
}
for _, category := range categories {
- _ = database.AddLinkCategory(link.ID, category.ID)
+ _ = db.AddLinkCategory(link.ID, category.ID)
}
for _, tag := range tags {
- _ = database.AddLinkTag(link.ID, tag.ID)
+ _ = db.AddLinkTag(link.ID, tag.ID)
}
- database.NewAudit(*authUser, fmt.Sprintf("updated link %s", link.URL))
+ db.NewAudit(*authUser, fmt.Sprintf("updated link %s", link.URL))
return c.Redirect(http.StatusFound, "/links")
} else if formName == "createPgp" {
@@ -1660,10 +1683,10 @@ func EditLinkHandler(c echo.Context) error {
}
data.PGPDescription = c.Request().PostFormValue("pgp_description")
data.PGPPublicKey = c.Request().PostFormValue("pgp_public_key")
- if _, err = database.CreateLinkPgp(link.ID, data.PGPTitle, data.PGPDescription, data.PGPPublicKey); err != nil {
+ if _, err = db.CreateLinkPgp(link.ID, data.PGPTitle, data.PGPDescription, data.PGPPublicKey); err != nil {
logrus.Error(err)
}
- database.NewAudit(*authUser, fmt.Sprintf("create gpg for link %s", link.URL))
+ db.NewAudit(*authUser, fmt.Sprintf("create gpg for link %s", link.URL))
return c.Redirect(http.StatusFound, c.Request().Referer())
} else if formName == "createMirror" {
@@ -1672,10 +1695,10 @@ func EditLinkHandler(c echo.Context) error {
data.ErrorMirrorLink = "invalid link"
return c.Render(http.StatusOK, "new-link", data)
}
- if _, err = database.CreateLinkMirror(link.ID, data.MirrorLink); err != nil {
+ if _, err = db.CreateLinkMirror(link.ID, data.MirrorLink); err != nil {
logrus.Error(err)
}
- database.NewAudit(*authUser, fmt.Sprintf("create mirror for link %s", link.URL))
+ db.NewAudit(*authUser, fmt.Sprintf("create mirror for link %s", link.URL))
return c.Redirect(http.StatusFound, c.Request().Referer())
}
return c.Redirect(http.StatusFound, "/links")
@@ -1686,8 +1709,9 @@ func EditLinkHandler(c echo.Context) error {
func ClaimLinkHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
linkUUID := c.Param("linkUUID")
- link, err := database.GetLinkByUUID(linkUUID)
+ link, err := db.GetLinkByUUID(linkUUID)
if err != nil {
return c.Redirect(http.StatusFound, "/")
}
@@ -1720,16 +1744,17 @@ func ClaimLinkHandler(c echo.Context) error {
link.SignedCertificate = signedCert
link.OwnerUserID = &authUser.ID
- link.DoSave()
+ link.DoSave(db)
return c.Redirect(http.StatusFound, "/links/"+link.UUID)
}
func ClaimDownloadCertificateLinkHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
linkUUID := c.Param("linkUUID")
- link, err := database.GetLinkByUUID(linkUUID)
+ link, err := db.GetLinkByUUID(linkUUID)
if err != nil {
return c.Redirect(http.StatusFound, "/")
}
@@ -1737,7 +1762,7 @@ func ClaimDownloadCertificateLinkHandler(c echo.Context) error {
fileName := "certificate.txt"
// Keep track of user downloads
- if _, err := database.CreateDownload(authUser.ID, fileName); err != nil {
+ if _, err := db.CreateDownload(authUser.ID, fileName); err != nil {
logrus.Error(err)
}
@@ -1747,7 +1772,8 @@ func ClaimDownloadCertificateLinkHandler(c echo.Context) error {
func ClaimCertificateLinkHandler(c echo.Context) error {
linkUUID := c.Param("linkUUID")
- link, err := database.GetLinkByUUID(linkUUID)
+ db := c.Get("database").(*database.DkfDB)
+ link, err := db.GetLinkByUUID(linkUUID)
if err != nil {
return c.Redirect(http.StatusFound, "/")
}
@@ -1756,35 +1782,38 @@ func ClaimCertificateLinkHandler(c echo.Context) error {
func ForumHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
var data forumData
- data.ForumCategories, _ = database.GetForumCategories()
- data.ForumThreads, _ = database.GetPublicForumCategoryThreads(authUser.ID, 1)
+ data.ForumCategories, _ = db.GetForumCategories()
+ data.ForumThreads, _ = db.GetPublicForumCategoryThreads(authUser.ID, 1)
return c.Render(http.StatusOK, "forum", data)
}
func ForumCategoryHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
categorySlug := c.Param("categorySlug")
var data forumCategoryData
- category, err := database.GetForumCategoryBySlug(categorySlug)
+ category, err := db.GetForumCategoryBySlug(categorySlug)
if err != nil {
return c.Redirect(http.StatusFound, "/forum")
}
- data.ForumThreads, _ = database.GetPublicForumCategoryThreads(authUser.ID, category.ID)
+ data.ForumThreads, _ = db.GetPublicForumCategoryThreads(authUser.ID, category.ID)
return c.Render(http.StatusOK, "forum", data)
}
func ThreadHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
threadUUID := database.ForumThreadUUID(c.Param("threadUUID"))
- thread, err := database.GetForumThreadByUUID(threadUUID)
+ thread, err := db.GetForumThreadByUUID(threadUUID)
if err != nil {
return c.Redirect(http.StatusFound, "/")
}
var data threadData
data.Thread = thread
- if err := database.DB.
+ if err := db.DB().
Table("forum_messages").
Where("thread_id = ?", thread.ID).
Scopes(func(query *gorm.DB) *gorm.DB {
@@ -1799,9 +1828,9 @@ func ThreadHandler(c echo.Context) error {
}
if authUser != nil {
- data.IsSubscribed = database.IsUserSubscribedToForumThread(authUser.ID, thread.ID)
+ data.IsSubscribed = db.IsUserSubscribedToForumThread(authUser.ID, thread.ID)
// Update read record
- database.UpdateForumReadRecord(authUser.ID, thread.ID)
+ db.UpdateForumReadRecord(authUser.ID, thread.ID)
}
return c.Render(http.StatusOK, "thread", data)
@@ -1809,8 +1838,9 @@ func ThreadHandler(c echo.Context) error {
func GistHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
gistUUID := c.Param("gistUUID")
- gist, err := database.GetGistByUUID(gistUUID)
+ gist, err := db.GetGistByUUID(gistUUID)
if err != nil {
return c.Redirect(http.StatusFound, "/")
}
@@ -1829,7 +1859,7 @@ func GistHandler(c echo.Context) error {
if gist.Password != "" {
hutils.DeleteGistCookie(c, gist.UUID)
}
- if err := database.DB.Delete(&gist).Error; err != nil {
+ if err := db.DB().Delete(&gist).Error; err != nil {
logrus.Error(err)
}
return c.Redirect(http.StatusFound, "/")
@@ -1885,12 +1915,13 @@ func ThreadReplyHandler(c echo.Context) error {
return c.Render(http.StatusOK, "flash", FlashResponse{Message: "Forum is temporarily disabled", Redirect: "/", Type: "alert-danger"})
}
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
if !authUser.CanUseForumFn() {
return c.Render(http.StatusOK, "flash", FlashResponse{Message: hutils.AccountTooYoungErr.Error(), Redirect: c.Request().Referer(), Type: "alert-danger"})
}
threadUUID := database.ForumThreadUUID(c.Param("threadUUID"))
- thread, err := database.GetForumThreadByUUID(threadUUID)
+ thread, err := db.GetForumThreadByUUID(threadUUID)
if err != nil {
return c.Redirect(http.StatusFound, "/")
}
@@ -1904,22 +1935,22 @@ func ThreadReplyHandler(c echo.Context) error {
return c.Render(http.StatusOK, "thread-reply", data)
}
if isForumSpam(data.Message) {
- database.NewAudit(*authUser, fmt.Sprintf("spam forum thread reply %s (#%d)", authUser.Username, authUser.ID))
+ db.NewAudit(*authUser, fmt.Sprintf("spam forum thread reply %s (#%d)", authUser.Username, authUser.ID))
authUser.CanUseForum = false
- authUser.DoSave()
+ authUser.DoSave(db)
return c.Redirect(http.StatusFound, "/")
}
message := database.MakeForumMessage(data.Message, authUser.ID, thread.ID)
message.IsSigned = message.ValidateSignature(authUser.GPGPublicKey)
- if err := database.DB.Create(&message).Error; err != nil {
+ if err := db.DB().Create(&message).Error; err != nil {
logrus.Error(err)
}
// Send notifications
- subs, _ := database.GetUsersSubscribedToForumThread(thread.ID)
+ subs, _ := db.GetUsersSubscribedToForumThread(thread.ID)
for _, sub := range subs {
if sub.UserID != authUser.ID {
msg := fmt.Sprintf(`New reply in thread "<a href="/t/%s#%s">%s</a>"`, thread.UUID, message.UUID, thread.Name)
- database.CreateNotification(msg, sub.UserID)
+ db.CreateNotification(msg, sub.UserID)
}
}
return c.Redirect(http.StatusFound, "/t/"+string(thread.UUID))
@@ -1933,11 +1964,12 @@ func ThreadDeleteMessageHandler(c echo.Context) error {
return c.Render(http.StatusOK, "flash", FlashResponse{Message: "Forum is temporarily disabled", Redirect: "/", Type: "alert-danger"})
}
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
if !authUser.CanUseForumFn() {
return c.Render(http.StatusOK, "flash", FlashResponse{Message: hutils.AccountTooYoungErr.Error(), Redirect: c.Request().Referer(), Type: "alert-danger"})
}
messageUUID := database.ForumMessageUUID(c.Param("messageUUID"))
- msg, err := database.GetForumMessageByUUID(messageUUID)
+ msg, err := db.GetForumMessageByUUID(messageUUID)
if err != nil {
return c.Redirect(http.StatusFound, c.Request().Referer())
}
@@ -1951,14 +1983,14 @@ func ThreadDeleteMessageHandler(c echo.Context) error {
}
var data deleteForumMessageData
- data.Thread, err = database.GetForumThreadByID(msg.ThreadID)
+ data.Thread, err = db.GetForumThreadByID(msg.ThreadID)
if err != nil {
return c.Redirect(http.StatusFound, c.Request().Referer())
}
data.Message = msg
if c.Request().Method == http.MethodPost {
- if err := database.DeleteForumMessageByID(msg.ID); err != nil {
+ if err := db.DeleteForumMessageByID(msg.ID); err != nil {
logrus.Error(err)
}
return c.Redirect(http.StatusFound, "/t/"+string(data.Thread.UUID))
@@ -1969,8 +2001,9 @@ func ThreadDeleteMessageHandler(c echo.Context) error {
func LinkDeleteHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
linkUUID := c.Param("linkUUID")
- link, err := database.GetLinkByUUID(linkUUID)
+ link, err := db.GetLinkByUUID(linkUUID)
if err != nil {
return c.Redirect(http.StatusFound, c.Request().Referer())
}
@@ -1983,8 +2016,8 @@ func LinkDeleteHandler(c echo.Context) error {
data.Link = link
if c.Request().Method == http.MethodPost {
- database.NewAudit(*authUser, fmt.Sprintf("deleted link %s", link.URL))
- if err := database.DeleteLinkByID(link.ID); err != nil {
+ db.NewAudit(*authUser, fmt.Sprintf("deleted link %s", link.URL))
+ if err := db.DeleteLinkByID(link.ID); err != nil {
logrus.Error(err)
}
return c.Redirect(http.StatusFound, "/links")
@@ -1995,12 +2028,13 @@ func LinkDeleteHandler(c echo.Context) error {
func LinkPgpDeleteHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
linkPgpID := utils.DoParseInt64(c.Param("linkPgpID"))
- linkPgp, err := database.GetLinkPgpByID(linkPgpID)
+ linkPgp, err := db.GetLinkPgpByID(linkPgpID)
if err != nil {
return c.Redirect(http.StatusFound, c.Request().Referer())
}
- link, err := database.GetLinkByID(linkPgp.LinkID)
+ link, err := db.GetLinkByID(linkPgp.LinkID)
if err != nil {
return c.Redirect(http.StatusFound, c.Request().Referer())
}
@@ -2014,7 +2048,7 @@ func LinkPgpDeleteHandler(c echo.Context) error {
data.LinkPgp = linkPgp
if c.Request().Method == http.MethodPost {
- if err := database.DeleteLinkPgpByID(linkPgp.ID); err != nil {
+ if err := db.DeleteLinkPgpByID(linkPgp.ID); err != nil {
logrus.Error(err)
}
return c.Redirect(http.StatusFound, "/links/"+link.UUID+"/edit")
@@ -2025,12 +2059,13 @@ func LinkPgpDeleteHandler(c echo.Context) error {
func LinkMirrorDeleteHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
linkMirrorID := utils.DoParseInt64(c.Param("linkMirrorID"))
- linkMirror, err := database.GetLinkMirrorByID(linkMirrorID)
+ linkMirror, err := db.GetLinkMirrorByID(linkMirrorID)
if err != nil {
return c.Redirect(http.StatusFound, c.Request().Referer())
}
- link, err := database.GetLinkByID(linkMirror.LinkID)
+ link, err := db.GetLinkByID(linkMirror.LinkID)
if err != nil {
return c.Redirect(http.StatusFound, c.Request().Referer())
}
@@ -2044,7 +2079,7 @@ func LinkMirrorDeleteHandler(c echo.Context) error {
data.LinkMirror = linkMirror
if c.Request().Method == http.MethodPost {
- if err := database.DeleteLinkMirrorByID(linkMirror.ID); err != nil {
+ if err := db.DeleteLinkMirrorByID(linkMirror.ID); err != nil {
logrus.Error(err)
}
return c.Redirect(http.StatusFound, "/links/"+link.UUID+"/edit")
@@ -2054,6 +2089,7 @@ func LinkMirrorDeleteHandler(c echo.Context) error {
}
func ThreadEditHandler(c echo.Context) error {
+ db := c.Get("database").(*database.DkfDB)
if config.ForumEnabled.IsFalse() {
return c.Render(http.StatusOK, "flash", FlashResponse{Message: "Forum is temporarily disabled", Redirect: "/", Type: "alert-danger"})
}
@@ -2065,7 +2101,7 @@ func ThreadEditHandler(c echo.Context) error {
return c.Redirect(http.StatusFound, c.Request().Referer())
}
threadUUID := database.ForumThreadUUID(c.Param("threadUUID"))
- thread, err := database.GetForumThreadByUUID(threadUUID)
+ thread, err := db.GetForumThreadByUUID(threadUUID)
if err != nil {
return c.Redirect(http.StatusFound, c.Request().Referer())
}
@@ -2074,7 +2110,7 @@ func ThreadEditHandler(c echo.Context) error {
if c.Request().Method == http.MethodPost {
thread.CategoryID = database.ForumCategoryID(utils.DoParseInt64(c.Request().PostFormValue("category_id")))
- thread.DoSave()
+ thread.DoSave(db)
return c.Redirect(http.StatusFound, "/forum")
}
@@ -2086,11 +2122,12 @@ func ThreadDeleteHandler(c echo.Context) error {
return c.Render(http.StatusOK, "flash", FlashResponse{Message: "Forum is temporarily disabled", Redirect: "/", Type: "alert-danger"})
}
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
if !authUser.CanUseForumFn() {
return c.Render(http.StatusOK, "flash", FlashResponse{Message: hutils.AccountTooYoungErr.Error(), Redirect: c.Request().Referer(), Type: "alert-danger"})
}
threadUUID := database.ForumThreadUUID(c.Param("threadUUID"))
- thread, err := database.GetForumThreadByUUID(threadUUID)
+ thread, err := db.GetForumThreadByUUID(threadUUID)
if err != nil {
return c.Redirect(http.StatusFound, c.Request().Referer())
}
@@ -2103,7 +2140,7 @@ func ThreadDeleteHandler(c echo.Context) error {
data.Thread = thread
if c.Request().Method == http.MethodPost {
- if err := database.DeleteForumThreadByID(thread.ID); err != nil {
+ if err := db.DeleteForumThreadByID(thread.ID); err != nil {
logrus.Error(err)
}
return c.Redirect(http.StatusFound, "/forum")
@@ -2117,16 +2154,17 @@ func ThreadEditMessageHandler(c echo.Context) error {
return c.Render(http.StatusOK, "flash", FlashResponse{Message: "Forum is temporarily disabled", Redirect: "/", Type: "alert-danger"})
}
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
if !authUser.CanUseForumFn() {
return c.Render(http.StatusOK, "flash", FlashResponse{Message: hutils.AccountTooYoungErr.Error(), Redirect: c.Request().Referer(), Type: "alert-danger"})
}
threadUUID := database.ForumThreadUUID(c.Param("threadUUID"))
messageUUID := database.ForumMessageUUID(c.Param("messageUUID"))
- thread, err := database.GetForumThreadByUUID(threadUUID)
+ thread, err := db.GetForumThreadByUUID(threadUUID)
if err != nil {
return c.Redirect(http.StatusFound, "/")
}
- msg, err := database.GetForumMessageByUUID(messageUUID)
+ msg, err := db.GetForumMessageByUUID(messageUUID)
if err != nil {
return c.Redirect(http.StatusFound, "/")
}
@@ -2145,14 +2183,14 @@ func ThreadEditMessageHandler(c echo.Context) error {
return c.Render(http.StatusOK, "thread-reply", data)
}
if isForumSpam(data.Message) {
- database.NewAudit(*authUser, fmt.Sprintf("spam forum edit msg %s (#%d)", authUser.Username, authUser.ID))
+ db.NewAudit(*authUser, fmt.Sprintf("spam forum edit msg %s (#%d)", authUser.Username, authUser.ID))
authUser.CanUseForum = false
- authUser.DoSave()
+ authUser.DoSave(db)
return c.Redirect(http.StatusFound, "/")
}
msg.Message = data.Message
msg.IsSigned = msg.ValidateSignature(authUser.GPGPublicKey)
- msg.DoSave()
+ msg.DoSave(db)
return c.Redirect(http.StatusFound, "/t/"+string(thread.UUID))
}
@@ -2163,8 +2201,9 @@ func ThreadRawMessageHandler(c echo.Context) error {
if config.ForumEnabled.IsFalse() {
return c.Render(http.StatusOK, "flash", FlashResponse{Message: "Forum is temporarily disabled", Redirect: "/", Type: "alert-danger"})
}
+ db := c.Get("database").(*database.DkfDB)
messageUUID := database.ForumMessageUUID(c.Param("messageUUID"))
- msg, err := database.GetForumMessageByUUID(messageUUID)
+ msg, err := db.GetForumMessageByUUID(messageUUID)
if err != nil {
return c.Redirect(http.StatusFound, "/")
}
@@ -2183,6 +2222,7 @@ func NewThreadHandler(c echo.Context) error {
return c.Render(http.StatusOK, "flash", FlashResponse{Message: "Forum is temporarily disabled", Redirect: "/", Type: "alert-danger"})
}
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
if !authUser.CanUseForumFn() {
return c.Render(http.StatusOK, "flash", FlashResponse{Message: hutils.AccountTooYoungErr.Error(), Redirect: c.Request().Referer(), Type: "alert-danger"})
}
@@ -2200,17 +2240,17 @@ func NewThreadHandler(c echo.Context) error {
return c.Render(http.StatusOK, "new-thread", data)
}
if isForumSpam(data.Message) {
- database.NewAudit(*authUser, fmt.Sprintf("spam forum new thread %s (#%d)", authUser.Username, authUser.ID))
+ db.NewAudit(*authUser, fmt.Sprintf("spam forum new thread %s (#%d)", authUser.Username, authUser.ID))
authUser.CanUseForum = false
- authUser.DoSave()
+ authUser.DoSave(db)
return c.Redirect(http.StatusFound, "/")
}
thread := database.MakeForumThread(data.ThreadName, authUser.ID, 1)
- database.DB.Create(&thread)
+ db.DB().Create(&thread)
message := database.MakeForumMessage(data.Message, authUser.ID, thread.ID)
message.IsSigned = message.ValidateSignature(authUser.GPGPublicKey)
- database.DB.Create(&message)
- _ = database.SubscribeToForumThread(authUser.ID, thread.ID)
+ db.DB().Create(&message)
+ _ = db.SubscribeToForumThread(authUser.ID, thread.ID)
return c.Redirect(http.StatusFound, "/t/"+string(thread.UUID))
}
@@ -2218,9 +2258,10 @@ func NewThreadHandler(c echo.Context) error {
}
func VipHandler(c echo.Context) error {
+ db := c.Get("database").(*database.DkfDB)
var data vipData
data.ActiveTab = "home"
- data.UsersBadges, _ = database.GetUsersBadges()
+ data.UsersBadges, _ = db.GetUsersBadges()
return c.Render(http.StatusOK, "vip.home", data)
}
@@ -2250,8 +2291,9 @@ func VipProjectsMalwareDropperHandler(c echo.Context) error {
func RoomsHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
var data roomsData
- data.Rooms, _ = database.GetListedChatRooms(authUser.ID)
+ data.Rooms, _ = db.GetListedChatRooms(authUser.ID)
return c.Render(http.StatusOK, "rooms", data)
}
@@ -2286,9 +2328,10 @@ func ChatHelpHandler(c echo.Context) error {
func RoomChatSettingsHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
var data roomChatSettingsData
roomName := c.Param("roomName")
- room, err := database.GetChatRoomByName(roomName)
+ room, err := db.GetChatRoomByName(roomName)
if err != nil {
return c.Redirect(http.StatusFound, "/")
}
@@ -2306,6 +2349,7 @@ func RoomChatSettingsHandler(c echo.Context) error {
func ChatCreateRoomHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
var data chatCreateRoomData
data.CaptchaID, data.CaptchaImg = captcha.New()
data.IsEphemeral = true
@@ -2328,7 +2372,7 @@ func ChatCreateRoomHandler(c echo.Context) error {
if data.Password != "" {
passwordHash = database.GetRoomPasswordHash(data.Password)
}
- if _, err := database.CreateRoom(data.RoomName, passwordHash, authUser.ID, data.IsListed); err != nil {
+ if _, err := db.CreateRoom(data.RoomName, passwordHash, authUser.ID, data.IsListed); err != nil {
data.Error = err.Error()
return c.Render(http.StatusOK, "chat-create-room", data)
}
@@ -2348,8 +2392,9 @@ func ShopHandler(c echo.Context) error {
return base64.StdEncoding.EncodeToString(buf.Bytes())
}
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
var data shopData
- invoice, err := database.CreateXmrInvoice(authUser.ID, 1)
+ invoice, err := db.CreateXmrInvoice(authUser.ID, 1)
if err != nil {
logrus.Error(err)
}
@@ -2411,12 +2456,13 @@ func SettingsChatHandler(c echo.Context) error {
func SettingsChatPMHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
var data settingsChatPMData
data.ActiveTab = "chat"
data.PmMode = authUser.PmMode
data.BlockNewUsersPm = authUser.BlockNewUsersPm
- data.WhitelistedUsers, _ = database.GetPmWhitelistedUsers(authUser.ID)
- data.BlacklistedUsers, _ = database.GetPmBlacklistedUsers(authUser.ID)
+ data.WhitelistedUsers, _ = db.GetPmWhitelistedUsers(authUser.ID)
+ data.BlacklistedUsers, _ = db.GetPmBlacklistedUsers(authUser.ID)
if c.Request().Method == http.MethodGet {
return c.Render(http.StatusOK, "settings.chat-pm", data)
@@ -2427,48 +2473,49 @@ func SettingsChatPMHandler(c echo.Context) error {
if formName == "addWhitelist" {
data.AddWhitelist = strings.TrimSpace(c.Request().PostFormValue("username"))
- user, err := database.GetUserByUsername(data.AddWhitelist)
+ user, err := db.GetUserByUsername(data.AddWhitelist)
if err != nil {
data.Error = "username not found"
return c.Render(http.StatusOK, "settings.chat-pm", data)
}
- database.AddWhitelistedUser(authUser.ID, user.ID)
+ db.AddWhitelistedUser(authUser.ID, user.ID)
return c.Redirect(http.StatusFound, c.Request().Referer())
} else if formName == "rmWhitelist" {
userID := dutils.DoParseUserID(c.Request().PostFormValue("userID"))
- database.RmWhitelistedUser(authUser.ID, userID)
+ db.RmWhitelistedUser(authUser.ID, userID)
return c.Redirect(http.StatusFound, c.Request().Referer())
} else if formName == "addBlacklist" {
data.AddBlacklist = strings.TrimSpace(c.Request().PostFormValue("username"))
- user, err := database.GetUserByUsername(data.AddBlacklist)
+ user, err := db.GetUserByUsername(data.AddBlacklist)
if err != nil {
data.Error = "username not found"
return c.Render(http.StatusOK, "settings.chat-pm", data)
}
- database.AddBlacklistedUser(authUser.ID, user.ID)
+ db.AddBlacklistedUser(authUser.ID, user.ID)
return c.Redirect(http.StatusFound, c.Request().Referer())
} else if formName == "rmBlacklist" {
userID := dutils.DoParseUserID(c.Request().PostFormValue("userID"))
- database.RmBlacklistedUser(authUser.ID, userID)
+ db.RmBlacklistedUser(authUser.ID, userID)
return c.Redirect(http.StatusFound, c.Request().Referer())
}
data.PmMode = utils.Clamp(utils.DoParseInt64(c.Request().PostFormValue("pm_mode")), 0, 1)
authUser.BlockNewUsersPm = utils.DoParseBool(c.Request().PostFormValue("block_new_users_pm"))
authUser.PmMode = data.PmMode
- authUser.DoSave()
+ authUser.DoSave(db)
return c.Redirect(http.StatusFound, c.Request().Referer())
}
func SettingsChatIgnoreHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
var data settingsChatIgnoreData
data.ActiveTab = "chat"
data.PmMode = authUser.PmMode
- data.IgnoredUsers, _ = database.GetIgnoredUsers(authUser.ID)
+ data.IgnoredUsers, _ = db.GetIgnoredUsers(authUser.ID)
if c.Request().Method == http.MethodGet {
return c.Render(http.StatusOK, "settings.chat-ignore", data)
@@ -2479,31 +2526,32 @@ func SettingsChatIgnoreHandler(c echo.Context) error {
if formName == "addIgnored" {
data.AddIgnored = strings.TrimSpace(c.Request().PostFormValue("username"))
- user, err := database.GetUserByUsername(data.AddIgnored)
+ user, err := db.GetUserByUsername(data.AddIgnored)
if err != nil {
data.Error = "username not found"
return c.Render(http.StatusOK, "settings.chat-ignore", data)
}
- database.IgnoreUser(authUser.ID, user.ID)
+ db.IgnoreUser(authUser.ID, user.ID)
return c.Redirect(http.StatusFound, c.Request().Referer())
} else if formName == "rmIgnored" {
userID := dutils.DoParseUserID(c.Request().PostFormValue("userID"))
- database.UnIgnoreUser(authUser.ID, userID)
+ db.UnIgnoreUser(authUser.ID, userID)
return c.Redirect(http.StatusFound, c.Request().Referer())
}
data.PmMode = utils.Clamp(utils.DoParseInt64(c.Request().PostFormValue("pm_mode")), 0, 1)
authUser.PmMode = data.PmMode
- authUser.DoSave()
+ authUser.DoSave(db)
return c.Redirect(http.StatusFound, c.Request().Referer())
}
func SettingsChatSnippetsHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
var data settingsChatSnippetsData
data.ActiveTab = "snippets"
- data.Snippets, _ = database.GetUserSnippets(authUser.ID)
+ data.Snippets, _ = db.GetUserSnippets(authUser.ID)
if c.Request().Method == http.MethodGet {
return c.Render(http.StatusOK, "settings.chat-snippets", data)
@@ -2531,7 +2579,7 @@ func SettingsChatSnippetsHandler(c echo.Context) error {
data.Error = "text must be 1-1000 characters"
return c.Render(http.StatusOK, "settings.chat-snippets", data)
}
- if _, err := database.CreateSnippet(authUser.ID, data.Name, data.Text); err != nil {
+ if _, err := db.CreateSnippet(authUser.ID, data.Name, data.Text); err != nil {
data.Error = err.Error()
return c.Render(http.StatusOK, "settings.chat-snippets", data)
}
@@ -2539,7 +2587,7 @@ func SettingsChatSnippetsHandler(c echo.Context) error {
} else if formName == "rmSnippet" {
snippetName := c.Request().PostFormValue("snippetName")
- database.DeleteSnippet(authUser.ID, snippetName)
+ db.DeleteSnippet(authUser.ID, snippetName)
return c.Redirect(http.StatusFound, c.Request().Referer())
}
@@ -2548,9 +2596,10 @@ func SettingsChatSnippetsHandler(c echo.Context) error {
func SettingsUploadsHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
var data settingsUploadsData
data.ActiveTab = "uploads"
- data.Files, _ = database.GetUserUploads(authUser.ID)
+ data.Files, _ = db.GetUserUploads(authUser.ID)
for _, f := range data.Files {
data.TotalSize += f.FileSize
}
@@ -2563,14 +2612,14 @@ func SettingsUploadsHandler(c echo.Context) error {
formName := c.FormValue("formName")
if formName == "deleteUpload" {
fileName := c.Request().PostFormValue("file_name")
- file, err := database.GetUploadByFileName(fileName)
+ file, err := db.GetUploadByFileName(fileName)
if err != nil {
return c.Redirect(http.StatusFound, "/")
}
if authUser.ID != file.UserID {
return c.Redirect(http.StatusFound, "/")
}
- if err := file.Delete(); err != nil {
+ if err := file.Delete(db); err != nil {
logrus.Error(err)
}
return c.Redirect(http.StatusFound, c.Request().Referer())
@@ -2580,16 +2629,17 @@ func SettingsUploadsHandler(c echo.Context) error {
func SettingsPublicNotesHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
if !authUser.CanUseForumFn() {
return c.Render(http.StatusOK, "flash", FlashResponse{Message: hutils.AccountTooYoungErr.Error(), Redirect: c.Request().Referer(), Type: "alert-danger"})
}
var data settingsPublicNotesData
data.ActiveTab = "notes"
- data.Notes, _ = database.GetUserPublicNotes(authUser.ID)
+ data.Notes, _ = db.GetUserPublicNotes(authUser.ID)
if c.Request().Method == http.MethodPost {
notes := c.Request().PostFormValue("public_notes")
- if err := database.SetUserPublicNotes(authUser.ID, notes); err != nil {
+ if err := db.SetUserPublicNotes(authUser.ID, notes); err != nil {
data.Error = err.Error()
return c.Render(http.StatusOK, "settings.public-notes", data)
}
@@ -2601,18 +2651,19 @@ func SettingsPublicNotesHandler(c echo.Context) error {
func SettingsPrivateNotesHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
if !authUser.CanUseForumFn() {
return c.Render(http.StatusOK, "flash", FlashResponse{Message: hutils.AccountTooYoungErr.Error(), Redirect: c.Request().Referer(), Type: "alert-danger"})
}
var data settingsPrivateNotesData
data.ActiveTab = "notes"
if !authUser.IsUnderDuress {
- data.Notes, _ = database.GetUserPrivateNotes(authUser.ID)
+ data.Notes, _ = db.GetUserPrivateNotes(authUser.ID)
}
if c.Request().Method == http.MethodPost {
notes := c.Request().PostFormValue("private_notes")
- if err := database.SetUserPrivateNotes(authUser.ID, notes); err != nil {
+ if err := db.SetUserPrivateNotes(authUser.ID, notes); err != nil {
data.Error = err.Error()
return c.Render(http.StatusOK, "settings.private-notes", data)
}
@@ -2625,14 +2676,15 @@ func SettingsPrivateNotesHandler(c echo.Context) error {
func SettingsInboxHandler(c echo.Context) error {
authCookie, _ := c.Cookie(hutils.AuthCookieName)
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
var data settingsInboxData
data.ActiveTab = "inbox"
// Do not fetch inboxes & notifications if logged in under duress
if !authUser.IsUnderDuress {
global.DeleteUserNotificationCount(authUser.ID, authCookie.Value)
- data.ChatMessages, _ = database.GetUserChatInboxMessages(authUser.ID)
- data.Notifications, _ = database.GetUserNotifications(authUser.ID)
- data.SessionNotifications, _ = database.GetUserSessionNotifications(authCookie.Value)
+ data.ChatMessages, _ = db.GetUserChatInboxMessages(authUser.ID)
+ data.Notifications, _ = db.GetUserNotifications(authUser.ID)
+ data.SessionNotifications, _ = db.GetUserSessionNotifications(authCookie.Value)
}
for _, m := range data.ChatMessages {
data.Notifs = append(data.Notifs, InboxTmp{IsNotif: false, ChatInboxMessage: m})
@@ -2669,28 +2721,31 @@ func SettingsInboxHandler(c echo.Context) error {
func SettingsInboxSentHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
var data settingsInboxSentData
data.ActiveTab = "inbox"
// Do not fetch inboxes & notifications if logged in under duress
if !authUser.IsUnderDuress {
- data.ChatInboxSent, _ = database.GetUserChatInboxMessagesSent(authUser.ID)
+ data.ChatInboxSent, _ = db.GetUserChatInboxMessagesSent(authUser.ID)
}
return c.Render(http.StatusOK, "settings.inbox-sent", data)
}
func SettingsSecurityHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
var data settingsSecurityData
data.ActiveTab = "security"
- data.Logs, _ = database.GetSecurityLogs(authUser.ID)
+ data.Logs, _ = db.GetSecurityLogs(authUser.ID)
return c.Render(http.StatusOK, "settings.security", data)
}
func SettingsSessionsHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
var data settingsSessionsData
data.ActiveTab = "sessions"
- sessions := database.GetActiveUserSessions(authUser.ID)
+ sessions := db.GetActiveUserSessions(authUser.ID)
authCookie, _ := c.Cookie(hutils.AuthCookieName)
for _, session := range sessions {
s := WrapperSession{Session: session}
@@ -2703,10 +2758,10 @@ func SettingsSessionsHandler(c echo.Context) error {
if c.Request().Method == http.MethodPost {
formName := c.Request().PostFormValue("formName")
if formName == "revoke_all_other_sessions" {
- _ = database.DeleteUserOtherSessions(authUser.ID, authCookie.Value)
+ _ = db.DeleteUserOtherSessions(authUser.ID, authCookie.Value)
} else {
sessionToken := c.Request().PostFormValue("sessionToken")
- _ = database.DeleteUserSessionByToken(authUser.ID, sessionToken)
+ _ = db.DeleteUserSessionByToken(authUser.ID, sessionToken)
}
return c.Redirect(http.StatusFound, c.Request().Referer())
}
@@ -2766,6 +2821,7 @@ func SettingsPasswordHandler(c echo.Context) error {
func SettingsSecretPhraseHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
var data settingsSecretPhraseData
data.ActiveTab = "secretPhrase"
data.SecretPhrase = string(authUser.SecretPhrase)
@@ -2790,42 +2846,44 @@ func SettingsSecretPhraseHandler(c echo.Context) error {
return c.Render(http.StatusOK, "settings.secret-phrase", data)
}
- if !authUser.CheckPassword(currentPassword) {
+ if !authUser.CheckPassword(db, currentPassword) {
data.ErrorCurrentPassword = "Invalid password"
return c.Render(http.StatusOK, "settings.secret-phrase", data)
}
authUser.SecretPhrase = database.EncryptedString(secretPhrase)
- authUser.DoSave()
+ authUser.DoSave(db)
- database.CreateSecurityLog(authUser.ID, database.ChangeSecretPhraseSecurityLog)
+ db.CreateSecurityLog(authUser.ID, database.ChangeSecretPhraseSecurityLog)
return c.Render(http.StatusFound, "flash", FlashResponse{Message: "Secret phrase changed successfully", Redirect: c.Request().Referer()})
}
// SettingsInvitationsHandler ...
func SettingsInvitationsHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
var data settingsInvitationsData
data.ActiveTab = "invitations"
data.DkfOnion = config.DkfOnion
if c.Request().Method == http.MethodPost {
- if _, err := database.CreateInvitation(authUser.ID); err != nil {
+ if _, err := db.CreateInvitation(authUser.ID); err != nil {
logrus.Error(err)
}
return c.Redirect(http.StatusFound, c.Request().Referer())
}
- data.Invitations, _ = database.GetUserUnusedInvitations(authUser.ID)
+ data.Invitations, _ = db.GetUserUnusedInvitations(authUser.ID)
return c.Render(http.StatusOK, "settings.invitations", data)
}
// SettingsWebsiteHandler ...
func SettingsWebsiteHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
var data settingsWebsiteData
data.ActiveTab = "website"
- settings := database.GetSettings()
+ settings := db.GetSettings()
data.SignupEnabled = settings.SignupEnabled
data.ForumEnabled = settings.ForumEnabled
data.SilentSelfKick = settings.SilentSelfKick
@@ -2833,11 +2891,11 @@ func SettingsWebsiteHandler(c echo.Context) error {
settings.SignupEnabled = utils.DoParseBool(c.Request().PostFormValue("signupEnabled"))
settings.ForumEnabled = utils.DoParseBool(c.Request().PostFormValue("forumEnabled"))
settings.SilentSelfKick = utils.DoParseBool(c.Request().PostFormValue("silentSelfKick"))
- settings.DoSave()
+ settings.DoSave(db)
config.SignupEnabled.Store(settings.SignupEnabled)
config.ForumEnabled.Store(settings.ForumEnabled)
config.SilentSelfKick.Store(settings.SilentSelfKick)
- database.NewAudit(*authUser, fmt.Sprintf("website settings, signup: %t, forum: %t, sk: %t",
+ db.NewAudit(*authUser, fmt.Sprintf("website settings, signup: %t, forum: %t, sk: %t",
settings.SignupEnabled, settings.ForumEnabled, settings.SilentSelfKick))
return c.Redirect(http.StatusFound, c.Request().Referer())
}
@@ -2847,6 +2905,7 @@ func SettingsWebsiteHandler(c echo.Context) error {
func editProfileForm(c echo.Context, data settingsAccountData) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
email := c.Request().PostFormValue("email")
website := c.Request().PostFormValue("website")
@@ -2871,13 +2930,14 @@ func editProfileForm(c echo.Context, data settingsAccountData) error {
authUser.Email = data.Email
authUser.LastSeenPublic = data.LastSeenPublic
authUser.TerminateAllSessionsOnLogout = data.TerminateAllSessionsOnLogout
- authUser.DoSave()
+ authUser.DoSave(db)
return c.Render(http.StatusOK, "flash", FlashResponse{Message: "Profile changed successfully", Redirect: c.Request().Referer()})
}
func changeAvatarForm(c echo.Context, data settingsAccountData) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
if !authUser.CanUpload() {
data.ErrorAvatar = hutils.AccountTooYoungErr.Error()
return c.Render(http.StatusOK, "settings.account", data)
@@ -2935,12 +2995,13 @@ func changeAvatarForm(c echo.Context, data settingsAccountData) error {
}
authUser.SetAvatar(fileBytes)
- authUser.DoSave()
+ authUser.DoSave(db)
return c.Render(http.StatusOK, "flash", FlashResponse{Message: "Avatar changed successfully", Redirect: c.Request().Referer()})
}
func changeUsernameForm(c echo.Context, data settingsAccountData) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
if !authUser.CanChangeUsername {
data.ErrorUsername = "Not allowed to change your username"
return c.Render(http.StatusOK, "settings.account", data)
@@ -2954,25 +3015,26 @@ func changeUsernameForm(c echo.Context, data settingsAccountData) error {
return c.Render(http.StatusOK, "settings.account", data)
}
- if err := database.CanRenameTo(authUser.Username, username); err != nil {
+ if err := db.CanRenameTo(authUser.Username, username); err != nil {
data.ErrorUsername = err.Error()
return c.Render(http.StatusOK, "settings.account", data)
}
managers.ActiveUsers.RemoveUser(authUser.ID)
authUser.Username = username
- if err := database.DB.Save(authUser).Error; err != nil {
+ if err := db.DB().Save(authUser).Error; err != nil {
logrus.Error(err)
data.ErrorUsername = err.Error()
return c.Render(http.StatusOK, "settings.account", data)
}
- database.CreateSecurityLog(authUser.ID, database.UsernameChangedSecurityLog)
+ db.CreateSecurityLog(authUser.ID, database.UsernameChangedSecurityLog)
return c.Render(http.StatusOK, "flash", FlashResponse{Message: "Username changed successfully", Redirect: c.Request().Referer()})
}
func changeSettingsForm(c echo.Context, data settingsChatData) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
data.RefreshRate = utils.Clamp(utils.DoParseInt64(c.Request().PostFormValue("refresh_rate")), 5, 60)
data.ChatColor = c.Request().PostFormValue("chat_color")
@@ -3051,7 +3113,7 @@ func changeSettingsForm(c echo.Context, data settingsChatData) error {
authUser.DisplayHellbanButton = data.DisplayHellbanButton
}
- if err := database.DB.Save(authUser).Error; err != nil {
+ if err := db.DB().Save(authUser).Error; err != nil {
logrus.Error(err)
data.Error = err.Error()
return c.Render(http.StatusOK, "settings.chat", data)
@@ -3062,6 +3124,7 @@ func changeSettingsForm(c echo.Context, data settingsChatData) error {
func changePasswordForm(c echo.Context, data settingsPasswordData) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
oldPassword := c.Request().PostFormValue("oldPassword")
newPassword := c.Request().PostFormValue("newPassword")
rePassword := c.Request().PostFormValue("rePassword")
@@ -3075,22 +3138,22 @@ func changePasswordForm(c echo.Context, data settingsPasswordData) error {
}
if len(newPassword) > 0 || len(rePassword) > 0 {
- hashedPassword, err := database.NewPasswordValidator(newPassword).CompareWith(rePassword).Hash()
+ hashedPassword, err := database.NewPasswordValidator(db, newPassword).CompareWith(rePassword).Hash()
if err != nil {
data.ErrorNewPassword = err.Error()
return c.Render(http.StatusOK, "settings.password", data)
}
- if !authUser.CheckPassword(oldPassword) {
+ if !authUser.CheckPassword(db, oldPassword) {
data.ErrorOldPassword = "Invalid password"
return c.Render(http.StatusOK, "settings.password", data)
}
- if err := authUser.ChangePassword(hashedPassword); err != nil {
+ if err := authUser.ChangePassword(db, hashedPassword); err != nil {
logrus.Error(err)
}
c.SetCookie(hutils.DeleteCookie(hutils.AuthCookieName))
- database.CreateSecurityLog(authUser.ID, database.ChangePasswordSecurityLog)
+ db.CreateSecurityLog(authUser.ID, database.ChangePasswordSecurityLog)
return c.Render(http.StatusFound, "flash", FlashResponse{Message: "Password changed successfully", Redirect: "/login"})
}
@@ -3099,6 +3162,7 @@ func changePasswordForm(c echo.Context, data settingsPasswordData) error {
func changeDuressPasswordForm(c echo.Context, data settingsPasswordData) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
oldDuressPassword := c.Request().PostFormValue("oldDuressPassword")
newDuressPassword := c.Request().PostFormValue("newDuressPassword")
reDuressPassword := c.Request().PostFormValue("reDuressPassword")
@@ -3112,22 +3176,22 @@ func changeDuressPasswordForm(c echo.Context, data settingsPasswordData) error {
}
if len(newDuressPassword) > 0 || len(reDuressPassword) > 0 {
- hashedPassword, err := database.NewPasswordValidator(newDuressPassword).CompareWith(reDuressPassword).Hash()
+ hashedPassword, err := database.NewPasswordValidator(db, newDuressPassword).CompareWith(reDuressPassword).Hash()
if err != nil {
data.ErrorNewDuressPassword = err.Error()
return c.Render(http.StatusOK, "settings.password", data)
}
- if !authUser.CheckPassword(oldDuressPassword) {
+ if !authUser.CheckPassword(db, oldDuressPassword) {
data.ErrorOldDuressPassword = "Invalid password"
return c.Render(http.StatusOK, "settings.password", data)
}
- if err := authUser.ChangeDuressPassword(hashedPassword); err != nil {
+ if err := authUser.ChangeDuressPassword(db, hashedPassword); err != nil {
logrus.Error(err)
}
c.SetCookie(hutils.DeleteCookie(hutils.AuthCookieName))
- database.CreateSecurityLog(authUser.ID, database.ChangeDuressPasswordSecurityLog)
+ db.CreateSecurityLog(authUser.ID, database.ChangeDuressPasswordSecurityLog)
return c.Render(http.StatusFound, "flash", FlashResponse{Message: "Password changed successfully", Redirect: "/login"})
}
@@ -3136,9 +3200,10 @@ func changeDuressPasswordForm(c echo.Context, data settingsPasswordData) error {
func ChatDeleteHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
var data chatDeleteData
roomName := c.Param("roomName")
- room, err := database.GetChatRoomByName(roomName)
+ room, err := db.GetChatRoomByName(roomName)
if err != nil {
return c.Redirect(http.StatusFound, "/")
}
@@ -3151,7 +3216,7 @@ func ChatDeleteHandler(c echo.Context) error {
if room.IsProtected() {
hutils.DeleteRoomCookie(c, int64(room.ID))
}
- database.DeleteChatRoomByID(room.ID)
+ db.DeleteChatRoomByID(room.ID)
return c.Redirect(http.StatusFound, "/chat")
}
@@ -3160,10 +3225,11 @@ func ChatDeleteHandler(c echo.Context) error {
func ChatArchiveHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
var data chatArchiveData
data.DateFormat = authUser.GetDateFormat()
roomName := c.Param("roomName")
- room, err := database.GetChatRoomByName(roomName)
+ room, err := db.GetChatRoomByName(roomName)
if err != nil {
return c.Redirect(http.StatusFound, "/")
}
@@ -3176,7 +3242,7 @@ func ChatArchiveHandler(c echo.Context) error {
data.Room = room
if data.UUID != "" {
- msg, err := database.GetRoomChatMessageByUUID(room.ID, data.UUID)
+ msg, err := db.GetRoomChatMessageByUUID(room.ID, data.UUID)
if err != nil {
return c.Redirect(http.StatusFound, "/")
}
@@ -3207,7 +3273,7 @@ func ChatArchiveHandler(c echo.Context) error {
) ORDER BY id DESC`
args = append(args, msg.ID, nbMsg)
args = append(args, args...)
- database.DB.Raw(raw, args...).Scan(&data.Messages)
+ db.DB().Raw(raw, args...).Scan(&data.Messages)
// Manually do Preload("Room")
for _, m := range data.Messages {
@@ -3222,7 +3288,7 @@ func ChatArchiveHandler(c echo.Context) error {
usersIDs.Insert(*m.ToUserID)
}
}
- users, _ := database.GetUsersByID(usersIDs.ToArray())
+ users, _ := db.GetUsersByID(usersIDs.ToArray())
usersMap := make(map[database.UserID]database.User)
for _, u := range users {
usersMap[u.ID] = u
@@ -3240,7 +3306,7 @@ func ChatArchiveHandler(c echo.Context) error {
//--- </ Manually do a Preload("User") Preload("ToUser") > ---
} else {
- if err := database.DB.Table("chat_messages").
+ if err := db.DB().Table("chat_messages").
Where("room_id = ? AND group_id IS NULL AND (to_user_id is null OR to_user_id = ? OR user_id = ?)", room.ID, authUser.ID, authUser.ID).
Scopes(func(query *gorm.DB) *gorm.DB {
if !authUser.DisplayIgnored {
@@ -3349,6 +3415,7 @@ func generatePgpToBeSignedTokenMessage(userID database.UserID, pkey string) stri
func AddPGPHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
var data addPGPData
data.PGPPublicKey = authUser.GPGPublicKey
if c.Request().Method == http.MethodPost {
@@ -3399,7 +3466,7 @@ func AddPGPHandler(c echo.Context) error {
pgpTokenCache.Delete(authUser.ID)
authUser.GPGPublicKey = token.PKey
- authUser.DoSave()
+ authUser.DoSave(db)
return c.Redirect(http.StatusFound, "/settings/pgp")
}
}
@@ -3408,6 +3475,7 @@ func AddPGPHandler(c echo.Context) error {
func AddAgeHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
var data addAgeData
data.AgePublicKey = authUser.AgePublicKey
if c.Request().Method == http.MethodPost {
@@ -3436,7 +3504,7 @@ func AddAgeHandler(c echo.Context) error {
}
ageTokenCache.Delete(authUser.ID)
authUser.AgePublicKey = token.PKey
- authUser.DoSave()
+ authUser.DoSave(db)
return c.Redirect(http.StatusFound, "/settings/age")
}
}
@@ -3453,6 +3521,7 @@ type twoFactorObj struct {
func GpgTwoFactorAuthenticationToggleHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
var data gpgTwoFactorAuthenticationVerifyData
data.IsEnabled = authUser.GpgTwoFactorEnabled
@@ -3463,7 +3532,7 @@ func GpgTwoFactorAuthenticationToggleHandler(c echo.Context) error {
}
password := c.Request().PostFormValue("password")
- if !authUser.CheckPassword(password) {
+ if !authUser.CheckPassword(db, password) {
data.ErrorPassword = "Invalid password"
return c.Render(http.StatusOK, "two-factor-authentication-gpg", data)
}
@@ -3471,8 +3540,8 @@ func GpgTwoFactorAuthenticationToggleHandler(c echo.Context) error {
// Disable
if authUser.GpgTwoFactorEnabled {
authUser.GpgTwoFactorEnabled = false
- authUser.DoSave()
- database.CreateSecurityLog(authUser.ID, database.Gpg2faDisabledSecurityLog)
+ authUser.DoSave(db)
+ db.CreateSecurityLog(authUser.ID, database.Gpg2faDisabledSecurityLog)
return c.Render(http.StatusOK, "flash", FlashResponse{"GPG Two-factor authentication disabled", "/settings/account", "alert-success"})
}
@@ -3481,14 +3550,14 @@ func GpgTwoFactorAuthenticationToggleHandler(c echo.Context) error {
return c.Render(http.StatusOK, "flash", FlashResponse{"You need to setup your PGP key first", "/settings/pgp", "alert-danger"})
}
// Delete active user sessions
- if err := database.DeleteUserSessions(authUser.ID); err != nil {
+ if err := db.DeleteUserSessions(authUser.ID); err != nil {
logrus.Error(err)
}
c.SetCookie(hutils.DeleteCookie(hutils.AuthCookieName))
authUser.GpgTwoFactorEnabled = true
authUser.GpgTwoFactorMode = utils.DoParseBool(c.Request().PostFormValue("gpg_two_factor_mode"))
- authUser.DoSave()
- database.CreateSecurityLog(authUser.ID, database.Gpg2faEnabledSecurityLog)
+ authUser.DoSave(db)
+ db.CreateSecurityLog(authUser.ID, database.Gpg2faEnabledSecurityLog)
return c.Render(http.StatusOK, "flash", FlashResponse{"GPG Two-factor authentication enabled", "/settings/account", "alert-success"})
}
@@ -3500,6 +3569,7 @@ func TwoFactorAuthenticationVerifyHandler(c echo.Context) error {
return base64.StdEncoding.EncodeToString(buf.Bytes())
}
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
if authUser.TwoFactorSecret != "" {
return c.Redirect(http.StatusFound, "/settings/account")
}
@@ -3510,7 +3580,7 @@ func TwoFactorAuthenticationVerifyHandler(c echo.Context) error {
return c.Redirect(http.StatusFound, "/two-factor-authentication/verify")
}
password := c.Request().PostFormValue("password")
- if !authUser.CheckPassword(password) {
+ if !authUser.CheckPassword(db, password) {
img, _ := twoFactor.key.Image(150, 150)
data.QRCode = getImgStr(img)
data.Secret = twoFactor.key.Secret()
@@ -3534,14 +3604,14 @@ func TwoFactorAuthenticationVerifyHandler(c echo.Context) error {
return c.Render(http.StatusOK, "two-factor-authentication-verify", data)
}
// Delete active user sessions
- if err := database.DeleteUserSessions(authUser.ID); err != nil {
+ if err := db.DeleteUserSessions(authUser.ID); err != nil {
logrus.Error(err)
}
c.SetCookie(hutils.DeleteCookie(hutils.AuthCookieName))
authUser.TwoFactorSecret = database.EncryptedString(twoFactor.key.Secret())
authUser.TwoFactorRecovery = string(h)
- authUser.DoSave()
- database.CreateSecurityLog(authUser.ID, database.TotpEnabledSecurityLog)
+ authUser.DoSave(db)
+ db.CreateSecurityLog(authUser.ID, database.TotpEnabledSecurityLog)
return c.Render(http.StatusOK, "flash", FlashResponse{"Two-factor authentication enabled", "/", "alert-success"})
}
key, _ := totp.Generate(totp.GenerateOpts{Issuer: "DarkForest", AccountName: authUser.Username})
@@ -3561,15 +3631,16 @@ func TwoFactorAuthenticationDisableHandler(c echo.Context) error {
return c.Render(http.StatusOK, "disable-totp", data)
}
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
password := c.Request().PostFormValue("password")
- if !authUser.CheckPassword(password) {
+ if !authUser.CheckPassword(db, password) {
data.ErrorPassword = "Invalid password"
return c.Render(http.StatusOK, "disable-totp", data)
}
authUser.TwoFactorSecret = ""
authUser.TwoFactorRecovery = ""
- authUser.DoSave()
- database.CreateSecurityLog(authUser.ID, database.TotpDisabledSecurityLog)
+ authUser.DoSave(db)
+ db.CreateSecurityLog(authUser.ID, database.TotpDisabledSecurityLog)
return c.Render(http.StatusOK, "flash", FlashResponse{"Two-factor authentication disabled", "/settings/account", "alert-success"})
}
@@ -3694,6 +3765,7 @@ var flagValidationCache = cache.NewWithKey[database.UserID, bool](time.Minute, t
func VipDownloadsHandler(c echo.Context) error {
const flagHash = "fefc9d5db52b51aeefd4b098f0178a8bcb7f0816dcadaf1714604f01ef63a621"
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
var data vipDownloadsHandlerData
data.ActiveTab = "home"
data.Files = getDownloadsFiles()
@@ -3709,7 +3781,7 @@ func VipDownloadsHandler(c echo.Context) error {
}
if utils.Sha256([]byte(flag)) == flagHash {
data.FlagMessage = "You found the flag!"
- _ = database.CreateUserBadge(authUser.ID, 1)
+ _ = db.CreateUserBadge(authUser.ID, 1)
} else {
data.FlagMessage = "Invalid flag"
}
@@ -3725,6 +3797,7 @@ func downloadFile(c echo.Context, folder, redirect string) error {
}
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
if authUser == nil {
return c.Redirect(http.StatusFound, "/login?redirect="+redirect)
}
@@ -3737,7 +3810,7 @@ func downloadFile(c echo.Context, folder, redirect string) error {
}
// Keep track of user downloads
- if _, err := database.CreateDownload(authUser.ID, filename); err != nil {
+ if _, err := db.CreateDownload(authUser.ID, filename); err != nil {
logrus.Error(err)
}
@@ -3758,6 +3831,7 @@ func VipDownloadFileHandler(c echo.Context) error {
func CaptchaRequiredHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
var data captchaRequiredData
data.CaptchaDescription = "Captcha required"
@@ -3778,7 +3852,7 @@ func CaptchaRequiredHandler(c echo.Context) error {
}
config.CaptchaRequiredSuccess.Inc()
authUser.CaptchaRequired = false
- authUser.DoSave()
+ authUser.DoSave(db)
return c.Redirect(http.StatusFound, "/chat")
}
@@ -3819,20 +3893,22 @@ func CaptchaHandler(c echo.Context) error {
func PublicUserProfileHandler(c echo.Context) error {
username := c.Param("username")
- user, err := database.GetUserByUsername(username)
+ db := c.Get("database").(*database.DkfDB)
+ user, err := db.GetUserByUsername(username)
if err != nil {
return c.Redirect(http.StatusFound, "/")
}
var data publicProfileData
data.User = user
data.UserStyle = user.GenerateChatStyle()
- data.PublicNotes, _ = database.GetUserPublicNotes(user.ID)
+ data.PublicNotes, _ = db.GetUserPublicNotes(user.ID)
return c.Render(http.StatusOK, "public-profile", data)
}
func PublicUserProfilePGPHandler(c echo.Context) error {
username := c.Param("username")
- user, err := database.GetUserByUsername(username)
+ db := c.Get("database").(*database.DkfDB)
+ user, err := db.GetUserByUsername(username)
if err != nil {
return c.Redirect(http.StatusFound, "/")
}
@@ -3877,8 +3953,9 @@ func isAttachmentMimeType(mimeType string) bool {
func UploadsDownloadHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
filename := c.Param("filename")
- file, err := database.GetUploadByFileName(filename)
+ file, err := db.GetUploadByFileName(filename)
if err != nil {
return c.Render(http.StatusOK, "standalone.upload404", nil)
}
@@ -3914,7 +3991,7 @@ func UploadsDownloadHandler(c echo.Context) error {
// MimeType that always trigger a file "download"
if isAttachmentMimeType(mimeType) {
// Keep track of user downloads
- if _, err := database.CreateDownload(authUser.ID, filename); err != nil {
+ if _, err := db.CreateDownload(authUser.ID, filename); err != nil {
logrus.Error(err)
}
c.Response().Header().Set(echo.HeaderContentDisposition, fmt.Sprintf("%s; filename=%q", "attachment", file.OrigFileName))
@@ -3929,7 +4006,7 @@ func UploadsDownloadHandler(c echo.Context) error {
return nil
}
- userNbDownloaded := database.UserNbDownloaded(authUser.ID, filename)
+ userNbDownloaded := db.UserNbDownloaded(authUser.ID, filename)
// Display captcha to new users, or old users if they already downloaded the file.
if !authUser.AccountOldEnough() || userNbDownloaded >= 1 {
@@ -3955,7 +4032,7 @@ func UploadsDownloadHandler(c echo.Context) error {
}
// Keep track of user downloads
- if _, err := database.CreateDownload(authUser.ID, filename); err != nil {
+ if _, err := db.CreateDownload(authUser.ID, filename); err != nil {
logrus.Error(err)
}
@@ -3972,7 +4049,8 @@ func UploadsDownloadHandler(c echo.Context) error {
func FiledropDownloadHandler(c echo.Context) error {
filename := c.Param("filename")
- filedrop, err := database.GetFiledropByFileName(filename)
+ db := c.Get("database").(*database.DkfDB)
+ filedrop, err := db.GetFiledropByFileName(filename)
if err != nil {
return c.Redirect(http.StatusFound, "/")
}
@@ -4023,6 +4101,7 @@ type ByteRoadPayload struct {
func ByteRoadChallengeHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
const byteRoadChallengeTmplName = "vip.byte-road-challenge"
var data byteRoadChallengeData
data.ActiveTab = "home"
@@ -4103,7 +4182,7 @@ func ByteRoadChallengeHandler(c echo.Context) error {
_ = byteRoadUsersCountCache.Update(authUser.ID, payload)
if payload.Count >= 100 {
data.FlagFound = true
- _ = database.CreateUserBadge(authUser.ID, 2)
+ _ = db.CreateUserBadge(authUser.ID, 2)
}
return c.Render(http.StatusOK, byteRoadChallengeTmplName, data)
}
@@ -4166,12 +4245,13 @@ func BHCHandler(c echo.Context) error {
func FileDropHandler(c echo.Context) error {
const filedropTmplName = "standalone.filedrop"
uuidParam := c.Param("uuid")
+ db := c.Get("database").(*database.DkfDB)
//if c.Request().ContentLength > config.MaxUserFileUploadSize {
// data.Error = fmt.Sprintf("The maximum file size is %s", humanize.Bytes(config.MaxUserFileUploadSize))
// return c.Render(http.StatusOK, "chat-top-bar", data)
//}
- filedrop, err := database.GetFiledropByUUID(uuidParam)
+ filedrop, err := db.GetFiledropByUUID(uuidParam)
if err != nil {
return c.Redirect(http.StatusFound, "/")
}
@@ -4227,17 +4307,18 @@ func FileDropHandler(c echo.Context) error {
filedrop.IV = encrypter.Meta().IV
filedrop.OrigFileName = origFileName
filedrop.FileSize = written
- filedrop.DoSave()
+ filedrop.DoSave(db)
data.Success = "File uploaded successfully"
return c.Render(http.StatusOK, filedropTmplName, data)
}
func FileDropDkfUploadHandler(c echo.Context) error {
+ db := c.Get("database").(*database.DkfDB)
// Init
if c.Request().PostFormValue("init") != "" {
filedropUUID := c.Param("uuid")
- _, err := database.GetFiledropByUUID(filedropUUID)
+ _, err := db.GetFiledropByUUID(filedropUUID)
if err != nil {
return c.Redirect(http.StatusFound, "/")
}
@@ -4277,7 +4358,7 @@ func FileDropDkfUploadHandler(c echo.Context) error {
if c.Request().PostFormValue("completed") != "" {
filedropUUID := c.Param("uuid")
- filedrop, err := database.GetFiledropByUUID(filedropUUID)
+ filedrop, err := db.GetFiledropByUUID(filedropUUID)
if err != nil {
return c.Redirect(http.StatusFound, "/")
}
@@ -4356,7 +4437,7 @@ func FileDropDkfUploadHandler(c echo.Context) error {
filedrop.IV = iv
filedrop.OrigFileName = origFileName
filedrop.FileSize = written
- filedrop.DoSave()
+ filedrop.DoSave(db)
return c.NoContent(http.StatusOK)
}
@@ -4374,7 +4455,7 @@ func FileDropDkfUploadHandler(c echo.Context) error {
}
}
- _, err := database.GetFiledropByUUID(filedropUUID)
+ _, err := db.GetFiledropByUUID(filedropUUID)
if err != nil {
return c.Redirect(http.StatusFound, "/")
}
@@ -4397,7 +4478,8 @@ func FileDropDkfUploadHandler(c echo.Context) error {
func FileDropDkfDownloadHandler(c echo.Context) error {
filedropUUID := c.Param("uuid")
- filedrop, err := database.GetFiledropByUUID(filedropUUID)
+ db := c.Get("database").(*database.DkfDB)
+ filedrop, err := db.GetFiledropByUUID(filedropUUID)
if err != nil {
return c.NoContent(http.StatusNotFound)
}
@@ -4456,6 +4538,7 @@ func FileDropDkfDownloadHandler(c echo.Context) error {
func FileDropDownloadHandler(c echo.Context) error {
authUser, ok := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
if !ok {
return c.Redirect(http.StatusFound, "/")
}
@@ -4466,7 +4549,7 @@ func FileDropDownloadHandler(c echo.Context) error {
return c.Redirect(http.StatusFound, "/")
}
- userNbDownloaded := database.UserNbDownloaded(authUser.ID, fileName)
+ userNbDownloaded := db.UserNbDownloaded(authUser.ID, fileName)
// Display captcha to new users, or old users if they already downloaded the file.
if !authUser.AccountOldEnough() || userNbDownloaded >= 1 {
@@ -4492,7 +4575,7 @@ func FileDropDownloadHandler(c echo.Context) error {
}
// Keep track of user downloads
- if _, err := database.CreateDownload(authUser.ID, fileName); err != nil {
+ if _, err := db.CreateDownload(authUser.ID, fileName); err != nil {
logrus.Error(err)
}
@@ -4513,6 +4596,7 @@ func FileDropDownloadHandler(c echo.Context) error {
func Stego1ChallengeHandler(c echo.Context) error {
const flagHash = "05b456689a9f8de69416d21cbb97157588b8491d07551167a95b93a1c7d61e7b"
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
var data stego1RoadChallengeData
data.ActiveTab = "home"
@@ -4528,7 +4612,7 @@ func Stego1ChallengeHandler(c echo.Context) error {
}
if utils.Sha256([]byte(flag)) == flagHash {
data.FlagMessage = "You found the flag!"
- _ = database.CreateUserBadge(authUser.ID, 3)
+ _ = db.CreateUserBadge(authUser.ID, 3)
} else {
data.FlagMessage = "Invalid flag"
}
@@ -4540,12 +4624,13 @@ func Stego1ChallengeHandler(c echo.Context) error {
func ChessHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
var data chessData
data.Games = v1.ChessInstance.GetGames()
if c.Request().Method == http.MethodPost {
data.Username = c.Request().PostFormValue("username")
- player2, err := database.GetUserByUsername(data.Username)
+ player2, err := db.GetUserByUsername(data.Username)
if err != nil {
data.Error = "invalid username"
return c.Render(http.StatusOK, "chess", data)
@@ -4612,13 +4697,14 @@ html, body {
func ChessGameHandler(c echo.Context) error {
authUser := c.Get("authUser").(*database.User)
+ //db := c.Get("database").(*database.DkfDB)
key := c.Param("key")
g := v1.ChessInstance.GetGame(key)
if g == nil {
// Chess debug
- //user1, _ := database.GetUserByID(1)
- //user2, _ := database.GetUserByID(24132)
+ //user1, _ := db.GetUserByID(1)
+ //user2, _ := db.GetUserByID(24132)
//v1.ChessInstance.NewGame(key, user1, user2)
//g = v1.ChessInstance.GetGame(key)
return c.Redirect(http.StatusFound, "/")
diff --git a/pkg/web/middlewares/middlewares.go b/pkg/web/middlewares/middlewares.go
@@ -213,6 +213,15 @@ func I18nMiddleware(bundle *i18n.Bundle, defaultLang string) echo.MiddlewareFunc
}
}
+func SetDatabaseMiddleware(db *database.DkfDB) echo.MiddlewareFunc {
+ return func(next echo.HandlerFunc) echo.HandlerFunc {
+ return func(ctx echo.Context) error {
+ ctx.Set("database", db)
+ return next(ctx)
+ }
+ }
+}
+
// SetUserMiddleware Get user and put it into echo context.
// - Get auth-token from cookie
// - If exists, get user from database
@@ -220,18 +229,19 @@ func I18nMiddleware(bundle *i18n.Bundle, defaultLang string) echo.MiddlewareFunc
// - Otherwise, empty user will be put in context
func SetUserMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
return func(ctx echo.Context) error {
+ db := ctx.Get("database").(*database.DkfDB)
var nilUser *database.User
var user database.User
if apiKey := ctx.Request().Header.Get("DKF_API_KEY"); apiKey != "" {
// Login using DKF_API_KEY
- if err := database.GetUserByApiKey(&user, apiKey); err == nil {
+ if err := db.GetUserByApiKey(&user, apiKey); err == nil {
ctx.Set("authUser", &user)
return next(ctx)
}
} else if authCookie, err := ctx.Cookie(hutils.AuthCookieName); err == nil {
// Login using auth cookie
- if err := database.GetUserBySessionKey(&user, authCookie.Value); err == nil {
+ if err := db.GetUserBySessionKey(&user, authCookie.Value); err == nil {
ctx.Set("authUser", &user)
return next(ctx)
}
@@ -248,6 +258,7 @@ func SetUserMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
func IsAuthMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
user := c.Get("authUser").(*database.User)
+ db := c.Get("database").(*database.DkfDB)
if user == nil {
if strings.HasPrefix(c.Path(), "/api/") {
return c.String(http.StatusUnauthorized, "unauthorized")
@@ -258,7 +269,7 @@ func IsAuthMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
c.Response().Header().Set("Cache-Control", "no-cache, no-store, must-revalidate")
user.LastSeenAt = time.Now()
- user.DoSave()
+ user.DoSave(db)
// Prevent clickjacking by setting the header on every logged in page
if !strings.Contains(c.Path(), "/api/v1/chat/messages") &&
diff --git a/pkg/web/public/views/pages/club/thread.gohtml b/pkg/web/public/views/pages/club/thread.gohtml
@@ -42,7 +42,7 @@
</form>
{{ end }}
</div>
- {{ .Escape | safe }}
+ {{ .Escape $.DB | safe }}
</td>
</tr>
</table>
diff --git a/pkg/web/public/views/pages/thread.gohtml b/pkg/web/public/views/pages/thread.gohtml
@@ -88,7 +88,7 @@
</div>
</div>
<div style="padding: 5px 5px 10px 10px;">
- {{ .Escape | safe }}
+ {{ .Escape $.DB | safe }}
</div>
</div>
{{ end }}
diff --git a/pkg/web/web.go b/pkg/web/web.go
@@ -2,6 +2,7 @@ package web
import (
"context"
+ "dkforest/pkg/database"
"fmt"
"github.com/ulule/limiter"
"github.com/ulule/limiter/drivers/store/memory"
@@ -27,7 +28,7 @@ import (
yaml "gopkg.in/yaml.v1"
)
-func getMainServer(i18nBundle *i18n.Bundle, renderer *tmp.Templates) echo.HandlerFunc {
+func getMainServer(db *database.DkfDB, i18nBundle *i18n.Bundle, renderer *tmp.Templates) echo.HandlerFunc {
e := echo.New()
e.HideBanner = true
e.HidePort = true
@@ -39,6 +40,7 @@ func getMainServer(i18nBundle *i18n.Bundle, renderer *tmp.Templates) echo.Handle
e.Use(staticbin.Static(bindata.Asset, staticbin.Options{Dir: "/public", SkipLogging: true}))
e.Renderer = renderer
+ e.Use(middlewares.SetDatabaseMiddleware(db))
e.Use(middlewares.FirstUseMiddleware)
e.Use(middlewares.DdosMiddleware)
e.Use(middlewares.MaintenanceMiddleware)
@@ -290,7 +292,7 @@ func getMainServer(i18nBundle *i18n.Bundle, renderer *tmp.Templates) echo.Handle
}
}
-func getBaseServer() *echo.Echo {
+func getBaseServer(db *database.DkfDB) *echo.Echo {
e := echo.New()
renderer := tmp.GetRenderer(e)
i18nBundle := getI18nBundle()
@@ -299,31 +301,32 @@ func getBaseServer() *echo.Echo {
e.Debug = true
e.Renderer = renderer
e.Use(middlewares.SetUselessHeadersMiddleware)
+ e.Use(middlewares.SetDatabaseMiddleware(db))
e.Use(middlewares.SetUserMiddleware)
e.Use(middlewares.I18nMiddleware(i18nBundle, "en"))
e.GET("/file-drop/:uuid", handlers.FileDropHandler)
e.POST("/file-drop/:uuid", handlers.FileDropHandler)
e.POST("/file-drop/:uuid/dkfupload", handlers.FileDropDkfUploadHandler)
- e.POST("/api/v1/file-drop/:uuid/dkfdownload", handlers.FileDropDkfDownloadHandler, middlewares.SetUserMiddleware, middlewares.IsAuthMiddleware)
+ e.POST("/api/v1/file-drop/:uuid/dkfdownload", handlers.FileDropDkfDownloadHandler, middlewares.IsAuthMiddleware)
e.GET("/downloads/:fileName", handlers.FileDropDownloadHandler)
e.POST("/downloads/:fileName", handlers.FileDropDownloadHandler)
- e.Any("*", getMainServer(i18nBundle, renderer))
+ e.Any("*", getMainServer(db, i18nBundle, renderer))
return e
}
-func startI2pServer(host string, port int) {
+func startI2pServer(db *database.DkfDB, host string, port int) {
if config.Development.IsTrue() {
return
}
address := host + ":" + strconv.Itoa(port)
- e := getBaseServer()
+ e := getBaseServer(db)
logrus.Info("start i2p server on " + address)
startServer(e, address)
}
-func startTorServer(host string, port int) {
+func startTorServer(db *database.DkfDB, host string, port int) {
address := host + ":" + strconv.Itoa(port)
- e := getBaseServer()
+ e := getBaseServer(db)
configTorProdServer(e)
logrus.Info("start tor server on " + address)
startServer(e, address)
@@ -338,11 +341,11 @@ func startServer(e *echo.Echo, address string) {
}
// Start ...
-func Start(host string, port int) {
+func Start(db *database.DkfDB, host string, port int) {
// Start server for I2P
- go startI2pServer(host, port+1)
+ go startI2pServer(db, host, port+1)
// Server for Tor/dev
- startTorServer(host, port)
+ startTorServer(db, host, port)
}
func extractGlobalCircuitIdentifier(m string) int64 {