diff --git a/tutorials/tracing/langgraph_agent_tracing_tutorial.ipynb b/tutorials/tracing/langgraph_agent_tracing_tutorial.ipynb new file mode 100644 index 0000000000..f40c391475 --- /dev/null +++ b/tutorials/tracing/langgraph_agent_tracing_tutorial.ipynb @@ -0,0 +1,656 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "9f853e403eabd4f8", + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "source": [ + "
\n", + "

\n", + " \"phoenix\n", + "
\n", + " Docs\n", + " |\n", + " GitHub\n", + " |\n", + " Community\n", + "

\n", + "
\n", + "

Tracing a LangGraph Agent

\n", + "\n", + "LangGraph provides tools to easily define a structured AI Agent. However, it can be challenging to understand what is going on under the hood and to pinpoint the cause of issues. Phoenix makes your LLM applications *observable* by visualizing the underlying structure of each call to your query engine and surfacing problematic \"spans\" of execution based on latency, token count, or other evaluation metrics.\n", + "\n", + "In this tutorial, you will:\n", + "- Build a simple SQL database agent using LangGraph,\n", + "- Record trace data in OpenInference format,\n", + "- Inspect the traces and spans of your application to identify sources of latency and cost\n", + "\n", + "ℹ️ This notebook requires an OpenAI API key." + ] + }, + { + "cell_type": "markdown", + "id": "b5a87813ffe7e4d2", + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "source": [ + "## Setup\n", + "\n", + "First let's install our required packages and set our API keys" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4a4be247", + "metadata": {}, + "outputs": [], + "source": [ + "%pip install -U langgraph langchain_openai langchain_community arize-phoenix openinference-instrumentation-langchain" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6c05a600f1afb5b6", + "metadata": {}, + "outputs": [], + "source": [ + "import getpass\n", + "import os\n", + "\n", + "\n", + "def _set_env(key: str):\n", + " if key not in os.environ:\n", + " os.environ[key] = getpass.getpass(f\"{key}:\")\n", + "\n", + "\n", + "_set_env(\"OPENAI_API_KEY\")" + ] + }, + { + "cell_type": "markdown", + "id": "877d8c85825089d8", + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "source": [ + "## Configure the database\n", + "\n", + "We will be creating a SQLite database for this tutorial. SQLite is a lightweight database that is easy to set up and use. We will be loading the `chinook` database, which is a sample database that represents a digital media store.\n", + "Find more information about the database [here](https://www.sqlitetutorial.net/sqlite-sample-database/).\n", + "\n", + "For convenience, we have hosted the database (`Chinook.db`) on a public GCS bucket." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "64b0bf1b14c2e902", + "metadata": {}, + "outputs": [], + "source": [ + "import requests\n", + "\n", + "url = \"https://storage.googleapis.com/benchmarks-artifacts/chinook/Chinook.db\"\n", + "\n", + "response = requests.get(url)\n", + "\n", + "if response.status_code == 200:\n", + " # Open a local file in binary write mode\n", + " with open(\"Chinook.db\", \"wb\") as file:\n", + " # Write the content of the response (the file) to the local file\n", + " file.write(response.content)\n", + " print(\"File downloaded and saved as Chinook.db\")\n", + "else:\n", + " print(f\"Failed to download the file. Status code: {response.status_code}\")" + ] + }, + { + "cell_type": "markdown", + "id": "61c8304aa5ceb6a5", + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "source": [ + "We will use a handy SQL database wrapper available in the `langchain_community` package to interact with the database. The wrapper provides a simple interface to execute SQL queries and fetch results. We will also use the `langchain_openai` package to interact with the OpenAI API for language models later in the tutorial." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1f1e1f4f86ed54", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_community.utilities import SQLDatabase\n", + "\n", + "db = SQLDatabase.from_uri(\"sqlite:///Chinook.db\")\n", + "print(db.dialect)\n", + "print(db.get_usable_table_names())\n", + "db.run(\"SELECT * FROM Artist LIMIT 10;\")" + ] + }, + { + "cell_type": "markdown", + "id": "6959e93141d8099c", + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "source": [ + "## Utility functions\n", + "\n", + "We will define a few utility functions to help us with the agent implementation. Specifically, we will wrap a `ToolNode` with a fallback to handle errors and surface them to the agent." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "deae8460e4cf72b1", + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Any\n", + "\n", + "from langchain_core.messages import ToolMessage\n", + "from langchain_core.runnables import RunnableLambda, RunnableWithFallbacks\n", + "from langgraph.prebuilt import ToolNode\n", + "\n", + "\n", + "def create_tool_node_with_fallback(tools: list) -> RunnableWithFallbacks[Any, dict]:\n", + " \"\"\"\n", + " Create a ToolNode with a fallback to handle errors and surface them to the agent.\n", + " \"\"\"\n", + " return ToolNode(tools).with_fallbacks(\n", + " [RunnableLambda(handle_tool_error)], exception_key=\"error\"\n", + " )\n", + "\n", + "\n", + "def handle_tool_error(state) -> dict:\n", + " error = state.get(\"error\")\n", + " tool_calls = state[\"messages\"][-1].tool_calls\n", + " return {\n", + " \"messages\": [\n", + " ToolMessage(\n", + " content=f\"Error: {repr(error)}\\n please fix your mistakes.\",\n", + " tool_call_id=tc[\"id\"],\n", + " )\n", + " for tc in tool_calls\n", + " ]\n", + " }" + ] + }, + { + "cell_type": "markdown", + "id": "d0196604f8cbb07b", + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "source": [ + "## Define tools for the agent\n", + "\n", + "We will define a few tools that the agent will use to interact with the database.\n", + "\n", + "1. `list_tables_tool`: Fetch the available tables from the database\n", + "2. `get_schema_tool`: Fetch the DDL for a table\n", + "3. `db_query_tool`: Execute the query and fetch the results OR return an error message if the query fails\n", + "\n", + "For the first two tools, we will grab them from the `SQLDatabaseToolkit`, also available in the `langchain_community` package." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "452d049a3d2a4406", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_community.agent_toolkits import SQLDatabaseToolkit\n", + "from langchain_openai import ChatOpenAI\n", + "\n", + "toolkit = SQLDatabaseToolkit(db=db, llm=ChatOpenAI(model=\"gpt-4o\"))\n", + "tools = toolkit.get_tools()\n", + "\n", + "list_tables_tool = next(tool for tool in tools if tool.name == \"sql_db_list_tables\")\n", + "get_schema_tool = next(tool for tool in tools if tool.name == \"sql_db_schema\")\n", + "\n", + "print(list_tables_tool.invoke(\"\"))\n", + "\n", + "print(get_schema_tool.invoke(\"Artist\"))" + ] + }, + { + "cell_type": "markdown", + "id": "c16359edada327fa", + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "source": [ + "The third will be defined manually. For the `db_query_tool`, we will execute the query against the database and return the results." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f7eb708ecb4c7cfc", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_core.tools import tool\n", + "\n", + "\n", + "@tool\n", + "def db_query_tool(query: str) -> str:\n", + " \"\"\"\n", + " Execute a SQL query against the database and get back the result.\n", + " If the query is not correct, an error message will be returned.\n", + " If an error is returned, rewrite the query, check the query, and try again.\n", + " \"\"\"\n", + " result = db.run_no_throw(query)\n", + " if not result:\n", + " return \"Error: Query failed. Please rewrite your query and try again.\"\n", + " return result\n", + "\n", + "\n", + "print(db_query_tool.invoke(\"SELECT * FROM Artist LIMIT 10;\"))" + ] + }, + { + "cell_type": "markdown", + "id": "f1d66db8b8621639", + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "source": [ + "While not strictly a tool, we will prompt an LLM to check for common mistakes in the query and later add this as a node in the workflow." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "293017e8f05ac2b3", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_core.prompts import ChatPromptTemplate\n", + "\n", + "query_check_system = \"\"\"You are a SQL expert with a strong attention to detail.\n", + "Double check the SQLite query for common mistakes, including:\n", + "- Using NOT IN with NULL values\n", + "- Using UNION when UNION ALL should have been used\n", + "- Using BETWEEN for exclusive ranges\n", + "- Data type mismatch in predicates\n", + "- Properly quoting identifiers\n", + "- Using the correct number of arguments for functions\n", + "- Casting to the correct data type\n", + "- Using the proper columns for joins\n", + "\n", + "If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.\n", + "\n", + "You will call the appropriate tool to execute the query after running this check.\"\"\"\n", + "\n", + "query_check_prompt = ChatPromptTemplate.from_messages(\n", + " [(\"system\", query_check_system), (\"placeholder\", \"{messages}\")]\n", + ")\n", + "query_check = query_check_prompt | ChatOpenAI(model=\"gpt-4o\", temperature=0).bind_tools(\n", + " [db_query_tool], tool_choice=\"required\"\n", + ")\n", + "\n", + "query_check.invoke({\"messages\": [(\"user\", \"SELECT * FROM Artist LIMIT 10;\")]})" + ] + }, + { + "cell_type": "markdown", + "id": "66f88452151e8188", + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "source": [ + "## Define the workflow\n", + "\n", + "We will then define the workflow for the agent. The agent will first force-call the `list_tables_tool` to fetch the available tables from the database, then follow the steps mentioned at the beginning of the tutorial." + ] + }, + { + "cell_type": "markdown", + "id": "2fd9e41c-95c3-47aa-9a12-80b78cc7ac2d", + "metadata": {}, + "source": [ + "
\n", + "

Using Pydantic with LangChain

\n", + "

\n", + " This notebook uses Pydantic v2 BaseModel, which requires langchain-core >= 0.3. Using langchain-core < 0.3 will result in errors due to mixing of Pydantic v1 and v2 BaseModels.\n", + "

\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "90d04ceea7b6b010", + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Annotated, Literal\n", + "\n", + "from langchain_core.messages import AIMessage\n", + "from langchain_openai import ChatOpenAI\n", + "from langgraph.graph import END, START, StateGraph\n", + "from langgraph.graph.message import AnyMessage, add_messages\n", + "from pydantic import BaseModel, Field\n", + "from typing_extensions import TypedDict\n", + "\n", + "\n", + "# Define the state for the agent\n", + "class State(TypedDict):\n", + " messages: Annotated[list[AnyMessage], add_messages]\n", + "\n", + "\n", + "# Define a new graph\n", + "workflow = StateGraph(State)\n", + "\n", + "\n", + "# Add a node for the first tool call\n", + "def first_tool_call(state: State) -> dict[str, list[AIMessage]]:\n", + " return {\n", + " \"messages\": [\n", + " AIMessage(\n", + " content=\"\",\n", + " tool_calls=[\n", + " {\n", + " \"name\": \"sql_db_list_tables\",\n", + " \"args\": {},\n", + " \"id\": \"tool_abcd123\",\n", + " }\n", + " ],\n", + " )\n", + " ]\n", + " }\n", + "\n", + "\n", + "def model_check_query(state: State) -> dict[str, list[AIMessage]]:\n", + " \"\"\"\n", + " Use this tool to double-check if your query is correct before executing it.\n", + " \"\"\"\n", + " return {\"messages\": [query_check.invoke({\"messages\": [state[\"messages\"][-1]]})]}\n", + "\n", + "\n", + "workflow.add_node(\"first_tool_call\", first_tool_call)\n", + "\n", + "# Add nodes for the first two tools\n", + "workflow.add_node(\"list_tables_tool\", create_tool_node_with_fallback([list_tables_tool]))\n", + "workflow.add_node(\"get_schema_tool\", create_tool_node_with_fallback([get_schema_tool]))\n", + "\n", + "# Add a node for a model to choose the relevant tables based on the question and available tables\n", + "model_get_schema = ChatOpenAI(model=\"gpt-4o\", temperature=0).bind_tools([get_schema_tool])\n", + "workflow.add_node(\n", + " \"model_get_schema\",\n", + " lambda state: {\n", + " \"messages\": [model_get_schema.invoke(state[\"messages\"])],\n", + " },\n", + ")\n", + "\n", + "\n", + "# Describe a tool to represent the end state\n", + "class SubmitFinalAnswer(BaseModel):\n", + " \"\"\"Submit the final answer to the user based on the query results.\"\"\"\n", + "\n", + " final_answer: str = Field(..., description=\"The final answer to the user\")\n", + "\n", + "\n", + "# Add a node for a model to generate a query based on the question and schema\n", + "query_gen_system = \"\"\"You are a SQL expert with a strong attention to detail.\n", + "\n", + "Given an input question, output a syntactically correct SQLite query to run, then look at the results of the query and return the answer.\n", + "\n", + "DO NOT call any tool besides SubmitFinalAnswer to submit the final answer.\n", + "\n", + "When generating the query:\n", + "\n", + "Output the SQL query that answers the input question without a tool call.\n", + "\n", + "Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results.\n", + "You can order the results by a relevant column to return the most interesting examples in the database.\n", + "Never query for all the columns from a specific table, only ask for the relevant columns given the question.\n", + "\n", + "If you get an error while executing a query, rewrite the query and try again.\n", + "\n", + "If you get an empty result set, you should try to rewrite the query to get a non-empty result set.\n", + "NEVER make stuff up if you don't have enough information to answer the query... just say you don't have enough information.\n", + "\n", + "If you have enough information to answer the input question, simply invoke the appropriate tool to submit the final answer to the user.\n", + "\n", + "DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.\"\"\"\n", + "query_gen_prompt = ChatPromptTemplate.from_messages(\n", + " [(\"system\", query_gen_system), (\"placeholder\", \"{messages}\")]\n", + ")\n", + "query_gen = query_gen_prompt | ChatOpenAI(model=\"gpt-4o\", temperature=0).bind_tools(\n", + " [SubmitFinalAnswer]\n", + ")\n", + "\n", + "\n", + "def query_gen_node(state: State):\n", + " message = query_gen.invoke(state)\n", + "\n", + " # Sometimes, the LLM will hallucinate and call the wrong tool. We need to catch this and return an error message.\n", + " tool_messages = []\n", + " if message.tool_calls:\n", + " for tc in message.tool_calls:\n", + " if tc[\"name\"] != \"SubmitFinalAnswer\":\n", + " tool_messages.append(\n", + " ToolMessage(\n", + " content=f\"Error: The wrong tool was called: {tc['name']}. Please fix your mistakes. Remember to only call SubmitFinalAnswer to submit the final answer. Generated queries should be outputted WITHOUT a tool call.\",\n", + " tool_call_id=tc[\"id\"],\n", + " )\n", + " )\n", + " else:\n", + " tool_messages = []\n", + " return {\"messages\": [message] + tool_messages}\n", + "\n", + "\n", + "workflow.add_node(\"query_gen\", query_gen_node)\n", + "\n", + "# Add a node for the model to check the query before executing it\n", + "workflow.add_node(\"correct_query\", model_check_query)\n", + "\n", + "# Add node for executing the query\n", + "workflow.add_node(\"execute_query\", create_tool_node_with_fallback([db_query_tool]))\n", + "\n", + "\n", + "# Define a conditional edge to decide whether to continue or end the workflow\n", + "def should_continue(state: State) -> Literal[END, \"correct_query\", \"query_gen\"]:\n", + " messages = state[\"messages\"]\n", + " last_message = messages[-1]\n", + " # If there is a tool call, then we finish\n", + " if getattr(last_message, \"tool_calls\", None):\n", + " return END\n", + " if last_message.content.startswith(\"Error:\"):\n", + " return \"query_gen\"\n", + " else:\n", + " return \"correct_query\"\n", + "\n", + "\n", + "# Specify the edges between the nodes\n", + "workflow.add_edge(START, \"first_tool_call\")\n", + "workflow.add_edge(\"first_tool_call\", \"list_tables_tool\")\n", + "workflow.add_edge(\"list_tables_tool\", \"model_get_schema\")\n", + "workflow.add_edge(\"model_get_schema\", \"get_schema_tool\")\n", + "workflow.add_edge(\"get_schema_tool\", \"query_gen\")\n", + "workflow.add_conditional_edges(\n", + " \"query_gen\",\n", + " should_continue,\n", + ")\n", + "workflow.add_edge(\"correct_query\", \"execute_query\")\n", + "workflow.add_edge(\"execute_query\", \"query_gen\")\n", + "\n", + "# Compile the workflow into a runnable\n", + "app = workflow.compile()" + ] + }, + { + "cell_type": "markdown", + "id": "6c344ae086ba8d22", + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "source": [ + "## Visualize the graph" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4f200d1813897000", + "metadata": {}, + "outputs": [], + "source": [ + "from IPython.display import Image, display\n", + "from langchain_core.runnables.graph import MermaidDrawMethod\n", + "\n", + "display(\n", + " Image(\n", + " app.get_graph().draw_mermaid_png(\n", + " draw_method=MermaidDrawMethod.API,\n", + " )\n", + " )\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "7679154e", + "metadata": {}, + "source": [ + "### Auto-Instrument Agent\n", + "\n", + "This example uses a hosted Phoenix instance. If you don't have one already, create one for free [here](https://phoenix.arize.com) to get your API key.\n", + "\n", + "If you'd rather self-host Phoenix, you can do so by following the instructions [here](https://docs.arize.com/phoenix/deployment)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "294df67b", + "metadata": {}, + "outputs": [], + "source": [ + "_set_env(\"PHOENIX_API_KEY\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9fb84dee", + "metadata": {}, + "outputs": [], + "source": [ + "from openinference.instrumentation.langchain import LangChainInstrumentor\n", + "\n", + "from phoenix.otel import register\n", + "\n", + "os.environ[\"PHOENIX_COLLECTOR_ENDPOINT\"] = \"https://app.phoenix.arize.com\"\n", + "os.environ[\"PHOENIX_CLIENT_HEADERS\"] = f\"api_key={os.getenv('PHOENIX_API_KEY')}\"\n", + "\n", + "tracer_provider = register()\n", + "LangChainInstrumentor().instrument(tracer_provider=tracer_provider)" + ] + }, + { + "cell_type": "markdown", + "id": "bdf78dc68548522c", + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "source": [ + "## Run the agent" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "85958809-03c5-4e52-97cc-e7c0ae986f60", + "metadata": {}, + "outputs": [], + "source": [ + "messages = app.invoke({\"messages\": [(\"user\", \"Which sales agent made the most in sales in 2009?\")]})\n", + "json_str = messages[\"messages\"][-1].tool_calls[0][\"args\"][\"final_answer\"]\n", + "json_str" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3bf7709f-500c-4f28-bb85-dda317286c63", + "metadata": {}, + "outputs": [], + "source": [ + "for event in app.stream(\n", + " {\"messages\": [(\"user\", \"Which sales agent made the most in sales in 2009?\")]}\n", + "):\n", + " print(event)" + ] + }, + { + "cell_type": "markdown", + "id": "b50378cc", + "metadata": {}, + "source": [ + "## Next Steps\n", + "\n", + "You should now see traces in the [Phoenix UI](https://app.phoenix.arize.com)!\n", + "\n", + "These traces contain information about the execution of the agent, including the tool calls, the LLM calls, and the overall execution path.\n", + "\n", + "From here you can use the traces you've captured to:\n", + "- Visualize the execution path of your agent, track token usage, latency, and errors\n", + "- [Evaluate various metrics of your agent's execution](https://docs.arize.com/phoenix/evaluation/how-to-evals/evaluating-phoenix-traces), such as the quality of its tool outputs and overall path\n", + "- [Add annotations](https://docs.arize.com/phoenix/tracing/how-to-tracing/capture-feedback#send-annotations-to-phoenix) to traces to capture feedback or additional context on each execution\n", + "- Use your annotations and evaluations to [create datasets](https://docs.arize.com/phoenix/datasets-and-experiments/how-to-datasets/creating-datasets) to power [experiments](https://docs.arize.com/phoenix/datasets-and-experiments/how-to-experiments/run-experiments), or export for [fine-tuning](https://docs.arize.com/phoenix/datasets-and-experiments/how-to-datasets/exporting-datasets)." + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}