Skip to content

Commit

Permalink
test: deflake and parallelize persister tests
Browse files Browse the repository at this point in the history
  • Loading branch information
zepatrik committed Jun 14, 2024
1 parent b192c92 commit b5fb0c4
Show file tree
Hide file tree
Showing 20 changed files with 179 additions and 170 deletions.
39 changes: 14 additions & 25 deletions driver/config/test_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@ package config

import (
"context"
"strings"

"github.com/knadh/koanf/maps"

"github.com/ory/kratos/embedx"
"github.com/ory/x/configx"
Expand All @@ -19,8 +16,7 @@ type (
contextx.Contextualizer
Options []configx.OptionModifier
}
contextKey int
mapProvider map[string]any
contextKey int
)

func (t *TestConfigProvider) NewProvider(ctx context.Context, opts ...configx.OptionModifier) (*configx.Provider, error) {
Expand All @@ -29,11 +25,15 @@ func (t *TestConfigProvider) NewProvider(ctx context.Context, opts ...configx.Op

func (t *TestConfigProvider) Config(ctx context.Context, config *configx.Provider) *configx.Provider {
config = t.Contextualizer.Config(ctx, config)
values, ok := ctx.Value(contextConfigKey).(mapProvider)
values, ok := ctx.Value(contextConfigKey).([]map[string]any)
if !ok {
return config
}
config, err := t.NewProvider(ctx, configx.WithValues(values))
opts := make([]configx.OptionModifier, 0, len(values))
for _, v := range values {
opts = append(opts, configx.WithValues(v))
}
config, err := t.NewProvider(ctx, opts...)
if err != nil {
// This is not production code. The provider is only used in tests.
panic(err)
Expand All @@ -51,25 +51,14 @@ func WithConfigValue(ctx context.Context, key string, value any) context.Context
return WithConfigValues(ctx, map[string]any{key: value})
}

func WithConfigValues(ctx context.Context, newValues map[string]any) context.Context {
values, ok := ctx.Value(contextConfigKey).(mapProvider)
func WithConfigValues(ctx context.Context, setValues map[string]any) context.Context {
values, ok := ctx.Value(contextConfigKey).([]map[string]any)
if !ok {
values = make(mapProvider)
}
expandedValues := make([]map[string]any, 0, len(newValues))
for k, v := range newValues {
parts := strings.Split(k, ".")
val := map[string]any{parts[len(parts)-1]: v}
if len(parts) > 1 {
for i := len(parts) - 2; i >= 0; i-- {
val = map[string]any{parts[i]: val}
}
}
expandedValues = append(expandedValues, val)
}
for _, v := range expandedValues {
maps.Merge(v, values)
values = make([]map[string]any, 0)
}
newValues := make([]map[string]any, len(values), len(values)+1)
copy(newValues, values)
newValues = append(newValues, setValues)

return context.WithValue(ctx, contextConfigKey, values)
return context.WithValue(ctx, contextConfigKey, newValues)
}
27 changes: 15 additions & 12 deletions identity/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"testing"
"time"

"github.com/ory/x/configx"
"github.com/ory/x/pointerx"
"github.com/ory/x/sqlcon"

Expand All @@ -29,12 +30,13 @@ import (
)

func TestManager(t *testing.T) {
conf, reg := internal.NewFastRegistryWithMocks(t)
testhelpers.SetDefaultIdentitySchema(conf, "file://./stub/manager.schema.json")
extensionSchemaID := testhelpers.UseIdentitySchema(t, conf, "file://./stub/extension.schema.json")
conf.MustSet(ctx, config.ViperKeyPublicBaseURL, "https://www.ory.sh/")
conf.MustSet(ctx, config.ViperKeyCourierSMTPURL, "smtp://foo@bar@dev.null/")
conf.MustSet(ctx, config.ViperKeySelfServiceRegistrationLoginHints, true)
conf, reg := internal.NewFastRegistryWithMocks(t, configx.WithValues(map[string]interface{}{
config.ViperKeyPublicBaseURL: "https://www.ory.sh/",
config.ViperKeyCourierSMTPURL: "smtp://foo@bar@dev.null/",
config.ViperKeySelfServiceRegistrationLoginHints: true,
}))
ctx := testhelpers.WithDefaultIdentitySchema(ctx, "file://./stub/manager.schema.json")
ctx, extensionSchemaID := testhelpers.WithAddIdentitySchema(ctx, t, conf, "file://./stub/extension.schema.json")

t.Run("case=should fail to create because validation fails", func(t *testing.T) {
i := identity.NewIdentity(config.DefaultIdentityTraitsSchemaID)
Expand Down Expand Up @@ -682,12 +684,13 @@ func TestManager(t *testing.T) {
}

func TestManagerNoDefaultNamedSchema(t *testing.T) {
conf, reg := internal.NewFastRegistryWithMocks(t)
conf.MustSet(ctx, config.ViperKeyDefaultIdentitySchemaID, "user_v0")
conf.MustSet(ctx, config.ViperKeyIdentitySchemas, config.Schemas{
{ID: "user_v0", URL: "file://./stub/manager.schema.json"},
})
conf.MustSet(ctx, config.ViperKeyPublicBaseURL, "https://www.ory.sh/")
_, reg := internal.NewFastRegistryWithMocks(t, configx.WithValues(map[string]interface{}{
config.ViperKeyDefaultIdentitySchemaID: "user_v0",
config.ViperKeyIdentitySchemas: config.Schemas{
{ID: "user_v0", URL: "file://./stub/manager.schema.json"},
},
config.ViperKeyPublicBaseURL: "https://www.ory.sh/",
}))

t.Run("case=should create identity with default schema", func(t *testing.T) {
stateChangedAt := sqlxx.NullTime(time.Now().UTC())
Expand Down
42 changes: 22 additions & 20 deletions identity/test/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,11 @@ import (
"github.com/ory/x/urlx"
)

func TestPool(ctx context.Context, conf *config.Config, p persistence.Persister, m *identity.Manager, dbname string) func(t *testing.T) {
func TestPool(ctx context.Context, p persistence.Persister, m *identity.Manager, dbname string) func(t *testing.T) {
return func(t *testing.T) {
exampleServerURL := urlx.ParseOrPanic("http://example.com")
conf.MustSet(ctx, config.ViperKeyPublicBaseURL, exampleServerURL.String())

nid, p := testhelpers.NewNetworkUnlessExisting(t, ctx, p)

exampleServerURL := urlx.ParseOrPanic("http://example.com")
expandSchema := schema.Schema{
ID: "expandSchema",
URL: urlx.ParseOrPanic("file://./stub/expand.schema.json"),
Expand All @@ -62,22 +61,25 @@ func TestPool(ctx context.Context, conf *config.Config, p persistence.Persister,
URL: urlx.ParseOrPanic("file://./stub/handler/multiple_emails.schema.json"),
RawURL: "file://./stub/identity-2.schema.json",
}
conf.MustSet(ctx, config.ViperKeyIdentitySchemas, []config.Schema{
{
ID: altSchema.ID,
URL: altSchema.RawURL,
},
{
ID: defaultSchema.ID,
URL: defaultSchema.RawURL,
},
{
ID: expandSchema.ID,
URL: expandSchema.RawURL,
},
{
ID: multipleEmailsSchema.ID,
URL: multipleEmailsSchema.RawURL,
ctx := config.WithConfigValues(ctx, map[string]any{
config.ViperKeyPublicBaseURL: exampleServerURL.String(),
config.ViperKeyIdentitySchemas: []config.Schema{
{
ID: altSchema.ID,
URL: altSchema.RawURL,
},
{
ID: defaultSchema.ID,
URL: defaultSchema.RawURL,
},
{
ID: expandSchema.ID,
URL: expandSchema.RawURL,
},
{
ID: multipleEmailsSchema.ID,
URL: multipleEmailsSchema.RawURL,
},
},
})

Expand Down
1 change: 1 addition & 0 deletions internal/client-go/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e h1:bRhVy7zSSasaqNksaRZiA5EEI+Ei4I1nO5Jh72wfHlg=
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 h1:YUO/7uOKsKeq9UokNS62b8FYywz3ker1l1vDZRCRefw=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
Expand Down
2 changes: 1 addition & 1 deletion internal/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func NewFastRegistryWithMocks(t *testing.T, opts ...configx.OptionModifier) (*co
func NewRegistryDefaultWithDSN(t testing.TB, dsn string, opts ...configx.OptionModifier) (*config.Config, *driver.RegistryDefault) {
ctx := context.Background()
c := NewConfigurationWithDefaults(t, append(opts, configx.WithValues(map[string]interface{}{
config.ViperKeyDSN: stringsx.Coalesce(dsn, dbal.NewSQLiteTestDatabase(t)),
config.ViperKeyDSN: stringsx.Coalesce(dsn, dbal.NewSQLiteTestDatabase(t)+"&lock=false&max_conns=1"),
"dev": true,
}))...)
reg, err := driver.NewRegistryFromDSN(ctx, c, logrusx.New("", "", logrusx.ForceLevel(logrus.ErrorLevel)))
Expand Down
18 changes: 18 additions & 0 deletions persistence/sql/persister_cleanup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import (
)

func TestPersister_Cleanup(t *testing.T) {
t.Parallel()

_, reg := internal.NewFastRegistryWithMocks(t)
p := reg.Persister()
ctx := context.Background()
Expand All @@ -29,6 +31,8 @@ func TestPersister_Cleanup(t *testing.T) {
}

func TestPersister_Continuity_Cleanup(t *testing.T) {
t.Parallel()

_, reg := internal.NewFastRegistryWithMocks(t)
p := reg.Persister()
currentTime := time.Now()
Expand All @@ -45,6 +49,8 @@ func TestPersister_Continuity_Cleanup(t *testing.T) {
}

func TestPersister_Login_Cleanup(t *testing.T) {
t.Parallel()

_, reg := internal.NewFastRegistryWithMocks(t)
p := reg.Persister()
currentTime := time.Now()
Expand All @@ -61,6 +67,8 @@ func TestPersister_Login_Cleanup(t *testing.T) {
}

func TestPersister_Recovery_Cleanup(t *testing.T) {
t.Parallel()

_, reg := internal.NewFastRegistryWithMocks(t)
p := reg.Persister()
currentTime := time.Now()
Expand All @@ -77,6 +85,8 @@ func TestPersister_Recovery_Cleanup(t *testing.T) {
}

func TestPersister_Registration_Cleanup(t *testing.T) {
t.Parallel()

_, reg := internal.NewFastRegistryWithMocks(t)
p := reg.Persister()
currentTime := time.Now()
Expand All @@ -93,6 +103,8 @@ func TestPersister_Registration_Cleanup(t *testing.T) {
}

func TestPersister_Session_Cleanup(t *testing.T) {
t.Parallel()

_, reg := internal.NewFastRegistryWithMocks(t)
p := reg.Persister()
currentTime := time.Now()
Expand All @@ -109,6 +121,8 @@ func TestPersister_Session_Cleanup(t *testing.T) {
}

func TestPersister_Settings_Cleanup(t *testing.T) {
t.Parallel()

_, reg := internal.NewFastRegistryWithMocks(t)
p := reg.Persister()
currentTime := time.Now()
Expand All @@ -125,6 +139,8 @@ func TestPersister_Settings_Cleanup(t *testing.T) {
}

func TestPersister_Verification_Cleanup(t *testing.T) {
t.Parallel()

_, reg := internal.NewFastRegistryWithMocks(t)
p := reg.Persister()
currentTime := time.Now()
Expand All @@ -141,6 +157,8 @@ func TestPersister_Verification_Cleanup(t *testing.T) {
}

func TestPersister_SessionTokenExchange_Cleanup(t *testing.T) {
t.Parallel()

_, reg := internal.NewFastRegistryWithMocks(t)
p := reg.Persister()
currentTime := time.Now()
Expand Down
2 changes: 1 addition & 1 deletion persistence/sql/persister_code.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func useOneTimeCode[P any, U interface {

secrets:
for _, secret := range p.r.Config().SecretsSession(ctx) {
suppliedCode := []byte(p.hmacValueWithSecret(ctx, userProvidedCode, secret))
suppliedCode := []byte(hmacValueWithSecret(ctx, userProvidedCode, secret))
for i := range codes {
c := codes[i]
if subtle.ConstantTimeCompare([]byte(c.GetHMACCode()), suppliedCode) == 0 {
Expand Down
36 changes: 20 additions & 16 deletions persistence/sql/persister_errorx.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,17 @@
package sql

import (
"bytes"
"context"
"encoding/json"
"time"

"github.com/gobuffalo/pop/v6"
"github.com/gofrs/uuid"
"github.com/pkg/errors"
"go.opentelemetry.io/otel/attribute"

"github.com/ory/jsonschema/v3"

"github.com/ory/herodot"
"github.com/ory/jsonschema/v3"
"github.com/ory/x/otelx"
"github.com/ory/x/sqlcon"

Expand All @@ -28,7 +27,7 @@ func (p *Persister) CreateErrorContainer(ctx context.Context, csrfToken string,
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateErrorContainer")
defer otelx.End(span, &err)

message, err := p.encodeSelfServiceErrors(ctx, errs)
message, err := encodeSelfServiceErrors(errs)
if err != nil {
return uuid.Nil, err
}
Expand All @@ -55,14 +54,19 @@ func (p *Persister) ReadErrorContainer(ctx context.Context, id uuid.UUID) (_ *er
defer otelx.End(span, &err)

var ec errorx.ErrorContainer
if err := p.GetConnection(ctx).Where("id = ? AND nid = ?", id, p.NetworkID(ctx)).First(&ec); err != nil {
return nil, sqlcon.HandleError(err)
}

if err := p.GetConnection(ctx).RawQuery(
"UPDATE selfservice_errors SET was_seen = true, seen_at = ? WHERE id = ? AND nid = ?",
time.Now().UTC(), id, p.NetworkID(ctx)).Exec(); err != nil {
return nil, sqlcon.HandleError(err)
if err := p.Transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
if err := c.Where("id = ? AND nid = ?", id, p.NetworkID(ctx)).First(&ec); err != nil {
return sqlcon.HandleError(err)
}

if err := c.RawQuery(
"UPDATE selfservice_errors SET was_seen = true, seen_at = ? WHERE id = ? AND nid = ?",
time.Now().UTC(), id, p.NetworkID(ctx)).Exec(); err != nil {
return sqlcon.HandleError(err)
}
return nil
}); err != nil {
return nil, err
}

return &ec, nil
Expand All @@ -85,7 +89,7 @@ func (p *Persister) ClearErrorContainers(ctx context.Context, olderThan time.Dur
return sqlcon.HandleError(err)
}

func (p *Persister) encodeSelfServiceErrors(ctx context.Context, e error) ([]byte, error) {
func encodeSelfServiceErrors(e error) ([]byte, error) {
if e == nil {
return nil, errors.WithStack(herodot.ErrInternalServerError.WithDebug("A nil error was passed to the error manager which is most likely a code bug."))
}
Expand All @@ -98,10 +102,10 @@ func (p *Persister) encodeSelfServiceErrors(ctx context.Context, e error) ([]byt
e = herodot.ToDefaultError(e, "")
}

var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(e); err != nil {
enc, err := json.Marshal(e)
if err != nil {
return nil, errors.WithStack(herodot.ErrInternalServerError.WithReason("Unable to encode error messages.").WithDebug(err.Error()))
}

return b.Bytes(), nil
return enc, nil
}
18 changes: 5 additions & 13 deletions persistence/sql/persister_hmac.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,19 @@ import (
"context"
"crypto/hmac"
"crypto/sha512"
"crypto/subtle"
"fmt"

"go.opentelemetry.io/otel/trace"
)

func (p *Persister) hmacValue(ctx context.Context, value string) string {
return p.hmacValueWithSecret(ctx, value, p.r.Config().SecretsSession(ctx)[0])
return hmacValueWithSecret(ctx, value, p.r.Config().SecretsSession(ctx)[0])
}

func (p *Persister) hmacValueWithSecret(ctx context.Context, value string, secret []byte) string {
_, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.hmacValueWithSecret")
func hmacValueWithSecret(ctx context.Context, value string, secret []byte) string {
ctx, span := trace.SpanFromContext(ctx).TracerProvider().Tracer("").Start(ctx, "persistence.sql.hmacValueWithSecret")

Check warning

Code scanning / CodeQL

Useless assignment to local variable Warning

This definition of ctx is never used.

Check failure on line 20 in persistence/sql/persister_hmac.go

View workflow job for this annotation

GitHub Actions / Run tests and lints

SA4006: this value of `ctx` is never used (staticcheck)
defer span.End()
h := hmac.New(sha512.New512_256, secret)
_, _ = h.Write([]byte(value))
return fmt.Sprintf("%x", h.Sum(nil))
}

func (p *Persister) hmacConstantCompare(ctx context.Context, value, hash string) bool {
for _, secret := range p.r.Config().SecretsSession(ctx) {
if subtle.ConstantTimeCompare([]byte(p.hmacValueWithSecret(ctx, value, secret)), []byte(hash)) == 1 {
return true
}
}
return false
}
Loading

0 comments on commit b5fb0c4

Please sign in to comment.