diff --git a/src/main/java/io/lettuce/core/pubsub/PubSubEndpoint.java b/src/main/java/io/lettuce/core/pubsub/PubSubEndpoint.java index 671b24aed3..36c511d522 100644 --- a/src/main/java/io/lettuce/core/pubsub/PubSubEndpoint.java +++ b/src/main/java/io/lettuce/core/pubsub/PubSubEndpoint.java @@ -128,6 +128,14 @@ public Set getChannels() { return unwrap(this.channels); } + public boolean hasShardChannelSubscriptions() { + return !shardChannels.isEmpty(); + } + + public Set getShardChannels() { + return unwrap(this.shardChannels); + } + public boolean hasPatternSubscriptions() { return !patterns.isEmpty(); } diff --git a/src/main/java/io/lettuce/core/pubsub/StatefulRedisPubSubConnectionImpl.java b/src/main/java/io/lettuce/core/pubsub/StatefulRedisPubSubConnectionImpl.java index 09c802e80d..6e012f4328 100644 --- a/src/main/java/io/lettuce/core/pubsub/StatefulRedisPubSubConnectionImpl.java +++ b/src/main/java/io/lettuce/core/pubsub/StatefulRedisPubSubConnectionImpl.java @@ -130,6 +130,10 @@ protected List> resubscribe() { result.add(async().subscribe(toArray(endpoint.getChannels()))); } + if (endpoint.hasShardChannelSubscriptions()) { + result.add(async().ssubscribe(toArray(endpoint.getShardChannels()))); + } + if (endpoint.hasPatternSubscriptions()) { result.add(async().psubscribe(toArray(endpoint.getPatterns()))); } diff --git a/src/test/java/io/lettuce/core/pubsub/PubSubCommandIntegrationTests.java b/src/test/java/io/lettuce/core/pubsub/PubSubCommandIntegrationTests.java index fdd7780558..9936ba4283 100644 --- a/src/test/java/io/lettuce/core/pubsub/PubSubCommandIntegrationTests.java +++ b/src/test/java/io/lettuce/core/pubsub/PubSubCommandIntegrationTests.java @@ -86,6 +86,8 @@ class PubSubCommandIntegrationTests extends AbstractRedisClientTest { BlockingQueue counts = listener.getCounts(); + BlockingQueue shardCounts = listener.getShardCounts(); + String channel = "channel0"; String shardChannel = "shard-channel"; @@ -523,6 +525,24 @@ void resubscribePatternsOnReconnect() throws Exception { assertThat(messages.take()).isEqualTo(message); } + @Test + void resubscribeShardChannelsOnReconnect() throws Exception { + pubsub.ssubscribe(shardChannel); + assertThat(shardChannels.take()).isEqualTo(shardChannel); + assertThat((long) shardCounts.take()).isEqualTo(1); + + pubsub.quit(); + + assertThat(shardChannels.take()).isEqualTo(shardChannel); + assertThat((long) shardCounts.take()).isEqualTo(1); + + Wait.untilTrue(pubsub::isOpen).waitOrTimeout(); + + redis.spublish(shardChannel, shardMessage); + assertThat(shardChannels.take()).isEqualTo(shardChannel); + assertThat(messages.take()).isEqualTo(shardMessage); + } + @Test void adapter() throws Exception { final BlockingQueue localCounts = LettuceFactories.newBlockingQueue(); diff --git a/src/test/java/io/lettuce/core/pubsub/StatefulRedisPubSubConnectionImplUnitTests.java b/src/test/java/io/lettuce/core/pubsub/StatefulRedisPubSubConnectionImplUnitTests.java index 8cd91df26a..9eb5528244 100644 --- a/src/test/java/io/lettuce/core/pubsub/StatefulRedisPubSubConnectionImplUnitTests.java +++ b/src/test/java/io/lettuce/core/pubsub/StatefulRedisPubSubConnectionImplUnitTests.java @@ -12,7 +12,7 @@ import org.junit.jupiter.api.Test; import static io.lettuce.TestTags.UNIT_TEST; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.mockito.Mockito.*; @@ -81,6 +81,7 @@ void resubscribeChannelSubscription() { when(mockedEndpoint.hasChannelSubscriptions()).thenReturn(true); when(mockedEndpoint.getChannels()).thenReturn(new HashSet<>(Arrays.asList(new String[] { "channel1", "channel2" }))); when(mockedEndpoint.hasPatternSubscriptions()).thenReturn(false); + when(mockedEndpoint.hasShardChannelSubscriptions()).thenReturn(false); List> subscriptions = connection.resubscribe(); RedisFuture commandFuture = subscriptions.get(0); @@ -90,17 +91,35 @@ void resubscribeChannelSubscription() { } @Test - void resubscribeChannelAndPatternSubscription() { + void resubscribeShardChannelSubscription() { + when(mockedEndpoint.hasShardChannelSubscriptions()).thenReturn(true); + when(mockedEndpoint.getShardChannels()) + .thenReturn(new HashSet<>(Arrays.asList(new String[] { "shard_channel1", "shard_channel2" }))); + when(mockedEndpoint.hasChannelSubscriptions()).thenReturn(false); + when(mockedEndpoint.hasPatternSubscriptions()).thenReturn(false); + + List> subscriptions = connection.resubscribe(); + RedisFuture commandFuture = subscriptions.get(0); + + assertEquals(1, subscriptions.size()); + assertInstanceOf(AsyncCommand.class, commandFuture); + } + + @Test + void resubscribeChannelAndPatternAndShardChanelSubscription() { when(mockedEndpoint.hasChannelSubscriptions()).thenReturn(true); - when(mockedEndpoint.getChannels()).thenReturn(new HashSet<>(Arrays.asList(new String[] { "channel1", "channel2" }))); when(mockedEndpoint.hasPatternSubscriptions()).thenReturn(true); + when(mockedEndpoint.hasShardChannelSubscriptions()).thenReturn(true); + when(mockedEndpoint.getChannels()).thenReturn(new HashSet<>(Arrays.asList(new String[] { "channel1", "channel2" }))); when(mockedEndpoint.getPatterns()).thenReturn(new HashSet<>(Arrays.asList(new String[] { "bcast*", "echo" }))); - + when(mockedEndpoint.getShardChannels()) + .thenReturn(new HashSet<>(Arrays.asList(new String[] { "shard_channel1", "shard_channel2" }))); List> subscriptions = connection.resubscribe(); - assertEquals(2, subscriptions.size()); + assertEquals(3, subscriptions.size()); assertInstanceOf(AsyncCommand.class, subscriptions.get(0)); assertInstanceOf(AsyncCommand.class, subscriptions.get(1)); + assertInstanceOf(AsyncCommand.class, subscriptions.get(1)); } }