Skip to content

Commit

Permalink
feat: propagate logout to identity provider (#3596)
Browse files Browse the repository at this point in the history
* feat: propagate logout to identity provider

This commit improves the integration between Hydra and Kratos when logging
out the user.

This adds a new configuration key for configuring a Kratos admin URL.
Additionally, Kratos can send a session ID when accepting a login request.
If a session ID was specified and a Kratos admin URL was configured,
Hydra will disable the corresponding Kratos session through the admin API
if a frontchannel or backchannel logout was triggered.

* fix: add special case for MySQL

* chore: update sdk

* chore: consistent naming

* fix: cleanup persister
  • Loading branch information
hperl authored Aug 14, 2023
1 parent dc878b8 commit c004fee
Show file tree
Hide file tree
Showing 63 changed files with 645 additions and 61 deletions.
2 changes: 1 addition & 1 deletion consent/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ type (
// Cookie management
GetRememberedLoginSession(ctx context.Context, loginSessionFromCookie *flow.LoginSession, id string) (*flow.LoginSession, error)
CreateLoginSession(ctx context.Context, session *flow.LoginSession) error
DeleteLoginSession(ctx context.Context, id string) error
DeleteLoginSession(ctx context.Context, id string) (deletedSession *flow.LoginSession, err error)
RevokeSubjectLoginSession(ctx context.Context, user string) error
ConfirmLoginSession(ctx context.Context, session *flow.LoginSession, id string, authTime time.Time, subject string, remember bool) error

Expand Down
14 changes: 10 additions & 4 deletions consent/manager_test_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -324,8 +324,12 @@ func TestHelperNID(r interface {
require.NoError(t, err)
require.Error(t, t2InvalidNID.ConfirmLoginSession(ctx, &testLS, testLS.ID, time.Now(), testLS.Subject, true))
require.NoError(t, t1ValidNID.ConfirmLoginSession(ctx, &testLS, testLS.ID, time.Now(), testLS.Subject, true))
require.Error(t, t2InvalidNID.DeleteLoginSession(ctx, testLS.ID))
require.NoError(t, t1ValidNID.DeleteLoginSession(ctx, testLS.ID))
ls, err := t2InvalidNID.DeleteLoginSession(ctx, testLS.ID)
require.Error(t, err)
assert.Nil(t, ls)
ls, err = t1ValidNID.DeleteLoginSession(ctx, testLS.ID)
require.NoError(t, err)
assert.Equal(t, testLS.ID, ls.ID)
}
}

Expand Down Expand Up @@ -429,8 +433,9 @@ func ManagerTests(deps Deps, m Manager, clientManager client.Manager, fositeMana
},
} {
t.Run("case=delete-get-"+tc.id, func(t *testing.T) {
err := m.DeleteLoginSession(ctx, tc.id)
ls, err := m.DeleteLoginSession(ctx, tc.id)
require.NoError(t, err)
assert.EqualValues(t, tc.id, ls.ID)

_, err = m.GetRememberedLoginSession(ctx, nil, tc.id)
require.Error(t, err)
Expand Down Expand Up @@ -1083,7 +1088,8 @@ func ManagerTests(deps Deps, m Manager, clientManager client.Manager, fositeMana
require.NoError(t, err)
assert.EqualValues(t, expected.ID, result.ID)

require.NoError(t, m.DeleteLoginSession(ctx, s.ID))
_, err = m.DeleteLoginSession(ctx, s.ID)
require.NoError(t, err)

result, err = m.GetConsentRequest(ctx, expected.ID)
require.NoError(t, err)
Expand Down
2 changes: 2 additions & 0 deletions consent/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/ory/fosite/handler/openid"
"github.com/ory/hydra/v2/aead"
"github.com/ory/hydra/v2/client"
"github.com/ory/hydra/v2/internal/kratos"
"github.com/ory/hydra/v2/x"
)

Expand All @@ -17,6 +18,7 @@ type InternalRegistry interface {
x.RegistryCookieStore
x.RegistryLogger
x.HTTPClientProvider
kratos.Provider
Registry
client.Registry

Expand Down
25 changes: 18 additions & 7 deletions consent/strategy_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,9 @@ func (s *DefaultStrategy) revokeAuthenticationSession(ctx context.Context, w htt
return nil
}

return s.r.ConsentManager().DeleteLoginSession(r.Context(), sid)
_, err = s.r.ConsentManager().DeleteLoginSession(r.Context(), sid)

return err
}

func (s *DefaultStrategy) revokeAuthenticationCookie(w http.ResponseWriter, r *http.Request, ss sessions.Store) (string, error) {
Expand Down Expand Up @@ -458,6 +460,7 @@ func (s *DefaultStrategy) verifyAuthentication(
return nil, fosite.ErrAccessDenied.WithHint("The login session cookie was not found or malformed.")
}

loginSession.IdentityProviderSessionID = f.IdentityProviderSessionID
if err := s.r.ConsentManager().ConfirmLoginSession(ctx, loginSession, sessionID, time.Time(session.AuthenticatedAt), session.Subject, session.Remember); err != nil {
return nil, err
}
Expand Down Expand Up @@ -731,7 +734,8 @@ func (s *DefaultStrategy) generateFrontChannelLogoutURLs(ctx context.Context, su
return urls, nil
}

func (s *DefaultStrategy) executeBackChannelLogout(ctx context.Context, r *http.Request, subject, sid string) error {
func (s *DefaultStrategy) executeBackChannelLogout(r *http.Request, subject, sid string) error {
ctx := r.Context()
clients, err := s.r.ConsentManager().ListUserAuthenticatedClientsWithBackChannelLogout(ctx, subject, sid)
if err != nil {
return err
Expand Down Expand Up @@ -1000,8 +1004,9 @@ func (s *DefaultStrategy) issueLogoutVerifier(ctx context.Context, w http.Respon
return nil, errorsx.WithStack(ErrAbortOAuth2Request)
}

func (s *DefaultStrategy) performBackChannelLogoutAndDeleteSession(_ context.Context, r *http.Request, subject string, sid string) error {
if err := s.executeBackChannelLogout(r.Context(), r, subject, sid); err != nil {
func (s *DefaultStrategy) performBackChannelLogoutAndDeleteSession(r *http.Request, subject string, sid string) error {
ctx := r.Context()
if err := s.executeBackChannelLogout(r, subject, sid); err != nil {
return err
}

Expand All @@ -1010,10 +1015,16 @@ func (s *DefaultStrategy) performBackChannelLogoutAndDeleteSession(_ context.Con
//
// executeBackChannelLogout only fails on system errors so not on URL errors, so this should be fine
// even if an upstream URL fails!
if err := s.r.ConsentManager().DeleteLoginSession(r.Context(), sid); errors.Is(err, sqlcon.ErrNoRows) {
if session, err := s.r.ConsentManager().DeleteLoginSession(ctx, sid); errors.Is(err, sqlcon.ErrNoRows) {
// This is ok (session probably already revoked), do nothing!
} else if err != nil {
return err
} else {
innerErr := s.r.Kratos().DisableSession(ctx, session.IdentityProviderSessionID.String())
if innerErr != nil {
s.r.Logger().WithError(innerErr).WithField("sid", sid).Error("Unable to revoke session in ORY Kratos.")
}
// We don't return the error here because we don't want to break the logout flow if Kratos is down.
}

return nil
Expand Down Expand Up @@ -1068,7 +1079,7 @@ func (s *DefaultStrategy) completeLogout(ctx context.Context, w http.ResponseWri
return nil, err
}

if err := s.performBackChannelLogoutAndDeleteSession(r.Context(), r, lr.Subject, lr.SessionID); err != nil {
if err := s.performBackChannelLogoutAndDeleteSession(r, lr.Subject, lr.SessionID); err != nil {
return nil, err
}

Expand Down Expand Up @@ -1105,7 +1116,7 @@ func (s *DefaultStrategy) HandleHeadlessLogout(ctx context.Context, _ http.Respo
return lsErr
}

if err := s.performBackChannelLogoutAndDeleteSession(r.Context(), r, loginSession.Subject, sid); err != nil {
if err := s.performBackChannelLogoutAndDeleteSession(r, loginSession.Subject, sid); err != nil {
return err
}

Expand Down
14 changes: 5 additions & 9 deletions consent/strategy_default_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,21 @@ import (
"net/http"
"net/http/cookiejar"
"net/http/httptest"
"testing"

hydra "github.com/ory/hydra-client-go/v2"

"github.com/stretchr/testify/require"

"github.com/ory/fosite/token/jwt"
"github.com/ory/x/urlx"

"net/url"
"testing"

"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"

"github.com/ory/fosite/token/jwt"
hydra "github.com/ory/hydra-client-go/v2"
"github.com/ory/hydra/v2/client"
. "github.com/ory/hydra/v2/consent"
"github.com/ory/hydra/v2/driver"
"github.com/ory/hydra/v2/internal/testhelpers"
"github.com/ory/x/ioutilx"
"github.com/ory/x/urlx"
)

func checkAndAcceptLoginHandler(t *testing.T, apiClient *hydra.APIClient, subject string, cb func(*testing.T, *hydra.OAuth2LoginRequest, error) hydra.AcceptOAuth2LoginRequest) http.HandlerFunc {
Expand Down
22 changes: 21 additions & 1 deletion consent/strategy_logout_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"testing"
"time"

"github.com/ory/hydra/v2/internal/kratos"
"github.com/ory/x/pointerx"

"github.com/stretchr/testify/assert"
Expand All @@ -35,9 +36,11 @@ import (

func TestLogoutFlows(t *testing.T) {
ctx := context.Background()
fakeKratos := kratos.NewFake()
reg := internal.NewMockedRegistry(t, &contextx.Default{})
reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque")
reg.Config().MustSet(ctx, config.KeyConsentRequestMaxAge, time.Hour)
reg.WithKratos(fakeKratos)

defaultRedirectedMessage := "redirected to default server"
postLogoutCallback := func(w http.ResponseWriter, r *http.Request) {
Expand Down Expand Up @@ -181,7 +184,10 @@ func TestLogoutFlows(t *testing.T) {
checkAndAcceptLoginHandler(t, adminApi, subject, func(t *testing.T, res *hydra.OAuth2LoginRequest, err error) hydra.AcceptOAuth2LoginRequest {
require.NoError(t, err)
//res.Payload.SessionID
return hydra.AcceptOAuth2LoginRequest{Remember: pointerx.Bool(true)}
return hydra.AcceptOAuth2LoginRequest{
Remember: pointerx.Ptr(true),
IdentityProviderSessionId: pointerx.Ptr(kratos.FakeSessionID),
}
}),
checkAndAcceptConsentHandler(t, adminApi, func(t *testing.T, res *hydra.OAuth2ConsentRequest, err error) hydra.AcceptOAuth2ConsentRequest {
require.NoError(t, err)
Expand Down Expand Up @@ -476,6 +482,7 @@ func TestLogoutFlows(t *testing.T) {
})

t.Run("case=should return to default post logout because session was revoked in browser context", func(t *testing.T) {
fakeKratos.Reset()
c := createSampleClient(t)
sid := make(chan string)
acceptLoginAsAndWatchSid(t, subject, sid)
Expand Down Expand Up @@ -518,9 +525,13 @@ func TestLogoutFlows(t *testing.T) {
assert.NotEmpty(t, res.Request.URL.Query().Get("code"))

wg.Wait()

assert.True(t, fakeKratos.DisableSessionWasCalled)
assert.Equal(t, fakeKratos.LastDisabledSession, kratos.FakeSessionID)
})

t.Run("case=should execute backchannel logout in headless flow with sid", func(t *testing.T) {
fakeKratos.Reset()
numSidConsumers := 2
sid := make(chan string, numSidConsumers)
acceptLoginAsAndWatchSidForConsumers(t, subject, sid, true, numSidConsumers)
Expand All @@ -535,22 +546,31 @@ func TestLogoutFlows(t *testing.T) {
logoutViaHeadlessAndExpectNoContent(t, createBrowserWithSession(t, c), url.Values{"sid": {<-sid}})

backChannelWG.Wait() // we want to ensure that all back channels have been called!
assert.True(t, fakeKratos.DisableSessionWasCalled)
assert.Equal(t, fakeKratos.LastDisabledSession, kratos.FakeSessionID)
})

t.Run("case=should logout in headless flow with non-existing sid", func(t *testing.T) {
fakeKratos.Reset()
logoutViaHeadlessAndExpectNoContent(t, browserWithoutSession, url.Values{"sid": {"non-existing-sid"}})
assert.False(t, fakeKratos.DisableSessionWasCalled)
})

t.Run("case=should logout in headless flow with session that has remember=false", func(t *testing.T) {
fakeKratos.Reset()
sid := make(chan string)
acceptLoginAsAndWatchSidForConsumers(t, subject, sid, false, 1)

c := createSampleClient(t)

logoutViaHeadlessAndExpectNoContent(t, createBrowserWithSession(t, c), url.Values{"sid": {<-sid}})
assert.True(t, fakeKratos.DisableSessionWasCalled)
assert.Equal(t, fakeKratos.LastDisabledSession, kratos.FakeSessionID)
})

t.Run("case=should fail headless logout because neither sid nor subject were provided", func(t *testing.T) {
fakeKratos.Reset()
logoutViaHeadlessAndExpectError(t, browserWithoutSession, url.Values{}, `Either 'subject' or 'sid' query parameters need to be defined.`)
assert.False(t, fakeKratos.DisableSessionWasCalled)
})
}
13 changes: 11 additions & 2 deletions driver/config/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ const (
KeyPublicURL = "urls.self.public"
KeyAdminURL = "urls.self.admin"
KeyIssuerURL = "urls.self.issuer"
KeyIdentityProviderAdminURL = "urls.identity_provider.admin_base_url"
KeyAccessTokenStrategy = "strategies.access_token"
KeyJWTScopeClaimStrategy = "strategies.jwt.scope_claim"
KeyDBIgnoreUnknownTableColumns = "db.ignore_unknown_table_columns"
Expand All @@ -104,8 +105,10 @@ const (

const DSNMemory = "memory"

var _ hasherx.PBKDF2Configurator = (*DefaultProvider)(nil)
var _ hasherx.BCryptConfigurator = (*DefaultProvider)(nil)
var (
_ hasherx.PBKDF2Configurator = (*DefaultProvider)(nil)
_ hasherx.BCryptConfigurator = (*DefaultProvider)(nil)
)

type DefaultProvider struct {
l *logrusx.Logger
Expand Down Expand Up @@ -393,6 +396,12 @@ func (p *DefaultProvider) IssuerURL(ctx context.Context) *url.URL {
)
}

func (p *DefaultProvider) KratosAdminURL(ctx context.Context) (*url.URL, bool) {
u := p.getProvider(ctx).RequestURIF(KeyIdentityProviderAdminURL, nil)

return u, u != nil
}

func (p *DefaultProvider) OAuth2ClientRegistrationURL(ctx context.Context) *url.URL {
return p.getProvider(ctx).RequestURIF(KeyOAuth2ClientRegistrationURL, new(url.URL))
}
Expand Down
4 changes: 4 additions & 0 deletions driver/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (

"go.opentelemetry.io/otel/trace"

"github.com/ory/hydra/v2/internal/kratos"
"github.com/ory/x/httprouterx"
"github.com/ory/x/popx"

Expand Down Expand Up @@ -54,6 +55,7 @@ type Registry interface {
WithLogger(l *logrusx.Logger) Registry
WithTracer(t trace.Tracer) Registry
WithTracerWrapper(TracerWrapper) Registry
WithKratos(k kratos.Client) Registry
x.HTTPClientProvider
GetJWKSFetcherStrategy() fosite.JWKSFetcherStrategy

Expand All @@ -72,6 +74,8 @@ type Registry interface {
x.TracingProvider
FlowCipher() *aead.XChaCha20Poly1305

kratos.Provider

RegisterRoutes(ctx context.Context, admin *httprouterx.RouterAdmin, public *httprouterx.RouterPublic)
ClientHandler() *client.Handler
KeyHandler() *jwk.Handler
Expand Down
14 changes: 14 additions & 0 deletions driver/registry_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"github.com/ory/hydra/v2/driver/config"
"github.com/ory/hydra/v2/fositex"
"github.com/ory/hydra/v2/hsm"
"github.com/ory/hydra/v2/internal/kratos"
"github.com/ory/hydra/v2/jwk"
"github.com/ory/hydra/v2/oauth2"
"github.com/ory/hydra/v2/oauth2/trust"
Expand Down Expand Up @@ -88,6 +89,7 @@ type RegistryBase struct {
hmacs *foauth2.HMACSHAStrategy
fc *fositex.Config
publicCORS *cors.Cors
kratos kratos.Client
}

func (m *RegistryBase) GetJWKSFetcherStrategy() fosite.JWKSFetcherStrategy {
Expand Down Expand Up @@ -201,6 +203,11 @@ func (m *RegistryBase) WithTracerWrapper(wrapper TracerWrapper) Registry {
return m.r
}

func (m *RegistryBase) WithKratos(k kratos.Client) Registry {
m.kratos = k
return m.r
}

func (m *RegistryBase) Logger() *logrusx.Logger {
if m.l == nil {
m.l = logrusx.New("Ory Hydra", m.BuildVersion())
Expand Down Expand Up @@ -552,3 +559,10 @@ func (m *RegistryBase) HSMContext() hsm.Context {
func (m *RegistrySQL) ClientAuthenticator() x.ClientAuthenticator {
return m.OAuth2Provider().(*fosite.Fosite)
}

func (m *RegistryBase) Kratos() kratos.Client {
if m.kratos == nil {
m.kratos = kratos.New(m)
}
return m.kratos
}
17 changes: 12 additions & 5 deletions flow/consent_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,12 @@ type OAuth2RedirectTo struct {

// swagger:ignore
type LoginSession struct {
ID string `db:"id"`
NID uuid.UUID `db:"nid"`
AuthenticatedAt sqlxx.NullTime `db:"authenticated_at"`
Subject string `db:"subject"`
Remember bool `db:"remember"`
ID string `db:"id"`
NID uuid.UUID `db:"nid"`
AuthenticatedAt sqlxx.NullTime `db:"authenticated_at"`
Subject string `db:"subject"`
IdentityProviderSessionID sqlxx.NullString `db:"identity_provider_session_id"`
Remember bool `db:"remember"`
}

func (LoginSession) TableName() string {
Expand Down Expand Up @@ -292,6 +293,12 @@ type HandledLoginRequest struct {
// required: true
Subject string `json:"subject"`

// IdentityProviderSessionID is the session ID of the end-user that authenticated.
// If specified, we will use this value to propagate the logout.
//
// required: false
IdentityProviderSessionID string `json:"identity_provider_session_id,omitempty"`

// ForceSubjectIdentifier forces the "pairwise" user ID of the end-user that authenticated. The "pairwise" user ID refers to the
// (Pairwise Identifier Algorithm)[http://openid.net/specs/openid-connect-core-1_0.html#PairwiseAlg] of the OpenID
// Connect specification. It allows you to set an obfuscated subject ("user") identifier that is unique to the client.
Expand Down
Loading

0 comments on commit c004fee

Please sign in to comment.