From 466cb8acb5168912d28aad2f28a53a458fc8ff12 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Mon, 29 Jul 2024 12:38:48 -0700 Subject: [PATCH] Fix issue when cond edge visited after multiple executions of Send - cond edge will run for each execution of Send, so target channels need to support multiple publishes --- libs/langgraph/langgraph/graph/state.py | 6 +++--- libs/langgraph/tests/test_pregel.py | 14 +++++++------- libs/langgraph/tests/test_pregel_async.py | 14 +++++++------- 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/libs/langgraph/langgraph/graph/state.py b/libs/langgraph/langgraph/graph/state.py index 0657e5654..b67d4e9d1 100644 --- a/libs/langgraph/langgraph/graph/state.py +++ b/libs/langgraph/langgraph/graph/state.py @@ -610,7 +610,7 @@ def branch_writer(packets: list[Union[str, Send]]) -> Optional[ChannelWrite]: if branch.then and branch.then != END: writes.append( ChannelWriteEntry( - f"branch:{start}:{name}:then", + f"branch:{start}:{name}::then", WaitForNames( {p.node if isinstance(p, Send) else p for p in filtered} ), @@ -630,12 +630,12 @@ def branch_writer(packets: list[Union[str, Send]]) -> Optional[ChannelWrite]: for end in ends: if end != END: channel_name = f"branch:{start}:{name}:{end}" - self.channels[channel_name] = EphemeralValue(Any) + self.channels[channel_name] = EphemeralValue(Any, guard=False) self.nodes[end].triggers.append(channel_name) # attach then subscriber if branch.then and branch.then != END: - channel_name = f"branch:{start}:{name}:then" + channel_name = f"branch:{start}:{name}::then" self.channels[channel_name] = DynamicBarrierValue(str) self.nodes[branch.then].triggers.append(channel_name) for end in ends: diff --git a/libs/langgraph/tests/test_pregel.py b/libs/langgraph/tests/test_pregel.py index 17d8d66d2..87a73647e 100644 --- a/libs/langgraph/tests/test_pregel.py +++ b/libs/langgraph/tests/test_pregel.py @@ -1134,15 +1134,15 @@ def __init__(self, name: str): setattr(self, "__name__", name) def __call__(self, state): - return state + [self.name] + return [self.name] def send_for_fun(state): - return [Send("2", state)] + return [Send("2", state), Send("2", state)] def route_to_three(state) -> Literal["3"]: return "3" - builder = StateGraph(list) + builder = StateGraph(Annotated[list, operator.add]) builder.add_node(Node("1")) builder.add_node(Node("2")) builder.add_node(Node("3")) @@ -1150,7 +1150,7 @@ def route_to_three(state) -> Literal["3"]: builder.add_conditional_edges("1", send_for_fun) builder.add_conditional_edges("2", route_to_three) graph = builder.compile() - assert graph.invoke(["0"]) == ["0", "1", "2", "3"] + assert graph.invoke(["0"]) == ["0", "1", "2", "2", "3"] async def test_checkpointer_null_pending_writes() -> None: @@ -6536,10 +6536,10 @@ class State(TypedDict): "timestamp": AnyStr(), "step": 3, "payload": { - "id": "ceada3c5-5f25-59e4-9ea5-544599ce1d2f", + "id": "9b590c54-15ef-54b1-83a7-140d27b0bc52", "name": "finish", "input": {"my_key": "value prepared slow", "market": "DE"}, - "triggers": ["branch:prepare:condition:then"], + "triggers": ["branch:prepare:condition::then"], }, }, { @@ -6547,7 +6547,7 @@ class State(TypedDict): "timestamp": AnyStr(), "step": 3, "payload": { - "id": "ceada3c5-5f25-59e4-9ea5-544599ce1d2f", + "id": "9b590c54-15ef-54b1-83a7-140d27b0bc52", "name": "finish", "result": [("my_key", " finished")], }, diff --git a/libs/langgraph/tests/test_pregel_async.py b/libs/langgraph/tests/test_pregel_async.py index 84d597c43..0aeb1546a 100644 --- a/libs/langgraph/tests/test_pregel_async.py +++ b/libs/langgraph/tests/test_pregel_async.py @@ -1257,15 +1257,15 @@ def __init__(self, name: str): setattr(self, "__name__", name) async def __call__(self, state): - return state + [self.name] + return [self.name] async def send_for_fun(state): - return [Send("2", state)] + return [Send("2", state), Send("2", state)] async def route_to_three(state) -> Literal["3"]: return "3" - builder = StateGraph(list) + builder = StateGraph(Annotated[list, operator.add]) builder.add_node(Node("1")) builder.add_node(Node("2")) builder.add_node(Node("3")) @@ -1274,7 +1274,7 @@ async def route_to_three(state) -> Literal["3"]: builder.add_conditional_edges("2", route_to_three) graph = builder.compile() - assert await graph.ainvoke(["0"]) == ["0", "1", "2", "3"] + assert await graph.ainvoke(["0"]) == ["0", "1", "2", "2", "3"] async def test_invoke_checkpoint_aiosqlite(mocker: MockerFixture) -> None: @@ -5119,10 +5119,10 @@ class State(TypedDict): "timestamp": AnyStr(), "step": 3, "payload": { - "id": "ceada3c5-5f25-59e4-9ea5-544599ce1d2f", + "id": "9b590c54-15ef-54b1-83a7-140d27b0bc52", "name": "finish", "input": {"my_key": "value prepared slow", "market": "DE"}, - "triggers": ["branch:prepare:condition:then"], + "triggers": ["branch:prepare:condition::then"], }, }, { @@ -5130,7 +5130,7 @@ class State(TypedDict): "timestamp": AnyStr(), "step": 3, "payload": { - "id": "ceada3c5-5f25-59e4-9ea5-544599ce1d2f", + "id": "9b590c54-15ef-54b1-83a7-140d27b0bc52", "name": "finish", "result": [("my_key", " finished")], },