Skip to content

Commit

Permalink
Merge pull request #1651 from curvegrid/cors-allow-origin-func
Browse files Browse the repository at this point in the history
CORS: add an optional custom function to validate the origin
  • Loading branch information
lammel authored Nov 25, 2020
2 parents 17a5fca + e6f24aa commit 502cce2
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 24 deletions.
65 changes: 41 additions & 24 deletions middleware/cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@ type (
// Optional. Default value []string{"*"}.
AllowOrigins []string `yaml:"allow_origins"`

// AllowOriginFunc is a custom function to validate the origin. It takes the
// origin as an argument and returns true if allowed or false otherwise. If
// an error is returned, it is returned by the handler. If this option is
// set, AllowOrigins is ignored.
// Optional.
AllowOriginFunc func(origin string) (bool, error) `yaml:"allow_origin_func"`

// AllowMethods defines a list methods allowed when accessing the resource.
// This is used in response to a preflight request.
// Optional. Default value DefaultCORSConfig.AllowMethods.
Expand Down Expand Up @@ -113,40 +120,50 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
return c.NoContent(http.StatusNoContent)
}

// Check allowed origins
for _, o := range config.AllowOrigins {
if o == "*" && config.AllowCredentials {
allowOrigin = origin
break
}
if o == "*" || o == origin {
allowOrigin = o
break
if config.AllowOriginFunc != nil {
allowed, err := config.AllowOriginFunc(origin)
if err != nil {
return err
}
if matchSubdomain(origin, o) {
if allowed {
allowOrigin = origin
break
}
}

// Check allowed origin patterns
for _, re := range allowOriginPatterns {
if allowOrigin == "" {
didx := strings.Index(origin, "://")
if didx == -1 {
continue
} else {
// Check allowed origins
for _, o := range config.AllowOrigins {
if o == "*" && config.AllowCredentials {
allowOrigin = origin
break
}
domAuth := origin[didx+3:]
// to avoid regex cost by invalid long domain
if len(domAuth) > 253 {
if o == "*" || o == origin {
allowOrigin = o
break
}

if match, _ := regexp.MatchString(re, origin); match {
if matchSubdomain(origin, o) {
allowOrigin = origin
break
}
}

// Check allowed origin patterns
for _, re := range allowOriginPatterns {
if allowOrigin == "" {
didx := strings.Index(origin, "://")
if didx == -1 {
continue
}
domAuth := origin[didx+3:]
// to avoid regex cost by invalid long domain
if len(domAuth) > 253 {
break
}

if match, _ := regexp.MatchString(re, origin); match {
allowOrigin = origin
break
}
}
}
}

// Origin not allowed
Expand Down
47 changes: 47 additions & 0 deletions middleware/cors_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package middleware

import (
"errors"
"net/http"
"net/http/httptest"
"testing"
Expand Down Expand Up @@ -360,3 +361,49 @@ func TestCorsHeaders(t *testing.T) {
}
}
}

func Test_allowOriginFunc(t *testing.T) {
returnTrue := func(origin string) (bool, error) {
return true, nil
}
returnFalse := func(origin string) (bool, error) {
return false, nil
}
returnError := func(origin string) (bool, error) {
return true, errors.New("this is a test error")
}

allowOriginFuncs := []func(origin string) (bool, error){
returnTrue,
returnFalse,
returnError,
}

const origin = "http://example.com"

e := echo.New()
for _, allowOriginFunc := range allowOriginFuncs {
req := httptest.NewRequest(http.MethodOptions, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
req.Header.Set(echo.HeaderOrigin, origin)
cors := CORSWithConfig(CORSConfig{
AllowOriginFunc: allowOriginFunc,
})
h := cors(echo.NotFoundHandler)
err := h(c)

expected, expectedErr := allowOriginFunc(origin)
if expectedErr != nil {
assert.Equal(t, expectedErr, err)
assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
continue
}

if expected {
assert.Equal(t, origin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
} else {
assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
}
}
}

0 comments on commit 502cce2

Please sign in to comment.