diff --git a/consent/handler.go b/consent/handler.go index 81dd69b7541..6f96e3240b7 100644 --- a/consent/handler.go +++ b/consent/handler.go @@ -695,6 +695,14 @@ func (h *Handler) acceptOAuth2ConsentRequest(w http.ResponseWriter, r *http.Requ h.r.Writer().WriteError(w, r, errorsx.WithStack(err)) return } else if hr.Skip { + if p.Remember && p.RememberFor > 0 { // TODO: Consider removing 'p.RememberFor > 0' to update consent validity in both ways (limited (RememberFor > 0) -> indefinitely (RememberFor = 0) and vice versa) + var ctx = r.Context() + err = h.r.ConsentManager().ExtendConsentRequest(r.Context(), h.r.Config().GetScopeStrategy(ctx), hr, p.RememberFor) + if err != nil { + h.r.Writer().WriteError(w, r, errorsx.WithStack(err)) + return + } + } p.Remember = false } diff --git a/consent/handler_test.go b/consent/handler_test.go index 6022674eeb1..174242900d2 100644 --- a/consent/handler_test.go +++ b/consent/handler_test.go @@ -13,10 +13,9 @@ import ( "testing" "time" - "github.com/ory/x/pointerx" - "github.com/ory/hydra/v2/x" "github.com/ory/x/contextx" + "github.com/ory/x/pointerx" "github.com/ory/x/sqlxx" "github.com/ory/hydra/v2/internal" @@ -213,6 +212,98 @@ func TestGetConsentRequest(t *testing.T) { } } +func TestExtendConsentRequest(t *testing.T) { + t.Run("case=extend consent expiry time", func(t *testing.T) { + conf := internal.NewConfigurationWithDefaults() + reg := internal.NewRegistryMemory(t, conf, &contextx.Default{}) + h := NewHandler(reg, conf) + r := x.NewRouterAdmin(conf.AdminURL) + h.SetRoutes(r) + ts := httptest.NewServer(r) + defer ts.Close() + + c := &http.Client{} + cl := &client.Client{LegacyClientID: "client-1"} + require.NoError(t, reg.ClientManager().CreateClient(context.Background(), cl)) + + var initialRememberFor time.Duration = 300 + var remainingValidTime time.Duration = 100 + + require.NoError(t, reg.ConsentManager().CreateLoginSession(context.Background(), &LoginSession{ + ID: makeID("fk-login-session", "1", "1"), + Subject: "subject-1", + })) + requestedTimeInPast := time.Now().UTC().Add(-(initialRememberFor - remainingValidTime) * time.Second) + require.NoError(t, reg.ConsentManager().CreateLoginRequest(context.Background(), &LoginRequest{ + ID: makeID("challenge", "1", "1"), + SessionID: sqlxx.NullString(makeID("fk-login-session", "1", "1")), + Client: cl, + Subject: "subject-1", + RequestedAt: requestedTimeInPast, + })) + require.NoError(t, reg.ConsentManager().CreateConsentRequest(context.Background(), &OAuth2ConsentRequest{ + ID: makeID("challenge", "1", "1"), + Subject: "subject-1", + Client: cl, + LoginSessionID: sqlxx.NullString(makeID("fk-login-session", "1", "1")), + LoginChallenge: sqlxx.NullString(makeID("challenge", "1", "1")), + Verifier: makeID("verifier", "1", "1"), + CSRF: "csrf1", + Skip: false, + ACR: "1", + })) + _, err := reg.ConsentManager().HandleConsentRequest(context.Background(), &AcceptOAuth2ConsentRequest{ + ID: makeID("challenge", "1", "1"), + Remember: true, + RememberFor: int(initialRememberFor), + WasHandled: true, + HandledAt: sqlxx.NullTime(time.Now().UTC()), + }) + require.NoError(t, err) + + require.NoError(t, reg.ConsentManager().CreateLoginRequest(context.Background(), &LoginRequest{ + ID: makeID("challenge", "1", "2"), + SessionID: sqlxx.NullString(makeID("fk-login-session", "1", "1")), + Verifier: makeID("verifier", "1", "1"), + Client: cl, + RequestedAt: time.Now().UTC(), + Subject: "subject-1", + })) + require.NoError(t, reg.ConsentManager().CreateConsentRequest(context.Background(), &OAuth2ConsentRequest{ + ID: makeID("challenge", "1", "2"), + Subject: "subject-1", + Client: cl, + LoginSessionID: sqlxx.NullString(makeID("fk-login-session", "1", "1")), + LoginChallenge: sqlxx.NullString(makeID("challenge", "1", "2")), + Verifier: makeID("verifier", "1", "2"), + CSRF: "csrf2", + Skip: true, + })) + + var b bytes.Buffer + var extendRememberFor time.Duration = 300 + require.NoError(t, json.NewEncoder(&b).Encode(&AcceptOAuth2ConsentRequest{ + Remember: true, + RememberFor: int(extendRememberFor), + })) + + req, err := http.NewRequest(http.MethodPut, ts.URL+"/admin"+ConsentPath+"/accept?challenge=challenge-1-2", &b) + require.NoError(t, err) + resp, err := c.Do(req) + require.NoError(t, err) + require.EqualValues(t, 200, resp.StatusCode) + + crs, err := reg.ConsentManager().FindSubjectsGrantedConsentRequests(context.Background(), "subject-1", 100, 0) + require.NoError(t, err) + require.NotNil(t, crs) + require.EqualValues(t, 1, len(crs)) + expectedRememberFor := int(initialRememberFor + extendRememberFor - remainingValidTime) + cr := crs[0] + require.EqualValues(t, "challenge-1-1", cr.ID) + require.InDelta(t, expectedRememberFor, cr.RememberFor, 1) + }) +} + func TestGetLoginRequestWithDuplicateAccept(t *testing.T) { t.Run("Test get login request with duplicate accept", func(t *testing.T) { challenge := "challenge" diff --git a/consent/manager.go b/consent/manager.go index 2910bcc9e40..318218fe985 100644 --- a/consent/manager.go +++ b/consent/manager.go @@ -7,6 +7,8 @@ import ( "context" "time" + "github.com/ory/fosite" + "github.com/gofrs/uuid" "github.com/ory/hydra/v2/client" @@ -27,6 +29,7 @@ type Manager interface { CreateConsentRequest(ctx context.Context, req *OAuth2ConsentRequest) error GetConsentRequest(ctx context.Context, challenge string) (*OAuth2ConsentRequest, error) HandleConsentRequest(ctx context.Context, r *AcceptOAuth2ConsentRequest) (*OAuth2ConsentRequest, error) + ExtendConsentRequest(ctx context.Context, scopeStrategy fosite.ScopeStrategy, req *OAuth2ConsentRequest, extendBy int) error RevokeSubjectConsentSession(ctx context.Context, user string) error RevokeSubjectClientConsentSession(ctx context.Context, user, client string) error diff --git a/consent/manager_test_helpers.go b/consent/manager_test_helpers.go index 084b9d4c4a4..1461724e82c 100644 --- a/consent/manager_test_helpers.go +++ b/consent/manager_test_helpers.go @@ -294,7 +294,7 @@ func TestHelperNID(t1ClientManager client.Manager, t1ValidNID Manager, t2Invalid } } -func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.FositeStorer, network string, parallel bool) func(t *testing.T) { +func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.FositeStorer, scopeStrategy fosite.ScopeStrategy, network string, parallel bool) func(t *testing.T) { lr := make(map[string]*LoginRequest) return func(t *testing.T) { @@ -534,6 +534,129 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit } }) + t.Run("case=extend consent request", func(t *testing.T) { + cl := &client.Client{LegacyClientID: "client-1"} + _ = clientManager.CreateClient(context.Background(), cl) + consentFlow := func(subject, sessionId, challenge string, rememberFor time.Duration, requestedAt time.Time, requestedScope string, skip bool) *OAuth2ConsentRequest { + require.NoError(t, m.CreateLoginRequest(context.Background(), &LoginRequest{ + ID: makeID("challenge", network, challenge), + SessionID: sqlxx.NullString(makeID("fk-login-session", network, sessionId)), + Client: cl, + Subject: subject, + Verifier: uuid.New().String(), + RequestedAt: requestedAt, + RequestedScope: []string{requestedScope}, + })) + + require.NoError(t, m.CreateConsentRequest(context.Background(), &OAuth2ConsentRequest{ + ID: makeID("challenge", network, challenge), + Client: cl, + Subject: subject, + LoginSessionID: sqlxx.NullString(makeID("fk-login-session", network, sessionId)), + LoginChallenge: sqlxx.NullString(makeID("challenge", network, challenge)), + Skip: skip, + Verifier: uuid.New().String(), + CSRF: "csrf1", + })) + cr, err := m.HandleConsentRequest(context.Background(), &AcceptOAuth2ConsentRequest{ + ID: makeID("challenge", network, challenge), + Remember: true, + RememberFor: int(rememberFor), + WasHandled: true, + HandledAt: sqlxx.NullTime(time.Now().UTC()), + GrantedScope: []string{"scope-a"}, + }) + require.NoError(t, err) + return cr + } + + t.Run("case=extend session related and latest consent expiry times", func(t *testing.T) { + var rememberForSession1 time.Duration = 300 + var remainingValidTimeSession1 time.Duration = 100 + var rememberForSession2 time.Duration = 300 + var remainingValidTimeSession2 time.Duration = 150 + var extendRememberFor time.Duration = 1000 + requestedAt1 := time.Now().UTC().Round(time.Second).Add(-(rememberForSession1 - remainingValidTimeSession1) * time.Second) + requestedAt2 := time.Now().UTC().Round(time.Second).Add(-(rememberForSession2 - remainingValidTimeSession2) * time.Second) + requestedAt3 := time.Now().UTC() + require.NoError(t, m.CreateLoginSession(context.Background(), &LoginSession{ + ID: makeID("fk-login-session", network, "ec1"), + Subject: "subject-1", + })) + require.NoError(t, m.CreateLoginSession(context.Background(), &LoginSession{ + ID: makeID("fk-login-session", network, "ec2"), + Subject: "subject-1", + })) + consentFlow("subject-1", "ec1", "c1", rememberForSession1, requestedAt1, "scope-a", false) + consentFlow("subject-1", "ec2", "c2", rememberForSession2, requestedAt2, "scope-a", false) + cr := consentFlow("subject-1", "ec1", "c3", extendRememberFor, requestedAt3, "scope-a", true) + + require.NoError(t, m.ExtendConsentRequest(context.Background(), scopeStrategy, cr, int(extendRememberFor))) + + crs, err := m.FindSubjectsGrantedConsentRequests(context.Background(), "subject-1", 100, 0) + require.NoError(t, err) + require.EqualValues(t, 2, len(crs)) + crSession := crs[1] + require.EqualValues(t, makeID("challenge", network, "c1"), crSession.ID) + expectedExtendedRememberFor1 := int(rememberForSession1 + extendRememberFor - remainingValidTimeSession1) + require.InDelta(t, expectedExtendedRememberFor1, crSession.RememberFor, 1) + crLatest := crs[0] + require.EqualValues(t, makeID("challenge", network, "c2"), crLatest.ID) + expectedExtendedRememberFor2 := int(rememberForSession2 + extendRememberFor - remainingValidTimeSession2) + require.InDelta(t, expectedExtendedRememberFor2, crLatest.RememberFor, 1) + }) + + t.Run("case=no previous consent found", func(t *testing.T) { + require.NoError(t, m.CreateLoginSession(context.Background(), &LoginSession{ + ID: makeID("fk-login-session", network, "ec3"), + Subject: "subject-1", + })) + cr := consentFlow("subject-1", "ec3", "c4", 300, time.Now().UTC(), "scope-a", true) + + require.ErrorIs(t, m.ExtendConsentRequest(context.Background(), scopeStrategy, cr, 1000), ErrNoPreviousConsentFound) + }) + + t.Run("case=invalid requested scope", func(t *testing.T) { + var rememberForSession1 time.Duration = 300 + var remainingValidTimeSession1 time.Duration = 100 + requestedAt1 := time.Now().UTC().Round(time.Second).Add(-(rememberForSession1 - remainingValidTimeSession1) * time.Second) + requestedAt2 := time.Now().UTC() + require.NoError(t, m.CreateLoginSession(context.Background(), &LoginSession{ + ID: makeID("fk-login-session", network, "ec4"), + Subject: "subject-2", + })) + consentFlow("subject-2", "ec4", "c5", 300, requestedAt1, "scope-a", false) + cr := consentFlow("subject-2", "ec4", "c6", 300, requestedAt2, "scope-b", true) + + require.NoError(t, m.ExtendConsentRequest(context.Background(), scopeStrategy, cr, 1000)) + + crs, err := m.FindSubjectsGrantedConsentRequests(context.Background(), "subject-2", 10, 0) + require.NoError(t, err) + require.EqualValues(t, 1, len(crs)) + cr1 := crs[0] + require.EqualValues(t, makeID("challenge", network, "c5"), cr1.ID) + require.EqualValues(t, 300, cr1.RememberFor) + }) + + t.Run("case=initial consent request expired", func(t *testing.T) { + var rememberForSession1 time.Duration = 300 + var remainingValidTimeSession1 time.Duration = 0 + requestedAtExpired := time.Now().UTC().Round(time.Second).Add(-(rememberForSession1 - remainingValidTimeSession1) * time.Second) + require.NoError(t, m.CreateLoginSession(context.Background(), &LoginSession{ + ID: makeID("fk-login-session", network, "ec5"), + Subject: "subject-3", + })) + consentFlow("subject-3", "ec5", "c7", 300, requestedAtExpired, "scope-a", false) + time.Sleep(time.Second) + cr := consentFlow("subject-3", "ec5", "c8", 300, time.Now().UTC(), "scope-a", true) + + require.NoError(t, m.ExtendConsentRequest(context.Background(), scopeStrategy, cr, 1000)) + + _, err := m.FindSubjectsGrantedConsentRequests(context.Background(), "subject-3", 100, 0) + require.Error(t, err, ErrNoPreviousConsentFound) + }) + }) + t.Run("case=revoke-auth-request", func(t *testing.T) { require.NoError(t, m.CreateLoginSession(context.Background(), &LoginSession{ ID: makeID("rev-session", network, "-1"), diff --git a/persistence/sql/persister_consent.go b/persistence/sql/persister_consent.go index 8f1fca3d490..b696fa712a4 100644 --- a/persistence/sql/persister_consent.go +++ b/persistence/sql/persister_consent.go @@ -283,6 +283,88 @@ func (p *Persister) HandleConsentRequest(ctx context.Context, r *consent.AcceptO return p.GetConsentRequest(ctx, r.ID) } +func (p *Persister) ExtendConsentRequest(ctx context.Context, scopeStrategy fosite.ScopeStrategy, cr *consent.OAuth2ConsentRequest, extendBy int) error { + return p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ExtendConsentRequest") + defer span.End() + + sessionFlow := &flow.Flow{} + if err := c. + Where( + strings.TrimSpace(fmt.Sprintf(` +(state = %d OR state = %d) AND +subject = ? AND +client_id = ? AND +login_session_id = ? AND +consent_skip=FALSE AND +consent_error='{}' AND +consent_remember=TRUE AND +nid = ?`, flow.FlowStateConsentUsed, flow.FlowStateConsentUnused, + )), + cr.Subject, cr.ClientID, cr.LoginSessionID.String(), p.NetworkID(ctx)). + Order("requested_at DESC"). + Limit(1). + First(sessionFlow); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return errorsx.WithStack(consent.ErrNoPreviousConsentFound) + } + return sqlcon.HandleError(err) + } + + latestFlow := &flow.Flow{} + if err := c. + Where( + strings.TrimSpace(fmt.Sprintf(` +(state = %d OR state = %d) AND +subject = ? AND +client_id = ? AND +consent_skip=FALSE AND +consent_error='{}' AND +consent_remember=TRUE AND +nid = ?`, flow.FlowStateConsentUsed, flow.FlowStateConsentUnused, + )), + cr.Subject, cr.ClientID, p.NetworkID(ctx)). + Order("requested_at DESC"). + Limit(1). + First(latestFlow); err != nil { + return sqlcon.HandleError(err) + } + + if err := p.extendHandledConsentRequest(ctx, cr, scopeStrategy, sessionFlow, extendBy); err != nil { + return err + } + + if latestFlow.ID != sessionFlow.ID { + if err := p.extendHandledConsentRequest(ctx, cr, scopeStrategy, latestFlow, extendBy); err != nil { + return err + } + } + return nil + }) +} + +func (p *Persister) extendHandledConsentRequest(ctx context.Context, cr *consent.OAuth2ConsentRequest, scopeStrategy fosite.ScopeStrategy, f *flow.Flow, extendBy int) error { + for _, scope := range cr.RequestedScope { + if !scopeStrategy(f.GrantedScope, scope) { + return nil + } + } + hcr := f.GetHandledConsentRequest() + if isConsentRequestExpired := hcr.RememberFor > 0 && hcr.RequestedAt.Add(time.Duration(hcr.RememberFor)*time.Second).Before(time.Now().UTC()); isConsentRequestExpired { + return nil + } + remainingTime := hcr.RequestedAt.Unix() + int64(hcr.RememberFor) - time.Now().Unix() + extendedRememberFor := hcr.RememberFor + extendBy - int(remainingTime) + f.ConsentRememberFor = &extendedRememberFor + + _, err := p.UpdateWithNetwork(ctx, f) + if err != nil { + return sqlcon.HandleError(err) + } else { + return nil + } +} + func (p *Persister) VerifyAndInvalidateConsentRequest(ctx context.Context, verifier string) (*consent.AcceptOAuth2ConsentRequest, error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.VerifyAndInvalidateConsentRequest") defer span.End() diff --git a/persistence/sql/persister_test.go b/persistence/sql/persister_test.go index 475d32b88a8..fb3b28e5302 100644 --- a/persistence/sql/persister_test.go +++ b/persistence/sql/persister_test.go @@ -51,9 +51,9 @@ func testRegistry(t *testing.T, ctx context.Context, k string, t1 driver.Registr if k == "memory" || k == "mysql" || k == "cockroach" { // TODO enable parallel tests for cockroach once we configure the cockroach integration test server to support retry parallel = false } - - t.Run("package=consent/manager="+k, consent.ManagerTests(t1.ConsentManager(), t1.ClientManager(), t1.OAuth2Storage(), "t1", parallel)) - t.Run("package=consent/manager="+k, consent.ManagerTests(t2.ConsentManager(), t2.ClientManager(), t2.OAuth2Storage(), "t2", parallel)) + scopeStrategy := t1.Config().GetScopeStrategy(ctx) + t.Run("package=consent/manager="+k, consent.ManagerTests(t1.ConsentManager(), t1.ClientManager(), t1.OAuth2Storage(), scopeStrategy, "t1", parallel)) + t.Run("package=consent/manager="+k, consent.ManagerTests(t2.ConsentManager(), t2.ClientManager(), t2.OAuth2Storage(), scopeStrategy, "t2", parallel)) t.Run("parallel-boundary", func(t *testing.T) { t.Run("package=consent/janitor="+k, testhelpers.JanitorTests(t1.Config(), t1.ConsentManager(), t1.ClientManager(), t1.OAuth2Storage(), "t1", parallel))