Skip to content

Commit

Permalink
fix: remove stray non-ctx configs
Browse files Browse the repository at this point in the history
  • Loading branch information
aeneasr committed Feb 8, 2021
1 parent 343d02d commit 19ed782
Show file tree
Hide file tree
Showing 47 changed files with 273 additions and 221 deletions.
2 changes: 1 addition & 1 deletion cmd/daemon/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ func bgTasks(d driver.Registry, wg *sync.WaitGroup, cmd *cobra.Command, args []s

d.Logger().Println("Courier worker started.")
if err := graceful.Graceful(func() error {
return d.Courier().Work(ctx)
return d.Courier(ctx).Work(ctx)
}, func(_ cx.Context) error {
cancel()
return nil
Expand Down
7 changes: 3 additions & 4 deletions courier/courier.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ type (
smtpDependencies interface {
PersistenceProvider
x.LoggingProvider
config.Provider
}
Courier struct {
Dialer *gomail.Dialer
d smtpDependencies
c *config.Config
}
Provider interface {
Courier() *Courier
Courier(ctx context.Context) *Courier
}
)

Expand All @@ -50,7 +50,6 @@ func NewSMTP(d smtpDependencies, c *config.Config) *Courier {

return &Courier{
d: d,
c: c,
Dialer: &gomail.Dialer{
/* #nosec we need to support SMTP servers without TLS */
TLSConfig: tlsConfig,
Expand Down Expand Up @@ -130,7 +129,7 @@ func (m *Courier) watchMessages(ctx context.Context, errChan chan error) {

switch msg.Type {
case MessageTypeEmail:
from := m.c.CourierSMTPFrom()
from := m.d.Config(ctx).CourierSMTPFrom()
gm := gomail.NewMessage()
gm.SetHeader("From", from)
gm.SetHeader("To", msg.Recipient)
Expand Down
2 changes: 1 addition & 1 deletion courier/courier_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func TestSMTP(t *testing.T) {
conf, reg := internal.NewFastRegistryWithMocks(t)
conf.MustSet(config.ViperKeyCourierSMTPURL, smtp)
conf.MustSet(config.ViperKeyCourierSMTPFrom, "test-stub@ory.sh")
c := reg.Courier()
c := reg.Courier(context.Background())

ctx, cancel := context.WithCancel(context.Background())

Expand Down
2 changes: 1 addition & 1 deletion driver/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ type Registry interface {
WithCSRFTokenGenerator(cg x.CSRFToken)

HealthHandler() *healthx.Handler
CookieManager() sessions.Store
CookieManager(ctx context.Context) sessions.Store
ContinuityCookieManager(ctx context.Context) sessions.Store

RegisterRoutes(public *x.RouterPublic, admin *x.RouterAdmin)
Expand Down
125 changes: 45 additions & 80 deletions driver/registry_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ type RegistryDefault struct {
healthxHandler *healthx.Handler
metricsHandler *prometheus.Handler

courier *courier.Courier
persister persistence.Persister

hookVerifier *hook.Verifier
Expand All @@ -84,7 +83,6 @@ type RegistryDefault struct {
schemaHandler *schema.Handler

sessionHandler *session.Handler
sessionsStore *sessions.CookieStore
sessionManager session.Manager

passwordHasher hash.Hasher
Expand Down Expand Up @@ -117,13 +115,7 @@ type RegistryDefault struct {

selfserviceLogoutHandler *logout.Handler

selfserviceStrategies []interface{}
loginStrategies []login.Strategy
activeCredentialsCounterStrategies []identity.ActiveCredentialsCounter
registrationStrategies []registration.Strategy
profileStrategies []settings.Strategy
recoveryStrategies []recovery.Strategy
verificationStrategies []verification.Strategy
selfserviceStrategies []interface{}

buildVersion string
buildHash string
Expand Down Expand Up @@ -151,13 +143,11 @@ func (m *RegistryDefault) RegisterPublicRoutes(router *x.RouterPublic) {
m.SelfServiceErrorHandler().RegisterPublicRoutes(router)
m.SchemaHandler().RegisterPublicRoutes(router)

if m.c.SelfServiceFlowRecoveryEnabled() {
m.AllRecoveryStrategies().RegisterPublicRoutes(router)
m.RecoveryHandler().RegisterPublicRoutes(router)
}
m.AllRecoveryStrategies().RegisterPublicRoutes(router)
m.RecoveryHandler().RegisterPublicRoutes(router)

m.VerificationHandler().RegisterPublicRoutes(router)
m.VerificationStrategies().RegisterPublicRoutes(router)
m.AllVerificationStrategies().RegisterPublicRoutes(router)

m.HealthHandler().SetRoutes(router.Router, false)
}
Expand All @@ -171,13 +161,11 @@ func (m *RegistryDefault) RegisterAdminRoutes(router *x.RouterAdmin) {
m.SessionHandler().RegisterAdminRoutes(router)
m.SelfServiceErrorHandler().RegisterAdminRoutes(router)

if m.c.SelfServiceFlowRecoveryEnabled() {
m.RecoveryHandler().RegisterAdminRoutes(router)
m.AllRecoveryStrategies().RegisterAdminRoutes(router)
}
m.RecoveryHandler().RegisterAdminRoutes(router)
m.AllRecoveryStrategies().RegisterAdminRoutes(router)

m.VerificationHandler().RegisterAdminRoutes(router)
m.VerificationStrategies().RegisterAdminRoutes(router)
m.AllVerificationStrategies().RegisterAdminRoutes(router)

m.HealthHandler().SetRoutes(router.Router, true)
m.MetricsHandler().SetRoutes(router.Router)
Expand All @@ -199,7 +187,7 @@ func (m *RegistryDefault) WithLogger(l *logrusx.Logger) Registry {

func (m *RegistryDefault) LogoutHandler() *logout.Handler {
if m.selfserviceLogoutHandler == nil {
m.selfserviceLogoutHandler = logout.NewHandler(m, m.c)
m.selfserviceLogoutHandler = logout.NewHandler(m)
}
return m.selfserviceLogoutHandler
}
Expand Down Expand Up @@ -252,17 +240,15 @@ func (m *RegistryDefault) selfServiceStrategies() []interface{} {
return m.selfserviceStrategies
}

func (m *RegistryDefault) RegistrationStrategies() registration.Strategies {
if len(m.registrationStrategies) == 0 {
for _, strategy := range m.selfServiceStrategies() {
if s, ok := strategy.(registration.Strategy); ok {
if m.c.SelfServiceStrategy(string(s.ID())).Enabled {
m.registrationStrategies = append(m.registrationStrategies, s)
}
func (m *RegistryDefault) RegistrationStrategies(ctx context.Context) (registrationStrategies registration.Strategies) {
for _, strategy := range m.selfServiceStrategies() {
if s, ok := strategy.(registration.Strategy); ok {
if m.Config(ctx).SelfServiceStrategy(string(s.ID())).Enabled {
registrationStrategies = append(registrationStrategies, s)
}
}
}
return m.registrationStrategies
return
}

func (m *RegistryDefault) AllRegistrationStrategies() registration.Strategies {
Expand All @@ -276,17 +262,15 @@ func (m *RegistryDefault) AllRegistrationStrategies() registration.Strategies {
return registrationStrategies
}

func (m *RegistryDefault) LoginStrategies() login.Strategies {
if len(m.loginStrategies) == 0 {
for _, strategy := range m.selfServiceStrategies() {
if s, ok := strategy.(login.Strategy); ok {
if m.c.SelfServiceStrategy(string(s.ID())).Enabled {
m.loginStrategies = append(m.loginStrategies, s)
}
func (m *RegistryDefault) LoginStrategies(ctx context.Context) (loginStrategies login.Strategies) {
for _, strategy := range m.selfServiceStrategies() {
if s, ok := strategy.(login.Strategy); ok {
if m.Config(ctx).SelfServiceStrategy(string(s.ID())).Enabled {
loginStrategies = append(loginStrategies, s)
}
}
}
return m.loginStrategies
return
}

func (m *RegistryDefault) AllLoginStrategies() login.Strategies {
Expand All @@ -299,26 +283,13 @@ func (m *RegistryDefault) AllLoginStrategies() login.Strategies {
return loginStrategies
}

func (m *RegistryDefault) VerificationStrategies() verification.Strategies {
if len(m.verificationStrategies) == 0 {
for _, strategy := range m.selfServiceStrategies() {
if s, ok := strategy.(verification.Strategy); ok {
m.verificationStrategies = append(m.verificationStrategies, s)
}
}
}
return m.verificationStrategies
}

func (m *RegistryDefault) ActiveCredentialsCounterStrategies(ctx context.Context) []identity.ActiveCredentialsCounter {
if len(m.activeCredentialsCounterStrategies) == 0 {
for _, strategy := range m.selfServiceStrategies() {
if s, ok := strategy.(identity.ActiveCredentialsCounter); ok {
m.activeCredentialsCounterStrategies = append(m.activeCredentialsCounterStrategies, s)
}
func (m *RegistryDefault) ActiveCredentialsCounterStrategies(ctx context.Context) (activeCredentialsCounterStrategies []identity.ActiveCredentialsCounter) {
for _, strategy := range m.selfServiceStrategies() {
if s, ok := strategy.(identity.ActiveCredentialsCounter); ok {
activeCredentialsCounterStrategies = append(activeCredentialsCounterStrategies, s)
}
}
return m.activeCredentialsCounterStrategies
return
}

func (m *RegistryDefault) IdentityValidator() *identity.Validator {
Expand Down Expand Up @@ -390,30 +361,27 @@ func (m *RegistryDefault) SelfServiceErrorHandler() *errorx.Handler {
return m.errorHandler
}

func (m *RegistryDefault) CookieManager() sessions.Store {
if m.sessionsStore == nil {
cs := sessions.NewCookieStore(m.c.SecretsSession()...)
cs.Options.Secure = !m.c.IsInsecureDevMode()
cs.Options.HttpOnly = true
if m.c.SessionDomain() != "" {
cs.Options.Domain = m.c.SessionDomain()
}
func (m *RegistryDefault) CookieManager(ctx context.Context) sessions.Store {
cs := sessions.NewCookieStore(m.Config(ctx).SecretsSession()...)
cs.Options.Secure = !m.Config(ctx).IsInsecureDevMode()
cs.Options.HttpOnly = true
if domain := m.Config(ctx).SessionDomain(); domain != "" {
cs.Options.Domain = domain
}

if m.c.SessionPath() != "" {
cs.Options.Path = m.c.SessionPath()
}
if path := m.Config(ctx).SessionPath(); path != "" {
cs.Options.Path = path
}

if m.c.SessionSameSiteMode() != 0 {
cs.Options.SameSite = m.c.SessionSameSiteMode()
}
if sameSite := m.Config(ctx).SessionSameSiteMode(); sameSite != 0 {
cs.Options.SameSite = sameSite
}

cs.Options.MaxAge = 0
if m.c.SessionPersistentCookie() {
cs.Options.MaxAge = int(m.c.SessionLifespan().Seconds())
}
m.sessionsStore = cs
cs.Options.MaxAge = 0
if m.Config(ctx).SessionPersistentCookie() {
cs.Options.MaxAge = int(m.Config(ctx).SessionLifespan().Seconds())
}
return m.sessionsStore
return cs
}

func (m *RegistryDefault) ContinuityCookieManager(ctx context.Context) sessions.Store {
Expand Down Expand Up @@ -501,7 +469,7 @@ func (m *RegistryDefault) Init(ctx context.Context) error {
}

// if dsn is memory we have to run the migrations on every start
if dbal.InMemoryDSN == m.c.DSN() {
if dbal.InMemoryDSN == m.Config(ctx).DSN() {
m.Logger().Infoln("ORY Kratos is running migrations on every startup as DSN is memory. This means your data is lost when Kratos terminates.")
if err := p.MigrateUp(ctx); err != nil {
return err
Expand All @@ -514,11 +482,8 @@ func (m *RegistryDefault) Init(ctx context.Context) error {
)
}

func (m *RegistryDefault) Courier() *courier.Courier {
if m.courier == nil {
m.courier = courier.NewSMTP(m, m.c)
}
return m.courier
func (m *RegistryDefault) Courier(ctx context.Context) *courier.Courier {
return courier.NewSMTP(m, m.Config(ctx))
}

func (m *RegistryDefault) ContinuityManager() continuity.Manager {
Expand Down
11 changes: 6 additions & 5 deletions driver/registry_default_login.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package driver

import (
"context"
"github.com/ory/kratos/identity"
"github.com/ory/kratos/selfservice/flow/login"
)
Expand All @@ -12,17 +13,17 @@ func (m *RegistryDefault) LoginHookExecutor() *login.HookExecutor {
return m.selfserviceLoginExecutor
}

func (m *RegistryDefault) PreLoginHooks() (b []login.PreHookExecutor) {
for _, v := range m.getHooks("", m.c.SelfServiceFlowLoginBeforeHooks()) {
func (m *RegistryDefault) PreLoginHooks(ctx context.Context) (b []login.PreHookExecutor) {
for _, v := range m.getHooks("", m.Config(ctx).SelfServiceFlowLoginBeforeHooks()) {
if hook, ok := v.(login.PreHookExecutor); ok {
b = append(b, hook)
}
}
return
}

func (m *RegistryDefault) PostLoginHooks(credentialsType identity.CredentialsType) (b []login.PostHookExecutor) {
for _, v := range m.getHooks(string(credentialsType), m.c.SelfServiceFlowLoginAfterHooks(string(credentialsType))) {
func (m *RegistryDefault) PostLoginHooks(ctx context.Context, credentialsType identity.CredentialsType) (b []login.PostHookExecutor) {
for _, v := range m.getHooks(string(credentialsType), m.Config(ctx).SelfServiceFlowLoginAfterHooks(string(credentialsType))) {
if hook, ok := v.(login.PostHookExecutor); ok {
b = append(b, hook)
}
Expand All @@ -40,7 +41,7 @@ func (m *RegistryDefault) LoginHandler() *login.Handler {

func (m *RegistryDefault) LoginFlowErrorHandler() *login.ErrorHandler {
if m.selfserviceLoginRequestErrorHandler == nil {
m.selfserviceLoginRequestErrorHandler = login.NewFlowErrorHandler(m, m.c)
m.selfserviceLoginRequestErrorHandler = login.NewFlowErrorHandler(m)
}

return m.selfserviceLoginRequestErrorHandler
Expand Down
16 changes: 7 additions & 9 deletions driver/registry_default_recovery.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package driver

import (
"context"
"github.com/ory/kratos/selfservice/flow/recovery"
)

Expand All @@ -20,25 +21,22 @@ func (m *RegistryDefault) RecoveryHandler() *recovery.Handler {
return m.selfserviceRecoveryHandler
}

func (m *RegistryDefault) RecoveryStrategies() recovery.Strategies {
if len(m.recoveryStrategies) == 0 {
func (m *RegistryDefault) RecoveryStrategies(ctx context.Context) (recoveryStrategies recovery.Strategies) {
for _, strategy := range m.selfServiceStrategies() {
if s, ok := strategy.(recovery.Strategy); ok {
if m.c.SelfServiceStrategy(s.RecoveryStrategyID()).Enabled {
m.recoveryStrategies = append(m.recoveryStrategies, s)
if m.Config(ctx).SelfServiceStrategy(s.RecoveryStrategyID()).Enabled {
recoveryStrategies = append(recoveryStrategies, s)
}
}
}
}
return m.recoveryStrategies
return
}

func (m *RegistryDefault) AllRecoveryStrategies() recovery.Strategies {
var recoveryStrategies []recovery.Strategy
func (m *RegistryDefault) AllRecoveryStrategies() (recoveryStrategies recovery.Strategies ){
for _, strategy := range m.selfServiceStrategies() {
if s, ok := strategy.(recovery.Strategy); ok {
recoveryStrategies = append(recoveryStrategies, s)
}
}
return recoveryStrategies
return
}
Loading

0 comments on commit 19ed782

Please sign in to comment.