From 7215a3253fb4e465f484ca723f02a7ce7492e7ca Mon Sep 17 00:00:00 2001 From: Yu Xia Date: Wed, 4 Jan 2023 16:40:52 -0800 Subject: [PATCH] Validate shard id in shard controller (#3776) --- service/history/shard/controller_impl.go | 36 ++++++++++++++++++------ service/history/shard/controller_test.go | 23 +++++++++++---- 2 files changed, 45 insertions(+), 14 deletions(-) diff --git a/service/history/shard/controller_impl.go b/service/history/shard/controller_impl.go index bd70be1fb5e..cfffdbed4f6 100644 --- a/service/history/shard/controller_impl.go +++ b/service/history/shard/controller_impl.go @@ -33,6 +33,7 @@ import ( "time" "go.opentelemetry.io/otel/trace" + "go.temporal.io/api/serviceerror" "go.temporal.io/server/api/historyservice/v1" "go.temporal.io/server/client" @@ -58,6 +59,11 @@ const ( shardControllerMembershipUpdateListenerName = "ShardController" ) +var ( + invalidShardIdLowerBound = serviceerror.NewInvalidArgument("shard Id cannot be equal or lower than zero") + invalidShardIdUpperBound = serviceerror.NewInvalidArgument("shard Id cannot be larger than max shard count") +) + type ( ControllerImpl struct { membershipUpdateCh chan *membership.ChangedEvent @@ -211,6 +217,17 @@ func (c *ControllerImpl) CloseShardByID(shardID int32) { } } +func (c *ControllerImpl) ShardIDs() []int32 { + c.RLock() + defer c.RUnlock() + + ids := make([]int32, 0, len(c.historyShards)) + for id := range c.historyShards { + ids = append(ids, id) + } + return ids +} + func (c *ControllerImpl) shardClosedCallback(shard *ContextImpl) { startTime := time.Now().UTC() defer func() { @@ -231,6 +248,10 @@ func (c *ControllerImpl) shardClosedCallback(shard *ContextImpl) { // if necessary. If a shard context is created, it will initialize in the background. // This function won't block on rangeid lease acquisition. func (c *ControllerImpl) getOrCreateShardContext(shardID int32) (*ContextImpl, error) { + err := c.validateShardId(shardID) + if err != nil { + return nil, err + } c.RLock() if shard, ok := c.historyShards[shardID]; ok { if shard.isValid() { @@ -443,15 +464,14 @@ func (c *ControllerImpl) doShutdown() { c.historyShards = nil } -func (c *ControllerImpl) ShardIDs() []int32 { - c.RLock() - defer c.RUnlock() - - ids := make([]int32, 0, len(c.historyShards)) - for id := range c.historyShards { - ids = append(ids, id) +func (c *ControllerImpl) validateShardId(shardID int32) error { + if shardID <= 0 { + return invalidShardIdLowerBound } - return ids + if shardID > c.config.NumberOfShards { + return invalidShardIdUpperBound + } + return nil } func IsShardOwnershipLostError(err error) bool { diff --git a/service/history/shard/controller_test.go b/service/history/shard/controller_test.go index 555c64cbd5f..d48a066f75b 100644 --- a/service/history/shard/controller_test.go +++ b/service/history/shard/controller_test.go @@ -594,9 +594,9 @@ func (s *controllerSuite) TestShardExplicitUnload() { s.mockClusterMetadata.EXPECT().GetAllClusterInfo().Return(cluster.TestSingleDCClusterInfo).AnyTimes() mockEngine := NewMockEngine(s.controller) mockEngine.EXPECT().Stop().AnyTimes() - s.setupMocksForAcquireShard(0, mockEngine, 5, 6, false) + s.setupMocksForAcquireShard(1, mockEngine, 5, 6, false) - shard, err := s.shardController.getOrCreateShardContext(0) + shard, err := s.shardController.getOrCreateShardContext(1) s.NoError(err) s.Equal(1, len(s.shardController.ShardIDs())) @@ -618,7 +618,7 @@ func (s *controllerSuite) TestShardExplicitUnloadCancelGetOrCreate() { mockEngine := NewMockEngine(s.controller) mockEngine.EXPECT().Stop().AnyTimes() - shardID := int32(0) + shardID := int32(1) s.mockServiceResolver.EXPECT().Lookup(convert.Int32ToString(shardID)).Return(s.hostInfo, nil) ready := make(chan struct{}) @@ -638,7 +638,7 @@ func (s *controllerSuite) TestShardExplicitUnloadCancelGetOrCreate() { }) // get shard, will start initializing in background - shard, err := s.shardController.getOrCreateShardContext(0) + shard, err := s.shardController.getOrCreateShardContext(1) s.NoError(err) <-ready @@ -659,7 +659,7 @@ func (s *controllerSuite) TestShardExplicitUnloadCancelAcquire() { mockEngine := NewMockEngine(s.controller) mockEngine.EXPECT().Stop().AnyTimes() - shardID := int32(0) + shardID := int32(1) s.mockServiceResolver.EXPECT().Lookup(convert.Int32ToString(shardID)).Return(s.hostInfo, nil) // return success from GetOrCreateShard s.mockShardManager.EXPECT().GetOrCreateShard(gomock.Any(), getOrCreateShardRequestMatcher(shardID)).Return( @@ -691,7 +691,7 @@ func (s *controllerSuite) TestShardExplicitUnloadCancelAcquire() { }) // get shard, will start initializing in background - shard, err := s.shardController.getOrCreateShardContext(0) + shard, err := s.shardController.getOrCreateShardContext(1) s.NoError(err) <-ready @@ -834,6 +834,17 @@ func (s *controllerSuite) TestShardControllerFuzz() { }, 1*time.Second, 50*time.Millisecond, "engine start/stop") } +func (s *controllerSuite) Test_GetOrCreateShard_InvalidShardID() { + numShards := int32(2) + s.config.NumberOfShards = numShards + + _, err := s.shardController.getOrCreateShardContext(0) + s.ErrorIs(err, invalidShardIdLowerBound) + + _, err = s.shardController.getOrCreateShardContext(3) + s.ErrorIs(err, invalidShardIdUpperBound) +} + func (s *controllerSuite) setupMocksForAcquireShard( shardID int32, mockEngine *MockEngine,