From b4f92e1e5737fe67f869cbc5ee10a3f1d4675c68 Mon Sep 17 00:00:00 2001 From: Andrew Chen Wang <60190294+Andrew-Chen-Wang@users.noreply.github.com> Date: Wed, 17 Nov 2021 17:28:03 -0500 Subject: [PATCH 01/10] Fix health check after unsubscribe --- aioredis/client.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/aioredis/client.py b/aioredis/client.py index 13fd7c87f..c44664a94 100644 --- a/aioredis/client.py +++ b/aioredis/client.py @@ -3937,16 +3937,14 @@ def __init__( # we need to know the encoding options for this connection in order # to lookup channel and pattern names for callback handlers. self.encoder = self.connection_pool.get_encoder() + self.health_check_message_b = self.encoder.encode(self.HEALTH_CHECK_MESSAGE) if self.encoder.decode_responses: self.health_check_response: Iterable[Union[str, bytes]] = [ "pong", self.HEALTH_CHECK_MESSAGE, ] else: - self.health_check_response = [ - b"pong", - self.encoder.encode(self.HEALTH_CHECK_MESSAGE), - ] + self.health_check_response = [b"pong", self.health_check_message_b] self.channels: Dict[ChannelT, PubSubHandler] = {} self.pending_unsubscribe_channels: Set[ChannelT] = set() self.patterns: Dict[ChannelT, PubSubHandler] = {} @@ -4049,7 +4047,10 @@ async def parse_response(self, block: bool = True, timeout: float = 0): return None response = await self._execute(conn, conn.read_response) - if conn.health_check_interval and response == self.health_check_response: + if ( + conn.health_check_interval + and response in (self.health_check_response, self.health_check_message_b) + ): # ignore the health check message as user might not expect it return None return response From 574649aee7ff2ff4a3035056e7694ada0168a08a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 17 Nov 2021 22:30:13 +0000 Subject: [PATCH 02/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- aioredis/client.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/aioredis/client.py b/aioredis/client.py index c44664a94..212c6adc1 100644 --- a/aioredis/client.py +++ b/aioredis/client.py @@ -4047,9 +4047,9 @@ async def parse_response(self, block: bool = True, timeout: float = 0): return None response = await self._execute(conn, conn.read_response) - if ( - conn.health_check_interval - and response in (self.health_check_response, self.health_check_message_b) + if conn.health_check_interval and response in ( + self.health_check_response, + self.health_check_message_b, ): # ignore the health check message as user might not expect it return None From 4097b1eba486e851e481feef9eb9f99563326565 Mon Sep 17 00:00:00 2001 From: Andrew Chen Wang <60190294+Andrew-Chen-Wang@users.noreply.github.com> Date: Wed, 17 Nov 2021 17:30:47 -0500 Subject: [PATCH 03/10] Create 1207.bugfix --- CHANGES/1207.bugfix | 1 + 1 file changed, 1 insertion(+) create mode 100644 CHANGES/1207.bugfix diff --git a/CHANGES/1207.bugfix b/CHANGES/1207.bugfix new file mode 100644 index 000000000..928107703 --- /dev/null +++ b/CHANGES/1207.bugfix @@ -0,0 +1 @@ +Fix #1206 health check message after unsubscribing From 5e0ceea80a11e657ce0dc3f2cff90e9dc1c20131 Mon Sep 17 00:00:00 2001 From: Andrew Chen Wang <60190294+Andrew-Chen-Wang@users.noreply.github.com> Date: Mon, 22 Nov 2021 09:48:25 -0500 Subject: [PATCH 04/10] Update aioredis/client.py Co-authored-by: Bruce Merry <1963944+bmerry@users.noreply.github.com> --- aioredis/client.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/aioredis/client.py b/aioredis/client.py index 212c6adc1..ae22ccc66 100644 --- a/aioredis/client.py +++ b/aioredis/client.py @@ -4047,9 +4047,11 @@ async def parse_response(self, block: bool = True, timeout: float = 0): return None response = await self._execute(conn, conn.read_response) + # The response depends on whether there were any subscriptions + # active at the time the PING was issued. if conn.health_check_interval and response in ( - self.health_check_response, - self.health_check_message_b, + self.health_check_response, # If there was at least one subscription + self.health_check_message_b, # If there wasn't ): # ignore the health check message as user might not expect it return None From 3ff4477da4dd160c044a172f5347f1f05ed8a846 Mon Sep 17 00:00:00 2001 From: Andrew-Chen-Wang Date: Thu, 23 Dec 2021 22:13:02 -0500 Subject: [PATCH 05/10] Update to reflect https://github.com/redis/redis-py/pull/1737 Signed-off-by: Andrew-Chen-Wang --- aioredis/client.py | 78 ++++++++++++++++++++++++++++++++++++++++---- tests/test_pubsub.py | 43 +++++++++++++++++------- 2 files changed, 103 insertions(+), 18 deletions(-) diff --git a/aioredis/client.py b/aioredis/client.py index ae22ccc66..a25107893 100644 --- a/aioredis/client.py +++ b/aioredis/client.py @@ -28,6 +28,8 @@ cast, ) +import async_timeout + from aioredis.compat import Protocol, TypedDict from aioredis.connection import ( Connection, @@ -3934,17 +3936,18 @@ def __init__( self.shard_hint = shard_hint self.ignore_subscribe_messages = ignore_subscribe_messages self.connection: Optional[Connection] = None + self.subscribed_event = asyncio.Event() # we need to know the encoding options for this connection in order # to lookup channel and pattern names for callback handlers. self.encoder = self.connection_pool.get_encoder() - self.health_check_message_b = self.encoder.encode(self.HEALTH_CHECK_MESSAGE) + self.health_check_response_b = self.encoder.encode(self.HEALTH_CHECK_MESSAGE) if self.encoder.decode_responses: self.health_check_response: Iterable[Union[str, bytes]] = [ "pong", self.HEALTH_CHECK_MESSAGE, ] else: - self.health_check_response = [b"pong", self.health_check_message_b] + self.health_check_response = [b"pong", self.health_check_response_b] self.channels: Dict[ChannelT, PubSubHandler] = {} self.pending_unsubscribe_channels: Set[ChannelT] = set() self.patterns: Dict[ChannelT, PubSubHandler] = {} @@ -3969,9 +3972,11 @@ async def reset(self): await self.connection_pool.release(self.connection) self.connection = None self.channels = {} + self.health_check_response_counter = 0 self.pending_unsubscribe_channels = set() self.patterns = {} self.pending_unsubscribe_patterns = set() + self.subscribed_event.clear() def close(self) -> Awaitable[NoReturn]: return self.reset() @@ -3997,7 +4002,7 @@ async def on_connect(self, connection: Connection): @property def subscribed(self): """Indicates if there are subscriptions to any channels or patterns""" - return bool(self.channels or self.patterns) + return self.subscribed_event.is_set() async def execute_command(self, *args: EncodableT): """Execute a publish/subscribe command""" @@ -4015,8 +4020,28 @@ async def execute_command(self, *args: EncodableT): self.connection.register_connect_callback(self.on_connect) connection = self.connection kwargs = {"check_health": not self.subscribed} + if not self.subscribed: + await self.clean_health_check_responses() await self._execute(connection, connection.send_command, *args, **kwargs) + async def clean_health_check_responses(self): + """ + If any health check responses are present, clean them + """ + ttl = 10 + conn = self.connection + while self.health_check_response_counter > 0 and ttl > 0: + if await self._execute(conn, conn.can_read, timeout=conn.socket_timeout): + response = await self._execute(conn, conn.read_response) + if self.is_health_check_response(response): + self.health_check_response_counter -= 1 + else: + raise PubSubError( + "A non health check response was cleaned by " + "execute_command: {}".format(response) + ) + ttl -= 1 + async def _execute(self, connection, command, *args, **kwargs): try: return await command(*args, **kwargs) @@ -4049,14 +4074,23 @@ async def parse_response(self, block: bool = True, timeout: float = 0): # The response depends on whether there were any subscriptions # active at the time the PING was issued. - if conn.health_check_interval and response in ( - self.health_check_response, # If there was at least one subscription - self.health_check_message_b, # If there wasn't - ): + if self.is_health_check_response(response): # ignore the health check message as user might not expect it + self.health_check_response_counter -= 1 return None return response + def is_health_check_response(self, response): + """ + Check if the response is a health check response. + If there are no subscriptions redis responds to PING command with a + bulk response, instead of a multi-bulk with "pong" and the response. + """ + return response in [ + self.health_check_response, # If there was a subscription + self.health_check_response_b, # If there wasn't + ] + async def check_health(self): conn = self.connection if conn is None: @@ -4069,6 +4103,7 @@ async def check_health(self): conn.health_check_interval and asyncio.get_event_loop().time() > conn.next_health_check ): + self.health_check_response_counter += 1 await conn.send_command( "PING", self.HEALTH_CHECK_MESSAGE, check_health=False ) @@ -4101,6 +4136,11 @@ async def psubscribe(self, *args: ChannelT, **kwargs: PubSubHandler): # for the reconnection. new_patterns = self._normalize_keys(new_patterns) self.patterns.update(new_patterns) + if not self.subscribed: + # Set the subscribed_event flag to True + self.subscribed_event.set() + # Clear the health check counter + self.health_check_response_counter = 0 self.pending_unsubscribe_patterns.difference_update(new_patterns) return ret_val @@ -4137,6 +4177,11 @@ async def subscribe(self, *args: ChannelT, **kwargs: Callable): # for the reconnection. new_channels = self._normalize_keys(new_channels) self.channels.update(new_channels) + if not self.subscribed: + # Set the subscribed_event flag to True + self.subscribed_event.set() + # Clear the health check counter + self.health_check_response_counter = 0 self.pending_unsubscribe_channels.difference_update(new_channels) return ret_val @@ -4171,6 +4216,21 @@ async def get_message( before returning. Timeout should be specified as a floating point number. """ + if not self.subscribed: + # Wait for subscription + start_time = asyncio.get_event_loop().time() + + async with async_timeout.timeout(timeout): + if await self.subscribed_event.wait() is True: + # The connection was subscribed during the timeout time frame. + # The timeout should be adjusted based on the time spent + # waiting for the subscription + time_spent = asyncio.get_event_loop().time() - start_time + timeout = max(0.0, timeout - time_spent) + else: + # The connection isn't subscribed to any channels or patterns, + # so no messages are available + return None response = await self.parse_response(block=False, timeout=timeout) if response: return self.handle_message(response, ignore_subscribe_messages) @@ -4224,6 +4284,10 @@ def handle_message(self, response, ignore_subscribe_messages=False): if channel in self.pending_unsubscribe_channels: self.pending_unsubscribe_channels.remove(channel) self.channels.pop(channel, None) + if not self.channels and not self.patterns: + # There are no subscriptions anymore, set subscribed_event flag + # to false + self.subscribed_event.clear() if message_type in self.PUBLISH_MESSAGE_TYPES: # if there's a message handler, invoke it diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py index 935f9cae5..f9dfd42f8 100644 --- a/tests/test_pubsub.py +++ b/tests/test_pubsub.py @@ -1,6 +1,5 @@ import asyncio -import threading -import time +from unittest.mock import patch import pytest @@ -343,15 +342,6 @@ async def test_unicode_pattern_message_handler(self, r): "pmessage", channel, "test message", pattern=pattern ) - async def test_get_message_without_subscribe(self, r): - p = r.pubsub() - with pytest.raises(RuntimeError) as info: - await p.get_message() - expect = ( - "connection not set: " "did you forget to call subscribe() or psubscribe()?" - ) - assert expect in info.exconly() - class TestPubSubAutoDecoding: """These tests only validate that we get unicode values back""" @@ -553,6 +543,37 @@ async def test_get_message_with_timeout_returns_none(self, r): assert await wait_for_message(p) == make_message("subscribe", "foo", 1) assert await p.get_message(timeout=0.01) is None + def test_get_message_not_subscribed_return_none(self, r): + p = r.pubsub() + assert p.subscribed is False + assert await p.get_message() is None + assert await p.get_message(timeout=0.1) is None + with patch.object(asyncio.Event, "wait") as mock: + mock.return_value = False + assert await p.get_message(timeout=0.01) is None + assert mock.called + + def test_get_message_subscribe_during_waiting(self, r): + p = r.pubsub() + + async def poll(ps, expected_res): + assert await ps.get_message() is None + message = await ps.get_message(timeout=1) + assert message == expected_res + + subscribe_response = make_message("subscribe", "foo", 1) + asyncio.create_task(poll(p, subscribe_response)) + await asyncio.sleep(0.2) + await p.subscribe("foo") + + def test_get_message_wait_for_subscription_not_being_called(self, r): + p = r.pubsub() + await p.subscribe("foo") + with patch.object(asyncio.Event, "wait") as mock: + assert p.subscribed is True + assert await wait_for_message(p) == make_message("subscribe", "foo", 1) + assert mock.called is False + class TestPubSubRun: async def _subscribe(self, p, *args, **kwargs): From fa99cff8ac043fb5a16d2572b98fd1fa574ba913 Mon Sep 17 00:00:00 2001 From: Andrew-Chen-Wang Date: Thu, 23 Dec 2021 23:35:11 -0500 Subject: [PATCH 06/10] Fix mypy issue Signed-off-by: Andrew-Chen-Wang --- aioredis/client.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/aioredis/client.py b/aioredis/client.py index a25107893..77e0b3ee1 100644 --- a/aioredis/client.py +++ b/aioredis/client.py @@ -4030,6 +4030,8 @@ async def clean_health_check_responses(self): """ ttl = 10 conn = self.connection + if not conn: + return while self.health_check_response_counter > 0 and ttl > 0: if await self._execute(conn, conn.can_read, timeout=conn.socket_timeout): response = await self._execute(conn, conn.read_response) From fec374901b7db4c19dbfb376f7017b3efb91ba04 Mon Sep 17 00:00:00 2001 From: Andrew Chen Wang <60190294+Andrew-Chen-Wang@users.noreply.github.com> Date: Thu, 23 Dec 2021 23:55:33 -0500 Subject: [PATCH 07/10] Fix tests' syntax --- tests/test_pubsub.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py index f9dfd42f8..c2be8b5b8 100644 --- a/tests/test_pubsub.py +++ b/tests/test_pubsub.py @@ -543,7 +543,7 @@ async def test_get_message_with_timeout_returns_none(self, r): assert await wait_for_message(p) == make_message("subscribe", "foo", 1) assert await p.get_message(timeout=0.01) is None - def test_get_message_not_subscribed_return_none(self, r): + async def test_get_message_not_subscribed_return_none(self, r): p = r.pubsub() assert p.subscribed is False assert await p.get_message() is None @@ -553,7 +553,7 @@ def test_get_message_not_subscribed_return_none(self, r): assert await p.get_message(timeout=0.01) is None assert mock.called - def test_get_message_subscribe_during_waiting(self, r): + async def test_get_message_subscribe_during_waiting(self, r): p = r.pubsub() async def poll(ps, expected_res): @@ -566,7 +566,7 @@ async def poll(ps, expected_res): await asyncio.sleep(0.2) await p.subscribe("foo") - def test_get_message_wait_for_subscription_not_being_called(self, r): + async def test_get_message_wait_for_subscription_not_being_called(self, r): p = r.pubsub() await p.subscribe("foo") with patch.object(asyncio.Event, "wait") as mock: From a0167c9bb85241287c34428a365ab134cb86e6c1 Mon Sep 17 00:00:00 2001 From: Andrew Chen Wang <60190294+Andrew-Chen-Wang@users.noreply.github.com> Date: Fri, 24 Dec 2021 00:09:26 -0500 Subject: [PATCH 08/10] Use reset in PubSub init --- aioredis/client.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/aioredis/client.py b/aioredis/client.py index 77e0b3ee1..eb06feef6 100644 --- a/aioredis/client.py +++ b/aioredis/client.py @@ -3925,6 +3925,11 @@ class PubSub: PUBLISH_MESSAGE_TYPES = ("message", "pmessage") UNSUBSCRIBE_MESSAGE_TYPES = ("unsubscribe", "punsubscribe") HEALTH_CHECK_MESSAGE = "aioredis-py-health-check" + + channels: Dict[ChannelT, PubSubHandler] + pending_unsubscribe_channels: Set[ChannelT] + patterns: Dict[ChannelT, PubSubHandler] + pending_unsubscribe_patterns: Set[ChannelT] def __init__( self, @@ -3948,11 +3953,8 @@ def __init__( ] else: self.health_check_response = [b"pong", self.health_check_response_b] - self.channels: Dict[ChannelT, PubSubHandler] = {} - self.pending_unsubscribe_channels: Set[ChannelT] = set() - self.patterns: Dict[ChannelT, PubSubHandler] = {} - self.pending_unsubscribe_patterns: Set[ChannelT] = set() self._lock = asyncio.Lock() + await self.reset() async def __aenter__(self): return self From ce155b7398d06b7112a80237c1d3aa55f023c1c8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 24 Dec 2021 05:09:44 +0000 Subject: [PATCH 09/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- aioredis/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aioredis/client.py b/aioredis/client.py index eb06feef6..cc67b7046 100644 --- a/aioredis/client.py +++ b/aioredis/client.py @@ -3925,7 +3925,7 @@ class PubSub: PUBLISH_MESSAGE_TYPES = ("message", "pmessage") UNSUBSCRIBE_MESSAGE_TYPES = ("unsubscribe", "punsubscribe") HEALTH_CHECK_MESSAGE = "aioredis-py-health-check" - + channels: Dict[ChannelT, PubSubHandler] pending_unsubscribe_channels: Set[ChannelT] patterns: Dict[ChannelT, PubSubHandler] From d4fe77c7cb470c96427b3b591ad5bffa1f141cbc Mon Sep 17 00:00:00 2001 From: Andrew Chen Wang <60190294+Andrew-Chen-Wang@users.noreply.github.com> Date: Fri, 24 Dec 2021 00:12:05 -0500 Subject: [PATCH 10/10] Add additional reset lines to PubSub init --- aioredis/client.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/aioredis/client.py b/aioredis/client.py index cc67b7046..b822586e8 100644 --- a/aioredis/client.py +++ b/aioredis/client.py @@ -3926,11 +3926,6 @@ class PubSub: UNSUBSCRIBE_MESSAGE_TYPES = ("unsubscribe", "punsubscribe") HEALTH_CHECK_MESSAGE = "aioredis-py-health-check" - channels: Dict[ChannelT, PubSubHandler] - pending_unsubscribe_channels: Set[ChannelT] - patterns: Dict[ChannelT, PubSubHandler] - pending_unsubscribe_patterns: Set[ChannelT] - def __init__( self, connection_pool: ConnectionPool, @@ -3953,8 +3948,13 @@ def __init__( ] else: self.health_check_response = [b"pong", self.health_check_response_b] + self.channels: Dict[ChannelT, PubSubHandler] = {} + self.pending_unsubscribe_channels: Set[ChannelT] = set() + self.patterns: Dict[ChannelT, PubSubHandler] = {} + self.pending_unsubscribe_patterns: Set[ChannelT] = set() self._lock = asyncio.Lock() - await self.reset() + self.health_check_response_counter = 0 + self.subscribed_event.clear() async def __aenter__(self): return self