diff --git a/.schema/api.swagger.json b/.schema/api.swagger.json index a9e82f5a0a82..5d239ba79ac2 100755 --- a/.schema/api.swagger.json +++ b/.schema/api.swagger.json @@ -1125,8 +1125,10 @@ "parameters": [ { "type": "string", + "description": "Error is the container's ID", "name": "error", - "in": "query" + "in": "query", + "required": true } ], "responses": { @@ -1558,8 +1560,13 @@ }, "errorContainer": { "type": "object", + "required": [ + "id", + "errors" + ], "properties": { "errors": { + "description": "Errors in the container", "type": "object" }, "id": { diff --git a/driver/registry_default.go b/driver/registry_default.go index 3abd13c03b0f..557d2bfed5e9 100644 --- a/driver/registry_default.go +++ b/driver/registry_default.go @@ -490,7 +490,7 @@ func (m *RegistryDefault) RecoveryRequestPersister() recovery.RequestPersister { return m.persister } -func (m *RegistryDefault) LoginRequestPersister() login.RequestPersister { +func (m *RegistryDefault) LoginFlowPersister() login.FlowPersister { return m.persister } diff --git a/driver/registry_default_login.go b/driver/registry_default_login.go index 9f28de5f71fd..ef9a4610d857 100644 --- a/driver/registry_default_login.go +++ b/driver/registry_default_login.go @@ -40,7 +40,7 @@ func (m *RegistryDefault) LoginHandler() *login.Handler { func (m *RegistryDefault) LoginRequestErrorHandler() *login.ErrorHandler { if m.selfserviceLoginRequestErrorHandler == nil { - m.selfserviceLoginRequestErrorHandler = login.NewErrorHandler(m, m.c) + m.selfserviceLoginRequestErrorHandler = login.NewFlowErrorHandler(m, m.c) } return m.selfserviceLoginRequestErrorHandler diff --git a/go.mod b/go.mod index cf581ed3ec73..ce4eeee7fbbc 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d // indirect github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0 github.com/armon/go-metrics v0.3.3 // indirect + github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 github.com/bxcodec/faker/v3 v3.3.1 github.com/cenkalti/backoff v2.2.1+incompatible github.com/coreos/go-oidc v2.2.1+incompatible diff --git a/internal/httpclient/client/common/get_self_service_error_parameters.go b/internal/httpclient/client/common/get_self_service_error_parameters.go index 70758db564b0..4c451b2002f1 100644 --- a/internal/httpclient/client/common/get_self_service_error_parameters.go +++ b/internal/httpclient/client/common/get_self_service_error_parameters.go @@ -60,8 +60,11 @@ for the get self service error operation typically these are written to a http.R */ type GetSelfServiceErrorParams struct { - /*Error*/ - Error *string + /*Error + Error is the container's ID + + */ + Error string timeout time.Duration Context context.Context @@ -102,13 +105,13 @@ func (o *GetSelfServiceErrorParams) SetHTTPClient(client *http.Client) { } // WithError adds the error to the get self service error params -func (o *GetSelfServiceErrorParams) WithError(error *string) *GetSelfServiceErrorParams { +func (o *GetSelfServiceErrorParams) WithError(error string) *GetSelfServiceErrorParams { o.SetError(error) return o } // SetError adds the error to the get self service error params -func (o *GetSelfServiceErrorParams) SetError(error *string) { +func (o *GetSelfServiceErrorParams) SetError(error string) { o.Error = error } @@ -120,20 +123,13 @@ func (o *GetSelfServiceErrorParams) WriteToRequest(r runtime.ClientRequest, reg } var res []error - if o.Error != nil { - - // query param error - var qrError string - if o.Error != nil { - qrError = *o.Error + // query param error + qrError := o.Error + qError := qrError + if qError != "" { + if err := r.SetQueryParam("error", qError); err != nil { + return err } - qError := qrError - if qError != "" { - if err := r.SetQueryParam("error", qError); err != nil { - return err - } - } - } if len(res) > 0 { diff --git a/internal/httpclient/models/error_container.go b/internal/httpclient/models/error_container.go index 8978b5bb67f0..87aef47f881b 100644 --- a/internal/httpclient/models/error_container.go +++ b/internal/httpclient/models/error_container.go @@ -9,6 +9,7 @@ import ( "github.com/go-openapi/errors" "github.com/go-openapi/strfmt" "github.com/go-openapi/swag" + "github.com/go-openapi/validate" ) // ErrorContainer error container @@ -16,18 +17,24 @@ import ( // swagger:model errorContainer type ErrorContainer struct { - // errors - Errors interface{} `json:"errors,omitempty"` + // Errors in the container + // Required: true + Errors interface{} `json:"errors"` // id + // Required: true // Format: uuid4 - ID UUID `json:"id,omitempty"` + ID UUID `json:"id"` } // Validate validates this error container func (m *ErrorContainer) Validate(formats strfmt.Registry) error { var res []error + if err := m.validateErrors(formats); err != nil { + res = append(res, err) + } + if err := m.validateID(formats); err != nil { res = append(res, err) } @@ -38,12 +45,17 @@ func (m *ErrorContainer) Validate(formats strfmt.Registry) error { return nil } -func (m *ErrorContainer) validateID(formats strfmt.Registry) error { +func (m *ErrorContainer) validateErrors(formats strfmt.Registry) error { - if swag.IsZero(m.ID) { // not required - return nil + if err := validate.Required("errors", "body", m.Errors); err != nil { + return err } + return nil +} + +func (m *ErrorContainer) validateID(formats strfmt.Registry) error { + if err := m.ID.Validate(formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName("id") diff --git a/internal/testhelpers/http.go b/internal/testhelpers/http.go index 75930db4b06f..62ae474d3873 100644 --- a/internal/testhelpers/http.go +++ b/internal/testhelpers/http.go @@ -11,6 +11,15 @@ import ( "github.com/stretchr/testify/require" ) +func NewHTTPGetRequestJSON(t *testing.T, url string) *http.Request { + req, err := http.NewRequest("GET", url, nil) + require.NoError(t, err) + + // req.Header.Set("Content-Type", "application/json;charset=utf-8") + req.Header.Set("Accept", "application/json") + return req +} + func HTTPRequestJSON(t *testing.T, client *http.Client, method string, url string, in interface{}) ([]byte, *http.Response) { var body bytes.Buffer require.NoError(t, json.NewEncoder(&body).Encode(in)) diff --git a/internal/testhelpers/login.go b/internal/testhelpers/login.go index a64bb1c964e6..85470466eec7 100644 --- a/internal/testhelpers/login.go +++ b/internal/testhelpers/login.go @@ -16,7 +16,7 @@ import ( func NewLoginUIRequestEchoServer(t *testing.T, reg driver.Registry) *httptest.Server { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - e, err := reg.LoginRequestPersister().GetLoginRequest(r.Context(), x.ParseUUID(r.URL.Query().Get("flow"))) + e, err := reg.LoginFlowPersister().GetLoginFlow(r.Context(), x.ParseUUID(r.URL.Query().Get("flow"))) require.NoError(t, err) reg.Writer().Write(w, r, e) })) diff --git a/persistence/reference.go b/persistence/reference.go index 1c46957996f5..601460ed05c4 100644 --- a/persistence/reference.go +++ b/persistence/reference.go @@ -27,7 +27,7 @@ type Persister interface { continuity.Persister identity.PrivilegedPool registration.RequestPersister - login.RequestPersister + login.FlowPersister settings.RequestPersister courier.Persister session.Persister diff --git a/persistence/sql/migratest/migration_test.go b/persistence/sql/migratest/migration_test.go index 1b0ea452e1a4..cd7d681bb392 100644 --- a/persistence/sql/migratest/migration_test.go +++ b/persistence/sql/migratest/migration_test.go @@ -129,7 +129,7 @@ func TestMigrations(t *testing.T) { require.NoError(t, c.Select("id").All(&ids)) for _, id := range ids { - actual, err := d.Registry().LoginRequestPersister().GetLoginRequest(context.Background(), id.ID) + actual, err := d.Registry().LoginFlowPersister().GetLoginFlow(context.Background(), id.ID) require.NoError(t, err) compareWithFixture(t, actual, "login_request", id.ID.String()) } diff --git a/persistence/sql/persister_login.go b/persistence/sql/persister_login.go index 4bf8cc126b67..a7849f7699db 100644 --- a/persistence/sql/persister_login.go +++ b/persistence/sql/persister_login.go @@ -13,16 +13,16 @@ import ( "github.com/ory/kratos/selfservice/flow/login" ) -var _ login.RequestPersister = new(Persister) +var _ login.FlowPersister = new(Persister) -func (p *Persister) CreateLoginRequest(ctx context.Context, r *login.Flow) error { +func (p *Persister) CreateLoginFlow(ctx context.Context, r *login.Flow) error { return p.GetConnection(ctx).Eager().Create(r) } -func (p *Persister) UpdateLoginRequest(ctx context.Context, r *login.Flow) error { +func (p *Persister) UpdateLoginFlow(ctx context.Context, r *login.Flow) error { return p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) error { - rr, err := p.GetLoginRequest(ctx, r.ID) + rr, err := p.GetLoginFlow(ctx, r.ID) if err != nil { return err } @@ -44,7 +44,7 @@ func (p *Persister) UpdateLoginRequest(ctx context.Context, r *login.Flow) error }) } -func (p *Persister) GetLoginRequest(ctx context.Context, id uuid.UUID) (*login.Flow, error) { +func (p *Persister) GetLoginFlow(ctx context.Context, id uuid.UUID) (*login.Flow, error) { conn := p.GetConnection(ctx) var r login.Flow if err := conn.Eager().Find(&r, id); err != nil { @@ -58,10 +58,10 @@ func (p *Persister) GetLoginRequest(ctx context.Context, id uuid.UUID) (*login.F return &r, nil } -func (p *Persister) MarkRequestForced(ctx context.Context, id uuid.UUID) error { +func (p *Persister) ForceLoginFlow(ctx context.Context, id uuid.UUID) error { return p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) error { - lr, err := p.GetLoginRequest(ctx, id) + lr, err := p.GetLoginFlow(ctx, id) if err != nil { return err } @@ -71,10 +71,10 @@ func (p *Persister) MarkRequestForced(ctx context.Context, id uuid.UUID) error { }) } -func (p *Persister) UpdateLoginRequestMethod(ctx context.Context, id uuid.UUID, ct identity.CredentialsType, rm *login.FlowMethod) error { +func (p *Persister) UpdateLoginFlowMethod(ctx context.Context, id uuid.UUID, ct identity.CredentialsType, rm *login.FlowMethod) error { return p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) error { - rr, err := p.GetLoginRequest(ctx, id) + rr, err := p.GetLoginFlow(ctx, id) if err != nil { return err } diff --git a/persistence/sql/persister_test.go b/persistence/sql/persister_test.go index 589f0a1902cc..bf112d6e2e33 100644 --- a/persistence/sql/persister_test.go +++ b/persistence/sql/persister_test.go @@ -138,7 +138,7 @@ func TestPersister(t *testing.T) { pop.SetLogger(pl(t)) identity.TestPool(p.(identity.PrivilegedPool))(t) }) - t.Run("contract=registration.TestRequestPersister", func(t *testing.T) { + t.Run("contract=registration.TestFlowPersister", func(t *testing.T) { pop.SetLogger(pl(t)) registration.TestRequestPersister(p)(t) }) @@ -146,15 +146,15 @@ func TestPersister(t *testing.T) { pop.SetLogger(pl(t)) errorx.TestPersister(p)(t) }) - t.Run("contract=login.TestRequestPersister", func(t *testing.T) { + t.Run("contract=login.TestFlowPersister", func(t *testing.T) { pop.SetLogger(pl(t)) - login.TestRequestPersister(p)(t) + login.TestFlowPersister(p)(t) }) - t.Run("contract=settings.TestRequestPersister", func(t *testing.T) { + t.Run("contract=settings.TestFlowPersister", func(t *testing.T) { pop.SetLogger(pl(t)) settings.TestRequestPersister(p)(t) }) - t.Run("contract=session.TestRequestPersister", func(t *testing.T) { + t.Run("contract=session.TestFlowPersister", func(t *testing.T) { pop.SetLogger(pl(t)) session.TestPersister(p)(t) }) @@ -166,7 +166,7 @@ func TestPersister(t *testing.T) { pop.SetLogger(pl(t)) verification.TestPersister(p)(t) }) - t.Run("contract=recovery.TestRequestPersister", func(t *testing.T) { + t.Run("contract=recovery.TestFlowPersister", func(t *testing.T) { pop.SetLogger(pl(t)) recovery.TestRequestPersister(p)(t) }) @@ -224,14 +224,14 @@ func TestPersister_Transaction(t *testing.T) { } err := c.Transaction(func(tx *pop.Connection) error { ctx := sql.WithTransaction(context.Background(), tx) - require.NoError(t, p.CreateLoginRequest(ctx, lr), "%+v", lr) - require.NoError(t, p.UpdateLoginRequestMethod(ctx, lr.ID, identity.CredentialsTypePassword, &login.FlowMethod{})) - require.NoError(t, getErr(p.GetLoginRequest(ctx, lr.ID)), "%+v", lr) + require.NoError(t, p.CreateLoginFlow(ctx, lr), "%+v", lr) + require.NoError(t, p.UpdateLoginFlowMethod(ctx, lr.ID, identity.CredentialsTypePassword, &login.FlowMethod{})) + require.NoError(t, getErr(p.GetLoginFlow(ctx, lr.ID)), "%+v", lr) return errors.Errorf(errMessage) }) require.Error(t, err) assert.Contains(t, err.Error(), errMessage) - _, err = p.GetLoginRequest(context.Background(), lr.ID) + _, err = p.GetLoginFlow(context.Background(), lr.ID) require.Error(t, err) assert.Equal(t, sqlcon.ErrNoRows.Error(), err.Error()) }) diff --git a/selfservice/errorx/error.go b/selfservice/errorx/error.go index 8d93409817b4..f8f4cef91925 100644 --- a/selfservice/errorx/error.go +++ b/selfservice/errorx/error.go @@ -10,10 +10,16 @@ import ( // swagger:model errorContainer type ErrorContainer struct { + // ID of the error container. + // + // required: true ID uuid.UUID `db:"id" json:"id"` CSRFToken string `db:"csrf_token" json:"-"` + // Errors in the container + // + // required: true Errors json.RawMessage `json:"errors" db:"errors"` // CreatedAt is a helper struct field for gobuffalo.pop. diff --git a/selfservice/errorx/handler.go b/selfservice/errorx/handler.go index 69990eb3105d..15046998ebdd 100644 --- a/selfservice/errorx/handler.go +++ b/selfservice/errorx/handler.go @@ -61,7 +61,10 @@ type errorContainerResponse struct { // nolint:deadcode,unused // swagger:parameters getSelfServiceError type errorContainerParameters struct { + // Error is the container's ID + // // in: query + // required: true Error string `json:"error"` } diff --git a/selfservice/flow/login/flow.go b/selfservice/flow/login/flow.go index ab148a926563..46cb7db4bcd0 100644 --- a/selfservice/flow/login/flow.go +++ b/selfservice/flow/login/flow.go @@ -143,7 +143,7 @@ func (f Flow) TableName() string { func (f *Flow) Valid() error { if f.ExpiresAt.Before(time.Now()) { - return errors.WithStack(newRequestExpiredError(time.Since(f.ExpiresAt))) + return errors.WithStack(NewRequestExpiredError(time.Since(f.ExpiresAt))) } return nil } diff --git a/selfservice/flow/login/handler.go b/selfservice/flow/login/handler.go index d1e43b572718..c1f69be3037c 100644 --- a/selfservice/flow/login/handler.go +++ b/selfservice/flow/login/handler.go @@ -71,7 +71,7 @@ func (h *Handler) NewLoginFlow(w http.ResponseWriter, r *http.Request, flow flow return nil, err } - if err := h.d.LoginRequestPersister().CreateLoginRequest(r.Context(), a); err != nil { + if err := h.d.LoginFlowPersister().CreateLoginFlow(r.Context(), a); err != nil { return nil, err } return a, nil @@ -133,7 +133,7 @@ func (h *Handler) initAPIFlow(w http.ResponseWriter, r *http.Request, _ httprout } if a.Forced { - if err := h.d.LoginRequestPersister().MarkRequestForced(r.Context(), a.ID); err != nil { + if err := h.d.LoginFlowPersister().ForceLoginFlow(r.Context(), a.ID); err != nil { h.d.Writer().WriteError(w, r, err) return } @@ -184,7 +184,7 @@ func (h *Handler) initBrowserFlow(w http.ResponseWriter, r *http.Request, ps htt } if a.Forced { - if err := h.d.LoginRequestPersister().MarkRequestForced(r.Context(), a.ID); err != nil { + if err := h.d.LoginFlowPersister().ForceLoginFlow(r.Context(), a.ID); err != nil { h.d.SelfServiceErrorManager().Forward(r.Context(), w, r, err) return } @@ -251,7 +251,7 @@ func (h *Handler) adminFetchLoginRequest(w http.ResponseWriter, r *http.Request, } func (h *Handler) fetchLoginRequest(w http.ResponseWriter, r *http.Request) error { - ar, err := h.d.LoginRequestPersister().GetLoginRequest(r.Context(), x.ParseUUID(r.URL.Query().Get("id"))) + ar, err := h.d.LoginFlowPersister().GetLoginFlow(r.Context(), x.ParseUUID(r.URL.Query().Get("id"))) if err != nil { return err } diff --git a/selfservice/flow/login/handler_test.go b/selfservice/flow/login/handler_test.go index cdfb8b35d380..8d3ed76ec355 100644 --- a/selfservice/flow/login/handler_test.go +++ b/selfservice/flow/login/handler_test.go @@ -55,7 +55,7 @@ func TestHandlerSettingForced(t *testing.T) { for _, s := range reg.LoginStrategies() { require.NoError(t, s.PopulateLoginMethod(req, loginFlow)) } - require.NoError(t, reg.LoginRequestPersister().CreateLoginRequest(context.TODO(), loginFlow), "%+v", loginFlow) + require.NoError(t, reg.LoginFlowPersister().CreateLoginFlow(context.TODO(), loginFlow), "%+v", loginFlow) q := url.Values{"id": {flowID.String()}} for key := range extQuery { @@ -164,7 +164,7 @@ func TestLoginHandler(t *testing.T) { t.Run("case=expired", func(t *testing.T) { lr := newExpiredRequest() - require.NoError(t, reg.LoginRequestPersister().CreateLoginRequest(context.Background(), lr)) + require.NoError(t, reg.LoginFlowPersister().CreateLoginFlow(context.Background(), lr)) res, body := x.EasyGet(t, admin.Client(), endpoint.URL+login.RouteGetFlow+"?id="+lr.ID.String()) assertExpiredPayload(t, res, body) }) diff --git a/selfservice/flow/login/persistence.go b/selfservice/flow/login/persistence.go index 03c9521c2693..ca1119406bed 100644 --- a/selfservice/flow/login/persistence.go +++ b/selfservice/flow/login/persistence.go @@ -16,19 +16,19 @@ import ( ) type ( - RequestPersister interface { - UpdateLoginRequest(context.Context, *Flow) error - CreateLoginRequest(context.Context, *Flow) error - GetLoginRequest(context.Context, uuid.UUID) (*Flow, error) - UpdateLoginRequestMethod(context.Context, uuid.UUID, identity.CredentialsType, *FlowMethod) error - MarkRequestForced(ctx context.Context, id uuid.UUID) error + FlowPersister interface { + UpdateLoginFlow(context.Context, *Flow) error + CreateLoginFlow(context.Context, *Flow) error + GetLoginFlow(context.Context, uuid.UUID) (*Flow, error) + UpdateLoginFlowMethod(context.Context, uuid.UUID, identity.CredentialsType, *FlowMethod) error + ForceLoginFlow(ctx context.Context, id uuid.UUID) error } RequestPersistenceProvider interface { - LoginRequestPersister() RequestPersister + LoginFlowPersister() FlowPersister } ) -func TestRequestPersister(p RequestPersister) func(t *testing.T) { +func TestFlowPersister(p FlowPersister) func(t *testing.T) { var clearids = func(r *Flow) { r.ID = uuid.UUID{} for k := range r.Methods { @@ -38,7 +38,7 @@ func TestRequestPersister(p RequestPersister) func(t *testing.T) { return func(t *testing.T) { t.Run("case=should error when the login flow does not exist", func(t *testing.T) { - _, err := p.GetLoginRequest(context.Background(), x.NewUUID()) + _, err := p.GetLoginFlow(context.Background(), x.NewUUID()) require.Error(t, err) }) @@ -56,13 +56,13 @@ func TestRequestPersister(p RequestPersister) func(t *testing.T) { t.Run("case=should create with set ids", func(t *testing.T) { var r Flow require.NoError(t, faker.FakeData(&r)) - require.NoError(t, p.CreateLoginRequest(context.Background(), &r)) + require.NoError(t, p.CreateLoginFlow(context.Background(), &r)) }) t.Run("case=should create a new login flow and properly set IDs", func(t *testing.T) { r := newRequest(t) methods := len(r.Methods) - err := p.CreateLoginRequest(context.Background(), r) + err := p.CreateLoginFlow(context.Background(), r) require.NoError(t, err, "%#v", err) assert.Nil(t, r.MethodsRaw) @@ -75,10 +75,10 @@ func TestRequestPersister(p RequestPersister) func(t *testing.T) { t.Run("case=should create and fetch a login flow", func(t *testing.T) { expected := newRequest(t) - err := p.CreateLoginRequest(context.Background(), expected) + err := p.CreateLoginFlow(context.Background(), expected) require.NoError(t, err) - actual, err := p.GetLoginRequest(context.Background(), expected.ID) + actual, err := p.GetLoginFlow(context.Background(), expected.ID) require.NoError(t, err) assert.Empty(t, actual.MethodsRaw) @@ -104,10 +104,10 @@ func TestRequestPersister(p RequestPersister) func(t *testing.T) { Config: &FlowMethodConfig{FlowMethodConfigurator: form.NewHTMLForm(string(identity.CredentialsTypePassword))}, }, } - err := p.CreateLoginRequest(context.Background(), expected) + err := p.CreateLoginFlow(context.Background(), expected) require.NoError(t, err) - actual, err := p.GetLoginRequest(context.Background(), expected.ID) + actual, err := p.GetLoginFlow(context.Background(), expected.ID) require.NoError(t, err) assert.Equal(t, flow.TypeAPI, actual.Type) @@ -118,9 +118,9 @@ func TestRequestPersister(p RequestPersister) func(t *testing.T) { actual.Type = flow.TypeBrowser actual.Forced = true - require.NoError(t, p.UpdateLoginRequest(context.Background(), actual)) + require.NoError(t, p.UpdateLoginFlow(context.Background(), actual)) - actual, err = p.GetLoginRequest(context.Background(), actual.ID) + actual, err = p.GetLoginFlow(context.Background(), actual.ID) require.NoError(t, err) assert.Equal(t, flow.TypeBrowser, actual.Type) assert.True(t, actual.Forced) @@ -133,10 +133,10 @@ func TestRequestPersister(p RequestPersister) func(t *testing.T) { t.Run("case=should properly update a flow", func(t *testing.T) { expected := newRequest(t) expected.Type = flow.TypeAPI - err := p.CreateLoginRequest(context.Background(), expected) + err := p.CreateLoginFlow(context.Background(), expected) require.NoError(t, err) - actual, err := p.GetLoginRequest(context.Background(), expected.ID) + actual, err := p.GetLoginFlow(context.Background(), expected.ID) require.NoError(t, err) assert.Equal(t, flow.TypeAPI, actual.Type) }) @@ -144,24 +144,24 @@ func TestRequestPersister(p RequestPersister) func(t *testing.T) { t.Run("case=should update a login flow", func(t *testing.T) { expected := newRequest(t) delete(expected.Methods, identity.CredentialsTypeOIDC) - err := p.CreateLoginRequest(context.Background(), expected) + err := p.CreateLoginFlow(context.Background(), expected) require.NoError(t, err) - actual, err := p.GetLoginRequest(context.Background(), expected.ID) + actual, err := p.GetLoginFlow(context.Background(), expected.ID) require.NoError(t, err) assert.Len(t, actual.Methods, 1) - require.NoError(t, p.UpdateLoginRequestMethod(context.Background(), expected.ID, identity.CredentialsTypeOIDC, &FlowMethod{ + require.NoError(t, p.UpdateLoginFlowMethod(context.Background(), expected.ID, identity.CredentialsTypeOIDC, &FlowMethod{ Method: identity.CredentialsTypeOIDC, Config: &FlowMethodConfig{FlowMethodConfigurator: form.NewHTMLForm(string(identity.CredentialsTypeOIDC))}, })) - require.NoError(t, p.UpdateLoginRequestMethod(context.Background(), expected.ID, identity.CredentialsTypePassword, &FlowMethod{ + require.NoError(t, p.UpdateLoginFlowMethod(context.Background(), expected.ID, identity.CredentialsTypePassword, &FlowMethod{ Method: identity.CredentialsTypePassword, Config: &FlowMethodConfig{FlowMethodConfigurator: form.NewHTMLForm(string(identity.CredentialsTypePassword))}, })) - actual, err = p.GetLoginRequest(context.Background(), expected.ID) + actual, err = p.GetLoginFlow(context.Background(), expected.ID) require.NoError(t, err) require.Len(t, actual.Methods, 2) assert.EqualValues(t, identity.CredentialsTypePassword, actual.Active) diff --git a/selfservice/strategy/oidc/strategy.go b/selfservice/strategy/oidc/strategy.go index 40e92b33c87a..8d9ccf30e1d2 100644 --- a/selfservice/strategy/oidc/strategy.go +++ b/selfservice/strategy/oidc/strategy.go @@ -219,7 +219,7 @@ func (s *Strategy) validateRequest(ctx context.Context, r *http.Request, rid uui return ar, nil } - if ar, err := s.d.LoginRequestPersister().GetLoginRequest(ctx, rid); err == nil { + if ar, err := s.d.LoginFlowPersister().GetLoginFlow(ctx, rid); err == nil { if err := ar.Valid(); err != nil { return ar, err } @@ -413,8 +413,8 @@ func (s *Strategy) handleError(w http.ResponseWriter, r *http.Request, rid uuid. return } - if lr, rerr := s.d.LoginRequestPersister().GetLoginRequest(r.Context(), rid); rerr == nil { - s.d.LoginRequestErrorHandler().HandleLoginError(w, r, s.ID(), lr, err) + if lr, rerr := s.d.LoginFlowPersister().GetLoginFlow(r.Context(), rid); rerr == nil { + s.d.LoginRequestErrorHandler().WriteFlowError(w, r, s.ID(), lr, err) return } else if sr, rerr := s.d.SettingsRequestPersister().GetSettingsRequest(r.Context(), rid); rerr == nil { s.d.SettingsRequestErrorHandler().HandleSettingsError(w, r, sr, err, s.SettingsStrategyID()) diff --git a/selfservice/strategy/oidc/strategy_helper_test.go b/selfservice/strategy/oidc/strategy_helper_test.go index 5bbb7a9024e9..3eb00c946a4f 100644 --- a/selfservice/strategy/oidc/strategy_helper_test.go +++ b/selfservice/strategy/oidc/strategy_helper_test.go @@ -168,7 +168,7 @@ func newUI(t *testing.T, reg driver.Registry) *httptest.Server { var e interface{} var err error if r.URL.Path == "/login" { - e, err = reg.LoginRequestPersister().GetLoginRequest(r.Context(), x.ParseUUID(r.URL.Query().Get("request"))) + e, err = reg.LoginFlowPersister().GetLoginFlow(r.Context(), x.ParseUUID(r.URL.Query().Get("request"))) } else if r.URL.Path == "/registration" { e, err = reg.RegistrationRequestPersister().GetRegistrationRequest(r.Context(), x.ParseUUID(r.URL.Query().Get("request"))) } else if r.URL.Path == "/settings" { diff --git a/selfservice/strategy/oidc/strategy_test.go b/selfservice/strategy/oidc/strategy_test.go index 716888961435..3bf78242242f 100644 --- a/selfservice/strategy/oidc/strategy_test.go +++ b/selfservice/strategy/oidc/strategy_test.go @@ -90,7 +90,7 @@ func TestStrategy(t *testing.T) { require.NotNil(t, method) config = method.Config.RequestMethodConfigurator.(*form.HTMLForm) require.NotNil(t, config) - } else if req, err := reg.LoginRequestPersister().GetLoginRequest(context.Background(), request); err == nil { + } else if req, err := reg.LoginFlowPersister().GetLoginFlow(context.Background(), request); err == nil { require.EqualValues(t, req.ID, request) method := req.Methods[identity.CredentialsTypeOIDC] require.NotNil(t, method) @@ -176,10 +176,10 @@ func TestStrategy(t *testing.T) { require.NoError(t, err) req.RequestURL = redirectTo req.ExpiresAt = time.Now().Add(exp) - require.NoError(t, reg.LoginRequestPersister().UpdateLoginRequest(context.Background(), req)) + require.NoError(t, reg.LoginFlowPersister().UpdateLoginFlow(context.Background(), req)) // sanity check - got, err := reg.LoginRequestPersister().GetLoginRequest(context.Background(), req.ID) + got, err := reg.LoginFlowPersister().GetLoginFlow(context.Background(), req.ID) require.NoError(t, err) require.Len(t, got.Methods, len(req.Methods)) @@ -399,7 +399,7 @@ func TestStrategy(t *testing.T) { res1, body1 := mrj(t, "valid", afv(t, r1.ID, "valid"), fv, jar) ai(t, res1, body1) r2 := nlr(t, returnTS.URL, time.Minute) - require.NoError(t, reg.LoginRequestPersister().MarkRequestForced(context.Background(), r2.ID)) + require.NoError(t, reg.LoginFlowPersister().ForceLoginFlow(context.Background(), r2.ID)) res2, body2 := mrj(t, "valid", afv(t, r2.ID, "valid"), fv, jar) ai(t, res2, body2) assert.NotEqual(t, gjson.GetBytes(body1, "sid"), gjson.GetBytes(body2, "sid")) diff --git a/selfservice/strategy/password/login.go b/selfservice/strategy/password/login.go index 399f53860ad6..ca11295524f4 100644 --- a/selfservice/strategy/password/login.go +++ b/selfservice/strategy/password/login.go @@ -39,7 +39,7 @@ func (s *Strategy) handleLoginError(w http.ResponseWriter, r *http.Request, rr * } } - s.d.LoginRequestErrorHandler().HandleLoginError(w, r, identity.CredentialsTypePassword, rr, err) + s.d.LoginRequestErrorHandler().WriteFlowError(w, r, identity.CredentialsTypePassword, rr, err) } // nolint:deadcode,unused @@ -102,7 +102,7 @@ func (s *Strategy) handleLogin(w http.ResponseWriter, r *http.Request, _ httprou return } - ar, err := s.d.LoginRequestPersister().GetLoginRequest(r.Context(), rid) + ar, err := s.d.LoginFlowPersister().GetLoginFlow(r.Context(), rid) if err != nil { s.handleLoginError(w, r, nil, nil, err) return diff --git a/selfservice/strategy/password/login_test.go b/selfservice/strategy/password/login_test.go index 6ca74248101b..801a8100f13a 100644 --- a/selfservice/strategy/password/login_test.go +++ b/selfservice/strategy/password/login_test.go @@ -126,7 +126,7 @@ func TestLoginNew(t *testing.T) { fakeRequest := func(t *testing.T, lr *login.Flow, isAPI bool, payload string, forceRequestID *string, jar *cookiejar.Jar) (*http.Response, []byte) { lr.RequestURL = ts.URL - require.NoError(t, reg.LoginRequestPersister().CreateLoginRequest(context.TODO(), lr)) + require.NoError(t, reg.LoginFlowPersister().CreateLoginFlow(context.TODO(), lr)) requestID := lr.ID.String() if forceRequestID != nil { diff --git a/x/err.go b/x/err.go index 97d660eb4b12..46b503897824 100644 --- a/x/err.go +++ b/x/err.go @@ -1,6 +1,7 @@ package x import ( + "errors" "net/http" "github.com/ory/herodot" @@ -12,3 +13,15 @@ var PseudoPanic = herodot.DefaultError{ ReasonField: "The code ended up at a place where it should not have. Please report this as an issue at https://github.com/ory/kratos", CodeField: http.StatusConflict, } + +type StatusCodeCarrier interface { + StatusCode() int +} + +func RecoverStatusCode(err error, fallback int) int { + var sc StatusCodeCarrier + if errors.As(err, &sc) { + return (sc).StatusCode() + } + return fallback +} diff --git a/x/isjsonrequest.go b/x/isjsonrequest.go index d78b1a891ed5..cf7bc22fdc6b 100644 --- a/x/isjsonrequest.go +++ b/x/isjsonrequest.go @@ -6,9 +6,13 @@ import ( "github.com/golang/gddo/httputil" ) +var offers = []string{"text/html", "text/*", "*/*", "application/json"} +var defaultOffer = "text/html" + func IsJSONRequest(r *http.Request) bool { - return httputil.NegotiateContentType(r, - []string{"application/json", "text/html", "text/*", "*/*"}, - "text/*", - ) == "application/json" + return httputil.NegotiateContentType(r, offers, defaultOffer) == "application/json" +} + +func IsBrowserRequest(r *http.Request) bool { + return httputil.NegotiateContentType(r, offers, defaultOffer) == "text/html" } diff --git a/x/isjsonrequest_test.go b/x/isjsonrequest_test.go new file mode 100644 index 000000000000..af93294bafad --- /dev/null +++ b/x/isjsonrequest_test.go @@ -0,0 +1,40 @@ +package x + +import ( + "fmt" + "net/http" + "testing" + + "github.com/golang/gddo/httputil" + "github.com/stretchr/testify/assert" +) + +func TestIsBrowserOrAPIRequest(t *testing.T) { + for k, tc := range []struct { + ua string + h string + e bool + }{ + {ua: "firefox-66", h: "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8", e: true}, + {ua: "safari-chrome", h: "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8", e: true}, + {ua: "ie8", h: "image/jpeg,application/x-ms-application,image/gif,application/xaml+xml,image/pjpeg,application/x-ms-xbap,application/x-shockwave-flash,application/msword,*/*", e: true}, + {ua: "ie8-any", h: "*/*", e: true}, + {ua: "edge", h: "text/html,application/xhtml+xml,image/jxr,*/*", e: true}, + {ua: "opera", h: "text/html,application/xml;q=0.9,application/xhtml+xml,image/png,image/webp,image/jpeg,image/gif,image/x-xbitmap,*/*;q=0.1", e: true}, + {ua: "json-api", h: "application/json", e: false}, + {ua: "no-accept", h: "", e: true}, + } { + t.Run(fmt.Sprintf("case=%d/ua=%s", k, tc.ua), func(t *testing.T) { + r := &http.Request{Header: map[string][]string{"Accept": {tc.h}}} + t.Logf("isBrowser: %s", httputil.NegotiateContentType(r, offers, defaultOffer)) + + t.Logf("isJSON: %s", httputil.NegotiateContentType(r, + []string{"application/json"}, + "text/html", + )) + + assert.Equal(t, tc.e, IsBrowserRequest(r)) + assert.Equal(t, !tc.e, IsJSONRequest(r)) + }) + } +}