diff --git a/README.md b/README.md index 7eac2e9..2c405be 100644 --- a/README.md +++ b/README.md @@ -62,8 +62,8 @@ Secure comes with a variety of configuration options (Note: these are not the de // ... s := secure.New(secure.Options{ AllowedHosts: []string{"ssl.example.com"}, // AllowedHosts is a list of fully qualified domain names that are allowed. Default is empty list, which allows any and all host names. - AllowedHostsFunc: func() []string { return []string{"example.com", "www.example.com" } // AllowedHostsFunc is a custom function that returns a list of fully qualified domain names that are allowed. This can be used in combination with the above AllowedHosts. - AllowedHostsAreRegex: false, // AllowedHostsAreRegex determines, if the provided AllowedHosts slice contains valid regular expressions. This does not apply to the `AllowedHostsFunc` values! Default is false. + AllowedHostsAreRegex: false, // AllowedHostsAreRegex determines, if the provided AllowedHosts slice contains valid regular expressions. Default is false. + AllowRequestFunc: nil, // AllowRequestFunc is a custom function type that allows you to determine if the request should proceed or not based on your own custom logic. Default is nil. HostsProxyHeaders: []string{"X-Forwarded-Hosts"}, // HostsProxyHeaders is a set of header keys that may hold a proxied hostname value for the request. SSLRedirect: true, // If SSLRedirect is set to true, then only allow HTTPS requests. Default is false. SSLTemporaryRedirect: false, // If SSLTemporaryRedirect is true, the a 302 will be used while redirecting. Default is false (301). @@ -102,8 +102,8 @@ s := secure.New() l := secure.New(secure.Options{ AllowedHosts: []string, - AllowedHostsFunc: nil, AllowedHostsAreRegex: false, + AllowRequestFunc: nil, HostsProxyHeaders: []string, SSLRedirect: false, SSLTemporaryRedirect: false, @@ -127,11 +127,20 @@ l := secure.New(secure.Options{ IsDevelopment: false, }) ~~~ -Also note the default bad host handler returns an error: +The default bad host handler returns the following error: ~~~ go http.Error(w, "Bad Host", http.StatusInternalServerError) ~~~ -Call `secure.SetBadHostHandler` to change the bad host handler. +Call `secure.SetBadHostHandler` to set your own custom handler. + +The default bad request handler returns the following error: +~~~ go +http.Error(w, "Bad Request", http.StatusBadRequest) +~~~ +Call `secure.SetBadRequestHandler` to set your own custom handler. + +### Allow Request Function +Secure allows you to set a custom function (`func(r *http.Request) bool`) for the `AllowRequestFunc` option. You can use this function as a custom filter to allow the request to continue or simply reject it. This can be handy if you need to do any dynamic filtering on any of the request properties. It should be noted that this function will be called on every request, so be sure to make your checks quick and not relying on time consuming external calls (or you will be slowing down all requests). See above on how to set a custom handler for the rejected requests. ### Redirecting HTTP to HTTPS If you want to redirect all HTTP requests to HTTPS, you can use the following example. diff --git a/secure.go b/secure.go index a471150..237280a 100644 --- a/secure.go +++ b/secure.go @@ -36,13 +36,17 @@ const ( // SSLHostFunc is a custom function type that can be used to dynamically set the SSL host of a request. type SSLHostFunc func(host string) (newHost string) -// AllowedHostsFunc is a custom function type that can be used to dynamically return a slice of strings that will be used in the `AllowHosts` check. -type AllowedHostsFunc func() []string +// AllowRequestFunc is a custom function type that can be used to dynamically determine if a request should proceed or not. +type AllowRequestFunc func(r *http.Request) bool func defaultBadHostHandler(w http.ResponseWriter, r *http.Request) { http.Error(w, "Bad Host", http.StatusInternalServerError) } +func defaultBadRequestHandler(w http.ResponseWriter, r *http.Request) { + http.Error(w, "Bad Request", http.StatusBadRequest) +} + // Options is a struct for specifying configuration options for the secure.Secure middleware. type Options struct { // If BrowserXssFilter is true, adds the X-XSS-Protection header with the value `1; mode=block`. Default is false. @@ -95,10 +99,10 @@ type Options struct { SSLHost string // AllowedHosts is a slice of fully qualified domain names that are allowed. Default is an empty slice, which allows any and all host names. AllowedHosts []string - // AllowedHostsFunc is a custom function that returns a slice of fully qualified domain names that are allowed. If set, values will be used in combination with the above AllowedHosts. Default is nil. - AllowedHostsFunc AllowedHostsFunc - // AllowedHostsAreRegex determines, if the provided `AllowedHosts` slice contains valid regular expressions. This does not apply to `AllowedHostsFunc`! If this flag is set to true, every request's host will be checked against these expressions. Default is false. + // AllowedHostsAreRegex determines, if the provided `AllowedHosts` slice contains valid regular expressions. If this flag is set to true, every request's host will be checked against these expressions. Default is false. AllowedHostsAreRegex bool + // AllowRequestFunc is a custom function that allows you to determine if the request should proceed or not based on your own custom logic. Default is nil. + AllowRequestFunc AllowRequestFunc // HostsProxyHeaders is a set of header keys that may hold a proxied hostname value for the request. HostsProxyHeaders []string // SSLHostFunc is a function pointer, the return value of the function is the host name that has same functionality as `SSHost`. Default is nil. @@ -123,6 +127,9 @@ type Secure struct { // badHostHandler is the handler used when an incorrect host is passed in. badHostHandler http.Handler + // badRequestHandler is the handler used when the AllowRequestFunc rejects a request. + badRequestHandler http.Handler + // cRegexAllowedHosts saves the compiled regular expressions of the AllowedHosts // option for subsequent use in processRequest cRegexAllowedHosts []*regexp.Regexp @@ -146,8 +153,9 @@ func New(options ...Options) *Secure { o.nonceEnabled = strings.Contains(o.ContentSecurityPolicy, "%[1]s") || strings.Contains(o.ContentSecurityPolicyReportOnly, "%[1]s") s := &Secure{ - opt: o, - badHostHandler: http.HandlerFunc(defaultBadHostHandler), + opt: o, + badHostHandler: http.HandlerFunc(defaultBadHostHandler), + badRequestHandler: http.HandlerFunc(defaultBadRequestHandler), } if s.opt.AllowedHostsAreRegex { @@ -174,6 +182,11 @@ func (s *Secure) SetBadHostHandler(handler http.Handler) { s.badHostHandler = handler } +// SetBadRequestHandler sets the handler to call when the AllowRequestFunc rejects a request. +func (s *Secure) SetBadRequestHandler(handler http.Handler) { + s.badRequestHandler = handler +} + // Handler implements the http.HandlerFunc for integration with the standard net/http lib. func (s *Secure) Handler(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -294,14 +307,7 @@ func (s *Secure) processRequest(w http.ResponseWriter, r *http.Request) (http.He } // Allowed hosts check. - combinedAllowedHosts := s.opt.AllowedHosts - var allowedFuncHosts []string - if s.opt.AllowedHostsFunc != nil { - allowedFuncHosts = s.opt.AllowedHostsFunc() - combinedAllowedHosts = append(combinedAllowedHosts, allowedFuncHosts...) - } - - if len(combinedAllowedHosts) > 0 && !s.opt.IsDevelopment { + if len(s.opt.AllowedHosts) > 0 && !s.opt.IsDevelopment { isGoodHost := false if s.opt.AllowedHostsAreRegex { for _, allowedHost := range s.cRegexAllowedHosts { @@ -310,14 +316,8 @@ func (s *Secure) processRequest(w http.ResponseWriter, r *http.Request) (http.He break } } - for _, allowedHost := range allowedFuncHosts { - if strings.EqualFold(allowedHost, host) { - isGoodHost = true - break - } - } } else { - for _, allowedHost := range combinedAllowedHosts { + for _, allowedHost := range s.opt.AllowedHosts { if strings.EqualFold(allowedHost, host) { isGoodHost = true break @@ -380,6 +380,12 @@ func (s *Secure) processRequest(w http.ResponseWriter, r *http.Request) (http.He } } + // If the AllowRequestFunc is set, call it and exit early if needed. + if s.opt.AllowRequestFunc != nil && !s.opt.AllowRequestFunc(r) { + s.badRequestHandler.ServeHTTP(w, r) + return nil, nil, fmt.Errorf("request not allowed") + } + // Create our header container. responseHeader := make(http.Header) diff --git a/secure_test.go b/secure_test.go index c99225b..5d14622 100644 --- a/secure_test.go +++ b/secure_test.go @@ -5,6 +5,7 @@ import ( "net/http" "net/http/httptest" "reflect" + "strings" "testing" ) @@ -1448,14 +1449,14 @@ func TestMultipleCustomSecureContextKeys(t *testing.T) { expect(t, s2Headers.Get(featurePolicyHeader), s2.opt.FeaturePolicy) } -func TestAllowHostsFunc(t *testing.T) { +func TestAllowRequestFuncTrue(t *testing.T) { s := New(Options{ - AllowedHostsFunc: func() []string { return []string{"www.allow-func.com"} }, + AllowRequestFunc: func(r *http.Request) bool { return true }, }) res := httptest.NewRecorder() req, _ := http.NewRequest("GET", "/foo", nil) - req.Host = "www.allow-func.com" + req.Host = "www.allow-request.com" s.Handler(myHandler).ServeHTTP(res, req) @@ -1463,42 +1464,37 @@ func TestAllowHostsFunc(t *testing.T) { expect(t, res.Body.String(), `bar`) } -func TestAllowHostsFuncWithAllowedHostsList(t *testing.T) { +func TestAllowRequestFuncFalse(t *testing.T) { s := New(Options{ - AllowedHosts: []string{"www.allow.com"}, - AllowedHostsFunc: func() []string { return []string{"www.allow-func.com"} }, + AllowRequestFunc: func(r *http.Request) bool { return false }, }) res := httptest.NewRecorder() req, _ := http.NewRequest("GET", "/foo", nil) - req.Host = "www.allow.com" + req.Host = "www.deny-request.com" s.Handler(myHandler).ServeHTTP(res, req) - expect(t, res.Code, http.StatusOK) - expect(t, res.Body.String(), `bar`) + expect(t, res.Code, http.StatusBadRequest) } -func TestAllowHostsFuncWithAllowedHostsListWithRegex(t *testing.T) { +func TestBadRequestHandler(t *testing.T) { s := New(Options{ - AllowedHosts: []string{"*\\.allow\\.com"}, - AllowedHostsFunc: func() []string { return []string{"foo.bar.allow.com"} }, - AllowedHostsAreRegex: true, + AllowRequestFunc: func(r *http.Request) bool { return false }, }) + badRequestFunc := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "custom error", http.StatusConflict) + }) + s.SetBadRequestHandler(badRequestFunc) res := httptest.NewRecorder() req, _ := http.NewRequest("GET", "/foo", nil) - req.Host = "foo.bar.allow.com" - s.Handler(myHandler).ServeHTTP(res, req) - expect(t, res.Code, http.StatusOK) - expect(t, res.Body.String(), `bar`) + req.Host = "www.deny-request.com" - res = httptest.NewRecorder() - req, _ = http.NewRequest("GET", "/foo", nil) - req.Host = "bar.allow.com" s.Handler(myHandler).ServeHTTP(res, req) - expect(t, res.Code, http.StatusOK) - expect(t, res.Body.String(), `bar`) + + expect(t, res.Code, http.StatusConflict) + expect(t, strings.TrimSpace(res.Body.String()), `custom error`) } /* Test Helpers */