From 19ed782d7166cf77079b7a4c5ecef092de48e2bc Mon Sep 17 00:00:00 2001 From: aeneasr <3372410+aeneasr@users.noreply.github.com> Date: Mon, 8 Feb 2021 17:34:53 +0100 Subject: [PATCH] fix: remove stray non-ctx configs --- cmd/daemon/serve.go | 2 +- courier/courier.go | 7 +- courier/courier_test.go | 2 +- driver/registry.go | 2 +- driver/registry_default.go | 125 +++++++----------- driver/registry_default_login.go | 11 +- driver/registry_default_recovery.go | 16 +-- driver/registry_default_registration.go | 15 ++- driver/registry_default_settings.go | 29 ++-- driver/registry_default_test.go | 20 +-- driver/registry_default_verify.go | 21 +++ internal/testhelpers/selfservice_settings.go | 4 +- persistence/sql/persister_courier.go | 1 + persistence/sql/persister_errorx.go | 3 + persistence/sql/persister_identity.go | 1 + persistence/sql/persister_session.go | 3 + selfservice/flow/login/error.go | 12 +- selfservice/flow/login/error_test.go | 2 +- selfservice/flow/login/handler.go | 2 +- selfservice/flow/login/hook.go | 13 +- selfservice/flow/login/strategy.go | 3 +- selfservice/flow/logout/handler.go | 12 +- selfservice/flow/recovery/error.go | 2 +- selfservice/flow/recovery/error_test.go | 2 +- selfservice/flow/recovery/handler.go | 20 ++- selfservice/flow/recovery/strategy.go | 3 +- selfservice/flow/registration/error_test.go | 2 +- selfservice/flow/registration/handler.go | 2 +- selfservice/flow/registration/hook.go | 21 +-- selfservice/flow/registration/strategy.go | 3 +- selfservice/flow/settings/error_test.go | 2 +- selfservice/flow/settings/handler.go | 2 +- selfservice/flow/settings/hook.go | 15 ++- selfservice/flow/settings/strategy.go | 3 +- selfservice/flow/verification/error.go | 2 +- selfservice/flow/verification/handler.go | 20 ++- selfservice/flow/verification/strategy.go | 3 +- selfservice/strategy/handler.go | 37 ++++-- selfservice/strategy/link/sender.go | 2 +- selfservice/strategy/link/sender_test.go | 4 +- .../strategy/link/strategy_recovery.go | 9 +- .../strategy/link/strategy_verification.go | 17 +-- .../strategy/oidc/strategy_settings_test.go | 2 +- selfservice/strategy/oidc/strategy_test.go | 4 +- .../strategy/password/registration_test.go | 3 +- session/manager_http.go | 6 +- x/provider.go | 2 +- 47 files changed, 273 insertions(+), 221 deletions(-) diff --git a/cmd/daemon/serve.go b/cmd/daemon/serve.go index dfdc48eff12e..bf64bc57848b 100644 --- a/cmd/daemon/serve.go +++ b/cmd/daemon/serve.go @@ -226,7 +226,7 @@ func bgTasks(d driver.Registry, wg *sync.WaitGroup, cmd *cobra.Command, args []s d.Logger().Println("Courier worker started.") if err := graceful.Graceful(func() error { - return d.Courier().Work(ctx) + return d.Courier(ctx).Work(ctx) }, func(_ cx.Context) error { cancel() return nil diff --git a/courier/courier.go b/courier/courier.go index 9c2b1f194c46..3766a01acdb1 100644 --- a/courier/courier.go +++ b/courier/courier.go @@ -23,14 +23,14 @@ type ( smtpDependencies interface { PersistenceProvider x.LoggingProvider + config.Provider } Courier struct { Dialer *gomail.Dialer d smtpDependencies - c *config.Config } Provider interface { - Courier() *Courier + Courier(ctx context.Context) *Courier } ) @@ -50,7 +50,6 @@ func NewSMTP(d smtpDependencies, c *config.Config) *Courier { return &Courier{ d: d, - c: c, Dialer: &gomail.Dialer{ /* #nosec we need to support SMTP servers without TLS */ TLSConfig: tlsConfig, @@ -130,7 +129,7 @@ func (m *Courier) watchMessages(ctx context.Context, errChan chan error) { switch msg.Type { case MessageTypeEmail: - from := m.c.CourierSMTPFrom() + from := m.d.Config(ctx).CourierSMTPFrom() gm := gomail.NewMessage() gm.SetHeader("From", from) gm.SetHeader("To", msg.Recipient) diff --git a/courier/courier_test.go b/courier/courier_test.go index 139f51d89628..685de533867a 100644 --- a/courier/courier_test.go +++ b/courier/courier_test.go @@ -100,7 +100,7 @@ func TestSMTP(t *testing.T) { conf, reg := internal.NewFastRegistryWithMocks(t) conf.MustSet(config.ViperKeyCourierSMTPURL, smtp) conf.MustSet(config.ViperKeyCourierSMTPFrom, "test-stub@ory.sh") - c := reg.Courier() + c := reg.Courier(context.Background()) ctx, cancel := context.WithCancel(context.Background()) diff --git a/driver/registry.go b/driver/registry.go index 178da8029e59..f57bba53a168 100644 --- a/driver/registry.go +++ b/driver/registry.go @@ -49,7 +49,7 @@ type Registry interface { WithCSRFTokenGenerator(cg x.CSRFToken) HealthHandler() *healthx.Handler - CookieManager() sessions.Store + CookieManager(ctx context.Context) sessions.Store ContinuityCookieManager(ctx context.Context) sessions.Store RegisterRoutes(public *x.RouterPublic, admin *x.RouterAdmin) diff --git a/driver/registry_default.go b/driver/registry_default.go index f879a77e608f..63e65da6f6df 100644 --- a/driver/registry_default.go +++ b/driver/registry_default.go @@ -68,7 +68,6 @@ type RegistryDefault struct { healthxHandler *healthx.Handler metricsHandler *prometheus.Handler - courier *courier.Courier persister persistence.Persister hookVerifier *hook.Verifier @@ -84,7 +83,6 @@ type RegistryDefault struct { schemaHandler *schema.Handler sessionHandler *session.Handler - sessionsStore *sessions.CookieStore sessionManager session.Manager passwordHasher hash.Hasher @@ -117,13 +115,7 @@ type RegistryDefault struct { selfserviceLogoutHandler *logout.Handler - selfserviceStrategies []interface{} - loginStrategies []login.Strategy - activeCredentialsCounterStrategies []identity.ActiveCredentialsCounter - registrationStrategies []registration.Strategy - profileStrategies []settings.Strategy - recoveryStrategies []recovery.Strategy - verificationStrategies []verification.Strategy + selfserviceStrategies []interface{} buildVersion string buildHash string @@ -151,13 +143,11 @@ func (m *RegistryDefault) RegisterPublicRoutes(router *x.RouterPublic) { m.SelfServiceErrorHandler().RegisterPublicRoutes(router) m.SchemaHandler().RegisterPublicRoutes(router) - if m.c.SelfServiceFlowRecoveryEnabled() { - m.AllRecoveryStrategies().RegisterPublicRoutes(router) - m.RecoveryHandler().RegisterPublicRoutes(router) - } + m.AllRecoveryStrategies().RegisterPublicRoutes(router) + m.RecoveryHandler().RegisterPublicRoutes(router) m.VerificationHandler().RegisterPublicRoutes(router) - m.VerificationStrategies().RegisterPublicRoutes(router) + m.AllVerificationStrategies().RegisterPublicRoutes(router) m.HealthHandler().SetRoutes(router.Router, false) } @@ -171,13 +161,11 @@ func (m *RegistryDefault) RegisterAdminRoutes(router *x.RouterAdmin) { m.SessionHandler().RegisterAdminRoutes(router) m.SelfServiceErrorHandler().RegisterAdminRoutes(router) - if m.c.SelfServiceFlowRecoveryEnabled() { - m.RecoveryHandler().RegisterAdminRoutes(router) - m.AllRecoveryStrategies().RegisterAdminRoutes(router) - } + m.RecoveryHandler().RegisterAdminRoutes(router) + m.AllRecoveryStrategies().RegisterAdminRoutes(router) m.VerificationHandler().RegisterAdminRoutes(router) - m.VerificationStrategies().RegisterAdminRoutes(router) + m.AllVerificationStrategies().RegisterAdminRoutes(router) m.HealthHandler().SetRoutes(router.Router, true) m.MetricsHandler().SetRoutes(router.Router) @@ -199,7 +187,7 @@ func (m *RegistryDefault) WithLogger(l *logrusx.Logger) Registry { func (m *RegistryDefault) LogoutHandler() *logout.Handler { if m.selfserviceLogoutHandler == nil { - m.selfserviceLogoutHandler = logout.NewHandler(m, m.c) + m.selfserviceLogoutHandler = logout.NewHandler(m) } return m.selfserviceLogoutHandler } @@ -252,17 +240,15 @@ func (m *RegistryDefault) selfServiceStrategies() []interface{} { return m.selfserviceStrategies } -func (m *RegistryDefault) RegistrationStrategies() registration.Strategies { - if len(m.registrationStrategies) == 0 { - for _, strategy := range m.selfServiceStrategies() { - if s, ok := strategy.(registration.Strategy); ok { - if m.c.SelfServiceStrategy(string(s.ID())).Enabled { - m.registrationStrategies = append(m.registrationStrategies, s) - } +func (m *RegistryDefault) RegistrationStrategies(ctx context.Context) (registrationStrategies registration.Strategies) { + for _, strategy := range m.selfServiceStrategies() { + if s, ok := strategy.(registration.Strategy); ok { + if m.Config(ctx).SelfServiceStrategy(string(s.ID())).Enabled { + registrationStrategies = append(registrationStrategies, s) } } } - return m.registrationStrategies + return } func (m *RegistryDefault) AllRegistrationStrategies() registration.Strategies { @@ -276,17 +262,15 @@ func (m *RegistryDefault) AllRegistrationStrategies() registration.Strategies { return registrationStrategies } -func (m *RegistryDefault) LoginStrategies() login.Strategies { - if len(m.loginStrategies) == 0 { - for _, strategy := range m.selfServiceStrategies() { - if s, ok := strategy.(login.Strategy); ok { - if m.c.SelfServiceStrategy(string(s.ID())).Enabled { - m.loginStrategies = append(m.loginStrategies, s) - } +func (m *RegistryDefault) LoginStrategies(ctx context.Context) (loginStrategies login.Strategies) { + for _, strategy := range m.selfServiceStrategies() { + if s, ok := strategy.(login.Strategy); ok { + if m.Config(ctx).SelfServiceStrategy(string(s.ID())).Enabled { + loginStrategies = append(loginStrategies, s) } } } - return m.loginStrategies + return } func (m *RegistryDefault) AllLoginStrategies() login.Strategies { @@ -299,26 +283,13 @@ func (m *RegistryDefault) AllLoginStrategies() login.Strategies { return loginStrategies } -func (m *RegistryDefault) VerificationStrategies() verification.Strategies { - if len(m.verificationStrategies) == 0 { - for _, strategy := range m.selfServiceStrategies() { - if s, ok := strategy.(verification.Strategy); ok { - m.verificationStrategies = append(m.verificationStrategies, s) - } - } - } - return m.verificationStrategies -} - -func (m *RegistryDefault) ActiveCredentialsCounterStrategies(ctx context.Context) []identity.ActiveCredentialsCounter { - if len(m.activeCredentialsCounterStrategies) == 0 { - for _, strategy := range m.selfServiceStrategies() { - if s, ok := strategy.(identity.ActiveCredentialsCounter); ok { - m.activeCredentialsCounterStrategies = append(m.activeCredentialsCounterStrategies, s) - } +func (m *RegistryDefault) ActiveCredentialsCounterStrategies(ctx context.Context) (activeCredentialsCounterStrategies []identity.ActiveCredentialsCounter) { + for _, strategy := range m.selfServiceStrategies() { + if s, ok := strategy.(identity.ActiveCredentialsCounter); ok { + activeCredentialsCounterStrategies = append(activeCredentialsCounterStrategies, s) } } - return m.activeCredentialsCounterStrategies + return } func (m *RegistryDefault) IdentityValidator() *identity.Validator { @@ -390,30 +361,27 @@ func (m *RegistryDefault) SelfServiceErrorHandler() *errorx.Handler { return m.errorHandler } -func (m *RegistryDefault) CookieManager() sessions.Store { - if m.sessionsStore == nil { - cs := sessions.NewCookieStore(m.c.SecretsSession()...) - cs.Options.Secure = !m.c.IsInsecureDevMode() - cs.Options.HttpOnly = true - if m.c.SessionDomain() != "" { - cs.Options.Domain = m.c.SessionDomain() - } +func (m *RegistryDefault) CookieManager(ctx context.Context) sessions.Store { + cs := sessions.NewCookieStore(m.Config(ctx).SecretsSession()...) + cs.Options.Secure = !m.Config(ctx).IsInsecureDevMode() + cs.Options.HttpOnly = true + if domain := m.Config(ctx).SessionDomain(); domain != "" { + cs.Options.Domain = domain + } - if m.c.SessionPath() != "" { - cs.Options.Path = m.c.SessionPath() - } + if path := m.Config(ctx).SessionPath(); path != "" { + cs.Options.Path = path + } - if m.c.SessionSameSiteMode() != 0 { - cs.Options.SameSite = m.c.SessionSameSiteMode() - } + if sameSite := m.Config(ctx).SessionSameSiteMode(); sameSite != 0 { + cs.Options.SameSite = sameSite + } - cs.Options.MaxAge = 0 - if m.c.SessionPersistentCookie() { - cs.Options.MaxAge = int(m.c.SessionLifespan().Seconds()) - } - m.sessionsStore = cs + cs.Options.MaxAge = 0 + if m.Config(ctx).SessionPersistentCookie() { + cs.Options.MaxAge = int(m.Config(ctx).SessionLifespan().Seconds()) } - return m.sessionsStore + return cs } func (m *RegistryDefault) ContinuityCookieManager(ctx context.Context) sessions.Store { @@ -501,7 +469,7 @@ func (m *RegistryDefault) Init(ctx context.Context) error { } // if dsn is memory we have to run the migrations on every start - if dbal.InMemoryDSN == m.c.DSN() { + if dbal.InMemoryDSN == m.Config(ctx).DSN() { m.Logger().Infoln("ORY Kratos is running migrations on every startup as DSN is memory. This means your data is lost when Kratos terminates.") if err := p.MigrateUp(ctx); err != nil { return err @@ -514,11 +482,8 @@ func (m *RegistryDefault) Init(ctx context.Context) error { ) } -func (m *RegistryDefault) Courier() *courier.Courier { - if m.courier == nil { - m.courier = courier.NewSMTP(m, m.c) - } - return m.courier +func (m *RegistryDefault) Courier(ctx context.Context) *courier.Courier { + return courier.NewSMTP(m, m.Config(ctx)) } func (m *RegistryDefault) ContinuityManager() continuity.Manager { diff --git a/driver/registry_default_login.go b/driver/registry_default_login.go index 0969f54b7733..f08a88704db2 100644 --- a/driver/registry_default_login.go +++ b/driver/registry_default_login.go @@ -1,6 +1,7 @@ package driver import ( + "context" "github.com/ory/kratos/identity" "github.com/ory/kratos/selfservice/flow/login" ) @@ -12,8 +13,8 @@ func (m *RegistryDefault) LoginHookExecutor() *login.HookExecutor { return m.selfserviceLoginExecutor } -func (m *RegistryDefault) PreLoginHooks() (b []login.PreHookExecutor) { - for _, v := range m.getHooks("", m.c.SelfServiceFlowLoginBeforeHooks()) { +func (m *RegistryDefault) PreLoginHooks(ctx context.Context) (b []login.PreHookExecutor) { + for _, v := range m.getHooks("", m.Config(ctx).SelfServiceFlowLoginBeforeHooks()) { if hook, ok := v.(login.PreHookExecutor); ok { b = append(b, hook) } @@ -21,8 +22,8 @@ func (m *RegistryDefault) PreLoginHooks() (b []login.PreHookExecutor) { return } -func (m *RegistryDefault) PostLoginHooks(credentialsType identity.CredentialsType) (b []login.PostHookExecutor) { - for _, v := range m.getHooks(string(credentialsType), m.c.SelfServiceFlowLoginAfterHooks(string(credentialsType))) { +func (m *RegistryDefault) PostLoginHooks(ctx context.Context, credentialsType identity.CredentialsType) (b []login.PostHookExecutor) { + for _, v := range m.getHooks(string(credentialsType), m.Config(ctx).SelfServiceFlowLoginAfterHooks(string(credentialsType))) { if hook, ok := v.(login.PostHookExecutor); ok { b = append(b, hook) } @@ -40,7 +41,7 @@ func (m *RegistryDefault) LoginHandler() *login.Handler { func (m *RegistryDefault) LoginFlowErrorHandler() *login.ErrorHandler { if m.selfserviceLoginRequestErrorHandler == nil { - m.selfserviceLoginRequestErrorHandler = login.NewFlowErrorHandler(m, m.c) + m.selfserviceLoginRequestErrorHandler = login.NewFlowErrorHandler(m) } return m.selfserviceLoginRequestErrorHandler diff --git a/driver/registry_default_recovery.go b/driver/registry_default_recovery.go index a86955dde4aa..a733b210795d 100644 --- a/driver/registry_default_recovery.go +++ b/driver/registry_default_recovery.go @@ -1,6 +1,7 @@ package driver import ( + "context" "github.com/ory/kratos/selfservice/flow/recovery" ) @@ -20,25 +21,22 @@ func (m *RegistryDefault) RecoveryHandler() *recovery.Handler { return m.selfserviceRecoveryHandler } -func (m *RegistryDefault) RecoveryStrategies() recovery.Strategies { - if len(m.recoveryStrategies) == 0 { +func (m *RegistryDefault) RecoveryStrategies(ctx context.Context) (recoveryStrategies recovery.Strategies) { for _, strategy := range m.selfServiceStrategies() { if s, ok := strategy.(recovery.Strategy); ok { - if m.c.SelfServiceStrategy(s.RecoveryStrategyID()).Enabled { - m.recoveryStrategies = append(m.recoveryStrategies, s) + if m.Config(ctx).SelfServiceStrategy(s.RecoveryStrategyID()).Enabled { + recoveryStrategies = append(recoveryStrategies, s) } } } - } - return m.recoveryStrategies + return } -func (m *RegistryDefault) AllRecoveryStrategies() recovery.Strategies { - var recoveryStrategies []recovery.Strategy +func (m *RegistryDefault) AllRecoveryStrategies() (recoveryStrategies recovery.Strategies ){ for _, strategy := range m.selfServiceStrategies() { if s, ok := strategy.(recovery.Strategy); ok { recoveryStrategies = append(recoveryStrategies, s) } } - return recoveryStrategies + return } diff --git a/driver/registry_default_registration.go b/driver/registry_default_registration.go index b7b763b3bae9..1e786070ddab 100644 --- a/driver/registry_default_registration.go +++ b/driver/registry_default_registration.go @@ -1,12 +1,13 @@ package driver import ( + "context" "github.com/ory/kratos/identity" "github.com/ory/kratos/selfservice/flow/registration" ) -func (m *RegistryDefault) PostRegistrationPrePersistHooks(credentialsType identity.CredentialsType) (b []registration.PostHookPrePersistExecutor) { - for _, v := range m.getHooks(string(credentialsType), m.c.SelfServiceFlowRegistrationAfterHooks(string(credentialsType))) { +func (m *RegistryDefault) PostRegistrationPrePersistHooks(ctx context.Context, credentialsType identity.CredentialsType) (b []registration.PostHookPrePersistExecutor) { + for _, v := range m.getHooks(string(credentialsType), m.Config(ctx).SelfServiceFlowRegistrationAfterHooks(string(credentialsType))) { if hook, ok := v.(registration.PostHookPrePersistExecutor); ok { b = append(b, hook) } @@ -14,12 +15,12 @@ func (m *RegistryDefault) PostRegistrationPrePersistHooks(credentialsType identi return } -func (m *RegistryDefault) PostRegistrationPostPersistHooks(credentialsType identity.CredentialsType) (b []registration.PostHookPostPersistExecutor) { - if m.c.SelfServiceFlowVerificationEnabled() { +func (m *RegistryDefault) PostRegistrationPostPersistHooks(ctx context.Context, credentialsType identity.CredentialsType) (b []registration.PostHookPostPersistExecutor) { + if m.Config(ctx).SelfServiceFlowVerificationEnabled() { b = append(b, m.HookVerifier()) } - for _, v := range m.getHooks(string(credentialsType), m.c.SelfServiceFlowRegistrationAfterHooks(string(credentialsType))) { + for _, v := range m.getHooks(string(credentialsType), m.Config(ctx).SelfServiceFlowRegistrationAfterHooks(string(credentialsType))) { if hook, ok := v.(registration.PostHookPostPersistExecutor); ok { b = append(b, hook) } @@ -27,8 +28,8 @@ func (m *RegistryDefault) PostRegistrationPostPersistHooks(credentialsType ident return } -func (m *RegistryDefault) PreRegistrationHooks() (b []registration.PreHookExecutor) { - for _, v := range m.getHooks("", m.c.SelfServiceFlowRegistrationBeforeHooks()) { +func (m *RegistryDefault) PreRegistrationHooks(ctx context.Context, ) (b []registration.PreHookExecutor) { + for _, v := range m.getHooks("", m.Config(ctx).SelfServiceFlowRegistrationBeforeHooks()) { if hook, ok := v.(registration.PreHookExecutor); ok { b = append(b, hook) } diff --git a/driver/registry_default_settings.go b/driver/registry_default_settings.go index ca3cfa41a08f..94b3ee91e419 100644 --- a/driver/registry_default_settings.go +++ b/driver/registry_default_settings.go @@ -1,9 +1,12 @@ package driver -import "github.com/ory/kratos/selfservice/flow/settings" +import ( + "context" + "github.com/ory/kratos/selfservice/flow/settings" +) -func (m *RegistryDefault) PostSettingsPrePersistHooks(settingsType string) (b []settings.PostHookPrePersistExecutor) { - for _, v := range m.getHooks(settingsType, m.c.SelfServiceFlowSettingsAfterHooks(settingsType)) { +func (m *RegistryDefault) PostSettingsPrePersistHooks(ctx context.Context, settingsType string) (b []settings.PostHookPrePersistExecutor) { + for _, v := range m.getHooks(settingsType, m.Config(ctx).SelfServiceFlowSettingsAfterHooks(settingsType)) { if hook, ok := v.(settings.PostHookPrePersistExecutor); ok { b = append(b, hook) } @@ -11,12 +14,12 @@ func (m *RegistryDefault) PostSettingsPrePersistHooks(settingsType string) (b [] return } -func (m *RegistryDefault) PostSettingsPostPersistHooks(settingsType string) (b []settings.PostHookPostPersistExecutor) { - if m.c.SelfServiceFlowVerificationEnabled() { +func (m *RegistryDefault) PostSettingsPostPersistHooks(ctx context.Context, settingsType string) (b []settings.PostHookPostPersistExecutor) { + if m.Config(ctx).SelfServiceFlowVerificationEnabled() { b = append(b, m.HookVerifier()) } - for _, v := range m.getHooks(settingsType, m.c.SelfServiceFlowSettingsAfterHooks(settingsType)) { + for _, v := range m.getHooks(settingsType, m.Config(ctx).SelfServiceFlowSettingsAfterHooks(settingsType)) { if hook, ok := v.(settings.PostHookPostPersistExecutor); ok { b = append(b, hook) } @@ -45,17 +48,15 @@ func (m *RegistryDefault) SettingsFlowErrorHandler() *settings.ErrorHandler { return m.selfserviceSettingsErrorHandler } -func (m *RegistryDefault) SettingsStrategies() settings.Strategies { - if len(m.profileStrategies) == 0 { - for _, strategy := range m.selfServiceStrategies() { - if s, ok := strategy.(settings.Strategy); ok { - if m.c.SelfServiceStrategy(s.SettingsStrategyID()).Enabled { - m.profileStrategies = append(m.profileStrategies, s) - } +func (m *RegistryDefault) SettingsStrategies(ctx context.Context) (profileStrategies settings.Strategies) { + for _, strategy := range m.selfServiceStrategies() { + if s, ok := strategy.(settings.Strategy); ok { + if m.Config(ctx).SelfServiceStrategy(s.SettingsStrategyID()).Enabled { + profileStrategies = append(profileStrategies, s) } } } - return m.profileStrategies + return } func (m *RegistryDefault) AllSettingsStrategies() settings.Strategies { diff --git a/driver/registry_default_test.go b/driver/registry_default_test.go index d293414013a9..92712dbe13da 100644 --- a/driver/registry_default_test.go +++ b/driver/registry_default_test.go @@ -1,6 +1,7 @@ package driver_test import ( + "context" "fmt" "testing" @@ -21,19 +22,20 @@ import ( ) func TestDriverDefault_Hooks(t *testing.T) { + ctx := context.Background() t.Run("case=verification", func(t *testing.T) { conf, reg := internal.NewFastRegistryWithMocks(t) conf.MustSet(config.ViperKeySelfServiceVerificationEnabled, true) t.Run("type=registration", func(t *testing.T) { - h := reg.PostRegistrationPostPersistHooks(identity.CredentialsTypePassword) + h := reg.PostRegistrationPostPersistHooks(ctx, identity.CredentialsTypePassword) require.Len(t, h, 1) assert.Equal(t, []registration.PostHookPostPersistExecutor{hook.NewVerifier(reg)}, h) conf.MustSet(config.ViperKeySelfServiceRegistrationAfter+".password.hooks", []map[string]interface{}{{"hook": "session"}}) - h = reg.PostRegistrationPostPersistHooks(identity.CredentialsTypePassword) + h = reg.PostRegistrationPostPersistHooks(ctx, identity.CredentialsTypePassword) require.Len(t, h, 2) assert.Equal(t, []registration.PostHookPostPersistExecutor{ hook.NewVerifier(reg), @@ -42,19 +44,19 @@ func TestDriverDefault_Hooks(t *testing.T) { }) t.Run("type=login", func(t *testing.T) { - h := reg.PostLoginHooks(identity.CredentialsTypePassword) + h := reg.PostLoginHooks(ctx, identity.CredentialsTypePassword) require.Len(t, h, 0) conf.MustSet(config.ViperKeySelfServiceLoginAfter+".password.hooks", []map[string]interface{}{{"hook": "revoke_active_sessions"}}) - h = reg.PostLoginHooks(identity.CredentialsTypePassword) + h = reg.PostLoginHooks(ctx, identity.CredentialsTypePassword) require.Len(t, h, 1) assert.Equal(t, []login.PostHookExecutor{hook.NewSessionDestroyer(reg)}, h) }) t.Run("type=settings", func(t *testing.T) { - h := reg.PostSettingsPostPersistHooks("profile") + h := reg.PostSettingsPostPersistHooks(ctx, "profile") require.Len(t, h, 1) assert.Equal(t, []settings.PostHookPostPersistExecutor{hook.NewVerifier(reg)}, h) }) @@ -82,7 +84,7 @@ func TestDriverDefault_Strategies(t *testing.T) { tc.prep(conf) t.Run("case=registration", func(t *testing.T) { - s := reg.RegistrationStrategies() + s := reg.RegistrationStrategies(context.Background()) require.Len(t, s, len(tc.expect)) for k, e := range tc.expect { assert.Equal(t, e, s[k].ID().String()) @@ -90,7 +92,7 @@ func TestDriverDefault_Strategies(t *testing.T) { }) t.Run("case=login", func(t *testing.T) { - s := reg.LoginStrategies() + s := reg.LoginStrategies(context.Background()) require.Len(t, s, len(tc.expect)) for k, e := range tc.expect { assert.Equal(t, e, s[k].ID().String()) @@ -115,7 +117,7 @@ func TestDriverDefault_Strategies(t *testing.T) { conf, reg := internal.NewFastRegistryWithMocks(t) tc.prep(conf) - s := reg.RecoveryStrategies() + s := reg.RecoveryStrategies(context.Background()) require.Len(t, s, len(tc.expect)) for k, e := range tc.expect { assert.Equal(t, e, s[k].RecoveryStrategyID()) @@ -179,7 +181,7 @@ func TestDriverDefault_Strategies(t *testing.T) { reg, err := driver.NewRegistryFromDSN(conf, logrusx.New("", "")) require.NoError(t, err) - s := reg.SettingsStrategies() + s := reg.SettingsStrategies(context.Background()) require.Len(t, s, len(tc.expect)) for k, e := range tc.expect { diff --git a/driver/registry_default_verify.go b/driver/registry_default_verify.go index 81943fc2bc37..eb8ad01dddfd 100644 --- a/driver/registry_default_verify.go +++ b/driver/registry_default_verify.go @@ -1,6 +1,7 @@ package driver import ( + "context" "github.com/ory/kratos/identity" "github.com/ory/kratos/selfservice/flow/verification" "github.com/ory/kratos/selfservice/strategy/link" @@ -41,3 +42,23 @@ func (m *RegistryDefault) LinkSender() *link.Sender { return m.selfserviceLinkSender } + +func (m *RegistryDefault) VerificationStrategies(ctx context.Context) (verificationStrategies verification.Strategies) { + for _, strategy := range m.selfServiceStrategies() { + if s, ok := strategy.(verification.Strategy); ok { + if m.Config(ctx).SelfServiceStrategy(s.VerificationStrategyID()).Enabled { + verificationStrategies = append(verificationStrategies, s) + } + } + } + return +} + +func (m *RegistryDefault) AllVerificationStrategies() (recoveryStrategies verification.Strategies) { + for _, strategy := range m.selfServiceStrategies() { + if s, ok := strategy.(verification.Strategy); ok { + recoveryStrategies = append(recoveryStrategies, s) + } + } + return +} diff --git a/internal/testhelpers/selfservice_settings.go b/internal/testhelpers/selfservice_settings.go index 0944d5f03c91..2607fac29c1a 100644 --- a/internal/testhelpers/selfservice_settings.go +++ b/internal/testhelpers/selfservice_settings.go @@ -187,10 +187,10 @@ func NewSettingsAPIServer(t *testing.T, reg *driver.RegistryDefault, ids map[str reg.WithCSRFHandler(hh) reg.SettingsHandler().RegisterPublicRoutes(public) - reg.SettingsStrategies().RegisterPublicRoutes(public) + reg.SettingsStrategies(context.Background()).RegisterPublicRoutes(public) reg.LoginHandler().RegisterPublicRoutes(public) reg.LoginHandler().RegisterAdminRoutes(admin) - reg.LoginStrategies().RegisterPublicRoutes(public) + reg.LoginStrategies(context.Background()).RegisterPublicRoutes(public) tsp, tsa := httptest.NewServer(hh), httptest.NewServer(admin) t.Cleanup(tsp.Close) diff --git a/persistence/sql/persister_courier.go b/persistence/sql/persister_courier.go index 244b98f314d8..e951b7ec50de 100644 --- a/persistence/sql/persister_courier.go +++ b/persistence/sql/persister_courier.go @@ -67,6 +67,7 @@ func (p *Persister) LatestQueuedMessage(ctx context.Context) (*courier.Message, func (p *Persister) SetMessageStatus(ctx context.Context, id uuid.UUID, ms courier.MessageStatus) error { count, err := p.GetConnection(ctx).RawQuery( + // #nosec G201 fmt.Sprintf( "UPDATE %s SET status = ? WHERE id = ?", corp.ContextualizeTableName(ctx, "courier_messages"), diff --git a/persistence/sql/persister_errorx.go b/persistence/sql/persister_errorx.go index eb85be1795f7..4991b88c8dc6 100644 --- a/persistence/sql/persister_errorx.go +++ b/persistence/sql/persister_errorx.go @@ -47,6 +47,7 @@ func (p *Persister) Read(ctx context.Context, id uuid.UUID) (*errorx.ErrorContai return nil, sqlcon.HandleError(err) } + // #nosec G201 if err := p.GetConnection(ctx).RawQuery(fmt.Sprintf("UPDATE %s SET was_seen = true, seen_at = ? WHERE id = ?", corp.ContextualizeTableName(ctx, "selfservice_errors")), time.Now().UTC(), id).Exec(); err != nil { return nil, sqlcon.HandleError(err) } @@ -56,8 +57,10 @@ func (p *Persister) Read(ctx context.Context, id uuid.UUID) (*errorx.ErrorContai func (p *Persister) Clear(ctx context.Context, olderThan time.Duration, force bool) (err error) { if force { + // #nosec G201 err = p.GetConnection(ctx).RawQuery(fmt.Sprintf("DELETE FROM %s WHERE seen_at < ? AND seen_at IS NOT NULL", corp.ContextualizeTableName(ctx, "selfservice_errors")), olderThan).Exec() } else { + // #nosec G201 err = p.GetConnection(ctx).RawQuery(fmt.Sprintf("DELETE FROM %s WHERE was_seen=true AND seen_at < ? AND seen_at IS NOT NULL", corp.ContextualizeTableName(ctx, "selfservice_errors")), time.Now().UTC().Add(-olderThan)).Exec() } diff --git a/persistence/sql/persister_identity.go b/persistence/sql/persister_identity.go index 5d03c9b80605..6129df250e6f 100644 --- a/persistence/sql/persister_identity.go +++ b/persistence/sql/persister_identity.go @@ -61,6 +61,7 @@ func (p *Persister) FindByCredentialsIdentifier(ctx context.Context, ct identity match = strings.ToLower(match) } + // #nosec G201 if err := p.GetConnection(ctx).RawQuery(fmt.Sprintf(`SELECT ic.identity_id FROM %s ic diff --git a/persistence/sql/persister_session.go b/persistence/sql/persister_session.go index de802a8ab751..acabfbde0abc 100644 --- a/persistence/sql/persister_session.go +++ b/persistence/sql/persister_session.go @@ -40,6 +40,7 @@ func (p *Persister) DeleteSession(ctx context.Context, sid uuid.UUID) error { } func (p *Persister) DeleteSessionsByIdentity(ctx context.Context, identityID uuid.UUID) error { + // #nosec G201 if err := p.GetConnection(ctx).RawQuery(fmt.Sprintf( "DELETE FROM %s WHERE identity_id = ?", corp.ContextualizeTableName(ctx, "sessions"), @@ -66,6 +67,7 @@ func (p *Persister) GetSessionByToken(ctx context.Context, token string) (*sessi } func (p *Persister) DeleteSessionByToken(ctx context.Context, token string) error { + // #nosec G201 if err := p.GetConnection(ctx).RawQuery(fmt.Sprintf( "DELETE FROM %s WHERE token = ?", corp.ContextualizeTableName(ctx, "sessions"), @@ -76,6 +78,7 @@ func (p *Persister) DeleteSessionByToken(ctx context.Context, token string) erro } func (p *Persister) RevokeSessionByToken(ctx context.Context, token string) error { + // #nosec G201 if err := p.GetConnection(ctx).RawQuery(fmt.Sprintf( "UPDATE %s SET active = false WHERE token = ?", corp.ContextualizeTableName(ctx, "sessions"), diff --git a/selfservice/flow/login/error.go b/selfservice/flow/login/error.go index fdf21d354345..1e2af94765f0 100644 --- a/selfservice/flow/login/error.go +++ b/selfservice/flow/login/error.go @@ -30,6 +30,7 @@ type ( errorx.ManagementProvider x.WriterProvider x.LoggingProvider + config.Provider FlowPersistenceProvider HandlerProvider @@ -39,7 +40,6 @@ type ( ErrorHandler struct { d errorHandlerDependencies - c *config.Config } FlowExpiredError struct { @@ -59,8 +59,8 @@ func NewFlowExpiredError(at time.Time) *FlowExpiredError { } } -func NewFlowErrorHandler(d errorHandlerDependencies, c *config.Config) *ErrorHandler { - return &ErrorHandler{d: d, c: c} +func NewFlowErrorHandler(d errorHandlerDependencies) *ErrorHandler { + return &ErrorHandler{d: d} } func (s *ErrorHandler) WriteFlowError(w http.ResponseWriter, r *http.Request, ct identity.CredentialsType, f *Flow, err error) { @@ -91,10 +91,10 @@ func (s *ErrorHandler) WriteFlowError(w http.ResponseWriter, r *http.Request, ct } if f.Type == flow.TypeAPI { - http.Redirect(w, r, urlx.CopyWithQuery(urlx.AppendPaths(s.c.SelfPublicURL(), + http.Redirect(w, r, urlx.CopyWithQuery(urlx.AppendPaths(s.d.Config(r.Context()).SelfPublicURL(), RouteGetFlow), url.Values{"id": {a.ID.String()}}).String(), http.StatusFound) } else { - http.Redirect(w, r, a.AppendTo(s.c.SelfServiceFlowLoginUI()).String(), http.StatusFound) + http.Redirect(w, r, a.AppendTo(s.d.Config(r.Context()).SelfServiceFlowLoginUI()).String(), http.StatusFound) } return } @@ -117,7 +117,7 @@ func (s *ErrorHandler) WriteFlowError(w http.ResponseWriter, r *http.Request, ct } if f.Type == flow.TypeBrowser { - http.Redirect(w, r, f.AppendTo(s.c.SelfServiceFlowLoginUI()).String(), http.StatusFound) + http.Redirect(w, r, f.AppendTo(s.d.Config(r.Context()).SelfServiceFlowLoginUI()).String(), http.StatusFound) return } diff --git a/selfservice/flow/login/error_test.go b/selfservice/flow/login/error_test.go index c8be282ba401..cbe9affd4f25 100644 --- a/selfservice/flow/login/error_test.go +++ b/selfservice/flow/login/error_test.go @@ -60,7 +60,7 @@ func TestHandleError(t *testing.T) { newFlow := func(t *testing.T, ttl time.Duration, ft flow.Type) *login.Flow { req := &http.Request{URL: urlx.ParseOrPanic("/")} f := login.NewFlow(ttl, "csrf_token", req, ft) - for _, s := range reg.LoginStrategies() { + for _, s := range reg.LoginStrategies(context.Background()) { require.NoError(t, s.PopulateLoginMethod(req, f)) } diff --git a/selfservice/flow/login/handler.go b/selfservice/flow/login/handler.go index 948d95bf3231..c0c7f61286de 100644 --- a/selfservice/flow/login/handler.go +++ b/selfservice/flow/login/handler.go @@ -62,7 +62,7 @@ func (h *Handler) RegisterAdminRoutes(admin *x.RouterAdmin) { func (h *Handler) NewLoginFlow(w http.ResponseWriter, r *http.Request, flow flow.Type) (*Flow, error) { a := NewFlow(h.d.Config(r.Context()).SelfServiceFlowLoginRequestLifespan(), h.d.GenerateCSRFToken(r), r, flow) - for _, s := range h.d.LoginStrategies() { + for _, s := range h.d.LoginStrategies(r.Context()) { if err := s.PopulateLoginMethod(r, a); err != nil { return nil, err } diff --git a/selfservice/flow/login/hook.go b/selfservice/flow/login/hook.go index a88ee77f5b4d..ec7f25a834bb 100644 --- a/selfservice/flow/login/hook.go +++ b/selfservice/flow/login/hook.go @@ -1,6 +1,7 @@ package login import ( + "context" "fmt" "net/http" "time" @@ -24,8 +25,8 @@ type ( } HooksProvider interface { - PreLoginHooks() []PreHookExecutor - PostLoginHooks(credentialsType identity.CredentialsType) []PostHookExecutor + PreLoginHooks(ctx context.Context) []PreHookExecutor + PostLoginHooks(ctx context.Context, credentialsType identity.CredentialsType) []PostHookExecutor } ) @@ -67,14 +68,14 @@ func (e *HookExecutor) PostLoginHook(w http.ResponseWriter, r *http.Request, ct WithField("identity_id", i.ID). WithField("flow_method", ct). Debug("Running ExecuteLoginPostHook.") - for k, executor := range e.d.PostLoginHooks(ct) { + for k, executor := range e.d.PostLoginHooks(r.Context(), ct) { if err := executor.ExecuteLoginPostHook(w, r, a, s); err != nil { if errors.Is(err, ErrHookAbortFlow) { e.d.Logger(). WithRequest(r). WithField("executor", fmt.Sprintf("%T", executor)). WithField("executor_position", k). - WithField("executors", PostHookExecutorNames(e.d.PostLoginHooks(ct))). + WithField("executors", PostHookExecutorNames(e.d.PostLoginHooks(r.Context(), ct))). WithField("identity_id", i.ID). WithField("flow_method", ct). Debug("A ExecuteLoginPostHook hook aborted early.") @@ -87,7 +88,7 @@ func (e *HookExecutor) PostLoginHook(w http.ResponseWriter, r *http.Request, ct WithRequest(r). WithField("executor", fmt.Sprintf("%T", executor)). WithField("executor_position", k). - WithField("executors", PostHookExecutorNames(e.d.PostLoginHooks(ct))). + WithField("executors", PostHookExecutorNames(e.d.PostLoginHooks(r.Context(), ct))). WithField("identity_id", i.ID). WithField("flow_method", ct). Debug("ExecuteLoginPostHook completed successfully.") @@ -121,7 +122,7 @@ func (e *HookExecutor) PostLoginHook(w http.ResponseWriter, r *http.Request, ct } func (e *HookExecutor) PreLoginHook(w http.ResponseWriter, r *http.Request, a *Flow) error { - for _, executor := range e.d.PreLoginHooks() { + for _, executor := range e.d.PreLoginHooks(r.Context()) { if err := executor.ExecuteLoginPreHook(w, r, a); err != nil { return err } diff --git a/selfservice/flow/login/strategy.go b/selfservice/flow/login/strategy.go index feb18762d66d..01a186ea632b 100644 --- a/selfservice/flow/login/strategy.go +++ b/selfservice/flow/login/strategy.go @@ -1,6 +1,7 @@ package login import ( + "context" "net/http" "github.com/pkg/errors" @@ -44,5 +45,5 @@ func (s Strategies) RegisterPublicRoutes(r *x.RouterPublic) { } type StrategyProvider interface { - LoginStrategies() Strategies + LoginStrategies(ctx context.Context) Strategies } diff --git a/selfservice/flow/logout/handler.go b/selfservice/flow/logout/handler.go index dfee7a356305..4cca5c4acf63 100644 --- a/selfservice/flow/logout/handler.go +++ b/selfservice/flow/logout/handler.go @@ -21,18 +21,18 @@ type ( x.CSRFProvider session.ManagementProvider errorx.ManagementProvider + config.Provider } HandlerProvider interface { LogoutHandler() *Handler } Handler struct { - c *config.Config d handlerDependencies } ) -func NewHandler(d handlerDependencies, c *config.Config) *Handler { - return &Handler{d: d, c: c} +func NewHandler(d handlerDependencies) *Handler { + return &Handler{d: d} } func (h *Handler) RegisterPublicRoutes(router *x.RouterPublic) { @@ -66,10 +66,10 @@ func (h *Handler) logout(w http.ResponseWriter, r *http.Request, ps httprouter.P return } - ret, err := x.SecureRedirectTo(r, h.c.SelfServiceFlowLogoutRedirectURL(), + ret, err := x.SecureRedirectTo(r, h.d.Config(r.Context()).SelfServiceFlowLogoutRedirectURL(), x.SecureRedirectUseSourceURL(r.RequestURI), - x.SecureRedirectAllowURLs(h.c.SelfServiceBrowserWhitelistedReturnToDomains()), - x.SecureRedirectAllowSelfServiceURLs(h.c.SelfPublicURL()), + x.SecureRedirectAllowURLs(h.d.Config(r.Context()).SelfServiceBrowserWhitelistedReturnToDomains()), + x.SecureRedirectAllowSelfServiceURLs(h.d.Config(r.Context()).SelfPublicURL()), ) if err != nil { fmt.Printf("\n%s\n\n", err.Error()) diff --git a/selfservice/flow/recovery/error.go b/selfservice/flow/recovery/error.go index f9bb0b32e213..c6193729abcd 100644 --- a/selfservice/flow/recovery/error.go +++ b/selfservice/flow/recovery/error.go @@ -84,7 +84,7 @@ func (s *ErrorHandler) WriteFlowError( if e := new(FlowExpiredError); errors.As(err, &e) { // create new flow because the old one is not valid - a, err := NewFlow(s.d.Config(r.Context()).SelfServiceFlowRecoveryRequestLifespan(), s.d.GenerateCSRFToken(r), r, s.d.RecoveryStrategies(), f.Type) + a, err := NewFlow(s.d.Config(r.Context()).SelfServiceFlowRecoveryRequestLifespan(), s.d.GenerateCSRFToken(r), r, s.d.RecoveryStrategies(r.Context()), f.Type) if err != nil { // failed to create a new session and redirect to it, handle that error as a new one s.WriteFlowError(w, r, methodName, f, err) diff --git a/selfservice/flow/recovery/error_test.go b/selfservice/flow/recovery/error_test.go index 364b4b7dbab0..09da59b0fb55 100644 --- a/selfservice/flow/recovery/error_test.go +++ b/selfservice/flow/recovery/error_test.go @@ -61,7 +61,7 @@ func TestHandleError(t *testing.T) { newFlow := func(t *testing.T, ttl time.Duration, ft flow.Type) *recovery.Flow { req := &http.Request{URL: urlx.ParseOrPanic("/")} - f, err := recovery.NewFlow(ttl, x.FakeCSRFToken, req, reg.RecoveryStrategies(), ft) + f, err := recovery.NewFlow(ttl, x.FakeCSRFToken, req, reg.RecoveryStrategies(context.Background()), ft) require.NoError(t, err) require.NoError(t, reg.RecoveryFlowPersister().CreateRecoveryFlow(context.Background(), f)) f, err = reg.RecoveryFlowPersister().GetRecoveryFlow(context.Background(), f.ID) diff --git a/selfservice/flow/recovery/handler.go b/selfservice/flow/recovery/handler.go index b5791bd7aad6..71cc7caa9812 100644 --- a/selfservice/flow/recovery/handler.go +++ b/selfservice/flow/recovery/handler.go @@ -1,6 +1,7 @@ package recovery import ( + "github.com/ory/herodot" "net/http" "time" @@ -91,7 +92,12 @@ func (h *Handler) RegisterAdminRoutes(admin *x.RouterAdmin) { // 500: genericError // 400: genericError func (h *Handler) initAPIFlow(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { - req, err := NewFlow(h.d.Config(r.Context()).SelfServiceFlowRecoveryRequestLifespan(), h.d.GenerateCSRFToken(r), r, h.d.RecoveryStrategies(), flow.TypeAPI) + if !h.d.Config(r.Context()).SelfServiceFlowRecoveryEnabled() { + h.d.SelfServiceErrorManager().Forward(r.Context(), w, r, errors.WithStack(herodot.ErrBadRequest.WithReasonf("Recovery is not allowed because it was disabled."))) + return + } + + req, err := NewFlow(h.d.Config(r.Context()).SelfServiceFlowRecoveryRequestLifespan(), h.d.GenerateCSRFToken(r), r, h.d.RecoveryStrategies(r.Context()), flow.TypeAPI) if err != nil { h.d.Writer().WriteError(w, r, err) return @@ -123,7 +129,12 @@ func (h *Handler) initAPIFlow(w http.ResponseWriter, r *http.Request, _ httprout // 302: emptyResponse // 500: genericError func (h *Handler) initBrowserFlow(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { - req, err := NewFlow(h.d.Config(r.Context()).SelfServiceFlowRecoveryRequestLifespan(), h.d.GenerateCSRFToken(r), r, h.d.RecoveryStrategies(), flow.TypeBrowser) + if !h.d.Config(r.Context()).SelfServiceFlowRecoveryEnabled() { + h.d.SelfServiceErrorManager().Forward(r.Context(), w, r, errors.WithStack(herodot.ErrBadRequest.WithReasonf("Recovery is not allowed because it was disabled."))) + return + } + + req, err := NewFlow(h.d.Config(r.Context()).SelfServiceFlowRecoveryRequestLifespan(), h.d.GenerateCSRFToken(r), r, h.d.RecoveryStrategies(r.Context()), flow.TypeBrowser) if err != nil { h.d.SelfServiceErrorManager().Forward(r.Context(), w, r, err) return @@ -169,6 +180,11 @@ type getSelfServiceRecoveryFlowParameters struct { // 410: genericError // 500: genericError func (h *Handler) fetch(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + if !h.d.Config(r.Context()).SelfServiceFlowRecoveryEnabled() { + h.d.SelfServiceErrorManager().Forward(r.Context(), w, r, errors.WithStack(herodot.ErrBadRequest.WithReasonf("Recovery is not allowed because it was disabled."))) + return + } + rid := x.ParseUUID(r.URL.Query().Get("id")) req, err := h.d.RecoveryFlowPersister().GetRecoveryFlow(r.Context(), rid) if err != nil { diff --git a/selfservice/flow/recovery/strategy.go b/selfservice/flow/recovery/strategy.go index 55fe2422783b..f32c3f5764c1 100644 --- a/selfservice/flow/recovery/strategy.go +++ b/selfservice/flow/recovery/strategy.go @@ -1,6 +1,7 @@ package recovery import ( + "context" "net/http" "github.com/pkg/errors" @@ -25,7 +26,7 @@ type ( } Strategies []Strategy StrategyProvider interface { - RecoveryStrategies() Strategies + RecoveryStrategies(ctx context.Context) Strategies } ) diff --git a/selfservice/flow/registration/error_test.go b/selfservice/flow/registration/error_test.go index 529dfd4969ac..31d480a84cbd 100644 --- a/selfservice/flow/registration/error_test.go +++ b/selfservice/flow/registration/error_test.go @@ -63,7 +63,7 @@ func TestHandleError(t *testing.T) { newFlow := func(t *testing.T, ttl time.Duration, ft flow.Type) *registration.Flow { req := &http.Request{URL: urlx.ParseOrPanic("/")} f := registration.NewFlow(ttl, "csrf_token", req, ft) - for _, s := range reg.RegistrationStrategies() { + for _, s := range reg.RegistrationStrategies(context.Background()) { require.NoError(t, s.PopulateRegistrationMethod(req, f)) } diff --git a/selfservice/flow/registration/handler.go b/selfservice/flow/registration/handler.go index 04d3a01b5fd7..0cea8e2ec585 100644 --- a/selfservice/flow/registration/handler.go +++ b/selfservice/flow/registration/handler.go @@ -64,7 +64,7 @@ func (h *Handler) RegisterAdminRoutes(admin *x.RouterAdmin) { func (h *Handler) NewRegistrationFlow(w http.ResponseWriter, r *http.Request, ft flow.Type) (*Flow, error) { a := NewFlow(h.d.Config(r.Context()).SelfServiceFlowRegistrationRequestLifespan(), h.d.GenerateCSRFToken(r), r, ft) - for _, s := range h.d.RegistrationStrategies() { + for _, s := range h.d.RegistrationStrategies(r.Context()) { if err := s.PopulateRegistrationMethod(r, a); err != nil { return nil, err } diff --git a/selfservice/flow/registration/hook.go b/selfservice/flow/registration/hook.go index 3d1168190df1..21b40b275bb1 100644 --- a/selfservice/flow/registration/hook.go +++ b/selfservice/flow/registration/hook.go @@ -1,6 +1,7 @@ package registration import ( + "context" "fmt" "net/http" "time" @@ -34,9 +35,9 @@ type ( PostHookPrePersistExecutorFunc func(w http.ResponseWriter, r *http.Request, a *Flow, i *identity.Identity) error HooksProvider interface { - PreRegistrationHooks() []PreHookExecutor - PostRegistrationPrePersistHooks(credentialsType identity.CredentialsType) []PostHookPrePersistExecutor - PostRegistrationPostPersistHooks(credentialsType identity.CredentialsType) []PostHookPostPersistExecutor + PreRegistrationHooks(ctx context.Context, ) []PreHookExecutor + PostRegistrationPrePersistHooks(ctx context.Context, credentialsType identity.CredentialsType) []PostHookPrePersistExecutor + PostRegistrationPostPersistHooks(ctx context.Context, credentialsType identity.CredentialsType) []PostHookPostPersistExecutor } ) @@ -86,14 +87,14 @@ func (e *HookExecutor) PostRegistrationHook(w http.ResponseWriter, r *http.Reque WithField("identity_id", i.ID). WithField("flow_method", ct). Debug("Running PostRegistrationPrePersistHooks.") - for k, executor := range e.d.PostRegistrationPrePersistHooks(ct) { + for k, executor := range e.d.PostRegistrationPrePersistHooks(r.Context(), ct) { if err := executor.ExecutePostRegistrationPrePersistHook(w, r, a, i); err != nil { if errors.Is(err, ErrHookAbortFlow) { e.d.Logger(). WithRequest(r). WithField("executor", fmt.Sprintf("%T", executor)). WithField("executor_position", k). - WithField("executors", PostHookPostPersistExecutorNames(e.d.PostRegistrationPostPersistHooks(ct))). + WithField("executors", PostHookPostPersistExecutorNames(e.d.PostRegistrationPostPersistHooks(r.Context(), ct))). WithField("identity_id", i.ID). WithField("flow_method", ct). Debug("A ExecutePostRegistrationPrePersistHook hook aborted early.") @@ -105,7 +106,7 @@ func (e *HookExecutor) PostRegistrationHook(w http.ResponseWriter, r *http.Reque e.d.Logger().WithRequest(r). WithField("executor", fmt.Sprintf("%T", executor)). WithField("executor_position", k). - WithField("executors", PostHookPostPersistExecutorNames(e.d.PostRegistrationPostPersistHooks(ct))). + WithField("executors", PostHookPostPersistExecutorNames(e.d.PostRegistrationPostPersistHooks(r.Context(), ct))). WithField("identity_id", i.ID). WithField("flow_method", ct). Debug("ExecutePostRegistrationPrePersistHook completed successfully.") @@ -133,14 +134,14 @@ func (e *HookExecutor) PostRegistrationHook(w http.ResponseWriter, r *http.Reque WithField("identity_id", i.ID). WithField("flow_method", ct). Debug("Running PostRegistrationPostPersistHooks.") - for k, executor := range e.d.PostRegistrationPostPersistHooks(ct) { + for k, executor := range e.d.PostRegistrationPostPersistHooks(r.Context(), ct) { if err := executor.ExecutePostRegistrationPostPersistHook(w, r, a, s); err != nil { if errors.Is(err, ErrHookAbortFlow) { e.d.Logger(). WithRequest(r). WithField("executor", fmt.Sprintf("%T", executor)). WithField("executor_position", k). - WithField("executors", PostHookPostPersistExecutorNames(e.d.PostRegistrationPostPersistHooks(ct))). + WithField("executors", PostHookPostPersistExecutorNames(e.d.PostRegistrationPostPersistHooks(r.Context(), ct))). WithField("identity_id", i.ID). WithField("flow_method", ct). Debug("A ExecutePostRegistrationPostPersistHook hook aborted early.") @@ -152,7 +153,7 @@ func (e *HookExecutor) PostRegistrationHook(w http.ResponseWriter, r *http.Reque e.d.Logger().WithRequest(r). WithField("executor", fmt.Sprintf("%T", executor)). WithField("executor_position", k). - WithField("executors", PostHookPostPersistExecutorNames(e.d.PostRegistrationPostPersistHooks(ct))). + WithField("executors", PostHookPostPersistExecutorNames(e.d.PostRegistrationPostPersistHooks(r.Context(), ct))). WithField("identity_id", i.ID). WithField("flow_method", ct). Debug("ExecutePostRegistrationPostPersistHook completed successfully.") @@ -174,7 +175,7 @@ func (e *HookExecutor) PostRegistrationHook(w http.ResponseWriter, r *http.Reque } func (e *HookExecutor) PreRegistrationHook(w http.ResponseWriter, r *http.Request, a *Flow) error { - for _, executor := range e.d.PreRegistrationHooks() { + for _, executor := range e.d.PreRegistrationHooks(r.Context()) { if err := executor.ExecuteRegistrationPreHook(w, r, a); err != nil { return err } diff --git a/selfservice/flow/registration/strategy.go b/selfservice/flow/registration/strategy.go index e870abbe50a4..46c8b2f3e0f0 100644 --- a/selfservice/flow/registration/strategy.go +++ b/selfservice/flow/registration/strategy.go @@ -1,6 +1,7 @@ package registration import ( + "context" "net/http" "github.com/pkg/errors" @@ -44,5 +45,5 @@ func (s Strategies) RegisterPublicRoutes(r *x.RouterPublic) { } type StrategyProvider interface { - RegistrationStrategies() Strategies + RegistrationStrategies(ctx context.Context) Strategies } diff --git a/selfservice/flow/settings/error_test.go b/selfservice/flow/settings/error_test.go index 2bba70ddb9d9..f64030803f79 100644 --- a/selfservice/flow/settings/error_test.go +++ b/selfservice/flow/settings/error_test.go @@ -70,7 +70,7 @@ func TestHandleError(t *testing.T) { newFlow := func(t *testing.T, ttl time.Duration, ft flow.Type) *settings.Flow { req := &http.Request{URL: urlx.ParseOrPanic("/")} f := settings.NewFlow(ttl, req, &id, ft) - for _, s := range reg.SettingsStrategies() { + for _, s := range reg.SettingsStrategies(context.Background()) { require.NoError(t, s.PopulateSettingsMethod(req, &id, f)) } diff --git a/selfservice/flow/settings/handler.go b/selfservice/flow/settings/handler.go index f04276d0f69b..d6e766a9b427 100644 --- a/selfservice/flow/settings/handler.go +++ b/selfservice/flow/settings/handler.go @@ -88,7 +88,7 @@ func (h *Handler) RegisterAdminRoutes(admin *x.RouterAdmin) { func (h *Handler) NewFlow(w http.ResponseWriter, r *http.Request, i *identity.Identity, ft flow.Type) (*Flow, error) { f := NewFlow(h.d.Config(r.Context()).SelfServiceFlowSettingsFlowLifespan(), r, i, ft) - for _, strategy := range h.d.SettingsStrategies() { + for _, strategy := range h.d.SettingsStrategies(r.Context()) { if err := h.d.ContinuityManager().Abort(r.Context(), w, r, ContinuityKey(strategy.SettingsStrategyID())); err != nil { return nil, err } diff --git a/selfservice/flow/settings/hook.go b/selfservice/flow/settings/hook.go index 8685721ec7bc..108a54be8dd2 100644 --- a/selfservice/flow/settings/hook.go +++ b/selfservice/flow/settings/hook.go @@ -1,6 +1,7 @@ package settings import ( + "context" "fmt" "net/http" "time" @@ -27,8 +28,8 @@ type ( } PostHookPostPersistExecutorFunc func(w http.ResponseWriter, r *http.Request, a *Flow, s *identity.Identity) error HooksProvider interface { - PostSettingsPrePersistHooks(settingsType string) []PostHookPrePersistExecutor - PostSettingsPostPersistHooks(settingsType string) []PostHookPostPersistExecutor + PostSettingsPrePersistHooks(ctx context.Context, settingsType string) []PostHookPrePersistExecutor + PostSettingsPostPersistHooks(ctx context.Context, settingsType string) []PostHookPostPersistExecutor } executorDependencies interface { identity.ManagementProvider @@ -101,11 +102,11 @@ func (e *HookExecutor) PostSettingsHook(w http.ResponseWriter, r *http.Request, f(config) } - for k, executor := range e.d.PostSettingsPrePersistHooks(settingsType) { + for k, executor := range e.d.PostSettingsPrePersistHooks(r.Context(), settingsType) { logFields := logrus.Fields{ "executor": fmt.Sprintf("%T", executor), "executor_position": k, - "executors": PostHookPrePersistExecutorNames(e.d.PostSettingsPrePersistHooks(settingsType)), + "executors": PostHookPrePersistExecutorNames(e.d.PostSettingsPrePersistHooks(r.Context(), settingsType)), "identity_id": i.ID, "flow_method": settingsType, } @@ -159,14 +160,14 @@ func (e *HookExecutor) PostSettingsHook(w http.ResponseWriter, r *http.Request, return err } - for k, executor := range e.d.PostSettingsPostPersistHooks(settingsType) { + for k, executor := range e.d.PostSettingsPostPersistHooks(r.Context(), settingsType) { if err := executor.ExecuteSettingsPostPersistHook(w, r, ctxUpdate.Flow, i); err != nil { if errors.Is(err, ErrHookAbortRequest) { e.d.Logger(). WithRequest(r). WithField("executor", fmt.Sprintf("%T", executor)). WithField("executor_position", k). - WithField("executors", PostHookPostPersistExecutorNames(e.d.PostSettingsPostPersistHooks(settingsType))). + WithField("executors", PostHookPostPersistExecutorNames(e.d.PostSettingsPostPersistHooks(r.Context(), settingsType))). WithField("identity_id", i.ID). WithField("flow_method", settingsType). Debug("A ExecuteSettingsPostPersistHook hook aborted early.") @@ -178,7 +179,7 @@ func (e *HookExecutor) PostSettingsHook(w http.ResponseWriter, r *http.Request, e.d.Logger().WithRequest(r). WithField("executor", fmt.Sprintf("%T", executor)). WithField("executor_position", k). - WithField("executors", PostHookPostPersistExecutorNames(e.d.PostSettingsPostPersistHooks(settingsType))). + WithField("executors", PostHookPostPersistExecutorNames(e.d.PostSettingsPostPersistHooks(r.Context(), settingsType))). WithField("identity_id", i.ID). WithField("flow_method", settingsType). Debug("ExecuteSettingsPostPersistHook completed successfully.") diff --git a/selfservice/flow/settings/strategy.go b/selfservice/flow/settings/strategy.go index 13fb1d3f2126..578947a56b49 100644 --- a/selfservice/flow/settings/strategy.go +++ b/selfservice/flow/settings/strategy.go @@ -1,6 +1,7 @@ package settings import ( + "context" "net/http" "reflect" @@ -51,5 +52,5 @@ func (s Strategies) RegisterPublicRoutes(r *x.RouterPublic) { } type StrategyProvider interface { - SettingsStrategies() Strategies + SettingsStrategies(ctx context.Context) Strategies } diff --git a/selfservice/flow/verification/error.go b/selfservice/flow/verification/error.go index 507a58207a80..1ce5c9ae738a 100644 --- a/selfservice/flow/verification/error.go +++ b/selfservice/flow/verification/error.go @@ -80,7 +80,7 @@ func (s *ErrorHandler) WriteFlowError( if e := new(FlowExpiredError); errors.As(err, &e) { // create new flow because the old one is not valid a, err := NewFlow(s.d.Config(r.Context()).SelfServiceFlowVerificationRequestLifespan(), - s.d.GenerateCSRFToken(r), r, s.d.VerificationStrategies(), f.Type) + s.d.GenerateCSRFToken(r), r, s.d.VerificationStrategies(r.Context()), f.Type) if err != nil { // failed to create a new session and redirect to it, handle that error as a new one s.WriteFlowError(w, r, methodName, f, err) diff --git a/selfservice/flow/verification/handler.go b/selfservice/flow/verification/handler.go index a119e17eb67f..304c5e6dce50 100644 --- a/selfservice/flow/verification/handler.go +++ b/selfservice/flow/verification/handler.go @@ -1,6 +1,7 @@ package verification import ( + "github.com/ory/herodot" "net/http" "time" @@ -88,7 +89,12 @@ func (h *Handler) RegisterAdminRoutes(admin *x.RouterAdmin) { // 500: genericError // 400: genericError func (h *Handler) initAPIFlow(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { - req, err := NewFlow(h.d.Config(r.Context()).SelfServiceFlowVerificationRequestLifespan(), h.d.GenerateCSRFToken(r), r, h.d.VerificationStrategies(), flow.TypeAPI) + if !h.d.Config(r.Context()).SelfServiceFlowVerificationEnabled() { + h.d.SelfServiceErrorManager().Forward(r.Context(), w, r, errors.WithStack(herodot.ErrBadRequest.WithReasonf("Verification is not allowed because it was disabled."))) + return + } + + req, err := NewFlow(h.d.Config(r.Context()).SelfServiceFlowVerificationRequestLifespan(), h.d.GenerateCSRFToken(r), r, h.d.VerificationStrategies(r.Context()), flow.TypeAPI) if err != nil { h.d.Writer().WriteError(w, r, err) return @@ -119,7 +125,12 @@ func (h *Handler) initAPIFlow(w http.ResponseWriter, r *http.Request, _ httprout // 302: emptyResponse // 500: genericError func (h *Handler) initBrowserFlow(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - req, err := NewFlow(h.d.Config(r.Context()).SelfServiceFlowVerificationRequestLifespan(), h.d.GenerateCSRFToken(r), r, h.d.VerificationStrategies(), flow.TypeBrowser) + if !h.d.Config(r.Context()).SelfServiceFlowVerificationEnabled() { + h.d.SelfServiceErrorManager().Forward(r.Context(), w, r, errors.WithStack(herodot.ErrBadRequest.WithReasonf("Verification is not allowed because it was disabled."))) + return + } + + req, err := NewFlow(h.d.Config(r.Context()).SelfServiceFlowVerificationRequestLifespan(), h.d.GenerateCSRFToken(r), r, h.d.VerificationStrategies(r.Context()), flow.TypeBrowser) if err != nil { h.d.SelfServiceErrorManager().Forward(r.Context(), w, r, err) return @@ -165,6 +176,11 @@ type getSelfServiceVerificationFlowParameters struct { // 404: genericError // 500: genericError func (h *Handler) fetch(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + if !h.d.Config(r.Context()).SelfServiceFlowVerificationEnabled() { + h.d.SelfServiceErrorManager().Forward(r.Context(), w, r, errors.WithStack(herodot.ErrBadRequest.WithReasonf("Verification is not allowed because it was disabled."))) + return + } + rid := x.ParseUUID(r.URL.Query().Get("id")) req, err := h.d.VerificationFlowPersister().GetVerificationFlow(r.Context(), rid) if err != nil { diff --git a/selfservice/flow/verification/strategy.go b/selfservice/flow/verification/strategy.go index a92e7365812c..81f591f022ad 100644 --- a/selfservice/flow/verification/strategy.go +++ b/selfservice/flow/verification/strategy.go @@ -1,6 +1,7 @@ package verification import ( + "context" "net/http" "github.com/pkg/errors" @@ -25,7 +26,7 @@ type ( } Strategies []Strategy StrategyProvider interface { - VerificationStrategies() Strategies + VerificationStrategies(ctx context.Context) Strategies } ) diff --git a/selfservice/strategy/handler.go b/selfservice/strategy/handler.go index b73f60778cfd..05461dcd1088 100644 --- a/selfservice/strategy/handler.go +++ b/selfservice/strategy/handler.go @@ -10,16 +10,37 @@ import ( const EndpointDisabledMessage = "This endpoint was disabled by system administrator. Please check your url or contact the system administrator to enable it." -func IsDisabled(c interface { +type disabledChecker interface { config.Provider x.WriterProvider -}, strategy string, wrap httprouter.Handle) httprouter.Handle { +} + +func disabledWriter(c disabledChecker, enabled bool, wrap httprouter.Handle, w http.ResponseWriter, r *http.Request, ps httprouter.Params) { + if !enabled { + c.Writer().WriteError(w, r, herodot.ErrNotFound.WithReason(EndpointDisabledMessage)) + return + } + wrap(w, r, ps) +} + +func IsDisabled(c disabledChecker, strategy string, wrap httprouter.Handle) httprouter.Handle { + return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { + disabledWriter(c, c.Config(r.Context()).SelfServiceStrategy(strategy).Enabled, wrap, w, r, ps) + } +} + +func IsRecoveryDisabled(c disabledChecker, strategy string, wrap httprouter.Handle) httprouter.Handle { + return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { + disabledWriter(c, + c.Config(r.Context()).SelfServiceStrategy(strategy).Enabled && c.Config(r.Context()).SelfServiceFlowRecoveryEnabled(), + wrap, w, r, ps) + } +} + +func IsVerificationDisabled(c disabledChecker, strategy string, wrap httprouter.Handle) httprouter.Handle { return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - enabled := c.Config(r.Context()).SelfServiceStrategy(strategy).Enabled - if !enabled { - c.Writer().WriteError(w, r, herodot.ErrNotFound.WithReason(EndpointDisabledMessage)) - return - } - wrap(w, r, ps) + disabledWriter(c, + c.Config(r.Context()).SelfServiceStrategy(strategy).Enabled && c.Config(r.Context()).SelfServiceFlowVerificationEnabled(), + wrap, w, r, ps) } } diff --git a/selfservice/strategy/link/sender.go b/selfservice/strategy/link/sender.go index 9171ea176771..48bea5bb92d2 100644 --- a/selfservice/strategy/link/sender.go +++ b/selfservice/strategy/link/sender.go @@ -142,7 +142,7 @@ func (s *Sender) SendVerificationTokenTo(ctx context.Context, address *identity. func (s *Sender) send(ctx context.Context, via string, t courier.EmailTemplate) error { switch via { case identity.AddressTypeEmail: - _, err := s.r.Courier().QueueEmail(ctx, t) + _, err := s.r.Courier(ctx).QueueEmail(ctx, t) return err default: return errors.Errorf("received unexpected via type: %s", via) diff --git a/selfservice/strategy/link/sender_test.go b/selfservice/strategy/link/sender_test.go index f490b9a9b706..e260b50a4dff 100644 --- a/selfservice/strategy/link/sender_test.go +++ b/selfservice/strategy/link/sender_test.go @@ -32,7 +32,7 @@ func TestManager(t *testing.T) { require.NoError(t, reg.IdentityManager().Create(context.Background(), i)) t.Run("method=SendRecoveryLink", func(t *testing.T) { - f, err := recovery.NewFlow(time.Hour, "", u, reg.RecoveryStrategies(), flow.TypeBrowser) + f, err := recovery.NewFlow(time.Hour, "", u, reg.RecoveryStrategies(context.Background()), flow.TypeBrowser) require.NoError(t, err) require.NoError(t, reg.RecoveryFlowPersister().CreateRecoveryFlow(context.Background(), f)) @@ -54,7 +54,7 @@ func TestManager(t *testing.T) { }) t.Run("method=SendVerificationLink", func(t *testing.T) { - f, err := verification.NewFlow(time.Hour, "", u, reg.VerificationStrategies(), flow.TypeBrowser) + f, err := verification.NewFlow(time.Hour, "", u, reg.VerificationStrategies(context.Background()), flow.TypeBrowser) require.NoError(t, err) require.NoError(t, reg.VerificationFlowPersister().CreateVerificationFlow(context.Background(), f)) diff --git a/selfservice/strategy/link/strategy_recovery.go b/selfservice/strategy/link/strategy_recovery.go index c18f231d8e2f..1776ff0178c5 100644 --- a/selfservice/strategy/link/strategy_recovery.go +++ b/selfservice/strategy/link/strategy_recovery.go @@ -43,9 +43,8 @@ func (s *Strategy) RecoveryStrategyID() string { func (s *Strategy) RegisterPublicRecoveryRoutes(public *x.RouterPublic) { redirect := session.RedirectOnAuthenticated(s.d) - wrappedHandleRecovery := strategy.IsDisabled(s.d, s.RecoveryStrategyID(), s.handleRecovery) + wrappedHandleRecovery := strategy.IsRecoveryDisabled(s.d, s.RecoveryStrategyID(), s.handleRecovery) public.GET(RouteRecovery, s.d.SessionHandler().IsNotAuthenticated(wrappedHandleRecovery, redirect)) - public.POST(RouteRecovery, s.d.SessionHandler().IsNotAuthenticated(wrappedHandleRecovery, redirect)) } @@ -157,7 +156,7 @@ func (s *Strategy) createRecoveryLink(w http.ResponseWriter, r *http.Request, _ return } - req, err := recovery.NewFlow(expiresIn, s.d.GenerateCSRFToken(r), r, s.d.RecoveryStrategies(), flow.TypeBrowser) + req, err := recovery.NewFlow(expiresIn, s.d.GenerateCSRFToken(r), r, s.d.RecoveryStrategies(r.Context()), flow.TypeBrowser) if err != nil { s.d.Writer().WriteError(w, r, err) return @@ -363,7 +362,7 @@ func (s *Strategy) recoveryUseToken(w http.ResponseWriter, r *http.Request, body var f *recovery.Flow if !token.FlowID.Valid { - f, err = recovery.NewFlow(time.Until(token.ExpiresAt), s.d.GenerateCSRFToken(r), r, s.d.RecoveryStrategies(), flow.TypeBrowser) + f, err = recovery.NewFlow(time.Until(token.ExpiresAt), s.d.GenerateCSRFToken(r), r, s.d.RecoveryStrategies(r.Context()), flow.TypeBrowser) if err != nil { s.handleRecoveryError(w, r, nil, body, err) return @@ -392,7 +391,7 @@ func (s *Strategy) recoveryUseToken(w http.ResponseWriter, r *http.Request, body func (s *Strategy) retryRecoveryFlowWithMessage(w http.ResponseWriter, r *http.Request, ft flow.Type, message *text.Message) { s.d.Logger().WithRequest(r).WithField("message", message).Debug("A recovery flow is being retried because a validation error occurred.") - req, err := recovery.NewFlow(s.d.Config(r.Context()).SelfServiceFlowRecoveryRequestLifespan(), s.d.GenerateCSRFToken(r), r, s.d.RecoveryStrategies(), ft) + req, err := recovery.NewFlow(s.d.Config(r.Context()).SelfServiceFlowRecoveryRequestLifespan(), s.d.GenerateCSRFToken(r), r, s.d.RecoveryStrategies(r.Context()), ft) if err != nil { s.d.SelfServiceErrorManager().Forward(r.Context(), w, r, err) return diff --git a/selfservice/strategy/link/strategy_verification.go b/selfservice/strategy/link/strategy_verification.go index 7c627c045895..62300d345e6b 100644 --- a/selfservice/strategy/link/strategy_verification.go +++ b/selfservice/strategy/link/strategy_verification.go @@ -1,12 +1,11 @@ package link import ( + "github.com/ory/kratos/selfservice/strategy" "net/http" "net/url" "time" - "github.com/ory/herodot" - "github.com/ory/x/pkgerx" "github.com/gofrs/uuid" @@ -37,8 +36,9 @@ func (s *Strategy) VerificationStrategyID() string { } func (s *Strategy) RegisterPublicVerificationRoutes(public *x.RouterPublic) { - public.POST(RouteVerification, s.handleVerification) - public.GET(RouteVerification, s.handleVerification) + wrappedHandleVerification := strategy.IsVerificationDisabled(s.d, s.RecoveryStrategyID(), s.handleVerification) + public.POST(RouteVerification, wrappedHandleVerification) + public.GET(RouteVerification, wrappedHandleVerification) } func (s *Strategy) RegisterAdminVerificationRoutes(admin *x.RouterAdmin) { @@ -170,11 +170,6 @@ type completeSelfServiceVerificationFlowWithLinkMethod struct { // 302: emptyResponse // 500: genericError func (s *Strategy) handleVerification(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - if !s.d.Config(r.Context()).SelfServiceStrategy(s.VerificationStrategyID()).Enabled { - s.handleVerificationError(w, r, nil, nil, errors.WithStack(herodot.ErrBadRequest.WithReasonf("Verification using this method is not allowed because it was disabled."))) - return - } - body, err := s.decodeVerification(r, false) if err != nil { s.handleVerificationError(w, r, nil, body, err) @@ -293,7 +288,7 @@ func (s *Strategy) verificationUseToken(w http.ResponseWriter, r *http.Request, var f *verification.Flow if !token.FlowID.Valid { - f, err = verification.NewFlow(time.Until(token.ExpiresAt), s.d.GenerateCSRFToken(r), r, s.d.VerificationStrategies(), flow.TypeBrowser) + f, err = verification.NewFlow(time.Until(token.ExpiresAt), s.d.GenerateCSRFToken(r), r, s.d.VerificationStrategies(r.Context()), flow.TypeBrowser) if err != nil { s.handleVerificationError(w, r, nil, body, err) return @@ -339,7 +334,7 @@ func (s *Strategy) verificationUseToken(w http.ResponseWriter, r *http.Request, func (s *Strategy) retryVerificationFlowWithMessage(w http.ResponseWriter, r *http.Request, ft flow.Type, message *text.Message) { s.d.Logger().WithRequest(r).WithField("message", message).Debug("A verification flow is being retried because a validation error occurred.") - req, err := verification.NewFlow(s.d.Config(r.Context()).SelfServiceFlowVerificationRequestLifespan(), s.d.GenerateCSRFToken(r), r, s.d.VerificationStrategies(), ft) + req, err := verification.NewFlow(s.d.Config(r.Context()).SelfServiceFlowVerificationRequestLifespan(), s.d.GenerateCSRFToken(r), r, s.d.VerificationStrategies(r.Context()), ft) if err != nil { s.d.SelfServiceErrorManager().Forward(r.Context(), w, r, err) return diff --git a/selfservice/strategy/oidc/strategy_settings_test.go b/selfservice/strategy/oidc/strategy_settings_test.go index ee4805a9f318..7508c946cd96 100644 --- a/selfservice/strategy/oidc/strategy_settings_test.go +++ b/selfservice/strategy/oidc/strategy_settings_test.go @@ -538,7 +538,7 @@ func TestPopulateSettingsMethod(t *testing.T) { } ns := func(t *testing.T, reg *driver.RegistryDefault) *oidc.Strategy { - ss, err := reg.SettingsStrategies().Strategy(identity.CredentialsTypeOIDC.String()) + ss, err := reg.SettingsStrategies(context.Background()).Strategy(identity.CredentialsTypeOIDC.String()) require.NoError(t, err) return ss.(*oidc.Strategy) } diff --git a/selfservice/strategy/oidc/strategy_test.go b/selfservice/strategy/oidc/strategy_test.go index 6774a92ea72d..c14e2df4b542 100644 --- a/selfservice/strategy/oidc/strategy_test.go +++ b/selfservice/strategy/oidc/strategy_test.go @@ -409,7 +409,7 @@ func TestStrategy(t *testing.T) { conf.MustSet(config.ViperKeyPublicBaseURL, "https://foo/") sr := registration.NewFlow(time.Minute, "nosurf", &http.Request{URL: urlx.ParseOrPanic("/")}, flow.TypeBrowser) - require.NoError(t, reg.RegistrationStrategies().MustStrategy(identity.CredentialsTypeOIDC).(*oidc.Strategy).PopulateRegistrationMethod(&http.Request{}, sr)) + require.NoError(t, reg.RegistrationStrategies(context.Background()).MustStrategy(identity.CredentialsTypeOIDC).(*oidc.Strategy).PopulateRegistrationMethod(&http.Request{}, sr)) expected := ®istration.FlowMethod{ Method: identity.CredentialsTypeOIDC, @@ -449,7 +449,7 @@ func TestStrategy(t *testing.T) { conf.MustSet(config.ViperKeyPublicBaseURL, "https://foo/") sr := login.NewFlow(time.Minute, "nosurf", &http.Request{URL: urlx.ParseOrPanic("/")}, flow.TypeBrowser) - require.NoError(t, reg.LoginStrategies().MustStrategy(identity.CredentialsTypeOIDC).(*oidc.Strategy).PopulateLoginMethod(&http.Request{}, sr)) + require.NoError(t, reg.LoginStrategies(context.Background()).MustStrategy(identity.CredentialsTypeOIDC).(*oidc.Strategy).PopulateLoginMethod(&http.Request{}, sr)) expected := &login.FlowMethod{ Method: identity.CredentialsTypeOIDC, diff --git a/selfservice/strategy/password/registration_test.go b/selfservice/strategy/password/registration_test.go index be294b075bd1..74616c94271a 100644 --- a/selfservice/strategy/password/registration_test.go +++ b/selfservice/strategy/password/registration_test.go @@ -2,6 +2,7 @@ package password_test import ( "bytes" + "context" "encoding/json" "fmt" "net/http" @@ -562,7 +563,7 @@ func TestRegistration(t *testing.T) { "enabled": true}) sr := registration.NewFlow(time.Minute, "nosurf", &http.Request{URL: urlx.ParseOrPanic("/")}, flow.TypeBrowser) - require.NoError(t, reg.RegistrationStrategies().MustStrategy(identity.CredentialsTypePassword).(*password.Strategy).PopulateRegistrationMethod(&http.Request{}, sr)) + require.NoError(t, reg.RegistrationStrategies(context.Background()).MustStrategy(identity.CredentialsTypePassword).(*password.Strategy).PopulateRegistrationMethod(&http.Request{}, sr)) expected := ®istration.FlowMethod{ Method: identity.CredentialsTypePassword, diff --git a/session/manager_http.go b/session/manager_http.go index 11a163e466ab..c3138490e479 100644 --- a/session/manager_http.go +++ b/session/manager_http.go @@ -50,7 +50,7 @@ func (s *ManagerHTTP) CreateAndIssueCookie(ctx context.Context, w http.ResponseW } func (s *ManagerHTTP) IssueCookie(ctx context.Context, w http.ResponseWriter, r *http.Request, session *Session) error { - cookie, _ := s.r.CookieManager().Get(r, s.cookieName) + cookie, _ := s.r.CookieManager(r.Context()).Get(r, s.cookieName) if s.r.Config(ctx).SessionDomain() != "" { cookie.Options.Domain = s.r.Config(ctx).SessionDomain() } @@ -93,7 +93,7 @@ func (s *ManagerHTTP) extractToken(r *http.Request) string { return token } - cookie, err := s.r.CookieManager().Get(r, s.cookieName) + cookie, err := s.r.CookieManager(r.Context()).Get(r, s.cookieName) if err != nil { return "" } @@ -133,7 +133,7 @@ func (s *ManagerHTTP) PurgeFromRequest(ctx context.Context, w http.ResponseWrite return errors.WithStack(s.r.SessionPersister().RevokeSessionByToken(ctx, token)) } - cookie, _ := s.r.CookieManager().Get(r, s.cookieName) + cookie, _ := s.r.CookieManager(r.Context()).Get(r, s.cookieName) token, ok := cookie.Values["session_token"].(string) if !ok { return nil diff --git a/x/provider.go b/x/provider.go index 6b3054e05268..e194b276b704 100644 --- a/x/provider.go +++ b/x/provider.go @@ -19,6 +19,6 @@ type WriterProvider interface { } type CookieProvider interface { - CookieManager() sessions.Store + CookieManager(ctx context.Context) sessions.Store ContinuityCookieManager(ctx context.Context) sessions.Store }