Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issue 2540: Synchronise concurrent command calls to single-client to single-client mode #2568

Merged
merged 1 commit into from
Jan 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions redis/asyncio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,15 +253,22 @@ def __init__(

self.response_callbacks = CaseInsensitiveDict(self.__class__.RESPONSE_CALLBACKS)

# If using a single connection client, we need to lock creation-of and use-of
# the client in order to avoid race conditions such as using asyncio.gather
# on a set of redis commands
self._single_conn_lock = asyncio.Lock()

def __repr__(self):
return f"{self.__class__.__name__}<{self.connection_pool!r}>"

def __await__(self):
return self.initialize().__await__()

async def initialize(self: _RedisT) -> _RedisT:
if self.single_connection_client and self.connection is None:
self.connection = await self.connection_pool.get_connection("_")
if self.single_connection_client:
async with self._single_conn_lock:
if self.connection is None:
self.connection = await self.connection_pool.get_connection("_")
return self

def set_response_callback(self, command: str, callback: ResponseCallbackT):
Expand Down Expand Up @@ -501,6 +508,8 @@ async def execute_command(self, *args, **options):
command_name = args[0]
conn = self.connection or await pool.get_connection(command_name, **options)

if self.single_connection_client:
await self._single_conn_lock.acquire()
try:
return await conn.retry.call_with_retry(
lambda: self._send_command_parse_response(
Expand All @@ -509,6 +518,8 @@ async def execute_command(self, *args, **options):
lambda error: self._disconnect_raise(conn, error),
)
finally:
if self.single_connection_client:
self._single_conn_lock.release()
if not self.connection:
await pool.release(conn)

Expand Down
45 changes: 45 additions & 0 deletions tests/test_asyncio/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest

import redis
from redis.asyncio import Redis
from redis.asyncio.connection import (
BaseParser,
Connection,
Expand Down Expand Up @@ -41,6 +42,50 @@ async def test_invalid_response(create_redis):
await r.connection.disconnect()


@pytest.mark.onlynoncluster
async def test_single_connection():
"""Test that concurrent requests on a single client are synchronised."""
r = Redis(single_connection_client=True)

init_call_count = 0
command_call_count = 0
in_use = False

class Retry_:
async def call_with_retry(self, _, __):
# If we remove the single-client lock, this error gets raised as two
# coroutines will be vying for the `in_use` flag due to the two
# asymmetric sleep calls
nonlocal command_call_count
nonlocal in_use
if in_use is True:
raise ValueError("Commands should be executed one at a time.")
in_use = True
await asyncio.sleep(0.01)
command_call_count += 1
await asyncio.sleep(0.03)
in_use = False
return "foo"

mock_conn = mock.MagicMock()
mock_conn.retry = Retry_()

async def get_conn(_):
# Validate only one client is created in single-client mode when
# concurrent requests are made
nonlocal init_call_count
await asyncio.sleep(0.01)
init_call_count += 1
return mock_conn

with mock.patch.object(r.connection_pool, "get_connection", get_conn):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only deterministic test to validate the locking does what we want that I could think of required patching the connection and call_with_retry objects. This test is consistent, and would fail without the client changes.

with mock.patch.object(r.connection_pool, "release"):
await asyncio.gather(r.set("a", "b"), r.set("c", "d"))

assert init_call_count == 1
assert command_call_count == 2


@skip_if_server_version_lt("4.0.0")
@pytest.mark.redismod
@pytest.mark.onlynoncluster
Expand Down