Skip to content

Commit

Permalink
improve tests, add tests for AddFuture, AddSend
Browse files Browse the repository at this point in the history
  • Loading branch information
yuandrew committed Nov 1, 2024
1 parent 5d2778e commit ceadefd
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 17 deletions.
141 changes: 129 additions & 12 deletions internal/internal_coroutines_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"go.temporal.io/api/workflowservice/v1"
"go.temporal.io/sdk/converter"
)

Expand Down Expand Up @@ -552,17 +553,18 @@ func TestBlockingSelect(t *testing.T) {
}

func TestSelectBlockingDefault(t *testing.T) {
// manually create a dispatcher to ensure sdkFlags are not set
var history []string
env := &workflowEnvironmentImpl{
sdkFlags: &sdkFlags{},
sdkFlags: newSDKFlags(&workflowservice.GetSystemInfoResponse_Capabilities{SdkMetadata: true}),
commandsHelper: newCommandsHelper(),
dataConverter: converter.GetDefaultDataConverter(),
workflowInfo: &WorkflowInfo{
Namespace: "namespace:" + t.Name(),
TaskQueueName: "taskqueue:" + t.Name(),
},
}
// Verify that the flag is not set
require.False(t, env.GetFlag(SDKFlagBlockedSelectorSignalReceive))
interceptor, ctx, err := newWorkflowContext(env, nil)
require.NoError(t, err, "newWorkflowContext failed")
d, _ := newDispatcher(ctx, interceptor, func(ctx Context) {
Expand Down Expand Up @@ -594,19 +596,18 @@ func TestSelectBlockingDefault(t *testing.T) {
history = append(history, fmt.Sprintf("c2-%v", v))
})
history = append(history, "select1")
require.False(t, selector.HasPending())
selector.Select(ctx)

// Default behavior this signal is lost
require.True(t, c1.Len() == 0 && v == "two")

history = append(history, "select2")
require.False(t, selector.HasPending())
selector.Select(ctx)
history = append(history, "done")
}, func() bool { return false })
defer d.Close()
requireNoExecuteErr(t, d.ExecuteUntilAllBlocked(defaultDeadlockDetectionTimeout))
require.True(t, d.IsDone())
require.False(t, d.IsDone())

expected := []string{
"select1",
Expand All @@ -616,15 +617,25 @@ func TestSelectBlockingDefault(t *testing.T) {
"add-two-done",
"c2-two",
"select2",
"done",
}
require.EqualValues(t, expected, history)
}

func TestSelectBlockingDefaultWithFlag(t *testing.T) {
// sdkFlags are set by default for tests
var history []string
d := createNewDispatcher(func(ctx Context) {
env := &workflowEnvironmentImpl{
sdkFlags: newSDKFlags(&workflowservice.GetSystemInfoResponse_Capabilities{SdkMetadata: true}),
commandsHelper: newCommandsHelper(),
dataConverter: converter.GetDefaultDataConverter(),
workflowInfo: &WorkflowInfo{
Namespace: "namespace:" + t.Name(),
TaskQueueName: "taskqueue:" + t.Name(),
},
}
require.True(t, env.TryUse(SDKFlagBlockedSelectorSignalReceive))
interceptor, ctx, err := newWorkflowContext(env, nil)
require.NoError(t, err, "newWorkflowContext failed")
d, _ := newDispatcher(ctx, interceptor, func(ctx Context) {
c1 := NewChannel(ctx)
c2 := NewChannel(ctx)

Expand Down Expand Up @@ -653,18 +664,15 @@ func TestSelectBlockingDefaultWithFlag(t *testing.T) {
history = append(history, fmt.Sprintf("c2-%v", v))
})
history = append(history, "select1")
require.False(t, selector.HasPending())
selector.Select(ctx)

// Signal should not be lost
require.False(t, c1.Len() == 0 && v == "two")

history = append(history, "select2")
require.True(t, selector.HasPending())
selector.Select(ctx)
require.False(t, selector.HasPending())
history = append(history, "done")
})
}, func() bool { return false })
defer d.Close()
requireNoExecuteErr(t, d.ExecuteUntilAllBlocked(defaultDeadlockDetectionTimeout))
require.True(t, d.IsDone())
Expand All @@ -680,6 +688,115 @@ func TestSelectBlockingDefaultWithFlag(t *testing.T) {
"c1-one",
"done",
}

require.EqualValues(t, expected, history)
}

func TestBlockingSelectFuture(t *testing.T) {
var history []string
d := createNewDispatcher(func(ctx Context) {
c1 := NewChannel(ctx)
f1, s1 := NewFuture(ctx)

Go(ctx, func(ctx Context) {
history = append(history, "add-one")
c1.Send(ctx, "one")
history = append(history, "add-one-done")
})
Go(ctx, func(ctx Context) {
history = append(history, "add-two")
s1.SetValue("one-future")
})

selector := NewSelector(ctx)
selector.
AddReceive(c1, func(c ReceiveChannel, more bool) {
var v string
c.Receive(ctx, &v)
history = append(history, fmt.Sprintf("c1-%v", v))
}).
AddFuture(f1, func(f Future) {
var v string
err := f.Get(ctx, &v)
require.NoError(t, err)
history = append(history, fmt.Sprintf("f1-%v", v))
})
history = append(history, "select1")
selector.Select(ctx)
fmt.Println("select1 done", history)

history = append(history, "select2")
selector.Select(ctx)
history = append(history, "done")

})
defer d.Close()
requireNoExecuteErr(t, d.ExecuteUntilAllBlocked(defaultDeadlockDetectionTimeout))
require.True(t, d.IsDone(), strings.Join(history, "\n"))
expected := []string{
"select1",
"add-one",
"add-one-done",
"add-two",
"c1-one",
"select2",
"f1-one-future",
"done",
}
require.EqualValues(t, expected, history)
}

func TestBlockingSelectSend(t *testing.T) {
var history []string
d := createNewDispatcher(func(ctx Context) {
c1 := NewChannel(ctx)
c2 := NewChannel(ctx)

Go(ctx, func(ctx Context) {
history = append(history, "add-one")
c1.Send(ctx, "one")
history = append(history, "add-one-done")
})
Go(ctx, func(ctx Context) {
require.True(t, c2.Len() == 1)
history = append(history, "receiver")
var v string
more := c2.Receive(ctx, &v)
require.True(t, more)
history = append(history, fmt.Sprintf("c2-%v", v))
require.True(t, c2.Len() == 0)
})

selector := NewSelector(ctx)
selector.
AddReceive(c1, func(c ReceiveChannel, more bool) {
var v string
c.Receive(ctx, &v)
history = append(history, fmt.Sprintf("c1-%v", v))
}).
AddSend(c2, "two", func() { history = append(history, "send2") })
history = append(history, "select1")
selector.Select(ctx)

history = append(history, "select2")
selector.Select(ctx)
history = append(history, "done")

})
defer d.Close()
requireNoExecuteErr(t, d.ExecuteUntilAllBlocked(defaultDeadlockDetectionTimeout))
require.True(t, d.IsDone(), strings.Join(history, "\n"))
expected := []string{
"select1",
"add-one",
"add-one-done",
"receiver",
"c1-one",
"select2",
"send2",
"done",
"c2-two",
}
require.EqualValues(t, expected, history)
}

Expand Down
7 changes: 5 additions & 2 deletions internal/internal_workflow_testsuite.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,8 @@ type (

workflowFunctionExecuting bool
bufferedUpdateRequests map[string][]func()

sdkFlags *sdkFlags
}

testSessionEnvironmentImpl struct {
Expand Down Expand Up @@ -289,6 +291,7 @@ func newTestWorkflowEnvironmentImpl(s *WorkflowTestSuite, parentRegistry *regist
failureConverter: GetDefaultFailureConverter(),
runTimeout: maxWorkflowTimeout,
bufferedUpdateRequests: make(map[string][]func()),
sdkFlags: newSDKFlags(&workflowservice.GetSystemInfoResponse_Capabilities{SdkMetadata: true}),
}

if debugMode {
Expand Down Expand Up @@ -581,11 +584,11 @@ func (env *testWorkflowEnvironmentImpl) getWorkflowDefinition(wt WorkflowType) (
}

func (env *testWorkflowEnvironmentImpl) TryUse(flag sdkFlag) bool {
return true
return env.sdkFlags.tryUse(flag, true)
}

func (env *testWorkflowEnvironmentImpl) GetFlag(flag sdkFlag) bool {
return true
return env.sdkFlags.getFlag(flag)
}

func (env *testWorkflowEnvironmentImpl) QueueUpdate(name string, f func()) {
Expand Down
10 changes: 7 additions & 3 deletions internal/internal_workflow_testsuite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4254,8 +4254,8 @@ func (s *WorkflowTestSuiteUnitTest) Test_SignalLoss() {
ch2.Receive(ctx, &v)
})
selector.Select(ctx)
s.Require().True(ch1.Len() == 1 && v == "s2")
s.Require().True(selector.HasPending())
s.Require().True(ch1.Len() == 0 && v == "s2")
selector.Select(ctx)

return nil
}
Expand All @@ -4268,5 +4268,9 @@ func (s *WorkflowTestSuiteUnitTest) Test_SignalLoss() {
}, 5*time.Second)
env.ExecuteWorkflow(workflowFn)
s.True(env.IsWorkflowCompleted())
s.NoError(env.GetWorkflowError())
err := env.GetWorkflowError()
s.Error(err)
var workflowErr *WorkflowExecutionError
s.True(errors.As(err, &workflowErr))
s.Equal("deadline exceeded (type: ScheduleToClose)", workflowErr.cause.Error())
}

0 comments on commit ceadefd

Please sign in to comment.