-
-
Notifications
You must be signed in to change notification settings - Fork 964
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(selfservice/login): enable reauthentication functionality #248
Changes from 13 commits
2480ab2
6203e6b
c0400c8
a75bd82
f38157b
5c7d151
88f612e
2633b01
a96cc4b
e0a973c
502ed83
f449f0b
56b2e37
8492d11
550f1db
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
drop_column("selfservice_login_requests", "is_reauthentication") |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
add_column("selfservice_login_requests", "is_reauthentication", "bool", {default: false}) | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,6 +29,7 @@ type ( | |
errorx.ManagementProvider | ||
StrategyProvider | ||
session.HandlerProvider | ||
session.ManagementProvider | ||
x.WriterProvider | ||
} | ||
HandlerProvider interface { | ||
|
@@ -50,7 +51,7 @@ func (h *Handler) WithTokenGenerator(f func(r *http.Request) string) { | |
} | ||
|
||
func (h *Handler) RegisterPublicRoutes(public *x.RouterPublic) { | ||
public.GET(BrowserLoginPath, h.d.SessionHandler().IsNotAuthenticated(h.initLoginRequest, session.RedirectOnAuthenticated(h.c))) | ||
public.GET(BrowserLoginPath, h.initLoginRequest) | ||
public.GET(BrowserLoginRequestsPath, h.publicFetchLoginRequest) | ||
} | ||
|
||
|
@@ -110,6 +111,12 @@ func (h *Handler) NewLoginRequest(w http.ResponseWriter, r *http.Request, redir | |
// 500: genericError | ||
func (h *Handler) initLoginRequest(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { | ||
if err := h.NewLoginRequest(w, r, func(a *Request) (string, error) { | ||
// we assume an error means the user has no session | ||
if _, err := h.d.SessionManager().FetchFromRequest(r.Context(), w, r); err == nil && r.URL.Query().Get("prompt") == "true" { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The && doesn't make sense here. Please add a failing test case first. |
||
if err := h.d.LoginRequestPersister().UpdateLoginRequestReauth(r.Context(), a.ID, true); err != nil { | ||
zepatrik marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return "", err | ||
} | ||
} | ||
return urlx.CopyWithQuery(h.c.LoginURL(), url.Values{"request": {a.ID.String()}}).String(), nil | ||
}); err != nil { | ||
h.d.SelfServiceErrorManager().Forward(r.Context(), w, r, err) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,67 +1,91 @@ | ||
package login_test | ||
|
||
import ( | ||
"fmt" | ||
"context" | ||
"io/ioutil" | ||
"net/http" | ||
"net/http/cookiejar" | ||
"net/url" | ||
"testing" | ||
"time" | ||
|
||
"github.com/gobuffalo/httptest" | ||
"github.com/stretchr/testify/assert" | ||
"github.com/stretchr/testify/require" | ||
"github.com/tidwall/gjson" | ||
|
||
"github.com/ory/kratos/selfservice/form" | ||
"github.com/ory/viper" | ||
|
||
"github.com/ory/kratos/driver/configuration" | ||
"github.com/ory/kratos/internal" | ||
"github.com/ory/kratos/selfservice/errorx" | ||
"github.com/ory/kratos/selfservice/flow/login" | ||
"github.com/ory/kratos/selfservice/strategy/oidc" | ||
"github.com/ory/kratos/selfservice/strategy/password" | ||
"github.com/ory/kratos/session" | ||
"github.com/ory/kratos/x" | ||
) | ||
|
||
type ( | ||
withTokenGenerator interface { | ||
WithTokenGenerator(g form.CSRFGenerator) | ||
} | ||
) | ||
|
||
func init() { | ||
internal.RegisterFakes() | ||
} | ||
|
||
func TestEnsureSessionRedirect(t *testing.T) { | ||
func TestHandlerSettingReauth(t *testing.T) { | ||
_, reg := internal.NewRegistryDefault(t) | ||
for _, strategy := range reg.LoginStrategies() { | ||
// We need to know the csrf token | ||
strategy.(withTokenGenerator).WithTokenGenerator(x.FakeCSRFTokenGenerator) | ||
} | ||
|
||
router := x.NewRouterPublic() | ||
admin := x.NewRouterAdmin() | ||
reg.LoginHandler().RegisterPublicRoutes(router) | ||
reg.LoginHandler().RegisterAdminRoutes(admin) | ||
reg.LoginHandler().WithTokenGenerator(x.FakeCSRFTokenGenerator) | ||
reg.LoginStrategies().RegisterPublicRoutes(router) | ||
ts := httptest.NewServer(router) | ||
defer ts.Close() | ||
|
||
redirTS := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
_, _ = w.Write([]byte("already authenticated")) | ||
})) | ||
defer redirTS.Close() | ||
loginTS := httptest.NewServer(login.TestRequestHandler(t, reg)) | ||
|
||
viper.Set(configuration.ViperKeyURLsDefaultReturnTo, redirTS.URL) | ||
viper.Set(configuration.ViperKeyURLsSelfPublic, ts.URL) | ||
viper.Set(configuration.ViperKeyURLsLogin, loginTS.URL) | ||
viper.Set(configuration.ViperKeyDefaultIdentityTraitsSchemaURL, "file://./stub/login.schema.json") | ||
|
||
for k, tc := range [][]string{ | ||
{"GET", login.BrowserLoginPath}, | ||
t.Run("does not set reauth flag on unauthenticated request", func(t *testing.T) { | ||
c := ts.Client() | ||
res, err := c.Get(ts.URL + login.BrowserLoginPath) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should work for both |
||
require.NoError(t, err) | ||
defer res.Body.Close() | ||
body, err := ioutil.ReadAll(res.Body) | ||
|
||
{"POST", password.LoginPath}, | ||
assert.Equal(t, false, gjson.GetBytes(body, "is_reauthentication").Bool(), "%s", body) | ||
}) | ||
|
||
// it is ok that these contain the parameters as raw strings as we are only interested in checking if the middleware is working | ||
{"POST", oidc.AuthPath}, | ||
{"GET", oidc.AuthPath}, | ||
{"GET", oidc.CallbackPath}, | ||
} { | ||
t.Run(fmt.Sprintf("case=%d/method=%s/path=%s", k, tc[0], tc[1]), func(t *testing.T) { | ||
body, _ := session.MockMakeAuthenticatedRequest(t, reg, router.Router, x.NewTestHTTPRequest(t, tc[0], ts.URL+tc[1], nil)) | ||
assert.EqualValues(t, "already authenticated", string(body)) | ||
}) | ||
} | ||
t.Run("does set reauth flag on authenticated request", func(t *testing.T) { | ||
rid := x.NewUUID() | ||
req := x.NewTestHTTPRequest(t, "GET", ts.URL+login.BrowserLoginPath, nil) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is missing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. gotcha 👍 |
||
loginReq := login.NewLoginRequest(time.Minute, x.FakeCSRFToken, req) | ||
loginReq.ID = rid | ||
for _, s := range reg.LoginStrategies() { | ||
require.NoError(t, s.PopulateLoginMethod(req, loginReq)) | ||
} | ||
require.NoError(t, reg.LoginRequestPersister().CreateLoginRequest(context.TODO(), loginReq), "%+v", loginReq) | ||
|
||
req.URL.RawQuery = url.Values{ | ||
"request": {rid.String()}, | ||
"prompt": {"true"}, | ||
}.Encode() | ||
|
||
body, _ := session.MockMakeAuthenticatedRequest(t, reg, router.Router, req) | ||
|
||
assert.Equal(t, true, gjson.GetBytes(body, "is_reauthentication").Bool(), "%s", body) | ||
}) | ||
} | ||
|
||
func TestLoginHandler(t *testing.T) { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,7 +18,8 @@ type ( | |
RequestPersister interface { | ||
CreateLoginRequest(context.Context, *Request) error | ||
GetLoginRequest(context.Context, uuid.UUID) (*Request, error) | ||
UpdateLoginRequest(context.Context, uuid.UUID, identity.CredentialsType, *RequestMethod) error | ||
UpdateLoginRequestMethod(context.Context, uuid.UUID, identity.CredentialsType, *RequestMethod) error | ||
UpdateLoginRequestReauth(ctx context.Context, id uuid.UUID, reauth bool) error | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since this is a one-way method (reauth is false by default and can only be set to true) wouldn't it make sense to name this something like |
||
} | ||
RequestPersistenceProvider interface { | ||
LoginRequestPersister() RequestPersister | ||
|
@@ -97,12 +98,12 @@ func TestRequestPersister(p RequestPersister) func(t *testing.T) { | |
require.NoError(t, err) | ||
assert.Len(t, actual.Methods, 1) | ||
|
||
require.NoError(t, p.UpdateLoginRequest(context.Background(), expected.ID, identity.CredentialsTypeOIDC, &RequestMethod{ | ||
require.NoError(t, p.UpdateLoginRequestMethod(context.Background(), expected.ID, identity.CredentialsTypeOIDC, &RequestMethod{ | ||
Method: identity.CredentialsTypeOIDC, | ||
Config: &RequestMethodConfig{RequestMethodConfigurator: form.NewHTMLForm(string(identity.CredentialsTypeOIDC))}, | ||
})) | ||
|
||
require.NoError(t, p.UpdateLoginRequest(context.Background(), expected.ID, identity.CredentialsTypePassword, &RequestMethod{ | ||
require.NoError(t, p.UpdateLoginRequestMethod(context.Background(), expected.ID, identity.CredentialsTypePassword, &RequestMethod{ | ||
Method: identity.CredentialsTypePassword, | ||
Config: &RequestMethodConfig{RequestMethodConfigurator: form.NewHTMLForm(string(identity.CredentialsTypePassword))}, | ||
})) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,13 @@ | ||
package login | ||
|
||
import ( | ||
"context" | ||
"net/http" | ||
"testing" | ||
"time" | ||
|
||
"github.com/stretchr/testify/require" | ||
|
||
"github.com/gobuffalo/pop/v5" | ||
"github.com/gofrs/uuid" | ||
"github.com/pkg/errors" | ||
|
@@ -61,6 +65,9 @@ type Request struct { | |
|
||
// CSRFToken contains the anti-csrf token associated with this request. | ||
CSRFToken string `json:"-" db:"csrf_token"` | ||
|
||
// IsReauthentication stores whether this login request is a reauthenication request. | ||
IsReauthentication bool `json:"is_reauthentication" db:"is_reauthentication"` | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think |
||
} | ||
|
||
func NewLoginRequest(exp time.Duration, csrf string, r *http.Request) *Request { | ||
|
@@ -130,3 +137,20 @@ func (r *Request) Valid() error { | |
func (r *Request) GetID() uuid.UUID { | ||
return r.ID | ||
} | ||
|
||
func (r *Request) IsReauth() bool { | ||
return r.IsReauthentication | ||
} | ||
|
||
type testRequestHandlerDependencies interface { | ||
RequestPersistenceProvider | ||
x.WriterProvider | ||
} | ||
|
||
func TestRequestHandler(t *testing.T, reg testRequestHandlerDependencies) http.HandlerFunc { | ||
return func(w http.ResponseWriter, r *http.Request) { | ||
e, err := reg.LoginRequestPersister().GetLoginRequest(context.Background(), x.ParseUUID(r.URL.Query().Get("request"))) | ||
require.NoError(t, err) | ||
reg.Writer().Write(w, r, e) | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would like to rename this to
forced
or, alternativelyprompt
although I think that forced is clearer.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 forced