diff --git a/util/testutil/session.go b/util/testutil/session.go new file mode 100644 index 00000000..53af10df --- /dev/null +++ b/util/testutil/session.go @@ -0,0 +1,187 @@ +package testutil + +import ( + "context" + stderrors "errors" + "sync/atomic" + + "github.com/samber/lo" + "goyave.dev/goyave/v5/util/errors" + "goyave.dev/goyave/v5/util/session" +) + +const ( + SessionCreated uint32 = iota + SessionCommitted + SessionRolledBack +) + +var ( + ErrSessionEnded = stderrors.New("testutil.Session: session already ended") + ErrEndRootSession = stderrors.New("testutil.Session: cannot commit/rollback root session") + ErrChildRunning = stderrors.New("testutil.Session: cannot commit/rollback if a child session is still running") + ErrNotParentContext = stderrors.New("testutil.Session: cannot create a child session with an unrelated context. Parent context should be the context or a child context of the parent session.") +) + +// ctxKey the key used to store the `*Session` into a context value. +type ctxKey struct{} + +// Session is an advanced mock for the `session.Session` interface. This implementation is designed to +// provide a realistic, observable transaction system and help identify incorrect usage. +// +// Each transaction created with this implementation has a cancellable context created from its parent. +// The context is canceled when the session is committed or rolled back. This helps detecting cases where +// your code tries to use a terminated transaction. +// +// A transaction cannot be committed or rolled back several times. It cannot be committed after being rolled back +// or the other way around. +// +// For nested transactions, all child sessions should be ended (committed or rolled back) before the parent can be ended. +// Moreover, the context given on `Begin()` should be the context or a child context of the parent session. +// +// A child session cannot be created or committed if its parent context is done. +// +// The root transaction cannot be committed or rolledback. This helps detecting cases where your codes +// uses the root session without creating a child session. +// +// This implementation is not meant to be used concurrently. You should create a new instance for each test. +type Session struct { + ctx context.Context + cancel func() + children []*Session + status atomic.Uint32 +} + +// NewTestSession create a new root session with the `context.Background()`. +func NewTestSession() *Session { + return &Session{ + ctx: context.Background(), + cancel: nil, + children: []*Session{}, + status: atomic.Uint32{}, + } +} + +// Begin returns a new child session with the given context. +func (s *Session) Begin(ctx context.Context) (session.Session, error) { + if s.status.Load() != SessionCreated { + return nil, ErrSessionEnded + } + if err := ctx.Err(); err != nil { + return nil, errors.New(err) + } + + parent := ctx.Value(ctxKey{}) + if s.cancel != nil && parent != s { + // Not the root session, we are creating a nested transaction + // The given context should belong to the parent session. + return nil, errors.New(ErrNotParentContext) + } + + childCtx, cancel := context.WithCancel(ctx) + tx := &Session{ + cancel: cancel, + children: []*Session{}, + status: atomic.Uint32{}, + } + tx.ctx = context.WithValue(childCtx, ctxKey{}, tx) + if parent != nil { + parentSession := parent.(*Session) + parentSession.children = append(parentSession.children, tx) + } else { + s.children = append(s.children, tx) + } + return tx, nil +} + +// Transaction executes a transaction. If the given function returns an error, the transaction +// is rolled back. Otherwise it is automatically committed before `Transaction()` returns. +// The underlying transaction mechanism is injected into the context as a value. +func (s *Session) Transaction(ctx context.Context, f func(context.Context) error) error { + tx, err := s.Begin(ctx) + if err != nil { + return errors.New(err) + } + + err = errors.New(f(tx.Context())) + if err != nil { + rollbackErr := errors.New(tx.Rollback()) + return errors.New([]error{err, rollbackErr}) + } + return errors.New(tx.Commit()) +} + +// Rollback the transaction. For this test utility, it only sets the sessions status to `SessionRollbedBack`. +// If the session status is not `testutil.SessionCreated`, returns `testutil.ErrSessionEnded`. +// +// It is not possible to roll back the root session. In this case, `testutil.ErrEndRootSession` is returned. +// +// This action is final. +func (s *Session) Rollback() error { + if s.cancel == nil { + return errors.New(ErrEndRootSession) + } + if s.hasRunningChild() { + return errors.New(ErrChildRunning) + } + swapped := s.status.CompareAndSwap(SessionCreated, SessionRolledBack) + if !swapped { + return errors.New(ErrSessionEnded) + } + s.cancel() + return nil +} + +// Commit the transaction. For this test utility, it only sets the sessions status to `SessionCommitted`. +// If the session status is not `testutil.SessionCreated`, returns `testutil.ErrSessionEnded`. +// +// It is not possible to commit the root session. In this case, `testutil.ErrEndRootSession` is returned. +// +// This action is final. +func (s *Session) Commit() error { + if s.cancel == nil { + return errors.New(ErrEndRootSession) + } + if err := s.ctx.Err(); err != nil { + if s.Status() != SessionCreated { + return errors.New([]error{err, ErrSessionEnded}) + } + return errors.New(err) + } + if s.hasRunningChild() { + return errors.New(ErrChildRunning) + } + swapped := s.status.CompareAndSwap(SessionCreated, SessionCommitted) + if !swapped { + return errors.New(ErrSessionEnded) + } + s.cancel() + return nil +} + +// Context returns the session's context. +func (s *Session) Context() context.Context { + return s.ctx +} + +// Status returns the session status. The value will be equal to `testutil.SessionCreated`, +// `testutil.SessionCommitted` or `testutil.SessionRolledBack`. +func (s *Session) Status() uint32 { + return s.status.Load() +} + +// Children returns the direct child sessions. You can use the returned values for your test assertions. +// The returned values are sorted in the order in which the child transactions were started. +// +// To access nested transactions, call `Children()` on the returned values. +// +// This method always returns a non-nil value but can return an empty slice. +func (s *Session) Children() []*Session { + return s.children +} + +func (s *Session) hasRunningChild() bool { + return lo.SomeBy(s.children, func(child *Session) bool { + return child.Status() == SessionCreated + }) +} diff --git a/util/testutil/session_test.go b/util/testutil/session_test.go new file mode 100644 index 00000000..06b8185d --- /dev/null +++ b/util/testutil/session_test.go @@ -0,0 +1,297 @@ +package testutil + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type testCtxKey struct{} + +func TestSession(t *testing.T) { + t.Run("new", func(t *testing.T) { + session := NewTestSession() + assert.Equal(t, context.Background(), session.Context()) + assert.Nil(t, session.cancel) + assert.NotNil(t, session.children) + assert.Empty(t, session.children) + assert.Equal(t, SessionCreated, session.Status()) + }) + + t.Run("manual_commit", func(t *testing.T) { + session := NewTestSession() + ctx := context.WithValue(context.Background(), testCtxKey{}, "test-value") + child, err := session.Begin(ctx) + require.NoError(t, err) + assert.Equal(t, SessionCreated, session.Status()) // Parent session status unchanged + + // Check the child session has been correctly created + childSession, ok := child.(*Session) + require.True(t, ok) + assert.Equal(t, SessionCreated, childSession.Status()) + assert.NotNil(t, childSession.cancel) + + // Uses the parent context + childCtx := childSession.Context() + assert.NotEqual(t, ctx, childCtx) + assert.Equal(t, "test-value", childCtx.Value(testCtxKey{})) + assert.Equal(t, child, childCtx.Value(ctxKey{})) + + err = childSession.Commit() + require.NoError(t, err) + assert.Equal(t, SessionCommitted, childSession.Status()) + assert.ErrorIs(t, childCtx.Err(), context.Canceled) // Make sure the child context has been canceled + + assert.NoError(t, session.Context().Err()) // The parent context should not be canceled + assert.Equal(t, SessionCreated, session.Status()) // Parent session status unchanged + }) + + t.Run("manual_rollback", func(t *testing.T) { + session := NewTestSession() + ctx := context.WithValue(context.Background(), testCtxKey{}, "test-value") + child, err := session.Begin(ctx) + require.NoError(t, err) + + childSession, ok := child.(*Session) + require.True(t, ok) + childCtx := childSession.Context() + + err = childSession.Rollback() + require.NoError(t, err) + assert.Equal(t, SessionRolledBack, childSession.Status()) + assert.ErrorIs(t, childCtx.Err(), context.Canceled) // Make sure the child context has been canceled + + assert.NoError(t, session.Context().Err()) // The parent context should not be canceled + assert.Equal(t, SessionCreated, session.Status()) // Parent session status unchanged + }) + + t.Run("manual_children_added", func(t *testing.T) { + session := NewTestSession() + child1, err := session.Begin(context.Background()) + require.NoError(t, err) + child2, err := session.Begin(context.Background()) + require.NoError(t, err) + + assert.Equal(t, []*Session{child1.(*Session), child2.(*Session)}, session.Children()) + + nested, err := child1.Begin(child1.Context()) + require.NoError(t, err) + assert.Equal(t, []*Session{nested.(*Session)}, child1.(*Session).Children()) + + nested2, err := session.Begin(child1.Context()) // Parent is child1 because of the context, not root session + require.NoError(t, err) + assert.Equal(t, []*Session{nested.(*Session), nested2.(*Session)}, child1.(*Session).Children()) + assert.Equal(t, []*Session{child1.(*Session), child2.(*Session)}, session.Children()) // Root session children unchanged + }) + + t.Run("full_tx_children_added", func(t *testing.T) { + session := NewTestSession() + ctx := context.WithValue(context.Background(), testCtxKey{}, "test-value") + for range 2 { + err := session.Transaction(ctx, func(_ context.Context) error { + return nil + }) + require.NoError(t, err) + } + + // Both children were added to the children slice + // and both should be committed. + children := session.Children() + assert.Len(t, children, 2) + for _, c := range children { + assert.Equal(t, SessionCommitted, c.Status()) + + // Uses the parent context + childCtx := c.Context() + assert.NotEqual(t, ctx, childCtx) + assert.Equal(t, "test-value", childCtx.Value(testCtxKey{})) + assert.ErrorIs(t, childCtx.Err(), context.Canceled) // Make sure the child context has been canceled + } + assert.NoError(t, session.Context().Err()) // The parent context should not be canceled + assert.Equal(t, SessionCreated, session.Status()) // Parent session status unchanged + }) + + t.Run("full_tx_rollback", func(t *testing.T) { + testError := errors.New("test error") + session := NewTestSession() + err := session.Transaction(context.Background(), func(_ context.Context) error { + return testError + }) + require.ErrorIs(t, err, testError) + assert.NotEqual(t, testError, err) // Error should be wrapped + + // Children should be rolled back. + children := session.Children() + assert.Len(t, children, 1) + for _, c := range children { + assert.Equal(t, SessionRolledBack, c.Status()) + assert.ErrorIs(t, c.Context().Err(), context.Canceled) // Make sure the child context has been canceled + } + + assert.NoError(t, session.Context().Err()) // The parent context should not be canceled + assert.Equal(t, SessionCreated, session.Status()) // Parent session status unchanged + }) + + t.Run("cannot_commit_root_session", func(t *testing.T) { + session := NewTestSession() + assert.ErrorIs(t, session.Commit(), ErrEndRootSession) + }) + + t.Run("cannot_rollback_root_session", func(t *testing.T) { + session := NewTestSession() + assert.ErrorIs(t, session.Rollback(), ErrEndRootSession) + }) + + t.Run("cannot_begin_from_ended_session", func(t *testing.T) { + session := NewTestSession() + child, err := session.Begin(context.Background()) + require.NoError(t, err) + require.NoError(t, child.Commit()) + + c, err := child.Begin(context.Background()) + assert.Nil(t, c) + assert.ErrorIs(t, err, ErrSessionEnded) + }) + + t.Run("cannot_commit_committed_session", func(t *testing.T) { + session := NewTestSession() + child, err := session.Begin(context.Background()) + require.NoError(t, err) + require.NoError(t, child.Commit()) + assert.ErrorIs(t, child.Commit(), ErrSessionEnded) + assert.Equal(t, SessionCommitted, child.(*Session).Status()) + }) + + t.Run("cannot_commit_rolledback_session", func(t *testing.T) { + session := NewTestSession() + child, err := session.Begin(context.Background()) + require.NoError(t, err) + require.NoError(t, child.Rollback()) + assert.ErrorIs(t, child.Commit(), ErrSessionEnded) + assert.Equal(t, SessionRolledBack, child.(*Session).Status()) + }) + + t.Run("cannot_rollback_rolledback_session", func(t *testing.T) { + session := NewTestSession() + child, err := session.Begin(context.Background()) + require.NoError(t, err) + require.NoError(t, child.Rollback()) + assert.ErrorIs(t, child.Rollback(), ErrSessionEnded) + assert.Equal(t, SessionRolledBack, child.(*Session).Status()) + }) + + t.Run("cannot_rollback_committed_session", func(t *testing.T) { + session := NewTestSession() + child, err := session.Begin(context.Background()) + require.NoError(t, err) + require.NoError(t, child.Commit()) + assert.ErrorIs(t, child.Rollback(), ErrSessionEnded) + assert.Equal(t, SessionCommitted, child.(*Session).Status()) + }) + + t.Run("cannot_commit_if_child_is_running", func(t *testing.T) { + session := NewTestSession() + mainSession, err := session.Begin(context.Background()) + require.NoError(t, err) + + committedChild, err := mainSession.Begin(mainSession.Context()) + require.NoError(t, err) + require.NoError(t, committedChild.Commit()) + rolledBackchild, err := mainSession.Begin(mainSession.Context()) + require.NoError(t, err) + require.NoError(t, rolledBackchild.Rollback()) + + _, err = mainSession.Begin(mainSession.Context()) // Running child + require.NoError(t, err) + + err = mainSession.Commit() + require.ErrorIs(t, err, ErrChildRunning) + assert.Equal(t, SessionCreated, mainSession.(*Session).Status()) // Status unchanged + }) + + t.Run("cannot_rollback_if_child_is_running", func(t *testing.T) { + session := NewTestSession() + mainSession, err := session.Begin(context.Background()) + require.NoError(t, err) + + committedChild, err := mainSession.Begin(mainSession.Context()) + require.NoError(t, err) + require.NoError(t, committedChild.Commit()) + rolledBackchild, err := mainSession.Begin(mainSession.Context()) + require.NoError(t, err) + require.NoError(t, rolledBackchild.Rollback()) + + _, err = mainSession.Begin(mainSession.Context()) // Running child + require.NoError(t, err) + + err = mainSession.Rollback() + require.ErrorIs(t, err, ErrChildRunning) + assert.Equal(t, SessionCreated, mainSession.(*Session).Status()) // Status unchanged + }) + + t.Run("cannot_begin_if_parent_context_canceled", func(t *testing.T) { + session := NewTestSession() + ctx, cancel := context.WithCancel(context.Background()) + cancel() + _, err := session.Begin(ctx) + require.ErrorIs(t, err, context.Canceled) + + ctx, cancel = context.WithCancel(context.Background()) + mainSession, err := session.Begin(ctx) + require.NoError(t, err) + cancel() + _, err = mainSession.Begin(context.WithValue(ctx, testCtxKey{}, "test-value")) + require.ErrorIs(t, err, context.Canceled) + require.NotErrorIs(t, err, ErrSessionEnded) + }) + + t.Run("cannot_full_tx_if_parent_context_canceled", func(t *testing.T) { + session := NewTestSession() + ctx, cancel := context.WithCancel(context.Background()) + mainSession, err := session.Begin(ctx) + require.NoError(t, err) + cancel() + err = mainSession.Transaction(context.WithValue(ctx, testCtxKey{}, "test-value"), func(_ context.Context) error { + return nil + }) + require.ErrorIs(t, err, context.Canceled) + require.NotErrorIs(t, err, ErrSessionEnded) + }) + + t.Run("cannot_commit_if_parent_context_canceled", func(t *testing.T) { + session := NewTestSession() + ctx, cancel := context.WithCancel(context.Background()) + mainSession, err := session.Begin(ctx) + require.NoError(t, err) + cancel() + err = mainSession.Commit() + require.ErrorIs(t, err, context.Canceled) + require.NotErrorIs(t, err, ErrSessionEnded) + }) + + t.Run("can_rollback_if_parent_context_canceled", func(t *testing.T) { + session := NewTestSession() + ctx, cancel := context.WithCancel(context.Background()) + mainSession, err := session.Begin(ctx) + require.NoError(t, err) + cancel() + err = mainSession.Rollback() + require.NoError(t, err) + }) + + t.Run("begin_context_should_be_from_parent_session", func(t *testing.T) { + session := NewTestSession() + mainSession, err := session.Begin(context.Background()) + require.NoError(t, err) + + _, err = mainSession.Begin(context.Background()) + require.ErrorIs(t, err, ErrNotParentContext) + + ctx := context.WithValue(mainSession.Context(), testCtxKey{}, "test-value") + _, err = mainSession.Begin(ctx) + require.NoError(t, err) + }) +}