Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(SelfService/Strategy/oidc): rework auth session expiry #242

Merged
merged 4 commits into from
Feb 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 37 additions & 4 deletions selfservice/flow/login/error.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
package login

import (
"context"
"fmt"
"github.com/ory/kratos/selfservice/form"
"github.com/ory/x/errorsx"
"net/http"
"net/url"
"time"

"github.com/pkg/errors"

Expand All @@ -18,10 +22,6 @@ import (

var (
ErrHookAbortRequest = errors.New("abort hook")

ErrRequestExpired = herodot.ErrBadRequest.
WithError("login request expired").
WithReasonf(`The login request has expired. Please restart the flow.`)
)

type (
Expand All @@ -31,6 +31,7 @@ type (
x.LoggingProvider

RequestPersistenceProvider
HandlerProvider
}

ErrorHandlerProvider interface{ LoginRequestErrorHandler() *ErrorHandler }
Expand All @@ -39,8 +40,21 @@ type (
d errorHandlerDependencies
c configuration.Provider
}

requestExpiredError struct {
*herodot.DefaultError
}
)

func newRequestExpiredError(since time.Duration) requestExpiredError {
return requestExpiredError{
herodot.ErrBadRequest.
WithError("login request expired").
WithReasonf(`The login request has expired. Please restart the flow.`).
WithReasonf("The login request expired %.2f minutes ago, please try again.", since.Minutes()),
}
}

func NewErrorHandler(d errorHandlerDependencies, c configuration.Provider) *ErrorHandler {
return &ErrorHandler{
d: d,
Expand All @@ -61,6 +75,25 @@ func (s *ErrorHandler) HandleLoginError(
WithField("login_request", rr).
Warn("Encountered login error.")

if _, ok := errorsx.Cause(err).(requestExpiredError); ok {
// create new request because the old one is not valid
if err = s.d.LoginHandler().NewLoginRequest(w, r, func(a *Request) (string, error) {
for name, method := range a.Methods {
method.Config.AddError(&form.Error{Message: "Your session expired, please try again."})
if err := s.d.LoginRequestPersister().UpdateLoginRequest(context.TODO(), a.ID, name, method); err != nil {
return s.d.SelfServiceErrorManager().Create(r.Context(), w, r, err)
}
a.Methods[name] = method
}

return urlx.CopyWithQuery(s.c.LoginURL(), url.Values{"request": {a.ID.String()}}).String(), nil
}); err != nil {
// failed to create a new session and redirect to it, handle that error as a new one
s.HandleLoginError(w, r, ct, rr, err)
}
return
}

if rr == nil {
s.d.SelfServiceErrorManager().Forward(r.Context(), w, r, err)
return
Expand Down
2 changes: 1 addition & 1 deletion selfservice/flow/login/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ func (r Request) TableName() string {

func (r *Request) Valid() error {
if r.ExpiresAt.Before(time.Now()) {
return errors.WithStack(ErrRequestExpired.WithReasonf("The login request expired %.2f minutes ago, please try again.", time.Since(r.ExpiresAt).Minutes()))
return errors.WithStack(newRequestExpiredError(time.Since(r.ExpiresAt)))
}

if r.IssuedAt.After(time.Now()) {
Expand Down
41 changes: 37 additions & 4 deletions selfservice/flow/registration/error.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
package registration

import (
"context"
"fmt"
"github.com/ory/kratos/selfservice/form"
"github.com/ory/x/errorsx"
"net/http"
"net/url"
"time"

"github.com/pkg/errors"

Expand All @@ -18,10 +22,6 @@ import (

var (
ErrHookAbortRequest = errors.New("abort hook")

ErrRequestExpired = herodot.ErrBadRequest.
WithError("registration request expired").
WithReasonf(`The registration request has expired. Please restart the flow.`)
)

type (
Expand All @@ -31,6 +31,7 @@ type (
x.LoggingProvider

RequestPersistenceProvider
HandlerProvider
}

ErrorHandlerProvider interface{ RegistrationRequestErrorHandler() *ErrorHandler }
Expand All @@ -39,8 +40,21 @@ type (
d errorHandlerDependencies
c configuration.Provider
}

requestExpiredError struct {
*herodot.DefaultError
}
)

func newRequestExpiredError(since time.Duration) requestExpiredError {
return requestExpiredError{
herodot.ErrBadRequest.
WithError("registration request expired").
WithReasonf(`The registration request has expired. Please restart the flow.`).
WithReasonf("The registration request expired %.2f minutes ago, please try again.", since.Minutes()),
}
}

func NewErrorHandler(d errorHandlerDependencies, c configuration.Provider) *ErrorHandler {
return &ErrorHandler{
d: d,
Expand All @@ -61,6 +75,25 @@ func (s *ErrorHandler) HandleRegistrationError(
WithField("login_request", rr).
Warn("Encountered login error.")

if _, ok := errorsx.Cause(err).(requestExpiredError); ok {
// create new request because the old one is not valid
if err = s.d.RegistrationHandler().NewRegistrationRequest(w, r, func(a *Request) (string, error) {
for name, method := range a.Methods {
method.Config.AddError(&form.Error{Message: "Your session expired, please try again."})
if err := s.d.RegistrationRequestPersister().UpdateRegistrationRequest(context.TODO(), a.ID, name, method); err != nil {
return s.d.SelfServiceErrorManager().Create(r.Context(), w, r, err)
}
a.Methods[name] = method
}

return urlx.CopyWithQuery(s.c.RegisterURL(), url.Values{"request": {a.ID.String()}}).String(), nil
}); err != nil {
// failed to create a new session and redirect to it, handle that error as a new one
s.HandleRegistrationError(w, r, ct, rr, err)
}
return
}

if rr == nil {
s.d.SelfServiceErrorManager().Forward(r.Context(), w, r, err)
return
Expand Down
2 changes: 1 addition & 1 deletion selfservice/flow/registration/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ func (r *Request) GetID() uuid.UUID {

func (r *Request) Valid() error {
if r.ExpiresAt.Before(time.Now()) {
return errors.WithStack(ErrRequestExpired.WithReasonf("The registration request expired %.2f minutes ago, please try again.", time.Since(r.ExpiresAt).Minutes()))
return errors.WithStack(newRequestExpiredError(time.Since(r.ExpiresAt)))
}
if r.IssuedAt.After(time.Now()) {
return errors.WithStack(herodot.ErrBadRequest.WithReason("The registration request was issued in the future."))
Expand Down
4 changes: 2 additions & 2 deletions selfservice/strategy/oidc/strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ func (s *Strategy) validateRequest(ctx context.Context, rid uuid.UUID) (request,

if ar, err := s.d.RegistrationRequestPersister().GetRegistrationRequest(ctx, rid); err == nil {
if err := ar.Valid(); err != nil {
return nil, err
return ar, err
}
return ar, nil
}
Expand All @@ -205,7 +205,7 @@ func (s *Strategy) validateRequest(ctx context.Context, rid uuid.UUID) (request,
}

if err := ar.Valid(); err != nil {
return nil, err
return ar, err
}

return ar, nil
Expand Down
8 changes: 6 additions & 2 deletions selfservice/strategy/oidc/strategy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -325,13 +325,17 @@ func TestStrategy(t *testing.T) {
t.Run("case=should fail because the login request is expired", func(t *testing.T) {
r := nlr(t, returnTS.URL, -time.Minute)
res, body := mr(t, "valid", r.ID, url.Values{})
aue(t, res, body, "login request expired")

assert.NotEqual(t, r.ID, gjson.GetBytes(body, "id"))
aue(t, res, body, "session expired")
})

t.Run("case=should fail because the registration request is expired", func(t *testing.T) {
r := nrr(t, returnTS.URL, -time.Minute)
res, body := mr(t, "valid", r.ID, url.Values{})
aue(t, res, body, "registration request expired")

assert.NotEqual(t, r.ID, gjson.GetBytes(body, "id"))
aue(t, res, body, "session expired")
})

t.Run("case=should fail registration because scope was not provided", func(t *testing.T) {
Expand Down
17 changes: 1 addition & 16 deletions selfservice/strategy/password/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package password

import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/url"
Expand Down Expand Up @@ -75,21 +74,7 @@ func (s *Strategy) handleLogin(w http.ResponseWriter, r *http.Request, _ httprou
}

if err := ar.Valid(); err != nil {
// create new request if the old one is not valid
if err = s.d.LoginHandler().NewLoginRequest(w, r, func(a *login.Request) (string, error) {
for name, method := range a.Methods {
method.Config.AddError(&form.Error{Message: "Your session expired, please try again."})
if err := s.d.LoginRequestPersister().UpdateLoginRequest(context.TODO(), a.ID, name, method); err != nil {
return s.d.SelfServiceErrorManager().Create(r.Context(), w, r, err)
}
a.Methods[name] = method
}

return urlx.CopyWithQuery(s.c.LoginURL(), url.Values{"request": {a.ID.String()}}).String(), nil
}); err != nil {
s.handleLoginError(w, r, ar, err)
return
}
s.handleLoginError(w, r, ar, err)
return
}

Expand Down
18 changes: 1 addition & 17 deletions selfservice/strategy/password/registration.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package password

import (
"context"
"encoding/json"
"fmt"
"net/http"
Expand Down Expand Up @@ -108,22 +107,7 @@ func (s *Strategy) handleRegistration(w http.ResponseWriter, r *http.Request, _
}

if err := ar.Valid(); err != nil {
// create new request if the old one is not valid
if err = s.d.RegistrationHandler().NewRegistrationRequest(w, r, func(a *registration.Request) (string, error) {
for name, method := range a.Methods {
method.Config.AddError(&form.Error{Message: "Your session expired, please try again."})
if err := s.d.RegistrationRequestPersister().UpdateRegistrationRequest(context.TODO(), a.ID, name, method); err != nil {
return s.d.SelfServiceErrorManager().Create(r.Context(), w, r, err)
}
a.Methods[name] = method
}

return urlx.CopyWithQuery(s.c.RegisterURL(), url.Values{"request": {a.ID.String()}}).String(), nil
}); err != nil {
s.handleRegistrationError(w, r, ar, nil, err)
return
}

s.handleRegistrationError(w, r, ar, nil, err)
return
}

Expand Down