diff --git a/common/persistence/nosql/sharded_nosql_store.go b/common/persistence/nosql/sharded_nosql_store.go index 26a2a496105..5d755f5c2b4 100644 --- a/common/persistence/nosql/sharded_nosql_store.go +++ b/common/persistence/nosql/sharded_nosql_store.go @@ -100,6 +100,8 @@ func (sn *shardedNosqlStoreImpl) GetDefaultShard() nosqlStore { } func (sn *shardedNosqlStoreImpl) Close() { + sn.RLock() + defer sn.RUnlock() for name, shard := range sn.connectedShards { sn.logger.Warn("Closing store shard", tag.StoreShard(name)) shard.Close() @@ -135,8 +137,8 @@ func (sn *shardedNosqlStoreImpl) getShard(shardName string) (*nosqlStore, error) } sn.Lock() + defer sn.Unlock() if shard, ok := sn.connectedShards[shardName]; ok { // read again to double-check - sn.Unlock() return &shard, nil } @@ -146,7 +148,6 @@ func (sn *shardedNosqlStoreImpl) getShard(shardName string) (*nosqlStore, error) } sn.connectedShards[shardName] = *s sn.logger.Info("Connected to store shard", tag.StoreShard(shardName)) - sn.Unlock() return s, nil } diff --git a/common/persistence/nosql/sharded_nosql_store_test.go b/common/persistence/nosql/sharded_nosql_store_test.go index 34a03e6c5f6..d9a3ccd92bb 100644 --- a/common/persistence/nosql/sharded_nosql_store_test.go +++ b/common/persistence/nosql/sharded_nosql_store_test.go @@ -21,6 +21,7 @@ package nosql import ( + "errors" "testing" "github.com/golang/mock/gomock" @@ -66,12 +67,7 @@ func TestShardedNosqlStoreTestSuite(t *testing.T) { } func (s *shardedNosqlStoreTestSuite) TestValidConfiguration() { - cfg := getValidShardedNoSQLConfig() - - storeInterface, err := newShardedNosqlStore(cfg, log.NewNoop(), nil) - s.NoError(err) - store := storeInterface.(*shardedNosqlStoreImpl) - + store := s.newShardedStoreForTest() s.Equal(1, len(store.connectedShards)) s.Contains(store.connectedShards, "shard-1") s.Equal(store.GetDefaultShard(), store.defaultShard) @@ -83,13 +79,18 @@ func (s *shardedNosqlStoreTestSuite) TestValidConfiguration() { func (s *shardedNosqlStoreTestSuite) TestStoreSelectionForHistoryShard() { mockDB1 := nosqlplugin.NewMockDB(s.mockController) + mockDB1.EXPECT().Close().Times(1) mockDB2 := nosqlplugin.NewMockDB(s.mockController) + mockDB2.EXPECT().Close().Times(1) mockPlugin := nosqlplugin.NewMockPlugin(s.mockController) gomock.InOrder( mockPlugin.EXPECT(). CreateDB(gomock.Any(), gomock.Any(), gomock.Any()). Return(mockDB1, nil), + mockPlugin.EXPECT(). + CreateDB(gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil, errors.New("error creating db")), mockPlugin.EXPECT(). CreateDB(gomock.Any(), gomock.Any(), gomock.Any()). Return(mockDB2, nil), @@ -97,11 +98,8 @@ func (s *shardedNosqlStoreTestSuite) TestStoreSelectionForHistoryShard() { delete(supportedPlugins, "cassandra") RegisterPlugin("cassandra", mockPlugin) - cfg := getValidShardedNoSQLConfig() - - storeInterface, err := newShardedNosqlStore(cfg, log.NewNoop(), nil) - s.NoError(err) - store := storeInterface.(*shardedNosqlStoreImpl) + store := s.newShardedStoreForTest() + defer store.Close() s.Equal(1, len(store.connectedShards)) s.True(mockDB1 == store.defaultShard.db) @@ -118,8 +116,13 @@ func (s *shardedNosqlStoreTestSuite) TestStoreSelectionForHistoryShard() { s.Equal(1, len(store.connectedShards)) s.True(mockDB1 == storeShard1.db) - // Getting a new shard should create a new connection + // Getting a new shard should create a new connection but it will fail on first attempt storeShard2, err := store.GetStoreShardByHistoryShard(1) + s.Error(err) + s.Equal(1, len(store.connectedShards)) + + // Getting a new shard should create a new connection on second attempt + storeShard2, err = store.GetStoreShardByHistoryShard(1) s.NoError(err) s.Equal(2, len(store.connectedShards)) s.True(mockDB2 == storeShard2.db) @@ -146,9 +149,23 @@ func (s *shardedNosqlStoreTestSuite) TestStoreSelectionForHistoryShard() { s.True(mockDB2 == storeShard2.db) } +func (s *shardedNosqlStoreTestSuite) newShardedStoreForTest() *shardedNosqlStoreImpl { + cfg := getValidShardedNoSQLConfig() + logger := log.NewNoop() + storeInterface, err := newShardedNosqlStore(cfg, logger, nil) + s.NoError(err) + s.Equal("shardedNosql", storeInterface.GetName()) + s.Equal(logger, storeInterface.GetLogger()) + store := storeInterface.(*shardedNosqlStoreImpl) + s.Equal(storeInterface.GetShardingPolicy(), store.shardingPolicy) + return store +} + func (s *shardedNosqlStoreTestSuite) TestStoreSelectionForTasklist() { mockDB1 := nosqlplugin.NewMockDB(s.mockController) + mockDB1.EXPECT().Close().Times(1) mockDB2 := nosqlplugin.NewMockDB(s.mockController) + mockDB2.EXPECT().Close().Times(1) mockPlugin := nosqlplugin.NewMockPlugin(s.mockController) gomock.InOrder( @@ -162,11 +179,8 @@ func (s *shardedNosqlStoreTestSuite) TestStoreSelectionForTasklist() { delete(supportedPlugins, "cassandra") RegisterPlugin("cassandra", mockPlugin) - cfg := getValidShardedNoSQLConfig() - - storeInterface, err := newShardedNosqlStore(cfg, log.NewNoop(), nil) - s.NoError(err) - store := storeInterface.(*shardedNosqlStoreImpl) + store := s.newShardedStoreForTest() + defer store.Close() s.Equal(1, len(store.connectedShards)) s.True(mockDB1 == store.defaultShard.db)