Skip to content

Commit

Permalink
Fix stream_mode=messages used together with subgraphs=True
Browse files Browse the repository at this point in the history
  • Loading branch information
nfcampos committed Oct 31, 2024
1 parent d6d6ab9 commit 4cc6d78
Showing 1 changed file with 40 additions and 28 deletions.
68 changes: 40 additions & 28 deletions libs/langgraph/langgraph/pregel/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,8 +504,11 @@ async def aupdate_state(
def _get_stream_modes(
self,
stream_mode: Optional[Union[StreamMode, list[StreamMode]]],
config: Optional[RunnableConfig],
default: StreamMode = "updates",
) -> tuple[list[StreamModeSDK], bool, bool]:
) -> tuple[
list[StreamModeSDK], list[StreamModeSDK], bool, Optional[StreamProtocol]
]:
"""Return a tuple of the final list of stream modes sent to the
remote graph and a boolean flag indicating if stream mode 'updates'
was present in the original list of stream modes.
Expand All @@ -514,7 +517,6 @@ def _get_stream_modes(
can be detected in the remote graph.
"""
updated_stream_modes: list[StreamModeSDK] = []
req_updates = False
req_single = True
# coerce to list, or add default stream mode
if stream_mode:
Expand All @@ -525,16 +527,21 @@ def _get_stream_modes(
updated_stream_modes.extend(stream_mode)
else:
updated_stream_modes.append(default)
requested_stream_modes = updated_stream_modes.copy()
# add any from parent graph
stream: Optional[StreamProtocol] = (
(config or {}).get(CONF, {}).get(CONFIG_KEY_STREAM)
)
if stream:
updated_stream_modes.extend(stream.modes)
# map "messages" to "messages-tuple"
if "messages" in updated_stream_modes:
updated_stream_modes.remove("messages")
updated_stream_modes.append("messages-tuple")
# add 'updates' mode if not present
if "updates" in updated_stream_modes:
req_updates = True
else:
if "updates" not in updated_stream_modes:
updated_stream_modes.append("updates")
return (updated_stream_modes, req_updates, req_single)
return (updated_stream_modes, requested_stream_modes, req_single, stream)

def stream(
self,
Expand Down Expand Up @@ -566,44 +573,44 @@ def stream(
sync_client = self._validate_sync_client()
merged_config = merge_configs(self.config, config)
sanitized_config = self._sanitize_config(merged_config)
stream_modes, req_updates, req_single = self._get_stream_modes(stream_mode)
stream: Optional[StreamProtocol] = (
(config or {}).get(CONF, {}).get(CONFIG_KEY_STREAM)
)
stream_modes_ext: list[StreamModeSDK] = (
[*stream_modes, *stream.modes] if stream else stream_modes
stream_modes, requested, req_single, stream = self._get_stream_modes(
stream_mode, config
)

for chunk in sync_client.runs.stream(
thread_id=sanitized_config["configurable"].get("thread_id"),
assistant_id=self.name,
input=input,
config=sanitized_config,
stream_mode=stream_modes_ext,
stream_mode=stream_modes,
interrupt_before=interrupt_before,
interrupt_after=interrupt_after,
stream_subgraphs=subgraphs or stream is not None,
if_not_exists="create",
):
# split mode and ns
if NS_SEP in chunk.event:
mode, ns_ = chunk.event.split(NS_SEP, 1)
ns = tuple(ns_.split(NS_SEP))
else:
mode, ns = chunk.event, ()
# prepend caller ns (as it is not passed to remote graph)
if caller_ns := (config or {}).get(CONF, {}).get(CONFIG_KEY_CHECKPOINT_NS):
caller_ns = tuple(caller_ns.split(NS_SEP))
ns = caller_ns + ns
if stream is not None and chunk.event in stream.modes:
# stream to parent stream
if stream is not None and mode in stream.modes:
stream((ns, mode, chunk.data))
# raise interrupt or errors
if chunk.event.startswith("updates"):
if isinstance(chunk.data, dict) and INTERRUPT in chunk.data:
raise GraphInterrupt(chunk.data[INTERRUPT])
if not req_updates:
continue
elif chunk.event.startswith("error"):
raise RemoteException(chunk.data)
if chunk.event.split(NS_SEP, 1)[0] not in stream_modes:
# filter for what was actually requested
if mode not in requested:
continue
# emit chunk
if subgraphs:
if NS_SEP in chunk.event:
mode, ns_ = chunk.event.split(NS_SEP, 1)
Expand Down Expand Up @@ -649,45 +656,50 @@ async def astream(
client = self._validate_client()
merged_config = merge_configs(self.config, config)
sanitized_config = self._sanitize_config(merged_config)
stream_modes, req_updates, req_single = self._get_stream_modes(stream_mode)
stream: Optional[StreamProtocol] = (
(config or {}).get(CONF, {}).get(CONFIG_KEY_STREAM)
)
stream_modes_ext: list[StreamModeSDK] = (
[*stream_modes, *stream.modes] if stream else stream_modes
stream_modes, requested, req_single, stream = self._get_stream_modes(
stream_mode, config
)

async for chunk in client.runs.stream(
thread_id=sanitized_config["configurable"].get("thread_id"),
assistant_id=self.name,
input=input,
config=sanitized_config,
stream_mode=stream_modes_ext,
stream_mode=stream_modes,
interrupt_before=interrupt_before,
interrupt_after=interrupt_after,
stream_subgraphs=subgraphs or stream is not None,
if_not_exists="create",
):
# split mode and ns
if NS_SEP in chunk.event:
mode, ns_ = chunk.event.split(NS_SEP, 1)
ns = tuple(ns_.split(NS_SEP))
else:
mode, ns = chunk.event, ()
# prepend caller ns (as it is not passed to remote graph)
if caller_ns := (config or {}).get(CONF, {}).get(CONFIG_KEY_CHECKPOINT_NS):
caller_ns = tuple(caller_ns.split(NS_SEP))
ns = caller_ns + ns
if stream is not None and chunk.event in stream.modes:
# stream to parent stream
if stream is not None and mode in stream.modes:
stream((ns, mode, chunk.data))
# raise interrupt or errors
if chunk.event.startswith("updates"):
if isinstance(chunk.data, dict) and INTERRUPT in chunk.data:
raise GraphInterrupt(chunk.data[INTERRUPT])
if not req_updates:
continue
elif chunk.event.startswith("error"):
raise RemoteException(chunk.data)
if chunk.event.split(NS_SEP, 1)[0] not in stream_modes:
# filter for what was actually requested
if mode not in requested:
continue
# emit chunk
if subgraphs:
if NS_SEP in chunk.event:
mode, ns_ = chunk.event.split(NS_SEP, 1)
ns = tuple(ns_.split(NS_SEP))
else:
mode, ns = chunk.event, ()
if req_single:
yield ns, chunk.data
else:
Expand Down

0 comments on commit 4cc6d78

Please sign in to comment.