Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix checksum validation for SQL implementation #5790

Merged
merged 3 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions common/persistence/data_manager_interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -864,6 +864,7 @@ type (
DomainID string
Execution types.WorkflowExecution
DomainName string
RangeID int64
}

// GetWorkflowExecutionResponse is the response to GetworkflowExecutionRequest
Expand Down
1 change: 1 addition & 0 deletions common/persistence/data_store_interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,7 @@ type (
InternalGetWorkflowExecutionRequest struct {
DomainID string
Execution types.WorkflowExecution
RangeID int64
}

// InternalGetWorkflowExecutionResponse is the response to GetWorkflowExecution for Persistence Interface
Expand Down
1 change: 1 addition & 0 deletions common/persistence/executionManager.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ func (m *executionManagerImpl) GetWorkflowExecution(
internalRequest := &InternalGetWorkflowExecutionRequest{
DomainID: request.DomainID,
Execution: request.Execution,
RangeID: request.RangeID,
}
response, err := m.persistence.GetWorkflowExecution(ctx, internalRequest)
if err != nil {
Expand Down
2 changes: 2 additions & 0 deletions common/persistence/persistence-tests/persistenceTestBase.go
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,7 @@ func (s *TestBase) GetWorkflowExecutionInfoWithStats(ctx context.Context, domain
response, err := s.ExecutionManager.GetWorkflowExecution(ctx, &persistence.GetWorkflowExecutionRequest{
DomainID: domainID,
Execution: workflowExecution,
RangeID: s.ShardInfo.RangeID,
})
if err != nil {
return nil, nil, err
Expand All @@ -490,6 +491,7 @@ func (s *TestBase) GetWorkflowExecutionInfo(ctx context.Context, domainID string
response, err := s.ExecutionManager.GetWorkflowExecution(ctx, &persistence.GetWorkflowExecutionRequest{
DomainID: domainID,
Execution: workflowExecution,
RangeID: s.ShardInfo.RangeID,
})
if err != nil {
return nil, err
Expand Down
3 changes: 3 additions & 0 deletions common/persistence/serializer.go
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,9 @@ func (t *serializerImpl) DeserializeAsyncWorkflowsConfig(data *DataBlob) (*types
}

func (t *serializerImpl) SerializeChecksum(sum checksum.Checksum, encodingType common.EncodingType) (*DataBlob, error) {
if len(sum.Value) == 0 {
return nil, nil
}
return t.serialize(sum, encodingType)
}

Expand Down
1 change: 1 addition & 0 deletions common/persistence/serializer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ func TestSerializers(t *testing.T) {
{
name: "checksum",
payloads: map[string]any{
"empty": checksum.Checksum{},
"normal": generateChecksum(),
},
serializeFn: func(payload any, encoding common.EncodingType) (*DataBlob, error) {
Expand Down
35 changes: 26 additions & 9 deletions common/persistence/sql/sql_execution_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -307,60 +307,60 @@ func (m *sqlExecutionStore) GetWorkflowExecution(
var bufferedEvents []*p.DataBlob
var signalsRequested map[string]struct{}

g, ctx := errgroup.WithContext(ctx)
g, childCtx := errgroup.WithContext(ctx)

g.Go(func() (e error) {
defer func() { recoverPanic(recover(), &e) }()
executions, e = m.getExecutions(ctx, request, domainID, wfID, runID)
executions, e = m.getExecutions(childCtx, request, domainID, wfID, runID)
return e
})

g.Go(func() (e error) {
defer func() { recoverPanic(recover(), &e) }()
activityInfos, e = getActivityInfoMap(
ctx, m.db, m.shardID, domainID, wfID, runID, m.parser)
childCtx, m.db, m.shardID, domainID, wfID, runID, m.parser)
return e
})

g.Go(func() (e error) {
defer func() { recoverPanic(recover(), &e) }()
timerInfos, e = getTimerInfoMap(
ctx, m.db, m.shardID, domainID, wfID, runID, m.parser)
childCtx, m.db, m.shardID, domainID, wfID, runID, m.parser)
return e
})

g.Go(func() (e error) {
defer func() { recoverPanic(recover(), &e) }()
childExecutionInfos, e = getChildExecutionInfoMap(
ctx, m.db, m.shardID, domainID, wfID, runID, m.parser)
childCtx, m.db, m.shardID, domainID, wfID, runID, m.parser)
return e
})

g.Go(func() (e error) {
defer func() { recoverPanic(recover(), &e) }()
requestCancelInfos, e = getRequestCancelInfoMap(
ctx, m.db, m.shardID, domainID, wfID, runID, m.parser)
childCtx, m.db, m.shardID, domainID, wfID, runID, m.parser)
return e
})

g.Go(func() (e error) {
defer func() { recoverPanic(recover(), &e) }()
signalInfos, e = getSignalInfoMap(
ctx, m.db, m.shardID, domainID, wfID, runID, m.parser)
childCtx, m.db, m.shardID, domainID, wfID, runID, m.parser)
return e
})

g.Go(func() (e error) {
defer func() { recoverPanic(recover(), &e) }()
bufferedEvents, e = getBufferedEvents(
ctx, m.db, m.shardID, domainID, wfID, runID)
childCtx, m.db, m.shardID, domainID, wfID, runID)
return e
})

g.Go(func() (e error) {
defer func() { recoverPanic(recover(), &e) }()
signalsRequested, e = getSignalsRequested(
ctx, m.db, m.shardID, domainID, wfID, runID)
childCtx, m.db, m.shardID, domainID, wfID, runID)
return e
})

Expand All @@ -375,6 +375,23 @@ func (m *sqlExecutionStore) GetWorkflowExecution(
Message: fmt.Sprintf("GetWorkflowExecution: failed. Error: %v", err),
}
}
// if we have checksum, we need to make sure the rangeID did not change
// if the rangeID changed, it means the shard ownership might have changed
// and the workflow might have been updated when we read the data, so the data
// we read might not be from a consistent view, the checksum validation might fail
// in that case, we clear the checksum data so that we will not perform the validation
Shaddoll marked this conversation as resolved.
Show resolved Hide resolved
if state.ChecksumData != nil {
row, err := m.db.SelectFromShards(ctx, &sqlplugin.ShardsFilter{ShardID: int64(m.shardID)})
if err != nil {
return nil, convertCommonErrors(m.db, "GetWorkflowExecution", "", err)
}
if row.RangeID != request.RangeID {
// The GetWorkflowExecution operation will not be impacted by this. ChecksumData is purely for validation purposes.
m.logger.Warn("GetWorkflowExecution's checksum is discarded. The shard might have changed owner.")
state.ChecksumData = nil
Shaddoll marked this conversation as resolved.
Show resolved Hide resolved
}
}

state.ActivityInfos = activityInfos
state.TimerInfos = timerInfos
state.ChildExecutionInfos = childExecutionInfos
Expand Down
127 changes: 127 additions & 0 deletions common/persistence/sql/sql_execution_store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2975,6 +2975,7 @@ func TestGetWorkflowExecution(t *testing.T) {
mockSetup func(*sqlplugin.MockDB, *serialization.MockParser)
want *persistence.InternalGetWorkflowExecutionResponse
wantErr bool
assertErr func(t *testing.T, err error)
}{
{
name: "Success case",
Expand All @@ -2984,6 +2985,7 @@ func TestGetWorkflowExecution(t *testing.T) {
WorkflowID: "test-workflow-id",
RunID: "ee8d7b6e-876c-4b1e-9b6e-5e3e3c6b6b3f",
},
RangeID: 1,
},
mockSetup: func(db *sqlplugin.MockDB, parser *serialization.MockParser) {
db.EXPECT().SelectFromExecutions(gomock.Any(), gomock.Any()).Return([]sqlplugin.ExecutionsRow{
Expand Down Expand Up @@ -3204,6 +3206,9 @@ func TestGetWorkflowExecution(t *testing.T) {
Control: []byte("test control"),
RequestID: "test-signal-request-id",
}, nil)
db.EXPECT().SelectFromShards(gomock.Any(), gomock.Any()).Return(&sqlplugin.ShardsRow{
RangeID: 1,
}, nil)
},
want: &persistence.InternalGetWorkflowExecutionResponse{
State: &persistence.InternalWorkflowMutableState{
Expand Down Expand Up @@ -3366,6 +3371,125 @@ func TestGetWorkflowExecution(t *testing.T) {
},
wantErr: false,
},
{
name: "Error - Shard owner changed",
req: &persistence.InternalGetWorkflowExecutionRequest{
DomainID: "ff9c8a3f-0e4f-4d3e-a4d2-6f5f8f3f7d9d",
Execution: types.WorkflowExecution{
WorkflowID: "test-workflow-id",
RunID: "ee8d7b6e-876c-4b1e-9b6e-5e3e3c6b6b3f",
},
},
mockSetup: func(db *sqlplugin.MockDB, parser *serialization.MockParser) {
db.EXPECT().SelectFromExecutions(gomock.Any(), gomock.Any()).Return([]sqlplugin.ExecutionsRow{
{
ShardID: 0,
DomainID: serialization.MustParseUUID("ff9c8a3f-0e4f-4d3e-a4d2-6f5f8f3f7d9d"),
WorkflowID: "test-workflow-id",
RunID: serialization.MustParseUUID("ee8d7b6e-876c-4b1e-9b6e-5e3e3c6b6b3f"),
NextEventID: 101,
LastWriteVersion: 11,
Data: []byte("test data"),
DataEncoding: "thriftrw",
},
}, nil)
db.EXPECT().SelectFromActivityInfoMaps(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
db.EXPECT().SelectFromTimerInfoMaps(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
db.EXPECT().SelectFromChildExecutionInfoMaps(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
db.EXPECT().SelectFromRequestCancelInfoMaps(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
db.EXPECT().SelectFromSignalInfoMaps(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
db.EXPECT().SelectFromSignalsRequestedSets(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
db.EXPECT().SelectFromBufferedEvents(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
parser.EXPECT().WorkflowExecutionInfoFromBlob(gomock.Any(), gomock.Any()).Return(&serialization.WorkflowExecutionInfo{
Checksum: []byte("test-checksum"),
ChecksumEncoding: "test-checksum-encoding",
}, nil)
db.EXPECT().SelectFromShards(gomock.Any(), gomock.Any()).Return(&sqlplugin.ShardsRow{
RangeID: 1,
}, nil)
},
want: &persistence.InternalGetWorkflowExecutionResponse{
State: &persistence.InternalWorkflowMutableState{
ExecutionInfo: &persistence.InternalWorkflowExecutionInfo{
DomainID: "ff9c8a3f-0e4f-4d3e-a4d2-6f5f8f3f7d9d",
WorkflowID: "test-workflow-id",
RunID: "ee8d7b6e-876c-4b1e-9b6e-5e3e3c6b6b3f",
NextEventID: 101,
CompletionEventBatchID: -23,
},
ActivityInfos: map[int64]*persistence.InternalActivityInfo{},
TimerInfos: map[string]*persistence.TimerInfo{},
ChildExecutionInfos: map[int64]*persistence.InternalChildExecutionInfo{},
RequestCancelInfos: map[int64]*persistence.RequestCancelInfo{},
SignalInfos: map[int64]*persistence.SignalInfo{},
SignalRequestedIDs: map[string]struct{}{},
ChecksumData: nil,
},
},
wantErr: false,
},
{
name: "Error - failed to get shard",
req: &persistence.InternalGetWorkflowExecutionRequest{
DomainID: "ff9c8a3f-0e4f-4d3e-a4d2-6f5f8f3f7d9d",
Execution: types.WorkflowExecution{
WorkflowID: "test-workflow-id",
RunID: "ee8d7b6e-876c-4b1e-9b6e-5e3e3c6b6b3f",
},
},
mockSetup: func(db *sqlplugin.MockDB, parser *serialization.MockParser) {
db.EXPECT().SelectFromExecutions(gomock.Any(), gomock.Any()).Return([]sqlplugin.ExecutionsRow{
{
ShardID: 0,
DomainID: serialization.MustParseUUID("ff9c8a3f-0e4f-4d3e-a4d2-6f5f8f3f7d9d"),
WorkflowID: "test-workflow-id",
RunID: serialization.MustParseUUID("ee8d7b6e-876c-4b1e-9b6e-5e3e3c6b6b3f"),
NextEventID: 101,
LastWriteVersion: 11,
Data: []byte("test data"),
DataEncoding: "thriftrw",
},
}, nil)
db.EXPECT().SelectFromActivityInfoMaps(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
db.EXPECT().SelectFromTimerInfoMaps(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
db.EXPECT().SelectFromChildExecutionInfoMaps(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
db.EXPECT().SelectFromRequestCancelInfoMaps(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
db.EXPECT().SelectFromSignalInfoMaps(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
db.EXPECT().SelectFromSignalsRequestedSets(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
db.EXPECT().SelectFromBufferedEvents(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
parser.EXPECT().WorkflowExecutionInfoFromBlob(gomock.Any(), gomock.Any()).Return(&serialization.WorkflowExecutionInfo{
Checksum: []byte("test-checksum"),
ChecksumEncoding: "test-checksum-encoding",
}, nil)
db.EXPECT().SelectFromShards(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
db.EXPECT().IsNotFoundError(gomock.Any()).Return(true).AnyTimes()
},
wantErr: true,
Shaddoll marked this conversation as resolved.
Show resolved Hide resolved
},
{
name: "Error - SelectFromExecutions no row",
req: &persistence.InternalGetWorkflowExecutionRequest{
DomainID: "ff9c8a3f-0e4f-4d3e-a4d2-6f5f8f3f7d9d",
Execution: types.WorkflowExecution{
WorkflowID: "test-workflow-id",
RunID: "ee8d7b6e-876c-4b1e-9b6e-5e3e3c6b6b3f",
},
},
mockSetup: func(db *sqlplugin.MockDB, parser *serialization.MockParser) {
db.EXPECT().SelectFromExecutions(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
db.EXPECT().SelectFromActivityInfoMaps(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
db.EXPECT().SelectFromTimerInfoMaps(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
db.EXPECT().SelectFromChildExecutionInfoMaps(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
db.EXPECT().SelectFromRequestCancelInfoMaps(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
db.EXPECT().SelectFromSignalInfoMaps(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
db.EXPECT().SelectFromSignalsRequestedSets(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
db.EXPECT().SelectFromBufferedEvents(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
},
wantErr: true,
assertErr: func(t *testing.T, err error) {
assert.IsType(t, &types.EntityNotExistsError{}, err)
},
},
{
name: "Error - SelectFromExecutions failed",
req: &persistence.InternalGetWorkflowExecutionRequest{
Expand Down Expand Up @@ -3562,6 +3686,9 @@ func TestGetWorkflowExecution(t *testing.T) {
resp, err := s.GetWorkflowExecution(context.Background(), tc.req)
if tc.wantErr {
assert.Error(t, err, "Expected an error for test case")
if tc.assertErr != nil {
tc.assertErr(t, err)
}
} else {
assert.NoError(t, err, "Did not expect an error for test case")
assert.Equal(t, tc.want, resp, "Response mismatch")
Expand Down
1 change: 1 addition & 0 deletions service/history/ndc/activity_replicator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ func (s *activityReplicatorSuite) TestSyncActivity_WorkflowNotFound() {
RunID: runID,
},
DomainName: domainName,
RangeID: 1,
}).Return(nil, &types.EntityNotExistsError{})
s.mockDomainCache.EXPECT().GetDomainByID(domainID).Return(
cache.NewGlobalDomainCacheEntryForTest(
Expand Down
2 changes: 2 additions & 0 deletions service/history/ndc/transaction_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,7 @@ func (s *transactionManagerSuite) TestCheckWorkflowExists_DoesNotExists() {
RunID: runID,
},
DomainName: domainName,
RangeID: 1,
}).Return(nil, &types.EntityNotExistsError{}).Once()

exists, err := s.transactionManager.checkWorkflowExists(ctx, domainID, workflowID, runID)
Expand All @@ -465,6 +466,7 @@ func (s *transactionManagerSuite) TestCheckWorkflowExists_DoesExists() {
RunID: runID,
},
DomainName: domainName,
RangeID: 1,
}).Return(&persistence.GetWorkflowExecutionResponse{}, nil).Once()

exists, err := s.transactionManager.checkWorkflowExists(ctx, domainID, workflowID, runID)
Expand Down
1 change: 1 addition & 0 deletions service/history/shard/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,7 @@ func (s *contextImpl) GetWorkflowExecution(
ctx context.Context,
request *persistence.GetWorkflowExecutionRequest,
) (*persistence.GetWorkflowExecutionResponse, error) {
request.RangeID = atomic.LoadInt64(&s.rangeID) // This is to make sure read is not blocked by write, s.rangeID is synced with s.shardInfo.RangeID
if s.isClosed() {
return nil, ErrShardClosed
}
Expand Down
3 changes: 3 additions & 0 deletions service/history/shard/context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,9 @@ func TestGetWorkflowExecution(t *testing.T) {
mockExecutionMgr := &mocks.ExecutionManager{}
shardContext := &contextImpl{
executionManager: mockExecutionMgr,
shardInfo: &persistence.ShardInfo{
RangeID: 12,
},
}
if tc.isClosed {
shardContext.closed = 1
Expand Down
Loading