Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

langgraph: fix edge case with string enums as node names #1926

Merged
merged 2 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion libs/langgraph/langgraph/pregel/algo.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections import defaultdict, deque

Check notice on line 1 in libs/langgraph/langgraph/pregel/algo.py

View workflow job for this annotation

GitHub Actions / benchmark

Benchmark results

......................................... fanout_to_subgraph_10x: Mean +- std dev: 59.1 ms +- 1.3 ms ......................................... WARNING: the benchmark result may be unstable * the standard deviation (7.09 ms) is 12% of the mean (57.3 ms) Try to rerun the benchmark with more runs, values and/or loops. Run 'python -m pyperf system tune' command to reduce the system jitter. Use pyperf stats, pyperf dump and pyperf hist to analyze results. Use --quiet option to hide these warnings. fanout_to_subgraph_10x_sync: Mean +- std dev: 57.3 ms +- 7.1 ms ......................................... fanout_to_subgraph_10x_checkpoint: Mean +- std dev: 77.0 ms +- 1.5 ms ......................................... fanout_to_subgraph_10x_checkpoint_sync: Mean +- std dev: 81.5 ms +- 0.8 ms ......................................... fanout_to_subgraph_100x: Mean +- std dev: 549 ms +- 8 ms ......................................... fanout_to_subgraph_100x_sync: Mean +- std dev: 503 ms +- 5 ms ......................................... fanout_to_subgraph_100x_checkpoint: Mean +- std dev: 779 ms +- 26 ms ......................................... fanout_to_subgraph_100x_checkpoint_sync: Mean +- std dev: 790 ms +- 7 ms ......................................... react_agent_10x: Mean +- std dev: 41.9 ms +- 3.6 ms ......................................... react_agent_10x_sync: Mean +- std dev: 29.7 ms +- 0.3 ms ......................................... react_agent_10x_checkpoint: Mean +- std dev: 52.9 ms +- 1.4 ms ......................................... react_agent_10x_checkpoint_sync: Mean +- std dev: 42.7 ms +- 3.3 ms ......................................... react_agent_100x: Mean +- std dev: 411 ms +- 8 ms ......................................... react_agent_100x_sync: Mean +- std dev: 330 ms +- 3 ms ......................................... react_agent_100x_checkpoint: Mean +- std dev: 919 ms +- 12 ms ......................................... react_agent_100x_checkpoint_sync: Mean +- std dev: 816 ms +- 10 ms ......................................... wide_state_25x300: Mean +- std dev: 20.5 ms +- 0.4 ms ......................................... wide_state_25x300_sync: Mean +- std dev: 12.8 ms +- 0.1 ms ......................................... wide_state_25x300_checkpoint: Mean +- std dev: 237 ms +- 7 ms ......................................... wide_state_25x300_checkpoint_sync: Mean +- std dev: 237 ms +- 14 ms ......................................... wide_state_15x600: Mean +- std dev: 23.7 ms +- 0.4 ms ......................................... wide_state_15x600_sync: Mean +- std dev: 14.8 ms +- 0.3 ms ......................................... wide_state_15x600_checkpoint: Mean +- std dev: 414 ms +- 13 ms ......................................... wide_state_15x600_checkpoint_sync: Mean +- std dev: 414 ms +- 19 ms ......................................... wide_state_9x1200: Mean +- std dev: 23.7 ms +- 0.3 ms ......................................... wide_state_9x1200_sync: Mean +- std dev: 14.8 ms +- 0.3 ms ......................................... wide_state_9x1200_checkpoint: Mean +- std dev: 268 ms +- 7 ms ......................................... wide_state_9x1200_checkpoint_sync: Mean +- std dev: 268 ms +- 15 ms

Check notice on line 1 in libs/langgraph/langgraph/pregel/algo.py

View workflow job for this annotation

GitHub Actions / benchmark

Comparison against main

+----------------------------------------+---------+-----------------------+ | Benchmark | main | changes | +========================================+=========+=======================+ | react_agent_100x_checkpoint_sync | 832 ms | 816 ms: 1.02x faster | +----------------------------------------+---------+-----------------------+ | wide_state_25x300_checkpoint | 241 ms | 237 ms: 1.02x faster | +----------------------------------------+---------+-----------------------+ | wide_state_15x600_checkpoint_sync | 420 ms | 414 ms: 1.01x faster | +----------------------------------------+---------+-----------------------+ | react_agent_100x_checkpoint | 929 ms | 919 ms: 1.01x faster | +----------------------------------------+---------+-----------------------+ | wide_state_15x600_checkpoint | 419 ms | 414 ms: 1.01x faster | +----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x | 554 ms | 549 ms: 1.01x faster | +----------------------------------------+---------+-----------------------+ | react_agent_10x_checkpoint | 53.3 ms | 52.9 ms: 1.01x faster | +----------------------------------------+---------+-----------------------+ | wide_state_9x1200_checkpoint | 270 ms | 268 ms: 1.01x faster | +----------------------------------------+---------+-----------------------+ | react_agent_100x_sync | 332 ms | 330 ms: 1.01x faster | +----------------------------------------+---------+-----------------------+ | react_agent_10x_sync | 29.9 ms | 29.7 ms: 1.00x faster | +----------------------------------------+---------+-----------------------+ | wide_state_15x600_sync | 14.9 ms | 14.8 ms: 1.00x faster | +----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x_checkpoint_sync | 81.8 ms | 81.5 ms: 1.00x faster | +----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x_checkpoint | 768 ms | 779 ms: 1.01x slower | +----------------------------------------+---------+-----------------------+ | Geometric mean | (ref) | 1.00x faster | +----------------------------------------+---------+-----------------------+ Benchmark hidden because not significant (15): wide_state_9x1200_checkpoint_sync, wide_state_25x300_checkpoint_sync, react_agent_10x_checkpoint_sync, fanout_to_subgraph_10x, fanout_to_subgraph_10x_checkpoint, react_agent_100x, wide_state_15x600, wide_state_9x1200_sync, wide_state_25x300_sync, fanout_to_subgraph_100x_checkpoint_sync, wide_state_9x1200, fanout_to_subgraph_100x_sync, wide_state_25x300, react_agent_10x, fanout_to_subgraph_10x_sync
from functools import partial
from hashlib import sha1
from typing import (
Expand All @@ -13,6 +13,7 @@
Protocol,
Sequence,
Union,
cast,
overload,
)
from uuid import UUID
Expand Down Expand Up @@ -471,7 +472,7 @@
else:
return PregelTask(task_id, packet.node, task_path)
elif task_path[0] == PULL:
name = str(task_path[1])
name = cast(str, task_path[1])
if name not in processes:
return
proc = processes[name]
Expand Down
21 changes: 21 additions & 0 deletions libs/langgraph/tests/test_pregel.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import enum
import json
import operator
import re
Expand Down Expand Up @@ -11517,3 +11518,23 @@ def node(input: State, config: RunnableConfig, store: BaseStore):
"some_val": 0,
} # Overwrites the whole doc
assert len(the_store.search(("foo", "bar"))) == 1 # still overwriting the same one


def test_enum_node_names():
class NodeName(str, enum.Enum):
BAZ = "baz"

class State(TypedDict):
foo: str
bar: str

def baz(state: State):
return {"bar": state["foo"] + "!"}

graph = StateGraph(State)
graph.add_node(NodeName.BAZ, baz)
graph.add_edge(START, NodeName.BAZ)
graph.add_edge(NodeName.BAZ, END)
graph = graph.compile()

assert graph.invoke({"foo": "hello"}) == {"foo": "hello", "bar": "hello!"}
Loading