diff --git a/.circleci/config.yml b/.circleci/config.yml index 483fb5678..c536386d6 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -3,7 +3,7 @@ version: 2.1 orbs: changelog: ory/changelog@0.1.4 nancy: ory/nancy@0.0.13 - golangci: ory/golangci@0.0.9 + golangci: ory/golangci@0.0.11 jobs: test: diff --git a/access_request_handler.go b/access_request_handler.go index 8b92b5c35..9e500f42b 100644 --- a/access_request_handler.go +++ b/access_request_handler.go @@ -57,7 +57,6 @@ import ( // client MUST authenticate with the authorization server as described // in Section 3.2.1. func (f *Fosite) NewAccessRequest(ctx context.Context, r *http.Request, session Session) (AccessRequester, error) { - var err error accessRequest := NewAccessRequest(session) if r.Method != "POST" { @@ -80,18 +79,34 @@ func (f *Fosite) NewAccessRequest(ctx context.Context, r *http.Request, session return accessRequest, errorsx.WithStack(ErrInvalidRequest.WithHint("Request parameter 'grant_type' is missing")) } - client, err := f.AuthenticateClient(ctx, r, r.PostForm) - if err != nil { - return accessRequest, err + client, clientErr := f.AuthenticateClient(ctx, r, r.PostForm) + if clientErr == nil { + accessRequest.Client = client } - accessRequest.Client = client var found = false for _, loader := range f.TokenEndpointHandlers { + // Is the loader responsible for handling the request? + if !loader.CanHandleTokenEndpointRequest(accessRequest) { + continue + } + + // The handler **is** responsible! + + // Is the client supplied in the request? If not can this handler skip client auth? + if !loader.CanSkipClientAuth(accessRequest) && clientErr != nil { + // No client and handler can not skip client auth -> error. + return accessRequest, clientErr + } + + // All good. if err := loader.HandleTokenEndpointRequest(ctx, accessRequest); err == nil { found = true } else if errors.Is(err, ErrUnknownRequest) { - // do nothing + // This is a duplicate because it should already have been handled by + // `loader.CanHandleTokenEndpointRequest(accessRequest)` but let's keep it for sanity. + // + continue } else if err != nil { return accessRequest, err } diff --git a/access_request_handler_test.go b/access_request_handler_test.go index 1af5842e0..3ce796d6b 100644 --- a/access_request_handler_test.go +++ b/access_request_handler_test.go @@ -42,6 +42,8 @@ func TestNewAccessRequest(t *testing.T) { ctrl := gomock.NewController(t) store := internal.NewMockStorage(ctrl) handler := internal.NewMockTokenEndpointHandler(ctrl) + handler.EXPECT().CanHandleTokenEndpointRequest(gomock.Any()).Return(true).AnyTimes() + handler.EXPECT().CanSkipClientAuth(gomock.Any()).Return(false).AnyTimes() hasher := internal.NewMockHasher(ctrl) defer ctrl.Finish() @@ -94,6 +96,7 @@ func TestNewAccessRequest(t *testing.T) { mock: func() { store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(nil, errors.New("")) }, + handlers: TokenEndpointHandlers{handler}, }, { header: http.Header{ @@ -118,6 +121,7 @@ func TestNewAccessRequest(t *testing.T) { mock: func() { store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(nil, errors.New("")) }, + handlers: TokenEndpointHandlers{handler}, }, { header: http.Header{ @@ -134,6 +138,7 @@ func TestNewAccessRequest(t *testing.T) { client.Secret = []byte("foo") hasher.EXPECT().Compare(context.TODO(), gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(errors.New("")) }, + handlers: TokenEndpointHandlers{handler}, }, { header: http.Header{ @@ -221,6 +226,239 @@ func TestNewAccessRequest(t *testing.T) { } } +func TestNewAccessRequestWithoutClientAuth(t *testing.T) { + ctrl := gomock.NewController(t) + store := internal.NewMockStorage(ctrl) + handler := internal.NewMockTokenEndpointHandler(ctrl) + handler.EXPECT().CanHandleTokenEndpointRequest(gomock.Any()).Return(true).AnyTimes() + handler.EXPECT().CanSkipClientAuth(gomock.Any()).Return(true).AnyTimes() + hasher := internal.NewMockHasher(ctrl) + defer ctrl.Finish() + + client := &DefaultClient{} + anotherClient := &DefaultClient{ID: "another"} + fosite := &Fosite{Store: store, Hasher: hasher, AudienceMatchingStrategy: DefaultAudienceMatchingStrategy} + for k, c := range []struct { + header http.Header + form url.Values + mock func() + method string + expectErr error + expect *AccessRequest + handlers TokenEndpointHandlers + }{ + // No grant type -> error + { + form: url.Values{}, + mock: func() { + store.EXPECT().GetClient(gomock.Any(), gomock.Any()).Times(0) + }, + method: "POST", + expectErr: ErrInvalidRequest, + }, + // No registered handlers -> error + { + form: url.Values{ + "grant_type": {"foo"}, + }, + mock: func() { + store.EXPECT().GetClient(gomock.Any(), gomock.Any()).Times(0) + }, + method: "POST", + expectErr: ErrInvalidRequest, + handlers: TokenEndpointHandlers{}, + }, + // Handler can skip client auth and ignores missing client. + { + header: http.Header{ + "Authorization": {basicAuth("foo", "bar")}, + }, + form: url.Values{ + "grant_type": {"foo"}, + }, + mock: func() { + // despite error from storage, we should success, because client auth is not required + store.EXPECT().GetClient(gomock.Any(), "foo").Return(nil, errors.New("no client")).Times(1) + handler.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any()).Return(nil) + }, + method: "POST", + expect: &AccessRequest{ + GrantTypes: Arguments{"foo"}, + Request: Request{ + Client: client, + }, + }, + handlers: TokenEndpointHandlers{handler}, + }, + // Should pass if no auth is set in the header and can skip! + { + form: url.Values{ + "grant_type": {"foo"}, + }, + mock: func() { + handler.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any()).Return(nil) + }, + method: "POST", + expect: &AccessRequest{ + GrantTypes: Arguments{"foo"}, + Request: Request{ + Client: client, + }, + }, + handlers: TokenEndpointHandlers{handler}, + }, + // Should also pass if client auth is set! + { + header: http.Header{ + "Authorization": {basicAuth("foo", "bar")}, + }, + form: url.Values{ + "grant_type": {"foo"}, + }, + mock: func() { + store.EXPECT().GetClient(gomock.Any(), "foo").Return(anotherClient, nil).Times(1) + hasher.EXPECT().Compare(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).Times(1) + handler.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any()).Return(nil) + }, + method: "POST", + expect: &AccessRequest{ + GrantTypes: Arguments{"foo"}, + Request: Request{ + Client: anotherClient, + }, + }, + handlers: TokenEndpointHandlers{handler}, + }, + } { + t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { + r := &http.Request{ + Header: c.header, + PostForm: c.form, + Form: c.form, + Method: c.method, + } + c.mock() + ctx := NewContext() + fosite.TokenEndpointHandlers = c.handlers + ar, err := fosite.NewAccessRequest(ctx, r, new(DefaultSession)) + + if c.expectErr != nil { + assert.EqualError(t, err, c.expectErr.Error()) + } else { + require.NoError(t, err) + AssertObjectKeysEqual(t, c.expect, ar, "GrantTypes", "Client") + assert.NotNil(t, ar.GetRequestedAt()) + } + }) + } +} + +// In this test case one handler requires client auth and another handler not. +func TestNewAccessRequestWithMixedClientAuth(t *testing.T) { + ctrl := gomock.NewController(t) + store := internal.NewMockStorage(ctrl) + + handlerWithClientAuth := internal.NewMockTokenEndpointHandler(ctrl) + handlerWithClientAuth.EXPECT().CanHandleTokenEndpointRequest(gomock.Any()).Return(true).AnyTimes() + handlerWithClientAuth.EXPECT().CanSkipClientAuth(gomock.Any()).Return(false).AnyTimes() + + handlerWithoutClientAuth := internal.NewMockTokenEndpointHandler(ctrl) + handlerWithoutClientAuth.EXPECT().CanHandleTokenEndpointRequest(gomock.Any()).Return(true).AnyTimes() + handlerWithoutClientAuth.EXPECT().CanSkipClientAuth(gomock.Any()).Return(true).AnyTimes() + + hasher := internal.NewMockHasher(ctrl) + defer ctrl.Finish() + + client := &DefaultClient{} + fosite := &Fosite{Store: store, Hasher: hasher, AudienceMatchingStrategy: DefaultAudienceMatchingStrategy} + for k, c := range []struct { + header http.Header + form url.Values + mock func() + method string + expectErr error + expect *AccessRequest + handlers TokenEndpointHandlers + }{ + { + header: http.Header{ + "Authorization": {basicAuth("foo", "bar")}, + }, + form: url.Values{ + "grant_type": {"foo"}, + }, + mock: func() { + store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil) + client.Public = false + client.Secret = []byte("foo") + hasher.EXPECT().Compare(context.TODO(), gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(errors.New("hash err")) + handlerWithoutClientAuth.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any()).Return(nil) + }, + method: "POST", + expectErr: ErrInvalidClient, + handlers: TokenEndpointHandlers{handlerWithoutClientAuth, handlerWithClientAuth}, + }, + { + header: http.Header{ + "Authorization": {basicAuth("foo", "bar")}, + }, + form: url.Values{ + "grant_type": {"foo"}, + }, + mock: func() { + store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil) + client.Public = false + client.Secret = []byte("foo") + hasher.EXPECT().Compare(context.TODO(), gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil) + handlerWithoutClientAuth.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any()).Return(nil) + handlerWithClientAuth.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any()).Return(nil) + }, + method: "POST", + expect: &AccessRequest{ + GrantTypes: Arguments{"foo"}, + Request: Request{ + Client: client, + }, + }, + handlers: TokenEndpointHandlers{handlerWithoutClientAuth, handlerWithClientAuth}, + }, + { + header: http.Header{}, + form: url.Values{ + "grant_type": {"foo"}, + }, + mock: func() { + store.EXPECT().GetClient(gomock.Any(), gomock.Any()).Times(0) + handlerWithoutClientAuth.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any()).Return(nil) + }, + method: "POST", + expectErr: ErrInvalidRequest, + handlers: TokenEndpointHandlers{handlerWithoutClientAuth, handlerWithClientAuth}, + }, + } { + t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { + r := &http.Request{ + Header: c.header, + PostForm: c.form, + Form: c.form, + Method: c.method, + } + c.mock() + ctx := NewContext() + fosite.TokenEndpointHandlers = c.handlers + ar, err := fosite.NewAccessRequest(ctx, r, new(DefaultSession)) + + if c.expectErr != nil { + assert.EqualError(t, err, c.expectErr.Error()) + } else { + require.NoError(t, err) + AssertObjectKeysEqual(t, c.expect, ar, "GrantTypes", "Client") + assert.NotNil(t, ar.GetRequestedAt()) + } + }) + } +} + func basicAuth(username, password string) string { return "Basic " + base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", username, password))) } diff --git a/compose/compose.go b/compose/compose.go index 03e58b87b..50fa73703 100644 --- a/compose/compose.go +++ b/compose/compose.go @@ -111,6 +111,7 @@ func ComposeAllEnabled(config *Config, storage interface{}, secret []byte, key * OAuth2ClientCredentialsGrantFactory, OAuth2RefreshTokenGrantFactory, OAuth2ResourceOwnerPasswordCredentialsFactory, + RFC7523AssertionGrantFactory, OpenIDConnectExplicitFactory, OpenIDConnectImplicitFactory, diff --git a/compose/compose_rfc7523.go b/compose/compose_rfc7523.go new file mode 100644 index 000000000..bb584ae12 --- /dev/null +++ b/compose/compose_rfc7523.go @@ -0,0 +1,26 @@ +package compose + +import ( + "github.com/ory/fosite/handler/oauth2" + "github.com/ory/fosite/handler/rfc7523" +) + +// RFC7523AssertionGrantFactory creates an OAuth2 Authorize JWT Grant (using JWTs as Authorization Grants) handler +// and registers an access token, refresh token and authorize code validator. +func RFC7523AssertionGrantFactory(config *Config, storage interface{}, strategy interface{}) interface{} { + return &rfc7523.Handler{ + Storage: storage.(rfc7523.RFC7523KeyStorage), + ScopeStrategy: config.GetScopeStrategy(), + AudienceMatchingStrategy: config.GetAudienceStrategy(), + TokenURL: config.TokenURL, + SkipClientAuth: config.GrantTypeJWTBearerCanSkipClientAuth, + JWTIDOptional: config.GrantTypeJWTBearerIDOptional, + JWTIssuedDateOptional: config.GrantTypeJWTBearerIssuedDateOptional, + JWTMaxDuration: config.GetJWTMaxDuration(), + HandleHelper: &oauth2.HandleHelper{ + AccessTokenStrategy: strategy.(oauth2.AccessTokenStrategy), + AccessTokenStorage: storage.(oauth2.AccessTokenStorage), + AccessTokenLifespan: config.GetAccessTokenLifespan(), + }, + } +} diff --git a/compose/config.go b/compose/config.go index 1466dfe09..fd90e0099 100644 --- a/compose/config.go +++ b/compose/config.go @@ -99,6 +99,18 @@ type Config struct { // UseLegacyErrorFormat controls whether the legacy error format (with `error_debug`, `error_hint`, ...) // should be used or not. UseLegacyErrorFormat bool + + // GrantTypeJWTBearerCanSkipClientAuth indicates, if client authentication can be skipped, when using jwt as assertion. + GrantTypeJWTBearerCanSkipClientAuth bool + + // GrantTypeJWTBearerIDOptional indicates, if jti (JWT ID) claim required or not in JWT. + GrantTypeJWTBearerIDOptional bool + + // GrantTypeJWTBearerIssuedDateOptional indicates, if "iat" (issued at) claim required or not in JWT. + GrantTypeJWTBearerIssuedDateOptional bool + + // GrantTypeJWTBearerMaxDuration sets the maximum time after JWT issued date, during which the JWT is considered valid. + GrantTypeJWTBearerMaxDuration time.Duration } // GetScopeStrategy returns the scope strategy to be used. Defaults to glob scope strategy. @@ -198,3 +210,14 @@ func (c *Config) GetMinParameterEntropy() int { return c.MinParameterEntropy } } + +// GetJWTMaxDuration specified the maximum amount of allowed `exp` time for a JWT. It compares +// the time with the JWT's `exp` time if the JWT time is larger, will cause the JWT to be invalid. +// +// Defaults to a day. +func (c *Config) GetJWTMaxDuration() time.Duration { + if c.GrantTypeJWTBearerMaxDuration == 0 { + return time.Hour * 24 + } + return c.GrantTypeJWTBearerMaxDuration +} diff --git a/generate-mocks.sh b/generate-mocks.sh index 05ece298f..d4dded4ea 100755 --- a/generate-mocks.sh +++ b/generate-mocks.sh @@ -6,6 +6,7 @@ mockgen -package internal -destination internal/transactional.go github.com/ory/ mockgen -package internal -destination internal/oauth2_storage.go github.com/ory/fosite/handler/oauth2 CoreStorage mockgen -package internal -destination internal/oauth2_strategy.go github.com/ory/fosite/handler/oauth2 CoreStrategy mockgen -package internal -destination internal/authorize_code_storage.go github.com/ory/fosite/handler/oauth2 AuthorizeCodeStorage +mockgen -package internal -destination internal/oauth2_auth_jwt_storage.go github.com/ory/fosite/handler/rfc7523 RFC7523KeyStorage mockgen -package internal -destination internal/access_token_storage.go github.com/ory/fosite/handler/oauth2 AccessTokenStorage mockgen -package internal -destination internal/refresh_token_strategy.go github.com/ory/fosite/handler/oauth2 RefreshTokenStorage mockgen -package internal -destination internal/oauth2_client_storage.go github.com/ory/fosite/handler/oauth2 ClientCredentialsGrantStorage diff --git a/go.mod b/go.mod index 79f32c2d8..3ab8afc5a 100644 --- a/go.mod +++ b/go.mod @@ -1,30 +1,31 @@ module github.com/ory/fosite -// Using replace for reflection libs allows to by pass fosite 0.29 broken dependecie -// ory/x should be updated to latest version of fosite if possible to avoid -// this line. -replace github.com/oleiade/reflections v1.0.0 => github.com/oleiade/reflections v1.0.1 - require ( github.com/asaskevich/govalidator v0.0.0-20200428143746-21a406dcc535 + github.com/dgraph-io/ristretto v0.0.3 // indirect github.com/dgrijalva/jwt-go v3.2.0+incompatible - github.com/golang/mock v1.4.3 + github.com/golang/mock v1.4.4 + github.com/golang/protobuf v1.4.0 // indirect github.com/gorilla/mux v1.7.3 github.com/gorilla/websocket v1.4.2 github.com/magiconair/properties v1.8.1 github.com/mattn/goveralls v0.0.6 github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 - github.com/oleiade/reflections v1.0.1 + github.com/oleiade/reflections v1.0.0 github.com/ory/go-acc v0.2.5 github.com/ory/go-convenience v0.1.0 github.com/ory/x v0.0.162 github.com/parnurzeal/gorequest v0.2.15 github.com/pborman/uuid v1.2.0 github.com/pkg/errors v0.9.1 + github.com/spf13/afero v1.3.2 // indirect github.com/stretchr/testify v1.6.1 golang.org/x/crypto v0.0.0-20201203163018-be400aefbc4c golang.org/x/net v0.0.0-20200625001655-4c5254603344 golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d + golang.org/x/sys v0.0.0-20200720211630-cb9d2d5c5666 // indirect + golang.org/x/text v0.3.3 // indirect + golang.org/x/tools v0.0.0-20200721223218-6123e77877b2 // indirect gopkg.in/square/go-jose.v2 v2.5.1 ) diff --git a/go.sum b/go.sum index de1b65492..9b9af9943 100644 --- a/go.sum +++ b/go.sum @@ -72,6 +72,8 @@ github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/dgraph-io/ristretto v0.0.1/go.mod h1:T40EBc7CJke8TkpiYfGGKAeFjSaxuFXhuXRyumBd6RE= github.com/dgraph-io/ristretto v0.0.2 h1:a5WaUrDa0qm0YrAAS1tUykT5El3kt62KNZZeMxQn3po= github.com/dgraph-io/ristretto v0.0.2/go.mod h1:KPxhHT9ZxKefz+PCeOGsrHpl1qZ7i70dGTu2u+Ahh6E= +github.com/dgraph-io/ristretto v0.0.3 h1:jh22xisGBjrEVnRZ1DVTpBVQm0Xndu8sMl0CWDzSIBI= +github.com/dgraph-io/ristretto v0.0.3/go.mod h1:KPxhHT9ZxKefz+PCeOGsrHpl1qZ7i70dGTu2u+Ahh6E= github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM= github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2 h1:tdlZCpZ/P9DhczCTSixgIKmwPv6+wP5DGjqLYw5SUiA= @@ -354,19 +356,26 @@ github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4er github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFUx0Y= -github.com/golang/mock v1.4.3 h1:GV+pQPG/EUUbkh47niozDcADz6go/dUwhVzdUQHIVRw= -github.com/golang/mock v1.4.3/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= +github.com/golang/mock v1.4.4 h1:l75CXGRSwbaYNpl/Z2X1XIIAMSCquvXgpVZDhwEIJsc= +github.com/golang/mock v1.4.4/go.mod h1:l3mdAwkq5BuhzHwde/uurv3sEJeZMXNpwsxVWU71h+4= github.com/golang/protobuf v1.1.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= +github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= +github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= +github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= +github.com/golang/protobuf v1.4.0 h1:oOuy+ugB+P/kBdUnG5QaMXSIyJ1q38wWSojYCb3z5VQ= +github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.2 h1:X2ev0eStA3AbceY54o37/0PQ/UWqKEiiO2dKL5OPaFM= github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-jsonnet v0.16.0/go.mod h1:sOcuej3UW1vpPTZOr8L7RQimqai1a57bt5j22LzGZCw= @@ -466,6 +475,7 @@ github.com/konsorten/go-windows-terminal-sequences v0.0.0-20180402223658-b729f26 github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.2.0 h1:s5hAObm+yFO5uHYt5dYjxi2rXrsnmRpJx4OYvIWUaQs= @@ -544,8 +554,8 @@ github.com/moul/http2curl v0.0.0-20170919181001-9ac6cf4d929b/go.mod h1:8UbvGypXm github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/nicksnyder/go-i18n v1.10.0/go.mod h1:HrK7VCrbOvQoUAQ7Vpy7i87N7JZZZ7R2xBGjv0j365Q= github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U= -github.com/oleiade/reflections v1.0.1 h1:D1XO3LVEYroYskEsoSiGItp9RUxG6jWnCVvrqH0HHQM= -github.com/oleiade/reflections v1.0.1/go.mod h1:rdFxbxq4QXVZWj0F+e9jqjDkc7dbp97vkRixKo2JR60= +github.com/oleiade/reflections v1.0.0 h1:0ir4pc6v8/PJ0yw5AEtMddfXpWBXg9cnG7SgSoJuCgY= +github.com/oleiade/reflections v1.0.0/go.mod h1:RbATFBbKYkVdqmSFtx13Bb/tVhR0lgOBXunWTZKeL4w= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.8.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= @@ -611,6 +621,7 @@ github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/profile v1.2.1/go.mod h1:hJw3o1OdXxsrSjjVksARp5W95eeEaEfptyVZyv6JUPA= +github.com/pkg/sftp v1.10.1/go.mod h1:lYOWFsE0bwd1+KfKJaKeuokY15vzFx25BLbzYYoAxZI= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= @@ -686,6 +697,8 @@ github.com/spf13/afero v1.1.2/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B github.com/spf13/afero v1.2.0/go.mod h1:9ZxEEn6pIJ8Rxe320qSDBk6AsU0r9pR7Q4OcevTdifk= github.com/spf13/afero v1.2.2 h1:5jhuqJyZCZf2JRofRvN/nIFgIWNzPa3/Vz8mYylgbWc= github.com/spf13/afero v1.2.2/go.mod h1:9ZxEEn6pIJ8Rxe320qSDBk6AsU0r9pR7Q4OcevTdifk= +github.com/spf13/afero v1.3.2 h1:GDarE4TJQI52kYSbSAmLiId1Elfj+xgSDqrUZxFhxlU= +github.com/spf13/afero v1.3.2/go.mod h1:5KUK8ByomD5Ti5Artl0RtHeI5pTF7MIDuXL3yY520V4= github.com/spf13/cast v1.2.0/go.mod h1:r2rcYCSwa1IExKTDiTfzaxqT2FNHs8hODu4LnUfgKEg= github.com/spf13/cast v1.3.0/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= github.com/spf13/cast v1.3.1 h1:nFm6S0SMdyzrzcmThSipiEubIDy8WEXKNZ0UOgiRpng= @@ -743,6 +756,7 @@ github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q= github.com/xtgo/uuid v0.0.0-20140804021211-a0b114877d4c/go.mod h1:UrdRz5enIKZ63MEE3IF9l2/ebyx59GyGgPi+tICQdmM= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= github.com/ziutek/mymysql v1.5.4/go.mod h1:LMSpPZ6DbqWFxNCHW77HeMg9I646SAhApZ/wKdgO/C0= go.elastic.co/apm v1.8.0/go.mod h1:tCw6CkOJgkWnzEthFN9HUP1uL3Gjc/Ur6m7gRPLaoH0= @@ -813,6 +827,8 @@ golang.org/x/mobile v0.0.0-20190312151609-d3739f865fa6/go.mod h1:z+o9i4GpDbdi3rU golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.2.0 h1:KU7oHjnv3XNWfa5COkzUifxZmxp1TyI7ImMXqFxLwvQ= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.3.0 h1:RM4zey1++hCTbCVQfnWeKs9/IEsaBLA8vTkd0WVtmH4= +golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180816102801-aaf60122140d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -862,6 +878,7 @@ golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180816055513-1c9583448a9c/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180831094639-fa5fdf94c789/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -911,12 +928,14 @@ golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200331124033-c3d80250170d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200602225109-6fdc65e7d980 h1:OjiUf46hAmXblsZdnoSXsEUSKU8r1UEzcL5RVZ4gO9Y= golang.org/x/sys v0.0.0-20200602225109-6fdc65e7d980/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200720211630-cb9d2d5c5666 h1:gVCS+QOncANNPlmlO1AhlU3oxs4V9z+gTtPwIk3p2N8= +golang.org/x/sys v0.0.0-20200720211630-cb9d2d5c5666/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= -golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180221164845-07fd8470d635/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -971,6 +990,8 @@ golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapK golang.org/x/tools v0.0.0-20200203215610-ab391d50b528/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.0.0-20200522201501-cb1345f3a375 h1:SjQ2+AKWgZLc1xej6WSzL+Dfs5Uyd5xcZH1mGC411IA= golang.org/x/tools v0.0.0-20200522201501-cb1345f3a375/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= +golang.org/x/tools v0.0.0-20200721223218-6123e77877b2 h1:kxDWg8KNMtpGjI/XVKGgOtSljTnVg/PrjhS8+0pxjLE= +golang.org/x/tools v0.0.0-20200721223218-6123e77877b2/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -1010,6 +1031,12 @@ google.golang.org/grpc v1.21.0/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ij google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= google.golang.org/grpc v1.22.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= google.golang.org/grpc v1.22.1/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= +google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= +google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= +google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= +google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= +google.golang.org/protobuf v1.21.0 h1:qdOKuR/EIArgaWNjetjgTzgVTAZ+S/WXVrq9HW9zimw= +google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= gopkg.in/DataDog/dd-trace-go.v1 v1.27.0/go.mod h1:Sp1lku8WJMvNV0kjDI4Ni/T7J/U3BO5ct5kEaoVU8+I= gopkg.in/airbrake/gobrake.v2 v2.0.9/go.mod h1:/h5ZAUhDkGaJfjzjKLSjv6zCL6O0LLBxU4K+aSYdM/U= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= @@ -1061,5 +1088,3 @@ modernc.org/strutil v1.1.0/go.mod h1:lstksw84oURvj9y3tn8lGvRxyRC1S2+g5uuIzNfIOBs modernc.org/xc v1.0.0/go.mod h1:mRNCo0bvLjGhHO9WsyuKVU4q0ceiDDDoEeWDJHrNx8I= rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8= rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= -rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0= -rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA= diff --git a/handler.go b/handler.go index 090538369..a73dc760b 100644 --- a/handler.go +++ b/handler.go @@ -48,6 +48,15 @@ type TokenEndpointHandler interface { // HandleTokenEndpointRequest handles an authorize request. If the handler is not responsible for handling // the request, this method should return ErrUnknownRequest and otherwise handle the request. HandleTokenEndpointRequest(ctx context.Context, requester AccessRequester) error + + // CanSkipClientAuth indicates if client authentication can be skipped. By default it MUST be false, unless you are + // implementing extension grant type, which allows unauthenticated client. CanSkipClientAuth must be called + // before HandleTokenEndpointRequest to decide, if AccessRequester will contain authenticated client. + CanSkipClientAuth(requester AccessRequester) bool + + // CanHandleRequest indicates, if TokenEndpointHandler can handle this request or not. If true, + // HandleTokenEndpointRequest can be called. + CanHandleTokenEndpointRequest(requester AccessRequester) bool } // RevocationHandler is the interface that allows token revocation for an OAuth2.0 provider. diff --git a/handler/oauth2/flow_authorize_code_token.go b/handler/oauth2/flow_authorize_code_token.go index 957619816..986663ef6 100644 --- a/handler/oauth2/flow_authorize_code_token.go +++ b/handler/oauth2/flow_authorize_code_token.go @@ -37,9 +37,7 @@ import ( // HandleTokenEndpointRequest implements // * https://tools.ietf.org/html/rfc6749#section-4.1.3 (everything) func (c *AuthorizeExplicitGrantHandler) HandleTokenEndpointRequest(ctx context.Context, request fosite.AccessRequester) error { - // grant_type REQUIRED. - // Value MUST be set to "authorization_code". - if !request.GetGrantTypes().ExactOne("authorization_code") { + if !c.CanHandleTokenEndpointRequest(request) { return errorsx.WithStack(errorsx.WithStack(fosite.ErrUnknownRequest)) } @@ -57,7 +55,7 @@ func (c *AuthorizeExplicitGrantHandler) HandleTokenEndpointRequest(ctx context.C WithDebug("GetAuthorizeCodeSession must return a value for \"fosite.Requester\" when returning \"ErrInvalidatedAuthorizeCode\".") } - //If an authorize code is used twice, we revoke all refresh and access tokens associated with this request. + // If an authorize code is used twice, we revoke all refresh and access tokens associated with this request. reqID := authorizeRequest.GetID() hint := "The authorization code has already been used." debug := "" @@ -133,9 +131,7 @@ func canIssueRefreshToken(c *AuthorizeExplicitGrantHandler, request fosite.Reque } func (c *AuthorizeExplicitGrantHandler) PopulateTokenEndpointResponse(ctx context.Context, requester fosite.AccessRequester, responder fosite.AccessResponder) error { - // grant_type REQUIRED. - // Value MUST be set to "authorization_code", as this is the explicit grant handler. - if !requester.GetGrantTypes().ExactOne("authorization_code") { + if !c.CanHandleTokenEndpointRequest(requester) { return errorsx.WithStack(fosite.ErrUnknownRequest) } @@ -208,3 +204,13 @@ func (c *AuthorizeExplicitGrantHandler) PopulateTokenEndpointResponse(ctx contex return nil } + +func (c *AuthorizeExplicitGrantHandler) CanSkipClientAuth(requester fosite.AccessRequester) bool { + return false +} + +func (c *AuthorizeExplicitGrantHandler) CanHandleTokenEndpointRequest(requester fosite.AccessRequester) bool { + // grant_type REQUIRED. + // Value MUST be set to "authorization_code" + return requester.GetGrantTypes().ExactOne("authorization_code") +} diff --git a/handler/oauth2/flow_client_credentials.go b/handler/oauth2/flow_client_credentials.go index 01bfea127..f86e56096 100644 --- a/handler/oauth2/flow_client_credentials.go +++ b/handler/oauth2/flow_client_credentials.go @@ -38,9 +38,7 @@ type ClientCredentialsGrantHandler struct { // IntrospectTokenEndpointRequest implements https://tools.ietf.org/html/rfc6749#section-4.4.2 func (c *ClientCredentialsGrantHandler) HandleTokenEndpointRequest(_ context.Context, request fosite.AccessRequester) error { - // grant_type REQUIRED. - // Value MUST be set to "client_credentials". - if !request.GetGrantTypes().ExactOne("client_credentials") { + if !c.CanHandleTokenEndpointRequest(request) { return errorsx.WithStack(fosite.ErrUnknownRequest) } @@ -69,7 +67,7 @@ func (c *ClientCredentialsGrantHandler) HandleTokenEndpointRequest(_ context.Con // PopulateTokenEndpointResponse implements https://tools.ietf.org/html/rfc6749#section-4.4.3 func (c *ClientCredentialsGrantHandler) PopulateTokenEndpointResponse(ctx context.Context, request fosite.AccessRequester, response fosite.AccessResponder) error { - if !request.GetGrantTypes().ExactOne("client_credentials") { + if !c.CanHandleTokenEndpointRequest(request) { return errorsx.WithStack(fosite.ErrUnknownRequest) } @@ -79,3 +77,13 @@ func (c *ClientCredentialsGrantHandler) PopulateTokenEndpointResponse(ctx contex return c.IssueAccessToken(ctx, request, response) } + +func (c *ClientCredentialsGrantHandler) CanSkipClientAuth(requester fosite.AccessRequester) bool { + return false +} + +func (c *ClientCredentialsGrantHandler) CanHandleTokenEndpointRequest(requester fosite.AccessRequester) bool { + // grant_type REQUIRED. + // Value MUST be set to "client_credentials". + return requester.GetGrantTypes().ExactOne("client_credentials") +} diff --git a/handler/oauth2/flow_refresh.go b/handler/oauth2/flow_refresh.go index 17c3036cc..adb05e1e9 100644 --- a/handler/oauth2/flow_refresh.go +++ b/handler/oauth2/flow_refresh.go @@ -53,9 +53,7 @@ type RefreshTokenGrantHandler struct { // HandleTokenEndpointRequest implements https://tools.ietf.org/html/rfc6749#section-6 func (c *RefreshTokenGrantHandler) HandleTokenEndpointRequest(ctx context.Context, request fosite.AccessRequester) error { - // grant_type REQUIRED. - // Value MUST be set to "refresh_token". - if !request.GetGrantTypes().ExactOne("refresh_token") { + if !c.CanHandleTokenEndpointRequest(request) { return errorsx.WithStack(fosite.ErrUnknownRequest) } @@ -117,7 +115,7 @@ func (c *RefreshTokenGrantHandler) HandleTokenEndpointRequest(ctx context.Contex // PopulateTokenEndpointResponse implements https://tools.ietf.org/html/rfc6749#section-6 func (c *RefreshTokenGrantHandler) PopulateTokenEndpointResponse(ctx context.Context, requester fosite.AccessRequester, responder fosite.AccessResponder) error { - if !requester.GetGrantTypes().ExactOne("refresh_token") { + if !c.CanHandleTokenEndpointRequest(requester) { return errorsx.WithStack(fosite.ErrUnknownRequest) } @@ -194,3 +192,13 @@ func handleRefreshTokenEndpointResponseStorageError(ctx context.Context, rollbac return errorsx.WithStack(fosite.ErrServerError.WithWrap(storageErr).WithDebug(storageErr.Error())) } + +func (c *RefreshTokenGrantHandler) CanSkipClientAuth(requester fosite.AccessRequester) bool { + return false +} + +func (c *RefreshTokenGrantHandler) CanHandleTokenEndpointRequest(requester fosite.AccessRequester) bool { + // grant_type REQUIRED. + // Value MUST be set to "refresh_token". + return requester.GetGrantTypes().ExactOne("refresh_token") +} diff --git a/handler/oauth2/flow_resource_owner.go b/handler/oauth2/flow_resource_owner.go index f0e49a51c..89263ea83 100644 --- a/handler/oauth2/flow_resource_owner.go +++ b/handler/oauth2/flow_resource_owner.go @@ -46,9 +46,7 @@ type ResourceOwnerPasswordCredentialsGrantHandler struct { // HandleTokenEndpointRequest implements https://tools.ietf.org/html/rfc6749#section-4.3.2 func (c *ResourceOwnerPasswordCredentialsGrantHandler) HandleTokenEndpointRequest(ctx context.Context, request fosite.AccessRequester) error { - // grant_type REQUIRED. - // Value MUST be set to "password". - if !request.GetGrantTypes().ExactOne("password") { + if !c.CanHandleTokenEndpointRequest(request) { return errorsx.WithStack(fosite.ErrUnknownRequest) } @@ -90,7 +88,7 @@ func (c *ResourceOwnerPasswordCredentialsGrantHandler) HandleTokenEndpointReques // PopulateTokenEndpointResponse implements https://tools.ietf.org/html/rfc6749#section-4.3.3 func (c *ResourceOwnerPasswordCredentialsGrantHandler) PopulateTokenEndpointResponse(ctx context.Context, requester fosite.AccessRequester, responder fosite.AccessResponder) error { - if !requester.GetGrantTypes().ExactOne("password") { + if !c.CanHandleTokenEndpointRequest(requester) { return errorsx.WithStack(fosite.ErrUnknownRequest) } @@ -115,3 +113,13 @@ func (c *ResourceOwnerPasswordCredentialsGrantHandler) PopulateTokenEndpointResp return nil } + +func (c *ResourceOwnerPasswordCredentialsGrantHandler) CanSkipClientAuth(requester fosite.AccessRequester) bool { + return false +} + +func (c *ResourceOwnerPasswordCredentialsGrantHandler) CanHandleTokenEndpointRequest(requester fosite.AccessRequester) bool { + // grant_type REQUIRED. + // Value MUST be set to "password". + return requester.GetGrantTypes().ExactOne("password") +} diff --git a/handler/oauth2/strategy_jwt_session.go b/handler/oauth2/strategy_jwt_session.go index a2bbd454a..7334566a1 100644 --- a/handler/oauth2/strategy_jwt_session.go +++ b/handler/oauth2/strategy_jwt_session.go @@ -63,43 +63,47 @@ func (j *JWTSession) GetJWTHeader() *jwt.Headers { return j.JWTHeader } -func (s *JWTSession) SetExpiresAt(key fosite.TokenType, exp time.Time) { - if s.ExpiresAt == nil { - s.ExpiresAt = make(map[fosite.TokenType]time.Time) +func (j *JWTSession) SetExpiresAt(key fosite.TokenType, exp time.Time) { + if j.ExpiresAt == nil { + j.ExpiresAt = make(map[fosite.TokenType]time.Time) } - s.ExpiresAt[key] = exp + j.ExpiresAt[key] = exp } -func (s *JWTSession) GetExpiresAt(key fosite.TokenType) time.Time { - if s.ExpiresAt == nil { - s.ExpiresAt = make(map[fosite.TokenType]time.Time) +func (j *JWTSession) GetExpiresAt(key fosite.TokenType) time.Time { + if j.ExpiresAt == nil { + j.ExpiresAt = make(map[fosite.TokenType]time.Time) } - if _, ok := s.ExpiresAt[key]; !ok { + if _, ok := j.ExpiresAt[key]; !ok { return time.Time{} } - return s.ExpiresAt[key] + return j.ExpiresAt[key] } -func (s *JWTSession) GetUsername() string { - if s == nil { +func (j *JWTSession) GetUsername() string { + if j == nil { return "" } - return s.Username + return j.Username } -func (s *JWTSession) GetSubject() string { - if s == nil { +func (j *JWTSession) SetSubject(subject string) { + j.Subject = subject +} + +func (j *JWTSession) GetSubject() string { + if j == nil { return "" } - return s.Subject + return j.Subject } -func (s *JWTSession) Clone() fosite.Session { - if s == nil { +func (j *JWTSession) Clone() fosite.Session { + if j == nil { return nil } - return deepcopy.Copy(s).(fosite.Session) + return deepcopy.Copy(j).(fosite.Session) } diff --git a/handler/openid/flow_explicit_token.go b/handler/openid/flow_explicit_token.go index 376680b40..8f2904831 100644 --- a/handler/openid/flow_explicit_token.go +++ b/handler/openid/flow_explicit_token.go @@ -36,7 +36,7 @@ func (c *OpenIDConnectExplicitHandler) HandleTokenEndpointRequest(ctx context.Co } func (c *OpenIDConnectExplicitHandler) PopulateTokenEndpointResponse(ctx context.Context, requester fosite.AccessRequester, responder fosite.AccessResponder) error { - if !requester.GetGrantTypes().ExactOne("authorization_code") { + if !c.CanHandleTokenEndpointRequest(requester) { return errorsx.WithStack(fosite.ErrUnknownRequest) } @@ -76,3 +76,11 @@ func (c *OpenIDConnectExplicitHandler) PopulateTokenEndpointResponse(ctx context return c.IssueExplicitIDToken(ctx, authorize, responder) } + +func (c *OpenIDConnectExplicitHandler) CanSkipClientAuth(requester fosite.AccessRequester) bool { + return false +} + +func (c *OpenIDConnectExplicitHandler) CanHandleTokenEndpointRequest(requester fosite.AccessRequester) bool { + return requester.GetGrantTypes().ExactOne("authorization_code") +} diff --git a/handler/openid/flow_refresh_token.go b/handler/openid/flow_refresh_token.go index 7a2ba0a98..90a23f7e9 100644 --- a/handler/openid/flow_refresh_token.go +++ b/handler/openid/flow_refresh_token.go @@ -39,7 +39,7 @@ type OpenIDConnectRefreshHandler struct { } func (c *OpenIDConnectRefreshHandler) HandleTokenEndpointRequest(ctx context.Context, request fosite.AccessRequester) error { - if !request.GetGrantTypes().ExactOne("refresh_token") { + if !c.CanHandleTokenEndpointRequest(request) { return errorsx.WithStack(fosite.ErrUnknownRequest) } @@ -77,7 +77,7 @@ func (c *OpenIDConnectRefreshHandler) HandleTokenEndpointRequest(ctx context.Con } func (c *OpenIDConnectRefreshHandler) PopulateTokenEndpointResponse(ctx context.Context, requester fosite.AccessRequester, responder fosite.AccessResponder) error { - if !requester.GetGrantTypes().ExactOne("refresh_token") { + if !c.CanHandleTokenEndpointRequest(requester) { return errorsx.WithStack(fosite.ErrUnknownRequest) } @@ -111,3 +111,13 @@ func (c *OpenIDConnectRefreshHandler) PopulateTokenEndpointResponse(ctx context. return c.IssueExplicitIDToken(ctx, requester, responder) } + +func (c *OpenIDConnectRefreshHandler) CanSkipClientAuth(requester fosite.AccessRequester) bool { + return false +} + +func (c *OpenIDConnectRefreshHandler) CanHandleTokenEndpointRequest(requester fosite.AccessRequester) bool { + // grant_type REQUIRED. + // Value MUST be set to "refresh_token" + return requester.GetGrantTypes().ExactOne("refresh_token") +} diff --git a/handler/openid/strategy_jwt.go b/handler/openid/strategy_jwt.go index 300708309..77f0f0a81 100644 --- a/handler/openid/strategy_jwt.go +++ b/handler/openid/strategy_jwt.go @@ -101,6 +101,10 @@ func (s *DefaultSession) GetUsername() string { return s.Username } +func (s *DefaultSession) SetSubject(subject string) { + s.Subject = subject +} + func (s *DefaultSession) GetSubject() string { if s == nil { return "" diff --git a/handler/pkce/handler.go b/handler/pkce/handler.go index 489491339..2098384b3 100644 --- a/handler/pkce/handler.go +++ b/handler/pkce/handler.go @@ -127,7 +127,7 @@ func (c *Handler) validate(challenge, method string, client fosite.Client) error } func (c *Handler) HandleTokenEndpointRequest(ctx context.Context, request fosite.AccessRequester) error { - if !request.GetGrantTypes().ExactOne("authorization_code") { + if !c.CanHandleTokenEndpointRequest(request) { return errorsx.WithStack(fosite.ErrUnknownRequest) } @@ -229,3 +229,13 @@ func (c *Handler) HandleTokenEndpointRequest(ctx context.Context, request fosite func (c *Handler) PopulateTokenEndpointResponse(ctx context.Context, requester fosite.AccessRequester, responder fosite.AccessResponder) error { return nil } + +func (c *Handler) CanSkipClientAuth(requester fosite.AccessRequester) bool { + return false +} + +func (c *Handler) CanHandleTokenEndpointRequest(requester fosite.AccessRequester) bool { + // grant_type REQUIRED. + // Value MUST be set to "authorization_code" + return requester.GetGrantTypes().ExactOne("authorization_code") +} diff --git a/handler/rfc7523/handler.go b/handler/rfc7523/handler.go new file mode 100644 index 000000000..2b474b148 --- /dev/null +++ b/handler/rfc7523/handler.go @@ -0,0 +1,331 @@ +/* + * Copyright © 2015-2018 Aeneas Rekkas + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * @author Aeneas Rekkas + * @copyright 2015-2018 Aeneas Rekkas + * @license Apache-2.0 + * + */ + +package rfc7523 + +import ( + "context" + "time" + + "github.com/ory/fosite/handler/oauth2" + + "gopkg.in/square/go-jose.v2" + "gopkg.in/square/go-jose.v2/jwt" + + "github.com/ory/fosite" + "github.com/ory/x/errorsx" +) + +const grantTypeJWTBearer = "urn:ietf:params:oauth:grant-type:jwt-bearer" + +type Handler struct { + Storage RFC7523KeyStorage + ScopeStrategy fosite.ScopeStrategy + AudienceMatchingStrategy fosite.AudienceMatchingStrategy + + // TokenURL is the the URL of the Authorization Server's Token Endpoint. + TokenURL string + // SkipClientAuth indicates, if client authentication can be skipped. + SkipClientAuth bool + // JWTIDOptional indicates, if jti (JWT ID) claim required or not. + JWTIDOptional bool + // JWTIssuedDateOptional indicates, if "iat" (issued at) claim required or not. + JWTIssuedDateOptional bool + // JWTMaxDuration sets the maximum time after token issued date (if present), during which the token is + // considered valid. If "iat" claim is not present, then current time will be used as issued date. + JWTMaxDuration time.Duration + + *oauth2.HandleHelper +} + +// HandleTokenEndpointRequest implements https://tools.ietf.org/html/rfc6749#section-4.1.3 (everything) and +// https://tools.ietf.org/html/rfc7523#section-2.1 (everything) +func (c *Handler) HandleTokenEndpointRequest(ctx context.Context, request fosite.AccessRequester) error { + if err := c.CheckRequest(request); err != nil { + return err + } + + assertion := request.GetRequestForm().Get("assertion") + if assertion == "" { + return errorsx.WithStack(fosite.ErrInvalidRequest.WithHintf("The assertion request parameter must be set when using grant_type of '%s'.", grantTypeJWTBearer)) + } + + token, err := jwt.ParseSigned(assertion) + if err != nil { + return errorsx.WithStack(fosite.ErrInvalidGrant. + WithHint("Unable to parse JSON Web Token passed in \"assertion\" request parameter."). + WithWrap(err).WithDebug(err.Error()), + ) + } + + // Check fo required claims in token, so we can later find public key based on them. + if err := c.validateTokenPreRequisites(token); err != nil { + return err + } + + key, err := c.findPublicKeyForToken(ctx, token) + if err != nil { + return err + } + + claims := jwt.Claims{} + if err := token.Claims(key, &claims); err != nil { + return errorsx.WithStack(fosite.ErrInvalidGrant. + WithHint("Unable to verify the integrity of the 'assertion' value."). + WithWrap(err).WithDebug(err.Error()), + ) + } + + if err := c.validateTokenClaims(ctx, claims, key); err != nil { + return err + } + + scopes, err := c.Storage.GetPublicKeyScopes(ctx, claims.Issuer, claims.Subject, key.KeyID) + if err != nil { + return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error())) + } + + for _, scope := range request.GetRequestedScopes() { + if !c.ScopeStrategy(scopes, scope) { + return errorsx.WithStack(fosite.ErrInvalidScope.WithHintf("The public key registered for issuer \"%s\" and subject \"%s\" is not allowed to request scope \"%s\".", claims.Issuer, claims.Subject, scope)) + } + } + + if claims.ID != "" { + if err := c.Storage.MarkJWTUsedForTime(ctx, claims.ID, claims.Expiry.Time()); err != nil { + return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error())) + } + } + + for _, scope := range request.GetRequestedScopes() { + request.GrantScope(scope) + } + + for _, audience := range claims.Audience { + request.GrantAudience(audience) + } + + session, err := c.getSessionFromRequest(request) + if err != nil { + return err + } + session.SetExpiresAt(fosite.AccessToken, time.Now().UTC().Add(c.HandleHelper.AccessTokenLifespan).Round(time.Second)) + session.SetSubject(claims.Subject) + + return nil +} + +func (c *Handler) PopulateTokenEndpointResponse(ctx context.Context, request fosite.AccessRequester, response fosite.AccessResponder) error { + if err := c.CheckRequest(request); err != nil { + return err + } + + return c.IssueAccessToken(ctx, request, response) +} + +func (c *Handler) CanSkipClientAuth(requester fosite.AccessRequester) bool { + return c.SkipClientAuth +} + +func (c *Handler) CanHandleTokenEndpointRequest(requester fosite.AccessRequester) bool { + // grant_type REQUIRED. + // Value MUST be set to "authorization_code" + return requester.GetGrantTypes().ExactOne(grantTypeJWTBearer) +} + +func (c *Handler) CheckRequest(request fosite.AccessRequester) error { + if !c.CanHandleTokenEndpointRequest(request) { + return errorsx.WithStack(fosite.ErrUnknownRequest) + } + + // Client Authentication is optional: + // + // Authentication of the client is optional, as described in + // Section 3.2.1 of OAuth 2.0 [RFC6749] and consequently, the + // "client_id" is only needed when a form of client authentication that + // relies on the parameter is used. + + // if client is authenticated, check grant types + if !c.CanSkipClientAuth(request) && !request.GetClient().GetGrantTypes().Has(grantTypeJWTBearer) { + return errorsx.WithStack(fosite.ErrUnauthorizedClient.WithHintf("The OAuth 2.0 Client is not allowed to use authorization grant \"%s\".", grantTypeJWTBearer)) + } + + return nil +} + +func (c *Handler) validateTokenPreRequisites(token *jwt.JSONWebToken) error { + unverifiedClaims := jwt.Claims{} + if err := token.UnsafeClaimsWithoutVerification(&unverifiedClaims); err != nil { + return errorsx.WithStack(fosite.ErrInvalidGrant. + WithHint("Looks like there are no claims in JWT in \"assertion\" request parameter."). + WithWrap(err).WithDebug(err.Error()), + ) + } + if unverifiedClaims.Issuer == "" { + return errorsx.WithStack(fosite.ErrInvalidGrant. + WithHint("The JWT in \"assertion\" request parameter MUST contain an \"iss\" (issuer) claim."), + ) + } + if unverifiedClaims.Subject == "" { + return errorsx.WithStack(fosite.ErrInvalidGrant. + WithHint("The JWT in \"assertion\" request parameter MUST contain a \"sub\" (subject) claim."), + ) + } + + return nil +} + +func (c *Handler) findPublicKeyForToken(ctx context.Context, token *jwt.JSONWebToken) (*jose.JSONWebKey, error) { + unverifiedClaims := jwt.Claims{} + if err := token.UnsafeClaimsWithoutVerification(&unverifiedClaims); err != nil { + return nil, errorsx.WithStack(fosite.ErrInvalidRequest.WithWrap(err).WithDebug(err.Error())) + } + + var keyID string + for _, header := range token.Headers { + if header.KeyID != "" { + keyID = header.KeyID + break + } + } + + keyNotFoundErr := fosite.ErrInvalidGrant.WithHintf( + "No public JWK was registered for issuer \"%s\" and subject \"%s\", and public key is required to check signature of JWT in \"assertion\" request parameter.", + unverifiedClaims.Issuer, + unverifiedClaims.Subject, + ) + if keyID != "" { + key, err := c.Storage.GetPublicKey(ctx, unverifiedClaims.Issuer, unverifiedClaims.Subject, keyID) + if err != nil { + return nil, errorsx.WithStack(keyNotFoundErr.WithWrap(err).WithDebug(err.Error())) + } + return key, nil + } + + keys, err := c.Storage.GetPublicKeys(ctx, unverifiedClaims.Issuer, unverifiedClaims.Subject) + if err != nil { + return nil, errorsx.WithStack(keyNotFoundErr.WithWrap(err).WithDebug(err.Error())) + } + + claims := jwt.Claims{} + for _, key := range keys.Keys { + err := token.Claims(key, &claims) + if err == nil { + return &key, nil + } + } + + return nil, errorsx.WithStack(keyNotFoundErr) +} + +func (c *Handler) validateTokenClaims(ctx context.Context, claims jwt.Claims, key *jose.JSONWebKey) error { + if len(claims.Audience) == 0 { + return errorsx.WithStack(fosite.ErrInvalidGrant. + WithHint("The JWT in \"assertion\" request parameter MUST contain an \"aud\" (audience) claim."), + ) + } + + if !claims.Audience.Contains(c.TokenURL) { + return errorsx.WithStack(fosite.ErrInvalidGrant. + WithHintf( + "The JWT in \"assertion\" request parameter MUST contain an \"aud\" (audience) claim containing a value \"%s\" that identifies the authorization server as an intended audience.", + c.TokenURL, + ), + ) + } + + if claims.Expiry == nil { + return errorsx.WithStack(fosite.ErrInvalidGrant. + WithHint("The JWT in \"assertion\" request parameter MUST contain an \"exp\" (expiration time) claim."), + ) + } + + if claims.Expiry.Time().Before(time.Now()) { + return errorsx.WithStack(fosite.ErrInvalidGrant. + WithHint("The JWT in \"assertion\" request parameter expired."), + ) + } + + if claims.NotBefore != nil && !claims.NotBefore.Time().Before(time.Now()) { + return errorsx.WithStack(fosite.ErrInvalidGrant. + WithHintf( + "The JWT in \"assertion\" request parameter contains an \"nbf\" (not before) claim, that identifies the time '%s' before which the token MUST NOT be accepted.", + claims.NotBefore.Time().Format(time.RFC3339), + ), + ) + } + + if !c.JWTIssuedDateOptional && claims.IssuedAt == nil { + return errorsx.WithStack(fosite.ErrInvalidGrant. + WithHint("The JWT in \"assertion\" request parameter MUST contain an \"iat\" (issued at) claim."), + ) + } + + var issuedDate time.Time + if claims.IssuedAt != nil { + issuedDate = claims.IssuedAt.Time() + } else { + issuedDate = time.Now() + } + if claims.Expiry.Time().Sub(issuedDate) > c.JWTMaxDuration { + return errorsx.WithStack(fosite.ErrInvalidGrant. + WithHintf( + "The JWT in \"assertion\" request parameter contains an \"exp\" (expiration time) claim with value \"%s\" that is unreasonably far in the future, considering token issued at \"%s\".", + claims.Expiry.Time().Format(time.RFC3339), + issuedDate.Format(time.RFC3339), + ), + ) + } + + if !c.JWTIDOptional && claims.ID == "" { + return errorsx.WithStack(fosite.ErrInvalidGrant. + WithHint("The JWT in \"assertion\" request parameter MUST contain an \"jti\" (JWT ID) claim."), + ) + } + + if claims.ID != "" { + used, err := c.Storage.IsJWTUsed(ctx, claims.ID) + if err != nil { + return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error())) + } + if used { + return errorsx.WithStack(fosite.ErrJTIKnown) + } + } + + return nil +} + +type extendedSession interface { + Session + fosite.Session +} + +func (c *Handler) getSessionFromRequest(requester fosite.AccessRequester) (extendedSession, error) { + session := requester.GetSession() + if jwtSession, ok := session.(extendedSession); !ok { + return nil, errorsx.WithStack( + fosite.ErrServerError.WithHintf("Session must be of type *rfc7523.Session but got type: %T", session), + ) + } else { + return jwtSession, nil + } +} diff --git a/handler/rfc7523/handler_test.go b/handler/rfc7523/handler_test.go new file mode 100644 index 000000000..b53baf721 --- /dev/null +++ b/handler/rfc7523/handler_test.go @@ -0,0 +1,916 @@ +package rfc7523 + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "errors" + "fmt" + mrand "math/rand" + "net/url" + "strconv" + "testing" + "time" + + "github.com/ory/fosite/handler/oauth2" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/suite" + "gopkg.in/square/go-jose.v2" + "gopkg.in/square/go-jose.v2/jwt" + + "github.com/ory/fosite" + "github.com/ory/fosite/internal" +) + +// Define the suite, and absorb the built-in basic suite +// functionality from testify - including a T() method which +// returns the current testing context. +type AuthorizeJWTGrantRequestHandlerTestSuite struct { + suite.Suite + + privateKey *rsa.PrivateKey + mockCtrl *gomock.Controller + mockStore *internal.MockRFC7523KeyStorage + mockAccessTokenStrategy *internal.MockAccessTokenStrategy + mockAccessTokenStore *internal.MockAccessTokenStorage + accessRequest *fosite.AccessRequest + handler *Handler +} + +// Setup before each test in the suite. +func (s *AuthorizeJWTGrantRequestHandlerTestSuite) SetupSuite() { + privateKey, err := rsa.GenerateKey(rand.Reader, 512) // fast RSA for testing + if err != nil { + s.FailNowf("failed to setup test suite", "failed to generate RSA private key: %s", err.Error()) + } + s.privateKey = privateKey +} + +// Will run after all the tests in the suite have been run. +func (s *AuthorizeJWTGrantRequestHandlerTestSuite) TearDownSuite() { +} + +// Will run after each test in the suite. +func (s *AuthorizeJWTGrantRequestHandlerTestSuite) TearDownTest() { + s.mockCtrl.Finish() +} + +// Setup before each test. +func (s *AuthorizeJWTGrantRequestHandlerTestSuite) SetupTest() { + s.mockCtrl = gomock.NewController(s.T()) + s.mockStore = internal.NewMockRFC7523KeyStorage(s.mockCtrl) + s.mockAccessTokenStrategy = internal.NewMockAccessTokenStrategy(s.mockCtrl) + s.mockAccessTokenStore = internal.NewMockAccessTokenStorage(s.mockCtrl) + s.accessRequest = fosite.NewAccessRequest(new(fosite.DefaultSession)) + s.accessRequest.Form = url.Values{} + s.accessRequest.Client = &fosite.DefaultClient{GrantTypes: []string{grantTypeJWTBearer}} + s.handler = &Handler{ + Storage: s.mockStore, + ScopeStrategy: fosite.HierarchicScopeStrategy, + AudienceMatchingStrategy: fosite.DefaultAudienceMatchingStrategy, + TokenURL: "https://www.example.com/token", + SkipClientAuth: false, + JWTIDOptional: false, + JWTIssuedDateOptional: false, + JWTMaxDuration: time.Hour * 24 * 30, + HandleHelper: &oauth2.HandleHelper{ + AccessTokenStrategy: s.mockAccessTokenStrategy, + AccessTokenStorage: s.mockAccessTokenStore, + AccessTokenLifespan: time.Hour, + }, + } +} + +// In order for 'go test' to run this suite, we need to create +// a normal test function and pass our suite to suite.Run. +func TestAuthorizeJWTGrantRequestHandlerTestSuite(t *testing.T) { + suite.Run(t, new(AuthorizeJWTGrantRequestHandlerTestSuite)) +} + +func (s *AuthorizeJWTGrantRequestHandlerTestSuite) TestRequestWithInvalidGrantType() { + // arrange + s.accessRequest.GrantTypes = []string{"authorization_code"} + + // act + err := s.handler.HandleTokenEndpointRequest(context.Background(), s.accessRequest) + + // assert + s.True(errors.Is(err, fosite.ErrUnknownRequest)) + s.EqualError(err, fosite.ErrUnknownRequest.Error(), "expected error, because of invalid grant type") +} + +func (s *AuthorizeJWTGrantRequestHandlerTestSuite) TestClientIsNotRegisteredForGrantType() { + // arrange + s.accessRequest.GrantTypes = []string{grantTypeJWTBearer} + s.accessRequest.Client = &fosite.DefaultClient{GrantTypes: []string{"authorization_code"}} + s.handler.SkipClientAuth = false + + // act + err := s.handler.HandleTokenEndpointRequest(context.Background(), s.accessRequest) + + // assert + s.True(errors.Is(err, fosite.ErrUnauthorizedClient)) + s.EqualError(err, fosite.ErrUnauthorizedClient.Error(), "expected error, because client is not registered to use this grant type") + s.Equal( + "The OAuth 2.0 Client is not allowed to use authorization grant \"urn:ietf:params:oauth:grant-type:jwt-bearer\".", + err.(*fosite.RFC6749Error).HintField, + ) +} + +func (s *AuthorizeJWTGrantRequestHandlerTestSuite) TestRequestWithoutAssertion() { + // arrange + s.accessRequest.GrantTypes = []string{grantTypeJWTBearer} + + // act + err := s.handler.HandleTokenEndpointRequest(context.Background(), s.accessRequest) + + // assert + s.True(errors.Is(err, fosite.ErrInvalidRequest)) + s.EqualError(err, fosite.ErrInvalidRequest.Error(), "expected error, because of missing assertion") + s.Equal( + "The assertion request parameter must be set when using grant_type of 'urn:ietf:params:oauth:grant-type:jwt-bearer'.", + err.(*fosite.RFC6749Error).HintField, + ) +} + +func (s *AuthorizeJWTGrantRequestHandlerTestSuite) TestRequestWithMalformedAssertion() { + // arrange + s.accessRequest.GrantTypes = []string{grantTypeJWTBearer} + s.accessRequest.Form.Add("assertion", "fjigjgfkjgkf") + + // act + err := s.handler.HandleTokenEndpointRequest(context.Background(), s.accessRequest) + + // assert + s.True(errors.Is(err, fosite.ErrInvalidGrant)) + s.EqualError(err, fosite.ErrInvalidGrant.Error(), "expected error, because of malformed assertion") + s.Equal( + "Unable to parse JSON Web Token passed in \"assertion\" request parameter.", + err.(*fosite.RFC6749Error).HintField, + ) +} + +func (s *AuthorizeJWTGrantRequestHandlerTestSuite) TestRequestAssertionWithoutIssuer() { + // arrange + s.accessRequest.GrantTypes = []string{grantTypeJWTBearer} + keyID := "my_key" + cl := s.createStandardClaim() + cl.Issuer = "" + s.accessRequest.Form.Add("assertion", s.createTestAssertion(cl, keyID)) + + // act + err := s.handler.HandleTokenEndpointRequest(context.Background(), s.accessRequest) + + // assert + s.True(errors.Is(err, fosite.ErrInvalidGrant)) + s.EqualError(err, fosite.ErrInvalidGrant.Error(), "expected error, because of missing issuer claim in assertion") + s.Equal( + "The JWT in \"assertion\" request parameter MUST contain an \"iss\" (issuer) claim.", + err.(*fosite.RFC6749Error).HintField, + ) +} + +func (s *AuthorizeJWTGrantRequestHandlerTestSuite) TestRequestAssertionWithoutSubject() { + // arrange + s.accessRequest.GrantTypes = []string{grantTypeJWTBearer} + keyID := "my_key" + cl := s.createStandardClaim() + cl.Subject = "" + s.accessRequest.Form.Add("assertion", s.createTestAssertion(cl, keyID)) + + // act + err := s.handler.HandleTokenEndpointRequest(context.Background(), s.accessRequest) + + // assert + s.True(errors.Is(err, fosite.ErrInvalidGrant)) + s.EqualError(err, fosite.ErrInvalidGrant.Error(), "expected error, because of missing subject claim in assertion") + s.Equal( + "The JWT in \"assertion\" request parameter MUST contain a \"sub\" (subject) claim.", + err.(*fosite.RFC6749Error).HintField, + ) +} + +func (s *AuthorizeJWTGrantRequestHandlerTestSuite) TestNoMatchingPublicKeyToCheckAssertionSignature() { + // arrange + ctx := context.Background() + s.accessRequest.GrantTypes = []string{grantTypeJWTBearer} + cl := s.createStandardClaim() + keyID := "my_key" + s.accessRequest.Form.Add("assertion", s.createTestAssertion(cl, keyID)) + s.mockStore.EXPECT().GetPublicKey(ctx, cl.Issuer, cl.Subject, keyID).Return(nil, fosite.ErrNotFound) + + // act + err := s.handler.HandleTokenEndpointRequest(ctx, s.accessRequest) + + // assert + s.True(errors.Is(err, fosite.ErrInvalidGrant)) + s.EqualError(err, fosite.ErrInvalidGrant.Error(), "expected error, because of missing public key to check assertion") + s.Equal( + fmt.Sprintf( + "No public JWK was registered for issuer \"%s\" and subject \"%s\", and public key is required to check signature of JWT in \"assertion\" request parameter.", + cl.Issuer, cl.Subject, + ), + err.(*fosite.RFC6749Error).HintField, + ) +} + +func (s *AuthorizeJWTGrantRequestHandlerTestSuite) TestNoMatchingPublicKeysToCheckAssertionSignature() { + // arrange + ctx := context.Background() + s.accessRequest.GrantTypes = []string{grantTypeJWTBearer} + keyID := "" // provide no hint of what key was used to sign assertion + cl := s.createStandardClaim() + s.accessRequest.Form.Add("assertion", s.createTestAssertion(cl, keyID)) + s.mockStore.EXPECT().GetPublicKeys(ctx, cl.Issuer, cl.Subject).Return(nil, fosite.ErrNotFound) + + // act + err := s.handler.HandleTokenEndpointRequest(ctx, s.accessRequest) + + // assert + s.True(errors.Is(err, fosite.ErrInvalidGrant)) + s.EqualError(err, fosite.ErrInvalidGrant.Error(), "expected error, because of missing public keys to check assertion") + s.Equal( + fmt.Sprintf( + "No public JWK was registered for issuer \"%s\" and subject \"%s\", and public key is required to check signature of JWT in \"assertion\" request parameter.", + cl.Issuer, cl.Subject, + ), + err.(*fosite.RFC6749Error).HintField, + ) +} + +func (s *AuthorizeJWTGrantRequestHandlerTestSuite) TestWrongPublicKeyToCheckAssertionSignature() { + // arrange + ctx := context.Background() + s.accessRequest.GrantTypes = []string{grantTypeJWTBearer} + keyID := "wrong_key" + cl := s.createStandardClaim() + s.accessRequest.Form.Add("assertion", s.createTestAssertion(cl, keyID)) + jwk := s.createRandomTestJWK() + s.mockStore.EXPECT().GetPublicKey(ctx, cl.Issuer, cl.Subject, keyID).Return(&jwk, nil) + + // act + err := s.handler.HandleTokenEndpointRequest(ctx, s.accessRequest) + + // assert + s.True(errors.Is(err, fosite.ErrInvalidGrant)) + s.EqualError(err, fosite.ErrInvalidGrant.Error(), "expected error, because wrong public key was registered for assertion") + s.Equal("Unable to verify the integrity of the 'assertion' value.", err.(*fosite.RFC6749Error).HintField) +} + +func (s *AuthorizeJWTGrantRequestHandlerTestSuite) TestWrongPublicKeysToCheckAssertionSignature() { + // arrange + ctx := context.Background() + s.accessRequest.GrantTypes = []string{grantTypeJWTBearer} + keyID := "" // provide no hint of what key was used to sign assertion + cl := s.createStandardClaim() + s.accessRequest.Form.Add("assertion", s.createTestAssertion(cl, keyID)) + s.mockStore.EXPECT().GetPublicKeys(ctx, cl.Issuer, cl.Subject).Return(s.createJWS(s.createRandomTestJWK(), s.createRandomTestJWK()), nil) + + // act + err := s.handler.HandleTokenEndpointRequest(ctx, s.accessRequest) + + // assert + s.True(errors.Is(err, fosite.ErrInvalidGrant)) + s.EqualError(err, fosite.ErrInvalidGrant.Error(), "expected error, because wrong public keys was registered for assertion") + s.Equal( + fmt.Sprintf( + "No public JWK was registered for issuer \"%s\" and subject \"%s\", and public key is required to check signature of JWT in \"assertion\" request parameter.", + cl.Issuer, cl.Subject, + ), + err.(*fosite.RFC6749Error).HintField, + ) +} + +func (s *AuthorizeJWTGrantRequestHandlerTestSuite) TestNoAudienceInAssertion() { + // arrange + ctx := context.Background() + s.accessRequest.GrantTypes = []string{grantTypeJWTBearer} + keyID := "my_key" + pubKey := s.createJWK(s.privateKey.Public(), keyID) + cl := s.createStandardClaim() + cl.Audience = []string{} + s.accessRequest.Form.Add("assertion", s.createTestAssertion(cl, keyID)) + s.mockStore.EXPECT().GetPublicKey(ctx, cl.Issuer, cl.Subject, keyID).Return(&pubKey, nil) + + // act + err := s.handler.HandleTokenEndpointRequest(ctx, s.accessRequest) + + // assert + s.True(errors.Is(err, fosite.ErrInvalidGrant)) + s.EqualError(err, fosite.ErrInvalidGrant.Error(), "expected error, because of missing audience claim in assertion") + s.Equal( + "The JWT in \"assertion\" request parameter MUST contain an \"aud\" (audience) claim.", + err.(*fosite.RFC6749Error).HintField, + ) +} + +func (s *AuthorizeJWTGrantRequestHandlerTestSuite) TestNotValidAudienceInAssertion() { + // arrange + ctx := context.Background() + s.accessRequest.GrantTypes = []string{grantTypeJWTBearer} + keyID := "my_key" + pubKey := s.createJWK(s.privateKey.Public(), keyID) + cl := s.createStandardClaim() + cl.Audience = jwt.Audience{"leela", "fry"} + s.accessRequest.Form.Add("assertion", s.createTestAssertion(cl, keyID)) + s.mockStore.EXPECT().GetPublicKey(ctx, cl.Issuer, cl.Subject, keyID).Return(&pubKey, nil) + + // act + err := s.handler.HandleTokenEndpointRequest(ctx, s.accessRequest) + + // assert + s.True(errors.Is(err, fosite.ErrInvalidGrant)) + s.EqualError(err, fosite.ErrInvalidGrant.Error(), "expected error, because of invalid audience claim in assertion") + s.Equal( + fmt.Sprintf( + "The JWT in \"assertion\" request parameter MUST contain an \"aud\" (audience) claim containing a value \"%s\" that identifies the authorization server as an intended audience.", + s.handler.TokenURL, + ), + err.(*fosite.RFC6749Error).HintField, + ) +} + +func (s *AuthorizeJWTGrantRequestHandlerTestSuite) TestNoExpirationInAssertion() { + // arrange + ctx := context.Background() + s.accessRequest.GrantTypes = []string{grantTypeJWTBearer} + keyID := "my_key" + pubKey := s.createJWK(s.privateKey.Public(), keyID) + cl := s.createStandardClaim() + cl.Expiry = nil + s.accessRequest.Form.Add("assertion", s.createTestAssertion(cl, keyID)) + s.mockStore.EXPECT().GetPublicKey(ctx, cl.Issuer, cl.Subject, keyID).Return(&pubKey, nil) + + // act + err := s.handler.HandleTokenEndpointRequest(ctx, s.accessRequest) + + // assert + s.True(errors.Is(err, fosite.ErrInvalidGrant)) + s.EqualError(err, fosite.ErrInvalidGrant.Error(), "expected error, because of missing expiration claim in assertion") + s.Equal( + "The JWT in \"assertion\" request parameter MUST contain an \"exp\" (expiration time) claim.", + err.(*fosite.RFC6749Error).HintField, + ) +} + +func (s *AuthorizeJWTGrantRequestHandlerTestSuite) TestExpiredAssertion() { + // arrange + ctx := context.Background() + s.accessRequest.GrantTypes = []string{grantTypeJWTBearer} + keyID := "my_key" + pubKey := s.createJWK(s.privateKey.Public(), keyID) + cl := s.createStandardClaim() + cl.Expiry = jwt.NewNumericDate(time.Now().AddDate(0, -1, 0)) + s.accessRequest.Form.Add("assertion", s.createTestAssertion(cl, keyID)) + s.mockStore.EXPECT().GetPublicKey(ctx, cl.Issuer, cl.Subject, keyID).Return(&pubKey, nil) + + // act + err := s.handler.HandleTokenEndpointRequest(ctx, s.accessRequest) + + // assert + s.True(errors.Is(err, fosite.ErrInvalidGrant)) + s.EqualError(err, fosite.ErrInvalidGrant.Error(), "expected error, because assertion expired") + s.Equal( + "The JWT in \"assertion\" request parameter expired.", + err.(*fosite.RFC6749Error).HintField, + ) +} + +func (s *AuthorizeJWTGrantRequestHandlerTestSuite) TestAssertionNotAcceptedBeforeDate() { + // arrange + ctx := context.Background() + s.accessRequest.GrantTypes = []string{grantTypeJWTBearer} + keyID := "my_key" + pubKey := s.createJWK(s.privateKey.Public(), keyID) + nbf := time.Now().AddDate(0, 1, 0) + cl := s.createStandardClaim() + cl.NotBefore = jwt.NewNumericDate(nbf) + s.accessRequest.Form.Add("assertion", s.createTestAssertion(cl, keyID)) + s.mockStore.EXPECT().GetPublicKey(ctx, cl.Issuer, cl.Subject, keyID).Return(&pubKey, nil) + + // act + err := s.handler.HandleTokenEndpointRequest(ctx, s.accessRequest) + + // assert + s.True(errors.Is(err, fosite.ErrInvalidGrant)) + s.EqualError(err, fosite.ErrInvalidGrant.Error(), "expected error, nbf claim in assertion indicates, that assertion can not be accepted now") + s.Equal( + fmt.Sprintf( + "The JWT in \"assertion\" request parameter contains an \"nbf\" (not before) claim, that identifies the time '%s' before which the token MUST NOT be accepted.", + nbf.Format(time.RFC3339), + ), + err.(*fosite.RFC6749Error).HintField, + ) +} + +func (s *AuthorizeJWTGrantRequestHandlerTestSuite) TestAssertionWithoutRequiredIssueDate() { + // arrange + ctx := context.Background() + s.accessRequest.GrantTypes = []string{grantTypeJWTBearer} + keyID := "my_key" + pubKey := s.createJWK(s.privateKey.Public(), keyID) + cl := s.createStandardClaim() + cl.IssuedAt = nil + s.handler.JWTIssuedDateOptional = false + s.accessRequest.Form.Add("assertion", s.createTestAssertion(cl, keyID)) + s.mockStore.EXPECT().GetPublicKey(ctx, cl.Issuer, cl.Subject, keyID).Return(&pubKey, nil) + + // act + err := s.handler.HandleTokenEndpointRequest(ctx, s.accessRequest) + + // assert + s.True(errors.Is(err, fosite.ErrInvalidGrant)) + s.EqualError(err, fosite.ErrInvalidGrant.Error(), "expected error, because of missing iat claim in assertion") + s.Equal( + "The JWT in \"assertion\" request parameter MUST contain an \"iat\" (issued at) claim.", + err.(*fosite.RFC6749Error).HintField, + ) +} + +func (s *AuthorizeJWTGrantRequestHandlerTestSuite) TestAssertionWithIssueDateFarInPast() { + // arrange + ctx := context.Background() + s.accessRequest.GrantTypes = []string{grantTypeJWTBearer} + keyID := "my_key" + pubKey := s.createJWK(s.privateKey.Public(), keyID) + issuedAt := time.Now().AddDate(0, 0, -31) + cl := s.createStandardClaim() + cl.IssuedAt = jwt.NewNumericDate(issuedAt) + s.handler.JWTIssuedDateOptional = false + s.handler.JWTMaxDuration = time.Hour * 24 * 30 + s.accessRequest.Form.Add("assertion", s.createTestAssertion(cl, keyID)) + s.mockStore.EXPECT().GetPublicKey(ctx, cl.Issuer, cl.Subject, keyID).Return(&pubKey, nil) + + // act + err := s.handler.HandleTokenEndpointRequest(ctx, s.accessRequest) + + // assert + s.True(errors.Is(err, fosite.ErrInvalidGrant)) + s.EqualError(err, fosite.ErrInvalidGrant.Error(), "expected error, because assertion was issued far in the past") + s.Equal( + fmt.Sprintf( + "The JWT in \"assertion\" request parameter contains an \"exp\" (expiration time) claim with value \"%s\" that is unreasonably far in the future, considering token issued at \"%s\".", + cl.Expiry.Time().Format(time.RFC3339), + cl.IssuedAt.Time().Format(time.RFC3339), + ), + err.(*fosite.RFC6749Error).HintField, + ) +} + +func (s *AuthorizeJWTGrantRequestHandlerTestSuite) TestAssertionWithExpirationDateFarInFuture() { + // arrange + ctx := context.Background() + s.accessRequest.GrantTypes = []string{grantTypeJWTBearer} + keyID := "my_key" + pubKey := s.createJWK(s.privateKey.Public(), keyID) + cl := s.createStandardClaim() + cl.IssuedAt = jwt.NewNumericDate(time.Now().AddDate(0, 0, -15)) + cl.Expiry = jwt.NewNumericDate(time.Now().AddDate(0, 0, 20)) + s.handler.JWTIssuedDateOptional = false + s.handler.JWTMaxDuration = time.Hour * 24 * 30 + s.accessRequest.Form.Add("assertion", s.createTestAssertion(cl, keyID)) + s.mockStore.EXPECT().GetPublicKey(ctx, cl.Issuer, cl.Subject, keyID).Return(&pubKey, nil) + + // act + err := s.handler.HandleTokenEndpointRequest(ctx, s.accessRequest) + + // assert + s.True(errors.Is(err, fosite.ErrInvalidGrant)) + s.EqualError(err, fosite.ErrInvalidGrant.Error(), "expected error, because assertion will expire unreasonably far in the future.") + s.Equal( + fmt.Sprintf( + "The JWT in \"assertion\" request parameter contains an \"exp\" (expiration time) claim with value \"%s\" that is unreasonably far in the future, considering token issued at \"%s\".", + cl.Expiry.Time().Format(time.RFC3339), + cl.IssuedAt.Time().Format(time.RFC3339), + ), + err.(*fosite.RFC6749Error).HintField, + ) +} + +func (s *AuthorizeJWTGrantRequestHandlerTestSuite) TestAssertionWithExpirationDateFarInFutureWithNoIssuerDate() { + // arrange + ctx := context.Background() + s.accessRequest.GrantTypes = []string{grantTypeJWTBearer} + keyID := "my_key" + pubKey := s.createJWK(s.privateKey.Public(), keyID) + cl := s.createStandardClaim() + cl.IssuedAt = nil + cl.Expiry = jwt.NewNumericDate(time.Now().AddDate(0, 0, 31)) + s.handler.JWTIssuedDateOptional = true + s.handler.JWTMaxDuration = time.Hour * 24 * 30 + s.accessRequest.Form.Add("assertion", s.createTestAssertion(cl, keyID)) + s.mockStore.EXPECT().GetPublicKey(ctx, cl.Issuer, cl.Subject, keyID).Return(&pubKey, nil) + + // act + err := s.handler.HandleTokenEndpointRequest(ctx, s.accessRequest) + + // assert + s.True(errors.Is(err, fosite.ErrInvalidGrant)) + s.EqualError(err, fosite.ErrInvalidGrant.Error(), "expected error, because assertion will expire unreasonably far in the future.") +} + +func (s *AuthorizeJWTGrantRequestHandlerTestSuite) TestAssertionWithoutRequiredTokenID() { + // arrange + ctx := context.Background() + s.accessRequest.GrantTypes = []string{grantTypeJWTBearer} + keyID := "my_key" + pubKey := s.createJWK(s.privateKey.Public(), keyID) + cl := s.createStandardClaim() + cl.ID = "" + s.accessRequest.Form.Add("assertion", s.createTestAssertion(cl, keyID)) + s.mockStore.EXPECT().GetPublicKey(ctx, cl.Issuer, cl.Subject, keyID).Return(&pubKey, nil) + + // act + err := s.handler.HandleTokenEndpointRequest(ctx, s.accessRequest) + + // assert + s.True(errors.Is(err, fosite.ErrInvalidGrant)) + s.EqualError(err, fosite.ErrInvalidGrant.Error(), "expected error, because of missing jti claim in assertion") + s.Equal( + "The JWT in \"assertion\" request parameter MUST contain an \"jti\" (JWT ID) claim.", + err.(*fosite.RFC6749Error).HintField, + ) +} + +func (s *AuthorizeJWTGrantRequestHandlerTestSuite) TestAssertionAlreadyUsed() { + // arrange + ctx := context.Background() + s.accessRequest.GrantTypes = []string{grantTypeJWTBearer} + keyID := "my_key" + pubKey := s.createJWK(s.privateKey.Public(), keyID) + cl := s.createStandardClaim() + s.accessRequest.Form.Add("assertion", s.createTestAssertion(cl, keyID)) + s.mockStore.EXPECT().GetPublicKey(ctx, cl.Issuer, cl.Subject, keyID).Return(&pubKey, nil) + s.mockStore.EXPECT().IsJWTUsed(ctx, cl.ID).Return(true, nil) + + // act + err := s.handler.HandleTokenEndpointRequest(ctx, s.accessRequest) + + // assert + s.True(errors.Is(err, fosite.ErrJTIKnown)) + s.EqualError(err, fosite.ErrJTIKnown.Error(), "expected error, because assertion was used") +} + +func (s *AuthorizeJWTGrantRequestHandlerTestSuite) TestErrWhenCheckingIfJWTWasUsed() { + // arrange + ctx := context.Background() + s.accessRequest.GrantTypes = []string{grantTypeJWTBearer} + keyID := "my_key" + pubKey := s.createJWK(s.privateKey.Public(), keyID) + cl := s.createStandardClaim() + s.accessRequest.Form.Add("assertion", s.createTestAssertion(cl, keyID)) + s.mockStore.EXPECT().GetPublicKey(ctx, cl.Issuer, cl.Subject, keyID).Return(&pubKey, nil) + s.mockStore.EXPECT().IsJWTUsed(ctx, cl.ID).Return(false, fosite.ErrServerError) + + // act + err := s.handler.HandleTokenEndpointRequest(ctx, s.accessRequest) + + // assert + s.True(errors.Is(err, fosite.ErrServerError)) + s.EqualError(err, fosite.ErrServerError.Error(), "expected error, because error occurred while trying to check if jwt was used") +} + +func (s *AuthorizeJWTGrantRequestHandlerTestSuite) TestErrWhenMarkingJWTAsUsed() { + // arrange + ctx := context.Background() + s.accessRequest.GrantTypes = []string{grantTypeJWTBearer} + keyID := "my_key" + pubKey := s.createJWK(s.privateKey.Public(), keyID) + cl := s.createStandardClaim() + s.accessRequest.Form.Add("assertion", s.createTestAssertion(cl, keyID)) + s.mockStore.EXPECT().GetPublicKey(ctx, cl.Issuer, cl.Subject, keyID).Return(&pubKey, nil) + s.mockStore.EXPECT().GetPublicKeyScopes(ctx, cl.Issuer, cl.Subject, keyID).Return([]string{"valid_scope"}, nil) + s.mockStore.EXPECT().IsJWTUsed(ctx, cl.ID).Return(false, nil) + s.mockStore.EXPECT().MarkJWTUsedForTime(ctx, cl.ID, cl.Expiry.Time()).Return(fosite.ErrServerError) + + // act + err := s.handler.HandleTokenEndpointRequest(ctx, s.accessRequest) + + // assert + s.True(errors.Is(err, fosite.ErrServerError)) + s.EqualError(err, fosite.ErrServerError.Error(), "expected error, because error occurred while trying to mark jwt as used") +} + +func (s *AuthorizeJWTGrantRequestHandlerTestSuite) TestErrWhileFetchingPublicKeyScope() { + // arrange + ctx := context.Background() + s.accessRequest.GrantTypes = []string{grantTypeJWTBearer} + keyID := "my_key" + pubKey := s.createJWK(s.privateKey.Public(), keyID) + cl := s.createStandardClaim() + + s.accessRequest.Form.Add("assertion", s.createTestAssertion(cl, keyID)) + s.mockStore.EXPECT().GetPublicKey(ctx, cl.Issuer, cl.Subject, keyID).Return(&pubKey, nil) + s.mockStore.EXPECT().GetPublicKeyScopes(ctx, cl.Issuer, cl.Subject, keyID).Return([]string{}, fosite.ErrServerError) + s.mockStore.EXPECT().IsJWTUsed(ctx, cl.ID).Return(false, nil) + + // act + err := s.handler.HandleTokenEndpointRequest(ctx, s.accessRequest) + + // assert + s.True(errors.Is(err, fosite.ErrServerError)) + s.EqualError(err, fosite.ErrServerError.Error(), "expected error, because error occurred while fetching public key scopes") +} + +func (s *AuthorizeJWTGrantRequestHandlerTestSuite) TestAssertionWithInvalidScopes() { + // arrange + ctx := context.Background() + s.accessRequest.GrantTypes = []string{grantTypeJWTBearer} + keyID := "my_key" + pubKey := s.createJWK(s.privateKey.Public(), keyID) + cl := s.createStandardClaim() + + s.accessRequest.Form.Add("assertion", s.createTestAssertion(cl, keyID)) + s.accessRequest.RequestedScope = []string{"some_scope"} + s.mockStore.EXPECT().GetPublicKey(ctx, cl.Issuer, cl.Subject, keyID).Return(&pubKey, nil) + s.mockStore.EXPECT().GetPublicKeyScopes(ctx, cl.Issuer, cl.Subject, keyID).Return([]string{"valid_scope"}, nil) + s.mockStore.EXPECT().IsJWTUsed(ctx, cl.ID).Return(false, nil) + + // act + err := s.handler.HandleTokenEndpointRequest(ctx, s.accessRequest) + + // assert + s.True(errors.Is(err, fosite.ErrInvalidScope)) + s.EqualError(err, fosite.ErrInvalidScope.Error(), "expected error, because requested scopes don't match allowed scope for this assertion") + s.Equal( + "The public key registered for issuer \"trusted_issuer\" and subject \"some_ro\" is not allowed to request scope \"some_scope\".", + err.(*fosite.RFC6749Error).HintField, + ) +} + +func (s *AuthorizeJWTGrantRequestHandlerTestSuite) TestValidAssertion() { + // arrange + ctx := context.Background() + s.accessRequest.GrantTypes = []string{grantTypeJWTBearer} + keyID := "my_key" + pubKey := s.createJWK(s.privateKey.Public(), keyID) + cl := s.createStandardClaim() + + s.accessRequest.Form.Add("assertion", s.createTestAssertion(cl, keyID)) + s.accessRequest.RequestedScope = []string{"valid_scope"} + s.mockStore.EXPECT().GetPublicKey(ctx, cl.Issuer, cl.Subject, keyID).Return(&pubKey, nil) + s.mockStore.EXPECT().GetPublicKeyScopes(ctx, cl.Issuer, cl.Subject, keyID).Return([]string{"valid_scope", "openid"}, nil) + s.mockStore.EXPECT().IsJWTUsed(ctx, cl.ID).Return(false, nil) + s.mockStore.EXPECT().MarkJWTUsedForTime(ctx, cl.ID, cl.Expiry.Time()).Return(nil) + + // act + err := s.handler.HandleTokenEndpointRequest(ctx, s.accessRequest) + + // assert + s.NoError(err, "no error expected, because assertion must be valid") +} + +func (s *AuthorizeJWTGrantRequestHandlerTestSuite) TestAssertionIsValidWhenNoScopesPassed() { + // arrange + ctx := context.Background() + s.accessRequest.GrantTypes = []string{grantTypeJWTBearer} + keyID := "my_key" + pubKey := s.createJWK(s.privateKey.Public(), keyID) + cl := s.createStandardClaim() + s.accessRequest.Form.Add("assertion", s.createTestAssertion(cl, keyID)) + s.mockStore.EXPECT().GetPublicKey(ctx, cl.Issuer, cl.Subject, keyID).Return(&pubKey, nil) + s.mockStore.EXPECT().GetPublicKeyScopes(ctx, cl.Issuer, cl.Subject, keyID).Return([]string{"valid_scope"}, nil) + s.mockStore.EXPECT().IsJWTUsed(ctx, cl.ID).Return(false, nil) + s.mockStore.EXPECT().MarkJWTUsedForTime(ctx, cl.ID, cl.Expiry.Time()).Return(nil) + + // act + err := s.handler.HandleTokenEndpointRequest(ctx, s.accessRequest) + + // assert + s.NoError(err, "no error expected, because assertion must be valid") +} + +func (s *AuthorizeJWTGrantRequestHandlerTestSuite) TestAssertionIsValidWhenJWTIDIsOptional() { + // arrange + ctx := context.Background() + s.accessRequest.GrantTypes = []string{grantTypeJWTBearer} + keyID := "my_key" + pubKey := s.createJWK(s.privateKey.Public(), keyID) + cl := s.createStandardClaim() + s.handler.JWTIDOptional = true + cl.ID = "" + s.accessRequest.Form.Add("assertion", s.createTestAssertion(cl, keyID)) + s.mockStore.EXPECT().GetPublicKey(ctx, cl.Issuer, cl.Subject, keyID).Return(&pubKey, nil) + s.mockStore.EXPECT().GetPublicKeyScopes(ctx, cl.Issuer, cl.Subject, keyID).Return([]string{"valid_scope"}, nil) + + // act + err := s.handler.HandleTokenEndpointRequest(ctx, s.accessRequest) + + // assert + s.NoError(err, "no error expected, because assertion must be valid, when no jti claim and it is allowed by option") +} + +func (s *AuthorizeJWTGrantRequestHandlerTestSuite) TestAssertionIsValidWhenJWTIssuedDateOptional() { + // arrange + ctx := context.Background() + s.accessRequest.GrantTypes = []string{grantTypeJWTBearer} + keyID := "my_key" + pubKey := s.createJWK(s.privateKey.Public(), keyID) + cl := s.createStandardClaim() + cl.IssuedAt = nil + s.handler.JWTIssuedDateOptional = true + s.accessRequest.Form.Add("assertion", s.createTestAssertion(cl, keyID)) + s.mockStore.EXPECT().GetPublicKey(ctx, cl.Issuer, cl.Subject, keyID).Return(&pubKey, nil) + s.mockStore.EXPECT().GetPublicKeyScopes(ctx, cl.Issuer, cl.Subject, keyID).Return([]string{"valid_scope"}, nil) + s.mockStore.EXPECT().IsJWTUsed(ctx, cl.ID).Return(false, nil) + s.mockStore.EXPECT().MarkJWTUsedForTime(ctx, cl.ID, cl.Expiry.Time()).Return(nil) + + // act + err := s.handler.HandleTokenEndpointRequest(ctx, s.accessRequest) + + // assert + s.NoError(err, "no error expected, because assertion must be valid, when no iss claim and it is allowed by option") +} + +func (s *AuthorizeJWTGrantRequestHandlerTestSuite) TestRequestIsValidWhenClientAuthOptional() { + // arrange + ctx := context.Background() + s.accessRequest.GrantTypes = []string{grantTypeJWTBearer} + keyID := "my_key" + pubKey := s.createJWK(s.privateKey.Public(), keyID) + cl := s.createStandardClaim() + s.accessRequest.Client = &fosite.DefaultClient{} + s.handler.SkipClientAuth = true + s.accessRequest.Form.Add("assertion", s.createTestAssertion(cl, keyID)) + s.mockStore.EXPECT().GetPublicKey(ctx, cl.Issuer, cl.Subject, keyID).Return(&pubKey, nil) + s.mockStore.EXPECT().GetPublicKeyScopes(ctx, cl.Issuer, cl.Subject, keyID).Return([]string{"valid_scope"}, nil) + s.mockStore.EXPECT().IsJWTUsed(ctx, cl.ID).Return(false, nil) + s.mockStore.EXPECT().MarkJWTUsedForTime(ctx, cl.ID, cl.Expiry.Time()).Return(nil) + + // act + err := s.handler.HandleTokenEndpointRequest(ctx, s.accessRequest) + + // assert + s.NoError(err, "no error expected, because request must be valid, when no client unauthenticated and it is allowed by option") +} + +func (s *AuthorizeJWTGrantRequestHandlerTestSuite) createTestAssertion(cl jwt.Claims, keyID string) string { + jwk := jose.JSONWebKey{Key: s.privateKey, KeyID: keyID, Algorithm: string(jose.RS256)} + sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.RS256, Key: jwk}, (&jose.SignerOptions{}).WithType("JWT")) + if err != nil { + s.FailNowf("failed to create test assertion", "failed to create signer: %s", err.Error()) + } + + raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize() + if err != nil { + s.FailNowf("failed to create test assertion", "failed to sign assertion: %s", err.Error()) + } + + return raw +} + +func (s *AuthorizeJWTGrantRequestHandlerTestSuite) createStandardClaim() jwt.Claims { + return jwt.Claims{ + Issuer: "trusted_issuer", + Subject: "some_ro", + Audience: jwt.Audience{"https://www.example.com/token", "leela", "fry"}, + Expiry: jwt.NewNumericDate(time.Now().AddDate(0, 0, 23)), + NotBefore: nil, + IssuedAt: jwt.NewNumericDate(time.Now().AddDate(0, 0, -7)), + ID: "my_token", + } +} + +func (s *AuthorizeJWTGrantRequestHandlerTestSuite) createRandomTestJWK() jose.JSONWebKey { + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + s.FailNowf("failed to create random test JWK", "failed to generate RSA private key: %s", err.Error()) + } + + return s.createJWK(privateKey.Public(), strconv.Itoa(mrand.Int())) +} + +func (s *AuthorizeJWTGrantRequestHandlerTestSuite) createJWK(key interface{}, keyID string) jose.JSONWebKey { + return jose.JSONWebKey{ + Key: key, + KeyID: keyID, + Algorithm: string(jose.RS256), + Use: "sig", + } +} + +func (s *AuthorizeJWTGrantRequestHandlerTestSuite) createJWS(keys ...jose.JSONWebKey) *jose.JSONWebKeySet { + return &jose.JSONWebKeySet{Keys: keys} +} + +// Define the suite, and absorb the built-in basic suite +// functionality from testify - including a T() method which +// returns the current testing context. +type AuthorizeJWTGrantPopulateTokenEndpointTestSuite struct { + suite.Suite + + privateKey *rsa.PrivateKey + mockCtrl *gomock.Controller + mockStore *internal.MockRFC7523KeyStorage + mockAccessTokenStrategy *internal.MockAccessTokenStrategy + mockAccessTokenStore *internal.MockAccessTokenStorage + accessRequest *fosite.AccessRequest + accessResponse *fosite.AccessResponse + handler *Handler +} + +// Setup before each test in the suite. +func (s *AuthorizeJWTGrantPopulateTokenEndpointTestSuite) SetupSuite() { + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + s.FailNowf("failed to setup test suite", "failed to generate RSA private key: %s", err.Error()) + } + s.privateKey = privateKey +} + +// Will run after all the tests in the suite have been run. +func (s *AuthorizeJWTGrantPopulateTokenEndpointTestSuite) TearDownSuite() { +} + +// Will run after each test in the suite. +func (s *AuthorizeJWTGrantPopulateTokenEndpointTestSuite) TearDownTest() { + s.mockCtrl.Finish() +} + +// Setup before each test. +func (s *AuthorizeJWTGrantPopulateTokenEndpointTestSuite) SetupTest() { + s.mockCtrl = gomock.NewController(s.T()) + s.mockStore = internal.NewMockRFC7523KeyStorage(s.mockCtrl) + s.mockAccessTokenStrategy = internal.NewMockAccessTokenStrategy(s.mockCtrl) + s.mockAccessTokenStore = internal.NewMockAccessTokenStorage(s.mockCtrl) + s.accessRequest = fosite.NewAccessRequest(new(fosite.DefaultSession)) + s.accessRequest.Form = url.Values{} + s.accessRequest.Client = &fosite.DefaultClient{GrantTypes: []string{grantTypeJWTBearer}} + s.accessResponse = fosite.NewAccessResponse() + s.handler = &Handler{ + Storage: s.mockStore, + ScopeStrategy: fosite.HierarchicScopeStrategy, + AudienceMatchingStrategy: fosite.DefaultAudienceMatchingStrategy, + TokenURL: "https://www.example.com/token", + SkipClientAuth: false, + JWTIDOptional: false, + JWTIssuedDateOptional: false, + JWTMaxDuration: time.Hour * 24 * 30, + HandleHelper: &oauth2.HandleHelper{ + AccessTokenStrategy: s.mockAccessTokenStrategy, + AccessTokenStorage: s.mockAccessTokenStore, + AccessTokenLifespan: time.Hour, + }, + } +} + +// In order for 'go test' to run this suite, we need to create +// a normal test function and pass our suite to suite.Run. +func TestAuthorizeJWTGrantPopulateTokenEndpointTestSuite(t *testing.T) { + suite.Run(t, new(AuthorizeJWTGrantPopulateTokenEndpointTestSuite)) +} + +func (s *AuthorizeJWTGrantPopulateTokenEndpointTestSuite) TestRequestWithInvalidGrantType() { + // arrange + s.accessRequest.GrantTypes = []string{"authorization_code"} + + // act + err := s.handler.PopulateTokenEndpointResponse(context.Background(), s.accessRequest, s.accessResponse) + + // assert + s.True(errors.Is(err, fosite.ErrUnknownRequest)) + s.EqualError(err, fosite.ErrUnknownRequest.Error(), "expected error, because of invalid grant type") +} + +func (s *AuthorizeJWTGrantPopulateTokenEndpointTestSuite) TestClientIsNotRegisteredForGrantType() { + // arrange + s.accessRequest.GrantTypes = []string{grantTypeJWTBearer} + s.accessRequest.Client = &fosite.DefaultClient{GrantTypes: []string{"authorization_code"}} + s.handler.SkipClientAuth = false + + // act + err := s.handler.PopulateTokenEndpointResponse(context.Background(), s.accessRequest, s.accessResponse) + + // assert + s.True(errors.Is(err, fosite.ErrUnauthorizedClient)) + s.EqualError(err, fosite.ErrUnauthorizedClient.Error(), "expected error, because client is not registered to use this grant type") + s.Equal( + "The OAuth 2.0 Client is not allowed to use authorization grant \"urn:ietf:params:oauth:grant-type:jwt-bearer\".", + err.(*fosite.RFC6749Error).HintField, + ) +} + +func (s *AuthorizeJWTGrantPopulateTokenEndpointTestSuite) TestAccessTokenIssuedSuccessfully() { + // arrange + ctx := context.Background() + s.accessRequest.GrantTypes = []string{grantTypeJWTBearer} + token := "token" + sig := "sig" + s.mockAccessTokenStrategy.EXPECT().GenerateAccessToken(ctx, s.accessRequest).Return(token, sig, nil) + s.mockAccessTokenStore.EXPECT().CreateAccessTokenSession(ctx, sig, s.accessRequest.Sanitize([]string{})) + + // act + err := s.handler.PopulateTokenEndpointResponse(context.Background(), s.accessRequest, s.accessResponse) + + // assert + s.NoError(err, "no error expected") + s.Equal(s.accessResponse.AccessToken, token, "access token expected in response") + s.Equal(s.accessResponse.TokenType, "bearer", "token type expected to be \"bearer\"") + s.Equal( + s.accessResponse.GetExtra("expires_in"), int64(s.handler.HandleHelper.AccessTokenLifespan.Seconds()), + "token expiration time expected in response to be equal to AccessTokenLifespan setting in handler", + ) + s.Equal(s.accessResponse.GetExtra("scope"), "", "no scopes expected in response") + s.Nil(s.accessResponse.GetExtra("refresh_token"), "refresh token not expected in response") +} diff --git a/handler/rfc7523/session.go b/handler/rfc7523/session.go new file mode 100644 index 000000000..dbc2a95c8 --- /dev/null +++ b/handler/rfc7523/session.go @@ -0,0 +1,28 @@ +/* + * Copyright © 2015-2018 Aeneas Rekkas + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * @author Aeneas Rekkas + * @copyright 2015-2018 Aeneas Rekkas + * @license Apache-2.0 + * + */ + +package rfc7523 + +// Session must be implemented by the session if RFC7523 is to be supported. +type Session interface { + // SetSubject sets the session's subject. + SetSubject(subject string) +} diff --git a/handler/rfc7523/storage.go b/handler/rfc7523/storage.go new file mode 100644 index 000000000..f00fd17a5 --- /dev/null +++ b/handler/rfc7523/storage.go @@ -0,0 +1,51 @@ +/* + * Copyright © 2015-2018 Aeneas Rekkas + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * @author Aeneas Rekkas + * @copyright 2015-2018 Aeneas Rekkas + * @license Apache-2.0 + * + */ + +package rfc7523 + +import ( + "context" + "time" + + "gopkg.in/square/go-jose.v2" +) + +// RFC7523KeyStorage holds information needed to validate jwt assertion in authorization grants. +type RFC7523KeyStorage interface { + // GetPublicKey returns public key, issued by 'issuer', and assigned for subject. Public key is used to check + // signature of jwt assertion in authorization grants. + GetPublicKey(ctx context.Context, issuer string, subject string, keyId string) (*jose.JSONWebKey, error) + + // GetPublicKeys returns public key, set issued by 'issuer', and assigned for subject. + GetPublicKeys(ctx context.Context, issuer string, subject string) (*jose.JSONWebKeySet, error) + + // GetPublicKeyScopes returns assigned scope for assertion, identified by public key, issued by 'issuer'. + GetPublicKeyScopes(ctx context.Context, issuer string, subject string, keyId string) ([]string, error) + + // IsJWTUsed returns true, if JWT is not known yet or it can not be considered valid, because it must be already + // expired. + IsJWTUsed(ctx context.Context, jti string) (bool, error) + + // MarkJWTUsedForTime marks JWT as used for a time passed in exp parameter. This helps ensure that JWTs are not + // replayed by maintaining the set of used "jti" values for the length of time for which the JWT would be + // considered valid based on the applicable "exp" instant. (https://tools.ietf.org/html/rfc7523#section-3) + MarkJWTUsedForTime(ctx context.Context, jti string, exp time.Time) error +} diff --git a/integration/authorize_jwt_bearer_required_iat_test.go b/integration/authorize_jwt_bearer_required_iat_test.go new file mode 100644 index 000000000..5e09c7695 --- /dev/null +++ b/integration/authorize_jwt_bearer_required_iat_test.go @@ -0,0 +1,129 @@ +/* + * Copyright © 2015-2018 Aeneas Rekkas + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * @author Aeneas Rekkas + * @copyright 2015-2018 Aeneas Rekkas + * @license Apache-2.0 + * + */ + +package integration_test + +import ( + "context" + "net/http" + "testing" + "time" + + "github.com/pborman/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" + "gopkg.in/square/go-jose.v2/jwt" + + "github.com/ory/fosite" + "github.com/ory/fosite/compose" + "github.com/ory/fosite/integration/clients" +) + +type authorizeJWTBearerRequiredIATSuite struct { + suite.Suite + + client *clients.JWTBearer +} + +func (s *authorizeJWTBearerRequiredIATSuite) TestBadResponseWithoutIssuedAt() { + ctx := context.Background() + client := s.getClient() + token, err := client.GetToken(ctx, &clients.JWTBearerPayload{ + Claims: &jwt.Claims{ + Issuer: firstJWTBearerIssuer, + Subject: firstJWTBearerSubject, + Audience: []string{tokenURL}, + Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), + ID: uuid.New(), + }, + }, []string{"fosite"}) + + s.assertBadResponse(s.T(), token, err) +} + +func (s *authorizeJWTBearerRequiredIATSuite) TestSuccessResponseWithIssuedAt() { + ctx := context.Background() + client := s.getClient() + token, err := client.GetToken(ctx, &clients.JWTBearerPayload{ + Claims: &jwt.Claims{ + Issuer: firstJWTBearerIssuer, + Subject: firstJWTBearerSubject, + Audience: []string{tokenURL}, + Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + ID: uuid.New(), + }, + }, []string{"fosite"}) + + s.assertSuccessResponse(s.T(), token, err) +} + +func (s *authorizeJWTBearerRequiredIATSuite) getClient() *clients.JWTBearer { + client := *s.client + + return &client +} + +func (s *authorizeJWTBearerRequiredIATSuite) assertSuccessResponse(t *testing.T, token *clients.Token, err error) { + assert.Nil(t, err) + assert.NotNil(t, token) + + assert.Equal(t, token.TokenType, "bearer") + assert.Empty(t, token.RefreshToken) + assert.NotEmpty(t, token.ExpiresIn) + assert.NotEmpty(t, token.AccessToken) +} + +func (s *authorizeJWTBearerRequiredIATSuite) assertBadResponse(t *testing.T, token *clients.Token, err error) { + assert.Nil(t, token) + assert.NotNil(t, err) + + retrieveError, ok := err.(*clients.RequestError) + assert.True(t, ok) + assert.Equal(t, retrieveError.Response.StatusCode, http.StatusBadRequest) +} + +func TestAuthorizeJWTBearerRequiredIATSuite(t *testing.T) { + provider := compose.Compose( + &compose.Config{ + GrantTypeJWTBearerCanSkipClientAuth: true, + GrantTypeJWTBearerIDOptional: true, + GrantTypeJWTBearerIssuedDateOptional: false, + TokenURL: tokenURL, + }, + fositeStore, + jwtStrategy, + nil, + compose.OAuth2ClientCredentialsGrantFactory, + compose.RFC7523AssertionGrantFactory, + ) + testServer := mockServer(t, provider, &fosite.DefaultSession{}) + defer testServer.Close() + + client := newJWTBearerAppClient(testServer) + if err := client.SetPrivateKey(firstKeyID, firstPrivateKey); err != nil { + assert.Nil(t, err) + } + + suite.Run(t, &authorizeJWTBearerRequiredIATSuite{ + client: client, + }) +} diff --git a/integration/authorize_jwt_bearer_required_jti_test.go b/integration/authorize_jwt_bearer_required_jti_test.go new file mode 100644 index 000000000..4dc6e8cb9 --- /dev/null +++ b/integration/authorize_jwt_bearer_required_jti_test.go @@ -0,0 +1,129 @@ +/* + * Copyright © 2015-2018 Aeneas Rekkas + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * @author Aeneas Rekkas + * @copyright 2015-2018 Aeneas Rekkas + * @license Apache-2.0 + * + */ + +package integration_test + +import ( + "context" + "net/http" + "testing" + "time" + + "github.com/pborman/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" + "gopkg.in/square/go-jose.v2/jwt" + + "github.com/ory/fosite" + "github.com/ory/fosite/compose" + "github.com/ory/fosite/integration/clients" +) + +type authorizeJWTBearerRequiredJtiSuite struct { + suite.Suite + + client *clients.JWTBearer +} + +func (s *authorizeJWTBearerRequiredJtiSuite) TestBadResponseWithoutJTI() { + ctx := context.Background() + client := s.getClient() + token, err := client.GetToken(ctx, &clients.JWTBearerPayload{ + Claims: &jwt.Claims{ + Issuer: firstJWTBearerIssuer, + Subject: firstJWTBearerSubject, + Audience: []string{tokenURL}, + Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + }, + }, []string{"fosite"}) + + s.assertBadResponse(s.T(), token, err) +} + +func (s *authorizeJWTBearerRequiredJtiSuite) TestSuccessResponseWithJTI() { + ctx := context.Background() + client := s.getClient() + token, err := client.GetToken(ctx, &clients.JWTBearerPayload{ + Claims: &jwt.Claims{ + Issuer: firstJWTBearerIssuer, + Subject: firstJWTBearerSubject, + Audience: []string{tokenURL}, + Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + ID: uuid.New(), + }, + }, []string{"fosite"}) + + s.assertSuccessResponse(s.T(), token, err) +} + +func (s *authorizeJWTBearerRequiredJtiSuite) getClient() *clients.JWTBearer { + client := *s.client + + return &client +} + +func (s *authorizeJWTBearerRequiredJtiSuite) assertSuccessResponse(t *testing.T, token *clients.Token, err error) { + assert.Nil(t, err) + assert.NotNil(t, token) + + assert.Equal(t, token.TokenType, "bearer") + assert.Empty(t, token.RefreshToken) + assert.NotEmpty(t, token.ExpiresIn) + assert.NotEmpty(t, token.AccessToken) +} + +func (s *authorizeJWTBearerRequiredJtiSuite) assertBadResponse(t *testing.T, token *clients.Token, err error) { + assert.Nil(t, token) + assert.NotNil(t, err) + + retrieveError, ok := err.(*clients.RequestError) + assert.True(t, ok) + assert.Equal(t, retrieveError.Response.StatusCode, http.StatusBadRequest) +} + +func TestAuthorizeJWTBearerRequiredJtiSuite(t *testing.T) { + provider := compose.Compose( + &compose.Config{ + GrantTypeJWTBearerCanSkipClientAuth: true, + GrantTypeJWTBearerIDOptional: false, + GrantTypeJWTBearerIssuedDateOptional: true, + TokenURL: tokenURL, + }, + fositeStore, + jwtStrategy, + nil, + compose.OAuth2ClientCredentialsGrantFactory, + compose.RFC7523AssertionGrantFactory, + ) + testServer := mockServer(t, provider, &fosite.DefaultSession{}) + defer testServer.Close() + + client := newJWTBearerAppClient(testServer) + if err := client.SetPrivateKey(firstKeyID, firstPrivateKey); err != nil { + assert.Nil(t, err) + } + + suite.Run(t, &authorizeJWTBearerRequiredJtiSuite{ + client: client, + }) +} diff --git a/integration/authorize_jwt_bearer_test.go b/integration/authorize_jwt_bearer_test.go new file mode 100644 index 000000000..cec1365a7 --- /dev/null +++ b/integration/authorize_jwt_bearer_test.go @@ -0,0 +1,446 @@ +/* + * Copyright © 2015-2018 Aeneas Rekkas + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * @author Aeneas Rekkas + * @copyright 2015-2018 Aeneas Rekkas + * @license Apache-2.0 + * + */ + +package integration_test + +import ( + "context" + "net/http" + "testing" + "time" + + "github.com/pborman/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" + "gopkg.in/square/go-jose.v2/jwt" + + "github.com/ory/fosite" + "github.com/ory/fosite/compose" + "github.com/ory/fosite/integration/clients" +) + +type authorizeJWTBearerSuite struct { + suite.Suite + + client *clients.JWTBearer +} + +func (s *authorizeJWTBearerSuite) TestSuccessResponseWithRequiredParamsOnly() { + ctx := context.Background() + client := s.getClient() + token, err := client.GetToken(ctx, &clients.JWTBearerPayload{ + Claims: &jwt.Claims{ + Issuer: firstJWTBearerIssuer, + Subject: firstJWTBearerSubject, + Audience: []string{tokenURL}, + Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), + }, + }, []string{"fosite"}) + + s.assertSuccessResponse(s.T(), token, err) +} + +func (s *authorizeJWTBearerSuite) TestSuccessResponseWithMultipleAudienceInAssertion() { + ctx := context.Background() + client := s.getClient() + token, err := client.GetToken(ctx, &clients.JWTBearerPayload{ + Claims: &jwt.Claims{ + Issuer: firstJWTBearerIssuer, + Subject: firstJWTBearerSubject, + Audience: []string{tokenURL, "https://example.com/oauth"}, + Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + }, + }, []string{"fosite"}) + + s.assertSuccessResponse(s.T(), token, err) +} + +func (s *authorizeJWTBearerSuite) TestSuccessResponseWithMultipleScopesInRequest() { + ctx := context.Background() + client := s.getClient() + token, err := client.GetToken(ctx, &clients.JWTBearerPayload{ + Claims: &jwt.Claims{ + Issuer: firstJWTBearerIssuer, + Subject: firstJWTBearerSubject, + Audience: []string{tokenURL}, + Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + }, + }, []string{"fosite", "gitlab"}) + + s.assertSuccessResponse(s.T(), token, err) +} + +func (s *authorizeJWTBearerSuite) TestSuccessResponseWithoutScopes() { + ctx := context.Background() + client := s.getClient() + token, err := client.GetToken(ctx, &clients.JWTBearerPayload{ + Claims: &jwt.Claims{ + Issuer: firstJWTBearerIssuer, + Subject: firstJWTBearerSubject, + Audience: []string{tokenURL}, + Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + }, + }, nil) + + s.assertSuccessResponse(s.T(), token, err) +} + +func (s *authorizeJWTBearerSuite) TestSuccessResponseWithExtraClaim() { + ctx := context.Background() + client := s.getClient() + token, err := client.GetToken(ctx, &clients.JWTBearerPayload{ + Claims: &jwt.Claims{ + Issuer: firstJWTBearerIssuer, + Subject: firstJWTBearerSubject, + Audience: []string{tokenURL}, + Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + }, + PrivateClaims: map[string]interface{}{"extraClaim": "extraClaimValue"}, + }, []string{"fosite"}) + + s.assertSuccessResponse(s.T(), token, err) +} + +func (s *authorizeJWTBearerSuite) TestSuccessResponseWithNotBeforeClaim() { + ctx := context.Background() + client := s.getClient() + token, err := client.GetToken(ctx, &clients.JWTBearerPayload{ + Claims: &jwt.Claims{ + Issuer: firstJWTBearerIssuer, + Subject: firstJWTBearerSubject, + Audience: []string{tokenURL}, + Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + NotBefore: jwt.NewNumericDate(time.Now()), + }, + }, []string{"fosite"}) + + s.assertSuccessResponse(s.T(), token, err) +} + +func (s *authorizeJWTBearerSuite) TestSuccessResponseWithJTIClaim() { + ctx := context.Background() + client := s.getClient() + token, err := client.GetToken(ctx, &clients.JWTBearerPayload{ + Claims: &jwt.Claims{ + Issuer: firstJWTBearerIssuer, + Subject: firstJWTBearerSubject, + Audience: []string{tokenURL}, + Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + ID: uuid.New(), + }, + }, []string{"fosite"}) + + s.assertSuccessResponse(s.T(), token, err) +} + +func (s *authorizeJWTBearerSuite) TestSuccessResponse() { + ctx := context.Background() + client := s.getClient() + token, err := client.GetToken(ctx, &clients.JWTBearerPayload{ + Claims: &jwt.Claims{ + Issuer: firstJWTBearerIssuer, + Subject: firstJWTBearerSubject, + Audience: []string{tokenURL, "example.com"}, + Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + NotBefore: jwt.NewNumericDate(time.Now().Add(-time.Hour)), + ID: uuid.New(), + }, + PrivateClaims: map[string]interface{}{"random": "random"}, + }, nil) + + s.assertSuccessResponse(s.T(), token, err) +} + +func (s *authorizeJWTBearerSuite) TestBadResponseWithExpiredJWT() { + ctx := context.Background() + client := s.getClient() + token, err := client.GetToken(ctx, &clients.JWTBearerPayload{ + Claims: &jwt.Claims{ + Issuer: firstJWTBearerIssuer, + Subject: firstJWTBearerSubject, + Audience: []string{tokenURL}, + Expiry: jwt.NewNumericDate(time.Now().Add(-time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + }, + }, []string{"fosite"}) + + s.assertBadResponse(s.T(), token, err) +} + +func (s *authorizeJWTBearerSuite) TestBadResponseWithExpiryMaxDuration() { + ctx := context.Background() + client := s.getClient() + token, err := client.GetToken(ctx, &clients.JWTBearerPayload{ + Claims: &jwt.Claims{ + Issuer: firstJWTBearerIssuer, + Subject: firstJWTBearerSubject, + Audience: []string{tokenURL}, + Expiry: jwt.NewNumericDate(time.Now().Add(365 * 24 * time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + }, + }, []string{"fosite"}) + + s.assertBadResponse(s.T(), token, err) +} + +func (s *authorizeJWTBearerSuite) TestBadResponseWithInvalidPrivateKey() { + ctx := context.Background() + client := s.getClient() + wrongPrivateKey := secondPrivateKey + + if err := client.SetPrivateKey(firstKeyID, wrongPrivateKey); err != nil { + assert.Nil(s.T(), err) + } + + token, err := client.GetToken(ctx, &clients.JWTBearerPayload{ + Claims: &jwt.Claims{ + Issuer: firstJWTBearerIssuer, + Subject: firstJWTBearerSubject, + Audience: []string{tokenURL}, + Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + }, + }, nil) + + s.assertBadResponse(s.T(), token, err) +} + +func (s *authorizeJWTBearerSuite) TestBadResponseWithInvalidKeyID() { + ctx := context.Background() + client := s.getClient() + + if err := client.SetPrivateKey("wrongKeyID", firstPrivateKey); err != nil { + assert.Nil(s.T(), err) + } + + token, err := client.GetToken(ctx, &clients.JWTBearerPayload{ + Claims: &jwt.Claims{ + Issuer: firstJWTBearerIssuer, + Subject: firstJWTBearerSubject, + Audience: []string{tokenURL}, + Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + }, + }, nil) + + s.assertBadResponse(s.T(), token, err) +} + +func (s *authorizeJWTBearerSuite) TestBadResponseWithInvalidAudience() { + ctx := context.Background() + client := s.getClient() + token, err := client.GetToken(ctx, &clients.JWTBearerPayload{ + Claims: &jwt.Claims{ + Issuer: firstJWTBearerIssuer, + Subject: firstJWTBearerSubject, + Audience: []string{"https://example.com/oauth"}, + Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + }, + }, nil) + + s.assertBadResponse(s.T(), token, err) +} + +func (s *authorizeJWTBearerSuite) TestBadResponseForSecondRequestWithSameJTI() { + ctx := context.Background() + client := s.getClient() + config := &clients.JWTBearerPayload{ + Claims: &jwt.Claims{ + Issuer: firstJWTBearerIssuer, + Subject: firstJWTBearerSubject, + Audience: []string{tokenURL}, + Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + ID: uuid.New(), + }, + } + + client.GetToken(ctx, config, nil) + token2, err := client.GetToken(ctx, config, nil) + + s.assertBadResponse(s.T(), token2, err) +} + +func (s *authorizeJWTBearerSuite) TestSuccessResponseForSecondRequestWithSameJTIAfterFirstExpired() { + ctx := context.Background() + client := s.getClient() + config := &clients.JWTBearerPayload{ + Claims: &jwt.Claims{ + Issuer: firstJWTBearerIssuer, + Subject: firstJWTBearerSubject, + Audience: []string{tokenURL}, + Expiry: jwt.NewNumericDate(time.Now().Add(time.Second)), + IssuedAt: jwt.NewNumericDate(time.Now().Add(-time.Hour)), + ID: uuid.New(), + }, + } + + client.GetToken(ctx, config, nil) + + time.Sleep(time.Second) + config.Expiry = jwt.NewNumericDate(time.Now().Add(time.Hour)) + + token2, err := client.GetToken(ctx, config, nil) + + s.assertSuccessResponse(s.T(), token2, err) +} + +func (s *authorizeJWTBearerSuite) TestBadResponseWithNotBeforeLaterThenIssueAt() { + ctx := context.Background() + client := s.getClient() + token, err := client.GetToken(ctx, &clients.JWTBearerPayload{ + Claims: &jwt.Claims{ + Issuer: firstJWTBearerIssuer, + Subject: firstJWTBearerSubject, + Audience: []string{tokenURL}, + Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + NotBefore: jwt.NewNumericDate(time.Now().Add(time.Hour)), + }, + }, nil) + + s.assertBadResponse(s.T(), token, err) +} + +func (s *authorizeJWTBearerSuite) TestBadResponseWithoutSubject() { + ctx := context.Background() + client := s.getClient() + token, err := client.GetToken(ctx, &clients.JWTBearerPayload{ + Claims: &jwt.Claims{ + Issuer: firstJWTBearerIssuer, + Subject: "", + Audience: []string{tokenURL}, + Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + }, + }, nil) + + s.assertBadResponse(s.T(), token, err) +} + +func (s *authorizeJWTBearerSuite) TestBadResponseWithWrongSubject() { + ctx := context.Background() + client := s.getClient() + token, err := client.GetToken(ctx, &clients.JWTBearerPayload{ + Claims: &jwt.Claims{ + Issuer: firstJWTBearerIssuer, + Subject: "wrong_subject", + Audience: []string{tokenURL}, + Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + }, + }, nil) + + s.assertBadResponse(s.T(), token, err) +} + +func (s *authorizeJWTBearerSuite) TestBadResponseWithWrongIssuer() { + ctx := context.Background() + client := s.getClient() + token, err := client.GetToken(ctx, &clients.JWTBearerPayload{ + Claims: &jwt.Claims{ + Issuer: "wrong_issuer", + Subject: firstJWTBearerSubject, + Audience: []string{tokenURL}, + Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + }, + }, nil) + + s.assertBadResponse(s.T(), token, err) +} + +func (s *authorizeJWTBearerSuite) TestBadResponseWithWrongScope() { + ctx := context.Background() + client := s.getClient() + token, err := client.GetToken(ctx, &clients.JWTBearerPayload{ + Claims: &jwt.Claims{ + Issuer: firstJWTBearerIssuer, + Subject: firstJWTBearerSubject, + Audience: []string{tokenURL}, + Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + }, + }, []string{"fosite", "permission"}) + + s.assertBadResponse(s.T(), token, err) +} + +func (s *authorizeJWTBearerSuite) getClient() *clients.JWTBearer { + client := *s.client + + return &client +} + +func (s *authorizeJWTBearerSuite) assertSuccessResponse(t *testing.T, token *clients.Token, err error) { + assert.Nil(t, err) + assert.NotNil(t, token) + + assert.Equal(t, token.TokenType, "bearer") + assert.Empty(t, token.RefreshToken) + assert.NotEmpty(t, token.ExpiresIn) + assert.NotEmpty(t, token.AccessToken) +} + +func (s *authorizeJWTBearerSuite) assertBadResponse(t *testing.T, token *clients.Token, err error) { + assert.Nil(t, token) + assert.NotNil(t, err) + + retrieveError, ok := err.(*clients.RequestError) + assert.True(t, ok) + assert.Equal(t, retrieveError.Response.StatusCode, http.StatusBadRequest) +} + +func TestAuthorizeJWTBearerSuite(t *testing.T) { + provider := compose.Compose( + &compose.Config{ + GrantTypeJWTBearerCanSkipClientAuth: true, + GrantTypeJWTBearerIDOptional: true, + GrantTypeJWTBearerIssuedDateOptional: true, + GrantTypeJWTBearerMaxDuration: 24 * time.Hour, + TokenURL: tokenURL, + }, + fositeStore, + jwtStrategy, + nil, + compose.OAuth2ClientCredentialsGrantFactory, + compose.RFC7523AssertionGrantFactory, + ) + testServer := mockServer(t, provider, &fosite.DefaultSession{}) + defer testServer.Close() + + client := newJWTBearerAppClient(testServer) + if err := client.SetPrivateKey(firstKeyID, firstPrivateKey); err != nil { + assert.Nil(t, err) + } + + suite.Run(t, &authorizeJWTBearerSuite{ + client: client, + }) +} diff --git a/integration/clients/error.go b/integration/clients/error.go new file mode 100644 index 000000000..e7359a811 --- /dev/null +++ b/integration/clients/error.go @@ -0,0 +1,36 @@ +/* + * Copyright © 2015-2018 Aeneas Rekkas + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * @author Aeneas Rekkas + * @copyright 2015-2018 Aeneas Rekkas + * @license Apache-2.0 + * + */ + +package clients + +import ( + "fmt" + "net/http" +) + +type RequestError struct { + Response *http.Response + Body []byte +} + +func (r *RequestError) Error() string { + return fmt.Sprintf("oauth2: cannot fetch token: %v\nResponse: %s", r.Response.Status, r.Body) +} diff --git a/integration/clients/introspect.go b/integration/clients/introspect.go new file mode 100644 index 000000000..a8667a9c9 --- /dev/null +++ b/integration/clients/introspect.go @@ -0,0 +1,120 @@ +/* + * Copyright © 2015-2018 Aeneas Rekkas + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * @author Aeneas Rekkas + * @copyright 2015-2018 Aeneas Rekkas + * @license Apache-2.0 + * + */ + +package clients + +import ( + "context" + "encoding/json" + "io/ioutil" + "net/http" + "net/url" + "strings" +) + +type IntrospectForm struct { + Token string + Scopes []string +} + +type IntrospectResponse struct { + Active bool `json:"active"` + ClientID string `json:"client_id,omitempty"` + Scope string `json:"scope,omitempty"` + Audience []string `json:"aud,omitempty"` + ExpiresAt int64 `json:"exp,omitempty"` + IssuedAt int64 `json:"iat,omitempty"` + Subject string `json:"sub,omitempty"` + Username string `json:"username,omitempty"` +} + +type Introspect struct { + endpointURL string + client *http.Client +} + +func (c *Introspect) IntrospectToken( + ctx context.Context, + form IntrospectForm, + header map[string]string, +) (*IntrospectResponse, error) { + data := url.Values{} + data.Set("token", form.Token) + data.Set("scope", strings.Join(form.Scopes, " ")) + + request, err := c.getRequest(ctx, data, header) + if err != nil { + return nil, err + } + + response, err := c.client.Do(request) + if err != nil { + return nil, err + } + + defer response.Body.Close() + + body, err := ioutil.ReadAll(response.Body) + if err != nil { + return nil, err + } + + if c := response.StatusCode; c < 200 || c > 299 { + return nil, &RequestError{ + Response: response, + Body: body, + } + } + + result := &IntrospectResponse{} + + if err := json.Unmarshal(body, result); err != nil { + return nil, err + } + + return result, nil +} + +func (c *Introspect) getRequest( + ctx context.Context, + data url.Values, + header map[string]string, +) (*http.Request, error) { + request, err := http.NewRequestWithContext(ctx, "POST", c.endpointURL, strings.NewReader(data.Encode())) + if err != nil { + return nil, err + } + + request.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + for header, value := range header { + request.Header.Set(header, value) + } + + return request, nil +} + +func NewIntrospectClient(endpointURL string) *Introspect { + return &Introspect{ + endpointURL: endpointURL, + client: &http.Client{}, + } +} diff --git a/integration/clients/jwt_bearer.go b/integration/clients/jwt_bearer.go new file mode 100644 index 000000000..7e56a40c7 --- /dev/null +++ b/integration/clients/jwt_bearer.go @@ -0,0 +1,153 @@ +/* + * Copyright © 2015-2018 Aeneas Rekkas + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * @author Aeneas Rekkas + * @copyright 2015-2018 Aeneas Rekkas + * @license Apache-2.0 + * + */ + +package clients + +import ( + "context" + "crypto/rsa" + "encoding/json" + "io" + "io/ioutil" + "net/http" + "net/url" + "strings" + + "gopkg.in/square/go-jose.v2" + "gopkg.in/square/go-jose.v2/jwt" +) + +const jwtBearerGrantType = "urn:ietf:params:oauth:grant-type:jwt-bearer" + +type JWTBearer struct { + tokenURL string + header *Header + client *http.Client + + Signer jose.Signer +} + +type Token struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type,omitempty"` + RefreshToken string `json:"refresh_token,omitempty"` + ExpiresIn int64 `json:"expires_in,omitempty"` +} + +type Header struct { + Algorithm string `json:"alg"` + Typ string `json:"typ"` + KeyID string `json:"kid,omitempty"` +} + +type JWTBearerPayload struct { + *jwt.Claims + + PrivateClaims map[string]interface{} +} + +func (c *JWTBearer) SetPrivateKey(keyID string, privateKey *rsa.PrivateKey) error { + jwk := jose.JSONWebKey{Key: privateKey, KeyID: keyID, Algorithm: string(jose.RS256)} + signingKey := jose.SigningKey{ + Algorithm: jose.RS256, + Key: jwk, + } + signerOptions := &jose.SignerOptions{} + signerOptions.WithType("JWT") + + sig, err := jose.NewSigner(signingKey, signerOptions) + if err != nil { + return err + } + + c.Signer = sig + + return nil +} + +func (c *JWTBearer) GetToken(ctx context.Context, payloadData *JWTBearerPayload, scope []string) (*Token, error) { + builder := jwt.Signed(c.Signer). + Claims(payloadData.Claims). + Claims(payloadData.PrivateClaims) + + assertion, err := builder.CompactSerialize() + if err != nil { + return nil, err + } + + requestBodyReader, err := c.getRequestBodyReader(assertion, scope) + if err != nil { + return nil, err + } + + request, err := http.NewRequestWithContext(ctx, "POST", c.tokenURL, requestBodyReader) + if err != nil { + return nil, err + } + + request.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + response, err := c.client.Do(request) + if err != nil { + return nil, err + } + + defer response.Body.Close() + + body, err := ioutil.ReadAll(response.Body) + if err != nil { + return nil, err + } + + if c := response.StatusCode; c < 200 || c > 299 { + return nil, &RequestError{ + Response: response, + Body: body, + } + } + + token := &Token{} + + if err := json.Unmarshal(body, token); err != nil { + return nil, err + } + + return token, err +} + +func (c *JWTBearer) getRequestBodyReader(assertion string, scope []string) (io.Reader, error) { + data := url.Values{} + data.Set("grant_type", jwtBearerGrantType) + data.Set("assertion", string(assertion)) + + if len(scope) != 0 { + data.Set("scope", strings.Join(scope, " ")) + } + + return strings.NewReader(data.Encode()), nil +} + +func NewJWTBearer(tokenURL string) *JWTBearer { + return &JWTBearer{ + client: &http.Client{}, + tokenURL: tokenURL, + } +} diff --git a/integration/helper_setup_test.go b/integration/helper_setup_test.go index 98834028f..17cd0fbd2 100644 --- a/integration/helper_setup_test.go +++ b/integration/helper_setup_test.go @@ -22,6 +22,9 @@ package integration_test import ( + "crypto" + "crypto/rand" + "crypto/rsa" "net/http/httptest" "testing" "time" @@ -29,16 +32,37 @@ import ( "github.com/gorilla/mux" goauth "golang.org/x/oauth2" "golang.org/x/oauth2/clientcredentials" + "gopkg.in/square/go-jose.v2" "github.com/ory/fosite" "github.com/ory/fosite/handler/oauth2" "github.com/ory/fosite/handler/openid" + "github.com/ory/fosite/integration/clients" "github.com/ory/fosite/internal" "github.com/ory/fosite/storage" "github.com/ory/fosite/token/hmac" "github.com/ory/fosite/token/jwt" ) +const ( + firstKeyID = "123" + secondKeyID = "321" + + firstJWTBearerIssuer = "first@example.com" + secondJWTBearerIssuer = "second@example.com" + + firstJWTBearerSubject = "first-service-client" + secondJWTBearerSubject = "second-service-client" + + tokenURL = "https://www.ory.sh/api" + tokenRelativePath = "/token" +) + +var ( + firstPrivateKey, _ = rsa.GenerateKey(rand.Reader, 2048) + secondPrivateKey, _ = rsa.GenerateKey(rand.Reader, 2048) +) + var fositeStore = &storage.MemoryStore{ Clients: map[string]fosite.Client{ "my-client": &fosite.DefaultClient{ @@ -48,7 +72,7 @@ var fositeStore = &storage.MemoryStore{ ResponseTypes: []string{"id_token", "code", "token", "token code", "id_token code", "token id_token", "token code id_token"}, GrantTypes: []string{"implicit", "refresh_token", "authorization_code", "password", "client_credentials"}, Scopes: []string{"fosite", "offline", "openid"}, - Audience: []string{"https://www.ory.sh/api"}, + Audience: []string{tokenURL}, }, "public-client": &fosite.DefaultClient{ ID: "public-client", @@ -58,7 +82,7 @@ var fositeStore = &storage.MemoryStore{ ResponseTypes: []string{"id_token", "code", "code id_token"}, GrantTypes: []string{"refresh_token", "authorization_code"}, Scopes: []string{"fosite", "offline", "openid"}, - Audience: []string{"https://www.ory.sh/api"}, + Audience: []string{tokenURL}, }, }, Users: map[string]storage.MemoryUserRelation{ @@ -67,6 +91,23 @@ var fositeStore = &storage.MemoryStore{ Password: "secret", }, }, + IssuerPublicKeys: map[string]storage.IssuerPublicKeys{ + firstJWTBearerIssuer: createIssuerPublicKey( + firstJWTBearerIssuer, + firstJWTBearerSubject, + firstKeyID, + firstPrivateKey.Public(), + []string{"fosite", "gitlab", "example.com", "docker"}, + ), + secondJWTBearerIssuer: createIssuerPublicKey( + secondJWTBearerIssuer, + secondJWTBearerSubject, + secondKeyID, + secondPrivateKey.Public(), + []string{"fosite"}, + ), + }, + BlacklistedJTIs: map[string]time.Time{}, AuthorizeCodes: map[string]storage.StoreAuthorizeCode{}, PKCES: map[string]fosite.Requester{}, AccessTokens: map[string]fosite.Requester{}, @@ -84,6 +125,28 @@ var accessTokenLifespan = time.Hour var authCodeLifespan = time.Minute +func createIssuerPublicKey(issuer, subject, keyID string, key crypto.PublicKey, scopes []string) storage.IssuerPublicKeys { + return storage.IssuerPublicKeys{ + Issuer: issuer, + KeysBySub: map[string]storage.SubjectPublicKeys{ + subject: { + Subject: subject, + Keys: map[string]storage.PublicKeyScopes{ + keyID: { + Key: &jose.JSONWebKey{ + Key: key, + Algorithm: string(jose.RS256), + Use: "sig", + KeyID: keyID, + }, + Scopes: scopes, + }, + }, + }, + }, + } +} + func newOAuth2Client(ts *httptest.Server) *goauth.Config { return &goauth.Config{ ClientID: "my-client", @@ -91,7 +154,7 @@ func newOAuth2Client(ts *httptest.Server) *goauth.Config { RedirectURL: ts.URL + "/callback", Scopes: []string{"fosite"}, Endpoint: goauth.Endpoint{ - TokenURL: ts.URL + "/token", + TokenURL: ts.URL + tokenRelativePath, AuthURL: ts.URL + "/auth", AuthStyle: goauth.AuthStyleInHeader, }, @@ -103,10 +166,14 @@ func newOAuth2AppClient(ts *httptest.Server) *clientcredentials.Config { ClientID: "my-client", ClientSecret: "foobar", Scopes: []string{"fosite"}, - TokenURL: ts.URL + "/token", + TokenURL: ts.URL + tokenRelativePath, } } +func newJWTBearerAppClient(ts *httptest.Server) *clients.JWTBearer { + return clients.NewJWTBearer(ts.URL + tokenRelativePath) +} + var hmacStrategy = &oauth2.HMACSHAStrategy{ Enigma: &hmac.HMACStrategy{ GlobalSecret: []byte("some-super-cool-secret-that-nobody-knows"), @@ -125,7 +192,7 @@ var jwtStrategy = &oauth2.DefaultJWTStrategy{ func mockServer(t *testing.T, f fosite.OAuth2Provider, session fosite.Session) *httptest.Server { router := mux.NewRouter() router.HandleFunc("/auth", authEndpointHandler(t, f, session)) - router.HandleFunc("/token", tokenEndpointHandler(t, f)) + router.HandleFunc(tokenRelativePath, tokenEndpointHandler(t, f)) router.HandleFunc("/callback", authCallbackHandler(t)) router.HandleFunc("/info", tokenInfoHandler(t, f, session)) router.HandleFunc("/introspect", tokenIntrospectionHandler(t, f, session)) diff --git a/integration/introspect_jwt_bearer_token_test.go b/integration/introspect_jwt_bearer_token_test.go new file mode 100644 index 000000000..e5a112b72 --- /dev/null +++ b/integration/introspect_jwt_bearer_token_test.go @@ -0,0 +1,290 @@ +/* + * Copyright © 2015-2018 Aeneas Rekkas + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * @author Aeneas Rekkas + * @copyright 2015-2018 Aeneas Rekkas + * @license Apache-2.0 + * + */ + +package integration_test + +import ( + "context" + "net/http" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" + "gopkg.in/square/go-jose.v2/jwt" + + "github.com/ory/fosite" + "github.com/ory/fosite/compose" + "github.com/ory/fosite/integration/clients" +) + +type introspectJWTBearerTokenSuite struct { + suite.Suite + + clientJWT *clients.JWTBearer + clientIntrospect *clients.Introspect + clientTokenPayload *clients.JWTBearerPayload + appTokenPayload *clients.JWTBearerPayload + + authorizationHeader string + scopes []string + audience []string +} + +func (s *introspectJWTBearerTokenSuite) SetupTest() { + s.scopes = []string{"fosite"} + s.audience = []string{tokenURL, "https://example.com"} + + s.clientTokenPayload = &clients.JWTBearerPayload{ + Claims: &jwt.Claims{ + Issuer: firstJWTBearerIssuer, + Subject: firstJWTBearerSubject, + Audience: s.audience, + Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), + }, + } + + s.appTokenPayload = &clients.JWTBearerPayload{ + Claims: &jwt.Claims{ + Issuer: secondJWTBearerIssuer, + Subject: secondJWTBearerSubject, + Audience: s.audience, + Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), + }, + } +} + +func (s *introspectJWTBearerTokenSuite) TestSuccessResponseWithMultipleScopesToken() { + ctx := context.Background() + + scopes := []string{"fosite", "docker"} + token, err := s.getJWTClient().GetToken(ctx, s.clientTokenPayload, scopes) + assert.Nil(s.T(), err) + + response, err := s.clientIntrospect.IntrospectToken( + ctx, + clients.IntrospectForm{ + Token: token.AccessToken, + Scopes: nil, + }, + map[string]string{"Authorization": s.authorizationHeader}, + ) + + s.assertSuccessResponse(s.T(), response, err, firstJWTBearerSubject) + assert.Equal(s.T(), strings.Split(response.Scope, " "), scopes) +} + +func (s *introspectJWTBearerTokenSuite) TestUnActiveResponseWithInvalidScopes() { + ctx := context.Background() + + token, err := s.getJWTClient().GetToken(ctx, s.clientTokenPayload, s.scopes) + assert.Nil(s.T(), err) + + response, err := s.clientIntrospect.IntrospectToken( + ctx, + clients.IntrospectForm{ + Token: token.AccessToken, + Scopes: []string{"invalid"}, + }, + map[string]string{"Authorization": s.authorizationHeader}, + ) + + assert.Nil(s.T(), err) + assert.NotNil(s.T(), response) + assert.False(s.T(), response.Active) +} + +func (s *introspectJWTBearerTokenSuite) TestSuccessResponseWithoutScopesForIntrospection() { + ctx := context.Background() + + token, err := s.getJWTClient().GetToken(ctx, s.clientTokenPayload, s.scopes) + assert.Nil(s.T(), err) + + response, err := s.clientIntrospect.IntrospectToken( + ctx, + clients.IntrospectForm{ + Token: token.AccessToken, + Scopes: nil, + }, + map[string]string{"Authorization": s.authorizationHeader}, + ) + + s.assertSuccessResponse(s.T(), response, err, firstJWTBearerSubject) +} + +func (s *introspectJWTBearerTokenSuite) TestSuccessResponseWithoutScopes() { + ctx := context.Background() + + token, err := s.getJWTClient().GetToken(ctx, s.clientTokenPayload, nil) + assert.Nil(s.T(), err) + + response, err := s.clientIntrospect.IntrospectToken( + ctx, + clients.IntrospectForm{ + Token: token.AccessToken, + Scopes: nil, + }, + map[string]string{"Authorization": s.authorizationHeader}, + ) + + s.assertSuccessResponse(s.T(), response, err, firstJWTBearerSubject) +} + +func (s *introspectJWTBearerTokenSuite) TestSubjectHasAccessToScopeButNotInited() { + ctx := context.Background() + + token, err := s.getJWTClient().GetToken(ctx, s.clientTokenPayload, nil) + assert.Nil(s.T(), err) + + response, err := s.clientIntrospect.IntrospectToken( + ctx, + clients.IntrospectForm{ + Token: token.AccessToken, + Scopes: s.scopes, + }, + map[string]string{"Authorization": s.authorizationHeader}, + ) + + assert.Nil(s.T(), err) + assert.NotNil(s.T(), response) + assert.False(s.T(), response.Active) +} + +func (s *introspectJWTBearerTokenSuite) TestTheSameTokenInRequestAndHeader() { + ctx := context.Background() + token, err := s.getJWTClient().GetToken(ctx, s.clientTokenPayload, s.scopes) + assert.Nil(s.T(), err) + + response, err := s.clientIntrospect.IntrospectToken( + ctx, + clients.IntrospectForm{ + Token: token.AccessToken, + Scopes: nil, + }, + map[string]string{"Authorization": "bearer " + token.AccessToken}, + ) + + s.assertUnauthorizedResponse(s.T(), response, err) +} + +func (s *introspectJWTBearerTokenSuite) TestUnauthorizedResponseForRequestWithoutAuthorization() { + ctx := context.Background() + token, err := s.getJWTClient().GetToken(ctx, s.clientTokenPayload, s.scopes) + assert.Nil(s.T(), err) + + response, err := s.clientIntrospect.IntrospectToken( + ctx, + clients.IntrospectForm{ + Token: token.AccessToken, + Scopes: nil, + }, + nil, + ) + + s.assertUnauthorizedResponse(s.T(), response, err) +} + +func (s *introspectJWTBearerTokenSuite) getJWTClient() *clients.JWTBearer { + client := *s.clientJWT + + return &client +} + +func (s *introspectJWTBearerTokenSuite) assertSuccessResponse( + t *testing.T, + response *clients.IntrospectResponse, + err error, + subject string, +) { + assert.Nil(t, err) + assert.NotNil(t, response) + + assert.True(t, response.Active) + assert.Equal(t, response.Subject, subject) + assert.NotEmpty(t, response.ExpiresAt) + assert.NotEmpty(t, response.IssuedAt) + assert.Equal(t, response.Audience, s.audience) + + tokenDuration := time.Unix(response.ExpiresAt, 0).Sub(time.Unix(response.IssuedAt, 0)) + assert.Less(t, int64(tokenDuration), int64(time.Hour+time.Minute)) + assert.Greater(t, int64(tokenDuration), int64(time.Hour-time.Minute)) +} + +func (s *introspectJWTBearerTokenSuite) assertUnauthorizedResponse( + t *testing.T, + response *clients.IntrospectResponse, + err error, +) { + assert.Nil(t, response) + assert.NotNil(t, err) + + retrieveError, ok := err.(*clients.RequestError) + assert.True(t, ok) + assert.Equal(t, retrieveError.Response.StatusCode, http.StatusUnauthorized) +} + +func TestIntrospectJWTBearerTokenSuite(t *testing.T) { + provider := compose.Compose( + &compose.Config{ + GrantTypeJWTBearerCanSkipClientAuth: true, + GrantTypeJWTBearerIDOptional: true, + GrantTypeJWTBearerIssuedDateOptional: true, + AccessTokenLifespan: time.Hour, + TokenURL: tokenURL, + }, + fositeStore, + jwtStrategy, + nil, + compose.OAuth2ClientCredentialsGrantFactory, + compose.RFC7523AssertionGrantFactory, + compose.OAuth2TokenIntrospectionFactory, + ) + testServer := mockServer(t, provider, &fosite.DefaultSession{}) + defer testServer.Close() + + client := newJWTBearerAppClient(testServer) + if err := client.SetPrivateKey(secondKeyID, secondPrivateKey); err != nil { + assert.Nil(t, err) + } + + token, err := client.GetToken(context.Background(), &clients.JWTBearerPayload{ + Claims: &jwt.Claims{ + Issuer: secondJWTBearerIssuer, + Subject: secondJWTBearerSubject, + Audience: []string{tokenURL}, + Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), + }, + }, []string{"fosite"}) + if err != nil { + assert.Nil(t, err) + } + + if err := client.SetPrivateKey(firstKeyID, firstPrivateKey); err != nil { + assert.Nil(t, err) + } + + suite.Run(t, &introspectJWTBearerTokenSuite{ + clientJWT: client, + clientIntrospect: clients.NewIntrospectClient(testServer.URL + "/introspect"), + authorizationHeader: "bearer " + token.AccessToken, + }) +} diff --git a/internal/authorize_request.go b/internal/authorize_request.go index 91717d790..e1228a11d 100644 --- a/internal/authorize_request.go +++ b/internal/authorize_request.go @@ -206,7 +206,7 @@ func (mr *MockAuthorizeRequesterMockRecorder) GetRequestedScopes() *gomock.Call // GetResponseMode mocks base method func (m *MockAuthorizeRequester) GetResponseMode() fosite.ResponseModeType { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetResponseModes") + ret := m.ctrl.Call(m, "GetResponseMode") ret0, _ := ret[0].(fosite.ResponseModeType) return ret0 } @@ -214,7 +214,7 @@ func (m *MockAuthorizeRequester) GetResponseMode() fosite.ResponseModeType { // GetResponseMode indicates an expected call of GetResponseMode func (mr *MockAuthorizeRequesterMockRecorder) GetResponseMode() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetResponseModes", reflect.TypeOf((*MockAuthorizeRequester)(nil).GetResponseMode)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetResponseMode", reflect.TypeOf((*MockAuthorizeRequester)(nil).GetResponseMode)) } // GetResponseTypes mocks base method @@ -371,18 +371,6 @@ func (mr *MockAuthorizeRequesterMockRecorder) SetRequestedScopes(arg0 interface{ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetRequestedScopes", reflect.TypeOf((*MockAuthorizeRequester)(nil).SetRequestedScopes), arg0) } -// SetResponseMode mocks base method -func (m *MockAuthorizeRequester) SetResponseMode(arg0 fosite.ResponseModeType) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SetResponseMode", arg0) -} - -// SetResponseMode indicates an expected call of SetResponseMode -func (mr *MockAuthorizeRequesterMockRecorder) SetResponseMode(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetResponseMode", reflect.TypeOf((*MockAuthorizeRequester)(nil).SetResponseMode), arg0) -} - // SetResponseTypeHandled mocks base method func (m *MockAuthorizeRequester) SetResponseTypeHandled(arg0 string) { m.ctrl.T.Helper() diff --git a/internal/oauth2_auth_jwt_storage.go b/internal/oauth2_auth_jwt_storage.go new file mode 100644 index 000000000..e38e63d68 --- /dev/null +++ b/internal/oauth2_auth_jwt_storage.go @@ -0,0 +1,111 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/ory/fosite/handler/rfc7523 (interfaces: RFC7523KeyStorage) + +// Package internal is a generated GoMock package. +package internal + +import ( + context "context" + reflect "reflect" + time "time" + + gomock "github.com/golang/mock/gomock" + go_jose_v2 "gopkg.in/square/go-jose.v2" +) + +// MockRFC7523KeyStorage is a mock of RFC7523KeyStorage interface +type MockRFC7523KeyStorage struct { + ctrl *gomock.Controller + recorder *MockRFC7523KeyStorageMockRecorder +} + +// MockRFC7523KeyStorageMockRecorder is the mock recorder for MockRFC7523KeyStorage +type MockRFC7523KeyStorageMockRecorder struct { + mock *MockRFC7523KeyStorage +} + +// NewMockRFC7523KeyStorage creates a new mock instance +func NewMockRFC7523KeyStorage(ctrl *gomock.Controller) *MockRFC7523KeyStorage { + mock := &MockRFC7523KeyStorage{ctrl: ctrl} + mock.recorder = &MockRFC7523KeyStorageMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockRFC7523KeyStorage) EXPECT() *MockRFC7523KeyStorageMockRecorder { + return m.recorder +} + +// GetPublicKey mocks base method +func (m *MockRFC7523KeyStorage) GetPublicKey(arg0 context.Context, arg1, arg2, arg3 string) (*go_jose_v2.JSONWebKey, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetPublicKey", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(*go_jose_v2.JSONWebKey) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetPublicKey indicates an expected call of GetPublicKey +func (mr *MockRFC7523KeyStorageMockRecorder) GetPublicKey(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPublicKey", reflect.TypeOf((*MockRFC7523KeyStorage)(nil).GetPublicKey), arg0, arg1, arg2, arg3) +} + +// GetPublicKeyScopes mocks base method +func (m *MockRFC7523KeyStorage) GetPublicKeyScopes(arg0 context.Context, arg1, arg2, arg3 string) ([]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetPublicKeyScopes", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetPublicKeyScopes indicates an expected call of GetPublicKeyScopes +func (mr *MockRFC7523KeyStorageMockRecorder) GetPublicKeyScopes(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPublicKeyScopes", reflect.TypeOf((*MockRFC7523KeyStorage)(nil).GetPublicKeyScopes), arg0, arg1, arg2, arg3) +} + +// GetPublicKeys mocks base method +func (m *MockRFC7523KeyStorage) GetPublicKeys(arg0 context.Context, arg1, arg2 string) (*go_jose_v2.JSONWebKeySet, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetPublicKeys", arg0, arg1, arg2) + ret0, _ := ret[0].(*go_jose_v2.JSONWebKeySet) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetPublicKeys indicates an expected call of GetPublicKeys +func (mr *MockRFC7523KeyStorageMockRecorder) GetPublicKeys(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPublicKeys", reflect.TypeOf((*MockRFC7523KeyStorage)(nil).GetPublicKeys), arg0, arg1, arg2) +} + +// IsJWTUsed mocks base method +func (m *MockRFC7523KeyStorage) IsJWTUsed(arg0 context.Context, arg1 string) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsJWTUsed", arg0, arg1) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// IsJWTUsed indicates an expected call of IsJWTUsed +func (mr *MockRFC7523KeyStorageMockRecorder) IsJWTUsed(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsJWTUsed", reflect.TypeOf((*MockRFC7523KeyStorage)(nil).IsJWTUsed), arg0, arg1) +} + +// MarkJWTUsedForTime mocks base method +func (m *MockRFC7523KeyStorage) MarkJWTUsedForTime(arg0 context.Context, arg1 string, arg2 time.Time) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "MarkJWTUsedForTime", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// MarkJWTUsedForTime indicates an expected call of MarkJWTUsedForTime +func (mr *MockRFC7523KeyStorageMockRecorder) MarkJWTUsedForTime(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkJWTUsedForTime", reflect.TypeOf((*MockRFC7523KeyStorage)(nil).MarkJWTUsedForTime), arg0, arg1, arg2) +} diff --git a/internal/token_handler.go b/internal/token_handler.go index a0626da5c..d9bfddcb9 100644 --- a/internal/token_handler.go +++ b/internal/token_handler.go @@ -36,6 +36,34 @@ func (m *MockTokenEndpointHandler) EXPECT() *MockTokenEndpointHandlerMockRecorde return m.recorder } +// CanHandleTokenEndpointRequest mocks base method +func (m *MockTokenEndpointHandler) CanHandleTokenEndpointRequest(arg0 fosite.AccessRequester) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CanHandleTokenEndpointRequest", arg0) + ret0, _ := ret[0].(bool) + return ret0 +} + +// CanHandleTokenEndpointRequest indicates an expected call of CanHandleTokenEndpointRequest +func (mr *MockTokenEndpointHandlerMockRecorder) CanHandleTokenEndpointRequest(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CanHandleTokenEndpointRequest", reflect.TypeOf((*MockTokenEndpointHandler)(nil).CanHandleTokenEndpointRequest), arg0) +} + +// CanSkipClientAuth mocks base method +func (m *MockTokenEndpointHandler) CanSkipClientAuth(arg0 fosite.AccessRequester) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CanSkipClientAuth", arg0) + ret0, _ := ret[0].(bool) + return ret0 +} + +// CanSkipClientAuth indicates an expected call of CanSkipClientAuth +func (mr *MockTokenEndpointHandlerMockRecorder) CanSkipClientAuth(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CanSkipClientAuth", reflect.TypeOf((*MockTokenEndpointHandler)(nil).CanSkipClientAuth), arg0) +} + // HandleTokenEndpointRequest mocks base method func (m *MockTokenEndpointHandler) HandleTokenEndpointRequest(arg0 context.Context, arg1 fosite.AccessRequester) error { m.ctrl.T.Helper() diff --git a/session.go b/session.go index c551e297c..8fb593605 100644 --- a/session.go +++ b/session.go @@ -82,6 +82,10 @@ func (s *DefaultSession) GetUsername() string { return s.Username } +func (s *DefaultSession) SetSubject(subject string) { + s.Subject = subject +} + func (s *DefaultSession) GetSubject() string { if s == nil { return "" diff --git a/storage/memory.go b/storage/memory.go index 04cdcc719..d5d658801 100644 --- a/storage/memory.go +++ b/storage/memory.go @@ -27,6 +27,7 @@ import ( "time" "github.com/pkg/errors" + "gopkg.in/square/go-jose.v2" "github.com/ory/fosite" ) @@ -36,6 +37,21 @@ type MemoryUserRelation struct { Password string } +type IssuerPublicKeys struct { + Issuer string + KeysBySub map[string]SubjectPublicKeys +} + +type SubjectPublicKeys struct { + Subject string + Keys map[string]PublicKeyScopes +} + +type PublicKeyScopes struct { + Key *jose.JSONWebKey + Scopes []string +} + type MemoryStore struct { Clients map[string]fosite.Client AuthorizeCodes map[string]StoreAuthorizeCode @@ -48,6 +64,8 @@ type MemoryStore struct { // In-memory request ID to token signatures AccessTokenRequestIDs map[string]string RefreshTokenRequestIDs map[string]string + // Public keys to check signature in auth grant jwt assertion. + IssuerPublicKeys map[string]IssuerPublicKeys clientsMutex sync.RWMutex authorizeCodesMutex sync.RWMutex @@ -59,6 +77,7 @@ type MemoryStore struct { blacklistedJTIsMutex sync.RWMutex accessTokenRequestIDsMutex sync.RWMutex refreshTokenRequestIDsMutex sync.RWMutex + issuerPublicKeysMutex sync.RWMutex } func NewMemoryStore() *MemoryStore { @@ -73,6 +92,7 @@ func NewMemoryStore() *MemoryStore { AccessTokenRequestIDs: make(map[string]string), RefreshTokenRequestIDs: make(map[string]string), BlacklistedJTIs: make(map[string]time.Time), + IssuerPublicKeys: make(map[string]IssuerPublicKeys), } } @@ -116,6 +136,7 @@ func NewExampleStore() *MemoryStore { PKCES: map[string]fosite.Requester{}, AccessTokenRequestIDs: map[string]string{}, RefreshTokenRequestIDs: map[string]string{}, + IssuerPublicKeys: map[string]IssuerPublicKeys{}, } } @@ -356,3 +377,67 @@ func (s *MemoryStore) RevokeAccessToken(ctx context.Context, requestID string) e } return nil } + +func (s *MemoryStore) GetPublicKey(ctx context.Context, issuer string, subject string, keyId string) (*jose.JSONWebKey, error) { + s.issuerPublicKeysMutex.RLock() + defer s.issuerPublicKeysMutex.RUnlock() + + if issuerKeys, ok := s.IssuerPublicKeys[issuer]; ok { + if subKeys, ok := issuerKeys.KeysBySub[subject]; ok { + if keyScopes, ok := subKeys.Keys[keyId]; ok { + return keyScopes.Key, nil + } + } + } + + return nil, fosite.ErrNotFound +} +func (s *MemoryStore) GetPublicKeys(ctx context.Context, issuer string, subject string) (*jose.JSONWebKeySet, error) { + s.issuerPublicKeysMutex.RLock() + defer s.issuerPublicKeysMutex.RUnlock() + + if issuerKeys, ok := s.IssuerPublicKeys[issuer]; ok { + if subKeys, ok := issuerKeys.KeysBySub[subject]; ok { + if len(subKeys.Keys) == 0 { + return nil, fosite.ErrNotFound + } + + keys := make([]jose.JSONWebKey, 0, len(subKeys.Keys)) + for _, keyScopes := range subKeys.Keys { + keys = append(keys, *keyScopes.Key) + } + + return &jose.JSONWebKeySet{Keys: keys}, nil + } + } + + return nil, fosite.ErrNotFound +} + +func (s *MemoryStore) GetPublicKeyScopes(ctx context.Context, issuer string, subject string, keyId string) ([]string, error) { + s.issuerPublicKeysMutex.RLock() + defer s.issuerPublicKeysMutex.RUnlock() + + if issuerKeys, ok := s.IssuerPublicKeys[issuer]; ok { + if subKeys, ok := issuerKeys.KeysBySub[subject]; ok { + if keyScopes, ok := subKeys.Keys[keyId]; ok { + return keyScopes.Scopes, nil + } + } + } + + return nil, fosite.ErrNotFound +} + +func (s *MemoryStore) IsJWTUsed(ctx context.Context, jti string) (bool, error) { + err := s.ClientAssertionJWTValid(ctx, jti) + if err != nil { + return true, nil + } + + return false, nil +} + +func (s *MemoryStore) MarkJWTUsedForTime(ctx context.Context, jti string, exp time.Time) error { + return s.SetClientAssertionJWT(ctx, jti, exp) +}