Skip to content

Commit

Permalink
feat: add and test id hint in reauth flow
Browse files Browse the repository at this point in the history
Closes #323
  • Loading branch information
aeneasr committed Apr 10, 2020
1 parent 56a44fa commit 2298f01
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 30 deletions.
3 changes: 2 additions & 1 deletion continuity/persistence.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func TestPersister(p interface {
}

var createContainer = func(t *testing.T) Container {
m := sqlxx.NullJSONRawMessage(`{"foo":"bar"}`)
m := sqlxx.NullJSONRawMessage(`{"foo": "bar"}`)
return Container{Name: "foo", IdentityID: x.PointToUUID(createIdentity(t).ID),
ExpiresAt: time.Now().Add(time.Hour).UTC().Truncate(time.Second),
Payload: m,
Expand All @@ -56,6 +56,7 @@ func TestPersister(p interface {

actual, err := p.GetContinuitySession(context.Background(), expected.ID)
require.NoError(t, err)
actual.UpdatedAt, actual.CreatedAt, expected.UpdatedAt, expected.CreatedAt = time.Time{}, time.Time{}, time.Time{}, time.Time{}
assert.EqualValues(t, expected.UTC(), actual.UTC())
})

Expand Down
2 changes: 1 addition & 1 deletion selfservice/flow/login/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ func (h *Handler) initLoginRequest(w http.ResponseWriter, r *http.Request, ps ht
return urlx.CopyWithQuery(h.c.LoginURL(), url.Values{"request": {a.ID.String()}}).String(), nil
}

if r.URL.Query().Get("prompt") == "login" {
if a.Forced {
if err := h.d.LoginRequestPersister().MarkRequestForced(r.Context(), a.ID); err != nil {
return "", err
}
Expand Down
1 change: 1 addition & 0 deletions selfservice/flow/login/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ func NewLoginRequest(exp time.Duration, csrf string, r *http.Request) *Request {
RequestURL: source.String(),
Methods: map[identity.CredentialsType]*RequestMethod{},
CSRFToken: csrf,
Forced: r.URL.Query().Get("prompt") == "login",
}
}

Expand Down
5 changes: 5 additions & 0 deletions selfservice/strategy/password/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,19 @@ func (s *Strategy) PopulateLoginMethod(r *http.Request, sr *login.Request) error

var identifier string
if !sr.IsForced() {
print("forced")
// do nothing
} else if sess, err := s.d.SessionManager().FetchFromRequest(r.Context(), r); err != nil {
print("sm")
// do nothing
} else if id, err := s.d.PrivilegedIdentityPool().GetIdentityConfidential(r.Context(), sess.IdentityID); err != nil {
print("confidential")
// do nothing
} else if creds, ok := id.GetCredentials(s.ID()); !ok {
print("nocreds")
// do nothing
} else if len(creds.Identifiers) == 0 {
print("noids")
// do nothing
} else {
identifier = creds.Identifiers[0]
Expand Down
120 changes: 92 additions & 28 deletions selfservice/strategy/password/login_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,29 +100,50 @@ func TestLoginNew(t *testing.T) {
viper.Set(configuration.ViperKeySecretsSession, []string{"not-a-secure-session-key"})
viper.Set(configuration.ViperKeyURLsDefaultReturnTo, returnTs.URL+"/return-ts")

makeRequest := func(lr *login.Request, payload string, forceRequestID *string, jar *cookiejar.Jar) (*http.Response, []byte) {
lr.RequestURL = ts.URL
require.NoError(t, reg.LoginRequestPersister().CreateLoginRequest(context.TODO(), lr))
mr := func(t *testing.T, payload string, requestID string, c *http.Client) (*http.Response, []byte) {
res, err := c.Post(ts.URL+password.LoginPath+"?request="+requestID, "application/x-www-form-urlencoded", strings.NewReader(payload))
require.NoError(t, err)
defer res.Body.Close()
require.EqualValues(t, http.StatusOK, res.StatusCode, "Request: %+v\n\t\tResponse: %s", res.Request, res)
body, err := ioutil.ReadAll(res.Body)
require.NoError(t, err)
return res, body
}

c := ts.Client()
makeRequest := func(t *testing.T, payload string, jar *cookiejar.Jar, force bool) (*http.Response, []byte) {
c := &http.Client{Jar: jar}
if jar == nil {
c.Jar, _ = cookiejar.New(&cookiejar.Options{})
} else {
c.Jar = jar
}

u := ts.URL + login.BrowserLoginPath
if force {
u = u + "?prompt=login"
}

res, err := c.Get(u)
require.NoError(t, err)
require.EqualValues(t, http.StatusOK, res.StatusCode, "Request: %+v\n\t\tResponse: %s", res.Request, res)
assert.NotEmpty(t, res.Request.URL.Query().Get("request"))

return mr(t, payload, res.Request.URL.Query().Get("request"), c)
}

fakeRequest := func(t *testing.T, lr *login.Request, payload string, forceRequestID *string, jar *cookiejar.Jar) (*http.Response, []byte) {
lr.RequestURL = ts.URL
require.NoError(t, reg.LoginRequestPersister().CreateLoginRequest(context.TODO(), lr))

requestID := lr.ID.String()
if forceRequestID != nil {
requestID = *forceRequestID
}

res, err := c.Post(ts.URL+password.LoginPath+"?request="+requestID, "application/x-www-form-urlencoded", strings.NewReader(payload))
require.NoError(t, err)
defer res.Body.Close()
require.EqualValues(t, http.StatusOK, res.StatusCode, "Request: %+v\n\t\tResponse: %s", res.Request, res)
body, err := ioutil.ReadAll(res.Body)
require.NoError(t, err)
return res, body
c := &http.Client{Jar: jar}
if jar == nil {
c.Jar, _ = cookiejar.New(&cookiejar.Options{})
}

return mr(t, payload, requestID, c)
}

ensureFieldsExist := func(t *testing.T, body []byte) {
Expand All @@ -148,7 +169,7 @@ func TestLoginNew(t *testing.T) {

t.Run("should show the error ui because the request is malformed", func(t *testing.T) {
lr := nlr(0)
res, body := makeRequest(lr, "14=)=!(%)$/ZP()GHIÖ", nil, nil)
res, body := fakeRequest(t, lr, "14=)=!(%)$/ZP()GHIÖ", nil, nil)

require.Contains(t, res.Request.URL.Path, "login-ts", "%+v", res.Request)
assert.Equal(t, lr.ID.String(), gjson.GetBytes(body, "id").String(), "%s", body)
Expand All @@ -158,7 +179,7 @@ func TestLoginNew(t *testing.T) {

t.Run("should show the error ui because the request id missing", func(t *testing.T) {
lr := nlr(time.Minute)
res, body := makeRequest(lr, url.Values{}.Encode(), pointerx.String(""), nil)
res, body := fakeRequest(t, lr, url.Values{}.Encode(), pointerx.String(""), nil)

require.Contains(t, res.Request.URL.Path, "error-ts")
assert.Equal(t, int64(http.StatusBadRequest), gjson.GetBytes(body, "0.code").Int(), "%s", body)
Expand All @@ -168,7 +189,7 @@ func TestLoginNew(t *testing.T) {

t.Run("should return an error because the request does not exist", func(t *testing.T) {
lr := nlr(0)
res, body := makeRequest(lr, url.Values{
res, body := fakeRequest(t, lr, url.Values{
"identifier": {"identifier"},
"password": {"password"},
}.Encode(), pointerx.String(x.NewUUID().String()), nil)
Expand All @@ -181,7 +202,7 @@ func TestLoginNew(t *testing.T) {

t.Run("should redirect to login init because the request is expired", func(t *testing.T) {
lr := nlr(-time.Hour)
res, body := makeRequest(lr, url.Values{
res, body := fakeRequest(t, lr, url.Values{
"identifier": {"identifier"},
"password": {"password"},
}.Encode(), nil, nil)
Expand All @@ -194,7 +215,7 @@ func TestLoginNew(t *testing.T) {

t.Run("should return an error because the credentials are invalid (user does not exist)", func(t *testing.T) {
lr := nlr(time.Hour)
res, body := makeRequest(lr, url.Values{
res, body := fakeRequest(t, lr, url.Values{
"identifier": {"identifier"},
"password": {"password"},
}.Encode(), nil, nil)
Expand All @@ -207,7 +228,7 @@ func TestLoginNew(t *testing.T) {

t.Run("should return an error because no identifier is set", func(t *testing.T) {
lr := nlr(time.Hour)
res, body := makeRequest(lr, url.Values{
res, body := fakeRequest(t, lr, url.Values{
"password": {"password"},
}.Encode(), nil, nil)

Expand All @@ -224,7 +245,7 @@ func TestLoginNew(t *testing.T) {

t.Run("should return an error because no password is set", func(t *testing.T) {
lr := nlr(time.Hour)
res, body := makeRequest(lr, url.Values{
res, body := fakeRequest(t, lr, url.Values{
"identifier": {"identifier"},
}.Encode(), nil, nil)

Expand All @@ -247,7 +268,7 @@ func TestLoginNew(t *testing.T) {
createIdentity(identifier, pwd)

lr := nlr(time.Hour)
res, body := makeRequest(lr, url.Values{
res, body := fakeRequest(t, lr, url.Values{
"identifier": {identifier},
"password": {"not-password"},
}.Encode(), nil, nil)
Expand All @@ -267,12 +288,12 @@ func TestLoginNew(t *testing.T) {
assert.Empty(t, gjson.GetBytes(body, "methods.password.config.fields.#(name==password).value").String())
})

t.Run("should pass because everything is a-ok", func(t *testing.T) {
t.Run("should pass because with fake request", func(t *testing.T) {
identifier, pwd := "login-identifier-7", "password"
createIdentity(identifier, pwd)

lr := nlr(time.Hour)
res, body := makeRequest(lr, url.Values{
res, body := fakeRequest(t, lr, url.Values{
"identifier": {identifier},
"password": {pwd},
}.Encode(), nil, nil)
Expand All @@ -281,6 +302,49 @@ func TestLoginNew(t *testing.T) {
assert.Equal(t, identifier, gjson.GetBytes(body, "identity.traits.subject").String(), "%s", body)
})

t.Run("should pass with real request", func(t *testing.T) {
identifier, pwd := "login-identifier-7", "password"
createIdentity(identifier, pwd)

jar, _ := cookiejar.New(nil)
res, body := makeRequest(t, url.Values{
"identifier": {identifier},
"password": {pwd},
}.Encode(), jar, true)

require.Contains(t, res.Request.URL.Path, "return-ts", "%s", res.Request.URL.String())
assert.Equal(t, identifier, gjson.GetBytes(body, "identity.traits.subject").String(), "%s", body)

t.Run("retry with different prompts", func(t *testing.T) {
c := &http.Client{Jar: jar}

t.Run("redirect to returnTS if prompt is missing", func(t *testing.T) {
res, err := c.Get(ts.URL + login.BrowserLoginPath)
require.NoError(t, err)
require.EqualValues(t, http.StatusOK, res.StatusCode)
})

t.Run("show UI and hint at username", func(t *testing.T) {
res, err := c.Get(ts.URL + login.BrowserLoginPath + "?prompt=login")
require.NoError(t, err)
require.EqualValues(t, http.StatusOK, res.StatusCode)

rid := res.Request.URL.Query().Get("request")
assert.NotEmpty(t, rid, "%s", res.Request.URL)

res, err = c.Get(ts.URL + login.BrowserLoginRequestsPath + "?request=" + rid)
require.NoError(t, err)
require.EqualValues(t, http.StatusOK, res.StatusCode)

body, err := ioutil.ReadAll(res.Body)
require.NoError(t, err)
assert.True(t, gjson.GetBytes(body, "forced").Bool())
assert.Equal(t, identifier, gjson.GetBytes(body, "methods.password.config.fields.#(name==identifier).value").String(), "%s", body)
assert.Empty(t, gjson.GetBytes(body, "methods.password.config.fields.#(name==password).value").String(), "%s", body)
})
})
})

t.Run("should return an error because not passing validation and reset previous errors and values", func(t *testing.T) {
lr := &login.Request{
ID: x.NewUUID(),
Expand Down Expand Up @@ -313,7 +377,7 @@ func TestLoginNew(t *testing.T) {
},
}

res, body := makeRequest(lr, url.Values{
res, body := fakeRequest(t, lr, url.Values{
"identifier": {"registration-identifier-9"},
// "password": {uuid.New().String()},
}.Encode(), nil, nil)
Expand All @@ -335,14 +399,14 @@ func TestLoginNew(t *testing.T) {

jar, err := cookiejar.New(&cookiejar.Options{})
require.NoError(t, err)
_, body1 := makeRequest(nlr(time.Hour), url.Values{
_, body1 := fakeRequest(t, nlr(time.Hour), url.Values{
"identifier": {identifier},
"password": {pwd},
}.Encode(), nil, jar)

lr2 := nlr(time.Hour)
lr2.Forced = true
res, body2 := makeRequest(lr2, url.Values{
res, body2 := fakeRequest(t, lr2, url.Values{
"identifier": {identifier},
"password": {pwd},
}.Encode(), nil, jar)
Expand All @@ -358,13 +422,13 @@ func TestLoginNew(t *testing.T) {

jar, err := cookiejar.New(&cookiejar.Options{})
require.NoError(t, err)
_, body1 := makeRequest(nlr(time.Hour), url.Values{
_, body1 := fakeRequest(t, nlr(time.Hour), url.Values{
"identifier": {identifier},
"password": {pwd},
}.Encode(), nil, jar)

lr2 := nlr(time.Hour)
res, body2 := makeRequest(lr2, url.Values{
res, body2 := fakeRequest(t, lr2, url.Values{
"identifier": {identifier},
"password": {pwd},
}.Encode(), nil, jar)
Expand Down

0 comments on commit 2298f01

Please sign in to comment.