Skip to content

Commit

Permalink
Add unit tests for CreateWorkflowExecution
Browse files Browse the repository at this point in the history
  • Loading branch information
Shaddoll committed Mar 5, 2024
1 parent e6e6cf9 commit c6fe9c3
Show file tree
Hide file tree
Showing 3 changed files with 318 additions and 8 deletions.
25 changes: 17 additions & 8 deletions common/persistence/sql/sql_execution_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,11 @@ const (

type sqlExecutionStore struct {
sqlStore
shardID int
shardID int
txExecuteShardLockedFn func(context.Context, int, string, int64, func(sqlplugin.Tx) error) error
lockCurrentExecutionIfExistsFn func(context.Context, sqlplugin.Tx, int, serialization.UUID, string) (*sqlplugin.CurrentExecutionsRow, error)
createOrUpdateCurrentExecutionFn func(context.Context, sqlplugin.Tx, p.CreateWorkflowMode, int, serialization.UUID, string, serialization.UUID, int, int, string, int64, int64) error
applyWorkflowSnapshotTxAsNewFn func(context.Context, sqlplugin.Tx, int, *p.InternalWorkflowSnapshot, serialization.Parser) error
}

var _ p.ExecutionStore = (*sqlExecutionStore)(nil)
Expand All @@ -63,15 +67,20 @@ func NewSQLExecutionStore(
dc *p.DynamicConfiguration,
) (p.ExecutionStore, error) {

return &sqlExecutionStore{
shardID: shardID,
store := &sqlExecutionStore{
shardID: shardID,
lockCurrentExecutionIfExistsFn: lockCurrentExecutionIfExists,
createOrUpdateCurrentExecutionFn: createOrUpdateCurrentExecution,
applyWorkflowSnapshotTxAsNewFn: applyWorkflowSnapshotTxAsNew,
sqlStore: sqlStore{
db: db,
logger: logger,
parser: parser,
dc: dc,
},
}, nil
}
store.txExecuteShardLockedFn = store.txExecuteShardLocked
return store, nil
}

// txExecuteShardLocked executes f under transaction and with read lock on shard row
Expand Down Expand Up @@ -105,7 +114,7 @@ func (m *sqlExecutionStore) CreateWorkflowExecution(
) (response *p.CreateWorkflowExecutionResponse, err error) {
dbShardID := sqlplugin.GetDBShardIDFromHistoryShardID(m.shardID, m.db.GetTotalNumDBShards())

err = m.txExecuteShardLocked(ctx, dbShardID, "CreateWorkflowExecution", request.RangeID, func(tx sqlplugin.Tx) error {
err = m.txExecuteShardLockedFn(ctx, dbShardID, "CreateWorkflowExecution", request.RangeID, func(tx sqlplugin.Tx) error {
response, err = m.createWorkflowExecutionTx(ctx, tx, request)
return err
})
Expand Down Expand Up @@ -136,7 +145,7 @@ func (m *sqlExecutionStore) createWorkflowExecutionTx(

var err error
var row *sqlplugin.CurrentExecutionsRow
if row, err = lockCurrentExecutionIfExists(ctx, tx, m.shardID, domainID, workflowID); err != nil {
if row, err = m.lockCurrentExecutionIfExistsFn(ctx, tx, m.shardID, domainID, workflowID); err != nil {
return nil, err
}

Expand Down Expand Up @@ -204,7 +213,7 @@ func (m *sqlExecutionStore) createWorkflowExecutionTx(
}
}

if err := createOrUpdateCurrentExecution(
if err := m.createOrUpdateCurrentExecutionFn(
ctx,
tx,
request.Mode,
Expand All @@ -220,7 +229,7 @@ func (m *sqlExecutionStore) createWorkflowExecutionTx(
return nil, err
}

if err := applyWorkflowSnapshotTxAsNew(ctx, tx, shardID, &request.NewWorkflowSnapshot, m.parser); err != nil {
if err := m.applyWorkflowSnapshotTxAsNewFn(ctx, tx, shardID, &request.NewWorkflowSnapshot, m.parser); err != nil {
return nil, err
}

Expand Down
299 changes: 299 additions & 0 deletions common/persistence/sql/sql_execution_store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1955,3 +1955,302 @@ func TestTxExecuteShardLocked(t *testing.T) {
})
}
}

func TestCreateWorkflowExecution(t *testing.T) {
testCases := []struct {
name string
req *persistence.InternalCreateWorkflowExecutionRequest
lockCurrentExecutionIfExistsFn func(context.Context, sqlplugin.Tx, int, serialization.UUID, string) (*sqlplugin.CurrentExecutionsRow, error)
createOrUpdateCurrentExecutionFn func(context.Context, sqlplugin.Tx, persistence.CreateWorkflowMode, int, serialization.UUID, string, serialization.UUID, int, int, string, int64, int64) error
applyWorkflowSnapshotTxAsNewFn func(context.Context, sqlplugin.Tx, int, *persistence.InternalWorkflowSnapshot, serialization.Parser) error
wantErr bool
want *persistence.CreateWorkflowExecutionResponse
assertErr func(t *testing.T, err error)
}{
{
name: "Success - mode brand new",
req: &persistence.InternalCreateWorkflowExecutionRequest{
RangeID: 1,
Mode: persistence.CreateWorkflowModeBrandNew,
NewWorkflowSnapshot: persistence.InternalWorkflowSnapshot{
ExecutionInfo: &persistence.InternalWorkflowExecutionInfo{},
},
},
lockCurrentExecutionIfExistsFn: func(context.Context, sqlplugin.Tx, int, serialization.UUID, string) (*sqlplugin.CurrentExecutionsRow, error) {
return nil, nil
},
createOrUpdateCurrentExecutionFn: func(context.Context, sqlplugin.Tx, persistence.CreateWorkflowMode, int, serialization.UUID, string, serialization.UUID, int, int, string, int64, int64) error {
return nil
},
applyWorkflowSnapshotTxAsNewFn: func(context.Context, sqlplugin.Tx, int, *persistence.InternalWorkflowSnapshot, serialization.Parser) error {
return nil
},
want: &persistence.CreateWorkflowExecutionResponse{},
},
{
name: "Success - mode workflow ID reuse",
req: &persistence.InternalCreateWorkflowExecutionRequest{
RangeID: 1,
Mode: persistence.CreateWorkflowModeWorkflowIDReuse,
NewWorkflowSnapshot: persistence.InternalWorkflowSnapshot{
ExecutionInfo: &persistence.InternalWorkflowExecutionInfo{},
},
},
lockCurrentExecutionIfExistsFn: func(context.Context, sqlplugin.Tx, int, serialization.UUID, string) (*sqlplugin.CurrentExecutionsRow, error) {
return &sqlplugin.CurrentExecutionsRow{
State: persistence.WorkflowStateCompleted,
}, nil
},
createOrUpdateCurrentExecutionFn: func(context.Context, sqlplugin.Tx, persistence.CreateWorkflowMode, int, serialization.UUID, string, serialization.UUID, int, int, string, int64, int64) error {
return nil
},
applyWorkflowSnapshotTxAsNewFn: func(context.Context, sqlplugin.Tx, int, *persistence.InternalWorkflowSnapshot, serialization.Parser) error {
return nil
},
want: &persistence.CreateWorkflowExecutionResponse{},
},
{
name: "Success - mode zombie",
req: &persistence.InternalCreateWorkflowExecutionRequest{
RangeID: 1,
Mode: persistence.CreateWorkflowModeZombie,
NewWorkflowSnapshot: persistence.InternalWorkflowSnapshot{
ExecutionInfo: &persistence.InternalWorkflowExecutionInfo{
State: persistence.WorkflowStateZombie,
},
},
},
lockCurrentExecutionIfExistsFn: func(context.Context, sqlplugin.Tx, int, serialization.UUID, string) (*sqlplugin.CurrentExecutionsRow, error) {
return &sqlplugin.CurrentExecutionsRow{
RunID: serialization.MustParseUUID("abdcea69-61d5-44c3-9d55-afe23505a54a"),
}, nil
},
createOrUpdateCurrentExecutionFn: func(context.Context, sqlplugin.Tx, persistence.CreateWorkflowMode, int, serialization.UUID, string, serialization.UUID, int, int, string, int64, int64) error {
return nil
},
applyWorkflowSnapshotTxAsNewFn: func(context.Context, sqlplugin.Tx, int, *persistence.InternalWorkflowSnapshot, serialization.Parser) error {
return nil
},
want: &persistence.CreateWorkflowExecutionResponse{},
},
{
name: "Error - mode state validation failed",
req: &persistence.InternalCreateWorkflowExecutionRequest{
RangeID: 1,
Mode: persistence.CreateWorkflowModeZombie,
NewWorkflowSnapshot: persistence.InternalWorkflowSnapshot{
ExecutionInfo: &persistence.InternalWorkflowExecutionInfo{
State: persistence.WorkflowStateCreated,
},
},
},
wantErr: true,
},
{
name: "Error - lockCurrentExecutionIfExists failed",
req: &persistence.InternalCreateWorkflowExecutionRequest{
RangeID: 1,
Mode: persistence.CreateWorkflowModeBrandNew,
NewWorkflowSnapshot: persistence.InternalWorkflowSnapshot{
ExecutionInfo: &persistence.InternalWorkflowExecutionInfo{},
},
},
lockCurrentExecutionIfExistsFn: func(context.Context, sqlplugin.Tx, int, serialization.UUID, string) (*sqlplugin.CurrentExecutionsRow, error) {
return nil, errors.New("some random error")
},
wantErr: true,
},
{
name: "Error - mode brand new",
req: &persistence.InternalCreateWorkflowExecutionRequest{
RangeID: 1,
Mode: persistence.CreateWorkflowModeBrandNew,
NewWorkflowSnapshot: persistence.InternalWorkflowSnapshot{
ExecutionInfo: &persistence.InternalWorkflowExecutionInfo{},
},
},
lockCurrentExecutionIfExistsFn: func(context.Context, sqlplugin.Tx, int, serialization.UUID, string) (*sqlplugin.CurrentExecutionsRow, error) {
return &sqlplugin.CurrentExecutionsRow{
CreateRequestID: "test",
WorkflowID: "test",
RunID: serialization.MustParseUUID("abdcea69-61d5-44c3-9d55-afe23505a54a"),
State: persistence.WorkflowStateCreated,
CloseStatus: persistence.WorkflowCloseStatusNone,
LastWriteVersion: 10,
}, nil
},
wantErr: true,
assertErr: func(t *testing.T, err error) {
assert.Equal(t, &persistence.WorkflowExecutionAlreadyStartedError{
Msg: "Workflow execution already running. WorkflowId: test",
StartRequestID: "test",
RunID: "abdcea69-61d5-44c3-9d55-afe23505a54a",
State: persistence.WorkflowStateCreated,
CloseStatus: persistence.WorkflowCloseStatusNone,
LastWriteVersion: 10,
}, err)
},
},
{
name: "Error - mode workflow ID reuse, version mismatch",
req: &persistence.InternalCreateWorkflowExecutionRequest{
RangeID: 1,
Mode: persistence.CreateWorkflowModeWorkflowIDReuse,
NewWorkflowSnapshot: persistence.InternalWorkflowSnapshot{
ExecutionInfo: &persistence.InternalWorkflowExecutionInfo{},
},
},
lockCurrentExecutionIfExistsFn: func(context.Context, sqlplugin.Tx, int, serialization.UUID, string) (*sqlplugin.CurrentExecutionsRow, error) {
return &sqlplugin.CurrentExecutionsRow{
State: persistence.WorkflowStateCompleted,
LastWriteVersion: 10,
}, nil
},
wantErr: true,
assertErr: func(t *testing.T, err error) {
assert.Equal(t, &persistence.CurrentWorkflowConditionFailedError{
Msg: "Workflow execution creation condition failed. WorkflowId: , LastWriteVersion: 10, PreviousLastWriteVersion: 0",
}, err)
},
},
{
name: "Error - mode workflow ID reuse, state mismatch",
req: &persistence.InternalCreateWorkflowExecutionRequest{
RangeID: 1,
Mode: persistence.CreateWorkflowModeWorkflowIDReuse,
NewWorkflowSnapshot: persistence.InternalWorkflowSnapshot{
ExecutionInfo: &persistence.InternalWorkflowExecutionInfo{},
},
},
lockCurrentExecutionIfExistsFn: func(context.Context, sqlplugin.Tx, int, serialization.UUID, string) (*sqlplugin.CurrentExecutionsRow, error) {
return &sqlplugin.CurrentExecutionsRow{
State: persistence.WorkflowStateCreated,
}, nil
},
wantErr: true,
assertErr: func(t *testing.T, err error) {
assert.Equal(t, &persistence.CurrentWorkflowConditionFailedError{
Msg: "Workflow execution creation condition failed. WorkflowId: , State: 0, Expected: 2",
}, err)
},
},
{
name: "Error - mode workflow ID reuse, run ID mismatch",
req: &persistence.InternalCreateWorkflowExecutionRequest{
RangeID: 1,
Mode: persistence.CreateWorkflowModeWorkflowIDReuse,
NewWorkflowSnapshot: persistence.InternalWorkflowSnapshot{
ExecutionInfo: &persistence.InternalWorkflowExecutionInfo{},
},
},
lockCurrentExecutionIfExistsFn: func(context.Context, sqlplugin.Tx, int, serialization.UUID, string) (*sqlplugin.CurrentExecutionsRow, error) {
return &sqlplugin.CurrentExecutionsRow{
State: persistence.WorkflowStateCompleted,
RunID: serialization.MustParseUUID("abdcea69-61d5-44c3-9d55-afe23505a54a"),
}, nil
},
wantErr: true,
assertErr: func(t *testing.T, err error) {
assert.Equal(t, &persistence.CurrentWorkflowConditionFailedError{
Msg: "Workflow execution creation condition failed. WorkflowId: , RunID: abdcea69-61d5-44c3-9d55-afe23505a54a, PreviousRunID: ",
}, err)
},
},
{
name: "Error - mode zombie, run ID match",
req: &persistence.InternalCreateWorkflowExecutionRequest{
RangeID: 1,
Mode: persistence.CreateWorkflowModeZombie,
NewWorkflowSnapshot: persistence.InternalWorkflowSnapshot{
ExecutionInfo: &persistence.InternalWorkflowExecutionInfo{
State: persistence.WorkflowStateZombie,
},
},
},
lockCurrentExecutionIfExistsFn: func(context.Context, sqlplugin.Tx, int, serialization.UUID, string) (*sqlplugin.CurrentExecutionsRow, error) {
return &sqlplugin.CurrentExecutionsRow{}, nil
},
wantErr: true,
},
{
name: "Error - unknown mode",
req: &persistence.InternalCreateWorkflowExecutionRequest{
RangeID: 1,
Mode: persistence.CreateWorkflowMode(100),
NewWorkflowSnapshot: persistence.InternalWorkflowSnapshot{
ExecutionInfo: &persistence.InternalWorkflowExecutionInfo{},
},
},
wantErr: true,
},
{
name: "Error - createOrUpdateCurrentExecution failed",
req: &persistence.InternalCreateWorkflowExecutionRequest{
RangeID: 1,
Mode: persistence.CreateWorkflowModeBrandNew,
NewWorkflowSnapshot: persistence.InternalWorkflowSnapshot{
ExecutionInfo: &persistence.InternalWorkflowExecutionInfo{},
},
},
lockCurrentExecutionIfExistsFn: func(context.Context, sqlplugin.Tx, int, serialization.UUID, string) (*sqlplugin.CurrentExecutionsRow, error) {
return nil, nil
},
createOrUpdateCurrentExecutionFn: func(context.Context, sqlplugin.Tx, persistence.CreateWorkflowMode, int, serialization.UUID, string, serialization.UUID, int, int, string, int64, int64) error {
return errors.New("some random error")
},
wantErr: true,
},
{
name: "Error - applyWorkflowSnapshotTxAsNew failed",
req: &persistence.InternalCreateWorkflowExecutionRequest{
RangeID: 1,
Mode: persistence.CreateWorkflowModeBrandNew,
NewWorkflowSnapshot: persistence.InternalWorkflowSnapshot{
ExecutionInfo: &persistence.InternalWorkflowExecutionInfo{},
},
},
lockCurrentExecutionIfExistsFn: func(context.Context, sqlplugin.Tx, int, serialization.UUID, string) (*sqlplugin.CurrentExecutionsRow, error) {
return nil, nil
},
createOrUpdateCurrentExecutionFn: func(context.Context, sqlplugin.Tx, persistence.CreateWorkflowMode, int, serialization.UUID, string, serialization.UUID, int, int, string, int64, int64) error {
return nil
},
applyWorkflowSnapshotTxAsNewFn: func(context.Context, sqlplugin.Tx, int, *persistence.InternalWorkflowSnapshot, serialization.Parser) error {
return errors.New("some random error")
},
wantErr: true,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ctrl := gomock.NewController(t)
mockDB := sqlplugin.NewMockDB(ctrl)
mockDB.EXPECT().GetTotalNumDBShards().Return(1)
s := &sqlExecutionStore{
shardID: 0,
sqlStore: sqlStore{
db: mockDB,
logger: testlogger.New(t),
},
txExecuteShardLockedFn: func(_ context.Context, _ int, _ string, _ int64, fn func(sqlplugin.Tx) error) error {
return fn(nil)
},
lockCurrentExecutionIfExistsFn: tc.lockCurrentExecutionIfExistsFn,
createOrUpdateCurrentExecutionFn: tc.createOrUpdateCurrentExecutionFn,
applyWorkflowSnapshotTxAsNewFn: tc.applyWorkflowSnapshotTxAsNewFn,
}

got, err := s.CreateWorkflowExecution(context.Background(), tc.req)
if tc.wantErr {
assert.Error(t, err, "Expected an error for test case")
if tc.assertErr != nil {
tc.assertErr(t, err)
}
} else {
assert.NoError(t, err, "Did not expect an error for test case")
assert.Equal(t, tc.want, got, "Unexpected result for test case")
}
})
}
}
2 changes: 2 additions & 0 deletions common/persistence/sql/sql_execution_store_util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,7 @@ func TestApplyWorkflowMutationTx(t *testing.T) {
DeleteSignalInfos: []int64{1, 2},
UpsertSignalRequestedIDs: []string{"a", "b"},
DeleteSignalRequestedIDs: []string{"c", "d"},
ClearBufferedEvents: true,
},
mockSetup: func(mockTx *sqlplugin.MockTx, mockParser *serialization.MockParser) {
mockSetupLockAndCheckNextEventID(mockTx, shardID, serialization.MustParseUUID("8be8a310-7d20-483e-a5d2-48659dc47602"), "abc", serialization.MustParseUUID("8be8a310-7d20-483e-a5d2-48659dc47603"), 9, false)
Expand All @@ -510,6 +511,7 @@ func TestApplyWorkflowMutationTx(t *testing.T) {
mockUpdateRequestCancelInfos(mockTx, mockParser, 1, 2, false)
mockUpdateSignalInfos(mockTx, mockParser, 1, 2, false)
mockUpdateSignalRequested(mockTx, mockParser, 1, 2, false)
mockDeleteBufferedEvents(mockTx, shardID, serialization.MustParseUUID("8be8a310-7d20-483e-a5d2-48659dc47602"), "abc", serialization.MustParseUUID("8be8a310-7d20-483e-a5d2-48659dc47603"), false)
},
wantErr: false,
},
Expand Down

0 comments on commit c6fe9c3

Please sign in to comment.