diff --git a/access_request_handler.go b/access_request_handler.go index 40766baf8..dcf937191 100644 --- a/access_request_handler.go +++ b/access_request_handler.go @@ -60,7 +60,7 @@ func (f *Fosite) NewAccessRequest(ctx context.Context, r *http.Request, session return accessRequest, errors.Wrap(ErrInvalidRequest, "HTTP authorization header missing or invalid") } - client, err := f.Store.GetClient(clientID) + client, err := f.Store.GetClient(ctx, clientID) if err != nil { return accessRequest, errors.Wrap(ErrInvalidClient, err.Error()) } diff --git a/access_request_handler_test.go b/access_request_handler_test.go index f1f93216e..da188eb84 100644 --- a/access_request_handler_test.go +++ b/access_request_handler_test.go @@ -68,7 +68,7 @@ func TestNewAccessRequest(t *testing.T) { }, expectErr: ErrInvalidClient, mock: func() { - store.EXPECT().GetClient(gomock.Eq("foo")).Return(nil, errors.New("")) + store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(nil, errors.New("")) }, }, { @@ -92,7 +92,7 @@ func TestNewAccessRequest(t *testing.T) { }, expectErr: ErrInvalidClient, mock: func() { - store.EXPECT().GetClient(gomock.Eq("foo")).Return(nil, errors.New("")) + store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(nil, errors.New("")) }, }, { @@ -105,7 +105,7 @@ func TestNewAccessRequest(t *testing.T) { }, expectErr: ErrInvalidClient, mock: func() { - store.EXPECT().GetClient(gomock.Eq("foo")).Return(client, nil) + store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil) client.EXPECT().IsPublic().Return(false) client.EXPECT().GetHashedSecret().Return([]byte("foo")) hasher.EXPECT().Compare(gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(errors.New("")) @@ -121,7 +121,7 @@ func TestNewAccessRequest(t *testing.T) { }, expectErr: ErrServerError, mock: func() { - store.EXPECT().GetClient(gomock.Eq("foo")).Return(client, nil) + store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil) client.EXPECT().IsPublic().Return(false) client.EXPECT().GetHashedSecret().Return([]byte("foo")) hasher.EXPECT().Compare(gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil) @@ -138,7 +138,7 @@ func TestNewAccessRequest(t *testing.T) { "grant_type": {"foo"}, }, mock: func() { - store.EXPECT().GetClient(gomock.Eq("foo")).Return(client, nil) + store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil) client.EXPECT().IsPublic().Return(false) client.EXPECT().GetHashedSecret().Return([]byte("foo")) hasher.EXPECT().Compare(gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil) @@ -161,7 +161,7 @@ func TestNewAccessRequest(t *testing.T) { "grant_type": {"foo"}, }, mock: func() { - store.EXPECT().GetClient(gomock.Eq("foo")).Return(client, nil) + store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil) client.EXPECT().IsPublic().Return(true) handler.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) }, diff --git a/authorize_request_handler.go b/authorize_request_handler.go index b75660a87..31ac2564a 100644 --- a/authorize_request_handler.go +++ b/authorize_request_handler.go @@ -21,7 +21,7 @@ func (c *Fosite) NewAuthorizeRequest(ctx context.Context, r *http.Request) (Auth } request.Form = r.Form - client, err := c.Store.GetClient(request.GetRequestForm().Get("client_id")) + client, err := c.Store.GetClient(ctx, request.GetRequestForm().Get("client_id")) if err != nil { return request, errors.WithStack(ErrInvalidClient) } diff --git a/authorize_request_handler_test.go b/authorize_request_handler_test.go index 5c1b9c70a..2f775192a 100644 --- a/authorize_request_handler_test.go +++ b/authorize_request_handler_test.go @@ -43,7 +43,7 @@ func TestNewAuthorizeRequest(t *testing.T) { r: &http.Request{}, expectedError: ErrInvalidClient, mock: func() { - store.EXPECT().GetClient(gomock.Any()).Return(nil, errors.New("foo")) + store.EXPECT().GetClient(gomock.Any(), gomock.Any()).Return(nil, errors.New("foo")) }, }, /* invalid redirect uri */ @@ -53,7 +53,7 @@ func TestNewAuthorizeRequest(t *testing.T) { query: url.Values{"redirect_uri": []string{"invalid"}}, expectedError: ErrInvalidClient, mock: func() { - store.EXPECT().GetClient(gomock.Any()).Return(nil, errors.New("foo")) + store.EXPECT().GetClient(gomock.Any(), gomock.Any()).Return(nil, errors.New("foo")) }, }, /* invalid client */ @@ -63,7 +63,7 @@ func TestNewAuthorizeRequest(t *testing.T) { query: url.Values{"redirect_uri": []string{"https://foo.bar/cb"}}, expectedError: ErrInvalidClient, mock: func() { - store.EXPECT().GetClient(gomock.Any()).Return(nil, errors.New("foo")) + store.EXPECT().GetClient(gomock.Any(), gomock.Any()).Return(nil, errors.New("foo")) }, }, /* redirect client mismatch */ @@ -75,7 +75,7 @@ func TestNewAuthorizeRequest(t *testing.T) { }, expectedError: ErrInvalidRequest, mock: func() { - store.EXPECT().GetClient("1234").Return(&DefaultClient{RedirectURIs: []string{"invalid"}}, nil) + store.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"invalid"}}, nil) }, }, /* redirect client mismatch */ @@ -88,7 +88,7 @@ func TestNewAuthorizeRequest(t *testing.T) { }, expectedError: ErrInvalidRequest, mock: func() { - store.EXPECT().GetClient("1234").Return(&DefaultClient{RedirectURIs: []string{"invalid"}}, nil) + store.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"invalid"}}, nil) }, }, /* redirect client mismatch */ @@ -101,7 +101,7 @@ func TestNewAuthorizeRequest(t *testing.T) { }, expectedError: ErrInvalidRequest, mock: func() { - store.EXPECT().GetClient("1234").Return(&DefaultClient{RedirectURIs: []string{"invalid"}}, nil) + store.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"invalid"}}, nil) }, }, /* no state */ @@ -115,7 +115,7 @@ func TestNewAuthorizeRequest(t *testing.T) { }, expectedError: ErrInvalidState, mock: func() { - store.EXPECT().GetClient("1234").Return(&DefaultClient{RedirectURIs: []string{"https://foo.bar/cb"}}, nil) + store.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"https://foo.bar/cb"}}, nil) }, }, /* short state */ @@ -130,7 +130,7 @@ func TestNewAuthorizeRequest(t *testing.T) { }, expectedError: ErrInvalidState, mock: func() { - store.EXPECT().GetClient("1234").Return(&DefaultClient{RedirectURIs: []string{"https://foo.bar/cb"}}, nil) + store.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"https://foo.bar/cb"}}, nil) }, }, /* success case */ @@ -145,7 +145,7 @@ func TestNewAuthorizeRequest(t *testing.T) { "scope": {"foo bar"}, }, mock: func() { - store.EXPECT().GetClient("1234").Return(&DefaultClient{RedirectURIs: []string{"https://foo.bar/cb"}}, nil) + store.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"https://foo.bar/cb"}}, nil) }, expect: &AuthorizeRequest{ RedirectURI: redir, diff --git a/client_manager.go b/client_manager.go index 18f92542b..9ccd56f25 100644 --- a/client_manager.go +++ b/client_manager.go @@ -1,8 +1,10 @@ package fosite +import "context" + // ClientManager defines the (persistent) manager interface for clients. type ClientManager interface { // GetClient loads the client by its ID or returns an error // if the client does not exist or another error occurred. - GetClient(id string) (Client, error) + GetClient(ctx context.Context, id string) (Client, error) } diff --git a/internal/storage.go b/internal/storage.go index 822606c74..de90b2945 100644 --- a/internal/storage.go +++ b/internal/storage.go @@ -4,6 +4,8 @@ package internal import ( + context "context" + gomock "github.com/golang/mock/gomock" fosite "github.com/ory/fosite" ) @@ -29,13 +31,13 @@ func (_m *MockStorage) EXPECT() *_MockStorageRecorder { return _m.recorder } -func (_m *MockStorage) GetClient(_param0 string) (fosite.Client, error) { - ret := _m.ctrl.Call(_m, "GetClient", _param0) +func (_m *MockStorage) GetClient(_param0 context.Context, _param1 string) (fosite.Client, error) { + ret := _m.ctrl.Call(_m, "GetClient", _param0, _param1) ret0, _ := ret[0].(fosite.Client) ret1, _ := ret[1].(error) return ret0, ret1 } -func (_mr *_MockStorageRecorder) GetClient(arg0 interface{}) *gomock.Call { - return _mr.mock.ctrl.RecordCall(_mr.mock, "GetClient", arg0) +func (_mr *_MockStorageRecorder) GetClient(arg0, arg1 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "GetClient", arg0, arg1) } diff --git a/introspection_request_handler.go b/introspection_request_handler.go index 30599c5f0..c80217dec 100644 --- a/introspection_request_handler.go +++ b/introspection_request_handler.go @@ -111,7 +111,7 @@ func (f *Fosite) NewIntrospectionRequest(ctx context.Context, r *http.Request, s return &IntrospectionResponse{Active: false}, errors.Wrap(ErrRequestUnauthorized, "HTTP Authorization header missing, malformed or credentials used are invalid") } - client, err := f.Store.GetClient(clientID) + client, err := f.Store.GetClient(ctx, clientID) if err != nil { return &IntrospectionResponse{Active: false}, errors.Wrap(ErrRequestUnauthorized, "HTTP Authorization header missing, malformed or credentials used are invalid") } diff --git a/revoke_handler.go b/revoke_handler.go index 52bca64f0..75a0f186f 100644 --- a/revoke_handler.go +++ b/revoke_handler.go @@ -39,7 +39,7 @@ func (f *Fosite) NewRevocationRequest(ctx context.Context, r *http.Request) erro return errors.Wrap(ErrInvalidRequest, "HTTP Authorization header missing or invalid") } - client, err := f.Store.GetClient(clientID) + client, err := f.Store.GetClient(ctx, clientID) if err != nil { return errors.Wrap(ErrInvalidClient, err.Error()) } diff --git a/revoke_handler_test.go b/revoke_handler_test.go index 8bb254b58..a1eeaa61c 100644 --- a/revoke_handler_test.go +++ b/revoke_handler_test.go @@ -61,7 +61,7 @@ func TestNewRevocationRequest(t *testing.T) { }, expectErr: ErrInvalidClient, mock: func() { - store.EXPECT().GetClient(gomock.Eq("foo")).Return(nil, errors.New("")) + store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(nil, errors.New("")) }, }, { @@ -74,7 +74,7 @@ func TestNewRevocationRequest(t *testing.T) { }, expectErr: ErrInvalidClient, mock: func() { - store.EXPECT().GetClient(gomock.Eq("foo")).Return(client, nil) + store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil) client.EXPECT().GetHashedSecret().Return([]byte("foo")) client.EXPECT().IsPublic().Return(false) hasher.EXPECT().Compare(gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(errors.New("")) @@ -90,7 +90,7 @@ func TestNewRevocationRequest(t *testing.T) { }, expectErr: nil, mock: func() { - store.EXPECT().GetClient(gomock.Eq("foo")).Return(client, nil) + store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil) client.EXPECT().GetHashedSecret().Return([]byte("foo")) client.EXPECT().IsPublic().Return(false) hasher.EXPECT().Compare(gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil) @@ -109,7 +109,7 @@ func TestNewRevocationRequest(t *testing.T) { }, expectErr: nil, mock: func() { - store.EXPECT().GetClient(gomock.Eq("foo")).Return(client, nil) + store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil) client.EXPECT().GetHashedSecret().Return([]byte("foo")) client.EXPECT().IsPublic().Return(false) hasher.EXPECT().Compare(gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil) @@ -128,7 +128,7 @@ func TestNewRevocationRequest(t *testing.T) { }, expectErr: nil, mock: func() { - store.EXPECT().GetClient(gomock.Eq("foo")).Return(client, nil) + store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil) client.EXPECT().IsPublic().Return(true) hasher.EXPECT().Compare(gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil) handler.EXPECT().RevokeToken(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) @@ -146,7 +146,7 @@ func TestNewRevocationRequest(t *testing.T) { }, expectErr: nil, mock: func() { - store.EXPECT().GetClient(gomock.Eq("foo")).Return(client, nil) + store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil) client.EXPECT().GetHashedSecret().Return([]byte("foo")) client.EXPECT().IsPublic().Return(false) handler.EXPECT().RevokeToken(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) @@ -164,7 +164,7 @@ func TestNewRevocationRequest(t *testing.T) { }, expectErr: nil, mock: func() { - store.EXPECT().GetClient(gomock.Eq("foo")).Return(client, nil) + store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil) client.EXPECT().GetHashedSecret().Return([]byte("foo")) client.EXPECT().IsPublic().Return(false) hasher.EXPECT().Compare(gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil) diff --git a/storage/memory.go b/storage/memory.go index 76d5dbdca..a0e9c3116 100644 --- a/storage/memory.go +++ b/storage/memory.go @@ -87,7 +87,7 @@ func (s *MemoryStore) DeleteOpenIDConnectSession(_ context.Context, authorizeCod return nil } -func (s *MemoryStore) GetClient(id string) (fosite.Client, error) { +func (s *MemoryStore) GetClient(_ context.Context, id string) (fosite.Client, error) { cl, ok := s.Clients[id] if !ok { return nil, fosite.ErrNotFound