Skip to content

Commit

Permalink
feat: enable simultaneous auth flows by creating flow related csrf co…
Browse files Browse the repository at this point in the history
…okie names
  • Loading branch information
aarmam committed Apr 6, 2022
1 parent 00100a1 commit 3b6c96f
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 20 deletions.
6 changes: 4 additions & 2 deletions consent/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ package consent

import (
"net/http"
"time"

"github.com/ory/x/errorsx"

Expand Down Expand Up @@ -63,19 +64,20 @@ func matchScopes(scopeStrategy fosite.ScopeStrategy, previousConsent []HandledCo
return nil
}

func createCsrfSession(w http.ResponseWriter, r *http.Request, store sessions.Store, name, csrf string, secure bool, sameSiteMode http.SameSite, sameSiteLegacyWorkaround bool) error {
func createCsrfSession(w http.ResponseWriter, r *http.Request, store sessions.Store, name, csrf string, secure bool, sameSiteMode http.SameSite, sameSiteLegacyWorkaround bool, maxAge time.Duration) error {
// Errors can be ignored here, because we always get a session session back. Error typically means that the
// session doesn't exist yet.
session, _ := store.Get(r, CookieName(secure, name))
session.Values["csrf"] = csrf
session.Options.HttpOnly = true
session.Options.Secure = secure
session.Options.SameSite = sameSiteMode
session.Options.MaxAge = int(maxAge.Seconds())
if err := session.Save(r, w); err != nil {
return errorsx.WithStack(err)
}
if sameSiteMode == http.SameSiteNoneMode && sameSiteLegacyWorkaround {
return createCsrfSession(w, r, store, legacyCsrfSessionName(name), csrf, secure, 0, false)
return createCsrfSession(w, r, store, legacyCsrfSessionName(name), csrf, secure, 0, false, maxAge)
}
return nil
}
Expand Down
39 changes: 27 additions & 12 deletions consent/helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/gorilla/securecookie"
"github.com/gorilla/sessions"
Expand Down Expand Up @@ -127,7 +128,7 @@ func TestValidateCsrfSession(t *testing.T) {
{
cookies: []cookie{
{
name: cookieAuthenticationCSRFName,
name: "oauth2_authentication_csrf",
csrfValue: "WRONG-CSRF-VALUE",
sameSite: http.SameSiteDefaultMode,
},
Expand All @@ -139,7 +140,7 @@ func TestValidateCsrfSession(t *testing.T) {
{
cookies: []cookie{
{
name: cookieAuthenticationCSRFName,
name: "oauth2_authentication_csrf",
csrfValue: "WRONG-CSRF-VALUE",
sameSite: http.SameSiteDefaultMode,
},
Expand All @@ -151,7 +152,7 @@ func TestValidateCsrfSession(t *testing.T) {
{
cookies: []cookie{
{
name: cookieAuthenticationCSRFName,
name: "oauth2_authentication_csrf",
csrfValue: "CSRF-VALUE",
sameSite: http.SameSiteDefaultMode,
},
Expand All @@ -163,7 +164,7 @@ func TestValidateCsrfSession(t *testing.T) {
{
cookies: []cookie{
{
name: cookieAuthenticationCSRFName,
name: "oauth2_authentication_csrf",
csrfValue: "CSRF-VALUE",
sameSite: http.SameSiteDefaultMode,
},
Expand All @@ -175,7 +176,7 @@ func TestValidateCsrfSession(t *testing.T) {
{
cookies: []cookie{
{
name: legacyCsrfSessionName(cookieAuthenticationCSRFName),
name: legacyCsrfSessionName("oauth2_authentication_csrf"),
csrfValue: "CSRF-VALUE",
sameSite: http.SameSiteDefaultMode,
},
Expand All @@ -187,7 +188,7 @@ func TestValidateCsrfSession(t *testing.T) {
{
cookies: []cookie{
{
name: legacyCsrfSessionName(cookieAuthenticationCSRFName),
name: legacyCsrfSessionName("oauth2_authentication_csrf"),
csrfValue: "CSRF-VALUE",
sameSite: http.SameSiteDefaultMode,
},
Expand All @@ -199,12 +200,12 @@ func TestValidateCsrfSession(t *testing.T) {
{
cookies: []cookie{
{
name: cookieAuthenticationCSRFName,
name: "oauth2_authentication_csrf",
csrfValue: "CSRF-VALUE",
sameSite: http.SameSiteNoneMode,
},
{
name: legacyCsrfSessionName(cookieAuthenticationCSRFName),
name: legacyCsrfSessionName("oauth2_authentication_csrf"),
csrfValue: "CSRF-VALUE",
sameSite: http.SameSiteDefaultMode,
},
Expand All @@ -216,12 +217,12 @@ func TestValidateCsrfSession(t *testing.T) {
{
cookies: []cookie{
{
name: cookieAuthenticationCSRFName,
name: "oauth2_authentication_csrf",
csrfValue: "CSRF-VALUE",
sameSite: http.SameSiteNoneMode,
},
{
name: legacyCsrfSessionName(cookieAuthenticationCSRFName),
name: legacyCsrfSessionName("oauth2_authentication_csrf"),
csrfValue: "CSRF-VALUE",
sameSite: http.SameSiteDefaultMode,
},
Expand All @@ -246,7 +247,7 @@ func TestValidateCsrfSession(t *testing.T) {
assert.NoError(t, err, "failed to save cookie %s", c.name)
}

err := validateCsrfSession(r, store, cookieAuthenticationCSRFName, tc.csrfValue, tc.sameSiteLegacyWorkaround, true)
err := validateCsrfSession(r, store, "oauth2_authentication_csrf", tc.csrfValue, tc.sameSiteLegacyWorkaround, true)
if tc.expectError {
assert.Error(t, err)
} else {
Expand All @@ -261,81 +262,94 @@ func TestCreateCsrfSession(t *testing.T) {
httpOnly bool
secure bool
sameSite http.SameSite
maxAge int
}
for _, tc := range []struct {
name string
secure bool
sameSite http.SameSite
maxAge time.Duration
sameSiteLegacyWorkaround bool
expectedCookies map[string]cookie
}{
{
name: "csrf_default",
secure: true,
sameSite: http.SameSiteDefaultMode,
maxAge: 10 * time.Second,
sameSiteLegacyWorkaround: false,
expectedCookies: map[string]cookie{
"csrf_default": {
httpOnly: true,
secure: true,
sameSite: 0, // see https://golang.org/doc/go1.16#net/http
maxAge: 10,
},
},
},
{
name: "csrf_lax_insecure",
secure: false,
sameSite: http.SameSiteLaxMode,
maxAge: 20 * time.Second,
sameSiteLegacyWorkaround: false,
expectedCookies: map[string]cookie{
"csrf_lax_insecure_insecure": {
httpOnly: true,
secure: false,
sameSite: http.SameSiteLaxMode,
maxAge: 20,
},
},
},
{
name: "csrf_none",
secure: true,
sameSite: http.SameSiteNoneMode,
maxAge: 30 * time.Second,
sameSiteLegacyWorkaround: false,
expectedCookies: map[string]cookie{
"csrf_none": {
httpOnly: true,
secure: true,
sameSite: http.SameSiteNoneMode,
maxAge: 30,
},
},
},
{
name: "csrf_none_fallback",
secure: true,
sameSite: http.SameSiteNoneMode,
maxAge: 40 * time.Second,
sameSiteLegacyWorkaround: true,
expectedCookies: map[string]cookie{
"csrf_none_fallback": {
httpOnly: true,
secure: true,
sameSite: http.SameSiteNoneMode,
maxAge: 40,
},
"csrf_none_fallback_legacy": {
httpOnly: true,
secure: true,
sameSite: 0,
maxAge: 40,
},
},
},
{
name: "csrf_strict_fallback_ignored",
secure: true,
sameSite: http.SameSiteStrictMode,
maxAge: 50 * time.Second,
sameSiteLegacyWorkaround: true,
expectedCookies: map[string]cookie{
"csrf_strict_fallback_ignored": {
httpOnly: true,
secure: true,
sameSite: http.SameSiteStrictMode,
maxAge: 50,
},
},
},
Expand All @@ -345,7 +359,7 @@ func TestCreateCsrfSession(t *testing.T) {
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)

err := createCsrfSession(rr, req, store, tc.name, "value", tc.secure, tc.sameSite, tc.sameSiteLegacyWorkaround)
err := createCsrfSession(rr, req, store, tc.name, "value", tc.secure, tc.sameSite, tc.sameSiteLegacyWorkaround, tc.maxAge)
assert.NoError(t, err)

cookies := make(map[string]cookie)
Expand All @@ -354,6 +368,7 @@ func TestCreateCsrfSession(t *testing.T) {
httpOnly: c.HttpOnly,
secure: c.Secure,
sameSite: c.SameSite,
maxAge: c.MaxAge,
}
}
assert.Equal(t, tc.expectedCookies, cookies)
Expand Down
21 changes: 15 additions & 6 deletions consent/strategy_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,16 @@ package consent

import (
"context"
"fmt"
"net/http"
"net/url"
"strconv"
"strings"
"sync"
"time"

"github.com/twmb/murmur3"

"github.com/ory/hydra/driver/config"

"github.com/ory/x/errorsx"
Expand Down Expand Up @@ -62,8 +65,8 @@ const (
CookieAuthenticationName = "oauth2_authentication_session"
CookieAuthenticationSIDName = "sid"

cookieAuthenticationCSRFName = "oauth2_authentication_csrf"
cookieConsentCSRFName = "oauth2_consent_csrf"
cookieAuthenticationCSRFNamePrefix = "oauth2_authentication_csrf_"
cookieConsentCSRFNamePrefix = "oauth2_consent_csrf_"
)

type DefaultStrategy struct {
Expand Down Expand Up @@ -254,6 +257,7 @@ func (s *DefaultStrategy) forwardAuthenticationRequest(w http.ResponseWriter, r
}

// Set the session
cl := sanitizeClientFromRequest(ar)
if err := s.r.ConsentManager().CreateLoginRequest(
r.Context(),
&LoginRequest{
Expand All @@ -264,7 +268,7 @@ func (s *DefaultStrategy) forwardAuthenticationRequest(w http.ResponseWriter, r
RequestedScope: []string(ar.GetRequestedScopes()),
RequestedAudience: []string(ar.GetRequestedAudience()),
Subject: subject,
Client: sanitizeClientFromRequest(ar),
Client: cl,
RequestURL: iu.String(),
AuthenticatedAt: sqlxx.NullTime(authenticatedAt),
RequestedAt: time.Now().Truncate(time.Second).UTC(),
Expand All @@ -281,7 +285,8 @@ func (s *DefaultStrategy) forwardAuthenticationRequest(w http.ResponseWriter, r
return errorsx.WithStack(err)
}

if err := createCsrfSession(w, r, s.r.CookieStore(), cookieAuthenticationCSRFName, csrf, s.c.TLS(config.PublicInterface).Enabled(), s.c.CookieSameSiteMode(), s.c.CookieSameSiteLegacyWorkaround()); err != nil {
cookieAuthenticationCSRFName := cookieAuthenticationCSRFNamePrefix + fmt.Sprint(murmur3.Sum32([]byte(cl.OutfacingID)))
if err := createCsrfSession(w, r, s.r.CookieStore(), cookieAuthenticationCSRFName, csrf, s.c.TLS(config.PublicInterface).Enabled(), s.c.CookieSameSiteMode(), s.c.CookieSameSiteLegacyWorkaround(), s.c.ConsentRequestMaxAge()); err != nil {
return errorsx.WithStack(err)
}

Expand Down Expand Up @@ -357,6 +362,7 @@ func (s *DefaultStrategy) verifyAuthentication(w http.ResponseWriter, r *http.Re
return nil, errorsx.WithStack(fosite.ErrRequestUnauthorized.WithHint("The login request has expired. Please try again."))
}

cookieAuthenticationCSRFName := cookieAuthenticationCSRFNamePrefix + fmt.Sprint(murmur3.Sum32([]byte(session.LoginRequest.Client.OutfacingID)))
if err := validateCsrfSession(r, s.r.CookieStore(), cookieAuthenticationCSRFName, session.LoginRequest.CSRF, s.c.CookieSameSiteLegacyWorkaround(), s.c.TLS(config.PublicInterface).Enabled()); err != nil {
return nil, err
}
Expand Down Expand Up @@ -544,6 +550,7 @@ func (s *DefaultStrategy) forwardConsentRequest(w http.ResponseWriter, r *http.R
challenge := strings.Replace(uuid.New(), "-", "", -1)
csrf := strings.Replace(uuid.New(), "-", "", -1)

cl := sanitizeClientFromRequest(ar)
if err := s.r.ConsentManager().CreateConsentRequest(
r.Context(),
&ConsentRequest{
Expand All @@ -556,7 +563,7 @@ func (s *DefaultStrategy) forwardConsentRequest(w http.ResponseWriter, r *http.R
RequestedScope: []string(ar.GetRequestedScopes()),
RequestedAudience: []string(ar.GetRequestedAudience()),
Subject: as.Subject,
Client: sanitizeClientFromRequest(ar),
Client: cl,
RequestURL: as.LoginRequest.RequestURL,
AuthenticatedAt: as.AuthenticatedAt,
RequestedAt: as.RequestedAt,
Expand All @@ -570,7 +577,8 @@ func (s *DefaultStrategy) forwardConsentRequest(w http.ResponseWriter, r *http.R
return errorsx.WithStack(err)
}

if err := createCsrfSession(w, r, s.r.CookieStore(), cookieConsentCSRFName, csrf, s.c.TLS(config.PublicInterface).Enabled(), s.c.CookieSameSiteMode(), s.c.CookieSameSiteLegacyWorkaround()); err != nil {
cookieConsentCSRFName := cookieConsentCSRFNamePrefix + fmt.Sprint(murmur3.Sum32([]byte(cl.OutfacingID)))
if err := createCsrfSession(w, r, s.r.CookieStore(), cookieConsentCSRFName, csrf, s.c.TLS(config.PublicInterface).Enabled(), s.c.CookieSameSiteMode(), s.c.CookieSameSiteLegacyWorkaround(), s.c.ConsentRequestMaxAge()); err != nil {
return errorsx.WithStack(err)
}

Expand Down Expand Up @@ -605,6 +613,7 @@ func (s *DefaultStrategy) verifyConsent(w http.ResponseWriter, r *http.Request,
return nil, errorsx.WithStack(fosite.ErrServerError.WithHint("The authenticatedAt value was not set."))
}

cookieConsentCSRFName := cookieConsentCSRFNamePrefix + fmt.Sprint(murmur3.Sum32([]byte(session.ConsentRequest.Client.OutfacingID)))
if err := validateCsrfSession(r, s.r.CookieStore(), cookieConsentCSRFName, session.ConsentRequest.CSRF, s.c.CookieSameSiteLegacyWorkaround(), s.c.TLS(config.PublicInterface).Enabled()); err != nil {
return nil, err
}
Expand Down
Loading

0 comments on commit 3b6c96f

Please sign in to comment.