From 62596048f5e42fa1c9bc8365f5a19e43f97647fb Mon Sep 17 00:00:00 2001 From: Andrew Brookins Date: Wed, 12 Feb 2025 16:06:55 -0800 Subject: [PATCH 01/22] WIP --- redisvl/index/index.py | 223 ++++++++++++++++++++--------------------- 1 file changed, 108 insertions(+), 115 deletions(-) diff --git a/redisvl/index/index.py b/redisvl/index/index.py index 8914b4c9..82bc3406 100644 --- a/redisvl/index/index.py +++ b/redisvl/index/index.py @@ -21,6 +21,7 @@ from redis.commands.search.document import Document from redis.commands.search.result import Result from redisvl.query.query import BaseQuery + import redis.asyncio import redis import redis.asyncio as aredis @@ -793,16 +794,29 @@ class AsyncSearchIndex(BaseSearchIndex): """ + # TODO: The `aredis.Redis` type is not working for type checks. + _redis_client: Optional[redis.asyncio.Redis] = None + _redis_url: Optional[str] = None + _redis_kwargs: Dict[str, Any] = {} + def __init__( self, schema: IndexSchema, + *, + redis_url: Optional[str] = None, + redis_client: Optional[aredis.Redis] = None, + redis_kwargs: Optional[Dict[str, Any]] = None, **kwargs, ): """Initialize the RedisVL async search index with a schema. Args: schema (IndexSchema): Index schema object. - connection_args (Dict[str, Any], optional): Redis client connection + redis_url (Optional[str], optional): The URL of the Redis server to + connect to. + redis_client (Optional[aredis.Redis], optional): An + instantiated redis client. + redis_kwargs (Dict[str, Any], optional): Redis client connection args. """ # final validation on schema object @@ -813,39 +827,21 @@ def __init__( self._lib_name: Optional[str] = kwargs.pop("lib_name", None) - # set up empty redis connection - self._redis_client: Optional[aredis.Redis] = None - - if "redis_client" in kwargs or "redis_url" in kwargs: - logger.warning( - "Must use set_client() or connect() methods to provide a Redis connection to AsyncSearchIndex" - ) - - atexit.register(self._cleanup_connection) - - def _cleanup_connection(self): - if self._redis_client: - - def run_in_thread(): - try: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - loop.run_until_complete(self._redis_client.aclose()) - loop.close() - except RuntimeError: - pass - - # Run cleanup in a background thread to avoid event loop issues - thread = threading.Thread(target=run_in_thread) - thread.start() - thread.join() + # Store connection parameters + if redis_client and redis_url: + raise ValueError("Cannot provide both redis_client and redis_url") + self._redis_client = redis_client + self._redis_url = redis_url + self._redis_kwargs = redis_kwargs or {} + self._lock = asyncio.Lock() + + async def disconnect(self): + """Asynchronously disconnect and cleanup the underlying async redis connection.""" + if self._redis_client is not None: + await self._redis_client.aclose() # type: ignore self._redis_client = None - def disconnect(self): - """Disconnect and cleanup the underlying async redis connection.""" - self._cleanup_connection() - @classmethod async def from_existing( cls, @@ -902,69 +898,59 @@ def client(self) -> Optional[aredis.Redis]: return self._redis_client async def connect(self, redis_url: Optional[str] = None, **kwargs): - """Connect to a Redis instance using the provided `redis_url`, falling - back to the `REDIS_URL` environment variable (if available). + """[DEPRECATED] Connect to a Redis instance. Use connection parameters in __init__.""" + import warnings - Note: Additional keyword arguments (`**kwargs`) can be used to provide - extra options specific to the Redis connection. - - Args: - redis_url (Optional[str], optional): The URL of the Redis server to - connect to. If not provided, the method defaults to using the - `REDIS_URL` environment variable. - - Raises: - redis.exceptions.ConnectionError: If the connection to the Redis - server fails. - ValueError: If the Redis URL is not provided nor accessible - through the `REDIS_URL` environment variable. - - .. code-block:: python - - index.connect(redis_url="redis://localhost:6379") - - """ + warnings.warn( + "connect() is deprecated; pass connection parameters in __init__", + DeprecationWarning, + ) client = RedisConnectionFactory.connect( redis_url=redis_url, use_async=True, **kwargs ) return await self.set_client(client) - @setup_async_redis() - async def set_client(self, redis_client: aredis.Redis): - """Manually set the Redis client to use with the search index. - - This method configures the search index to use a specific - Async Redis client. It is useful for cases where an external, - custom-configured client is preferred instead of creating a new one. - - Args: - redis_client (aredis.Redis): An Async Redis - client instance to be used for the connection. - - Raises: - TypeError: If the provided client is not valid. - - .. code-block:: python + async def set_client(self, redis_client: Optional[aredis.Redis]): + """[DEPRECATED] Manually set the Redis client to use with the search index. + This method is deprecated; please provide connection parameters in __init__. + """ + import warnings - import redis.asyncio as aredis - from redisvl.index import AsyncSearchIndex + warnings.warn( + "set_client() is deprecated; pass connection parameters in __init__", + DeprecationWarning, + ) + return await self._set_client(redis_client) - # async Redis client and index - client = aredis.Redis.from_url("redis://localhost:6379") - index = AsyncSearchIndex.from_yaml("schemas/schema.yaml") - await index.set_client(client) + async def _set_client(self, redis_client: Optional[redis.asyncio.Redis]): + """ + Set the Redis client to use with the search index. + NOTE: Remove this method once the deprecation period is over. """ - if isinstance(redis_client, redis.Redis): - print("Setting client and converting from async", flush=True) - self._redis_client = RedisConnectionFactory.sync_to_async_redis( - redis_client - ) - else: + if self._redis_client is not None: + await self._redis_client.aclose() # type: ignore + async with self._lock: self._redis_client = redis_client - return self + async def _get_client(self) -> aredis.Redis: + """Lazily instantiate and return the async Redis client.""" + if self._redis_client is None: + async with self._lock: + # Double-check to protect against concurrent access + if self._redis_client is None: + kwargs = self._redis_kwargs + if self._redis_url: + kwargs["redis_url"] = self._redis_url + self._redis_client = ( + RedisConnectionFactory.get_async_redis_connection(**kwargs) + ) + await RedisConnectionFactory.validate_async_redis( + self._redis_client, self._lib_name + ) + return self._redis_client + async def create(self, overwrite: bool = False, drop: bool = False) -> None: """Asynchronously create an index in Redis with the current schema and properties. @@ -990,6 +976,7 @@ async def create(self, overwrite: bool = False, drop: bool = False) -> None: # overwrite an index in Redis; drop associated data (clean slate) await index.create(overwrite=True, drop=True) """ + client = await self._get_client() redis_fields = self.schema.redis_fields if not redis_fields: @@ -1005,7 +992,7 @@ async def create(self, overwrite: bool = False, drop: bool = False) -> None: await self.delete(drop) try: - await self._redis_client.ft(self.schema.index.name).create_index( # type: ignore + await client.ft(self.schema.index.name).create_index( fields=redis_fields, definition=IndexDefinition( prefix=[self.schema.index.prefix], index_type=self._storage.type @@ -1025,10 +1012,9 @@ async def delete(self, drop: bool = True): Raises: redis.exceptions.ResponseError: If the index does not exist. """ + client = await self._get_client() try: - await self._redis_client.ft(self.schema.index.name).dropindex( # type: ignore - delete_documents=drop - ) + await client.ft(self.schema.index.name).dropindex(delete_documents=drop) except Exception as e: raise RedisSearchError(f"Error while deleting index: {str(e)}") from e @@ -1039,16 +1025,15 @@ async def clear(self) -> int: Returns: int: Count of records deleted from Redis. """ - # Track deleted records + client = await self._get_client() total_records_deleted: int = 0 - # Paginate using queries and delete in batches async for batch in self.paginate( FilterQuery(FilterExpression("*"), return_fields=["id"]), page_size=500 ): batch_keys = [record["id"] for record in batch] - records_deleted = await self._redis_client.delete(*batch_keys) # type: ignore - total_records_deleted += records_deleted # type: ignore + records_deleted = await client.delete(*batch_keys) + total_records_deleted += records_deleted return total_records_deleted @@ -1061,10 +1046,11 @@ async def drop_keys(self, keys: Union[str, List[str]]) -> int: Returns: int: Count of records deleted from Redis. """ - if isinstance(keys, List): - return await self._redis_client.delete(*keys) # type: ignore + client = await self._get_client() + if isinstance(keys, list): + return await client.delete(*keys) else: - return await self._redis_client.delete(keys) # type: ignore + return await client.delete(keys) async def load( self, @@ -1124,9 +1110,10 @@ async def add_field(d): keys = await index.load(data, preprocess=add_field) """ + client = await self._get_client() try: return await self._storage.awrite( - self._redis_client, # type: ignore + client, objects=data, id_field=id_field, keys=keys, @@ -1150,7 +1137,8 @@ async def fetch(self, id: str) -> Optional[Dict[str, Any]]: Returns: Dict[str, Any]: The fetched object. """ - obj = await self._storage.aget(self._redis_client, [self.key(id)]) # type: ignore + client = await self._get_client() + obj = await self._storage.aget(client, [self.key(id)]) if obj: return convert_bytes(obj[0]) return None @@ -1165,10 +1153,10 @@ async def aggregate(self, *args, **kwargs) -> "AggregateResult": Returns: Result: Raw Redis aggregation results. """ + client = await self._get_client() try: - return await self._redis_client.ft(self.schema.index.name).aggregate( # type: ignore - *args, **kwargs - ) + # TODO: Typing + return await client.ft(self.schema.index.name).aggregate(*args, **kwargs) except Exception as e: raise RedisSearchError(f"Error while aggregating: {str(e)}") from e @@ -1182,10 +1170,10 @@ async def search(self, *args, **kwargs) -> "Result": Returns: Result: Raw Redis search results. """ + client = await self._get_client() try: - return await self._redis_client.ft(self.schema.index.name).search( # type: ignore - *args, **kwargs - ) + # TODO: Typing + return await client.ft(self.schema.index.name).search(*args, **kwargs) except Exception as e: raise RedisSearchError(f"Error while searching: {str(e)}") from e @@ -1256,7 +1244,7 @@ async def paginate(self, query: BaseQuery, page_size: int = 30) -> AsyncGenerato """ if not isinstance(page_size, int): - raise TypeError("page_size must be an integer") + raise TypeError("page_size must be of type int") if page_size <= 0: raise ValueError("page_size must be greater than 0") @@ -1268,7 +1256,6 @@ async def paginate(self, query: BaseQuery, page_size: int = 30) -> AsyncGenerato if not results: break yield results - # increment the pagination tracker first += page_size async def listall(self) -> List[str]: @@ -1277,9 +1264,8 @@ async def listall(self) -> List[str]: Returns: List[str]: The list of indices in the database. """ - return convert_bytes( - await self._redis_client.execute_command("FT._LIST") # type: ignore - ) + client: aredis.Redis = await self._get_client() + return convert_bytes(await client.execute_command("FT._LIST")) async def exists(self) -> bool: """Check if the index exists in Redis. @@ -1289,15 +1275,6 @@ async def exists(self) -> bool: """ return self.schema.index.name in await self.listall() - @staticmethod - async def _info(name: str, redis_client: aredis.Redis) -> Dict[str, Any]: - try: - return convert_bytes(await redis_client.ft(name).info()) # type: ignore - except Exception as e: - raise RedisSearchError( - f"Error while fetching {name} index info: {str(e)}" - ) from e - async def info(self, name: Optional[str] = None) -> Dict[str, Any]: """Get information about the index. @@ -1308,5 +1285,21 @@ async def info(self, name: Optional[str] = None) -> Dict[str, Any]: Returns: dict: A dictionary containing the information about the index. """ + client: aredis.Redis = await self._get_client() index_name = name or self.schema.index.name - return await self._info(index_name, self._redis_client) # type: ignore + return await type(self)._info(index_name, client) + + @staticmethod + async def _info(name: str, redis_client: aredis.Redis) -> Dict[str, Any]: + try: + return convert_bytes(await redis_client.ft(name).info()) # type: ignore + except Exception as e: + raise RedisSearchError( + f"Error while fetching {name} index info: {str(e)}" + ) from e + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.disconnect() From af1c2d0a0ba7b61f363be743de0d99f3e9de90d0 Mon Sep 17 00:00:00 2001 From: Andrew Brookins Date: Wed, 12 Feb 2025 17:17:09 -0800 Subject: [PATCH 02/22] WIP --- redisvl/extensions/llmcache/semantic.py | 16 ++-- redisvl/index/index.py | 80 ++++++++++---------- redisvl/utils/utils.py | 26 ++++++- tests/integration/test_async_search_index.py | 4 +- tests/unit/test_utils.py | 21 ++++- 5 files changed, 94 insertions(+), 53 deletions(-) diff --git a/redisvl/extensions/llmcache/semantic.py b/redisvl/extensions/llmcache/semantic.py index 5553bfa4..ad45eafa 100644 --- a/redisvl/extensions/llmcache/semantic.py +++ b/redisvl/extensions/llmcache/semantic.py @@ -174,17 +174,16 @@ def _modify_schema( async def _get_async_index(self) -> AsyncSearchIndex: """Lazily construct the async search index class.""" + # Construct async index if necessary if not self._aindex: - # Construct async index if necessary - self._aindex = AsyncSearchIndex(schema=self._index.schema) - # Connect Redis async client redis_client = self.redis_kwargs["redis_client"] redis_url = self.redis_kwargs["redis_url"] connection_kwargs = self.redis_kwargs["connection_kwargs"] - if redis_client is not None: - await self._aindex.set_client(redis_client) - elif redis_url: - await self._aindex.connect(redis_url, **connection_kwargs) # type: ignore + + self._aindex = AsyncSearchIndex(schema=self._index.schema, + redis_client=redis_client, + redis_url=redis_url, + **connection_kwargs) return self._aindex @property @@ -290,7 +289,8 @@ async def _async_refresh_ttl(self, key: str) -> None: """Async refresh the time-to-live for the specified key.""" aindex = await self._get_async_index() if self._ttl: - await aindex.client.expire(key, self._ttl) # type: ignore + client = await aindex.get_client() + await client.expire(key, self._ttl) # type: ignore def _vectorize_prompt(self, prompt: Optional[str]) -> List[float]: """Converts a text prompt to its vector representation using the diff --git a/redisvl/index/index.py b/redisvl/index/index.py index 82bc3406..e11dff7d 100644 --- a/redisvl/index/index.py +++ b/redisvl/index/index.py @@ -15,6 +15,9 @@ Optional, Union, ) +import warnings + +from redisvl.utils.utils import deprecated_function if TYPE_CHECKING: from redis.commands.search.aggregation import AggregateResult @@ -839,6 +842,7 @@ def __init__( async def disconnect(self): """Asynchronously disconnect and cleanup the underlying async redis connection.""" if self._redis_client is not None: + print(self._redis_client) await self._redis_client.aclose() # type: ignore self._redis_client = None @@ -860,12 +864,13 @@ async def from_existing( redis_url (Optional[str]): The URL of the Redis server to connect to. """ - if redis_url: - redis_client = RedisConnectionFactory.connect( - redis_url=redis_url, use_async=True, **kwargs - ) - - if not redis_client: + if redis_client and redis_url: + raise ValueError("Cannot provide both redis_client and redis_url") + elif redis_url: + redis_client = RedisConnectionFactory.get_async_redis_connection(url=redis_url, **kwargs) + elif redis_client: + pass + else: raise ValueError( "Must provide either a redis_url or redis_client to fetch Redis index info." ) @@ -888,14 +893,21 @@ async def from_existing( index_info = await cls._info(name, redis_client) schema_dict = convert_index_info_to_schema(index_info) schema = IndexSchema.from_dict(schema_dict) - index = cls(schema, **kwargs) - await index.set_client(redis_client) - return index + return cls(schema, redis_client=redis_client, **kwargs) + @deprecated_function("client", "Use await self.get_client()") @property - def client(self) -> Optional[aredis.Redis]: + def client(self) -> aredis.Redis: """The underlying redis-py client object.""" - return self._redis_client + if self._redis_client is None: + if asyncio.current_task() is not None: + warnings.warn("Risk of deadlock! Use await self.get_client() if you are " + "in an async context.") + redis_client = asyncio.run_coroutine_threadsafe(self.get_client(), asyncio.get_event_loop()) + client = redis_client.result() + else: + client = self._redis_client + return client async def connect(self, redis_url: Optional[str] = None, **kwargs): """[DEPRECATED] Connect to a Redis instance. Use connection parameters in __init__.""" @@ -910,23 +922,11 @@ async def connect(self, redis_url: Optional[str] = None, **kwargs): ) return await self.set_client(client) + @deprecated_function("set_client", "Pass connection parameters in __init__") async def set_client(self, redis_client: Optional[aredis.Redis]): - """[DEPRECATED] Manually set the Redis client to use with the search index. - This method is deprecated; please provide connection parameters in __init__. """ - import warnings - - warnings.warn( - "set_client() is deprecated; pass connection parameters in __init__", - DeprecationWarning, - ) - return await self._set_client(redis_client) - - async def _set_client(self, redis_client: Optional[redis.asyncio.Redis]): - """ - Set the Redis client to use with the search index. - - NOTE: Remove this method once the deprecation period is over. + [DEPRECATED] Manually set the Redis client to use with the search index. + This method is deprecated; please provide connection parameters in __init__. """ if self._redis_client is not None: await self._redis_client.aclose() # type: ignore @@ -934,7 +934,7 @@ async def _set_client(self, redis_client: Optional[redis.asyncio.Redis]): self._redis_client = redis_client return self - async def _get_client(self) -> aredis.Redis: + async def get_client(self) -> aredis.Redis: """Lazily instantiate and return the async Redis client.""" if self._redis_client is None: async with self._lock: @@ -976,7 +976,7 @@ async def create(self, overwrite: bool = False, drop: bool = False) -> None: # overwrite an index in Redis; drop associated data (clean slate) await index.create(overwrite=True, drop=True) """ - client = await self._get_client() + client = await self.get_client() redis_fields = self.schema.redis_fields if not redis_fields: @@ -1012,7 +1012,7 @@ async def delete(self, drop: bool = True): Raises: redis.exceptions.ResponseError: If the index does not exist. """ - client = await self._get_client() + client = await self.get_client() try: await client.ft(self.schema.index.name).dropindex(delete_documents=drop) except Exception as e: @@ -1025,7 +1025,7 @@ async def clear(self) -> int: Returns: int: Count of records deleted from Redis. """ - client = await self._get_client() + client = await self.get_client() total_records_deleted: int = 0 async for batch in self.paginate( @@ -1046,7 +1046,7 @@ async def drop_keys(self, keys: Union[str, List[str]]) -> int: Returns: int: Count of records deleted from Redis. """ - client = await self._get_client() + client = await self.get_client() if isinstance(keys, list): return await client.delete(*keys) else: @@ -1110,7 +1110,7 @@ async def add_field(d): keys = await index.load(data, preprocess=add_field) """ - client = await self._get_client() + client = await self.get_client() try: return await self._storage.awrite( client, @@ -1137,7 +1137,7 @@ async def fetch(self, id: str) -> Optional[Dict[str, Any]]: Returns: Dict[str, Any]: The fetched object. """ - client = await self._get_client() + client = await self.get_client() obj = await self._storage.aget(client, [self.key(id)]) if obj: return convert_bytes(obj[0]) @@ -1153,10 +1153,9 @@ async def aggregate(self, *args, **kwargs) -> "AggregateResult": Returns: Result: Raw Redis aggregation results. """ - client = await self._get_client() + client = await self.get_client() try: - # TODO: Typing - return await client.ft(self.schema.index.name).aggregate(*args, **kwargs) + return client.ft(self.schema.index.name).aggregate(*args, **kwargs) except Exception as e: raise RedisSearchError(f"Error while aggregating: {str(e)}") from e @@ -1170,10 +1169,9 @@ async def search(self, *args, **kwargs) -> "Result": Returns: Result: Raw Redis search results. """ - client = await self._get_client() + client = await self.get_client() try: - # TODO: Typing - return await client.ft(self.schema.index.name).search(*args, **kwargs) + return client.ft(self.schema.index.name).search(*args, **kwargs) except Exception as e: raise RedisSearchError(f"Error while searching: {str(e)}") from e @@ -1264,7 +1262,7 @@ async def listall(self) -> List[str]: Returns: List[str]: The list of indices in the database. """ - client: aredis.Redis = await self._get_client() + client: aredis.Redis = await self.get_client() return convert_bytes(await client.execute_command("FT._LIST")) async def exists(self) -> bool: @@ -1285,7 +1283,7 @@ async def info(self, name: Optional[str] = None) -> Dict[str, Any]: Returns: dict: A dictionary containing the information about the index. """ - client: aredis.Redis = await self._get_client() + client: aredis.Redis = await self.get_client() index_name = name or self.schema.index.name return await type(self)._info(index_name, client) diff --git a/redisvl/utils/utils.py b/redisvl/utils/utils.py index da1067d8..02ded426 100644 --- a/redisvl/utils/utils.py +++ b/redisvl/utils/utils.py @@ -52,7 +52,7 @@ def validate_vector_dims(v1: int, v2: int) -> None: def serialize(data: Dict[str, Any]) -> str: - """Serlize the input into a string.""" + """Serialize the input into a string.""" return json.dumps(data) @@ -88,3 +88,27 @@ def inner(*args, **kwargs): return inner return wrapper + + +def deprecated_function(name: Optional[str] = None, replacement: Optional[str] = None): + """ + Decorator to mark a function as deprecated. + + When the wrapped function is called, the decorator will log a deprecation + warning. + """ + def decorator(func): + fn_name = name or func.__name__ + warning_message = f"Function {fn_name} is deprecated and will be " \ + "removed in the next major release." + if replacement: + warning_message += replacement + + @wraps(func) + def wrapper(*args, **kwargs): + warn(warning_message, category=DeprecationWarning, stacklevel=3) + return func(*args, **kwargs) + + return wrapper + + return decorator diff --git a/tests/integration/test_async_search_index.py b/tests/integration/test_async_search_index.py index 9dc8460d..c40790a4 100644 --- a/tests/integration/test_async_search_index.py +++ b/tests/integration/test_async_search_index.py @@ -137,7 +137,7 @@ async def test_search_index_redis_url(redis_url, index_schema): ) assert async_index.client - async_index.disconnect() + await async_index.disconnect() assert async_index.client == None @@ -155,7 +155,7 @@ async def test_search_index_set_client(async_client, client, async_index): assert async_index.client == async_client await async_index.set_client(client) - async_index.disconnect() + await async_index.disconnect() assert async_index.client == None diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index cd8a3c9f..36531109 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -7,7 +7,7 @@ convert_bytes, make_dict, ) -from redisvl.utils.utils import deprecated_argument +from redisvl.utils.utils import deprecated_argument, deprecated_function def test_even_number_of_elements(): @@ -237,3 +237,22 @@ async def test_func(dtype=None, vectorizer=None): with pytest.warns(DeprecationWarning): result = await test_func(dtype="float32") assert result == 1 + + + +class TestDeprecatedFunction: + def test_deprecated_function_warning(self): + @deprecated_function("new_func", "Use new_func2") + def old_func(): + pass + + with pytest.warns(DeprecationWarning): + old_func() + + def test_deprecated_function_warning_with_name(self): + @deprecated_function("new_func", "Use new_func2") + def old_func(): + pass + + with pytest.warns(DeprecationWarning): + old_func() From c2bf2802f206f512c7fccc49221652792b434dd8 Mon Sep 17 00:00:00 2001 From: Andrew Brookins Date: Wed, 12 Feb 2025 22:45:28 -0800 Subject: [PATCH 03/22] More cleanup --- redisvl/extensions/llmcache/semantic.py | 14 ++--- redisvl/index/index.py | 65 ++++++++++---------- redisvl/redis/connection.py | 8 ++- redisvl/utils/utils.py | 2 +- tests/integration/test_async_search_index.py | 8 +-- 5 files changed, 50 insertions(+), 47 deletions(-) diff --git a/redisvl/extensions/llmcache/semantic.py b/redisvl/extensions/llmcache/semantic.py index ad45eafa..abda1f27 100644 --- a/redisvl/extensions/llmcache/semantic.py +++ b/redisvl/extensions/llmcache/semantic.py @@ -176,14 +176,12 @@ async def _get_async_index(self) -> AsyncSearchIndex: """Lazily construct the async search index class.""" # Construct async index if necessary if not self._aindex: - redis_client = self.redis_kwargs["redis_client"] - redis_url = self.redis_kwargs["redis_url"] - connection_kwargs = self.redis_kwargs["connection_kwargs"] - - self._aindex = AsyncSearchIndex(schema=self._index.schema, - redis_client=redis_client, - redis_url=redis_url, - **connection_kwargs) + self._aindex = AsyncSearchIndex( + schema=self._index.schema, + redis_client=self.redis_kwargs["redis_client"], + redis_url=self.redis_kwargs["redis_url"], + **self.redis_kwargs["connection_kwargs"] + ) return self._aindex @property diff --git a/redisvl/index/index.py b/redisvl/index/index.py index e11dff7d..d579154e 100644 --- a/redisvl/index/index.py +++ b/redisvl/index/index.py @@ -14,9 +14,9 @@ List, Optional, Union, + cast, ) import warnings - from redisvl.utils.utils import deprecated_function if TYPE_CHECKING: @@ -797,11 +797,6 @@ class AsyncSearchIndex(BaseSearchIndex): """ - # TODO: The `aredis.Redis` type is not working for type checks. - _redis_client: Optional[redis.asyncio.Redis] = None - _redis_url: Optional[str] = None - _redis_kwargs: Dict[str, Any] = {} - def __init__( self, schema: IndexSchema, @@ -842,7 +837,6 @@ def __init__( async def disconnect(self): """Asynchronously disconnect and cleanup the underlying async redis connection.""" if self._redis_client is not None: - print(self._redis_client) await self._redis_client.aclose() # type: ignore self._redis_client = None @@ -867,7 +861,9 @@ async def from_existing( if redis_client and redis_url: raise ValueError("Cannot provide both redis_client and redis_url") elif redis_url: - redis_client = RedisConnectionFactory.get_async_redis_connection(url=redis_url, **kwargs) + redis_client = RedisConnectionFactory.get_async_redis_connection( + url=redis_url, **kwargs + ) elif redis_client: pass else: @@ -895,46 +891,36 @@ async def from_existing( schema = IndexSchema.from_dict(schema_dict) return cls(schema, redis_client=redis_client, **kwargs) - @deprecated_function("client", "Use await self.get_client()") @property - def client(self) -> aredis.Redis: + def client(self) -> Optional[aredis.Redis]: """The underlying redis-py client object.""" - if self._redis_client is None: - if asyncio.current_task() is not None: - warnings.warn("Risk of deadlock! Use await self.get_client() if you are " - "in an async context.") - redis_client = asyncio.run_coroutine_threadsafe(self.get_client(), asyncio.get_event_loop()) - client = redis_client.result() - else: - client = self._redis_client - return client + return self._redis_client + @deprecated_function("connect", "Pass connection parameters in __init__.") async def connect(self, redis_url: Optional[str] = None, **kwargs): """[DEPRECATED] Connect to a Redis instance. Use connection parameters in __init__.""" - import warnings - warnings.warn( "connect() is deprecated; pass connection parameters in __init__", DeprecationWarning, ) - client = RedisConnectionFactory.connect( + client: redis.asyncio.Redis = RedisConnectionFactory.connect( redis_url=redis_url, use_async=True, **kwargs - ) + ) # type: ignore return await self.set_client(client) - @deprecated_function("set_client", "Pass connection parameters in __init__") - async def set_client(self, redis_client: Optional[aredis.Redis]): + @deprecated_function("set_client", "Pass connection parameters in __init__.") + async def set_client(self, redis_client: aredis.Redis): """ [DEPRECATED] Manually set the Redis client to use with the search index. This method is deprecated; please provide connection parameters in __init__. """ - if self._redis_client is not None: - await self._redis_client.aclose() # type: ignore + redis_client = await self._validate_client(redis_client) + await self.disconnect() async with self._lock: self._redis_client = redis_client return self - async def get_client(self) -> aredis.Redis: + async def _get_client(self): """Lazily instantiate and return the async Redis client.""" if self._redis_client is None: async with self._lock: @@ -951,6 +937,23 @@ async def get_client(self) -> aredis.Redis: ) return self._redis_client + async def get_client(self) -> aredis.Redis: + """Return this index's async Redis client.""" + return await self._get_client() + + async def _validate_client(self, redis_client: aredis.Redis) -> aredis.Redis: + if isinstance(redis_client, redis.Redis): + warnings.warn( + "Converting sync Redis client to async client is deprecated " + "and will be removed in the next major version. Please use an " + "async Redis client instead.", + DeprecationWarning, + ) + redis_client = RedisConnectionFactory.sync_to_async_redis(redis_client) + elif not isinstance(redis_client, aredis.Redis): + raise ValueError("Invalid client type: must be redis.asyncio.Redis") + return redis_client + async def create(self, overwrite: bool = False, drop: bool = False) -> None: """Asynchronously create an index in Redis with the current schema and properties. @@ -1171,7 +1174,7 @@ async def search(self, *args, **kwargs) -> "Result": """ client = await self.get_client() try: - return client.ft(self.schema.index.name).search(*args, **kwargs) + return await client.ft(self.schema.index.name).search(*args, **kwargs) except Exception as e: raise RedisSearchError(f"Error while searching: {str(e)}") from e @@ -1262,7 +1265,7 @@ async def listall(self) -> List[str]: Returns: List[str]: The list of indices in the database. """ - client: aredis.Redis = await self.get_client() + client = await self.get_client() return convert_bytes(await client.execute_command("FT._LIST")) async def exists(self) -> bool: @@ -1283,7 +1286,7 @@ async def info(self, name: Optional[str] = None) -> Dict[str, Any]: Returns: dict: A dictionary containing the information about the index. """ - client: aredis.Redis = await self.get_client() + client = await self.get_client() index_name = name or self.schema.index.name return await type(self)._info(index_name, client) diff --git a/redisvl/redis/connection.py b/redisvl/redis/connection.py index 6cf6cd4f..693c10d9 100644 --- a/redisvl/redis/connection.py +++ b/redisvl/redis/connection.py @@ -1,5 +1,5 @@ import os -from typing import Any, Dict, List, Optional, Type +from typing import Any, Dict, List, Optional, Type, Union from redis import Redis from redis.asyncio import Connection as AsyncConnection @@ -12,6 +12,7 @@ from redisvl.exceptions import RedisModuleVersionError from redisvl.redis.constants import DEFAULT_REQUIRED_MODULES from redisvl.redis.utils import convert_bytes +from redisvl.utils.utils import deprecated_function from redisvl.version import __version__ @@ -191,7 +192,7 @@ class RedisConnectionFactory: @classmethod def connect( cls, redis_url: Optional[str] = None, use_async: bool = False, **kwargs - ) -> None: + ) -> Union[Redis, AsyncRedis]: """Create a connection to the Redis database based on a URL and some connection kwargs. @@ -260,6 +261,7 @@ def get_async_redis_connection(url: Optional[str] = None, **kwargs) -> AsyncRedi # fallback to env var REDIS_URL return AsyncRedis.from_url(get_address_from_env(), **kwargs) + @deprecated_function("sync_to_async_redis", "Please use an async Redis client instead.") @staticmethod def sync_to_async_redis(redis_client: Redis) -> AsyncRedis: # pick the right connection class @@ -267,7 +269,7 @@ def sync_to_async_redis(redis_client: Redis) -> AsyncRedis: AsyncSSLConnection if redis_client.connection_pool.connection_class == SSLConnection else AsyncConnection - ) + ) # type: ignore # make async client return AsyncRedis.from_pool( # type: ignore AsyncConnectionPool( diff --git a/redisvl/utils/utils.py b/redisvl/utils/utils.py index 02ded426..8af4aa8b 100644 --- a/redisvl/utils/utils.py +++ b/redisvl/utils/utils.py @@ -100,7 +100,7 @@ def deprecated_function(name: Optional[str] = None, replacement: Optional[str] = def decorator(func): fn_name = name or func.__name__ warning_message = f"Function {fn_name} is deprecated and will be " \ - "removed in the next major release." + "removed in the next major release. " if replacement: warning_message += replacement diff --git a/tests/integration/test_async_search_index.py b/tests/integration/test_async_search_index.py index c40790a4..5815fbec 100644 --- a/tests/integration/test_async_search_index.py +++ b/tests/integration/test_async_search_index.py @@ -33,7 +33,7 @@ def test_search_index_properties(index_schema, async_index): assert async_index.schema == index_schema # custom settings assert async_index.name == index_schema.index.name == "my_index" - assert async_index.client == None + assert async_index.client is None # default settings assert async_index.prefix == index_schema.index.prefix == "rvl" assert async_index.key_separator == index_schema.index.key_separator == ":" @@ -45,7 +45,7 @@ def test_search_index_properties(index_schema, async_index): def test_search_index_from_yaml(async_index_from_yaml): assert async_index_from_yaml.name == "json-test" - assert async_index_from_yaml.client == None + assert async_index_from_yaml.client is None assert async_index_from_yaml.prefix == "json" assert async_index_from_yaml.key_separator == ":" assert async_index_from_yaml.storage_type == StorageType.JSON @@ -54,7 +54,7 @@ def test_search_index_from_yaml(async_index_from_yaml): def test_search_index_from_dict(async_index_from_dict): assert async_index_from_dict.name == "my_index" - assert async_index_from_dict.client == None + assert async_index_from_dict.client is None assert async_index_from_dict.prefix == "rvl" assert async_index_from_dict.key_separator == ":" assert async_index_from_dict.storage_type == StorageType.HASH @@ -156,7 +156,7 @@ async def test_search_index_set_client(async_client, client, async_index): await async_index.set_client(client) await async_index.disconnect() - assert async_index.client == None + assert async_index.client is None @pytest.mark.asyncio From c1999680fe7b83755886bf53307db15a33d51c05 Mon Sep 17 00:00:00 2001 From: Andrew Brookins Date: Wed, 12 Feb 2025 22:47:27 -0800 Subject: [PATCH 04/22] Format, lint --- redisvl/extensions/llmcache/semantic.py | 2 +- redisvl/index/index.py | 3 ++- redisvl/redis/connection.py | 6 ++++-- redisvl/utils/utils.py | 5 ++++- tests/unit/test_utils.py | 1 - 5 files changed, 11 insertions(+), 6 deletions(-) diff --git a/redisvl/extensions/llmcache/semantic.py b/redisvl/extensions/llmcache/semantic.py index abda1f27..4559179b 100644 --- a/redisvl/extensions/llmcache/semantic.py +++ b/redisvl/extensions/llmcache/semantic.py @@ -180,7 +180,7 @@ async def _get_async_index(self) -> AsyncSearchIndex: schema=self._index.schema, redis_client=self.redis_kwargs["redis_client"], redis_url=self.redis_kwargs["redis_url"], - **self.redis_kwargs["connection_kwargs"] + **self.redis_kwargs["connection_kwargs"], ) return self._aindex diff --git a/redisvl/index/index.py b/redisvl/index/index.py index d579154e..79f0cfe2 100644 --- a/redisvl/index/index.py +++ b/redisvl/index/index.py @@ -2,6 +2,7 @@ import atexit import json import threading +import warnings from functools import wraps from typing import ( TYPE_CHECKING, @@ -16,7 +17,7 @@ Union, cast, ) -import warnings + from redisvl.utils.utils import deprecated_function if TYPE_CHECKING: diff --git a/redisvl/redis/connection.py b/redisvl/redis/connection.py index 693c10d9..414edef5 100644 --- a/redisvl/redis/connection.py +++ b/redisvl/redis/connection.py @@ -261,7 +261,9 @@ def get_async_redis_connection(url: Optional[str] = None, **kwargs) -> AsyncRedi # fallback to env var REDIS_URL return AsyncRedis.from_url(get_address_from_env(), **kwargs) - @deprecated_function("sync_to_async_redis", "Please use an async Redis client instead.") + @deprecated_function( + "sync_to_async_redis", "Please use an async Redis client instead." + ) @staticmethod def sync_to_async_redis(redis_client: Redis) -> AsyncRedis: # pick the right connection class @@ -269,7 +271,7 @@ def sync_to_async_redis(redis_client: Redis) -> AsyncRedis: AsyncSSLConnection if redis_client.connection_pool.connection_class == SSLConnection else AsyncConnection - ) # type: ignore + ) # type: ignore # make async client return AsyncRedis.from_pool( # type: ignore AsyncConnectionPool( diff --git a/redisvl/utils/utils.py b/redisvl/utils/utils.py index 8af4aa8b..252ba6bf 100644 --- a/redisvl/utils/utils.py +++ b/redisvl/utils/utils.py @@ -97,10 +97,13 @@ def deprecated_function(name: Optional[str] = None, replacement: Optional[str] = When the wrapped function is called, the decorator will log a deprecation warning. """ + def decorator(func): fn_name = name or func.__name__ - warning_message = f"Function {fn_name} is deprecated and will be " \ + warning_message = ( + f"Function {fn_name} is deprecated and will be " "removed in the next major release. " + ) if replacement: warning_message += replacement diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 36531109..a82431a1 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -239,7 +239,6 @@ async def test_func(dtype=None, vectorizer=None): assert result == 1 - class TestDeprecatedFunction: def test_deprecated_function_warning(self): @deprecated_function("new_func", "Use new_func2") From d27c36ea848f56798c1b6889156aa55c9870bb19 Mon Sep 17 00:00:00 2001 From: Andrew Brookins Date: Thu, 13 Feb 2025 12:29:52 -0800 Subject: [PATCH 05/22] ignore our own deprecation warnings --- tests/integration/test_async_search_index.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/integration/test_async_search_index.py b/tests/integration/test_async_search_index.py index 5815fbec..c6b37d6b 100644 --- a/tests/integration/test_async_search_index.py +++ b/tests/integration/test_async_search_index.py @@ -8,6 +8,19 @@ fields = [{"name": "test", "type": "tag"}] +# Remove deprecation warnings after the next major release +pytestmark = [ + pytest.mark.filterwarnings( + "ignore:connect\\(\\) is deprecated; pass connection parameters in __init__:DeprecationWarning" + ), + pytest.mark.filterwarnings( + "ignore:Converting sync Redis client to async client is deprecated.*:DeprecationWarning" + ), + pytest.mark.filterwarnings( + "ignore:Function .* is deprecated and will be removed in the next major release.*:DeprecationWarning" + ), +] + @pytest.fixture def index_schema(): From 49bf7c4c48ca8c41878fcfa0cb05e4d32defc8fe Mon Sep 17 00:00:00 2001 From: Andrew Brookins Date: Thu, 13 Feb 2025 17:30:38 -0800 Subject: [PATCH 06/22] Put key expiration behind an interface --- redisvl/extensions/llmcache/semantic.py | 5 +- redisvl/index/index.py | 62 ++++++++++++++++++------- 2 files changed, 47 insertions(+), 20 deletions(-) diff --git a/redisvl/extensions/llmcache/semantic.py b/redisvl/extensions/llmcache/semantic.py index 4559179b..0062ef75 100644 --- a/redisvl/extensions/llmcache/semantic.py +++ b/redisvl/extensions/llmcache/semantic.py @@ -281,14 +281,13 @@ async def adrop( def _refresh_ttl(self, key: str) -> None: """Refresh the time-to-live for the specified key.""" if self._ttl: - self._index.client.expire(key, self._ttl) # type: ignore + self._index.expire_keys(key, self._ttl) async def _async_refresh_ttl(self, key: str) -> None: """Async refresh the time-to-live for the specified key.""" aindex = await self._get_async_index() if self._ttl: - client = await aindex.get_client() - await client.expire(key, self._ttl) # type: ignore + await aindex.expire_keys(key, self._ttl) def _vectorize_prompt(self, prompt: Optional[str]) -> List[float]: """Converts a text prompt to its vector representation using the diff --git a/redisvl/index/index.py b/redisvl/index/index.py index 79f0cfe2..72d7511c 100644 --- a/redisvl/index/index.py +++ b/redisvl/index/index.py @@ -1,7 +1,5 @@ import asyncio -import atexit import json -import threading import warnings from functools import wraps from typing import ( @@ -15,7 +13,6 @@ List, Optional, Union, - cast, ) from redisvl.utils.utils import deprecated_function @@ -524,6 +521,23 @@ def drop_keys(self, keys: Union[str, List[str]]) -> int: else: return self._redis_client.delete(keys) # type: ignore + def expire_keys( + self, keys: Union[str, List[str]], ttl: int + ) -> Union[int, List[int]]: + """Set the expiration time for a specific entry or entries in Redis. + + Args: + keys (Union[str, List[str]]): The document ID or IDs to set the expiration for. + ttl (int): The time-to-live in seconds. + """ + if isinstance(keys, list): + pipe = self._redis_client.pipeline() # type: ignore + for key in keys: + pipe.expire(key, ttl) + return pipe.execute() + else: + return self._redis_client.expire(keys, ttl) # type: ignore + def load( self, data: Iterable[Any], @@ -938,10 +952,6 @@ async def _get_client(self): ) return self._redis_client - async def get_client(self) -> aredis.Redis: - """Return this index's async Redis client.""" - return await self._get_client() - async def _validate_client(self, redis_client: aredis.Redis) -> aredis.Redis: if isinstance(redis_client, redis.Redis): warnings.warn( @@ -980,7 +990,7 @@ async def create(self, overwrite: bool = False, drop: bool = False) -> None: # overwrite an index in Redis; drop associated data (clean slate) await index.create(overwrite=True, drop=True) """ - client = await self.get_client() + client = await self._get_client() redis_fields = self.schema.redis_fields if not redis_fields: @@ -1016,7 +1026,7 @@ async def delete(self, drop: bool = True): Raises: redis.exceptions.ResponseError: If the index does not exist. """ - client = await self.get_client() + client = await self._get_client() try: await client.ft(self.schema.index.name).dropindex(delete_documents=drop) except Exception as e: @@ -1029,7 +1039,7 @@ async def clear(self) -> int: Returns: int: Count of records deleted from Redis. """ - client = await self.get_client() + client = await self._get_client() total_records_deleted: int = 0 async for batch in self.paginate( @@ -1050,12 +1060,30 @@ async def drop_keys(self, keys: Union[str, List[str]]) -> int: Returns: int: Count of records deleted from Redis. """ - client = await self.get_client() + client = await self._get_client() if isinstance(keys, list): return await client.delete(*keys) else: return await client.delete(keys) + async def expire_keys( + self, keys: Union[str, List[str]], ttl: int + ) -> Union[int, List[int]]: + """Set the expiration time for a specific entry or entries in Redis. + + Args: + keys (Union[str, List[str]]): The document ID or IDs to set the expiration for. + ttl (int): The time-to-live in seconds. + """ + client = await self._get_client() + if isinstance(keys, list): + pipe = client.pipeline() + for key in keys: + pipe.expire(key, ttl) + return await pipe.execute() + else: + return await client.expire(keys, ttl) + async def load( self, data: Iterable[Any], @@ -1114,7 +1142,7 @@ async def add_field(d): keys = await index.load(data, preprocess=add_field) """ - client = await self.get_client() + client = await self._get_client() try: return await self._storage.awrite( client, @@ -1141,7 +1169,7 @@ async def fetch(self, id: str) -> Optional[Dict[str, Any]]: Returns: Dict[str, Any]: The fetched object. """ - client = await self.get_client() + client = await self._get_client() obj = await self._storage.aget(client, [self.key(id)]) if obj: return convert_bytes(obj[0]) @@ -1157,7 +1185,7 @@ async def aggregate(self, *args, **kwargs) -> "AggregateResult": Returns: Result: Raw Redis aggregation results. """ - client = await self.get_client() + client = await self._get_client() try: return client.ft(self.schema.index.name).aggregate(*args, **kwargs) except Exception as e: @@ -1173,7 +1201,7 @@ async def search(self, *args, **kwargs) -> "Result": Returns: Result: Raw Redis search results. """ - client = await self.get_client() + client = await self._get_client() try: return await client.ft(self.schema.index.name).search(*args, **kwargs) except Exception as e: @@ -1266,7 +1294,7 @@ async def listall(self) -> List[str]: Returns: List[str]: The list of indices in the database. """ - client = await self.get_client() + client = await self._get_client() return convert_bytes(await client.execute_command("FT._LIST")) async def exists(self) -> bool: @@ -1287,7 +1315,7 @@ async def info(self, name: Optional[str] = None) -> Dict[str, Any]: Returns: dict: A dictionary containing the information about the index. """ - client = await self.get_client() + client = await self._get_client() index_name = name or self.schema.index.name return await type(self)._info(index_name, client) From b5af151b59e627a14f072478e0bdfce9f7289121 Mon Sep 17 00:00:00 2001 From: Andrew Brookins Date: Thu, 13 Feb 2025 17:32:27 -0800 Subject: [PATCH 07/22] todo: fix the deprecation warnings instead of hiding --- tests/integration/test_async_search_index.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/tests/integration/test_async_search_index.py b/tests/integration/test_async_search_index.py index c6b37d6b..5815fbec 100644 --- a/tests/integration/test_async_search_index.py +++ b/tests/integration/test_async_search_index.py @@ -8,19 +8,6 @@ fields = [{"name": "test", "type": "tag"}] -# Remove deprecation warnings after the next major release -pytestmark = [ - pytest.mark.filterwarnings( - "ignore:connect\\(\\) is deprecated; pass connection parameters in __init__:DeprecationWarning" - ), - pytest.mark.filterwarnings( - "ignore:Converting sync Redis client to async client is deprecated.*:DeprecationWarning" - ), - pytest.mark.filterwarnings( - "ignore:Function .* is deprecated and will be removed in the next major release.*:DeprecationWarning" - ), -] - @pytest.fixture def index_schema(): From ac42d74657859e30b92c118a95ed42800a26a156 Mon Sep 17 00:00:00 2001 From: Andrew Brookins Date: Fri, 14 Feb 2025 14:43:02 -0800 Subject: [PATCH 08/22] normalize sync and async index interfaces --- docs/user_guide/01_getting_started.ipynb | 151 +++++---- docs/user_guide/02_hybrid_queries.ipynb | 305 ++++++++++++++---- docs/user_guide/04_vectorizers.ipynb | 104 ++++-- docs/user_guide/05_hash_vs_json.ipynb | 10 +- redisvl/cli/index.py | 7 +- redisvl/extensions/llmcache/semantic.py | 15 +- redisvl/extensions/router/semantic.py | 12 +- .../session_manager/semantic_session.py | 13 +- .../session_manager/standard_session.py | 12 +- redisvl/index/index.py | 223 +++++++------ redisvl/redis/connection.py | 78 ++++- redisvl/utils/utils.py | 2 +- tests/conftest.py | 2 +- tests/integration/test_async_search_index.py | 75 ++--- tests/integration/test_connection.py | 67 ++-- tests/integration/test_flow.py | 4 +- tests/integration/test_flow_async.py | 4 +- tests/integration/test_query.py | 6 +- tests/integration/test_search_index.py | 69 ++-- tests/integration/test_search_results.py | 5 +- 20 files changed, 740 insertions(+), 424 deletions(-) diff --git a/docs/user_guide/01_getting_started.ipynb b/docs/user_guide/01_getting_started.ipynb index dfa2b581..2a9bae16 100644 --- a/docs/user_guide/01_getting_started.ipynb +++ b/docs/user_guide/01_getting_started.ipynb @@ -81,7 +81,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -126,7 +126,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -178,7 +178,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -195,7 +195,7 @@ "Now we also need to facilitate a Redis connection. There are a few ways to do this:\n", "\n", "- Create & manage your own client connection (recommended)\n", - "- Provide a simple Redis URL and let RedisVL connect on your behalf" + "- Provide a Redis URL and let RedisVL connect on your behalf (by default, it will connect to \"redis://localhost:6379\")" ] }, { @@ -209,16 +209,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "from redis import Redis\n", "\n", "client = Redis.from_url(\"redis://localhost:6379\")\n", + "index = SearchIndex.from_dict(schema, redis_client=client)\n", "\n", - "index.set_client(client)\n", - "# optionally provide an async Redis client object to enable async index operations" + "# alternatively, provide an async Redis client object to enable async index operations\n", + "# from redis.asyncio import Redis\n", + "# from redisvl.index import AsyncSearchIndex\n", + "# client = Redis.from_url(\"redis://localhost:6379\")\n", + "# index = AsyncSearchIndex.from_dict(schema, redis_client=client)\n" ] }, { @@ -232,23 +236,14 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 7, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "index.connect(\"redis://localhost:6379\")\n", - "# optionally use an async client by passing use_async=True" + "index = SearchIndex.from_dict(schema, redis_url=\"redis://localhost:6379\")\n", + "\n", + "# If you don't specify a client or Redis URL, the index will attempt to\n", + "# connect to Redis at the default address (\"redis://localhost:6379\")." ] }, { @@ -262,7 +257,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -286,15 +281,15 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "\u001b[32m11:53:23\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m Indices:\n", - "\u001b[32m11:53:23\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 1. user_simple\n" + "\u001b[32m11:28:30\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m Indices:\n", + "\u001b[32m11:28:30\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 1. user_simple\n" ] } ], @@ -304,7 +299,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -320,15 +315,15 @@ "│ user_simple │ HASH │ ['user_simple_docs'] │ [] │ 0 │\n", "╰──────────────┴────────────────┴──────────────────────┴─────────────────┴────────────╯\n", "Index Fields:\n", - "╭────────────────┬────────────────┬─────────┬────────────────┬────────────────╮\n", - "│ Name │ Attribute │ Type │ Field Option │ Option Value │\n", - "├────────────────┼────────────────┼─────────┼────────────────┼────────────────┤\n", - "│ user │ user │ TAG │ SEPARATOR │ , │\n", - "│ credit_score │ credit_score │ TAG │ SEPARATOR │ , │\n", - "│ job │ job │ TEXT │ WEIGHT │ 1 │\n", - "│ age │ age │ NUMERIC │ │ │\n", - "│ user_embedding │ user_embedding │ VECTOR │ │ │\n", - "╰────────────────┴────────────────┴─────────┴────────────────┴────────────────╯\n" + "╭────────────────┬────────────────┬─────────┬────────────────┬────────────────┬────────────────┬────────────────┬────────────────┬────────────────┬─────────────────┬────────────────╮\n", + "│ Name │ Attribute │ Type │ Field Option │ Option Value │ Field Option │ Option Value │ Field Option │ Option Value │ Field Option │ Option Value │\n", + "├────────────────┼────────────────┼─────────┼────────────────┼────────────────┼────────────────┼────────────────┼────────────────┼────────────────┼─────────────────┼────────────────┤\n", + "│ user │ user │ TAG │ SEPARATOR │ , │ │ │ │ │ │ │\n", + "│ credit_score │ credit_score │ TAG │ SEPARATOR │ , │ │ │ │ │ │ │\n", + "│ job │ job │ TEXT │ WEIGHT │ 1 │ │ │ │ │ │ │\n", + "│ age │ age │ NUMERIC │ │ │ │ │ │ │ │ │\n", + "│ user_embedding │ user_embedding │ VECTOR │ algorithm │ FLAT │ data_type │ FLOAT32 │ dim │ 3 │ distance_metric │ COSINE │\n", + "╰────────────────┴────────────────┴─────────┴────────────────┴────────────────┴────────────────┴────────────────┴────────────────┴────────────────┴─────────────────┴────────────────╯\n" ] } ], @@ -347,14 +342,14 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "['user_simple_docs:d424b73c516442f7919cc11ed3bb1882', 'user_simple_docs:6da16f88342048e79b3500bec5448805', 'user_simple_docs:ef5a590ef85e4d4888fd8ebe79ae1e8c']\n" + "['user_simple_docs:01JM2YY8AYEH44NDKDE657ARKK', 'user_simple_docs:01JM2YY8AY3GNT0YYFWBSAK4N7', 'user_simple_docs:01JM2YY8AY4FTMMTQANJR3M4RE']\n" ] } ], @@ -381,14 +376,14 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "['user_simple_docs:9806a362604f4700b17513cc94fcf10d']\n" + "['user_simple_docs:01JM2YY8B4RVK4A3MC80PFP2D8']\n" ] } ], @@ -418,7 +413,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -443,9 +438,16 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 14, "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "*=>[KNN 3 @user_embedding $vector AS vector_distance] RETURN 6 user age job credit_score vector_distance vector_distance SORTBY vector_distance ASC DIALECT 2 LIMIT 0 3\n" + ] + }, { "data": { "text/html": [ @@ -476,7 +478,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ @@ -485,13 +487,12 @@ "\n", "client = Redis.from_url(\"redis://localhost:6379\")\n", "\n", - "index = AsyncSearchIndex.from_dict(schema)\n", - "await index.set_client(client)" + "index = AsyncSearchIndex.from_dict(schema, redis_client=client)" ] }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 16, "metadata": {}, "outputs": [ { @@ -532,7 +533,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 17, "metadata": {}, "outputs": [], "source": [ @@ -557,14 +558,14 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "11:53:25 redisvl.index.index INFO Index already exists, overwriting.\n" + "11:28:32 redisvl.index.index INFO Index already exists, overwriting.\n" ] } ], @@ -575,13 +576,13 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
vector_distanceuseragejobcredit_score
0john1engineerhigh
0mary2doctorlow
0.0566299557686tyler9engineerhigh
" + "
vector_distanceuseragejobcredit_score
0mary2doctorlow
0john1engineerhigh
0.0566299557686tyler9engineerhigh
" ], "text/plain": [ "" @@ -607,7 +608,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 20, "metadata": {}, "outputs": [ { @@ -626,16 +627,16 @@ "│ percent_indexed │ 1 │\n", "│ hash_indexing_failures │ 0 │\n", "│ number_of_uses │ 2 │\n", - "│ bytes_per_record_avg │ 1 │\n", - "│ doc_table_size_mb │ 0.00044632 │\n", - "│ inverted_sz_mb │ 1.90735e-05 │\n", - "│ key_table_size_mb │ 0.000138283 │\n", + "│ bytes_per_record_avg │ 47.8 │\n", + "│ doc_table_size_mb │ 0.000423431 │\n", + "│ inverted_sz_mb │ 0.000911713 │\n", + "│ key_table_size_mb │ 0.000165939 │\n", "│ offset_bits_per_record_avg │ nan │\n", "│ offset_vectors_sz_mb │ 0 │\n", "│ offsets_per_term_avg │ 0 │\n", "│ records_per_doc_avg │ 5 │\n", "│ sortable_values_size_mb │ 0 │\n", - "│ total_indexing_time │ 1.796 │\n", + "│ total_indexing_time │ 0.239 │\n", "│ total_inverted_index_blocks │ 11 │\n", "│ vector_index_sz_mb │ 0.235603 │\n", "╰─────────────────────────────┴─────────────╯\n" @@ -666,9 +667,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 21, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "4" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "# (optionally) clear all data from Redis associated with the index\n", "await index.clear()" @@ -676,9 +688,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 22, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "# but the index is still in place\n", "await index.exists()" @@ -686,7 +709,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 23, "metadata": {}, "outputs": [], "source": [ @@ -697,7 +720,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "env", "language": "python", "name": "python3" }, @@ -711,7 +734,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.9" + "version": "3.11.11" }, "orig_nbformat": 4 }, diff --git a/docs/user_guide/02_hybrid_queries.ipynb b/docs/user_guide/02_hybrid_queries.ipynb index 4a89ffb3..e2a9b0d5 100644 --- a/docs/user_guide/02_hybrid_queries.ipynb +++ b/docs/user_guide/02_hybrid_queries.ipynb @@ -83,10 +83,7 @@ "from redisvl.index import SearchIndex\n", "\n", "# construct a search index from the schema\n", - "index = SearchIndex.from_dict(schema)\n", - "\n", - "# connect to local redis instance\n", - "index.connect(\"redis://localhost:6379\")\n", + "index = SearchIndex.from_dict(schema, redis_url=\"redis://localhost:6379\")\n", "\n", "# create the index (no data yet)\n", "index.create(overwrite=True)" @@ -101,8 +98,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "\u001b[32m14:16:51\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m Indices:\n", - "\u001b[32m14:16:51\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 1. user_queries\n" + "\u001b[32m13:03:16\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m Indices:\n", + "\u001b[32m13:03:16\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 1. user_queries\n" ] } ], @@ -142,9 +139,16 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 6, "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "@credit_score:{high}=>[KNN 10 @user_embedding $vector AS vector_distance] RETURN 6 user credit_score age job office_location vector_distance SORTBY vector_distance ASC DIALECT 2 LIMIT 0 10\n" + ] + }, { "data": { "text/html": [ @@ -180,6 +184,13 @@ "execution_count": 7, "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(-@credit_score:{high})=>[KNN 10 @user_embedding $vector AS vector_distance] RETURN 6 user credit_score age job office_location vector_distance SORTBY vector_distance ASC DIALECT 2 LIMIT 0 10\n" + ] + }, { "data": { "text/html": [ @@ -206,6 +217,13 @@ "execution_count": 8, "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "@credit_score:{high|medium}=>[KNN 10 @user_embedding $vector AS vector_distance] RETURN 6 user credit_score age job office_location vector_distance SORTBY vector_distance ASC DIALECT 2 LIMIT 0 10\n" + ] + }, { "data": { "text/html": [ @@ -232,6 +250,13 @@ "execution_count": 9, "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "@credit_score:{high|medium}=>[KNN 10 @user_embedding $vector AS vector_distance] RETURN 6 user credit_score age job office_location vector_distance SORTBY vector_distance ASC DIALECT 2 LIMIT 0 10\n" + ] + }, { "data": { "text/html": [ @@ -269,6 +294,13 @@ "execution_count": 10, "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "*=>[KNN 10 @user_embedding $vector AS vector_distance] RETURN 6 user credit_score age job office_location vector_distance SORTBY vector_distance ASC DIALECT 2 LIMIT 0 10\n" + ] + }, { "data": { "text/html": [ @@ -304,6 +336,13 @@ "execution_count": 11, "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "@age:[(15 +inf]=>[KNN 10 @user_embedding $vector AS vector_distance] RETURN 6 user credit_score age job office_location vector_distance SORTBY vector_distance ASC DIALECT 2 LIMIT 0 10\n" + ] + }, { "data": { "text/html": [ @@ -331,6 +370,13 @@ "execution_count": 12, "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "@age:[14 14]=>[KNN 10 @user_embedding $vector AS vector_distance] RETURN 6 user credit_score age job office_location vector_distance SORTBY vector_distance ASC DIALECT 2 LIMIT 0 10\n" + ] + }, { "data": { "text/html": [ @@ -357,6 +403,13 @@ "execution_count": 13, "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(-@age:[14 14])=>[KNN 10 @user_embedding $vector AS vector_distance] RETURN 6 user credit_score age job office_location vector_distance SORTBY vector_distance ASC DIALECT 2 LIMIT 0 10\n" + ] + }, { "data": { "text/html": [ @@ -392,6 +445,13 @@ "execution_count": 14, "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "@job:(\"doctor\")=>[KNN 10 @user_embedding $vector AS vector_distance] RETURN 6 user credit_score age job office_location vector_distance SORTBY vector_distance ASC DIALECT 2 LIMIT 0 10\n" + ] + }, { "data": { "text/html": [ @@ -420,6 +480,13 @@ "execution_count": 15, "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(-@job:\"doctor\")=>[KNN 10 @user_embedding $vector AS vector_distance] RETURN 6 user credit_score age job office_location vector_distance SORTBY vector_distance ASC DIALECT 2 LIMIT 0 10\n" + ] + }, { "data": { "text/html": [ @@ -446,6 +513,13 @@ "execution_count": 16, "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "@job:(doct*)=>[KNN 10 @user_embedding $vector AS vector_distance] RETURN 6 user credit_score age job office_location vector_distance SORTBY vector_distance ASC DIALECT 2 LIMIT 0 10\n" + ] + }, { "data": { "text/html": [ @@ -472,6 +546,13 @@ "execution_count": 17, "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "@job:(%%engine%%)=>[KNN 10 @user_embedding $vector AS vector_distance] RETURN 6 user credit_score age job office_location vector_distance SORTBY vector_distance ASC DIALECT 2 LIMIT 0 10\n" + ] + }, { "data": { "text/html": [ @@ -498,6 +579,13 @@ "execution_count": 18, "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "@job:(engineer|doctor)=>[KNN 10 @user_embedding $vector AS vector_distance] RETURN 6 user credit_score age job office_location vector_distance SORTBY vector_distance ASC DIALECT 2 LIMIT 0 10\n" + ] + }, { "data": { "text/html": [ @@ -524,6 +612,13 @@ "execution_count": 19, "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "*=>[KNN 10 @user_embedding $vector AS vector_distance] RETURN 6 user credit_score age job office_location vector_distance SORTBY vector_distance ASC DIALECT 2 LIMIT 0 10\n" + ] + }, { "data": { "text/html": [ @@ -554,21 +649,28 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 20, "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(~(@job:engineer))=>[KNN 10 @user_embedding $vector AS vector_distance] SCORER BM25 WITHSCORES RETURN 6 user credit_score age job office_location vector_distance SORTBY vector_distance ASC DIALECT 2 LIMIT 0 10\n" + ] + }, { "data": { "text/plain": [ - "[{'id': 'user_queries_docs:409ff48274724984ba14865db0495fc5',\n", - " 'score': 0.9090908893868948,\n", + "[{'id': 'user_queries_docs:01JM34BQDEBYP2ZQTRK6SEAHHT',\n", + " 'score': 1.8181817787737895,\n", " 'vector_distance': '0',\n", " 'user': 'john',\n", " 'credit_score': 'high',\n", " 'age': '18',\n", " 'job': 'engineer',\n", " 'office_location': '-122.4194,37.7749'},\n", - " {'id': 'user_queries_docs:69cb262c303a4147b213dfdec8bd4b01',\n", + " {'id': 'user_queries_docs:01JM34BQDEEV60RV4VXK0JBBBK',\n", " 'score': 0.0,\n", " 'vector_distance': '0',\n", " 'user': 'derrick',\n", @@ -576,15 +678,15 @@ " 'age': '14',\n", " 'job': 'doctor',\n", " 'office_location': '-122.4194,37.7749'},\n", - " {'id': 'user_queries_docs:562263669ff74a0295c515018d151d7b',\n", - " 'score': 0.9090908893868948,\n", + " {'id': 'user_queries_docs:01JM34BQDE6QKS9N8G49JN4V71',\n", + " 'score': 1.8181817787737895,\n", " 'vector_distance': '0.109129190445',\n", " 'user': 'tyler',\n", " 'credit_score': 'high',\n", " 'age': '100',\n", " 'job': 'engineer',\n", " 'office_location': '-122.0839,37.3861'},\n", - " {'id': 'user_queries_docs:94176145f9de4e288ca2460cd5d1188e',\n", + " {'id': 'user_queries_docs:01JM34BQDEZ9AXAQRE6GYAQPSP',\n", " 'score': 0.0,\n", " 'vector_distance': '0.158808946609',\n", " 'user': 'tim',\n", @@ -592,7 +694,7 @@ " 'age': '12',\n", " 'job': 'dermatologist',\n", " 'office_location': '-122.0839,37.3861'},\n", - " {'id': 'user_queries_docs:d0bcf6842862410583901004b6b3aeba',\n", + " {'id': 'user_queries_docs:01JM34BQDEM8JT6XC3YJRN72R4',\n", " 'score': 0.0,\n", " 'vector_distance': '0.217882037163',\n", " 'user': 'taimur',\n", @@ -600,7 +702,7 @@ " 'age': '15',\n", " 'job': 'CEO',\n", " 'office_location': '-122.0839,37.3861'},\n", - " {'id': 'user_queries_docs:3dec0e9f2db04e19bff224c5a2a0ba3c',\n", + " {'id': 'user_queries_docs:01JM34BQDEJK2B70X6B11JBQ82',\n", " 'score': 0.0,\n", " 'vector_distance': '0.266666650772',\n", " 'user': 'nancy',\n", @@ -608,7 +710,7 @@ " 'age': '94',\n", " 'job': 'doctor',\n", " 'office_location': '-122.4194,37.7749'},\n", - " {'id': 'user_queries_docs:93ee6c0e4ccb42f6b7af7858ea6a6408',\n", + " {'id': 'user_queries_docs:01JM34BQDEBCVNY5BZSVD16N9R',\n", " 'score': 0.0,\n", " 'vector_distance': '0.653301358223',\n", " 'user': 'joe',\n", @@ -618,7 +720,7 @@ " 'office_location': '-122.0839,37.3861'}]" ] }, - "execution_count": 32, + "execution_count": 20, "metadata": {}, "output_type": "execute_result" } @@ -641,9 +743,16 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 21, "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "@office_location:[-122.4194 37.7749 10 km]=>[KNN 10 @user_embedding $vector AS vector_distance] SCORER BM25 WITHSCORES RETURN 6 user credit_score age job office_location vector_distance SORTBY vector_distance ASC DIALECT 2 LIMIT 0 10\n" + ] + }, { "data": { "text/html": [ @@ -669,9 +778,16 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 22, "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "@office_location:[-122.4194 37.7749 100 km]=>[KNN 10 @user_embedding $vector AS vector_distance] SCORER BM25 WITHSCORES RETURN 6 user credit_score age job office_location vector_distance SORTBY vector_distance ASC DIALECT 2 LIMIT 0 10\n" + ] + }, { "data": { "text/html": [ @@ -695,9 +811,16 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 23, "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(-@office_location:[-122.4194 37.7749 10 km])=>[KNN 10 @user_embedding $vector AS vector_distance] SCORER BM25 WITHSCORES RETURN 6 user credit_score age job office_location vector_distance SORTBY vector_distance ASC DIALECT 2 LIMIT 0 10\n" + ] + }, { "data": { "text/html": [ @@ -732,9 +855,16 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 24, "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "((@credit_score:{high} @age:[18 +inf]) @age:[-inf 100])=>[KNN 10 @user_embedding $vector AS vector_distance] RETURN 6 user credit_score age job office_location vector_distance SORTBY vector_distance ASC DIALECT 2 LIMIT 0 10\n" + ] + }, { "data": { "text/html": [ @@ -775,9 +905,16 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": 25, "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(@age:[-inf (18] | @age:[(93 +inf])=>[KNN 10 @user_embedding $vector AS vector_distance] RETURN 6 user credit_score age job office_location vector_distance SORTBY vector_distance ASC DIALECT 2 LIMIT 0 10\n" + ] + }, { "data": { "text/html": [ @@ -819,7 +956,7 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": 26, "metadata": {}, "outputs": [], "source": [ @@ -834,9 +971,16 @@ }, { "cell_type": "code", - "execution_count": 39, + "execution_count": 27, "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "((@age:[(18 +inf] @credit_score:{high}) @job:(engineer))=>[KNN 10 @user_embedding $vector AS vector_distance] RETURN 6 user credit_score age job office_location vector_distance SORTBY vector_distance ASC DIALECT 2 LIMIT 0 10\n" + ] + }, { "data": { "text/html": [ @@ -859,9 +1003,16 @@ }, { "cell_type": "code", - "execution_count": 40, + "execution_count": 28, "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(@age:[(18 +inf] @credit_score:{high})=>[KNN 10 @user_embedding $vector AS vector_distance] RETURN 6 user credit_score age job office_location vector_distance SORTBY vector_distance ASC DIALECT 2 LIMIT 0 10\n" + ] + }, { "data": { "text/html": [ @@ -884,9 +1035,16 @@ }, { "cell_type": "code", - "execution_count": 41, + "execution_count": 29, "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "@age:[(18 +inf]=>[KNN 10 @user_embedding $vector AS vector_distance] RETURN 6 user credit_score age job office_location vector_distance SORTBY vector_distance ASC DIALECT 2 LIMIT 0 10\n" + ] + }, { "data": { "text/html": [ @@ -909,9 +1067,16 @@ }, { "cell_type": "code", - "execution_count": 42, + "execution_count": 30, "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "*=>[KNN 10 @user_embedding $vector AS vector_distance] RETURN 6 user credit_score age job office_location vector_distance SORTBY vector_distance ASC DIALECT 2 LIMIT 0 10\n" + ] + }, { "data": { "text/html": [ @@ -943,9 +1108,16 @@ }, { "cell_type": "code", - "execution_count": 43, + "execution_count": 31, "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "@credit_score:{low} RETURN 5 user credit_score age job location DIALECT 2 LIMIT 0 10\n" + ] + }, { "data": { "text/html": [ @@ -985,13 +1157,14 @@ }, { "cell_type": "code", - "execution_count": 44, + "execution_count": 32, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ + "@credit_score:{low} NOCONTENT DIALECT 2 LIMIT 0 0\n", "2 records match the filter expression @credit_score:{low} for the given index.\n" ] } @@ -1019,9 +1192,16 @@ }, { "cell_type": "code", - "execution_count": 45, + "execution_count": 33, "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "@user_embedding:[VECTOR_RANGE $distance_threshold $vector]=>{$yield_distance_as: vector_distance} RETURN 6 user credit_score age job location vector_distance SORTBY vector_distance ASC DIALECT 2 LIMIT 0 10\n" + ] + }, { "data": { "text/html": [ @@ -1060,9 +1240,16 @@ }, { "cell_type": "code", - "execution_count": 46, + "execution_count": 34, "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "@user_embedding:[VECTOR_RANGE $distance_threshold $vector]=>{$yield_distance_as: vector_distance} RETURN 6 user credit_score age job location vector_distance SORTBY vector_distance ASC DIALECT 2 LIMIT 0 10\n" + ] + }, { "data": { "text/html": [ @@ -1091,9 +1278,16 @@ }, { "cell_type": "code", - "execution_count": 47, + "execution_count": 35, "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(@user_embedding:[VECTOR_RANGE $distance_threshold $vector]=>{$yield_distance_as: vector_distance} @job:(\"engineer\")) RETURN 6 user credit_score age job location vector_distance SORTBY vector_distance ASC DIALECT 2 LIMIT 0 10\n" + ] + }, { "data": { "text/html": [ @@ -1131,9 +1325,16 @@ }, { "cell_type": "code", - "execution_count": 48, + "execution_count": 36, "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "@job:(\"engineer\")=>[KNN 5 @user_embedding $vector AS vector_distance] RETURN 6 user credit_score age job office_location vector_distance SORTBY age DESC DIALECT 3 LIMIT 0 5\n" + ] + }, { "data": { "text/html": [ @@ -1172,7 +1373,7 @@ }, { "cell_type": "code", - "execution_count": 49, + "execution_count": 37, "metadata": {}, "outputs": [ { @@ -1181,7 +1382,7 @@ "'@job:(\"engineer\")=>[KNN 5 @user_embedding $vector AS vector_distance] RETURN 6 user credit_score age job office_location vector_distance SORTBY age DESC DIALECT 3 LIMIT 0 5'" ] }, - "execution_count": 49, + "execution_count": 37, "metadata": {}, "output_type": "execute_result" } @@ -1193,7 +1394,7 @@ }, { "cell_type": "code", - "execution_count": 50, + "execution_count": 38, "metadata": {}, "outputs": [ { @@ -1202,7 +1403,7 @@ "'@credit_score:{high}'" ] }, - "execution_count": 50, + "execution_count": 38, "metadata": {}, "output_type": "execute_result" } @@ -1215,7 +1416,7 @@ }, { "cell_type": "code", - "execution_count": 51, + "execution_count": 39, "metadata": {}, "outputs": [ { @@ -1224,7 +1425,7 @@ "'((@credit_score:{high} @age:[18 +inf]) @age:[-inf 100])'" ] }, - "execution_count": 51, + "execution_count": 39, "metadata": {}, "output_type": "execute_result" } @@ -1249,17 +1450,18 @@ }, { "cell_type": "code", - "execution_count": 52, + "execution_count": 40, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "{'id': 'user_queries_docs:409ff48274724984ba14865db0495fc5', 'payload': None, 'user': 'john', 'age': '18', 'job': 'engineer', 'credit_score': 'high', 'office_location': '-122.4194,37.7749', 'user_embedding': '==\\x00\\x00\\x00?'}\n", - "{'id': 'user_queries_docs:3dec0e9f2db04e19bff224c5a2a0ba3c', 'payload': None, 'user': 'nancy', 'age': '94', 'job': 'doctor', 'credit_score': 'high', 'office_location': '-122.4194,37.7749', 'user_embedding': '333?=\\x00\\x00\\x00?'}\n", - "{'id': 'user_queries_docs:562263669ff74a0295c515018d151d7b', 'payload': None, 'user': 'tyler', 'age': '100', 'job': 'engineer', 'credit_score': 'high', 'office_location': '-122.0839,37.3861', 'user_embedding': '=>\\x00\\x00\\x00?'}\n", - "{'id': 'user_queries_docs:94176145f9de4e288ca2460cd5d1188e', 'payload': None, 'user': 'tim', 'age': '12', 'job': 'dermatologist', 'credit_score': 'high', 'office_location': '-122.0839,37.3861', 'user_embedding': '>>\\x00\\x00\\x00?'}\n" + "\n", + "{'id': 'user_queries_docs:01JM34BQDEBYP2ZQTRK6SEAHHT', 'payload': None, 'user': 'john', 'age': '18', 'job': 'engineer', 'credit_score': 'high', 'office_location': '-122.4194,37.7749', 'user_embedding': '==\\x00\\x00\\x00?'}\n", + "{'id': 'user_queries_docs:01JM34BQDEJK2B70X6B11JBQ82', 'payload': None, 'user': 'nancy', 'age': '94', 'job': 'doctor', 'credit_score': 'high', 'office_location': '-122.4194,37.7749', 'user_embedding': '333?=\\x00\\x00\\x00?'}\n", + "{'id': 'user_queries_docs:01JM34BQDE6QKS9N8G49JN4V71', 'payload': None, 'user': 'tyler', 'age': '100', 'job': 'engineer', 'credit_score': 'high', 'office_location': '-122.0839,37.3861', 'user_embedding': '=>\\x00\\x00\\x00?'}\n", + "{'id': 'user_queries_docs:01JM34BQDEZ9AXAQRE6GYAQPSP', 'payload': None, 'user': 'tim', 'age': '12', 'job': 'dermatologist', 'credit_score': 'high', 'office_location': '-122.0839,37.3861', 'user_embedding': '>>\\x00\\x00\\x00?'}\n" ] } ], @@ -1271,7 +1473,7 @@ }, { "cell_type": "code", - "execution_count": 53, + "execution_count": 41, "metadata": {}, "outputs": [], "source": [ @@ -1282,7 +1484,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3.8.13 ('redisvl2')", + "display_name": "env", "language": "python", "name": "python3" }, @@ -1296,14 +1498,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.10" + "version": "3.11.11" }, - "orig_nbformat": 4, - "vscode": { - "interpreter": { - "hash": "9b1e6e9c2967143209c2f955cb869d1d3234f92dc4787f49f155f3abbdfb1316" - } - } + "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 diff --git a/docs/user_guide/04_vectorizers.ipynb b/docs/user_guide/04_vectorizers.ipynb index d5870b88..63f8d705 100644 --- a/docs/user_guide/04_vectorizers.ipynb +++ b/docs/user_guide/04_vectorizers.ipynb @@ -32,7 +32,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -68,7 +68,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -80,9 +80,36 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Vector dimensions: 1536\n" + ] + }, + { + "data": { + "text/plain": [ + "[-0.0011391325388103724,\n", + " -0.003206387162208557,\n", + " 0.002380132209509611,\n", + " -0.004501554183661938,\n", + " -0.010328996926546097,\n", + " 0.012922565452754498,\n", + " -0.005491119809448719,\n", + " -0.0029864837415516376,\n", + " -0.007327961269766092,\n", + " -0.03365817293524742]" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "from redisvl.utils.vectorize import OpenAITextVectorizer\n", "\n", @@ -99,9 +126,29 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "[-0.017466850578784943,\n", + " 1.8471690054866485e-05,\n", + " 0.00129731057677418,\n", + " -0.02555876597762108,\n", + " -0.019842341542243958,\n", + " 0.01603139191865921,\n", + " -0.0037347301840782166,\n", + " 0.0009670283179730177,\n", + " 0.006618348415941,\n", + " -0.02497442066669464]" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "# Create many embeddings at once\n", "sentences = [\n", @@ -116,9 +163,17 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of Embeddings: 3\n" + ] + } + ], "source": [ "# openai also supports asyncronous requests, which we can use to speed up the vectorization process.\n", "embeddings = await oai.aembed_many(sentences)\n", @@ -138,7 +193,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -151,9 +206,23 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": {}, - "outputs": [], + "outputs": [ + { + "ename": "ValueError", + "evalue": "AzureOpenAI API endpoint is required. Provide it in api_config or set the AZURE_OPENAI_ENDPOINT environment variable.", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[7], line 4\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mredisvl\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mutils\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mvectorize\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m AzureOpenAITextVectorizer\n\u001b[1;32m 3\u001b[0m \u001b[38;5;66;03m# create a vectorizer\u001b[39;00m\n\u001b[0;32m----> 4\u001b[0m az_oai \u001b[38;5;241m=\u001b[39m \u001b[43mAzureOpenAITextVectorizer\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 5\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdeployment_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Must be your CUSTOM deployment name\u001b[39;49;00m\n\u001b[1;32m 6\u001b[0m \u001b[43m \u001b[49m\u001b[43mapi_config\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m{\u001b[49m\n\u001b[1;32m 7\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mapi_key\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43mapi_key\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 8\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mapi_version\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43mapi_version\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 9\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mazure_endpoint\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43mazure_endpoint\u001b[49m\n\u001b[1;32m 10\u001b[0m \u001b[43m \u001b[49m\u001b[43m}\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 11\u001b[0m \u001b[43m)\u001b[49m\n\u001b[1;32m 13\u001b[0m test \u001b[38;5;241m=\u001b[39m az_oai\u001b[38;5;241m.\u001b[39membed(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mThis is a test sentence.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 14\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mVector dimensions: \u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28mlen\u001b[39m(test))\n", + "File \u001b[0;32m~/src/redis-vl-python/redisvl/utils/vectorize/text/azureopenai.py:78\u001b[0m, in \u001b[0;36mAzureOpenAITextVectorizer.__init__\u001b[0;34m(self, model, api_config, dtype)\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21m__init__\u001b[39m(\n\u001b[1;32m 55\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 56\u001b[0m model: \u001b[38;5;28mstr\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtext-embedding-ada-002\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 57\u001b[0m api_config: Optional[Dict] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 58\u001b[0m dtype: \u001b[38;5;28mstr\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 59\u001b[0m ):\n\u001b[1;32m 60\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Initialize the AzureOpenAI vectorizer.\u001b[39;00m\n\u001b[1;32m 61\u001b[0m \n\u001b[1;32m 62\u001b[0m \u001b[38;5;124;03m Args:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 76\u001b[0m \u001b[38;5;124;03m ValueError: If an invalid dtype is provided.\u001b[39;00m\n\u001b[1;32m 77\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m---> 78\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_initialize_clients\u001b[49m\u001b[43m(\u001b[49m\u001b[43mapi_config\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 79\u001b[0m \u001b[38;5;28msuper\u001b[39m()\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__init__\u001b[39m(model\u001b[38;5;241m=\u001b[39mmodel, dims\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_set_model_dims(model), dtype\u001b[38;5;241m=\u001b[39mdtype)\n", + "File \u001b[0;32m~/src/redis-vl-python/redisvl/utils/vectorize/text/azureopenai.py:106\u001b[0m, in \u001b[0;36mAzureOpenAITextVectorizer._initialize_clients\u001b[0;34m(self, api_config)\u001b[0m\n\u001b[1;32m 99\u001b[0m azure_endpoint \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 100\u001b[0m api_config\u001b[38;5;241m.\u001b[39mpop(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mazure_endpoint\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 101\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m api_config\n\u001b[1;32m 102\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m os\u001b[38;5;241m.\u001b[39mgetenv(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mAZURE_OPENAI_ENDPOINT\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 103\u001b[0m )\n\u001b[1;32m 105\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m azure_endpoint:\n\u001b[0;32m--> 106\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 107\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mAzureOpenAI API endpoint is required. \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 108\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mProvide it in api_config or set the AZURE_OPENAI_ENDPOINT\u001b[39m\u001b[38;5;130;01m\\\u001b[39;00m\n\u001b[1;32m 109\u001b[0m \u001b[38;5;124m environment variable.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 110\u001b[0m )\n\u001b[1;32m 112\u001b[0m api_version \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 113\u001b[0m api_config\u001b[38;5;241m.\u001b[39mpop(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mapi_version\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m api_config\n\u001b[1;32m 115\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m os\u001b[38;5;241m.\u001b[39mgetenv(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mOPENAI_API_VERSION\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 116\u001b[0m )\n\u001b[1;32m 118\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m api_version:\n", + "\u001b[0;31mValueError\u001b[0m: AzureOpenAI API endpoint is required. Provide it in api_config or set the AZURE_OPENAI_ENDPOINT environment variable." + ] + } + ], "source": [ "from redisvl.utils.vectorize import AzureOpenAITextVectorizer\n", "\n", @@ -589,10 +658,7 @@ "from redisvl.index import SearchIndex\n", "\n", "# construct a search index from the schema\n", - "index = SearchIndex.from_yaml(\"./schema.yaml\")\n", - "\n", - "# connect to local redis instance\n", - "index.connect(\"redis://localhost:6379\")\n", + "index = SearchIndex.from_yaml(\"./schema.yaml\", redis_url=\"redis://localhost:6379\")\n", "\n", "# create the index (no data yet)\n", "index.create(overwrite=True)" @@ -726,7 +792,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "env", "language": "python", "name": "python3" }, @@ -740,10 +806,10 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.9" + "version": "3.11.11" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 -} \ No newline at end of file +} diff --git a/docs/user_guide/05_hash_vs_json.ipynb b/docs/user_guide/05_hash_vs_json.ipynb index 071cff5c..2ab3b01d 100644 --- a/docs/user_guide/05_hash_vs_json.ipynb +++ b/docs/user_guide/05_hash_vs_json.ipynb @@ -139,10 +139,7 @@ "outputs": [], "source": [ "# construct a search index from the hash schema\n", - "hindex = SearchIndex.from_dict(hash_schema)\n", - "\n", - "# connect to local redis instance\n", - "hindex.connect(\"redis://localhost:6379\")\n", + "hindex = SearchIndex.from_dict(hash_schema, redis_url=\"redis://localhost:6379\")\n", "\n", "# create the index (no data yet)\n", "hindex.create(overwrite=True)" @@ -382,10 +379,7 @@ "outputs": [], "source": [ "# construct a search index from the json schema\n", - "jindex = SearchIndex.from_dict(json_schema)\n", - "\n", - "# connect to local redis instance\n", - "jindex.connect(\"redis://localhost:6379\")\n", + "jindex = SearchIndex.from_dict(json_schema, redis_url=\"redis://localhost:6379\")\n", "\n", "# create the index (no data yet)\n", "jindex.create(overwrite=True)" diff --git a/redisvl/cli/index.py b/redisvl/cli/index.py index c5b350b3..bbb0d3c7 100644 --- a/redisvl/cli/index.py +++ b/redisvl/cli/index.py @@ -59,9 +59,7 @@ def create(self, args: Namespace): """ if not args.schema: logger.error("Schema must be provided to create an index") - index = SearchIndex.from_yaml(args.schema) - redis_url = create_redis_url(args) - index.connect(redis_url) + index = SearchIndex.from_yaml(args.schema, redis_url=create_redis_url(args)) index.create() logger.info("Index created successfully") @@ -120,8 +118,7 @@ def _connect_to_index(self, args: Namespace) -> SearchIndex: schema = IndexSchema.from_dict({"index": {"name": args.index}}) index = SearchIndex(schema=schema, redis_url=redis_url) elif args.schema: - index = SearchIndex.from_yaml(args.schema) - index.set_client(conn) + index = SearchIndex.from_yaml(args.schema, redis_client=conn) else: logger.error("Index name or schema must be provided") exit(0) diff --git a/redisvl/extensions/llmcache/semantic.py b/redisvl/extensions/llmcache/semantic.py index 0062ef75..176ad722 100644 --- a/redisvl/extensions/llmcache/semantic.py +++ b/redisvl/extensions/llmcache/semantic.py @@ -128,13 +128,12 @@ def __init__( name, prefix, vectorizer.dims, vectorizer.dtype ) schema = self._modify_schema(schema, filterable_fields) - self._index = SearchIndex(schema=schema) - - # Handle redis connection - if redis_client: - self._index.set_client(redis_client) - elif redis_url: - self._index.connect(redis_url=redis_url, **connection_kwargs) + self._index = SearchIndex( + schema=schema, + redis_client=redis_client, + redis_url=redis_url, + **connection_kwargs, + ) # Check for existing cache index if not overwrite and self._index.exists(): @@ -180,7 +179,7 @@ async def _get_async_index(self) -> AsyncSearchIndex: schema=self._index.schema, redis_client=self.redis_kwargs["redis_client"], redis_url=self.redis_kwargs["redis_url"], - **self.redis_kwargs["connection_kwargs"], + connection_kwargs=self.redis_kwargs["connection_kwargs"], ) return self._aindex diff --git a/redisvl/extensions/router/semantic.py b/redisvl/extensions/router/semantic.py index 8a349f46..5454e0ee 100644 --- a/redisvl/extensions/router/semantic.py +++ b/redisvl/extensions/router/semantic.py @@ -111,12 +111,12 @@ def _initialize_index( schema = SemanticRouterIndexSchema.from_params( self.name, self.vectorizer.dims, self.vectorizer.dtype ) - self._index = SearchIndex(schema=schema) - - if redis_client: - self._index.set_client(redis_client) - elif redis_url: - self._index.connect(redis_url=redis_url, **connection_kwargs) + self._index = SearchIndex( + schema=schema, + redis_client=redis_client, + redis_url=redis_url, + **connection_kwargs, + ) # Check for existing router index existed = self._index.exists() diff --git a/redisvl/extensions/session_manager/semantic_session.py b/redisvl/extensions/session_manager/semantic_session.py index 74d8f4ab..9262795e 100644 --- a/redisvl/extensions/session_manager/semantic_session.py +++ b/redisvl/extensions/session_manager/semantic_session.py @@ -97,13 +97,12 @@ def __init__( name, prefix, self._vectorizer.dims, vectorizer.dtype ) - self._index = SearchIndex(schema=schema) - - # handle redis connection - if redis_client: - self._index.set_client(redis_client) - elif redis_url: - self._index.connect(redis_url=redis_url, **connection_kwargs) + self._index = SearchIndex( + schema=schema, + redis_client=redis_client, + redis_url=redis_url, + **connection_kwargs, + ) # Check for existing session index if not overwrite and self._index.exists(): diff --git a/redisvl/extensions/session_manager/standard_session.py b/redisvl/extensions/session_manager/standard_session.py index 42b86628..4e46010c 100644 --- a/redisvl/extensions/session_manager/standard_session.py +++ b/redisvl/extensions/session_manager/standard_session.py @@ -61,12 +61,12 @@ def __init__( schema = StandardSessionIndexSchema.from_params(name, prefix) - self._index = SearchIndex(schema=schema) - - if redis_client: - self._index.set_client(redis_client) - else: - self._index.connect(redis_url=redis_url, **connection_kwargs) + self._index = SearchIndex( + schema=schema, + redis_client=redis_client, + redis_url=redis_url, + **connection_kwargs, + ) self._index.create(overwrite=False) diff --git a/redisvl/index/index.py b/redisvl/index/index.py index 72d7511c..cffb57a2 100644 --- a/redisvl/index/index.py +++ b/redisvl/index/index.py @@ -1,5 +1,7 @@ import asyncio import json +from os import replace +import threading import warnings from functools import wraps from typing import ( @@ -15,7 +17,7 @@ Union, ) -from redisvl.utils.utils import deprecated_function +from redisvl.utils.utils import deprecated_argument, deprecated_function if TYPE_CHECKING: from redis.commands.search.aggregation import AggregateResult @@ -96,36 +98,6 @@ def _process(doc: "Document") -> Dict[str, Any]: return [_process(doc) for doc in results.docs] -def setup_redis(): - def decorator(func): - @wraps(func) - def wrapper(self, *args, **kwargs): - result = func(self, *args, **kwargs) - RedisConnectionFactory.validate_sync_redis( - self._redis_client, self._lib_name - ) - return result - - return wrapper - - return decorator - - -def setup_async_redis(): - def decorator(func): - @wraps(func) - async def wrapper(self, *args, **kwargs): - result = await func(self, *args, **kwargs) - await RedisConnectionFactory.validate_async_redis( - self._redis_client, self._lib_name - ) - return result - - return wrapper - - return decorator - - class BaseSearchIndex: """Base search engine class""" @@ -220,8 +192,7 @@ def from_dict(cls, schema_dict: Dict[str, Any], **kwargs): def disconnect(self): """Disconnect from the Redis database.""" - self._redis_client = None - return self + raise NotImplementedError("This method should be implemented by subclasses.") def key(self, id: str) -> str: """Construct a redis key as a combination of an index key prefix (optional) @@ -257,8 +228,7 @@ class SearchIndex(BaseSearchIndex): from redisvl.index import SearchIndex # initialize the index object with schema from file - index = SearchIndex.from_yaml("schemas/schema.yaml") - index.connect(redis_url="redis://localhost:6379") + index = SearchIndex.from_yaml("schemas/schema.yaml", redis_url="redis://localhost:6379") # create the index index.create(overwrite=True) @@ -271,12 +241,18 @@ class SearchIndex(BaseSearchIndex): """ + required_modules = [ + {"name": "search", "ver": 20810}, + {"name": "searchlight", "ver": 20810}, + ] + + @deprecated_argument("connection_args", "Use connection_kwargs instead.") def __init__( self, schema: IndexSchema, redis_client: Optional[redis.Redis] = None, redis_url: Optional[str] = None, - connection_args: Dict[str, Any] = {}, + connection_kwargs: Optional[Dict[str, Any]] = None, **kwargs, ): """Initialize the RedisVL search index with a schema, Redis client @@ -289,10 +265,12 @@ def __init__( instantiated redis client. redis_url (Optional[str]): The URL of the Redis server to connect to. - connection_args (Dict[str, Any], optional): Redis client connection + connection_kwargs (Dict[str, Any], optional): Redis client connection args. """ - # final validation on schema object + if "connection_args" in kwargs: + connection_kwargs = kwargs.pop("connection_args") + if not isinstance(schema, IndexSchema): raise ValueError("Must provide a valid IndexSchema object") @@ -300,13 +278,17 @@ def __init__( self._lib_name: Optional[str] = kwargs.pop("lib_name", None) - # set up redis connection - self._redis_client: Optional[redis.Redis] = None + # Store connection parameters + self.__redis_client = redis_client + self._redis_url = redis_url + self._connection_kwargs = connection_kwargs or {} + self._lock = threading.Lock() - if redis_client is not None: - self.set_client(redis_client) - elif redis_url is not None: - self.connect(redis_url, **connection_args) + def disconnect(self): + """Disconnect from the Redis database.""" + if self.__redis_client: + self.__redis_client.close() + self.__redis_client = None @classmethod def from_existing( @@ -314,6 +296,7 @@ def from_existing( name: str, redis_client: Optional[redis.Redis] = None, redis_url: Optional[str] = None, + required_modules: Optional[List[Dict[str, Any]]] = None, **kwargs, ): """ @@ -325,31 +308,30 @@ def from_existing( instantiated redis client. redis_url (Optional[str]): The URL of the Redis server to connect to. - """ - # Handle redis instance - if redis_url: - redis_client = RedisConnectionFactory.connect( - redis_url=redis_url, use_async=False, **kwargs - ) - if not redis_client: - raise ValueError( - "Must provide either a redis_url or redis_client to fetch Redis index info." - ) - - # Validate modules - installed_modules = RedisConnectionFactory.get_modules(redis_client) + required_modules (Optional[List[Dict[str, Any]]]): List of required + Redis modules with version requirements. + Raises: + ValueError: If redis_url or redis_client is not provided. + RedisModuleVersionError: If required Redis modules are not installed. + """ try: - required_modules = [ - {"name": "search", "ver": 20810}, - {"name": "searchlight", "ver": 20810}, - ] - validate_modules(installed_modules, required_modules) + if redis_url: + redis_client = RedisConnectionFactory.get_redis_connection( + redis_url=redis_url, required_modules=required_modules, **kwargs + ) + elif redis_client: + RedisConnectionFactory.validate_sync_redis( + redis_client, required_modules=required_modules + ) except RedisModuleVersionError as e: raise RedisModuleVersionError( f"Loading from existing index failed. {str(e)}" ) + if not redis_client: + raise ValueError("Must provide either a redis_url or redis_client") + # Fetch index info and convert to schema index_info = cls._info(name, redis_client) schema_dict = convert_index_info_to_schema(index_info) @@ -359,8 +341,26 @@ def from_existing( @property def client(self) -> Optional[redis.Redis]: """The underlying redis-py client object.""" - return self._redis_client + return self.__redis_client + + @property + def _redis_client(self) -> Optional[redis.Redis]: + """ + Get a Redis client instance. + + Lazily creates a Redis client instance if it doesn't exist. + """ + if self.__redis_client is None: + with self._lock: + if self.__redis_client is None: + self.__redis_client = RedisConnectionFactory.get_redis_connection( + url=self._redis_url, + **self._connection_kwargs, + ) + return self.__redis_client + + @deprecated_function("connect", "Pass connection parameters in __init__.") def connect(self, redis_url: Optional[str] = None, **kwargs): """Connect to a Redis instance using the provided `redis_url`, falling back to the `REDIS_URL` environment variable (if available). @@ -378,18 +378,18 @@ def connect(self, redis_url: Optional[str] = None, **kwargs): server fails. ValueError: If the Redis URL is not provided nor accessible through the `REDIS_URL` environment variable. + ModuleNotFoundError: If required Redis modules are not installed. .. code-block:: python index.connect(redis_url="redis://localhost:6379") """ - client = RedisConnectionFactory.connect( - redis_url=redis_url, use_async=False, **kwargs + self.__redis_client = RedisConnectionFactory.get_redis_connection( + redis_url=redis_url, required_modules=self.required_modules, **kwargs ) - return self.set_client(client) - @setup_redis() + @deprecated_function("set_client", "Pass connection parameters in __init__.") def set_client(self, redis_client: redis.Redis, **kwargs): """Manually set the Redis client to use with the search index. @@ -414,10 +414,10 @@ def set_client(self, redis_client: redis.Redis, **kwargs): index.set_client(client) """ - if not isinstance(redis_client, redis.Redis): - raise TypeError("Invalid Redis client instance") - - self._redis_client = redis_client + RedisConnectionFactory.validate_sync_redis( + redis_client, required_modules=self.required_modules + ) + self.__redis_client = redis_client return self def create(self, overwrite: bool = False, drop: bool = False) -> None: @@ -527,7 +527,7 @@ def expire_keys( """Set the expiration time for a specific entry or entries in Redis. Args: - keys (Union[str, List[str]]): The document ID or IDs to set the expiration for. + keys (Union[str, List[str]]): The entry ID or IDs to set the expiration for. ttl (int): The time-to-live in seconds. """ if isinstance(keys, list): @@ -798,8 +798,10 @@ class AsyncSearchIndex(BaseSearchIndex): from redisvl.index import AsyncSearchIndex # initialize the index object with schema from file - index = AsyncSearchIndex.from_yaml("schemas/schema.yaml") - await index.connect(redis_url="redis://localhost:6379") + index = AsyncSearchIndex.from_yaml( + "schemas/schema.yaml", + redis_url="redis://localhost:6379" + ) # create the index await index.create(overwrite=True) @@ -812,13 +814,19 @@ class AsyncSearchIndex(BaseSearchIndex): """ + required_modules = [ + {"name": "search", "ver": 20810}, + {"name": "searchlight", "ver": 20810}, + ] + + @deprecated_argument("redis_kwargs", "Use connection_kwargs instead.") def __init__( self, schema: IndexSchema, *, redis_url: Optional[str] = None, redis_client: Optional[aredis.Redis] = None, - redis_kwargs: Optional[Dict[str, Any]] = None, + connection_kwargs: Optional[Dict[str, Any]] = None, **kwargs, ): """Initialize the RedisVL async search index with a schema. @@ -829,9 +837,12 @@ def __init__( connect to. redis_client (Optional[aredis.Redis], optional): An instantiated redis client. - redis_kwargs (Dict[str, Any], optional): Redis client connection + connection_kwargs (Dict[str, Any], optional): Redis client connection args. """ + if "redis_kwargs" in kwargs: + connection_kwargs = kwargs.pop("redis_kwargs") + # final validation on schema object if not isinstance(schema, IndexSchema): raise ValueError("Must provide a valid IndexSchema object") @@ -841,12 +852,9 @@ def __init__( self._lib_name: Optional[str] = kwargs.pop("lib_name", None) # Store connection parameters - if redis_client and redis_url: - raise ValueError("Cannot provide both redis_client and redis_url") - self._redis_client = redis_client self._redis_url = redis_url - self._redis_kwargs = redis_kwargs or {} + self._connection_kwargs = connection_kwargs or {} self._lock = asyncio.Lock() async def disconnect(self): @@ -873,33 +881,31 @@ async def from_existing( redis_url (Optional[str]): The URL of the Redis server to connect to. """ - if redis_client and redis_url: - raise ValueError("Cannot provide both redis_client and redis_url") - elif redis_url: - redis_client = RedisConnectionFactory.get_async_redis_connection( - url=redis_url, **kwargs - ) - elif redis_client: - pass - else: + if not redis_url and not redis_client: raise ValueError( "Must provide either a redis_url or redis_client to fetch Redis index info." ) - # Validate modules - installed_modules = await RedisConnectionFactory.get_modules_async(redis_client) - try: - required_modules = [ - {"name": "search", "ver": 20810}, - {"name": "searchlight", "ver": 20810}, - ] - validate_modules(installed_modules, required_modules) + if redis_url: + redis_client = await RedisConnectionFactory._get_aredis_connection( + url=redis_url, required_modules=cls.required_modules, **kwargs + ) + elif redis_client: + await RedisConnectionFactory.validate_async_redis( + redis_client, required_modules=cls.required_modules + ) except RedisModuleVersionError as e: raise RedisModuleVersionError( f"Loading from existing index failed. {str(e)}" ) from e + if redis_client is None: + raise ValueError( + "Failed to obtain a valid Redis client. " + "Please provide a valid redis_client or redis_url." + ) + # Fetch index info and convert to schema index_info = await cls._info(name, redis_client) schema_dict = convert_index_info_to_schema(index_info) @@ -918,10 +924,10 @@ async def connect(self, redis_url: Optional[str] = None, **kwargs): "connect() is deprecated; pass connection parameters in __init__", DeprecationWarning, ) - client: redis.asyncio.Redis = RedisConnectionFactory.connect( - redis_url=redis_url, use_async=True, **kwargs - ) # type: ignore - return await self.set_client(client) + client = await RedisConnectionFactory._get_aredis_connection( + redis_url=redis_url, required_modules=self.required_modules, **kwargs + ) + await self.set_client(client) @deprecated_function("set_client", "Pass connection parameters in __init__.") async def set_client(self, redis_client: aredis.Redis): @@ -935,17 +941,20 @@ async def set_client(self, redis_client: aredis.Redis): self._redis_client = redis_client return self - async def _get_client(self): + async def _get_client(self) -> aredis.Redis: """Lazily instantiate and return the async Redis client.""" if self._redis_client is None: async with self._lock: # Double-check to protect against concurrent access if self._redis_client is None: - kwargs = self._redis_kwargs + kwargs = self._connection_kwargs if self._redis_url: - kwargs["redis_url"] = self._redis_url + kwargs["url"] = self._redis_url self._redis_client = ( - RedisConnectionFactory.get_async_redis_connection(**kwargs) + await RedisConnectionFactory._get_aredis_connection( + required_modules=self.required_modules, + **kwargs + ) ) await RedisConnectionFactory.validate_async_redis( self._redis_client, self._lib_name @@ -1072,7 +1081,7 @@ async def expire_keys( """Set the expiration time for a specific entry or entries in Redis. Args: - keys (Union[str, List[str]]): The document ID or IDs to set the expiration for. + keys (Union[str, List[str]]): The entry ID or IDs to set the expiration for. ttl (int): The time-to-live in seconds. """ client = await self._get_client() @@ -1203,7 +1212,7 @@ async def search(self, *args, **kwargs) -> "Result": """ client = await self._get_client() try: - return await client.ft(self.schema.index.name).search(*args, **kwargs) + return await client.ft(self.schema.index.name).search(*args, **kwargs) # type: ignore except Exception as e: raise RedisSearchError(f"Error while searching: {str(e)}") from e diff --git a/redisvl/redis/connection.py b/redisvl/redis/connection.py index 414edef5..99a273c5 100644 --- a/redisvl/redis/connection.py +++ b/redisvl/redis/connection.py @@ -1,5 +1,6 @@ import os from typing import Any, Dict, List, Optional, Type, Union +from warnings import warn from redis import Redis from redis.asyncio import Connection as AsyncConnection @@ -190,6 +191,7 @@ class RedisConnectionFactory: """ @classmethod + @deprecated_function("connect", "Please use `get_redis_connection` or `get_async_redis_connection`.") def connect( cls, redis_url: Optional[str] = None, use_async: bool = False, **kwargs ) -> Union[Redis, AsyncRedis]: @@ -218,12 +220,18 @@ def connect( return connection_func(redis_url, **kwargs) # type: ignore @staticmethod - def get_redis_connection(url: Optional[str] = None, **kwargs) -> Redis: + def get_redis_connection( + url: Optional[str] = None, + required_modules: Optional[List[Dict[str, Any]]] = None, + **kwargs, + ) -> Redis: """Creates and returns a synchronous Redis client. Args: url (Optional[str]): The URL of the Redis server. If not provided, the environment variable REDIS_URL is used. + required_modules (Optional[List[Dict[str, Any]]]): List of required + Redis modules with version requirements. **kwargs: Additional keyword arguments to be passed to the Redis client constructor. @@ -233,14 +241,57 @@ def get_redis_connection(url: Optional[str] = None, **kwargs) -> Redis: Raises: ValueError: If url is not provided and REDIS_URL environment variable is not set. + RedisModuleVersionError: If required Redis modules are not installed. + """ + url = url or get_address_from_env() + client = Redis.from_url(url, **kwargs) + + RedisConnectionFactory.validate_sync_redis( + client, required_modules=required_modules + ) + + return client + + @staticmethod + async def _get_aredis_connection( + url: Optional[str] = None, + required_modules: Optional[List[Dict[str, Any]]] = None, + **kwargs, + ) -> AsyncRedis: + """Creates and returns an asynchronous Redis client. + + NOTE: This method is the future form of `get_async_redis_connection` but is + only used internally by the library now. + + Args: + url (Optional[str]): The URL of the Redis server. If not provided, + the environment variable REDIS_URL is used. + required_modules (Optional[List[Dict[str, Any]]]): List of required + Redis modules with version requirements. + **kwargs: Additional keyword arguments to be passed to the async + Redis client constructor. + + Returns: + AsyncRedis: An asynchronous Redis client instance. + + Raises: + ValueError: If url is not provided and REDIS_URL environment + variable is not set. + RedisModuleVersionError: If required Redis modules are not installed. """ - if url: - return Redis.from_url(url, **kwargs) - # fallback to env var REDIS_URL - return Redis.from_url(get_address_from_env(), **kwargs) + url = url or get_address_from_env() + client = AsyncRedis.from_url(url, **kwargs) + + await RedisConnectionFactory.validate_async_redis( + client, required_modules=required_modules + ) + return client @staticmethod - def get_async_redis_connection(url: Optional[str] = None, **kwargs) -> AsyncRedis: + def get_async_redis_connection( + url: Optional[str] = None, + **kwargs, + ) -> AsyncRedis: """Creates and returns an asynchronous Redis client. Args: @@ -256,13 +307,15 @@ def get_async_redis_connection(url: Optional[str] = None, **kwargs) -> AsyncRedi ValueError: If url is not provided and REDIS_URL environment variable is not set. """ - if url: - return AsyncRedis.from_url(url, **kwargs) - # fallback to env var REDIS_URL - return AsyncRedis.from_url(get_address_from_env(), **kwargs) + warn( + "get_async_redis_connection will become async in the next major release.", + DeprecationWarning, + ) + url = url or get_address_from_env() + return AsyncRedis.from_url(url, **kwargs) @deprecated_function( - "sync_to_async_redis", "Please use an async Redis client instead." + "sync_to_async_redis", "Please use an async Redis client." ) @staticmethod def sync_to_async_redis(redis_client: Redis) -> AsyncRedis: @@ -295,6 +348,9 @@ def validate_sync_redis( required_modules: Optional[List[Dict[str, Any]]] = None, ) -> None: """Validates the sync Redis client.""" + if not isinstance(redis_client, Redis): + raise TypeError("Invalid Redis client instance") + # Set client library name _lib_name = make_lib_name(lib_name) try: diff --git a/redisvl/utils/utils.py b/redisvl/utils/utils.py index 252ba6bf..ece10cb9 100644 --- a/redisvl/utils/utils.py +++ b/redisvl/utils/utils.py @@ -10,7 +10,7 @@ def create_ulid() -> str: - """Generate a unique indentifier to group related Redis documents.""" + """Generate a unique identifier to group related Redis documents.""" return str(ULID()) diff --git a/tests/conftest.py b/tests/conftest.py index de41b4b2..1708b35e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -54,7 +54,7 @@ async def async_client(redis_url): """ An async Redis client that uses the dynamic `redis_url`. """ - async with await RedisConnectionFactory.get_async_redis_connection( + async with await RedisConnectionFactory._get_aredis_connection( redis_url ) as client: yield client diff --git a/tests/integration/test_async_search_index.py b/tests/integration/test_async_search_index.py index 5815fbec..a8afc8fd 100644 --- a/tests/integration/test_async_search_index.py +++ b/tests/integration/test_async_search_index.py @@ -1,4 +1,6 @@ import pytest +import warnings +from redis.asyncio import Redis from redisvl.exceptions import RedisSearchError from redisvl.index import AsyncSearchIndex @@ -15,8 +17,8 @@ def index_schema(): @pytest.fixture -def async_index(index_schema): - return AsyncSearchIndex(schema=index_schema) +def async_index(index_schema, async_client): + return AsyncSearchIndex(schema=index_schema, redis_client=async_client) @pytest.fixture @@ -33,7 +35,7 @@ def test_search_index_properties(index_schema, async_index): assert async_index.schema == index_schema # custom settings assert async_index.name == index_schema.index.name == "my_index" - assert async_index.client is None + assert async_index.client # default settings assert async_index.prefix == index_schema.index.prefix == "rvl" assert async_index.key_separator == index_schema.index.key_separator == ":" @@ -64,7 +66,6 @@ def test_search_index_from_dict(async_index_from_dict): @pytest.mark.asyncio async def test_search_index_from_existing(async_client, async_index): - await async_index.set_client(async_client) await async_index.create(overwrite=True) try: @@ -107,9 +108,7 @@ async def test_search_index_from_existing_complex(async_client): }, ], } - async_index = await AsyncSearchIndex.from_dict(schema).set_client( - redis_client=async_client - ) + async_index = AsyncSearchIndex.from_dict(schema, redis_client=async_client) await async_index.create(overwrite=True) try: @@ -132,36 +131,41 @@ def test_search_index_no_prefix(index_schema): @pytest.mark.asyncio async def test_search_index_redis_url(redis_url, index_schema): - async_index = await AsyncSearchIndex(schema=index_schema).connect( - redis_url=redis_url - ) + async_index = AsyncSearchIndex(schema=index_schema, redis_url=redis_url) + # Client is None until a command is run + assert async_index.client is None + + # Lazily create the client by running a command + await async_index.create(overwrite=True, drop=True) assert async_index.client await async_index.disconnect() - assert async_index.client == None + assert async_index.client is None @pytest.mark.asyncio async def test_search_index_client(async_client, index_schema): - async_index = await AsyncSearchIndex(schema=index_schema).set_client( - redis_client=async_client - ) + async_index = AsyncSearchIndex(schema=index_schema, redis_client=async_client) assert async_index.client == async_client @pytest.mark.asyncio async def test_search_index_set_client(async_client, client, async_index): - await async_index.set_client(async_client) - assert async_index.client == async_client - await async_index.set_client(client) + # Ignore deprecation warnings + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=DeprecationWarning) + assert async_index.client == async_client + + # Tests deprecated sync -> async conversation behavior + await async_index.set_client(client) + assert isinstance(async_index.client, Redis) await async_index.disconnect() assert async_index.client is None @pytest.mark.asyncio -async def test_search_index_create(async_client, async_index): - await async_index.set_client(async_client) +async def test_search_index_create(async_index): await async_index.create(overwrite=True, drop=True) assert await async_index.exists() assert async_index.name in convert_bytes( @@ -170,8 +174,7 @@ async def test_search_index_create(async_client, async_index): @pytest.mark.asyncio -async def test_search_index_delete(async_client, async_index): - await async_index.set_client(async_client) +async def test_search_index_delete(async_index): await async_index.create(overwrite=True, drop=True) await async_index.delete(drop=True) assert not await async_index.exists() @@ -181,8 +184,7 @@ async def test_search_index_delete(async_client, async_index): @pytest.mark.asyncio -async def test_search_index_clear(async_client, async_index): - await async_index.set_client(async_client) +async def test_search_index_clear(async_index): await async_index.create(overwrite=True, drop=True) data = [{"id": "1", "test": "foo"}] await async_index.load(data, id_field="id") @@ -193,8 +195,7 @@ async def test_search_index_clear(async_client, async_index): @pytest.mark.asyncio -async def test_search_index_drop_key(async_client, async_index): - await async_index.set_client(async_client) +async def test_search_index_drop_key(async_index): await async_index.create(overwrite=True, drop=True) data = [{"id": "1", "test": "foo"}, {"id": "2", "test": "bar"}] keys = await async_index.load(data, id_field="id") @@ -206,8 +207,7 @@ async def test_search_index_drop_key(async_client, async_index): @pytest.mark.asyncio -async def test_search_index_drop_keys(async_client, async_index): - await async_index.set_client(async_client) +async def test_search_index_drop_keys(async_index): await async_index.create(overwrite=True, drop=True) data = [ {"id": "1", "test": "foo"}, @@ -226,8 +226,7 @@ async def test_search_index_drop_keys(async_client, async_index): @pytest.mark.asyncio -async def test_search_index_load_and_fetch(async_client, async_index): - await async_index.set_client(async_client) +async def test_search_index_load_and_fetch(async_index): await async_index.create(overwrite=True, drop=True) data = [{"id": "1", "test": "foo"}] await async_index.load(data, id_field="id") @@ -245,8 +244,7 @@ async def test_search_index_load_and_fetch(async_client, async_index): @pytest.mark.asyncio -async def test_search_index_load_preprocess(async_client, async_index): - await async_index.set_client(async_client) +async def test_search_index_load_preprocess(async_index): await async_index.create(overwrite=True, drop=True) data = [{"id": "1", "test": "foo"}] @@ -270,15 +268,13 @@ async def bad_preprocess(record): @pytest.mark.asyncio -async def test_search_index_load_empty(async_client, async_index): - await async_index.set_client(async_client) +async def test_search_index_load_empty(async_index): await async_index.create(overwrite=True, drop=True) await async_index.load([]) @pytest.mark.asyncio -async def test_no_id_field(async_client, async_index): - await async_index.set_client(async_client) +async def test_no_id_field(async_index): await async_index.create(overwrite=True, drop=True) bad_data = [{"wrong_key": "1", "value": "test"}] @@ -288,8 +284,7 @@ async def test_no_id_field(async_client, async_index): @pytest.mark.asyncio -async def test_check_index_exists_before_delete(async_client, async_index): - await async_index.set_client(async_client) +async def test_check_index_exists_before_delete(async_index): await async_index.create(overwrite=True, drop=True) await async_index.delete(drop=True) with pytest.raises(RedisSearchError): @@ -297,8 +292,7 @@ async def test_check_index_exists_before_delete(async_client, async_index): @pytest.mark.asyncio -async def test_check_index_exists_before_search(async_client, async_index): - await async_index.set_client(async_client) +async def test_check_index_exists_before_search(async_index): await async_index.create(overwrite=True, drop=True) await async_index.delete(drop=True) @@ -313,8 +307,7 @@ async def test_check_index_exists_before_search(async_client, async_index): @pytest.mark.asyncio -async def test_check_index_exists_before_info(async_client, async_index): - await async_index.set_client(async_client) +async def test_check_index_exists_before_info(async_index): await async_index.create(overwrite=True, drop=True) await async_index.delete(drop=True) diff --git a/tests/integration/test_connection.py b/tests/integration/test_connection.py index aa1f4e2a..8e2d2ea8 100644 --- a/tests/integration/test_connection.py +++ b/tests/integration/test_connection.py @@ -10,7 +10,6 @@ RedisConnectionFactory, compare_versions, convert_index_info_to_schema, - get_address_from_env, unpack_redis_modules, validate_modules, ) @@ -19,6 +18,9 @@ EXPECTED_LIB_NAME = f"redis-py(redisvl_v{__version__})" +# Remove after we remove connect() method from RedisConnectionFactory +pytestmark = pytest.mark.filterwarnings("ignore::DeprecationWarning") + def test_unpack_redis_modules(): module_list = [ @@ -129,41 +131,38 @@ def test_validate_modules_not_exist(): ) -def test_sync_redis_connect(redis_url): - client = RedisConnectionFactory.connect(redis_url) - assert client is not None - assert isinstance(client, Redis) - # Perform a simple operation - assert client.ping() - - -@pytest.mark.asyncio -async def test_async_redis_connect(redis_url): - client = RedisConnectionFactory.connect(redis_url, use_async=True) - assert client is not None - assert isinstance(client, AsyncRedis) - # Perform a simple operation - assert await client.ping() - - -def test_missing_env_var(): - redis_url = os.getenv("REDIS_URL") - if redis_url: - del os.environ["REDIS_URL"] +class TestConnect: + def test_sync_redis_connect(self, redis_url): + client = RedisConnectionFactory.connect(redis_url) + assert client is not None + assert isinstance(client, Redis) + # Perform a simple operation + assert client.ping() + + @pytest.mark.asyncio + async def test_async_redis_connect(self, redis_url): + client = RedisConnectionFactory.connect(redis_url, use_async=True) + assert client is not None + assert isinstance(client, AsyncRedis) + # Perform a simple operation + assert await client.ping() + + def test_missing_env_var(self): + redis_url = os.getenv("REDIS_URL") + if redis_url: + del os.environ["REDIS_URL"] + with pytest.raises(ValueError): + RedisConnectionFactory.connect() + os.environ["REDIS_URL"] = redis_url + + def test_invalid_url_format(self): with pytest.raises(ValueError): - RedisConnectionFactory.connect() - os.environ["REDIS_URL"] = redis_url - - -def test_invalid_url_format(): - with pytest.raises(ValueError): - RedisConnectionFactory.connect(redis_url="invalid_url_format") - + RedisConnectionFactory.connect(redis_url="invalid_url_format") -def test_unknown_redis(): - bad_client = RedisConnectionFactory.connect(redis_url="redis://fake:1234") - with pytest.raises(ConnectionError): - bad_client.ping() + def test_unknown_redis(self): + with pytest.raises(ConnectionError): + bad_client = RedisConnectionFactory.connect(redis_url="redis://fake:1234") + bad_client.ping() def test_validate_redis(client): diff --git a/tests/integration/test_flow.py b/tests/integration/test_flow.py index b448a636..7542528a 100644 --- a/tests/integration/test_flow.py +++ b/tests/integration/test_flow.py @@ -43,9 +43,7 @@ @pytest.mark.parametrize("schema", [hash_schema, json_schema]) def test_simple(client, schema, sample_data): - index = SearchIndex.from_dict(schema) - # assign client (only for testing) - index.set_client(client) + index = SearchIndex.from_dict(schema, redis_client=client) # create the index index.create(overwrite=True, drop=True) diff --git a/tests/integration/test_flow_async.py b/tests/integration/test_flow_async.py index fbfa7d22..a368f677 100644 --- a/tests/integration/test_flow_async.py +++ b/tests/integration/test_flow_async.py @@ -47,9 +47,7 @@ @pytest.mark.asyncio @pytest.mark.parametrize("schema", [hash_schema, json_schema]) async def test_simple(async_client, schema, sample_data): - index = AsyncSearchIndex.from_dict(schema) - # assign client (only for testing) - await index.set_client(async_client) + index = AsyncSearchIndex.from_dict(schema, redis_client=async_client) # create the index await index.create(overwrite=True, drop=True) diff --git a/tests/integration/test_query.py b/tests/integration/test_query.py index 2c6dc376..271d36da 100644 --- a/tests/integration/test_query.py +++ b/tests/integration/test_query.py @@ -92,12 +92,10 @@ def index(sample_data, redis_url): }, }, ], - } + }, + redis_url=redis_url, ) - # connect to local redis instance - index.connect(redis_url) - # create the index (no data yet) index.create(overwrite=True) diff --git a/tests/integration/test_search_index.py b/tests/integration/test_search_index.py index 4f4392d3..4d09c846 100644 --- a/tests/integration/test_search_index.py +++ b/tests/integration/test_search_index.py @@ -1,9 +1,9 @@ +import warnings import pytest from redisvl.exceptions import RedisSearchError from redisvl.index import SearchIndex from redisvl.query import VectorQuery -from redisvl.redis.connection import RedisConnectionFactory, validate_modules from redisvl.redis.utils import convert_bytes from redisvl.schema import IndexSchema, StorageType @@ -27,8 +27,8 @@ def index_schema(): @pytest.fixture -def index(index_schema): - return SearchIndex(schema=index_schema) +def index(index_schema, client): + return SearchIndex(schema=index_schema, redis_client=client) @pytest.fixture @@ -45,7 +45,7 @@ def test_search_index_properties(index_schema, index): assert index.schema == index_schema # custom settings assert index.name == index_schema.index.name == "my_index" - assert index.client == None + # default settings assert index.prefix == index_schema.index.prefix == "rvl" assert index.key_separator == index_schema.index.key_separator == ":" @@ -73,7 +73,6 @@ def test_search_index_from_dict(index_from_dict): def test_search_index_from_existing(client, index): - index.set_client(client) index.create(overwrite=True) try: @@ -134,10 +133,14 @@ def test_search_index_no_prefix(index_schema): def test_search_index_redis_url(redis_url, index_schema): index = SearchIndex(schema=index_schema, redis_url=redis_url) + # Client is not set until a command runs + assert index.client is None + + index.create(overwrite=True) assert index.client index.disconnect() - assert index.client == None + assert index.client is None def test_search_index_client(client, index_schema): @@ -146,33 +149,31 @@ def test_search_index_client(client, index_schema): def test_search_index_set_client(async_client, client, index): - index.set_client(client) - assert index.client == client - # should not be able to set the sync client here - with pytest.raises(TypeError): - index.set_client(async_client) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=DeprecationWarning) + assert index.client == client + # should not be able to set an async client here + with pytest.raises(TypeError): + index.set_client(async_client) - index.disconnect() - assert index.client == None + index.disconnect() + assert index.client is None -def test_search_index_create(client, index): - index.set_client(client) +def test_search_index_create(index): index.create(overwrite=True, drop=True) assert index.exists() assert index.name in convert_bytes(index.client.execute_command("FT._LIST")) -def test_search_index_delete(client, index): - index.set_client(client) +def test_search_index_delete(index): index.create(overwrite=True, drop=True) index.delete(drop=True) assert not index.exists() assert index.name not in convert_bytes(index.client.execute_command("FT._LIST")) -def test_search_index_clear(client, index): - index.set_client(client) +def test_search_index_clear(index): index.create(overwrite=True, drop=True) data = [{"id": "1", "test": "foo"}] index.load(data, id_field="id") @@ -182,8 +183,7 @@ def test_search_index_clear(client, index): assert index.exists() -def test_search_index_drop_key(client, index): - index.set_client(client) +def test_search_index_drop_key(index): index.create(overwrite=True, drop=True) data = [{"id": "1", "test": "foo"}, {"id": "2", "test": "bar"}] keys = index.load(data, id_field="id") @@ -195,8 +195,7 @@ def test_search_index_drop_key(client, index): assert index.fetch(keys[1]) is not None # still have all other entries -def test_search_index_drop_keys(client, index): - index.set_client(client) +def test_search_index_drop_keys(index): index.create(overwrite=True, drop=True) data = [ {"id": "1", "test": "foo"}, @@ -215,22 +214,20 @@ def test_search_index_drop_keys(client, index): assert index.exists() -def test_search_index_load_and_fetch(client, index): - index.set_client(client) +def test_search_index_load_and_fetch(index): index.create(overwrite=True, drop=True) data = [{"id": "1", "test": "foo"}] index.load(data, id_field="id") res = index.fetch("1") - assert res["test"] == convert_bytes(client.hget("rvl:1", "test")) == "foo" + assert res["test"] == convert_bytes(index.client.hget("rvl:1", "test")) == "foo" index.delete(drop=True) assert not index.exists() assert not index.fetch("1") -def test_search_index_load_preprocess(client, index): - index.set_client(client) +def test_search_index_load_preprocess(index): index.create(overwrite=True, drop=True) data = [{"id": "1", "test": "foo"}] @@ -240,7 +237,7 @@ def preprocess(record): index.load(data, id_field="id", preprocess=preprocess) res = index.fetch("1") - assert res["test"] == convert_bytes(client.hget("rvl:1", "test")) == "bar" + assert res["test"] == convert_bytes(index.client.hget("rvl:1", "test")) == "bar" def bad_preprocess(record): return 1 @@ -249,8 +246,7 @@ def bad_preprocess(record): index.load(data, id_field="id", preprocess=bad_preprocess) -def test_no_id_field(client, index): - index.set_client(client) +def test_no_id_field(index): index.create(overwrite=True, drop=True) bad_data = [{"wrong_key": "1", "value": "test"}] @@ -259,16 +255,14 @@ def test_no_id_field(client, index): index.load(bad_data, id_field="key") -def test_check_index_exists_before_delete(client, index): - index.set_client(client) +def test_check_index_exists_before_delete(index): index.create(overwrite=True, drop=True) index.delete(drop=True) with pytest.raises(RedisSearchError): index.delete() -def test_check_index_exists_before_search(client, index): - index.set_client(client) +def test_check_index_exists_before_search(index): index.create(overwrite=True, drop=True) index.delete(drop=True) @@ -282,8 +276,7 @@ def test_check_index_exists_before_search(client, index): index.search(query.query, query_params=query.params) -def test_check_index_exists_before_info(client, index): - index.set_client(client) +def test_check_index_exists_before_info(index): index.create(overwrite=True, drop=True) index.delete(drop=True) @@ -293,4 +286,4 @@ def test_check_index_exists_before_info(client, index): def test_index_needs_valid_schema(): with pytest.raises(ValueError, match=r"Must provide a valid IndexSchema object"): - index = SearchIndex(schema="Not A Valid Schema") + SearchIndex(schema="Not A Valid Schema") # type: ignore diff --git a/tests/integration/test_search_results.py b/tests/integration/test_search_results.py index 83efad53..c451f039 100644 --- a/tests/integration/test_search_results.py +++ b/tests/integration/test_search_results.py @@ -44,10 +44,7 @@ def index(sample_data, redis_url): } # construct a search index from the schema - index = SearchIndex.from_dict(json_schema) - - # connect to local redis instance - index.connect(redis_url=redis_url) + index = SearchIndex.from_dict(json_schema, redis_url=redis_url) # create the index (no data yet) index.create(overwrite=True) From 9114941b1995fd164b2a0c85a8c9121fe22c53b5 Mon Sep 17 00:00:00 2001 From: Andrew Brookins Date: Fri, 14 Feb 2025 15:11:21 -0800 Subject: [PATCH 09/22] retain support for sync -> async conversion in SemanticCache --- redisvl/extensions/llmcache/semantic.py | 10 +++++++--- redisvl/redis/connection.py | 5 ++--- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/redisvl/extensions/llmcache/semantic.py b/redisvl/extensions/llmcache/semantic.py index 176ad722..245bfd85 100644 --- a/redisvl/extensions/llmcache/semantic.py +++ b/redisvl/extensions/llmcache/semantic.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List, Optional from redis import Redis - +from redis.asyncio import Redis as AsyncRedis from redisvl.extensions.constants import ( CACHE_VECTOR_FIELD_NAME, ENTRY_ID_FIELD_NAME, @@ -22,6 +22,7 @@ from redisvl.index import AsyncSearchIndex, SearchIndex from redisvl.query import RangeQuery from redisvl.query.filter import FilterExpression +from redisvl.redis.connection import RedisConnectionFactory from redisvl.utils.utils import ( current_timestamp, deprecated_argument, @@ -175,11 +176,14 @@ async def _get_async_index(self) -> AsyncSearchIndex: """Lazily construct the async search index class.""" # Construct async index if necessary if not self._aindex: + client = self.redis_kwargs.get("redis_client") + if client and isinstance(client, Redis): + client = RedisConnectionFactory.sync_to_async_redis(client) self._aindex = AsyncSearchIndex( schema=self._index.schema, - redis_client=self.redis_kwargs["redis_client"], + redis_client=client, redis_url=self.redis_kwargs["redis_url"], - connection_kwargs=self.redis_kwargs["connection_kwargs"], + **self.redis_kwargs["connection_kwargs"], ) return self._aindex diff --git a/redisvl/redis/connection.py b/redisvl/redis/connection.py index 99a273c5..8924e58f 100644 --- a/redisvl/redis/connection.py +++ b/redisvl/redis/connection.py @@ -314,11 +314,10 @@ def get_async_redis_connection( url = url or get_address_from_env() return AsyncRedis.from_url(url, **kwargs) - @deprecated_function( - "sync_to_async_redis", "Please use an async Redis client." - ) + @staticmethod def sync_to_async_redis(redis_client: Redis) -> AsyncRedis: + """Convert a synchronous Redis client to an asynchronous one.""" # pick the right connection class connection_class: Type[AbstractConnection] = ( AsyncSSLConnection From 3eb1d435a47c129e5f7d28dc218568d887495176 Mon Sep 17 00:00:00 2001 From: Andrew Brookins Date: Fri, 14 Feb 2025 17:34:38 -0800 Subject: [PATCH 10/22] add finalizers, explicit disconnect methods, ctx mgrs --- redisvl/extensions/llmcache/semantic.py | 42 +++++++++++++++++++- redisvl/index/index.py | 41 ++++++++++++++++--- tests/integration/test_async_search_index.py | 30 ++++++++++++++ tests/integration/test_llmcache.py | 33 +++++++++++++++ tests/integration/test_search_index.py | 21 ++++++++++ 5 files changed, 161 insertions(+), 6 deletions(-) diff --git a/redisvl/extensions/llmcache/semantic.py b/redisvl/extensions/llmcache/semantic.py index 422a2e0a..5ea389c1 100644 --- a/redisvl/extensions/llmcache/semantic.py +++ b/redisvl/extensions/llmcache/semantic.py @@ -1,8 +1,8 @@ import asyncio from typing import Any, Dict, List, Optional +import weakref from redis import Redis -from redis.asyncio import Redis as AsyncRedis from redisvl.extensions.constants import ( CACHE_VECTOR_FIELD_NAME, ENTRY_ID_FIELD_NAME, @@ -23,6 +23,7 @@ from redisvl.query import RangeQuery from redisvl.query.filter import FilterExpression from redisvl.redis.connection import RedisConnectionFactory +from redisvl.utils.log import get_logger from redisvl.utils.utils import ( current_timestamp, deprecated_argument, @@ -32,6 +33,9 @@ from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer +logger = get_logger("[RedisVL]") + + class SemanticCache(BaseLLMCache): """Semantic Cache for Large Language Models.""" @@ -149,6 +153,8 @@ def __init__( # Create the search index in Redis self._index.create(overwrite=overwrite, drop=False) + + weakref.finalize(self, self._finalize_async) def _modify_schema( self, @@ -702,3 +708,37 @@ async def aupdate(self, key: str, **kwargs) -> None: await aindex.load(data=[kwargs], keys=[key]) await self._async_refresh_ttl(key) + + def _finalize_async(self): + if self._index: + self._index.disconnect() + if self._aindex: + try: + loop = None + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(self._aindex.disconnect()) + except Exception as e: + logger.info(f"Error disconnecting from index: {e}") + + def disconnect(self): + if self._index: + self._index.disconnect() + if self._aindex: + asyncio.run(self._aindex.disconnect()) + + async def adisconnect(self): + if self._index: + self._index.disconnect() + if self._aindex: + await self._aindex.disconnect() + self._aindex = None + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.adisconnect() diff --git a/redisvl/index/index.py b/redisvl/index/index.py index cffb57a2..3d2fed58 100644 --- a/redisvl/index/index.py +++ b/redisvl/index/index.py @@ -16,6 +16,7 @@ Optional, Union, ) +import weakref from redisvl.utils.utils import deprecated_argument, deprecated_function @@ -784,6 +785,12 @@ def info(self, name: Optional[str] = None) -> Dict[str, Any]: index_name = name or self.schema.index.name return self._info(index_name, self._redis_client) # type: ignore + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.disconnect() + class AsyncSearchIndex(BaseSearchIndex): """A search index class for interacting with Redis as a vector database in @@ -857,11 +864,8 @@ def __init__( self._connection_kwargs = connection_kwargs or {} self._lock = asyncio.Lock() - async def disconnect(self): - """Asynchronously disconnect and cleanup the underlying async redis connection.""" - if self._redis_client is not None: - await self._redis_client.aclose() # type: ignore - self._redis_client = None + # Close connections when the object is garbage collected + weakref.finalize(self, self._finalize_disconnect) @classmethod async def from_existing( @@ -1336,9 +1340,36 @@ async def _info(name: str, redis_client: aredis.Redis) -> Dict[str, Any]: raise RedisSearchError( f"Error while fetching {name} index info: {str(e)}" ) from e + + async def disconnect(self): + """Asynchronously disconnect and cleanup the underlying async redis connection.""" + if self._redis_client is not None: + await self._redis_client.aclose() # type: ignore + self._redis_client = None + + def disconnect_sync(self): + """Synchronously disconnect and cleanup the underlying async redis connection.""" + if self._redis_client is None: + return + loop = asyncio.get_running_loop() + if loop is None or not loop.is_running(): + asyncio.run(self._redis_client.aclose()) # type: ignore + else: + loop.create_task(self.disconnect()) + self._redis_client = None async def __aenter__(self): return self async def __aexit__(self, exc_type, exc_val, exc_tb): await self.disconnect() + + def _finalize_disconnect(self): + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + if loop is None or not loop.is_running(): + asyncio.run(self.disconnect()) + else: + loop.create_task(self.disconnect()) diff --git a/tests/integration/test_async_search_index.py b/tests/integration/test_async_search_index.py index a8afc8fd..0d7d8244 100644 --- a/tests/integration/test_async_search_index.py +++ b/tests/integration/test_async_search_index.py @@ -313,3 +313,33 @@ async def test_check_index_exists_before_info(async_index): with pytest.raises(RedisSearchError): await async_index.info() + + +@pytest.mark.asyncio +async def test_search_index_async_context_manager(async_index): + async with async_index: + await async_index.create(overwrite=True, drop=True) + assert async_index._redis_client + assert async_index._redis_client is None + + +@pytest.mark.asyncio +async def test_search_index_context_manager_with_exception(async_index): + async with async_index: + await async_index.create(overwrite=True, drop=True) + raise ValueError("test") + assert async_index._redis_client is None + + +@pytest.mark.asyncio +async def test_search_index_disconnect(async_index): + await async_index.create(overwrite=True, drop=True) + await async_index.disconnect() + assert async_index._redis_client is None + + +@pytest.mark.asyncio +async def test_search_engine_disconnect_sync(async_index): + await async_index.create(overwrite=True, drop=True) + async_index.disconnect_sync() + assert async_index._redis_client is None diff --git a/tests/integration/test_llmcache.py b/tests/integration/test_llmcache.py index 5b32b918..5babebe2 100644 --- a/tests/integration/test_llmcache.py +++ b/tests/integration/test_llmcache.py @@ -941,3 +941,36 @@ def test_deprecated_dtype_argument(redis_url): redis_url=redis_url, overwrite=True, ) + + + +@pytest.mark.asyncio +async def test_cache_async_context_manager(redis_url): + async with SemanticCache(name="test_cache", redis_url=redis_url) as cache: + await cache.astore("test prompt", "test response") + assert cache._aindex + assert cache._aindex is None + + +@pytest.mark.asyncio +async def test_cache_async_context_manager_with_exception(redis_url): + async with SemanticCache(name="test_cache", redis_url=redis_url) as cache: + await cache.astore("test prompt", "test response") + raise ValueError("test") + assert cache._aindex is None + + +@pytest.mark.asyncio +async def test_cache_async_disconnect(redis_url): + cache = SemanticCache(name="test_cache", redis_url=redis_url) + await cache.astore("test prompt", "test response") + await cache.adisconnect() + assert cache._aindex is None + + +def test_cache_disconnect(redis_url): + cache = SemanticCache(name="test_cache", redis_url=redis_url) + cache.store("test prompt", "test response") + cache.disconnect() + # We keep this index object around because it isn't lazily created + assert cache._index.client is None diff --git a/tests/integration/test_search_index.py b/tests/integration/test_search_index.py index 4d09c846..2b803d8a 100644 --- a/tests/integration/test_search_index.py +++ b/tests/integration/test_search_index.py @@ -287,3 +287,24 @@ def test_check_index_exists_before_info(index): def test_index_needs_valid_schema(): with pytest.raises(ValueError, match=r"Must provide a valid IndexSchema object"): SearchIndex(schema="Not A Valid Schema") # type: ignore + + +def test_search_index_context_manager(index): + with index: + index.create(overwrite=True, drop=True) + assert index.client + assert index.client is None + + +def test_search_index_context_manager_with_exception(index): + with pytest.raises(ValueError): + with index: + index.create(overwrite=True, drop=True) + raise ValueError("test") + assert index.client is None + + +def test_search_index_disconnect(index): + index.create(overwrite=True, drop=True) + index.disconnect() + assert index.client is None From e9470d34a2ccc106078d092bbc742e7be0ca162e Mon Sep 17 00:00:00 2001 From: Andrew Brookins Date: Wed, 19 Feb 2025 16:43:00 -0800 Subject: [PATCH 11/22] Rejigger the finalizers and more --- redisvl/extensions/llmcache/semantic.py | 42 ++--- redisvl/index/index.py | 69 ++++--- redisvl/utils/utils.py | 38 +++- tests/integration/test_async_search_index.py | 76 ++++++-- tests/integration/test_llmcache.py | 186 +++++++++++-------- tests/integration/test_search_index.py | 50 ++++- 6 files changed, 300 insertions(+), 161 deletions(-) diff --git a/redisvl/extensions/llmcache/semantic.py b/redisvl/extensions/llmcache/semantic.py index 5ea389c1..56df24bc 100644 --- a/redisvl/extensions/llmcache/semantic.py +++ b/redisvl/extensions/llmcache/semantic.py @@ -25,6 +25,7 @@ from redisvl.redis.connection import RedisConnectionFactory from redisvl.utils.log import get_logger from redisvl.utils.utils import ( + sync_wrapper, current_timestamp, deprecated_argument, serialize, @@ -133,6 +134,12 @@ def __init__( name, prefix, vectorizer.dims, vectorizer.dtype # type: ignore ) schema = self._modify_schema(schema, filterable_fields) + + if redis_client: + self._owns_redis_client = False + else: + self._owns_redis_client = True + self._index = SearchIndex( schema=schema, redis_client=redis_client, @@ -153,8 +160,6 @@ def __init__( # Create the search index in Redis self._index.create(overwrite=overwrite, drop=False) - - weakref.finalize(self, self._finalize_async) def _modify_schema( self, @@ -317,7 +322,9 @@ async def _avectorize_prompt(self, prompt: Optional[str]) -> List[float]: def _check_vector_dims(self, vector: List[float]): """Checks the size of the provided vector and raises an error if it doesn't match the search index vector dimensions.""" - schema_vector_dims = self._index.schema.fields[CACHE_VECTOR_FIELD_NAME].attrs.dims # type: ignore + schema_vector_dims = self._index.schema.fields[ + CACHE_VECTOR_FIELD_NAME + ].attrs.dims # type: ignore validate_vector_dims(len(vector), schema_vector_dims) def check( @@ -392,7 +399,8 @@ def check( # Search the cache! cache_search_results = self._index.query(query) redis_keys, cache_hits = self._process_cache_results( - cache_search_results, return_fields # type: ignore + cache_search_results, + return_fields, # type: ignore ) # Extend TTL on keys for key in redis_keys: @@ -473,7 +481,8 @@ async def acheck( # Search the cache! cache_search_results = await aindex.query(query) redis_keys, cache_hits = self._process_cache_results( - cache_search_results, return_fields # type: ignore + cache_search_results, + return_fields, # type: ignore ) # Extend TTL on keys await asyncio.gather(*[self._async_refresh_ttl(key) for key in redis_keys]) @@ -646,7 +655,6 @@ def update(self, key: str, **kwargs) -> None: """ if kwargs: for k, v in kwargs.items(): - # Make sure the item is in the index schema if k not in set(self._index.schema.field_names + [METADATA_FIELD_NAME]): raise ValueError(f"{k} is not a valid field within the cache entry") @@ -689,7 +697,6 @@ async def aupdate(self, key: str, **kwargs) -> None: if kwargs: for k, v in kwargs.items(): - # Make sure the item is in the index schema if k not in set(self._index.schema.field_names + [METADATA_FIELD_NAME]): raise ValueError(f"{k} is not a valid field within the cache entry") @@ -708,29 +715,18 @@ async def aupdate(self, key: str, **kwargs) -> None: await aindex.load(data=[kwargs], keys=[key]) await self._async_refresh_ttl(key) - - def _finalize_async(self): - if self._index: - self._index.disconnect() - if self._aindex: - try: - loop = None - try: - loop = asyncio.get_running_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - loop.run_until_complete(self._aindex.disconnect()) - except Exception as e: - logger.info(f"Error disconnecting from index: {e}") def disconnect(self): + if self._owns_redis_client is False: + return if self._index: self._index.disconnect() if self._aindex: - asyncio.run(self._aindex.disconnect()) + self._aindex.disconnect_sync() async def adisconnect(self): + if not self._owns_redis_client: + return if self._index: self._index.disconnect() if self._aindex: diff --git a/redisvl/index/index.py b/redisvl/index/index.py index 3d2fed58..2e8ba1d0 100644 --- a/redisvl/index/index.py +++ b/redisvl/index/index.py @@ -1,9 +1,12 @@ import asyncio import json +import logging from os import replace +from re import S import threading import warnings from functools import wraps + from typing import ( TYPE_CHECKING, Any, @@ -18,14 +21,13 @@ ) import weakref -from redisvl.utils.utils import deprecated_argument, deprecated_function +from redisvl.utils.utils import deprecated_argument, deprecated_function, sync_wrapper if TYPE_CHECKING: from redis.commands.search.aggregation import AggregateResult from redis.commands.search.document import Document from redis.commands.search.result import Result from redisvl.query.query import BaseQuery - import redis.asyncio import redis import redis.asyncio as aredis @@ -38,7 +40,6 @@ from redisvl.redis.connection import ( RedisConnectionFactory, convert_index_info_to_schema, - validate_modules, ) from redisvl.redis.utils import convert_bytes from redisvl.schema import IndexSchema, StorageType @@ -279,14 +280,21 @@ def __init__( self._lib_name: Optional[str] = kwargs.pop("lib_name", None) - # Store connection parameters + # Store connection parameters self.__redis_client = redis_client self._redis_url = redis_url self._connection_kwargs = connection_kwargs or {} - self._lock = threading.Lock() + self._lock = threading.Lock() + + self._owns_redis_client = redis_client is None + if self._owns_redis_client: + weakref.finalize(self, self.disconnect) def disconnect(self): """Disconnect from the Redis database.""" + if self._owns_redis_client is False: + print("Index does not own client, not disconnecting") + return if self.__redis_client: self.__redis_client.close() self.__redis_client = None @@ -343,12 +351,12 @@ def from_existing( def client(self) -> Optional[redis.Redis]: """The underlying redis-py client object.""" return self.__redis_client - + @property def _redis_client(self) -> Optional[redis.Redis]: """ Get a Redis client instance. - + Lazily creates a Redis client instance if it doesn't exist. """ if self.__redis_client is None: @@ -359,7 +367,6 @@ def _redis_client(self) -> Optional[redis.Redis]: **self._connection_kwargs, ) return self.__redis_client - @deprecated_function("connect", "Pass connection parameters in __init__.") def connect(self, redis_url: Optional[str] = None, **kwargs): @@ -371,8 +378,7 @@ def connect(self, redis_url: Optional[str] = None, **kwargs): Args: redis_url (Optional[str], optional): The URL of the Redis server to - connect to. If not provided, the method defaults to using the - `REDIS_URL` environment variable. + connect to. Raises: redis.exceptions.ConnectionError: If the connection to the Redis @@ -842,9 +848,9 @@ def __init__( schema (IndexSchema): Index schema object. redis_url (Optional[str], optional): The URL of the Redis server to connect to. - redis_client (Optional[aredis.Redis], optional): An + redis_client (Optional[aredis.Redis]): An instantiated redis client. - connection_kwargs (Dict[str, Any], optional): Redis client connection + connection_kwargs (Optional[Dict[str, Any]]): Redis client connection args. """ if "redis_kwargs" in kwargs: @@ -864,8 +870,9 @@ def __init__( self._connection_kwargs = connection_kwargs or {} self._lock = asyncio.Lock() - # Close connections when the object is garbage collected - weakref.finalize(self, self._finalize_disconnect) + self._owns_redis_client = redis_client is None + if self._owns_redis_client: + weakref.finalize(self, sync_wrapper(self.disconnect)) @classmethod async def from_existing( @@ -934,7 +941,7 @@ async def connect(self, redis_url: Optional[str] = None, **kwargs): await self.set_client(client) @deprecated_function("set_client", "Pass connection parameters in __init__.") - async def set_client(self, redis_client: aredis.Redis): + async def set_client(self, redis_client: Union[aredis.Redis, redis.Redis]): """ [DEPRECATED] Manually set the Redis client to use with the search index. This method is deprecated; please provide connection parameters in __init__. @@ -956,8 +963,7 @@ async def _get_client(self) -> aredis.Redis: kwargs["url"] = self._redis_url self._redis_client = ( await RedisConnectionFactory._get_aredis_connection( - required_modules=self.required_modules, - **kwargs + required_modules=self.required_modules, **kwargs ) ) await RedisConnectionFactory.validate_async_redis( @@ -965,7 +971,9 @@ async def _get_client(self) -> aredis.Redis: ) return self._redis_client - async def _validate_client(self, redis_client: aredis.Redis) -> aredis.Redis: + async def _validate_client( + self, redis_client: Union[aredis.Redis, redis.Redis] + ) -> aredis.Redis: if isinstance(redis_client, redis.Redis): warnings.warn( "Converting sync Redis client to async client is deprecated " @@ -1340,36 +1348,21 @@ async def _info(name: str, redis_client: aredis.Redis) -> Dict[str, Any]: raise RedisSearchError( f"Error while fetching {name} index info: {str(e)}" ) from e - + async def disconnect(self): - """Asynchronously disconnect and cleanup the underlying async redis connection.""" + if self._owns_redis_client is False: + return if self._redis_client is not None: await self._redis_client.aclose() # type: ignore self._redis_client = None def disconnect_sync(self): - """Synchronously disconnect and cleanup the underlying async redis connection.""" - if self._redis_client is None: + if self._redis_client is None or self._owns_redis_client is False: return - loop = asyncio.get_running_loop() - if loop is None or not loop.is_running(): - asyncio.run(self._redis_client.aclose()) # type: ignore - else: - loop.create_task(self.disconnect()) - self._redis_client = None + sync_wrapper(self.disconnect)() async def __aenter__(self): return self async def __aexit__(self, exc_type, exc_val, exc_tb): await self.disconnect() - - def _finalize_disconnect(self): - try: - loop = asyncio.get_running_loop() - except RuntimeError: - loop = None - if loop is None or not loop.is_running(): - asyncio.run(self.disconnect()) - else: - loop.create_task(self.disconnect()) diff --git a/redisvl/utils/utils.py b/redisvl/utils/utils.py index 11f2fb8f..578b30fa 100644 --- a/redisvl/utils/utils.py +++ b/redisvl/utils/utils.py @@ -1,19 +1,23 @@ +import asyncio import inspect import json +import logging import warnings from contextlib import contextmanager from enum import Enum from functools import wraps from time import time -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Coroutine, Dict, Optional from warnings import warn from pydantic import BaseModel from ulid import ULID +from redisvl.utils.log import get_logger + def create_ulid() -> str: - """Generate a unique indentifier to group related Redis documents.""" + """Generate a unique identifier to group related Redis documents.""" return str(ULID()) @@ -159,3 +163,33 @@ def wrapper(*args, **kwargs): return wrapper return decorator + + +def sync_wrapper(fn: Callable[[], Coroutine[Any, Any, Any]]) -> Callable[[], None]: + def wrapper(): + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + try: + if loop is None or not loop.is_running(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + task = loop.create_task(fn()) + loop.run_until_complete(task) + except RuntimeError: + # This could happen if an object stored an event loop and now + # that event loop is closed. There's nothing we can do other than + # advise the user to use explicit cleanup methods. + # + # Uses logging module instead of get_logger() to avoid I/O errors + # if the wrapped function is called as a finalizer. + logging.info( + f"Could not run the async function {fn.__name__} because the event loop is closed. " + "This usually means the object was not properly cleaned up. Please use explicit " + "cleanup methods (e.g., disconnect(), close()) or use the object as an async " + "context manager.", + ) + return + + return wrapper diff --git a/tests/integration/test_async_search_index.py b/tests/integration/test_async_search_index.py index 0d7d8244..ea3eb452 100644 --- a/tests/integration/test_async_search_index.py +++ b/tests/integration/test_async_search_index.py @@ -1,6 +1,7 @@ import pytest import warnings from redis.asyncio import Redis +from redis import Redis as SyncRedis from redisvl.exceptions import RedisSearchError from redisvl.index import AsyncSearchIndex @@ -150,13 +151,16 @@ async def test_search_index_client(async_client, index_schema): @pytest.mark.asyncio -async def test_search_index_set_client(async_client, client, async_index): +async def test_search_index_set_client(client, redis_url, index_schema): + async_index = AsyncSearchIndex(schema=index_schema, redis_url=redis_url) # Ignore deprecation warnings with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=DeprecationWarning) - assert async_index.client == async_client + await async_index.create(overwrite=True, drop=True) + assert isinstance(async_index.client, Redis) # Tests deprecated sync -> async conversation behavior + assert isinstance(client, SyncRedis) await async_index.set_client(client) assert isinstance(async_index.client, Redis) @@ -316,30 +320,78 @@ async def test_check_index_exists_before_info(async_index): @pytest.mark.asyncio -async def test_search_index_async_context_manager(async_index): +async def test_search_index_that_does_not_own_client_context_manager(async_index): async with async_index: await async_index.create(overwrite=True, drop=True) assert async_index._redis_client - assert async_index._redis_client is None + client = async_index._redis_client + assert async_index._redis_client == client @pytest.mark.asyncio -async def test_search_index_context_manager_with_exception(async_index): +async def test_search_index_that_does_not_own_client_context_manager_with_exception( + async_index, +): + try: + async with async_index: + await async_index.create(overwrite=True, drop=True) + client = async_index._redis_client + raise ValueError("test") + except ValueError: + pass + assert async_index._redis_client == client + + +@pytest.mark.asyncio +async def test_search_index_that_does_not_own_client_disconnect(async_index): + await async_index.create(overwrite=True, drop=True) + client = async_index._redis_client + await async_index.disconnect() + assert async_index._redis_client == client + + +@pytest.mark.asyncio +async def test_search_index_that_does_not_own_client_disconnect_sync(async_index): + await async_index.create(overwrite=True, drop=True) + client = async_index._redis_client + async_index.disconnect_sync() + assert async_index._redis_client == client + + +@pytest.mark.asyncio +async def test_search_index_that_owns_client_context_manager(index_schema, redis_url): + async_index = AsyncSearchIndex(schema=index_schema, redis_url=redis_url) async with async_index: await async_index.create(overwrite=True, drop=True) - raise ValueError("test") + assert async_index._redis_client + assert async_index._redis_client is None + + +@pytest.mark.asyncio +async def test_search_index_that_owns_client_context_manager_with_exception( + index_schema, redis_url +): + async_index = AsyncSearchIndex(schema=index_schema, redis_url=redis_url) + try: + async with async_index: + await async_index.create(overwrite=True, drop=True) + raise ValueError("test") + except ValueError: + pass assert async_index._redis_client is None - - + + @pytest.mark.asyncio -async def test_search_index_disconnect(async_index): +async def test_search_index_that_owns_client_disconnect(index_schema, redis_url): + async_index = AsyncSearchIndex(schema=index_schema, redis_url=redis_url) await async_index.create(overwrite=True, drop=True) await async_index.disconnect() assert async_index._redis_client is None - + @pytest.mark.asyncio -async def test_search_engine_disconnect_sync(async_index): +async def test_search_index_that_owns_client_disconnect_sync(index_schema, redis_url): + async_index = AsyncSearchIndex(schema=index_schema, redis_url=redis_url) await async_index.create(overwrite=True, drop=True) - async_index.disconnect_sync() + await async_index.disconnect() assert async_index._redis_client is None diff --git a/tests/integration/test_llmcache.py b/tests/integration/test_llmcache.py index 5babebe2..b489caf9 100644 --- a/tests/integration/test_llmcache.py +++ b/tests/integration/test_llmcache.py @@ -108,15 +108,21 @@ def test_get_index(cache): @pytest.mark.asyncio async def test_get_async_index(cache): - aindex = await cache._get_async_index() - assert isinstance(aindex, AsyncSearchIndex) + async with cache: + aindex = await cache._get_async_index() + assert isinstance(aindex, AsyncSearchIndex) @pytest.mark.asyncio async def test_get_async_index_from_provided_client(cache_with_redis_client): - aindex = await cache_with_redis_client._get_async_index() - assert isinstance(aindex, AsyncSearchIndex) - assert aindex == cache_with_redis_client.aindex + async with cache_with_redis_client: + aindex = await cache_with_redis_client._get_async_index() + # Shouldn't have to do this because it already was done + await aindex.create(overwrite=True, drop=True) + assert await aindex.exists() + assert isinstance(aindex, AsyncSearchIndex) + assert aindex == cache_with_redis_client.aindex + assert await cache_with_redis_client.aindex.exists() def test_delete(cache_no_cleanup): @@ -126,8 +132,9 @@ def test_delete(cache_no_cleanup): @pytest.mark.asyncio async def test_async_delete(cache_no_cleanup): - await cache_no_cleanup.adelete() - assert not cache_no_cleanup.index.exists() + async with cache_no_cleanup: + await cache_no_cleanup.adelete() + assert not cache_no_cleanup.index.exists() def test_store_and_check(cache, vectorizer): @@ -150,8 +157,9 @@ async def test_async_store_and_check(cache, vectorizer): response = "This is a test response." vector = vectorizer.embed(prompt) - await cache.astore(prompt, response, vector=vector) - check_result = await cache.acheck(vector=vector, distance_threshold=0.4) + async with cache: + await cache.astore(prompt, response, vector=vector) + check_result = await cache.acheck(vector=vector, distance_threshold=0.4) assert len(check_result) == 1 print(check_result, flush=True) @@ -202,36 +210,37 @@ async def test_async_return_fields(cache, vectorizer): response = "This is a test response." vector = vectorizer.embed(prompt) - await cache.astore(prompt, response, vector=vector) + async with cache: + await cache.astore(prompt, response, vector=vector) - # check default return fields - check_result = await cache.acheck(vector=vector) - assert set(check_result[0].keys()) == { - "key", - "entry_id", - "prompt", - "response", - "vector_distance", - "inserted_at", - "updated_at", - } - - # check specific return fields - fields = [ - "key", - "entry_id", - "prompt", - "response", - "vector_distance", - ] - check_result = await cache.acheck(vector=vector, return_fields=fields) - assert set(check_result[0].keys()) == set(fields) - - # check only some return fields - fields = ["inserted_at", "updated_at"] - check_result = await cache.acheck(vector=vector, return_fields=fields) - fields.append("key") - assert set(check_result[0].keys()) == set(fields) + # check default return fields + check_result = await cache.acheck(vector=vector) + assert set(check_result[0].keys()) == { + "key", + "entry_id", + "prompt", + "response", + "vector_distance", + "inserted_at", + "updated_at", + } + + # check specific return fields + fields = [ + "key", + "entry_id", + "prompt", + "response", + "vector_distance", + ] + check_result = await cache.acheck(vector=vector, return_fields=fields) + assert set(check_result[0].keys()) == set(fields) + + # check only some return fields + fields = ["inserted_at", "updated_at"] + check_result = await cache.acheck(vector=vector, return_fields=fields) + fields.append("key") + assert set(check_result[0].keys()) == set(fields) # Test clearing the cache @@ -253,9 +262,10 @@ async def test_async_clear(cache, vectorizer): response = "This is a test response." vector = vectorizer.embed(prompt) - await cache.astore(prompt, response, vector=vector) - await cache.aclear() - check_result = await cache.acheck(vector=vector) + async with cache: + await cache.astore(prompt, response, vector=vector) + await cache.aclear() + check_result = await cache.acheck(vector=vector) assert len(check_result) == 0 @@ -279,10 +289,11 @@ async def test_async_ttl_expiration(cache_with_ttl, vectorizer): response = "This is a test response." vector = vectorizer.embed(prompt) - await cache_with_ttl.astore(prompt, response, vector=vector) - await asyncio.sleep(3) + async with cache_with_ttl: + await cache_with_ttl.astore(prompt, response, vector=vector) + await asyncio.sleep(3) - check_result = await cache_with_ttl.acheck(vector=vector) + check_result = await cache_with_ttl.acheck(vector=vector) assert len(check_result) == 0 @@ -305,10 +316,11 @@ async def test_async_custom_ttl(cache_with_ttl, vectorizer): response = "This is a test response." vector = vectorizer.embed(prompt) - await cache_with_ttl.astore(prompt, response, vector=vector, ttl=5) - await asyncio.sleep(3) + async with cache_with_ttl: + await cache_with_ttl.astore(prompt, response, vector=vector, ttl=5) + await asyncio.sleep(3) + check_result = await cache_with_ttl.acheck(vector=vector) - check_result = await cache_with_ttl.acheck(vector=vector) assert len(check_result) != 0 assert cache_with_ttl.ttl == 2 @@ -333,11 +345,12 @@ async def test_async_ttl_refresh(cache_with_ttl, vectorizer): response = "This is a test response." vector = vectorizer.embed(prompt) - await cache_with_ttl.astore(prompt, response, vector=vector) + async with cache_with_ttl: + await cache_with_ttl.astore(prompt, response, vector=vector) - for _ in range(3): - await asyncio.sleep(1) - check_result = await cache_with_ttl.acheck(vector=vector) + for _ in range(3): + await asyncio.sleep(1) + check_result = await cache_with_ttl.acheck(vector=vector) assert len(check_result) == 1 @@ -362,11 +375,13 @@ async def test_async_drop_document(cache, vectorizer): response = "This is a test response." vector = vectorizer.embed(prompt) - await cache.astore(prompt, response, vector=vector) - check_result = await cache.acheck(vector=vector) + async with cache: + await cache.astore(prompt, response, vector=vector) + check_result = await cache.acheck(vector=vector) + + await cache.adrop(ids=[check_result[0]["entry_id"]]) + recheck_result = await cache.acheck(vector=vector) - await cache.adrop(ids=[check_result[0]["entry_id"]]) - recheck_result = await cache.acheck(vector=vector) assert len(recheck_result) == 0 @@ -411,12 +426,14 @@ async def test_async_drop_documents(cache, vectorizer): vector = vectorizer.embed(prompt) await cache.astore(prompt, response, vector=vector) - check_result = await cache.acheck(vector=vector, num_results=3) - print(check_result, flush=True) - ids = [r["entry_id"] for r in check_result[0:2]] # drop first 2 entries - await cache.adrop(ids=ids) + async with cache: + check_result = await cache.acheck(vector=vector, num_results=3) + print(check_result, flush=True) + ids = [r["entry_id"] for r in check_result[0:2]] # drop first 2 entries + await cache.adrop(ids=ids) + + recheck_result = await cache.acheck(vector=vector, num_results=3) - recheck_result = await cache.acheck(vector=vector, num_results=3) assert len(recheck_result) == 1 @@ -445,19 +462,22 @@ def test_updating_document(cache): async def test_async_updating_document(cache): prompt = "This is a test prompt." response = "This is a test response." - await cache.astore(prompt=prompt, response=response) - check_result = await cache.acheck(prompt=prompt, return_fields=["updated_at"]) - key = check_result[0]["key"] + async with cache: + await cache.astore(prompt=prompt, response=response) - await asyncio.sleep(1) + check_result = await cache.acheck(prompt=prompt, return_fields=["updated_at"]) + key = check_result[0]["key"] - metadata = {"foo": "bar"} - await cache.aupdate(key=key, metadata=metadata) + await asyncio.sleep(1) + + metadata = {"foo": "bar"} + await cache.aupdate(key=key, metadata=metadata) + + updated_result = await cache.acheck( + prompt=prompt, return_fields=["updated_at", "metadata"] + ) - updated_result = await cache.acheck( - prompt=prompt, return_fields=["updated_at", "metadata"] - ) assert updated_result[0]["metadata"] == metadata assert updated_result[0]["updated_at"] > check_result[0]["updated_at"] @@ -486,10 +506,12 @@ async def test_async_ttl_expiration_after_update(cache_with_ttl, vectorizer): assert cache_with_ttl.ttl == 4 - await cache_with_ttl.astore(prompt, response, vector=vector) - await asyncio.sleep(5) + async with cache_with_ttl: + await cache_with_ttl.astore(prompt, response, vector=vector) + await asyncio.sleep(5) + + check_result = await cache_with_ttl.acheck(vector=vector) - check_result = await cache_with_ttl.acheck(vector=vector) assert len(check_result) == 0 @@ -943,10 +965,11 @@ def test_deprecated_dtype_argument(redis_url): ) - @pytest.mark.asyncio async def test_cache_async_context_manager(redis_url): - async with SemanticCache(name="test_cache", redis_url=redis_url) as cache: + async with SemanticCache( + name="test_cache_async_context_manager", redis_url=redis_url + ) as cache: await cache.astore("test prompt", "test response") assert cache._aindex assert cache._aindex is None @@ -954,22 +977,27 @@ async def test_cache_async_context_manager(redis_url): @pytest.mark.asyncio async def test_cache_async_context_manager_with_exception(redis_url): - async with SemanticCache(name="test_cache", redis_url=redis_url) as cache: - await cache.astore("test prompt", "test response") - raise ValueError("test") + try: + async with SemanticCache( + name="test_cache_async_context_manager_with_exception", redis_url=redis_url + ) as cache: + await cache.astore("test prompt", "test response") + raise ValueError("test") + except ValueError: + pass assert cache._aindex is None @pytest.mark.asyncio async def test_cache_async_disconnect(redis_url): - cache = SemanticCache(name="test_cache", redis_url=redis_url) + cache = SemanticCache(name="test_cache_async_disconnect", redis_url=redis_url) await cache.astore("test prompt", "test response") await cache.adisconnect() assert cache._aindex is None def test_cache_disconnect(redis_url): - cache = SemanticCache(name="test_cache", redis_url=redis_url) + cache = SemanticCache(name="test_cache_disconnect", redis_url=redis_url) cache.store("test prompt", "test response") cache.disconnect() # We keep this index object around because it isn't lazily created diff --git a/tests/integration/test_search_index.py b/tests/integration/test_search_index.py index 2b803d8a..b3f317cb 100644 --- a/tests/integration/test_search_index.py +++ b/tests/integration/test_search_index.py @@ -148,13 +148,17 @@ def test_search_index_client(client, index_schema): assert index.client == client -def test_search_index_set_client(async_client, client, index): +def test_search_index_set_client(async_client, redis_url, index_schema): + index = SearchIndex(schema=index_schema, redis_url=redis_url) + with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=DeprecationWarning) - assert index.client == client + index.create(overwrite=True, drop=True) + assert index.client # should not be able to set an async client here with pytest.raises(TypeError): index.set_client(async_client) + assert index.client is not async_client index.disconnect() assert index.client is None @@ -289,14 +293,45 @@ def test_index_needs_valid_schema(): SearchIndex(schema="Not A Valid Schema") # type: ignore -def test_search_index_context_manager(index): +def test_search_index_that_does_not_own_client_context_manager(index): + with index: + index.create(overwrite=True, drop=True) + assert index.client + client = index.client + # Client should not have changed outside of the context manager + assert index.client == client + + +def test_search_index_that_does_not_own_client_context_manager_with_exception(index): + with pytest.raises(ValueError): + with index: + index.create(overwrite=True, drop=True) + client = index.client + raise ValueError("test") + # Client should not have changed outside of the context manager + assert index.client == client + + +def test_search_index_that_does_not_own_client_disconnect(index): + index.create(overwrite=True, drop=True) + client = index.client + index.disconnect() + # Client should not have changed after disconnecting + assert index.client == client + + +def test_search_index_that_owns_client_context_manager(index_schema, redis_url): + index = SearchIndex(schema=index_schema, redis_url=redis_url) with index: index.create(overwrite=True, drop=True) assert index.client assert index.client is None - - -def test_search_index_context_manager_with_exception(index): + + +def test_search_index_that_owns_client_context_manager_with_exception( + index_schema, redis_url +): + index = SearchIndex(schema=index_schema, redis_url=redis_url) with pytest.raises(ValueError): with index: index.create(overwrite=True, drop=True) @@ -304,7 +339,8 @@ def test_search_index_context_manager_with_exception(index): assert index.client is None -def test_search_index_disconnect(index): +def test_search_index_that_owns_client_disconnect(index_schema, redis_url): + index = SearchIndex(schema=index_schema, redis_url=redis_url) index.create(overwrite=True, drop=True) index.disconnect() assert index.client is None From aaa2091cc69477975bc0b131ff283ab05b9ff3a3 Mon Sep 17 00:00:00 2001 From: Andrew Brookins Date: Wed, 19 Feb 2025 16:51:35 -0800 Subject: [PATCH 12/22] lint --- redisvl/extensions/llmcache/semantic.py | 6 +++--- redisvl/index/index.py | 7 +++---- redisvl/redis/connection.py | 7 ++++--- tests/conftest.py | 4 +--- tests/integration/test_async_search_index.py | 5 +++-- tests/integration/test_search_index.py | 1 + tests/unit/test_utils.py | 9 +++++++-- 7 files changed, 22 insertions(+), 17 deletions(-) diff --git a/redisvl/extensions/llmcache/semantic.py b/redisvl/extensions/llmcache/semantic.py index 56df24bc..b5ffaa09 100644 --- a/redisvl/extensions/llmcache/semantic.py +++ b/redisvl/extensions/llmcache/semantic.py @@ -1,8 +1,9 @@ import asyncio -from typing import Any, Dict, List, Optional import weakref +from typing import Any, Dict, List, Optional from redis import Redis + from redisvl.extensions.constants import ( CACHE_VECTOR_FIELD_NAME, ENTRY_ID_FIELD_NAME, @@ -25,15 +26,14 @@ from redisvl.redis.connection import RedisConnectionFactory from redisvl.utils.log import get_logger from redisvl.utils.utils import ( - sync_wrapper, current_timestamp, deprecated_argument, serialize, + sync_wrapper, validate_vector_dims, ) from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer - logger = get_logger("[RedisVL]") diff --git a/redisvl/index/index.py b/redisvl/index/index.py index 2e8ba1d0..2d834c16 100644 --- a/redisvl/index/index.py +++ b/redisvl/index/index.py @@ -1,12 +1,12 @@ import asyncio import json import logging -from os import replace -from re import S import threading import warnings +import weakref from functools import wraps - +from os import replace +from re import S from typing import ( TYPE_CHECKING, Any, @@ -19,7 +19,6 @@ Optional, Union, ) -import weakref from redisvl.utils.utils import deprecated_argument, deprecated_function, sync_wrapper diff --git a/redisvl/redis/connection.py b/redisvl/redis/connection.py index 8924e58f..cc95c479 100644 --- a/redisvl/redis/connection.py +++ b/redisvl/redis/connection.py @@ -191,7 +191,9 @@ class RedisConnectionFactory: """ @classmethod - @deprecated_function("connect", "Please use `get_redis_connection` or `get_async_redis_connection`.") + @deprecated_function( + "connect", "Please use `get_redis_connection` or `get_async_redis_connection`." + ) def connect( cls, redis_url: Optional[str] = None, use_async: bool = False, **kwargs ) -> Union[Redis, AsyncRedis]: @@ -251,7 +253,7 @@ def get_redis_connection( ) return client - + @staticmethod async def _get_aredis_connection( url: Optional[str] = None, @@ -314,7 +316,6 @@ def get_async_redis_connection( url = url or get_address_from_env() return AsyncRedis.from_url(url, **kwargs) - @staticmethod def sync_to_async_redis(redis_client: Redis) -> AsyncRedis: """Convert a synchronous Redis client to an asynchronous one.""" diff --git a/tests/conftest.py b/tests/conftest.py index f3c069cb..8657cb44 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -54,9 +54,7 @@ async def async_client(redis_url): """ An async Redis client that uses the dynamic `redis_url`. """ - async with await RedisConnectionFactory._get_aredis_connection( - redis_url - ) as client: + async with await RedisConnectionFactory._get_aredis_connection(redis_url) as client: yield client diff --git a/tests/integration/test_async_search_index.py b/tests/integration/test_async_search_index.py index ea3eb452..ddd14bc5 100644 --- a/tests/integration/test_async_search_index.py +++ b/tests/integration/test_async_search_index.py @@ -1,7 +1,8 @@ -import pytest import warnings -from redis.asyncio import Redis + +import pytest from redis import Redis as SyncRedis +from redis.asyncio import Redis from redisvl.exceptions import RedisSearchError from redisvl.index import AsyncSearchIndex diff --git a/tests/integration/test_search_index.py b/tests/integration/test_search_index.py index b3f317cb..9d649cc8 100644 --- a/tests/integration/test_search_index.py +++ b/tests/integration/test_search_index.py @@ -1,4 +1,5 @@ import warnings + import pytest from redisvl.exceptions import RedisSearchError diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index a87e6443..a6e3dd02 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -1,6 +1,7 @@ +from functools import wraps + import numpy as np import pytest -from functools import wraps from redisvl.redis.utils import ( array_to_buffer, @@ -8,7 +9,11 @@ convert_bytes, make_dict, ) -from redisvl.utils.utils import assert_no_warnings, deprecated_argument, deprecated_function +from redisvl.utils.utils import ( + assert_no_warnings, + deprecated_argument, + deprecated_function, +) def test_even_number_of_elements(): From 8f8642f16e97946e0b818a306efc6565aecf901b Mon Sep 17 00:00:00 2001 From: Andrew Brookins Date: Wed, 19 Feb 2025 20:41:11 -0800 Subject: [PATCH 13/22] Scale back module validation to match existing behavior --- redisvl/index/index.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/redisvl/index/index.py b/redisvl/index/index.py index 2d834c16..86410810 100644 --- a/redisvl/index/index.py +++ b/redisvl/index/index.py @@ -391,8 +391,10 @@ def connect(self, redis_url: Optional[str] = None, **kwargs): index.connect(redis_url="redis://localhost:6379") """ + # TODO: Intentionally not including required modules to match existing + # behavior, but we need to review. self.__redis_client = RedisConnectionFactory.get_redis_connection( - redis_url=redis_url, required_modules=self.required_modules, **kwargs + redis_url=redis_url, **kwargs ) @deprecated_function("set_client", "Pass connection parameters in __init__.") @@ -420,6 +422,8 @@ def set_client(self, redis_client: redis.Redis, **kwargs): index.set_client(client) """ + # TODO: Including required modules to match existing behavior, but we + # need to review. RedisConnectionFactory.validate_sync_redis( redis_client, required_modules=self.required_modules ) @@ -934,8 +938,10 @@ async def connect(self, redis_url: Optional[str] = None, **kwargs): "connect() is deprecated; pass connection parameters in __init__", DeprecationWarning, ) + # TODO: Intentionally not including required modules to match existing + # behavior, but we need to review. client = await RedisConnectionFactory._get_aredis_connection( - redis_url=redis_url, required_modules=self.required_modules, **kwargs + redis_url=redis_url, **kwargs ) await self.set_client(client) @@ -961,9 +967,9 @@ async def _get_client(self) -> aredis.Redis: if self._redis_url: kwargs["url"] = self._redis_url self._redis_client = ( - await RedisConnectionFactory._get_aredis_connection( - required_modules=self.required_modules, **kwargs - ) + # TODO: Intentionally not including required modules to match existing + # behavior, but we need to review. + await RedisConnectionFactory._get_aredis_connection(**kwargs) ) await RedisConnectionFactory.validate_async_redis( self._redis_client, self._lib_name From 3198840c57dcf5fc8cc2577fcd98f2d8f294aa82 Mon Sep 17 00:00:00 2001 From: Andrew Brookins Date: Wed, 19 Feb 2025 21:06:42 -0800 Subject: [PATCH 14/22] remove required_modules param from `from_existing` --- redisvl/index/index.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/redisvl/index/index.py b/redisvl/index/index.py index 86410810..f32a4695 100644 --- a/redisvl/index/index.py +++ b/redisvl/index/index.py @@ -304,7 +304,6 @@ def from_existing( name: str, redis_client: Optional[redis.Redis] = None, redis_url: Optional[str] = None, - required_modules: Optional[List[Dict[str, Any]]] = None, **kwargs, ): """ @@ -316,8 +315,6 @@ def from_existing( instantiated redis client. redis_url (Optional[str]): The URL of the Redis server to connect to. - required_modules (Optional[List[Dict[str, Any]]]): List of required - Redis modules with version requirements. Raises: ValueError: If redis_url or redis_client is not provided. @@ -326,11 +323,13 @@ def from_existing( try: if redis_url: redis_client = RedisConnectionFactory.get_redis_connection( - redis_url=redis_url, required_modules=required_modules, **kwargs + redis_url=redis_url, + required_modules=cls.required_modules, + **kwargs, ) elif redis_client: RedisConnectionFactory.validate_sync_redis( - redis_client, required_modules=required_modules + redis_client, required_modules=cls.required_modules ) except RedisModuleVersionError as e: raise RedisModuleVersionError( From cd37fe9f413356a9c0cb454f2ad3aac9789d60cf Mon Sep 17 00:00:00 2001 From: Andrew Brookins Date: Wed, 19 Feb 2025 21:30:47 -0800 Subject: [PATCH 15/22] old man yells at mypy emoji --- redisvl/extensions/llmcache/semantic.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/redisvl/extensions/llmcache/semantic.py b/redisvl/extensions/llmcache/semantic.py index b5ffaa09..41e6e214 100644 --- a/redisvl/extensions/llmcache/semantic.py +++ b/redisvl/extensions/llmcache/semantic.py @@ -186,13 +186,14 @@ def _modify_schema( async def _get_async_index(self) -> AsyncSearchIndex: """Lazily construct the async search index class.""" # Construct async index if necessary - if not self._aindex: + async_client = None + if self._aindex is None: client = self.redis_kwargs.get("redis_client") - if client and isinstance(client, Redis): - client = RedisConnectionFactory.sync_to_async_redis(client) + if isinstance(client, Redis): + async_client = RedisConnectionFactory.sync_to_async_redis(client) self._aindex = AsyncSearchIndex( schema=self._index.schema, - redis_client=client, + redis_client=async_client, redis_url=self.redis_kwargs["redis_url"], **self.redis_kwargs["connection_kwargs"], ) From 8dab1e4f8f6ba0f736ea5cb6c98564536ec6637a Mon Sep 17 00:00:00 2001 From: Andrew Brookins Date: Thu, 20 Feb 2025 11:12:11 -0800 Subject: [PATCH 16/22] Update docs --- README.md | 11 +++++------ tests/integration/test_async_search_index.py | 2 +- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 1d138377..22765c27 100644 --- a/README.md +++ b/README.md @@ -121,19 +121,18 @@ Choose from multiple Redis deployment options: }) ``` -2. [Create a SearchIndex](https://docs.redisvl.com/en/stable/user_guide/01_getting_started.html#create-a-searchindex) class with an input schema and client connection in order to perform admin and search operations on your index in Redis: +2. [Create a SearchIndex](https://docs.redisvl.com/en/stable/user_guide/01_getting_started.html#create-a-searchindex) class with an input schema to perform admin and search operations on your index in Redis: ```python from redis import Redis from redisvl.index import SearchIndex - # Establish Redis connection and define index - client = Redis.from_url("redis://localhost:6379") - index = SearchIndex(schema, client) + # Define the index + index = SearchIndex(schema, redis_url="redis://localhost:6379") # Create the index in Redis index.create() ``` - > Async compliant search index class also available: [AsyncSearchIndex](https://docs.redisvl.com/en/stable/api/searchindex.html#redisvl.index.AsyncSearchIndex). + > An async-compatible index class also available: [AsyncSearchIndex](https://docs.redisvl.com/en/stable/api/searchindex.html#redisvl.index.AsyncSearchIndex). 3. [Load](https://docs.redisvl.com/en/stable/user_guide/01_getting_started.html#load-data-to-searchindex) and [fetch](https://docs.redisvl.com/en/stable/user_guide/01_getting_started.html#fetch-an-object-from-redis) data to/from your Redis instance: @@ -346,7 +345,7 @@ Commands: stats Obtain statistics about an index ``` -> Read more about [using the CLI](https://docs.redisvl.com/en/stable/user_guide/cli.html). +> Read more about [using the CLI](https://docs.redisvl.com/en/latest/overview/cli.html). ## 🚀 Why RedisVL? diff --git a/tests/integration/test_async_search_index.py b/tests/integration/test_async_search_index.py index ddd14bc5..2aa7ae0f 100644 --- a/tests/integration/test_async_search_index.py +++ b/tests/integration/test_async_search_index.py @@ -160,7 +160,7 @@ async def test_search_index_set_client(client, redis_url, index_schema): await async_index.create(overwrite=True, drop=True) assert isinstance(async_index.client, Redis) - # Tests deprecated sync -> async conversation behavior + # Tests deprecated sync -> async conversion behavior assert isinstance(client, SyncRedis) await async_index.set_client(client) assert isinstance(async_index.client, Redis) From 33cf7c5e010300e09bc4fbcd83dd59dc8e817d5c Mon Sep 17 00:00:00 2001 From: Andrew Brookins Date: Thu, 20 Feb 2025 12:37:55 -0800 Subject: [PATCH 17/22] Use default required mods unless needing introspection --- redisvl/index/index.py | 30 ++++++++++++------------------ redisvl/utils/utils.py | 3 +-- 2 files changed, 13 insertions(+), 20 deletions(-) diff --git a/redisvl/index/index.py b/redisvl/index/index.py index f32a4695..10fa510e 100644 --- a/redisvl/index/index.py +++ b/redisvl/index/index.py @@ -1,12 +1,8 @@ import asyncio import json -import logging import threading import warnings import weakref -from functools import wraps -from os import replace -from re import S from typing import ( TYPE_CHECKING, Any, @@ -47,6 +43,12 @@ logger = get_logger(__name__) +REQUIRED_MODULES_FOR_INTROSPECTION = [ + {"name": "search", "ver": 20810}, + {"name": "searchlight", "ver": 20810}, +] + + def process_results( results: "Result", query: BaseQuery, storage_type: StorageType ) -> List[Dict[str, Any]]: @@ -242,11 +244,6 @@ class SearchIndex(BaseSearchIndex): """ - required_modules = [ - {"name": "search", "ver": 20810}, - {"name": "searchlight", "ver": 20810}, - ] - @deprecated_argument("connection_args", "Use connection_kwargs instead.") def __init__( self, @@ -324,12 +321,12 @@ def from_existing( if redis_url: redis_client = RedisConnectionFactory.get_redis_connection( redis_url=redis_url, - required_modules=cls.required_modules, + required_modules=REQUIRED_MODULES_FOR_INTROSPECTION, **kwargs, ) elif redis_client: RedisConnectionFactory.validate_sync_redis( - redis_client, required_modules=cls.required_modules + redis_client, required_modules=REQUIRED_MODULES_FOR_INTROSPECTION ) except RedisModuleVersionError as e: raise RedisModuleVersionError( @@ -829,11 +826,6 @@ class AsyncSearchIndex(BaseSearchIndex): """ - required_modules = [ - {"name": "search", "ver": 20810}, - {"name": "searchlight", "ver": 20810}, - ] - @deprecated_argument("redis_kwargs", "Use connection_kwargs instead.") def __init__( self, @@ -902,11 +894,13 @@ async def from_existing( try: if redis_url: redis_client = await RedisConnectionFactory._get_aredis_connection( - url=redis_url, required_modules=cls.required_modules, **kwargs + url=redis_url, + required_modules=REQUIRED_MODULES_FOR_INTROSPECTION, + **kwargs, ) elif redis_client: await RedisConnectionFactory.validate_async_redis( - redis_client, required_modules=cls.required_modules + redis_client, required_modules=REQUIRED_MODULES_FOR_INTROSPECTION ) except RedisModuleVersionError as e: raise RedisModuleVersionError( diff --git a/redisvl/utils/utils.py b/redisvl/utils/utils.py index 578b30fa..4341ac6d 100644 --- a/redisvl/utils/utils.py +++ b/redisvl/utils/utils.py @@ -13,8 +13,6 @@ from pydantic import BaseModel from ulid import ULID -from redisvl.utils.log import get_logger - def create_ulid() -> str: """Generate a unique identifier to group related Redis documents.""" @@ -157,6 +155,7 @@ def decorator(func): @wraps(func) def wrapper(*args, **kwargs): + print("???") warn(warning_message, category=DeprecationWarning, stacklevel=3) return func(*args, **kwargs) From 35c22772c715c82a2530d6e8026f310c0138b41e Mon Sep 17 00:00:00 2001 From: Andrew Brookins Date: Thu, 20 Feb 2025 12:48:02 -0800 Subject: [PATCH 18/22] finish the job on module variable shuffling --- redisvl/index/index.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/redisvl/index/index.py b/redisvl/index/index.py index 10fa510e..42e9c74f 100644 --- a/redisvl/index/index.py +++ b/redisvl/index/index.py @@ -418,11 +418,7 @@ def set_client(self, redis_client: redis.Redis, **kwargs): index.set_client(client) """ - # TODO: Including required modules to match existing behavior, but we - # need to review. - RedisConnectionFactory.validate_sync_redis( - redis_client, required_modules=self.required_modules - ) + RedisConnectionFactory.validate_sync_redis(redis_client) self.__redis_client = redis_client return self @@ -931,8 +927,6 @@ async def connect(self, redis_url: Optional[str] = None, **kwargs): "connect() is deprecated; pass connection parameters in __init__", DeprecationWarning, ) - # TODO: Intentionally not including required modules to match existing - # behavior, but we need to review. client = await RedisConnectionFactory._get_aredis_connection( redis_url=redis_url, **kwargs ) @@ -960,8 +954,6 @@ async def _get_client(self) -> aredis.Redis: if self._redis_url: kwargs["url"] = self._redis_url self._redis_client = ( - # TODO: Intentionally not including required modules to match existing - # behavior, but we need to review. await RedisConnectionFactory._get_aredis_connection(**kwargs) ) await RedisConnectionFactory.validate_async_redis( From a9b014b31a9853f1594abbcc89ecfc83fc3ea48a Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Thu, 20 Feb 2025 15:48:42 -0500 Subject: [PATCH 19/22] use logger --- redisvl/index/index.py | 2 +- redisvl/utils/utils.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/redisvl/index/index.py b/redisvl/index/index.py index 42e9c74f..0565a9b7 100644 --- a/redisvl/index/index.py +++ b/redisvl/index/index.py @@ -289,7 +289,7 @@ def __init__( def disconnect(self): """Disconnect from the Redis database.""" if self._owns_redis_client is False: - print("Index does not own client, not disconnecting") + logger.info("Index does not own client, not disconnecting") return if self.__redis_client: self.__redis_client.close() diff --git a/redisvl/utils/utils.py b/redisvl/utils/utils.py index 4341ac6d..4c40d41a 100644 --- a/redisvl/utils/utils.py +++ b/redisvl/utils/utils.py @@ -155,7 +155,6 @@ def decorator(func): @wraps(func) def wrapper(*args, **kwargs): - print("???") warn(warning_message, category=DeprecationWarning, stacklevel=3) return func(*args, **kwargs) From fd2b417cd2605dac3d3092f1604ab6db05c627c9 Mon Sep 17 00:00:00 2001 From: Andrew Brookins Date: Thu, 20 Feb 2025 12:51:29 -0800 Subject: [PATCH 20/22] remove unnecessary comment --- redisvl/index/index.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/redisvl/index/index.py b/redisvl/index/index.py index 0565a9b7..59d20f28 100644 --- a/redisvl/index/index.py +++ b/redisvl/index/index.py @@ -387,8 +387,6 @@ def connect(self, redis_url: Optional[str] = None, **kwargs): index.connect(redis_url="redis://localhost:6379") """ - # TODO: Intentionally not including required modules to match existing - # behavior, but we need to review. self.__redis_client = RedisConnectionFactory.get_redis_connection( redis_url=redis_url, **kwargs ) From 3abd8ee77e2555bcf15cfefe6cd7d79dd2c0c4f3 Mon Sep 17 00:00:00 2001 From: Andrew Brookins Date: Thu, 20 Feb 2025 13:02:47 -0800 Subject: [PATCH 21/22] remove print output --- docs/user_guide/02_hybrid_queries.ipynb | 361 +++++++----------------- 1 file changed, 101 insertions(+), 260 deletions(-) diff --git a/docs/user_guide/02_hybrid_queries.ipynb b/docs/user_guide/02_hybrid_queries.ipynb index e2a9b0d5..9568669d 100644 --- a/docs/user_guide/02_hybrid_queries.ipynb +++ b/docs/user_guide/02_hybrid_queries.ipynb @@ -78,7 +78,15 @@ "cell_type": "code", "execution_count": 3, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "13:02:18 redisvl.index.index INFO Index already exists, overwriting.\n" + ] + } + ], "source": [ "from redisvl.index import SearchIndex\n", "\n", @@ -98,8 +106,18 @@ "name": "stdout", "output_type": "stream", "text": [ - "\u001b[32m13:03:16\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m Indices:\n", - "\u001b[32m13:03:16\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 1. user_queries\n" + "\u001b[32m13:02:25\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m Indices:\n", + "\u001b[32m13:02:25\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 1. float64_cache\n", + "\u001b[32m13:02:25\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 2. float64_session\n", + "\u001b[32m13:02:25\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 3. float16_cache\n", + "\u001b[32m13:02:25\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 4. float16_session\n", + "\u001b[32m13:02:25\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 5. float32_session\n", + "\u001b[32m13:02:25\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 6. float32_cache\n", + "\u001b[32m13:02:25\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 7. bfloat_cache\n", + "\u001b[32m13:02:25\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 8. user_queries\n", + "\u001b[32m13:02:25\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 9. student tutor\n", + "\u001b[32m13:02:25\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 10. tutor\n", + "\u001b[32m13:02:25\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 11. bfloat_session\n" ] } ], @@ -142,17 +160,10 @@ "execution_count": 6, "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "@credit_score:{high}=>[KNN 10 @user_embedding $vector AS vector_distance] RETURN 6 user credit_score age job office_location vector_distance SORTBY vector_distance ASC DIALECT 2 LIMIT 0 10\n" - ] - }, { "data": { "text/html": [ - "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
" + "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0johnhigh18engineer-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.266666650772nancyhigh94doctor-122.4194,37.7749
" ], "text/plain": [ "" @@ -184,17 +195,10 @@ "execution_count": 7, "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(-@credit_score:{high})=>[KNN 10 @user_embedding $vector AS vector_distance] RETURN 6 user credit_score age job office_location vector_distance SORTBY vector_distance ASC DIALECT 2 LIMIT 0 10\n" - ] - }, { "data": { "text/html": [ - "
vector_distanceusercredit_scoreagejoboffice_location
0derricklow14doctor-122.4194,37.7749
0.217882037163taimurlow15CEO-122.0839,37.3861
0.653301358223joemedium35dentist-122.0839,37.3861
" + "
vector_distanceusercredit_scoreagejoboffice_location
0derricklow14doctor-122.4194,37.7749
0derricklow14doctor-122.4194,37.7749
0.217882037163taimurlow15CEO-122.0839,37.3861
0.217882037163taimurlow15CEO-122.0839,37.3861
0.653301358223joemedium35dentist-122.0839,37.3861
0.653301358223joemedium35dentist-122.0839,37.3861
" ], "text/plain": [ "" @@ -217,17 +221,10 @@ "execution_count": 8, "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "@credit_score:{high|medium}=>[KNN 10 @user_embedding $vector AS vector_distance] RETURN 6 user credit_score age job office_location vector_distance SORTBY vector_distance ASC DIALECT 2 LIMIT 0 10\n" - ] - }, { "data": { "text/html": [ - "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.653301358223joemedium35dentist-122.0839,37.3861
" + "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0johnhigh18engineer-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.653301358223joemedium35dentist-122.0839,37.3861
0.653301358223joemedium35dentist-122.0839,37.3861
" ], "text/plain": [ "" @@ -250,17 +247,10 @@ "execution_count": 9, "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "@credit_score:{high|medium}=>[KNN 10 @user_embedding $vector AS vector_distance] RETURN 6 user credit_score age job office_location vector_distance SORTBY vector_distance ASC DIALECT 2 LIMIT 0 10\n" - ] - }, { "data": { "text/html": [ - "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.653301358223joemedium35dentist-122.0839,37.3861
" + "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0johnhigh18engineer-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.653301358223joemedium35dentist-122.0839,37.3861
0.653301358223joemedium35dentist-122.0839,37.3861
" ], "text/plain": [ "" @@ -294,17 +284,10 @@ "execution_count": 10, "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "*=>[KNN 10 @user_embedding $vector AS vector_distance] RETURN 6 user credit_score age job office_location vector_distance SORTBY vector_distance ASC DIALECT 2 LIMIT 0 10\n" - ] - }, { "data": { "text/html": [ - "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0derricklow14doctor-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.217882037163taimurlow15CEO-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.653301358223joemedium35dentist-122.0839,37.3861
" + "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0derricklow14doctor-122.4194,37.7749
0johnhigh18engineer-122.4194,37.7749
0derricklow14doctor-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.217882037163taimurlow15CEO-122.0839,37.3861
0.217882037163taimurlow15CEO-122.0839,37.3861
" ], "text/plain": [ "" @@ -336,17 +319,10 @@ "execution_count": 11, "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "@age:[(15 +inf]=>[KNN 10 @user_embedding $vector AS vector_distance] RETURN 6 user credit_score age job office_location vector_distance SORTBY vector_distance ASC DIALECT 2 LIMIT 0 10\n" - ] - }, { "data": { "text/html": [ - "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.653301358223joemedium35dentist-122.0839,37.3861
" + "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0johnhigh18engineer-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.653301358223joemedium35dentist-122.0839,37.3861
0.653301358223joemedium35dentist-122.0839,37.3861
" ], "text/plain": [ "" @@ -370,17 +346,10 @@ "execution_count": 12, "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "@age:[14 14]=>[KNN 10 @user_embedding $vector AS vector_distance] RETURN 6 user credit_score age job office_location vector_distance SORTBY vector_distance ASC DIALECT 2 LIMIT 0 10\n" - ] - }, { "data": { "text/html": [ - "
vector_distanceusercredit_scoreagejoboffice_location
0derricklow14doctor-122.4194,37.7749
" + "
vector_distanceusercredit_scoreagejoboffice_location
0derricklow14doctor-122.4194,37.7749
0derricklow14doctor-122.4194,37.7749
" ], "text/plain": [ "" @@ -403,17 +372,10 @@ "execution_count": 13, "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(-@age:[14 14])=>[KNN 10 @user_embedding $vector AS vector_distance] RETURN 6 user credit_score age job office_location vector_distance SORTBY vector_distance ASC DIALECT 2 LIMIT 0 10\n" - ] - }, { "data": { "text/html": [ - "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.217882037163taimurlow15CEO-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.653301358223joemedium35dentist-122.0839,37.3861
" + "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0johnhigh18engineer-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.217882037163taimurlow15CEO-122.0839,37.3861
0.217882037163taimurlow15CEO-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.266666650772nancyhigh94doctor-122.4194,37.7749
" ], "text/plain": [ "" @@ -445,17 +407,10 @@ "execution_count": 14, "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "@job:(\"doctor\")=>[KNN 10 @user_embedding $vector AS vector_distance] RETURN 6 user credit_score age job office_location vector_distance SORTBY vector_distance ASC DIALECT 2 LIMIT 0 10\n" - ] - }, { "data": { "text/html": [ - "
vector_distanceusercredit_scoreagejoboffice_location
0derricklow14doctor-122.4194,37.7749
0.266666650772nancyhigh94doctor-122.4194,37.7749
" + "
vector_distanceusercredit_scoreagejoboffice_location
0derricklow14doctor-122.4194,37.7749
0derricklow14doctor-122.4194,37.7749
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.266666650772nancyhigh94doctor-122.4194,37.7749
" ], "text/plain": [ "" @@ -480,17 +435,10 @@ "execution_count": 15, "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(-@job:\"doctor\")=>[KNN 10 @user_embedding $vector AS vector_distance] RETURN 6 user credit_score age job office_location vector_distance SORTBY vector_distance ASC DIALECT 2 LIMIT 0 10\n" - ] - }, { "data": { "text/html": [ - "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.217882037163taimurlow15CEO-122.0839,37.3861
0.653301358223joemedium35dentist-122.0839,37.3861
" + "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0johnhigh18engineer-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.217882037163taimurlow15CEO-122.0839,37.3861
0.217882037163taimurlow15CEO-122.0839,37.3861
0.653301358223joemedium35dentist-122.0839,37.3861
0.653301358223joemedium35dentist-122.0839,37.3861
" ], "text/plain": [ "" @@ -513,17 +461,10 @@ "execution_count": 16, "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "@job:(doct*)=>[KNN 10 @user_embedding $vector AS vector_distance] RETURN 6 user credit_score age job office_location vector_distance SORTBY vector_distance ASC DIALECT 2 LIMIT 0 10\n" - ] - }, { "data": { "text/html": [ - "
vector_distanceusercredit_scoreagejoboffice_location
0derricklow14doctor-122.4194,37.7749
0.266666650772nancyhigh94doctor-122.4194,37.7749
" + "
vector_distanceusercredit_scoreagejoboffice_location
0derricklow14doctor-122.4194,37.7749
0derricklow14doctor-122.4194,37.7749
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.266666650772nancyhigh94doctor-122.4194,37.7749
" ], "text/plain": [ "" @@ -546,17 +487,10 @@ "execution_count": 17, "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "@job:(%%engine%%)=>[KNN 10 @user_embedding $vector AS vector_distance] RETURN 6 user credit_score age job office_location vector_distance SORTBY vector_distance ASC DIALECT 2 LIMIT 0 10\n" - ] - }, { "data": { "text/html": [ - "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
" + "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0johnhigh18engineer-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.109129190445tylerhigh100engineer-122.0839,37.3861
" ], "text/plain": [ "" @@ -579,17 +513,10 @@ "execution_count": 18, "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "@job:(engineer|doctor)=>[KNN 10 @user_embedding $vector AS vector_distance] RETURN 6 user credit_score age job office_location vector_distance SORTBY vector_distance ASC DIALECT 2 LIMIT 0 10\n" - ] - }, { "data": { "text/html": [ - "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0derricklow14doctor-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
" + "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0derricklow14doctor-122.4194,37.7749
0johnhigh18engineer-122.4194,37.7749
0derricklow14doctor-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.266666650772nancyhigh94doctor-122.4194,37.7749
" ], "text/plain": [ "" @@ -612,17 +539,10 @@ "execution_count": 19, "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "*=>[KNN 10 @user_embedding $vector AS vector_distance] RETURN 6 user credit_score age job office_location vector_distance SORTBY vector_distance ASC DIALECT 2 LIMIT 0 10\n" - ] - }, { "data": { "text/html": [ - "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0derricklow14doctor-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.217882037163taimurlow15CEO-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.653301358223joemedium35dentist-122.0839,37.3861
" + "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0derricklow14doctor-122.4194,37.7749
0johnhigh18engineer-122.4194,37.7749
0derricklow14doctor-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.217882037163taimurlow15CEO-122.0839,37.3861
0.217882037163taimurlow15CEO-122.0839,37.3861
" ], "text/plain": [ "" @@ -652,17 +572,10 @@ "execution_count": 20, "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(~(@job:engineer))=>[KNN 10 @user_embedding $vector AS vector_distance] SCORER BM25 WITHSCORES RETURN 6 user credit_score age job office_location vector_distance SORTBY vector_distance ASC DIALECT 2 LIMIT 0 10\n" - ] - }, { "data": { "text/plain": [ - "[{'id': 'user_queries_docs:01JM34BQDEBYP2ZQTRK6SEAHHT',\n", + "[{'id': 'user_queries_docs:01JMJJHE28ZW4F33ZNRKXRHYCS',\n", " 'score': 1.8181817787737895,\n", " 'vector_distance': '0',\n", " 'user': 'john',\n", @@ -670,7 +583,7 @@ " 'age': '18',\n", " 'job': 'engineer',\n", " 'office_location': '-122.4194,37.7749'},\n", - " {'id': 'user_queries_docs:01JM34BQDEEV60RV4VXK0JBBBK',\n", + " {'id': 'user_queries_docs:01JMJJHE2899024DYPXT6424N9',\n", " 'score': 0.0,\n", " 'vector_distance': '0',\n", " 'user': 'derrick',\n", @@ -678,7 +591,31 @@ " 'age': '14',\n", " 'job': 'doctor',\n", " 'office_location': '-122.4194,37.7749'},\n", - " {'id': 'user_queries_docs:01JM34BQDE6QKS9N8G49JN4V71',\n", + " {'id': 'user_queries_docs:01JMJJPEYCQ89ZQW6QR27J72WT',\n", + " 'score': 1.8181817787737895,\n", + " 'vector_distance': '0',\n", + " 'user': 'john',\n", + " 'credit_score': 'high',\n", + " 'age': '18',\n", + " 'job': 'engineer',\n", + " 'office_location': '-122.4194,37.7749'},\n", + " {'id': 'user_queries_docs:01JMJJPEYD544WB1TKDBJ3Z3J9',\n", + " 'score': 0.0,\n", + " 'vector_distance': '0',\n", + " 'user': 'derrick',\n", + " 'credit_score': 'low',\n", + " 'age': '14',\n", + " 'job': 'doctor',\n", + " 'office_location': '-122.4194,37.7749'},\n", + " {'id': 'user_queries_docs:01JMJJHE28B5R6T00DH37A7KSJ',\n", + " 'score': 1.8181817787737895,\n", + " 'vector_distance': '0.109129190445',\n", + " 'user': 'tyler',\n", + " 'credit_score': 'high',\n", + " 'age': '100',\n", + " 'job': 'engineer',\n", + " 'office_location': '-122.0839,37.3861'},\n", + " {'id': 'user_queries_docs:01JMJJPEYDPF9S5328WHCQN0ND',\n", " 'score': 1.8181817787737895,\n", " 'vector_distance': '0.109129190445',\n", " 'user': 'tyler',\n", @@ -686,7 +623,15 @@ " 'age': '100',\n", " 'job': 'engineer',\n", " 'office_location': '-122.0839,37.3861'},\n", - " {'id': 'user_queries_docs:01JM34BQDEZ9AXAQRE6GYAQPSP',\n", + " {'id': 'user_queries_docs:01JMJJHE28G5F943YGWMB1ZX1V',\n", + " 'score': 0.0,\n", + " 'vector_distance': '0.158808946609',\n", + " 'user': 'tim',\n", + " 'credit_score': 'high',\n", + " 'age': '12',\n", + " 'job': 'dermatologist',\n", + " 'office_location': '-122.0839,37.3861'},\n", + " {'id': 'user_queries_docs:01JMJJPEYDKA9ARKHRK1D7KPXQ',\n", " 'score': 0.0,\n", " 'vector_distance': '0.158808946609',\n", " 'user': 'tim',\n", @@ -694,7 +639,7 @@ " 'age': '12',\n", " 'job': 'dermatologist',\n", " 'office_location': '-122.0839,37.3861'},\n", - " {'id': 'user_queries_docs:01JM34BQDEM8JT6XC3YJRN72R4',\n", + " {'id': 'user_queries_docs:01JMJJHE28NR7KF0EZEA433T2J',\n", " 'score': 0.0,\n", " 'vector_distance': '0.217882037163',\n", " 'user': 'taimur',\n", @@ -702,21 +647,13 @@ " 'age': '15',\n", " 'job': 'CEO',\n", " 'office_location': '-122.0839,37.3861'},\n", - " {'id': 'user_queries_docs:01JM34BQDEJK2B70X6B11JBQ82',\n", - " 'score': 0.0,\n", - " 'vector_distance': '0.266666650772',\n", - " 'user': 'nancy',\n", - " 'credit_score': 'high',\n", - " 'age': '94',\n", - " 'job': 'doctor',\n", - " 'office_location': '-122.4194,37.7749'},\n", - " {'id': 'user_queries_docs:01JM34BQDEBCVNY5BZSVD16N9R',\n", + " {'id': 'user_queries_docs:01JMJJPEYD9EAVGJ2AZ8K9VX7Q',\n", " 'score': 0.0,\n", - " 'vector_distance': '0.653301358223',\n", - " 'user': 'joe',\n", - " 'credit_score': 'medium',\n", - " 'age': '35',\n", - " 'job': 'dentist',\n", + " 'vector_distance': '0.217882037163',\n", + " 'user': 'taimur',\n", + " 'credit_score': 'low',\n", + " 'age': '15',\n", + " 'job': 'CEO',\n", " 'office_location': '-122.0839,37.3861'}]" ] }, @@ -746,17 +683,10 @@ "execution_count": 21, "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "@office_location:[-122.4194 37.7749 10 km]=>[KNN 10 @user_embedding $vector AS vector_distance] SCORER BM25 WITHSCORES RETURN 6 user credit_score age job office_location vector_distance SORTBY vector_distance ASC DIALECT 2 LIMIT 0 10\n" - ] - }, { "data": { "text/html": [ - "
scorevector_distanceusercredit_scoreagejoboffice_location
0.45454544469344740johnhigh18engineer-122.4194,37.7749
0.45454544469344740derricklow14doctor-122.4194,37.7749
0.45454544469344740.266666650772nancyhigh94doctor-122.4194,37.7749
" + "
scorevector_distanceusercredit_scoreagejoboffice_location
0.45454544469344740johnhigh18engineer-122.4194,37.7749
0.45454544469344740derricklow14doctor-122.4194,37.7749
0.45454544469344740johnhigh18engineer-122.4194,37.7749
0.45454544469344740derricklow14doctor-122.4194,37.7749
0.45454544469344740.266666650772nancyhigh94doctor-122.4194,37.7749
0.45454544469344740.266666650772nancyhigh94doctor-122.4194,37.7749
" ], "text/plain": [ "" @@ -781,17 +711,10 @@ "execution_count": 22, "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "@office_location:[-122.4194 37.7749 100 km]=>[KNN 10 @user_embedding $vector AS vector_distance] SCORER BM25 WITHSCORES RETURN 6 user credit_score age job office_location vector_distance SORTBY vector_distance ASC DIALECT 2 LIMIT 0 10\n" - ] - }, { "data": { "text/html": [ - "
scorevector_distanceusercredit_scoreagejoboffice_location
0.45454544469344740johnhigh18engineer-122.4194,37.7749
0.45454544469344740derricklow14doctor-122.4194,37.7749
0.45454544469344740.109129190445tylerhigh100engineer-122.0839,37.3861
0.45454544469344740.158808946609timhigh12dermatologist-122.0839,37.3861
0.45454544469344740.217882037163taimurlow15CEO-122.0839,37.3861
0.45454544469344740.266666650772nancyhigh94doctor-122.4194,37.7749
0.45454544469344740.653301358223joemedium35dentist-122.0839,37.3861
" + "
scorevector_distanceusercredit_scoreagejoboffice_location
0.45454544469344740johnhigh18engineer-122.4194,37.7749
0.45454544469344740derricklow14doctor-122.4194,37.7749
0.45454544469344740johnhigh18engineer-122.4194,37.7749
0.45454544469344740derricklow14doctor-122.4194,37.7749
0.45454544469344740.109129190445tylerhigh100engineer-122.0839,37.3861
0.45454544469344740.109129190445tylerhigh100engineer-122.0839,37.3861
0.45454544469344740.158808946609timhigh12dermatologist-122.0839,37.3861
0.45454544469344740.158808946609timhigh12dermatologist-122.0839,37.3861
0.45454544469344740.217882037163taimurlow15CEO-122.0839,37.3861
0.45454544469344740.217882037163taimurlow15CEO-122.0839,37.3861
" ], "text/plain": [ "" @@ -814,17 +737,10 @@ "execution_count": 23, "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(-@office_location:[-122.4194 37.7749 10 km])=>[KNN 10 @user_embedding $vector AS vector_distance] SCORER BM25 WITHSCORES RETURN 6 user credit_score age job office_location vector_distance SORTBY vector_distance ASC DIALECT 2 LIMIT 0 10\n" - ] - }, { "data": { "text/html": [ - "
scorevector_distanceusercredit_scoreagejoboffice_location
0.00.109129190445tylerhigh100engineer-122.0839,37.3861
0.00.158808946609timhigh12dermatologist-122.0839,37.3861
0.00.217882037163taimurlow15CEO-122.0839,37.3861
0.00.653301358223joemedium35dentist-122.0839,37.3861
" + "
scorevector_distanceusercredit_scoreagejoboffice_location
0.00.109129190445tylerhigh100engineer-122.0839,37.3861
0.00.109129190445tylerhigh100engineer-122.0839,37.3861
0.00.158808946609timhigh12dermatologist-122.0839,37.3861
0.00.158808946609timhigh12dermatologist-122.0839,37.3861
0.00.217882037163taimurlow15CEO-122.0839,37.3861
0.00.217882037163taimurlow15CEO-122.0839,37.3861
0.00.653301358223joemedium35dentist-122.0839,37.3861
0.00.653301358223joemedium35dentist-122.0839,37.3861
" ], "text/plain": [ "" @@ -858,17 +774,10 @@ "execution_count": 24, "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "((@credit_score:{high} @age:[18 +inf]) @age:[-inf 100])=>[KNN 10 @user_embedding $vector AS vector_distance] RETURN 6 user credit_score age job office_location vector_distance SORTBY vector_distance ASC DIALECT 2 LIMIT 0 10\n" - ] - }, { "data": { "text/html": [ - "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
" + "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0johnhigh18engineer-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.266666650772nancyhigh94doctor-122.4194,37.7749
" ], "text/plain": [ "" @@ -908,17 +817,10 @@ "execution_count": 25, "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(@age:[-inf (18] | @age:[(93 +inf])=>[KNN 10 @user_embedding $vector AS vector_distance] RETURN 6 user credit_score age job office_location vector_distance SORTBY vector_distance ASC DIALECT 2 LIMIT 0 10\n" - ] - }, { "data": { "text/html": [ - "
vector_distanceusercredit_scoreagejoboffice_location
0derricklow14doctor-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.217882037163taimurlow15CEO-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
" + "
vector_distanceusercredit_scoreagejoboffice_location
0derricklow14doctor-122.4194,37.7749
0derricklow14doctor-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.217882037163taimurlow15CEO-122.0839,37.3861
0.217882037163taimurlow15CEO-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.266666650772nancyhigh94doctor-122.4194,37.7749
" ], "text/plain": [ "" @@ -974,17 +876,10 @@ "execution_count": 27, "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "((@age:[(18 +inf] @credit_score:{high}) @job:(engineer))=>[KNN 10 @user_embedding $vector AS vector_distance] RETURN 6 user credit_score age job office_location vector_distance SORTBY vector_distance ASC DIALECT 2 LIMIT 0 10\n" - ] - }, { "data": { "text/html": [ - "
vector_distanceusercredit_scoreagejoboffice_location
0.109129190445tylerhigh100engineer-122.0839,37.3861
" + "
vector_distanceusercredit_scoreagejoboffice_location
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.109129190445tylerhigh100engineer-122.0839,37.3861
" ], "text/plain": [ "" @@ -1006,17 +901,10 @@ "execution_count": 28, "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(@age:[(18 +inf] @credit_score:{high})=>[KNN 10 @user_embedding $vector AS vector_distance] RETURN 6 user credit_score age job office_location vector_distance SORTBY vector_distance ASC DIALECT 2 LIMIT 0 10\n" - ] - }, { "data": { "text/html": [ - "
vector_distanceusercredit_scoreagejoboffice_location
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
" + "
vector_distanceusercredit_scoreagejoboffice_location
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.266666650772nancyhigh94doctor-122.4194,37.7749
" ], "text/plain": [ "" @@ -1038,17 +926,10 @@ "execution_count": 29, "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "@age:[(18 +inf]=>[KNN 10 @user_embedding $vector AS vector_distance] RETURN 6 user credit_score age job office_location vector_distance SORTBY vector_distance ASC DIALECT 2 LIMIT 0 10\n" - ] - }, { "data": { "text/html": [ - "
vector_distanceusercredit_scoreagejoboffice_location
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.653301358223joemedium35dentist-122.0839,37.3861
" + "
vector_distanceusercredit_scoreagejoboffice_location
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.653301358223joemedium35dentist-122.0839,37.3861
0.653301358223joemedium35dentist-122.0839,37.3861
" ], "text/plain": [ "" @@ -1070,17 +951,10 @@ "execution_count": 30, "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "*=>[KNN 10 @user_embedding $vector AS vector_distance] RETURN 6 user credit_score age job office_location vector_distance SORTBY vector_distance ASC DIALECT 2 LIMIT 0 10\n" - ] - }, { "data": { "text/html": [ - "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0derricklow14doctor-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.217882037163taimurlow15CEO-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.653301358223joemedium35dentist-122.0839,37.3861
" + "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0derricklow14doctor-122.4194,37.7749
0johnhigh18engineer-122.4194,37.7749
0derricklow14doctor-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.217882037163taimurlow15CEO-122.0839,37.3861
0.217882037163taimurlow15CEO-122.0839,37.3861
" ], "text/plain": [ "" @@ -1111,17 +985,10 @@ "execution_count": 31, "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "@credit_score:{low} RETURN 5 user credit_score age job location DIALECT 2 LIMIT 0 10\n" - ] - }, { "data": { "text/html": [ - "
usercredit_scoreagejob
derricklow14doctor
taimurlow15CEO
" + "
usercredit_scoreagejob
derricklow14doctor
taimurlow15CEO
derricklow14doctor
taimurlow15CEO
" ], "text/plain": [ "" @@ -1164,8 +1031,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "@credit_score:{low} NOCONTENT DIALECT 2 LIMIT 0 0\n", - "2 records match the filter expression @credit_score:{low} for the given index.\n" + "4 records match the filter expression @credit_score:{low} for the given index.\n" ] } ], @@ -1195,17 +1061,10 @@ "execution_count": 33, "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "@user_embedding:[VECTOR_RANGE $distance_threshold $vector]=>{$yield_distance_as: vector_distance} RETURN 6 user credit_score age job location vector_distance SORTBY vector_distance ASC DIALECT 2 LIMIT 0 10\n" - ] - }, { "data": { "text/html": [ - "
vector_distanceusercredit_scoreagejob
0johnhigh18engineer
0derricklow14doctor
0.109129190445tylerhigh100engineer
0.158808946609timhigh12dermatologist
" + "
vector_distanceusercredit_scoreagejob
0johnhigh18engineer
0derricklow14doctor
0johnhigh18engineer
0derricklow14doctor
0.109129190445tylerhigh100engineer
0.109129190445tylerhigh100engineer
0.158808946609timhigh12dermatologist
0.158808946609timhigh12dermatologist
" ], "text/plain": [ "" @@ -1243,17 +1102,10 @@ "execution_count": 34, "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "@user_embedding:[VECTOR_RANGE $distance_threshold $vector]=>{$yield_distance_as: vector_distance} RETURN 6 user credit_score age job location vector_distance SORTBY vector_distance ASC DIALECT 2 LIMIT 0 10\n" - ] - }, { "data": { "text/html": [ - "
vector_distanceusercredit_scoreagejob
0johnhigh18engineer
0derricklow14doctor
" + "
vector_distanceusercredit_scoreagejob
0johnhigh18engineer
0derricklow14doctor
0johnhigh18engineer
0derricklow14doctor
" ], "text/plain": [ "" @@ -1281,17 +1133,10 @@ "execution_count": 35, "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(@user_embedding:[VECTOR_RANGE $distance_threshold $vector]=>{$yield_distance_as: vector_distance} @job:(\"engineer\")) RETURN 6 user credit_score age job location vector_distance SORTBY vector_distance ASC DIALECT 2 LIMIT 0 10\n" - ] - }, { "data": { "text/html": [ - "
vector_distanceusercredit_scoreagejob
0johnhigh18engineer
" + "
vector_distanceusercredit_scoreagejob
0johnhigh18engineer
0johnhigh18engineer
" ], "text/plain": [ "" @@ -1328,17 +1173,10 @@ "execution_count": 36, "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "@job:(\"engineer\")=>[KNN 5 @user_embedding $vector AS vector_distance] RETURN 6 user credit_score age job office_location vector_distance SORTBY age DESC DIALECT 3 LIMIT 0 5\n" - ] - }, { "data": { "text/html": [ - "
vector_distanceageusercredit_scorejoboffice_location
0.109129190445100tylerhighengineer-122.0839,37.3861
018johnhighengineer-122.4194,37.7749
" + "
vector_distanceageusercredit_scorejoboffice_location
0.109129190445100tylerhighengineer-122.0839,37.3861
0.109129190445100tylerhighengineer-122.0839,37.3861
018johnhighengineer-122.4194,37.7749
018johnhighengineer-122.4194,37.7749
" ], "text/plain": [ "" @@ -1457,11 +1295,14 @@ "name": "stdout", "output_type": "stream", "text": [ - "\n", - "{'id': 'user_queries_docs:01JM34BQDEBYP2ZQTRK6SEAHHT', 'payload': None, 'user': 'john', 'age': '18', 'job': 'engineer', 'credit_score': 'high', 'office_location': '-122.4194,37.7749', 'user_embedding': '==\\x00\\x00\\x00?'}\n", - "{'id': 'user_queries_docs:01JM34BQDEJK2B70X6B11JBQ82', 'payload': None, 'user': 'nancy', 'age': '94', 'job': 'doctor', 'credit_score': 'high', 'office_location': '-122.4194,37.7749', 'user_embedding': '333?=\\x00\\x00\\x00?'}\n", - "{'id': 'user_queries_docs:01JM34BQDE6QKS9N8G49JN4V71', 'payload': None, 'user': 'tyler', 'age': '100', 'job': 'engineer', 'credit_score': 'high', 'office_location': '-122.0839,37.3861', 'user_embedding': '=>\\x00\\x00\\x00?'}\n", - "{'id': 'user_queries_docs:01JM34BQDEZ9AXAQRE6GYAQPSP', 'payload': None, 'user': 'tim', 'age': '12', 'job': 'dermatologist', 'credit_score': 'high', 'office_location': '-122.0839,37.3861', 'user_embedding': '>>\\x00\\x00\\x00?'}\n" + "{'id': 'user_queries_docs:01JMJJHE28G5F943YGWMB1ZX1V', 'payload': None, 'user': 'tim', 'age': '12', 'job': 'dermatologist', 'credit_score': 'high', 'office_location': '-122.0839,37.3861', 'user_embedding': '>>\\x00\\x00\\x00?'}\n", + "{'id': 'user_queries_docs:01JMJJHE28ZW4F33ZNRKXRHYCS', 'payload': None, 'user': 'john', 'age': '18', 'job': 'engineer', 'credit_score': 'high', 'office_location': '-122.4194,37.7749', 'user_embedding': '==\\x00\\x00\\x00?'}\n", + "{'id': 'user_queries_docs:01JMJJHE28B5R6T00DH37A7KSJ', 'payload': None, 'user': 'tyler', 'age': '100', 'job': 'engineer', 'credit_score': 'high', 'office_location': '-122.0839,37.3861', 'user_embedding': '=>\\x00\\x00\\x00?'}\n", + "{'id': 'user_queries_docs:01JMJJHE28EX13NEE7BGBM8FH3', 'payload': None, 'user': 'nancy', 'age': '94', 'job': 'doctor', 'credit_score': 'high', 'office_location': '-122.4194,37.7749', 'user_embedding': '333?=\\x00\\x00\\x00?'}\n", + "{'id': 'user_queries_docs:01JMJJPEYCQ89ZQW6QR27J72WT', 'payload': None, 'user': 'john', 'age': '18', 'job': 'engineer', 'credit_score': 'high', 'office_location': '-122.4194,37.7749', 'user_embedding': '==\\x00\\x00\\x00?'}\n", + "{'id': 'user_queries_docs:01JMJJPEYDAN0M3V7EQEVPS6HX', 'payload': None, 'user': 'nancy', 'age': '94', 'job': 'doctor', 'credit_score': 'high', 'office_location': '-122.4194,37.7749', 'user_embedding': '333?=\\x00\\x00\\x00?'}\n", + "{'id': 'user_queries_docs:01JMJJPEYDPF9S5328WHCQN0ND', 'payload': None, 'user': 'tyler', 'age': '100', 'job': 'engineer', 'credit_score': 'high', 'office_location': '-122.0839,37.3861', 'user_embedding': '=>\\x00\\x00\\x00?'}\n", + "{'id': 'user_queries_docs:01JMJJPEYDKA9ARKHRK1D7KPXQ', 'payload': None, 'user': 'tim', 'age': '12', 'job': 'dermatologist', 'credit_score': 'high', 'office_location': '-122.0839,37.3861', 'user_embedding': '>>\\x00\\x00\\x00?'}\n" ] } ], From 2ec7402e6c2dffc3ee48c7368b67e8dab7186e22 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Thu, 20 Feb 2025 16:26:10 -0500 Subject: [PATCH 22/22] reference redis_url as a kwarg --- docs/user_guide/05_hash_vs_json.ipynb | 34 +++++++++++++-------------- 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/docs/user_guide/05_hash_vs_json.ipynb b/docs/user_guide/05_hash_vs_json.ipynb index 04c75f60..7918949d 100644 --- a/docs/user_guide/05_hash_vs_json.ipynb +++ b/docs/user_guide/05_hash_vs_json.ipynb @@ -283,10 +283,12 @@ "\n", "t = (Tag(\"credit_score\") == \"high\") & (Text(\"job\") % \"enginee*\") & (Num(\"age\") > 17)\n", "\n", - "v = VectorQuery([0.1, 0.1, 0.5],\n", - " \"user_embedding\",\n", - " return_fields=[\"user\", \"credit_score\", \"age\", \"job\", \"office_location\"],\n", - " filter_expression=t)\n", + "v = VectorQuery(\n", + " vector=[0.1, 0.1, 0.5],\n", + " vector_field_name=\"user_embedding\",\n", + " return_fields=[\"user\", \"credit_score\", \"age\", \"job\", \"office_location\"],\n", + " filter_expression=t\n", + ")\n", "\n", "\n", "results = hindex.query(v)\n", @@ -395,8 +397,6 @@ "metadata": {}, "outputs": [], "source": [ - "import numpy as np\n", - "\n", "json_data = data.copy()\n", "\n", "for d in json_data:\n", @@ -593,10 +593,7 @@ "outputs": [], "source": [ "# construct a search index from the json schema\n", - "bike_index = SearchIndex.from_dict(bike_schema)\n", - "\n", - "# connect to local redis instance\n", - "bike_index.connect(\"redis://localhost:6379\")\n", + "bike_index = SearchIndex.from_dict(bike_schema, redis_url=\"redis://localhost:6379\")\n", "\n", "# create the index (no data yet)\n", "bike_index.create(overwrite=True)" @@ -633,14 +630,15 @@ "\n", "vec = emb_model.embed(\"I'd like a bike for aggressive riding\")\n", "\n", - "v = VectorQuery(vector=vec,\n", - " vector_field_name=\"bike_embedding\",\n", - " return_fields=[\n", - " \"brand\",\n", - " \"name\",\n", - " \"$.metadata.type\"\n", - " ]\n", - " )\n", + "v = VectorQuery(\n", + " vector=vec,\n", + " vector_field_name=\"bike_embedding\",\n", + " return_fields=[\n", + " \"brand\",\n", + " \"name\",\n", + " \"$.metadata.type\"\n", + " ]\n", + ")\n", "\n", "\n", "results = bike_index.query(v)"