Skip to content

Commit

Permalink
refactor: complete login flow refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
aeneasr committed Aug 25, 2020
1 parent 88f581f commit ad2b3db
Show file tree
Hide file tree
Showing 26 changed files with 185 additions and 94 deletions.
9 changes: 8 additions & 1 deletion .schema/api.swagger.json
Original file line number Diff line number Diff line change
Expand Up @@ -1125,8 +1125,10 @@
"parameters": [
{
"type": "string",
"description": "Error is the container's ID",
"name": "error",
"in": "query"
"in": "query",
"required": true
}
],
"responses": {
Expand Down Expand Up @@ -1558,8 +1560,13 @@
},
"errorContainer": {
"type": "object",
"required": [
"id",
"errors"
],
"properties": {
"errors": {
"description": "Errors in the container",
"type": "object"
},
"id": {
Expand Down
2 changes: 1 addition & 1 deletion driver/registry_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ func (m *RegistryDefault) RecoveryRequestPersister() recovery.RequestPersister {
return m.persister
}

func (m *RegistryDefault) LoginRequestPersister() login.RequestPersister {
func (m *RegistryDefault) LoginFlowPersister() login.FlowPersister {
return m.persister
}

Expand Down
2 changes: 1 addition & 1 deletion driver/registry_default_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func (m *RegistryDefault) LoginHandler() *login.Handler {

func (m *RegistryDefault) LoginRequestErrorHandler() *login.ErrorHandler {
if m.selfserviceLoginRequestErrorHandler == nil {
m.selfserviceLoginRequestErrorHandler = login.NewErrorHandler(m, m.c)
m.selfserviceLoginRequestErrorHandler = login.NewFlowErrorHandler(m, m.c)
}

return m.selfserviceLoginRequestErrorHandler
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ require (
github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d // indirect
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0
github.com/armon/go-metrics v0.3.3 // indirect
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869
github.com/bxcodec/faker/v3 v3.3.1
github.com/cenkalti/backoff v2.2.1+incompatible
github.com/coreos/go-oidc v2.2.1+incompatible
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

24 changes: 18 additions & 6 deletions internal/httpclient/models/error_container.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 9 additions & 0 deletions internal/testhelpers/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,15 @@ import (
"github.com/stretchr/testify/require"
)

func NewHTTPGetRequestJSON(t *testing.T, url string) *http.Request {
req, err := http.NewRequest("GET", url, nil)
require.NoError(t, err)

// req.Header.Set("Content-Type", "application/json;charset=utf-8")
req.Header.Set("Accept", "application/json")
return req
}

func HTTPRequestJSON(t *testing.T, client *http.Client, method string, url string, in interface{}) ([]byte, *http.Response) {
var body bytes.Buffer
require.NoError(t, json.NewEncoder(&body).Encode(in))
Expand Down
2 changes: 1 addition & 1 deletion internal/testhelpers/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (

func NewLoginUIRequestEchoServer(t *testing.T, reg driver.Registry) *httptest.Server {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
e, err := reg.LoginRequestPersister().GetLoginRequest(r.Context(), x.ParseUUID(r.URL.Query().Get("flow")))
e, err := reg.LoginFlowPersister().GetLoginFlow(r.Context(), x.ParseUUID(r.URL.Query().Get("flow")))
require.NoError(t, err)
reg.Writer().Write(w, r, e)
}))
Expand Down
2 changes: 1 addition & 1 deletion persistence/reference.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ type Persister interface {
continuity.Persister
identity.PrivilegedPool
registration.RequestPersister
login.RequestPersister
login.FlowPersister
settings.RequestPersister
courier.Persister
session.Persister
Expand Down
2 changes: 1 addition & 1 deletion persistence/sql/migratest/migration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ func TestMigrations(t *testing.T) {
require.NoError(t, c.Select("id").All(&ids))

for _, id := range ids {
actual, err := d.Registry().LoginRequestPersister().GetLoginRequest(context.Background(), id.ID)
actual, err := d.Registry().LoginFlowPersister().GetLoginFlow(context.Background(), id.ID)
require.NoError(t, err)
compareWithFixture(t, actual, "login_request", id.ID.String())
}
Expand Down
18 changes: 9 additions & 9 deletions persistence/sql/persister_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,16 @@ import (
"github.com/ory/kratos/selfservice/flow/login"
)

var _ login.RequestPersister = new(Persister)
var _ login.FlowPersister = new(Persister)

func (p *Persister) CreateLoginRequest(ctx context.Context, r *login.Flow) error {
func (p *Persister) CreateLoginFlow(ctx context.Context, r *login.Flow) error {
return p.GetConnection(ctx).Eager().Create(r)
}

func (p *Persister) UpdateLoginRequest(ctx context.Context, r *login.Flow) error {
func (p *Persister) UpdateLoginFlow(ctx context.Context, r *login.Flow) error {
return p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) error {

rr, err := p.GetLoginRequest(ctx, r.ID)
rr, err := p.GetLoginFlow(ctx, r.ID)
if err != nil {
return err
}
Expand All @@ -44,7 +44,7 @@ func (p *Persister) UpdateLoginRequest(ctx context.Context, r *login.Flow) error
})
}

func (p *Persister) GetLoginRequest(ctx context.Context, id uuid.UUID) (*login.Flow, error) {
func (p *Persister) GetLoginFlow(ctx context.Context, id uuid.UUID) (*login.Flow, error) {
conn := p.GetConnection(ctx)
var r login.Flow
if err := conn.Eager().Find(&r, id); err != nil {
Expand All @@ -58,10 +58,10 @@ func (p *Persister) GetLoginRequest(ctx context.Context, id uuid.UUID) (*login.F
return &r, nil
}

func (p *Persister) MarkRequestForced(ctx context.Context, id uuid.UUID) error {
func (p *Persister) ForceLoginFlow(ctx context.Context, id uuid.UUID) error {
return p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) error {

lr, err := p.GetLoginRequest(ctx, id)
lr, err := p.GetLoginFlow(ctx, id)
if err != nil {
return err
}
Expand All @@ -71,10 +71,10 @@ func (p *Persister) MarkRequestForced(ctx context.Context, id uuid.UUID) error {
})
}

func (p *Persister) UpdateLoginRequestMethod(ctx context.Context, id uuid.UUID, ct identity.CredentialsType, rm *login.FlowMethod) error {
func (p *Persister) UpdateLoginFlowMethod(ctx context.Context, id uuid.UUID, ct identity.CredentialsType, rm *login.FlowMethod) error {
return p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) error {

rr, err := p.GetLoginRequest(ctx, id)
rr, err := p.GetLoginFlow(ctx, id)
if err != nil {
return err
}
Expand Down
20 changes: 10 additions & 10 deletions persistence/sql/persister_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,23 +138,23 @@ func TestPersister(t *testing.T) {
pop.SetLogger(pl(t))
identity.TestPool(p.(identity.PrivilegedPool))(t)
})
t.Run("contract=registration.TestRequestPersister", func(t *testing.T) {
t.Run("contract=registration.TestFlowPersister", func(t *testing.T) {
pop.SetLogger(pl(t))
registration.TestRequestPersister(p)(t)
})
t.Run("contract=errorx.TestPersister", func(t *testing.T) {
pop.SetLogger(pl(t))
errorx.TestPersister(p)(t)
})
t.Run("contract=login.TestRequestPersister", func(t *testing.T) {
t.Run("contract=login.TestFlowPersister", func(t *testing.T) {
pop.SetLogger(pl(t))
login.TestRequestPersister(p)(t)
login.TestFlowPersister(p)(t)
})
t.Run("contract=settings.TestRequestPersister", func(t *testing.T) {
t.Run("contract=settings.TestFlowPersister", func(t *testing.T) {
pop.SetLogger(pl(t))
settings.TestRequestPersister(p)(t)
})
t.Run("contract=session.TestRequestPersister", func(t *testing.T) {
t.Run("contract=session.TestFlowPersister", func(t *testing.T) {
pop.SetLogger(pl(t))
session.TestPersister(p)(t)
})
Expand All @@ -166,7 +166,7 @@ func TestPersister(t *testing.T) {
pop.SetLogger(pl(t))
verification.TestPersister(p)(t)
})
t.Run("contract=recovery.TestRequestPersister", func(t *testing.T) {
t.Run("contract=recovery.TestFlowPersister", func(t *testing.T) {
pop.SetLogger(pl(t))
recovery.TestRequestPersister(p)(t)
})
Expand Down Expand Up @@ -224,14 +224,14 @@ func TestPersister_Transaction(t *testing.T) {
}
err := c.Transaction(func(tx *pop.Connection) error {
ctx := sql.WithTransaction(context.Background(), tx)
require.NoError(t, p.CreateLoginRequest(ctx, lr), "%+v", lr)
require.NoError(t, p.UpdateLoginRequestMethod(ctx, lr.ID, identity.CredentialsTypePassword, &login.FlowMethod{}))
require.NoError(t, getErr(p.GetLoginRequest(ctx, lr.ID)), "%+v", lr)
require.NoError(t, p.CreateLoginFlow(ctx, lr), "%+v", lr)
require.NoError(t, p.UpdateLoginFlowMethod(ctx, lr.ID, identity.CredentialsTypePassword, &login.FlowMethod{}))
require.NoError(t, getErr(p.GetLoginFlow(ctx, lr.ID)), "%+v", lr)
return errors.Errorf(errMessage)
})
require.Error(t, err)
assert.Contains(t, err.Error(), errMessage)
_, err = p.GetLoginRequest(context.Background(), lr.ID)
_, err = p.GetLoginFlow(context.Background(), lr.ID)
require.Error(t, err)
assert.Equal(t, sqlcon.ErrNoRows.Error(), err.Error())
})
Expand Down
6 changes: 6 additions & 0 deletions selfservice/errorx/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,16 @@ import (

// swagger:model errorContainer
type ErrorContainer struct {
// ID of the error container.
//
// required: true
ID uuid.UUID `db:"id" json:"id"`

CSRFToken string `db:"csrf_token" json:"-"`

// Errors in the container
//
// required: true
Errors json.RawMessage `json:"errors" db:"errors"`

// CreatedAt is a helper struct field for gobuffalo.pop.
Expand Down
3 changes: 3 additions & 0 deletions selfservice/errorx/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,10 @@ type errorContainerResponse struct {
// nolint:deadcode,unused
// swagger:parameters getSelfServiceError
type errorContainerParameters struct {
// Error is the container's ID
//
// in: query
// required: true
Error string `json:"error"`
}

Expand Down
2 changes: 1 addition & 1 deletion selfservice/flow/login/flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ func (f Flow) TableName() string {

func (f *Flow) Valid() error {
if f.ExpiresAt.Before(time.Now()) {
return errors.WithStack(newRequestExpiredError(time.Since(f.ExpiresAt)))
return errors.WithStack(NewRequestExpiredError(time.Since(f.ExpiresAt)))
}
return nil
}
Expand Down
8 changes: 4 additions & 4 deletions selfservice/flow/login/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func (h *Handler) NewLoginFlow(w http.ResponseWriter, r *http.Request, flow flow
return nil, err
}

if err := h.d.LoginRequestPersister().CreateLoginRequest(r.Context(), a); err != nil {
if err := h.d.LoginFlowPersister().CreateLoginFlow(r.Context(), a); err != nil {
return nil, err
}
return a, nil
Expand Down Expand Up @@ -133,7 +133,7 @@ func (h *Handler) initAPIFlow(w http.ResponseWriter, r *http.Request, _ httprout
}

if a.Forced {
if err := h.d.LoginRequestPersister().MarkRequestForced(r.Context(), a.ID); err != nil {
if err := h.d.LoginFlowPersister().ForceLoginFlow(r.Context(), a.ID); err != nil {
h.d.Writer().WriteError(w, r, err)
return
}
Expand Down Expand Up @@ -184,7 +184,7 @@ func (h *Handler) initBrowserFlow(w http.ResponseWriter, r *http.Request, ps htt
}

if a.Forced {
if err := h.d.LoginRequestPersister().MarkRequestForced(r.Context(), a.ID); err != nil {
if err := h.d.LoginFlowPersister().ForceLoginFlow(r.Context(), a.ID); err != nil {
h.d.SelfServiceErrorManager().Forward(r.Context(), w, r, err)
return
}
Expand Down Expand Up @@ -251,7 +251,7 @@ func (h *Handler) adminFetchLoginRequest(w http.ResponseWriter, r *http.Request,
}

func (h *Handler) fetchLoginRequest(w http.ResponseWriter, r *http.Request) error {
ar, err := h.d.LoginRequestPersister().GetLoginRequest(r.Context(), x.ParseUUID(r.URL.Query().Get("id")))
ar, err := h.d.LoginFlowPersister().GetLoginFlow(r.Context(), x.ParseUUID(r.URL.Query().Get("id")))
if err != nil {
return err
}
Expand Down
Loading

0 comments on commit ad2b3db

Please sign in to comment.