Skip to content

Commit

Permalink
Merge pull request #2166 from langchain-ai/nc/23oct/remote-graph-interop
Browse files Browse the repository at this point in the history
Interop of RemoteGraph w core lib
  • Loading branch information
nfcampos authored Oct 24, 2024
2 parents bdc75a2 + 05f008c commit 83238f5
Show file tree
Hide file tree
Showing 7 changed files with 135 additions and 163 deletions.
8 changes: 8 additions & 0 deletions libs/langgraph/langgraph/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,14 @@ def branch_writer(
self.nodes[end].triggers.append(channel_name)
cast(list[str], self.nodes[end].channels).append(channel_name)

async def aget_graph(
self,
config: Optional[RunnableConfig] = None,
*,
xray: Union[int, bool] = False,
) -> DrawableGraph:
return self.get_graph(config, xray=xray)

def get_graph(
self,
config: Optional[RunnableConfig] = None,
Expand Down
15 changes: 13 additions & 2 deletions libs/langgraph/langgraph/pregel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

from langchain_core.globals import get_debug
from langchain_core.runnables import (
Runnable,
RunnableSequence,
)
from langchain_core.runnables.base import Input, Output
Expand All @@ -34,6 +33,7 @@
get_async_callback_manager_for_config,
get_callback_manager_for_config,
)
from langchain_core.runnables.graph import Graph
from langchain_core.runnables.utils import (
ConfigurableFieldSpec,
get_unique_config_specs,
Expand Down Expand Up @@ -86,6 +86,7 @@
from langgraph.pregel.loop import AsyncPregelLoop, StreamProtocol, SyncPregelLoop
from langgraph.pregel.manager import AsyncChannelsManager, ChannelsManager
from langgraph.pregel.messages import StreamMessagesHandler
from langgraph.pregel.protocol import PregelProtocol
from langgraph.pregel.read import PregelNode
from langgraph.pregel.retry import RetryPolicy
from langgraph.pregel.runner import PregelRunner
Expand Down Expand Up @@ -179,7 +180,7 @@ def write_to(
)


class Pregel(Runnable[Union[dict[str, Any], Any], Union[dict[str, Any], Any]]):
class Pregel(PregelProtocol):
nodes: dict[str, PregelNode]

channels: dict[str, Union[BaseChannel, ManagedValueSpec]]
Expand Down Expand Up @@ -259,6 +260,16 @@ def __init__(
if auto_validate:
self.validate()

def get_graph(
self, config: RunnableConfig | None = None, *, xray: int | bool = False
) -> Graph:
raise NotImplementedError

async def aget_graph(
self, config: RunnableConfig | None = None, *, xray: int | bool = False
) -> Graph:
raise NotImplementedError

def copy(self, update: dict[str, Any] | None = None) -> Self:
attrs = {**self.__dict__, **(update or {})}
return self.__class__(**attrs)
Expand Down
31 changes: 18 additions & 13 deletions libs/langgraph/langgraph/pregel/protocol.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,55 @@
from abc import ABC, abstractmethod
from typing import (
Any,
AsyncIterator,
Iterator,
Optional,
Protocol,
Sequence,
Union,
runtime_checkable,
)

from langchain_core.runnables import RunnableConfig
from langchain_core.runnables import Runnable, RunnableConfig
from langchain_core.runnables.graph import Graph as DrawableGraph
from typing_extensions import Self

from langgraph.pregel.types import All, StateSnapshot, StreamMode


@runtime_checkable
class PregelProtocol(Protocol):
class PregelProtocol(
Runnable[Union[dict[str, Any], Any], Union[dict[str, Any], Any]], ABC
):
@abstractmethod
def with_config(
self, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Self: ...

@abstractmethod
def get_graph(
self,
config: Optional[RunnableConfig] = None,
*,
xray: Union[int, bool] = False,
) -> DrawableGraph: ...

@abstractmethod
async def aget_graph(
self,
config: Optional[RunnableConfig] = None,
*,
xray: Union[int, bool] = False,
) -> DrawableGraph: ...

def get_subgraphs(
self, namespace: Optional[str] = None, recurse: bool = False
) -> Iterator[tuple[str, "PregelProtocol"]]: ...

def aget_subgraphs(
self, namespace: Optional[str] = None, recurse: bool = False
) -> AsyncIterator[tuple[str, "PregelProtocol"]]: ...

@abstractmethod
def get_state(
self, config: RunnableConfig, *, subgraphs: bool = False
) -> StateSnapshot: ...

@abstractmethod
async def aget_state(
self, config: RunnableConfig, *, subgraphs: bool = False
) -> StateSnapshot: ...

@abstractmethod
def get_state_history(
self,
config: RunnableConfig,
Expand All @@ -61,6 +59,7 @@ def get_state_history(
limit: Optional[int] = None,
) -> Iterator[StateSnapshot]: ...

@abstractmethod
def aget_state_history(
self,
config: RunnableConfig,
Expand All @@ -70,20 +69,23 @@ def aget_state_history(
limit: Optional[int] = None,
) -> AsyncIterator[StateSnapshot]: ...

@abstractmethod
def update_state(
self,
config: RunnableConfig,
values: Optional[Union[dict[str, Any], Any]],
as_node: Optional[str] = None,
) -> RunnableConfig: ...

@abstractmethod
async def aupdate_state(
self,
config: RunnableConfig,
values: Optional[Union[dict[str, Any], Any]],
as_node: Optional[str] = None,
) -> RunnableConfig: ...

@abstractmethod
def stream(
self,
input: Union[dict[str, Any], Any],
Expand All @@ -95,6 +97,7 @@ def stream(
subgraphs: bool = False,
) -> Iterator[Union[dict[str, Any], Any]]: ...

@abstractmethod
def astream(
self,
input: Union[dict[str, Any], Any],
Expand All @@ -106,6 +109,7 @@ def astream(
subgraphs: bool = False,
) -> AsyncIterator[Union[dict[str, Any], Any]]: ...

@abstractmethod
def invoke(
self,
input: Union[dict[str, Any], Any],
Expand All @@ -115,6 +119,7 @@ def invoke(
interrupt_after: Optional[Union[All, Sequence[str]]] = None,
) -> Union[dict[str, Any], Any]: ...

@abstractmethod
async def ainvoke(
self,
input: Union[dict[str, Any], Any],
Expand Down
81 changes: 41 additions & 40 deletions libs/langgraph/langgraph/pregel/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
Any,
AsyncIterator,
Iterator,
Literal,
Optional,
Sequence,
Union,
cast,
)

import orjson
from langchain_core.runnables import Runnable, RunnableConfig
from langchain_core.runnables import RunnableConfig
from langchain_core.runnables.graph import (
Edge as DrawableEdge,
)
Expand Down Expand Up @@ -44,10 +45,14 @@ class RemoteException(Exception):
pass


class RemoteGraph(PregelProtocol, Runnable):
class RemoteGraph(PregelProtocol):
name: str

def __init__(
self,
graph_id: str,
name: str, # graph_id
/,
*,
url: Optional[str] = None,
api_key: Optional[str] = None,
headers: Optional[dict[str, str]] = None,
Expand All @@ -60,7 +65,7 @@ def __init__(
If `client` or `sync_client` are provided, they will be used instead of the default clients.
See `LangGraphClient` and `SyncLangGraphClient` for details on the default clients.
"""
self.graph_id = graph_id
self.name = name
self.config = config
self.client = client or get_client(url=url, api_key=api_key, headers=headers)
self.sync_client = sync_client or get_sync_client(
Expand All @@ -69,7 +74,7 @@ def __init__(

def copy(self, update: dict[str, Any]) -> Self:
attrs = {**self.__dict__, **update}
return self.__class__(**attrs)
return self.__class__(attrs.pop("name"), **attrs)

def with_config(
self, config: Optional[RunnableConfig] = None, **kwargs: Any
Expand Down Expand Up @@ -99,7 +104,7 @@ def get_graph(
xray: Union[int, bool] = False,
) -> DrawableGraph:
graph = self.sync_client.assistants.get_graph(
assistant_id=self.graph_id,
assistant_id=self.name,
xray=xray,
)
return DrawableGraph(
Expand All @@ -114,38 +119,14 @@ async def aget_graph(
xray: Union[int, bool] = False,
) -> DrawableGraph:
graph = await self.client.assistants.get_graph(
assistant_id=self.graph_id,
assistant_id=self.name,
xray=xray,
)
return DrawableGraph(
nodes=self._get_drawable_nodes(graph),
edges=[DrawableEdge(**edge) for edge in graph["edges"]],
)

def get_subgraphs(
self, namespace: Optional[str] = None, recurse: bool = False
) -> Iterator[tuple[str, "PregelProtocol"]]:
subgraphs = self.sync_client.assistants.get_subgraphs(
assistant_id=self.graph_id,
namespace=namespace,
recurse=recurse,
)
for namespace, graph_schema in subgraphs.items():
remote_subgraph = self.copy({"graph_id": graph_schema["graph_id"]})
yield (namespace, remote_subgraph)

async def aget_subgraphs(
self, namespace: Optional[str] = None, recurse: bool = False
) -> AsyncIterator[tuple[str, "PregelProtocol"]]:
subgraphs = await self.client.assistants.get_subgraphs(
assistant_id=self.graph_id,
namespace=namespace,
recurse=recurse,
)
for namespace, graph_schema in subgraphs.items():
remote_subgraph = self.copy({"graph_id": graph_schema["graph_id"]})
yield (namespace, remote_subgraph)

def _create_state_snapshot(self, state: ThreadState) -> StateSnapshot:
tasks = []
for task in state["tasks"]:
Expand Down Expand Up @@ -258,7 +239,11 @@ def _sanitize_obj(obj: Any) -> Any:
if k not in reserved_configurable_keys and not k.startswith("__pregel_")
}

return {"configurable": new_configurable}
return {
"tags": config.get("tags"),
"metadata": config.get("metadata"),
"configurable": new_configurable,
}

def get_state(
self, config: RunnableConfig, *, subgraphs: bool = False
Expand Down Expand Up @@ -402,8 +387,8 @@ def stream(
stream_modes, req_updates, req_single = self._get_stream_modes(stream_mode)

for chunk in self.sync_client.runs.stream(
thread_id=cast(str, sanitized_config["configurable"]["thread_id"]),
assistant_id=self.graph_id,
thread_id=sanitized_config["configurable"].get("thread_id"),
assistant_id=self.name,
input=input,
config=sanitized_config,
stream_mode=stream_modes,
Expand Down Expand Up @@ -449,8 +434,8 @@ async def astream(
stream_modes, req_updates, req_single = self._get_stream_modes(stream_mode)

async for chunk in self.client.runs.stream(
thread_id=sanitized_config["configurable"]["thread_id"],
assistant_id=self.graph_id,
thread_id=sanitized_config["configurable"].get("thread_id"),
assistant_id=self.name,
input=input,
config=sanitized_config,
stream_mode=stream_modes,
Expand Down Expand Up @@ -481,6 +466,22 @@ async def astream(
else:
yield chunk

async def astream_events(
self,
input: Any,
config: Optional[RunnableConfig] = None,
*,
version: Literal["v1", "v2"],
include_names: Optional[Sequence[All]] = None,
include_types: Optional[Sequence[All]] = None,
include_tags: Optional[Sequence[All]] = None,
exclude_names: Optional[Sequence[All]] = None,
exclude_types: Optional[Sequence[All]] = None,
exclude_tags: Optional[Sequence[All]] = None,
**kwargs: Any,
) -> AsyncIterator[dict[str, Any]]:
raise NotImplementedError

def invoke(
self,
input: Union[dict[str, Any], Any],
Expand All @@ -493,8 +494,8 @@ def invoke(
sanitized_config = self._sanitize_config(merged_config)

return self.sync_client.runs.wait(
thread_id=sanitized_config["configurable"]["thread_id"],
assistant_id=self.graph_id,
thread_id=sanitized_config["configurable"].get("thread_id"),
assistant_id=self.name,
input=input,
config=sanitized_config,
interrupt_before=interrupt_before,
Expand All @@ -514,8 +515,8 @@ async def ainvoke(
sanitized_config = self._sanitize_config(merged_config)

return await self.client.runs.wait(
thread_id=sanitized_config["configurable"]["thread_id"],
assistant_id=self.graph_id,
thread_id=sanitized_config["configurable"].get("thread_id"),
assistant_id=self.name,
input=input,
config=sanitized_config,
interrupt_before=interrupt_before,
Expand Down
5 changes: 3 additions & 2 deletions libs/langgraph/langgraph/pregel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from langchain_core.runnables.utils import get_function_nonlocals

from langgraph.checkpoint.base import ChannelVersions
from langgraph.pregel.protocol import PregelProtocol
from langgraph.utils.runnable import Runnable, RunnableCallable, RunnableSeq


Expand Down Expand Up @@ -32,9 +33,9 @@ def find_subgraph_pregel(candidate: Runnable) -> Optional[Runnable]:

for c in candidates:
if (
isinstance(c, Pregel)
isinstance(c, PregelProtocol)
# subgraphs that disabled checkpointing are not considered
and c.checkpointer is not False
and (not isinstance(c, Pregel) or c.checkpointer is not False)
):
return c
elif isinstance(c, RunnableSequence) or isinstance(c, RunnableSeq):
Expand Down
Loading

0 comments on commit 83238f5

Please sign in to comment.