dkforest

A forum and chat platform (onion)
git clone https://git.dasho.dev/n0tr1v/dkforest.git
Log | Files | Refs | LICENSE

usersStreamsManager.go (2047B)


      1 package usersStreamsManager
      2 
      3 import (
      4 	"dkforest/pkg/database"
      5 	"errors"
      6 	"sync"
      7 )
      8 
      9 const userMaxStream = 15
     10 
     11 var ErrTooManyStreams = errors.New("too many streams")
     12 
     13 type UserStreamsMap map[string]int64
     14 
     15 func (m *UserStreamsMap) count() (out int64) {
     16 	for _, v := range *m {
     17 		out += v
     18 	}
     19 	return
     20 }
     21 
     22 // UsersStreamsManager ensure that a user doesn't have more than userMaxStream
     23 // http long polling streams open at the same time.
     24 // If the limit is reached, the pages will then refuse to load.
     25 // This is to prevent a malicious user from opening unlimited amount of streams and wasting the server resources.
     26 type UsersStreamsManager struct {
     27 	sync.RWMutex
     28 	m map[database.UserID]UserStreamsMap
     29 }
     30 
     31 func NewUsersStreamsManager() *UsersStreamsManager {
     32 	return &UsersStreamsManager{m: make(map[database.UserID]UserStreamsMap)}
     33 }
     34 
     35 type Item struct {
     36 	m      *UsersStreamsManager
     37 	userID database.UserID
     38 	key    string
     39 }
     40 
     41 func (i *Item) Cleanup() {
     42 	i.m.Remove(i.userID, i.key)
     43 }
     44 
     45 func (m *UsersStreamsManager) Add(userID database.UserID, key string) (*Item, error) {
     46 	m.Lock()
     47 	defer m.Unlock()
     48 	userMap, found := m.m[userID]
     49 	if found && userMap.count() >= userMaxStream {
     50 		return nil, ErrTooManyStreams
     51 	}
     52 	if !found {
     53 		userMap = make(UserStreamsMap)
     54 	}
     55 	userMap[key]++
     56 	m.m[userID] = userMap
     57 	return &Item{m: m, userID: userID, key: key}, nil
     58 }
     59 
     60 func (m *UsersStreamsManager) Remove(userID database.UserID, key string) {
     61 	m.Lock()
     62 	defer m.Unlock()
     63 	if userMap, found := m.m[userID]; found {
     64 		userMap[key]--
     65 		m.m[userID] = userMap
     66 	}
     67 }
     68 
     69 func (m *UsersStreamsManager) GetUserStreamsCountFor(userID database.UserID, key string) (out int64) {
     70 	m.RLock()
     71 	defer m.RUnlock()
     72 	if userMap, found := m.m[userID]; found {
     73 		if nbStreams, found1 := userMap[key]; found1 {
     74 			return nbStreams
     75 		}
     76 	}
     77 	return
     78 }
     79 
     80 func (m *UsersStreamsManager) GetUsers() (out []database.UserID) {
     81 	m.RLock()
     82 	defer m.RUnlock()
     83 	for userID, userMap := range m.m {
     84 		if userMap.count() > 0 {
     85 			out = append(out, userID)
     86 		}
     87 	}
     88 	return
     89 }
     90 
     91 var Inst = NewUsersStreamsManager()