Skip to content

Validate passed-in Redis clients #296

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion redisvl/index/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ def __init__(
self._connection_kwargs = connection_kwargs or {}
self._lock = threading.Lock()

self._validated_client = False
self._owns_redis_client = redis_client is None
if self._owns_redis_client:
weakref.finalize(self, self.disconnect)
Expand Down Expand Up @@ -361,6 +362,12 @@ def _redis_client(self) -> Optional[redis.Redis]:
redis_url=self._redis_url,
**self._connection_kwargs,
)
if not self._validated_client:
RedisConnectionFactory.validate_sync_redis(
self.__redis_client,
self._lib_name,
)
self._validated_client = True
return self.__redis_client

@deprecated_function("connect", "Pass connection parameters in __init__.")
Expand Down Expand Up @@ -858,6 +865,7 @@ def __init__(
self._connection_kwargs = connection_kwargs or {}
self._lock = asyncio.Lock()

self._validated_client = False
self._owns_redis_client = redis_client is None
if self._owns_redis_client:
weakref.finalize(self, sync_wrapper(self.disconnect))
Expand Down Expand Up @@ -954,9 +962,12 @@ async def _get_client(self) -> aredis.Redis:
self._redis_client = (
await RedisConnectionFactory._get_aredis_connection(**kwargs)
)
if not self._validated_client:
await RedisConnectionFactory.validate_async_redis(
self._redis_client, self._lib_name
self._redis_client,
self._lib_name,
)
self._validated_client = True
return self._redis_client

async def _validate_client(
Expand Down
2 changes: 1 addition & 1 deletion redisvl/redis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def validate_modules(
required_modules: List of required modules.

Raises:
ValueError: If required Redis modules are not installed.
RedisModuleVersionError: If required Redis modules are not installed.
"""
required_modules = required_modules or DEFAULT_REQUIRED_MODULES

Expand Down
34 changes: 30 additions & 4 deletions tests/integration/test_async_search_index.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import warnings
from unittest import mock

import pytest
from redis import Redis as SyncRedis
from redis.asyncio import Redis
from redis.asyncio import Redis as AsyncRedis

from redisvl.exceptions import RedisSearchError
from redisvl.exceptions import RedisModuleVersionError, RedisSearchError
from redisvl.index import AsyncSearchIndex
from redisvl.query import VectorQuery
from redisvl.redis.utils import convert_bytes
Expand Down Expand Up @@ -172,12 +173,12 @@ async def test_search_index_set_client(client, redis_url, index_schema):
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning)
await async_index.create(overwrite=True, drop=True)
assert isinstance(async_index.client, Redis)
assert isinstance(async_index.client, AsyncRedis)

# Tests deprecated sync -> async conversion behavior
assert isinstance(client, SyncRedis)
await async_index.set_client(client)
assert isinstance(async_index.client, Redis)
assert isinstance(async_index.client, AsyncRedis)

await async_index.disconnect()
assert async_index.client is None
Expand Down Expand Up @@ -410,3 +411,28 @@ async def test_search_index_that_owns_client_disconnect_sync(index_schema, redis
await async_index.create(overwrite=True, drop=True)
await async_index.disconnect()
assert async_index._redis_client is None


@pytest.mark.asyncio
async def test_async_search_index_validates_redis_modules(redis_url):
"""
A regression test for RAAE-694: we should validate that a passed-in
Redis client has the correct modules installed.
"""
client = AsyncRedis.from_url(redis_url)
with mock.patch(
"redisvl.index.index.RedisConnectionFactory.validate_async_redis"
) as mock_validate_async_redis:
mock_validate_async_redis.side_effect = RedisModuleVersionError(
"Required modules not installed"
)
with pytest.raises(RedisModuleVersionError):
index = AsyncSearchIndex(
schema=IndexSchema.from_dict(
{"index": {"name": "my_index"}, "fields": fields}
),
redis_client=client,
)
await index.create(overwrite=True, drop=True)

mock_validate_async_redis.assert_called_once()
28 changes: 27 additions & 1 deletion tests/integration/test_search_index.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import warnings
from unittest import mock

import pytest
from redis import Redis

from redisvl.exceptions import RedisSearchError
from redisvl.exceptions import RedisModuleVersionError, RedisSearchError
from redisvl.index import SearchIndex
from redisvl.query import VectorQuery
from redisvl.redis.utils import convert_bytes
Expand Down Expand Up @@ -363,3 +365,27 @@ def test_search_index_that_owns_client_disconnect(index_schema, redis_url):
index.create(overwrite=True, drop=True)
index.disconnect()
assert index.client is None


def test_search_index_validates_redis_modules(redis_url):
"""
A regression test for RAAE-694: we should validate that a passed-in
Redis client has the correct modules installed.
"""
client = Redis.from_url(redis_url)
with mock.patch(
"redisvl.index.index.RedisConnectionFactory.validate_sync_redis"
) as mock_validate_sync_redis:
mock_validate_sync_redis.side_effect = RedisModuleVersionError(
"Required modules not installed"
)
with pytest.raises(RedisModuleVersionError):
index = SearchIndex(
schema=IndexSchema.from_dict(
{"index": {"name": "my_index"}, "fields": fields}
),
redis_client=client,
)
index.create(overwrite=True, drop=True)

mock_validate_sync_redis.assert_called_once()