Skip to content

Commit

Permalink
feat(oidc): Make state timeout duration configurable to support long …
Browse files Browse the repository at this point in the history
…taking sign ins (#362)

* feat(oidc): Make state timeout duration configurable to support long taking sign in

* chore: CR feedback
  • Loading branch information
carstendietrich authored Oct 18, 2023
1 parent 258bd2d commit f410758
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 16 deletions.
1 change: 1 addition & 0 deletions core/auth/oauth/module.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ core: auth: {
}
enableEndSessionEndpoint: bool | *true
overrideIssuerURL: string | *""
stateLifeTime: string | *"30m"
}
}
`
Expand Down
20 changes: 19 additions & 1 deletion core/auth/oauth/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ type (
oidcConfig oidcConfig
verifierConfigurator []func(*oidc.Config)
callbackErrorHandler CallbackErrorHandler
stateTimeout *time.Duration
}

sessionData struct {
Expand Down Expand Up @@ -85,6 +86,7 @@ type (
} `json:"claims"`
EnableEndSessionEndpoint bool `json:"enableEndSessionEndpoint"`
OverrideIssuerURL string `json:"overrideIssuerURL"`
StateLifeTime string `json:"stateLifeTime"`
}
)

Expand Down Expand Up @@ -140,6 +142,15 @@ func oidcFactory(cfg config.Map) (auth.RequestIdentifier, error) {
authCodeOptions = append(authCodeOptions, oauth2AuthCodeOption{authCodeOption: authCodeOption})
}

stateTimeout := defaultStateTimeout

if oidcConfig.StateLifeTime != "" {
stateTimeout, err = time.ParseDuration(oidcConfig.StateLifeTime)
if err != nil {
panic("invalid value for oidc broker config stateLifeTime")
}
}

return &openIDIdentifier{
oauth2Config: &oauth2.Config{
ClientID: oidcConfig.ClientID,
Expand All @@ -152,6 +163,7 @@ func oidcFactory(cfg config.Map) (auth.RequestIdentifier, error) {
provider: provider,
authcodeOptions: authCodeOptions,
oidcConfig: oidcConfig,
stateTimeout: &stateTimeout,
}, nil
}

Expand Down Expand Up @@ -362,7 +374,7 @@ type StateEntry struct {
TS time.Time
}

const stateTimeout = time.Minute * 30
const defaultStateTimeout = time.Minute * 30

func init() {
gob.Register([]StateEntry(nil))
Expand All @@ -374,6 +386,12 @@ const sessionStatesKey = "states"
var now = time.Now

func (i *openIDIdentifier) validateSessionCode(request *web.Request, code string) bool {
stateTimeout := defaultStateTimeout

if i.stateTimeout != nil {
stateTimeout = *i.stateTimeout
}

sessionStates, ok := request.Session().Load(i.sessionCode(sessionStatesKey))
if !ok {
return false
Expand Down
78 changes: 63 additions & 15 deletions core/auth/oauth/oidc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,25 +54,27 @@ func (m *mockCallbackErrorHandler) Handle(_ context.Context, _ string, _ *web.Re

var _ CallbackErrorHandler = &mockCallbackErrorHandler{}

//nolint:paralleltest // time.Now lives in global var `now` running this parallel will cause race cond
func TestParallelStateRaceConditions(t *testing.T) {
identifier := &openIDIdentifier{
authCodeOptionerProvider: func() []AuthCodeOptioner { return nil },
oauth2Config: &oauth2.Config{},
reverseRouter: new(mockRouter),
responder: &web.Responder{},
}
//nolint:paralleltest // time.Now lives in global var `now` running this parallel will cause race cond
t.Run("test states", func(t *testing.T) {
identifier := &openIDIdentifier{
authCodeOptionerProvider: func() []AuthCodeOptioner { return nil },
oauth2Config: &oauth2.Config{},
reverseRouter: new(mockRouter),
responder: &web.Responder{},
}

session := web.EmptySession()
session := web.EmptySession()

resp := identifier.Authenticate(context.Background(), web.CreateRequest(nil, session))
state1 := resp.(*web.URLRedirectResponse).URL.Query().Get("state")
resp = identifier.Authenticate(context.Background(), web.CreateRequest(nil, session))
state2 := resp.(*web.URLRedirectResponse).URL.Query().Get("state")
resp := identifier.Authenticate(context.Background(), web.CreateRequest(nil, session))
state1 := resp.(*web.URLRedirectResponse).URL.Query().Get("state")
resp = identifier.Authenticate(context.Background(), web.CreateRequest(nil, session))
state2 := resp.(*web.URLRedirectResponse).URL.Query().Get("state")

request, err := http.NewRequest(http.MethodGet, "http://example.com/callback", nil)
assert.NoError(t, err)
request, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://example.com/callback", nil)
assert.NoError(t, err)

t.Run("test states", func(t *testing.T) {
request.URL.RawQuery = url.Values{"state": []string{"invalid-state"}}.Encode()
resp = identifier.Callback(context.Background(), web.CreateRequest(request, session), nil)
errResp := resp.(*web.ServerErrorResponse)
Expand All @@ -94,7 +96,20 @@ func TestParallelStateRaceConditions(t *testing.T) {
assert.EqualError(t, errResp.Error, "state mismatch")
})

t.Run("test timeshift", func(t *testing.T) {
//nolint:paralleltest // time.Now lives in global var `now` running this parallel will cause race cond
t.Run("test default time shift", func(t *testing.T) {
identifier := &openIDIdentifier{
authCodeOptionerProvider: func() []AuthCodeOptioner { return nil },
oauth2Config: &oauth2.Config{},
reverseRouter: new(mockRouter),
responder: &web.Responder{},
}

request, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://example.com/callback", nil)
assert.NoError(t, err)

session := web.EmptySession()

resp := identifier.Authenticate(context.Background(), web.CreateRequest(nil, session))
state1 := resp.(*web.URLRedirectResponse).URL.Query().Get("state")

Expand All @@ -109,6 +124,37 @@ func TestParallelStateRaceConditions(t *testing.T) {

now = time.Now
})

//nolint:paralleltest // time.Now lives in global var `now` running this parallel will cause race cond
t.Run("test custom time shift", func(t *testing.T) {
identifier := &openIDIdentifier{
authCodeOptionerProvider: func() []AuthCodeOptioner { return nil },
oauth2Config: &oauth2.Config{},
reverseRouter: new(mockRouter),
responder: &web.Responder{},
}

request, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://example.com/callback", nil)
assert.NoError(t, err)

session := web.EmptySession()

resp := identifier.Authenticate(context.Background(), web.CreateRequest(nil, session))
state1 := resp.(*web.URLRedirectResponse).URL.Query().Get("state")
oneHour := time.Hour
identifier.stateTimeout = &oneHour

now = func() time.Time {
return time.Now().Add(35 * time.Minute)
}

request.URL.RawQuery = url.Values{"state": []string{state1}}.Encode()
resp = identifier.Callback(context.Background(), web.CreateRequest(request, session), nil)
errResp := resp.(*web.ServerErrorResponse)
assert.NotContains(t, errResp.Error.Error(), "state mismatch")

now = time.Now
})
}

type testOidcProvider struct {
Expand Down Expand Up @@ -303,6 +349,8 @@ func TestOidcCallback(t *testing.T) {
}

func Test_openIDIdentifier_RefreshIdentity(t *testing.T) {
t.Parallel()

var identifier auth.RequestIdentifier = &openIDIdentifier{broker: "broker"}
session := web.EmptySession()
session.Store("core.auth.oidc.broker.sessiondata", sessionData{
Expand Down

0 comments on commit f410758

Please sign in to comment.