Skip to content

fix: Support cluster clients in saver #56

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 50 additions & 12 deletions langgraph/checkpoint/redis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import json
from contextlib import contextmanager
from typing import Any, Dict, Iterator, List, Optional, Tuple, cast
import logging
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast

from langchain_core.runnables import RunnableConfig
from langgraph.checkpoint.base import (
Expand All @@ -14,6 +15,7 @@
)
from langgraph.constants import TASKS
from redis import Redis
from redis.cluster import RedisCluster
from redisvl.index import SearchIndex
from redisvl.query import FilterQuery
from redisvl.query.filter import Num, Tag
Expand All @@ -32,15 +34,21 @@
)
from langgraph.checkpoint.redis.version import __lib_name__, __version__

logger = logging.getLogger(__name__)

class RedisSaver(BaseRedisSaver[Redis, SearchIndex]):

class RedisSaver(BaseRedisSaver[Union[Redis, RedisCluster], SearchIndex]):
"""Standard Redis implementation for checkpoint saving."""

_redis: Union[Redis, RedisCluster] # Support both standalone and cluster clients
# Whether to assume the Redis server is a cluster; None triggers auto-detection
cluster_mode: Optional[bool] = None

def __init__(
self,
redis_url: Optional[str] = None,
*,
redis_client: Optional[Redis] = None,
redis_client: Optional[Union[Redis, RedisCluster]] = None,
connection_args: Optional[Dict[str, Any]] = None,
ttl: Optional[Dict[str, Any]] = None,
) -> None:
Expand All @@ -54,7 +62,7 @@ def __init__(
def configure_client(
self,
redis_url: Optional[str] = None,
redis_client: Optional[Redis] = None,
redis_client: Optional[Union[Redis, RedisCluster]] = None,
connection_args: Optional[Dict[str, Any]] = None,
) -> None:
"""Configure the Redis client."""
Expand All @@ -74,6 +82,27 @@ def create_indexes(self) -> None:
self.SCHEMAS[2], redis_client=self._redis
)

def setup(self) -> None:
"""Initialize the indices in Redis and detect cluster mode."""
self._detect_cluster_mode()
super().setup()

def _detect_cluster_mode(self) -> None:
"""Detect if the Redis client is a cluster client by inspecting its class."""
if self.cluster_mode is not None:
logger.info(
f"Redis cluster_mode explicitly set to {self.cluster_mode}, skipping detection."
)
return

# Determine cluster mode based on client class
if isinstance(self._redis, RedisCluster):
logger.info("Redis client is a cluster client")
self.cluster_mode = True
else:
logger.info("Redis client is a standalone client")
self.cluster_mode = False

def list(
self,
config: Optional[RunnableConfig],
Expand Down Expand Up @@ -458,7 +487,7 @@ def from_conn_string(
cls,
redis_url: Optional[str] = None,
*,
redis_client: Optional[Redis] = None,
redis_client: Optional[Union[Redis, RedisCluster]] = None,
connection_args: Optional[Dict[str, Any]] = None,
ttl: Optional[Dict[str, Any]] = None,
) -> Iterator[RedisSaver]:
Expand Down Expand Up @@ -592,8 +621,8 @@ def delete_thread(self, thread_id: str) -> None:

checkpoint_results = self.checkpoints_index.search(checkpoint_query)

# Delete all checkpoint-related keys
pipeline = self._redis.pipeline()
# Collect all keys to delete
keys_to_delete = []

for doc in checkpoint_results.docs:
checkpoint_ns = getattr(doc, "checkpoint_ns", "")
Expand All @@ -603,7 +632,7 @@ def delete_thread(self, thread_id: str) -> None:
checkpoint_key = BaseRedisSaver._make_redis_checkpoint_key(
storage_safe_thread_id, checkpoint_ns, checkpoint_id
)
pipeline.delete(checkpoint_key)
keys_to_delete.append(checkpoint_key)

# Delete all blobs for this thread
blob_query = FilterQuery(
Expand All @@ -622,7 +651,7 @@ def delete_thread(self, thread_id: str) -> None:
blob_key = BaseRedisSaver._make_redis_checkpoint_blob_key(
storage_safe_thread_id, checkpoint_ns, channel, version
)
pipeline.delete(blob_key)
keys_to_delete.append(blob_key)

# Delete all writes for this thread
writes_query = FilterQuery(
Expand All @@ -642,10 +671,19 @@ def delete_thread(self, thread_id: str) -> None:
write_key = BaseRedisSaver._make_redis_checkpoint_writes_key(
storage_safe_thread_id, checkpoint_ns, checkpoint_id, task_id, idx
)
pipeline.delete(write_key)
keys_to_delete.append(write_key)

# Execute all deletions
pipeline.execute()
# Execute all deletions based on cluster mode
if self.cluster_mode:
# For cluster mode, delete keys individually
for key in keys_to_delete:
self._redis.delete(key)
else:
# For non-cluster mode, use pipeline for efficiency
pipeline = self._redis.pipeline()
for key in keys_to_delete:
pipeline.delete(key)
pipeline.execute()


__all__ = [
Expand Down
Loading