From dc8260bb7218e5dc0afcc425ed3cb3efdc40b185 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Wed, 23 Oct 2024 15:37:17 -0700 Subject: [PATCH 1/5] Interop of RemoteGraph w core lib --- libs/langgraph/langgraph/graph/graph.py | 2 +- libs/langgraph/langgraph/pregel/__init__.py | 5 +- libs/langgraph/langgraph/pregel/remote.py | 44 ++++++++++---- libs/langgraph/langgraph/pregel/utils.py | 5 +- libs/langgraph/tests/test_remote_graph.py | 66 +++++++++++++++------ 5 files changed, 87 insertions(+), 35 deletions(-) diff --git a/libs/langgraph/langgraph/graph/graph.py b/libs/langgraph/langgraph/graph/graph.py index 7c0923751..2071caedc 100644 --- a/libs/langgraph/langgraph/graph/graph.py +++ b/libs/langgraph/langgraph/graph/graph.py @@ -170,7 +170,7 @@ def add_node( def add_node( self, node: Union[str, RunnableLike], - action: Optional[RunnableLike] = None, + action: Optional[Union[RunnableLike]] = None, *, metadata: Optional[dict[str, Any]] = None, ) -> Self: diff --git a/libs/langgraph/langgraph/pregel/__init__.py b/libs/langgraph/langgraph/pregel/__init__.py index b535a06e1..f894c6420 100644 --- a/libs/langgraph/langgraph/pregel/__init__.py +++ b/libs/langgraph/langgraph/pregel/__init__.py @@ -86,6 +86,7 @@ from langgraph.pregel.loop import AsyncPregelLoop, StreamProtocol, SyncPregelLoop from langgraph.pregel.manager import AsyncChannelsManager, ChannelsManager from langgraph.pregel.messages import StreamMessagesHandler +from langgraph.pregel.protocol import PregelProtocol from langgraph.pregel.read import PregelNode from langgraph.pregel.retry import RetryPolicy from langgraph.pregel.runner import PregelRunner @@ -179,7 +180,9 @@ def write_to( ) -class Pregel(Runnable[Union[dict[str, Any], Any], Union[dict[str, Any], Any]]): +class Pregel( + Runnable[Union[dict[str, Any], Any], Union[dict[str, Any], Any]], PregelProtocol +): nodes: dict[str, PregelNode] channels: dict[str, Union[BaseChannel, ManagedValueSpec]] diff --git a/libs/langgraph/langgraph/pregel/remote.py b/libs/langgraph/langgraph/pregel/remote.py index f0001659b..dfd8ccd11 100644 --- a/libs/langgraph/langgraph/pregel/remote.py +++ b/libs/langgraph/langgraph/pregel/remote.py @@ -47,7 +47,9 @@ class RemoteException(Exception): class RemoteGraph(PregelProtocol, Runnable): def __init__( self, - graph_id: str, + name: str, # graph_id + /, + *, url: Optional[str] = None, api_key: Optional[str] = None, headers: Optional[dict[str, str]] = None, @@ -60,7 +62,7 @@ def __init__( If `client` or `sync_client` are provided, they will be used instead of the default clients. See `LangGraphClient` and `SyncLangGraphClient` for details on the default clients. """ - self.graph_id = graph_id + self.name = name self.config = config self.client = client or get_client(url=url, api_key=api_key, headers=headers) self.sync_client = sync_client or get_sync_client( @@ -69,7 +71,7 @@ def __init__( def copy(self, update: dict[str, Any]) -> Self: attrs = {**self.__dict__, **update} - return self.__class__(**attrs) + return self.__class__(attrs.pop("name"), **attrs) def with_config( self, config: Optional[RunnableConfig] = None, **kwargs: Any @@ -99,7 +101,7 @@ def get_graph( xray: Union[int, bool] = False, ) -> DrawableGraph: graph = self.sync_client.assistants.get_graph( - assistant_id=self.graph_id, + assistant_id=self.name, xray=xray, ) return DrawableGraph( @@ -114,7 +116,7 @@ async def aget_graph( xray: Union[int, bool] = False, ) -> DrawableGraph: graph = await self.client.assistants.get_graph( - assistant_id=self.graph_id, + assistant_id=self.name, xray=xray, ) return DrawableGraph( @@ -126,24 +128,24 @@ def get_subgraphs( self, namespace: Optional[str] = None, recurse: bool = False ) -> Iterator[tuple[str, "PregelProtocol"]]: subgraphs = self.sync_client.assistants.get_subgraphs( - assistant_id=self.graph_id, + assistant_id=self.name, namespace=namespace, recurse=recurse, ) for namespace, graph_schema in subgraphs.items(): - remote_subgraph = self.copy({"graph_id": graph_schema["graph_id"]}) + remote_subgraph = self.copy({"name": graph_schema["graph_id"]}) yield (namespace, remote_subgraph) async def aget_subgraphs( self, namespace: Optional[str] = None, recurse: bool = False ) -> AsyncIterator[tuple[str, "PregelProtocol"]]: subgraphs = await self.client.assistants.get_subgraphs( - assistant_id=self.graph_id, + assistant_id=self.name, namespace=namespace, recurse=recurse, ) for namespace, graph_schema in subgraphs.items(): - remote_subgraph = self.copy({"graph_id": graph_schema["graph_id"]}) + remote_subgraph = self.copy({"name": graph_schema["graph_id"]}) yield (namespace, remote_subgraph) def _create_state_snapshot(self, state: ThreadState) -> StateSnapshot: @@ -403,7 +405,7 @@ def stream( for chunk in self.sync_client.runs.stream( thread_id=cast(str, sanitized_config["configurable"]["thread_id"]), - assistant_id=self.graph_id, + assistant_id=self.name, input=input, config=sanitized_config, stream_mode=stream_modes, @@ -450,7 +452,7 @@ async def astream( async for chunk in self.client.runs.stream( thread_id=sanitized_config["configurable"]["thread_id"], - assistant_id=self.graph_id, + assistant_id=self.name, input=input, config=sanitized_config, stream_mode=stream_modes, @@ -481,6 +483,22 @@ async def astream( else: yield chunk + async def astream_events( + self, + input: Any, + config: RunnableConfig | None = None, + *, + version: All | All, + include_names: Sequence[All] | None = None, + include_types: Sequence[All] | None = None, + include_tags: Sequence[All] | None = None, + exclude_names: Sequence[All] | None = None, + exclude_types: Sequence[All] | None = None, + exclude_tags: Sequence[All] | None = None, + **kwargs: Any, + ) -> AsyncIterator[dict[str, Any]]: + raise NotImplementedError + def invoke( self, input: Union[dict[str, Any], Any], @@ -494,7 +512,7 @@ def invoke( return self.sync_client.runs.wait( thread_id=sanitized_config["configurable"]["thread_id"], - assistant_id=self.graph_id, + assistant_id=self.name, input=input, config=sanitized_config, interrupt_before=interrupt_before, @@ -515,7 +533,7 @@ async def ainvoke( return await self.client.runs.wait( thread_id=sanitized_config["configurable"]["thread_id"], - assistant_id=self.graph_id, + assistant_id=self.name, input=input, config=sanitized_config, interrupt_before=interrupt_before, diff --git a/libs/langgraph/langgraph/pregel/utils.py b/libs/langgraph/langgraph/pregel/utils.py index 2b09f8f75..cc7221ea4 100644 --- a/libs/langgraph/langgraph/pregel/utils.py +++ b/libs/langgraph/langgraph/pregel/utils.py @@ -4,6 +4,7 @@ from langchain_core.runnables.utils import get_function_nonlocals from langgraph.checkpoint.base import ChannelVersions +from langgraph.pregel.protocol import PregelProtocol from langgraph.utils.runnable import Runnable, RunnableCallable, RunnableSeq @@ -32,9 +33,9 @@ def find_subgraph_pregel(candidate: Runnable) -> Optional[Runnable]: for c in candidates: if ( - isinstance(c, Pregel) + isinstance(c, PregelProtocol) # subgraphs that disabled checkpointing are not considered - and c.checkpointer is not False + and (not isinstance(c, Pregel) or c.checkpointer is not False) ): return c elif isinstance(c, RunnableSequence) or isinstance(c, RunnableSeq): diff --git a/libs/langgraph/tests/test_remote_graph.py b/libs/langgraph/tests/test_remote_graph.py index 46b3ab9a4..dfd4ad57c 100644 --- a/libs/langgraph/tests/test_remote_graph.py +++ b/libs/langgraph/tests/test_remote_graph.py @@ -17,7 +17,7 @@ def test_with_config(): # set up test remote_pregel = RemoteGraph( - graph_id="test_graph_id", + "test_graph_id", config={ "configurable": { "foo": "bar", @@ -64,7 +64,7 @@ def test_get_graph(): ], } - remote_pregel = RemoteGraph(sync_client=mock_sync_client, graph_id="test_graph_id") + remote_pregel = RemoteGraph("test_graph_id", sync_client=mock_sync_client) # call method / assertions drawable_graph = remote_pregel.get_graph() @@ -111,7 +111,7 @@ async def test_aget_graph(): ], } - remote_pregel = RemoteGraph(client=mock_async_client, graph_id="test_graph_id") + remote_pregel = RemoteGraph("test_graph_id", client=mock_async_client) # call method / assertions drawable_graph = await remote_pregel.aget_graph() @@ -155,9 +155,7 @@ def test_get_subgraphs(): }, } - remote_pregel = RemoteGraph( - sync_client=mock_sync_client, graph_id="test_graph_id_1" - ) + remote_pregel = RemoteGraph("test_graph_id_1", sync_client=mock_sync_client) # call method / assertions subgraphs = list(remote_pregel.get_subgraphs()) @@ -198,8 +196,8 @@ async def test_aget_subgraphs(): } remote_pregel = RemoteGraph( + "test_graph_id_1", client=mock_async_client, - graph_id="test_graph_id_1", ) # call method / assertions @@ -240,7 +238,10 @@ def test_get_state(): } # call method / assertions - remote_pregel = RemoteGraph(sync_client=mock_sync_client, graph_id="test_graph_id") + remote_pregel = RemoteGraph( + "test_graph_id", + sync_client=mock_sync_client, + ) config = {"configurable": {"thread_id": "thread1"}} state_snapshot = remote_pregel.get_state(config) @@ -287,7 +288,10 @@ async def test_aget_state(): } # call method / assertions - remote_pregel = RemoteGraph(client=mock_async_client, graph_id="test_graph_id") + remote_pregel = RemoteGraph( + "test_graph_id", + client=mock_async_client, + ) config = {"configurable": {"thread_id": "thread1"}} state_snapshot = await remote_pregel.aget_state(config) @@ -338,7 +342,10 @@ def test_get_state_history(): ] # call method / assertions - remote_pregel = RemoteGraph(sync_client=mock_sync_client, graph_id="test_graph_id") + remote_pregel = RemoteGraph( + "test_graph_id", + sync_client=mock_sync_client, + ) config = {"configurable": {"thread_id": "thread1"}} state_history_snapshot = list( @@ -386,7 +393,10 @@ async def test_aget_state_history(): ] # call method / assertions - remote_pregel = RemoteGraph(client=mock_async_client, graph_id="test_graph_id") + remote_pregel = RemoteGraph( + "test_graph_id", + client=mock_async_client, + ) config = {"configurable": {"thread_id": "thread1"}} state_history_snapshot = [] @@ -427,7 +437,10 @@ def test_update_state(): } # call method / assertions - remote_pregel = RemoteGraph(sync_client=mock_sync_client, graph_id="test_graph_id") + remote_pregel = RemoteGraph( + "test_graph_id", + sync_client=mock_sync_client, + ) config = {"configurable": {"thread_id": "thread1"}} response = remote_pregel.update_state(config, {"key": "value"}) @@ -456,7 +469,10 @@ async def test_aupdate_state(): } # call method / assertions - remote_pregel = RemoteGraph(client=mock_async_client, graph_id="test_graph_id") + remote_pregel = RemoteGraph( + "test_graph_id", + client=mock_async_client, + ) config = {"configurable": {"thread_id": "thread1"}} response = await remote_pregel.aupdate_state(config, {"key": "value"}) @@ -483,7 +499,10 @@ def test_stream(): ] # call method / assertions - remote_pregel = RemoteGraph(sync_client=mock_sync_client, graph_id="test_graph_id") + remote_pregel = RemoteGraph( + "test_graph_id", + sync_client=mock_sync_client, + ) # stream modes doesn't include 'updates' stream_parts = [] @@ -583,7 +602,10 @@ async def test_astream(): mock_async_client.runs.stream.return_value = async_iter # call method / assertions - remote_pregel = RemoteGraph(client=mock_async_client, graph_id="test_graph_id") + remote_pregel = RemoteGraph( + "test_graph_id", + client=mock_async_client, + ) # stream modes doesn't include 'updates' stream_parts = [] @@ -717,7 +739,10 @@ def test_invoke(): } # call method / assertions - remote_pregel = RemoteGraph(sync_client=mock_sync_client, graph_id="test_graph_id") + remote_pregel = RemoteGraph( + "test_graph_id", + sync_client=mock_sync_client, + ) config = {"configurable": {"thread_id": "thread_1"}} result = remote_pregel.invoke( @@ -736,7 +761,10 @@ async def test_ainvoke(): } # call method / assertions - remote_pregel = RemoteGraph(client=mock_async_client, graph_id="test_graph_id") + remote_pregel = RemoteGraph( + "test_graph_id", + client=mock_async_client, + ) config = {"configurable": {"thread_id": "thread_1"}} result = await remote_pregel.ainvoke( @@ -758,7 +786,9 @@ async def test_langgraph_cloud_integration(): client = get_client() sync_client = get_sync_client() remote_pregel = RemoteGraph( - client=client, sync_client=sync_client, graph_id="agent" + "agent", + client=client, + sync_client=sync_client, ) # define graph From 69227daff34652ee463d76743f77a958129a4564 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Wed, 23 Oct 2024 17:01:58 -0700 Subject: [PATCH 2/5] Lint --- libs/langgraph/langgraph/graph/graph.py | 8 ++ libs/langgraph/langgraph/pregel/__init__.py | 16 +++- libs/langgraph/langgraph/pregel/protocol.py | 31 ++++--- libs/langgraph/langgraph/pregel/remote.py | 30 +------ libs/langgraph/tests/test_remote_graph.py | 90 --------------------- 5 files changed, 42 insertions(+), 133 deletions(-) diff --git a/libs/langgraph/langgraph/graph/graph.py b/libs/langgraph/langgraph/graph/graph.py index 2071caedc..7947544cb 100644 --- a/libs/langgraph/langgraph/graph/graph.py +++ b/libs/langgraph/langgraph/graph/graph.py @@ -514,6 +514,14 @@ def branch_writer( self.nodes[end].triggers.append(channel_name) cast(list[str], self.nodes[end].channels).append(channel_name) + async def aget_graph( + self, + config: Optional[RunnableConfig] = None, + *, + xray: Union[int, bool] = False, + ) -> DrawableGraph: + return self.get_graph(config, xray=xray) + def get_graph( self, config: Optional[RunnableConfig] = None, diff --git a/libs/langgraph/langgraph/pregel/__init__.py b/libs/langgraph/langgraph/pregel/__init__.py index f894c6420..c6018cbc7 100644 --- a/libs/langgraph/langgraph/pregel/__init__.py +++ b/libs/langgraph/langgraph/pregel/__init__.py @@ -25,7 +25,6 @@ from langchain_core.globals import get_debug from langchain_core.runnables import ( - Runnable, RunnableSequence, ) from langchain_core.runnables.base import Input, Output @@ -34,6 +33,7 @@ get_async_callback_manager_for_config, get_callback_manager_for_config, ) +from langchain_core.runnables.graph import Graph from langchain_core.runnables.utils import ( ConfigurableFieldSpec, get_unique_config_specs, @@ -180,9 +180,7 @@ def write_to( ) -class Pregel( - Runnable[Union[dict[str, Any], Any], Union[dict[str, Any], Any]], PregelProtocol -): +class Pregel(PregelProtocol): nodes: dict[str, PregelNode] channels: dict[str, Union[BaseChannel, ManagedValueSpec]] @@ -262,6 +260,16 @@ def __init__( if auto_validate: self.validate() + def get_graph( + self, config: RunnableConfig | None = None, *, xray: int | bool = False + ) -> Graph: + raise NotImplementedError + + async def aget_graph( + self, config: RunnableConfig | None = None, *, xray: int | bool = False + ) -> Graph: + raise NotImplementedError + def copy(self, update: dict[str, Any] | None = None) -> Self: attrs = {**self.__dict__, **(update or {})} return self.__class__(**attrs) diff --git a/libs/langgraph/langgraph/pregel/protocol.py b/libs/langgraph/langgraph/pregel/protocol.py index 34789284b..ac046e949 100644 --- a/libs/langgraph/langgraph/pregel/protocol.py +++ b/libs/langgraph/langgraph/pregel/protocol.py @@ -1,27 +1,29 @@ +from abc import ABC, abstractmethod from typing import ( Any, AsyncIterator, Iterator, Optional, - Protocol, Sequence, Union, - runtime_checkable, ) -from langchain_core.runnables import RunnableConfig +from langchain_core.runnables import Runnable, RunnableConfig from langchain_core.runnables.graph import Graph as DrawableGraph from typing_extensions import Self from langgraph.pregel.types import All, StateSnapshot, StreamMode -@runtime_checkable -class PregelProtocol(Protocol): +class PregelProtocol( + Runnable[Union[dict[str, Any], Any], Union[dict[str, Any], Any]], ABC +): + @abstractmethod def with_config( self, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Self: ... + @abstractmethod def get_graph( self, config: Optional[RunnableConfig] = None, @@ -29,6 +31,7 @@ def get_graph( xray: Union[int, bool] = False, ) -> DrawableGraph: ... + @abstractmethod async def aget_graph( self, config: Optional[RunnableConfig] = None, @@ -36,22 +39,17 @@ async def aget_graph( xray: Union[int, bool] = False, ) -> DrawableGraph: ... - def get_subgraphs( - self, namespace: Optional[str] = None, recurse: bool = False - ) -> Iterator[tuple[str, "PregelProtocol"]]: ... - - def aget_subgraphs( - self, namespace: Optional[str] = None, recurse: bool = False - ) -> AsyncIterator[tuple[str, "PregelProtocol"]]: ... - + @abstractmethod def get_state( self, config: RunnableConfig, *, subgraphs: bool = False ) -> StateSnapshot: ... + @abstractmethod async def aget_state( self, config: RunnableConfig, *, subgraphs: bool = False ) -> StateSnapshot: ... + @abstractmethod def get_state_history( self, config: RunnableConfig, @@ -61,6 +59,7 @@ def get_state_history( limit: Optional[int] = None, ) -> Iterator[StateSnapshot]: ... + @abstractmethod def aget_state_history( self, config: RunnableConfig, @@ -70,6 +69,7 @@ def aget_state_history( limit: Optional[int] = None, ) -> AsyncIterator[StateSnapshot]: ... + @abstractmethod def update_state( self, config: RunnableConfig, @@ -77,6 +77,7 @@ def update_state( as_node: Optional[str] = None, ) -> RunnableConfig: ... + @abstractmethod async def aupdate_state( self, config: RunnableConfig, @@ -84,6 +85,7 @@ async def aupdate_state( as_node: Optional[str] = None, ) -> RunnableConfig: ... + @abstractmethod def stream( self, input: Union[dict[str, Any], Any], @@ -95,6 +97,7 @@ def stream( subgraphs: bool = False, ) -> Iterator[Union[dict[str, Any], Any]]: ... + @abstractmethod def astream( self, input: Union[dict[str, Any], Any], @@ -106,6 +109,7 @@ def astream( subgraphs: bool = False, ) -> AsyncIterator[Union[dict[str, Any], Any]]: ... + @abstractmethod def invoke( self, input: Union[dict[str, Any], Any], @@ -115,6 +119,7 @@ def invoke( interrupt_after: Optional[Union[All, Sequence[str]]] = None, ) -> Union[dict[str, Any], Any]: ... + @abstractmethod async def ainvoke( self, input: Union[dict[str, Any], Any], diff --git a/libs/langgraph/langgraph/pregel/remote.py b/libs/langgraph/langgraph/pregel/remote.py index dfd8ccd11..80295dae5 100644 --- a/libs/langgraph/langgraph/pregel/remote.py +++ b/libs/langgraph/langgraph/pregel/remote.py @@ -9,7 +9,7 @@ ) import orjson -from langchain_core.runnables import Runnable, RunnableConfig +from langchain_core.runnables import RunnableConfig from langchain_core.runnables.graph import ( Edge as DrawableEdge, ) @@ -44,7 +44,9 @@ class RemoteException(Exception): pass -class RemoteGraph(PregelProtocol, Runnable): +class RemoteGraph(PregelProtocol): + name: str + def __init__( self, name: str, # graph_id @@ -124,30 +126,6 @@ async def aget_graph( edges=[DrawableEdge(**edge) for edge in graph["edges"]], ) - def get_subgraphs( - self, namespace: Optional[str] = None, recurse: bool = False - ) -> Iterator[tuple[str, "PregelProtocol"]]: - subgraphs = self.sync_client.assistants.get_subgraphs( - assistant_id=self.name, - namespace=namespace, - recurse=recurse, - ) - for namespace, graph_schema in subgraphs.items(): - remote_subgraph = self.copy({"name": graph_schema["graph_id"]}) - yield (namespace, remote_subgraph) - - async def aget_subgraphs( - self, namespace: Optional[str] = None, recurse: bool = False - ) -> AsyncIterator[tuple[str, "PregelProtocol"]]: - subgraphs = await self.client.assistants.get_subgraphs( - assistant_id=self.name, - namespace=namespace, - recurse=recurse, - ) - for namespace, graph_schema in subgraphs.items(): - remote_subgraph = self.copy({"name": graph_schema["graph_id"]}) - yield (namespace, remote_subgraph) - def _create_state_snapshot(self, state: ThreadState) -> StateSnapshot: tasks = [] for task in state["tasks"]: diff --git a/libs/langgraph/tests/test_remote_graph.py b/libs/langgraph/tests/test_remote_graph.py index dfd4ad57c..8ee0a3f9a 100644 --- a/libs/langgraph/tests/test_remote_graph.py +++ b/libs/langgraph/tests/test_remote_graph.py @@ -135,90 +135,6 @@ async def test_aget_graph(): ] -def test_get_subgraphs(): - # set up test - mock_sync_client = MagicMock() - mock_sync_client.assistants.get_subgraphs.return_value = { - "namespace_1": { - "graph_id": "test_graph_id_2", - "input_schema": {}, - "output_schema": {}, - "state_schema": {}, - "config_schema": {}, - }, - "namespace_2": { - "graph_id": "test_graph_id_3", - "input_schema": {}, - "output_schema": {}, - "state_schema": {}, - "config_schema": {}, - }, - } - - remote_pregel = RemoteGraph("test_graph_id_1", sync_client=mock_sync_client) - - # call method / assertions - subgraphs = list(remote_pregel.get_subgraphs()) - assert len(subgraphs) == 2 - - subgraph_1 = subgraphs[0] - ns_1 = subgraph_1[0] - remote_pregel_1: RemoteGraph = subgraph_1[1] - assert ns_1 == "namespace_1" - assert remote_pregel_1.graph_id == "test_graph_id_2" - - subgraph_2 = subgraphs[1] - ns_2 = subgraph_2[0] - remote_pregel_2: RemoteGraph = subgraph_2[1] - assert ns_2 == "namespace_2" - assert remote_pregel_2.graph_id == "test_graph_id_3" - - -@pytest.mark.anyio -async def test_aget_subgraphs(): - # set up test - mock_async_client = AsyncMock() - mock_async_client.assistants.get_subgraphs.return_value = { - "namespace_1": { - "graph_id": "test_graph_id_2", - "input_schema": {}, - "output_schema": {}, - "state_schema": {}, - "config_schema": {}, - }, - "namespace_2": { - "graph_id": "test_graph_id_3", - "input_schema": {}, - "output_schema": {}, - "state_schema": {}, - "config_schema": {}, - }, - } - - remote_pregel = RemoteGraph( - "test_graph_id_1", - client=mock_async_client, - ) - - # call method / assertions - subgraphs = [] - async for subgraph in remote_pregel.aget_subgraphs(): - subgraphs.append(subgraph) - assert len(subgraphs) == 2 - - subgraph_1 = subgraphs[0] - ns_1 = subgraph_1[0] - remote_pregel_1: RemoteGraph = subgraph_1[1] - assert ns_1 == "namespace_1" - assert remote_pregel_1.graph_id == "test_graph_id_2" - - subgraph_2 = subgraphs[1] - ns_2 = subgraph_2[0] - remote_pregel_2: RemoteGraph = subgraph_2[1] - assert ns_2 == "namespace_2" - assert remote_pregel_2.graph_id == "test_graph_id_3" - - def test_get_state(): # set up test mock_sync_client = MagicMock() @@ -866,9 +782,3 @@ async def test_langgraph_cloud_integration(): remote_pregel.graph_id = "fe096781-5601-53d2-b2f6-0d3403f7e9ca" # must be UUID graph = await remote_pregel.aget_graph(xray=True) print("graph:", graph) - - # test get subgraphs - remote_pregel.graph_id = "fe096781-5601-53d2-b2f6-0d3403f7e9ca" # must be UUID - async for name, pregel in remote_pregel.aget_subgraphs(): - print("name:", name) - print("pregel:", pregel) From aa245a8e71a928a211807c024de968b9b1bc4503 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Wed, 23 Oct 2024 20:22:04 -0700 Subject: [PATCH 3/5] Fix up --- libs/langgraph/langgraph/graph/graph.py | 2 +- libs/langgraph/langgraph/pregel/remote.py | 28 +++++++++++++---------- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/libs/langgraph/langgraph/graph/graph.py b/libs/langgraph/langgraph/graph/graph.py index 7947544cb..e3d11c1eb 100644 --- a/libs/langgraph/langgraph/graph/graph.py +++ b/libs/langgraph/langgraph/graph/graph.py @@ -170,7 +170,7 @@ def add_node( def add_node( self, node: Union[str, RunnableLike], - action: Optional[Union[RunnableLike]] = None, + action: Optional[RunnableLike] = None, *, metadata: Optional[dict[str, Any]] = None, ) -> Self: diff --git a/libs/langgraph/langgraph/pregel/remote.py b/libs/langgraph/langgraph/pregel/remote.py index 80295dae5..0d717f9e0 100644 --- a/libs/langgraph/langgraph/pregel/remote.py +++ b/libs/langgraph/langgraph/pregel/remote.py @@ -238,7 +238,11 @@ def _sanitize_obj(obj: Any) -> Any: if k not in reserved_configurable_keys and not k.startswith("__pregel_") } - return {"configurable": new_configurable} + return { + "tags": config.get("tags"), + "metadata": config.get("metadata"), + "configurable": new_configurable, + } def get_state( self, config: RunnableConfig, *, subgraphs: bool = False @@ -382,7 +386,7 @@ def stream( stream_modes, req_updates, req_single = self._get_stream_modes(stream_mode) for chunk in self.sync_client.runs.stream( - thread_id=cast(str, sanitized_config["configurable"]["thread_id"]), + thread_id=sanitized_config["configurable"].get("thread_id"), assistant_id=self.name, input=input, config=sanitized_config, @@ -429,7 +433,7 @@ async def astream( stream_modes, req_updates, req_single = self._get_stream_modes(stream_mode) async for chunk in self.client.runs.stream( - thread_id=sanitized_config["configurable"]["thread_id"], + thread_id=sanitized_config["configurable"].get("thread_id"), assistant_id=self.name, input=input, config=sanitized_config, @@ -464,15 +468,15 @@ async def astream( async def astream_events( self, input: Any, - config: RunnableConfig | None = None, + config: Optional[RunnableConfig] = None, *, version: All | All, - include_names: Sequence[All] | None = None, - include_types: Sequence[All] | None = None, - include_tags: Sequence[All] | None = None, - exclude_names: Sequence[All] | None = None, - exclude_types: Sequence[All] | None = None, - exclude_tags: Sequence[All] | None = None, + include_names: Optional[Sequence[All]] = None, + include_types: Optional[Sequence[All]] = None, + include_tags: Optional[Sequence[All]] = None, + exclude_names: Optional[Sequence[All]] = None, + exclude_types: Optional[Sequence[All]] = None, + exclude_tags: Optional[Sequence[All]] = None, **kwargs: Any, ) -> AsyncIterator[dict[str, Any]]: raise NotImplementedError @@ -489,7 +493,7 @@ def invoke( sanitized_config = self._sanitize_config(merged_config) return self.sync_client.runs.wait( - thread_id=sanitized_config["configurable"]["thread_id"], + thread_id=sanitized_config["configurable"].get("thread_id"), assistant_id=self.name, input=input, config=sanitized_config, @@ -510,7 +514,7 @@ async def ainvoke( sanitized_config = self._sanitize_config(merged_config) return await self.client.runs.wait( - thread_id=sanitized_config["configurable"]["thread_id"], + thread_id=sanitized_config["configurable"].get("thread_id"), assistant_id=self.name, input=input, config=sanitized_config, From a8ae2a52a341e75f1e3c6e023365f9f901339046 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Wed, 23 Oct 2024 20:27:18 -0700 Subject: [PATCH 4/5] Lint --- libs/sdk-py/langgraph_sdk/client.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/libs/sdk-py/langgraph_sdk/client.py b/libs/sdk-py/langgraph_sdk/client.py index 239f4cf2a..7dabff4b9 100644 --- a/libs/sdk-py/langgraph_sdk/client.py +++ b/libs/sdk-py/langgraph_sdk/client.py @@ -1202,6 +1202,7 @@ def stream( feedback_keys: Optional[Sequence[str]] = None, on_disconnect: Optional[DisconnectMode] = None, on_completion: Optional[OnCompletionBehavior] = None, + if_not_exists: Optional[IfNotExists] = None, webhook: Optional[str] = None, after_seconds: Optional[int] = None, ) -> AsyncIterator[StreamPart]: ... @@ -1327,6 +1328,7 @@ async def create( interrupt_after: Optional[Union[All, Sequence[str]]] = None, webhook: Optional[str] = None, on_completion: Optional[OnCompletionBehavior] = None, + if_not_exists: Optional[IfNotExists] = None, after_seconds: Optional[int] = None, ) -> Run: ... @@ -1529,6 +1531,7 @@ async def wait( webhook: Optional[str] = None, on_disconnect: Optional[DisconnectMode] = None, on_completion: Optional[OnCompletionBehavior] = None, + if_not_exists: Optional[IfNotExists] = None, after_seconds: Optional[int] = None, ) -> Union[list[dict], dict[str, Any]]: ... @@ -3280,6 +3283,7 @@ def stream( feedback_keys: Optional[Sequence[str]] = None, on_disconnect: Optional[DisconnectMode] = None, on_completion: Optional[OnCompletionBehavior] = None, + if_not_exists: Optional[IfNotExists] = None, webhook: Optional[str] = None, after_seconds: Optional[int] = None, ) -> Iterator[StreamPart]: ... @@ -3405,6 +3409,7 @@ def create( interrupt_after: Optional[Union[All, Sequence[str]]] = None, webhook: Optional[str] = None, on_completion: Optional[OnCompletionBehavior] = None, + if_not_exists: Optional[IfNotExists] = None, after_seconds: Optional[int] = None, ) -> Run: ... @@ -3607,6 +3612,7 @@ def wait( webhook: Optional[str] = None, on_disconnect: Optional[DisconnectMode] = None, on_completion: Optional[OnCompletionBehavior] = None, + if_not_exists: Optional[IfNotExists] = None, after_seconds: Optional[int] = None, ) -> Union[list[dict], dict[str, Any]]: ... From 05f008cbfb32edf002d84e3e4def55826265e31e Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Wed, 23 Oct 2024 20:30:32 -0700 Subject: [PATCH 5/5] Lint --- libs/langgraph/langgraph/pregel/remote.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/libs/langgraph/langgraph/pregel/remote.py b/libs/langgraph/langgraph/pregel/remote.py index 0d717f9e0..ae0f5d2f3 100644 --- a/libs/langgraph/langgraph/pregel/remote.py +++ b/libs/langgraph/langgraph/pregel/remote.py @@ -2,6 +2,7 @@ Any, AsyncIterator, Iterator, + Literal, Optional, Sequence, Union, @@ -470,7 +471,7 @@ async def astream_events( input: Any, config: Optional[RunnableConfig] = None, *, - version: All | All, + version: Literal["v1", "v2"], include_names: Optional[Sequence[All]] = None, include_types: Optional[Sequence[All]] = None, include_tags: Optional[Sequence[All]] = None,