Skip to content

Commit

Permalink
Fix issue when cond edge visited after multiple executions of Send
Browse files Browse the repository at this point in the history
- cond edge will run for each execution of Send, so target channels need to support multiple publishes
  • Loading branch information
nfcampos committed Jul 29, 2024
1 parent ea07193 commit 466cb8a
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 17 deletions.
6 changes: 3 additions & 3 deletions libs/langgraph/langgraph/graph/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,7 @@ def branch_writer(packets: list[Union[str, Send]]) -> Optional[ChannelWrite]:
if branch.then and branch.then != END:
writes.append(
ChannelWriteEntry(
f"branch:{start}:{name}:then",
f"branch:{start}:{name}::then",
WaitForNames(
{p.node if isinstance(p, Send) else p for p in filtered}
),
Expand All @@ -630,12 +630,12 @@ def branch_writer(packets: list[Union[str, Send]]) -> Optional[ChannelWrite]:
for end in ends:
if end != END:
channel_name = f"branch:{start}:{name}:{end}"
self.channels[channel_name] = EphemeralValue(Any)
self.channels[channel_name] = EphemeralValue(Any, guard=False)
self.nodes[end].triggers.append(channel_name)

# attach then subscriber
if branch.then and branch.then != END:
channel_name = f"branch:{start}:{name}:then"
channel_name = f"branch:{start}:{name}::then"
self.channels[channel_name] = DynamicBarrierValue(str)
self.nodes[branch.then].triggers.append(channel_name)
for end in ends:
Expand Down
14 changes: 7 additions & 7 deletions libs/langgraph/tests/test_pregel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1134,23 +1134,23 @@ def __init__(self, name: str):
setattr(self, "__name__", name)

def __call__(self, state):
return state + [self.name]
return [self.name]

def send_for_fun(state):
return [Send("2", state)]
return [Send("2", state), Send("2", state)]

def route_to_three(state) -> Literal["3"]:
return "3"

builder = StateGraph(list)
builder = StateGraph(Annotated[list, operator.add])
builder.add_node(Node("1"))
builder.add_node(Node("2"))
builder.add_node(Node("3"))
builder.add_edge(START, "1")
builder.add_conditional_edges("1", send_for_fun)
builder.add_conditional_edges("2", route_to_three)
graph = builder.compile()
assert graph.invoke(["0"]) == ["0", "1", "2", "3"]
assert graph.invoke(["0"]) == ["0", "1", "2", "2", "3"]


async def test_checkpointer_null_pending_writes() -> None:
Expand Down Expand Up @@ -6536,18 +6536,18 @@ class State(TypedDict):
"timestamp": AnyStr(),
"step": 3,
"payload": {
"id": "ceada3c5-5f25-59e4-9ea5-544599ce1d2f",
"id": "9b590c54-15ef-54b1-83a7-140d27b0bc52",
"name": "finish",
"input": {"my_key": "value prepared slow", "market": "DE"},
"triggers": ["branch:prepare:condition:then"],
"triggers": ["branch:prepare:condition::then"],
},
},
{
"type": "task_result",
"timestamp": AnyStr(),
"step": 3,
"payload": {
"id": "ceada3c5-5f25-59e4-9ea5-544599ce1d2f",
"id": "9b590c54-15ef-54b1-83a7-140d27b0bc52",
"name": "finish",
"result": [("my_key", " finished")],
},
Expand Down
14 changes: 7 additions & 7 deletions libs/langgraph/tests/test_pregel_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -1257,15 +1257,15 @@ def __init__(self, name: str):
setattr(self, "__name__", name)

async def __call__(self, state):
return state + [self.name]
return [self.name]

async def send_for_fun(state):
return [Send("2", state)]
return [Send("2", state), Send("2", state)]

async def route_to_three(state) -> Literal["3"]:
return "3"

builder = StateGraph(list)
builder = StateGraph(Annotated[list, operator.add])
builder.add_node(Node("1"))
builder.add_node(Node("2"))
builder.add_node(Node("3"))
Expand All @@ -1274,7 +1274,7 @@ async def route_to_three(state) -> Literal["3"]:
builder.add_conditional_edges("2", route_to_three)
graph = builder.compile()

assert await graph.ainvoke(["0"]) == ["0", "1", "2", "3"]
assert await graph.ainvoke(["0"]) == ["0", "1", "2", "2", "3"]


async def test_invoke_checkpoint_aiosqlite(mocker: MockerFixture) -> None:
Expand Down Expand Up @@ -5119,18 +5119,18 @@ class State(TypedDict):
"timestamp": AnyStr(),
"step": 3,
"payload": {
"id": "ceada3c5-5f25-59e4-9ea5-544599ce1d2f",
"id": "9b590c54-15ef-54b1-83a7-140d27b0bc52",
"name": "finish",
"input": {"my_key": "value prepared slow", "market": "DE"},
"triggers": ["branch:prepare:condition:then"],
"triggers": ["branch:prepare:condition::then"],
},
},
{
"type": "task_result",
"timestamp": AnyStr(),
"step": 3,
"payload": {
"id": "ceada3c5-5f25-59e4-9ea5-544599ce1d2f",
"id": "9b590c54-15ef-54b1-83a7-140d27b0bc52",
"name": "finish",
"result": [("my_key", " finished")],
},
Expand Down

0 comments on commit 466cb8a

Please sign in to comment.