diff --git a/README.md b/README.md index 4abd82c3dee..a6934942ee9 100644 --- a/README.md +++ b/README.md @@ -265,6 +265,16 @@ Then run it with in-memory database: DATABASE_URL=memory go run main.go host ``` +If you want to add mocks for the interfaces you are being testing, refer to [generator script](generate-mocks.sh). +* Add needed mocks generation to this script. +* Then run +```bash +go get github.com/golang/mock/gomock # if you don't have it +go get github.com/golang/mock/mockgen # if you don't have it +go get golang.org/x/tools/cmd/goimports # if you don't have it +./generate-mocks.sh +``` + **Notes** * We changed organization name from `ory-am` to `ory`. In order to keep backwards compatibility, we did not rename Go packages. diff --git a/cmd/server/handler_oauth2_factory.go b/cmd/server/handler_oauth2_factory.go index 53a4695af83..06599bbf68c 100644 --- a/cmd/server/handler_oauth2_factory.go +++ b/cmd/server/handler_oauth2_factory.go @@ -56,7 +56,11 @@ func injectFositeStore(c *config.Config, clients client.Manager) { func newOAuth2Provider(c *config.Config, km jwk.Manager) fosite.OAuth2Provider { var ctx = c.Context() - var store = ctx.FositeStore + var store = oauth2.CommonStore{ + FositeStorer: ctx.FositeStore, + KeyManager: km, + ClusterURL: c.ClusterURL, + } createRS256KeysIfNotExist(c, oauth2.OpenIDConnectKeyName, "private", "sig") keys, err := km.GetKey(oauth2.OpenIDConnectKeyName, "private") @@ -80,6 +84,7 @@ func newOAuth2Provider(c *config.Config, km jwk.Manager) fosite.OAuth2Provider { IDTokenLifespan: c.GetIDTokenLifespan(), HashCost: c.BCryptWorkFactor, } + return compose.Compose( fc, store, @@ -91,6 +96,7 @@ func newOAuth2Provider(c *config.Config, km jwk.Manager) fosite.OAuth2Provider { compose.OAuth2AuthorizeExplicitFactory, compose.OAuth2AuthorizeImplicitFactory, compose.OAuth2ClientCredentialsGrantFactory, + oauth2.JWTBearerGrantFactory, compose.OAuth2RefreshTokenGrantFactory, compose.OpenIDConnectExplicitFactory, compose.OpenIDConnectHybridFactory, diff --git a/generate-mocks.sh b/generate-mocks.sh new file mode 100755 index 00000000000..891c249a0e6 --- /dev/null +++ b/generate-mocks.sh @@ -0,0 +1,16 @@ +#!/bin/sh + +PROJECT_ROOT=github.com/ory/hydra +MOCKS_DIR=internal/mocks + +mockgen -package internal -destination $MOCKS_DIR/fosite_access_token_strategy.go github.com/ory/fosite/handler/oauth2 AccessTokenStrategy +mockgen -package internal -destination $MOCKS_DIR/fosite_access_token_storage.go github.com/ory/fosite/handler/oauth2 AccessTokenStorage +mockgen -package internal -destination $MOCKS_DIR/fosite_access_request.go github.com/ory/fosite AccessRequester +mockgen -package internal -destination $MOCKS_DIR/fosite_client.go github.com/ory/fosite Client +mockgen -package internal -destination $MOCKS_DIR/hydra_key_manager.go github.com/ory/hydra/jwk Manager + +# See https://github.com/golang/mock/issues/30 +find $MOCKS_DIR -type f -exec sed -i.bak "s,$PROJECT_ROOT/vendor/,,g" {} \; +rm -rf $MOCKS_DIR/*.bak + +goimports -w $MOCKS_DIR/ diff --git a/glide.lock b/glide.lock index f3e42dc3cb3..ca7997b6528 100644 --- a/glide.lock +++ b/glide.lock @@ -1,5 +1,5 @@ -hash: 4f5430a29f1bb6f18cc5a1d401f62b5f4f6ebda2eb461f04b4084be1f975109d -updated: 2017-10-10T21:36:28.006157-06:00 +hash: ce07cdc316b9dfe7950e78afc664b6b3b32116e8a4402f68b160f3352d644cf2 +updated: 2017-11-23T14:46:01.87753904+03:00 imports: - name: github.com/asaskevich/govalidator version: 4918b99a7cb949bb295f3c7bbaf24b577d806e35 @@ -55,6 +55,10 @@ imports: version: d2a6d0596004cc01062a2a068540b817f911e6dc - name: github.com/go-sql-driver/mysql version: a0583e0143b1624142adab07e0e97fe106d99561 +- name: github.com/golang/mock + version: 13f360950a79f5864a972c786a10a50e44b69541 + subpackages: + - gomock - name: github.com/golang/protobuf version: 130e6b02ab059e7b717a096f397c5b60111cae74 subpackages: diff --git a/glide.yaml b/glide.yaml index 7cbbdd262cf..63268869535 100644 --- a/glide.yaml +++ b/glide.yaml @@ -67,6 +67,10 @@ import: - clientcredentials - package: gopkg.in/yaml.v2 - package: github.com/mohae/deepcopy +- package: github.com/golang/mock + version: ^1.0.0 + subpackages: + - gomock testImport: - package: github.com/bmizerany/assert - package: github.com/ory/dockertest diff --git a/internal/mocks/fosite_access_request.go b/internal/mocks/fosite_access_request.go new file mode 100644 index 00000000000..1ac46ecad62 --- /dev/null +++ b/internal/mocks/fosite_access_request.go @@ -0,0 +1,183 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/ory/fosite (interfaces: AccessRequester) + +// Package internal is a generated GoMock package. +package internal + +import ( + url "net/url" + reflect "reflect" + time "time" + + gomock "github.com/golang/mock/gomock" + fosite "github.com/ory/fosite" +) + +// MockAccessRequester is a mock of AccessRequester interface +type MockAccessRequester struct { + ctrl *gomock.Controller + recorder *MockAccessRequesterMockRecorder +} + +// MockAccessRequesterMockRecorder is the mock recorder for MockAccessRequester +type MockAccessRequesterMockRecorder struct { + mock *MockAccessRequester +} + +// NewMockAccessRequester creates a new mock instance +func NewMockAccessRequester(ctrl *gomock.Controller) *MockAccessRequester { + mock := &MockAccessRequester{ctrl: ctrl} + mock.recorder = &MockAccessRequesterMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockAccessRequester) EXPECT() *MockAccessRequesterMockRecorder { + return m.recorder +} + +// AppendRequestedScope mocks base method +func (m *MockAccessRequester) AppendRequestedScope(arg0 string) { + m.ctrl.Call(m, "AppendRequestedScope", arg0) +} + +// AppendRequestedScope indicates an expected call of AppendRequestedScope +func (mr *MockAccessRequesterMockRecorder) AppendRequestedScope(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppendRequestedScope", reflect.TypeOf((*MockAccessRequester)(nil).AppendRequestedScope), arg0) +} + +// GetClient mocks base method +func (m *MockAccessRequester) GetClient() fosite.Client { + ret := m.ctrl.Call(m, "GetClient") + ret0, _ := ret[0].(fosite.Client) + return ret0 +} + +// GetClient indicates an expected call of GetClient +func (mr *MockAccessRequesterMockRecorder) GetClient() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClient", reflect.TypeOf((*MockAccessRequester)(nil).GetClient)) +} + +// GetGrantTypes mocks base method +func (m *MockAccessRequester) GetGrantTypes() fosite.Arguments { + ret := m.ctrl.Call(m, "GetGrantTypes") + ret0, _ := ret[0].(fosite.Arguments) + return ret0 +} + +// GetGrantTypes indicates an expected call of GetGrantTypes +func (mr *MockAccessRequesterMockRecorder) GetGrantTypes() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGrantTypes", reflect.TypeOf((*MockAccessRequester)(nil).GetGrantTypes)) +} + +// GetGrantedScopes mocks base method +func (m *MockAccessRequester) GetGrantedScopes() fosite.Arguments { + ret := m.ctrl.Call(m, "GetGrantedScopes") + ret0, _ := ret[0].(fosite.Arguments) + return ret0 +} + +// GetGrantedScopes indicates an expected call of GetGrantedScopes +func (mr *MockAccessRequesterMockRecorder) GetGrantedScopes() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGrantedScopes", reflect.TypeOf((*MockAccessRequester)(nil).GetGrantedScopes)) +} + +// GetID mocks base method +func (m *MockAccessRequester) GetID() string { + ret := m.ctrl.Call(m, "GetID") + ret0, _ := ret[0].(string) + return ret0 +} + +// GetID indicates an expected call of GetID +func (mr *MockAccessRequesterMockRecorder) GetID() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetID", reflect.TypeOf((*MockAccessRequester)(nil).GetID)) +} + +// GetRequestForm mocks base method +func (m *MockAccessRequester) GetRequestForm() url.Values { + ret := m.ctrl.Call(m, "GetRequestForm") + ret0, _ := ret[0].(url.Values) + return ret0 +} + +// GetRequestForm indicates an expected call of GetRequestForm +func (mr *MockAccessRequesterMockRecorder) GetRequestForm() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRequestForm", reflect.TypeOf((*MockAccessRequester)(nil).GetRequestForm)) +} + +// GetRequestedAt mocks base method +func (m *MockAccessRequester) GetRequestedAt() time.Time { + ret := m.ctrl.Call(m, "GetRequestedAt") + ret0, _ := ret[0].(time.Time) + return ret0 +} + +// GetRequestedAt indicates an expected call of GetRequestedAt +func (mr *MockAccessRequesterMockRecorder) GetRequestedAt() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRequestedAt", reflect.TypeOf((*MockAccessRequester)(nil).GetRequestedAt)) +} + +// GetRequestedScopes mocks base method +func (m *MockAccessRequester) GetRequestedScopes() fosite.Arguments { + ret := m.ctrl.Call(m, "GetRequestedScopes") + ret0, _ := ret[0].(fosite.Arguments) + return ret0 +} + +// GetRequestedScopes indicates an expected call of GetRequestedScopes +func (mr *MockAccessRequesterMockRecorder) GetRequestedScopes() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRequestedScopes", reflect.TypeOf((*MockAccessRequester)(nil).GetRequestedScopes)) +} + +// GetSession mocks base method +func (m *MockAccessRequester) GetSession() fosite.Session { + ret := m.ctrl.Call(m, "GetSession") + ret0, _ := ret[0].(fosite.Session) + return ret0 +} + +// GetSession indicates an expected call of GetSession +func (mr *MockAccessRequesterMockRecorder) GetSession() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSession", reflect.TypeOf((*MockAccessRequester)(nil).GetSession)) +} + +// GrantScope mocks base method +func (m *MockAccessRequester) GrantScope(arg0 string) { + m.ctrl.Call(m, "GrantScope", arg0) +} + +// GrantScope indicates an expected call of GrantScope +func (mr *MockAccessRequesterMockRecorder) GrantScope(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GrantScope", reflect.TypeOf((*MockAccessRequester)(nil).GrantScope), arg0) +} + +// Merge mocks base method +func (m *MockAccessRequester) Merge(arg0 fosite.Requester) { + m.ctrl.Call(m, "Merge", arg0) +} + +// Merge indicates an expected call of Merge +func (mr *MockAccessRequesterMockRecorder) Merge(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Merge", reflect.TypeOf((*MockAccessRequester)(nil).Merge), arg0) +} + +// SetRequestedScopes mocks base method +func (m *MockAccessRequester) SetRequestedScopes(arg0 fosite.Arguments) { + m.ctrl.Call(m, "SetRequestedScopes", arg0) +} + +// SetRequestedScopes indicates an expected call of SetRequestedScopes +func (mr *MockAccessRequesterMockRecorder) SetRequestedScopes(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetRequestedScopes", reflect.TypeOf((*MockAccessRequester)(nil).SetRequestedScopes), arg0) +} + +// SetSession mocks base method +func (m *MockAccessRequester) SetSession(arg0 fosite.Session) { + m.ctrl.Call(m, "SetSession", arg0) +} + +// SetSession indicates an expected call of SetSession +func (mr *MockAccessRequesterMockRecorder) SetSession(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetSession", reflect.TypeOf((*MockAccessRequester)(nil).SetSession), arg0) +} diff --git a/internal/mocks/fosite_access_token_storage.go b/internal/mocks/fosite_access_token_storage.go new file mode 100644 index 00000000000..5c3faa47af6 --- /dev/null +++ b/internal/mocks/fosite_access_token_storage.go @@ -0,0 +1,73 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/ory/fosite/handler/oauth2 (interfaces: AccessTokenStorage) + +// Package internal is a generated GoMock package. +package internal + +import ( + context "context" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + fosite "github.com/ory/fosite" +) + +// MockAccessTokenStorage is a mock of AccessTokenStorage interface +type MockAccessTokenStorage struct { + ctrl *gomock.Controller + recorder *MockAccessTokenStorageMockRecorder +} + +// MockAccessTokenStorageMockRecorder is the mock recorder for MockAccessTokenStorage +type MockAccessTokenStorageMockRecorder struct { + mock *MockAccessTokenStorage +} + +// NewMockAccessTokenStorage creates a new mock instance +func NewMockAccessTokenStorage(ctrl *gomock.Controller) *MockAccessTokenStorage { + mock := &MockAccessTokenStorage{ctrl: ctrl} + mock.recorder = &MockAccessTokenStorageMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockAccessTokenStorage) EXPECT() *MockAccessTokenStorageMockRecorder { + return m.recorder +} + +// CreateAccessTokenSession mocks base method +func (m *MockAccessTokenStorage) CreateAccessTokenSession(arg0 context.Context, arg1 string, arg2 fosite.Requester) error { + ret := m.ctrl.Call(m, "CreateAccessTokenSession", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// CreateAccessTokenSession indicates an expected call of CreateAccessTokenSession +func (mr *MockAccessTokenStorageMockRecorder) CreateAccessTokenSession(arg0, arg1, arg2 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateAccessTokenSession", reflect.TypeOf((*MockAccessTokenStorage)(nil).CreateAccessTokenSession), arg0, arg1, arg2) +} + +// DeleteAccessTokenSession mocks base method +func (m *MockAccessTokenStorage) DeleteAccessTokenSession(arg0 context.Context, arg1 string) error { + ret := m.ctrl.Call(m, "DeleteAccessTokenSession", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteAccessTokenSession indicates an expected call of DeleteAccessTokenSession +func (mr *MockAccessTokenStorageMockRecorder) DeleteAccessTokenSession(arg0, arg1 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAccessTokenSession", reflect.TypeOf((*MockAccessTokenStorage)(nil).DeleteAccessTokenSession), arg0, arg1) +} + +// GetAccessTokenSession mocks base method +func (m *MockAccessTokenStorage) GetAccessTokenSession(arg0 context.Context, arg1 string, arg2 fosite.Session) (fosite.Requester, error) { + ret := m.ctrl.Call(m, "GetAccessTokenSession", arg0, arg1, arg2) + ret0, _ := ret[0].(fosite.Requester) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetAccessTokenSession indicates an expected call of GetAccessTokenSession +func (mr *MockAccessTokenStorageMockRecorder) GetAccessTokenSession(arg0, arg1, arg2 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccessTokenSession", reflect.TypeOf((*MockAccessTokenStorage)(nil).GetAccessTokenSession), arg0, arg1, arg2) +} diff --git a/internal/mocks/fosite_access_token_strategy.go b/internal/mocks/fosite_access_token_strategy.go new file mode 100644 index 00000000000..8b143f0331e --- /dev/null +++ b/internal/mocks/fosite_access_token_strategy.go @@ -0,0 +1,74 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/ory/fosite/handler/oauth2 (interfaces: AccessTokenStrategy) + +// Package internal is a generated GoMock package. +package internal + +import ( + context "context" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + fosite "github.com/ory/fosite" +) + +// MockAccessTokenStrategy is a mock of AccessTokenStrategy interface +type MockAccessTokenStrategy struct { + ctrl *gomock.Controller + recorder *MockAccessTokenStrategyMockRecorder +} + +// MockAccessTokenStrategyMockRecorder is the mock recorder for MockAccessTokenStrategy +type MockAccessTokenStrategyMockRecorder struct { + mock *MockAccessTokenStrategy +} + +// NewMockAccessTokenStrategy creates a new mock instance +func NewMockAccessTokenStrategy(ctrl *gomock.Controller) *MockAccessTokenStrategy { + mock := &MockAccessTokenStrategy{ctrl: ctrl} + mock.recorder = &MockAccessTokenStrategyMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockAccessTokenStrategy) EXPECT() *MockAccessTokenStrategyMockRecorder { + return m.recorder +} + +// AccessTokenSignature mocks base method +func (m *MockAccessTokenStrategy) AccessTokenSignature(arg0 string) string { + ret := m.ctrl.Call(m, "AccessTokenSignature", arg0) + ret0, _ := ret[0].(string) + return ret0 +} + +// AccessTokenSignature indicates an expected call of AccessTokenSignature +func (mr *MockAccessTokenStrategyMockRecorder) AccessTokenSignature(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AccessTokenSignature", reflect.TypeOf((*MockAccessTokenStrategy)(nil).AccessTokenSignature), arg0) +} + +// GenerateAccessToken mocks base method +func (m *MockAccessTokenStrategy) GenerateAccessToken(arg0 context.Context, arg1 fosite.Requester) (string, string, error) { + ret := m.ctrl.Call(m, "GenerateAccessToken", arg0, arg1) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(string) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// GenerateAccessToken indicates an expected call of GenerateAccessToken +func (mr *MockAccessTokenStrategyMockRecorder) GenerateAccessToken(arg0, arg1 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GenerateAccessToken", reflect.TypeOf((*MockAccessTokenStrategy)(nil).GenerateAccessToken), arg0, arg1) +} + +// ValidateAccessToken mocks base method +func (m *MockAccessTokenStrategy) ValidateAccessToken(arg0 context.Context, arg1 fosite.Requester, arg2 string) error { + ret := m.ctrl.Call(m, "ValidateAccessToken", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// ValidateAccessToken indicates an expected call of ValidateAccessToken +func (mr *MockAccessTokenStrategyMockRecorder) ValidateAccessToken(arg0, arg1, arg2 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateAccessToken", reflect.TypeOf((*MockAccessTokenStrategy)(nil).ValidateAccessToken), arg0, arg1, arg2) +} diff --git a/internal/mocks/fosite_client.go b/internal/mocks/fosite_client.go new file mode 100644 index 00000000000..3863552a7e1 --- /dev/null +++ b/internal/mocks/fosite_client.go @@ -0,0 +1,119 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/ory/fosite (interfaces: Client) + +// Package internal is a generated GoMock package. +package internal + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + fosite "github.com/ory/fosite" +) + +// MockClient is a mock of Client interface +type MockClient struct { + ctrl *gomock.Controller + recorder *MockClientMockRecorder +} + +// MockClientMockRecorder is the mock recorder for MockClient +type MockClientMockRecorder struct { + mock *MockClient +} + +// NewMockClient creates a new mock instance +func NewMockClient(ctrl *gomock.Controller) *MockClient { + mock := &MockClient{ctrl: ctrl} + mock.recorder = &MockClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockClient) EXPECT() *MockClientMockRecorder { + return m.recorder +} + +// GetGrantTypes mocks base method +func (m *MockClient) GetGrantTypes() fosite.Arguments { + ret := m.ctrl.Call(m, "GetGrantTypes") + ret0, _ := ret[0].(fosite.Arguments) + return ret0 +} + +// GetGrantTypes indicates an expected call of GetGrantTypes +func (mr *MockClientMockRecorder) GetGrantTypes() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGrantTypes", reflect.TypeOf((*MockClient)(nil).GetGrantTypes)) +} + +// GetHashedSecret mocks base method +func (m *MockClient) GetHashedSecret() []byte { + ret := m.ctrl.Call(m, "GetHashedSecret") + ret0, _ := ret[0].([]byte) + return ret0 +} + +// GetHashedSecret indicates an expected call of GetHashedSecret +func (mr *MockClientMockRecorder) GetHashedSecret() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHashedSecret", reflect.TypeOf((*MockClient)(nil).GetHashedSecret)) +} + +// GetID mocks base method +func (m *MockClient) GetID() string { + ret := m.ctrl.Call(m, "GetID") + ret0, _ := ret[0].(string) + return ret0 +} + +// GetID indicates an expected call of GetID +func (mr *MockClientMockRecorder) GetID() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetID", reflect.TypeOf((*MockClient)(nil).GetID)) +} + +// GetRedirectURIs mocks base method +func (m *MockClient) GetRedirectURIs() []string { + ret := m.ctrl.Call(m, "GetRedirectURIs") + ret0, _ := ret[0].([]string) + return ret0 +} + +// GetRedirectURIs indicates an expected call of GetRedirectURIs +func (mr *MockClientMockRecorder) GetRedirectURIs() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRedirectURIs", reflect.TypeOf((*MockClient)(nil).GetRedirectURIs)) +} + +// GetResponseTypes mocks base method +func (m *MockClient) GetResponseTypes() fosite.Arguments { + ret := m.ctrl.Call(m, "GetResponseTypes") + ret0, _ := ret[0].(fosite.Arguments) + return ret0 +} + +// GetResponseTypes indicates an expected call of GetResponseTypes +func (mr *MockClientMockRecorder) GetResponseTypes() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetResponseTypes", reflect.TypeOf((*MockClient)(nil).GetResponseTypes)) +} + +// GetScopes mocks base method +func (m *MockClient) GetScopes() fosite.Arguments { + ret := m.ctrl.Call(m, "GetScopes") + ret0, _ := ret[0].(fosite.Arguments) + return ret0 +} + +// GetScopes indicates an expected call of GetScopes +func (mr *MockClientMockRecorder) GetScopes() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetScopes", reflect.TypeOf((*MockClient)(nil).GetScopes)) +} + +// IsPublic mocks base method +func (m *MockClient) IsPublic() bool { + ret := m.ctrl.Call(m, "IsPublic") + ret0, _ := ret[0].(bool) + return ret0 +} + +// IsPublic indicates an expected call of IsPublic +func (mr *MockClientMockRecorder) IsPublic() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsPublic", reflect.TypeOf((*MockClient)(nil).IsPublic)) +} diff --git a/internal/mocks/hydra_key_manager.go b/internal/mocks/hydra_key_manager.go new file mode 100644 index 00000000000..bf8ca61fe9a --- /dev/null +++ b/internal/mocks/hydra_key_manager.go @@ -0,0 +1,109 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/ory/hydra/jwk (interfaces: Manager) + +// Package internal is a generated GoMock package. +package internal + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + go_jose "github.com/square/go-jose" +) + +// MockManager is a mock of Manager interface +type MockManager struct { + ctrl *gomock.Controller + recorder *MockManagerMockRecorder +} + +// MockManagerMockRecorder is the mock recorder for MockManager +type MockManagerMockRecorder struct { + mock *MockManager +} + +// NewMockManager creates a new mock instance +func NewMockManager(ctrl *gomock.Controller) *MockManager { + mock := &MockManager{ctrl: ctrl} + mock.recorder = &MockManagerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockManager) EXPECT() *MockManagerMockRecorder { + return m.recorder +} + +// AddKey mocks base method +func (m *MockManager) AddKey(arg0 string, arg1 *go_jose.JSONWebKey) error { + ret := m.ctrl.Call(m, "AddKey", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// AddKey indicates an expected call of AddKey +func (mr *MockManagerMockRecorder) AddKey(arg0, arg1 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddKey", reflect.TypeOf((*MockManager)(nil).AddKey), arg0, arg1) +} + +// AddKeySet mocks base method +func (m *MockManager) AddKeySet(arg0 string, arg1 *go_jose.JSONWebKeySet) error { + ret := m.ctrl.Call(m, "AddKeySet", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// AddKeySet indicates an expected call of AddKeySet +func (mr *MockManagerMockRecorder) AddKeySet(arg0, arg1 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddKeySet", reflect.TypeOf((*MockManager)(nil).AddKeySet), arg0, arg1) +} + +// DeleteKey mocks base method +func (m *MockManager) DeleteKey(arg0, arg1 string) error { + ret := m.ctrl.Call(m, "DeleteKey", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteKey indicates an expected call of DeleteKey +func (mr *MockManagerMockRecorder) DeleteKey(arg0, arg1 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteKey", reflect.TypeOf((*MockManager)(nil).DeleteKey), arg0, arg1) +} + +// DeleteKeySet mocks base method +func (m *MockManager) DeleteKeySet(arg0 string) error { + ret := m.ctrl.Call(m, "DeleteKeySet", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteKeySet indicates an expected call of DeleteKeySet +func (mr *MockManagerMockRecorder) DeleteKeySet(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteKeySet", reflect.TypeOf((*MockManager)(nil).DeleteKeySet), arg0) +} + +// GetKey mocks base method +func (m *MockManager) GetKey(arg0, arg1 string) (*go_jose.JSONWebKeySet, error) { + ret := m.ctrl.Call(m, "GetKey", arg0, arg1) + ret0, _ := ret[0].(*go_jose.JSONWebKeySet) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetKey indicates an expected call of GetKey +func (mr *MockManagerMockRecorder) GetKey(arg0, arg1 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKey", reflect.TypeOf((*MockManager)(nil).GetKey), arg0, arg1) +} + +// GetKeySet mocks base method +func (m *MockManager) GetKeySet(arg0 string) (*go_jose.JSONWebKeySet, error) { + ret := m.ctrl.Call(m, "GetKeySet", arg0) + ret0, _ := ret[0].(*go_jose.JSONWebKeySet) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetKeySet indicates an expected call of GetKeySet +func (mr *MockManagerMockRecorder) GetKeySet(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKeySet", reflect.TypeOf((*MockManager)(nil).GetKeySet), arg0) +} diff --git a/oauth2/flow_jwt_bearer.go b/oauth2/flow_jwt_bearer.go new file mode 100644 index 00000000000..dc2cba48570 --- /dev/null +++ b/oauth2/flow_jwt_bearer.go @@ -0,0 +1,174 @@ +package oauth2 + +import ( + "context" + "crypto/rsa" + "fmt" + "strings" + "time" + + "github.com/ory/hydra/jwk" + + "github.com/dgrijalva/jwt-go" + "github.com/ory/fosite" + "github.com/ory/fosite/compose" + "github.com/ory/fosite/handler/oauth2" + "github.com/ory/fosite/handler/openid" + "github.com/pkg/errors" +) + +// JWT bearer grant type mark. According to the latest https://tools.ietf.org/html/rfc7523 +const jwtBearerGrantType = "urn:ietf:params:oauth:grant-type:jwt-bearer" + +// JWTBearerGrantFactory function for creating handler for JWT Bearer Grant +func JWTBearerGrantFactory(config *compose.Config, storage interface{}, strategy interface{}) interface{} { + return &JWTBearerGrantHandler{ + HandleHelper: &oauth2.HandleHelper{ + AccessTokenStrategy: strategy.(oauth2.AccessTokenStrategy), + AccessTokenStorage: storage.(oauth2.AccessTokenStorage), + AccessTokenLifespan: config.GetAccessTokenLifespan(), + }, + ScopeStrategy: fosite.HierarchicScopeStrategy, + KeyManager: storage.(CommonStore).KeyManager, + Audience: strings.Trim(storage.(CommonStore).ClusterURL, "/") + "/oauth2/token", + } +} + +// JWTBearerGrantHandler handles JWT bearer flow +type JWTBearerGrantHandler struct { + *oauth2.HandleHelper + ScopeStrategy fosite.ScopeStrategy + KeyManager jwk.Manager + Audience string +} + +// HandleTokenEndpointRequest implements https://tools.ietf.org/html/rfc7523#section-3 +func (c *JWTBearerGrantHandler) HandleTokenEndpointRequest(ctx context.Context, request fosite.AccessRequester) error { + // grant_type REQUIRED. + // Value MUST be set to "urn:ietf:params:oauth:client-assertion-type:jwt-bearer". + if !request.GetGrantTypes().Exact(jwtBearerGrantType) { + return errors.WithStack(fosite.ErrUnknownRequest) + } + + client := request.GetClient() + + if !client.GetGrantTypes().Has(jwtBearerGrantType) { + return errors.Wrap(fosite.ErrInvalidGrant, + fmt.Sprintf("The client is not allowed to use grant type %s", jwtBearerGrantType)) + } + + for _, scope := range request.GetRequestedScopes() { + if !c.ScopeStrategy(client.GetScopes(), scope) { + return errors.Wrap(fosite.ErrInvalidScope, fmt.Sprintf("The client is not allowed to request scope %s", scope)) + } + } + + // assertion REQUIRED. + // Value MUST be set to JWT string value. + jwtToken := request.GetRequestForm().Get("assertion") + if jwtToken == "" { + return errors.Wrap(fosite.ErrInvalidRequest, "Field 'assertion' is missing") + } + + token, err := jwt.Parse(jwtToken, func(token *jwt.Token) (interface{}, error) { + // We stick to this option: https://tools.ietf.org/html/rfc7515#section-4.1.4 + keyID, _ := token.Header["kid"].(string) + if keyID == "" { + return nil, fmt.Errorf("your key-set ID should be present in 'kid' of the JOSE header") + } + switch token.Method.(type) { + case *jwt.SigningMethodRSA, *jwt.SigningMethodECDSA: + ks, err := c.KeyManager.GetKey(keyID, "public") + if err != nil { + return nil, err + } + rsaKey, ok := jwk.First(ks.Keys).Key.(*rsa.PublicKey) + if !ok { + return nil, fmt.Errorf("could not convert to RSA Public Key") + } + return rsaKey, nil + default: + return nil, fmt.Errorf("unexpected signing method: '%v'. We support only RSA, ECDSA", token.Header["alg"]) + } + }) + + if err != nil { + // Catch possible jwt.Parse errors. + if e, ok := errors.Cause(err).(*jwt.ValidationError); ok { + switch e.Errors { + case jwt.ValidationErrorUnverifiable, jwt.ValidationErrorSignatureInvalid: + return errors.Wrap(fosite.ErrTokenSignatureMismatch, err.Error()) + case jwt.ValidationErrorExpired: + return errors.Wrap(fosite.ErrTokenExpired, err.Error()) + case jwt.ValidationErrorIssuedAt: + return errors.Wrap(fosite.ErrInactiveToken, err.Error()) + default: + return errors.Wrap(fosite.ErrInvalidTokenFormat, err.Error()) + } + } + // It means we have some unknown error. + return errors.Wrap(fosite.ErrServerError, err.Error()) + } + + claims, ok := token.Claims.(jwt.MapClaims) + if !ok { + return errors.Wrap(fosite.ErrInvalidTokenFormat, "JWT claims were not found or are malformed") + } + + // For https://tools.ietf.org/html/rfc7523#section-3.1 + // We check that client ID obtained from the Basic auth is what was stated as 'iss' in JWT. + // Otherwise it can be a client ID forgery attempt. + if !claims.VerifyIssuer(client.GetID(), true) { + return errors.Wrap(fosite.ErrTokenClaim, "Issuer (iss) claim should be present and should be your client ID") + } + // For https://tools.ietf.org/html/rfc7523#section-3.2 + if val, ok := claims["sub"].(string); !ok || val == "" { + return errors.Wrap(fosite.ErrTokenClaim, "Subject (sub) claim should be a nonempty string") + } + // For https://tools.ietf.org/html/rfc7523#section-3.3 + if !claims.VerifyAudience(c.Audience, true) { + return errors.Wrap(fosite.ErrTokenClaim, "Audience (aud) is invalid or missing") + } + // For https://tools.ietf.org/html/rfc7523#section-3.3 + // Actually jwt.Parse already checks exp value, but it does not require it to be present, so re-checking it. + if _, ok := claims["exp"]; !ok { + return errors.Wrap(fosite.ErrTokenClaim, "Expires (exp) claim should be present") + } + // Actually jwt.Parse already checks iat value, but it does not require it to be present, so re-checking it. + if _, ok := claims["iat"]; !ok { + return errors.Wrap(fosite.ErrTokenClaim, "Issued at (iat) claim should be present") + } + + // The client MUST authenticate with the authorization server as described in Section 3.2.1. + // in https://tools.ietf.org/html/rfc6749#section-3.2.1 + if client.IsPublic() { + return errors.Wrap(fosite.ErrInvalidGrant, + fmt.Sprintf("The client is public and thus not allowed to use grant type '%s'", jwtBearerGrantType)) + } + + session, ok := request.GetSession().(*Session) + if !ok { + return errors.WithStack(openid.ErrInvalidSession) + } + + session.SetExpiresAt(fosite.AccessToken, time.Now().Add(c.AccessTokenLifespan)) + session.Subject = claims["sub"].(string) + // Use custom claim for detecting tenant ID. + if val, ok := claims["tnt"]; ok && val != "" { + session.SetExtra("tenant", val) + } + return nil +} + +// PopulateTokenEndpointResponse implements https://tools.ietf.org/html/rfc6749#section-4.4.3 +func (c *JWTBearerGrantHandler) PopulateTokenEndpointResponse(ctx context.Context, request fosite.AccessRequester, response fosite.AccessResponder) error { + if !request.GetGrantTypes().Exact(jwtBearerGrantType) { + return errors.WithStack(fosite.ErrUnknownRequest) + } + + if !request.GetClient().GetGrantTypes().Has(jwtBearerGrantType) { + return errors.Wrap(fosite.ErrInvalidGrant, fmt.Sprintf("The client is not allowed to use grant type %s", jwtBearerGrantType)) + } + + return c.IssueAccessToken(ctx, request, response) +} diff --git a/oauth2/flow_jwt_bearer_test.go b/oauth2/flow_jwt_bearer_test.go new file mode 100644 index 00000000000..5f779bb2541 --- /dev/null +++ b/oauth2/flow_jwt_bearer_test.go @@ -0,0 +1,646 @@ +package oauth2 + +import ( + "crypto/rsa" + "fmt" + "net/http" + "net/url" + "testing" + "time" + + "github.com/ory/hydra/internal/mocks" + "github.com/ory/hydra/jwk" + + "github.com/dgrijalva/jwt-go" + "github.com/golang/mock/gomock" + "github.com/ory/fosite" + "github.com/ory/fosite/handler/oauth2" + "github.com/ory/fosite/handler/openid" + "github.com/pkg/errors" + "github.com/square/go-jose" + "github.com/stretchr/testify/assert" +) + +func TestJWTBearerFlow_HandleTokenEndpointRequest_Validation(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + strategy := internal.NewMockAccessTokenStrategy(ctrl) + store := internal.NewMockAccessTokenStorage(ctrl) + keyManager := internal.NewMockManager(ctrl) + areq := internal.NewMockAccessRequester(ctrl) + client := internal.NewMockClient(ctrl) + + h := JWTBearerGrantHandler{ + HandleHelper: &oauth2.HandleHelper{ + AccessTokenStrategy: strategy, + AccessTokenStorage: store, + AccessTokenLifespan: time.Hour, + }, + ScopeStrategy: fosite.HierarchicScopeStrategy, + KeyManager: keyManager, + Audience: "http://hydra-cluster.url/oauth2/token", + } + + for k, c := range []struct { + description string + mock func() + req *http.Request + expectErr error + expectErrorMsg string + }{ + { + description: "should fail because request handler is not responsible for this grant type", + expectErr: fosite.ErrUnknownRequest, + expectErrorMsg: "not responsible", + mock: func() { + areq.EXPECT().GetGrantTypes().Return(fosite.Arguments{""}) + }, + }, + { + description: "should fail because client is not assigned to this this grant type", + expectErr: fosite.ErrInvalidGrant, + expectErrorMsg: "client is not allowed to use grant type " + jwtBearerGrantType, + mock: func() { + client.EXPECT().GetGrantTypes().Return(fosite.Arguments{""}) + + areq.EXPECT().GetGrantTypes().Return(fosite.Arguments{jwtBearerGrantType}) + areq.EXPECT().GetClient().Return(client) + }, + }, + { + description: "should fail because client is not assigned to this scope", + expectErr: fosite.ErrInvalidScope, + expectErrorMsg: "client is not allowed to request scope foo-scope", + mock: func() { + client.EXPECT().GetGrantTypes().Return(fosite.Arguments{jwtBearerGrantType}) + client.EXPECT().GetScopes().Return(fosite.Arguments{"bar-scope"}) + areq.EXPECT().GetClient().Return(client) + + areq.EXPECT().GetGrantTypes().Return(fosite.Arguments{jwtBearerGrantType}) + areq.EXPECT().GetRequestedScopes().Return(fosite.Arguments{"foo-scope"}) + }, + }, + { + description: "should fail because field 'assertion' is missing in the request form data", + expectErr: fosite.ErrInvalidRequest, + expectErrorMsg: "'assertion' is missing", + mock: func() { + client.EXPECT().GetGrantTypes().Return(fosite.Arguments{jwtBearerGrantType}) + client.EXPECT().GetScopes().Return(fosite.Arguments{"foo-scope"}) + areq.EXPECT().GetClient().Return(client) + + areq.EXPECT().GetGrantTypes().Return(fosite.Arguments{jwtBearerGrantType}) + areq.EXPECT().GetRequestedScopes().Return(fosite.Arguments{"foo-scope"}) + + areq.EXPECT().GetRequestForm().Return(url.Values{}) + }, + }, + { + description: "should fail because 'assertion' field contains malformed JWT token", + expectErr: fosite.ErrInvalidTokenFormat, + expectErrorMsg: "token contains an invalid number of segments", + mock: func() { + client.EXPECT().GetGrantTypes().Return(fosite.Arguments{jwtBearerGrantType}) + client.EXPECT().GetScopes().Return(fosite.Arguments{"foo-scope"}) + areq.EXPECT().GetClient().Return(client) + + areq.EXPECT().GetGrantTypes().Return(fosite.Arguments{jwtBearerGrantType}) + areq.EXPECT().GetRequestedScopes().Return(fosite.Arguments{"foo-scope"}) + + areq.EXPECT().GetRequestForm().Return(url.Values{"assertion": []string{"i-am-not-jwt"}}) + }, + }, + { + description: "should fail because JWT JOSE header has no 'kid' field", + expectErr: fosite.ErrTokenSignatureMismatch, + expectErrorMsg: "key-set ID should be present in 'kid'", + mock: func() { + client.EXPECT().GetGrantTypes().Return(fosite.Arguments{jwtBearerGrantType}) + client.EXPECT().GetScopes().Return(fosite.Arguments{"foo-scope"}) + areq.EXPECT().GetClient().Return(client) + + areq.EXPECT().GetGrantTypes().Return(fosite.Arguments{jwtBearerGrantType}) + areq.EXPECT().GetRequestedScopes().Return(fosite.Arguments{"foo-scope"}) + + token := generateTestJWT(make(jwt.MapClaims), make(map[string]interface{}), generateTestJWKSet()) + areq.EXPECT().GetRequestForm().Return(url.Values{"assertion": []string{token}}) + }, + }, + { + description: "should fail because key set used to sign JWT was not found by Hydra's Key Manager", + expectErr: fosite.ErrTokenSignatureMismatch, + expectErrorMsg: "no key found", + mock: func() { + client.EXPECT().GetGrantTypes().Return(fosite.Arguments{jwtBearerGrantType}) + client.EXPECT().GetScopes().Return(fosite.Arguments{"foo-scope"}) + areq.EXPECT().GetClient().Return(client) + + areq.EXPECT().GetGrantTypes().Return(fosite.Arguments{jwtBearerGrantType}) + areq.EXPECT().GetRequestedScopes().Return(fosite.Arguments{"foo-scope"}) + + headers := map[string]interface{}{ + "kid": "123set", + } + token := generateTestJWT(make(jwt.MapClaims), headers, generateTestJWKSet()) + keyManager.EXPECT().GetKey("123set", "public").Return(nil, fmt.Errorf("no key found")) + areq.EXPECT().GetRequestForm().Return(url.Values{"assertion": []string{token}}) + }, + }, + { + description: "should fail because Hydra's Key Manager wrong key type", + expectErr: fosite.ErrTokenSignatureMismatch, + expectErrorMsg: "convert to RSA Public", + mock: func() { + client.EXPECT().GetGrantTypes().Return(fosite.Arguments{jwtBearerGrantType}) + client.EXPECT().GetScopes().Return(fosite.Arguments{"foo-scope"}) + areq.EXPECT().GetClient().Return(client) + + areq.EXPECT().GetGrantTypes().Return(fosite.Arguments{jwtBearerGrantType}) + areq.EXPECT().GetRequestedScopes().Return(fosite.Arguments{"foo-scope"}) + + jwkSet := generateTestJWKSet() + headers := map[string]interface{}{ + "kid": "123set", + } + token := generateTestJWT(make(jwt.MapClaims), headers, jwkSet) + keySetWithOnlyPrivateKey := &jose.JSONWebKeySet{Keys: []jose.JSONWebKey{jwkSet.Keys[0]}} + keyManager.EXPECT().GetKey("123set", "public").Return(keySetWithOnlyPrivateKey, nil) + areq.EXPECT().GetRequestForm().Return(url.Values{"assertion": []string{token}}) + }, + }, + { + description: "should fail because JWT was signed with the key of another client", + expectErr: fosite.ErrTokenSignatureMismatch, + mock: func() { + client.EXPECT().GetGrantTypes().Return(fosite.Arguments{jwtBearerGrantType}) + client.EXPECT().GetScopes().Return(fosite.Arguments{"foo-scope"}) + areq.EXPECT().GetClient().Return(client) + + areq.EXPECT().GetGrantTypes().Return(fosite.Arguments{jwtBearerGrantType}) + areq.EXPECT().GetRequestedScopes().Return(fosite.Arguments{"foo-scope"}) + + anotherClientJwkSet := generateTestJWKSet() + myJwkSet := generateTestJWKSet() + headers := map[string]interface{}{ + "kid": "123set", + } + token := generateTestJWT(make(jwt.MapClaims), headers, anotherClientJwkSet) + keySetWithOnlyPublicKey := &jose.JSONWebKeySet{Keys: []jose.JSONWebKey{myJwkSet.Keys[1]}} + keyManager.EXPECT().GetKey("123set", "public").Return(keySetWithOnlyPublicKey, nil) + areq.EXPECT().GetRequestForm().Return(url.Values{"assertion": []string{token}}) + }, + }, + { + description: "should fail because 'iss' in JWT has wrong client ID", + expectErr: fosite.ErrTokenClaim, + expectErrorMsg: "should be your client ID", + mock: func() { + client.EXPECT().GetGrantTypes().Return(fosite.Arguments{jwtBearerGrantType}) + client.EXPECT().GetScopes().Return(fosite.Arguments{"foo-scope"}) + client.EXPECT().GetID().Return("client_1") + + areq.EXPECT().GetClient().Return(client) + areq.EXPECT().GetGrantTypes().Return(fosite.Arguments{jwtBearerGrantType}) + areq.EXPECT().GetRequestedScopes().Return(fosite.Arguments{"foo-scope"}) + + jwkSet := generateTestJWKSet() + headers := map[string]interface{}{ + "kid": "123set", + } + claims := jwt.MapClaims{ + "iss": "client_2", + } + token := generateTestJWT(claims, headers, jwkSet) + keySetWithOnlyPublicKey := &jose.JSONWebKeySet{Keys: []jose.JSONWebKey{jwkSet.Keys[1]}} + keyManager.EXPECT().GetKey("123set", "public").Return(keySetWithOnlyPublicKey, nil) + areq.EXPECT().GetRequestForm().Return(url.Values{"assertion": []string{token}}) + }, + }, + { + description: "should fail because 'sub' is missing in JWT", + expectErr: fosite.ErrTokenClaim, + expectErrorMsg: "Subject (sub) claim should be a nonempty string", + mock: func() { + client.EXPECT().GetGrantTypes().Return(fosite.Arguments{jwtBearerGrantType}) + client.EXPECT().GetScopes().Return(fosite.Arguments{"foo-scope"}) + client.EXPECT().GetID().Return("client_1") + + areq.EXPECT().GetClient().Return(client) + areq.EXPECT().GetGrantTypes().Return(fosite.Arguments{jwtBearerGrantType}) + areq.EXPECT().GetRequestedScopes().Return(fosite.Arguments{"foo-scope"}) + + jwkSet := generateTestJWKSet() + headers := map[string]interface{}{ + "kid": "123set", + } + claims := jwt.MapClaims{ + "iss": "client_1", + } + token := generateTestJWT(claims, headers, jwkSet) + keySetWithOnlyPublicKey := &jose.JSONWebKeySet{Keys: []jose.JSONWebKey{jwkSet.Keys[1]}} + keyManager.EXPECT().GetKey("123set", "public").Return(keySetWithOnlyPublicKey, nil) + areq.EXPECT().GetRequestForm().Return(url.Values{"assertion": []string{token}}) + }, + }, + { + description: "should fail because 'sub' is an empty string", + expectErr: fosite.ErrTokenClaim, + expectErrorMsg: "Subject (sub) claim should be a nonempty string", + mock: func() { + client.EXPECT().GetGrantTypes().Return(fosite.Arguments{jwtBearerGrantType}) + client.EXPECT().GetScopes().Return(fosite.Arguments{"foo-scope"}) + client.EXPECT().GetID().Return("client_1") + + areq.EXPECT().GetClient().Return(client) + areq.EXPECT().GetGrantTypes().Return(fosite.Arguments{jwtBearerGrantType}) + areq.EXPECT().GetRequestedScopes().Return(fosite.Arguments{"foo-scope"}) + + jwkSet := generateTestJWKSet() + headers := map[string]interface{}{ + "kid": "123set", + } + claims := jwt.MapClaims{ + "iss": "client_1", + "sub": "", + } + token := generateTestJWT(claims, headers, jwkSet) + keySetWithOnlyPublicKey := &jose.JSONWebKeySet{Keys: []jose.JSONWebKey{jwkSet.Keys[1]}} + keyManager.EXPECT().GetKey("123set", "public").Return(keySetWithOnlyPublicKey, nil) + areq.EXPECT().GetRequestForm().Return(url.Values{"assertion": []string{token}}) + }, + }, + { + description: "should fail because 'sub' is not a string", + expectErr: fosite.ErrTokenClaim, + expectErrorMsg: "Subject (sub) claim should be a nonempty string", + mock: func() { + client.EXPECT().GetGrantTypes().Return(fosite.Arguments{jwtBearerGrantType}) + client.EXPECT().GetScopes().Return(fosite.Arguments{"foo-scope"}) + client.EXPECT().GetID().Return("client_1") + + areq.EXPECT().GetClient().Return(client) + areq.EXPECT().GetGrantTypes().Return(fosite.Arguments{jwtBearerGrantType}) + areq.EXPECT().GetRequestedScopes().Return(fosite.Arguments{"foo-scope"}) + + jwkSet := generateTestJWKSet() + headers := map[string]interface{}{ + "kid": "123set", + } + claims := jwt.MapClaims{ + "iss": "client_1", + "sub": 123456, + } + token := generateTestJWT(claims, headers, jwkSet) + keySetWithOnlyPublicKey := &jose.JSONWebKeySet{Keys: []jose.JSONWebKey{jwkSet.Keys[1]}} + keyManager.EXPECT().GetKey("123set", "public").Return(keySetWithOnlyPublicKey, nil) + areq.EXPECT().GetRequestForm().Return(url.Values{"assertion": []string{token}}) + }, + }, + { + description: "should fail because 'aud' in JWT has unknown OAuth2 token endpoint URL", + expectErr: fosite.ErrTokenClaim, + expectErrorMsg: "Audience (aud) is invalid or missing", + mock: func() { + client.EXPECT().GetGrantTypes().Return(fosite.Arguments{jwtBearerGrantType}) + client.EXPECT().GetScopes().Return(fosite.Arguments{"foo-scope"}) + client.EXPECT().GetID().Return("client_1") + + areq.EXPECT().GetClient().Return(client) + areq.EXPECT().GetGrantTypes().Return(fosite.Arguments{jwtBearerGrantType}) + areq.EXPECT().GetRequestedScopes().Return(fosite.Arguments{"foo-scope"}) + + jwkSet := generateTestJWKSet() + headers := map[string]interface{}{ + "kid": "123set", + } + claims := jwt.MapClaims{ + "iss": "client_1", + "sub": "some-user-id", + "aud": "http://some-other-oauth2-cluster.url", + } + token := generateTestJWT(claims, headers, jwkSet) + keySetWithOnlyPublicKey := &jose.JSONWebKeySet{Keys: []jose.JSONWebKey{jwkSet.Keys[1]}} + keyManager.EXPECT().GetKey("123set", "public").Return(keySetWithOnlyPublicKey, nil) + areq.EXPECT().GetRequestForm().Return(url.Values{"assertion": []string{token}}) + }, + }, + { + description: "should fail because 'exp' token time is expired", + expectErr: fosite.ErrTokenExpired, + mock: func() { + client.EXPECT().GetGrantTypes().Return(fosite.Arguments{jwtBearerGrantType}) + client.EXPECT().GetScopes().Return(fosite.Arguments{"foo-scope"}) + + areq.EXPECT().GetClient().Return(client) + areq.EXPECT().GetGrantTypes().Return(fosite.Arguments{jwtBearerGrantType}) + areq.EXPECT().GetRequestedScopes().Return(fosite.Arguments{"foo-scope"}) + + jwkSet := generateTestJWKSet() + headers := map[string]interface{}{ + "kid": "123set", + } + claims := jwt.MapClaims{ + "iss": "client_1", + "sub": "some-user-id", + "aud": "http://hydra-cluster.url/oauth2/token", + "exp": time.Now().Add(time.Hour * -24).Unix(), + } + token := generateTestJWT(claims, headers, jwkSet) + keySetWithOnlyPublicKey := &jose.JSONWebKeySet{Keys: []jose.JSONWebKey{jwkSet.Keys[1]}} + keyManager.EXPECT().GetKey("123set", "public").Return(keySetWithOnlyPublicKey, nil) + areq.EXPECT().GetRequestForm().Return(url.Values{"assertion": []string{token}}) + }, + }, + { + description: "should fail because 'exp' claim is not set in JWT", + expectErr: fosite.ErrTokenClaim, + expectErrorMsg: "Expires (exp) claim should be present", + mock: func() { + client.EXPECT().GetGrantTypes().Return(fosite.Arguments{jwtBearerGrantType}) + client.EXPECT().GetScopes().Return(fosite.Arguments{"foo-scope"}) + client.EXPECT().GetID().Return("client_1") + + areq.EXPECT().GetClient().Return(client) + areq.EXPECT().GetGrantTypes().Return(fosite.Arguments{jwtBearerGrantType}) + areq.EXPECT().GetRequestedScopes().Return(fosite.Arguments{"foo-scope"}) + + jwkSet := generateTestJWKSet() + headers := map[string]interface{}{ + "kid": "123set", + } + claims := jwt.MapClaims{ + "iss": "client_1", + "sub": "some-user-id", + "aud": "http://hydra-cluster.url/oauth2/token", + } + token := generateTestJWT(claims, headers, jwkSet) + keySetWithOnlyPublicKey := &jose.JSONWebKeySet{Keys: []jose.JSONWebKey{jwkSet.Keys[1]}} + keyManager.EXPECT().GetKey("123set", "public").Return(keySetWithOnlyPublicKey, nil) + areq.EXPECT().GetRequestForm().Return(url.Values{"assertion": []string{token}}) + }, + }, + { + description: "should fail because JWT is issued at some time in the future", + expectErr: fosite.ErrInactiveToken, + mock: func() { + client.EXPECT().GetGrantTypes().Return(fosite.Arguments{jwtBearerGrantType}) + client.EXPECT().GetScopes().Return(fosite.Arguments{"foo-scope"}) + + areq.EXPECT().GetClient().Return(client) + areq.EXPECT().GetGrantTypes().Return(fosite.Arguments{jwtBearerGrantType}) + areq.EXPECT().GetRequestedScopes().Return(fosite.Arguments{"foo-scope"}) + + jwkSet := generateTestJWKSet() + headers := map[string]interface{}{ + "kid": "123set", + } + claims := jwt.MapClaims{ + "iat": time.Now().Add(time.Hour * 200).Unix(), + } + token := generateTestJWT(claims, headers, jwkSet) + keySetWithOnlyPublicKey := &jose.JSONWebKeySet{Keys: []jose.JSONWebKey{jwkSet.Keys[1]}} + keyManager.EXPECT().GetKey("123set", "public").Return(keySetWithOnlyPublicKey, nil) + areq.EXPECT().GetRequestForm().Return(url.Values{"assertion": []string{token}}) + }, + }, + { + description: "should fail because 'iat' claim is not set in JWT", + expectErr: fosite.ErrTokenClaim, + expectErrorMsg: "Issued at (iat) claim should be present", + mock: func() { + client.EXPECT().GetGrantTypes().Return(fosite.Arguments{jwtBearerGrantType}) + client.EXPECT().GetScopes().Return(fosite.Arguments{"foo-scope"}) + client.EXPECT().GetID().Return("client_1") + + areq.EXPECT().GetClient().Return(client) + areq.EXPECT().GetGrantTypes().Return(fosite.Arguments{jwtBearerGrantType}) + areq.EXPECT().GetRequestedScopes().Return(fosite.Arguments{"foo-scope"}) + + jwkSet := generateTestJWKSet() + headers := map[string]interface{}{ + "kid": "123set", + } + claims := jwt.MapClaims{ + "iss": "client_1", + "sub": "some-user-id", + "aud": "http://hydra-cluster.url/oauth2/token", + "exp": time.Now().Add(time.Hour).Unix(), + } + token := generateTestJWT(claims, headers, jwkSet) + keySetWithOnlyPublicKey := &jose.JSONWebKeySet{Keys: []jose.JSONWebKey{jwkSet.Keys[1]}} + keyManager.EXPECT().GetKey("123set", "public").Return(keySetWithOnlyPublicKey, nil) + areq.EXPECT().GetRequestForm().Return(url.Values{"assertion": []string{token}}) + }, + }, + { + description: "should fail because public clients are not allowed to use this grant-type", + expectErr: fosite.ErrInvalidGrant, + expectErrorMsg: "not allowed to use grant type", + mock: func() { + client.EXPECT().GetGrantTypes().Return(fosite.Arguments{jwtBearerGrantType}) + client.EXPECT().GetScopes().Return(fosite.Arguments{"foo-scope"}) + client.EXPECT().GetID().Return("client_1") + client.EXPECT().IsPublic().Return(true) + + areq.EXPECT().GetClient().Return(client) + areq.EXPECT().GetGrantTypes().Return(fosite.Arguments{jwtBearerGrantType}) + areq.EXPECT().GetRequestedScopes().Return(fosite.Arguments{"foo-scope"}) + + jwkSet := generateTestJWKSet() + headers := map[string]interface{}{ + "kid": "123set", + } + claims := jwt.MapClaims{ + "iss": "client_1", + "sub": "some-user-id", + "aud": "http://hydra-cluster.url/oauth2/token", + "iat": time.Now().Add(time.Hour * -1).Unix(), + "exp": time.Now().Add(time.Hour).Unix(), + } + token := generateTestJWT(claims, headers, jwkSet) + keySetWithOnlyPublicKey := &jose.JSONWebKeySet{Keys: []jose.JSONWebKey{jwkSet.Keys[1]}} + keyManager.EXPECT().GetKey("123set", "public").Return(keySetWithOnlyPublicKey, nil) + areq.EXPECT().GetRequestForm().Return(url.Values{"assertion": []string{token}}) + }, + }, + { + description: "should fail because the request session is not of an appropriate type", + expectErr: openid.ErrInvalidSession, + mock: func() { + client.EXPECT().GetGrantTypes().Return(fosite.Arguments{jwtBearerGrantType}) + client.EXPECT().GetScopes().Return(fosite.Arguments{"foo-scope"}) + client.EXPECT().GetID().Return("client_1") + client.EXPECT().IsPublic().Return(false) + + areq.EXPECT().GetClient().Return(client) + areq.EXPECT().GetGrantTypes().Return(fosite.Arguments{jwtBearerGrantType}) + areq.EXPECT().GetRequestedScopes().Return(fosite.Arguments{"foo-scope"}) + + areq.EXPECT().GetSession().Return(nil) + + jwkSet := generateTestJWKSet() + headers := map[string]interface{}{ + "kid": "123set", + } + claims := jwt.MapClaims{ + "iss": "client_1", + "sub": "some-user-id", + "aud": "http://hydra-cluster.url/oauth2/token", + "iat": time.Now().Add(time.Hour * -1).Unix(), + "exp": time.Now().Add(time.Hour).Unix(), + } + token := generateTestJWT(claims, headers, jwkSet) + keySetWithOnlyPublicKey := &jose.JSONWebKeySet{Keys: []jose.JSONWebKey{jwkSet.Keys[1]}} + keyManager.EXPECT().GetKey("123set", "public").Return(keySetWithOnlyPublicKey, nil) + areq.EXPECT().GetRequestForm().Return(url.Values{"assertion": []string{token}}) + }, + }, + } { + t.Logf("Running test case %d", k) + c.mock() + err := h.HandleTokenEndpointRequest(nil, areq) + assert.True(t, errors.Cause(err) == c.expectErr, "(%d) %s\nExpected: %s\nGot: %s", k, c.description, c.expectErr, err) + assert.Contains(t, err.Error(), c.expectErrorMsg, + "(%d) %s\nMessage expected to contain: %s\nGot: %s", k, c.description, c.expectErrorMsg, err.Error()) + } +} + +func TestJWTBearerFlow_HandleTokenEndpointRequest_SessionPopulation(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + strategy := internal.NewMockAccessTokenStrategy(ctrl) + store := internal.NewMockAccessTokenStorage(ctrl) + keyManager := internal.NewMockManager(ctrl) + areq := internal.NewMockAccessRequester(ctrl) + client := internal.NewMockClient(ctrl) + + h := JWTBearerGrantHandler{ + HandleHelper: &oauth2.HandleHelper{ + AccessTokenStrategy: strategy, + AccessTokenStorage: store, + AccessTokenLifespan: time.Hour, + }, + ScopeStrategy: fosite.HierarchicScopeStrategy, + KeyManager: keyManager, + Audience: "http://hydra-cluster.url/oauth2/token", + } + + client.EXPECT().GetGrantTypes().Return(fosite.Arguments{jwtBearerGrantType}) + client.EXPECT().GetScopes().Return(fosite.Arguments{"foo-scope"}) + client.EXPECT().GetID().Return("client_1") + client.EXPECT().IsPublic().Return(false) + + areq.EXPECT().GetClient().Return(client) + areq.EXPECT().GetGrantTypes().Return(fosite.Arguments{jwtBearerGrantType}) + areq.EXPECT().GetRequestedScopes().Return(fosite.Arguments{"foo-scope"}) + + jwkSet := generateTestJWKSet() + headers := map[string]interface{}{ + "kid": "123set", + } + claims := jwt.MapClaims{ + "iss": "client_1", + "sub": "some-user-id", + "aud": "http://hydra-cluster.url/oauth2/token", + "iat": time.Now().Add(time.Hour * -1).Unix(), + "exp": time.Now().Add(time.Hour).Unix(), + "tnt": "some_tenant_1", + } + token := generateTestJWT(claims, headers, jwkSet) + keySetWithOnlyPublicKey := &jose.JSONWebKeySet{Keys: []jose.JSONWebKey{jwkSet.Keys[1]}} + keyManager.EXPECT().GetKey("123set", "public").Return(keySetWithOnlyPublicKey, nil) + areq.EXPECT().GetRequestForm().Return(url.Values{"assertion": []string{token}}) + + session := &Session{ + DefaultSession: &openid.DefaultSession{}, + } + areq.EXPECT().GetSession().Return(session) + + err := h.HandleTokenEndpointRequest(nil, areq) + assert.Nil(t, err, "Should finish without errors") + assert.NotNil(t, session.ExpiresAt, "Should set expires to session") + assert.Equal(t, "some-user-id", session.Subject, "Should set Subject based on 'sub' claim") + assert.NotNil(t, session.Extra) + assert.Contains(t, session.Extra, "tenant") + assert.Equal(t, "some_tenant_1", session.Extra["tenant"], "Should set tenant to extra fields based on 'tnt' claim") +} + +func TestJWTBearerFlow_PopulateTokenEndpointResponse(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + store := internal.NewMockAccessTokenStorage(ctrl) + tokenStrategy := internal.NewMockAccessTokenStrategy(ctrl) + areq := fosite.NewAccessRequest(new(fosite.DefaultSession)) + aresp := fosite.NewAccessResponse() + + h := JWTBearerGrantHandler{ + HandleHelper: &oauth2.HandleHelper{ + AccessTokenStorage: store, + AccessTokenStrategy: tokenStrategy, + AccessTokenLifespan: time.Hour, + }, + ScopeStrategy: fosite.HierarchicScopeStrategy, + KeyManager: nil, + Audience: "http://hydra-cluster.url/oauth2/token", + } + + for k, c := range []struct { + description string + mock func() + req *http.Request + expectErr error + }{ + { + description: "should fail because not responsible", + expectErr: fosite.ErrUnknownRequest, + mock: func() { + areq.GrantTypes = fosite.Arguments{""} + }, + }, + { + description: "should fail because client not allowed", + expectErr: fosite.ErrInvalidGrant, + mock: func() { + areq.GrantTypes = fosite.Arguments{jwtBearerGrantType} + areq.Client = &fosite.DefaultClient{GrantTypes: fosite.Arguments{"foo"}} + }, + }, + { + description: "should pass", + mock: func() { + areq.GrantTypes = fosite.Arguments{jwtBearerGrantType} + areq.Session = &fosite.DefaultSession{} + areq.Client = &fosite.DefaultClient{GrantTypes: fosite.Arguments{jwtBearerGrantType}} + tokenStrategy.EXPECT().GenerateAccessToken(nil, areq).Return("tokenfoo.bar", "bar", nil) + store.EXPECT().CreateAccessTokenSession(nil, "bar", areq).Return(nil) + }, + }, + } { + t.Logf("Running test case %d", k) + c.mock() + err := h.PopulateTokenEndpointResponse(nil, areq, aresp) + assert.True(t, errors.Cause(err) == c.expectErr, "(%d) %s\nExpected: %s\nGot: %s", k, c.description, c.expectErr, err) + } +} + +// generateTestJWT creates a valid test RSA-256 signed JWT with provided claims. +func generateTestJWT(claims jwt.MapClaims, headers map[string]interface{}, keySet *jose.JSONWebKeySet) string { + token := jwt.New(jwt.SigningMethodRS256) + token.Claims = claims + for k, v := range headers { + token.Header[k] = v + } + + rsaKey, _ := jwk.First(keySet.Keys).Key.(*rsa.PrivateKey) + encoded, _ := token.SigningString() + signature, _ := token.Method.Sign(encoded, rsaKey) + + return fmt.Sprintf("%s.%s", encoded, signature) +} + +// generateTestJWKSet creates a RSA-256 JWK for test purposes. +func generateTestJWKSet() *jose.JSONWebKeySet { + var keyGenerator = &jwk.RS256Generator{} + pk, _ := keyGenerator.Generate("") + return pk +} diff --git a/oauth2/handler.go b/oauth2/handler.go index 8c533ff8365..d9062fcaf84 100644 --- a/oauth2/handler.go +++ b/oauth2/handler.go @@ -275,6 +275,14 @@ func (h *Handler) TokenHandler(w http.ResponseWriter, r *http.Request, _ httprou } } + if accessRequest.GetGrantTypes().Exact(jwtBearerGrantType) { + for _, scope := range accessRequest.GetRequestedScopes() { + if fosite.HierarchicScopeStrategy(accessRequest.GetClient().GetScopes(), scope) { + accessRequest.GrantScope(scope) + } + } + } + accessResponse, err := h.OAuth2.NewAccessResponse(ctx, accessRequest) if err != nil { pkg.LogError(err, h.L) diff --git a/oauth2/session.go b/oauth2/session.go index 2647d4d3fa3..12c552cf1e7 100644 --- a/oauth2/session.go +++ b/oauth2/session.go @@ -7,9 +7,12 @@ import ( "github.com/ory/fosite/token/jwt" ) +// SessionExtraInfo is for extra values we want to be stored in session +type SessionExtraInfo map[string]interface{} + type Session struct { *openid.DefaultSession `json:"idToken"` - Extra map[string]interface{} `json:"extra"` + Extra SessionExtraInfo `json:"extra"` } func NewSession(subject string) *Session { @@ -29,3 +32,13 @@ func (s *Session) Clone() fosite.Session { return deepcopy.Copy(s).(fosite.Session) } + +// SetExtra sets one extra attribute to session. +// Additionally lazy-allocated Extra field. +func (s *Session) SetExtra(key string, value interface{}) { + // Deferred initialization. + if s.Extra == nil { + s.Extra = make(SessionExtraInfo) + } + s.Extra[key] = value +} diff --git a/oauth2/session_test.go b/oauth2/session_test.go new file mode 100644 index 00000000000..2a5f24b29bb --- /dev/null +++ b/oauth2/session_test.go @@ -0,0 +1,25 @@ +package oauth2_test + +import ( + "testing" + + "github.com/ory/hydra/oauth2" + "github.com/stretchr/testify/assert" +) + +func TestSetExtra(t *testing.T) { + session := oauth2.NewSession("foo") + assert.Nil(t, session.Extra) + + session.SetExtra("one", 1) + assert.NotNil(t, session.Extra) + assert.Contains(t, session.Extra, "one") + assert.Equal(t, 1, session.Extra["one"]) + + session.SetExtra("two", 2) + assert.NotNil(t, session.Extra) + assert.Contains(t, session.Extra, "one") + assert.Equal(t, 1, session.Extra["one"]) + assert.Contains(t, session.Extra, "two") + assert.Equal(t, 2, session.Extra["two"]) +} diff --git a/oauth2/store.go b/oauth2/store.go new file mode 100644 index 00000000000..6dc147f05c0 --- /dev/null +++ b/oauth2/store.go @@ -0,0 +1,13 @@ +package oauth2 + +import ( + "github.com/ory/hydra/jwk" + "github.com/ory/hydra/pkg" +) + +// CommonStore is Hydra specific store that obtains additional information for the application. +type CommonStore struct { + pkg.FositeStorer + KeyManager jwk.Manager + ClusterURL string +}