From a36d0af611582985de5d7e939d059425b1b30d45 Mon Sep 17 00:00:00 2001 From: hackerman <3372410+aeneasr@users.noreply.github.com> Date: Sun, 26 Aug 2018 14:28:05 +0200 Subject: [PATCH] oauth2: Enable client specific CORS settings (#1009) Field `allowed_cors_origins` was added to OAuth 2.0 Clients. It enables CORS for the whitelisted URLS for paths which clients interact with, such as /oauth2/token. Closes #975 Signed-off-by: arekkas --- Gopkg.lock | 8 +- client/client.go | 6 + client/manager_0_sql_migrations_test.go | 26 +++++ client/manager_sql.go | 35 ++++++ client/manager_test_helpers.go | 1 + client/validator.go | 22 ++++ cmd/server/handler.go | 17 +-- cmd/server/handler_oauth2_factory.go | 8 +- cmd/server/helper_cors.go | 92 +++++++++++++++ cmd/server/helper_cors_test.go | 129 ++++++++++++++++++++++ oauth2/handler.go | 16 +-- oauth2/handler_fallback_endpoints_test.go | 4 +- oauth2/handler_test.go | 12 +- oauth2/introspector_test.go | 4 +- oauth2/oauth2_auth_code_test.go | 8 +- oauth2/oauth2_client_credentials_test.go | 5 +- oauth2/revocator_test.go | 4 +- 17 files changed, 368 insertions(+), 29 deletions(-) create mode 100644 cmd/server/helper_cors.go create mode 100644 cmd/server/helper_cors_test.go diff --git a/Gopkg.lock b/Gopkg.lock index b423e2c9c6b..03ebcb47ea7 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -22,6 +22,12 @@ packages = ["."] revision = "cd527374f1e5bff4938207604a14f2e38a9cf512" +[[projects]] + name = "github.com/aeneasr/cors" + packages = ["."] + revision = "1bbd41dfc0e2d813965f77b0a8600571b3dc4f72" + version = "v1.5.0" + [[projects]] name = "github.com/asaskevich/govalidator" packages = ["."] @@ -637,6 +643,6 @@ [solve-meta] analyzer-name = "dep" analyzer-version = 1 - inputs-digest = "c4398344f70a0ba6d5db1b004bc09a354eef45f3c6fdf63227feb78b74fef415" + inputs-digest = "c16af2ef7a4de1a072384d6ddd8ca446693376c776ee866f7139027476b30084" solver-name = "gps-cdcl" solver-version = 1 diff --git a/client/client.go b/client/client.go index 44a032a7620..51d229d766b 100644 --- a/client/client.go +++ b/client/client.go @@ -74,6 +74,12 @@ type Client struct { // retains, and discloses personal data. PolicyURI string `json:"policy_uri"` + // AllowedCORSOrigins are one or more URLs (scheme://host[:port]) which are allowed to make CORS requests + // to the /oauth/token endpoint. If this array is empty, the sever's CORS origin configuration (`CORS_ALLOWED_ORIGINS`) + // will be used instead. If this array is set, the allowed origins are appended to the server's CORS origin configuration. + // Be aware that environment variable `CORS_ENABLED` MUST be set to `true` for this to work. + AllowedCORSOrigins []string `json:"allowed_cors_origins"` + // TermsOfServiceURI is a URL string that points to a human-readable terms of service // document for the client that describes a contractual relationship // between the end-user and the client that the end-user accepts when diff --git a/client/manager_0_sql_migrations_test.go b/client/manager_0_sql_migrations_test.go index 5dccf1aabb3..f4a1a85e463 100644 --- a/client/manager_0_sql_migrations_test.go +++ b/client/manager_0_sql_migrations_test.go @@ -90,6 +90,24 @@ var createClientMigrations = []*migrate.Migration{ `DELETE FROM hydra_client WHERE id='6-data'`, }, }, + { + Id: "7-data", + Up: []string{ + `INSERT INTO hydra_client (id, allowed_cors_origins, client_name, client_secret, redirect_uris, grant_types, response_types, scope, owner, policy_uri, tos_uri, client_uri, logo_uri, contacts, client_secret_expires_at, sector_identifier_uri, jwks, jwks_uri, token_endpoint_auth_method, request_uris, request_object_signing_alg, userinfo_signed_response_alg, subject_type) VALUES ('7-data', 'http://localhost|http://google', 'some-client', 'abcdef', 'http://localhost|http://google', 'authorize_code|implicit', 'token|id_token', 'foo|bar', 'aeneas', 'http://policy', 'http://tos', 'http://client', 'http://logo', 'aeneas|foo', 0, 'http://sector', '{"keys": []}', 'http://jwks', 'none', 'http://uri1|http://uri2', 'rs256', 'rs526', 'public')`, + }, + Down: []string{ + `DELETE FROM hydra_client WHERE id='7-data'`, + }, + }, + { + Id: "8-data", + Up: []string{ + `INSERT INTO hydra_client (id, allowed_cors_origins, client_name, client_secret, redirect_uris, grant_types, response_types, scope, owner, policy_uri, tos_uri, client_uri, logo_uri, contacts, client_secret_expires_at, sector_identifier_uri, jwks, jwks_uri, token_endpoint_auth_method, request_uris, request_object_signing_alg, userinfo_signed_response_alg, subject_type) VALUES ('8-data', 'http://localhost|http://google', 'some-client', 'abcdef', 'http://localhost|http://google', 'authorize_code|implicit', 'token|id_token', 'foo|bar', 'aeneas', 'http://policy', 'http://tos', 'http://client', 'http://logo', 'aeneas|foo', 0, 'http://sector', '{"keys": []}', 'http://jwks', 'none', 'http://uri1|http://uri2', 'rs256', 'rs526', 'public')`, + }, + Down: []string{ + `DELETE FROM hydra_client WHERE id='8-data'`, + }, + }, } var migrations = map[string]*migrate.MemoryMigrationSource{ @@ -108,6 +126,10 @@ var migrations = map[string]*migrate.MemoryMigrationSource{ createClientMigrations[4], client.Migrations["mysql"].Migrations[5], createClientMigrations[5], + client.Migrations["mysql"].Migrations[6], + createClientMigrations[6], + client.Migrations["mysql"].Migrations[7], + createClientMigrations[7], }, }, "postgres": { @@ -125,6 +147,10 @@ var migrations = map[string]*migrate.MemoryMigrationSource{ createClientMigrations[4], client.Migrations["postgres"].Migrations[5], createClientMigrations[5], + client.Migrations["postgres"].Migrations[6], + createClientMigrations[6], + client.Migrations["postgres"].Migrations[7], + createClientMigrations[7], }, }, } diff --git a/client/manager_sql.go b/client/manager_sql.go index cb28ee9aca4..6f581d5a173 100644 --- a/client/manager_sql.go +++ b/client/manager_sql.go @@ -108,6 +108,15 @@ var sharedMigrations = []*migrate.Migration{ `ALTER TABLE hydra_client DROP COLUMN subject_type`, }, }, + { + Id: "7", + Up: []string{ + `ALTER TABLE hydra_client ADD allowed_cors_origins TEXT`, + }, + Down: []string{ + `ALTER TABLE hydra_client DROP COLUMN allowed_cors_origins`, + }, + }, } var Migrations = map[string]*migrate.MemoryMigrationSource{ @@ -133,6 +142,17 @@ var Migrations = map[string]*migrate.MemoryMigrationSource{ }, sharedMigrations[3], sharedMigrations[4], + sharedMigrations[5], + { + Id: "8", + Up: []string{ + `UPDATE hydra_client SET allowed_cors_origins=''`, + `ALTER TABLE hydra_client MODIFY allowed_cors_origins TEXT NOT NULL`, + }, + Down: []string{ + `ALTER TABLE hydra_client MODIFY allowed_cors_origins TEXT`, + }, + }, }}, "postgres": {Migrations: []*migrate.Migration{ sharedMigrations[0], @@ -156,6 +176,17 @@ var Migrations = map[string]*migrate.MemoryMigrationSource{ }, sharedMigrations[3], sharedMigrations[4], + sharedMigrations[5], + { + Id: "8", + Up: []string{ + `UPDATE hydra_client SET allowed_cors_origins=''`, + `ALTER TABLE hydra_client ALTER COLUMN allowed_cors_origins SET NOT NULL`, + }, + Down: []string{ + `ALTER TABLE hydra_client ALTER COLUMN allowed_cors_origins DROP NOT NULL`, + }, + }, }}, } @@ -187,6 +218,7 @@ type sqlData struct { SubjectType string `db:"subject_type"` RequestObjectSigningAlgorithm string `db:"request_object_signing_alg"` UserinfoSignedResponseAlg string `db:"userinfo_signed_response_alg"` + AllowedCORSOrigins string `db:"allowed_cors_origins"` } var sqlParams = []string{ @@ -212,6 +244,7 @@ var sqlParams = []string{ "request_uris", "request_object_signing_alg", "userinfo_signed_response_alg", + "allowed_cors_origins", } func sqlDataFromClient(d *Client) (*sqlData, error) { @@ -248,6 +281,7 @@ func sqlDataFromClient(d *Client) (*sqlData, error) { RequestURIs: strings.Join(d.RequestURIs, "|"), UserinfoSignedResponseAlg: d.UserinfoSignedResponseAlg, SubjectType: d.SubjectType, + AllowedCORSOrigins: strings.Join(d.AllowedCORSOrigins, "|"), }, nil } @@ -274,6 +308,7 @@ func (d *sqlData) ToClient() (*Client, error) { RequestURIs: stringsx.Splitx(d.RequestURIs, "|"), UserinfoSignedResponseAlg: d.UserinfoSignedResponseAlg, SubjectType: d.SubjectType, + AllowedCORSOrigins: stringsx.Splitx(d.AllowedCORSOrigins, "|"), } if d.JSONWebKeys != "" { diff --git a/client/manager_test_helpers.go b/client/manager_test_helpers.go index 70f58c02a56..ec7b407d518 100644 --- a/client/manager_test_helpers.go +++ b/client/manager_test_helpers.go @@ -89,6 +89,7 @@ func TestHelperCreateGetDeleteClient(k string, m Storage) func(t *testing.T) { JSONWebKeysURI: "https://...", TokenEndpointAuthMethod: "none", RequestURIs: []string{"foo", "bar"}, + AllowedCORSOrigins: []string{"foo", "bar"}, RequestObjectSigningAlgorithm: "rs256", UserinfoSignedResponseAlg: "RS256", } diff --git a/client/validator.go b/client/validator.go index 0301f388055..8ac3ae107ca 100644 --- a/client/validator.go +++ b/client/validator.go @@ -83,6 +83,28 @@ func (v *Validator) Validate(c *Client) error { c.Scope = strings.Join(v.DefaultClientScopes, " ") } + for k, origin := range c.AllowedCORSOrigins { + u, err := url.Parse(origin) + if err != nil { + return errors.WithStack(fosite.ErrInvalidRequest.WithHint(fmt.Sprintf("Origin URL %s from allowed_cors_origins could not be parsed: %s", origin, err))) + } + + if u.Scheme != "https" && u.Scheme != "http" { + return errors.WithStack(fosite.ErrInvalidRequest.WithHint(fmt.Sprintf("Origin URL %s must use https:// or http:// as HTTP scheme.", origin))) + } + + if u.User != nil && len(u.User.String()) > 0 { + return errors.WithStack(fosite.ErrInvalidRequest.WithHint(fmt.Sprintf("Origin URL %s has HTTP user and/or password set which is not allowed.", origin))) + } + + u.Path = strings.TrimRight(u.Path, "/") + if len(u.Path)+len(u.RawQuery)+len(u.Fragment) > 0 { + return errors.WithStack(fosite.ErrInvalidRequest.WithHint(fmt.Sprintf("Origin URL %s must have an empty path, query, and fragment but one of the parts is not empty.", origin))) + } + + c.AllowedCORSOrigins[k] = u.String() + } + // has to be 0 because it is not supposed to be set c.SecretExpiresAt = 0 diff --git a/cmd/server/handler.go b/cmd/server/handler.go index 2e14a8ecb36..d7332118d0d 100644 --- a/cmd/server/handler.go +++ b/cmd/server/handler.go @@ -51,16 +51,17 @@ import ( var _ = &consent.Handler{} -func enhanceRouter(c *config.Config, cmd *cobra.Command, serverHandler *Handler, router *httprouter.Router, middlewares []negroni.Handler) http.Handler { +func enhanceRouter(c *config.Config, cmd *cobra.Command, serverHandler *Handler, router *httprouter.Router, middlewares []negroni.Handler, enableCors bool) http.Handler { n := negroni.New() for _, m := range middlewares { n.Use(m) } n.UseFunc(serverHandler.rejectInsecureRequests) n.UseHandler(router) - if viper.GetString("CORS_ENABLED") == "true" { + if enableCors { c.GetLogger().Info("Enabled CORS") - return context.ClearHandler(cors.New(corsx.ParseOptions()).Handler(n)) + options := corsx.ParseOptions() + return context.ClearHandler(cors.New(options).Handler(n)) } else { return context.ClearHandler(n) } @@ -77,7 +78,7 @@ func RunServeAdmin(c *config.Config) func(cmd *cobra.Command, args []string) { cert := getOrCreateTLSCertificate(cmd, c) // go serve(c, cmd, enhanceRouter(c, cmd, serverHandler, frontend), c.GetFrontendAddress(), &wg) - go serve(c, cmd, enhanceRouter(c, cmd, serverHandler, backend, mws), c.GetBackendAddress(), &wg, cert) + go serve(c, cmd, enhanceRouter(c, cmd, serverHandler, backend, mws, viper.GetString("CORS_ENABLED") == "true"), c.GetBackendAddress(), &wg, cert) wg.Wait() } @@ -93,7 +94,7 @@ func RunServePublic(c *config.Config) func(cmd *cobra.Command, args []string) { wg.Add(2) cert := getOrCreateTLSCertificate(cmd, c) - go serve(c, cmd, enhanceRouter(c, cmd, serverHandler, frontend, mws), c.GetFrontendAddress(), &wg, cert) + go serve(c, cmd, enhanceRouter(c, cmd, serverHandler, frontend, mws, false), c.GetFrontendAddress(), &wg, cert) // go serve(c, cmd, enhanceRouter(c, cmd, serverHandler, backend), c.GetBackendAddress(), &wg) wg.Wait() @@ -109,8 +110,8 @@ func RunServeAll(c *config.Config) func(cmd *cobra.Command, args []string) { wg.Add(2) cert := getOrCreateTLSCertificate(cmd, c) - go serve(c, cmd, enhanceRouter(c, cmd, serverHandler, frontend, mws), c.GetFrontendAddress(), &wg, cert) - go serve(c, cmd, enhanceRouter(c, cmd, serverHandler, backend, mws), c.GetBackendAddress(), &wg, cert) + go serve(c, cmd, enhanceRouter(c, cmd, serverHandler, frontend, mws, false), c.GetFrontendAddress(), &wg, cert) + go serve(c, cmd, enhanceRouter(c, cmd, serverHandler, backend, mws, viper.GetString("CORS_ENABLED") == "true"), c.GetBackendAddress(), &wg, cert) wg.Wait() } @@ -257,7 +258,7 @@ func (h *Handler) registerRoutes(frontend, backend *httprouter.Router) { h.Clients = newClientHandler(c, backend, clientsManager) h.Keys = newJWKHandler(c, frontend, backend) h.Consent = newConsentHandler(c, frontend, backend) - h.OAuth2 = newOAuth2Handler(c, frontend, backend, ctx.ConsentManager, oauth2Provider) + h.OAuth2 = newOAuth2Handler(c, frontend, backend, ctx.ConsentManager, oauth2Provider, clientsManager) _ = newHealthHandler(c, backend) } diff --git a/cmd/server/handler_oauth2_factory.go b/cmd/server/handler_oauth2_factory.go index e1235607787..95ffb80a958 100644 --- a/cmd/server/handler_oauth2_factory.go +++ b/cmd/server/handler_oauth2_factory.go @@ -41,6 +41,7 @@ import ( "github.com/ory/hydra/oauth2" "github.com/ory/hydra/pkg" "github.com/pborman/uuid" + "github.com/spf13/viper" ) func injectFositeStore(c *config.Config, clients client.Manager) { @@ -151,8 +152,8 @@ func setDefaultConsentURL(s string, c *config.Config, path string) string { } //func newOAuth2Handler(c *config.Config, router *httprouter.Router, cm oauth2.ConsentRequestManager, o fosite.OAuth2Provider, idTokenKeyID string) *oauth2.Handler { -func newOAuth2Handler(c *config.Config, frontend, backend *httprouter.Router, cm consent.Manager, o fosite.OAuth2Provider) *oauth2.Handler { - expectDependency(c.GetLogger(), c.Context().FositeStore) +func newOAuth2Handler(c *config.Config, frontend, backend *httprouter.Router, cm consent.Manager, o fosite.OAuth2Provider, clm client.Manager) *oauth2.Handler { + expectDependency(c.GetLogger(), c.Context().FositeStore, clm) c.ConsentURL = setDefaultConsentURL(c.ConsentURL, c, "oauth2/fallbacks/consent") c.LoginURL = setDefaultConsentURL(c.LoginURL, c, "oauth2/fallbacks/consent") @@ -214,6 +215,7 @@ func newOAuth2Handler(c *config.Config, frontend, backend *httprouter.Router, cm ShareOAuth2Debug: c.SendOAuth2DebugMessagesToClients, } - handler.SetRoutes(frontend, backend) + corsMiddleware := newCORSMiddleware(viper.GetString("CORS_ENABLED") == "true", c, o.IntrospectToken, clm.GetConcreteClient) + handler.SetRoutes(frontend, backend, corsMiddleware) return handler } diff --git a/cmd/server/helper_cors.go b/cmd/server/helper_cors.go new file mode 100644 index 00000000000..5a30225cc3d --- /dev/null +++ b/cmd/server/helper_cors.go @@ -0,0 +1,92 @@ +/* + * Copyright © 2015-2018 Aeneas Rekkas + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * @author Aeneas Rekkas + * @Copyright 2017-2018 Aeneas Rekkas + * @license Apache-2.0 + */ + +package server + +import ( + "context" + "net/http" + + "github.com/aeneasr/cors" + "github.com/ory/fosite" + "github.com/ory/go-convenience/corsx" + "github.com/ory/go-convenience/stringslice" + "github.com/ory/hydra/client" + "github.com/ory/hydra/config" + "github.com/ory/hydra/oauth2" +) + +func newCORSMiddleware( + enable bool, c *config.Config, + o func(ctx context.Context, token string, tokenType fosite.TokenType, session fosite.Session, scope ...string) (fosite.TokenType, fosite.AccessRequester, error), + clm func(id string) (*client.Client, error), +) func(h http.Handler) http.Handler { + if !enable { + return func(h http.Handler) http.Handler { + return h + } + } + + c.GetLogger().Info("Enabled CORS") + po := corsx.ParseOptions() + options := cors.Options{ + AllowedOrigins: po.AllowedOrigins, + AllowedMethods: po.AllowedMethods, + AllowedHeaders: po.AllowedHeaders, + ExposedHeaders: po.ExposedHeaders, + MaxAge: po.MaxAge, + AllowCredentials: po.AllowCredentials, + OptionsPassthrough: po.OptionsPassthrough, + Debug: po.Debug, + AllowOriginRequestFunc: func(r *http.Request, origin string) bool { + if stringslice.Has(po.AllowedOrigins, origin) { + return true + } + + username, _, ok := r.BasicAuth() + if !ok || username == "" { + token := fosite.AccessTokenFromRequest(r) + if token == "" { + return false + } + + session := oauth2.NewSession("") + _, ar, err := o(context.Background(), token, fosite.AccessToken, session) + if err != nil { + return false + } + + username = ar.GetClient().GetID() + } + + cl, err := clm(username) + if err != nil { + return false + } + + if stringslice.Has(cl.AllowedCORSOrigins, origin) { + return true + } + + return false + }, + } + return cors.New(options).Handler +} diff --git a/cmd/server/helper_cors_test.go b/cmd/server/helper_cors_test.go new file mode 100644 index 00000000000..38a3e757e1f --- /dev/null +++ b/cmd/server/helper_cors_test.go @@ -0,0 +1,129 @@ +/* + * Copyright © 2015-2018 Aeneas Rekkas + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * @author Aeneas Rekkas + * @Copyright 2017-2018 Aeneas Rekkas + * @license Apache-2.0 + */ + +package server + +import ( + "context" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/julienschmidt/httprouter" + "github.com/ory/fosite" + "github.com/ory/hydra/client" + "github.com/ory/hydra/config" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCORSMiddleware(t *testing.T) { + handler := httprouter.New() + handler.GET("/", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + w.WriteHeader(http.StatusNoContent) + }) + + c := new(config.Config) + for k, tc := range []struct { + d string + mw func(http.Handler) http.Handler + code int + header http.Header + expectHeader http.Header + }{ + { + d: "should ignore when disabled", + mw: newCORSMiddleware(false, nil, nil, nil), + code: 204, + header: http.Header{}, + expectHeader: http.Header{}, + }, + { + d: "should reject when basic auth but client does not exist", + mw: newCORSMiddleware(true, c, nil, func(id string) (*client.Client, error) { + return nil, errors.New("") + }), + code: 204, + header: http.Header{"Origin": {"http://foobar.com"}, "Authorization": {"Basic Zm9vOmJhcg=="}}, + expectHeader: http.Header{"Vary": {"Origin"}}, + }, + { + d: "should reject when basic auth client exists but origin not allowed", + mw: newCORSMiddleware(true, c, nil, func(id string) (*client.Client, error) { + return &client.Client{AllowedCORSOrigins: []string{"http://not-foobar.com"}}, nil + }), + code: 204, + header: http.Header{"Origin": {"http://foobar.com"}, "Authorization": {"Basic Zm9vOmJhcg=="}}, + expectHeader: http.Header{"Vary": {"Origin"}}, + }, + { + d: "should accept when basic auth client exists and origin allowed", + mw: newCORSMiddleware(true, c, nil, func(id string) (*client.Client, error) { + return &client.Client{AllowedCORSOrigins: []string{"http://foobar.com"}}, nil + }), + code: 204, + header: http.Header{"Origin": {"http://foobar.com"}, "Authorization": {"Basic Zm9vOmJhcg=="}}, + expectHeader: http.Header{"Vary": {"Origin"}, "Access-Control-Allow-Origin": {"http://foobar.com"}}, + }, + { + d: "should fail when token introspection fails", + mw: newCORSMiddleware(true, c, func(ctx context.Context, token string, tokenType fosite.TokenType, session fosite.Session, scope ...string) (fosite.TokenType, fosite.AccessRequester, error) { + return "", nil, errors.New("") + }, func(id string) (*client.Client, error) { + return &client.Client{AllowedCORSOrigins: []string{"http://foobar.com"}}, nil + }), + code: 204, + header: http.Header{"Origin": {"http://foobar.com"}, "Authorization": {"Basic Zm9vOmJhcg=="}}, + expectHeader: http.Header{"Vary": {"Origin"}, "Access-Control-Allow-Origin": {"http://foobar.com"}}, + }, + { + d: "should fail when token introspection fails", + mw: newCORSMiddleware(true, c, func(ctx context.Context, token string, tokenType fosite.TokenType, session fosite.Session, scope ...string) (fosite.TokenType, fosite.AccessRequester, error) { + if token != "1234" { + return "", nil, errors.New("") + } + return "", &fosite.AccessRequest{Request: fosite.Request{Client: &client.Client{ClientID: "asdf"}}}, nil + }, func(id string) (*client.Client, error) { + if id != "asdf" { + return nil, errors.New("") + } + return &client.Client{AllowedCORSOrigins: []string{"http://foobar.com"}}, nil + }), + code: 204, + header: http.Header{"Origin": {"http://foobar.com"}, "Authorization": {"bearer 1234"}}, + expectHeader: http.Header{"Vary": {"Origin"}, "Access-Control-Allow-Origin": {"http://foobar.com"}}, + }, + } { + t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { + req, err := http.NewRequest("GET", "http://foobar.com/", nil) + require.NoError(t, err) + for k := range tc.header { + req.Header.Set(k, tc.header.Get(k)) + } + + res := httptest.NewRecorder() + tc.mw(handler).ServeHTTP(res, req) + require.NoError(t, err) + assert.EqualValues(t, tc.expectHeader, res.Header()) + }) + } +} diff --git a/oauth2/handler.go b/oauth2/handler.go index 5920dcdd69b..387b2f566d8 100644 --- a/oauth2/handler.go +++ b/oauth2/handler.go @@ -156,17 +156,17 @@ type FlushInactiveOAuth2TokensRequest struct { NotAfter time.Time `json:"notAfter"` } -func (h *Handler) SetRoutes(frontend, backend *httprouter.Router) { - frontend.POST(TokenPath, h.TokenHandler) +func (h *Handler) SetRoutes(frontend, backend *httprouter.Router, corsMiddleware func(http.Handler) http.Handler) { + frontend.Handler("POST", TokenPath, corsMiddleware(http.HandlerFunc(h.TokenHandler))) frontend.GET(AuthPath, h.AuthHandler) frontend.POST(AuthPath, h.AuthHandler) frontend.GET(DefaultConsentPath, h.DefaultConsentHandler) frontend.GET(DefaultErrorPath, h.DefaultErrorHandler) frontend.GET(DefaultLogoutPath, h.DefaultLogoutHandler) - frontend.POST(RevocationPath, h.RevocationHandler) + frontend.Handler("POST", RevocationPath, corsMiddleware(http.HandlerFunc(h.RevocationHandler))) frontend.GET(WellKnownPath, h.WellKnownHandler) - frontend.GET(UserinfoPath, h.UserinfoHandler) - frontend.POST(UserinfoPath, h.UserinfoHandler) + frontend.Handler("GET", UserinfoPath, corsMiddleware(http.HandlerFunc(h.UserinfoHandler))) + frontend.Handler("POST", UserinfoPath, corsMiddleware(http.HandlerFunc(h.UserinfoHandler))) backend.POST(IntrospectPath, h.IntrospectHandler) backend.POST(FlushPath, h.FlushHandler) @@ -251,7 +251,7 @@ func (h *Handler) WellKnownHandler(w http.ResponseWriter, r *http.Request, _ htt // 200: userinfoResponse // 401: genericError // 500: genericError -func (h *Handler) UserinfoHandler(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { +func (h *Handler) UserinfoHandler(w http.ResponseWriter, r *http.Request) { session := NewSession("") tokenType, ar, err := h.OAuth2.IntrospectToken(r.Context(), fosite.AccessTokenFromRequest(r), fosite.AccessToken, session) if err != nil { @@ -341,7 +341,7 @@ func (h *Handler) UserinfoHandler(w http.ResponseWriter, r *http.Request, _ http // 200: emptyResponse // 401: genericError // 500: genericError -func (h *Handler) RevocationHandler(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { +func (h *Handler) RevocationHandler(w http.ResponseWriter, r *http.Request) { var ctx = fosite.NewContext() err := h.OAuth2.NewRevocationRequest(ctx, r) @@ -513,7 +513,7 @@ func (h *Handler) FlushHandler(w http.ResponseWriter, r *http.Request, _ httprou // 200: oauthTokenResponse // 401: genericError // 500: genericError -func (h *Handler) TokenHandler(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { +func (h *Handler) TokenHandler(w http.ResponseWriter, r *http.Request) { var session = NewSession("") var ctx = fosite.NewContext() diff --git a/oauth2/handler_fallback_endpoints_test.go b/oauth2/handler_fallback_endpoints_test.go index 45cef5d0a0b..ba487c4fd51 100644 --- a/oauth2/handler_fallback_endpoints_test.go +++ b/oauth2/handler_fallback_endpoints_test.go @@ -38,7 +38,9 @@ func TestHandlerConsent(t *testing.T) { ScopeStrategy: fosite.HierarchicScopeStrategy, } r := httprouter.New() - h.SetRoutes(r, r) + h.SetRoutes(r, r, func(h http.Handler) http.Handler { + return h + }) ts := httptest.NewServer(r) res, err := http.Get(ts.URL + DefaultConsentPath) diff --git a/oauth2/handler_test.go b/oauth2/handler_test.go index f7b69245b83..f2f45237525 100644 --- a/oauth2/handler_test.go +++ b/oauth2/handler_test.go @@ -96,7 +96,9 @@ func TestHandlerFlushHandler(t *testing.T) { } r := httprouter.New() - h.SetRoutes(r, r) + h.SetRoutes(r, r, func(h http.Handler) http.Handler { + return h + }) ts := httptest.NewServer(r) c := hydra.NewOAuth2ApiWithBasePath(ts.URL) @@ -154,7 +156,9 @@ func TestUserinfo(t *testing.T) { OpenIDJWTStrategy: jwtStrategy, } router := httprouter.New() - h.SetRoutes(router, router) + h.SetRoutes(router, router, func(h http.Handler) http.Handler { + return h + }) ts := httptest.NewServer(router) defer ts.Close() @@ -368,7 +372,9 @@ func TestHandlerWellKnown(t *testing.T) { JWKPathT := "/.well-known/jwks.json" r := httprouter.New() - h.SetRoutes(r, r) + h.SetRoutes(r, r, func(h http.Handler) http.Handler { + return h + }) ts := httptest.NewServer(r) res, err := http.Get(ts.URL + "/.well-known/openid-configuration") diff --git a/oauth2/introspector_test.go b/oauth2/introspector_test.go index f9983aa756c..b2a36497c20 100644 --- a/oauth2/introspector_test.go +++ b/oauth2/introspector_test.go @@ -75,7 +75,9 @@ func TestIntrospectorSDK(t *testing.T) { IssuerURL: "foobariss", OpenIDJWTStrategy: jwtStrategy, } - handler.SetRoutes(router, router) + handler.SetRoutes(router, router, func(h http.Handler) http.Handler { + return h + }) server := httptest.NewServer(router) now := time.Now().UTC().Round(time.Minute) diff --git a/oauth2/oauth2_auth_code_test.go b/oauth2/oauth2_auth_code_test.go index eaff87af60a..e79eb380266 100644 --- a/oauth2/oauth2_auth_code_test.go +++ b/oauth2/oauth2_auth_code_test.go @@ -193,7 +193,9 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { IssuerURL: ts.URL, ForcedHTTP: true, L: l, OpenIDJWTStrategy: jwtStrategy, } - handler.SetRoutes(router, router) + handler.SetRoutes(router, router, func(h http.Handler) http.Handler { + return h + }) apiHandler := consent.NewHandler(herodot.NewJSONWriter(l), cm, cookieStore, "") apiRouter := httprouter.New() @@ -745,7 +747,9 @@ func TestAuthCodeWithMockStrategy(t *testing.T) { IssuerURL: ts.URL, OpenIDJWTStrategy: jwtStrategy, } - handler.SetRoutes(router, router) + handler.SetRoutes(router, router, func(h http.Handler) http.Handler { + return h + }) var callbackHandler *httprouter.Handle router.GET("/callback", func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { diff --git a/oauth2/oauth2_client_credentials_test.go b/oauth2/oauth2_client_credentials_test.go index 73efcd0de1a..6b739149c67 100644 --- a/oauth2/oauth2_client_credentials_test.go +++ b/oauth2/oauth2_client_credentials_test.go @@ -23,6 +23,7 @@ package oauth2_test import ( "context" "encoding/json" + "net/http" "net/http/httptest" "strings" "testing" @@ -110,7 +111,9 @@ func TestClientCredentials(t *testing.T) { OpenIDJWTStrategy: jwtStrategy, } - handler.SetRoutes(router, router) + handler.SetRoutes(router, router, func(h http.Handler) http.Handler { + return h + }) require.NoError(t, store.CreateClient(&hc.Client{ ClientID: "app-client", diff --git a/oauth2/revocator_test.go b/oauth2/revocator_test.go index c7fcca15c62..4eb23798a92 100644 --- a/oauth2/revocator_test.go +++ b/oauth2/revocator_test.go @@ -92,7 +92,9 @@ func TestRevoke(t *testing.T) { } router := httprouter.New() - handler.SetRoutes(router, router) + handler.SetRoutes(router, router, func(h http.Handler) http.Handler { + return h + }) server := httptest.NewServer(router) createAccessTokenSession("alice", "my-client", tokens[0][0], now.Add(time.Hour), store, nil)