diff --git a/libs/langgraph/langgraph/graph/graph.py b/libs/langgraph/langgraph/graph/graph.py index 7c0923751..e3d11c1eb 100644 --- a/libs/langgraph/langgraph/graph/graph.py +++ b/libs/langgraph/langgraph/graph/graph.py @@ -514,6 +514,14 @@ def branch_writer( self.nodes[end].triggers.append(channel_name) cast(list[str], self.nodes[end].channels).append(channel_name) + async def aget_graph( + self, + config: Optional[RunnableConfig] = None, + *, + xray: Union[int, bool] = False, + ) -> DrawableGraph: + return self.get_graph(config, xray=xray) + def get_graph( self, config: Optional[RunnableConfig] = None, diff --git a/libs/langgraph/langgraph/pregel/__init__.py b/libs/langgraph/langgraph/pregel/__init__.py index b535a06e1..c6018cbc7 100644 --- a/libs/langgraph/langgraph/pregel/__init__.py +++ b/libs/langgraph/langgraph/pregel/__init__.py @@ -25,7 +25,6 @@ from langchain_core.globals import get_debug from langchain_core.runnables import ( - Runnable, RunnableSequence, ) from langchain_core.runnables.base import Input, Output @@ -34,6 +33,7 @@ get_async_callback_manager_for_config, get_callback_manager_for_config, ) +from langchain_core.runnables.graph import Graph from langchain_core.runnables.utils import ( ConfigurableFieldSpec, get_unique_config_specs, @@ -86,6 +86,7 @@ from langgraph.pregel.loop import AsyncPregelLoop, StreamProtocol, SyncPregelLoop from langgraph.pregel.manager import AsyncChannelsManager, ChannelsManager from langgraph.pregel.messages import StreamMessagesHandler +from langgraph.pregel.protocol import PregelProtocol from langgraph.pregel.read import PregelNode from langgraph.pregel.retry import RetryPolicy from langgraph.pregel.runner import PregelRunner @@ -179,7 +180,7 @@ def write_to( ) -class Pregel(Runnable[Union[dict[str, Any], Any], Union[dict[str, Any], Any]]): +class Pregel(PregelProtocol): nodes: dict[str, PregelNode] channels: dict[str, Union[BaseChannel, ManagedValueSpec]] @@ -259,6 +260,16 @@ def __init__( if auto_validate: self.validate() + def get_graph( + self, config: RunnableConfig | None = None, *, xray: int | bool = False + ) -> Graph: + raise NotImplementedError + + async def aget_graph( + self, config: RunnableConfig | None = None, *, xray: int | bool = False + ) -> Graph: + raise NotImplementedError + def copy(self, update: dict[str, Any] | None = None) -> Self: attrs = {**self.__dict__, **(update or {})} return self.__class__(**attrs) diff --git a/libs/langgraph/langgraph/pregel/protocol.py b/libs/langgraph/langgraph/pregel/protocol.py index 34789284b..ac046e949 100644 --- a/libs/langgraph/langgraph/pregel/protocol.py +++ b/libs/langgraph/langgraph/pregel/protocol.py @@ -1,27 +1,29 @@ +from abc import ABC, abstractmethod from typing import ( Any, AsyncIterator, Iterator, Optional, - Protocol, Sequence, Union, - runtime_checkable, ) -from langchain_core.runnables import RunnableConfig +from langchain_core.runnables import Runnable, RunnableConfig from langchain_core.runnables.graph import Graph as DrawableGraph from typing_extensions import Self from langgraph.pregel.types import All, StateSnapshot, StreamMode -@runtime_checkable -class PregelProtocol(Protocol): +class PregelProtocol( + Runnable[Union[dict[str, Any], Any], Union[dict[str, Any], Any]], ABC +): + @abstractmethod def with_config( self, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Self: ... + @abstractmethod def get_graph( self, config: Optional[RunnableConfig] = None, @@ -29,6 +31,7 @@ def get_graph( xray: Union[int, bool] = False, ) -> DrawableGraph: ... + @abstractmethod async def aget_graph( self, config: Optional[RunnableConfig] = None, @@ -36,22 +39,17 @@ async def aget_graph( xray: Union[int, bool] = False, ) -> DrawableGraph: ... - def get_subgraphs( - self, namespace: Optional[str] = None, recurse: bool = False - ) -> Iterator[tuple[str, "PregelProtocol"]]: ... - - def aget_subgraphs( - self, namespace: Optional[str] = None, recurse: bool = False - ) -> AsyncIterator[tuple[str, "PregelProtocol"]]: ... - + @abstractmethod def get_state( self, config: RunnableConfig, *, subgraphs: bool = False ) -> StateSnapshot: ... + @abstractmethod async def aget_state( self, config: RunnableConfig, *, subgraphs: bool = False ) -> StateSnapshot: ... + @abstractmethod def get_state_history( self, config: RunnableConfig, @@ -61,6 +59,7 @@ def get_state_history( limit: Optional[int] = None, ) -> Iterator[StateSnapshot]: ... + @abstractmethod def aget_state_history( self, config: RunnableConfig, @@ -70,6 +69,7 @@ def aget_state_history( limit: Optional[int] = None, ) -> AsyncIterator[StateSnapshot]: ... + @abstractmethod def update_state( self, config: RunnableConfig, @@ -77,6 +77,7 @@ def update_state( as_node: Optional[str] = None, ) -> RunnableConfig: ... + @abstractmethod async def aupdate_state( self, config: RunnableConfig, @@ -84,6 +85,7 @@ async def aupdate_state( as_node: Optional[str] = None, ) -> RunnableConfig: ... + @abstractmethod def stream( self, input: Union[dict[str, Any], Any], @@ -95,6 +97,7 @@ def stream( subgraphs: bool = False, ) -> Iterator[Union[dict[str, Any], Any]]: ... + @abstractmethod def astream( self, input: Union[dict[str, Any], Any], @@ -106,6 +109,7 @@ def astream( subgraphs: bool = False, ) -> AsyncIterator[Union[dict[str, Any], Any]]: ... + @abstractmethod def invoke( self, input: Union[dict[str, Any], Any], @@ -115,6 +119,7 @@ def invoke( interrupt_after: Optional[Union[All, Sequence[str]]] = None, ) -> Union[dict[str, Any], Any]: ... + @abstractmethod async def ainvoke( self, input: Union[dict[str, Any], Any], diff --git a/libs/langgraph/langgraph/pregel/remote.py b/libs/langgraph/langgraph/pregel/remote.py index f0001659b..ae0f5d2f3 100644 --- a/libs/langgraph/langgraph/pregel/remote.py +++ b/libs/langgraph/langgraph/pregel/remote.py @@ -2,6 +2,7 @@ Any, AsyncIterator, Iterator, + Literal, Optional, Sequence, Union, @@ -9,7 +10,7 @@ ) import orjson -from langchain_core.runnables import Runnable, RunnableConfig +from langchain_core.runnables import RunnableConfig from langchain_core.runnables.graph import ( Edge as DrawableEdge, ) @@ -44,10 +45,14 @@ class RemoteException(Exception): pass -class RemoteGraph(PregelProtocol, Runnable): +class RemoteGraph(PregelProtocol): + name: str + def __init__( self, - graph_id: str, + name: str, # graph_id + /, + *, url: Optional[str] = None, api_key: Optional[str] = None, headers: Optional[dict[str, str]] = None, @@ -60,7 +65,7 @@ def __init__( If `client` or `sync_client` are provided, they will be used instead of the default clients. See `LangGraphClient` and `SyncLangGraphClient` for details on the default clients. """ - self.graph_id = graph_id + self.name = name self.config = config self.client = client or get_client(url=url, api_key=api_key, headers=headers) self.sync_client = sync_client or get_sync_client( @@ -69,7 +74,7 @@ def __init__( def copy(self, update: dict[str, Any]) -> Self: attrs = {**self.__dict__, **update} - return self.__class__(**attrs) + return self.__class__(attrs.pop("name"), **attrs) def with_config( self, config: Optional[RunnableConfig] = None, **kwargs: Any @@ -99,7 +104,7 @@ def get_graph( xray: Union[int, bool] = False, ) -> DrawableGraph: graph = self.sync_client.assistants.get_graph( - assistant_id=self.graph_id, + assistant_id=self.name, xray=xray, ) return DrawableGraph( @@ -114,7 +119,7 @@ async def aget_graph( xray: Union[int, bool] = False, ) -> DrawableGraph: graph = await self.client.assistants.get_graph( - assistant_id=self.graph_id, + assistant_id=self.name, xray=xray, ) return DrawableGraph( @@ -122,30 +127,6 @@ async def aget_graph( edges=[DrawableEdge(**edge) for edge in graph["edges"]], ) - def get_subgraphs( - self, namespace: Optional[str] = None, recurse: bool = False - ) -> Iterator[tuple[str, "PregelProtocol"]]: - subgraphs = self.sync_client.assistants.get_subgraphs( - assistant_id=self.graph_id, - namespace=namespace, - recurse=recurse, - ) - for namespace, graph_schema in subgraphs.items(): - remote_subgraph = self.copy({"graph_id": graph_schema["graph_id"]}) - yield (namespace, remote_subgraph) - - async def aget_subgraphs( - self, namespace: Optional[str] = None, recurse: bool = False - ) -> AsyncIterator[tuple[str, "PregelProtocol"]]: - subgraphs = await self.client.assistants.get_subgraphs( - assistant_id=self.graph_id, - namespace=namespace, - recurse=recurse, - ) - for namespace, graph_schema in subgraphs.items(): - remote_subgraph = self.copy({"graph_id": graph_schema["graph_id"]}) - yield (namespace, remote_subgraph) - def _create_state_snapshot(self, state: ThreadState) -> StateSnapshot: tasks = [] for task in state["tasks"]: @@ -258,7 +239,11 @@ def _sanitize_obj(obj: Any) -> Any: if k not in reserved_configurable_keys and not k.startswith("__pregel_") } - return {"configurable": new_configurable} + return { + "tags": config.get("tags"), + "metadata": config.get("metadata"), + "configurable": new_configurable, + } def get_state( self, config: RunnableConfig, *, subgraphs: bool = False @@ -402,8 +387,8 @@ def stream( stream_modes, req_updates, req_single = self._get_stream_modes(stream_mode) for chunk in self.sync_client.runs.stream( - thread_id=cast(str, sanitized_config["configurable"]["thread_id"]), - assistant_id=self.graph_id, + thread_id=sanitized_config["configurable"].get("thread_id"), + assistant_id=self.name, input=input, config=sanitized_config, stream_mode=stream_modes, @@ -449,8 +434,8 @@ async def astream( stream_modes, req_updates, req_single = self._get_stream_modes(stream_mode) async for chunk in self.client.runs.stream( - thread_id=sanitized_config["configurable"]["thread_id"], - assistant_id=self.graph_id, + thread_id=sanitized_config["configurable"].get("thread_id"), + assistant_id=self.name, input=input, config=sanitized_config, stream_mode=stream_modes, @@ -481,6 +466,22 @@ async def astream( else: yield chunk + async def astream_events( + self, + input: Any, + config: Optional[RunnableConfig] = None, + *, + version: Literal["v1", "v2"], + include_names: Optional[Sequence[All]] = None, + include_types: Optional[Sequence[All]] = None, + include_tags: Optional[Sequence[All]] = None, + exclude_names: Optional[Sequence[All]] = None, + exclude_types: Optional[Sequence[All]] = None, + exclude_tags: Optional[Sequence[All]] = None, + **kwargs: Any, + ) -> AsyncIterator[dict[str, Any]]: + raise NotImplementedError + def invoke( self, input: Union[dict[str, Any], Any], @@ -493,8 +494,8 @@ def invoke( sanitized_config = self._sanitize_config(merged_config) return self.sync_client.runs.wait( - thread_id=sanitized_config["configurable"]["thread_id"], - assistant_id=self.graph_id, + thread_id=sanitized_config["configurable"].get("thread_id"), + assistant_id=self.name, input=input, config=sanitized_config, interrupt_before=interrupt_before, @@ -514,8 +515,8 @@ async def ainvoke( sanitized_config = self._sanitize_config(merged_config) return await self.client.runs.wait( - thread_id=sanitized_config["configurable"]["thread_id"], - assistant_id=self.graph_id, + thread_id=sanitized_config["configurable"].get("thread_id"), + assistant_id=self.name, input=input, config=sanitized_config, interrupt_before=interrupt_before, diff --git a/libs/langgraph/langgraph/pregel/utils.py b/libs/langgraph/langgraph/pregel/utils.py index 2b09f8f75..cc7221ea4 100644 --- a/libs/langgraph/langgraph/pregel/utils.py +++ b/libs/langgraph/langgraph/pregel/utils.py @@ -4,6 +4,7 @@ from langchain_core.runnables.utils import get_function_nonlocals from langgraph.checkpoint.base import ChannelVersions +from langgraph.pregel.protocol import PregelProtocol from langgraph.utils.runnable import Runnable, RunnableCallable, RunnableSeq @@ -32,9 +33,9 @@ def find_subgraph_pregel(candidate: Runnable) -> Optional[Runnable]: for c in candidates: if ( - isinstance(c, Pregel) + isinstance(c, PregelProtocol) # subgraphs that disabled checkpointing are not considered - and c.checkpointer is not False + and (not isinstance(c, Pregel) or c.checkpointer is not False) ): return c elif isinstance(c, RunnableSequence) or isinstance(c, RunnableSeq): diff --git a/libs/langgraph/tests/test_remote_graph.py b/libs/langgraph/tests/test_remote_graph.py index 46b3ab9a4..8ee0a3f9a 100644 --- a/libs/langgraph/tests/test_remote_graph.py +++ b/libs/langgraph/tests/test_remote_graph.py @@ -17,7 +17,7 @@ def test_with_config(): # set up test remote_pregel = RemoteGraph( - graph_id="test_graph_id", + "test_graph_id", config={ "configurable": { "foo": "bar", @@ -64,7 +64,7 @@ def test_get_graph(): ], } - remote_pregel = RemoteGraph(sync_client=mock_sync_client, graph_id="test_graph_id") + remote_pregel = RemoteGraph("test_graph_id", sync_client=mock_sync_client) # call method / assertions drawable_graph = remote_pregel.get_graph() @@ -111,7 +111,7 @@ async def test_aget_graph(): ], } - remote_pregel = RemoteGraph(client=mock_async_client, graph_id="test_graph_id") + remote_pregel = RemoteGraph("test_graph_id", client=mock_async_client) # call method / assertions drawable_graph = await remote_pregel.aget_graph() @@ -135,92 +135,6 @@ async def test_aget_graph(): ] -def test_get_subgraphs(): - # set up test - mock_sync_client = MagicMock() - mock_sync_client.assistants.get_subgraphs.return_value = { - "namespace_1": { - "graph_id": "test_graph_id_2", - "input_schema": {}, - "output_schema": {}, - "state_schema": {}, - "config_schema": {}, - }, - "namespace_2": { - "graph_id": "test_graph_id_3", - "input_schema": {}, - "output_schema": {}, - "state_schema": {}, - "config_schema": {}, - }, - } - - remote_pregel = RemoteGraph( - sync_client=mock_sync_client, graph_id="test_graph_id_1" - ) - - # call method / assertions - subgraphs = list(remote_pregel.get_subgraphs()) - assert len(subgraphs) == 2 - - subgraph_1 = subgraphs[0] - ns_1 = subgraph_1[0] - remote_pregel_1: RemoteGraph = subgraph_1[1] - assert ns_1 == "namespace_1" - assert remote_pregel_1.graph_id == "test_graph_id_2" - - subgraph_2 = subgraphs[1] - ns_2 = subgraph_2[0] - remote_pregel_2: RemoteGraph = subgraph_2[1] - assert ns_2 == "namespace_2" - assert remote_pregel_2.graph_id == "test_graph_id_3" - - -@pytest.mark.anyio -async def test_aget_subgraphs(): - # set up test - mock_async_client = AsyncMock() - mock_async_client.assistants.get_subgraphs.return_value = { - "namespace_1": { - "graph_id": "test_graph_id_2", - "input_schema": {}, - "output_schema": {}, - "state_schema": {}, - "config_schema": {}, - }, - "namespace_2": { - "graph_id": "test_graph_id_3", - "input_schema": {}, - "output_schema": {}, - "state_schema": {}, - "config_schema": {}, - }, - } - - remote_pregel = RemoteGraph( - client=mock_async_client, - graph_id="test_graph_id_1", - ) - - # call method / assertions - subgraphs = [] - async for subgraph in remote_pregel.aget_subgraphs(): - subgraphs.append(subgraph) - assert len(subgraphs) == 2 - - subgraph_1 = subgraphs[0] - ns_1 = subgraph_1[0] - remote_pregel_1: RemoteGraph = subgraph_1[1] - assert ns_1 == "namespace_1" - assert remote_pregel_1.graph_id == "test_graph_id_2" - - subgraph_2 = subgraphs[1] - ns_2 = subgraph_2[0] - remote_pregel_2: RemoteGraph = subgraph_2[1] - assert ns_2 == "namespace_2" - assert remote_pregel_2.graph_id == "test_graph_id_3" - - def test_get_state(): # set up test mock_sync_client = MagicMock() @@ -240,7 +154,10 @@ def test_get_state(): } # call method / assertions - remote_pregel = RemoteGraph(sync_client=mock_sync_client, graph_id="test_graph_id") + remote_pregel = RemoteGraph( + "test_graph_id", + sync_client=mock_sync_client, + ) config = {"configurable": {"thread_id": "thread1"}} state_snapshot = remote_pregel.get_state(config) @@ -287,7 +204,10 @@ async def test_aget_state(): } # call method / assertions - remote_pregel = RemoteGraph(client=mock_async_client, graph_id="test_graph_id") + remote_pregel = RemoteGraph( + "test_graph_id", + client=mock_async_client, + ) config = {"configurable": {"thread_id": "thread1"}} state_snapshot = await remote_pregel.aget_state(config) @@ -338,7 +258,10 @@ def test_get_state_history(): ] # call method / assertions - remote_pregel = RemoteGraph(sync_client=mock_sync_client, graph_id="test_graph_id") + remote_pregel = RemoteGraph( + "test_graph_id", + sync_client=mock_sync_client, + ) config = {"configurable": {"thread_id": "thread1"}} state_history_snapshot = list( @@ -386,7 +309,10 @@ async def test_aget_state_history(): ] # call method / assertions - remote_pregel = RemoteGraph(client=mock_async_client, graph_id="test_graph_id") + remote_pregel = RemoteGraph( + "test_graph_id", + client=mock_async_client, + ) config = {"configurable": {"thread_id": "thread1"}} state_history_snapshot = [] @@ -427,7 +353,10 @@ def test_update_state(): } # call method / assertions - remote_pregel = RemoteGraph(sync_client=mock_sync_client, graph_id="test_graph_id") + remote_pregel = RemoteGraph( + "test_graph_id", + sync_client=mock_sync_client, + ) config = {"configurable": {"thread_id": "thread1"}} response = remote_pregel.update_state(config, {"key": "value"}) @@ -456,7 +385,10 @@ async def test_aupdate_state(): } # call method / assertions - remote_pregel = RemoteGraph(client=mock_async_client, graph_id="test_graph_id") + remote_pregel = RemoteGraph( + "test_graph_id", + client=mock_async_client, + ) config = {"configurable": {"thread_id": "thread1"}} response = await remote_pregel.aupdate_state(config, {"key": "value"}) @@ -483,7 +415,10 @@ def test_stream(): ] # call method / assertions - remote_pregel = RemoteGraph(sync_client=mock_sync_client, graph_id="test_graph_id") + remote_pregel = RemoteGraph( + "test_graph_id", + sync_client=mock_sync_client, + ) # stream modes doesn't include 'updates' stream_parts = [] @@ -583,7 +518,10 @@ async def test_astream(): mock_async_client.runs.stream.return_value = async_iter # call method / assertions - remote_pregel = RemoteGraph(client=mock_async_client, graph_id="test_graph_id") + remote_pregel = RemoteGraph( + "test_graph_id", + client=mock_async_client, + ) # stream modes doesn't include 'updates' stream_parts = [] @@ -717,7 +655,10 @@ def test_invoke(): } # call method / assertions - remote_pregel = RemoteGraph(sync_client=mock_sync_client, graph_id="test_graph_id") + remote_pregel = RemoteGraph( + "test_graph_id", + sync_client=mock_sync_client, + ) config = {"configurable": {"thread_id": "thread_1"}} result = remote_pregel.invoke( @@ -736,7 +677,10 @@ async def test_ainvoke(): } # call method / assertions - remote_pregel = RemoteGraph(client=mock_async_client, graph_id="test_graph_id") + remote_pregel = RemoteGraph( + "test_graph_id", + client=mock_async_client, + ) config = {"configurable": {"thread_id": "thread_1"}} result = await remote_pregel.ainvoke( @@ -758,7 +702,9 @@ async def test_langgraph_cloud_integration(): client = get_client() sync_client = get_sync_client() remote_pregel = RemoteGraph( - client=client, sync_client=sync_client, graph_id="agent" + "agent", + client=client, + sync_client=sync_client, ) # define graph @@ -836,9 +782,3 @@ async def test_langgraph_cloud_integration(): remote_pregel.graph_id = "fe096781-5601-53d2-b2f6-0d3403f7e9ca" # must be UUID graph = await remote_pregel.aget_graph(xray=True) print("graph:", graph) - - # test get subgraphs - remote_pregel.graph_id = "fe096781-5601-53d2-b2f6-0d3403f7e9ca" # must be UUID - async for name, pregel in remote_pregel.aget_subgraphs(): - print("name:", name) - print("pregel:", pregel) diff --git a/libs/sdk-py/langgraph_sdk/client.py b/libs/sdk-py/langgraph_sdk/client.py index 239f4cf2a..7dabff4b9 100644 --- a/libs/sdk-py/langgraph_sdk/client.py +++ b/libs/sdk-py/langgraph_sdk/client.py @@ -1202,6 +1202,7 @@ def stream( feedback_keys: Optional[Sequence[str]] = None, on_disconnect: Optional[DisconnectMode] = None, on_completion: Optional[OnCompletionBehavior] = None, + if_not_exists: Optional[IfNotExists] = None, webhook: Optional[str] = None, after_seconds: Optional[int] = None, ) -> AsyncIterator[StreamPart]: ... @@ -1327,6 +1328,7 @@ async def create( interrupt_after: Optional[Union[All, Sequence[str]]] = None, webhook: Optional[str] = None, on_completion: Optional[OnCompletionBehavior] = None, + if_not_exists: Optional[IfNotExists] = None, after_seconds: Optional[int] = None, ) -> Run: ... @@ -1529,6 +1531,7 @@ async def wait( webhook: Optional[str] = None, on_disconnect: Optional[DisconnectMode] = None, on_completion: Optional[OnCompletionBehavior] = None, + if_not_exists: Optional[IfNotExists] = None, after_seconds: Optional[int] = None, ) -> Union[list[dict], dict[str, Any]]: ... @@ -3280,6 +3283,7 @@ def stream( feedback_keys: Optional[Sequence[str]] = None, on_disconnect: Optional[DisconnectMode] = None, on_completion: Optional[OnCompletionBehavior] = None, + if_not_exists: Optional[IfNotExists] = None, webhook: Optional[str] = None, after_seconds: Optional[int] = None, ) -> Iterator[StreamPart]: ... @@ -3405,6 +3409,7 @@ def create( interrupt_after: Optional[Union[All, Sequence[str]]] = None, webhook: Optional[str] = None, on_completion: Optional[OnCompletionBehavior] = None, + if_not_exists: Optional[IfNotExists] = None, after_seconds: Optional[int] = None, ) -> Run: ... @@ -3607,6 +3612,7 @@ def wait( webhook: Optional[str] = None, on_disconnect: Optional[DisconnectMode] = None, on_completion: Optional[OnCompletionBehavior] = None, + if_not_exists: Optional[IfNotExists] = None, after_seconds: Optional[int] = None, ) -> Union[list[dict], dict[str, Any]]: ...