diff --git a/driver/config/test_config.go b/driver/config/test_config.go index 459ae15ac89c..c95fba7b7876 100644 --- a/driver/config/test_config.go +++ b/driver/config/test_config.go @@ -5,9 +5,6 @@ package config import ( "context" - "strings" - - "github.com/knadh/koanf/maps" "github.com/ory/kratos/embedx" "github.com/ory/x/configx" @@ -19,8 +16,7 @@ type ( contextx.Contextualizer Options []configx.OptionModifier } - contextKey int - mapProvider map[string]any + contextKey int ) func (t *TestConfigProvider) NewProvider(ctx context.Context, opts ...configx.OptionModifier) (*configx.Provider, error) { @@ -29,11 +25,15 @@ func (t *TestConfigProvider) NewProvider(ctx context.Context, opts ...configx.Op func (t *TestConfigProvider) Config(ctx context.Context, config *configx.Provider) *configx.Provider { config = t.Contextualizer.Config(ctx, config) - values, ok := ctx.Value(contextConfigKey).(mapProvider) + values, ok := ctx.Value(contextConfigKey).([]map[string]any) if !ok { return config } - config, err := t.NewProvider(ctx, configx.WithValues(values)) + opts := make([]configx.OptionModifier, 0, len(values)) + for _, v := range values { + opts = append(opts, configx.WithValues(v)) + } + config, err := t.NewProvider(ctx, opts...) if err != nil { // This is not production code. The provider is only used in tests. panic(err) @@ -51,25 +51,14 @@ func WithConfigValue(ctx context.Context, key string, value any) context.Context return WithConfigValues(ctx, map[string]any{key: value}) } -func WithConfigValues(ctx context.Context, newValues map[string]any) context.Context { - values, ok := ctx.Value(contextConfigKey).(mapProvider) +func WithConfigValues(ctx context.Context, setValues map[string]any) context.Context { + values, ok := ctx.Value(contextConfigKey).([]map[string]any) if !ok { - values = make(mapProvider) - } - expandedValues := make([]map[string]any, 0, len(newValues)) - for k, v := range newValues { - parts := strings.Split(k, ".") - val := map[string]any{parts[len(parts)-1]: v} - if len(parts) > 1 { - for i := len(parts) - 2; i >= 0; i-- { - val = map[string]any{parts[i]: val} - } - } - expandedValues = append(expandedValues, val) - } - for _, v := range expandedValues { - maps.Merge(v, values) + values = make([]map[string]any, 0) } + newValues := make([]map[string]any, len(values), len(values)+1) + copy(newValues, values) + newValues = append(newValues, setValues) - return context.WithValue(ctx, contextConfigKey, values) + return context.WithValue(ctx, contextConfigKey, newValues) } diff --git a/identity/manager_test.go b/identity/manager_test.go index e0346b8ee0c0..f7ad7d3b2da4 100644 --- a/identity/manager_test.go +++ b/identity/manager_test.go @@ -9,6 +9,7 @@ import ( "testing" "time" + "github.com/ory/x/configx" "github.com/ory/x/pointerx" "github.com/ory/x/sqlcon" @@ -29,12 +30,13 @@ import ( ) func TestManager(t *testing.T) { - conf, reg := internal.NewFastRegistryWithMocks(t) - testhelpers.SetDefaultIdentitySchema(conf, "file://./stub/manager.schema.json") - extensionSchemaID := testhelpers.UseIdentitySchema(t, conf, "file://./stub/extension.schema.json") - conf.MustSet(ctx, config.ViperKeyPublicBaseURL, "https://www.ory.sh/") - conf.MustSet(ctx, config.ViperKeyCourierSMTPURL, "smtp://foo@bar@dev.null/") - conf.MustSet(ctx, config.ViperKeySelfServiceRegistrationLoginHints, true) + conf, reg := internal.NewFastRegistryWithMocks(t, configx.WithValues(map[string]interface{}{ + config.ViperKeyPublicBaseURL: "https://www.ory.sh/", + config.ViperKeyCourierSMTPURL: "smtp://foo@bar@dev.null/", + config.ViperKeySelfServiceRegistrationLoginHints: true, + })) + ctx := testhelpers.WithDefaultIdentitySchema(ctx, "file://./stub/manager.schema.json") + ctx, extensionSchemaID := testhelpers.WithAddIdentitySchema(ctx, t, conf, "file://./stub/extension.schema.json") t.Run("case=should fail to create because validation fails", func(t *testing.T) { i := identity.NewIdentity(config.DefaultIdentityTraitsSchemaID) @@ -682,12 +684,13 @@ func TestManager(t *testing.T) { } func TestManagerNoDefaultNamedSchema(t *testing.T) { - conf, reg := internal.NewFastRegistryWithMocks(t) - conf.MustSet(ctx, config.ViperKeyDefaultIdentitySchemaID, "user_v0") - conf.MustSet(ctx, config.ViperKeyIdentitySchemas, config.Schemas{ - {ID: "user_v0", URL: "file://./stub/manager.schema.json"}, - }) - conf.MustSet(ctx, config.ViperKeyPublicBaseURL, "https://www.ory.sh/") + _, reg := internal.NewFastRegistryWithMocks(t, configx.WithValues(map[string]interface{}{ + config.ViperKeyDefaultIdentitySchemaID: "user_v0", + config.ViperKeyIdentitySchemas: config.Schemas{ + {ID: "user_v0", URL: "file://./stub/manager.schema.json"}, + }, + config.ViperKeyPublicBaseURL: "https://www.ory.sh/", + })) t.Run("case=should create identity with default schema", func(t *testing.T) { stateChangedAt := sqlxx.NullTime(time.Now().UTC()) diff --git a/identity/test/pool.go b/identity/test/pool.go index 450b5c1ea881..bf6d114b8510 100644 --- a/identity/test/pool.go +++ b/identity/test/pool.go @@ -36,12 +36,11 @@ import ( "github.com/ory/x/urlx" ) -func TestPool(ctx context.Context, conf *config.Config, p persistence.Persister, m *identity.Manager, dbname string) func(t *testing.T) { +func TestPool(ctx context.Context, p persistence.Persister, m *identity.Manager, dbname string) func(t *testing.T) { return func(t *testing.T) { - exampleServerURL := urlx.ParseOrPanic("http://example.com") - conf.MustSet(ctx, config.ViperKeyPublicBaseURL, exampleServerURL.String()) - nid, p := testhelpers.NewNetworkUnlessExisting(t, ctx, p) + + exampleServerURL := urlx.ParseOrPanic("http://example.com") expandSchema := schema.Schema{ ID: "expandSchema", URL: urlx.ParseOrPanic("file://./stub/expand.schema.json"), @@ -62,22 +61,25 @@ func TestPool(ctx context.Context, conf *config.Config, p persistence.Persister, URL: urlx.ParseOrPanic("file://./stub/handler/multiple_emails.schema.json"), RawURL: "file://./stub/identity-2.schema.json", } - conf.MustSet(ctx, config.ViperKeyIdentitySchemas, []config.Schema{ - { - ID: altSchema.ID, - URL: altSchema.RawURL, - }, - { - ID: defaultSchema.ID, - URL: defaultSchema.RawURL, - }, - { - ID: expandSchema.ID, - URL: expandSchema.RawURL, - }, - { - ID: multipleEmailsSchema.ID, - URL: multipleEmailsSchema.RawURL, + ctx := config.WithConfigValues(ctx, map[string]any{ + config.ViperKeyPublicBaseURL: exampleServerURL.String(), + config.ViperKeyIdentitySchemas: []config.Schema{ + { + ID: altSchema.ID, + URL: altSchema.RawURL, + }, + { + ID: defaultSchema.ID, + URL: defaultSchema.RawURL, + }, + { + ID: expandSchema.ID, + URL: expandSchema.RawURL, + }, + { + ID: multipleEmailsSchema.ID, + URL: multipleEmailsSchema.RawURL, + }, }, }) diff --git a/internal/client-go/go.sum b/internal/client-go/go.sum index c966c8ddfd0d..6cc3f5911d11 100644 --- a/internal/client-go/go.sum +++ b/internal/client-go/go.sum @@ -4,6 +4,7 @@ github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e h1:bRhVy7zSSasaqNksaRZiA5EEI+Ei4I1nO5Jh72wfHlg= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 h1:YUO/7uOKsKeq9UokNS62b8FYywz3ker1l1vDZRCRefw= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/internal/driver.go b/internal/driver.go index 3499a83b5b9b..5d31cfe5ceb2 100644 --- a/internal/driver.go +++ b/internal/driver.go @@ -84,7 +84,7 @@ func NewFastRegistryWithMocks(t *testing.T, opts ...configx.OptionModifier) (*co func NewRegistryDefaultWithDSN(t testing.TB, dsn string, opts ...configx.OptionModifier) (*config.Config, *driver.RegistryDefault) { ctx := context.Background() c := NewConfigurationWithDefaults(t, append(opts, configx.WithValues(map[string]interface{}{ - config.ViperKeyDSN: stringsx.Coalesce(dsn, dbal.NewSQLiteTestDatabase(t)), + config.ViperKeyDSN: stringsx.Coalesce(dsn, dbal.NewSQLiteTestDatabase(t)+"&lock=false&max_conns=1"), "dev": true, }))...) reg, err := driver.NewRegistryFromDSN(ctx, c, logrusx.New("", "", logrusx.ForceLevel(logrus.ErrorLevel))) diff --git a/persistence/sql/persister_cleanup_test.go b/persistence/sql/persister_cleanup_test.go index 65e95ea6ea00..efb14a05e6c9 100644 --- a/persistence/sql/persister_cleanup_test.go +++ b/persistence/sql/persister_cleanup_test.go @@ -14,6 +14,8 @@ import ( ) func TestPersister_Cleanup(t *testing.T) { + t.Parallel() + _, reg := internal.NewFastRegistryWithMocks(t) p := reg.Persister() ctx := context.Background() @@ -29,6 +31,8 @@ func TestPersister_Cleanup(t *testing.T) { } func TestPersister_Continuity_Cleanup(t *testing.T) { + t.Parallel() + _, reg := internal.NewFastRegistryWithMocks(t) p := reg.Persister() currentTime := time.Now() @@ -45,6 +49,8 @@ func TestPersister_Continuity_Cleanup(t *testing.T) { } func TestPersister_Login_Cleanup(t *testing.T) { + t.Parallel() + _, reg := internal.NewFastRegistryWithMocks(t) p := reg.Persister() currentTime := time.Now() @@ -61,6 +67,8 @@ func TestPersister_Login_Cleanup(t *testing.T) { } func TestPersister_Recovery_Cleanup(t *testing.T) { + t.Parallel() + _, reg := internal.NewFastRegistryWithMocks(t) p := reg.Persister() currentTime := time.Now() @@ -77,6 +85,8 @@ func TestPersister_Recovery_Cleanup(t *testing.T) { } func TestPersister_Registration_Cleanup(t *testing.T) { + t.Parallel() + _, reg := internal.NewFastRegistryWithMocks(t) p := reg.Persister() currentTime := time.Now() @@ -93,6 +103,8 @@ func TestPersister_Registration_Cleanup(t *testing.T) { } func TestPersister_Session_Cleanup(t *testing.T) { + t.Parallel() + _, reg := internal.NewFastRegistryWithMocks(t) p := reg.Persister() currentTime := time.Now() @@ -109,6 +121,8 @@ func TestPersister_Session_Cleanup(t *testing.T) { } func TestPersister_Settings_Cleanup(t *testing.T) { + t.Parallel() + _, reg := internal.NewFastRegistryWithMocks(t) p := reg.Persister() currentTime := time.Now() @@ -125,6 +139,8 @@ func TestPersister_Settings_Cleanup(t *testing.T) { } func TestPersister_Verification_Cleanup(t *testing.T) { + t.Parallel() + _, reg := internal.NewFastRegistryWithMocks(t) p := reg.Persister() currentTime := time.Now() @@ -141,6 +157,8 @@ func TestPersister_Verification_Cleanup(t *testing.T) { } func TestPersister_SessionTokenExchange_Cleanup(t *testing.T) { + t.Parallel() + _, reg := internal.NewFastRegistryWithMocks(t) p := reg.Persister() currentTime := time.Now() diff --git a/persistence/sql/persister_code.go b/persistence/sql/persister_code.go index 31e0b80dc2d2..ece7dea75ec3 100644 --- a/persistence/sql/persister_code.go +++ b/persistence/sql/persister_code.go @@ -94,7 +94,7 @@ func useOneTimeCode[P any, U interface { secrets: for _, secret := range p.r.Config().SecretsSession(ctx) { - suppliedCode := []byte(p.hmacValueWithSecret(ctx, userProvidedCode, secret)) + suppliedCode := []byte(hmacValueWithSecret(ctx, userProvidedCode, secret)) for i := range codes { c := codes[i] if subtle.ConstantTimeCompare([]byte(c.GetHMACCode()), suppliedCode) == 0 { diff --git a/persistence/sql/persister_errorx.go b/persistence/sql/persister_errorx.go index 15faf9fd163b..fc656074f0fc 100644 --- a/persistence/sql/persister_errorx.go +++ b/persistence/sql/persister_errorx.go @@ -4,18 +4,17 @@ package sql import ( - "bytes" "context" "encoding/json" "time" + "github.com/gobuffalo/pop/v6" "github.com/gofrs/uuid" "github.com/pkg/errors" "go.opentelemetry.io/otel/attribute" - "github.com/ory/jsonschema/v3" - "github.com/ory/herodot" + "github.com/ory/jsonschema/v3" "github.com/ory/x/otelx" "github.com/ory/x/sqlcon" @@ -28,7 +27,7 @@ func (p *Persister) CreateErrorContainer(ctx context.Context, csrfToken string, ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateErrorContainer") defer otelx.End(span, &err) - message, err := p.encodeSelfServiceErrors(ctx, errs) + message, err := encodeSelfServiceErrors(errs) if err != nil { return uuid.Nil, err } @@ -55,14 +54,19 @@ func (p *Persister) ReadErrorContainer(ctx context.Context, id uuid.UUID) (_ *er defer otelx.End(span, &err) var ec errorx.ErrorContainer - if err := p.GetConnection(ctx).Where("id = ? AND nid = ?", id, p.NetworkID(ctx)).First(&ec); err != nil { - return nil, sqlcon.HandleError(err) - } - - if err := p.GetConnection(ctx).RawQuery( - "UPDATE selfservice_errors SET was_seen = true, seen_at = ? WHERE id = ? AND nid = ?", - time.Now().UTC(), id, p.NetworkID(ctx)).Exec(); err != nil { - return nil, sqlcon.HandleError(err) + if err := p.Transaction(ctx, func(ctx context.Context, c *pop.Connection) error { + if err := c.Where("id = ? AND nid = ?", id, p.NetworkID(ctx)).First(&ec); err != nil { + return sqlcon.HandleError(err) + } + + if err := c.RawQuery( + "UPDATE selfservice_errors SET was_seen = true, seen_at = ? WHERE id = ? AND nid = ?", + time.Now().UTC(), id, p.NetworkID(ctx)).Exec(); err != nil { + return sqlcon.HandleError(err) + } + return nil + }); err != nil { + return nil, err } return &ec, nil @@ -85,7 +89,7 @@ func (p *Persister) ClearErrorContainers(ctx context.Context, olderThan time.Dur return sqlcon.HandleError(err) } -func (p *Persister) encodeSelfServiceErrors(ctx context.Context, e error) ([]byte, error) { +func encodeSelfServiceErrors(e error) ([]byte, error) { if e == nil { return nil, errors.WithStack(herodot.ErrInternalServerError.WithDebug("A nil error was passed to the error manager which is most likely a code bug.")) } @@ -98,10 +102,10 @@ func (p *Persister) encodeSelfServiceErrors(ctx context.Context, e error) ([]byt e = herodot.ToDefaultError(e, "") } - var b bytes.Buffer - if err := json.NewEncoder(&b).Encode(e); err != nil { + enc, err := json.Marshal(e) + if err != nil { return nil, errors.WithStack(herodot.ErrInternalServerError.WithReason("Unable to encode error messages.").WithDebug(err.Error())) } - return b.Bytes(), nil + return enc, nil } diff --git a/persistence/sql/persister_hmac.go b/persistence/sql/persister_hmac.go index 9c4d6636f14c..5fc964d676b6 100644 --- a/persistence/sql/persister_hmac.go +++ b/persistence/sql/persister_hmac.go @@ -7,27 +7,19 @@ import ( "context" "crypto/hmac" "crypto/sha512" - "crypto/subtle" "fmt" + + "go.opentelemetry.io/otel/trace" ) func (p *Persister) hmacValue(ctx context.Context, value string) string { - return p.hmacValueWithSecret(ctx, value, p.r.Config().SecretsSession(ctx)[0]) + return hmacValueWithSecret(ctx, value, p.r.Config().SecretsSession(ctx)[0]) } -func (p *Persister) hmacValueWithSecret(ctx context.Context, value string, secret []byte) string { - _, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.hmacValueWithSecret") +func hmacValueWithSecret(ctx context.Context, value string, secret []byte) string { + ctx, span := trace.SpanFromContext(ctx).TracerProvider().Tracer("").Start(ctx, "persistence.sql.hmacValueWithSecret") defer span.End() h := hmac.New(sha512.New512_256, secret) _, _ = h.Write([]byte(value)) return fmt.Sprintf("%x", h.Sum(nil)) } - -func (p *Persister) hmacConstantCompare(ctx context.Context, value, hash string) bool { - for _, secret := range p.r.Config().SecretsSession(ctx) { - if subtle.ConstantTimeCompare([]byte(p.hmacValueWithSecret(ctx, value, secret)), []byte(hash)) == 1 { - return true - } - } - return false -} diff --git a/persistence/sql/persister_hmac_test.go b/persistence/sql/persister_hmac_test.go index 7b8cc8575368..c569affa0cd9 100644 --- a/persistence/sql/persister_hmac_test.go +++ b/persistence/sql/persister_hmac_test.go @@ -8,9 +8,10 @@ import ( "os" "testing" + "github.com/ory/x/configx" + "github.com/ory/x/contextx" - "github.com/ory/x/configx" "github.com/ory/x/otelx" "github.com/gobuffalo/pop/v6" @@ -49,10 +50,10 @@ func (l *logRegistryOnly) Audit() *logrusx.Logger { panic("implement me") } -func (l *logRegistryOnly) Tracer(ctx context.Context) *otelx.Tracer { +func (l *logRegistryOnly) Tracer(context.Context) *otelx.Tracer { return otelx.NewNoop(l.l, new(otelx.Config)) } -func (l *logRegistryOnly) IdentityTraitsSchemas(ctx context.Context) (schema.IdentitySchemaList, error) { +func (l *logRegistryOnly) IdentityTraitsSchemas(context.Context) (schema.IdentitySchemaList, error) { panic("implement me") } @@ -63,25 +64,36 @@ func (l *logRegistryOnly) IdentityValidator() *identity.Validator { var _ persisterDependencies = &logRegistryOnly{} func TestPersisterHMAC(t *testing.T) { + t.Parallel() + ctx := context.Background() - conf := config.MustNew(t, logrusx.New("", ""), os.Stderr, &contextx.Default{}, configx.SkipValidation()) - conf.MustSet(ctx, config.ViperKeySecretsDefault, []string{"foobarbaz"}) + baseSecret := "foobarbaz" + baseSecretBytes := []byte(baseSecret) + opts := []configx.OptionModifier{configx.SkipValidation(), configx.WithValue(config.ViperKeySecretsDefault, []string{baseSecret})} + conf := config.MustNew(t, logrusx.New("", ""), os.Stderr, &config.TestConfigProvider{Contextualizer: &contextx.Default{}, Options: opts}, opts...) c, err := pop.NewConnection(&pop.ConnectionDetails{URL: "sqlite://foo?mode=memory"}) require.NoError(t, err) - p, err := NewPersister(context.Background(), &logRegistryOnly{c: conf}, c) + p, err := NewPersister(ctx, &logRegistryOnly{c: conf}, c) require.NoError(t, err) - assert.True(t, p.hmacConstantCompare(context.Background(), "hashme", p.hmacValue(context.Background(), "hashme"))) - assert.False(t, p.hmacConstantCompare(context.Background(), "notme", p.hmacValue(context.Background(), "hashme"))) - assert.False(t, p.hmacConstantCompare(context.Background(), "hashme", p.hmacValue(context.Background(), "notme"))) - - hash := p.hmacValue(context.Background(), "hashme") - conf.MustSet(ctx, config.ViperKeySecretsDefault, []string{"notfoobarbaz"}) - assert.False(t, p.hmacConstantCompare(context.Background(), "hashme", hash)) - assert.True(t, p.hmacConstantCompare(context.Background(), "hashme", p.hmacValue(context.Background(), "hashme"))) - - conf.MustSet(ctx, config.ViperKeySecretsDefault, []string{"notfoobarbaz", "foobarbaz"}) - assert.True(t, p.hmacConstantCompare(context.Background(), "hashme", hash)) - assert.True(t, p.hmacConstantCompare(context.Background(), "hashme", p.hmacValue(context.Background(), "hashme"))) - assert.NotEqual(t, hash, p.hmacValue(context.Background(), "hashme")) + t.Run("case=behaves deterministically", func(t *testing.T) { + assert.Equal(t, hmacValueWithSecret(ctx, "hashme", baseSecretBytes), p.hmacValue(ctx, "hashme")) + assert.NotEqual(t, hmacValueWithSecret(ctx, "notme", baseSecretBytes), p.hmacValue(ctx, "hashme")) + assert.NotEqual(t, hmacValueWithSecret(ctx, "hashme", baseSecretBytes), p.hmacValue(ctx, "notme")) + }) + + hash := p.hmacValue(ctx, "hashme") + newSecret := "not" + baseSecret + + t.Run("case=with only new sectet", func(t *testing.T) { + ctx = config.WithConfigValue(ctx, config.ViperKeySecretsDefault, []string{newSecret}) + assert.NotEqual(t, hmacValueWithSecret(ctx, "hashme", baseSecretBytes), p.hmacValue(ctx, "hashme")) + assert.Equal(t, hmacValueWithSecret(ctx, "hashme", []byte(newSecret)), p.hmacValue(ctx, "hashme")) + }) + + t.Run("case=with new and old secret", func(t *testing.T) { + ctx = config.WithConfigValue(ctx, config.ViperKeySecretsDefault, []string{newSecret, baseSecret}) + assert.Equal(t, hmacValueWithSecret(ctx, "hashme", []byte(newSecret)), p.hmacValue(ctx, "hashme")) + assert.NotEqual(t, hash, p.hmacValue(ctx, "hashme")) + }) } diff --git a/persistence/sql/persister_recovery.go b/persistence/sql/persister_recovery.go index 468ba5a2b144..bb23d3fd319e 100644 --- a/persistence/sql/persister_recovery.go +++ b/persistence/sql/persister_recovery.go @@ -82,7 +82,7 @@ func (p *Persister) UseRecoveryToken(ctx context.Context, fID uuid.UUID, token s nid := p.NetworkID(ctx) if err := sqlcon.HandleError(p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) (err error) { for _, secret := range p.r.Config().SecretsSession(ctx) { - if err = tx.Where("token = ? AND nid = ? AND NOT used AND selfservice_recovery_flow_id = ?", p.hmacValueWithSecret(ctx, token, secret), nid, fID).First(&rt); err != nil { + if err = tx.Where("token = ? AND nid = ? AND NOT used AND selfservice_recovery_flow_id = ?", hmacValueWithSecret(ctx, token, secret), nid, fID).First(&rt); err != nil { if !errors.Is(sqlcon.HandleError(err), sqlcon.ErrNoRows) { return err } diff --git a/persistence/sql/persister_test.go b/persistence/sql/persister_test.go index f88d7380a5c7..6a48e763d0f3 100644 --- a/persistence/sql/persister_test.go +++ b/persistence/sql/persister_test.go @@ -94,7 +94,7 @@ func pl(t testing.TB) func(lvl logging.Level, s string, args ...interface{}) { func createCleanDatabases(t testing.TB) map[string]*driver.RegistryDefault { conns := map[string]string{ - "sqlite": "sqlite://file:" + t.TempDir() + "/db.sqlite?_fk=true", + "sqlite": "sqlite://file:" + t.TempDir() + "/db.sqlite?_fk=true&max_conns=1&lock=false", } var l sync.Mutex @@ -160,111 +160,104 @@ func createCleanDatabases(t testing.TB) map[string]*driver.RegistryDefault { } func TestPersister(t *testing.T) { + t.Parallel() + conns := createCleanDatabases(t) - ctx := context.Background() + ctx := testhelpers.WithDefaultIdentitySchema(context.Background(), "file://./stub/identity.schema.json") - for name := range conns { - name := name - reg := conns[name] + for name, reg := range conns { t.Run(fmt.Sprintf("database=%s", name), func(t *testing.T) { t.Parallel() _, p := testhelpers.NewNetwork(t, ctx, reg.Persister()) - conf := reg.Config() - t.Logf("DSN: %s", conf.DSN(ctx)) + t.Logf("DSN: %s", reg.Config().DSN(ctx)) - // This test must remain the first test in the test suite! t.Run("racy identity creation", func(t *testing.T) { - defaultSchema := schema.Schema{ - ID: config.DefaultIdentityTraitsSchemaID, - URL: urlx.ParseOrPanic("file://./stub/identity.schema.json"), - RawURL: "file://./stub/identity.schema.json", - } + t.Parallel() var wg sync.WaitGroup - testhelpers.SetDefaultIdentitySchema(reg.Config(), defaultSchema.RawURL) + _, ps := testhelpers.NewNetwork(t, ctx, reg.Persister()) - for i := 0; i < 10; i++ { + for i := range 10 { wg.Add(1) - // capture i - ii := i go func() { defer wg.Done() id := ri.NewIdentity("") id.SetCredentials(ri.CredentialsTypePassword, ri.Credentials{ Type: ri.CredentialsTypePassword, - Identifiers: []string{fmt.Sprintf("racy identity %d", ii)}, + Identifiers: []string{fmt.Sprintf("racy identity %d", i)}, Config: sqlxx.JSONRawMessage(`{"foo":"bar"}`), }) id.Traits = ri.Traits("{}") - require.NoError(t, ps.CreateIdentity(context.Background(), id)) + require.NoError(t, ps.CreateIdentity(ctx, id)) }() } wg.Wait() }) - t.Run("case=credentials types", func(t *testing.T) { + t.Run("case=credential types exist", func(t *testing.T) { + t.Parallel() for _, ct := range []ri.CredentialsType{ri.CredentialsTypeOIDC, ri.CredentialsTypePassword} { require.NoError(t, p.(*sql.Persister).Connection(context.Background()).Where("name = ?", ct).First(&ri.CredentialsTypeTable{})) } }) t.Run("contract=identity.TestPool", func(t *testing.T) { - pop.SetLogger(pl(t)) - identity.TestPool(ctx, conf, p, reg.IdentityManager(), name)(t) + t.Parallel() + identity.TestPool(ctx, p, reg.IdentityManager(), name)(t) }) t.Run("contract=registration.TestFlowPersister", func(t *testing.T) { - pop.SetLogger(pl(t)) + t.Parallel() registration.TestFlowPersister(ctx, p)(t) }) t.Run("contract=errorx.TestPersister", func(t *testing.T) { - pop.SetLogger(pl(t)) + t.Parallel() errorx.TestPersister(ctx, p)(t) }) t.Run("contract=login.TestFlowPersister", func(t *testing.T) { - pop.SetLogger(pl(t)) + t.Parallel() login.TestFlowPersister(ctx, p)(t) }) t.Run("contract=settings.TestFlowPersister", func(t *testing.T) { - pop.SetLogger(pl(t)) - settings.TestFlowPersister(ctx, conf, p)(t) + t.Parallel() + settings.TestFlowPersister(ctx, p)(t) }) t.Run("contract=session.TestPersister", func(t *testing.T) { - pop.SetLogger(pl(t)) - session.TestPersister(ctx, conf, p)(t) + t.Parallel() + session.TestPersister(ctx, reg.Config(), p)(t) }) t.Run("contract=sessiontokenexchange.TestPersister", func(t *testing.T) { - pop.SetLogger(pl(t)) - sessiontokenexchange.TestPersister(ctx, conf, p)(t) + t.Parallel() + sessiontokenexchange.TestPersister(ctx, p)(t) }) t.Run("contract=courier.TestPersister", func(t *testing.T) { - pop.SetLogger(pl(t)) + t.Parallel() upsert, insert := sqltesthelpers.DefaultNetworkWrapper(p) courier.TestPersister(ctx, upsert, insert)(t) }) t.Run("contract=verification.TestFlowPersister", func(t *testing.T) { - pop.SetLogger(pl(t)) - verification.TestFlowPersister(ctx, conf, p)(t) + t.Parallel() + verification.TestFlowPersister(ctx, p)(t) }) t.Run("contract=recovery.TestFlowPersister", func(t *testing.T) { - pop.SetLogger(pl(t)) - recovery.TestFlowPersister(ctx, conf, p)(t) + t.Parallel() + recovery.TestFlowPersister(ctx, p)(t) }) t.Run("contract=link.TestPersister", func(t *testing.T) { - pop.SetLogger(pl(t)) - link.TestPersister(ctx, conf, p)(t) + t.Parallel() + link.TestPersister(ctx, p)(t) }) t.Run("contract=code.TestPersister", func(t *testing.T) { - pop.SetLogger(pl(t)) - code.TestPersister(ctx, conf, p)(t) + t.Parallel() + code.TestPersister(ctx, p)(t) }) t.Run("contract=continuity.TestPersister", func(t *testing.T) { - pop.SetLogger(pl(t)) + t.Parallel() continuity.TestPersister(ctx, p)(t) }) }) @@ -283,6 +276,8 @@ func getErr(args ...interface{}) error { } func TestPersister_Transaction(t *testing.T) { + t.Parallel() + _, reg := internal.NewFastRegistryWithMocks(t) p := reg.Persister() diff --git a/persistence/sql/persister_verification.go b/persistence/sql/persister_verification.go index 8d983ed1635d..7feae0592ae7 100644 --- a/persistence/sql/persister_verification.go +++ b/persistence/sql/persister_verification.go @@ -82,7 +82,7 @@ func (p *Persister) UseVerificationToken(ctx context.Context, fID uuid.UUID, tok nid := p.NetworkID(ctx) if err := sqlcon.HandleError(p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) (err error) { for _, secret := range p.r.Config().SecretsSession(ctx) { - if err = tx.Where("token = ? AND nid = ? AND NOT used AND selfservice_verification_flow_id = ?", p.hmacValueWithSecret(ctx, token, secret), nid, fID).First(&rt); err != nil { + if err = tx.Where("token = ? AND nid = ? AND NOT used AND selfservice_verification_flow_id = ?", hmacValueWithSecret(ctx, token, secret), nid, fID).First(&rt); err != nil { if !errors.Is(sqlcon.HandleError(err), sqlcon.ErrNoRows) { return err } diff --git a/selfservice/flow/recovery/test/persistence.go b/selfservice/flow/recovery/test/persistence.go index 8bc9efad88e0..caea28d8e43e 100644 --- a/selfservice/flow/recovery/test/persistence.go +++ b/selfservice/flow/recovery/test/persistence.go @@ -12,7 +12,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/ory/kratos/driver/config" "github.com/ory/kratos/internal/testhelpers" "github.com/ory/kratos/persistence" "github.com/ory/kratos/selfservice/flow" @@ -23,7 +22,7 @@ import ( "github.com/ory/x/sqlcon" ) -func TestFlowPersister(ctx context.Context, conf *config.Config, p interface { +func TestFlowPersister(ctx context.Context, p interface { persistence.Persister }, ) func(t *testing.T) { @@ -33,7 +32,7 @@ func TestFlowPersister(ctx context.Context, conf *config.Config, p interface { return func(t *testing.T) { nid, p := testhelpers.NewNetworkUnlessExisting(t, ctx, p) - testhelpers.SetDefaultIdentitySchema(conf, "file://./stub/identity.schema.json") + ctx := testhelpers.WithDefaultIdentitySchema(ctx, "file://./stub/identity.schema.json") t.Run("case=should error when the recovery request does not exist", func(t *testing.T) { _, err := p.GetRecoveryFlow(ctx, x.NewUUID()) diff --git a/selfservice/flow/settings/test/persistence.go b/selfservice/flow/settings/test/persistence.go index 85c80e49d74e..2af6c75aba6c 100644 --- a/selfservice/flow/settings/test/persistence.go +++ b/selfservice/flow/settings/test/persistence.go @@ -27,7 +27,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/ory/kratos/driver/config" "github.com/ory/kratos/x" ) @@ -37,11 +36,11 @@ func clearids(r *settings.Flow) { r.IdentityID = uuid.Nil } -func TestFlowPersister(ctx context.Context, conf *config.Config, p persistence.Persister) func(t *testing.T) { +func TestFlowPersister(ctx context.Context, p persistence.Persister) func(t *testing.T) { return func(t *testing.T) { _, p := testhelpers.NewNetworkUnlessExisting(t, ctx, p) - testhelpers.SetDefaultIdentitySchema(conf, "file://./stub/identity.schema.json") + ctx := testhelpers.WithDefaultIdentitySchema(ctx, "file://./stub/identity.schema.json") t.Run("case=should error when the settings request does not exist", func(t *testing.T) { _, err := p.GetSettingsFlow(ctx, x.NewUUID()) diff --git a/selfservice/flow/verification/test/persistence.go b/selfservice/flow/verification/test/persistence.go index 57c35cba8d2e..825a5129f757 100644 --- a/selfservice/flow/verification/test/persistence.go +++ b/selfservice/flow/verification/test/persistence.go @@ -12,7 +12,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/ory/kratos/driver/config" "github.com/ory/kratos/internal/testhelpers" "github.com/ory/kratos/persistence" "github.com/ory/kratos/selfservice/flow" @@ -23,7 +22,7 @@ import ( "github.com/ory/x/sqlcon" ) -func TestFlowPersister(ctx context.Context, conf *config.Config, p interface { +func TestFlowPersister(ctx context.Context, p interface { persistence.Persister }, ) func(t *testing.T) { @@ -34,7 +33,7 @@ func TestFlowPersister(ctx context.Context, conf *config.Config, p interface { return func(t *testing.T) { nid, p := testhelpers.NewNetworkUnlessExisting(t, ctx, p) - testhelpers.SetDefaultIdentitySchema(conf, "file://./stub/identity.schema.json") + ctx := testhelpers.WithDefaultIdentitySchema(ctx, "file://./stub/identity.schema.json") t.Run("case=should error when the verification request does not exist", func(t *testing.T) { _, err := p.GetVerificationFlow(ctx, x.NewUUID()) diff --git a/selfservice/sessiontokenexchange/test/persistence.go b/selfservice/sessiontokenexchange/test/persistence.go index 53db63db04f3..da19c3edc1a3 100644 --- a/selfservice/sessiontokenexchange/test/persistence.go +++ b/selfservice/sessiontokenexchange/test/persistence.go @@ -11,7 +11,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/ory/kratos/driver/config" "github.com/ory/kratos/internal/testhelpers" "github.com/ory/kratos/persistence" "github.com/ory/kratos/selfservice/sessiontokenexchange" @@ -36,11 +35,10 @@ func (t *testParams) setCodes(e *sessiontokenexchange.Exchanger) { t.returnToCode = e.ReturnToCode } -func TestPersister(ctx context.Context, _ *config.Config, p interface { +func TestPersister(ctx context.Context, p interface { persistence.Persister }) func(t *testing.T) { return func(t *testing.T) { - t.Parallel() nid, p := testhelpers.NewNetworkUnlessExisting(t, ctx, p) t.Run("suite=create-update-get", func(t *testing.T) { diff --git a/selfservice/strategy/code/test/persistence.go b/selfservice/strategy/code/test/persistence.go index f3c120402ddb..2e0ce7d00ca9 100644 --- a/selfservice/strategy/code/test/persistence.go +++ b/selfservice/strategy/code/test/persistence.go @@ -24,15 +24,15 @@ import ( "github.com/ory/kratos/x" ) -func TestPersister(ctx context.Context, conf *config.Config, p interface { +func TestPersister(ctx context.Context, p interface { persistence.Persister }, ) func(t *testing.T) { return func(t *testing.T) { nid, p := testhelpers.NewNetworkUnlessExisting(t, ctx, p) - testhelpers.SetDefaultIdentitySchema(conf, "file://./stub/identity.schema.json") - conf.MustSet(ctx, config.ViperKeySecretsDefault, []string{"secret-a", "secret-b"}) + ctx := testhelpers.WithDefaultIdentitySchema(ctx, "file://./stub/identity.schema.json") + ctx = config.WithConfigValue(ctx, config.ViperKeySecretsDefault, []string{"secret-a", "secret-b"}) t.Run("code=recovery", func(t *testing.T) { newRecoveryCodeDTO := func(t *testing.T, email string) (*code.CreateRecoveryCodeParams, *recovery.Flow, *identity.RecoveryAddress) { diff --git a/selfservice/strategy/link/test/persistence.go b/selfservice/strategy/link/test/persistence.go index af5738eaae31..5207a810d75b 100644 --- a/selfservice/strategy/link/test/persistence.go +++ b/selfservice/strategy/link/test/persistence.go @@ -28,15 +28,15 @@ import ( "github.com/ory/kratos/x" ) -func TestPersister(ctx context.Context, conf *config.Config, p interface { +func TestPersister(ctx context.Context, p interface { persistence.Persister }, ) func(t *testing.T) { return func(t *testing.T) { nid, p := testhelpers.NewNetworkUnlessExisting(t, ctx, p) - testhelpers.SetDefaultIdentitySchema(conf, "file://./stub/identity.schema.json") - conf.MustSet(ctx, config.ViperKeySecretsDefault, []string{"secret-a", "secret-b"}) + ctx := testhelpers.WithDefaultIdentitySchema(ctx, "file://./stub/identity.schema.json") + ctx = config.WithConfigValue(ctx, config.ViperKeySecretsDefault, []string{"secret-a", "secret-b"}) t.Run("token=recovery", func(t *testing.T) { newRecoveryToken := func(t *testing.T, email string) (*link.RecoveryToken, *recovery.Flow) { diff --git a/session/test/persistence.go b/session/test/persistence.go index 0db6964468d8..0b709a3866b8 100644 --- a/session/test/persistence.go +++ b/session/test/persistence.go @@ -42,8 +42,6 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface { return func(t *testing.T) { _, p := testhelpers.NewNetworkUnlessExisting(t, ctx, p) - ctx := testhelpers.WithDefaultIdentitySchema(ctx, "file://./stub/identity.schema.json") - t.Run("case=not found", func(t *testing.T) { _, err := p.GetSession(ctx, x.NewUUID(), session.ExpandNothing) require.Error(t, err)