Skip to content

Commit

Permalink
Validate shard id in shard controller (#3776)
Browse files Browse the repository at this point in the history
  • Loading branch information
yux0 authored Jan 5, 2023
1 parent 32b5d91 commit 7215a32
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 14 deletions.
36 changes: 28 additions & 8 deletions service/history/shard/controller_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"time"

"go.opentelemetry.io/otel/trace"
"go.temporal.io/api/serviceerror"

"go.temporal.io/server/api/historyservice/v1"
"go.temporal.io/server/client"
Expand All @@ -58,6 +59,11 @@ const (
shardControllerMembershipUpdateListenerName = "ShardController"
)

var (
invalidShardIdLowerBound = serviceerror.NewInvalidArgument("shard Id cannot be equal or lower than zero")
invalidShardIdUpperBound = serviceerror.NewInvalidArgument("shard Id cannot be larger than max shard count")
)

type (
ControllerImpl struct {
membershipUpdateCh chan *membership.ChangedEvent
Expand Down Expand Up @@ -211,6 +217,17 @@ func (c *ControllerImpl) CloseShardByID(shardID int32) {
}
}

func (c *ControllerImpl) ShardIDs() []int32 {
c.RLock()
defer c.RUnlock()

ids := make([]int32, 0, len(c.historyShards))
for id := range c.historyShards {
ids = append(ids, id)
}
return ids
}

func (c *ControllerImpl) shardClosedCallback(shard *ContextImpl) {
startTime := time.Now().UTC()
defer func() {
Expand All @@ -231,6 +248,10 @@ func (c *ControllerImpl) shardClosedCallback(shard *ContextImpl) {
// if necessary. If a shard context is created, it will initialize in the background.
// This function won't block on rangeid lease acquisition.
func (c *ControllerImpl) getOrCreateShardContext(shardID int32) (*ContextImpl, error) {
err := c.validateShardId(shardID)
if err != nil {
return nil, err
}
c.RLock()
if shard, ok := c.historyShards[shardID]; ok {
if shard.isValid() {
Expand Down Expand Up @@ -443,15 +464,14 @@ func (c *ControllerImpl) doShutdown() {
c.historyShards = nil
}

func (c *ControllerImpl) ShardIDs() []int32 {
c.RLock()
defer c.RUnlock()

ids := make([]int32, 0, len(c.historyShards))
for id := range c.historyShards {
ids = append(ids, id)
func (c *ControllerImpl) validateShardId(shardID int32) error {
if shardID <= 0 {
return invalidShardIdLowerBound
}
return ids
if shardID > c.config.NumberOfShards {
return invalidShardIdUpperBound
}
return nil
}

func IsShardOwnershipLostError(err error) bool {
Expand Down
23 changes: 17 additions & 6 deletions service/history/shard/controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -594,9 +594,9 @@ func (s *controllerSuite) TestShardExplicitUnload() {
s.mockClusterMetadata.EXPECT().GetAllClusterInfo().Return(cluster.TestSingleDCClusterInfo).AnyTimes()
mockEngine := NewMockEngine(s.controller)
mockEngine.EXPECT().Stop().AnyTimes()
s.setupMocksForAcquireShard(0, mockEngine, 5, 6, false)
s.setupMocksForAcquireShard(1, mockEngine, 5, 6, false)

shard, err := s.shardController.getOrCreateShardContext(0)
shard, err := s.shardController.getOrCreateShardContext(1)
s.NoError(err)
s.Equal(1, len(s.shardController.ShardIDs()))

Expand All @@ -618,7 +618,7 @@ func (s *controllerSuite) TestShardExplicitUnloadCancelGetOrCreate() {
mockEngine := NewMockEngine(s.controller)
mockEngine.EXPECT().Stop().AnyTimes()

shardID := int32(0)
shardID := int32(1)
s.mockServiceResolver.EXPECT().Lookup(convert.Int32ToString(shardID)).Return(s.hostInfo, nil)

ready := make(chan struct{})
Expand All @@ -638,7 +638,7 @@ func (s *controllerSuite) TestShardExplicitUnloadCancelGetOrCreate() {
})

// get shard, will start initializing in background
shard, err := s.shardController.getOrCreateShardContext(0)
shard, err := s.shardController.getOrCreateShardContext(1)
s.NoError(err)

<-ready
Expand All @@ -659,7 +659,7 @@ func (s *controllerSuite) TestShardExplicitUnloadCancelAcquire() {
mockEngine := NewMockEngine(s.controller)
mockEngine.EXPECT().Stop().AnyTimes()

shardID := int32(0)
shardID := int32(1)
s.mockServiceResolver.EXPECT().Lookup(convert.Int32ToString(shardID)).Return(s.hostInfo, nil)
// return success from GetOrCreateShard
s.mockShardManager.EXPECT().GetOrCreateShard(gomock.Any(), getOrCreateShardRequestMatcher(shardID)).Return(
Expand Down Expand Up @@ -691,7 +691,7 @@ func (s *controllerSuite) TestShardExplicitUnloadCancelAcquire() {
})

// get shard, will start initializing in background
shard, err := s.shardController.getOrCreateShardContext(0)
shard, err := s.shardController.getOrCreateShardContext(1)
s.NoError(err)

<-ready
Expand Down Expand Up @@ -834,6 +834,17 @@ func (s *controllerSuite) TestShardControllerFuzz() {
}, 1*time.Second, 50*time.Millisecond, "engine start/stop")
}

func (s *controllerSuite) Test_GetOrCreateShard_InvalidShardID() {
numShards := int32(2)
s.config.NumberOfShards = numShards

_, err := s.shardController.getOrCreateShardContext(0)
s.ErrorIs(err, invalidShardIdLowerBound)

_, err = s.shardController.getOrCreateShardContext(3)
s.ErrorIs(err, invalidShardIdUpperBound)
}

func (s *controllerSuite) setupMocksForAcquireShard(
shardID int32,
mockEngine *MockEngine,
Expand Down

0 comments on commit 7215a32

Please sign in to comment.