diff --git a/redisvl/index/index.py b/redisvl/index/index.py index 9bec205c..f52de03d 100644 --- a/redisvl/index/index.py +++ b/redisvl/index/index.py @@ -282,6 +282,7 @@ def __init__( self._connection_kwargs = connection_kwargs or {} self._lock = threading.Lock() + self._validated_client = False self._owns_redis_client = redis_client is None if self._owns_redis_client: weakref.finalize(self, self.disconnect) @@ -361,6 +362,12 @@ def _redis_client(self) -> Optional[redis.Redis]: redis_url=self._redis_url, **self._connection_kwargs, ) + if not self._validated_client: + RedisConnectionFactory.validate_sync_redis( + self.__redis_client, + self._lib_name, + ) + self._validated_client = True return self.__redis_client @deprecated_function("connect", "Pass connection parameters in __init__.") @@ -858,6 +865,7 @@ def __init__( self._connection_kwargs = connection_kwargs or {} self._lock = asyncio.Lock() + self._validated_client = False self._owns_redis_client = redis_client is None if self._owns_redis_client: weakref.finalize(self, sync_wrapper(self.disconnect)) @@ -954,9 +962,12 @@ async def _get_client(self) -> aredis.Redis: self._redis_client = ( await RedisConnectionFactory._get_aredis_connection(**kwargs) ) + if not self._validated_client: await RedisConnectionFactory.validate_async_redis( - self._redis_client, self._lib_name + self._redis_client, + self._lib_name, ) + self._validated_client = True return self._redis_client async def _validate_client( diff --git a/redisvl/redis/connection.py b/redisvl/redis/connection.py index a558fe4b..9690fb56 100644 --- a/redisvl/redis/connection.py +++ b/redisvl/redis/connection.py @@ -159,7 +159,7 @@ def validate_modules( required_modules: List of required modules. Raises: - ValueError: If required Redis modules are not installed. + RedisModuleVersionError: If required Redis modules are not installed. """ required_modules = required_modules or DEFAULT_REQUIRED_MODULES diff --git a/tests/integration/test_async_search_index.py b/tests/integration/test_async_search_index.py index 2dd97e36..d1b42235 100644 --- a/tests/integration/test_async_search_index.py +++ b/tests/integration/test_async_search_index.py @@ -1,10 +1,11 @@ import warnings +from unittest import mock import pytest from redis import Redis as SyncRedis -from redis.asyncio import Redis +from redis.asyncio import Redis as AsyncRedis -from redisvl.exceptions import RedisSearchError +from redisvl.exceptions import RedisModuleVersionError, RedisSearchError from redisvl.index import AsyncSearchIndex from redisvl.query import VectorQuery from redisvl.redis.utils import convert_bytes @@ -172,12 +173,12 @@ async def test_search_index_set_client(client, redis_url, index_schema): with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=DeprecationWarning) await async_index.create(overwrite=True, drop=True) - assert isinstance(async_index.client, Redis) + assert isinstance(async_index.client, AsyncRedis) # Tests deprecated sync -> async conversion behavior assert isinstance(client, SyncRedis) await async_index.set_client(client) - assert isinstance(async_index.client, Redis) + assert isinstance(async_index.client, AsyncRedis) await async_index.disconnect() assert async_index.client is None @@ -410,3 +411,28 @@ async def test_search_index_that_owns_client_disconnect_sync(index_schema, redis await async_index.create(overwrite=True, drop=True) await async_index.disconnect() assert async_index._redis_client is None + + +@pytest.mark.asyncio +async def test_async_search_index_validates_redis_modules(redis_url): + """ + A regression test for RAAE-694: we should validate that a passed-in + Redis client has the correct modules installed. + """ + client = AsyncRedis.from_url(redis_url) + with mock.patch( + "redisvl.index.index.RedisConnectionFactory.validate_async_redis" + ) as mock_validate_async_redis: + mock_validate_async_redis.side_effect = RedisModuleVersionError( + "Required modules not installed" + ) + with pytest.raises(RedisModuleVersionError): + index = AsyncSearchIndex( + schema=IndexSchema.from_dict( + {"index": {"name": "my_index"}, "fields": fields} + ), + redis_client=client, + ) + await index.create(overwrite=True, drop=True) + + mock_validate_async_redis.assert_called_once() diff --git a/tests/integration/test_search_index.py b/tests/integration/test_search_index.py index fab4a591..368c048a 100644 --- a/tests/integration/test_search_index.py +++ b/tests/integration/test_search_index.py @@ -1,8 +1,10 @@ import warnings +from unittest import mock import pytest +from redis import Redis -from redisvl.exceptions import RedisSearchError +from redisvl.exceptions import RedisModuleVersionError, RedisSearchError from redisvl.index import SearchIndex from redisvl.query import VectorQuery from redisvl.redis.utils import convert_bytes @@ -363,3 +365,27 @@ def test_search_index_that_owns_client_disconnect(index_schema, redis_url): index.create(overwrite=True, drop=True) index.disconnect() assert index.client is None + + +def test_search_index_validates_redis_modules(redis_url): + """ + A regression test for RAAE-694: we should validate that a passed-in + Redis client has the correct modules installed. + """ + client = Redis.from_url(redis_url) + with mock.patch( + "redisvl.index.index.RedisConnectionFactory.validate_sync_redis" + ) as mock_validate_sync_redis: + mock_validate_sync_redis.side_effect = RedisModuleVersionError( + "Required modules not installed" + ) + with pytest.raises(RedisModuleVersionError): + index = SearchIndex( + schema=IndexSchema.from_dict( + {"index": {"name": "my_index"}, "fields": fields} + ), + redis_client=client, + ) + index.create(overwrite=True, drop=True) + + mock_validate_sync_redis.assert_called_once()