Skip to content

Commit

Permalink
feat: extend latest and session related consents
Browse files Browse the repository at this point in the history
  • Loading branch information
aarmam committed Nov 1, 2022
1 parent d37b323 commit 47cbb3c
Show file tree
Hide file tree
Showing 6 changed files with 313 additions and 6 deletions.
8 changes: 8 additions & 0 deletions consent/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,14 @@ func (h *Handler) acceptOAuth2ConsentRequest(w http.ResponseWriter, r *http.Requ
h.r.Writer().WriteError(w, r, errorsx.WithStack(err))
return
} else if hr.Skip {
if p.Remember && p.RememberFor > 0 { // TODO: Consider removing 'p.RememberFor > 0' to update consent validity in both ways (limited (RememberFor > 0) -> indefinitely (RememberFor = 0) and vice versa)
var ctx = r.Context()
err = h.r.ConsentManager().ExtendConsentRequest(r.Context(), h.r.Config().GetScopeStrategy(ctx), hr, p.RememberFor)
if err != nil {
h.r.Writer().WriteError(w, r, errorsx.WithStack(err))
return
}
}
p.Remember = false
}

Expand Down
95 changes: 93 additions & 2 deletions consent/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,10 @@ import (
"testing"
"time"

"github.com/ory/x/pointerx"

"github.com/ory/hydra/consent"
"github.com/ory/hydra/x"
"github.com/ory/x/contextx"
"github.com/ory/x/pointerx"
"github.com/ory/x/sqlxx"

"github.com/ory/hydra/internal"
Expand Down Expand Up @@ -231,6 +230,98 @@ func TestGetConsentRequest(t *testing.T) {
}
}

func TestExtendConsentRequest(t *testing.T) {
t.Run("case=extend consent expiry time", func(t *testing.T) {
conf := internal.NewConfigurationWithDefaults()
reg := internal.NewRegistryMemory(t, conf, &contextx.Default{})
h := NewHandler(reg, conf)
r := x.NewRouterAdmin(conf.AdminURL)
h.SetRoutes(r)
ts := httptest.NewServer(r)
defer ts.Close()

c := &http.Client{}
cl := &client.Client{LegacyClientID: "client-1"}
require.NoError(t, reg.ClientManager().CreateClient(context.Background(), cl))

var initialRememberFor time.Duration = 300
var remainingValidTime time.Duration = 100

require.NoError(t, reg.ConsentManager().CreateLoginSession(context.Background(), &LoginSession{
ID: makeID("fk-login-session", "1", "1"),
Subject: "subject-1",
}))
requestedTimeInPast := time.Now().UTC().Add(-(initialRememberFor - remainingValidTime) * time.Second)
require.NoError(t, reg.ConsentManager().CreateLoginRequest(context.Background(), &LoginRequest{
ID: makeID("challenge", "1", "1"),
SessionID: sqlxx.NullString(makeID("fk-login-session", "1", "1")),
Client: cl,
Subject: "subject-1",
RequestedAt: requestedTimeInPast,
}))
require.NoError(t, reg.ConsentManager().CreateConsentRequest(context.Background(), &OAuth2ConsentRequest{
ID: makeID("challenge", "1", "1"),
Subject: "subject-1",
Client: cl,
LoginSessionID: sqlxx.NullString(makeID("fk-login-session", "1", "1")),
LoginChallenge: sqlxx.NullString(makeID("challenge", "1", "1")),
Verifier: makeID("verifier", "1", "1"),
CSRF: "csrf1",
Skip: false,
ACR: "1",
}))
_, err := reg.ConsentManager().HandleConsentRequest(context.Background(), &AcceptOAuth2ConsentRequest{
ID: makeID("challenge", "1", "1"),
Remember: true,
RememberFor: int(initialRememberFor),
WasHandled: true,
HandledAt: sqlxx.NullTime(time.Now().UTC()),
})
require.NoError(t, err)

require.NoError(t, reg.ConsentManager().CreateLoginRequest(context.Background(), &LoginRequest{
ID: makeID("challenge", "1", "2"),
SessionID: sqlxx.NullString(makeID("fk-login-session", "1", "1")),
Verifier: makeID("verifier", "1", "1"),
Client: cl,
RequestedAt: time.Now().UTC(),
Subject: "subject-1",
}))
require.NoError(t, reg.ConsentManager().CreateConsentRequest(context.Background(), &OAuth2ConsentRequest{
ID: makeID("challenge", "1", "2"),
Subject: "subject-1",
Client: cl,
LoginSessionID: sqlxx.NullString(makeID("fk-login-session", "1", "1")),
LoginChallenge: sqlxx.NullString(makeID("challenge", "1", "2")),
Verifier: makeID("verifier", "1", "2"),
CSRF: "csrf2",
Skip: true,
}))

var b bytes.Buffer
var extendRememberFor time.Duration = 300
require.NoError(t, json.NewEncoder(&b).Encode(&AcceptOAuth2ConsentRequest{
Remember: true,
RememberFor: int(extendRememberFor),
}))

req, err := http.NewRequest(http.MethodPut, ts.URL+"/admin"+ConsentPath+"/accept?challenge=challenge-1-2", &b)
require.NoError(t, err)
resp, err := c.Do(req)
require.NoError(t, err)
require.EqualValues(t, 200, resp.StatusCode)

crs, err := reg.ConsentManager().FindSubjectsGrantedConsentRequests(context.Background(), "subject-1", 100, 0)
require.NoError(t, err)
require.NotNil(t, crs)
require.EqualValues(t, 1, len(crs))
expectedRememberFor := int(initialRememberFor + extendRememberFor - remainingValidTime)
cr := crs[0]
require.EqualValues(t, "challenge-1-1", cr.ID)
require.InDelta(t, expectedRememberFor, cr.RememberFor, 1)
})
}

func TestGetLoginRequestWithDuplicateAccept(t *testing.T) {
t.Run("Test get login request with duplicate accept", func(t *testing.T) {
challenge := "challenge"
Expand Down
3 changes: 3 additions & 0 deletions consent/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import (
"context"
"time"

"github.com/ory/fosite"

"github.com/gofrs/uuid"

"github.com/ory/hydra/client"
Expand All @@ -44,6 +46,7 @@ type Manager interface {
CreateConsentRequest(ctx context.Context, req *OAuth2ConsentRequest) error
GetConsentRequest(ctx context.Context, challenge string) (*OAuth2ConsentRequest, error)
HandleConsentRequest(ctx context.Context, r *AcceptOAuth2ConsentRequest) (*OAuth2ConsentRequest, error)
ExtendConsentRequest(ctx context.Context, scopeStrategy fosite.ScopeStrategy, req *OAuth2ConsentRequest, extendBy int) error
RevokeSubjectConsentSession(ctx context.Context, user string) error
RevokeSubjectClientConsentSession(ctx context.Context, user, client string) error

Expand Down
125 changes: 124 additions & 1 deletion consent/manager_test_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ func TestHelperNID(t1ClientManager client.Manager, t1ValidNID Manager, t2Invalid
}
}

func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.FositeStorer, network string, parallel bool) func(t *testing.T) {
func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.FositeStorer, scopeStrategy fosite.ScopeStrategy, network string, parallel bool) func(t *testing.T) {
lr := make(map[string]*LoginRequest)

return func(t *testing.T) {
Expand Down Expand Up @@ -548,6 +548,129 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit
}
})

t.Run("case=extend consent request", func(t *testing.T) {
cl := &client.Client{LegacyClientID: "client-1"}
_ = clientManager.CreateClient(context.Background(), cl)
consentFlow := func(subject, sessionId, challenge string, rememberFor time.Duration, requestedAt time.Time, requestedScope string, skip bool) *OAuth2ConsentRequest {
require.NoError(t, m.CreateLoginRequest(context.Background(), &LoginRequest{
ID: makeID("challenge", network, challenge),
SessionID: sqlxx.NullString(makeID("fk-login-session", network, sessionId)),
Client: cl,
Subject: subject,
Verifier: uuid.New().String(),
RequestedAt: requestedAt,
RequestedScope: []string{requestedScope},
}))

require.NoError(t, m.CreateConsentRequest(context.Background(), &OAuth2ConsentRequest{
ID: makeID("challenge", network, challenge),
Client: cl,
Subject: subject,
LoginSessionID: sqlxx.NullString(makeID("fk-login-session", network, sessionId)),
LoginChallenge: sqlxx.NullString(makeID("challenge", network, challenge)),
Skip: skip,
Verifier: uuid.New().String(),
CSRF: "csrf1",
}))
cr, err := m.HandleConsentRequest(context.Background(), &AcceptOAuth2ConsentRequest{
ID: makeID("challenge", network, challenge),
Remember: true,
RememberFor: int(rememberFor),
WasHandled: true,
HandledAt: sqlxx.NullTime(time.Now().UTC()),
GrantedScope: []string{"scope-a"},
})
require.NoError(t, err)
return cr
}

t.Run("case=extend session related and latest consent expiry times", func(t *testing.T) {
var rememberForSession1 time.Duration = 300
var remainingValidTimeSession1 time.Duration = 100
var rememberForSession2 time.Duration = 300
var remainingValidTimeSession2 time.Duration = 150
var extendRememberFor time.Duration = 1000
requestedAt1 := time.Now().UTC().Round(time.Second).Add(-(rememberForSession1 - remainingValidTimeSession1) * time.Second)
requestedAt2 := time.Now().UTC().Round(time.Second).Add(-(rememberForSession2 - remainingValidTimeSession2) * time.Second)
requestedAt3 := time.Now().UTC()
require.NoError(t, m.CreateLoginSession(context.Background(), &LoginSession{
ID: makeID("fk-login-session", network, "ec1"),
Subject: "subject-1",
}))
require.NoError(t, m.CreateLoginSession(context.Background(), &LoginSession{
ID: makeID("fk-login-session", network, "ec2"),
Subject: "subject-1",
}))
consentFlow("subject-1", "ec1", "c1", rememberForSession1, requestedAt1, "scope-a", false)
consentFlow("subject-1", "ec2", "c2", rememberForSession2, requestedAt2, "scope-a", false)
cr := consentFlow("subject-1", "ec1", "c3", extendRememberFor, requestedAt3, "scope-a", true)

require.NoError(t, m.ExtendConsentRequest(context.Background(), scopeStrategy, cr, int(extendRememberFor)))

crs, err := m.FindSubjectsGrantedConsentRequests(context.Background(), "subject-1", 100, 0)
require.NoError(t, err)
require.EqualValues(t, 2, len(crs))
crSession := crs[1]
require.EqualValues(t, makeID("challenge", network, "c1"), crSession.ID)
expectedExtendedRememberFor1 := int(rememberForSession1 + extendRememberFor - remainingValidTimeSession1)
require.InDelta(t, expectedExtendedRememberFor1, crSession.RememberFor, 1)
crLatest := crs[0]
require.EqualValues(t, makeID("challenge", network, "c2"), crLatest.ID)
expectedExtendedRememberFor2 := int(rememberForSession2 + extendRememberFor - remainingValidTimeSession2)
require.InDelta(t, expectedExtendedRememberFor2, crLatest.RememberFor, 1)
})

t.Run("case=no previous consent found", func(t *testing.T) {
require.NoError(t, m.CreateLoginSession(context.Background(), &LoginSession{
ID: makeID("fk-login-session", network, "ec3"),
Subject: "subject-1",
}))
cr := consentFlow("subject-1", "ec3", "c4", 300, time.Now().UTC(), "scope-a", true)

require.ErrorIs(t, m.ExtendConsentRequest(context.Background(), scopeStrategy, cr, 1000), ErrNoPreviousConsentFound)
})

t.Run("case=invalid requested scope", func(t *testing.T) {
var rememberForSession1 time.Duration = 300
var remainingValidTimeSession1 time.Duration = 100
requestedAt1 := time.Now().UTC().Round(time.Second).Add(-(rememberForSession1 - remainingValidTimeSession1) * time.Second)
requestedAt2 := time.Now().UTC()
require.NoError(t, m.CreateLoginSession(context.Background(), &LoginSession{
ID: makeID("fk-login-session", network, "ec4"),
Subject: "subject-2",
}))
consentFlow("subject-2", "ec4", "c5", 300, requestedAt1, "scope-a", false)
cr := consentFlow("subject-2", "ec4", "c6", 300, requestedAt2, "scope-b", true)

require.NoError(t, m.ExtendConsentRequest(context.Background(), scopeStrategy, cr, 1000))

crs, err := m.FindSubjectsGrantedConsentRequests(context.Background(), "subject-2", 10, 0)
require.NoError(t, err)
require.EqualValues(t, 1, len(crs))
cr1 := crs[0]
require.EqualValues(t, makeID("challenge", network, "c5"), cr1.ID)
require.EqualValues(t, 300, cr1.RememberFor)
})

t.Run("case=initial consent request expired", func(t *testing.T) {
var rememberForSession1 time.Duration = 300
var remainingValidTimeSession1 time.Duration = 0
requestedAtExpired := time.Now().UTC().Round(time.Second).Add(-(rememberForSession1 - remainingValidTimeSession1) * time.Second)
require.NoError(t, m.CreateLoginSession(context.Background(), &LoginSession{
ID: makeID("fk-login-session", network, "ec5"),
Subject: "subject-3",
}))
consentFlow("subject-3", "ec5", "c7", 300, requestedAtExpired, "scope-a", false)
time.Sleep(time.Second)
cr := consentFlow("subject-3", "ec5", "c8", 300, time.Now().UTC(), "scope-a", true)

require.NoError(t, m.ExtendConsentRequest(context.Background(), scopeStrategy, cr, 1000))

_, err := m.FindSubjectsGrantedConsentRequests(context.Background(), "subject-3", 100, 0)
require.Error(t, err, ErrNoPreviousConsentFound)
})
})

t.Run("case=revoke-auth-request", func(t *testing.T) {
require.NoError(t, m.CreateLoginSession(context.Background(), &LoginSession{
ID: makeID("rev-session", network, "-1"),
Expand Down
82 changes: 82 additions & 0 deletions persistence/sql/persister_consent.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,88 @@ func (p *Persister) HandleConsentRequest(ctx context.Context, r *consent.AcceptO
return p.GetConsentRequest(ctx, r.ID)
}

func (p *Persister) ExtendConsentRequest(ctx context.Context, scopeStrategy fosite.ScopeStrategy, cr *consent.OAuth2ConsentRequest, extendBy int) error {
return p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ExtendConsentRequest")
defer span.End()

sessionFlow := &flow.Flow{}
if err := c.
Where(
strings.TrimSpace(fmt.Sprintf(`
(state = %d OR state = %d) AND
subject = ? AND
client_id = ? AND
login_session_id = ? AND
consent_skip=FALSE AND
consent_error='{}' AND
consent_remember=TRUE AND
nid = ?`, flow.FlowStateConsentUsed, flow.FlowStateConsentUnused,
)),
cr.Subject, cr.ClientID, cr.LoginSessionID.String(), p.NetworkID(ctx)).
Order("requested_at DESC").
Limit(1).
First(sessionFlow); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return errorsx.WithStack(consent.ErrNoPreviousConsentFound)
}
return sqlcon.HandleError(err)
}

latestFlow := &flow.Flow{}
if err := c.
Where(
strings.TrimSpace(fmt.Sprintf(`
(state = %d OR state = %d) AND
subject = ? AND
client_id = ? AND
consent_skip=FALSE AND
consent_error='{}' AND
consent_remember=TRUE AND
nid = ?`, flow.FlowStateConsentUsed, flow.FlowStateConsentUnused,
)),
cr.Subject, cr.ClientID, p.NetworkID(ctx)).
Order("requested_at DESC").
Limit(1).
First(latestFlow); err != nil {
return sqlcon.HandleError(err)
}

if err := p.extendHandledConsentRequest(ctx, cr, scopeStrategy, sessionFlow, extendBy); err != nil {
return err
}

if latestFlow.ID != sessionFlow.ID {
if err := p.extendHandledConsentRequest(ctx, cr, scopeStrategy, latestFlow, extendBy); err != nil {
return err
}
}
return nil
})
}

func (p *Persister) extendHandledConsentRequest(ctx context.Context, cr *consent.OAuth2ConsentRequest, scopeStrategy fosite.ScopeStrategy, f *flow.Flow, extendBy int) error {
for _, scope := range cr.RequestedScope {
if !scopeStrategy(f.GrantedScope, scope) {
return nil
}
}
hcr := f.GetHandledConsentRequest()
if isConsentRequestExpired := hcr.RememberFor > 0 && hcr.RequestedAt.Add(time.Duration(hcr.RememberFor)*time.Second).Before(time.Now().UTC()); isConsentRequestExpired {
return nil
}
remainingTime := hcr.RequestedAt.Unix() + int64(hcr.RememberFor) - time.Now().Unix()
extendedRememberFor := hcr.RememberFor + extendBy - int(remainingTime)
f.ConsentRememberFor = &extendedRememberFor

_, err := p.UpdateWithNetwork(ctx, f)
if err != nil {
return sqlcon.HandleError(err)
} else {
return nil
}
}

func (p *Persister) VerifyAndInvalidateConsentRequest(ctx context.Context, verifier string) (*consent.AcceptOAuth2ConsentRequest, error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.VerifyAndInvalidateConsentRequest")
defer span.End()
Expand Down
6 changes: 3 additions & 3 deletions persistence/sql/persister_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ func testRegistry(t *testing.T, ctx context.Context, k string, t1 driver.Registr
if k == "memory" || k == "mysql" || k == "cockroach" { // TODO enable parallel tests for cockroach once we configure the cockroach integration test server to support retry
parallel = false
}

t.Run("package=consent/manager="+k, consent.ManagerTests(t1.ConsentManager(), t1.ClientManager(), t1.OAuth2Storage(), "t1", parallel))
t.Run("package=consent/manager="+k, consent.ManagerTests(t2.ConsentManager(), t2.ClientManager(), t2.OAuth2Storage(), "t2", parallel))
scopeStrategy := t1.Config().GetScopeStrategy(ctx)
t.Run("package=consent/manager="+k, consent.ManagerTests(t1.ConsentManager(), t1.ClientManager(), t1.OAuth2Storage(), scopeStrategy, "t1", parallel))
t.Run("package=consent/manager="+k, consent.ManagerTests(t2.ConsentManager(), t2.ClientManager(), t2.OAuth2Storage(), scopeStrategy, "t2", parallel))

t.Run("parallel-boundary", func(t *testing.T) {
t.Run("package=consent/janitor="+k, testhelpers.JanitorTests(t1.Config(), t1.ConsentManager(), t1.ClientManager(), t1.OAuth2Storage(), "t1", parallel))
Expand Down

0 comments on commit 47cbb3c

Please sign in to comment.