commit 2601958547eb0535e0b84140ca373c6bb3537774
parent 050ae0756f75da6590184e1b2fbfb1ca353126e7
Author: n0tr1v <n0tr1v@protonmail.com>
Date: Sat, 30 Dec 2023 14:26:43 -0500
improve rate limiter
Diffstat:
1 file changed, 30 insertions(+), 18 deletions(-)
diff --git a/pkg/web/middlewares/middlewares.go b/pkg/web/middlewares/middlewares.go
@@ -3,25 +3,22 @@ package middlewares
import (
"dkforest/bindata"
"dkforest/pkg/cache"
- "dkforest/pkg/web/clientFrontends"
- hutils "dkforest/pkg/web/handlers/utils"
- "net"
- "net/http"
- "strings"
- "time"
-
- "dkforest/pkg/web/handlers"
-
- "github.com/labstack/echo/middleware"
-
"dkforest/pkg/captcha"
"dkforest/pkg/config"
"dkforest/pkg/database"
"dkforest/pkg/utils"
+ "dkforest/pkg/web/clientFrontends"
+ "dkforest/pkg/web/handlers"
+ hutils "dkforest/pkg/web/handlers/utils"
"github.com/labstack/echo"
+ "github.com/labstack/echo/middleware"
"github.com/nicksnyder/go-i18n/v2/i18n"
"github.com/ulule/limiter"
"github.com/ulule/limiter/drivers/store/memory"
+ "net"
+ "net/http"
+ "strings"
+ "time"
)
// GzipMiddleware ...
@@ -274,21 +271,36 @@ func SetUserMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
}
}
-type RateLimit[K comparable] struct {
- cache *cache.Cache[K, struct{}]
+type RateLimit[K comparable, V any] struct {
+ cache *cache.Cache[K, V]
+ value V
}
-func NewRateLimit[K comparable](defaultExpiration time.Duration) *RateLimit[K] {
- return &RateLimit[K]{
- cache: cache.NewWithKey[K, struct{}](defaultExpiration, time.Minute),
+func NewRateLimit[K comparable](defaultExpiration time.Duration) *RateLimit[K, struct{}] {
+ return NewRateLimitV[K, struct{}](defaultExpiration)
+}
+
+func NewRateLimitV[K comparable, V any](defaultExpiration time.Duration) *RateLimit[K, V] {
+ return &RateLimit[K, V]{
+ cache: cache.NewWithKey[K, V](defaultExpiration, time.Minute),
}
}
-func (l *RateLimit[K]) Clb(k K, clb func()) {
+func (l *RateLimit[K, V]) Clb(k K, clb func()) {
if !l.cache.Has(k) {
clb()
- l.cache.SetD(k, struct{}{})
+ l.cache.SetD(k, l.value)
+ }
+}
+
+func (l *RateLimit[K, V]) ClbV(k K, clb func() (V, error)) (V, bool, error) {
+ var err error
+ if !l.cache.Has(k) {
+ l.value, err = clb()
+ l.cache.SetD(k, l.value)
+ return l.value, true, err
}
+ return l.value, false, err
}
var lastSeenRL = NewRateLimit[database.UserID](time.Second)