From 39f6c4f009cb378125f22243e3719b2211e9b968 Mon Sep 17 00:00:00 2001 From: "Chayim I. Kirshen" Date: Sun, 6 Aug 2023 12:11:29 +0300 Subject: [PATCH] cherry-picking #3670 --- redis/asyncio/connection.py | 20 ++++++- redis/client.py | 11 +++- redis/cluster.py | 3 + redis/commands/core.py | 7 +++ redis/connection.py | 5 ++ redis/utils.py | 13 ++++ tests/conftest.py | 5 +- tests/test_asyncio/conftest.py | 12 ++++ tests/test_asyncio/test_cluster.py | 2 +- tests/test_asyncio/test_commands.py | 25 +++++++- tests/test_asyncio/test_connection.py | 86 ++------------------------- tests/test_commands.py | 21 +++++++ tests/test_search.py | 2 +- 13 files changed, 121 insertions(+), 91 deletions(-) diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index d6195e1801..21d74b9276 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -57,7 +57,7 @@ TimeoutError, ) from redis.typing import EncodableT, EncodedT -from redis.utils import HIREDIS_AVAILABLE, str_if_bytes +from redis.utils import HIREDIS_AVAILABLE, get_lib_version, str_if_bytes hiredis = None if HIREDIS_AVAILABLE: @@ -453,6 +453,8 @@ class AbstractConnection: "db", "username", "client_name", + "lib_name", + "lib_version", "credential_provider", "password", "socket_timeout", @@ -491,6 +493,8 @@ def __init__( socket_read_size: int = 65536, health_check_interval: float = 0, client_name: Optional[str] = None, + lib_name: Optional[str] = "redis-py", + lib_version: Optional[str] = get_lib_version(), username: Optional[str] = None, retry: Optional[Retry] = None, redis_connect_func: Optional[ConnectCallbackT] = None, @@ -507,6 +511,8 @@ def __init__( self.pid = os.getpid() self.db = db self.client_name = client_name + self.lib_name = lib_name + self.lib_version = lib_version self.credential_provider = credential_provider self.password = password self.username = username @@ -654,6 +660,18 @@ async def on_connect(self) -> None: if str_if_bytes(await self.read_response()) != "OK": raise ConnectionError("Error setting client name") + try: + # set the library name and version + if self.lib_name: + await self.send_command("CLIENT", "SETINFO", "LIB-NAME", self.lib_name) + await self.read_response() + if self.lib_version: + await self.send_command( + "CLIENT", "SETINFO", "LIB-VER", self.lib_version + ) + await self.read_response() + except ResponseError: + pass # if a database is specified, switch to it if self.db: await self.send_command("SELECT", self.db) diff --git a/redis/client.py b/redis/client.py index ab626ccdf4..825f61a063 100755 --- a/redis/client.py +++ b/redis/client.py @@ -27,7 +27,7 @@ ) from redis.lock import Lock from redis.retry import Retry -from redis.utils import safe_str, str_if_bytes +from redis.utils import get_lib_version, safe_str, str_if_bytes SYM_EMPTY = b"" EMPTY_RESPONSE = "EMPTY_RESPONSE" @@ -643,7 +643,11 @@ def parse_client_info(value): "key1=value1 key2=value2 key3=value3" """ client_info = {} + value = str_if_bytes(value) + if value[-1] == "\n": + value = value[:-1] infos = str_if_bytes(value).split(" ") + infos = value.split(" ") for info in infos: key, value = info.split("=") client_info[key] = value @@ -754,6 +758,7 @@ class AbstractRedis: "CLIENT SETNAME": bool_ok, "CLIENT UNBLOCK": lambda r: r and int(r) == 1 or False, "CLIENT PAUSE": bool_ok, + "CLIENT SETINFO": bool_ok, "CLIENT GETREDIR": int, "CLIENT TRACKINGINFO": lambda r: list(map(str_if_bytes, r)), "CLUSTER ADDSLOTS": bool_ok, @@ -949,6 +954,8 @@ def __init__( single_connection_client=False, health_check_interval=0, client_name=None, + lib_name="redis-py", + lib_version=get_lib_version(), username=None, retry=None, redis_connect_func=None, @@ -999,6 +1006,8 @@ def __init__( "max_connections": max_connections, "health_check_interval": health_check_interval, "client_name": client_name, + "lib_name": lib_name, + "lib_version": lib_version, "redis_connect_func": redis_connect_func, "credential_provider": credential_provider, } diff --git a/redis/cluster.py b/redis/cluster.py index be8e4623a7..e480a82055 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -137,6 +137,8 @@ def parse_cluster_myshardid(resp, **options): "encoding_errors", "errors", "host", + "lib_name", + "lib_version", "max_connections", "nodes_flag", "redis_connect_func", @@ -220,6 +222,7 @@ class AbstractRedisCluster: "CLIENT LIST", "CLIENT SETNAME", "CLIENT GETNAME", + "CLIENT SETINFO", "CONFIG SET", "CONFIG REWRITE", "CONFIG RESETSTAT", diff --git a/redis/commands/core.py b/redis/commands/core.py index 392ddb542c..36344ef587 100644 --- a/redis/commands/core.py +++ b/redis/commands/core.py @@ -706,6 +706,13 @@ def client_setname(self, name: str, **kwargs) -> ResponseT: """ return self.execute_command("CLIENT SETNAME", name, **kwargs) + def client_setinfo(self, attr: str, value: str, **kwargs) -> ResponseT: + """ + Sets the current connection library name or version + For mor information see https://redis.io/commands/client-setinfo + """ + return self.execute_command("CLIENT SETINFO", attr, value, **kwargs) + def client_unblock( self, client_id: int, error: bool = False, **kwargs ) -> ResponseT: diff --git a/redis/connection.py b/redis/connection.py index bec456c9ce..ebb8baff02 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -39,6 +39,7 @@ CRYPTOGRAPHY_AVAILABLE, HIREDIS_AVAILABLE, HIREDIS_PACK_AVAILABLE, + get_lib_version, str_if_bytes, ) @@ -605,6 +606,8 @@ def __init__( socket_read_size=65536, health_check_interval=0, client_name=None, + lib_name="redis-py", + lib_version=get_lib_version(), username=None, retry=None, redis_connect_func=None, @@ -628,6 +631,8 @@ def __init__( self.pid = os.getpid() self.db = db self.client_name = client_name + self.lib_name = lib_name + self.lib_version = lib_version self.credential_provider = credential_provider self.password = password self.username = username diff --git a/redis/utils.py b/redis/utils.py index d95e62c042..4582bd7257 100644 --- a/redis/utils.py +++ b/redis/utils.py @@ -1,3 +1,4 @@ +import sys from contextlib import contextmanager from functools import wraps from typing import Any, Dict, Mapping, Union @@ -12,6 +13,10 @@ HIREDIS_AVAILABLE = False HIREDIS_PACK_AVAILABLE = False +if sys.version_info >= (3, 8): + from importlib import metadata +else: + import importlib_metadata as metadata try: import cryptography # noqa @@ -110,3 +115,11 @@ def wrapper(*args, **kwargs): return wrapper return decorator + + +def get_lib_version(): + try: + libver = metadata.version("redis") + except metadata.PackageNotFoundError: + libver = "99.99.99" + return libver diff --git a/tests/conftest.py b/tests/conftest.py index 4cd4c3c160..6056146f7e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -141,6 +141,7 @@ def pytest_sessionstart(session): enterprise = info["enterprise"] except redis.ConnectionError: # provide optimistic defaults + info = {} version = "10.0.0" arch_bits = 64 cluster_enabled = False @@ -157,9 +158,7 @@ def pytest_sessionstart(session): redismod_url = session.config.getoption("--redismod-url") info = _get_info(redismod_url) REDIS_INFO["modules"] = info["modules"] - except redis.exceptions.ConnectionError: - pass - except KeyError: + except (KeyError, redis.exceptions.ConnectionError): pass if cluster_enabled: diff --git a/tests/test_asyncio/conftest.py b/tests/test_asyncio/conftest.py index 121a13b41b..929ee792d8 100644 --- a/tests/test_asyncio/conftest.py +++ b/tests/test_asyncio/conftest.py @@ -255,3 +255,15 @@ async def __aexit__(self, exc_type, exc_inst, tb): def asynccontextmanager(func): return _asynccontextmanager(func) + + +# helpers to get the connection arguments for this run +@pytest.fixture() +def redis_url(request): + return request.config.getoption("--redis-url") + + +@pytest.fixture() +def connect_args(request): + url = request.config.getoption("--redis-url") + return parse_url(url) diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index c41d4a2168..32929f0864 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -2186,7 +2186,7 @@ async def test_acl_log( await user_client.hset("{cache}:0", "hkey", "hval") assert isinstance(await r.acl_log(target_nodes=node), list) - assert len(await r.acl_log(target_nodes=node)) == 2 + assert len(await r.acl_log(target_nodes=node)) == 3 assert len(await r.acl_log(count=1, target_nodes=node)) == 1 assert isinstance((await r.acl_log(target_nodes=node))[0], dict) assert "client-info" in (await r.acl_log(count=1, target_nodes=node))[0] diff --git a/tests/test_asyncio/test_commands.py b/tests/test_asyncio/test_commands.py index c0259680c0..d0e4734670 100644 --- a/tests/test_asyncio/test_commands.py +++ b/tests/test_asyncio/test_commands.py @@ -112,7 +112,7 @@ async def test_acl_deluser(self, r_teardown): username = "redis-py-user" r = r_teardown(username) - assert await r.acl_deluser(username) == 0 + assert await r.acl_deluser(username) in [0, 1] assert await r.acl_setuser(username, enabled=False, reset=True) assert await r.acl_deluser(username) == 1 @@ -268,7 +268,7 @@ async def test_acl_log(self, r_teardown, create_redis): await user_client.hset("cache:0", "hkey", "hval") assert isinstance(await r.acl_log(), list) - assert len(await r.acl_log()) == 2 + assert len(await r.acl_log()) == 3 assert len(await r.acl_log(count=1)) == 1 assert isinstance((await r.acl_log())[0], dict) assert "client-info" in (await r.acl_log(count=1))[0] @@ -347,6 +347,27 @@ async def test_client_setname(self, r: redis.Redis): assert await r.client_setname("redis_py_test") assert await r.client_getname() == "redis_py_test" + @skip_if_server_version_lt("7.2.0") + async def test_client_setinfo(self, r: redis.Redis): + await r.ping() + info = await r.client_info() + assert info["lib-name"] == "redis-py" + assert info["lib-ver"] == redis.__version__ + assert await r.client_setinfo("lib-name", "test") + assert await r.client_setinfo("lib-ver", "123") + + info = await r.client_info() + assert info["lib-name"] == "test" + assert info["lib-ver"] == "123" + r2 = redis.asyncio.Redis(lib_name="test2", lib_version="1234") + info = await r2.client_info() + assert info["lib-name"] == "test2" + assert info["lib-ver"] == "1234" + r3 = redis.asyncio.Redis(lib_name=None, lib_version=None) + info = await r3.client_info() + assert info["lib-name"] == "" + assert info["lib-ver"] == "" + @skip_if_server_version_lt("2.6.9") @pytest.mark.onlynoncluster async def test_client_kill(self, r: redis.Redis, r2): diff --git a/tests/test_asyncio/test_connection.py b/tests/test_asyncio/test_connection.py index 158b8545e2..3f9c8f5790 100644 --- a/tests/test_asyncio/test_connection.py +++ b/tests/test_asyncio/test_connection.py @@ -10,14 +10,12 @@ from redis.asyncio.connection import ( BaseParser, Connection, - HiredisParser, PythonParser, UnixDomainSocketConnection, ) from redis.asyncio.retry import Retry from redis.backoff import NoBackoff from redis.exceptions import ConnectionError, InvalidResponse, TimeoutError -from redis.utils import HIREDIS_AVAILABLE from tests.conftest import skip_if_server_version_lt from .compat import mock @@ -126,9 +124,11 @@ async def test_can_run_concurrent_commands(r): assert all(await asyncio.gather(*(r.ping() for _ in range(10)))) -async def test_connect_retry_on_timeout_error(): +async def test_connect_retry_on_timeout_error(connect_args): """Test that the _connect function is retried in case of a timeout""" - conn = Connection(retry_on_timeout=True, retry=Retry(NoBackoff(), 3)) + conn = Connection( + retry_on_timeout=True, retry=Retry(NoBackoff(), 3), **connect_args + ) origin_connect = conn._connect conn._connect = mock.AsyncMock() @@ -195,84 +195,6 @@ async def test_connection_parse_response_resume(r: redis.Redis): assert i > 0 -@pytest.mark.onlynoncluster -@pytest.mark.parametrize( - "parser_class", [PythonParser, HiredisParser], ids=["PythonParser", "HiredisParser"] -) -async def test_connection_disconect_race(parser_class): - """ - This test reproduces the case in issue #2349 - where a connection is closed while the parser is reading to feed the - internal buffer.The stream `read()` will succeed, but when it returns, - another task has already called `disconnect()` and is waiting for - close to finish. When we attempts to feed the buffer, we will fail - since the buffer is no longer there. - - This test verifies that a read in progress can finish even - if the `disconnect()` method is called. - """ - if parser_class == HiredisParser and not HIREDIS_AVAILABLE: - pytest.skip("Hiredis not available") - - args = {} - args["parser_class"] = parser_class - - conn = Connection(**args) - - cond = asyncio.Condition() - # 0 == initial - # 1 == reader is reading - # 2 == closer has closed and is waiting for close to finish - state = 0 - - # Mock read function, which wait for a close to happen before returning - # Can either be invoked as two `read()` calls (HiredisParser) - # or as a `readline()` followed by `readexact()` (PythonParser) - chunks = [b"$13\r\n", b"Hello, World!\r\n"] - - async def read(_=None): - nonlocal state - async with cond: - if state == 0: - state = 1 # we are reading - cond.notify() - # wait until the closing task has done - await cond.wait_for(lambda: state == 2) - return chunks.pop(0) - - # function closes the connection while reader is still blocked reading - async def do_close(): - nonlocal state - async with cond: - await cond.wait_for(lambda: state == 1) - state = 2 - cond.notify() - await conn.disconnect() - - async def do_read(): - return await conn.read_response() - - reader = mock.AsyncMock() - writer = mock.AsyncMock() - writer.transport = mock.Mock() - writer.transport.get_extra_info.side_effect = None - - # for HiredisParser - reader.read.side_effect = read - # for PythonParser - reader.readline.side_effect = read - reader.readexactly.side_effect = read - - async def open_connection(*args, **kwargs): - return reader, writer - - with patch.object(asyncio, "open_connection", open_connection): - await conn.connect() - - vals = await asyncio.gather(do_read(), do_close()) - assert vals == [b"Hello, World!", None] - - @pytest.mark.onlynoncluster def test_create_single_connection_client_from_url(): client = Redis.from_url("redis://localhost:6379/0?", single_connection_client=True) diff --git a/tests/test_commands.py b/tests/test_commands.py index 2213e81f72..c5f338cd74 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -533,6 +533,27 @@ def test_client_setname(self, r): assert r.client_setname("redis_py_test") assert r.client_getname() == "redis_py_test" + @skip_if_server_version_lt("7.2.0") + def test_client_setinfo(self, r: redis.Redis): + r.ping() + info = r.client_info() + assert info["lib-name"] == "redis-py" + assert info["lib-ver"] == redis.__version__ + assert r.client_setinfo("lib-name", "test") + + assert r.client_setinfo("lib-ver", "123") + info = r.client_info() + assert info["lib-name"] == "test" + assert info["lib-ver"] == "123" + r2 = redis.Redis(lib_name="test2", lib_version="1234") + info = r2.client_info() + assert info["lib-name"] == "test2" + assert info["lib-ver"] == "1234" + r3 = redis.Redis(lib_name=None, lib_version=None) + info = r3.client_info() + assert info["lib-name"] == "" + assert info["lib-ver"] == "" + @pytest.mark.onlynoncluster @skip_if_server_version_lt("2.6.9") def test_client_kill(self, r, r2): diff --git a/tests/test_search.py b/tests/test_search.py index 7a2428151e..cb03cc79c1 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -1550,7 +1550,7 @@ def test_search_commands_in_pipeline(client): @pytest.mark.onlynoncluster @skip_ifmodversion_lt("2.4.3", "search") def test_dialect_config(modclient: redis.Redis): - assert modclient.ft().config_get("DEFAULT_DIALECT") == {"DEFAULT_DIALECT": "1"} + assert modclient.ft().config_get("DEFAULT_DIALECT") == {"DEFAULT_DIALECT": "0"} assert modclient.ft().config_set("DEFAULT_DIALECT", 2) assert modclient.ft().config_get("DEFAULT_DIALECT") == {"DEFAULT_DIALECT": "2"} with pytest.raises(redis.ResponseError):