diff --git a/redisvl/extensions/llmcache/semantic.py b/redisvl/extensions/llmcache/semantic.py index 0a08f3f5..e1c63875 100644 --- a/redisvl/extensions/llmcache/semantic.py +++ b/redisvl/extensions/llmcache/semantic.py @@ -28,7 +28,7 @@ def __init__( vectorizer: Optional[BaseVectorizer] = None, redis_client: Optional[Redis] = None, redis_url: str = "redis://localhost:6379", - connection_args: Dict[str, Any] = {}, + connection_kwargs: Dict[str, Any] = {}, **kwargs, ): """Semantic Cache for Large Language Models. @@ -43,14 +43,13 @@ def __init__( cache. Defaults to 0.1. ttl (Optional[int], optional): The time-to-live for records cached in Redis. Defaults to None. - vectorizer (BaseVectorizer, optional): The vectorizer for the cache. + vectorizer (Optional[BaseVectorizer], optional): The vectorizer for the cache. Defaults to HFTextVectorizer. - redis_client(Redis, optional): A redis client connection instance. + redis_client(Optional[Redis], optional): A redis client connection instance. Defaults to None. - redis_url (str, optional): The redis url. Defaults to - "redis://localhost:6379". - connection_args (Dict[str, Any], optional): The connection arguments - for the redis client. Defaults to None. + redis_url (str, optional): The redis url. Defaults to redis://localhost:6379. + connection_kwargs (Dict[str, Any]): The connection arguments + for the redis client. Defaults to empty {}. Raises: TypeError: If an invalid vectorizer is provided. @@ -96,8 +95,8 @@ def __init__( # handle redis connection if redis_client: self._index.set_client(redis_client) - else: - self._index.connect(redis_url=redis_url, **connection_args) + elif redis_url: + self._index.connect(redis_url=redis_url, **connection_kwargs) # initialize other components self.default_return_fields = [ diff --git a/redisvl/extensions/router/semantic.py b/redisvl/extensions/router/semantic.py index ee78b9d3..1c34202a 100644 --- a/redisvl/extensions/router/semantic.py +++ b/redisvl/extensions/router/semantic.py @@ -86,8 +86,9 @@ def __init__( vectorizer: Optional[BaseVectorizer] = None, routing_config: Optional[RoutingConfig] = None, redis_client: Optional[Redis] = None, - redis_url: Optional[str] = None, + redis_url: str = "redis://localhost:6379", overwrite: bool = False, + connection_kwargs: Dict[str, Any] = {}, **kwargs, ): """Initialize the SemanticRouter. @@ -98,9 +99,10 @@ def __init__( vectorizer (BaseVectorizer, optional): The vectorizer used to embed route references. Defaults to default HFTextVectorizer. routing_config (RoutingConfig, optional): Configuration for routing behavior. Defaults to the default RoutingConfig. redis_client (Optional[Redis], optional): Redis client for connection. Defaults to None. - redis_url (Optional[str], optional): Redis URL for connection. Defaults to None. + redis_url (str, optional): The redis url. Defaults to redis://localhost:6379. overwrite (bool, optional): Whether to overwrite existing index. Defaults to False. - **kwargs: Additional arguments. + connection_kwargs (Dict[str, Any]): The connection arguments + for the redis client. Defaults to empty {}. """ # Set vectorizer default if vectorizer is None: @@ -115,12 +117,12 @@ def __init__( vectorizer=vectorizer, routing_config=routing_config, ) - self._initialize_index(redis_client, redis_url, overwrite) + self._initialize_index(redis_client, redis_url, overwrite, **connection_kwargs) def _initialize_index( self, redis_client: Optional[Redis] = None, - redis_url: Optional[str] = None, + redis_url: str = "redis://localhost:6379", overwrite: bool = False, **connection_kwargs, ): @@ -132,8 +134,6 @@ def _initialize_index( self._index.set_client(redis_client) elif redis_url: self._index.connect(redis_url=redis_url, **connection_kwargs) - else: - raise ValueError("Must provide either a redis client or redis url string.") existed = self._index.exists() self._index.create(overwrite=overwrite) @@ -479,19 +479,12 @@ def clear(self) -> None: def from_dict( cls, data: Dict[str, Any], - redis_client: Optional[Redis] = None, - redis_url: Optional[str] = None, - overwrite: bool = False, **kwargs, ) -> "SemanticRouter": """Create a SemanticRouter from a dictionary. Args: data (Dict[str, Any]): The dictionary containing the semantic router data. - redis_client (Optional[Redis]): Redis client for connection. - redis_url (Optional[str]): Redis URL for connection. - overwrite (bool): Whether to overwrite existing index. - **kwargs: Additional arguments. Returns: SemanticRouter: The semantic router instance. @@ -533,9 +526,6 @@ def from_dict( routes=routes, vectorizer=vectorizer, routing_config=routing_config, - redis_client=redis_client, - redis_url=redis_url, - overwrite=overwrite, **kwargs, ) @@ -565,19 +555,12 @@ def to_dict(self) -> Dict[str, Any]: def from_yaml( cls, file_path: str, - redis_client: Optional[Redis] = None, - redis_url: Optional[str] = None, - overwrite: bool = False, **kwargs, ) -> "SemanticRouter": """Create a SemanticRouter from a YAML file. Args: file_path (str): The path to the YAML file. - redis_client (Optional[Redis]): Redis client for connection. - redis_url (Optional[str]): Redis URL for connection. - overwrite (bool): Whether to overwrite existing index. - **kwargs: Additional arguments. Returns: SemanticRouter: The semantic router instance. @@ -603,9 +586,6 @@ def from_yaml( yaml_data = yaml.safe_load(f) return cls.from_dict( yaml_data, - redis_client=redis_client, - redis_url=redis_url, - overwrite=overwrite, **kwargs, ) diff --git a/redisvl/extensions/session_manager/semantic_session.py b/redisvl/extensions/session_manager/semantic_session.py index 6ce253d5..e91f4754 100644 --- a/redisvl/extensions/session_manager/semantic_session.py +++ b/redisvl/extensions/session_manager/semantic_session.py @@ -1,5 +1,5 @@ from time import time -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union from redis import Redis @@ -27,6 +27,8 @@ def __init__( distance_threshold: float = 0.3, redis_client: Optional[Redis] = None, redis_url: str = "redis://localhost:6379", + connection_kwargs: Dict[str, Any] = {}, + **kwargs, ): """Initialize session memory with index @@ -43,12 +45,14 @@ def __init__( user_tag (str): Tag to be added to entries to link to a specific user. prefix (Optional[str]): Prefix for the keys for this session data. Defaults to None and will be replaced with the index name. - vectorizer (Vectorizer): The vectorizer to create embeddings with. + vectorizer (Optional[BaseVectorizer]): The vectorizer used to create embeddings. distance_threshold (float): The maximum semantic distance to be included in the context. Defaults to 0.3. redis_client (Optional[Redis]): A Redis client instance. Defaults to None. - redis_url (str): The URL of the Redis instance. Defaults to 'redis://localhost:6379'. + redis_url (str, optional): The redis url. Defaults to redis://localhost:6379. + connection_kwargs (Dict[str, Any]): The connection arguments + for the redis client. Defaults to empty {}. The proposed schema will support a single vector embedding constructed from either the prompt or response in a single string. @@ -89,10 +93,11 @@ def __init__( self._index = SearchIndex(schema=schema) + # handle redis connection if redis_client: self._index.set_client(redis_client) - else: - self._index.connect(redis_url=redis_url) + elif redis_url: + self._index.connect(redis_url=redis_url, **connection_kwargs) self._index.create(overwrite=False) diff --git a/redisvl/extensions/session_manager/standard_session.py b/redisvl/extensions/session_manager/standard_session.py index d0b76898..8617a3ab 100644 --- a/redisvl/extensions/session_manager/standard_session.py +++ b/redisvl/extensions/session_manager/standard_session.py @@ -1,10 +1,11 @@ import json from time import time -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union from redis import Redis from redisvl.extensions.session_manager import BaseSessionManager +from redisvl.redis.connection import RedisConnectionFactory class StandardSessionManager(BaseSessionManager): @@ -16,6 +17,8 @@ def __init__( user_tag: str, redis_client: Optional[Redis] = None, redis_url: str = "redis://localhost:6379", + connection_kwargs: Dict[str, Any] = {}, + **kwargs, ): """Initialize session memory @@ -31,7 +34,9 @@ def __init__( user_tag (str): Tag to be added to entries to link to a specific user. redis_client (Optional[Redis]): A Redis client instance. Defaults to None. - redis_url (str): The URL of the Redis instance. Defaults to 'redis://localhost:6379'. + redis_url (str, optional): The redis url. Defaults to redis://localhost:6379. + connection_kwargs (Dict[str, Any]): The connection arguments + for the redis client. Defaults to empty {}. The proposed schema will support a single combined vector embedding constructed from the prompt & response in a single string. @@ -39,10 +44,14 @@ def __init__( """ super().__init__(name, session_tag, user_tag) + # handle redis connection if redis_client: self._client = redis_client - else: - self._client = Redis.from_url(redis_url) + elif redis_url: + self._client = RedisConnectionFactory.get_redis_connection( + redis_url, **connection_kwargs + ) + RedisConnectionFactory.validate_redis(self._client) self.set_scope(session_tag, user_tag) @@ -51,7 +60,7 @@ def set_scope( session_tag: Optional[str] = None, user_tag: Optional[str] = None, ) -> None: - """Set the filter to apply to querries based on the desired scope. + """Set the filter to apply to queries based on the desired scope. This new scope persists until another call to set_scope is made, or if scope is specified in calls to get_recent. diff --git a/tests/integration/test_llmcache.py b/tests/integration/test_llmcache.py index e9ad8b5e..cc152291 100644 --- a/tests/integration/test_llmcache.py +++ b/tests/integration/test_llmcache.py @@ -2,6 +2,7 @@ from time import sleep import pytest +from redis.exceptions import ConnectionError from redisvl.extensions.llmcache import SemanticCache from redisvl.index.index import SearchIndex @@ -40,19 +41,17 @@ def cache_with_ttl(vectorizer, redis_url): @pytest.fixture -def cache_with_redis_client(vectorizer, client, redis_url): +def cache_with_redis_client(vectorizer, client): cache_instance = SemanticCache( vectorizer=vectorizer, redis_client=client, distance_threshold=0.2, - redis_url=redis_url, ) yield cache_instance cache_instance.clear() # Clear cache after each test cache_instance._index.delete(True) # Clean up index -# # Test handling invalid input for check method def test_bad_ttl(cache): with pytest.raises(ValueError): cache.set_ttl(2.5) @@ -76,7 +75,6 @@ def test_reset_ttl(cache): assert cache.ttl is None -# Test basic store and check functionality def test_store_and_check(cache, vectorizer): prompt = "This is a test prompt." response = "This is a test response." @@ -91,7 +89,6 @@ def test_store_and_check(cache, vectorizer): assert "metadata" not in check_result[0] -# Test clearing the cache def test_clear(cache, vectorizer): prompt = "This is a test prompt." response = "This is a test response." @@ -139,7 +136,6 @@ def test_check_no_match(cache, vectorizer): assert len(check_result) == 0 -# Test handling invalid input for check method def test_check_invalid_input(cache): with pytest.raises(ValueError): cache.check() @@ -148,7 +144,15 @@ def test_check_invalid_input(cache): cache.check(prompt="test", return_fields="bad value") -# Test storing with metadata +def test_bad_connection_info(vectorizer): + with pytest.raises(ConnectionError): + SemanticCache( + vectorizer=vectorizer, + distance_threshold=0.2, + redis_url="redis://localhost:6389", + ) + + def test_store_with_metadata(cache, vectorizer): prompt = "This is another test prompt." response = "This is another test response." @@ -165,7 +169,6 @@ def test_store_with_metadata(cache, vectorizer): assert check_result[0]["prompt"] == prompt -# Test storing with invalid metadata def test_store_with_invalid_metadata(cache, vectorizer): prompt = "This is another test prompt." response = "This is another test response." @@ -179,7 +182,6 @@ def test_store_with_invalid_metadata(cache, vectorizer): cache.store(prompt, response, vector=vector, metadata=metadata) -# Test setting and getting the distance threshold def test_distance_threshold(cache): initial_threshold = cache.distance_threshold new_threshold = 0.1 @@ -189,14 +191,12 @@ def test_distance_threshold(cache): assert cache.distance_threshold != initial_threshold -# Test out of range distance threshold def test_distance_threshold_out_of_range(cache): out_of_range_threshold = -1 with pytest.raises(ValueError): cache.set_threshold(out_of_range_threshold) -# Test storing and retrieving multiple items def test_multiple_items(cache, vectorizer): prompts_responses = { "prompt1": "response1", @@ -217,12 +217,10 @@ def test_multiple_items(cache, vectorizer): assert "metadata" not in check_result[0] -# Test retrieving underlying SearchIndex for the cache. def test_get_index(cache): assert isinstance(cache.index, SearchIndex) -# Test basic functionality with cache created with user-provided Redis client def test_store_and_check_with_provided_client(cache_with_redis_client, vectorizer): prompt = "This is a test prompt." response = "This is a test response." @@ -237,13 +235,11 @@ def test_store_and_check_with_provided_client(cache_with_redis_client, vectorize assert "metadata" not in check_result[0] -# Test deleting the cache def test_delete(cache_no_cleanup): cache_no_cleanup.delete() assert not cache_no_cleanup.index.exists() -# Test we can only store and check vectors of correct dimensions def test_vector_size(cache, vectorizer): prompt = "This is test prompt." response = "This is a test response." diff --git a/tests/integration/test_semantic_router.py b/tests/integration/test_semantic_router.py index 6c54a2b9..b2a7c716 100644 --- a/tests/integration/test_semantic_router.py +++ b/tests/integration/test_semantic_router.py @@ -1,6 +1,7 @@ import pathlib import pytest +from redis.exceptions import ConnectionError from redisvl.extensions.router import SemanticRouter from redisvl.extensions.router.schema import Route, RoutingConfig @@ -225,3 +226,14 @@ def test_idempotent_to_dict(semantic_router): router_dict, redis_client=semantic_router._index.client ) assert new_router.to_dict() == router_dict + + +def test_bad_connection_info(routes): + with pytest.raises(ConnectionError): + SemanticRouter( + name="test-router", + routes=routes, + routing_config=RoutingConfig(distance_threshold=0.3, max_k=2), + redis_url="redis://localhost:6389", # bad connection url + overwrite=False, + ) diff --git a/tests/integration/test_session_manager.py b/tests/integration/test_session_manager.py index 0a236aaf..c1f277c8 100644 --- a/tests/integration/test_session_manager.py +++ b/tests/integration/test_session_manager.py @@ -2,6 +2,7 @@ import time import pytest +from redis.exceptions import ConnectionError from redisvl.extensions.session_manager import ( SemanticSessionManager, @@ -10,9 +11,9 @@ @pytest.fixture -def standard_session(app_name, user_tag, session_tag): +def standard_session(app_name, user_tag, session_tag, client): session = StandardSessionManager( - app_name, session_tag=session_tag, user_tag=user_tag + app_name, session_tag=session_tag, user_tag=user_tag, redis_client=client ) yield session session.clear() @@ -20,9 +21,9 @@ def standard_session(app_name, user_tag, session_tag): @pytest.fixture -def semantic_session(app_name, user_tag, session_tag): +def semantic_session(app_name, user_tag, session_tag, client): session = SemanticSessionManager( - app_name, session_tag=session_tag, user_tag=user_tag + app_name, session_tag=session_tag, user_tag=user_tag, redis_client=client ) yield session session.clear() @@ -30,9 +31,11 @@ def semantic_session(app_name, user_tag, session_tag): # test standard session manager -def test_key_creation(): +def test_key_creation(client): # test default key creation - session = StandardSessionManager(name="test_app", session_tag="123", user_tag="abc") + session = StandardSessionManager( + name="test_app", session_tag="123", user_tag="abc", redis_client=client + ) assert session.key == "test_app:abc:123" @@ -43,6 +46,26 @@ def test_specify_redis_client(client): assert isinstance(session._client, type(client)) +def test_specify_redis_url(client): + session = StandardSessionManager( + name="test_app", + session_tag="abc", + user_tag="123", + redis_url="redis://localhost:6379", + ) + assert isinstance(session._client, type(client)) + + +def test_standard_bad_connection_info(): + with pytest.raises(ConnectionError): + StandardSessionManager( + name="test_app", + session_tag="abc", + user_tag="123", + redis_url="redis://localhost:6389", # bad url + ) + + def test_standard_store_and_get(standard_session): context = standard_session.get_recent() assert len(context) == 0 @@ -367,6 +390,16 @@ def test_semantic_specify_client(client): assert isinstance(session._index.client, type(client)) +def test_semantic_bad_connection_info(): + with pytest.raises(ConnectionError): + SemanticSessionManager( + name="test_app", + session_tag="abc", + user_tag="123", + redis_url="redis://localhost:6389", + ) + + def test_semantic_set_scope(semantic_session, app_name, user_tag, session_tag): # test calling set_scope with no params does not change scope semantic_session.store("some prompt", "some response")