Skip to content

Commit

Permalink
moved methodAllowed from CORS to CORSOptions
Browse files Browse the repository at this point in the history
  • Loading branch information
Johannes Koch committed Mar 21, 2022
1 parent 60ea2f6 commit 0cb44e5
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 23 deletions.
12 changes: 6 additions & 6 deletions config/runtime/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,12 +166,12 @@ func NewServerConfiguration(conf *config.Couper, log *logrus.Entry, memStore *ca
return nil, err
}

corsOptions, cerr := middleware.NewCORSOptions(whichCORS(srvConf, srvConf.Spa))
corsOptions, cerr := middleware.NewCORSOptions(whichCORS(srvConf, srvConf.Spa), nil)
if cerr != nil {
return nil, cerr
}

spaHandler = middleware.NewCORSHandler(corsOptions, nil, spaHandler)
spaHandler = middleware.NewCORSHandler(corsOptions, spaHandler)

spaBodies := bodiesWithACBodies(conf.Definitions, srvConf.Spa.AccessControl, srvConf.Spa.DisableAccessControl)
spaHandler = middleware.NewCustomLogsHandler(
Expand Down Expand Up @@ -210,12 +210,12 @@ func NewServerConfiguration(conf *config.Couper, log *logrus.Entry, memStore *ca
return nil, err
}

corsOptions, cerr := middleware.NewCORSOptions(whichCORS(srvConf, srvConf.Files))
corsOptions, cerr := middleware.NewCORSOptions(whichCORS(srvConf, srvConf.Files), nil)
if cerr != nil {
return nil, cerr
}

fileHandler = middleware.NewCORSHandler(corsOptions, nil, fileHandler)
fileHandler = middleware.NewCORSHandler(corsOptions, fileHandler)

fileBodies := bodiesWithACBodies(conf.Definitions, srvConf.Files.AccessControl, srvConf.Files.DisableAccessControl)
fileHandler = middleware.NewCustomLogsHandler(
Expand Down Expand Up @@ -347,12 +347,12 @@ func NewServerConfiguration(conf *config.Couper, log *logrus.Entry, memStore *ca
return nil, err
}

corsOptions, err := middleware.NewCORSOptions(whichCORS(srvConf, parentAPI))
corsOptions, err := middleware.NewCORSOptions(whichCORS(srvConf, parentAPI), allowedMethodsHandler.MethodAllowed)
if err != nil {
return nil, err
}

epHandler = middleware.NewCORSHandler(corsOptions, allowedMethodsHandler.MethodAllowed, epHandler)
epHandler = middleware.NewCORSHandler(corsOptions, epHandler)

bodies := serverBodies
if parentAPI != nil {
Expand Down
10 changes: 5 additions & 5 deletions handler/middleware/cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,17 @@ var _ http.Handler = &CORS{}

type CORS struct {
options *CORSOptions
methodAllowed methodAllowedFunc
nextHandler http.Handler
}

type CORSOptions struct {
AllowedOrigins []string
AllowCredentials bool
MaxAge string
methodAllowed methodAllowedFunc
}

func NewCORSOptions(cors *config.CORS) (*CORSOptions, error) {
func NewCORSOptions(cors *config.CORS, methodAllowed methodAllowedFunc) (*CORSOptions, error) {
if cors == nil {
return nil, nil
}
Expand All @@ -50,6 +50,7 @@ func NewCORSOptions(cors *config.CORS) (*CORSOptions, error) {
AllowedOrigins: allowedOrigins,
AllowCredentials: cors.AllowCredentials,
MaxAge: corsMaxAge,
methodAllowed: methodAllowed,
}, nil
}

Expand All @@ -67,13 +68,12 @@ func (c *CORSOptions) AllowsOrigin(origin string) bool {
return false
}

func NewCORSHandler(opts *CORSOptions, methodAllowed methodAllowedFunc, nextHandler http.Handler) http.Handler {
func NewCORSHandler(opts *CORSOptions, nextHandler http.Handler) http.Handler {
if opts == nil {
return nextHandler
}
return &CORS{
options: opts,
methodAllowed: methodAllowed,
nextHandler: nextHandler,
}
}
Expand Down Expand Up @@ -130,7 +130,7 @@ func (c *CORS) setCorsRespHeaders(headers http.Header, req *http.Request) {
// Reflect request header value
acrm := req.Header.Get("Access-Control-Request-Method")
if acrm != "" {
if c.methodAllowed == nil || c.methodAllowed(acrm) {
if c.options.methodAllowed == nil || c.options.methodAllowed(acrm) {
headers.Set("Access-Control-Allow-Methods", acrm)
}
headers.Add("Vary", "Access-Control-Request-Method")
Expand Down
24 changes: 12 additions & 12 deletions handler/middleware/cors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ func TestCORS_ServeHTTP(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(subT *testing.T) {
corsHandler := NewCORSHandler(tt.corsOptions, nil, upstreamHandler)
corsHandler := NewCORSHandler(tt.corsOptions, upstreamHandler)

req := httptest.NewRequest(http.MethodPost, "http://1.2.3.4/", nil)
for name, value := range tt.requestHeaders {
Expand Down Expand Up @@ -378,7 +378,7 @@ func TestProxy_ServeHTTP_CORS_PFC(t *testing.T) {
}{
{
"specific origin, with ACRM",
&CORSOptions{AllowedOrigins: []string{"https://www.example.com"}},
&CORSOptions{AllowedOrigins: []string{"https://www.example.com"}, methodAllowed: methodAllowed},
map[string]string{
"Origin": "https://www.example.com",
"Access-Control-Request-Method": "POST",
Expand All @@ -394,7 +394,7 @@ func TestProxy_ServeHTTP_CORS_PFC(t *testing.T) {
},
{
"specific origin, with ACRM, method not allowed",
&CORSOptions{AllowedOrigins: []string{"https://www.example.com"}},
&CORSOptions{AllowedOrigins: []string{"https://www.example.com"}, methodAllowed: methodAllowed},
map[string]string{
"Origin": "https://www.example.com",
"Access-Control-Request-Method": "PUT",
Expand All @@ -410,7 +410,7 @@ func TestProxy_ServeHTTP_CORS_PFC(t *testing.T) {
},
{
"specific origin, with ACRH",
&CORSOptions{AllowedOrigins: []string{"https://www.example.com"}},
&CORSOptions{AllowedOrigins: []string{"https://www.example.com"}, methodAllowed: methodAllowed},
map[string]string{
"Origin": "https://www.example.com",
"Access-Control-Request-Headers": "X-Foo, X-Bar",
Expand All @@ -426,7 +426,7 @@ func TestProxy_ServeHTTP_CORS_PFC(t *testing.T) {
},
{
"specific origin, with ACRM, ACRH",
&CORSOptions{AllowedOrigins: []string{"https://www.example.com"}},
&CORSOptions{AllowedOrigins: []string{"https://www.example.com"}, methodAllowed: methodAllowed},
map[string]string{
"Origin": "https://www.example.com",
"Access-Control-Request-Method": "POST",
Expand All @@ -443,7 +443,7 @@ func TestProxy_ServeHTTP_CORS_PFC(t *testing.T) {
},
{
"specific origin, with ACRM, credentials",
&CORSOptions{AllowedOrigins: []string{"https://www.example.com"}, AllowCredentials: true},
&CORSOptions{AllowedOrigins: []string{"https://www.example.com"}, AllowCredentials: true, methodAllowed: methodAllowed},
map[string]string{
"Origin": "https://www.example.com",
"Access-Control-Request-Method": "POST",
Expand All @@ -459,7 +459,7 @@ func TestProxy_ServeHTTP_CORS_PFC(t *testing.T) {
},
{
"specific origin, with ACRM, max-age",
&CORSOptions{AllowedOrigins: []string{"https://www.example.com"}, MaxAge: "3600"},
&CORSOptions{AllowedOrigins: []string{"https://www.example.com"}, MaxAge: "3600", methodAllowed: methodAllowed},
map[string]string{
"Origin": "https://www.example.com",
"Access-Control-Request-Method": "POST",
Expand All @@ -475,7 +475,7 @@ func TestProxy_ServeHTTP_CORS_PFC(t *testing.T) {
},
{
"any origin, with ACRM",
&CORSOptions{AllowedOrigins: []string{"*"}},
&CORSOptions{AllowedOrigins: []string{"*"}, methodAllowed: methodAllowed},
map[string]string{
"Origin": "https://www.example.com",
"Access-Control-Request-Method": "POST",
Expand All @@ -491,7 +491,7 @@ func TestProxy_ServeHTTP_CORS_PFC(t *testing.T) {
},
{
"any origin, with ACRM, credentials",
&CORSOptions{AllowedOrigins: []string{"*"}, AllowCredentials: true},
&CORSOptions{AllowedOrigins: []string{"*"}, AllowCredentials: true, methodAllowed: methodAllowed},
map[string]string{
"Origin": "https://www.example.com",
"Access-Control-Request-Method": "POST",
Expand All @@ -507,7 +507,7 @@ func TestProxy_ServeHTTP_CORS_PFC(t *testing.T) {
},
{
"origin mismatch",
&CORSOptions{AllowedOrigins: []string{"https://www.example.com"}},
&CORSOptions{AllowedOrigins: []string{"https://www.example.com"}, methodAllowed: methodAllowed},
map[string]string{
"Origin": "https://www.example.org",
"Access-Control-Request-Method": "POST",
Expand All @@ -523,7 +523,7 @@ func TestProxy_ServeHTTP_CORS_PFC(t *testing.T) {
},
{
"origin mismatch, credentials",
&CORSOptions{AllowedOrigins: []string{"https://www.example.com"}, AllowCredentials: true},
&CORSOptions{AllowedOrigins: []string{"https://www.example.com"}, AllowCredentials: true, methodAllowed: methodAllowed},
map[string]string{
"Origin": "https://www.example.org",
"Access-Control-Request-Method": "POST",
Expand All @@ -540,7 +540,7 @@ func TestProxy_ServeHTTP_CORS_PFC(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(subT *testing.T) {
corsHandler := NewCORSHandler(tt.corsOptions, methodAllowed, upstreamHandler)
corsHandler := NewCORSHandler(tt.corsOptions, upstreamHandler)

req := httptest.NewRequest(http.MethodOptions, "http://1.2.3.4/", nil)
for name, value := range tt.requestHeaders {
Expand Down

0 comments on commit 0cb44e5

Please sign in to comment.