From 4cc6d789edf5be08f5925e2c3cd34a366f501e88 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Wed, 30 Oct 2024 19:35:28 -0700 Subject: [PATCH] Fix stream_mode=messages used together with subgraphs=True --- libs/langgraph/langgraph/pregel/remote.py | 68 +++++++++++++---------- 1 file changed, 40 insertions(+), 28 deletions(-) diff --git a/libs/langgraph/langgraph/pregel/remote.py b/libs/langgraph/langgraph/pregel/remote.py index 69fbb60a4..bf7a76bd5 100644 --- a/libs/langgraph/langgraph/pregel/remote.py +++ b/libs/langgraph/langgraph/pregel/remote.py @@ -504,8 +504,11 @@ async def aupdate_state( def _get_stream_modes( self, stream_mode: Optional[Union[StreamMode, list[StreamMode]]], + config: Optional[RunnableConfig], default: StreamMode = "updates", - ) -> tuple[list[StreamModeSDK], bool, bool]: + ) -> tuple[ + list[StreamModeSDK], list[StreamModeSDK], bool, Optional[StreamProtocol] + ]: """Return a tuple of the final list of stream modes sent to the remote graph and a boolean flag indicating if stream mode 'updates' was present in the original list of stream modes. @@ -514,7 +517,6 @@ def _get_stream_modes( can be detected in the remote graph. """ updated_stream_modes: list[StreamModeSDK] = [] - req_updates = False req_single = True # coerce to list, or add default stream mode if stream_mode: @@ -525,16 +527,21 @@ def _get_stream_modes( updated_stream_modes.extend(stream_mode) else: updated_stream_modes.append(default) + requested_stream_modes = updated_stream_modes.copy() + # add any from parent graph + stream: Optional[StreamProtocol] = ( + (config or {}).get(CONF, {}).get(CONFIG_KEY_STREAM) + ) + if stream: + updated_stream_modes.extend(stream.modes) # map "messages" to "messages-tuple" if "messages" in updated_stream_modes: updated_stream_modes.remove("messages") updated_stream_modes.append("messages-tuple") # add 'updates' mode if not present - if "updates" in updated_stream_modes: - req_updates = True - else: + if "updates" not in updated_stream_modes: updated_stream_modes.append("updates") - return (updated_stream_modes, req_updates, req_single) + return (updated_stream_modes, requested_stream_modes, req_single, stream) def stream( self, @@ -566,12 +573,8 @@ def stream( sync_client = self._validate_sync_client() merged_config = merge_configs(self.config, config) sanitized_config = self._sanitize_config(merged_config) - stream_modes, req_updates, req_single = self._get_stream_modes(stream_mode) - stream: Optional[StreamProtocol] = ( - (config or {}).get(CONF, {}).get(CONFIG_KEY_STREAM) - ) - stream_modes_ext: list[StreamModeSDK] = ( - [*stream_modes, *stream.modes] if stream else stream_modes + stream_modes, requested, req_single, stream = self._get_stream_modes( + stream_mode, config ) for chunk in sync_client.runs.stream( @@ -579,31 +582,35 @@ def stream( assistant_id=self.name, input=input, config=sanitized_config, - stream_mode=stream_modes_ext, + stream_mode=stream_modes, interrupt_before=interrupt_before, interrupt_after=interrupt_after, stream_subgraphs=subgraphs or stream is not None, if_not_exists="create", ): + # split mode and ns if NS_SEP in chunk.event: mode, ns_ = chunk.event.split(NS_SEP, 1) ns = tuple(ns_.split(NS_SEP)) else: mode, ns = chunk.event, () + # prepend caller ns (as it is not passed to remote graph) if caller_ns := (config or {}).get(CONF, {}).get(CONFIG_KEY_CHECKPOINT_NS): caller_ns = tuple(caller_ns.split(NS_SEP)) ns = caller_ns + ns - if stream is not None and chunk.event in stream.modes: + # stream to parent stream + if stream is not None and mode in stream.modes: stream((ns, mode, chunk.data)) + # raise interrupt or errors if chunk.event.startswith("updates"): if isinstance(chunk.data, dict) and INTERRUPT in chunk.data: raise GraphInterrupt(chunk.data[INTERRUPT]) - if not req_updates: - continue elif chunk.event.startswith("error"): raise RemoteException(chunk.data) - if chunk.event.split(NS_SEP, 1)[0] not in stream_modes: + # filter for what was actually requested + if mode not in requested: continue + # emit chunk if subgraphs: if NS_SEP in chunk.event: mode, ns_ = chunk.event.split(NS_SEP, 1) @@ -649,12 +656,8 @@ async def astream( client = self._validate_client() merged_config = merge_configs(self.config, config) sanitized_config = self._sanitize_config(merged_config) - stream_modes, req_updates, req_single = self._get_stream_modes(stream_mode) - stream: Optional[StreamProtocol] = ( - (config or {}).get(CONF, {}).get(CONFIG_KEY_STREAM) - ) - stream_modes_ext: list[StreamModeSDK] = ( - [*stream_modes, *stream.modes] if stream else stream_modes + stream_modes, requested, req_single, stream = self._get_stream_modes( + stream_mode, config ) async for chunk in client.runs.stream( @@ -662,32 +665,41 @@ async def astream( assistant_id=self.name, input=input, config=sanitized_config, - stream_mode=stream_modes_ext, + stream_mode=stream_modes, interrupt_before=interrupt_before, interrupt_after=interrupt_after, stream_subgraphs=subgraphs or stream is not None, if_not_exists="create", ): + # split mode and ns if NS_SEP in chunk.event: mode, ns_ = chunk.event.split(NS_SEP, 1) ns = tuple(ns_.split(NS_SEP)) else: mode, ns = chunk.event, () + # prepend caller ns (as it is not passed to remote graph) if caller_ns := (config or {}).get(CONF, {}).get(CONFIG_KEY_CHECKPOINT_NS): caller_ns = tuple(caller_ns.split(NS_SEP)) ns = caller_ns + ns - if stream is not None and chunk.event in stream.modes: + # stream to parent stream + if stream is not None and mode in stream.modes: stream((ns, mode, chunk.data)) + # raise interrupt or errors if chunk.event.startswith("updates"): if isinstance(chunk.data, dict) and INTERRUPT in chunk.data: raise GraphInterrupt(chunk.data[INTERRUPT]) - if not req_updates: - continue elif chunk.event.startswith("error"): raise RemoteException(chunk.data) - if chunk.event.split(NS_SEP, 1)[0] not in stream_modes: + # filter for what was actually requested + if mode not in requested: continue + # emit chunk if subgraphs: + if NS_SEP in chunk.event: + mode, ns_ = chunk.event.split(NS_SEP, 1) + ns = tuple(ns_.split(NS_SEP)) + else: + mode, ns = chunk.event, () if req_single: yield ns, chunk.data else: