Skip to content

Commit

Permalink
fix: add ability to resume continuity sessions from several cookies
Browse files Browse the repository at this point in the history
Closes #2016
Closes #1786
Closes ory-corp/cloud#1786
Closes #2108
  • Loading branch information
aeneasr committed Jan 10, 2022
1 parent 8e4b4fb commit 406eb47
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 27 deletions.
43 changes: 43 additions & 0 deletions continuity/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,49 @@ func TestManager(t *testing.T) {
assert.EqualValues(t, res.Cookies()[0].Name, continuity.CookieName)
})

t.Run("case=can deal with duplicate cookies", func(t *testing.T) {
tc := &persisterTestCase{expected: &persisterTestPayload{"bar"}}
ts := newServer(t, p, tc)
href := ts.URL + "/" + x.NewUUID().String()

res, err := http.DefaultClient.Do(x.NewTestHTTPRequest(t, "PUT", href, nil))
require.NoError(t, err)
require.NoError(t, res.Body.Close())
require.Equal(t, http.StatusNoContent, res.StatusCode)

// We change the key to another one
href = ts.URL + "/" + x.NewUUID().String()
req := x.NewTestHTTPRequest(t, "GET", href, nil)
require.Len(t, res.Cookies(), 1)
for _, c := range res.Cookies() {
req.AddCookie(c)
}

tc.ro = []continuity.ManagerOption{continuity.WithPayload(&persisterTestPayload{"bar"})}
res, err = http.DefaultClient.Do(x.NewTestHTTPRequest(t, "PUT", href, nil))
require.NoError(t, err)
require.NoError(t, res.Body.Close())
require.Equal(t, http.StatusNoContent, res.StatusCode)

require.Len(t, res.Cookies(), 1)
for _, c := range res.Cookies() {
req.AddCookie(c)
}

res, err = http.DefaultClient.Do(req)
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, res.Body.Close()) })

require.Len(t, res.Cookies(), 1, "continuing the flow with a broken cookie should instruct the browser to forget it")
assert.EqualValues(t, res.Cookies()[0].Name, continuity.CookieName)

var b bytes.Buffer
require.NoError(t, json.NewEncoder(&b).Encode(tc.expected))
body := ioutilx.MustReadAll(res.Body)
assert.JSONEq(t, b.String(), gjson.GetBytes(body, "payload").Raw, "%s", body)
assert.Contains(t, href, gjson.GetBytes(body, "name").String(), "%s", body)
})

for k, tc := range []persisterTestCase{
{},
{
Expand Down
6 changes: 3 additions & 3 deletions driver/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ type Registry interface {
WithCSRFHandler(c nosurf.Handler)
WithCSRFTokenGenerator(cg x.CSRFToken)

HealthHandler(ctx context.Context) *healthx.Handler
CookieManager(ctx context.Context) sessions.Store
MetricsHandler() *prometheus.Handler
ContinuityCookieManager(ctx context.Context) sessions.Store
HealthHandler(ctx context.Context) *healthx.Handler
CookieManager(ctx context.Context) sessions.StoreExact
ContinuityCookieManager(ctx context.Context) sessions.StoreExact

RegisterRoutes(ctx context.Context, public *x.RouterPublic, admin *x.RouterAdmin)
RegisterPublicRoutes(ctx context.Context, public *x.RouterPublic)
Expand Down
4 changes: 2 additions & 2 deletions driver/registry_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ func (m *RegistryDefault) SelfServiceErrorHandler() *errorx.Handler {
return m.errorHandler
}

func (m *RegistryDefault) CookieManager(ctx context.Context) sessions.Store {
func (m *RegistryDefault) CookieManager(ctx context.Context) sessions.StoreExact {
cs := sessions.NewCookieStore(m.Config(ctx).SecretsSession()...)
cs.Options.Secure = !m.Config(ctx).IsInsecureDevMode()
cs.Options.HttpOnly = true
Expand All @@ -447,7 +447,7 @@ func (m *RegistryDefault) CookieManager(ctx context.Context) sessions.Store {
return cs
}

func (m *RegistryDefault) ContinuityCookieManager(ctx context.Context) sessions.Store {
func (m *RegistryDefault) ContinuityCookieManager(ctx context.Context) sessions.StoreExact {
// To support hot reloading, this can not be instantiated only once.
cs := sessions.NewCookieStore(m.Config(ctx).SecretsSession()...)
cs.Options.Secure = !m.Config(ctx).IsInsecureDevMode()
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ replace (
go.mongodb.org/mongo-driver => go.mongodb.org/mongo-driver v1.4.6
golang.org/x/sys => golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac
gopkg.in/DataDog/dd-trace-go.v1 => gopkg.in/DataDog/dd-trace-go.v1 v1.27.1-0.20201005154917-54b73b3e126a
github.com/gorilla/sessions => ../../gorilla/sessions
)

require (
Expand Down
45 changes: 26 additions & 19 deletions x/cookie.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,9 @@ import (
)

// SessionPersistValues adds values to the session store and persists the changes.
func SessionPersistValues(w http.ResponseWriter, r *http.Request, s sessions.Store, id string, values map[string]interface{}) error {
func SessionPersistValues(w http.ResponseWriter, r *http.Request, s sessions.StoreExact, id string, values map[string]interface{}) error {
// The error does not matter because in the worst case we're re-writing the session cookie.
cookie, err := s.Get(r, id)
if err != nil {
cookie = sessions.NewSession(s, id)
}

cookie, _ := s.Get(r, id)
for k, v := range values {
cookie.Values[k] = v
}
Expand All @@ -24,32 +20,43 @@ func SessionPersistValues(w http.ResponseWriter, r *http.Request, s sessions.Sto

// SessionGetString returns a string for the given id and key or an error if the session is invalid,
// the key does not exist, or the key value is not a string.
func SessionGetString(r *http.Request, s sessions.Store, id string, key interface{}) (string, error) {
cookie, err := s.Get(r, id)
if err != nil {
return "", errors.WithStack(err)
func SessionGetString(r *http.Request, s sessions.StoreExact, id string, key interface{}) (string, error) {
check := func(v map[interface{}]interface{}) (string, error) {
vv, ok := v[key]
if !ok {
return "", errors.Errorf("key %s does not exist in cookie: %+v", key, id)
} else if vvv, ok := vv.(string); !ok {
return "", errors.Errorf("value of key %s is not of type string in cookie", key)
} else {
return vvv, nil
}
}

if v, ok := cookie.Values[key]; !ok {
return "", errors.Errorf("key %s does not exist in cookie: %+v", key, cookie.Values)
} else if vv, ok := v.(string); !ok {
return "", errors.Errorf("value of key %s is not of type string in cookie", key)
} else {
return vv, nil
var exactErr error
cookie, err := s.GetExact(r, id, func(s *sessions.Session) bool {
_, exactErr = check(s.Values)
return exactErr == nil
})
if err != nil {
return "", err
} else if exactErr != nil {
return "", exactErr
}

return check(cookie.Values)
}

// SessionGetStringOr returns a string for the given id and key or the fallback value if the session is invalid,
// the key does not exist, or the key value is not a string.
func SessionGetStringOr(r *http.Request, s sessions.Store, id, key, fallback string) string {
func SessionGetStringOr(r *http.Request, s sessions.StoreExact, id, key, fallback string) string {
v, err := SessionGetString(r, s, id, key)
if err != nil {
return fallback
}
return v
}

func SessionUnset(w http.ResponseWriter, r *http.Request, s sessions.Store, id string) error {
func SessionUnset(w http.ResponseWriter, r *http.Request, s sessions.StoreExact, id string) error {
cookie, err := s.Get(r, id)
if err == nil && cookie.IsNew {
// No cookie was sent in the request. We have nothing to do.
Expand All @@ -61,7 +68,7 @@ func SessionUnset(w http.ResponseWriter, r *http.Request, s sessions.Store, id s
return errors.WithStack(cookie.Save(r, w))
}

func SessionUnsetKey(w http.ResponseWriter, r *http.Request, s sessions.Store, id, key string) error {
func SessionUnsetKey(w http.ResponseWriter, r *http.Request, s sessions.StoreExact, id, key string) error {
cookie, err := s.Get(r, id)
if err == nil && cookie.IsNew {
// No cookie was sent in the request. We have nothing to do.
Expand Down
43 changes: 42 additions & 1 deletion x/cookie_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,48 @@ func TestSession(t *testing.T) {

w.WriteHeader(http.StatusNoContent)
})
mr(t, id)
})

t.Run("case=GetStringMultipleCookies", func(t *testing.T) {
id := "get-string-multiple"

router.GET("/set/"+id, func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
require.NoError(t, SessionPersistValues(w, r, s, sid, map[string]interface{}{
"multiple-string-1": "foo",
}))
require.NoError(t, SessionPersistValues(w, r, s, sid, map[string]interface{}{
"multiple-string-2": "bar",
}))
isExpiryCorrect(t, r)
w.WriteHeader(http.StatusNoContent)
})

router.GET("/get/"+id, func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
got, err := SessionGetString(r, s, sid, "multiple-string-1")
require.NoError(t, err)
assert.EqualValues(t, "foo", got)

got, err = SessionGetString(r, s, sid, "multiple-string-2")
require.NoError(t, err)
assert.EqualValues(t, "bar", got)

w.WriteHeader(http.StatusNoContent)
})

res, err := http.DefaultClient.Get(ts.URL + "/set/" + id)
require.NoError(t, err)
require.EqualValues(t, http.StatusNoContent, res.StatusCode)
require.NoError(t, res.Body.Close())

req, _ := http.NewRequest("GET", ts.URL+"/get/"+id, nil)
for _, c := range res.Cookies() {
req.AddCookie(c)
}

res, err = http.DefaultClient.Do(req)
require.NoError(t, err)
require.EqualValues(t, http.StatusNoContent, res.StatusCode)
require.NoError(t, res.Body.Close())
})

t.Run("case=GetStringOr", func(t *testing.T) {
Expand Down
4 changes: 2 additions & 2 deletions x/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ type WriterProvider interface {
}

type CookieProvider interface {
CookieManager(ctx context.Context) sessions.Store
ContinuityCookieManager(ctx context.Context) sessions.Store
CookieManager(ctx context.Context) sessions.StoreExact
ContinuityCookieManager(ctx context.Context) sessions.StoreExact
}

type TracingProvider interface {
Expand Down

0 comments on commit 406eb47

Please sign in to comment.