Skip to content

Commit

Permalink
langgraph: fix edge case with string enums as node names (#1926)
Browse files Browse the repository at this point in the history
  • Loading branch information
vbarda authored Sep 30, 2024
1 parent 6d1d705 commit da935e7
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 1 deletion.
3 changes: 2 additions & 1 deletion libs/langgraph/langgraph/pregel/algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Protocol,
Sequence,
Union,
cast,
overload,
)
from uuid import UUID
Expand Down Expand Up @@ -471,7 +472,7 @@ def prepare_single_task(
else:
return PregelTask(task_id, packet.node, task_path)
elif task_path[0] == PULL:
name = str(task_path[1])
name = cast(str, task_path[1])
if name not in processes:
return
proc = processes[name]
Expand Down
21 changes: 21 additions & 0 deletions libs/langgraph/tests/test_pregel.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import enum
import json
import operator
import re
Expand Down Expand Up @@ -11517,3 +11518,23 @@ def node(input: State, config: RunnableConfig, store: BaseStore):
"some_val": 0,
} # Overwrites the whole doc
assert len(the_store.search(("foo", "bar"))) == 1 # still overwriting the same one


def test_enum_node_names():
class NodeName(str, enum.Enum):
BAZ = "baz"

class State(TypedDict):
foo: str
bar: str

def baz(state: State):
return {"bar": state["foo"] + "!"}

graph = StateGraph(State)
graph.add_node(NodeName.BAZ, baz)
graph.add_edge(START, NodeName.BAZ)
graph.add_edge(NodeName.BAZ, END)
graph = graph.compile()

assert graph.invoke({"foo": "hello"}) == {"foo": "hello", "bar": "hello!"}

0 comments on commit da935e7

Please sign in to comment.