Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added context to GetClient storage interface, see issue: ory/fosite#161 #162

Merged
merged 1 commit into from
May 18, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion access_request_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func (f *Fosite) NewAccessRequest(ctx context.Context, r *http.Request, session
return accessRequest, errors.Wrap(ErrInvalidRequest, "HTTP authorization header missing or invalid")
}

client, err := f.Store.GetClient(clientID)
client, err := f.Store.GetClient(ctx, clientID)
if err != nil {
return accessRequest, errors.Wrap(ErrInvalidClient, err.Error())
}
Expand Down
12 changes: 6 additions & 6 deletions access_request_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func TestNewAccessRequest(t *testing.T) {
},
expectErr: ErrInvalidClient,
mock: func() {
store.EXPECT().GetClient(gomock.Eq("foo")).Return(nil, errors.New(""))
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(nil, errors.New(""))
},
},
{
Expand All @@ -92,7 +92,7 @@ func TestNewAccessRequest(t *testing.T) {
},
expectErr: ErrInvalidClient,
mock: func() {
store.EXPECT().GetClient(gomock.Eq("foo")).Return(nil, errors.New(""))
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(nil, errors.New(""))
},
},
{
Expand All @@ -105,7 +105,7 @@ func TestNewAccessRequest(t *testing.T) {
},
expectErr: ErrInvalidClient,
mock: func() {
store.EXPECT().GetClient(gomock.Eq("foo")).Return(client, nil)
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil)
client.EXPECT().IsPublic().Return(false)
client.EXPECT().GetHashedSecret().Return([]byte("foo"))
hasher.EXPECT().Compare(gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(errors.New(""))
Expand All @@ -121,7 +121,7 @@ func TestNewAccessRequest(t *testing.T) {
},
expectErr: ErrServerError,
mock: func() {
store.EXPECT().GetClient(gomock.Eq("foo")).Return(client, nil)
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil)
client.EXPECT().IsPublic().Return(false)
client.EXPECT().GetHashedSecret().Return([]byte("foo"))
hasher.EXPECT().Compare(gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil)
Expand All @@ -138,7 +138,7 @@ func TestNewAccessRequest(t *testing.T) {
"grant_type": {"foo"},
},
mock: func() {
store.EXPECT().GetClient(gomock.Eq("foo")).Return(client, nil)
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil)
client.EXPECT().IsPublic().Return(false)
client.EXPECT().GetHashedSecret().Return([]byte("foo"))
hasher.EXPECT().Compare(gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil)
Expand All @@ -161,7 +161,7 @@ func TestNewAccessRequest(t *testing.T) {
"grant_type": {"foo"},
},
mock: func() {
store.EXPECT().GetClient(gomock.Eq("foo")).Return(client, nil)
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil)
client.EXPECT().IsPublic().Return(true)
handler.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil)
},
Expand Down
2 changes: 1 addition & 1 deletion authorize_request_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func (c *Fosite) NewAuthorizeRequest(ctx context.Context, r *http.Request) (Auth
}

request.Form = r.Form
client, err := c.Store.GetClient(request.GetRequestForm().Get("client_id"))
client, err := c.Store.GetClient(ctx, request.GetRequestForm().Get("client_id"))
if err != nil {
return request, errors.WithStack(ErrInvalidClient)
}
Expand Down
18 changes: 9 additions & 9 deletions authorize_request_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func TestNewAuthorizeRequest(t *testing.T) {
r: &http.Request{},
expectedError: ErrInvalidClient,
mock: func() {
store.EXPECT().GetClient(gomock.Any()).Return(nil, errors.New("foo"))
store.EXPECT().GetClient(gomock.Any(), gomock.Any()).Return(nil, errors.New("foo"))
},
},
/* invalid redirect uri */
Expand All @@ -53,7 +53,7 @@ func TestNewAuthorizeRequest(t *testing.T) {
query: url.Values{"redirect_uri": []string{"invalid"}},
expectedError: ErrInvalidClient,
mock: func() {
store.EXPECT().GetClient(gomock.Any()).Return(nil, errors.New("foo"))
store.EXPECT().GetClient(gomock.Any(), gomock.Any()).Return(nil, errors.New("foo"))
},
},
/* invalid client */
Expand All @@ -63,7 +63,7 @@ func TestNewAuthorizeRequest(t *testing.T) {
query: url.Values{"redirect_uri": []string{"https://foo.bar/cb"}},
expectedError: ErrInvalidClient,
mock: func() {
store.EXPECT().GetClient(gomock.Any()).Return(nil, errors.New("foo"))
store.EXPECT().GetClient(gomock.Any(), gomock.Any()).Return(nil, errors.New("foo"))
},
},
/* redirect client mismatch */
Expand All @@ -75,7 +75,7 @@ func TestNewAuthorizeRequest(t *testing.T) {
},
expectedError: ErrInvalidRequest,
mock: func() {
store.EXPECT().GetClient("1234").Return(&DefaultClient{RedirectURIs: []string{"invalid"}}, nil)
store.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"invalid"}}, nil)
},
},
/* redirect client mismatch */
Expand All @@ -88,7 +88,7 @@ func TestNewAuthorizeRequest(t *testing.T) {
},
expectedError: ErrInvalidRequest,
mock: func() {
store.EXPECT().GetClient("1234").Return(&DefaultClient{RedirectURIs: []string{"invalid"}}, nil)
store.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"invalid"}}, nil)
},
},
/* redirect client mismatch */
Expand All @@ -101,7 +101,7 @@ func TestNewAuthorizeRequest(t *testing.T) {
},
expectedError: ErrInvalidRequest,
mock: func() {
store.EXPECT().GetClient("1234").Return(&DefaultClient{RedirectURIs: []string{"invalid"}}, nil)
store.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"invalid"}}, nil)
},
},
/* no state */
Expand All @@ -115,7 +115,7 @@ func TestNewAuthorizeRequest(t *testing.T) {
},
expectedError: ErrInvalidState,
mock: func() {
store.EXPECT().GetClient("1234").Return(&DefaultClient{RedirectURIs: []string{"https://foo.bar/cb"}}, nil)
store.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"https://foo.bar/cb"}}, nil)
},
},
/* short state */
Expand All @@ -130,7 +130,7 @@ func TestNewAuthorizeRequest(t *testing.T) {
},
expectedError: ErrInvalidState,
mock: func() {
store.EXPECT().GetClient("1234").Return(&DefaultClient{RedirectURIs: []string{"https://foo.bar/cb"}}, nil)
store.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"https://foo.bar/cb"}}, nil)
},
},
/* success case */
Expand All @@ -145,7 +145,7 @@ func TestNewAuthorizeRequest(t *testing.T) {
"scope": {"foo bar"},
},
mock: func() {
store.EXPECT().GetClient("1234").Return(&DefaultClient{RedirectURIs: []string{"https://foo.bar/cb"}}, nil)
store.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"https://foo.bar/cb"}}, nil)
},
expect: &AuthorizeRequest{
RedirectURI: redir,
Expand Down
4 changes: 3 additions & 1 deletion client_manager.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package fosite

import "context"

// ClientManager defines the (persistent) manager interface for clients.
type ClientManager interface {
// GetClient loads the client by its ID or returns an error
// if the client does not exist or another error occurred.
GetClient(id string) (Client, error)
GetClient(ctx context.Context, id string) (Client, error)
}
10 changes: 6 additions & 4 deletions internal/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
package internal

import (
context "context"

gomock "github.com/golang/mock/gomock"
fosite "github.com/ory/fosite"
)
Expand All @@ -29,13 +31,13 @@ func (_m *MockStorage) EXPECT() *_MockStorageRecorder {
return _m.recorder
}

func (_m *MockStorage) GetClient(_param0 string) (fosite.Client, error) {
ret := _m.ctrl.Call(_m, "GetClient", _param0)
func (_m *MockStorage) GetClient(_param0 context.Context, _param1 string) (fosite.Client, error) {
ret := _m.ctrl.Call(_m, "GetClient", _param0, _param1)
ret0, _ := ret[0].(fosite.Client)
ret1, _ := ret[1].(error)
return ret0, ret1
}

func (_mr *_MockStorageRecorder) GetClient(arg0 interface{}) *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "GetClient", arg0)
func (_mr *_MockStorageRecorder) GetClient(arg0, arg1 interface{}) *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "GetClient", arg0, arg1)
}
2 changes: 1 addition & 1 deletion introspection_request_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ func (f *Fosite) NewIntrospectionRequest(ctx context.Context, r *http.Request, s
return &IntrospectionResponse{Active: false}, errors.Wrap(ErrRequestUnauthorized, "HTTP Authorization header missing, malformed or credentials used are invalid")
}

client, err := f.Store.GetClient(clientID)
client, err := f.Store.GetClient(ctx, clientID)
if err != nil {
return &IntrospectionResponse{Active: false}, errors.Wrap(ErrRequestUnauthorized, "HTTP Authorization header missing, malformed or credentials used are invalid")
}
Expand Down
2 changes: 1 addition & 1 deletion revoke_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func (f *Fosite) NewRevocationRequest(ctx context.Context, r *http.Request) erro
return errors.Wrap(ErrInvalidRequest, "HTTP Authorization header missing or invalid")
}

client, err := f.Store.GetClient(clientID)
client, err := f.Store.GetClient(ctx, clientID)
if err != nil {
return errors.Wrap(ErrInvalidClient, err.Error())
}
Expand Down
14 changes: 7 additions & 7 deletions revoke_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func TestNewRevocationRequest(t *testing.T) {
},
expectErr: ErrInvalidClient,
mock: func() {
store.EXPECT().GetClient(gomock.Eq("foo")).Return(nil, errors.New(""))
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(nil, errors.New(""))
},
},
{
Expand All @@ -74,7 +74,7 @@ func TestNewRevocationRequest(t *testing.T) {
},
expectErr: ErrInvalidClient,
mock: func() {
store.EXPECT().GetClient(gomock.Eq("foo")).Return(client, nil)
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil)
client.EXPECT().GetHashedSecret().Return([]byte("foo"))
client.EXPECT().IsPublic().Return(false)
hasher.EXPECT().Compare(gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(errors.New(""))
Expand All @@ -90,7 +90,7 @@ func TestNewRevocationRequest(t *testing.T) {
},
expectErr: nil,
mock: func() {
store.EXPECT().GetClient(gomock.Eq("foo")).Return(client, nil)
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil)
client.EXPECT().GetHashedSecret().Return([]byte("foo"))
client.EXPECT().IsPublic().Return(false)
hasher.EXPECT().Compare(gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil)
Expand All @@ -109,7 +109,7 @@ func TestNewRevocationRequest(t *testing.T) {
},
expectErr: nil,
mock: func() {
store.EXPECT().GetClient(gomock.Eq("foo")).Return(client, nil)
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil)
client.EXPECT().GetHashedSecret().Return([]byte("foo"))
client.EXPECT().IsPublic().Return(false)
hasher.EXPECT().Compare(gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil)
Expand All @@ -128,7 +128,7 @@ func TestNewRevocationRequest(t *testing.T) {
},
expectErr: nil,
mock: func() {
store.EXPECT().GetClient(gomock.Eq("foo")).Return(client, nil)
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil)
client.EXPECT().IsPublic().Return(true)
hasher.EXPECT().Compare(gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil)
handler.EXPECT().RevokeToken(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil)
Expand All @@ -146,7 +146,7 @@ func TestNewRevocationRequest(t *testing.T) {
},
expectErr: nil,
mock: func() {
store.EXPECT().GetClient(gomock.Eq("foo")).Return(client, nil)
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil)
client.EXPECT().GetHashedSecret().Return([]byte("foo"))
client.EXPECT().IsPublic().Return(false)
handler.EXPECT().RevokeToken(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil)
Expand All @@ -164,7 +164,7 @@ func TestNewRevocationRequest(t *testing.T) {
},
expectErr: nil,
mock: func() {
store.EXPECT().GetClient(gomock.Eq("foo")).Return(client, nil)
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil)
client.EXPECT().GetHashedSecret().Return([]byte("foo"))
client.EXPECT().IsPublic().Return(false)
hasher.EXPECT().Compare(gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil)
Expand Down
2 changes: 1 addition & 1 deletion storage/memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func (s *MemoryStore) DeleteOpenIDConnectSession(_ context.Context, authorizeCod
return nil
}

func (s *MemoryStore) GetClient(id string) (fosite.Client, error) {
func (s *MemoryStore) GetClient(_ context.Context, id string) (fosite.Client, error) {
cl, ok := s.Clients[id]
if !ok {
return nil, fosite.ErrNotFound
Expand Down