Routing function does not appear to work with custom state #1146
-
import os
import asyncio
from typing import Annotated, Literal, List
from typing_extensions import TypedDict
from langchain_core.tools import tool
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode
from langchain_openai import ChatOpenAI
from langgraph.graph import END, StateGraph, START
from functools import partial
from langchain_core.runnables import RunnableConfig
class State(TypedDict):
messages: Annotated[list, add_messages]
status: str = ""
# Define the function that calls the model
async def call_model(state: State, model: ChatOpenAI):
print("Status: ", state["status"])
messages = state["messages"]
response = await model.ainvoke(messages)
return {"messages": [response]}
async def run_app(app, inputs, config = None):
if config is None:
config = RunnableConfig(stream_mode="updates", recursion_limit=8)
all_values = [inputs]
async for output in app.astream(inputs, config):
# stream_mode="updates" yields dictionaries with output keyed by node name
for key, value in output.items():
print(f"Output from node '{key}':")
print("---")
print(value["messages"][-1].pretty_print())
all_values.append(value)
print("\n---\n")
print("All values:")
print("---")
print(all_values)
def route(state: State) -> Literal["call_tools", "continue"]:
messages = state["messages"]
print("Status: ", state["status"]) # I get: KeyError: 'status'
last_message = messages[-1]
if last_message.tool_calls:
return "call_tools"
else:
return "continue"
@tool
def multiply(a: int, b: int) -> int:
"""Multiply two numbers"""
return a * b
if __name__ == "__main__":
openai_api_key = os.getenv("OPENAI_API_KEY")
model = ChatOpenAI(model="gpt-4o-mini", api_key=openai_api_key)
tools = [multiply]
tool_node = ToolNode(tools)
if len(tools) > 0:
model = model.bind_tools(tools)
workflow = StateGraph(State)
# Define the nodes we will cycle between
workflow.add_node("agent", partial(call_model, model=model))
workflow.add_node("action", tool_node)
workflow.add_edge(START, "agent")
workflow.add_edge("action", "agent")
workflow.add_conditional_edges("agent", route, {
"call_tools": "action",
"continue": END})
app = workflow.compile()
inputs = {"messages": [
("user",
"What is 1 + 1 = 2?")]}
config = RunnableConfig(stream_mode="updates", recursion_limit=8)
all_values = [inputs]
asyncio.run(run_app(app, inputs, config)) In this example, I tried to define a custom state, that has a field called "status". I am able to print its value inside call_model, but when I try to access its value in route, I see that the status field has disappeared and my state only has "messages". Did I do something incorrectly? Thanks. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
@fedshyvana the issue is that you're trying to access the value for so in your example, you can verify that it works by changing input to inputs = {"messages": [
("user",
"What is 1 + 1 = 2?")], "status": "my_status"} hope this helps! |
Beta Was this translation helpful? Give feedback.
@fedshyvana the issue is that you're trying to access the value for
status
key but it's never actually set. TypedDict doesn't support defaults, so the assignment in the state schema won't work. so you need to either pass status when invoking / streaming from the graph (i.e. passstatus
as part of the initial state the graph receives) or a node in the graph needs to send an update forstatus
key before it's accessed by a subsequent nodeso in your example, you can verify that it works by changing input to
hope this helps!