dkforest

A forum and chat platform (onion)
git clone https://git.dasho.dev/n0tr1v/dkforest.git
Log | Files | Refs | LICENSE

csrf.go (6017B)


      1 package middlewares
      2 
      3 import (
      4 	"crypto/subtle"
      5 	"errors"
      6 	"net/http"
      7 	"strings"
      8 	"time"
      9 
     10 	"github.com/labstack/echo"
     11 	"github.com/labstack/echo/middleware"
     12 	"github.com/labstack/gommon/random"
     13 )
     14 
     15 type (
     16 	// CSRFConfig defines the config for CSRF middleware.
     17 	CSRFConfig struct {
     18 		// Skipper defines a function to skip middleware.
     19 		Skipper middleware.Skipper
     20 
     21 		// TokenLength is the length of the generated token.
     22 		TokenLength uint8 `yaml:"token_length"`
     23 		// Optional. Default value 32.
     24 
     25 		// TokenLookup is a string in the form of "<source>:<key>" that is used
     26 		// to extract token from the request.
     27 		// Optional. Default value "header:X-CSRF-Token".
     28 		// Possible values:
     29 		// - "header:<name>"
     30 		// - "form:<name>"
     31 		// - "query:<name>"
     32 		TokenLookup string `yaml:"token_lookup"`
     33 
     34 		// Context key to store generated CSRF token into context.
     35 		// Optional. Default value "csrf".
     36 		ContextKey string `yaml:"context_key"`
     37 
     38 		// Name of the CSRF cookie. This cookie will store CSRF token.
     39 		// Optional. Default value "csrf".
     40 		CookieName string `yaml:"cookie_name"`
     41 
     42 		// Domain of the CSRF cookie.
     43 		// Optional. Default value none.
     44 		CookieDomain string `yaml:"cookie_domain"`
     45 
     46 		// Path of the CSRF cookie.
     47 		// Optional. Default value none.
     48 		CookiePath string `yaml:"cookie_path"`
     49 
     50 		// Max age (in seconds) of the CSRF cookie.
     51 		// Optional. Default value 86400 (24hr).
     52 		CookieMaxAge int64 `yaml:"cookie_max_age"`
     53 
     54 		// Indicates if CSRF cookie is secure.
     55 		// Optional. Default value false.
     56 		CookieSecure bool `yaml:"cookie_secure"`
     57 
     58 		SameSite http.SameSite `yaml:"cookie_same_site"`
     59 
     60 		// Indicates if CSRF cookie is HTTP only.
     61 		// Optional. Default value false.
     62 		CookieHTTPOnly bool `yaml:"cookie_http_only"`
     63 	}
     64 
     65 	// csrfTokenExtractor defines a function that takes `echo.Context` and returns
     66 	// either a token or an error.
     67 	csrfTokenExtractor func(echo.Context) (string, error)
     68 )
     69 
     70 var (
     71 	// DefaultCSRFConfig is the default CSRF middleware config.
     72 	DefaultCSRFConfig = CSRFConfig{
     73 		Skipper:      middleware.DefaultSkipper,
     74 		TokenLength:  32,
     75 		TokenLookup:  "header:" + echo.HeaderXCSRFToken,
     76 		ContextKey:   "csrf",
     77 		CookieName:   "_csrf",
     78 		CookieMaxAge: 86400,
     79 	}
     80 )
     81 
     82 // CSRF returns a Cross-Site Request Forgery (CSRF) middleware.
     83 // See: https://en.wikipedia.org/wiki/Cross-site_request_forgery
     84 func CSRF() echo.MiddlewareFunc {
     85 	c := DefaultCSRFConfig
     86 	return CSRFWithConfig(c)
     87 }
     88 
     89 // CSRFWithConfig returns a CSRF middleware with config.
     90 // See `CSRF()`.
     91 func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
     92 	// Defaults
     93 	if config.Skipper == nil {
     94 		config.Skipper = DefaultCSRFConfig.Skipper
     95 	}
     96 	if config.TokenLength == 0 {
     97 		config.TokenLength = DefaultCSRFConfig.TokenLength
     98 	}
     99 	if config.TokenLookup == "" {
    100 		config.TokenLookup = DefaultCSRFConfig.TokenLookup
    101 	}
    102 	if config.ContextKey == "" {
    103 		config.ContextKey = DefaultCSRFConfig.ContextKey
    104 	}
    105 	if config.CookieName == "" {
    106 		config.CookieName = DefaultCSRFConfig.CookieName
    107 	}
    108 	if config.CookieMaxAge == 0 {
    109 		config.CookieMaxAge = DefaultCSRFConfig.CookieMaxAge
    110 	}
    111 
    112 	// Initialize
    113 	parts := strings.Split(config.TokenLookup, ":")
    114 	extractor := csrfTokenFromHeader(parts[1])
    115 	switch parts[0] {
    116 	case "form":
    117 		extractor = csrfTokenFromForm(parts[1])
    118 	case "query":
    119 		extractor = csrfTokenFromQuery(parts[1])
    120 	}
    121 
    122 	return func(next echo.HandlerFunc) echo.HandlerFunc {
    123 		return func(c echo.Context) error {
    124 			req := c.Request()
    125 			k, err := c.Cookie(config.CookieName)
    126 			token := ""
    127 
    128 			// Generate token
    129 			if err != nil {
    130 				token = random.String(config.TokenLength)
    131 
    132 				// Set CSRF cookie
    133 				cookie := new(http.Cookie)
    134 				cookie.Name = config.CookieName
    135 				cookie.Value = token
    136 				if config.CookiePath != "" {
    137 					cookie.Path = config.CookiePath
    138 				}
    139 				if config.CookieDomain != "" {
    140 					cookie.Domain = config.CookieDomain
    141 				}
    142 				cookie.Expires = time.Now().Add(time.Duration(config.CookieMaxAge) * time.Second)
    143 				cookie.Secure = config.CookieSecure
    144 				cookie.SameSite = config.SameSite
    145 				cookie.HttpOnly = config.CookieHTTPOnly
    146 				c.SetCookie(cookie)
    147 			} else {
    148 				// Reuse token
    149 				token = k.Value
    150 			}
    151 
    152 			// Store token in the context
    153 			c.Set(config.ContextKey, token)
    154 
    155 			// Protect clients from caching the response
    156 			c.Response().Header().Add(echo.HeaderVary, echo.HeaderCookie)
    157 
    158 			if config.Skipper(c) {
    159 				return next(c)
    160 			}
    161 
    162 			switch req.Method {
    163 			case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace:
    164 			default:
    165 				// Validate token only for requests which are not defined as 'safe' by RFC7231
    166 				clientToken, err := extractor(c)
    167 				if err != nil {
    168 					return echo.NewHTTPError(http.StatusBadRequest, err.Error())
    169 				}
    170 				if !validateCSRFToken(token, clientToken) {
    171 					return echo.NewHTTPError(http.StatusForbidden, "invalid csrf token, please try to resubmit the form")
    172 				}
    173 			}
    174 
    175 			return next(c)
    176 		}
    177 	}
    178 }
    179 
    180 // csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the
    181 // provided request header.
    182 func csrfTokenFromHeader(header string) csrfTokenExtractor {
    183 	return func(c echo.Context) (string, error) {
    184 		return c.Request().Header.Get(header), nil
    185 	}
    186 }
    187 
    188 // csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the
    189 // provided form parameter.
    190 func csrfTokenFromForm(param string) csrfTokenExtractor {
    191 	return func(c echo.Context) (string, error) {
    192 		token := c.FormValue(param)
    193 		if token == "" {
    194 			return "", errors.New("missing csrf token in the form parameter")
    195 		}
    196 		return token, nil
    197 	}
    198 }
    199 
    200 // csrfTokenFromQuery returns a `csrfTokenExtractor` that extracts token from the
    201 // provided query parameter.
    202 func csrfTokenFromQuery(param string) csrfTokenExtractor {
    203 	return func(c echo.Context) (string, error) {
    204 		token := c.QueryParam(param)
    205 		if token == "" {
    206 			return "", errors.New("missing csrf token in the query string")
    207 		}
    208 		return token, nil
    209 	}
    210 }
    211 
    212 func validateCSRFToken(token, clientToken string) bool {
    213 	return subtle.ConstantTimeCompare([]byte(token), []byte(clientToken)) == 1
    214 }