From 52126f6e6d9934906ae74b5fc0c055904914b7fe Mon Sep 17 00:00:00 2001 From: Andrew Brookins Date: Wed, 19 Mar 2025 16:48:12 -0700 Subject: [PATCH 1/4] Validate passed-in Redis clients --- redisvl/index/index.py | 18 ++++++++++- redisvl/redis/connection.py | 2 +- tests/integration/test_async_search_index.py | 34 +++++++++++++++++--- tests/integration/test_search_index.py | 26 ++++++++++++++- 4 files changed, 73 insertions(+), 7 deletions(-) diff --git a/redisvl/index/index.py b/redisvl/index/index.py index 9bec205c..f0898191 100644 --- a/redisvl/index/index.py +++ b/redisvl/index/index.py @@ -272,6 +272,11 @@ def __init__( if not isinstance(schema, IndexSchema): raise ValueError("Must provide a valid IndexSchema object") + if redis_client: + RedisConnectionFactory.validate_sync_redis( + redis_client, required_modules=REQUIRED_MODULES_FOR_INTROSPECTION + ) + self.schema = schema self._lib_name: Optional[str] = kwargs.pop("lib_name", None) @@ -282,6 +287,7 @@ def __init__( self._connection_kwargs = connection_kwargs or {} self._lock = threading.Lock() + self._validated_client = False self._owns_redis_client = redis_client is None if self._owns_redis_client: weakref.finalize(self, self.disconnect) @@ -361,6 +367,12 @@ def _redis_client(self) -> Optional[redis.Redis]: redis_url=self._redis_url, **self._connection_kwargs, ) + if not self._validated_client: + RedisConnectionFactory.validate_sync_redis( + self.__redis_client, + self._lib_name, + ) + self._validated_client = True return self.__redis_client @deprecated_function("connect", "Pass connection parameters in __init__.") @@ -851,6 +863,7 @@ def __init__( self.schema = schema self._lib_name: Optional[str] = kwargs.pop("lib_name", None) + self._validated_client = False # Store connection parameters self._redis_client = redis_client @@ -954,9 +967,12 @@ async def _get_client(self) -> aredis.Redis: self._redis_client = ( await RedisConnectionFactory._get_aredis_connection(**kwargs) ) + if not self._validated_client: await RedisConnectionFactory.validate_async_redis( - self._redis_client, self._lib_name + self._redis_client, + self._lib_name, ) + self._validated_client = True return self._redis_client async def _validate_client( diff --git a/redisvl/redis/connection.py b/redisvl/redis/connection.py index a558fe4b..9690fb56 100644 --- a/redisvl/redis/connection.py +++ b/redisvl/redis/connection.py @@ -159,7 +159,7 @@ def validate_modules( required_modules: List of required modules. Raises: - ValueError: If required Redis modules are not installed. + RedisModuleVersionError: If required Redis modules are not installed. """ required_modules = required_modules or DEFAULT_REQUIRED_MODULES diff --git a/tests/integration/test_async_search_index.py b/tests/integration/test_async_search_index.py index 2dd97e36..d1b42235 100644 --- a/tests/integration/test_async_search_index.py +++ b/tests/integration/test_async_search_index.py @@ -1,10 +1,11 @@ import warnings +from unittest import mock import pytest from redis import Redis as SyncRedis -from redis.asyncio import Redis +from redis.asyncio import Redis as AsyncRedis -from redisvl.exceptions import RedisSearchError +from redisvl.exceptions import RedisModuleVersionError, RedisSearchError from redisvl.index import AsyncSearchIndex from redisvl.query import VectorQuery from redisvl.redis.utils import convert_bytes @@ -172,12 +173,12 @@ async def test_search_index_set_client(client, redis_url, index_schema): with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=DeprecationWarning) await async_index.create(overwrite=True, drop=True) - assert isinstance(async_index.client, Redis) + assert isinstance(async_index.client, AsyncRedis) # Tests deprecated sync -> async conversion behavior assert isinstance(client, SyncRedis) await async_index.set_client(client) - assert isinstance(async_index.client, Redis) + assert isinstance(async_index.client, AsyncRedis) await async_index.disconnect() assert async_index.client is None @@ -410,3 +411,28 @@ async def test_search_index_that_owns_client_disconnect_sync(index_schema, redis await async_index.create(overwrite=True, drop=True) await async_index.disconnect() assert async_index._redis_client is None + + +@pytest.mark.asyncio +async def test_async_search_index_validates_redis_modules(redis_url): + """ + A regression test for RAAE-694: we should validate that a passed-in + Redis client has the correct modules installed. + """ + client = AsyncRedis.from_url(redis_url) + with mock.patch( + "redisvl.index.index.RedisConnectionFactory.validate_async_redis" + ) as mock_validate_async_redis: + mock_validate_async_redis.side_effect = RedisModuleVersionError( + "Required modules not installed" + ) + with pytest.raises(RedisModuleVersionError): + index = AsyncSearchIndex( + schema=IndexSchema.from_dict( + {"index": {"name": "my_index"}, "fields": fields} + ), + redis_client=client, + ) + await index.create(overwrite=True, drop=True) + + mock_validate_async_redis.assert_called_once() diff --git a/tests/integration/test_search_index.py b/tests/integration/test_search_index.py index fab4a591..219af206 100644 --- a/tests/integration/test_search_index.py +++ b/tests/integration/test_search_index.py @@ -1,8 +1,10 @@ import warnings +from unittest import mock import pytest +from redis import Redis -from redisvl.exceptions import RedisSearchError +from redisvl.exceptions import RedisModuleVersionError, RedisSearchError from redisvl.index import SearchIndex from redisvl.query import VectorQuery from redisvl.redis.utils import convert_bytes @@ -363,3 +365,25 @@ def test_search_index_that_owns_client_disconnect(index_schema, redis_url): index.create(overwrite=True, drop=True) index.disconnect() assert index.client is None + + +def test_search_index_validates_redis_modules(redis_url): + """ + A regression test for RAAE-694: we should validate that a passed-in + Redis client has the correct modules installed. + """ + client = Redis.from_url(redis_url) + with mock.patch( + "redisvl.index.index.RedisConnectionFactory.validate_sync_redis" + ) as mock_validate_sync_redis: + mock_validate_sync_redis.side_effect = RedisModuleVersionError( + "Required modules not installed" + ) + with pytest.raises(RedisModuleVersionError): + SearchIndex( + schema=IndexSchema.from_dict( + {"index": {"name": "my_index"}, "fields": fields} + ), + redis_client=client, + ) + mock_validate_sync_redis.assert_called_once() From 4b7713f4c8b381a39be2e61ec17e96cc9901f62b Mon Sep 17 00:00:00 2001 From: Andrew Brookins Date: Wed, 19 Mar 2025 16:50:40 -0700 Subject: [PATCH 2/4] Remove validation from __init__ --- redisvl/index/index.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/redisvl/index/index.py b/redisvl/index/index.py index f0898191..878a3671 100644 --- a/redisvl/index/index.py +++ b/redisvl/index/index.py @@ -272,11 +272,6 @@ def __init__( if not isinstance(schema, IndexSchema): raise ValueError("Must provide a valid IndexSchema object") - if redis_client: - RedisConnectionFactory.validate_sync_redis( - redis_client, required_modules=REQUIRED_MODULES_FOR_INTROSPECTION - ) - self.schema = schema self._lib_name: Optional[str] = kwargs.pop("lib_name", None) From dc381692c3eb70b14c7af6a343125fc749e3e3bd Mon Sep 17 00:00:00 2001 From: Andrew Brookins Date: Wed, 19 Mar 2025 16:52:49 -0700 Subject: [PATCH 3/4] Initialize variables in the same place --- redisvl/index/index.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redisvl/index/index.py b/redisvl/index/index.py index 878a3671..f52de03d 100644 --- a/redisvl/index/index.py +++ b/redisvl/index/index.py @@ -858,7 +858,6 @@ def __init__( self.schema = schema self._lib_name: Optional[str] = kwargs.pop("lib_name", None) - self._validated_client = False # Store connection parameters self._redis_client = redis_client @@ -866,6 +865,7 @@ def __init__( self._connection_kwargs = connection_kwargs or {} self._lock = asyncio.Lock() + self._validated_client = False self._owns_redis_client = redis_client is None if self._owns_redis_client: weakref.finalize(self, sync_wrapper(self.disconnect)) From 232747b0841006bed96ad5816e69c455ff9a36fa Mon Sep 17 00:00:00 2001 From: Andrew Brookins Date: Wed, 19 Mar 2025 17:03:04 -0700 Subject: [PATCH 4/4] Fix test --- tests/integration/test_search_index.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/integration/test_search_index.py b/tests/integration/test_search_index.py index 219af206..368c048a 100644 --- a/tests/integration/test_search_index.py +++ b/tests/integration/test_search_index.py @@ -380,10 +380,12 @@ def test_search_index_validates_redis_modules(redis_url): "Required modules not installed" ) with pytest.raises(RedisModuleVersionError): - SearchIndex( + index = SearchIndex( schema=IndexSchema.from_dict( {"index": {"name": "my_index"}, "fields": fields} ), redis_client=client, ) + index.create(overwrite=True, drop=True) + mock_validate_sync_redis.assert_called_once()