Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(selfservice/login): enable reauthentication functionality #248

Merged
merged 15 commits into from
Feb 28, 2020
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ test-resetdb:

.PHONY: test
test: test-resetdb
source scripts/test-env.sh && go test -tags sqlite ./...
source scripts/test-envs.sh && go test -tags sqlite ./...

# Generates the SDKs
.PHONY: sdk
Expand Down
4 changes: 4 additions & 0 deletions docs/api.swagger.json
Original file line number Diff line number Diff line change
Expand Up @@ -1173,6 +1173,10 @@
"id": {
"$ref": "#/definitions/UUID"
},
"is_reauthentication": {
"description": "IsReauthentication stores whether this login request is a reauthenication request.",
"type": "boolean"
},
"issued_at": {
"description": "IssuedAt is the time (UTC) when the request occurred.",
"type": "string",
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ require (
github.com/golang/gddo v0.0.0-20190904175337-72a348e765d2
github.com/golang/mock v1.3.1
github.com/google/go-github/v27 v27.0.1
github.com/google/go-querystring v1.0.0
github.com/google/uuid v1.1.1
github.com/gorilla/context v1.1.1
github.com/gorilla/securecookie v1.1.1
Expand Down
12 changes: 6 additions & 6 deletions identity/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -405,12 +405,12 @@ func TestPool(p PrivilegedPool) func(t *testing.T) {
}

compare := func(t *testing.T, expected, actual VerifiableAddress) {
actual.CreatedAt = actual.CreatedAt.UTC().Round(time.Hour * 24)
actual.UpdatedAt = actual.UpdatedAt.UTC().Round(time.Hour * 24)
actual.ExpiresAt = actual.ExpiresAt.UTC().Round(time.Hour * 24)
expected.CreatedAt = expected.CreatedAt.UTC().Round(time.Hour * 24)
expected.UpdatedAt = expected.UpdatedAt.UTC().Round(time.Hour * 24)
expected.ExpiresAt = expected.ExpiresAt.UTC().Round(time.Hour * 24)
actual.CreatedAt = actual.CreatedAt.UTC().Truncate(time.Hour * 24)
actual.UpdatedAt = actual.UpdatedAt.UTC().Truncate(time.Hour * 24)
actual.ExpiresAt = actual.ExpiresAt.UTC().Truncate(time.Hour * 24)
expected.CreatedAt = expected.CreatedAt.UTC().Truncate(time.Hour * 24)
expected.UpdatedAt = expected.UpdatedAt.UTC().Truncate(time.Hour * 24)
expected.ExpiresAt = expected.ExpiresAt.UTC().Truncate(time.Hour * 24)
assert.EqualValues(t, expected, actual)
}

Expand Down
3 changes: 3 additions & 0 deletions internal/httpclient/models/login_request.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
drop_column("selfservice_login_requests", "is_reauthentication")
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
add_column("selfservice_login_requests", "is_reauthentication", "bool", {default: false})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to rename this to forced or, alternatively prompt although I think that forced is clearer.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 forced

15 changes: 14 additions & 1 deletion persistence/sql/persister_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,20 @@ func (p *Persister) GetLoginRequest(ctx context.Context, id uuid.UUID) (*login.R
return &r, nil
}

func (p *Persister) UpdateLoginRequest(ctx context.Context, id uuid.UUID, ct identity.CredentialsType, rm *login.RequestMethod) error {
func (p *Persister) UpdateLoginRequestReauth(ctx context.Context, id uuid.UUID, reauth bool) error {
return p.Transaction(ctx, func(tx *pop.Connection) error {
ctx := WithTransaction(ctx, tx)
lr, err := p.GetLoginRequest(ctx, id)
if err != nil {
return err
}

lr.IsReauthentication = reauth
return tx.Save(lr)
})
}

func (p *Persister) UpdateLoginRequestMethod(ctx context.Context, id uuid.UUID, ct identity.CredentialsType, rm *login.RequestMethod) error {
return p.Transaction(ctx, func(tx *pop.Connection) error {
ctx := WithTransaction(ctx, tx)
rr, err := p.GetLoginRequest(ctx, id)
Expand Down
2 changes: 1 addition & 1 deletion persistence/sql/persister_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ func TestPersister_Transaction(t *testing.T) {
err := c.Transaction(func(tx *pop.Connection) error {
ctx := sql.WithTransaction(context.Background(), tx)
require.NoError(t, p.CreateLoginRequest(ctx, lr), "%+v", lr)
require.NoError(t, p.UpdateLoginRequest(ctx, lr.ID, identity.CredentialsTypePassword, &login.RequestMethod{}))
require.NoError(t, p.UpdateLoginRequestMethod(ctx, lr.ID, identity.CredentialsTypePassword, &login.RequestMethod{}))
require.NoError(t, getErr(p.GetLoginRequest(ctx, lr.ID)), "%+v", lr)
return errors.Errorf(errMessage)
})
Expand Down
5 changes: 2 additions & 3 deletions selfservice/flow/login/error.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package login

import (
"context"
"fmt"
"net/http"
"net/url"
Expand Down Expand Up @@ -81,7 +80,7 @@ func (s *ErrorHandler) HandleLoginError(
if err = s.d.LoginHandler().NewLoginRequest(w, r, func(a *Request) (string, error) {
for name, method := range a.Methods {
method.Config.AddError(&form.Error{Message: "Your session expired, please try again."})
if err := s.d.LoginRequestPersister().UpdateLoginRequest(context.TODO(), a.ID, name, method); err != nil {
if err := s.d.LoginRequestPersister().UpdateLoginRequestMethod(r.Context(), a.ID, name, method); err != nil {
return s.d.SelfServiceErrorManager().Create(r.Context(), w, r, err)
}
a.Methods[name] = method
Expand Down Expand Up @@ -114,7 +113,7 @@ func (s *ErrorHandler) HandleLoginError(
return
}

if err := s.d.LoginRequestPersister().UpdateLoginRequest(r.Context(), rr.ID, ct, method); err != nil {
if err := s.d.LoginRequestPersister().UpdateLoginRequestMethod(r.Context(), rr.ID, ct, method); err != nil {
s.d.SelfServiceErrorManager().Forward(r.Context(), w, r, err)
return
}
Expand Down
9 changes: 8 additions & 1 deletion selfservice/flow/login/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ type (
errorx.ManagementProvider
StrategyProvider
session.HandlerProvider
session.ManagementProvider
x.WriterProvider
}
HandlerProvider interface {
Expand All @@ -50,7 +51,7 @@ func (h *Handler) WithTokenGenerator(f func(r *http.Request) string) {
}

func (h *Handler) RegisterPublicRoutes(public *x.RouterPublic) {
public.GET(BrowserLoginPath, h.d.SessionHandler().IsNotAuthenticated(h.initLoginRequest, session.RedirectOnAuthenticated(h.c)))
public.GET(BrowserLoginPath, h.initLoginRequest)
public.GET(BrowserLoginRequestsPath, h.publicFetchLoginRequest)
}

Expand Down Expand Up @@ -110,6 +111,12 @@ func (h *Handler) NewLoginRequest(w http.ResponseWriter, r *http.Request, redir
// 500: genericError
func (h *Handler) initLoginRequest(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
if err := h.NewLoginRequest(w, r, func(a *Request) (string, error) {
// we assume an error means the user has no session
if _, err := h.d.SessionManager().FetchFromRequest(r.Context(), w, r); err == nil && r.URL.Query().Get("prompt") == "true" {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The && doesn't make sense here. Please add a failing test case first.

if err := h.d.LoginRequestPersister().UpdateLoginRequestReauth(r.Context(), a.ID, true); err != nil {
zepatrik marked this conversation as resolved.
Show resolved Hide resolved
return "", err
}
}
return urlx.CopyWithQuery(h.c.LoginURL(), url.Values{"request": {a.ID.String()}}).String(), nil
}); err != nil {
h.d.SelfServiceErrorManager().Forward(r.Context(), w, r, err)
Expand Down
68 changes: 46 additions & 22 deletions selfservice/flow/login/handler_test.go
Original file line number Diff line number Diff line change
@@ -1,67 +1,91 @@
package login_test

import (
"fmt"
"context"
"io/ioutil"
"net/http"
"net/http/cookiejar"
"net/url"
"testing"
"time"

"github.com/gobuffalo/httptest"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"

"github.com/ory/kratos/selfservice/form"
"github.com/ory/viper"

"github.com/ory/kratos/driver/configuration"
"github.com/ory/kratos/internal"
"github.com/ory/kratos/selfservice/errorx"
"github.com/ory/kratos/selfservice/flow/login"
"github.com/ory/kratos/selfservice/strategy/oidc"
"github.com/ory/kratos/selfservice/strategy/password"
"github.com/ory/kratos/session"
"github.com/ory/kratos/x"
)

type (
withTokenGenerator interface {
WithTokenGenerator(g form.CSRFGenerator)
}
)

func init() {
internal.RegisterFakes()
}

func TestEnsureSessionRedirect(t *testing.T) {
func TestHandlerSettingReauth(t *testing.T) {
_, reg := internal.NewRegistryDefault(t)
for _, strategy := range reg.LoginStrategies() {
// We need to know the csrf token
strategy.(withTokenGenerator).WithTokenGenerator(x.FakeCSRFTokenGenerator)
}

router := x.NewRouterPublic()
admin := x.NewRouterAdmin()
reg.LoginHandler().RegisterPublicRoutes(router)
reg.LoginHandler().RegisterAdminRoutes(admin)
reg.LoginHandler().WithTokenGenerator(x.FakeCSRFTokenGenerator)
reg.LoginStrategies().RegisterPublicRoutes(router)
ts := httptest.NewServer(router)
defer ts.Close()

redirTS := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte("already authenticated"))
}))
defer redirTS.Close()
loginTS := httptest.NewServer(login.TestRequestHandler(t, reg))

viper.Set(configuration.ViperKeyURLsDefaultReturnTo, redirTS.URL)
viper.Set(configuration.ViperKeyURLsSelfPublic, ts.URL)
viper.Set(configuration.ViperKeyURLsLogin, loginTS.URL)
viper.Set(configuration.ViperKeyDefaultIdentityTraitsSchemaURL, "file://./stub/login.schema.json")

for k, tc := range [][]string{
{"GET", login.BrowserLoginPath},
t.Run("does not set reauth flag on unauthenticated request", func(t *testing.T) {
c := ts.Client()
res, err := c.Get(ts.URL + login.BrowserLoginPath)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should work for both prompt=true and prompt=false and no prompt

require.NoError(t, err)
defer res.Body.Close()
body, err := ioutil.ReadAll(res.Body)

{"POST", password.LoginPath},
assert.Equal(t, false, gjson.GetBytes(body, "is_reauthentication").Bool(), "%s", body)
})

// it is ok that these contain the parameters as raw strings as we are only interested in checking if the middleware is working
{"POST", oidc.AuthPath},
{"GET", oidc.AuthPath},
{"GET", oidc.CallbackPath},
} {
t.Run(fmt.Sprintf("case=%d/method=%s/path=%s", k, tc[0], tc[1]), func(t *testing.T) {
body, _ := session.MockMakeAuthenticatedRequest(t, reg, router.Router, x.NewTestHTTPRequest(t, tc[0], ts.URL+tc[1], nil))
assert.EqualValues(t, "already authenticated", string(body))
})
}
t.Run("does set reauth flag on authenticated request", func(t *testing.T) {
rid := x.NewUUID()
req := x.NewTestHTTPRequest(t, "GET", ts.URL+login.BrowserLoginPath, nil)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is missing ?prompt=true

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gotcha 👍

loginReq := login.NewLoginRequest(time.Minute, x.FakeCSRFToken, req)
loginReq.ID = rid
for _, s := range reg.LoginStrategies() {
require.NoError(t, s.PopulateLoginMethod(req, loginReq))
}
require.NoError(t, reg.LoginRequestPersister().CreateLoginRequest(context.TODO(), loginReq), "%+v", loginReq)

req.URL.RawQuery = url.Values{
"request": {rid.String()},
"prompt": {"true"},
}.Encode()

body, _ := session.MockMakeAuthenticatedRequest(t, reg, router.Router, req)

assert.Equal(t, true, gjson.GetBytes(body, "is_reauthentication").Bool(), "%s", body)
})
}

func TestLoginHandler(t *testing.T) {
Expand Down
7 changes: 4 additions & 3 deletions selfservice/flow/login/persistence.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ type (
RequestPersister interface {
CreateLoginRequest(context.Context, *Request) error
GetLoginRequest(context.Context, uuid.UUID) (*Request, error)
UpdateLoginRequest(context.Context, uuid.UUID, identity.CredentialsType, *RequestMethod) error
UpdateLoginRequestMethod(context.Context, uuid.UUID, identity.CredentialsType, *RequestMethod) error
UpdateLoginRequestReauth(ctx context.Context, id uuid.UUID, reauth bool) error
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is a one-way method (reauth is false by default and can only be set to true) wouldn't it make sense to name this something like MarkLoginRequestForced? Alternatively UpdateLoginRequest(context.Context, *Request).

}
RequestPersistenceProvider interface {
LoginRequestPersister() RequestPersister
Expand Down Expand Up @@ -97,12 +98,12 @@ func TestRequestPersister(p RequestPersister) func(t *testing.T) {
require.NoError(t, err)
assert.Len(t, actual.Methods, 1)

require.NoError(t, p.UpdateLoginRequest(context.Background(), expected.ID, identity.CredentialsTypeOIDC, &RequestMethod{
require.NoError(t, p.UpdateLoginRequestMethod(context.Background(), expected.ID, identity.CredentialsTypeOIDC, &RequestMethod{
Method: identity.CredentialsTypeOIDC,
Config: &RequestMethodConfig{RequestMethodConfigurator: form.NewHTMLForm(string(identity.CredentialsTypeOIDC))},
}))

require.NoError(t, p.UpdateLoginRequest(context.Background(), expected.ID, identity.CredentialsTypePassword, &RequestMethod{
require.NoError(t, p.UpdateLoginRequestMethod(context.Background(), expected.ID, identity.CredentialsTypePassword, &RequestMethod{
Method: identity.CredentialsTypePassword,
Config: &RequestMethodConfig{RequestMethodConfigurator: form.NewHTMLForm(string(identity.CredentialsTypePassword))},
}))
Expand Down
24 changes: 24 additions & 0 deletions selfservice/flow/login/request.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
package login

import (
"context"
"net/http"
"testing"
"time"

"github.com/stretchr/testify/require"

"github.com/gobuffalo/pop/v5"
"github.com/gofrs/uuid"
"github.com/pkg/errors"
Expand Down Expand Up @@ -61,6 +65,9 @@ type Request struct {

// CSRFToken contains the anti-csrf token associated with this request.
CSRFToken string `json:"-" db:"csrf_token"`

// IsReauthentication stores whether this login request is a reauthenication request.
IsReauthentication bool `json:"is_reauthentication" db:"is_reauthentication"`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think forced is shorter and as concise :)

}

func NewLoginRequest(exp time.Duration, csrf string, r *http.Request) *Request {
Expand Down Expand Up @@ -130,3 +137,20 @@ func (r *Request) Valid() error {
func (r *Request) GetID() uuid.UUID {
return r.ID
}

func (r *Request) IsReauth() bool {
return r.IsReauthentication
}

type testRequestHandlerDependencies interface {
RequestPersistenceProvider
x.WriterProvider
}

func TestRequestHandler(t *testing.T, reg testRequestHandlerDependencies) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
e, err := reg.LoginRequestPersister().GetLoginRequest(context.Background(), x.ParseUUID(r.URL.Query().Get("request")))
require.NoError(t, err)
reg.Writer().Write(w, r, e)
}
}
24 changes: 5 additions & 19 deletions selfservice/flow/registration/handler_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package registration_test

import (
"fmt"
"net/http"
"net/http/cookiejar"
"net/http/httptest"
Expand All @@ -17,8 +16,6 @@ import (
"github.com/ory/kratos/internal"
"github.com/ory/kratos/selfservice/errorx"
"github.com/ory/kratos/selfservice/flow/registration"
"github.com/ory/kratos/selfservice/strategy/oidc"
"github.com/ory/kratos/selfservice/strategy/password"
"github.com/ory/kratos/session"
"github.com/ory/kratos/x"
)
Expand All @@ -27,7 +24,7 @@ func init() {
internal.RegisterFakes()
}

func TestEnsureSessionRedirect(t *testing.T) {
func TestHandlerRedirectOnAuthenticated(t *testing.T) {
_, reg := internal.NewRegistryDefault(t)

router := x.NewRouterPublic()
Expand All @@ -45,21 +42,10 @@ func TestEnsureSessionRedirect(t *testing.T) {
viper.Set(configuration.ViperKeyURLsSelfPublic, ts.URL)
viper.Set(configuration.ViperKeyDefaultIdentityTraitsSchemaURL, "file://./stub/registration.schema.json")

for k, tc := range [][]string{
{"GET", registration.BrowserRegistrationPath},

{"POST", password.RegistrationPath},

// it is ok that these contain the parameters as arw strings as we are only interested in checking if the middleware is working
{"POST", oidc.AuthPath},
{"GET", oidc.AuthPath},
{"GET", oidc.CallbackPath},
} {
t.Run(fmt.Sprintf("case=%d/method=%s/path=%s", k, tc[0], tc[1]), func(t *testing.T) {
body, _ := session.MockMakeAuthenticatedRequest(t, reg, router.Router, x.NewTestHTTPRequest(t, tc[0], ts.URL+tc[1], nil))
assert.EqualValues(t, "already authenticated", string(body))
})
}
t.Run("does redirect to default on authenticated request", func(t *testing.T) {
body, _ := session.MockMakeAuthenticatedRequest(t, reg, router.Router, x.NewTestHTTPRequest(t, "GET", ts.URL+registration.BrowserRegistrationPath, nil))
assert.EqualValues(t, "already authenticated", string(body))
})
}

func TestRegistrationHandler(t *testing.T) {
Expand Down
4 changes: 4 additions & 0 deletions selfservice/flow/registration/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,10 @@ func (r *Request) GetID() uuid.UUID {
return r.ID
}

func (r *Request) IsReauth() bool {
return false
}

func (r *Request) Valid() error {
if r.ExpiresAt.Before(time.Now()) {
return errors.WithStack(newRequestExpiredError(time.Since(r.ExpiresAt)))
Expand Down
Loading