diff --git a/client/handler.go b/client/handler.go index d59e95a54b3..b6771e1f6fd 100644 --- a/client/handler.go +++ b/client/handler.go @@ -107,7 +107,7 @@ func (h *Handler) Create(w http.ResponseWriter, r *http.Request, _ httprouter.Pa } secret := c.Secret - if err := h.Manager.CreateClient(&c); err != nil { + if err := h.Manager.CreateClient(r.Context(), &c); err != nil { h.H.WriteError(w, r, err) return } @@ -159,7 +159,7 @@ func (h *Handler) Update(w http.ResponseWriter, r *http.Request, ps httprouter.P return } - if err := h.Manager.UpdateClient(&c); err != nil { + if err := h.Manager.UpdateClient(r.Context(), &c); err != nil { h.H.WriteError(w, r, err) return } @@ -191,7 +191,7 @@ func (h *Handler) Update(w http.ResponseWriter, r *http.Request, ps httprouter.P // 500: genericError func (h *Handler) List(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { limit, offset := pagination.Parse(r, 100, 0, 500) - c, err := h.Manager.GetClients(limit, offset) + c, err := h.Manager.GetClients(r.Context(), limit, offset) if err != nil { h.H.WriteError(w, r, err) return @@ -233,7 +233,7 @@ func (h *Handler) List(w http.ResponseWriter, r *http.Request, ps httprouter.Par func (h *Handler) Get(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { var id = ps.ByName("id") - c, err := h.Manager.GetConcreteClient(id) + c, err := h.Manager.GetConcreteClient(r.Context(), id) if err != nil { h.H.WriteError(w, r, err) return @@ -267,7 +267,7 @@ func (h *Handler) Get(w http.ResponseWriter, r *http.Request, ps httprouter.Para func (h *Handler) Delete(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { var id = ps.ByName("id") - if err := h.Manager.DeleteClient(id); err != nil { + if err := h.Manager.DeleteClient(r.Context(), id); err != nil { h.H.WriteError(w, r, err) return } diff --git a/client/manager.go b/client/manager.go index d68074de3b9..4641ed49078 100644 --- a/client/manager.go +++ b/client/manager.go @@ -21,25 +21,27 @@ package client import ( + "context" + "github.com/ory/fosite" ) type Manager interface { Storage - Authenticate(id string, secret []byte) (*Client, error) + Authenticate(ctx context.Context, id string, secret []byte) (*Client, error) } type Storage interface { fosite.Storage - CreateClient(c *Client) error + CreateClient(ctx context.Context, c *Client) error - UpdateClient(c *Client) error + UpdateClient(ctx context.Context, c *Client) error - DeleteClient(id string) error + DeleteClient(ctx context.Context, id string) error - GetClients(limit, offset int) (map[string]Client, error) + GetClients(ctx context.Context, limit, offset int) (map[string]Client, error) - GetConcreteClient(id string) (*Client, error) + GetConcreteClient(ctx context.Context, id string) (*Client, error) } diff --git a/client/manager_0_sql_migrations_test.go b/client/manager_0_sql_migrations_test.go index f4a1a85e463..fb9214f48e4 100644 --- a/client/manager_0_sql_migrations_test.go +++ b/client/manager_0_sql_migrations_test.go @@ -26,6 +26,8 @@ import ( "sync" "testing" + "context" + "github.com/jmoiron/sqlx" "github.com/ory/fosite" "github.com/ory/hydra/client" @@ -197,7 +199,7 @@ func TestMigrations(t *testing.T) { for _, key := range []string{"1-data", "2-data", "3-data", "4-data", "5-data"} { t.Run("client="+key, func(t *testing.T) { s := &client.SQLManager{DB: db, Hasher: &fosite.BCrypt{WorkFactor: 4}} - c, err := s.GetConcreteClient(key) + c, err := s.GetConcreteClient(context.TODO(), key) require.NoError(t, err) assert.EqualValues(t, c.GetID(), key) }) diff --git a/client/manager_memory.go b/client/manager_memory.go index 64b433a0d9f..75d0797425e 100644 --- a/client/manager_memory.go +++ b/client/manager_memory.go @@ -48,7 +48,7 @@ func NewMemoryManager(hasher fosite.Hasher) *MemoryManager { } } -func (m *MemoryManager) GetConcreteClient(id string) (*Client, error) { +func (m *MemoryManager) GetConcreteClient(ctx context.Context, id string) (*Client, error) { m.RLock() defer m.RUnlock() @@ -61,12 +61,12 @@ func (m *MemoryManager) GetConcreteClient(id string) (*Client, error) { return nil, errors.WithStack(sqlcon.ErrNoRows) } -func (m *MemoryManager) GetClient(_ context.Context, id string) (fosite.Client, error) { - return m.GetConcreteClient(id) +func (m *MemoryManager) GetClient(ctx context.Context, id string) (fosite.Client, error) { + return m.GetConcreteClient(ctx, id) } -func (m *MemoryManager) UpdateClient(c *Client) error { - o, err := m.GetClient(context.Background(), c.GetID()) +func (m *MemoryManager) UpdateClient(ctx context.Context, c *Client) error { + o, err := m.GetClient(ctx, c.GetID()) if err != nil { return err } @@ -95,11 +95,11 @@ func (m *MemoryManager) UpdateClient(c *Client) error { return nil } -func (m *MemoryManager) Authenticate(id string, secret []byte) (*Client, error) { +func (m *MemoryManager) Authenticate(ctx context.Context, id string, secret []byte) (*Client, error) { m.RLock() defer m.RUnlock() - c, err := m.GetConcreteClient(id) + c, err := m.GetConcreteClient(ctx, id) if err != nil { return nil, err } @@ -111,8 +111,8 @@ func (m *MemoryManager) Authenticate(id string, secret []byte) (*Client, error) return c, nil } -func (m *MemoryManager) CreateClient(c *Client) error { - if _, err := m.GetConcreteClient(c.GetID()); err == nil { +func (m *MemoryManager) CreateClient(ctx context.Context, c *Client) error { + if _, err := m.GetConcreteClient(ctx, c.GetID()); err == nil { return sqlcon.ErrUniqueViolation } @@ -129,7 +129,7 @@ func (m *MemoryManager) CreateClient(c *Client) error { return nil } -func (m *MemoryManager) DeleteClient(id string) error { +func (m *MemoryManager) DeleteClient(ctx context.Context, id string) error { m.Lock() defer m.Unlock() @@ -143,7 +143,7 @@ func (m *MemoryManager) DeleteClient(id string) error { return nil } -func (m *MemoryManager) GetClients(limit, offset int) (clients map[string]Client, err error) { +func (m *MemoryManager) GetClients(ctx context.Context, limit, offset int) (clients map[string]Client, err error) { m.RLock() defer m.RUnlock() clients = make(map[string]Client) diff --git a/client/manager_sql.go b/client/manager_sql.go index 6f581d5a173..20b874c3503 100644 --- a/client/manager_sql.go +++ b/client/manager_sql.go @@ -336,7 +336,7 @@ func (m *SQLManager) CreateSchemas() (int, error) { return n, nil } -func (m *SQLManager) GetConcreteClient(id string) (*Client, error) { +func (m *SQLManager) GetConcreteClient(ctx context.Context, id string) (*Client, error) { var d sqlData if err := m.DB.Get(&d, m.DB.Rebind("SELECT * FROM hydra_client WHERE id=?"), id); err != nil { return nil, sqlcon.HandleError(err) @@ -345,11 +345,11 @@ func (m *SQLManager) GetConcreteClient(id string) (*Client, error) { return d.ToClient() } -func (m *SQLManager) GetClient(_ context.Context, id string) (fosite.Client, error) { - return m.GetConcreteClient(id) +func (m *SQLManager) GetClient(ctx context.Context, id string) (fosite.Client, error) { + return m.GetConcreteClient(ctx, id) } -func (m *SQLManager) UpdateClient(c *Client) error { +func (m *SQLManager) UpdateClient(ctx context.Context, c *Client) error { o, err := m.GetClient(context.Background(), c.GetID()) if err != nil { return errors.WithStack(err) @@ -381,8 +381,8 @@ func (m *SQLManager) UpdateClient(c *Client) error { return nil } -func (m *SQLManager) Authenticate(id string, secret []byte) (*Client, error) { - c, err := m.GetConcreteClient(id) +func (m *SQLManager) Authenticate(ctx context.Context, id string, secret []byte) (*Client, error) { + c, err := m.GetConcreteClient(ctx, id) if err != nil { return nil, errors.WithStack(err) } @@ -394,7 +394,7 @@ func (m *SQLManager) Authenticate(id string, secret []byte) (*Client, error) { return c, nil } -func (m *SQLManager) CreateClient(c *Client) error { +func (m *SQLManager) CreateClient(ctx context.Context, c *Client) error { h, err := m.Hasher.Hash([]byte(c.Secret)) if err != nil { return errors.WithStack(err) @@ -417,14 +417,14 @@ func (m *SQLManager) CreateClient(c *Client) error { return nil } -func (m *SQLManager) DeleteClient(id string) error { +func (m *SQLManager) DeleteClient(ctx context.Context, id string) error { if _, err := m.DB.Exec(m.DB.Rebind(`DELETE FROM hydra_client WHERE id=?`), id); err != nil { return sqlcon.HandleError(err) } return nil } -func (m *SQLManager) GetClients(limit, offset int) (clients map[string]Client, err error) { +func (m *SQLManager) GetClients(ctx context.Context, limit, offset int) (clients map[string]Client, err error) { d := make([]sqlData, 0) clients = make(map[string]Client) diff --git a/client/manager_test_helpers.go b/client/manager_test_helpers.go index ec7b407d518..b59acf3b3b0 100644 --- a/client/manager_test_helpers.go +++ b/client/manager_test_helpers.go @@ -24,6 +24,8 @@ import ( "crypto/x509" "testing" + "context" + "github.com/ory/fosite" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -32,6 +34,7 @@ import ( func TestHelperClientAutoGenerateKey(k string, m Storage) func(t *testing.T) { return func(t *testing.T) { + ctx := context.TODO() t.Parallel() c := &Client{ ClientID: "foo", @@ -39,25 +42,26 @@ func TestHelperClientAutoGenerateKey(k string, m Storage) func(t *testing.T) { RedirectURIs: []string{"http://redirect"}, TermsOfServiceURI: "foo", } - assert.NoError(t, m.CreateClient(c)) + assert.NoError(t, m.CreateClient(ctx, c)) //assert.NotEmpty(t, c.ID) - assert.NoError(t, m.DeleteClient(c.GetID())) + assert.NoError(t, m.DeleteClient(ctx, c.GetID())) } } func TestHelperClientAuthenticate(k string, m Manager) func(t *testing.T) { return func(t *testing.T) { + ctx := context.TODO() t.Parallel() - m.CreateClient(&Client{ + m.CreateClient(ctx, &Client{ ClientID: "1234321", Secret: "secret", RedirectURIs: []string{"http://redirect"}, }) - c, err := m.Authenticate("1234321", []byte("secret1")) + c, err := m.Authenticate(ctx, "1234321", []byte("secret1")) require.NotNil(t, err) - c, err = m.Authenticate("1234321", []byte("secret")) + c, err = m.Authenticate(ctx, "1234321", []byte("secret")) require.NoError(t, err) assert.Equal(t, "1234321", c.GetID()) } @@ -69,6 +73,8 @@ func TestHelperCreateGetDeleteClient(k string, m Storage) func(t *testing.T) { _, err := m.GetClient(nil, "4321") assert.NotNil(t, err) + ctx := context.TODO() + c := &Client{ ClientID: "1234", Name: "name", @@ -94,13 +100,13 @@ func TestHelperCreateGetDeleteClient(k string, m Storage) func(t *testing.T) { UserinfoSignedResponseAlg: "RS256", } - assert.NoError(t, m.CreateClient(c)) + assert.NoError(t, m.CreateClient(ctx, c)) assert.Equal(t, c.GetID(), "1234") if k != "http" { assert.NotEmpty(t, c.GetHashedSecret()) } - assert.NoError(t, m.CreateClient(&Client{ + assert.NoError(t, m.CreateClient(ctx, &Client{ ClientID: "2-1234", Name: "name", Secret: "secret", @@ -114,7 +120,7 @@ func TestHelperCreateGetDeleteClient(k string, m Storage) func(t *testing.T) { compare(t, c, d, k) - ds, err := m.GetClients(100, 0) + ds, err := m.GetClients(ctx, 100, 0) assert.NoError(t, err) assert.Len(t, ds, 2) assert.NotEqual(t, ds["1234"].ClientID, ds["2-1234"].ClientID) @@ -124,15 +130,15 @@ func TestHelperCreateGetDeleteClient(k string, m Storage) func(t *testing.T) { assert.Equal(t, ds["1234"].SecretExpiresAt, 0) assert.Equal(t, ds["2-1234"].SecretExpiresAt, 1) - ds, err = m.GetClients(1, 0) + ds, err = m.GetClients(ctx, 1, 0) assert.NoError(t, err) assert.Len(t, ds, 1) - ds, err = m.GetClients(100, 100) + ds, err = m.GetClients(ctx, 100, 100) assert.NoError(t, err) assert.Len(t, ds, 0) - err = m.UpdateClient(&Client{ + err = m.UpdateClient(ctx, &Client{ ClientID: "2-1234", Name: "name-new", Secret: "secret-new", @@ -141,7 +147,7 @@ func TestHelperCreateGetDeleteClient(k string, m Storage) func(t *testing.T) { }) require.NoError(t, err) - nc, err := m.GetConcreteClient("2-1234") + nc, err := m.GetConcreteClient(ctx, "2-1234") require.NoError(t, err) if k != "http" { @@ -153,7 +159,7 @@ func TestHelperCreateGetDeleteClient(k string, m Storage) func(t *testing.T) { assert.EqualValues(t, []string{"http://redirect/new"}, nc.GetRedirectURIs()) assert.Zero(t, len(nc.Contacts)) - err = m.DeleteClient("1234") + err = m.DeleteClient(ctx, "1234") assert.NoError(t, err) _, err = m.GetClient(nil, "1234") diff --git a/cmd/server/helper_cert.go b/cmd/server/helper_cert.go index 0a3c70e8894..90a6f7518c2 100644 --- a/cmd/server/helper_cert.go +++ b/cmd/server/helper_cert.go @@ -30,6 +30,8 @@ import ( "strings" "time" + "context" + "github.com/ory/hydra/config" "github.com/ory/hydra/jwk" "github.com/ory/hydra/pkg" @@ -105,10 +107,10 @@ func getOrCreateTLSCertificate(cmd *cobra.Command, c *config.Config) tls.Certifi } privateKey.Certificates = []*x509.Certificate{cert} - if err := ctx.KeyManager.DeleteKey(tlsKeyName, privateKey.KeyID); err != nil { + if err := ctx.KeyManager.DeleteKey(context.TODO(), tlsKeyName, privateKey.KeyID); err != nil { c.GetLogger().WithError(err).Fatalf(`Could not update (delete) the self signed TLS certificate.`) } - if err := ctx.KeyManager.AddKey(tlsKeyName, privateKey); err != nil { + if err := ctx.KeyManager.AddKey(context.TODO(), tlsKeyName, privateKey); err != nil { c.GetLogger().WithError(err).Fatalf(`Could not update (add) the self signed TLS certificate.`) } } diff --git a/cmd/server/helper_cors.go b/cmd/server/helper_cors.go index 3a042ee72e3..20b14301494 100644 --- a/cmd/server/helper_cors.go +++ b/cmd/server/helper_cors.go @@ -36,7 +36,7 @@ import ( func newCORSMiddleware( enable bool, c *config.Config, o func(ctx context.Context, token string, tokenType fosite.TokenType, session fosite.Session, scope ...string) (fosite.TokenType, fosite.AccessRequester, error), - clm func(id string) (*client.Client, error), + clm func(ctx context.Context, id string) (*client.Client, error), ) func(h http.Handler) http.Handler { if !enable { return func(h http.Handler) http.Handler { @@ -76,7 +76,7 @@ func newCORSMiddleware( username = ar.GetClient().GetID() } - cl, err := clm(username) + cl, err := clm(r.Context(), username) if err != nil { return false } diff --git a/cmd/server/helper_cors_test.go b/cmd/server/helper_cors_test.go index 38a3e757e1f..c6aa61d4cb2 100644 --- a/cmd/server/helper_cors_test.go +++ b/cmd/server/helper_cors_test.go @@ -59,7 +59,7 @@ func TestCORSMiddleware(t *testing.T) { }, { d: "should reject when basic auth but client does not exist", - mw: newCORSMiddleware(true, c, nil, func(id string) (*client.Client, error) { + mw: newCORSMiddleware(true, c, nil, func(ctx context.Context, id string) (*client.Client, error) { return nil, errors.New("") }), code: 204, @@ -68,7 +68,7 @@ func TestCORSMiddleware(t *testing.T) { }, { d: "should reject when basic auth client exists but origin not allowed", - mw: newCORSMiddleware(true, c, nil, func(id string) (*client.Client, error) { + mw: newCORSMiddleware(true, c, nil, func(ctx context.Context, id string) (*client.Client, error) { return &client.Client{AllowedCORSOrigins: []string{"http://not-foobar.com"}}, nil }), code: 204, @@ -77,7 +77,7 @@ func TestCORSMiddleware(t *testing.T) { }, { d: "should accept when basic auth client exists and origin allowed", - mw: newCORSMiddleware(true, c, nil, func(id string) (*client.Client, error) { + mw: newCORSMiddleware(true, c, nil, func(ctx context.Context, id string) (*client.Client, error) { return &client.Client{AllowedCORSOrigins: []string{"http://foobar.com"}}, nil }), code: 204, @@ -88,7 +88,7 @@ func TestCORSMiddleware(t *testing.T) { d: "should fail when token introspection fails", mw: newCORSMiddleware(true, c, func(ctx context.Context, token string, tokenType fosite.TokenType, session fosite.Session, scope ...string) (fosite.TokenType, fosite.AccessRequester, error) { return "", nil, errors.New("") - }, func(id string) (*client.Client, error) { + }, func(ctx context.Context, id string) (*client.Client, error) { return &client.Client{AllowedCORSOrigins: []string{"http://foobar.com"}}, nil }), code: 204, @@ -102,7 +102,7 @@ func TestCORSMiddleware(t *testing.T) { return "", nil, errors.New("") } return "", &fosite.AccessRequest{Request: fosite.Request{Client: &client.Client{ClientID: "asdf"}}}, nil - }, func(id string) (*client.Client, error) { + }, func(ctx context.Context, id string) (*client.Client, error) { if id != "asdf" { return nil, errors.New("") } diff --git a/cmd/server/helper_keys.go b/cmd/server/helper_keys.go index 81809e8ea4e..e3fec28ec23 100644 --- a/cmd/server/helper_keys.go +++ b/cmd/server/helper_keys.go @@ -24,6 +24,8 @@ import ( "crypto/ecdsa" "crypto/rsa" + "context" + "github.com/ory/hydra/config" "github.com/ory/hydra/jwk" "github.com/ory/hydra/pkg" @@ -36,7 +38,7 @@ func createOrGetJWK(c *config.Config, set string, kid string, prefix string) (ke expectDependency(c.GetLogger(), ctx.KeyManager) - keys, err := ctx.KeyManager.GetKeySet(set) + keys, err := ctx.KeyManager.GetKeySet(context.TODO(), set) if errors.Cause(err) == pkg.ErrNotFound || keys != nil && len(keys.Keys) == 0 { c.GetLogger().Infof("JSON Web Key Set %s does not exist yet, generating new key pair...", set) keys, err = createJWKS(ctx, set, kid) @@ -77,7 +79,7 @@ func createJWKS(ctx *config.Context, set, kid string) (*jose.JSONWebKeySet, erro keys.Keys[i] = k } - err = ctx.KeyManager.AddKeySet(set, keys) + err = ctx.KeyManager.AddKeySet(context.TODO(), set, keys) if err != nil { return nil, errors.Wrapf(err, "Could not persist %s key", set) } diff --git a/consent/handler.go b/consent/handler.go index 0ee4cccd6ee..471e3303969 100644 --- a/consent/handler.go +++ b/consent/handler.go @@ -101,7 +101,7 @@ func (h *Handler) SetRoutes(frontend, backend *httprouter.Router) { // 500: genericError func (h *Handler) DeleteUserConsentSession(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { user := ps.ByName("user") - if err := h.M.RevokeUserConsentSession(user); err != nil { + if err := h.M.RevokeUserConsentSession(r.Context(), user); err != nil { h.H.WriteError(w, r, err) return } @@ -137,7 +137,7 @@ func (h *Handler) DeleteUserClientConsentSession(w http.ResponseWriter, r *http. return } - if err := h.M.RevokeUserClientConsentSession(user, client); err != nil { + if err := h.M.RevokeUserClientConsentSession(r.Context(), user, client); err != nil { h.H.WriteError(w, r, err) return } @@ -172,7 +172,7 @@ func (h *Handler) GetConsentSessions(w http.ResponseWriter, r *http.Request, ps } limit, offset := pagination.Parse(r, 100, 0, 500) - sessions, err := h.M.FindPreviouslyGrantedConsentRequestsByUser(user, limit, offset) + sessions, err := h.M.FindPreviouslyGrantedConsentRequestsByUser(r.Context(), user, limit, offset) if errors.Cause(err) == ErrNoPreviousConsentFound { h.H.Write(w, r, []PreviousConsentSession{}) return @@ -218,7 +218,7 @@ func (h *Handler) GetConsentSessions(w http.ResponseWriter, r *http.Request, ps func (h *Handler) DeleteLoginSession(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { user := ps.ByName("user") - if err := h.M.RevokeUserAuthenticationSession(user); err != nil { + if err := h.M.RevokeUserAuthenticationSession(r.Context(), user); err != nil { h.H.WriteError(w, r, err) return } @@ -252,7 +252,7 @@ func (h *Handler) DeleteLoginSession(w http.ResponseWriter, r *http.Request, ps // 401: genericError // 500: genericError func (h *Handler) GetLoginRequest(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - request, err := h.M.GetAuthenticationRequest(ps.ByName("challenge")) + request, err := h.M.GetAuthenticationRequest(r.Context(), ps.ByName("challenge")) if err != nil { h.H.WriteError(w, r, err) return @@ -303,7 +303,7 @@ func (h *Handler) AcceptLoginRequest(w http.ResponseWriter, r *http.Request, ps } p.Challenge = ps.ByName("challenge") - ar, err := h.M.GetAuthenticationRequest(ps.ByName("challenge")) + ar, err := h.M.GetAuthenticationRequest(r.Context(), ps.ByName("challenge")) if err != nil { h.H.WriteError(w, r, err) return @@ -322,7 +322,7 @@ func (h *Handler) AcceptLoginRequest(w http.ResponseWriter, r *http.Request, ps } p.RequestedAt = ar.RequestedAt - request, err := h.M.HandleAuthenticationRequest(ps.ByName("challenge"), &p) + request, err := h.M.HandleAuthenticationRequest(r.Context(), ps.ByName("challenge"), &p) if err != nil { h.H.WriteError(w, r, errors.WithStack(err)) return @@ -377,13 +377,13 @@ func (h *Handler) RejectLoginRequest(w http.ResponseWriter, r *http.Request, ps return } - ar, err := h.M.GetAuthenticationRequest(ps.ByName("challenge")) + ar, err := h.M.GetAuthenticationRequest(r.Context(), ps.ByName("challenge")) if err != nil { h.H.WriteError(w, r, err) return } - request, err := h.M.HandleAuthenticationRequest(ps.ByName("challenge"), &HandledAuthenticationRequest{ + request, err := h.M.HandleAuthenticationRequest(r.Context(), ps.ByName("challenge"), &HandledAuthenticationRequest{ Error: &p, Challenge: ps.ByName("challenge"), RequestedAt: ar.RequestedAt, @@ -432,7 +432,7 @@ func (h *Handler) RejectLoginRequest(w http.ResponseWriter, r *http.Request, ps // 401: genericError // 500: genericError func (h *Handler) GetConsentRequest(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - request, err := h.M.GetConsentRequest(ps.ByName("challenge")) + request, err := h.M.GetConsentRequest(r.Context(), ps.ByName("challenge")) if err != nil { h.H.WriteError(w, r, err) return @@ -485,7 +485,7 @@ func (h *Handler) AcceptConsentRequest(w http.ResponseWriter, r *http.Request, p return } - cr, err := h.M.GetConsentRequest(ps.ByName("challenge")) + cr, err := h.M.GetConsentRequest(r.Context(), ps.ByName("challenge")) if err != nil { h.H.WriteError(w, r, errors.WithStack(err)) return @@ -494,7 +494,7 @@ func (h *Handler) AcceptConsentRequest(w http.ResponseWriter, r *http.Request, p p.Challenge = ps.ByName("challenge") p.RequestedAt = cr.RequestedAt - hr, err := h.M.HandleConsentRequest(ps.ByName("challenge"), &p) + hr, err := h.M.HandleConsentRequest(r.Context(), ps.ByName("challenge"), &p) if err != nil { h.H.WriteError(w, r, errors.WithStack(err)) return @@ -555,13 +555,13 @@ func (h *Handler) RejectConsentRequest(w http.ResponseWriter, r *http.Request, p return } - hr, err := h.M.GetConsentRequest(ps.ByName("challenge")) + hr, err := h.M.GetConsentRequest(r.Context(), ps.ByName("challenge")) if err != nil { h.H.WriteError(w, r, errors.WithStack(err)) return } - request, err := h.M.HandleConsentRequest(ps.ByName("challenge"), &HandledConsentRequest{ + request, err := h.M.HandleConsentRequest(r.Context(), ps.ByName("challenge"), &HandledConsentRequest{ Error: &p, Challenge: ps.ByName("challenge"), RequestedAt: hr.RequestedAt, @@ -607,7 +607,7 @@ func (h *Handler) LogoutUser(w http.ResponseWriter, r *http.Request, ps httprout } if sid != "" { - if err := h.M.DeleteAuthenticationSession(sid); err != nil { + if err := h.M.DeleteAuthenticationSession(r.Context(), sid); err != nil { h.H.WriteError(w, r, err) return } diff --git a/consent/handler_test.go b/consent/handler_test.go index 6490ce348c0..63dd753e33a 100644 --- a/consent/handler_test.go +++ b/consent/handler_test.go @@ -28,6 +28,8 @@ import ( "testing" "time" + "context" + "github.com/gorilla/sessions" "github.com/julienschmidt/httprouter" "github.com/ory/herodot" @@ -50,7 +52,7 @@ func TestLogout(t *testing.T) { r.Handle("GET", "/login", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { cookie, _ := cs.Get(r, cookieAuthenticationName) - require.NoError(t, h.M.CreateAuthenticationSession(&AuthenticationSession{ + require.NoError(t, h.M.CreateAuthenticationSession(context.TODO(), &AuthenticationSession{ ID: sid, Subject: "foo", AuthenticatedAt: time.Now(), diff --git a/consent/manager.go b/consent/manager.go index 4d0dc0fb4e5..ea143a4ba22 100644 --- a/consent/manager.go +++ b/consent/manager.go @@ -20,6 +20,8 @@ package consent +import "context" + type ForcedObfuscatedAuthenticationSession struct { ClientID string `db:"client_id"` Subject string `db:"subject"` @@ -27,27 +29,27 @@ type ForcedObfuscatedAuthenticationSession struct { } type Manager interface { - CreateConsentRequest(*ConsentRequest) error - GetConsentRequest(challenge string) (*ConsentRequest, error) - HandleConsentRequest(challenge string, r *HandledConsentRequest) (*ConsentRequest, error) - RevokeUserConsentSession(user string) error - RevokeUserClientConsentSession(user, client string) error + CreateConsentRequest(ctx context.Context, req *ConsentRequest) error + GetConsentRequest(ctx context.Context, challenge string) (*ConsentRequest, error) + HandleConsentRequest(ctx context.Context, challenge string, r *HandledConsentRequest) (*ConsentRequest, error) + RevokeUserConsentSession(ctx context.Context, user string) error + RevokeUserClientConsentSession(ctx context.Context, user, client string) error - VerifyAndInvalidateConsentRequest(verifier string) (*HandledConsentRequest, error) - FindPreviouslyGrantedConsentRequests(client string, user string) ([]HandledConsentRequest, error) - FindPreviouslyGrantedConsentRequestsByUser(user string, limit, offset int) ([]HandledConsentRequest, error) + VerifyAndInvalidateConsentRequest(ctx context.Context, verifier string) (*HandledConsentRequest, error) + FindPreviouslyGrantedConsentRequests(ctx context.Context, client string, user string) ([]HandledConsentRequest, error) + FindPreviouslyGrantedConsentRequestsByUser(ctx context.Context, user string, limit, offset int) ([]HandledConsentRequest, error) // Cookie management - GetAuthenticationSession(id string) (*AuthenticationSession, error) - CreateAuthenticationSession(*AuthenticationSession) error - DeleteAuthenticationSession(id string) error - RevokeUserAuthenticationSession(user string) error - - CreateAuthenticationRequest(*AuthenticationRequest) error - GetAuthenticationRequest(challenge string) (*AuthenticationRequest, error) - HandleAuthenticationRequest(challenge string, r *HandledAuthenticationRequest) (*AuthenticationRequest, error) - VerifyAndInvalidateAuthenticationRequest(verifier string) (*HandledAuthenticationRequest, error) - - CreateForcedObfuscatedAuthenticationSession(*ForcedObfuscatedAuthenticationSession) error - GetForcedObfuscatedAuthenticationSession(client, obfuscated string) (*ForcedObfuscatedAuthenticationSession, error) + GetAuthenticationSession(ctx context.Context, id string) (*AuthenticationSession, error) + CreateAuthenticationSession(ctx context.Context, session *AuthenticationSession) error + DeleteAuthenticationSession(ctx context.Context, id string) error + RevokeUserAuthenticationSession(ctx context.Context, user string) error + + CreateAuthenticationRequest(ctx context.Context, req *AuthenticationRequest) error + GetAuthenticationRequest(ctx context.Context, challenge string) (*AuthenticationRequest, error) + HandleAuthenticationRequest(ctx context.Context, challenge string, r *HandledAuthenticationRequest) (*AuthenticationRequest, error) + VerifyAndInvalidateAuthenticationRequest(ctx context.Context, verifier string) (*HandledAuthenticationRequest, error) + + CreateForcedObfuscatedAuthenticationSession(ctx context.Context, session *ForcedObfuscatedAuthenticationSession) error + GetForcedObfuscatedAuthenticationSession(ctx context.Context, client, obfuscated string) (*ForcedObfuscatedAuthenticationSession, error) } diff --git a/consent/manager_memory.go b/consent/manager_memory.go index f9f9c7eac2d..61f347087b7 100644 --- a/consent/manager_memory.go +++ b/consent/manager_memory.go @@ -24,6 +24,8 @@ import ( "sync" "time" + "context" + "github.com/ory/fosite" "github.com/ory/hydra/pkg" "github.com/ory/pagination" @@ -60,7 +62,7 @@ func NewMemoryManager(store pkg.FositeStorer) *MemoryManager { } } -func (m *MemoryManager) CreateForcedObfuscatedAuthenticationSession(s *ForcedObfuscatedAuthenticationSession) error { +func (m *MemoryManager) CreateForcedObfuscatedAuthenticationSession(ctx context.Context, s *ForcedObfuscatedAuthenticationSession) error { for k, v := range m.pairwise { if v.Subject == s.Subject && v.ClientID == s.ClientID { m.pairwise[k] = *s @@ -72,7 +74,7 @@ func (m *MemoryManager) CreateForcedObfuscatedAuthenticationSession(s *ForcedObf return nil } -func (m *MemoryManager) GetForcedObfuscatedAuthenticationSession(client, obfuscated string) (*ForcedObfuscatedAuthenticationSession, error) { +func (m *MemoryManager) GetForcedObfuscatedAuthenticationSession(ctx context.Context, client, obfuscated string) (*ForcedObfuscatedAuthenticationSession, error) { for _, v := range m.pairwise { if v.SubjectObfuscated == obfuscated && v.ClientID == client { return &v, nil @@ -82,11 +84,11 @@ func (m *MemoryManager) GetForcedObfuscatedAuthenticationSession(client, obfusca return nil, errors.WithStack(pkg.ErrNotFound) } -func (m *MemoryManager) RevokeUserConsentSession(user string) error { - return m.RevokeUserClientConsentSession(user, "") +func (m *MemoryManager) RevokeUserConsentSession(ctx context.Context, user string) error { + return m.RevokeUserClientConsentSession(ctx, user, "") } -func (m *MemoryManager) RevokeUserClientConsentSession(user, client string) error { +func (m *MemoryManager) RevokeUserClientConsentSession(ctx context.Context, user, client string) error { m.m["handledConsentRequests"].Lock() defer m.m["handledConsentRequests"].Unlock() m.m["consentRequests"].Lock() @@ -117,7 +119,7 @@ func (m *MemoryManager) RevokeUserClientConsentSession(user, client string) erro return nil } -func (m *MemoryManager) RevokeUserAuthenticationSession(user string) error { +func (m *MemoryManager) RevokeUserAuthenticationSession(ctx context.Context, user string) error { m.m["authSessions"].Lock() defer m.m["authSessions"].Unlock() @@ -135,7 +137,7 @@ func (m *MemoryManager) RevokeUserAuthenticationSession(user string) error { return nil } -func (m *MemoryManager) CreateConsentRequest(c *ConsentRequest) error { +func (m *MemoryManager) CreateConsentRequest(ctx context.Context, c *ConsentRequest) error { m.m["consentRequests"].Lock() defer m.m["consentRequests"].Unlock() if _, ok := m.consentRequests[c.Challenge]; ok { @@ -145,7 +147,7 @@ func (m *MemoryManager) CreateConsentRequest(c *ConsentRequest) error { return nil } -func (m *MemoryManager) GetConsentRequest(challenge string) (*ConsentRequest, error) { +func (m *MemoryManager) GetConsentRequest(ctx context.Context, challenge string) (*ConsentRequest, error) { m.m["consentRequests"].RLock() defer m.m["consentRequests"].RUnlock() if c, ok := m.consentRequests[challenge]; ok { @@ -155,14 +157,14 @@ func (m *MemoryManager) GetConsentRequest(challenge string) (*ConsentRequest, er return nil, errors.WithStack(pkg.ErrNotFound) } -func (m *MemoryManager) HandleConsentRequest(challenge string, r *HandledConsentRequest) (*ConsentRequest, error) { +func (m *MemoryManager) HandleConsentRequest(ctx context.Context, challenge string, r *HandledConsentRequest) (*ConsentRequest, error) { m.m["handledConsentRequests"].Lock() m.handledConsentRequests[r.Challenge] = *r m.m["handledConsentRequests"].Unlock() - return m.GetConsentRequest(challenge) + return m.GetConsentRequest(ctx, challenge) } -func (m *MemoryManager) VerifyAndInvalidateConsentRequest(verifier string) (*HandledConsentRequest, error) { +func (m *MemoryManager) VerifyAndInvalidateConsentRequest(ctx context.Context, verifier string) (*HandledConsentRequest, error) { for _, c := range m.consentRequests { if c.Verifier == verifier { for _, h := range m.handledConsentRequests { @@ -172,7 +174,7 @@ func (m *MemoryManager) VerifyAndInvalidateConsentRequest(verifier string) (*Han } h.WasUsed = true - if _, err := m.HandleConsentRequest(h.Challenge, &h); err != nil { + if _, err := m.HandleConsentRequest(ctx, h.Challenge, &h); err != nil { return nil, err } @@ -186,9 +188,9 @@ func (m *MemoryManager) VerifyAndInvalidateConsentRequest(verifier string) (*Han return nil, errors.WithStack(pkg.ErrNotFound) } -func (m *MemoryManager) FindPreviouslyGrantedConsentRequests(client string, subject string) ([]HandledConsentRequest, error) { +func (m *MemoryManager) FindPreviouslyGrantedConsentRequests(ctx context.Context, client string, subject string) ([]HandledConsentRequest, error) { var rs []HandledConsentRequest - filteredByUser, err := m.FindPreviouslyGrantedConsentRequestsByUser(subject, -1, -1) + filteredByUser, err := m.FindPreviouslyGrantedConsentRequestsByUser(ctx, subject, -1, -1) if errors.Cause(err) == pkg.ErrNotFound { return nil, errors.WithStack(ErrNoPreviousConsentFound) } else if err != nil { @@ -208,10 +210,10 @@ func (m *MemoryManager) FindPreviouslyGrantedConsentRequests(client string, subj return rs, nil } -func (m *MemoryManager) FindPreviouslyGrantedConsentRequestsByUser(subject string, limit, offset int) ([]HandledConsentRequest, error) { +func (m *MemoryManager) FindPreviouslyGrantedConsentRequestsByUser(ctx context.Context, subject string, limit, offset int) ([]HandledConsentRequest, error) { var rs []HandledConsentRequest for _, c := range m.handledConsentRequests { - cr, err := m.GetConsentRequest(c.Challenge) + cr, err := m.GetConsentRequest(ctx, c.Challenge) if err != nil { return nil, err } @@ -253,7 +255,7 @@ func (m *MemoryManager) FindPreviouslyGrantedConsentRequestsByUser(subject strin return rs[start:end], nil } -func (m *MemoryManager) GetAuthenticationSession(id string) (*AuthenticationSession, error) { +func (m *MemoryManager) GetAuthenticationSession(ctx context.Context, id string) (*AuthenticationSession, error) { m.m["authSessions"].RLock() defer m.m["authSessions"].RUnlock() if c, ok := m.authSessions[id]; ok { @@ -262,7 +264,7 @@ func (m *MemoryManager) GetAuthenticationSession(id string) (*AuthenticationSess return nil, errors.WithStack(pkg.ErrNotFound) } -func (m *MemoryManager) CreateAuthenticationSession(a *AuthenticationSession) error { +func (m *MemoryManager) CreateAuthenticationSession(ctx context.Context, a *AuthenticationSession) error { m.m["authSessions"].Lock() defer m.m["authSessions"].Unlock() if _, ok := m.authSessions[a.ID]; ok { @@ -272,14 +274,14 @@ func (m *MemoryManager) CreateAuthenticationSession(a *AuthenticationSession) er return nil } -func (m *MemoryManager) DeleteAuthenticationSession(id string) error { +func (m *MemoryManager) DeleteAuthenticationSession(ctx context.Context, id string) error { m.m["authSessions"].Lock() defer m.m["authSessions"].Unlock() delete(m.authSessions, id) return nil } -func (m *MemoryManager) CreateAuthenticationRequest(a *AuthenticationRequest) error { +func (m *MemoryManager) CreateAuthenticationRequest(ctx context.Context, a *AuthenticationRequest) error { m.m["authRequests"].Lock() defer m.m["authRequests"].Unlock() if _, ok := m.authRequests[a.Challenge]; ok { @@ -289,7 +291,7 @@ func (m *MemoryManager) CreateAuthenticationRequest(a *AuthenticationRequest) er return nil } -func (m *MemoryManager) GetAuthenticationRequest(challenge string) (*AuthenticationRequest, error) { +func (m *MemoryManager) GetAuthenticationRequest(ctx context.Context, challenge string) (*AuthenticationRequest, error) { m.m["authRequests"].RLock() defer m.m["authRequests"].RUnlock() if c, ok := m.authRequests[challenge]; ok { @@ -299,14 +301,14 @@ func (m *MemoryManager) GetAuthenticationRequest(challenge string) (*Authenticat return nil, errors.WithStack(pkg.ErrNotFound) } -func (m *MemoryManager) HandleAuthenticationRequest(challenge string, r *HandledAuthenticationRequest) (*AuthenticationRequest, error) { +func (m *MemoryManager) HandleAuthenticationRequest(ctx context.Context, challenge string, r *HandledAuthenticationRequest) (*AuthenticationRequest, error) { m.m["handledAuthRequests"].Lock() m.handledAuthRequests[r.Challenge] = *r m.m["handledAuthRequests"].Unlock() - return m.GetAuthenticationRequest(challenge) + return m.GetAuthenticationRequest(ctx, challenge) } -func (m *MemoryManager) VerifyAndInvalidateAuthenticationRequest(verifier string) (*HandledAuthenticationRequest, error) { +func (m *MemoryManager) VerifyAndInvalidateAuthenticationRequest(ctx context.Context, verifier string) (*HandledAuthenticationRequest, error) { for _, c := range m.authRequests { if c.Verifier == verifier { for _, h := range m.handledAuthRequests { @@ -316,7 +318,7 @@ func (m *MemoryManager) VerifyAndInvalidateAuthenticationRequest(verifier string } h.WasUsed = true - if _, err := m.HandleAuthenticationRequest(h.Challenge, &h); err != nil { + if _, err := m.HandleAuthenticationRequest(ctx, h.Challenge, &h); err != nil { return nil, err } diff --git a/consent/manager_sql.go b/consent/manager_sql.go index ffe17921dd3..b0d39acf30d 100644 --- a/consent/manager_sql.go +++ b/consent/manager_sql.go @@ -26,6 +26,8 @@ import ( "strings" "time" + "context" + "github.com/jmoiron/sqlx" "github.com/ory/fosite" "github.com/ory/hydra/client" @@ -58,15 +60,15 @@ func (m *SQLManager) CreateSchemas() (int, error) { return n, nil } -func (m *SQLManager) RevokeUserConsentSession(user string) error { - return m.revokeConsentSession(user, "") +func (m *SQLManager) RevokeUserConsentSession(ctx context.Context, user string) error { + return m.revokeConsentSession(ctx, user, "") } -func (m *SQLManager) RevokeUserClientConsentSession(user, client string) error { - return m.revokeConsentSession(user, client) +func (m *SQLManager) RevokeUserClientConsentSession(ctx context.Context, user, client string) error { + return m.revokeConsentSession(ctx, user, client) } -func (m *SQLManager) revokeConsentSession(user, client string) error { +func (m *SQLManager) revokeConsentSession(ctx context.Context, user, client string) error { args := []interface{}{user} part := "r.subject=?" if client != "" { @@ -87,12 +89,12 @@ JOIN hydra_oauth2_consent_request as r ON r.challenge = h.challenge WHERE %s`, } for _, challenge := range challenges { - if err := m.store.RevokeAccessToken(nil, challenge); errors.Cause(err) == fosite.ErrNotFound { + if err := m.store.RevokeAccessToken(ctx, challenge); errors.Cause(err) == fosite.ErrNotFound { // do nothing } else if err != nil { return err } - if err := m.store.RevokeRefreshToken(nil, challenge); errors.Cause(err) == fosite.ErrNotFound { + if err := m.store.RevokeRefreshToken(ctx, challenge); errors.Cause(err) == fosite.ErrNotFound { // do nothing } else if err != nil { return err @@ -131,7 +133,7 @@ WHERE challenge IN (SELECT r.challenge FROM hydra_oauth2_consent_request as r WH return nil } -func (m *SQLManager) RevokeUserAuthenticationSession(user string) error { +func (m *SQLManager) RevokeUserAuthenticationSession(ctx context.Context, user string) error { rows, err := m.db.Exec( m.db.Rebind("DELETE FROM hydra_oauth2_authentication_session WHERE subject=?"), user, @@ -150,7 +152,7 @@ func (m *SQLManager) RevokeUserAuthenticationSession(user string) error { return nil } -func (m *SQLManager) CreateForcedObfuscatedAuthenticationSession(s *ForcedObfuscatedAuthenticationSession) error { +func (m *SQLManager) CreateForcedObfuscatedAuthenticationSession(ctx context.Context, s *ForcedObfuscatedAuthenticationSession) error { tx, err := m.db.Beginx() if err != nil { return sqlcon.HandleError(err) @@ -183,7 +185,7 @@ func (m *SQLManager) CreateForcedObfuscatedAuthenticationSession(s *ForcedObfusc return nil } -func (m *SQLManager) GetForcedObfuscatedAuthenticationSession(client, obfuscated string) (*ForcedObfuscatedAuthenticationSession, error) { +func (m *SQLManager) GetForcedObfuscatedAuthenticationSession(ctx context.Context, client, obfuscated string) (*ForcedObfuscatedAuthenticationSession, error) { var d ForcedObfuscatedAuthenticationSession if err := m.db.Get(&d, m.db.Rebind("SELECT * FROM hydra_oauth2_obfuscated_authentication_session WHERE client_id=? AND subject_obfuscated=?"), client, obfuscated); err != nil { @@ -196,7 +198,7 @@ func (m *SQLManager) GetForcedObfuscatedAuthenticationSession(client, obfuscated return &d, nil } -func (m *SQLManager) CreateConsentRequest(c *ConsentRequest) error { +func (m *SQLManager) CreateConsentRequest(ctx context.Context, c *ConsentRequest) error { d, err := newSQLConsentRequest(c) if err != nil { return err @@ -213,7 +215,7 @@ func (m *SQLManager) CreateConsentRequest(c *ConsentRequest) error { return nil } -func (m *SQLManager) GetConsentRequest(challenge string) (*ConsentRequest, error) { +func (m *SQLManager) GetConsentRequest(ctx context.Context, challenge string) (*ConsentRequest, error) { var d sqlConsentRequest if err := m.db.Get(&d, m.db.Rebind("SELECT * FROM hydra_oauth2_consent_request WHERE challenge=?"), challenge); err != nil { @@ -223,7 +225,7 @@ func (m *SQLManager) GetConsentRequest(challenge string) (*ConsentRequest, error return nil, sqlcon.HandleError(err) } - c, err := m.c.GetConcreteClient(d.Client) + c, err := m.c.GetConcreteClient(ctx, d.Client) if err != nil { return nil, err } @@ -231,7 +233,7 @@ func (m *SQLManager) GetConsentRequest(challenge string) (*ConsentRequest, error return d.toConsentRequest(c) } -func (m *SQLManager) CreateAuthenticationRequest(c *AuthenticationRequest) error { +func (m *SQLManager) CreateAuthenticationRequest(ctx context.Context, c *AuthenticationRequest) error { d, err := newSQLAuthenticationRequest(c) if err != nil { return err @@ -248,7 +250,7 @@ func (m *SQLManager) CreateAuthenticationRequest(c *AuthenticationRequest) error return nil } -func (m *SQLManager) GetAuthenticationRequest(challenge string) (*AuthenticationRequest, error) { +func (m *SQLManager) GetAuthenticationRequest(ctx context.Context, challenge string) (*AuthenticationRequest, error) { var d sqlConsentRequest if err := m.db.Get(&d, m.db.Rebind("SELECT * FROM hydra_oauth2_authentication_request WHERE challenge=?"), challenge); err != nil { @@ -258,7 +260,7 @@ func (m *SQLManager) GetAuthenticationRequest(challenge string) (*Authentication return nil, sqlcon.HandleError(err) } - c, err := m.c.GetConcreteClient(d.Client) + c, err := m.c.GetConcreteClient(ctx, d.Client) if err != nil { return nil, err } @@ -266,7 +268,7 @@ func (m *SQLManager) GetAuthenticationRequest(challenge string) (*Authentication return d.toAuthenticationRequest(c) } -func (m *SQLManager) HandleConsentRequest(challenge string, r *HandledConsentRequest) (*ConsentRequest, error) { +func (m *SQLManager) HandleConsentRequest(ctx context.Context, challenge string, r *HandledConsentRequest) (*ConsentRequest, error) { d, err := newSQLHandledConsentRequest(r) if err != nil { return nil, err @@ -280,10 +282,10 @@ func (m *SQLManager) HandleConsentRequest(challenge string, r *HandledConsentReq return nil, sqlcon.HandleError(err) } - return m.GetConsentRequest(challenge) + return m.GetConsentRequest(ctx, challenge) } -func (m *SQLManager) VerifyAndInvalidateConsentRequest(verifier string) (*HandledConsentRequest, error) { +func (m *SQLManager) VerifyAndInvalidateConsentRequest(ctx context.Context, verifier string) (*HandledConsentRequest, error) { var d sqlHandledConsentRequest var challenge string @@ -301,7 +303,7 @@ func (m *SQLManager) VerifyAndInvalidateConsentRequest(verifier string) (*Handle return nil, errors.WithStack(fosite.ErrInvalidRequest.WithDebug("Consent verifier has been used already")) } - r, err := m.GetConsentRequest(challenge) + r, err := m.GetConsentRequest(ctx, challenge) if err != nil { return nil, err } @@ -313,7 +315,7 @@ func (m *SQLManager) VerifyAndInvalidateConsentRequest(verifier string) (*Handle return d.toHandledConsentRequest(r) } -func (m *SQLManager) HandleAuthenticationRequest(challenge string, r *HandledAuthenticationRequest) (*AuthenticationRequest, error) { +func (m *SQLManager) HandleAuthenticationRequest(ctx context.Context, challenge string, r *HandledAuthenticationRequest) (*AuthenticationRequest, error) { d, err := newSQLHandledAuthenticationRequest(r) if err != nil { return nil, err @@ -327,10 +329,10 @@ func (m *SQLManager) HandleAuthenticationRequest(challenge string, r *HandledAut return nil, sqlcon.HandleError(err) } - return m.GetAuthenticationRequest(challenge) + return m.GetAuthenticationRequest(ctx, challenge) } -func (m *SQLManager) VerifyAndInvalidateAuthenticationRequest(verifier string) (*HandledAuthenticationRequest, error) { +func (m *SQLManager) VerifyAndInvalidateAuthenticationRequest(ctx context.Context, verifier string) (*HandledAuthenticationRequest, error) { var d sqlHandledAuthenticationRequest var challenge string @@ -352,7 +354,7 @@ func (m *SQLManager) VerifyAndInvalidateAuthenticationRequest(verifier string) ( return nil, sqlcon.HandleError(err) } - r, err := m.GetAuthenticationRequest(challenge) + r, err := m.GetAuthenticationRequest(ctx, challenge) if err != nil { return nil, err } @@ -360,7 +362,7 @@ func (m *SQLManager) VerifyAndInvalidateAuthenticationRequest(verifier string) ( return d.toHandledAuthenticationRequest(r) } -func (m *SQLManager) GetAuthenticationSession(id string) (*AuthenticationSession, error) { +func (m *SQLManager) GetAuthenticationSession(ctx context.Context, id string) (*AuthenticationSession, error) { var a AuthenticationSession if err := m.db.Get(&a, m.db.Rebind("SELECT * FROM hydra_oauth2_authentication_session WHERE id=?"), id); err != nil { if err == sql.ErrNoRows { @@ -372,7 +374,7 @@ func (m *SQLManager) GetAuthenticationSession(id string) (*AuthenticationSession return &a, nil } -func (m *SQLManager) CreateAuthenticationSession(a *AuthenticationSession) error { +func (m *SQLManager) CreateAuthenticationSession(ctx context.Context, a *AuthenticationSession) error { if _, err := m.db.NamedExec(fmt.Sprintf( "INSERT INTO hydra_oauth2_authentication_session (%s) VALUES (%s)", strings.Join(sqlParamsAuthSession, ", "), @@ -384,7 +386,7 @@ func (m *SQLManager) CreateAuthenticationSession(a *AuthenticationSession) error return nil } -func (m *SQLManager) DeleteAuthenticationSession(id string) error { +func (m *SQLManager) DeleteAuthenticationSession(ctx context.Context, id string) error { if _, err := m.db.Exec(m.db.Rebind("DELETE FROM hydra_oauth2_authentication_session WHERE id=?"), id); err != nil { return sqlcon.HandleError(err) } @@ -392,7 +394,7 @@ func (m *SQLManager) DeleteAuthenticationSession(id string) error { return nil } -func (m *SQLManager) FindPreviouslyGrantedConsentRequests(client string, subject string) ([]HandledConsentRequest, error) { +func (m *SQLManager) FindPreviouslyGrantedConsentRequests(ctx context.Context, client string, subject string) ([]HandledConsentRequest, error) { var a []sqlHandledConsentRequest if err := m.db.Select(&a, m.db.Rebind(`SELECT h.* FROM @@ -410,10 +412,10 @@ WHERE return nil, sqlcon.HandleError(err) } - return m.resolveHandledConsentRequests(a) + return m.resolveHandledConsentRequests(ctx, a) } -func (m *SQLManager) FindPreviouslyGrantedConsentRequestsByUser(subject string, limit, offset int) ([]HandledConsentRequest, error) { +func (m *SQLManager) FindPreviouslyGrantedConsentRequestsByUser(ctx context.Context, subject string, limit, offset int) ([]HandledConsentRequest, error) { var a []sqlHandledConsentRequest if err := m.db.Select(&a, m.db.Rebind(`SELECT h.* FROM @@ -429,13 +431,13 @@ LIMIT ? OFFSET ? return nil, sqlcon.HandleError(err) } - return m.resolveHandledConsentRequests(a) + return m.resolveHandledConsentRequests(ctx, a) } -func (m *SQLManager) resolveHandledConsentRequests(requests []sqlHandledConsentRequest) ([]HandledConsentRequest, error) { +func (m *SQLManager) resolveHandledConsentRequests(ctx context.Context, requests []sqlHandledConsentRequest) ([]HandledConsentRequest, error) { var aa []HandledConsentRequest for _, v := range requests { - r, err := m.GetConsentRequest(v.Challenge) + r, err := m.GetConsentRequest(ctx, v.Challenge) if err != nil { return nil, err } else if errors.Cause(err) == sqlcon.ErrNoRows { diff --git a/consent/manager_test.go b/consent/manager_test.go index 592cb5c6ef9..5ca8c5546fd 100644 --- a/consent/manager_test.go +++ b/consent/manager_test.go @@ -28,6 +28,8 @@ import ( "testing" "time" + "context" + _ "github.com/go-sql-driver/mysql" _ "github.com/lib/pq" "github.com/ory/fosite" @@ -223,13 +225,13 @@ func TestManagers(t *testing.T) { }, } { t.Run("case=create-get-"+tc.s.ID, func(t *testing.T) { - _, err := m.GetAuthenticationSession(tc.s.ID) + _, err := m.GetAuthenticationSession(context.TODO(), tc.s.ID) require.EqualError(t, err, pkg.ErrNotFound.Error()) - err = m.CreateAuthenticationSession(&tc.s) + err = m.CreateAuthenticationSession(context.TODO(), &tc.s) require.NoError(t, err) - got, err := m.GetAuthenticationSession(tc.s.ID) + got, err := m.GetAuthenticationSession(context.TODO(), tc.s.ID) require.NoError(t, err) assert.EqualValues(t, tc.s.ID, got.ID) assert.EqualValues(t, tc.s.AuthenticatedAt.Unix(), got.AuthenticatedAt.Unix()) @@ -247,10 +249,10 @@ func TestManagers(t *testing.T) { }, } { t.Run("case=delete-get-"+tc.id, func(t *testing.T) { - err := m.DeleteAuthenticationSession(tc.id) + err := m.DeleteAuthenticationSession(context.TODO(), tc.id) require.NoError(t, err) - _, err = m.GetAuthenticationSession(tc.id) + _, err = m.GetAuthenticationSession(context.TODO(), tc.id) require.Error(t, err) }) } @@ -279,27 +281,27 @@ func TestManagers(t *testing.T) { } { t.Run("key="+tc.key, func(t *testing.T) { c, h := mockConsentRequest(tc.key, tc.remember, tc.rememberFor, tc.hasError, tc.skip, tc.authAt) - clientManager.CreateClient(c.Client) // Ignore errors that are caused by duplication + clientManager.CreateClient(context.TODO(), c.Client) // Ignore errors that are caused by duplication - _, err := m.GetConsentRequest("challenge" + tc.key) + _, err := m.GetConsentRequest(context.TODO(), "challenge"+tc.key) require.Error(t, err) - require.NoError(t, m.CreateConsentRequest(c)) + require.NoError(t, m.CreateConsentRequest(context.TODO(), c)) - got1, err := m.GetConsentRequest("challenge" + tc.key) + got1, err := m.GetConsentRequest(context.TODO(), "challenge"+tc.key) require.NoError(t, err) compareConsentRequest(t, c, got1) - got1, err = m.HandleConsentRequest("challenge"+tc.key, h) + got1, err = m.HandleConsentRequest(context.TODO(), "challenge"+tc.key, h) require.NoError(t, err) compareConsentRequest(t, c, got1) - got2, err := m.VerifyAndInvalidateConsentRequest("verifier" + tc.key) + got2, err := m.VerifyAndInvalidateConsentRequest(context.TODO(), "verifier"+tc.key) require.NoError(t, err) compareConsentRequest(t, c, got2.ConsentRequest) assert.Equal(t, c.Challenge, got2.Challenge) - _, err = m.VerifyAndInvalidateConsentRequest("verifier" + tc.key) + _, err = m.VerifyAndInvalidateConsentRequest(context.TODO(), "verifier"+tc.key) require.Error(t, err) }) } @@ -319,7 +321,7 @@ func TestManagers(t *testing.T) { {"6", "6", 0}, } { t.Run("key="+tc.keyC+"-"+tc.keyS, func(t *testing.T) { - rs, err := m.FindPreviouslyGrantedConsentRequests("client"+tc.keyC, "subject"+tc.keyS) + rs, err := m.FindPreviouslyGrantedConsentRequests(context.TODO(), "client"+tc.keyC, "subject"+tc.keyS) if tc.expectedLength == 0 { assert.EqualError(t, err, ErrNoPreviousConsentFound.Error()) } else { @@ -348,27 +350,27 @@ func TestManagers(t *testing.T) { } { t.Run("key="+tc.key, func(t *testing.T) { c, h := mockAuthRequest(tc.key, tc.authAt) - clientManager.CreateClient(c.Client) // Ignore errors that are caused by duplication + clientManager.CreateClient(context.TODO(), c.Client) // Ignore errors that are caused by duplication - _, err := m.GetAuthenticationRequest("challenge" + tc.key) + _, err := m.GetAuthenticationRequest(context.TODO(), "challenge"+tc.key) require.Error(t, err) - require.NoError(t, m.CreateAuthenticationRequest(c)) + require.NoError(t, m.CreateAuthenticationRequest(context.TODO(), c)) - got1, err := m.GetAuthenticationRequest("challenge" + tc.key) + got1, err := m.GetAuthenticationRequest(context.TODO(), "challenge"+tc.key) require.NoError(t, err) compareAuthenticationRequest(t, c, got1) - got1, err = m.HandleAuthenticationRequest("challenge"+tc.key, h) + got1, err = m.HandleAuthenticationRequest(context.TODO(), "challenge"+tc.key, h) require.NoError(t, err) compareAuthenticationRequest(t, c, got1) - got2, err := m.VerifyAndInvalidateAuthenticationRequest("verifier" + tc.key) + got2, err := m.VerifyAndInvalidateAuthenticationRequest(context.TODO(), "verifier"+tc.key) require.NoError(t, err) compareAuthenticationRequest(t, c, got2.AuthenticationRequest) assert.Equal(t, c.Challenge, got2.Challenge) - _, err = m.VerifyAndInvalidateAuthenticationRequest("verifier" + tc.key) + _, err = m.VerifyAndInvalidateAuthenticationRequest(context.TODO(), "verifier"+tc.key) require.Error(t, err) }) } @@ -378,19 +380,19 @@ func TestManagers(t *testing.T) { t.Run("case=revoke-auth-request", func(t *testing.T) { for k, m := range managers { - require.NoError(t, m.CreateAuthenticationSession(&AuthenticationSession{ + require.NoError(t, m.CreateAuthenticationSession(context.TODO(), &AuthenticationSession{ ID: "rev-session-1", AuthenticatedAt: time.Now(), Subject: "subject-1", })) - require.NoError(t, m.CreateAuthenticationSession(&AuthenticationSession{ + require.NoError(t, m.CreateAuthenticationSession(context.TODO(), &AuthenticationSession{ ID: "rev-session-2", AuthenticatedAt: time.Now(), Subject: "subject-2", })) - require.NoError(t, m.CreateAuthenticationSession(&AuthenticationSession{ + require.NoError(t, m.CreateAuthenticationSession(context.TODO(), &AuthenticationSession{ ID: "rev-session-3", AuthenticatedAt: time.Now(), Subject: "subject-1", @@ -411,11 +413,11 @@ func TestManagers(t *testing.T) { }, } { t.Run(fmt.Sprintf("case=%d/subject=%s", i, tc.subject), func(t *testing.T) { - require.NoError(t, m.RevokeUserAuthenticationSession(tc.subject)) + require.NoError(t, m.RevokeUserAuthenticationSession(context.TODO(), tc.subject)) for _, id := range tc.ids { t.Run(fmt.Sprintf("id=%s", id), func(t *testing.T) { - _, err := m.GetAuthenticationSession(id) + _, err := m.GetAuthenticationSession(context.TODO(), id) assert.EqualError(t, err, pkg.ErrNotFound.Error()) }) } @@ -429,14 +431,14 @@ func TestManagers(t *testing.T) { for k, m := range managers { cr1, hcr1 := mockConsentRequest("rv1", false, 0, false, false, false) cr2, hcr2 := mockConsentRequest("rv2", false, 0, false, false, false) - clientManager.CreateClient(cr1.Client) - clientManager.CreateClient(cr2.Client) + clientManager.CreateClient(context.TODO(), cr1.Client) + clientManager.CreateClient(context.TODO(), cr2.Client) - require.NoError(t, m.CreateConsentRequest(cr1)) - require.NoError(t, m.CreateConsentRequest(cr2)) - _, err := m.HandleConsentRequest("challengerv1", hcr1) + require.NoError(t, m.CreateConsentRequest(context.TODO(), cr1)) + require.NoError(t, m.CreateConsentRequest(context.TODO(), cr2)) + _, err := m.HandleConsentRequest(context.TODO(), "challengerv1", hcr1) require.NoError(t, err) - _, err = m.HandleConsentRequest("challengerv2", hcr2) + _, err = m.HandleConsentRequest(context.TODO(), "challengerv2", hcr2) require.NoError(t, err) t.Run("manager="+k, func(t *testing.T) { @@ -472,14 +474,14 @@ func TestManagers(t *testing.T) { assert.True(t, found) if tc.client == "" { - require.NoError(t, m.RevokeUserConsentSession(tc.subject)) + require.NoError(t, m.RevokeUserConsentSession(context.TODO(), tc.subject)) } else { - require.NoError(t, m.RevokeUserClientConsentSession(tc.subject, tc.client)) + require.NoError(t, m.RevokeUserClientConsentSession(context.TODO(), tc.subject, tc.client)) } for _, id := range tc.ids { t.Run(fmt.Sprintf("id=%s", id), func(t *testing.T) { - _, err := m.GetConsentRequest(id) + _, err := m.GetConsentRequest(context.TODO(), id) assert.EqualError(t, err, pkg.ErrNotFound.Error()) }) } @@ -500,14 +502,14 @@ func TestManagers(t *testing.T) { for k, m := range managers { cr1, hcr1 := mockConsentRequest("rv1", true, 0, false, false, false) cr2, hcr2 := mockConsentRequest("rv2", false, 0, false, false, false) - clientManager.CreateClient(cr1.Client) - clientManager.CreateClient(cr2.Client) + clientManager.CreateClient(context.TODO(), cr1.Client) + clientManager.CreateClient(context.TODO(), cr2.Client) - require.NoError(t, m.CreateConsentRequest(cr1)) - require.NoError(t, m.CreateConsentRequest(cr2)) - _, err := m.HandleConsentRequest("challengerv1", hcr1) + require.NoError(t, m.CreateConsentRequest(context.TODO(), cr1)) + require.NoError(t, m.CreateConsentRequest(context.TODO(), cr2)) + _, err := m.HandleConsentRequest(context.TODO(), "challengerv1", hcr1) require.NoError(t, err) - _, err = m.HandleConsentRequest("challengerv2", hcr2) + _, err = m.HandleConsentRequest(context.TODO(), "challengerv2", hcr2) require.NoError(t, err) t.Run("manager="+k, func(t *testing.T) { @@ -528,7 +530,7 @@ func TestManagers(t *testing.T) { }, } { t.Run(fmt.Sprintf("case=%d/subject=%s", i, tc.subject), func(t *testing.T) { - consents, err := m.FindPreviouslyGrantedConsentRequestsByUser(tc.subject, 100, 0) + consents, err := m.FindPreviouslyGrantedConsentRequestsByUser(context.TODO(), tc.subject, 100, 0) assert.Equal(t, len(tc.challenges), len(consents)) if len(tc.challenges) == 0 { @@ -549,7 +551,7 @@ func TestManagers(t *testing.T) { t.Run("case=obfuscated", func(t *testing.T) { for k, m := range managers { t.Run(fmt.Sprintf("manager=%s", k), func(t *testing.T) { - got, err := m.GetForcedObfuscatedAuthenticationSession("client-1", "obfuscated-1") + got, err := m.GetForcedObfuscatedAuthenticationSession(context.TODO(), "client-1", "obfuscated-1") require.EqualError(t, err, pkg.ErrNotFound.Error()) expect := &ForcedObfuscatedAuthenticationSession{ @@ -557,9 +559,9 @@ func TestManagers(t *testing.T) { Subject: "subject-1", SubjectObfuscated: "obfuscated-1", } - require.NoError(t, m.CreateForcedObfuscatedAuthenticationSession(expect)) + require.NoError(t, m.CreateForcedObfuscatedAuthenticationSession(context.TODO(), expect)) - got, err = m.GetForcedObfuscatedAuthenticationSession("client-1", "obfuscated-1") + got, err = m.GetForcedObfuscatedAuthenticationSession(context.TODO(), "client-1", "obfuscated-1") require.NoError(t, err) assert.EqualValues(t, expect, got) @@ -568,13 +570,13 @@ func TestManagers(t *testing.T) { Subject: "subject-1", SubjectObfuscated: "obfuscated-2", } - require.NoError(t, m.CreateForcedObfuscatedAuthenticationSession(expect)) + require.NoError(t, m.CreateForcedObfuscatedAuthenticationSession(context.TODO(), expect)) - got, err = m.GetForcedObfuscatedAuthenticationSession("client-1", "obfuscated-2") + got, err = m.GetForcedObfuscatedAuthenticationSession(context.TODO(), "client-1", "obfuscated-2") require.NoError(t, err) assert.EqualValues(t, expect, got) - got, err = m.GetForcedObfuscatedAuthenticationSession("client-1", "obfuscated-1") + got, err = m.GetForcedObfuscatedAuthenticationSession(context.TODO(), "client-1", "obfuscated-1") require.EqualError(t, err, pkg.ErrNotFound.Error()) }) } diff --git a/consent/sdk_test.go b/consent/sdk_test.go index 54e79cc18b5..03954b87d3d 100644 --- a/consent/sdk_test.go +++ b/consent/sdk_test.go @@ -26,6 +26,8 @@ import ( "testing" "time" + "context" + "github.com/gorilla/sessions" "github.com/julienschmidt/httprouter" "github.com/ory/herodot" @@ -51,27 +53,27 @@ func TestSDK(t *testing.T) { }) require.NoError(t, err) - require.NoError(t, m.CreateAuthenticationSession(&AuthenticationSession{ + require.NoError(t, m.CreateAuthenticationSession(context.TODO(), &AuthenticationSession{ ID: "session1", Subject: "subject1", })) ar1, _ := mockAuthRequest("1", false) ar2, _ := mockAuthRequest("2", false) - require.NoError(t, m.CreateAuthenticationRequest(ar1)) - require.NoError(t, m.CreateAuthenticationRequest(ar2)) + require.NoError(t, m.CreateAuthenticationRequest(context.TODO(), ar1)) + require.NoError(t, m.CreateAuthenticationRequest(context.TODO(), ar2)) cr1, hcr1 := mockConsentRequest("1", false, 0, false, false, false) cr2, hcr2 := mockConsentRequest("2", false, 0, false, false, false) cr3, hcr3 := mockConsentRequest("3", true, 3600, false, false, false) - require.NoError(t, m.CreateConsentRequest(cr1)) - require.NoError(t, m.CreateConsentRequest(cr2)) - require.NoError(t, m.CreateConsentRequest(cr3)) - _, err = m.HandleConsentRequest("challenge1", hcr1) + require.NoError(t, m.CreateConsentRequest(context.TODO(), cr1)) + require.NoError(t, m.CreateConsentRequest(context.TODO(), cr2)) + require.NoError(t, m.CreateConsentRequest(context.TODO(), cr3)) + _, err = m.HandleConsentRequest(context.TODO(), "challenge1", hcr1) require.NoError(t, err) - _, err = m.HandleConsentRequest("challenge2", hcr2) + _, err = m.HandleConsentRequest(context.TODO(), "challenge2", hcr2) require.NoError(t, err) - _, err = m.HandleConsentRequest("challenge3", hcr3) + _, err = m.HandleConsentRequest(context.TODO(), "challenge3", hcr3) require.NoError(t, err) crGot, res, err := sdk.GetConsentRequest("challenge1") diff --git a/consent/strategy_default.go b/consent/strategy_default.go index 0dc1ea6a279..07f24ada8f3 100644 --- a/consent/strategy_default.go +++ b/consent/strategy_default.go @@ -28,6 +28,8 @@ import ( "strings" "time" + "context" + jwtgo "github.com/dgrijalva/jwt-go" "github.com/gorilla/sessions" "github.com/ory/fosite" @@ -117,7 +119,7 @@ func (s *DefaultStrategy) requestAuthentication(w http.ResponseWriter, r *http.R return s.forwardAuthenticationRequest(w, r, ar, "", time.Time{}, nil) } - session, err := s.M.GetAuthenticationSession(sessionID) + session, err := s.M.GetAuthenticationSession(context.TODO(), sessionID) if errors.Cause(err) == pkg.ErrNotFound { return s.forwardAuthenticationRequest(w, r, ar, "", time.Time{}, nil) } else if err != nil { @@ -160,7 +162,7 @@ func (s *DefaultStrategy) requestAuthentication(w http.ResponseWriter, r *http.R return err } - if s, err := s.M.GetForcedObfuscatedAuthenticationSession(ar.GetClient().GetID(), hintSub); errors.Cause(err) == pkg.ErrNotFound { + if s, err := s.M.GetForcedObfuscatedAuthenticationSession(context.TODO(), ar.GetClient().GetID(), hintSub); errors.Cause(err) == pkg.ErrNotFound { // do nothing } else if err != nil { return err @@ -221,6 +223,7 @@ func (s *DefaultStrategy) forwardAuthenticationRequest(w http.ResponseWriter, r // Set the session if err := s.M.CreateAuthenticationRequest( + context.TODO(), &AuthenticationRequest{ Challenge: challenge, Verifier: verifier, @@ -274,7 +277,7 @@ func (s *DefaultStrategy) revokeAuthenticationSession(w http.ResponseWriter, r * return nil } - return s.M.DeleteAuthenticationSession(sid) + return s.M.DeleteAuthenticationSession(context.TODO(), sid) } func revokeAuthenticationCookie(w http.ResponseWriter, r *http.Request, s sessions.Store) (string, error) { @@ -310,7 +313,7 @@ func (s *DefaultStrategy) obfuscateSubjectIdentifier(subject string, req fosite. } func (s *DefaultStrategy) verifyAuthentication(w http.ResponseWriter, r *http.Request, req fosite.AuthorizeRequester, verifier string) (*HandledAuthenticationRequest, error) { - session, err := s.M.VerifyAndInvalidateAuthenticationRequest(verifier) + session, err := s.M.VerifyAndInvalidateAuthenticationRequest(context.TODO(), verifier) if errors.Cause(err) == pkg.ErrNotFound { return nil, errors.WithStack(fosite.ErrAccessDenied.WithDebug("The login verifier has already been used, has not been granted, or is invalid.")) } else if err != nil { @@ -383,7 +386,7 @@ func (s *DefaultStrategy) verifyAuthentication(w http.ResponseWriter, r *http.Re } if session.ForceSubjectIdentifier != "" { - if err := s.M.CreateForcedObfuscatedAuthenticationSession(&ForcedObfuscatedAuthenticationSession{ + if err := s.M.CreateForcedObfuscatedAuthenticationSession(context.TODO(), &ForcedObfuscatedAuthenticationSession{ Subject: session.Subject, ClientID: req.GetClient().GetID(), SubjectObfuscated: session.ForceSubjectIdentifier, @@ -407,7 +410,7 @@ func (s *DefaultStrategy) verifyAuthentication(w http.ResponseWriter, r *http.Re cookie, _ := s.CookieStore.Get(r, cookieAuthenticationName) sid := uuid.New() - if err := s.M.CreateAuthenticationSession(&AuthenticationSession{ + if err := s.M.CreateAuthenticationSession(context.TODO(), &AuthenticationSession{ ID: sid, Subject: session.Subject, AuthenticatedAt: session.AuthenticatedAt, @@ -469,7 +472,7 @@ func (s *DefaultStrategy) requestConsent(w http.ResponseWriter, r *http.Request, // return s.forwardConsentRequest(w, r, ar, authenticationSession, nil) // } - consentSessions, err := s.M.FindPreviouslyGrantedConsentRequests(ar.GetClient().GetID(), authenticationSession.Subject) + consentSessions, err := s.M.FindPreviouslyGrantedConsentRequests(context.TODO(), ar.GetClient().GetID(), authenticationSession.Subject) if errors.Cause(err) == ErrNoPreviousConsentFound { return s.forwardConsentRequest(w, r, ar, authenticationSession, nil) } else if err != nil { @@ -500,6 +503,7 @@ func (s *DefaultStrategy) forwardConsentRequest(w http.ResponseWriter, r *http.R csrf := strings.Replace(uuid.New(), "-", "", -1) if err := s.M.CreateConsentRequest( + context.TODO(), &ConsentRequest{ Challenge: challenge, Verifier: verifier, @@ -540,7 +544,7 @@ func (s *DefaultStrategy) forwardConsentRequest(w http.ResponseWriter, r *http.R } func (s *DefaultStrategy) verifyConsent(w http.ResponseWriter, r *http.Request, req fosite.AuthorizeRequester, verifier string) (*HandledConsentRequest, error) { - session, err := s.M.VerifyAndInvalidateConsentRequest(verifier) + session, err := s.M.VerifyAndInvalidateConsentRequest(context.TODO(), verifier) if errors.Cause(err) == pkg.ErrNotFound { return nil, errors.WithStack(fosite.ErrAccessDenied.WithDebug("The consent verifier has already been used, has not been granted, or is invalid.")) } else if err != nil { diff --git a/integration/sql_schema_test.go b/integration/sql_schema_test.go index 105682e7c34..0b21bb42dcc 100644 --- a/integration/sql_schema_test.go +++ b/integration/sql_schema_test.go @@ -25,6 +25,8 @@ import ( "testing" "time" + "context" + _ "github.com/go-sql-driver/mysql" _ "github.com/lib/pq" "github.com/ory/fosite" @@ -74,10 +76,10 @@ func TestSQLSchema(t *testing.T) { _, err = crm.CreateSchemas() require.NoError(t, err) - require.NoError(t, jm.AddKey("integration-test-foo", jwk.First(p1))) + require.NoError(t, jm.AddKey(context.TODO(), "integration-test-foo", jwk.First(p1))) require.NoError(t, pm.Create(&ladon.DefaultPolicy{ID: "integration-test-foo", Resources: []string{"foo"}, Actions: []string{"bar"}, Subjects: []string{"baz"}, Effect: "allow"})) - require.NoError(t, cm.CreateClient(&client.Client{ClientID: "integration-test-foo"})) - require.NoError(t, crm.CreateAuthenticationSession(&consent.AuthenticationSession{ + require.NoError(t, cm.CreateClient(context.TODO(), &client.Client{ClientID: "integration-test-foo"})) + require.NoError(t, crm.CreateAuthenticationSession(context.TODO(), &consent.AuthenticationSession{ ID: "foo", AuthenticatedAt: time.Now(), Subject: "bar", diff --git a/jwk/handler.go b/jwk/handler.go index 21ebce60361..b4ae137f14c 100644 --- a/jwk/handler.go +++ b/jwk/handler.go @@ -110,7 +110,7 @@ func (h *Handler) WellKnown(w http.ResponseWriter, r *http.Request, ps httproute var jwks jose.JSONWebKeySet for _, set := range h.WellKnownKeys { - keys, err := h.Manager.GetKeySet(set) + keys, err := h.Manager.GetKeySet(r.Context(), set) if err != nil { h.H.WriteError(w, r, err) return @@ -153,7 +153,7 @@ func (h *Handler) GetKey(w http.ResponseWriter, r *http.Request, ps httprouter.P var setName = ps.ByName("set") var keyName = ps.ByName("key") - keys, err := h.Manager.GetKey(setName, keyName) + keys, err := h.Manager.GetKey(r.Context(), setName, keyName) if err != nil { h.H.WriteError(w, r, err) return @@ -186,7 +186,7 @@ func (h *Handler) GetKey(w http.ResponseWriter, r *http.Request, ps httprouter.P func (h *Handler) GetKeySet(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { var setName = ps.ByName("set") - keys, err := h.Manager.GetKeySet(setName) + keys, err := h.Manager.GetKeySet(r.Context(), setName) if err != nil { h.H.WriteError(w, r, err) return @@ -236,7 +236,7 @@ func (h *Handler) Create(w http.ResponseWriter, r *http.Request, ps httprouter.P return } - if err := h.Manager.AddKeySet(set, keys); err != nil { + if err := h.Manager.AddKeySet(r.Context(), set, keys); err != nil { h.H.WriteError(w, r, err) return } @@ -274,7 +274,7 @@ func (h *Handler) UpdateKeySet(w http.ResponseWriter, r *http.Request, ps httpro return } - if err := h.Manager.AddKeySet(set, &keySet); err != nil { + if err := h.Manager.AddKeySet(r.Context(), set, &keySet); err != nil { h.H.WriteError(w, r, err) return } @@ -312,12 +312,12 @@ func (h *Handler) UpdateKey(w http.ResponseWriter, r *http.Request, ps httproute return } - if err := h.Manager.DeleteKey(set, key.KeyID); err != nil { + if err := h.Manager.DeleteKey(r.Context(), set, key.KeyID); err != nil { h.H.WriteError(w, r, err) return } - if err := h.Manager.AddKey(set, &key); err != nil { + if err := h.Manager.AddKey(r.Context(), set, &key); err != nil { h.H.WriteError(w, r, err) return } @@ -349,7 +349,7 @@ func (h *Handler) UpdateKey(w http.ResponseWriter, r *http.Request, ps httproute func (h *Handler) DeleteKeySet(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { var setName = ps.ByName("set") - if err := h.Manager.DeleteKeySet(setName); err != nil { + if err := h.Manager.DeleteKeySet(r.Context(), setName); err != nil { h.H.WriteError(w, r, err) return } @@ -382,7 +382,7 @@ func (h *Handler) DeleteKey(w http.ResponseWriter, r *http.Request, ps httproute var setName = ps.ByName("set") var keyName = ps.ByName("key") - if err := h.Manager.DeleteKey(setName, keyName); err != nil { + if err := h.Manager.DeleteKey(r.Context(), setName, keyName); err != nil { h.H.WriteError(w, r, err) return } diff --git a/jwk/handler_test.go b/jwk/handler_test.go index 7eb780c4fbb..a81e6a8c2e7 100644 --- a/jwk/handler_test.go +++ b/jwk/handler_test.go @@ -26,6 +26,8 @@ import ( "net/http/httptest" "testing" + "context" + "github.com/julienschmidt/httprouter" "github.com/ory/herodot" . "github.com/ory/hydra/jwk" @@ -47,7 +49,7 @@ func init() { herodot.NewJSONWriter(nil), []string{}, ) - h.Manager.AddKeySet(IDTokenKeyName, IDKS) + h.Manager.AddKeySet(context.TODO(), IDTokenKeyName, IDKS) h.SetRoutes(router, router) testServer = httptest.NewServer(router) } diff --git a/jwk/jwt_strategy.go b/jwk/jwt_strategy.go index 1b0a6f85d67..f1f59beea82 100644 --- a/jwk/jwt_strategy.go +++ b/jwk/jwt_strategy.go @@ -24,6 +24,8 @@ import ( "crypto/rsa" "strings" + "context" + jwt2 "github.com/dgrijalva/jwt-go" "github.com/ory/fosite/token/jwt" "github.com/pkg/errors" @@ -104,7 +106,7 @@ func (j *RS256JWTStrategy) GetPublicKeyID() (string, error) { } func (j *RS256JWTStrategy) refresh() error { - keys, err := j.Manager.GetKeySet(j.Set) + keys, err := j.Manager.GetKeySet(context.TODO(), j.Set) if err != nil { return err } diff --git a/jwk/jwt_strategy_test.go b/jwk/jwt_strategy_test.go index 2839af443e1..bb059eda7c6 100644 --- a/jwk/jwt_strategy_test.go +++ b/jwk/jwt_strategy_test.go @@ -23,6 +23,8 @@ package jwk import ( "testing" + "context" + jwt2 "github.com/dgrijalva/jwt-go" "github.com/ory/fosite/token/jwt" "github.com/stretchr/testify/assert" @@ -39,7 +41,7 @@ func TestRS256JWTStrategy(t *testing.T) { ks, err := testGenerator.Generate("foo", "sig") require.NoError(t, err) - require.NoError(t, m.AddKeySet("foo-set", ks)) + require.NoError(t, m.AddKeySet(context.TODO(), "foo-set", ks)) s, err := NewRS256JWTStrategy(m, "foo-set") require.NoError(t, err) @@ -57,7 +59,7 @@ func TestRS256JWTStrategy(t *testing.T) { ks, err = testGenerator.Generate("bar", "sig") require.NoError(t, err) - require.NoError(t, m.AddKeySet("foo-set", ks)) + require.NoError(t, m.AddKeySet(context.TODO(), "foo-set", ks)) a, b, err = s.Generate(jwt2.MapClaims{"foo": "bar"}, &jwt.Headers{}) require.NoError(t, err) diff --git a/jwk/manager.go b/jwk/manager.go index 46ef370a60e..2afaff9e52a 100644 --- a/jwk/manager.go +++ b/jwk/manager.go @@ -20,18 +20,22 @@ package jwk -import "gopkg.in/square/go-jose.v2" +import ( + "context" + + "gopkg.in/square/go-jose.v2" +) type Manager interface { - AddKey(set string, key *jose.JSONWebKey) error + AddKey(ctx context.Context, set string, key *jose.JSONWebKey) error - AddKeySet(set string, keys *jose.JSONWebKeySet) error + AddKeySet(ctx context.Context, set string, keys *jose.JSONWebKeySet) error - GetKey(set, kid string) (*jose.JSONWebKeySet, error) + GetKey(ctx context.Context, set, kid string) (*jose.JSONWebKeySet, error) - GetKeySet(set string) (*jose.JSONWebKeySet, error) + GetKeySet(ctx context.Context, set string) (*jose.JSONWebKeySet, error) - DeleteKey(set, kid string) error + DeleteKey(ctx context.Context, set, kid string) error - DeleteKeySet(set string) error + DeleteKeySet(ctx context.Context, set string) error } diff --git a/jwk/manager_memory.go b/jwk/manager_memory.go index 07c8e8e44de..a01f54bc4fc 100644 --- a/jwk/manager_memory.go +++ b/jwk/manager_memory.go @@ -26,6 +26,8 @@ import ( "fmt" "net/http" + "context" + "github.com/ory/fosite" "github.com/ory/hydra/pkg" "github.com/pkg/errors" @@ -37,7 +39,7 @@ type MemoryManager struct { sync.RWMutex } -func (m *MemoryManager) AddKey(set string, key *jose.JSONWebKey) error { +func (m *MemoryManager) AddKey(ctx context.Context, set string, key *jose.JSONWebKey) error { m.Lock() defer m.Unlock() @@ -60,14 +62,14 @@ func (m *MemoryManager) AddKey(set string, key *jose.JSONWebKey) error { return nil } -func (m *MemoryManager) AddKeySet(set string, keys *jose.JSONWebKeySet) error { +func (m *MemoryManager) AddKeySet(ctx context.Context, set string, keys *jose.JSONWebKeySet) error { for _, key := range keys.Keys { - m.AddKey(set, &key) + m.AddKey(ctx, set, &key) } return nil } -func (m *MemoryManager) GetKey(set, kid string) (*jose.JSONWebKeySet, error) { +func (m *MemoryManager) GetKey(ctx context.Context, set, kid string) (*jose.JSONWebKeySet, error) { m.RLock() defer m.RUnlock() @@ -87,7 +89,7 @@ func (m *MemoryManager) GetKey(set, kid string) (*jose.JSONWebKeySet, error) { }, nil } -func (m *MemoryManager) GetKeySet(set string) (*jose.JSONWebKeySet, error) { +func (m *MemoryManager) GetKeySet(ctx context.Context, set string) (*jose.JSONWebKeySet, error) { m.RLock() defer m.RUnlock() @@ -104,8 +106,8 @@ func (m *MemoryManager) GetKeySet(set string) (*jose.JSONWebKeySet, error) { return keys, nil } -func (m *MemoryManager) DeleteKey(set, kid string) error { - keys, err := m.GetKeySet(set) +func (m *MemoryManager) DeleteKey(ctx context.Context, set, kid string) error { + keys, err := m.GetKeySet(ctx, set) if err != nil { return err } @@ -123,7 +125,7 @@ func (m *MemoryManager) DeleteKey(set, kid string) error { return nil } -func (m *MemoryManager) DeleteKeySet(set string) error { +func (m *MemoryManager) DeleteKeySet(ctx context.Context, set string) error { m.Lock() defer m.Unlock() diff --git a/jwk/manager_sql.go b/jwk/manager_sql.go index f7a757bd228..8384817a6ed 100644 --- a/jwk/manager_sql.go +++ b/jwk/manager_sql.go @@ -24,6 +24,8 @@ import ( "encoding/json" "time" + "context" + "github.com/jmoiron/sqlx" "github.com/ory/hydra/pkg" "github.com/ory/sqlcon" @@ -95,7 +97,7 @@ func (m *SQLManager) CreateSchemas() (int, error) { return n, nil } -func (m *SQLManager) AddKey(set string, key *jose.JSONWebKey) error { +func (m *SQLManager) AddKey(ctx context.Context, set string, key *jose.JSONWebKey) error { out, err := json.Marshal(key) if err != nil { return errors.WithStack(err) @@ -117,13 +119,13 @@ func (m *SQLManager) AddKey(set string, key *jose.JSONWebKey) error { return nil } -func (m *SQLManager) AddKeySet(set string, keys *jose.JSONWebKeySet) error { +func (m *SQLManager) AddKeySet(ctx context.Context, set string, keys *jose.JSONWebKeySet) error { tx, err := m.DB.Beginx() if err != nil { return errors.WithStack(err) } - if err := m.addKeySet(tx, m.Cipher, set, keys); err != nil { + if err := m.addKeySet(ctx, tx, m.Cipher, set, keys); err != nil { if re := tx.Rollback(); re != nil { return errors.Wrap(err, re.Error()) } @@ -139,7 +141,7 @@ func (m *SQLManager) AddKeySet(set string, keys *jose.JSONWebKeySet) error { return nil } -func (m *SQLManager) addKeySet(tx *sqlx.Tx, cipher *AEAD, set string, keys *jose.JSONWebKeySet) error { +func (m *SQLManager) addKeySet(ctx context.Context, tx *sqlx.Tx, cipher *AEAD, set string, keys *jose.JSONWebKeySet) error { for _, key := range keys.Keys { out, err := json.Marshal(key) if err != nil { @@ -164,7 +166,7 @@ func (m *SQLManager) addKeySet(tx *sqlx.Tx, cipher *AEAD, set string, keys *jose return nil } -func (m *SQLManager) GetKey(set, kid string) (*jose.JSONWebKeySet, error) { +func (m *SQLManager) GetKey(ctx context.Context, set, kid string) (*jose.JSONWebKeySet, error) { var d sqlData if err := m.DB.Get(&d, m.DB.Rebind("SELECT * FROM hydra_jwk WHERE sid=? AND kid=? ORDER BY created_at DESC"), set, kid); err != nil { return nil, sqlcon.HandleError(err) @@ -185,7 +187,7 @@ func (m *SQLManager) GetKey(set, kid string) (*jose.JSONWebKeySet, error) { }, nil } -func (m *SQLManager) GetKeySet(set string) (*jose.JSONWebKeySet, error) { +func (m *SQLManager) GetKeySet(ctx context.Context, set string) (*jose.JSONWebKeySet, error) { var ds []sqlData if err := m.DB.Select(&ds, m.DB.Rebind("SELECT * FROM hydra_jwk WHERE sid=? ORDER BY created_at DESC"), set); err != nil { return nil, sqlcon.HandleError(err) @@ -216,20 +218,20 @@ func (m *SQLManager) GetKeySet(set string) (*jose.JSONWebKeySet, error) { return keys, nil } -func (m *SQLManager) DeleteKey(set, kid string) error { +func (m *SQLManager) DeleteKey(ctx context.Context, set, kid string) error { if _, err := m.DB.Exec(m.DB.Rebind(`DELETE FROM hydra_jwk WHERE sid=? AND kid=?`), set, kid); err != nil { return sqlcon.HandleError(err) } return nil } -func (m *SQLManager) DeleteKeySet(set string) error { +func (m *SQLManager) DeleteKeySet(ctx context.Context, set string) error { tx, err := m.DB.Beginx() if err != nil { return errors.WithStack(err) } - if err := m.deleteKeySet(tx, set); err != nil { + if err := m.deleteKeySet(ctx, tx, set); err != nil { if re := tx.Rollback(); re != nil { return errors.Wrap(err, re.Error()) } @@ -245,7 +247,7 @@ func (m *SQLManager) DeleteKeySet(set string) error { return nil } -func (m *SQLManager) deleteKeySet(tx *sqlx.Tx, set string) error { +func (m *SQLManager) deleteKeySet(ctx context.Context, tx *sqlx.Tx, set string) error { if _, err := tx.Exec(m.DB.Rebind(`DELETE FROM hydra_jwk WHERE sid=?`), set); err != nil { return sqlcon.HandleError(err) } @@ -260,7 +262,7 @@ func (m *SQLManager) RotateKeys(new *AEAD) error { sets := make([]jose.JSONWebKeySet, 0) for _, sid := range sids { - set, err := m.GetKeySet(sid) + set, err := m.GetKeySet(context.TODO(), sid) if err != nil { return errors.WithStack(err) } @@ -273,14 +275,14 @@ func (m *SQLManager) RotateKeys(new *AEAD) error { } for k, set := range sets { - if err := m.deleteKeySet(tx, sids[k]); err != nil { + if err := m.deleteKeySet(context.TODO(), tx, sids[k]); err != nil { if re := tx.Rollback(); re != nil { return errors.Wrap(err, re.Error()) } return sqlcon.HandleError(err) } - if err := m.addKeySet(tx, new, sids[k], &set); err != nil { + if err := m.addKeySet(context.TODO(), tx, new, sids[k], &set); err != nil { if re := tx.Rollback(); re != nil { return errors.Wrap(err, re.Error()) } diff --git a/jwk/manager_test.go b/jwk/manager_test.go index 5dec6376deb..fa92607eee3 100644 --- a/jwk/manager_test.go +++ b/jwk/manager_test.go @@ -27,6 +27,8 @@ import ( "sync" "testing" + "context" + _ "github.com/go-sql-driver/mysql" _ "github.com/lib/pq" . "github.com/ory/hydra/jwk" @@ -130,15 +132,15 @@ func TestManagerRotate(t *testing.T) { require.NoError(t, err) t.Logf("Applied %d migrations to %s", n, name) - require.NoError(t, m.AddKeySet("TestManagerRotate", ks)) + require.NoError(t, m.AddKeySet(context.TODO(), "TestManagerRotate", ks)) require.NoError(t, m.RotateKeys(newCipher)) - _, err = m.GetKeySet("TestManagerRotate") + _, err = m.GetKeySet(context.TODO(), "TestManagerRotate") require.Error(t, err) m.Cipher = newCipher - got, err := m.GetKeySet("TestManagerRotate") + got, err := m.GetKeySet(context.TODO(), "TestManagerRotate") require.NoError(t, err) for _, key := range ks.Keys { diff --git a/jwk/manager_test_helpers.go b/jwk/manager_test_helpers.go index 61ee51bfa2c..5ef975036fb 100644 --- a/jwk/manager_test_helpers.go +++ b/jwk/manager_test_helpers.go @@ -26,6 +26,8 @@ import ( "testing" "time" + "context" + "github.com/pkg/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -46,24 +48,24 @@ func TestHelperManagerKey(m Manager, keys *jose.JSONWebKeySet, suffix string) fu return func(t *testing.T) { t.Parallel() - _, err := m.GetKey("faz", "baz") + _, err := m.GetKey(context.TODO(), "faz", "baz") assert.NotNil(t, err) - err = m.AddKey("faz", First(priv)) + err = m.AddKey(context.TODO(), "faz", First(priv)) require.NoError(t, err) - got, err := m.GetKey("faz", "private:"+suffix) + got, err := m.GetKey(context.TODO(), "faz", "private:"+suffix) require.NoError(t, err) assert.Equal(t, priv, got.Keys) - err = m.AddKey("faz", First(pub)) + err = m.AddKey(context.TODO(), "faz", First(pub)) require.NoError(t, err) - got, err = m.GetKey("faz", "private:"+suffix) + got, err = m.GetKey(context.TODO(), "faz", "private:"+suffix) require.NoError(t, err) assert.Equal(t, priv, got.Keys) - got, err = m.GetKey("faz", "public:"+suffix) + got, err = m.GetKey(context.TODO(), "faz", "public:"+suffix) require.NoError(t, err) assert.Equal(t, pub, got.Keys) @@ -71,20 +73,20 @@ func TestHelperManagerKey(m Manager, keys *jose.JSONWebKeySet, suffix string) fu time.Sleep(time.Second * 2) First(pub).KeyID = "new-key-id:" + suffix - err = m.AddKey("faz", First(pub)) + err = m.AddKey(context.TODO(), "faz", First(pub)) require.NoError(t, err) - _, err = m.GetKey("faz", "new-key-id:"+suffix) + _, err = m.GetKey(context.TODO(), "faz", "new-key-id:"+suffix) require.NoError(t, err) - keys, err = m.GetKeySet("faz") + keys, err = m.GetKeySet(context.TODO(), "faz") require.NoError(t, err) assert.EqualValues(t, "new-key-id:"+suffix, First(keys.Keys).KeyID) - err = m.DeleteKey("faz", "public:"+suffix) + err = m.DeleteKey(context.TODO(), "faz", "public:"+suffix) require.NoError(t, err) - _, err = m.GetKey("faz", "public:"+suffix) + _, err = m.GetKey(context.TODO(), "faz", "public:"+suffix) require.Error(t, err) } } @@ -92,21 +94,21 @@ func TestHelperManagerKey(m Manager, keys *jose.JSONWebKeySet, suffix string) fu func TestHelperManagerKeySet(m Manager, keys *jose.JSONWebKeySet, suffix string) func(t *testing.T) { return func(t *testing.T) { t.Parallel() - _, err := m.GetKeySet("foo") + _, err := m.GetKeySet(context.TODO(), "foo") require.Error(t, err) - err = m.AddKeySet("bar", keys) + err = m.AddKeySet(context.TODO(), "bar", keys) require.NoError(t, err) - got, err := m.GetKeySet("bar") + got, err := m.GetKeySet(context.TODO(), "bar") require.NoError(t, err) assert.Equal(t, keys.Key("public:"+suffix), got.Key("public:"+suffix)) assert.Equal(t, keys.Key("private:"+suffix), got.Key("private:"+suffix)) - err = m.DeleteKeySet("bar") + err = m.DeleteKeySet(context.TODO(), "bar") require.NoError(t, err) - _, err = m.GetKeySet("bar") + _, err = m.GetKeySet(context.TODO(), "bar") require.Error(t, err) } } diff --git a/oauth2/fosite_store_sql.go b/oauth2/fosite_store_sql.go index fb2f40f9a63..b070e11d8be 100644 --- a/oauth2/fosite_store_sql.go +++ b/oauth2/fosite_store_sql.go @@ -265,7 +265,7 @@ func (s *FositeSQLStore) hashSignature(signature, table string) string { return signature } -func (s *FositeSQLStore) createSession(signature string, requester fosite.Requester, table string) error { +func (s *FositeSQLStore) createSession(ctx context.Context, signature string, requester fosite.Requester, table string) error { signature = s.hashSignature(signature, table) data, err := sqlSchemaFromRequest(signature, requester, s.L) @@ -285,7 +285,7 @@ func (s *FositeSQLStore) createSession(signature string, requester fosite.Reques return nil } -func (s *FositeSQLStore) findSessionBySignature(signature string, session fosite.Session, table string) (fosite.Requester, error) { +func (s *FositeSQLStore) findSessionBySignature(ctx context.Context, signature string, session fosite.Session, table string) (fosite.Requester, error) { signature = s.hashSignature(signature, table) var d sqlData @@ -306,7 +306,7 @@ func (s *FositeSQLStore) findSessionBySignature(signature string, session fosite return d.toRequest(session, s.Manager, s.L) } -func (s *FositeSQLStore) deleteSession(signature string, table string) error { +func (s *FositeSQLStore) deleteSession(ctx context.Context, signature string, table string) error { signature = s.hashSignature(signature, table) if _, err := s.DB.Exec(s.DB.Rebind(fmt.Sprintf("DELETE FROM hydra_oauth2_%s WHERE signature=?", table)), signature); err != nil { @@ -324,24 +324,24 @@ func (s *FositeSQLStore) CreateSchemas() (int, error) { return n, nil } -func (s *FositeSQLStore) CreateOpenIDConnectSession(_ context.Context, signature string, requester fosite.Requester) error { - return s.createSession(signature, requester, sqlTableOpenID) +func (s *FositeSQLStore) CreateOpenIDConnectSession(ctx context.Context, signature string, requester fosite.Requester) error { + return s.createSession(ctx, signature, requester, sqlTableOpenID) } -func (s *FositeSQLStore) GetOpenIDConnectSession(_ context.Context, signature string, requester fosite.Requester) (fosite.Requester, error) { - return s.findSessionBySignature(signature, requester.GetSession(), sqlTableOpenID) +func (s *FositeSQLStore) GetOpenIDConnectSession(ctx context.Context, signature string, requester fosite.Requester) (fosite.Requester, error) { + return s.findSessionBySignature(ctx, signature, requester.GetSession(), sqlTableOpenID) } -func (s *FositeSQLStore) DeleteOpenIDConnectSession(_ context.Context, signature string) error { - return s.deleteSession(signature, sqlTableOpenID) +func (s *FositeSQLStore) DeleteOpenIDConnectSession(ctx context.Context, signature string) error { + return s.deleteSession(ctx, signature, sqlTableOpenID) } -func (s *FositeSQLStore) CreateAuthorizeCodeSession(_ context.Context, signature string, requester fosite.Requester) error { - return s.createSession(signature, requester, sqlTableCode) +func (s *FositeSQLStore) CreateAuthorizeCodeSession(ctx context.Context, signature string, requester fosite.Requester) error { + return s.createSession(ctx, signature, requester, sqlTableCode) } -func (s *FositeSQLStore) GetAuthorizeCodeSession(_ context.Context, signature string, session fosite.Session) (fosite.Requester, error) { - return s.findSessionBySignature(signature, session, sqlTableCode) +func (s *FositeSQLStore) GetAuthorizeCodeSession(ctx context.Context, signature string, session fosite.Session) (fosite.Requester, error) { + return s.findSessionBySignature(ctx, signature, session, sqlTableCode) } func (s *FositeSQLStore) InvalidateAuthorizeCodeSession(ctx context.Context, signature string) error { @@ -355,40 +355,40 @@ func (s *FositeSQLStore) InvalidateAuthorizeCodeSession(ctx context.Context, sig return nil } -func (s *FositeSQLStore) CreateAccessTokenSession(_ context.Context, signature string, requester fosite.Requester) error { - return s.createSession(signature, requester, sqlTableAccess) +func (s *FositeSQLStore) CreateAccessTokenSession(ctx context.Context, signature string, requester fosite.Requester) error { + return s.createSession(ctx, signature, requester, sqlTableAccess) } -func (s *FositeSQLStore) GetAccessTokenSession(_ context.Context, signature string, session fosite.Session) (fosite.Requester, error) { - return s.findSessionBySignature(signature, session, sqlTableAccess) +func (s *FositeSQLStore) GetAccessTokenSession(ctx context.Context, signature string, session fosite.Session) (fosite.Requester, error) { + return s.findSessionBySignature(ctx, signature, session, sqlTableAccess) } -func (s *FositeSQLStore) DeleteAccessTokenSession(_ context.Context, signature string) error { - return s.deleteSession(signature, sqlTableAccess) +func (s *FositeSQLStore) DeleteAccessTokenSession(ctx context.Context, signature string) error { + return s.deleteSession(ctx, signature, sqlTableAccess) } -func (s *FositeSQLStore) CreateRefreshTokenSession(_ context.Context, signature string, requester fosite.Requester) error { - return s.createSession(signature, requester, sqlTableRefresh) +func (s *FositeSQLStore) CreateRefreshTokenSession(ctx context.Context, signature string, requester fosite.Requester) error { + return s.createSession(ctx, signature, requester, sqlTableRefresh) } -func (s *FositeSQLStore) GetRefreshTokenSession(_ context.Context, signature string, session fosite.Session) (fosite.Requester, error) { - return s.findSessionBySignature(signature, session, sqlTableRefresh) +func (s *FositeSQLStore) GetRefreshTokenSession(ctx context.Context, signature string, session fosite.Session) (fosite.Requester, error) { + return s.findSessionBySignature(ctx, signature, session, sqlTableRefresh) } -func (s *FositeSQLStore) DeleteRefreshTokenSession(_ context.Context, signature string) error { - return s.deleteSession(signature, sqlTableRefresh) +func (s *FositeSQLStore) DeleteRefreshTokenSession(ctx context.Context, signature string) error { + return s.deleteSession(ctx, signature, sqlTableRefresh) } -func (s *FositeSQLStore) CreatePKCERequestSession(_ context.Context, signature string, requester fosite.Requester) error { - return s.createSession(signature, requester, sqlTablePKCE) +func (s *FositeSQLStore) CreatePKCERequestSession(ctx context.Context, signature string, requester fosite.Requester) error { + return s.createSession(ctx, signature, requester, sqlTablePKCE) } -func (s *FositeSQLStore) GetPKCERequestSession(_ context.Context, signature string, session fosite.Session) (fosite.Requester, error) { - return s.findSessionBySignature(signature, session, sqlTablePKCE) +func (s *FositeSQLStore) GetPKCERequestSession(ctx context.Context, signature string, session fosite.Session) (fosite.Requester, error) { + return s.findSessionBySignature(ctx, signature, session, sqlTablePKCE) } -func (s *FositeSQLStore) DeletePKCERequestSession(_ context.Context, signature string) error { - return s.deleteSession(signature, sqlTablePKCE) +func (s *FositeSQLStore) DeletePKCERequestSession(ctx context.Context, signature string) error { + return s.deleteSession(ctx, signature, sqlTablePKCE) } func (s *FositeSQLStore) CreateImplicitAccessTokenSession(ctx context.Context, signature string, requester fosite.Requester) error { diff --git a/oauth2/handler.go b/oauth2/handler.go index 387b2f566d8..6213bceceb6 100644 --- a/oauth2/handler.go +++ b/oauth2/handler.go @@ -21,7 +21,6 @@ package oauth2 import ( - "context" "encoding/json" "fmt" "net/http" @@ -342,7 +341,7 @@ func (h *Handler) UserinfoHandler(w http.ResponseWriter, r *http.Request) { // 401: genericError // 500: genericError func (h *Handler) RevocationHandler(w http.ResponseWriter, r *http.Request) { - var ctx = fosite.NewContext() + var ctx = r.Context() err := h.OAuth2.NewRevocationRequest(ctx, r) if err != nil { @@ -378,7 +377,7 @@ func (h *Handler) RevocationHandler(w http.ResponseWriter, r *http.Request) { // 500: genericError func (h *Handler) IntrospectHandler(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { var session = NewSession("") - var ctx = fosite.NewContext() + var ctx = r.Context() if r.Method != "POST" { err := errors.WithStack(fosite.ErrInvalidRequest.WithHintf("HTTP method is \"%s\", expected \"POST\".", r.Method)) @@ -480,7 +479,7 @@ func (h *Handler) FlushHandler(w http.ResponseWriter, r *http.Request, _ httprou fr.NotAfter = time.Now() } - if err := h.Storage.FlushInactiveAccessTokens(context.Background(), fr.NotAfter); err != nil { + if err := h.Storage.FlushInactiveAccessTokens(r.Context(), fr.NotAfter); err != nil { h.H.WriteError(w, r, err) return } @@ -515,7 +514,7 @@ func (h *Handler) FlushHandler(w http.ResponseWriter, r *http.Request, _ httprou // 500: genericError func (h *Handler) TokenHandler(w http.ResponseWriter, r *http.Request) { var session = NewSession("") - var ctx = fosite.NewContext() + var ctx = r.Context() accessRequest, err := h.OAuth2.NewAccessRequest(ctx, r, session) if err != nil { @@ -577,7 +576,7 @@ func (h *Handler) TokenHandler(w http.ResponseWriter, r *http.Request) { // 401: genericError // 500: genericError func (h *Handler) AuthHandler(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { - var ctx = fosite.NewContext() + var ctx = r.Context() authorizeRequest, err := h.OAuth2.NewAuthorizeRequest(ctx, r) if err != nil { diff --git a/oauth2/handler_test.go b/oauth2/handler_test.go index f2f45237525..59b15fb3c80 100644 --- a/oauth2/handler_test.go +++ b/oauth2/handler_test.go @@ -147,7 +147,7 @@ func TestUserinfo(t *testing.T) { jm := &jwk.MemoryManager{Keys: map[string]*jose.JSONWebKeySet{}} keys, err := (&jwk.RS256Generator{}).Generate("signing", "sig") require.NoError(t, err) - require.NoError(t, jm.AddKeySet(oauth2.OpenIDConnectKeyName, keys)) + require.NoError(t, jm.AddKeySet(context.TODO(), oauth2.OpenIDConnectKeyName, keys)) jwtStrategy, err := jwk.NewRS256JWTStrategy(jm, oauth2.OpenIDConnectKeyName) h := &oauth2.Handler{ diff --git a/oauth2/introspector_test.go b/oauth2/introspector_test.go index e72ff9ffb01..872a9a7cee9 100644 --- a/oauth2/introspector_test.go +++ b/oauth2/introspector_test.go @@ -28,6 +28,8 @@ import ( "testing" "time" + "context" + "github.com/julienschmidt/httprouter" "github.com/ory/fosite" "github.com/ory/fosite/compose" @@ -54,7 +56,7 @@ func TestIntrospectorSDK(t *testing.T) { jm := &jwk.MemoryManager{Keys: map[string]*jose.JSONWebKeySet{}} keys, err := (&jwk.RS256Generator{}).Generate("", "sig") require.NoError(t, err) - require.NoError(t, jm.AddKeySet(oauth2.OpenIDConnectKeyName, keys)) + require.NoError(t, jm.AddKeySet(context.TODO(), oauth2.OpenIDConnectKeyName, keys)) jwtStrategy, err := jwk.NewRS256JWTStrategy(jm, oauth2.OpenIDConnectKeyName) router := httprouter.New() diff --git a/oauth2/oauth2_auth_code_test.go b/oauth2/oauth2_auth_code_test.go index e79eb380266..5339fa1b93e 100644 --- a/oauth2/oauth2_auth_code_test.go +++ b/oauth2/oauth2_auth_code_test.go @@ -34,6 +34,8 @@ import ( "testing" "time" + "context" + djwt "github.com/dgrijalva/jwt-go" "github.com/gorilla/sessions" "github.com/julienschmidt/httprouter" @@ -80,7 +82,7 @@ func mockProvider(h *func(w http.ResponseWriter, r *http.Request)) *httptest.Ser } type clientCreator interface { - CreateClient(client *hc.Client) error + CreateClient(cxt context.Context, client *hc.Client) error } // TestAuthCodeWithDefaultStrategy runs proper integration tests against in-memory and database connectors, specifically @@ -169,7 +171,7 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { jm := &jwk.MemoryManager{Keys: map[string]*jose.JSONWebKeySet{}} keys, err := (&jwk.RS256Generator{}).Generate("", "sig") require.NoError(t, err) - require.NoError(t, jm.AddKeySet(OpenIDConnectKeyName, keys)) + require.NoError(t, jm.AddKeySet(context.TODO(), OpenIDConnectKeyName, keys)) jwtStrategy, err := jwk.NewRS256JWTStrategy(jm, OpenIDConnectKeyName) handler := &Handler{ @@ -214,7 +216,7 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { RedirectURL: client.RedirectURIs[0], Scopes: []string{"hydra", "offline", "openid"}, } - require.NoError(t, fs.(clientCreator).CreateClient(&client)) + require.NoError(t, fs.(clientCreator).CreateClient(context.TODO(), &client)) apiClient := swagger.NewOAuth2ApiWithBasePath(api.URL) var callbackHandler *httprouter.Handle @@ -718,7 +720,7 @@ func TestAuthCodeWithMockStrategy(t *testing.T) { jm := &jwk.MemoryManager{Keys: map[string]*jose.JSONWebKeySet{}} keys, err := (&jwk.RS256Generator{}).Generate("", "sig") require.NoError(t, err) - require.NoError(t, jm.AddKeySet(OpenIDConnectKeyName, keys)) + require.NoError(t, jm.AddKeySet(context.TODO(), OpenIDConnectKeyName, keys)) jwtStrategy, err := jwk.NewRS256JWTStrategy(jm, OpenIDConnectKeyName) handler := &Handler{ @@ -757,7 +759,7 @@ func TestAuthCodeWithMockStrategy(t *testing.T) { }) m := sync.Mutex{} - store.CreateClient(&hc.Client{ + store.CreateClient(context.TODO(), &hc.Client{ ClientID: "app-client", Secret: "secret", RedirectURIs: []string{ts.URL + "/callback"}, diff --git a/oauth2/oauth2_client_credentials_test.go b/oauth2/oauth2_client_credentials_test.go index 6b739149c67..4c7d7d0db1d 100644 --- a/oauth2/oauth2_client_credentials_test.go +++ b/oauth2/oauth2_client_credentials_test.go @@ -87,7 +87,7 @@ func TestClientCredentials(t *testing.T) { jm := &jwk.MemoryManager{Keys: map[string]*jose.JSONWebKeySet{}} keys, err := (&jwk.RS256Generator{}).Generate("", "sig") require.NoError(t, err) - require.NoError(t, jm.AddKeySet(OpenIDConnectKeyName, keys)) + require.NoError(t, jm.AddKeySet(context.TODO(), OpenIDConnectKeyName, keys)) jwtStrategy, err := jwk.NewRS256JWTStrategy(jm, OpenIDConnectKeyName) ts := httptest.NewServer(router) @@ -115,7 +115,7 @@ func TestClientCredentials(t *testing.T) { return h }) - require.NoError(t, store.CreateClient(&hc.Client{ + require.NoError(t, store.CreateClient(context.TODO(), &hc.Client{ ClientID: "app-client", Secret: "secret", RedirectURIs: []string{ts.URL + "/callback"}, diff --git a/oauth2/revocator_test.go b/oauth2/revocator_test.go index a66c938793e..28ba7ea49b4 100644 --- a/oauth2/revocator_test.go +++ b/oauth2/revocator_test.go @@ -27,6 +27,8 @@ import ( "testing" "time" + "context" + "github.com/julienschmidt/httprouter" "github.com/ory/fosite" "github.com/ory/fosite/compose" @@ -72,7 +74,7 @@ func TestRevoke(t *testing.T) { jm := &jwk.MemoryManager{Keys: map[string]*jose.JSONWebKeySet{}} keys, err := (&jwk.RS256Generator{}).Generate("", "sig") require.NoError(t, err) - require.NoError(t, jm.AddKeySet(oauth2.OpenIDConnectKeyName, keys)) + require.NoError(t, jm.AddKeySet(context.TODO(), oauth2.OpenIDConnectKeyName, keys)) jwtStrategy, err := jwk.NewRS256JWTStrategy(jm, oauth2.OpenIDConnectKeyName) handler := &oauth2.Handler{