diff --git a/oauth2/fosite_store_test.go b/oauth2/fosite_store_test.go index 767f33444d5..f1e0a03c65a 100644 --- a/oauth2/fosite_store_test.go +++ b/oauth2/fosite_store_test.go @@ -52,7 +52,7 @@ func setupRegistries(t *testing.T) { } func TestManagers(t *testing.T) { - ctx := context.TODO() + ctx := context.Background() tests := []struct { name string enableSessionEncrypted bool @@ -67,7 +67,7 @@ func TestManagers(t *testing.T) { }, } for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { + t.Run("suite="+tc.name, func(t *testing.T) { setupRegistries(t) require.NoError(t, registries["memory"].ClientManager().CreateClient(context.Background(), &client.Client{LegacyClientID: "foobar"})) // this is a workaround because the client is not being created for memory store by test helpers. diff --git a/persistence/sql/migratest/migration_test.go b/persistence/sql/migratest/migration_test.go index 7c4db0c81d2..e08834a9b4a 100644 --- a/persistence/sql/migratest/migration_test.go +++ b/persistence/sql/migratest/migration_test.go @@ -60,7 +60,6 @@ func CompareWithFixture(t *testing.T, actual interface{}, prefix string, id stri } func TestMigrations(t *testing.T) { - //pop.Debug = true connections := make(map[string]*pop.Connection, 1) if testing.Short() { diff --git a/persistence/sql/persister_consent.go b/persistence/sql/persister_consent.go index 11ca9b28a46..ced61f3c0e6 100644 --- a/persistence/sql/persister_consent.go +++ b/persistence/sql/persister_consent.go @@ -461,6 +461,11 @@ func (p *Persister) DeleteLoginSession(ctx context.Context, id string) (deletedS ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteLoginSession") defer otelx.End(span, &err) + if p.Connection(ctx).Dialect.Name() == "mysql" { + // MySQL does not support RETURNING. + return p.mySQLDeleteLoginSession(ctx, id) + } + var session flow.LoginSession err = p.Connection(ctx).RawQuery( @@ -477,6 +482,36 @@ func (p *Persister) DeleteLoginSession(ctx context.Context, id string) (deletedS return &session, nil } +func (p *Persister) mySQLDeleteLoginSession(ctx context.Context, id string) (*flow.LoginSession, error) { + var session flow.LoginSession + + err := p.Connection(ctx).Transaction(func(tx *pop.Connection) error { + err := tx.RawQuery(` +SELECT * FROM hydra_oauth2_authentication_session +WHERE id = ? AND nid = ?`, + id, + p.NetworkID(ctx), + ).First(&session) + if err != nil { + return err + } + + return p.Connection(ctx).RawQuery(` +DELETE FROM hydra_oauth2_authentication_session +WHERE id = ? AND nid = ?`, + id, + p.NetworkID(ctx), + ).Exec() + }) + + if err != nil { + return nil, sqlcon.HandleError(err) + } + + return &session, nil + +} + func (p *Persister) FindGrantedAndRememberedConsentRequests(ctx context.Context, client, subject string) (rs []flow.AcceptOAuth2ConsentRequest, err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.FindGrantedAndRememberedConsentRequests") defer span.End() diff --git a/persistence/sql/persister_nid_test.go b/persistence/sql/persister_nid_test.go index e3c72bd3856..b60ec326874 100644 --- a/persistence/sql/persister_nid_test.go +++ b/persistence/sql/persister_nid_test.go @@ -190,7 +190,7 @@ func (s *PersisterTestSuite) TestConfirmLoginSession() { require.NoError(t, r.Persister().CreateLoginSession(s.t1, ls)) // Expects the login session to be confirmed in the correct context. - require.NoError(t, r.Persister().ConfirmLoginSession(s.t1, ls, ls.ID, time.Now(), ls.Subject, !ls.Remember)) + require.NoError(t, r.Persister().ConfirmLoginSession(s.t1, ls, ls.ID, time.Now().UTC(), ls.Subject, !ls.Remember)) actual := &flow.LoginSession{} require.NoError(t, r.Persister().Connection(context.Background()).Find(actual, ls.ID)) exp, _ := json.Marshal(ls) @@ -199,7 +199,7 @@ func (s *PersisterTestSuite) TestConfirmLoginSession() { // Can't find the login session in the wrong context. require.ErrorIs(t, - r.Persister().ConfirmLoginSession(s.t2, ls, ls.ID, time.Now(), ls.Subject, !ls.Remember), + r.Persister().ConfirmLoginSession(s.t2, ls, ls.ID, time.Now().UTC(), ls.Subject, !ls.Remember), x.ErrNotFound, ) })