diff --git a/codec.go b/codec.go new file mode 100644 index 0000000..be7e6cb --- /dev/null +++ b/codec.go @@ -0,0 +1,49 @@ +package scs + +import ( + "bytes" + "encoding/gob" + "time" +) + +// Codec is the interface for encoding/decoding session data to and from a byte +// slice for use by the session store. +type Codec interface { + Encode(deadline time.Time, values map[string]interface{}) ([]byte, error) + Decode([]byte) (deadline time.Time, values map[string]interface{}, err error) +} + +type gobCodec struct{} + +func (gobCodec) Encode(deadline time.Time, values map[string]interface{}) ([]byte, error) { + aux := &struct { + Deadline time.Time + Values map[string]interface{} + }{ + Deadline: deadline, + Values: values, + } + + var b bytes.Buffer + err := gob.NewEncoder(&b).Encode(&aux) + if err != nil { + return nil, err + } + + return b.Bytes(), nil +} + +func (gobCodec) Decode(b []byte) (time.Time, map[string]interface{}, error) { + aux := &struct { + Deadline time.Time + Values map[string]interface{} + }{} + + r := bytes.NewReader(b) + err := gob.NewDecoder(r).Decode(&aux) + if err != nil { + return time.Time{}, nil, err + } + + return aux.Deadline, aux.Values, nil +} diff --git a/data.go b/data.go index e4fe95b..c18bfdc 100644 --- a/data.go +++ b/data.go @@ -1,11 +1,9 @@ package scs import ( - "bytes" "context" "crypto/rand" "encoding/base64" - "encoding/gob" "fmt" "sort" "sync" @@ -30,18 +28,18 @@ const ( ) type sessionData struct { - Deadline time.Time // Exported for gob encoding. + deadline time.Time status Status token string - Values map[string]interface{} // Exported for gob encoding. + values map[string]interface{} mu sync.Mutex } func newSessionData(lifetime time.Duration) *sessionData { return &sessionData{ - Deadline: time.Now().Add(lifetime).UTC(), + deadline: time.Now().Add(lifetime).UTC(), status: Unmodified, - Values: make(map[string]interface{}), + values: make(map[string]interface{}), } } @@ -71,7 +69,7 @@ func (s *SessionManager) Load(ctx context.Context, token string) (context.Contex status: Unmodified, token: token, } - err = sd.decode(b) + sd.deadline, sd.values, err = s.Codec.Decode(b) if err != nil { return nil, err } @@ -104,12 +102,12 @@ func (s *SessionManager) Commit(ctx context.Context) (string, time.Time, error) } } - b, err := sd.encode() + b, err := s.Codec.Encode(sd.deadline, sd.values) if err != nil { return "", time.Time{}, err } - expiry := sd.Deadline + expiry := sd.deadline if s.IdleTimeout > 0 { ie := time.Now().Add(s.IdleTimeout) if ie.Before(expiry) { @@ -143,9 +141,9 @@ func (s *SessionManager) Destroy(ctx context.Context) error { // Reset everything else to defaults. sd.token = "" - sd.Deadline = time.Now().Add(s.Lifetime).UTC() - for key := range sd.Values { - delete(sd.Values, key) + sd.deadline = time.Now().Add(s.Lifetime).UTC() + for key := range sd.values { + delete(sd.values, key) } return nil @@ -158,7 +156,7 @@ func (s *SessionManager) Put(ctx context.Context, key string, val interface{}) { sd := s.getSessionDataFromContext(ctx) sd.mu.Lock() - sd.Values[key] = val + sd.values[key] = val sd.status = Modified sd.mu.Unlock() } @@ -180,7 +178,7 @@ func (s *SessionManager) Get(ctx context.Context, key string) interface{} { sd.mu.Lock() defer sd.mu.Unlock() - return sd.Values[key] + return sd.values[key] } // Pop acts like a one-time Get. It returns the value for a given key from the @@ -193,11 +191,11 @@ func (s *SessionManager) Pop(ctx context.Context, key string) interface{} { sd.mu.Lock() defer sd.mu.Unlock() - val, exists := sd.Values[key] + val, exists := sd.values[key] if !exists { return nil } - delete(sd.Values, key) + delete(sd.values, key) sd.status = Modified return val @@ -212,12 +210,12 @@ func (s *SessionManager) Remove(ctx context.Context, key string) { sd.mu.Lock() defer sd.mu.Unlock() - _, exists := sd.Values[key] + _, exists := sd.values[key] if !exists { return } - delete(sd.Values, key) + delete(sd.values, key) sd.status = Modified } @@ -230,12 +228,12 @@ func (s *SessionManager) Clear(ctx context.Context) error { sd.mu.Lock() defer sd.mu.Unlock() - if len(sd.Values) == 0 { + if len(sd.values) == 0 { return nil } - for key := range sd.Values { - delete(sd.Values, key) + for key := range sd.values { + delete(sd.values, key) } sd.status = Modified return nil @@ -246,7 +244,7 @@ func (s *SessionManager) Exists(ctx context.Context, key string) bool { sd := s.getSessionDataFromContext(ctx) sd.mu.Lock() - _, exists := sd.Values[key] + _, exists := sd.values[key] sd.mu.Unlock() return exists @@ -259,9 +257,9 @@ func (s *SessionManager) Keys(ctx context.Context) []string { sd := s.getSessionDataFromContext(ctx) sd.mu.Lock() - keys := make([]string, len(sd.Values)) + keys := make([]string, len(sd.values)) i := 0 - for key := range sd.Values { + for key := range sd.values { keys[i] = key i++ } @@ -298,7 +296,7 @@ func (s *SessionManager) RenewToken(ctx context.Context) error { } sd.token = newToken - sd.Deadline = time.Now().Add(s.Lifetime).UTC() + sd.deadline = time.Now().Add(s.Lifetime).UTC() sd.status = Modified return nil @@ -477,21 +475,6 @@ func (s *SessionManager) getSessionDataFromContext(ctx context.Context) *session return c } -func (sd *sessionData) encode() ([]byte, error) { - var b bytes.Buffer - err := gob.NewEncoder(&b).Encode(sd) - if err != nil { - return nil, err - } - - return b.Bytes(), nil -} - -func (sd *sessionData) decode(b []byte) error { - r := bytes.NewReader(b) - return gob.NewDecoder(r).Decode(sd) -} - func generateToken() (string, error) { b := make([]byte, 32) _, err := rand.Read(b) diff --git a/data_test.go b/data_test.go index ec7c220..a50b6b9 100644 --- a/data_test.go +++ b/data_test.go @@ -26,8 +26,8 @@ func TestPut(t *testing.T) { s.Put(ctx, "foo", "bar") - if sd.Values["foo"] != "bar" { - t.Errorf("got %q: expected %q", sd.Values["foo"], "bar") + if sd.values["foo"] != "bar" { + t.Errorf("got %q: expected %q", sd.values["foo"], "bar") } if sd.status != Modified { @@ -38,7 +38,7 @@ func TestPut(t *testing.T) { func TestGet(t *testing.T) { s := NewSession() sd := newSessionData(time.Hour) - sd.Values["foo"] = "bar" + sd.values["foo"] = "bar" ctx := s.addSessionDataToContext(context.Background(), sd) str, ok := s.Get(ctx, "foo").(string) @@ -54,7 +54,7 @@ func TestGet(t *testing.T) { func TestPop(t *testing.T) { s := NewSession() sd := newSessionData(time.Hour) - sd.Values["foo"] = "bar" + sd.values["foo"] = "bar" ctx := s.addSessionDataToContext(context.Background(), sd) str, ok := s.Pop(ctx, "foo").(string) @@ -66,7 +66,7 @@ func TestPop(t *testing.T) { t.Errorf("got %q: expected %q", str, "bar") } - _, ok = sd.Values["foo"] + _, ok = sd.values["foo"] if ok { t.Errorf("got %v: expected %v", ok, false) } @@ -79,13 +79,13 @@ func TestPop(t *testing.T) { func TestRemove(t *testing.T) { s := NewSession() sd := newSessionData(time.Hour) - sd.Values["foo"] = "bar" + sd.values["foo"] = "bar" ctx := s.addSessionDataToContext(context.Background(), sd) s.Remove(ctx, "foo") - if sd.Values["foo"] != nil { - t.Errorf("got %v: expected %v", sd.Values["foo"], nil) + if sd.values["foo"] != nil { + t.Errorf("got %v: expected %v", sd.values["foo"], nil) } if sd.status != Modified { @@ -96,18 +96,18 @@ func TestRemove(t *testing.T) { func TestClear(t *testing.T) { s := NewSession() sd := newSessionData(time.Hour) - sd.Values["foo"] = "bar" - sd.Values["baz"] = "boz" + sd.values["foo"] = "bar" + sd.values["baz"] = "boz" ctx := s.addSessionDataToContext(context.Background(), sd) s.Clear(ctx) - if sd.Values["foo"] != nil { - t.Errorf("got %v: expected %v", sd.Values["foo"], nil) + if sd.values["foo"] != nil { + t.Errorf("got %v: expected %v", sd.values["foo"], nil) } - if sd.Values["baz"] != nil { - t.Errorf("got %v: expected %v", sd.Values["baz"], nil) + if sd.values["baz"] != nil { + t.Errorf("got %v: expected %v", sd.values["baz"], nil) } if sd.status != Modified { @@ -118,7 +118,7 @@ func TestClear(t *testing.T) { func TestExists(t *testing.T) { s := NewSession() sd := newSessionData(time.Hour) - sd.Values["foo"] = "bar" + sd.values["foo"] = "bar" ctx := s.addSessionDataToContext(context.Background(), sd) if !s.Exists(ctx, "foo") { @@ -133,8 +133,8 @@ func TestExists(t *testing.T) { func TestKeys(t *testing.T) { s := NewSession() sd := newSessionData(time.Hour) - sd.Values["foo"] = "bar" - sd.Values["woo"] = "waa" + sd.values["foo"] = "bar" + sd.values["woo"] = "waa" ctx := s.addSessionDataToContext(context.Background(), sd) keys := s.Keys(ctx) @@ -146,7 +146,7 @@ func TestKeys(t *testing.T) { func TestGetString(t *testing.T) { s := NewSession() sd := newSessionData(time.Hour) - sd.Values["foo"] = "bar" + sd.values["foo"] = "bar" ctx := s.addSessionDataToContext(context.Background(), sd) str := s.GetString(ctx, "foo") @@ -163,7 +163,7 @@ func TestGetString(t *testing.T) { func TestGetBool(t *testing.T) { s := NewSession() sd := newSessionData(time.Hour) - sd.Values["foo"] = true + sd.values["foo"] = true ctx := s.addSessionDataToContext(context.Background(), sd) b := s.GetBool(ctx, "foo") @@ -180,7 +180,7 @@ func TestGetBool(t *testing.T) { func TestGetInt(t *testing.T) { s := NewSession() sd := newSessionData(time.Hour) - sd.Values["foo"] = 123 + sd.values["foo"] = 123 ctx := s.addSessionDataToContext(context.Background(), sd) i := s.GetInt(ctx, "foo") @@ -197,7 +197,7 @@ func TestGetInt(t *testing.T) { func TestGetFloat(t *testing.T) { s := NewSession() sd := newSessionData(time.Hour) - sd.Values["foo"] = 123.456 + sd.values["foo"] = 123.456 ctx := s.addSessionDataToContext(context.Background(), sd) f := s.GetFloat(ctx, "foo") @@ -214,7 +214,7 @@ func TestGetFloat(t *testing.T) { func TestGetBytes(t *testing.T) { s := NewSession() sd := newSessionData(time.Hour) - sd.Values["foo"] = []byte("bar") + sd.values["foo"] = []byte("bar") ctx := s.addSessionDataToContext(context.Background(), sd) b := s.GetBytes(ctx, "foo") @@ -233,7 +233,7 @@ func TestGetTime(t *testing.T) { s := NewSession() sd := newSessionData(time.Hour) - sd.Values["foo"] = now + sd.values["foo"] = now ctx := s.addSessionDataToContext(context.Background(), sd) tm := s.GetTime(ctx, "foo") @@ -250,7 +250,7 @@ func TestGetTime(t *testing.T) { func TestPopString(t *testing.T) { s := NewSession() sd := newSessionData(time.Hour) - sd.Values["foo"] = "bar" + sd.values["foo"] = "bar" ctx := s.addSessionDataToContext(context.Background(), sd) str := s.PopString(ctx, "foo") @@ -258,7 +258,7 @@ func TestPopString(t *testing.T) { t.Errorf("got %q: expected %q", str, "bar") } - _, ok := sd.Values["foo"] + _, ok := sd.values["foo"] if ok { t.Errorf("got %v: expected %v", ok, false) } @@ -276,7 +276,7 @@ func TestPopString(t *testing.T) { func TestPopBool(t *testing.T) { s := NewSession() sd := newSessionData(time.Hour) - sd.Values["foo"] = true + sd.values["foo"] = true ctx := s.addSessionDataToContext(context.Background(), sd) b := s.PopBool(ctx, "foo") @@ -284,7 +284,7 @@ func TestPopBool(t *testing.T) { t.Errorf("got %v: expected %v", b, true) } - _, ok := sd.Values["foo"] + _, ok := sd.values["foo"] if ok { t.Errorf("got %v: expected %v", ok, false) } @@ -302,7 +302,7 @@ func TestPopBool(t *testing.T) { func TestPopInt(t *testing.T) { s := NewSession() sd := newSessionData(time.Hour) - sd.Values["foo"] = 123 + sd.values["foo"] = 123 ctx := s.addSessionDataToContext(context.Background(), sd) i := s.PopInt(ctx, "foo") @@ -310,7 +310,7 @@ func TestPopInt(t *testing.T) { t.Errorf("got %d: expected %d", i, 123) } - _, ok := sd.Values["foo"] + _, ok := sd.values["foo"] if ok { t.Errorf("got %v: expected %v", ok, false) } @@ -328,7 +328,7 @@ func TestPopInt(t *testing.T) { func TestPopFloat(t *testing.T) { s := NewSession() sd := newSessionData(time.Hour) - sd.Values["foo"] = 123.456 + sd.values["foo"] = 123.456 ctx := s.addSessionDataToContext(context.Background(), sd) f := s.PopFloat(ctx, "foo") @@ -336,7 +336,7 @@ func TestPopFloat(t *testing.T) { t.Errorf("got %f: expected %f", f, 123.456) } - _, ok := sd.Values["foo"] + _, ok := sd.values["foo"] if ok { t.Errorf("got %v: expected %v", ok, false) } @@ -354,14 +354,14 @@ func TestPopFloat(t *testing.T) { func TestPopBytes(t *testing.T) { s := NewSession() sd := newSessionData(time.Hour) - sd.Values["foo"] = []byte("bar") + sd.values["foo"] = []byte("bar") ctx := s.addSessionDataToContext(context.Background(), sd) b := s.PopBytes(ctx, "foo") if !bytes.Equal(b, []byte("bar")) { t.Errorf("got %v: expected %v", b, []byte("bar")) } - _, ok := sd.Values["foo"] + _, ok := sd.values["foo"] if ok { t.Errorf("got %v: expected %v", ok, false) } @@ -380,7 +380,7 @@ func TestPopTime(t *testing.T) { now := time.Now() s := NewSession() sd := newSessionData(time.Hour) - sd.Values["foo"] = now + sd.values["foo"] = now ctx := s.addSessionDataToContext(context.Background(), sd) tm := s.PopTime(ctx, "foo") @@ -388,7 +388,7 @@ func TestPopTime(t *testing.T) { t.Errorf("got %v: expected %v", tm, now) } - _, ok := sd.Values["foo"] + _, ok := sd.values["foo"] if ok { t.Errorf("got %v: expected %v", ok, false) } diff --git a/session.go b/session.go index c420a2e..6a3f047 100644 --- a/session.go +++ b/session.go @@ -34,6 +34,11 @@ type SessionManager struct { // Cookie contains the configuration settings for session cookies. Cookie SessionCookie + // Codec controls the encoder/decoder used to transform session data to a + // byte slice for use by the session store. By default session data is + // encoded/decoded using encoding/gob. + Codec Codec + // contextKey is the key used to set and retrieve the session data from a // context.Context. It's automatically generated to ensure uniqueness. contextKey contextKey @@ -87,6 +92,7 @@ func New() *SessionManager { IdleTimeout: 0, Lifetime: 24 * time.Hour, Store: memstore.New(), + Codec: gobCodec{}, contextKey: generateContextKey(), Cookie: SessionCookie{ Name: "session",