Skip to content

Commit

Permalink
Merge pull request #128 from ska-sa/aioredis-1206-workaround
Browse files Browse the repository at this point in the history
Make aio wait_key more robust
  • Loading branch information
bmerry authored Nov 17, 2021
2 parents 0b042ae + 326be8e commit 2d25852
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 8 deletions.
22 changes: 17 additions & 5 deletions katsdptelstate/aio/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@

logger = logging.getLogger(__name__)
_QueueItem = Tuple[bytes, Optional[bytes]]
# Note: this must be valid UTF-8, because aioredis decodes it if it needs to
# reconnect to the server.
_DUMMY_CHANNEL = b'\0katsdptelstate-internal0\001'


@contextlib.contextmanager
Expand Down Expand Up @@ -225,6 +228,13 @@ def _handle_message(self, message: Dict[str, Any]) -> None:
"""Process a message received via pub/sub."""
logger.debug('Received pub/sub message %s', message)
channel_name = message['channel']
if channel_name == _DUMMY_CHANNEL:
return
if isinstance(channel_name, int):
# Extra workaround for
# https://github.com/aio-libs/aioredis-py/issues/1206,
# although subscribing to _DUMMY_CHANNEL should be sufficient.
return
assert channel_name.startswith(b'update/')
key = channel_name[7:]
channel = self._channels.get(key)
Expand Down Expand Up @@ -289,12 +299,14 @@ async def _run_pubsub(self) -> None:
"""
try:
loop = asyncio.get_event_loop()
# Ensure we are always subscribed to something, as a workaround for
# https://github.com/aio-libs/aioredis-py/issues/1206.
await self._pubsub.subscribe(_DUMMY_CHANNEL)
get_message_task: Optional[asyncio.Task] = None
command_queue_task: asyncio.Task = loop.create_task(self._commands.get())
tasks: Set[asyncio.Future] = {command_queue_task}
while True:
# get_message raises an error if we try this before the first
# connection attempt.
# get_message raises an error if we try this when not connected.
if get_message_task is None and self._pubsub.connection:
# Using a small timeout ensures that health checks get run
get_message_task = loop.create_task(self._pubsub.get_message(timeout=1))
Expand All @@ -304,7 +316,7 @@ async def _run_pubsub(self) -> None:
if get_message_task is not None and get_message_task in done:
try:
message = await get_message_task
except aioredis.ConnectionError as exc:
except (ConnectionError, aioredis.ConnectionError) as exc:
message = None
logger.warning('redis connection error (%s), trying to reconnect', exc)
# aioredis doesn't automatically reconnect
Expand All @@ -313,10 +325,10 @@ async def _run_pubsub(self) -> None:
if self._pubsub.connection is not None:
await self._pubsub.connection.disconnect()
await self._pubsub.connection.connect()
except aioredis.ConnectionError as exc:
except (ConnectionError, aioredis.ConnectionError) as exc:
# Avoid spamming the server with connection attempts
logger.warning('redis reconnect attempt failed (%s), trying in 1s', exc)
await asyncio.sleep(1)
await asyncio.sleep(1)
finally:
# Causes new task to created on next iteration
get_message_task = None
Expand Down
6 changes: 3 additions & 3 deletions katsdptelstate/aio/test/test_telescope_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from katsdptelstate import ImmutableKeyError, encode_value, KeyType, ENCODING_MSGPACK
from katsdptelstate.aio import TelescopeState
from katsdptelstate.aio.memory import MemoryBackend
from katsdptelstate.aio.redis import RedisBackend
from katsdptelstate.aio.redis import RedisBackend, _DUMMY_CHANNEL


class TestTelescopeState(asynctest.TestCase):
Expand Down Expand Up @@ -388,7 +388,7 @@ async def test_wait_key_concurrent(self) -> None:
# so we need to sleep a bit to let them take place.
if isinstance(self.ts.backend, RedisBackend):
await asyncio.sleep(0.1)
self.assertEqual(self.ts.backend._pubsub.channels, {})
self.assertEqual(self.ts.backend._pubsub.channels, {_DUMMY_CHANNEL: None})

async def test_wait_key_concurrent_same(self) -> None:
task1 = asyncio.ensure_future(self.ts.wait_key('key'))
Expand All @@ -403,7 +403,7 @@ async def test_wait_key_concurrent_same(self) -> None:
# so we need to sleep a bit to let them take place.
if isinstance(self.ts.backend, RedisBackend):
await asyncio.sleep(0.1)
self.assertEqual(self.ts.backend._pubsub.channels, {})
self.assertEqual(self.ts.backend._pubsub.channels, {_DUMMY_CHANNEL: None})

async def test_wait_indexed_already_done(self) -> None:
await self.ts.set_indexed('test_key', 'sub_key', 5)
Expand Down

0 comments on commit 2d25852

Please sign in to comment.