Skip to content

Commit

Permalink
Add unit tests for CreateFailoverMarkerTasks
Browse files Browse the repository at this point in the history
  • Loading branch information
Shaddoll committed Mar 5, 2024
1 parent 50b4b84 commit c99c7f8
Show file tree
Hide file tree
Showing 2 changed files with 214 additions and 18 deletions.
52 changes: 34 additions & 18 deletions common/persistence/sql/sql_execution_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ import (
"github.com/uber/cadence/common"
"github.com/uber/cadence/common/collection"
"github.com/uber/cadence/common/log"
"github.com/uber/cadence/common/log/tag"
p "github.com/uber/cadence/common/persistence"
"github.com/uber/cadence/common/persistence/serialization"
"github.com/uber/cadence/common/persistence/sql/sqlplugin"
Expand Down Expand Up @@ -1241,25 +1240,42 @@ func (m *sqlExecutionStore) CreateFailoverMarkerTasks(
request *p.CreateFailoverMarkersRequest,
) error {
dbShardID := sqlplugin.GetDBShardIDFromHistoryShardID(m.shardID, m.db.GetTotalNumDBShards())
return m.txExecuteShardLocked(ctx, dbShardID, "CreateFailoverMarkerTasks", request.RangeID, func(tx sqlplugin.Tx) error {
for _, task := range request.Markers {
t := []p.Task{task}
if err := createReplicationTasks(
ctx,
tx,
t,
m.shardID,
serialization.MustParseUUID(task.DomainID),
emptyWorkflowID,
serialization.MustParseUUID(emptyReplicationRunID),
m.parser,
); err != nil {
rollBackErr := tx.Rollback()
if rollBackErr != nil {
m.logger.Error("transaction rollback error", tag.Error(rollBackErr))
}
return m.txExecuteShardLockedFn(ctx, dbShardID, "CreateFailoverMarkerTasks", request.RangeID, func(tx sqlplugin.Tx) error {
replicationTasksRows := make([]sqlplugin.ReplicationTasksRow, len(request.Markers))
for i, task := range request.Markers {
blob, err := m.parser.ReplicationTaskInfoToBlob(&serialization.ReplicationTaskInfo{
DomainID: serialization.MustParseUUID(task.DomainID),
WorkflowID: emptyWorkflowID,
RunID: serialization.MustParseUUID(emptyReplicationRunID),
TaskType: int16(task.GetType()),
FirstEventID: common.EmptyEventID,
NextEventID: common.EmptyEventID,
Version: task.GetVersion(),
ScheduledID: common.EmptyEventID,
EventStoreVersion: p.EventStoreVersion,
NewRunEventStoreVersion: p.EventStoreVersion,
BranchToken: nil,
NewRunBranchToken: nil,
CreationTimestamp: task.GetVisibilityTimestamp(),
})
if err != nil {
return err
}
replicationTasksRows[i].ShardID = m.shardID
replicationTasksRows[i].TaskID = task.GetTaskID()
replicationTasksRows[i].Data = blob.Data
replicationTasksRows[i].DataEncoding = string(blob.Encoding)
}
result, err := tx.InsertIntoReplicationTasks(ctx, replicationTasksRows)
if err != nil {
return convertCommonErrors(tx, "CreateFailoverMarkerTasks", "", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return &types.InternalServiceError{Message: fmt.Sprintf("CreateFailoverMarkerTasks failed. Could not verify number of rows inserted. Error: %v", err)}
}
if int(rowsAffected) != len(replicationTasksRows) {
return &types.InternalServiceError{Message: fmt.Sprintf("CreateFailoverMarkerTasks failed. Inserted %v instead of %v rows into replication_tasks.", rowsAffected, len(replicationTasksRows))}
}
return nil
})
Expand Down
180 changes: 180 additions & 0 deletions common/persistence/sql/sql_execution_store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2763,3 +2763,183 @@ func TestConflictResolveWorkflowExecution(t *testing.T) {
})
}
}

func TestCreateFailoverMarkerTasks(t *testing.T) {
testCases := []struct {
name string
req *persistence.CreateFailoverMarkersRequest
mockSetup func(*sqlplugin.MockTx, *serialization.MockParser)
wantErr bool
}{
{
name: "Success case",
req: &persistence.CreateFailoverMarkersRequest{
RangeID: 1,
Markers: []*persistence.FailoverMarkerTask{
{
TaskID: 1,
VisibilityTimestamp: time.Unix(11, 12),
Version: 101,
DomainID: "ff9c8a3f-0e4f-4d3e-a4d2-6f5f8f3f7d9d",
},
},
},
mockSetup: func(tx *sqlplugin.MockTx, parser *serialization.MockParser) {
parser.EXPECT().ReplicationTaskInfoToBlob(gomock.Any()).Return(persistence.DataBlob{
Encoding: common.EncodingTypeThriftRW,
Data: []byte("test data"),
}, nil)
tx.EXPECT().InsertIntoReplicationTasks(gomock.Any(), []sqlplugin.ReplicationTasksRow{
{
ShardID: 0,
TaskID: 1,
Data: []byte("test data"),
DataEncoding: "thriftrw",
},
}).Return(&sqlResult{
rowsAffected: 1,
}, nil)
},
wantErr: false,
},
{
name: "Error - ReplicationTaskInfoToBlob failed",
req: &persistence.CreateFailoverMarkersRequest{
RangeID: 1,
Markers: []*persistence.FailoverMarkerTask{
{
TaskID: 1,
VisibilityTimestamp: time.Unix(11, 12),
Version: 101,
DomainID: "ff9c8a3f-0e4f-4d3e-a4d2-6f5f8f3f7d9d",
},
},
},
mockSetup: func(tx *sqlplugin.MockTx, parser *serialization.MockParser) {
parser.EXPECT().ReplicationTaskInfoToBlob(gomock.Any()).Return(persistence.DataBlob{}, errors.New("some random error"))
},
wantErr: true,
},
{
name: "Error - InsertIntoReplicationTasks failed",
req: &persistence.CreateFailoverMarkersRequest{
RangeID: 1,
Markers: []*persistence.FailoverMarkerTask{
{
TaskID: 1,
VisibilityTimestamp: time.Unix(11, 12),
Version: 101,
DomainID: "ff9c8a3f-0e4f-4d3e-a4d2-6f5f8f3f7d9d",
},
},
},
mockSetup: func(tx *sqlplugin.MockTx, parser *serialization.MockParser) {
parser.EXPECT().ReplicationTaskInfoToBlob(gomock.Any()).Return(persistence.DataBlob{
Encoding: common.EncodingTypeThriftRW,
Data: []byte("test data"),
}, nil)
tx.EXPECT().InsertIntoReplicationTasks(gomock.Any(), []sqlplugin.ReplicationTasksRow{
{
ShardID: 0,
TaskID: 1,
Data: []byte("test data"),
DataEncoding: "thriftrw",
},
}).Return(nil, errors.New("some random error"))
tx.EXPECT().IsNotFoundError(gomock.Any()).Return(true)
},
wantErr: true,
},
{
name: "Error - row affected error",
req: &persistence.CreateFailoverMarkersRequest{
RangeID: 1,
Markers: []*persistence.FailoverMarkerTask{
{
TaskID: 1,
VisibilityTimestamp: time.Unix(11, 12),
Version: 101,
DomainID: "ff9c8a3f-0e4f-4d3e-a4d2-6f5f8f3f7d9d",
},
},
},
mockSetup: func(tx *sqlplugin.MockTx, parser *serialization.MockParser) {
parser.EXPECT().ReplicationTaskInfoToBlob(gomock.Any()).Return(persistence.DataBlob{
Encoding: common.EncodingTypeThriftRW,
Data: []byte("test data"),
}, nil)
tx.EXPECT().InsertIntoReplicationTasks(gomock.Any(), []sqlplugin.ReplicationTasksRow{
{
ShardID: 0,
TaskID: 1,
Data: []byte("test data"),
DataEncoding: "thriftrw",
},
}).Return(&sqlResult{
err: errors.New("some error"),
}, nil)
},
wantErr: true,
},
{
name: "Error - row affected number mismatch",
req: &persistence.CreateFailoverMarkersRequest{
RangeID: 1,
Markers: []*persistence.FailoverMarkerTask{
{
TaskID: 1,
VisibilityTimestamp: time.Unix(11, 12),
Version: 101,
DomainID: "ff9c8a3f-0e4f-4d3e-a4d2-6f5f8f3f7d9d",
},
},
},
mockSetup: func(tx *sqlplugin.MockTx, parser *serialization.MockParser) {
parser.EXPECT().ReplicationTaskInfoToBlob(gomock.Any()).Return(persistence.DataBlob{
Encoding: common.EncodingTypeThriftRW,
Data: []byte("test data"),
}, nil)
tx.EXPECT().InsertIntoReplicationTasks(gomock.Any(), []sqlplugin.ReplicationTasksRow{
{
ShardID: 0,
TaskID: 1,
Data: []byte("test data"),
DataEncoding: "thriftrw",
},
}).Return(&sqlResult{
rowsAffected: 0,
}, nil)
},
wantErr: true,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ctrl := gomock.NewController(t)
db := sqlplugin.NewMockDB(ctrl)
db.EXPECT().GetTotalNumDBShards().Return(1)
tx := sqlplugin.NewMockTx(ctrl)
parser := serialization.NewMockParser(ctrl)
tc.mockSetup(tx, parser)
s := &sqlExecutionStore{
shardID: 0,
sqlStore: sqlStore{
db: db,
logger: testlogger.New(t),
parser: parser,
},
txExecuteShardLockedFn: func(_ context.Context, _ int, _ string, _ int64, fn func(sqlplugin.Tx) error) error {
return fn(tx)
},
}

err := s.CreateFailoverMarkerTasks(context.Background(), tc.req)
if tc.wantErr {
assert.Error(t, err, "Expected an error for test case")
} else {
assert.NoError(t, err, "Did not expect an error for test case")
}
})
}
}

0 comments on commit c99c7f8

Please sign in to comment.