From 25c509651fed672b933974a05f3c908e8339721e Mon Sep 17 00:00:00 2001 From: vbarda Date: Mon, 30 Sep 2024 17:50:28 -0400 Subject: [PATCH 1/2] langgraph: fix edge case with string enums as node names --- libs/langgraph/langgraph/pregel/algo.py | 3 ++- libs/langgraph/tests/test_pregel.py | 21 +++++++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/libs/langgraph/langgraph/pregel/algo.py b/libs/langgraph/langgraph/pregel/algo.py index daa942a39..0cb81b710 100644 --- a/libs/langgraph/langgraph/pregel/algo.py +++ b/libs/langgraph/langgraph/pregel/algo.py @@ -14,6 +14,7 @@ Sequence, Union, overload, + cast ) from uuid import UUID @@ -471,7 +472,7 @@ def prepare_single_task( else: return PregelTask(task_id, packet.node, task_path) elif task_path[0] == PULL: - name = str(task_path[1]) + name = cast(str, task_path[1]) if name not in processes: return proc = processes[name] diff --git a/libs/langgraph/tests/test_pregel.py b/libs/langgraph/tests/test_pregel.py index 0bd5772d4..531a88cf6 100644 --- a/libs/langgraph/tests/test_pregel.py +++ b/libs/langgraph/tests/test_pregel.py @@ -7,6 +7,7 @@ from collections import Counter from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager +import enum from random import randrange from typing import ( Annotated, @@ -11517,3 +11518,23 @@ def node(input: State, config: RunnableConfig, store: BaseStore): "some_val": 0, } # Overwrites the whole doc assert len(the_store.search(("foo", "bar"))) == 1 # still overwriting the same one + + +def test_enum_node_names(): + class NodeName(str, enum.Enum): + BAZ = "baz" + + class State(TypedDict): + foo: str + bar: str + + def baz(state: State): + return {"bar": state["foo"] + "!"} + + graph = StateGraph(State) + graph.add_node(NodeName.BAZ, baz) + graph.add_edge(START, NodeName.BAZ) + graph.add_edge(NodeName.BAZ, END) + graph = graph.compile() + + assert graph.invoke({"foo": "hello"}) == {"foo": "hello", "bar": "hello!"} \ No newline at end of file From 96dc4952484ebac07b3f8e1f4239bce58418cd9a Mon Sep 17 00:00:00 2001 From: vbarda Date: Mon, 30 Sep 2024 17:52:16 -0400 Subject: [PATCH 2/2] lint --- libs/langgraph/langgraph/pregel/algo.py | 2 +- libs/langgraph/tests/test_pregel.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/libs/langgraph/langgraph/pregel/algo.py b/libs/langgraph/langgraph/pregel/algo.py index 0cb81b710..98ac8576f 100644 --- a/libs/langgraph/langgraph/pregel/algo.py +++ b/libs/langgraph/langgraph/pregel/algo.py @@ -13,8 +13,8 @@ Protocol, Sequence, Union, + cast, overload, - cast ) from uuid import UUID diff --git a/libs/langgraph/tests/test_pregel.py b/libs/langgraph/tests/test_pregel.py index 531a88cf6..29e287da2 100644 --- a/libs/langgraph/tests/test_pregel.py +++ b/libs/langgraph/tests/test_pregel.py @@ -1,3 +1,4 @@ +import enum import json import operator import re @@ -7,7 +8,6 @@ from collections import Counter from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager -import enum from random import randrange from typing import ( Annotated, @@ -11537,4 +11537,4 @@ def baz(state: State): graph.add_edge(NodeName.BAZ, END) graph = graph.compile() - assert graph.invoke({"foo": "hello"}) == {"foo": "hello", "bar": "hello!"} \ No newline at end of file + assert graph.invoke({"foo": "hello"}) == {"foo": "hello", "bar": "hello!"}