Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Interop of RemoteGraph w core lib #2166

Merged
merged 5 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions libs/langgraph/langgraph/graph/graph.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import asyncio

Check notice on line 1 in libs/langgraph/langgraph/graph/graph.py

View workflow job for this annotation

GitHub Actions / benchmark

Benchmark results

......................................... fanout_to_subgraph_10x: Mean +- std dev: 48.0 ms +- 0.7 ms ......................................... fanout_to_subgraph_10x_sync: Mean +- std dev: 46.2 ms +- 3.1 ms ......................................... fanout_to_subgraph_10x_checkpoint: Mean +- std dev: 76.1 ms +- 1.4 ms ......................................... fanout_to_subgraph_10x_checkpoint_sync: Mean +- std dev: 84.6 ms +- 0.6 ms ......................................... fanout_to_subgraph_100x: Mean +- std dev: 473 ms +- 12 ms ......................................... fanout_to_subgraph_100x_sync: Mean +- std dev: 426 ms +- 5 ms ......................................... fanout_to_subgraph_100x_checkpoint: Mean +- std dev: 790 ms +- 45 ms ......................................... fanout_to_subgraph_100x_checkpoint_sync: Mean +- std dev: 829 ms +- 17 ms ......................................... react_agent_10x: Mean +- std dev: 28.9 ms +- 0.6 ms ......................................... react_agent_10x_sync: Mean +- std dev: 22.2 ms +- 1.6 ms ......................................... react_agent_10x_checkpoint: Mean +- std dev: 47.5 ms +- 3.5 ms ......................................... react_agent_10x_checkpoint_sync: Mean +- std dev: 37.2 ms +- 3.3 ms ......................................... react_agent_100x: Mean +- std dev: 328 ms +- 14 ms ......................................... react_agent_100x_sync: Mean +- std dev: 260 ms +- 12 ms ......................................... react_agent_100x_checkpoint: Mean +- std dev: 914 ms +- 8 ms ......................................... react_agent_100x_checkpoint_sync: Mean +- std dev: 817 ms +- 7 ms ......................................... wide_state_25x300: Mean +- std dev: 18.2 ms +- 0.4 ms ......................................... wide_state_25x300_sync: Mean +- std dev: 10.8 ms +- 0.1 ms ......................................... wide_state_25x300_checkpoint: Mean +- std dev: 272 ms +- 5 ms ......................................... wide_state_25x300_checkpoint_sync: Mean +- std dev: 261 ms +- 5 ms ......................................... wide_state_15x600: Mean +- std dev: 21.1 ms +- 0.4 ms ......................................... wide_state_15x600_sync: Mean +- std dev: 12.5 ms +- 0.1 ms ......................................... wide_state_15x600_checkpoint: Mean +- std dev: 471 ms +- 6 ms ......................................... wide_state_15x600_checkpoint_sync: Mean +- std dev: 466 ms +- 13 ms ......................................... wide_state_9x1200: Mean +- std dev: 21.1 ms +- 0.4 ms ......................................... wide_state_9x1200_sync: Mean +- std dev: 12.5 ms +- 0.2 ms ......................................... wide_state_9x1200_checkpoint: Mean +- std dev: 305 ms +- 3 ms ......................................... wide_state_9x1200_checkpoint_sync: Mean +- std dev: 302 ms +- 13 ms

Check notice on line 1 in libs/langgraph/langgraph/graph/graph.py

View workflow job for this annotation

GitHub Actions / benchmark

Comparison against main

+-----------------------------------------+---------+-----------------------+ | Benchmark | main | changes | +=========================================+=========+=======================+ | fanout_to_subgraph_100x_sync | 439 ms | 426 ms: 1.03x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_25x300_sync | 10.9 ms | 10.8 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_9x1200_sync | 12.6 ms | 12.5 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | react_agent_10x | 29.2 ms | 28.9 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_15x600_sync | 12.6 ms | 12.5 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x_checkpoint_sync | 837 ms | 829 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_9x1200 | 21.3 ms | 21.1 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_25x300 | 18.3 ms | 18.2 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_15x600 | 21.2 ms | 21.1 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x_checkpoint | 76.6 ms | 76.1 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x_checkpoint_sync | 85.1 ms | 84.6 ms: 1.00x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_9x1200_checkpoint | 304 ms | 305 ms: 1.00x slower | +-----------------------------------------+---------+-----------------------+ | wide_state_15x600_checkpoint_sync | 460 ms | 466 ms: 1.01x slower | +-----------------------------------------+---------+-----------------------+ | wide_state_9x1200_checkpoint_sync | 292 ms | 302 ms: 1.03x slower | +-----------------------------------------+---------+-----------------------+ | Geometric mean | (ref) | 1.00x faster | +-----------------------------------------+---------+-----------------------+ Benchmark hidden because not significant (14): fanout_to_subgraph_100x_checkpoint, react_agent_10x_checkpoint_sync, react_agent_100x_sync, wide_state_25x300_checkpoint, react_agent_10x_sync, wide_state_15x600_checkpoint, wide_state_25x300_checkpoint_sync, fanout_to_subgraph_10x, fanout_to_subgraph_100x, react_agent_10x_checkpoint, fanout_to_subgraph_10x_sync, react_agent_100x_checkpoint_sync, react_agent_100x_checkpoint, react_agent_100x
import logging
from collections import defaultdict
from typing import (
Expand Down Expand Up @@ -514,6 +514,14 @@
self.nodes[end].triggers.append(channel_name)
cast(list[str], self.nodes[end].channels).append(channel_name)

async def aget_graph(
self,
config: Optional[RunnableConfig] = None,
*,
xray: Union[int, bool] = False,
) -> DrawableGraph:
return self.get_graph(config, xray=xray)

def get_graph(
self,
config: Optional[RunnableConfig] = None,
Expand Down
15 changes: 13 additions & 2 deletions libs/langgraph/langgraph/pregel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

from langchain_core.globals import get_debug
from langchain_core.runnables import (
Runnable,
RunnableSequence,
)
from langchain_core.runnables.base import Input, Output
Expand All @@ -34,6 +33,7 @@
get_async_callback_manager_for_config,
get_callback_manager_for_config,
)
from langchain_core.runnables.graph import Graph
from langchain_core.runnables.utils import (
ConfigurableFieldSpec,
get_unique_config_specs,
Expand Down Expand Up @@ -86,6 +86,7 @@
from langgraph.pregel.loop import AsyncPregelLoop, StreamProtocol, SyncPregelLoop
from langgraph.pregel.manager import AsyncChannelsManager, ChannelsManager
from langgraph.pregel.messages import StreamMessagesHandler
from langgraph.pregel.protocol import PregelProtocol
from langgraph.pregel.read import PregelNode
from langgraph.pregel.retry import RetryPolicy
from langgraph.pregel.runner import PregelRunner
Expand Down Expand Up @@ -179,7 +180,7 @@ def write_to(
)


class Pregel(Runnable[Union[dict[str, Any], Any], Union[dict[str, Any], Any]]):
class Pregel(PregelProtocol):
nodes: dict[str, PregelNode]

channels: dict[str, Union[BaseChannel, ManagedValueSpec]]
Expand Down Expand Up @@ -259,6 +260,16 @@ def __init__(
if auto_validate:
self.validate()

def get_graph(
self, config: RunnableConfig | None = None, *, xray: int | bool = False
) -> Graph:
raise NotImplementedError

async def aget_graph(
self, config: RunnableConfig | None = None, *, xray: int | bool = False
) -> Graph:
raise NotImplementedError

def copy(self, update: dict[str, Any] | None = None) -> Self:
attrs = {**self.__dict__, **(update or {})}
return self.__class__(**attrs)
Expand Down
31 changes: 18 additions & 13 deletions libs/langgraph/langgraph/pregel/protocol.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,55 @@
from abc import ABC, abstractmethod
from typing import (
Any,
AsyncIterator,
Iterator,
Optional,
Protocol,
Sequence,
Union,
runtime_checkable,
)

from langchain_core.runnables import RunnableConfig
from langchain_core.runnables import Runnable, RunnableConfig
from langchain_core.runnables.graph import Graph as DrawableGraph
from typing_extensions import Self

from langgraph.pregel.types import All, StateSnapshot, StreamMode


@runtime_checkable
class PregelProtocol(Protocol):
class PregelProtocol(
Runnable[Union[dict[str, Any], Any], Union[dict[str, Any], Any]], ABC
):
@abstractmethod
def with_config(
self, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Self: ...

@abstractmethod
def get_graph(
self,
config: Optional[RunnableConfig] = None,
*,
xray: Union[int, bool] = False,
) -> DrawableGraph: ...

@abstractmethod
async def aget_graph(
self,
config: Optional[RunnableConfig] = None,
*,
xray: Union[int, bool] = False,
) -> DrawableGraph: ...

def get_subgraphs(
self, namespace: Optional[str] = None, recurse: bool = False
) -> Iterator[tuple[str, "PregelProtocol"]]: ...

def aget_subgraphs(
self, namespace: Optional[str] = None, recurse: bool = False
) -> AsyncIterator[tuple[str, "PregelProtocol"]]: ...

@abstractmethod
def get_state(
self, config: RunnableConfig, *, subgraphs: bool = False
) -> StateSnapshot: ...

@abstractmethod
async def aget_state(
self, config: RunnableConfig, *, subgraphs: bool = False
) -> StateSnapshot: ...

@abstractmethod
def get_state_history(
self,
config: RunnableConfig,
Expand All @@ -61,6 +59,7 @@ def get_state_history(
limit: Optional[int] = None,
) -> Iterator[StateSnapshot]: ...

@abstractmethod
def aget_state_history(
self,
config: RunnableConfig,
Expand All @@ -70,20 +69,23 @@ def aget_state_history(
limit: Optional[int] = None,
) -> AsyncIterator[StateSnapshot]: ...

@abstractmethod
def update_state(
self,
config: RunnableConfig,
values: Optional[Union[dict[str, Any], Any]],
as_node: Optional[str] = None,
) -> RunnableConfig: ...

@abstractmethod
async def aupdate_state(
self,
config: RunnableConfig,
values: Optional[Union[dict[str, Any], Any]],
as_node: Optional[str] = None,
) -> RunnableConfig: ...

@abstractmethod
def stream(
self,
input: Union[dict[str, Any], Any],
Expand All @@ -95,6 +97,7 @@ def stream(
subgraphs: bool = False,
) -> Iterator[Union[dict[str, Any], Any]]: ...

@abstractmethod
def astream(
self,
input: Union[dict[str, Any], Any],
Expand All @@ -106,6 +109,7 @@ def astream(
subgraphs: bool = False,
) -> AsyncIterator[Union[dict[str, Any], Any]]: ...

@abstractmethod
def invoke(
self,
input: Union[dict[str, Any], Any],
Expand All @@ -115,6 +119,7 @@ def invoke(
interrupt_after: Optional[Union[All, Sequence[str]]] = None,
) -> Union[dict[str, Any], Any]: ...

@abstractmethod
async def ainvoke(
self,
input: Union[dict[str, Any], Any],
Expand Down
81 changes: 41 additions & 40 deletions libs/langgraph/langgraph/pregel/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
Any,
AsyncIterator,
Iterator,
Literal,
Optional,
Sequence,
Union,
cast,
)

import orjson
from langchain_core.runnables import Runnable, RunnableConfig
from langchain_core.runnables import RunnableConfig
from langchain_core.runnables.graph import (
Edge as DrawableEdge,
)
Expand Down Expand Up @@ -44,10 +45,14 @@ class RemoteException(Exception):
pass


class RemoteGraph(PregelProtocol, Runnable):
class RemoteGraph(PregelProtocol):
name: str

def __init__(
self,
graph_id: str,
name: str, # graph_id
/,
*,
url: Optional[str] = None,
api_key: Optional[str] = None,
headers: Optional[dict[str, str]] = None,
Expand All @@ -60,7 +65,7 @@ def __init__(
If `client` or `sync_client` are provided, they will be used instead of the default clients.
See `LangGraphClient` and `SyncLangGraphClient` for details on the default clients.
"""
self.graph_id = graph_id
self.name = name
self.config = config
self.client = client or get_client(url=url, api_key=api_key, headers=headers)
self.sync_client = sync_client or get_sync_client(
Expand All @@ -69,7 +74,7 @@ def __init__(

def copy(self, update: dict[str, Any]) -> Self:
attrs = {**self.__dict__, **update}
return self.__class__(**attrs)
return self.__class__(attrs.pop("name"), **attrs)

def with_config(
self, config: Optional[RunnableConfig] = None, **kwargs: Any
Expand Down Expand Up @@ -99,7 +104,7 @@ def get_graph(
xray: Union[int, bool] = False,
) -> DrawableGraph:
graph = self.sync_client.assistants.get_graph(
assistant_id=self.graph_id,
assistant_id=self.name,
xray=xray,
)
return DrawableGraph(
Expand All @@ -114,38 +119,14 @@ async def aget_graph(
xray: Union[int, bool] = False,
) -> DrawableGraph:
graph = await self.client.assistants.get_graph(
assistant_id=self.graph_id,
assistant_id=self.name,
xray=xray,
)
return DrawableGraph(
nodes=self._get_drawable_nodes(graph),
edges=[DrawableEdge(**edge) for edge in graph["edges"]],
)

def get_subgraphs(
self, namespace: Optional[str] = None, recurse: bool = False
) -> Iterator[tuple[str, "PregelProtocol"]]:
subgraphs = self.sync_client.assistants.get_subgraphs(
assistant_id=self.graph_id,
namespace=namespace,
recurse=recurse,
)
for namespace, graph_schema in subgraphs.items():
remote_subgraph = self.copy({"graph_id": graph_schema["graph_id"]})
yield (namespace, remote_subgraph)

async def aget_subgraphs(
self, namespace: Optional[str] = None, recurse: bool = False
) -> AsyncIterator[tuple[str, "PregelProtocol"]]:
subgraphs = await self.client.assistants.get_subgraphs(
assistant_id=self.graph_id,
namespace=namespace,
recurse=recurse,
)
for namespace, graph_schema in subgraphs.items():
remote_subgraph = self.copy({"graph_id": graph_schema["graph_id"]})
yield (namespace, remote_subgraph)

def _create_state_snapshot(self, state: ThreadState) -> StateSnapshot:
tasks = []
for task in state["tasks"]:
Expand Down Expand Up @@ -258,7 +239,11 @@ def _sanitize_obj(obj: Any) -> Any:
if k not in reserved_configurable_keys and not k.startswith("__pregel_")
}

return {"configurable": new_configurable}
return {
"tags": config.get("tags"),
"metadata": config.get("metadata"),
"configurable": new_configurable,
}

def get_state(
self, config: RunnableConfig, *, subgraphs: bool = False
Expand Down Expand Up @@ -402,8 +387,8 @@ def stream(
stream_modes, req_updates, req_single = self._get_stream_modes(stream_mode)

for chunk in self.sync_client.runs.stream(
thread_id=cast(str, sanitized_config["configurable"]["thread_id"]),
assistant_id=self.graph_id,
thread_id=sanitized_config["configurable"].get("thread_id"),
assistant_id=self.name,
input=input,
config=sanitized_config,
stream_mode=stream_modes,
Expand Down Expand Up @@ -449,8 +434,8 @@ async def astream(
stream_modes, req_updates, req_single = self._get_stream_modes(stream_mode)

async for chunk in self.client.runs.stream(
thread_id=sanitized_config["configurable"]["thread_id"],
assistant_id=self.graph_id,
thread_id=sanitized_config["configurable"].get("thread_id"),
assistant_id=self.name,
input=input,
config=sanitized_config,
stream_mode=stream_modes,
Expand Down Expand Up @@ -481,6 +466,22 @@ async def astream(
else:
yield chunk

async def astream_events(
self,
input: Any,
config: Optional[RunnableConfig] = None,
*,
version: Literal["v1", "v2"],
include_names: Optional[Sequence[All]] = None,
include_types: Optional[Sequence[All]] = None,
include_tags: Optional[Sequence[All]] = None,
exclude_names: Optional[Sequence[All]] = None,
exclude_types: Optional[Sequence[All]] = None,
exclude_tags: Optional[Sequence[All]] = None,
**kwargs: Any,
) -> AsyncIterator[dict[str, Any]]:
raise NotImplementedError

def invoke(
self,
input: Union[dict[str, Any], Any],
Expand All @@ -493,8 +494,8 @@ def invoke(
sanitized_config = self._sanitize_config(merged_config)

return self.sync_client.runs.wait(
thread_id=sanitized_config["configurable"]["thread_id"],
assistant_id=self.graph_id,
thread_id=sanitized_config["configurable"].get("thread_id"),
assistant_id=self.name,
input=input,
config=sanitized_config,
interrupt_before=interrupt_before,
Expand All @@ -514,8 +515,8 @@ async def ainvoke(
sanitized_config = self._sanitize_config(merged_config)

return await self.client.runs.wait(
thread_id=sanitized_config["configurable"]["thread_id"],
assistant_id=self.graph_id,
thread_id=sanitized_config["configurable"].get("thread_id"),
assistant_id=self.name,
input=input,
config=sanitized_config,
interrupt_before=interrupt_before,
Expand Down
5 changes: 3 additions & 2 deletions libs/langgraph/langgraph/pregel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from langchain_core.runnables.utils import get_function_nonlocals

from langgraph.checkpoint.base import ChannelVersions
from langgraph.pregel.protocol import PregelProtocol
from langgraph.utils.runnable import Runnable, RunnableCallable, RunnableSeq


Expand Down Expand Up @@ -32,9 +33,9 @@ def find_subgraph_pregel(candidate: Runnable) -> Optional[Runnable]:

for c in candidates:
if (
isinstance(c, Pregel)
isinstance(c, PregelProtocol)
# subgraphs that disabled checkpointing are not considered
and c.checkpointer is not False
and (not isinstance(c, Pregel) or c.checkpointer is not False)
):
return c
elif isinstance(c, RunnableSequence) or isinstance(c, RunnableSeq):
Expand Down
Loading
Loading