Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce cors.Config.AllowOriginRequestFunc #60

Merged
merged 1 commit into from
Aug 26, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ handler = c.Handler(handler)

* **AllowedOrigins** `[]string`: A list of origins a cross-domain request can be executed from. If the special `*` value is present in the list, all origins will be allowed. An origin may contain a wildcard (`*`) to replace 0 or more characters (i.e.: `http://*.domain.com`). Usage of wildcards implies a small performance penality. Only one wildcard can be used per origin. The default value is `*`.
* **AllowOriginFunc** `func (origin string) bool`: A custom function to validate the origin. It takes the origin as an argument and returns true if allowed, or false otherwise. If this option is set, the content of `AllowedOrigins` is ignored.
* **AllowOriginRequestFunc** `func (r *http.Request origin string) bool`: A custom function to validate the origin. It takes the HTTP Request object and the origin as argument and returns true if allowed or false otherwise. If this option is set, the content of `AllowedOrigins` and `AllowOriginFunc` is ignored
* **AllowedMethods** `[]string`: A list of methods the client is allowed to use with cross-domain requests. Default value is simple methods (`GET` and `POST`).
* **AllowedHeaders** `[]string`: A list of non simple headers the client is allowed to use with cross-domain requests.
* **ExposedHeaders** `[]string`: Indicates which headers are safe to expose to the API of a CORS API specification
Expand Down
28 changes: 19 additions & 9 deletions cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ type Options struct {
// as argument and returns true if allowed or false otherwise. If this option is
// set, the content of AllowedOrigins is ignored.
AllowOriginFunc func(origin string) bool
// AllowOriginFunc is a custom function to validate the origin. It takes the HTTP Request object and the origin as
// argument and returns true if allowed or false otherwise. If this option is set, the content of `AllowedOrigins`
// and `AllowOriginFunc` is ignored.
AllowOriginRequestFunc func(r *http.Request, origin string) bool
// AllowedMethods is a list of methods the client is allowed to use with
// cross-domain requests. Default value is simple methods (HEAD, GET and POST).
AllowedMethods []string
Expand Down Expand Up @@ -75,6 +79,8 @@ type Cors struct {
allowedWOrigins []wildcard
// Optional origin validator function
allowOriginFunc func(origin string) bool
// Optional origin validator (with request) function
allowOriginRequestFunc func(r *http.Request, origin string) bool
// Normalized list of allowed headers
allowedHeaders []string
// Normalized list of allowed methods
Expand All @@ -93,11 +99,12 @@ type Cors struct {
// New creates a new Cors handler with the provided options.
func New(options Options) *Cors {
c := &Cors{
exposedHeaders: convert(options.ExposedHeaders, http.CanonicalHeaderKey),
allowOriginFunc: options.AllowOriginFunc,
allowCredentials: options.AllowCredentials,
maxAge: options.MaxAge,
optionPassthrough: options.OptionsPassthrough,
exposedHeaders: convert(options.ExposedHeaders, http.CanonicalHeaderKey),
allowOriginFunc: options.AllowOriginFunc,
allowOriginRequestFunc: options.AllowOriginRequestFunc,
allowCredentials: options.AllowCredentials,
maxAge: options.MaxAge,
optionPassthrough: options.OptionsPassthrough,
}
if options.Debug {
c.Log = log.New(os.Stdout, "[cors] ", log.LstdFlags)
Expand All @@ -109,7 +116,7 @@ func New(options Options) *Cors {

// Allowed Origins
if len(options.AllowedOrigins) == 0 {
if options.AllowOriginFunc == nil {
if options.AllowOriginFunc == nil && options.AllowOriginRequestFunc == nil {
// Default is all origins
c.allowedOriginsAll = true
}
Expand Down Expand Up @@ -254,7 +261,7 @@ func (c *Cors) handlePreflight(w http.ResponseWriter, r *http.Request) {
c.logf(" Preflight aborted: empty origin")
return
}
if !c.isOriginAllowed(origin) {
if !c.isOriginAllowed(r, origin) {
c.logf(" Preflight aborted: origin '%s' not allowed", origin)
return
}
Expand Down Expand Up @@ -307,7 +314,7 @@ func (c *Cors) handleActualRequest(w http.ResponseWriter, r *http.Request) {
c.logf(" Actual request no headers added: missing origin")
return
}
if !c.isOriginAllowed(origin) {
if !c.isOriginAllowed(r, origin) {
c.logf(" Actual request no headers added: origin '%s' not allowed", origin)
return
}
Expand Down Expand Up @@ -344,7 +351,10 @@ func (c *Cors) logf(format string, a ...interface{}) {

// isOriginAllowed checks if a given origin is allowed to perform cross-domain requests
// on the endpoint
func (c *Cors) isOriginAllowed(origin string) bool {
func (c *Cors) isOriginAllowed(r *http.Request, origin string) bool {
if c.allowOriginRequestFunc != nil {
return c.allowOriginRequestFunc(r, origin)
}
if c.allowOriginFunc != nil {
return c.allowOriginFunc(origin)
}
Expand Down
34 changes: 26 additions & 8 deletions cors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func TestSpec(t *testing.T) {
{
"NoConfig",
Options{
// Intentionally left blank.
// Intentionally left blank.
},
"GET",
map[string]string{},
Expand Down Expand Up @@ -158,15 +158,33 @@ func TestSpec(t *testing.T) {
},
},
{
"AllowedOriginFuncNotMatch",
"AllowOriginRequestFuncMatch",
Options{
AllowOriginFunc: func(o string) bool {
return regexp.MustCompile("^http://foo").MatchString(o)
AllowOriginRequestFunc: func(r *http.Request, o string) bool {
return regexp.MustCompile("^http://foo").MatchString(o) && r.Header.Get("Authorization") == "secret"
},
},
"GET",
map[string]string{
"Origin": "http://foobar.com",
"Authorization": "secret",
},
map[string]string{
"Vary": "Origin",
"Access-Control-Allow-Origin": "http://foobar.com",
},
},
{
"AllowOriginRequestFuncNotMatch",
Options{
AllowOriginRequestFunc: func(r *http.Request, o string) bool {
return regexp.MustCompile("^http://foo").MatchString(o) && r.Header.Get("Authorization") == "secret"
},
},
"GET",
map[string]string{
"Origin": "http://barfoo.com",
"Origin": "http://foobar.com",
"Authorization": "not-secret",
},
map[string]string{
"Vary": "Origin",
Expand Down Expand Up @@ -447,7 +465,7 @@ func TestHandlePreflightInvalidOriginAbortion(t *testing.T) {

func TestHandlePreflightNoOptionsAbortion(t *testing.T) {
s := New(Options{
// Intentionally left blank.
// Intentionally left blank.
})
res := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "http://example.com/foo", nil)
Expand Down Expand Up @@ -503,7 +521,7 @@ func TestHandleActualRequestInvalidMethodAbortion(t *testing.T) {

func TestIsMethodAllowedReturnsFalseWithNoMethods(t *testing.T) {
s := New(Options{
// Intentionally left blank.
// Intentionally left blank.
})
s.allowedMethods = []string{}
if s.isMethodAllowed("") {
Expand All @@ -513,7 +531,7 @@ func TestIsMethodAllowedReturnsFalseWithNoMethods(t *testing.T) {

func TestIsMethodAllowedReturnsTrueWithOptions(t *testing.T) {
s := New(Options{
// Intentionally left blank.
// Intentionally left blank.
})
if !s.isMethodAllowed("OPTIONS") {
t.Error("IsMethodAllowed should return true when c.allowedMethods is nil.")
Expand Down