diff --git a/langgraph/checkpoint/redis/__init__.py b/langgraph/checkpoint/redis/__init__.py index 5851172..2fbb151 100644 --- a/langgraph/checkpoint/redis/__init__.py +++ b/langgraph/checkpoint/redis/__init__.py @@ -2,7 +2,8 @@ import json from contextlib import contextmanager -from typing import Any, Dict, Iterator, List, Optional, Tuple, cast +import logging +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast from langchain_core.runnables import RunnableConfig from langgraph.checkpoint.base import ( @@ -14,6 +15,7 @@ ) from langgraph.constants import TASKS from redis import Redis +from redis.cluster import RedisCluster from redisvl.index import SearchIndex from redisvl.query import FilterQuery from redisvl.query.filter import Num, Tag @@ -32,15 +34,21 @@ ) from langgraph.checkpoint.redis.version import __lib_name__, __version__ +logger = logging.getLogger(__name__) -class RedisSaver(BaseRedisSaver[Redis, SearchIndex]): + +class RedisSaver(BaseRedisSaver[Union[Redis, RedisCluster], SearchIndex]): """Standard Redis implementation for checkpoint saving.""" + _redis: Union[Redis, RedisCluster] # Support both standalone and cluster clients + # Whether to assume the Redis server is a cluster; None triggers auto-detection + cluster_mode: Optional[bool] = None + def __init__( self, redis_url: Optional[str] = None, *, - redis_client: Optional[Redis] = None, + redis_client: Optional[Union[Redis, RedisCluster]] = None, connection_args: Optional[Dict[str, Any]] = None, ttl: Optional[Dict[str, Any]] = None, ) -> None: @@ -54,7 +62,7 @@ def __init__( def configure_client( self, redis_url: Optional[str] = None, - redis_client: Optional[Redis] = None, + redis_client: Optional[Union[Redis, RedisCluster]] = None, connection_args: Optional[Dict[str, Any]] = None, ) -> None: """Configure the Redis client.""" @@ -74,6 +82,27 @@ def create_indexes(self) -> None: self.SCHEMAS[2], redis_client=self._redis ) + def setup(self) -> None: + """Initialize the indices in Redis and detect cluster mode.""" + self._detect_cluster_mode() + super().setup() + + def _detect_cluster_mode(self) -> None: + """Detect if the Redis client is a cluster client by inspecting its class.""" + if self.cluster_mode is not None: + logger.info( + f"Redis cluster_mode explicitly set to {self.cluster_mode}, skipping detection." + ) + return + + # Determine cluster mode based on client class + if isinstance(self._redis, RedisCluster): + logger.info("Redis client is a cluster client") + self.cluster_mode = True + else: + logger.info("Redis client is a standalone client") + self.cluster_mode = False + def list( self, config: Optional[RunnableConfig], @@ -458,7 +487,7 @@ def from_conn_string( cls, redis_url: Optional[str] = None, *, - redis_client: Optional[Redis] = None, + redis_client: Optional[Union[Redis, RedisCluster]] = None, connection_args: Optional[Dict[str, Any]] = None, ttl: Optional[Dict[str, Any]] = None, ) -> Iterator[RedisSaver]: @@ -592,8 +621,8 @@ def delete_thread(self, thread_id: str) -> None: checkpoint_results = self.checkpoints_index.search(checkpoint_query) - # Delete all checkpoint-related keys - pipeline = self._redis.pipeline() + # Collect all keys to delete + keys_to_delete = [] for doc in checkpoint_results.docs: checkpoint_ns = getattr(doc, "checkpoint_ns", "") @@ -603,7 +632,7 @@ def delete_thread(self, thread_id: str) -> None: checkpoint_key = BaseRedisSaver._make_redis_checkpoint_key( storage_safe_thread_id, checkpoint_ns, checkpoint_id ) - pipeline.delete(checkpoint_key) + keys_to_delete.append(checkpoint_key) # Delete all blobs for this thread blob_query = FilterQuery( @@ -622,7 +651,7 @@ def delete_thread(self, thread_id: str) -> None: blob_key = BaseRedisSaver._make_redis_checkpoint_blob_key( storage_safe_thread_id, checkpoint_ns, channel, version ) - pipeline.delete(blob_key) + keys_to_delete.append(blob_key) # Delete all writes for this thread writes_query = FilterQuery( @@ -642,10 +671,19 @@ def delete_thread(self, thread_id: str) -> None: write_key = BaseRedisSaver._make_redis_checkpoint_writes_key( storage_safe_thread_id, checkpoint_ns, checkpoint_id, task_id, idx ) - pipeline.delete(write_key) + keys_to_delete.append(write_key) - # Execute all deletions - pipeline.execute() + # Execute all deletions based on cluster mode + if self.cluster_mode: + # For cluster mode, delete keys individually + for key in keys_to_delete: + self._redis.delete(key) + else: + # For non-cluster mode, use pipeline for efficiency + pipeline = self._redis.pipeline() + for key in keys_to_delete: + pipeline.delete(key) + pipeline.execute() __all__ = [ diff --git a/langgraph/checkpoint/redis/aio.py b/langgraph/checkpoint/redis/aio.py index 46a4d25..a6c7fe8 100644 --- a/langgraph/checkpoint/redis/aio.py +++ b/langgraph/checkpoint/redis/aio.py @@ -4,11 +4,23 @@ import asyncio import json +import logging import os from contextlib import asynccontextmanager from functools import partial from types import TracebackType -from typing import Any, AsyncIterator, Dict, List, Optional, Sequence, Tuple, Type, cast +from typing import ( + Any, + AsyncIterator, + Dict, + List, + Optional, + Sequence, + Tuple, + Type, + Union, + cast, +) from langchain_core.runnables import RunnableConfig from langgraph.checkpoint.base import ( @@ -23,6 +35,7 @@ from langgraph.constants import TASKS from redis.asyncio import Redis as AsyncRedis from redis.asyncio.client import Pipeline +from redis.asyncio.cluster import RedisCluster as AsyncRedisCluster from redisvl.index import AsyncSearchIndex from redisvl.query import FilterQuery from redisvl.query.filter import Num, Tag @@ -38,6 +51,8 @@ to_storage_safe_str, ) +logger = logging.getLogger(__name__) + async def _write_obj_tx( pipe: Pipeline, @@ -58,7 +73,9 @@ async def _write_obj_tx( await pipe.json().set(key, "$", write_obj) -class AsyncRedisSaver(BaseRedisSaver[AsyncRedis, AsyncSearchIndex]): +class AsyncRedisSaver( + BaseRedisSaver[Union[AsyncRedis, AsyncRedisCluster], AsyncSearchIndex] +): """Async Redis implementation for checkpoint saver.""" _redis_url: str @@ -66,13 +83,17 @@ class AsyncRedisSaver(BaseRedisSaver[AsyncRedis, AsyncSearchIndex]): checkpoint_blobs_index: AsyncSearchIndex checkpoint_writes_index: AsyncSearchIndex - _redis: AsyncRedis # Override the type from the base class + _redis: Union[ + AsyncRedis, AsyncRedisCluster + ] # Support both standalone and cluster clients + # Whether to assume the Redis server is a cluster; None triggers auto-detection + cluster_mode: Optional[bool] = None def __init__( self, redis_url: Optional[str] = None, *, - redis_client: Optional[AsyncRedis] = None, + redis_client: Optional[Union[AsyncRedis, AsyncRedisCluster]] = None, connection_args: Optional[Dict[str, Any]] = None, ttl: Optional[Dict[str, Any]] = None, ) -> None: @@ -87,7 +108,7 @@ def __init__( def configure_client( self, redis_url: Optional[str] = None, - redis_client: Optional[AsyncRedis] = None, + redis_client: Optional[Union[AsyncRedis, AsyncRedisCluster]] = None, connection_args: Optional[Dict[str, Any]] = None, ) -> None: """Configure the Redis client.""" @@ -144,12 +165,78 @@ async def __aexit__( self.checkpoint_writes_index._redis_client = None async def asetup(self) -> None: - """Initialize Redis indexes asynchronously.""" - # Create indexes in Redis asynchronously + """Set up the checkpoint saver.""" + self.create_indexes() await self.checkpoints_index.create(overwrite=False) await self.checkpoint_blobs_index.create(overwrite=False) await self.checkpoint_writes_index.create(overwrite=False) + # Detect cluster mode if not explicitly set + await self._detect_cluster_mode() + + async def _detect_cluster_mode(self) -> None: + """Detect if the Redis client is a cluster client by inspecting its class.""" + if self.cluster_mode is not None: + logger.info( + f"Redis cluster_mode explicitly set to {self.cluster_mode}, skipping detection." + ) + return + + # Determine cluster mode based on client class + if isinstance(self._redis, AsyncRedisCluster): + logger.info("Redis client is a cluster client") + self.cluster_mode = True + else: + logger.info("Redis client is a standalone client") + self.cluster_mode = False + + async def _apply_ttl_to_keys( + self, + main_key: str, + related_keys: Optional[list[str]] = None, + ttl_minutes: Optional[float] = None, + ) -> Any: + """Apply Redis native TTL to keys asynchronously. + + Args: + main_key: The primary Redis key + related_keys: Additional Redis keys that should expire at the same time + ttl_minutes: Time-to-live in minutes, overrides default_ttl if provided + + Returns: + Result of the Redis operation + """ + if ttl_minutes is None: + # Check if there's a default TTL in config + if self.ttl_config and "default_ttl" in self.ttl_config: + ttl_minutes = self.ttl_config.get("default_ttl") + + if ttl_minutes is not None: + ttl_seconds = int(ttl_minutes * 60) + + if self.cluster_mode: + # For cluster mode, execute TTL operations individually + await self._redis.expire(main_key, ttl_seconds) + + if related_keys: + for key in related_keys: + await self._redis.expire(key, ttl_seconds) + + return True + else: + # For non-cluster mode, use pipeline for efficiency + pipeline = self._redis.pipeline() + + # Set TTL for main key + pipeline.expire(main_key, ttl_seconds) + + # Set TTL for related keys + if related_keys: + for key in related_keys: + pipeline.expire(key, ttl_seconds) + + return await pipeline.execute() + async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: """Get a checkpoint tuple from Redis asynchronously.""" thread_id = config["configurable"]["thread_id"] @@ -223,18 +310,10 @@ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: write_keys = [safely_decode(key) for key in write_keys] # Apply TTL to checkpoint, blob keys, and write keys - ttl_minutes = self.ttl_config.get("default_ttl") - if ttl_minutes is not None: - ttl_seconds = int(ttl_minutes * 60) - pipeline = self._redis.pipeline() - pipeline.expire(checkpoint_key, ttl_seconds) - - # Combine blob keys and write keys for TTL refresh - all_related_keys = blob_keys + write_keys - for key in all_related_keys: - pipeline.expire(key, ttl_seconds) - - await pipeline.execute() + all_related_keys = blob_keys + write_keys + await self._apply_ttl_to_keys( + checkpoint_key, all_related_keys if all_related_keys else None + ) # Fetch channel_values channel_values = await self.aget_channel_values( @@ -474,11 +553,8 @@ async def aput( } } - # Store checkpoint data with transaction handling + # Store checkpoint data with cluster-aware handling try: - # Create a pipeline with transaction=True for atomicity - pipeline = self._redis.pipeline(transaction=True) - # Store checkpoint data checkpoint_data = { "thread_id": storage_safe_thread_id, @@ -501,9 +577,6 @@ async def aput( storage_safe_checkpoint_id, ) - # Add checkpoint data to Redis - await pipeline.json().set(checkpoint_key, "$", checkpoint_data) - # Store blob values blobs = self._dump_blobs( storage_safe_thread_id, @@ -512,29 +585,56 @@ async def aput( new_versions, ) - if blobs: - # Add all blob operations to the pipeline - for key, data in blobs: - await pipeline.json().set(key, "$", data) + if self.cluster_mode: + # For cluster mode, execute operations individually + await self._redis.json().set(checkpoint_key, "$", checkpoint_data) - # Execute all operations atomically - await pipeline.execute() + if blobs: + for key, data in blobs: + await self._redis.json().set(key, "$", data) - # Apply TTL to checkpoint and blob keys if configured - if self.ttl_config and "default_ttl" in self.ttl_config: - all_keys = ( - [checkpoint_key] + [key for key, _ in blobs] - if blobs - else [checkpoint_key] - ) - ttl_minutes = self.ttl_config.get("default_ttl") - ttl_seconds = int(ttl_minutes * 60) + # Apply TTL if configured + if self.ttl_config and "default_ttl" in self.ttl_config: + all_keys = ( + [checkpoint_key] + [key for key, _ in blobs] + if blobs + else [checkpoint_key] + ) + ttl_minutes = self.ttl_config.get("default_ttl") + ttl_seconds = int(ttl_minutes * 60) + + for key in all_keys: + await self._redis.expire(key, ttl_seconds) + else: + # For non-cluster mode, use pipeline with transaction for atomicity + pipeline = self._redis.pipeline(transaction=True) - # Use a new pipeline for TTL operations - ttl_pipeline = self._redis.pipeline() - for key in all_keys: - ttl_pipeline.expire(key, ttl_seconds) - await ttl_pipeline.execute() + # Add checkpoint data to pipeline + await pipeline.json().set(checkpoint_key, "$", checkpoint_data) + + if blobs: + # Add all blob operations to the pipeline + for key, data in blobs: + await pipeline.json().set(key, "$", data) + + # Execute all operations atomically + await pipeline.execute() + + # Apply TTL to checkpoint and blob keys if configured + if self.ttl_config and "default_ttl" in self.ttl_config: + all_keys = ( + [checkpoint_key] + [key for key, _ in blobs] + if blobs + else [checkpoint_key] + ) + ttl_minutes = self.ttl_config.get("default_ttl") + ttl_seconds = int(ttl_minutes * 60) + + # Use a new pipeline for TTL operations + ttl_pipeline = self._redis.pipeline() + for key in all_keys: + ttl_pipeline.expire(key, ttl_seconds) + await ttl_pipeline.execute() return next_config @@ -544,9 +644,6 @@ async def aput( # For these modes, we want to ensure any partial state is committed # to allow resuming the stream later try: - # Try to commit what we have so far - pipeline = self._redis.pipeline(transaction=True) - # Store minimal checkpoint data checkpoint_data = { "thread_id": storage_safe_thread_id, @@ -570,9 +667,16 @@ async def aput( storage_safe_checkpoint_id, ) - # Add checkpoint data to Redis - await pipeline.json().set(checkpoint_key, "$", checkpoint_data) - await pipeline.execute() + if self.cluster_mode: + # For cluster mode, execute operation directly + await self._redis.json().set( + checkpoint_key, "$", checkpoint_data + ) + else: + # For non-cluster mode, use pipeline + pipeline = self._redis.pipeline(transaction=True) + await pipeline.json().set(checkpoint_key, "$", checkpoint_data) + await pipeline.execute() except Exception: # If this also fails, we just propagate the original cancellation pass @@ -630,57 +734,105 @@ async def aput_writes( writes_objects.append(write_obj) try: - # Use a transaction pipeline for atomicity - pipeline = self._redis.pipeline(transaction=True) - # Determine if this is an upsert case upsert_case = all(w[0] in WRITES_IDX_MAP for w in writes) created_keys = [] - # Add all write operations to the pipeline - for write_obj in writes_objects: - key = self._make_redis_checkpoint_writes_key( - thread_id, - checkpoint_ns, - checkpoint_id, - task_id, - write_obj["idx"], # type: ignore[arg-type] - ) + if self.cluster_mode: + # For cluster mode, execute operations individually + for write_obj in writes_objects: + key = self._make_redis_checkpoint_writes_key( + thread_id, + checkpoint_ns, + checkpoint_id, + task_id, + write_obj["idx"], # type: ignore[arg-type] + ) - if upsert_case: - # For upsert case, we need to check if the key exists and update differently - exists = await self._redis.exists(key) - if exists: - # Update existing key - await pipeline.json().set( - key, "$.channel", write_obj["channel"] - ) - await pipeline.json().set(key, "$.type", write_obj["type"]) - await pipeline.json().set(key, "$.blob", write_obj["blob"]) + if upsert_case: + # For upsert case, check if key exists and update differently + exists = await self._redis.exists(key) + if exists: + # Update existing key + await self._redis.json().set( + key, "$.channel", write_obj["channel"] + ) + await self._redis.json().set( + key, "$.type", write_obj["type"] + ) + await self._redis.json().set( + key, "$.blob", write_obj["blob"] + ) + else: + # Create new key + await self._redis.json().set(key, "$", write_obj) + created_keys.append(key) else: - # Create new key - await pipeline.json().set(key, "$", write_obj) - created_keys.append(key) - else: - # For non-upsert case, only set if key doesn't exist - exists = await self._redis.exists(key) - if not exists: - await pipeline.json().set(key, "$", write_obj) - created_keys.append(key) + # For non-upsert case, only set if key doesn't exist + exists = await self._redis.exists(key) + if not exists: + await self._redis.json().set(key, "$", write_obj) + created_keys.append(key) + + # Apply TTL to newly created keys + if ( + created_keys + and self.ttl_config + and "default_ttl" in self.ttl_config + ): + ttl_minutes = self.ttl_config.get("default_ttl") + ttl_seconds = int(ttl_minutes * 60) + + for key in created_keys: + await self._redis.expire(key, ttl_seconds) + else: + # For non-cluster mode, use transaction pipeline for atomicity + pipeline = self._redis.pipeline(transaction=True) + + # Add all write operations to the pipeline + for write_obj in writes_objects: + key = self._make_redis_checkpoint_writes_key( + thread_id, + checkpoint_ns, + checkpoint_id, + task_id, + write_obj["idx"], # type: ignore[arg-type] + ) - # Execute all operations atomically - await pipeline.execute() + if upsert_case: + # For upsert case, we need to check if the key exists and update differently + exists = await self._redis.exists(key) + if exists: + # Update existing key + await pipeline.json().set( + key, "$.channel", write_obj["channel"] + ) + await pipeline.json().set(key, "$.type", write_obj["type"]) + await pipeline.json().set(key, "$.blob", write_obj["blob"]) + else: + # Create new key + await pipeline.json().set(key, "$", write_obj) + created_keys.append(key) + else: + # For non-upsert case, only set if key doesn't exist + exists = await self._redis.exists(key) + if not exists: + await pipeline.json().set(key, "$", write_obj) + created_keys.append(key) - # Apply TTL to newly created keys - if created_keys and self.ttl_config and "default_ttl" in self.ttl_config: - ttl_minutes = self.ttl_config.get("default_ttl") - ttl_seconds = int(ttl_minutes * 60) + # Execute all operations atomically + await pipeline.execute() - # Use a new pipeline for TTL operations - ttl_pipeline = self._redis.pipeline() - for key in created_keys: - ttl_pipeline.expire(key, ttl_seconds) - await ttl_pipeline.execute() + # Apply TTL to newly created keys + if ( + created_keys + and self.ttl_config + and "default_ttl" in self.ttl_config + ): + await self._apply_ttl_to_keys( + created_keys[0], + created_keys[1:] if len(created_keys) > 1 else None, + ) except asyncio.CancelledError: # Handle cancellation/interruption @@ -780,7 +932,7 @@ async def from_conn_string( cls, redis_url: Optional[str] = None, *, - redis_client: Optional[AsyncRedis] = None, + redis_client: Optional[Union[AsyncRedis, AsyncRedisCluster]] = None, connection_args: Optional[Dict[str, Any]] = None, ttl: Optional[Dict[str, Any]] = None, ) -> AsyncIterator[AsyncRedisSaver]: @@ -943,8 +1095,8 @@ async def adelete_thread(self, thread_id: str) -> None: checkpoint_results = await self.checkpoints_index.search(checkpoint_query) - # Delete all checkpoint-related keys - pipeline = self._redis.pipeline() + # Collect all keys to delete + keys_to_delete = [] for doc in checkpoint_results.docs: checkpoint_ns = getattr(doc, "checkpoint_ns", "") @@ -954,7 +1106,7 @@ async def adelete_thread(self, thread_id: str) -> None: checkpoint_key = BaseRedisSaver._make_redis_checkpoint_key( storage_safe_thread_id, checkpoint_ns, checkpoint_id ) - pipeline.delete(checkpoint_key) + keys_to_delete.append(checkpoint_key) # Delete all blobs for this thread blob_query = FilterQuery( @@ -973,7 +1125,7 @@ async def adelete_thread(self, thread_id: str) -> None: blob_key = BaseRedisSaver._make_redis_checkpoint_blob_key( storage_safe_thread_id, checkpoint_ns, channel, version ) - pipeline.delete(blob_key) + keys_to_delete.append(blob_key) # Delete all writes for this thread writes_query = FilterQuery( @@ -993,7 +1145,16 @@ async def adelete_thread(self, thread_id: str) -> None: write_key = BaseRedisSaver._make_redis_checkpoint_writes_key( storage_safe_thread_id, checkpoint_ns, checkpoint_id, task_id, idx ) - pipeline.delete(write_key) + keys_to_delete.append(write_key) - # Execute all deletions - await pipeline.execute() + # Execute all deletions based on cluster mode + if self.cluster_mode: + # For cluster mode, delete keys individually + for key in keys_to_delete: + await self._redis.delete(key) + else: + # For non-cluster mode, use pipeline for efficiency + pipeline = self._redis.pipeline() + for key in keys_to_delete: + pipeline.delete(key) + await pipeline.execute() diff --git a/langgraph/checkpoint/redis/base.py b/langgraph/checkpoint/redis/base.py index b8a7c69..c7645ba 100644 --- a/langgraph/checkpoint/redis/base.py +++ b/langgraph/checkpoint/redis/base.py @@ -241,17 +241,32 @@ def _apply_ttl_to_keys( if ttl_minutes is not None: ttl_seconds = int(ttl_minutes * 60) - pipeline = self._redis.pipeline() - # Set TTL for main key - pipeline.expire(main_key, ttl_seconds) + # Check if cluster mode is detected (for sync checkpoint savers) + cluster_mode = getattr(self, "cluster_mode", False) - # Set TTL for related keys - if related_keys: - for key in related_keys: - pipeline.expire(key, ttl_seconds) + if cluster_mode: + # For cluster mode, execute TTL operations individually + self._redis.expire(main_key, ttl_seconds) - return pipeline.execute() + if related_keys: + for key in related_keys: + self._redis.expire(key, ttl_seconds) + + return True + else: + # For non-cluster mode, use pipeline for efficiency + pipeline = self._redis.pipeline() + + # Set TTL for main key + pipeline.expire(main_key, ttl_seconds) + + # Set TTL for related keys + if related_keys: + for key in related_keys: + pipeline.expire(key, ttl_seconds) + + return pipeline.execute() def _dump_checkpoint(self, checkpoint: Checkpoint) -> dict[str, Any]: """Convert checkpoint to Redis format.""" diff --git a/langgraph/checkpoint/redis/types.py b/langgraph/checkpoint/redis/types.py index 167ac9d..52c6ab2 100644 --- a/langgraph/checkpoint/redis/types.py +++ b/langgraph/checkpoint/redis/types.py @@ -2,8 +2,12 @@ from redis import Redis from redis.asyncio import Redis as AsyncRedis +from redis.cluster import RedisCluster +from redis.asyncio.cluster import RedisCluster as AsyncRedisCluster from redisvl.index import AsyncSearchIndex, SearchIndex -RedisClientType = TypeVar("RedisClientType", bound=Union[Redis, AsyncRedis]) +RedisClientType = TypeVar( + "RedisClientType", bound=Union[Redis, AsyncRedis, RedisCluster, AsyncRedisCluster] +) IndexType = TypeVar("IndexType", bound=Union[SearchIndex, AsyncSearchIndex]) MetadataInput = Optional[dict[str, Any]] diff --git a/langgraph/store/redis/__init__.py b/langgraph/store/redis/__init__.py index 034d35b..dd6a820 100644 --- a/langgraph/store/redis/__init__.py +++ b/langgraph/store/redis/__init__.py @@ -4,6 +4,7 @@ import asyncio import json +import logging import math from contextlib import contextmanager from datetime import datetime, timezone @@ -41,7 +42,6 @@ _namespace_to_text, _row_to_item, _row_to_search_item, - logger, ) from .token_unescaper import TokenUnescaper @@ -49,6 +49,8 @@ _token_escaper = TokenEscaper() _token_unescaper = TokenUnescaper() +logger = logging.getLogger(__name__) + def _convert_redis_score_to_similarity(score: float, distance_type: str) -> float: """Convert Redis vector distance to similarity score.""" diff --git a/langgraph/store/redis/base.py b/langgraph/store/redis/base.py index af46530..a511417 100644 --- a/langgraph/store/redis/base.py +++ b/langgraph/store/redis/base.py @@ -8,7 +8,6 @@ from datetime import datetime, timedelta, timezone from typing import Any, Generic, Iterable, Optional, Sequence, TypedDict, TypeVar, Union -from langchain_core.embeddings import Embeddings from langgraph.store.base import ( GetOp, IndexConfig, @@ -25,7 +24,6 @@ ) from redis import Redis from redis.asyncio import Redis as AsyncRedis -from redis.cluster import RedisCluster as SyncRedisCluster from redis.exceptions import ResponseError from redisvl.index import SearchIndex from redisvl.query.filter import Tag, Text @@ -263,7 +261,6 @@ def __init__( def set_client_info(self) -> None: """Set client info for Redis monitoring.""" - from redis.exceptions import ResponseError from langgraph.checkpoint.redis.version import __redisvl_version__ @@ -283,7 +280,6 @@ def set_client_info(self) -> None: async def aset_client_info(self) -> None: """Set client info for Redis monitoring asynchronously.""" - from redis.exceptions import ResponseError from langgraph.checkpoint.redis.version import __redisvl_version__ diff --git a/tests/test_async_cluster_mode.py b/tests/test_async_cluster_mode.py index 01e3fba..a942ee9 100644 --- a/tests/test_async_cluster_mode.py +++ b/tests/test_async_cluster_mode.py @@ -11,6 +11,7 @@ ) from langgraph.store.redis import AsyncRedisStore +from langgraph.checkpoint.redis.aio import AsyncRedisSaver # Override session-scoped redis_container fixture to prevent Docker operations and provide dummy host/port @@ -93,6 +94,27 @@ def __init__(self, *args, **kwargs): self.expire_calls = [] self.delete_calls = [] + # Add required cluster attributes to prevent AttributeError + self.cluster_error_retry_attempts = 3 + self.connection_pool = AsyncMock() + + # Mock the client_setinfo method that's called during setup + async def client_setinfo(self, *args, **kwargs): + return True + + # Mock execute_command to avoid cluster-specific execution + async def execute_command(self, *args, **kwargs): + command = args[0] if args else "" + if command == "CLIENT SETINFO": + return True + # Add other command responses as needed + return None + + # Mock module_list method for Redis modules check + async def module_list(self): + # Return mock modules that satisfy the validation requirements + return [{"name": "search", "ver": 20600}, {"name": "json", "ver": 20600}] + # Mock pipeline to record calls and simulate async behavior def pipeline(self, transaction=True): # print(f"AsyncMockRedisCluster.pipeline called with transaction={transaction}") @@ -204,3 +226,130 @@ async def test_async_cluster_mode_behavior_differs( call.get("transaction") is True for call in mock_async_redis_client.pipeline_calls ), "Transactional pipeline expected for async non-cluster TTL" + + +@pytest.fixture(params=[False, True]) +async def async_checkpoint_saver(request): + """Parameterized fixture for AsyncRedisSaver with regular or cluster client.""" + is_cluster = request.param + client = AsyncMockRedisCluster() if is_cluster else AsyncMockRedis() + + saver = AsyncRedisSaver(redis_client=client) + + # Mock the search indices + saver.checkpoints_index = AsyncMock() + saver.checkpoints_index.create = AsyncMock() + saver.checkpoints_index.search = AsyncMock(return_value=MagicMock(docs=[])) + saver.checkpoints_index.load = AsyncMock() + + saver.checkpoint_blobs_index = AsyncMock() + saver.checkpoint_blobs_index.create = AsyncMock() + saver.checkpoint_blobs_index.search = AsyncMock(return_value=MagicMock(docs=[])) + saver.checkpoint_blobs_index.load = AsyncMock() + + saver.checkpoint_writes_index = AsyncMock() + saver.checkpoint_writes_index.create = AsyncMock() + saver.checkpoint_writes_index.search = AsyncMock(return_value=MagicMock(docs=[])) + saver.checkpoint_writes_index.load = AsyncMock() + + # Skip asetup() to avoid complex RedisVL index creation, just test cluster detection + await saver._detect_cluster_mode() + return saver + + +@pytest.mark.asyncio +async def test_async_checkpoint_saver_cluster_detection(async_checkpoint_saver): + """Test that async checkpoint saver cluster_mode is set correctly.""" + is_client_cluster = isinstance(async_checkpoint_saver._redis, AsyncRedisCluster) + assert async_checkpoint_saver.cluster_mode == is_client_cluster + + +@pytest.mark.asyncio +async def test_async_checkpoint_saver_aput_ttl_behavior(async_checkpoint_saver): + """Test TTL behavior in aput for async checkpoint saver in cluster vs. non-cluster mode.""" + from langgraph.checkpoint.base import Checkpoint, CheckpointMetadata + + client = async_checkpoint_saver._redis + client.expire_calls.clear() + client.pipeline_calls.clear() + + # Set up TTL config + async_checkpoint_saver.ttl_config = {"default_ttl": 5.0} + + # Mock the JSON operations to avoid actual data operations + mock_json = AsyncMock() + mock_json.set = AsyncMock(return_value=True) + client.json = MagicMock(return_value=mock_json) + + # Create mock checkpoint and metadata + config = { + "configurable": { + "thread_id": "test_thread", + "checkpoint_ns": "", + "checkpoint_id": "test_checkpoint", + } + } + checkpoint: Checkpoint = {"channel_values": {}, "version": "1.0"} + metadata: CheckpointMetadata = {"source": "test", "step": 1} + new_versions = {} + + # Call aput which should trigger TTL operations + await async_checkpoint_saver.aput(config, checkpoint, metadata, new_versions) + + if async_checkpoint_saver.cluster_mode: + # In cluster mode, TTL operations should be called directly + assert len(client.expire_calls) >= 1 # At least one TTL call for the checkpoint + # Check that expire was called with correct TTL (5 minutes = 300 seconds) + ttl_calls = [call for call in client.expire_calls if call.get("ttl") == 300] + assert len(ttl_calls) >= 1 + else: + # In non-cluster mode, pipeline should be used for TTL operations + assert len(client.pipeline_calls) > 0 + # Should have pipeline calls for the main operations and potentially TTL operations + + +@pytest.mark.asyncio +async def test_async_checkpoint_saver_delete_thread_behavior(async_checkpoint_saver): + """Test delete_thread behavior for async checkpoint saver in cluster vs. non-cluster mode.""" + client = async_checkpoint_saver._redis + client.delete_calls.clear() + client.pipeline_calls.clear() + + # Mock search results to simulate existing data + mock_checkpoint_doc = MagicMock() + mock_checkpoint_doc.checkpoint_ns = "test_ns" + mock_checkpoint_doc.checkpoint_id = "test_checkpoint" + + mock_blob_doc = MagicMock() + mock_blob_doc.checkpoint_ns = "test_ns" + mock_blob_doc.channel = "test_channel" + mock_blob_doc.version = "1" + + mock_write_doc = MagicMock() + mock_write_doc.checkpoint_ns = "test_ns" + mock_write_doc.checkpoint_id = "test_checkpoint" + mock_write_doc.task_id = "test_task" + mock_write_doc.idx = 0 + + async_checkpoint_saver.checkpoints_index.search.return_value = MagicMock( + docs=[mock_checkpoint_doc] + ) + async_checkpoint_saver.checkpoint_blobs_index.search.return_value = MagicMock( + docs=[] + ) + async_checkpoint_saver.checkpoint_writes_index.search.return_value = MagicMock( + docs=[] + ) + + await async_checkpoint_saver.adelete_thread("test_thread") + + if async_checkpoint_saver.cluster_mode: + # In cluster mode, delete operations should be called directly + assert len(client.delete_calls) > 0 # At least one checkpoint key deletion + # Pipeline should not be used for deletions in cluster mode + # (it might be called for other reasons but not for delete operations) + else: + # In non-cluster mode, pipeline should be used for deletions + assert len(client.pipeline_calls) > 0 # At least one pipeline used + # Direct delete calls should not be made in non-cluster mode + assert len(client.delete_calls) == 0 diff --git a/tests/test_cluster_mode.py b/tests/test_cluster_mode.py index e699061..a4edfba 100644 --- a/tests/test_cluster_mode.py +++ b/tests/test_cluster_mode.py @@ -18,6 +18,8 @@ STORE_PREFIX, STORE_VECTOR_PREFIX, ) +from langgraph.checkpoint.redis import RedisSaver +from langgraph.checkpoint.base import Checkpoint, CheckpointMetadata # Override session-scoped redis_container fixture to prevent Docker operations and provide dummy host/port @@ -364,3 +366,120 @@ def test_batch_list_namespaces_ops_behavior(store): assert ("test", "documents") in results[1] assert ("test", "images") in results[1] assert ("prod", "documents") in results[1] + + +@pytest.fixture(params=[False, True]) +def checkpoint_saver(request): + """Parameterized fixture for RedisSaver with regular or cluster client.""" + is_cluster = request.param + client = MockRedisCluster() if is_cluster else MockRedis() + + saver = RedisSaver(redis_client=client) + + # Mock the search indices + saver.checkpoints_index = MagicMock() + saver.checkpoints_index.create = MagicMock() + saver.checkpoints_index.search = MagicMock(return_value=MagicMock(docs=[])) + saver.checkpoints_index.load = MagicMock() + + saver.checkpoint_blobs_index = MagicMock() + saver.checkpoint_blobs_index.create = MagicMock() + saver.checkpoint_blobs_index.search = MagicMock(return_value=MagicMock(docs=[])) + saver.checkpoint_blobs_index.load = MagicMock() + + saver.checkpoint_writes_index = MagicMock() + saver.checkpoint_writes_index.create = MagicMock() + saver.checkpoint_writes_index.search = MagicMock(return_value=MagicMock(docs=[])) + saver.checkpoint_writes_index.load = MagicMock() + + saver.setup() + return saver + + +def test_checkpoint_saver_cluster_detection(checkpoint_saver): + """Test that checkpoint saver cluster_mode is set correctly.""" + is_client_cluster = isinstance(checkpoint_saver._redis, SyncRedisCluster) + assert checkpoint_saver.cluster_mode == is_client_cluster + + +def test_checkpoint_saver_ttl_behavior(checkpoint_saver): + """Test TTL behavior for checkpoint saver in cluster vs. non-cluster mode.""" + client = checkpoint_saver._redis + client.expire_calls.clear() + client.pipeline_calls.clear() + + # Set up TTL config + checkpoint_saver.ttl_config = {"default_ttl": 5.0} + + main_key = "checkpoint:test:key" + blob_keys = ["blob:key1", "blob:key2"] + + checkpoint_saver._apply_ttl_to_keys(main_key, blob_keys) + + if checkpoint_saver.cluster_mode: + # In cluster mode, TTL operations should be called directly + assert len(client.expire_calls) == 3 + assert { + "key": main_key, + "ttl": 300, + } in client.expire_calls # 5 minutes = 300 seconds + assert {"key": "blob:key1", "ttl": 300} in client.expire_calls + assert {"key": "blob:key2", "ttl": 300} in client.expire_calls + # Pipeline should not be used + assert len(client.pipeline_calls) == 0 + else: + # In non-cluster mode, pipeline should be used + assert len(client.pipeline_calls) > 0 + assert client.pipeline_calls[0]["transaction"] is True + client._pipeline.expire.assert_any_call(main_key, 300) + client._pipeline.expire.assert_any_call("blob:key1", 300) + client._pipeline.expire.assert_any_call("blob:key2", 300) + # Direct expire calls should not be made + assert len(client.expire_calls) == 0 + + +def test_checkpoint_saver_delete_thread_behavior(checkpoint_saver): + """Test delete_thread behavior for checkpoint saver in cluster vs. non-cluster mode.""" + client = checkpoint_saver._redis + client.delete_calls.clear() + client.pipeline_calls.clear() + + # Mock search results to simulate existing data + mock_checkpoint_doc = MagicMock() + mock_checkpoint_doc.checkpoint_ns = "test_ns" + mock_checkpoint_doc.checkpoint_id = "test_checkpoint" + + mock_blob_doc = MagicMock() + mock_blob_doc.checkpoint_ns = "test_ns" + mock_blob_doc.channel = "test_channel" + mock_blob_doc.version = "1" + + mock_write_doc = MagicMock() + mock_write_doc.checkpoint_ns = "test_ns" + mock_write_doc.checkpoint_id = "test_checkpoint" + mock_write_doc.task_id = "test_task" + mock_write_doc.idx = 0 + + checkpoint_saver.checkpoints_index.search.return_value = MagicMock( + docs=[mock_checkpoint_doc] + ) + checkpoint_saver.checkpoint_blobs_index.search.return_value = MagicMock( + docs=[mock_blob_doc] + ) + checkpoint_saver.checkpoint_writes_index.search.return_value = MagicMock( + docs=[mock_write_doc] + ) + + checkpoint_saver.delete_thread("test_thread") + + if checkpoint_saver.cluster_mode: + # In cluster mode, delete operations should be called directly + assert len(client.delete_calls) > 0 + # Pipeline should not be used for deletions + assert not any(call.get("transaction") for call in client.pipeline_calls) + else: + # In non-cluster mode, pipeline should be used for deletions + assert len(client.pipeline_calls) > 0 + client._pipeline.delete.assert_called() + # Direct delete calls should not be made + assert len(client.delete_calls) == 0