From c734f7aaef0b7775c745c964df0ef5b4c8823f05 Mon Sep 17 00:00:00 2001 From: arekkas Date: Sun, 26 Aug 2018 11:49:54 +0200 Subject: [PATCH] Introduce cors.Config.AllowOriginRequestFunc This patch introduces cors.Config.AllowOriginRequestFunc ( `func (r *http.Request origin string) bool`) which is a custom function to validate the origin. It takes the HTTP Request object and the origin as argument and returns true if allowed or false otherwise. If this option is set, the content of `AllowedOrigins` and `AllowOriginFunc` is ignored Closes #59 Signed-off-by: arekkas --- README.md | 1 + cors.go | 28 +++++++++++++++++++--------- cors_test.go | 34 ++++++++++++++++++++++++++-------- 3 files changed, 46 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index 8934739..ecc83b2 100644 --- a/README.md +++ b/README.md @@ -88,6 +88,7 @@ handler = c.Handler(handler) * **AllowedOrigins** `[]string`: A list of origins a cross-domain request can be executed from. If the special `*` value is present in the list, all origins will be allowed. An origin may contain a wildcard (`*`) to replace 0 or more characters (i.e.: `http://*.domain.com`). Usage of wildcards implies a small performance penality. Only one wildcard can be used per origin. The default value is `*`. * **AllowOriginFunc** `func (origin string) bool`: A custom function to validate the origin. It takes the origin as an argument and returns true if allowed, or false otherwise. If this option is set, the content of `AllowedOrigins` is ignored. +* **AllowOriginRequestFunc** `func (r *http.Request origin string) bool`: A custom function to validate the origin. It takes the HTTP Request object and the origin as argument and returns true if allowed or false otherwise. If this option is set, the content of `AllowedOrigins` and `AllowOriginFunc` is ignored * **AllowedMethods** `[]string`: A list of methods the client is allowed to use with cross-domain requests. Default value is simple methods (`GET` and `POST`). * **AllowedHeaders** `[]string`: A list of non simple headers the client is allowed to use with cross-domain requests. * **ExposedHeaders** `[]string`: Indicates which headers are safe to expose to the API of a CORS API specification diff --git a/cors.go b/cors.go index 1518108..d301ca7 100644 --- a/cors.go +++ b/cors.go @@ -41,6 +41,10 @@ type Options struct { // as argument and returns true if allowed or false otherwise. If this option is // set, the content of AllowedOrigins is ignored. AllowOriginFunc func(origin string) bool + // AllowOriginFunc is a custom function to validate the origin. It takes the HTTP Request object and the origin as + // argument and returns true if allowed or false otherwise. If this option is set, the content of `AllowedOrigins` + // and `AllowOriginFunc` is ignored. + AllowOriginRequestFunc func(r *http.Request, origin string) bool // AllowedMethods is a list of methods the client is allowed to use with // cross-domain requests. Default value is simple methods (HEAD, GET and POST). AllowedMethods []string @@ -75,6 +79,8 @@ type Cors struct { allowedWOrigins []wildcard // Optional origin validator function allowOriginFunc func(origin string) bool + // Optional origin validator (with request) function + allowOriginRequestFunc func(r *http.Request, origin string) bool // Normalized list of allowed headers allowedHeaders []string // Normalized list of allowed methods @@ -93,11 +99,12 @@ type Cors struct { // New creates a new Cors handler with the provided options. func New(options Options) *Cors { c := &Cors{ - exposedHeaders: convert(options.ExposedHeaders, http.CanonicalHeaderKey), - allowOriginFunc: options.AllowOriginFunc, - allowCredentials: options.AllowCredentials, - maxAge: options.MaxAge, - optionPassthrough: options.OptionsPassthrough, + exposedHeaders: convert(options.ExposedHeaders, http.CanonicalHeaderKey), + allowOriginFunc: options.AllowOriginFunc, + allowOriginRequestFunc: options.AllowOriginRequestFunc, + allowCredentials: options.AllowCredentials, + maxAge: options.MaxAge, + optionPassthrough: options.OptionsPassthrough, } if options.Debug { c.Log = log.New(os.Stdout, "[cors] ", log.LstdFlags) @@ -109,7 +116,7 @@ func New(options Options) *Cors { // Allowed Origins if len(options.AllowedOrigins) == 0 { - if options.AllowOriginFunc == nil { + if options.AllowOriginFunc == nil && options.AllowOriginRequestFunc == nil { // Default is all origins c.allowedOriginsAll = true } @@ -254,7 +261,7 @@ func (c *Cors) handlePreflight(w http.ResponseWriter, r *http.Request) { c.logf(" Preflight aborted: empty origin") return } - if !c.isOriginAllowed(origin) { + if !c.isOriginAllowed(r, origin) { c.logf(" Preflight aborted: origin '%s' not allowed", origin) return } @@ -307,7 +314,7 @@ func (c *Cors) handleActualRequest(w http.ResponseWriter, r *http.Request) { c.logf(" Actual request no headers added: missing origin") return } - if !c.isOriginAllowed(origin) { + if !c.isOriginAllowed(r, origin) { c.logf(" Actual request no headers added: origin '%s' not allowed", origin) return } @@ -344,7 +351,10 @@ func (c *Cors) logf(format string, a ...interface{}) { // isOriginAllowed checks if a given origin is allowed to perform cross-domain requests // on the endpoint -func (c *Cors) isOriginAllowed(origin string) bool { +func (c *Cors) isOriginAllowed(r *http.Request, origin string) bool { + if c.allowOriginRequestFunc != nil { + return c.allowOriginRequestFunc(r, origin) + } if c.allowOriginFunc != nil { return c.allowOriginFunc(origin) } diff --git a/cors_test.go b/cors_test.go index d3dbcba..68c12eb 100644 --- a/cors_test.go +++ b/cors_test.go @@ -49,7 +49,7 @@ func TestSpec(t *testing.T) { { "NoConfig", Options{ - // Intentionally left blank. + // Intentionally left blank. }, "GET", map[string]string{}, @@ -158,15 +158,33 @@ func TestSpec(t *testing.T) { }, }, { - "AllowedOriginFuncNotMatch", + "AllowOriginRequestFuncMatch", Options{ - AllowOriginFunc: func(o string) bool { - return regexp.MustCompile("^http://foo").MatchString(o) + AllowOriginRequestFunc: func(r *http.Request, o string) bool { + return regexp.MustCompile("^http://foo").MatchString(o) && r.Header.Get("Authorization") == "secret" + }, + }, + "GET", + map[string]string{ + "Origin": "http://foobar.com", + "Authorization": "secret", + }, + map[string]string{ + "Vary": "Origin", + "Access-Control-Allow-Origin": "http://foobar.com", + }, + }, + { + "AllowOriginRequestFuncNotMatch", + Options{ + AllowOriginRequestFunc: func(r *http.Request, o string) bool { + return regexp.MustCompile("^http://foo").MatchString(o) && r.Header.Get("Authorization") == "secret" }, }, "GET", map[string]string{ - "Origin": "http://barfoo.com", + "Origin": "http://foobar.com", + "Authorization": "not-secret", }, map[string]string{ "Vary": "Origin", @@ -447,7 +465,7 @@ func TestHandlePreflightInvalidOriginAbortion(t *testing.T) { func TestHandlePreflightNoOptionsAbortion(t *testing.T) { s := New(Options{ - // Intentionally left blank. + // Intentionally left blank. }) res := httptest.NewRecorder() req, _ := http.NewRequest("GET", "http://example.com/foo", nil) @@ -503,7 +521,7 @@ func TestHandleActualRequestInvalidMethodAbortion(t *testing.T) { func TestIsMethodAllowedReturnsFalseWithNoMethods(t *testing.T) { s := New(Options{ - // Intentionally left blank. + // Intentionally left blank. }) s.allowedMethods = []string{} if s.isMethodAllowed("") { @@ -513,7 +531,7 @@ func TestIsMethodAllowedReturnsFalseWithNoMethods(t *testing.T) { func TestIsMethodAllowedReturnsTrueWithOptions(t *testing.T) { s := New(Options{ - // Intentionally left blank. + // Intentionally left blank. }) if !s.isMethodAllowed("OPTIONS") { t.Error("IsMethodAllowed should return true when c.allowedMethods is nil.")