Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(persistence/sql): move connection to context to enable transactions #254

Merged
merged 2 commits into from
Feb 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cmd/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (

"github.com/fsnotify/fsnotify"
"github.com/gobuffalo/packr/v2"

"github.com/ory/viper"
"github.com/ory/x/flagx"
"github.com/ory/x/viperx"
Expand Down
4 changes: 2 additions & 2 deletions courier/courier_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ func TestSMTP(t *testing.T) {
}

smtp, api := runTestSMTP(t)
t.Logf("SMTP URL: %s",smtp)
t.Logf("API URL: %s",api)
t.Logf("SMTP URL: %s", smtp)
t.Logf("API URL: %s", api)

conf, reg := internal.NewRegistryDefault(t)
viper.Set(configuration.ViperKeyCourierSMTPURL, smtp)
Expand Down
4 changes: 4 additions & 0 deletions persistence/reference.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"context"
"io"

"github.com/gobuffalo/pop/v5"

"github.com/ory/kratos/courier"
"github.com/ory/kratos/identity"
"github.com/ory/kratos/selfservice/errorx"
Expand Down Expand Up @@ -33,4 +35,6 @@ type Persister interface {
MigrationStatus(c context.Context, b io.Writer) error
MigrateDown(c context.Context, steps int) error
MigrateUp(c context.Context) error
GetConnection(ctx context.Context) *pop.Connection
Transaction(ctx context.Context, callback func(connection *pop.Connection) error) error
}
14 changes: 7 additions & 7 deletions persistence/sql/persister.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,26 +41,26 @@ func NewPersister(r persisterDependencies, conf configuration.Provider, c *pop.C
return &Persister{c: c, mb: m, cf: conf, r: r}, nil
}

func (p *Persister) MigrationStatus(c context.Context, w io.Writer) error {
func (p *Persister) MigrationStatus(ctx context.Context, w io.Writer) error {
return errors.WithStack(p.mb.Status(w))
}

func (p *Persister) MigrateDown(c context.Context, steps int) error {
func (p *Persister) MigrateDown(ctx context.Context, steps int) error {
return errors.WithStack(p.mb.Down(steps))
}

func (p *Persister) MigrateUp(c context.Context) error {
func (p *Persister) MigrateUp(ctx context.Context) error {
return errors.WithStack(p.mb.Up())
}

func (p *Persister) Close(c context.Context) error {
return errors.WithStack(p.c.Close())
func (p *Persister) Close(ctx context.Context) error {
return errors.WithStack(p.GetConnection(ctx).Close())
}

func (p *Persister) Ping(c context.Context) error {
func (p *Persister) Ping(ctx context.Context) error {
type pinger interface {
Ping() error
}

return errors.WithStack(p.c.Store.(pinger).Ping())
return errors.WithStack(p.GetConnection(ctx).Store.(pinger).Ping())
}
8 changes: 4 additions & 4 deletions persistence/sql/persister_courier.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ var _ courier.Persister = new(Persister)

func (p *Persister) AddMessage(ctx context.Context, m *courier.Message) error {
m.Status = courier.MessageStatusQueued
return sqlcon.HandleError(p.c.Create(m)) // do not create eager to avoid identity injection.
return sqlcon.HandleError(p.GetConnection(ctx).Create(m)) // do not create eager to avoid identity injection.
}

func (p *Persister) NextMessages(ctx context.Context, limit uint8) ([]courier.Message, error) {
var m []courier.Message
if err := p.c.
if err := p.GetConnection(ctx).
Eager().
Where("status != ?", courier.MessageStatusSent).
Order("created_at ASC").Limit(int(limit)).All(&m); err != nil {
Expand All @@ -40,7 +40,7 @@ func (p *Persister) NextMessages(ctx context.Context, limit uint8) ([]courier.Me

func (p *Persister) LatestQueuedMessage(ctx context.Context) (*courier.Message, error) {
var m courier.Message
if err := p.c.
if err := p.GetConnection(ctx).
Eager().
Where("status != ?", courier.MessageStatusSent).
Order("created_at DESC").First(&m); err != nil {
Expand All @@ -54,7 +54,7 @@ func (p *Persister) LatestQueuedMessage(ctx context.Context) (*courier.Message,
}

func (p *Persister) SetMessageStatus(ctx context.Context, id uuid.UUID, ms courier.MessageStatus) error {
count, err := p.c.RawQuery("UPDATE courier_messages SET status = ? WHERE id = ?", ms, id).ExecWithCount()
count, err := p.GetConnection(ctx).RawQuery("UPDATE courier_messages SET status = ? WHERE id = ?", ms, id).ExecWithCount()
if err != nil {
return sqlcon.HandleError(err)
}
Expand Down
10 changes: 5 additions & 5 deletions persistence/sql/persister_errorx.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func (p *Persister) Add(ctx context.Context, csrfToken string, errs ...error) (u
WasSeen: false,
}

if err := p.c.Create(c); err != nil {
if err := p.GetConnection(ctx).Create(c); err != nil {
return uuid.Nil, sqlcon.HandleError(err)
}

Expand All @@ -42,11 +42,11 @@ func (p *Persister) Add(ctx context.Context, csrfToken string, errs ...error) (u

func (p *Persister) Read(ctx context.Context, id uuid.UUID) (*errorx.ErrorContainer, error) {
var ec errorx.ErrorContainer
if err := p.c.Find(&ec, id); err != nil {
if err := p.GetConnection(ctx).Find(&ec, id); err != nil {
return nil, sqlcon.HandleError(err)
}

if err := p.c.RawQuery("UPDATE selfservice_errors SET was_seen = true, seen_at = ? WHERE id = ?", time.Now().UTC(), id).Exec(); err != nil {
if err := p.GetConnection(ctx).RawQuery("UPDATE selfservice_errors SET was_seen = true, seen_at = ? WHERE id = ?", time.Now().UTC(), id).Exec(); err != nil {
return nil, sqlcon.HandleError(err)
}

Expand All @@ -55,9 +55,9 @@ func (p *Persister) Read(ctx context.Context, id uuid.UUID) (*errorx.ErrorContai

func (p *Persister) Clear(ctx context.Context, olderThan time.Duration, force bool) (err error) {
if force {
err = p.c.RawQuery("DELETE FROM selfservice_errors WHERE seen_at < ?", olderThan).Exec()
err = p.GetConnection(ctx).RawQuery("DELETE FROM selfservice_errors WHERE seen_at < ?", olderThan).Exec()
} else {
err = p.c.RawQuery("DELETE FROM selfservice_errors WHERE was_seen=true AND seen_at < ?", time.Now().UTC().Add(-olderThan)).Exec()
err = p.GetConnection(ctx).RawQuery("DELETE FROM selfservice_errors WHERE was_seen=true AND seen_at < ?", time.Now().UTC().Add(-olderThan)).Exec()
}

return sqlcon.HandleError(err)
Expand Down
28 changes: 14 additions & 14 deletions persistence/sql/persister_identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ var _ identity.PrivilegedPool = new(Persister)

func (p *Persister) FindByCredentialsIdentifier(ctx context.Context, ct identity.CredentialsType, match string) (*identity.Identity, *identity.Credentials, error) {
var cts []identity.CredentialsTypeTable
if err := p.c.All(&cts); err != nil {
if err := p.GetConnection(ctx).All(&cts); err != nil {
return nil, nil, sqlcon.HandleError(err)
}

var find struct {
IdentityID uuid.UUID `db:"identity_id"`
}

if err := p.c.RawQuery(`SELECT
if err := p.GetConnection(ctx).RawQuery(`SELECT
ic.identity_id
FROM identity_credentials ic
INNER JOIN identity_credential_types ict on ic.identity_credential_type_id = ict.id
Expand Down Expand Up @@ -148,7 +148,7 @@ func (p *Persister) CreateIdentity(ctx context.Context, i *identity.Identity) er
return err
}

return sqlcon.HandleError(p.c.Transaction(func(tx *pop.Connection) error {
return sqlcon.HandleError(p.GetConnection(ctx).Transaction(func(tx *pop.Connection) error {
if err := tx.Create(i); err != nil {
return err
}
Expand All @@ -165,7 +165,7 @@ func (p *Persister) ListIdentities(ctx context.Context, limit, offset int) ([]id
is := make([]identity.Identity, 0)

/* #nosec G201 TableName is static */
if err := sqlcon.HandleError(p.c.RawQuery(fmt.Sprintf("SELECT * FROM %s LIMIT ? OFFSET ?", new(identity.Identity).TableName()), limit, offset).All(&is)); err != nil {
if err := sqlcon.HandleError(p.GetConnection(ctx).RawQuery(fmt.Sprintf("SELECT * FROM %s LIMIT ? OFFSET ?", new(identity.Identity).TableName()), limit, offset).All(&is)); err != nil {
return nil, err
}

Expand All @@ -183,7 +183,7 @@ func (p *Persister) UpdateIdentity(ctx context.Context, i *identity.Identity) er
return err
}

return sqlcon.HandleError(p.c.Transaction(func(tx *pop.Connection) error {
return sqlcon.HandleError(p.GetConnection(ctx).Transaction(func(tx *pop.Connection) error {
if count, err := tx.Where("id = ?", i.ID).Count(i); err != nil {
return err
} else if count == 0 {
Expand Down Expand Up @@ -214,7 +214,7 @@ func (p *Persister) UpdateIdentity(ctx context.Context, i *identity.Identity) er

func (p *Persister) DeleteIdentity(ctx context.Context, id uuid.UUID) error {
/* #nosec G201 TableName is static */
count, err := p.c.RawQuery(fmt.Sprintf("DELETE FROM %s WHERE id = ?", new(identity.Identity).TableName()), id).ExecWithCount()
count, err := p.GetConnection(ctx).RawQuery(fmt.Sprintf("DELETE FROM %s WHERE id = ?", new(identity.Identity).TableName()), id).ExecWithCount()
if err != nil {
return sqlcon.HandleError(err)
}
Expand All @@ -226,7 +226,7 @@ func (p *Persister) DeleteIdentity(ctx context.Context, id uuid.UUID) error {

func (p *Persister) GetIdentity(ctx context.Context, id uuid.UUID) (*identity.Identity, error) {
var i identity.Identity
if err := p.c.Find(&i, id); err != nil {
if err := p.GetConnection(ctx).Find(&i, id); err != nil {
return nil, sqlcon.HandleError(err)
}
i.Credentials = nil
Expand All @@ -239,19 +239,19 @@ func (p *Persister) GetIdentity(ctx context.Context, id uuid.UUID) (*identity.Id

func (p *Persister) GetIdentityConfidential(ctx context.Context, id uuid.UUID) (*identity.Identity, error) {
var i identity.Identity
if err := p.c.Eager().Find(&i, id); err != nil {
if err := p.GetConnection(ctx).Eager().Find(&i, id); err != nil {
return nil, sqlcon.HandleError(err)
}

var cts []identity.CredentialsTypeTable
if err := p.c.All(&cts); err != nil {
if err := p.GetConnection(ctx).All(&cts); err != nil {
return nil, sqlcon.HandleError(err)
}

i.Credentials = map[identity.CredentialsType]identity.Credentials{}
for _, creds := range i.CredentialsCollection {
var cs identity.CredentialIdentifierCollection
if err := p.c.Where("identity_credential_id = ?", creds.ID).All(&cs); err != nil {
if err := p.GetConnection(ctx).Where("identity_credential_id = ?", creds.ID).All(&cs); err != nil {
return nil, sqlcon.HandleError(err)
}

Expand All @@ -277,7 +277,7 @@ func (p *Persister) GetIdentityConfidential(ctx context.Context, id uuid.UUID) (

func (p *Persister) FindAddressByCode(ctx context.Context, code string) (*identity.VerifiableAddress, error) {
var address identity.VerifiableAddress
if err := p.c.Where("code = ?", code).First(&address); err != nil {
if err := p.GetConnection(ctx).Where("code = ?", code).First(&address); err != nil {
return nil, sqlcon.HandleError(err)
}

Expand All @@ -286,7 +286,7 @@ func (p *Persister) FindAddressByCode(ctx context.Context, code string) (*identi

func (p *Persister) FindAddressByValue(ctx context.Context, via identity.VerifiableAddressType, value string) (*identity.VerifiableAddress, error) {
var address identity.VerifiableAddress
if err := p.c.Where("via = ? AND value = ?", via, value).First(&address); err != nil {
if err := p.GetConnection(ctx).Where("via = ? AND value = ?", via, value).First(&address); err != nil {
return nil, sqlcon.HandleError(err)
}

Expand All @@ -299,7 +299,7 @@ func (p *Persister) VerifyAddress(ctx context.Context, code string) error {
return err
}

return sqlcon.HandleError(p.c.RawQuery(
return sqlcon.HandleError(p.GetConnection(ctx).RawQuery(
/* #nosec G201 TableName is static */
fmt.Sprintf(
"UPDATE %s SET status = ?, verified = true, verified_at = ?, code = ? WHERE code = ?",
Expand All @@ -313,7 +313,7 @@ func (p *Persister) VerifyAddress(ctx context.Context, code string) error {
}

func (p *Persister) UpdateVerifiableAddress(ctx context.Context, address *identity.VerifiableAddress) error {
return sqlcon.HandleError(p.c.Update(address))
return sqlcon.HandleError(p.GetConnection(ctx).Update(address))
}

func (p *Persister) validateIdentity(i *identity.Identity) error {
Expand Down
40 changes: 23 additions & 17 deletions persistence/sql/persister_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package sql
import (
"context"

"github.com/gobuffalo/pop/v5"

"github.com/gofrs/uuid"

"github.com/ory/x/sqlcon"
Expand All @@ -14,35 +16,39 @@ import (
var _ login.RequestPersister = new(Persister)

func (p *Persister) CreateLoginRequest(ctx context.Context, r *login.Request) error {
return p.c.Eager().Create(r)
return p.GetConnection(ctx).Eager().Create(r)
}

func (p *Persister) GetLoginRequest(ctx context.Context, id uuid.UUID) (*login.Request, error) {
conn := p.GetConnection(ctx)
var r login.Request
if err := p.c.Eager().Find(&r, id); err != nil {
if err := conn.Eager().Find(&r, id); err != nil {
return nil, sqlcon.HandleError(err)
}

if err := (&r).AfterFind(p.c); err != nil {
if err := (&r).AfterFind(conn); err != nil {
return nil, err
}

return &r, nil
}

func (p *Persister) UpdateLoginRequest(ctx context.Context, id uuid.UUID, ct identity.CredentialsType, rm *login.RequestMethod) error {
rr, err := p.GetLoginRequest(ctx, id)
if err != nil {
return err
}

method, ok := rr.Methods[ct]
if !ok {
rm.RequestID = rr.ID
rm.Method = ct
return p.c.Save(rm)
}

method.Config = rm.Config
return p.c.Save(method)
return p.Transaction(ctx, func(tx *pop.Connection) error {
ctx := WithTransaction(ctx, tx)
rr, err := p.GetLoginRequest(ctx, id)
if err != nil {
return err
}

method, ok := rr.Methods[ct]
if !ok {
rm.RequestID = rr.ID
rm.Method = ct
return tx.Save(rm)
}

method.Config = rm.Config
return tx.Save(method)
})
}
6 changes: 3 additions & 3 deletions persistence/sql/persister_profile.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,17 @@ var _ profile.RequestPersister = new(Persister)

func (p *Persister) CreateProfileRequest(ctx context.Context, r *profile.Request) error {
r.IdentityID = r.Identity.ID
return sqlcon.HandleError(p.c.Create(r)) // This must not be eager or identities will be created / updated
return sqlcon.HandleError(p.GetConnection(ctx).Create(r)) // This must not be eager or identities will be created / updated
}

func (p *Persister) GetProfileRequest(ctx context.Context, id uuid.UUID) (*profile.Request, error) {
var r profile.Request
if err := p.c.Eager().Find(&r, id); err != nil {
if err := p.GetConnection(ctx).Eager().Find(&r, id); err != nil {
return nil, sqlcon.HandleError(err)
}
return &r, nil
}

func (p *Persister) UpdateProfileRequest(ctx context.Context, r *profile.Request) error {
return sqlcon.HandleError(p.c.Update(r)) // This must not be eager or identities will be created / updated
return sqlcon.HandleError(p.GetConnection(ctx).Update(r)) // This must not be eager or identities will be created / updated
}
10 changes: 5 additions & 5 deletions persistence/sql/persister_registration.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,16 @@ import (
)

func (p *Persister) CreateRegistrationRequest(ctx context.Context, r *registration.Request) error {
return p.c.Eager().Create(r)
return p.GetConnection(ctx).Eager().Create(r)
}

func (p *Persister) GetRegistrationRequest(ctx context.Context, id uuid.UUID) (*registration.Request, error) {
var r registration.Request
if err := p.c.Eager().Find(&r, id); err != nil {
if err := p.GetConnection(ctx).Eager().Find(&r, id); err != nil {
return nil, sqlcon.HandleError(err)
}

if err := (&r).AfterFind(p.c); err != nil {
if err := (&r).AfterFind(p.GetConnection(ctx)); err != nil {
return nil, err
}

Expand All @@ -38,9 +38,9 @@ func (p *Persister) UpdateRegistrationRequest(ctx context.Context, id uuid.UUID,
if !ok {
rm.RequestID = rr.ID
rm.Method = ct
return p.c.Save(rm)
return p.GetConnection(ctx).Save(rm)
}

method.Config = rm.Config
return p.c.Save(method)
return p.GetConnection(ctx).Save(method)
}
Loading