csrf.go (6017B)
1 package middlewares 2 3 import ( 4 "crypto/subtle" 5 "errors" 6 "net/http" 7 "strings" 8 "time" 9 10 "github.com/labstack/echo" 11 "github.com/labstack/echo/middleware" 12 "github.com/labstack/gommon/random" 13 ) 14 15 type ( 16 // CSRFConfig defines the config for CSRF middleware. 17 CSRFConfig struct { 18 // Skipper defines a function to skip middleware. 19 Skipper middleware.Skipper 20 21 // TokenLength is the length of the generated token. 22 TokenLength uint8 `yaml:"token_length"` 23 // Optional. Default value 32. 24 25 // TokenLookup is a string in the form of "<source>:<key>" that is used 26 // to extract token from the request. 27 // Optional. Default value "header:X-CSRF-Token". 28 // Possible values: 29 // - "header:<name>" 30 // - "form:<name>" 31 // - "query:<name>" 32 TokenLookup string `yaml:"token_lookup"` 33 34 // Context key to store generated CSRF token into context. 35 // Optional. Default value "csrf". 36 ContextKey string `yaml:"context_key"` 37 38 // Name of the CSRF cookie. This cookie will store CSRF token. 39 // Optional. Default value "csrf". 40 CookieName string `yaml:"cookie_name"` 41 42 // Domain of the CSRF cookie. 43 // Optional. Default value none. 44 CookieDomain string `yaml:"cookie_domain"` 45 46 // Path of the CSRF cookie. 47 // Optional. Default value none. 48 CookiePath string `yaml:"cookie_path"` 49 50 // Max age (in seconds) of the CSRF cookie. 51 // Optional. Default value 86400 (24hr). 52 CookieMaxAge int64 `yaml:"cookie_max_age"` 53 54 // Indicates if CSRF cookie is secure. 55 // Optional. Default value false. 56 CookieSecure bool `yaml:"cookie_secure"` 57 58 SameSite http.SameSite `yaml:"cookie_same_site"` 59 60 // Indicates if CSRF cookie is HTTP only. 61 // Optional. Default value false. 62 CookieHTTPOnly bool `yaml:"cookie_http_only"` 63 } 64 65 // csrfTokenExtractor defines a function that takes `echo.Context` and returns 66 // either a token or an error. 67 csrfTokenExtractor func(echo.Context) (string, error) 68 ) 69 70 var ( 71 // DefaultCSRFConfig is the default CSRF middleware config. 72 DefaultCSRFConfig = CSRFConfig{ 73 Skipper: middleware.DefaultSkipper, 74 TokenLength: 32, 75 TokenLookup: "header:" + echo.HeaderXCSRFToken, 76 ContextKey: "csrf", 77 CookieName: "_csrf", 78 CookieMaxAge: 86400, 79 } 80 ) 81 82 // CSRF returns a Cross-Site Request Forgery (CSRF) middleware. 83 // See: https://en.wikipedia.org/wiki/Cross-site_request_forgery 84 func CSRF() echo.MiddlewareFunc { 85 c := DefaultCSRFConfig 86 return CSRFWithConfig(c) 87 } 88 89 // CSRFWithConfig returns a CSRF middleware with config. 90 // See `CSRF()`. 91 func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { 92 // Defaults 93 if config.Skipper == nil { 94 config.Skipper = DefaultCSRFConfig.Skipper 95 } 96 if config.TokenLength == 0 { 97 config.TokenLength = DefaultCSRFConfig.TokenLength 98 } 99 if config.TokenLookup == "" { 100 config.TokenLookup = DefaultCSRFConfig.TokenLookup 101 } 102 if config.ContextKey == "" { 103 config.ContextKey = DefaultCSRFConfig.ContextKey 104 } 105 if config.CookieName == "" { 106 config.CookieName = DefaultCSRFConfig.CookieName 107 } 108 if config.CookieMaxAge == 0 { 109 config.CookieMaxAge = DefaultCSRFConfig.CookieMaxAge 110 } 111 112 // Initialize 113 parts := strings.Split(config.TokenLookup, ":") 114 extractor := csrfTokenFromHeader(parts[1]) 115 switch parts[0] { 116 case "form": 117 extractor = csrfTokenFromForm(parts[1]) 118 case "query": 119 extractor = csrfTokenFromQuery(parts[1]) 120 } 121 122 return func(next echo.HandlerFunc) echo.HandlerFunc { 123 return func(c echo.Context) error { 124 req := c.Request() 125 k, err := c.Cookie(config.CookieName) 126 token := "" 127 128 // Generate token 129 if err != nil { 130 token = random.String(config.TokenLength) 131 132 // Set CSRF cookie 133 cookie := new(http.Cookie) 134 cookie.Name = config.CookieName 135 cookie.Value = token 136 if config.CookiePath != "" { 137 cookie.Path = config.CookiePath 138 } 139 if config.CookieDomain != "" { 140 cookie.Domain = config.CookieDomain 141 } 142 cookie.Expires = time.Now().Add(time.Duration(config.CookieMaxAge) * time.Second) 143 cookie.Secure = config.CookieSecure 144 cookie.SameSite = config.SameSite 145 cookie.HttpOnly = config.CookieHTTPOnly 146 c.SetCookie(cookie) 147 } else { 148 // Reuse token 149 token = k.Value 150 } 151 152 // Store token in the context 153 c.Set(config.ContextKey, token) 154 155 // Protect clients from caching the response 156 c.Response().Header().Add(echo.HeaderVary, echo.HeaderCookie) 157 158 if config.Skipper(c) { 159 return next(c) 160 } 161 162 switch req.Method { 163 case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace: 164 default: 165 // Validate token only for requests which are not defined as 'safe' by RFC7231 166 clientToken, err := extractor(c) 167 if err != nil { 168 return echo.NewHTTPError(http.StatusBadRequest, err.Error()) 169 } 170 if !validateCSRFToken(token, clientToken) { 171 return echo.NewHTTPError(http.StatusForbidden, "invalid csrf token, please try to resubmit the form") 172 } 173 } 174 175 return next(c) 176 } 177 } 178 } 179 180 // csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the 181 // provided request header. 182 func csrfTokenFromHeader(header string) csrfTokenExtractor { 183 return func(c echo.Context) (string, error) { 184 return c.Request().Header.Get(header), nil 185 } 186 } 187 188 // csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the 189 // provided form parameter. 190 func csrfTokenFromForm(param string) csrfTokenExtractor { 191 return func(c echo.Context) (string, error) { 192 token := c.FormValue(param) 193 if token == "" { 194 return "", errors.New("missing csrf token in the form parameter") 195 } 196 return token, nil 197 } 198 } 199 200 // csrfTokenFromQuery returns a `csrfTokenExtractor` that extracts token from the 201 // provided query parameter. 202 func csrfTokenFromQuery(param string) csrfTokenExtractor { 203 return func(c echo.Context) (string, error) { 204 token := c.QueryParam(param) 205 if token == "" { 206 return "", errors.New("missing csrf token in the query string") 207 } 208 return token, nil 209 } 210 } 211 212 func validateCSRFToken(token, clientToken string) bool { 213 return subtle.ConstantTimeCompare([]byte(token), []byte(clientToken)) == 1 214 }