Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: improve session extend performance #3948

Merged
merged 2 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions driver/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ const (
ViperKeySessionTokenizerTemplates = "session.whoami.tokenizer.templates"
ViperKeySessionWhoAmIAAL = "session.whoami.required_aal"
ViperKeySessionWhoAmICaching = "feature_flags.cacheable_sessions"
ViperKeyFeatureFlagFasterSessionExtend = "feature_flags.faster_session_extend"
ViperKeySessionWhoAmICachingMaxAge = "feature_flags.cacheable_sessions_max_age"
ViperKeyUseContinueWithTransitions = "feature_flags.use_continue_with_transitions"
ViperKeySessionRefreshMinTimeLeft = "session.earliest_possible_extend"
Expand Down Expand Up @@ -1369,6 +1370,10 @@ func (p *Config) SessionWhoAmICaching(ctx context.Context) bool {
return p.GetProvider(ctx).Bool(ViperKeySessionWhoAmICaching)
}

func (p *Config) FeatureFlagFasterSessionExtend(ctx context.Context) bool {
return p.GetProvider(ctx).Bool(ViperKeyFeatureFlagFasterSessionExtend)
}

func (p *Config) SessionWhoAmICachingMaxAge(ctx context.Context) time.Duration {
return p.GetProvider(ctx).DurationF(ViperKeySessionWhoAmICachingMaxAge, 0)
}
Expand Down
6 changes: 6 additions & 0 deletions embedx/config.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -2852,6 +2852,12 @@
"title": "Enable new flow transitions using `continue_with` items",
"description": "If enabled allows new flow transitions using `continue_with` items.",
"default": false
},
"faster_session_extend": {
"type": "boolean",
"title": "Enable faster session extension",
"description": "If enabled allows faster session extension by skipping the session lookup. Disabling this feature will be deprecated in the future.",
"default": false
}
},
"additionalProperties": false
Expand Down
59 changes: 59 additions & 0 deletions persistence/sql/persister_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
"fmt"
"time"

"github.com/ory/herodot"
"github.com/ory/x/dbal"

"github.com/gobuffalo/pop/v6"
"github.com/gofrs/uuid"
"github.com/pkg/errors"
Expand Down Expand Up @@ -176,6 +179,61 @@
return s, t, nil
}

// ExtendSession updates the expiry of a session.
func (p *Persister) ExtendSession(ctx context.Context, sessionID uuid.UUID) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ExtendSession")
defer otelx.End(span, &err)

nid := p.NetworkID(ctx)
s := new(session.Session)
var didRefresh bool
if err := errors.WithStack(p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) (err error) {
lockBehavior := ""
if tx.Dialect.Name() == dbal.DriverCockroachDB {
// SKIP LOCKED returns no rows if the row is locked by another transaction.
lockBehavior = "FOR UPDATE SKIP LOCKED"
}

if err := tx.
Where(
// We make use of the fact that CRDB supports FOR UPDATE as part of the WHERE clause.
fmt.Sprintf("id = ? AND nid = ? %s", lockBehavior),
sessionID, nid,
).First(s); err != nil {

// This is a special case for CockroachDB. If the row is locked, we do not see the session. Therefor we return
// a 404 not found error indicating to the user that the session might already be updated by someone else.
if errors.Is(err, sqlcon.ErrNoRows) && tx.Dialect.Name() == dbal.DriverCockroachDB {
return errors.WithStack(herodot.ErrNotFound.WithReason("The session you are trying to extend is already being extended by another request or does not exist."))

Check warning on line 207 in persistence/sql/persister_session.go

View check run for this annotation

Codecov / codecov/patch

persistence/sql/persister_session.go#L207

Added line #L207 was not covered by tests
}

return sqlcon.HandleError(err)
}

if !s.CanBeRefreshed(ctx, p.r.Config()) {
// This prevents excessive writes to the database.
return nil
aeneasr marked this conversation as resolved.
Show resolved Hide resolved
}

didRefresh = true
s = s.Refresh(ctx, p.r.Config())

if _, err := tx.Where("id = ? AND nid = ?", sessionID, nid).UpdateQuery(s, "expires_at"); err != nil {
return sqlcon.HandleError(err)

Check warning on line 222 in persistence/sql/persister_session.go

View check run for this annotation

Codecov / codecov/patch

persistence/sql/persister_session.go#L222

Added line #L222 was not covered by tests
}

return nil
})); err != nil {
return err
}

if didRefresh {
trace.SpanFromContext(ctx).AddEvent(events.NewSessionLifespanExtended(ctx, s.ID, s.IdentityID, s.ExpiresAt))
}

return nil
}

// UpsertSession creates a session if not found else updates.
// This operation also inserts Session device records when a session is being created.
// The update operation skips updating Session device records since only one record would need to be updated in this case.
Expand All @@ -196,6 +254,7 @@
trace.SpanFromContext(ctx).AddEvent(events.NewSessionIssued(ctx, string(s.AuthenticatorAssuranceLevel), s.ID, s.IdentityID))
}
}()

return errors.WithStack(p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) (err error) {
updated = false
exists := false
Expand Down
27 changes: 18 additions & 9 deletions session/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -873,6 +873,10 @@
// Calling this endpoint extends the given session ID. If `session.earliest_possible_extend` is set it
// will only extend the session after the specified time has passed.
//
// This endpoint returns per default a 204 No Content response on success. Older Ory Network projects may
// return a 200 OK response with the session in the body. Returning the session as part of the response
// will be deprecated in the future and should not be relied upon.
//
// Retrieve the session ID from the `/sessions/whoami` endpoint / `toSession` SDK method.
//
// Schemes: http, https
Expand All @@ -882,30 +886,35 @@
//
// Responses:
// 200: session
// 204: emptyResponse
// 400: errorGeneric
// 404: errorGeneric
// default: errorGeneric
func (h *Handler) adminSessionExtend(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
iID, err := uuid.FromString(ps.ByName("id"))
id, err := uuid.FromString(ps.ByName("id"))
if err != nil {
h.r.Writer().WriteError(w, r, errors.WithStack(herodot.ErrBadRequest.WithError(err.Error()).WithDebug("could not parse UUID")))
return
}

s, err := h.r.SessionPersister().GetSession(r.Context(), iID, ExpandDefault)
if err != nil {
c := h.r.Config()
if err := h.r.SessionPersister().ExtendSession(r.Context(), id); err != nil {
h.r.Writer().WriteError(w, r, err)
return
}

c := h.r.Config()
if s.CanBeRefreshed(r.Context(), c) {
if err := h.r.SessionPersister().UpsertSession(r.Context(), s.Refresh(r.Context(), c)); err != nil {
h.r.Writer().WriteError(w, r, err)
return
}
// Default behavior going forward.
if c.FeatureFlagFasterSessionExtend(r.Context()) {
w.WriteHeader(http.StatusNoContent)
return

Check warning on line 909 in session/handler.go

View check run for this annotation

Codecov / codecov/patch

session/handler.go#L908-L909

Added lines #L908 - L909 were not covered by tests
}

// WARNING - this will be deprecated at some point!
s, err := h.r.SessionPersister().GetSession(r.Context(), id, ExpandDefault)
if err != nil {
h.r.Writer().WriteError(w, r, err)
return

Check warning on line 916 in session/handler.go

View check run for this annotation

Codecov / codecov/patch

session/handler.go#L915-L916

Added lines #L915 - L916 were not covered by tests
}
h.r.Writer().Write(w, r, s)
}

Expand Down
3 changes: 3 additions & 0 deletions session/persistence.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ type Persister interface {
// UpsertSession inserts or updates a session into / in the store.
UpsertSession(ctx context.Context, s *Session) error

// ExtendSession updates the expiry of a session.
ExtendSession(ctx context.Context, sessionID uuid.UUID) error

// DeleteSession removes a session from the store.
DeleteSession(ctx context.Context, id uuid.UUID) error

Expand Down
82 changes: 82 additions & 0 deletions session/test/persistence.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ import (
"testing"
"time"

"github.com/pkg/errors"
"golang.org/x/sync/errgroup"

"github.com/ory/x/dbal"

"github.com/gobuffalo/pop/v6"

"github.com/ory/x/pagination/keysetpagination"
Expand Down Expand Up @@ -604,5 +609,82 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface {
_, err = p.GetSessionByToken(ctx, t2, session.ExpandNothing, identity.ExpandDefault)
require.ErrorIs(t, err, sqlcon.ErrNoRows)
})

t.Run("extend session lifespan but min time is not yet reached", func(t *testing.T) {
conf.MustSet(ctx, config.ViperKeySessionRefreshMinTimeLeft, time.Hour*2)
t.Cleanup(func() {
conf.MustSet(ctx, config.ViperKeySessionRefreshMinTimeLeft, nil)
})

var expected session.Session
require.NoError(t, faker.FakeData(&expected))
expected.ExpiresAt = time.Now().Add(time.Hour * 10).Round(time.Second).UTC()
require.NoError(t, p.CreateIdentity(ctx, expected.Identity))
require.NoError(t, p.UpsertSession(ctx, &expected))

require.NoError(t, p.ExtendSession(ctx, expected.ID))
actual, err := p.GetSession(ctx, expected.ID, session.ExpandNothing)
require.NoError(t, err)
assert.Equal(t, expected.ExpiresAt, actual.ExpiresAt)
})

t.Run("extend session lifespan", func(t *testing.T) {
conf.MustSet(ctx, config.ViperKeySessionRefreshMinTimeLeft, time.Hour)
t.Cleanup(func() {
conf.MustSet(ctx, config.ViperKeySessionRefreshMinTimeLeft, nil)
})

conf.MustSet(ctx, config.ViperKeySessionRefreshMinTimeLeft, time.Hour*2)
var expected session.Session
require.NoError(t, faker.FakeData(&expected))
expected.ExpiresAt = time.Now().Add(time.Hour).UTC()
require.NoError(t, p.CreateIdentity(ctx, expected.Identity))
require.NoError(t, p.UpsertSession(ctx, &expected))

expectedExpiry := expected.Refresh(ctx, conf).ExpiresAt.Round(time.Minute)
require.NoError(t, p.ExtendSession(ctx, expected.ID))
actual, err := p.GetSession(ctx, expected.ID, session.ExpandNothing)
require.NoError(t, err)
assert.Equal(t, expectedExpiry, actual.ExpiresAt.Round(time.Minute))
})

t.Run("extend session lifespan on CockroachDB", func(t *testing.T) {
if p.GetConnection(ctx).Dialect.Name() != dbal.DriverCockroachDB {
t.Skip("Skipping test because driver is not CockroachDB")
}

conf.MustSet(ctx, config.ViperKeySessionRefreshMinTimeLeft, time.Hour)
t.Cleanup(func() {
conf.MustSet(ctx, config.ViperKeySessionRefreshMinTimeLeft, nil)
})

conf.MustSet(ctx, config.ViperKeySessionRefreshMinTimeLeft, time.Hour*2)
var expected session.Session
require.NoError(t, faker.FakeData(&expected))
expected.ExpiresAt = time.Now().Add(time.Hour).UTC()
require.NoError(t, p.CreateIdentity(ctx, expected.Identity))
require.NoError(t, p.UpsertSession(ctx, &expected))

expectedExpiry := expected.Refresh(ctx, conf).ExpiresAt.Round(time.Minute)

var foundExpectedCockroachError bool
g := errgroup.Group{}
for i := 0; i < 10; i++ {
g.Go(func() error {
err := p.ExtendSession(ctx, expected.ID)
if errors.Is(err, sqlcon.ErrNoRows) {
foundExpectedCockroachError = true
return nil
}
return err
})
}
require.NoError(t, g.Wait())

actual, err := p.GetSession(ctx, expected.ID, session.ExpandNothing)
require.NoError(t, err)
assert.Equal(t, expectedExpiry, actual.ExpiresAt.Round(time.Minute))
assert.True(t, foundExpectedCockroachError, "We expect to find a not found error caused by ... FOR UPDATE SKIP LOCKED")
})
}
}
58 changes: 38 additions & 20 deletions x/events/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,31 +16,33 @@ import (
)

const (
SessionIssued semconv.Event = "SessionIssued"
SessionChanged semconv.Event = "SessionChanged"
SessionRevoked semconv.Event = "SessionRevoked"
SessionChecked semconv.Event = "SessionChecked"
SessionTokenizedAsJWT semconv.Event = "SessionTokenizedAsJWT"
RegistrationFailed semconv.Event = "RegistrationFailed"
RegistrationSucceeded semconv.Event = "RegistrationSucceeded"
LoginFailed semconv.Event = "LoginFailed"
LoginSucceeded semconv.Event = "LoginSucceeded"
SettingsFailed semconv.Event = "SettingsFailed"
SettingsSucceeded semconv.Event = "SettingsSucceeded"
RecoveryFailed semconv.Event = "RecoveryFailed"
RecoverySucceeded semconv.Event = "RecoverySucceeded"
VerificationFailed semconv.Event = "VerificationFailed"
VerificationSucceeded semconv.Event = "VerificationSucceeded"
IdentityCreated semconv.Event = "IdentityCreated"
IdentityUpdated semconv.Event = "IdentityUpdated"
WebhookDelivered semconv.Event = "WebhookDelivered"
WebhookSucceeded semconv.Event = "WebhookSucceeded"
WebhookFailed semconv.Event = "WebhookFailed"
SessionIssued semconv.Event = "SessionIssued"
SessionChanged semconv.Event = "SessionChanged"
SessionLifespanExtended semconv.Event = "SessionLifespanExtended"
SessionRevoked semconv.Event = "SessionRevoked"
SessionChecked semconv.Event = "SessionChecked"
SessionTokenizedAsJWT semconv.Event = "SessionTokenizedAsJWT"
RegistrationFailed semconv.Event = "RegistrationFailed"
RegistrationSucceeded semconv.Event = "RegistrationSucceeded"
LoginFailed semconv.Event = "LoginFailed"
LoginSucceeded semconv.Event = "LoginSucceeded"
SettingsFailed semconv.Event = "SettingsFailed"
SettingsSucceeded semconv.Event = "SettingsSucceeded"
RecoveryFailed semconv.Event = "RecoveryFailed"
RecoverySucceeded semconv.Event = "RecoverySucceeded"
VerificationFailed semconv.Event = "VerificationFailed"
VerificationSucceeded semconv.Event = "VerificationSucceeded"
IdentityCreated semconv.Event = "IdentityCreated"
IdentityUpdated semconv.Event = "IdentityUpdated"
WebhookDelivered semconv.Event = "WebhookDelivered"
WebhookSucceeded semconv.Event = "WebhookSucceeded"
WebhookFailed semconv.Event = "WebhookFailed"
)

const (
attributeKeySessionID semconv.AttributeKey = "SessionID"
attributeKeySessionAAL semconv.AttributeKey = "SessionAAL"
attributeKeySessionExpiresAt semconv.AttributeKey = "SessionExpiresAt"
attributeKeySelfServiceFlowType semconv.AttributeKey = "SelfServiceFlowType"
attributeKeySelfServiceMethodUsed semconv.AttributeKey = "SelfServiceMethodUsed"
attributeKeySelfServiceSSOProviderUsed semconv.AttributeKey = "SelfServiceSSOProviderUsed"
Expand Down Expand Up @@ -71,6 +73,10 @@ func attLoginRequestedAAL(val string) otelattr.KeyValue {
return otelattr.String(attributeKeyLoginRequestedAAL.String(), val)
}

func attSessionExpiresAt(expiresAt time.Time) otelattr.KeyValue {
return otelattr.String(attributeKeySessionExpiresAt.String(), expiresAt.String())
}

func attLoginRequestedPrivilegedSession(val bool) otelattr.KeyValue {
return otelattr.Bool(attributeKeyLoginRequestedPrivilegedSession.String(), val)
}
Expand Down Expand Up @@ -135,6 +141,18 @@ func NewSessionChanged(ctx context.Context, aal string, sessionID, identityID uu
)
}

func NewSessionLifespanExtended(ctx context.Context, sessionID, identityID uuid.UUID, newExpiry time.Time) (string, trace.EventOption) {
return SessionLifespanExtended.String(),
trace.WithAttributes(
append(
semconv.AttributesFromContext(ctx),
semconv.AttrIdentityID(identityID),
attrSessionID(sessionID),
attSessionExpiresAt(newExpiry),
)...,
)
}

type LoginSucceededOpts struct {
SessionID, IdentityID uuid.UUID
FlowType, RequestedAAL, Method, SSOProvider string
Expand Down
Loading