diff --git a/x/oauth2cors/cors.go b/x/oauth2cors/cors.go index efde2b8cfb2..99ef6bd3b13 100644 --- a/x/oauth2cors/cors.go +++ b/x/oauth2cors/cors.go @@ -91,6 +91,12 @@ func Middleware(reg interface { } } + // pre-flight requests do not contain credentials (cookies, HTTP authorization) + // so we return true in all cases here. + if r.Method == http.MethodOptions { + return true + } + username, _, ok := r.BasicAuth() if !ok || username == "" { token := fosite.AccessTokenFromRequest(r) diff --git a/x/oauth2cors/cors_test.go b/x/oauth2cors/cors_test.go index c9e04d2db67..83e513653df 100644 --- a/x/oauth2cors/cors_test.go +++ b/x/oauth2cors/cors_test.go @@ -52,6 +52,7 @@ func TestOAuth2AwareCORSMiddleware(t *testing.T) { code int header http.Header expectHeader http.Header + method string }{ { d: "should ignore when disabled", @@ -160,6 +161,17 @@ func TestOAuth2AwareCORSMiddleware(t *testing.T) { header: http.Header{"Origin": {"http://foobar.com"}, "Authorization": {fmt.Sprintf("Basic %s", x.BasicAuth("foo-7", "bar"))}}, expectHeader: http.Header{"Access-Control-Allow-Credentials": []string{"true"}, "Access-Control-Allow-Origin": []string{"http://foobar.com"}, "Access-Control-Expose-Headers": []string{"Content-Type"}, "Vary": []string{"Origin"}}, }, + { + d: "should succeed on pre-flight request when token introspection fails", + prep: func(t *testing.T, r driver.Registry) { + r.Config().MustSet("serve.public.cors.enabled", true) + r.Config().MustSet("serve.public.cors.allowed_origins", []string{"http://not-test-domain.com"}) + }, + code: http.StatusNotImplemented, + header: http.Header{"Origin": {"http://foobar.com"}, "Authorization": {"Bearer 1234"}}, + expectHeader: http.Header{"Access-Control-Allow-Credentials": []string{"true"}, "Access-Control-Allow-Origin": []string{"http://foobar.com"}, "Access-Control-Expose-Headers": []string{"Content-Type"}, "Vary": []string{"Origin"}}, + method: "OPTIONS", + }, { d: "should fail when token introspection fails", prep: func(t *testing.T, r driver.Registry) { @@ -237,7 +249,11 @@ func TestOAuth2AwareCORSMiddleware(t *testing.T) { tc.prep(t, r) } - req, err := http.NewRequest("GET", "http://foobar.com/", nil) + method := "GET" + if tc.method != "" { + method = tc.method + } + req, err := http.NewRequest(method, "http://foobar.com/", nil) require.NoError(t, err) for k := range tc.header { req.Header.Set(k, tc.header.Get(k))