diff --git a/access_request_handler.go b/access_request_handler.go index 9e500f42b..2c373dddc 100644 --- a/access_request_handler.go +++ b/access_request_handler.go @@ -59,6 +59,9 @@ import ( func (f *Fosite) NewAccessRequest(ctx context.Context, r *http.Request, session Session) (AccessRequester, error) { accessRequest := NewAccessRequest(session) + ctx = context.WithValue(ctx, RequestContextKey, r) + ctx = context.WithValue(ctx, AccessRequestContextKey, accessRequest) + if r.Method != "POST" { return accessRequest, errorsx.WithStack(ErrInvalidRequest.WithHintf("HTTP method is '%s', expected 'POST'.", r.Method)) } else if err := r.ParseMultipartForm(1 << 20); err != nil && err != http.ErrNotMultipart { diff --git a/access_request_handler_test.go b/access_request_handler_test.go index 3ce796d6b..9ea79f10a 100644 --- a/access_request_handler_test.go +++ b/access_request_handler_test.go @@ -47,6 +47,8 @@ func TestNewAccessRequest(t *testing.T) { hasher := internal.NewMockHasher(ctrl) defer ctrl.Finish() + ctx := gomock.AssignableToTypeOf(context.WithValue(context.TODO(), ContextKey("test"), nil)) + client := &DefaultClient{} fosite := &Fosite{Store: store, Hasher: hasher, AudienceMatchingStrategy: DefaultAudienceMatchingStrategy} for k, c := range []struct { @@ -136,7 +138,7 @@ func TestNewAccessRequest(t *testing.T) { store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil) client.Public = false client.Secret = []byte("foo") - hasher.EXPECT().Compare(context.TODO(), gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(errors.New("")) + hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(errors.New("")) }, handlers: TokenEndpointHandlers{handler}, }, @@ -153,7 +155,7 @@ func TestNewAccessRequest(t *testing.T) { store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil) client.Public = false client.Secret = []byte("foo") - hasher.EXPECT().Compare(context.TODO(), gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil) + hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil) handler.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any()).Return(ErrServerError) }, handlers: TokenEndpointHandlers{handler}, @@ -170,7 +172,7 @@ func TestNewAccessRequest(t *testing.T) { store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil) client.Public = false client.Secret = []byte("foo") - hasher.EXPECT().Compare(context.TODO(), gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil) + hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil) handler.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any()).Return(nil) }, handlers: TokenEndpointHandlers{handler}, @@ -369,6 +371,8 @@ func TestNewAccessRequestWithMixedClientAuth(t *testing.T) { hasher := internal.NewMockHasher(ctrl) defer ctrl.Finish() + ctx := gomock.AssignableToTypeOf(context.WithValue(context.TODO(), ContextKey("test"), nil)) + client := &DefaultClient{} fosite := &Fosite{Store: store, Hasher: hasher, AudienceMatchingStrategy: DefaultAudienceMatchingStrategy} for k, c := range []struct { @@ -391,7 +395,7 @@ func TestNewAccessRequestWithMixedClientAuth(t *testing.T) { store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil) client.Public = false client.Secret = []byte("foo") - hasher.EXPECT().Compare(context.TODO(), gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(errors.New("hash err")) + hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(errors.New("hash err")) handlerWithoutClientAuth.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any()).Return(nil) }, method: "POST", @@ -409,7 +413,7 @@ func TestNewAccessRequestWithMixedClientAuth(t *testing.T) { store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil) client.Public = false client.Secret = []byte("foo") - hasher.EXPECT().Compare(context.TODO(), gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil) + hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil) handlerWithoutClientAuth.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any()).Return(nil) handlerWithClientAuth.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any()).Return(nil) }, diff --git a/access_response_writer.go b/access_response_writer.go index 0547266d2..0b3f302d4 100644 --- a/access_response_writer.go +++ b/access_response_writer.go @@ -34,6 +34,10 @@ func (f *Fosite) NewAccessResponse(ctx context.Context, requester AccessRequeste var tk TokenEndpointHandler response := NewAccessResponse() + + ctx = context.WithValue(ctx, AccessRequestContextKey, requester) + ctx = context.WithValue(ctx, AccessResponseContextKey, response) + for _, tk = range f.TokenEndpointHandlers { if err = tk.PopulateTokenEndpointResponse(ctx, requester, response); err == nil { // do nothing diff --git a/access_response_writer_test.go b/access_response_writer_test.go index a3ecf5e86..54c719a83 100644 --- a/access_response_writer_test.go +++ b/access_response_writer_test.go @@ -92,7 +92,7 @@ func TestNewAccessResponse(t *testing.T) { t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { f.TokenEndpointHandlers = c.handlers c.mock() - ar, err := f.NewAccessResponse(nil, nil) + ar, err := f.NewAccessResponse(context.TODO(), nil) if c.expectErr != nil { assert.EqualError(t, err, c.expectErr.Error()) diff --git a/authorize_request_handler.go b/authorize_request_handler.go index 271a644d6..de16356aa 100644 --- a/authorize_request_handler.go +++ b/authorize_request_handler.go @@ -275,6 +275,9 @@ func (f *Fosite) validateResponseMode(r *http.Request, request *AuthorizeRequest func (f *Fosite) NewAuthorizeRequest(ctx context.Context, r *http.Request) (AuthorizeRequester, error) { request := NewAuthorizeRequest() + ctx = context.WithValue(ctx, RequestContextKey, r) + ctx = context.WithValue(ctx, AuthorizeRequestContextKey, request) + if err := r.ParseMultipartForm(1 << 20); err != nil && err != http.ErrNotMultipart { return request, errorsx.WithStack(ErrInvalidRequest.WithHint("Unable to parse HTTP body, make sure to send a properly formatted form request body.").WithWrap(err).WithDebug(err.Error())) } diff --git a/authorize_response_writer.go b/authorize_response_writer.go index 4422c63aa..7c56d5576 100644 --- a/authorize_response_writer.go +++ b/authorize_response_writer.go @@ -35,6 +35,9 @@ func (f *Fosite) NewAuthorizeResponse(ctx context.Context, ar AuthorizeRequester Parameters: url.Values{}, } + ctx = context.WithValue(ctx, AuthorizeRequestContextKey, ar) + ctx = context.WithValue(ctx, AuthorizeResponseContextKey, resp) + ar.SetSession(session) for _, h := range f.AuthorizeEndpointHandlers { if err := h.HandleAuthorizeEndpointRequest(ctx, ar, resp); err != nil { diff --git a/context.go b/context.go index 90f9aa9b4..48558e9a3 100644 --- a/context.go +++ b/context.go @@ -26,3 +26,13 @@ import "context" func NewContext() context.Context { return context.Background() } + +type ContextKey string + +const ( + RequestContextKey = ContextKey("request") + AccessRequestContextKey = ContextKey("accessRequest") + AccessResponseContextKey = ContextKey("accessResponse") + AuthorizeRequestContextKey = ContextKey("authorizeRequest") + AuthorizeResponseContextKey = ContextKey("authorizeResponse") +) diff --git a/introspection_request_handler.go b/introspection_request_handler.go index f2f1b28b0..893ba21fe 100644 --- a/introspection_request_handler.go +++ b/introspection_request_handler.go @@ -110,6 +110,8 @@ import ( // // token=mF_9.B5f-4.1JqM&token_type_hint=access_token func (f *Fosite) NewIntrospectionRequest(ctx context.Context, r *http.Request, session Session) (IntrospectionResponder, error) { + ctx = context.WithValue(ctx, RequestContextKey, r) + if r.Method != "POST" { return &IntrospectionResponse{Active: false}, errorsx.WithStack(ErrInvalidRequest.WithHintf("HTTP method is '%s' but expected 'POST'.", r.Method)) } else if err := r.ParseMultipartForm(1 << 20); err != nil && err != http.ErrNotMultipart { diff --git a/introspection_request_handler_test.go b/introspection_request_handler_test.go index 955808183..31d7c0339 100644 --- a/introspection_request_handler_test.go +++ b/introspection_request_handler_test.go @@ -45,6 +45,8 @@ func TestIntrospectionResponseTokenUse(t *testing.T) { validator := internal.NewMockTokenIntrospector(ctrl) defer ctrl.Finish() + ctx := gomock.AssignableToTypeOf(context.WithValue(context.TODO(), ContextKey("test"), nil)) + f := compose.ComposeAllEnabled(new(compose.Config), storage.NewExampleStore(), []byte{}, nil).(*Fosite) httpreq := &http.Request{ Method: "POST", @@ -65,8 +67,8 @@ func TestIntrospectionResponseTokenUse(t *testing.T) { description: "introspecting access token", setup: func() { f.TokenIntrospectionHandlers = TokenIntrospectionHandlers{validator} - validator.EXPECT().IntrospectToken(context.TODO(), "some-token", gomock.Any(), gomock.Any(), gomock.Any()).Return(TokenUse(""), nil) - validator.EXPECT().IntrospectToken(context.TODO(), "introspect-token", gomock.Any(), gomock.Any(), gomock.Any()).Return(AccessToken, nil) + validator.EXPECT().IntrospectToken(ctx, "some-token", gomock.Any(), gomock.Any(), gomock.Any()).Return(TokenUse(""), nil) + validator.EXPECT().IntrospectToken(ctx, "introspect-token", gomock.Any(), gomock.Any(), gomock.Any()).Return(AccessToken, nil) }, expectedATT: BearerAccessToken, expectedTU: AccessToken, @@ -75,8 +77,8 @@ func TestIntrospectionResponseTokenUse(t *testing.T) { description: "introspecting refresh token", setup: func() { f.TokenIntrospectionHandlers = TokenIntrospectionHandlers{validator} - validator.EXPECT().IntrospectToken(context.TODO(), "some-token", gomock.Any(), gomock.Any(), gomock.Any()).Return(TokenUse(""), nil) - validator.EXPECT().IntrospectToken(context.TODO(), "introspect-token", gomock.Any(), gomock.Any(), gomock.Any()).Return(RefreshToken, nil) + validator.EXPECT().IntrospectToken(ctx, "some-token", gomock.Any(), gomock.Any(), gomock.Any()).Return(TokenUse(""), nil) + validator.EXPECT().IntrospectToken(ctx, "introspect-token", gomock.Any(), gomock.Any(), gomock.Any()).Return(RefreshToken, nil) }, expectedATT: "", expectedTU: RefreshToken, @@ -106,6 +108,8 @@ func TestNewIntrospectionRequest(t *testing.T) { validator := internal.NewMockTokenIntrospector(ctrl) defer ctrl.Finish() + ctx := gomock.AssignableToTypeOf(context.WithValue(context.TODO(), ContextKey("test"), nil)) + f := compose.ComposeAllEnabled(new(compose.Config), storage.NewExampleStore(), []byte{}, nil).(*Fosite) httpreq := &http.Request{ Method: "POST", @@ -139,8 +143,8 @@ func TestNewIntrospectionRequest(t *testing.T) { "token": []string{"introspect-token"}, }, } - validator.EXPECT().IntrospectToken(context.TODO(), "some-token", gomock.Any(), gomock.Any(), gomock.Any()).Return(TokenUse(""), nil) - validator.EXPECT().IntrospectToken(context.TODO(), "introspect-token", gomock.Any(), gomock.Any(), gomock.Any()).Return(TokenUse(""), newErr) + validator.EXPECT().IntrospectToken(ctx, "some-token", gomock.Any(), gomock.Any(), gomock.Any()).Return(TokenUse(""), nil) + validator.EXPECT().IntrospectToken(ctx, "introspect-token", gomock.Any(), gomock.Any(), gomock.Any()).Return(TokenUse(""), newErr) }, isActive: false, expectErr: ErrInactiveToken, @@ -158,8 +162,8 @@ func TestNewIntrospectionRequest(t *testing.T) { "token": []string{"introspect-token"}, }, } - validator.EXPECT().IntrospectToken(context.TODO(), "some-token", gomock.Any(), gomock.Any(), gomock.Any()).Return(TokenUse(""), nil) - validator.EXPECT().IntrospectToken(context.TODO(), "introspect-token", gomock.Any(), gomock.Any(), gomock.Any()).Return(TokenUse(""), nil) + validator.EXPECT().IntrospectToken(ctx, "some-token", gomock.Any(), gomock.Any(), gomock.Any()).Return(TokenUse(""), nil) + validator.EXPECT().IntrospectToken(ctx, "introspect-token", gomock.Any(), gomock.Any(), gomock.Any()).Return(TokenUse(""), nil) }, isActive: true, }, @@ -177,7 +181,7 @@ func TestNewIntrospectionRequest(t *testing.T) { "token": []string{"introspect-token"}, }, } - validator.EXPECT().IntrospectToken(context.TODO(), "introspect-token", gomock.Any(), gomock.Any(), gomock.Any()).Return(TokenUse(""), nil) + validator.EXPECT().IntrospectToken(ctx, "introspect-token", gomock.Any(), gomock.Any(), gomock.Any()).Return(TokenUse(""), nil) }, isActive: true, }, @@ -195,7 +199,7 @@ func TestNewIntrospectionRequest(t *testing.T) { "token": []string{"introspect-token"}, }, } - validator.EXPECT().IntrospectToken(context.TODO(), "introspect-token", gomock.Any(), gomock.Any(), gomock.Any()).Return(TokenUse(""), nil) + validator.EXPECT().IntrospectToken(ctx, "introspect-token", gomock.Any(), gomock.Any(), gomock.Any()).Return(TokenUse(""), nil) }, isActive: true, }, diff --git a/revoke_handler.go b/revoke_handler.go index 66cab3dda..d10d8effd 100644 --- a/revoke_handler.go +++ b/revoke_handler.go @@ -50,6 +50,8 @@ import ( // An invalid token type hint value is ignored by the authorization // server and does not influence the revocation response. func (f *Fosite) NewRevocationRequest(ctx context.Context, r *http.Request) error { + ctx = context.WithValue(ctx, RequestContextKey, r) + if r.Method != "POST" { return errorsx.WithStack(ErrInvalidRequest.WithHintf("HTTP method is '%s' but expected 'POST'.", r.Method)) } else if err := r.ParseMultipartForm(1 << 20); err != nil && err != http.ErrNotMultipart { diff --git a/revoke_handler_test.go b/revoke_handler_test.go index 6e76f37f0..81ba59bc9 100644 --- a/revoke_handler_test.go +++ b/revoke_handler_test.go @@ -44,6 +44,8 @@ func TestNewRevocationRequest(t *testing.T) { hasher := internal.NewMockHasher(ctrl) defer ctrl.Finish() + ctx := gomock.AssignableToTypeOf(context.WithValue(context.TODO(), ContextKey("test"), nil)) + client := &DefaultClient{} fosite := &Fosite{Store: store, Hasher: hasher} for k, c := range []struct { @@ -102,7 +104,7 @@ func TestNewRevocationRequest(t *testing.T) { store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil) client.Secret = []byte("foo") client.Public = false - hasher.EXPECT().Compare(context.TODO(), gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(errors.New("")) + hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(errors.New("")) }, }, { @@ -118,7 +120,7 @@ func TestNewRevocationRequest(t *testing.T) { store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil) client.Secret = []byte("foo") client.Public = false - hasher.EXPECT().Compare(context.TODO(), gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil) + hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil) handler.EXPECT().RevokeToken(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) }, handlers: RevocationHandlers{handler}, @@ -137,7 +139,7 @@ func TestNewRevocationRequest(t *testing.T) { store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil) client.Secret = []byte("foo") client.Public = false - hasher.EXPECT().Compare(context.TODO(), gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil) + hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil) handler.EXPECT().RevokeToken(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) }, handlers: RevocationHandlers{handler}, @@ -173,7 +175,7 @@ func TestNewRevocationRequest(t *testing.T) { store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil) client.Secret = []byte("foo") client.Public = false - hasher.EXPECT().Compare(context.TODO(), gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil) + hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil) handler.EXPECT().RevokeToken(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) }, handlers: RevocationHandlers{handler}, @@ -192,7 +194,7 @@ func TestNewRevocationRequest(t *testing.T) { store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil) client.Secret = []byte("foo") client.Public = false - hasher.EXPECT().Compare(context.TODO(), gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil) + hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil) handler.EXPECT().RevokeToken(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) }, handlers: RevocationHandlers{handler},