Skip to content

Commit

Permalink
Adding AllowRequestFunc (#85)
Browse files Browse the repository at this point in the history
  • Loading branch information
unrolled authored Jul 16, 2022
1 parent 56ae1bd commit 0ce3852
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 49 deletions.
19 changes: 14 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ Secure comes with a variety of configuration options (Note: these are not the de
// ...
s := secure.New(secure.Options{
AllowedHosts: []string{"ssl.example.com"}, // AllowedHosts is a list of fully qualified domain names that are allowed. Default is empty list, which allows any and all host names.
AllowedHostsFunc: func() []string { return []string{"example.com", "www.example.com" } // AllowedHostsFunc is a custom function that returns a list of fully qualified domain names that are allowed. This can be used in combination with the above AllowedHosts.
AllowedHostsAreRegex: false, // AllowedHostsAreRegex determines, if the provided AllowedHosts slice contains valid regular expressions. This does not apply to the `AllowedHostsFunc` values! Default is false.
AllowedHostsAreRegex: false, // AllowedHostsAreRegex determines, if the provided AllowedHosts slice contains valid regular expressions. Default is false.
AllowRequestFunc: nil, // AllowRequestFunc is a custom function type that allows you to determine if the request should proceed or not based on your own custom logic. Default is nil.
HostsProxyHeaders: []string{"X-Forwarded-Hosts"}, // HostsProxyHeaders is a set of header keys that may hold a proxied hostname value for the request.
SSLRedirect: true, // If SSLRedirect is set to true, then only allow HTTPS requests. Default is false.
SSLTemporaryRedirect: false, // If SSLTemporaryRedirect is true, the a 302 will be used while redirecting. Default is false (301).
Expand Down Expand Up @@ -102,8 +102,8 @@ s := secure.New()

l := secure.New(secure.Options{
AllowedHosts: []string,
AllowedHostsFunc: nil,
AllowedHostsAreRegex: false,
AllowRequestFunc: nil,
HostsProxyHeaders: []string,
SSLRedirect: false,
SSLTemporaryRedirect: false,
Expand All @@ -127,11 +127,20 @@ l := secure.New(secure.Options{
IsDevelopment: false,
})
~~~
Also note the default bad host handler returns an error:
The default bad host handler returns the following error:
~~~ go
http.Error(w, "Bad Host", http.StatusInternalServerError)
~~~
Call `secure.SetBadHostHandler` to change the bad host handler.
Call `secure.SetBadHostHandler` to set your own custom handler.

The default bad request handler returns the following error:
~~~ go
http.Error(w, "Bad Request", http.StatusBadRequest)
~~~
Call `secure.SetBadRequestHandler` to set your own custom handler.

### Allow Request Function
Secure allows you to set a custom function (`func(r *http.Request) bool`) for the `AllowRequestFunc` option. You can use this function as a custom filter to allow the request to continue or simply reject it. This can be handy if you need to do any dynamic filtering on any of the request properties. It should be noted that this function will be called on every request, so be sure to make your checks quick and not relying on time consuming external calls (or you will be slowing down all requests). See above on how to set a custom handler for the rejected requests.

### Redirecting HTTP to HTTPS
If you want to redirect all HTTP requests to HTTPS, you can use the following example.
Expand Down
50 changes: 28 additions & 22 deletions secure.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,17 @@ const (
// SSLHostFunc is a custom function type that can be used to dynamically set the SSL host of a request.
type SSLHostFunc func(host string) (newHost string)

// AllowedHostsFunc is a custom function type that can be used to dynamically return a slice of strings that will be used in the `AllowHosts` check.
type AllowedHostsFunc func() []string
// AllowRequestFunc is a custom function type that can be used to dynamically determine if a request should proceed or not.
type AllowRequestFunc func(r *http.Request) bool

func defaultBadHostHandler(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Bad Host", http.StatusInternalServerError)
}

func defaultBadRequestHandler(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Bad Request", http.StatusBadRequest)
}

// Options is a struct for specifying configuration options for the secure.Secure middleware.
type Options struct {
// If BrowserXssFilter is true, adds the X-XSS-Protection header with the value `1; mode=block`. Default is false.
Expand Down Expand Up @@ -95,10 +99,10 @@ type Options struct {
SSLHost string
// AllowedHosts is a slice of fully qualified domain names that are allowed. Default is an empty slice, which allows any and all host names.
AllowedHosts []string
// AllowedHostsFunc is a custom function that returns a slice of fully qualified domain names that are allowed. If set, values will be used in combination with the above AllowedHosts. Default is nil.
AllowedHostsFunc AllowedHostsFunc
// AllowedHostsAreRegex determines, if the provided `AllowedHosts` slice contains valid regular expressions. This does not apply to `AllowedHostsFunc`! If this flag is set to true, every request's host will be checked against these expressions. Default is false.
// AllowedHostsAreRegex determines, if the provided `AllowedHosts` slice contains valid regular expressions. If this flag is set to true, every request's host will be checked against these expressions. Default is false.
AllowedHostsAreRegex bool
// AllowRequestFunc is a custom function that allows you to determine if the request should proceed or not based on your own custom logic. Default is nil.
AllowRequestFunc AllowRequestFunc
// HostsProxyHeaders is a set of header keys that may hold a proxied hostname value for the request.
HostsProxyHeaders []string
// SSLHostFunc is a function pointer, the return value of the function is the host name that has same functionality as `SSHost`. Default is nil.
Expand All @@ -123,6 +127,9 @@ type Secure struct {
// badHostHandler is the handler used when an incorrect host is passed in.
badHostHandler http.Handler

// badRequestHandler is the handler used when the AllowRequestFunc rejects a request.
badRequestHandler http.Handler

// cRegexAllowedHosts saves the compiled regular expressions of the AllowedHosts
// option for subsequent use in processRequest
cRegexAllowedHosts []*regexp.Regexp
Expand All @@ -146,8 +153,9 @@ func New(options ...Options) *Secure {
o.nonceEnabled = strings.Contains(o.ContentSecurityPolicy, "%[1]s") || strings.Contains(o.ContentSecurityPolicyReportOnly, "%[1]s")

s := &Secure{
opt: o,
badHostHandler: http.HandlerFunc(defaultBadHostHandler),
opt: o,
badHostHandler: http.HandlerFunc(defaultBadHostHandler),
badRequestHandler: http.HandlerFunc(defaultBadRequestHandler),
}

if s.opt.AllowedHostsAreRegex {
Expand All @@ -174,6 +182,11 @@ func (s *Secure) SetBadHostHandler(handler http.Handler) {
s.badHostHandler = handler
}

// SetBadRequestHandler sets the handler to call when the AllowRequestFunc rejects a request.
func (s *Secure) SetBadRequestHandler(handler http.Handler) {
s.badRequestHandler = handler
}

// Handler implements the http.HandlerFunc for integration with the standard net/http lib.
func (s *Secure) Handler(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand Down Expand Up @@ -294,14 +307,7 @@ func (s *Secure) processRequest(w http.ResponseWriter, r *http.Request) (http.He
}

// Allowed hosts check.
combinedAllowedHosts := s.opt.AllowedHosts
var allowedFuncHosts []string
if s.opt.AllowedHostsFunc != nil {
allowedFuncHosts = s.opt.AllowedHostsFunc()
combinedAllowedHosts = append(combinedAllowedHosts, allowedFuncHosts...)
}

if len(combinedAllowedHosts) > 0 && !s.opt.IsDevelopment {
if len(s.opt.AllowedHosts) > 0 && !s.opt.IsDevelopment {
isGoodHost := false
if s.opt.AllowedHostsAreRegex {
for _, allowedHost := range s.cRegexAllowedHosts {
Expand All @@ -310,14 +316,8 @@ func (s *Secure) processRequest(w http.ResponseWriter, r *http.Request) (http.He
break
}
}
for _, allowedHost := range allowedFuncHosts {
if strings.EqualFold(allowedHost, host) {
isGoodHost = true
break
}
}
} else {
for _, allowedHost := range combinedAllowedHosts {
for _, allowedHost := range s.opt.AllowedHosts {
if strings.EqualFold(allowedHost, host) {
isGoodHost = true
break
Expand Down Expand Up @@ -380,6 +380,12 @@ func (s *Secure) processRequest(w http.ResponseWriter, r *http.Request) (http.He
}
}

// If the AllowRequestFunc is set, call it and exit early if needed.
if s.opt.AllowRequestFunc != nil && !s.opt.AllowRequestFunc(r) {
s.badRequestHandler.ServeHTTP(w, r)
return nil, nil, fmt.Errorf("request not allowed")
}

// Create our header container.
responseHeader := make(http.Header)

Expand Down
40 changes: 18 additions & 22 deletions secure_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"net/http"
"net/http/httptest"
"reflect"
"strings"
"testing"
)

Expand Down Expand Up @@ -1448,57 +1449,52 @@ func TestMultipleCustomSecureContextKeys(t *testing.T) {
expect(t, s2Headers.Get(featurePolicyHeader), s2.opt.FeaturePolicy)
}

func TestAllowHostsFunc(t *testing.T) {
func TestAllowRequestFuncTrue(t *testing.T) {
s := New(Options{
AllowedHostsFunc: func() []string { return []string{"www.allow-func.com"} },
AllowRequestFunc: func(r *http.Request) bool { return true },
})

res := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/foo", nil)
req.Host = "www.allow-func.com"
req.Host = "www.allow-request.com"

s.Handler(myHandler).ServeHTTP(res, req)

expect(t, res.Code, http.StatusOK)
expect(t, res.Body.String(), `bar`)
}

func TestAllowHostsFuncWithAllowedHostsList(t *testing.T) {
func TestAllowRequestFuncFalse(t *testing.T) {
s := New(Options{
AllowedHosts: []string{"www.allow.com"},
AllowedHostsFunc: func() []string { return []string{"www.allow-func.com"} },
AllowRequestFunc: func(r *http.Request) bool { return false },
})

res := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/foo", nil)
req.Host = "www.allow.com"
req.Host = "www.deny-request.com"

s.Handler(myHandler).ServeHTTP(res, req)

expect(t, res.Code, http.StatusOK)
expect(t, res.Body.String(), `bar`)
expect(t, res.Code, http.StatusBadRequest)
}

func TestAllowHostsFuncWithAllowedHostsListWithRegex(t *testing.T) {
func TestBadRequestHandler(t *testing.T) {
s := New(Options{
AllowedHosts: []string{"*\\.allow\\.com"},
AllowedHostsFunc: func() []string { return []string{"foo.bar.allow.com"} },
AllowedHostsAreRegex: true,
AllowRequestFunc: func(r *http.Request) bool { return false },
})
badRequestFunc := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "custom error", http.StatusConflict)
})
s.SetBadRequestHandler(badRequestFunc)

res := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/foo", nil)
req.Host = "foo.bar.allow.com"
s.Handler(myHandler).ServeHTTP(res, req)
expect(t, res.Code, http.StatusOK)
expect(t, res.Body.String(), `bar`)
req.Host = "www.deny-request.com"

res = httptest.NewRecorder()
req, _ = http.NewRequest("GET", "/foo", nil)
req.Host = "bar.allow.com"
s.Handler(myHandler).ServeHTTP(res, req)
expect(t, res.Code, http.StatusOK)
expect(t, res.Body.String(), `bar`)

expect(t, res.Code, http.StatusConflict)
expect(t, strings.TrimSpace(res.Body.String()), `custom error`)
}

/* Test Helpers */
Expand Down

0 comments on commit 0ce3852

Please sign in to comment.