diff --git a/docs/_static/imgs/_langgraph_agent_evaluation_28_0.jpg b/docs/_static/imgs/_langgraph_agent_evaluation_28_0.jpg new file mode 100644 index 000000000..e580db4d4 Binary files /dev/null and b/docs/_static/imgs/_langgraph_agent_evaluation_28_0.jpg differ diff --git a/docs/howtos/integrations/_langgraph_agent_evaluation.md b/docs/howtos/integrations/_langgraph_agent_evaluation.md new file mode 100644 index 000000000..8ffa47a10 --- /dev/null +++ b/docs/howtos/integrations/_langgraph_agent_evaluation.md @@ -0,0 +1,424 @@ +# Building and Evaluating a ReAct Agent for Fetching Metal Prices + +AI agents are becoming increasingly valuable in domains like finance, e-commerce, and customer support. These agents can autonomously interact with APIs, retrieve real-time data, and perform tasks that align with user goals. Evaluating these agents is crucial to ensure they are effective, accurate, and responsive to different inputs. + +In this tutorial, we'll: + +1. Build a [ReAct agent](https://arxiv.org/abs/2210.03629) to fetch metal prices. +2. Set up an evaluation pipeline to track key performance metrics. +3. Run and assess the agent's effectiveness with different queries. + +Click the [link](https://colab.research.google.com/github/explodinggradients/ragas/blob/main/docs/howtos/integrations/langgraph_agent_evaluation.ipynb) to open the notebook in Google Colab. + +## Prerequisites +- Python 3.8+ +- Basic understanding of LangGraph, LangChain and LLMs + +## Installing Ragas and Other Dependencies +Install Ragas and Langgraph with pip: + + +```python +%pip install langgraph==0.2.44 +%pip install ragas +%pip install nltk +``` + +## Building the ReAct Agent + +### Initializing External Components +To begin, you have two options for setting up the external components: + +1. Use a Live API Key: + + - Sign up for an account on [metals.dev](https://metals.dev/) to get your API key. + +2. Simulate the API Response: + + - Alternatively, you can use a predefined JSON object to simulate the API response. This allows you to get started more quickly without needing a live API key. + + +Choose the method that best fits your needs to proceed with the setup. + +### Predefined JSON Object to simulate API response +If you would like to quickly get started without creating an account, you can bypass the setup process and use the predefined JSON object given below that simulates the API response. + + +```python +metal_price = { + "gold": 88.1553, + "silver": 1.0523, + "platinum": 32.169, + "palladium": 35.8252, + "lbma_gold_am": 88.3294, + "lbma_gold_pm": 88.2313, + "lbma_silver": 1.0545, + "lbma_platinum_am": 31.99, + "lbma_platinum_pm": 32.2793, + "lbma_palladium_am": 36.0088, + "lbma_palladium_pm": 36.2017, + "mcx_gold": 93.2689, + "mcx_gold_am": 94.281, + "mcx_gold_pm": 94.1764, + "mcx_silver": 1.125, + "mcx_silver_am": 1.1501, + "mcx_silver_pm": 1.1483, + "ibja_gold": 93.2713, + "copper": 0.0098, + "aluminum": 0.0026, + "lead": 0.0021, + "nickel": 0.0159, + "zinc": 0.0031, + "lme_copper": 0.0096, + "lme_aluminum": 0.0026, + "lme_lead": 0.002, + "lme_nickel": 0.0158, + "lme_zinc": 0.0031, +} +``` + +### Define the get_metal_price Tool + +The get_metal_price tool will be used by the agent to fetch the price of a specified metal. We'll create this tool using the @tool decorator from LangChain. + +If you want to use real-time data from the metals.dev API, you can modify the function to make a live request to the API. + + +```python +from langchain_core.tools import tool + + +# Define the tools for the agent to use +@tool +def get_metal_price(metal_name: str) -> float: + """Fetches the current per gram price of the specified metal. + + Args: + metal_name : The name of the metal (e.g., 'gold', 'silver', 'platinum'). + + Returns: + float: The current price of the metal in dollars per gram. + + Raises: + KeyError: If the specified metal is not found in the data source. + """ + try: + metal_name = metal_name.lower().strip() + if metal_name not in metal_price: + raise KeyError( + f"Metal '{metal_name}' not found. Available metals: {', '.join(metal_price['metals'].keys())}" + ) + return metal_price[metal_name] + except Exception as e: + raise Exception(f"Error fetching metal price: {str(e)}") +``` + +### Binding the Tool to the LLM +With the get_metal_price tool defined, the next step is to bind it to the ChatOpenAI model. This enables the agent to invoke the tool during its execution based on the user's requests allowing it to interact with external data and perform actions beyond its native capabilities. + + +```python +from langchain_openai import ChatOpenAI + +tools = [get_metal_price] +llm = ChatOpenAI(model="gpt-4o-mini") +llm_with_tools = llm.bind_tools(tools) +``` + +In LangGraph, state plays a crucial role in tracking and updating information as the graph executes. As different parts of the graph run, the state evolves to reflect the changes and contains information that is passed between nodes. + +For example, in a conversational system like this one, the state is used to track the exchanged messages. Each time a new message is generated, it is added to the state and the updated state is passed through the nodes, ensuring the conversation progresses logically. + +### Defining the State +To implement this in LangGraph, we define a state class that maintains a list of messages. Whenever a new message is produced it gets appended to this list, ensuring that the conversation history is continuously updated. + + +```python +from langgraph.graph import END +from langchain_core.messages import AnyMessage +from langgraph.graph.message import add_messages +from typing import Annotated +from typing_extensions import TypedDict + + +class GraphState(TypedDict): + messages: Annotated[list[AnyMessage], add_messages] +``` + +### Defining the should_continue Function +The `should_continue` function determines whether the conversation should proceed with further tool interactions or end. Specifically, it checks if the last message contains any tool calls (e.g., a request for metal prices). + +- If the last message includes tool calls, indicating that the agent has invoked an external tool, the conversation continues and moves to the "tools" node. +- If there are no tool calls, the conversation ends, represented by the END state. + + +```python +# Define the function that determines whether to continue or not +def should_continue(state: GraphState): + messages = state["messages"] + last_message = messages[-1] + if last_message.tool_calls: + return "tools" + return END +``` + +### Calling the Model +The `call_model` function interacts with the Language Model (LLM) to generate a response based on the current state of the conversation. It takes the updated state as input, processes it and returns a model-generated response. + + +```python +# Define the function that calls the model +def call_model(state: GraphState): + messages = state["messages"] + response = llm_with_tools.invoke(messages) + return {"messages": [response]} +``` + +### Creating the Assistant Node +The `assistant` node is a key component responsible for processing the current state of the conversation and using the Language Model (LLM) to generate a relevant response. It evaluates the state, determines the appropriate course of action, and invokes the LLM to produce a response that aligns with the ongoing dialogue. + + +```python +# Node +def assistant(state: GraphState): + response = llm_with_tools.invoke(state["messages"]) + return {"messages": [response]} +``` + +### Creating the Tool Node +The `tool_node` is responsible for managing interactions with external tools, such as fetching metal prices or performing other actions beyond the LLM's native capabilities. The tools themselves are defined earlier in the code, and the tool_node invokes these tools based on the current state and the needs of the conversation. + + +```python +from langgraph.prebuilt import ToolNode + +# Node +tools = [get_metal_price] +tool_node = ToolNode(tools) +``` + +### Building the Graph +The graph structure is the backbone of the agentic workflow, consisting of interconnected nodes and edges. To construct this graph, we use the StateGraph builder which allows us to define and connect various nodes. Each node represents a step in the process (e.g., the assistant node, tool node) and the edges dictate the flow of execution between these steps. + + +```python +from langgraph.graph import START, StateGraph +from IPython.display import Image, display + +# Define a new graph for the agent +builder = StateGraph(GraphState) + +# Define the two nodes we will cycle between +builder.add_node("assistant", assistant) +builder.add_node("tools", tool_node) + +# Set the entrypoint as `agent` +builder.add_edge(START, "assistant") + +# Making a conditional edge +# should_continue will determine which node is called next. +builder.add_conditional_edges("assistant", should_continue, ["tools", END]) + +# Making a normal edge from `tools` to `agent`. +# The `agent` node will be called after the `tool`. +builder.add_edge("tools", "assistant") + +# Compile and display the graph for a visual overview +react_graph = builder.compile() +display(Image(react_graph.get_graph(xray=True).draw_mermaid_png())) +``` + + + +![jpeg](../../_static/imgs/_langgraph_agent_evaluation_28_0.jpg) + + + +To test our setup, we will run the agent with a query. The agent will fetch the price of copper using the metals.dev API. + + +```python +from langchain_core.messages import HumanMessage + +messages = [HumanMessage(content="What is the price of copper?")] +result = react_graph.invoke({"messages": messages}) +``` + + +```python +result["messages"] +``` + + + + + [HumanMessage(content='What is the price of copper?', id='4122f5d4-e298-49e8-a0e0-c98adda78c6c'), + AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_DkVQBK4UMgiXrpguUS2qC4mA', 'function': {'arguments': '{"metal_name":"copper"}', 'name': 'get_metal_price'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 18, 'prompt_tokens': 116, 'total_tokens': 134, 'prompt_tokens_details': {'cached_tokens': 0, 'audio_tokens': 0}, 'completion_tokens_details': {'reasoning_tokens': 0, 'audio_tokens': 0, 'accepted_prediction_tokens': 0, 'rejected_prediction_tokens': 0}}, 'model_name': 'gpt-4o-mini-2024-07-18', 'system_fingerprint': 'fp_0ba0d124f1', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-0f77b156-e43e-4c1e-bd3a-307333eefb68-0', tool_calls=[{'name': 'get_metal_price', 'args': {'metal_name': 'copper'}, 'id': 'call_DkVQBK4UMgiXrpguUS2qC4mA', 'type': 'tool_call'}], usage_metadata={'input_tokens': 116, 'output_tokens': 18, 'total_tokens': 134}), + ToolMessage(content='0.0098', name='get_metal_price', id='422c089a-6b76-4e48-952f-8925c3700ae3', tool_call_id='call_DkVQBK4UMgiXrpguUS2qC4mA'), + AIMessage(content='The price of copper is $0.0098 per gram.', response_metadata={'token_usage': {'completion_tokens': 14, 'prompt_tokens': 148, 'total_tokens': 162, 'prompt_tokens_details': {'cached_tokens': 0, 'audio_tokens': 0}, 'completion_tokens_details': {'reasoning_tokens': 0, 'audio_tokens': 0, 'accepted_prediction_tokens': 0, 'rejected_prediction_tokens': 0}}, 'model_name': 'gpt-4o-mini-2024-07-18', 'system_fingerprint': 'fp_0ba0d124f1', 'finish_reason': 'stop', 'logprobs': None}, id='run-67cbf98b-4fa6-431e-9ce4-58697a76c36e-0', usage_metadata={'input_tokens': 148, 'output_tokens': 14, 'total_tokens': 162})] + + + +### Converting Messages to Ragas Evaluation Format + +In the current implementation, the GraphState stores messages exchanged between the human user, the AI (LLM's responses), and any external tools (APIs or services the AI uses) in a list. Each message is an object in LangChain's format + +```python +# Implementation of Graph State +class GraphState(TypedDict): + messages: Annotated[list[AnyMessage], add_messages] +``` + +Each time a message is exchanged during agent execution, it gets added to the messages list in the GraphState. However, Ragas requires a specific message format for evaluating interactions. + +Ragas uses its own format to evaluate agent interactions. So, if you're using LangGraph, you will need to convert the LangChain message objects into Ragas message objects. This allows you to evaluate your AI agents with Ragas’ built-in evaluation tools. + +**Goal:** Convert the list of LangChain messages (e.g., HumanMessage, AIMessage, and ToolMessage) into the format expected by Ragas, so the evaluation framework can understand and process them properly. + +To convert a list of LangChain messages into a format suitable for Ragas evaluation, Ragas provides the function [convert_to_ragas_messages][ragas.integrations.langgraph.convert_to_ragas_messages], which can be used to transform LangChain messages into the format expected by Ragas. + +Here's how you can use the function: + + +```python +from ragas.integrations.langgraph import convert_to_ragas_messages + +# Assuming 'result["messages"]' contains the list of LangChain messages +ragas_trace = convert_to_ragas_messages(result["messages"]) +``` + + +```python +ragas_trace # List of Ragas messages +``` + + + + + [HumanMessage(content='What is the price of copper?', metadata=None, type='human'), + AIMessage(content='', metadata=None, type='ai', tool_calls=[ToolCall(name='get_metal_price', args={'metal_name': 'copper'})]), + ToolMessage(content='0.0098', metadata=None, type='tool'), + AIMessage(content='The price of copper is $0.0098 per gram.', metadata=None, type='ai', tool_calls=None)] + + + +## Evaluating the Agent's Performance + +For this tutorial, let us evaluate the Agent with the following metrics: + +- [Tool call Accuracy](https://docs.ragas.io/en/stable/concepts/metrics/available_metrics/agents/#tool-call-accuracy):ToolCallAccuracy is a metric that can be used to evaluate the performance of the LLM in identifying and calling the required tools to complete a given task. + +- [Agent Goal accuracy](https://docs.ragas.io/en/stable/concepts/metrics/available_metrics/agents/#agent-goal-accuracy): Agent goal accuracy is a metric that can be used to evaluate the performance of the LLM in identifying and achieving the goals of the user. This is a binary metric, with 1 indicating that the AI has achieved the goal and 0 indicating that the AI has not achieved the goal. + + +First, let us actually run our Agent with a couple of queries, and make sure we have the ground truth labels for these queries. + +### Tool Call Accuracy + + +```python +from ragas.metrics import ToolCallAccuracy +from ragas.dataset_schema import MultiTurnSample +from ragas.integrations.langgraph import convert_to_ragas_messages +import ragas.messages as r + + +ragas_trace = convert_to_ragas_messages( + messages=result["messages"] +) # List of Ragas messages converted using the Ragas function + +sample = MultiTurnSample( + user_input=ragas_trace, + reference_tool_calls=[ + r.ToolCall(name="get_metal_price", args={"metal_name": "copper"}) + ], +) + +tool_accuracy_scorer = ToolCallAccuracy() +tool_accuracy_scorer.llm = ChatOpenAI(model="gpt-4o-mini") +await tool_accuracy_scorer.multi_turn_ascore(sample) +``` + + + + + 1.0 + + + +Tool Call Accuracy: 1, because the LLM correctly identified and used the necessary tool (get_metal_price) with the correct parameters (i.e., metal name as "copper"). + +### Agent Goal Accuracy + + +```python +messages = [HumanMessage(content="What is the price of 10 grams of silver?")] + +result = react_graph.invoke({"messages": messages}) +``` + + +```python +result["messages"] # List of Langchain messages +``` + + + + + [HumanMessage(content='What is the price of 10 grams of silver?', id='51a469de-5b7c-4d01-ab71-f8db64c8da49'), + AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_rdplOo95CRwo3mZcPu4dmNxG', 'function': {'arguments': '{"metal_name":"silver"}', 'name': 'get_metal_price'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 17, 'prompt_tokens': 120, 'total_tokens': 137, 'prompt_tokens_details': {'cached_tokens': 0, 'audio_tokens': 0}, 'completion_tokens_details': {'reasoning_tokens': 0, 'audio_tokens': 0, 'accepted_prediction_tokens': 0, 'rejected_prediction_tokens': 0}}, 'model_name': 'gpt-4o-mini-2024-07-18', 'system_fingerprint': 'fp_0ba0d124f1', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-3bb60e27-1275-41f1-a46e-03f77984c9d8-0', tool_calls=[{'name': 'get_metal_price', 'args': {'metal_name': 'silver'}, 'id': 'call_rdplOo95CRwo3mZcPu4dmNxG', 'type': 'tool_call'}], usage_metadata={'input_tokens': 120, 'output_tokens': 17, 'total_tokens': 137}), + ToolMessage(content='1.0523', name='get_metal_price', id='0b5f9260-df26-4164-b042-6df2e869adfb', tool_call_id='call_rdplOo95CRwo3mZcPu4dmNxG'), + AIMessage(content='The current price of silver is approximately $1.0523 per gram. Therefore, the price of 10 grams of silver would be about $10.52.', response_metadata={'token_usage': {'completion_tokens': 34, 'prompt_tokens': 151, 'total_tokens': 185, 'prompt_tokens_details': {'cached_tokens': 0, 'audio_tokens': 0}, 'completion_tokens_details': {'reasoning_tokens': 0, 'audio_tokens': 0, 'accepted_prediction_tokens': 0, 'rejected_prediction_tokens': 0}}, 'model_name': 'gpt-4o-mini-2024-07-18', 'system_fingerprint': 'fp_0ba0d124f1', 'finish_reason': 'stop', 'logprobs': None}, id='run-93e38f71-cc9d-41d6-812a-bfad9f9231b2-0', usage_metadata={'input_tokens': 151, 'output_tokens': 34, 'total_tokens': 185})] + + + + +```python +from ragas.integrations.langgraph import convert_to_ragas_messages + +ragas_trace = convert_to_ragas_messages( + result["messages"] +) # List of Ragas messages converted using the Ragas function +ragas_trace +``` + + + + + [HumanMessage(content='What is the price of 10 grams of silver?', metadata=None, type='human'), + AIMessage(content='', metadata=None, type='ai', tool_calls=[ToolCall(name='get_metal_price', args={'metal_name': 'silver'})]), + ToolMessage(content='1.0523', metadata=None, type='tool'), + AIMessage(content='The current price of silver is approximately $1.0523 per gram. Therefore, the price of 10 grams of silver would be about $10.52.', metadata=None, type='ai', tool_calls=None)] + + + + +```python +from ragas.dataset_schema import MultiTurnSample +from ragas.metrics import AgentGoalAccuracyWithReference +from ragas.llms import LangchainLLMWrapper + + +sample = MultiTurnSample( + user_input=ragas_trace, + reference="Price of 10 grams of silver", +) + +scorer = AgentGoalAccuracyWithReference() + +evaluator_llm = LangchainLLMWrapper(ChatOpenAI(model="gpt-4o-mini")) +scorer.llm = evaluator_llm +await scorer.multi_turn_ascore(sample) +``` + + + + + 1.0 + + + +Agent Goal Accuracy: 1, because the LLM correctly achieved the user’s goal of retrieving the price of 10 grams of silver. + +## What’s next +🎉 Congratulations! We have learned how to evaluate an agent using the Ragas evaluation framework. diff --git a/docs/howtos/integrations/langgraph_agent_evaluation.ipynb b/docs/howtos/integrations/langgraph_agent_evaluation.ipynb new file mode 100644 index 000000000..3f2b59698 --- /dev/null +++ b/docs/howtos/integrations/langgraph_agent_evaluation.ipynb @@ -0,0 +1,783 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "t1ub1OLYZQvz" + }, + "source": [ + "# Building and Evaluating a ReAct Agent for Fetching Metal Prices\n", + "\n", + "AI agents are becoming increasingly valuable in domains like finance, e-commerce, and customer support. These agents can autonomously interact with APIs, retrieve real-time data, and perform tasks that align with user goals. Evaluating these agents is crucial to ensure they are effective, accurate, and responsive to different inputs.\n", + "\n", + "In this tutorial, we'll:\n", + "\n", + "1. Build a [ReAct agent](https://arxiv.org/abs/2210.03629) to fetch metal prices.\n", + "2. Set up an evaluation pipeline to track key performance metrics.\n", + "3. Run and assess the agent's effectiveness with different queries.\n", + "\n", + "Click the [link](https://colab.research.google.com/github/explodinggradients/ragas/blob/main/docs/howtos/integrations/langgraph_agent_evaluation.ipynb) to open the notebook in Google Colab." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Prerequisites\n", + "- Python 3.8+\n", + "- Basic understanding of LangGraph, LangChain and LLMs" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Q8Ms4ngAZQv1" + }, + "source": [ + "## Installing Ragas and Other Dependencies\n", + "Install Ragas and Langgraph with pip:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "collapsed": true, + "id": "vQk4aWbpZQv1", + "outputId": "4af0ac60-3d1a-4e41-de6e-d33f74921845" + }, + "outputs": [], + "source": [ + "%pip install langgraph==0.2.44\n", + "%pip install ragas\n", + "%pip install nltk" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "eJJ-WKWMZQv2" + }, + "source": [ + "## Building the ReAct Agent" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "lAXAIbo7ZQv2" + }, + "source": [ + "### Initializing External Components\n", + "To begin, you have two options for setting up the external components:\n", + "\n", + "1. Use a Live API Key: \n", + "\n", + " - Sign up for an account on [metals.dev](https://metals.dev/) to get your API key. \n", + " \n", + "2. Simulate the API Response: \n", + "\n", + " - Alternatively, you can use a predefined JSON object to simulate the API response. This allows you to get started more quickly without needing a live API key. \n", + "\n", + "\n", + "Choose the method that best fits your needs to proceed with the setup." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "PNZijyBXZQv3" + }, + "source": [ + "### Predefined JSON Object to simulate API response\n", + "If you would like to quickly get started without creating an account, you can bypass the setup process and use the predefined JSON object given below that simulates the API response." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "puMC36BPZQv3" + }, + "outputs": [], + "source": [ + "metal_price = {\n", + " \"gold\": 88.1553,\n", + " \"silver\": 1.0523,\n", + " \"platinum\": 32.169,\n", + " \"palladium\": 35.8252,\n", + " \"lbma_gold_am\": 88.3294,\n", + " \"lbma_gold_pm\": 88.2313,\n", + " \"lbma_silver\": 1.0545,\n", + " \"lbma_platinum_am\": 31.99,\n", + " \"lbma_platinum_pm\": 32.2793,\n", + " \"lbma_palladium_am\": 36.0088,\n", + " \"lbma_palladium_pm\": 36.2017,\n", + " \"mcx_gold\": 93.2689,\n", + " \"mcx_gold_am\": 94.281,\n", + " \"mcx_gold_pm\": 94.1764,\n", + " \"mcx_silver\": 1.125,\n", + " \"mcx_silver_am\": 1.1501,\n", + " \"mcx_silver_pm\": 1.1483,\n", + " \"ibja_gold\": 93.2713,\n", + " \"copper\": 0.0098,\n", + " \"aluminum\": 0.0026,\n", + " \"lead\": 0.0021,\n", + " \"nickel\": 0.0159,\n", + " \"zinc\": 0.0031,\n", + " \"lme_copper\": 0.0096,\n", + " \"lme_aluminum\": 0.0026,\n", + " \"lme_lead\": 0.002,\n", + " \"lme_nickel\": 0.0158,\n", + " \"lme_zinc\": 0.0031,\n", + "}" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2SduQYJbZQv3" + }, + "source": [ + "### Define the get_metal_price Tool\n", + "\n", + "The get_metal_price tool will be used by the agent to fetch the price of a specified metal. We'll create this tool using the @tool decorator from LangChain.\n", + "\n", + "If you want to use real-time data from the metals.dev API, you can modify the function to make a live request to the API." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "id": "1X2TsFLfZQv3" + }, + "outputs": [], + "source": [ + "from langchain_core.tools import tool\n", + "\n", + "\n", + "# Define the tools for the agent to use\n", + "@tool\n", + "def get_metal_price(metal_name: str) -> float:\n", + " \"\"\"Fetches the current per gram price of the specified metal.\n", + "\n", + " Args:\n", + " metal_name : The name of the metal (e.g., 'gold', 'silver', 'platinum').\n", + "\n", + " Returns:\n", + " float: The current price of the metal in dollars per gram.\n", + "\n", + " Raises:\n", + " KeyError: If the specified metal is not found in the data source.\n", + " \"\"\"\n", + " try:\n", + " metal_name = metal_name.lower().strip()\n", + " if metal_name not in metal_price:\n", + " raise KeyError(\n", + " f\"Metal '{metal_name}' not found. Available metals: {', '.join(metal_price['metals'].keys())}\"\n", + " )\n", + " return metal_price[metal_name]\n", + " except Exception as e:\n", + " raise Exception(f\"Error fetching metal price: {str(e)}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "j85XikcLZQv4" + }, + "source": [ + "### Binding the Tool to the LLM\n", + "With the get_metal_price tool defined, the next step is to bind it to the ChatOpenAI model. This enables the agent to invoke the tool during its execution based on the user's requests allowing it to interact with external data and perform actions beyond its native capabilities." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "id": "lsxVT0lUZQv4" + }, + "outputs": [], + "source": [ + "from langchain_openai import ChatOpenAI\n", + "\n", + "tools = [get_metal_price]\n", + "llm = ChatOpenAI(model=\"gpt-4o-mini\")\n", + "llm_with_tools = llm.bind_tools(tools)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "yuDuSrmQZQv4" + }, + "source": [ + "In LangGraph, state plays a crucial role in tracking and updating information as the graph executes. As different parts of the graph run, the state evolves to reflect the changes and contains information that is passed between nodes.\n", + "\n", + "For example, in a conversational system like this one, the state is used to track the exchanged messages. Each time a new message is generated, it is added to the state and the updated state is passed through the nodes, ensuring the conversation progresses logically.\n", + "\n", + "### Defining the State\n", + "To implement this in LangGraph, we define a state class that maintains a list of messages. Whenever a new message is produced it gets appended to this list, ensuring that the conversation history is continuously updated." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "id": "JHHXxYT1ZQv4" + }, + "outputs": [], + "source": [ + "from langgraph.graph import END\n", + "from langchain_core.messages import AnyMessage\n", + "from langgraph.graph.message import add_messages\n", + "from typing import Annotated\n", + "from typing_extensions import TypedDict\n", + "\n", + "\n", + "class GraphState(TypedDict):\n", + " messages: Annotated[list[AnyMessage], add_messages]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1KGbjrAOZQv4" + }, + "source": [ + "### Defining the should_continue Function\n", + "The `should_continue` function determines whether the conversation should proceed with further tool interactions or end. Specifically, it checks if the last message contains any tool calls (e.g., a request for metal prices).\n", + "\n", + "- If the last message includes tool calls, indicating that the agent has invoked an external tool, the conversation continues and moves to the \"tools\" node.\n", + "- If there are no tool calls, the conversation ends, represented by the END state." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "id": "KjppKPRDZQv4" + }, + "outputs": [], + "source": [ + "# Define the function that determines whether to continue or not\n", + "def should_continue(state: GraphState):\n", + " messages = state[\"messages\"]\n", + " last_message = messages[-1]\n", + " if last_message.tool_calls:\n", + " return \"tools\"\n", + " return END" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ZbyJRNRvZQv4" + }, + "source": [ + "### Calling the Model\n", + "The `call_model` function interacts with the Language Model (LLM) to generate a response based on the current state of the conversation. It takes the updated state as input, processes it and returns a model-generated response." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "id": "ZYflc7eZZQv4" + }, + "outputs": [], + "source": [ + "# Define the function that calls the model\n", + "def call_model(state: GraphState):\n", + " messages = state[\"messages\"]\n", + " response = llm_with_tools.invoke(messages)\n", + " return {\"messages\": [response]}" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "VzxIHVa2ZQv4" + }, + "source": [ + "### Creating the Assistant Node\n", + "The `assistant` node is a key component responsible for processing the current state of the conversation and using the Language Model (LLM) to generate a relevant response. It evaluates the state, determines the appropriate course of action, and invokes the LLM to produce a response that aligns with the ongoing dialogue." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "id": "_fPD6W2SZQv4" + }, + "outputs": [], + "source": [ + "# Node\n", + "def assistant(state: GraphState):\n", + " response = llm_with_tools.invoke(state[\"messages\"])\n", + " return {\"messages\": [response]}" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Vc3No3agZQv5" + }, + "source": [ + "### Creating the Tool Node\n", + "The `tool_node` is responsible for managing interactions with external tools, such as fetching metal prices or performing other actions beyond the LLM's native capabilities. The tools themselves are defined earlier in the code, and the tool_node invokes these tools based on the current state and the needs of the conversation." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "id": "vz2qlceBZQv5" + }, + "outputs": [], + "source": [ + "from langgraph.prebuilt import ToolNode\n", + "\n", + "# Node\n", + "tools = [get_metal_price]\n", + "tool_node = ToolNode(tools)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "M2FWZfGFZQv5" + }, + "source": [ + "### Building the Graph\n", + "The graph structure is the backbone of the agentic workflow, consisting of interconnected nodes and edges. To construct this graph, we use the StateGraph builder which allows us to define and connect various nodes. Each node represents a step in the process (e.g., the assistant node, tool node) and the edges dictate the flow of execution between these steps." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 266 + }, + "id": "FeGI8G3KZQv5", + "outputId": "4575b3ed-e162-4419-f44f-ff0086aaf546" + }, + "outputs": [ + { + "data": { + "image/jpeg": "/9j/4AAQSkZJRgABAQAAAQABAAD/4gHYSUNDX1BST0ZJTEUAAQEAAAHIAAAAAAQwAABtbnRyUkdCIFhZWiAH4AABAAEAAAAAAABhY3NwAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAQAA9tYAAQAAAADTLQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAlkZXNjAAAA8AAAACRyWFlaAAABFAAAABRnWFlaAAABKAAAABRiWFlaAAABPAAAABR3dHB0AAABUAAAABRyVFJDAAABZAAAAChnVFJDAAABZAAAAChiVFJDAAABZAAAAChjcHJ0AAABjAAAADxtbHVjAAAAAAAAAAEAAAAMZW5VUwAAAAgAAAAcAHMAUgBHAEJYWVogAAAAAAAAb6IAADj1AAADkFhZWiAAAAAAAABimQAAt4UAABjaWFlaIAAAAAAAACSgAAAPhAAAts9YWVogAAAAAAAA9tYAAQAAAADTLXBhcmEAAAAAAAQAAAACZmYAAPKnAAANWQAAE9AAAApbAAAAAAAAAABtbHVjAAAAAAAAAAEAAAAMZW5VUwAAACAAAAAcAEcAbwBvAGcAbABlACAASQBuAGMALgAgADIAMAAxADb/2wBDAAMCAgMCAgMDAwMEAwMEBQgFBQQEBQoHBwYIDAoMDAsKCwsNDhIQDQ4RDgsLEBYQERMUFRUVDA8XGBYUGBIUFRT/2wBDAQMEBAUEBQkFBQkUDQsNFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBT/wAARCAD5ANYDASIAAhEBAxEB/8QAHQABAAIDAQEBAQAAAAAAAAAAAAUGAwQHCAECCf/EAFEQAAEEAQIDAgYLDAcGBwAAAAEAAgMEBQYRBxIhEzEVFiJBUZQIFBcyVVZhdNHS0yM1NlRxdYGRk5WytCU3QkNSgpIYJGRylqEzNFNiscHw/8QAGwEBAQADAQEBAAAAAAAAAAAAAAECAwUEBgf/xAAzEQEAAQIBCQUJAQADAAAAAAAAAQIRAwQSITFBUVKR0RQzYXGhBRMVI2KSscHhgSLw8f/aAAwDAQACEQMRAD8A/qmiIgIiICIiAsNq5XpR89ieOuz/ABSvDR+sqDu37uevz47FTGlVrnkt5NrQ5zX/APpQhwLS4d7nuBa3cNAc4u5Ptbh/p+F5llxcF+ydua1fb7ZmcR5y9+5/V0W+KKae8n/IW29u+NWF+F6HrLPpTxqwvwxQ9ZZ9KeKuF+B6HqzPoTxVwvwPQ9WZ9CvyfH0XQeNWF+GKHrLPpTxqwvwxQ9ZZ9KeKuF+B6HqzPoTxVwvwPQ9WZ9CfJ8fQ0HjVhfhih6yz6U8asL8MUPWWfSnirhfgeh6sz6E8VcL8D0PVmfQnyfH0NB41YX4Yoess+lblTIVb7S6rZhstHeYZA4D9S0/FXC/A9D1Zn0LUtaB05bkErsNThnad22K0QhmafkkZs4foKfJnbPp/E0J9FWI7NzSM8MN+1NksPK4RsvT8va1XE7NbKQAHMPQB+24O3NvuXCzrXXRm+MEwIiLWgiIgIiICIiAiIgIiICIiAojV2Yfp/S+VyMQDpq1Z8kTXdxft5IP6dlLqvcQqct7ROZjhaZJm13SsY0blzmeWAB6SW7LbgxE4lMVarwsa0hp/Dx4DDVKEZ5uxZ5cnnkkJ3e8/K5xc4n0kqRWGnaivVILMDueGZjZGO9LSNwf1FZlhVMzVM1a0FUuIHFbS3C6LHv1JkzSfkJHRVIIa01madzW8z+SKFj3kNHUnbYbjchW1cU9krQqPg07k48frBupMc+zJiM5o7HG7NQldG0OZNEA4Ojl6Atc0tPL1LehWI2cp7JjT+N4q6b0m2tetUc3hfC8OTq463ODzyQthaGxwu8lzZHOdISAzZodylwVgtcftBUdct0hZz3tfOvtNotilpzthNhw3bCJzH2XaHcbN59zuBsuUx5fWendd8Ltfax0nlrtuxpGzicxDp6g+4+neklrTDnij3LWu7J43G4aehPnVA4t4/Wep5tTDMYbX+W1Bj9VwW8fUxsEwwsOJguRSRyRtjIjsSGJpJGz5ec9GgDoHpi3x20TT1je0ocpYsahozR17VCnjbVh8DpI2yMLzHE4NYWvb5ZPLuSN9wQIvgLx7xvHPBWblWjdx1yvYsxyV56VlkYjZYkijc2aSJjHuc1gc5jSSwktcAQtbhLp+7jOMXGnJWsbYqQZLLY91W3NA5jbUbMdA0ljiNnta/nb03APMO/dRfsY7GQ0vh8poTMaezWNyWLymUte3rFF7aFmGW9JLG6GxtyPLmzNPKDuOV24GyDuCIiDXyFCvlaFmlbibPVsxuhlif3PY4bOB/KCVEaGvz39Nwi1L29upLNRmlO+8j4ZXRF53/wAXJzfpU+qzw8b2mn5Lg35L921cj5htvHJO90Z2+VnKf0r0U9zVffH7XYsyIi86CIiAiIgIiICIiAiIgIiICIiCqU52aDeaNvaLAOeXU7fXkqbncwynuY3cnkf0btsw7EN7THqvhFobX+RjyWo9JYTP3mxCFlrIUYp5BGCSGhzgTy7ucdvlKtr2NkY5j2h7HDYtcNwR6Cq0/h9joSTjbOQwoP8AdY62+OIejaI7xt/Q0f8AYL0TVRiaa5tPO/8A3/WWiVePsbeFBaG+5vpblBJA8EwbA+f+z8gVm0fw70tw9hsxaY09jNPxWXNdOzG1GQCUjcAuDQN9tz3+lYfEmx8as9+2h+yTxJsfGrPftofsk93h8fpKWjetCKr+JNj41Z79tD9kqnex2Wr8VcHp5mqcx4OuYW/flJlh7TtYZ6bGbfc/e8tiTfp38vUed7vD4/SS0b3VFC6s0XgNd4xuO1HhaGdx7ZBM2rka7Z4w8AgO5XAjcBxG/wApWj4k2PjVnv20P2SeJNj41Z79tD9knu8Pj9JLRvQDfY3cKWBwbw40u0PGzgMTB1G4Ox8n0gfqUnpngroDRmXiyuA0XgcNk4g5sdyjj4oZWhw2cA5rQRuCQVueJNj41Z79tD9kvviBTsO/pDIZXKs337G1deIj+VjOVrh8jgQmZhxrr5R/4Wh+crkPG7t8Nipeeo/mhyGRhd5ELOodFG4d8p7unvBu4kHla6ywQR1oI4YWNiijaGMYwbBrQNgAPMF8q1YaVeOvXhjrwRtDWRRNDWtA7gAOgCyrCuuJjNp1QSIiLUgiIgIiICIiAiIgIiICIiAiIgIiICIiAufZYt937SwJPN4sZfYebb21jd/P+TzfpHn6Cuf5Xf3ftLdW7eLGX6EDf/zWN7vPt+Tp3b+ZB0BERAREQEREBERAREQEREBERAREQEREBERAREQEREBERAXPcsB/tA6VPM0HxXzHk7dT/veM677d36fOP0dCXPctt/tBaV6nm8V8xsOX/i8Z5/8A9/2QdCREQEREBERAREQEREBERAREQEREBERAREQERaeXy1fB46a7aLhDEBuGNLnOJIDWtA7ySQAPOSFYiaptGsbiKlP1Dquby4cVia7HdRHYuyOkaP8A3cse2/pAJHylfnw7rD8Qwfrc32a9fZa98c4Wy7oqR4d1h+IYP1ub7NPDusPxDB+tzfZp2WvfHOCy7rwHrH2e2V097IivibXCud2ocTHc06MfFmA7t5Z7FZzXsd7X35T7XG2w8oPB8wXsXw7rD8Qwfrc32a5BnvY/zah9kHh+LVjH4YZnHVexNQWJDFPM0csU7j2e/Oxp2H/Kz/D1dlr3xzgs9LIqR4d1h+IYP1ub7NPDusPxDB+tzfZp2WvfHOCy7oqR4d1h+IYP1ub7NPDusPxDB+tzfZp2WvfHOCy7oqUzPaua7d+NwsjR3tbdmaT+nsjt+pWPAZyHP0PbEbHwSMeYpq8u3PDI33zHbdOnpG4IIIJBBWqvArw4zp1eE3LJJERaEEREBERAREQEREBERAREQFUuJh2wVEeY5ahuD85jVtVR4m/eKh+dqH8zGvTk3f0ecMqdcNtERepiIiICKJy2qsXgsthsbesmG7mJn16MXZvd2r2RukcNwCG7Ma47uIHTbv6KRt24KFWazZmjr1oWOklmlcGsY0DcucT0AAG5JUGVFr43I1cxjqt+lPHapWomTwTxO5mSRuAc1zT5wQQR+VbCoItXKZWng8bayORtQ0aFWJ009mw8MjijaN3Oc49AAASSVmrzx2oI5oXiSKRoex7e5zSNwQgyLR0Af6V1kPMMszYAf8DVK3lo6A++2s/zvH/I1VZ7uvy/cMo1SuKIi5bEREQEREBERAREQEREBERAVR4m/eKh+dqH8zGrcqjxN+8VD87UP5mNenJu/o84ZU64bapHGvU1PSPDDOZG7NlIIuSOux2EkbHddLLI2KJsTndGuc97W8x6DffzK7qK1TpbFa10/dwecpR5HFXGdnPWl32eNwR1BBBBAIIIIIBBBC9M6mLzLpStxPZkOJvD+nl7uIzEunamSw5y2ddl5aU0kk0bh7adG1zecRjps4MPVpO6zxRal1Nw/vYHS1rWdfUWAz1eTUun8rqD+k3V3QbmCpf3I5H7tla7mbzbOG7AQF2Cn7HXh9RZkBFgXOfkaRx92aW/ZkltQl7X8skjpC55BY3lc4lzQNmkAkL432OfD5mAfhm4KVtR91uQfK3I2hadYawxtkNjte1JDCWjd/QEha82RzHEanhzWq+BGS03qLU8mOyFrK421WzN6UvkMNS04stRc3LJJHKzbmIJ8huzj0Kr+GqZbFaV17ozXuX1W/XE+mb1508uZfNjclCwnexU5SDAQSxrotmbNdts4EleiMZwk0jhYdMQ0MNHUi00+aTFMhlkaK75Y3xyu995Zc2R+5fzHdxPf1WnovgZofh9dtW8HgmVrFisab3z2JrPLXJ5jCwSvcGRk7Esbs07Dp0VzZHG8fVq6V9jvw0wuOvatv5fVUdD2lXx+oJYZnymkJHsFmQuNes1jHOLY9tthyjqVWqWqtaxcOsjp/IagymPyWN4lY7AMuw5Q27UVSZ9ZzojZdG0zbdu8cz2dRsCDsu8wexv4eVdPHBw4KWPGCyy5FE3JWg6tKwODHQP7Xmg2D3DaMtGziNtlt47gHoLEV5IKWAbWhkyFTKvjjtThr7dYh0M5HP1eCAXE+/I8vmUzZHCeKWPt4vTXsgdFSZ3OZLCUtKVszT9v5OaeeCR7LPaR9s5xe6JxgYSxxLdi4bbOIXonhXputpfQmIq1bmQvRSV45+1yV+W5Ju5jTsHyucQ30NB2HmC3Z9AaftZjN5SfGxz3M1RjxuQdK5z2WKzO05Y3MJ5dvusm+wBPN136L8aD4eYHhnhXYnTtSaljzJ2vYy25rHKeVrdmmV7i1oa1oDQQBt0CyiLSLGtHQH321n+d4/5Gqt5aOgPvtrP87x/yNVbJ7uvy/cMo1SuKIi5bEREQEREBERAREQEREBERAVR4m/eKh+dqH8zGrcorU2D8YcPLTbN7WmD45oZuXm7OWN4ewkbjcczRuNxuNxuN1vwKooxaaqtUTCxoloooZ9/UVfyJdJ2rEg6OfSuVnRH5WmSRjtvytB+RanjPmDfbTbo3LvmLXOcWTVHMZy8m4e8TcrXESNIaSCRuQCGkjoZn1R90dSyyIoTwtnviZlfWqX26eFs98TMr61S+3TM+qPujqtk2ihPC2e+JmV9apfbqr3eMdbH8Qsfoexg78WqshUfdrY4z1eaSFm/M7m7blHc47E7kNJA2BTM+qPujqWdDRQnhbPfEzK+tUvt08LZ74mZX1ql9umZ9UfdHUsm0UJ4Wz3xMyvrVL7dPC2e+JmV9apfbpmfVH3R1LJtaOgPvtrP87x/yNVRGP1RlcpI+GHSmRgsNBJiuWK0TmgPczmLe1Lw0ljtncpDgNwSCFbdKYObC0rDrcrJb92c2rJi37Nry1rQ1m/Xla1jW7nbfbfYb7DXiTFGHVEzGnRomJ2xOzyNUJtERcxiIiICIiAiIgIiICIiAiIgIvjnBjS5xDWgbknuCgY32NT2GyRyTUsRBOfeiNzcpGYuhDtyWxczz3crnOiBB7M/dA/M+Qs6lE1bEyy06ZjhlZnIuykilBk8uOEbkl3I07vLeUdowt5yHBstjcVTw8MkNGrFUikmksPbEwNDpJHl8jzt3uc5xJPnJKzVq0NKtFXrxMggiYI44omhrWNA2DQB0AA6bLKgIiIC/njxB9jLxuz3suqmsq2otK1c/OZszi43XbRigqVJYIhA8iv5xYjBABB3fufT/Q5c/wAhyzcfMByhpdX0zkec7nmaJLVHl6d2x7J3+n8qDoCIiAiIgis3p2vmWPla99DJivJWr5WqyP21Va8tLuzc9rhtzMjcWuBa4sbzNcBstV+opcRekhzcUNKpLahq0L0cjntsukb0bIOUdi/nBYASWu5o9ncz+Rs+iAirIqy6Jqh1NktrT9WCxNNWHbWrjHc3aNEI3c57QC9oiAJADGsGwDVYoJ47MLJoniSJ7Q5rm9xB7igyIiICIiAiIgIiICIiAiLFan9q1ppuR8vZsL+SMbudsN9gPOUEBZEOsr1zHu5J8JUdJTyVK5j+eO690bHBjXv8l0bQ883K1wL9m8wMcjDZFA6Dj5NF4R3a5SYyVI5i/Nn/AH3d7Q4iYDoHjm2LR0BGw6AKeQEREBERAXPuHBOq9Q6g1xvzUciIsdiHb7h9GAvInHXbaWWWZwI99G2E+jb96ltS8QsrY0pjJnR4iu8Mz+Qhc5ruXYO9pROHdI8Edo4Hdkbths+RrmXqvXiqQRwQRshhiaGMjjaGtY0DYAAdwA8yDIiIgIiICIiAoG7RfgbdrK0Ws7CeT2xkoXNlke8Nj5eeJrOby+VrByhp5+UDoepnkQa2OyNXMY+rfo2I7dK1E2eCxC4OZLG4BzXNI6EEEEH5Vsqv4WWSjqTMYuR+UtMcGZGGzbiBrxtlLmmvFKO8sdEXlrurRMzYkbBtgQEREBERAREQERQuY1tp7T9oVsnnMdj7JHN2Nm0xj9vTyk77LOmiqubUxeVtdNIqt7qWjvjTiPXY/pVZ4l3+G3FfQmZ0ln9R4qbFZSDsZQy/G17SCHMe07++a9rXDfpu0bgjotvZ8bgnlK5s7kjoXiBpeGWpow6k31NSdLSGKzuQidmJxCXDtnx83O8PjYJWv28qNzXnvKvy/nF7CngvR4K+yJ1ff1Hm8XJj8PTNbE5T2ywRXDM4fdIzvtuI2uDh3tL9j8vvT3UtHfGnEeux/SnZ8bgnlJmzuWlFVvdS0d8acR67H9Ke6lo7404j12P6U7PjcE8pM2dy0qm57O5DUGXk05puXsJIi0ZXM8vM3HsI37KLccr7Lm9zTuImuEjwd445ojJcRqus86zS+ls5UgfLHz28vFPG50LCPeVmu3Esx9OxZGOrtzysdesHg6Gm8XDjsbWbVpw8xbG0kkuc4ue9zjuXOc5znOc4lznOJJJJK1VUVUTauLJaz5gcDQ0xiK2MxlcVqVcEMZzFxJJLnOc5xLnvc4lznuJc5ziSSSSpBEWCCIiAiIgIiICIiCu2yG8Q8UN8yS/F3OkX3tHLNW/8b0Tnm+5+lgn9CsS45k/ZFcKq/EbFQy8T8LE9mNvtfEzO1Bjw4TVBtP8AdOk469mP8Ptj0LsaAiIgIiICIiDSzVx2Pw960wAvggklaD6WtJH/AMKo6SqR1sBSkA5p7MTJ55ndXzSOaC57iepJJ/R3dwVn1V+DGY+ZzfwFV7TX4OYr5pF/AF0MDRhT5rsSSIizQREQEREGrksbWy1OStajEkT/AJdi0jqHNI6tcDsQ4dQQCOq39B5SfNaLwd60/tbM9OJ8sm23O7lG7tvNueu3yrEsPCz+rnTnzGL+FY4unBnwmPxPRdi0oiLnIIiICIq3rrWcGisQLDoxZuTv7KrV5uXtX95JPma0bkn0DYbkgHZh4dWLXFFEXmRM5PLUcJUdbyNyvQqt99PalbGwflc4gKsS8YdHQvLTnIXEdN445Hj9YaQuH5O1azuR8IZWw6/e68skg8mIb+9jb3Mb0HQdTsCST1WNfW4XsPDin5tc38P7cvDuPuzaN+Gm+ry/UT3ZtG/DTfV5fqLhyLd8Dybiq5x0Lw4FxI9jppPVPsxsdqSvcjPD3JSeGMq4RSBsdhh3fBy7c33V/Keg2Ae70L3d7s2jfhpvq8v1Fw5E+B5NxVc46F4dx92bRvw031eX6i+s4yaNe7bw3G35XwyNH6y1cNRPgeTcVXOOheHpbD6gxmoa7p8XkKuQiaeVzq0rZA0+g7HofkKkF5YgMlK9HepTyUb8fvLVchr2/IehDh0HkuBB26gruvDfXw1jSmr22sgy9MNE8bPeytPdKweZpIII72kEdRsTxcu9l1ZLT7yib0+sLr1LkiIuEiL1V+DGY+ZzfwFV7TX4OYr5pF/AFYdVfgxmPmc38BVe01+DmK+aRfwBdHB7mfP9Lsb1h0jIJHQsbLMGksY53KHO26AnY7dfPsV524W8etUYzgrmNZ68xUVivUvW4Ks2Puiazdn8ISV46wh7GNrNnckbXcx5gOYhvVejV57h4Baul0DqXQU+RwsWAdfmy+By0Jldchsm8LkTZ4i0M5WvLmkteSRt0Ck32IsDfZCT6WtZmpxD0wdIWqGFlz8XtXINyEdmtE4Nla14YzaVrnMHJtsecbOIWCvxvzs9iriNT6Om0dNqDF27WEsx5Ntpz3xQ9q6KUNY0wyhh5wAXDyXeVuFG5ngRqji5kM3e4i3MNRdPp2xp+hU086WaOHt3NdJZe+VrCXbxx7MA2AB3J71u47hRrrV+qtNZHX9/BMqaap2oajMCZnvuWJ4DXdPL2jWiMCMv2Y3m6vPldAp/yEHpLjjmNNcMOC2MixbtV6o1XhGTNnyuWFRkj4oInSc072vL5XmQbN2Jds4kjZehMfNPZoVprNY07MkTXy1y8P7J5AJZzDodjuNx0Oy8/WOC2vncEMDw9sUdC6ir4+pJjpJMr7ZaOzY1rKtiPlY4smaA4uA8+3K8Ltmg9P29KaJwGFv5KTMXsdQgqT5CbfnsvZGGukO5J3cQT1JPXqSrTfaJ1YeFn9XOnPmMX8KzLDws/q5058xi/hVxe5nzj8SuxaURFzkEREBcC4s5J2S4iWIHOJixtWOCNp7muk+6PI/KOyB/5Au+rgXFnGuxnEOedzSIsnVjnjee5z4/ubwPyDsj/nC73sXN7Vp12m3p+rrslVkWvkb8WLoz25xKYYWF7xDC+V+w9DGAucfkAJVVHFvT5/us5/07kPsF9vViUUaKpiGtcnODWkkgAdST5lxOl7KDD3chUeyDHnCW7bKkU7M1A695T+RsjqY8sMLiD74uDTuWhXtnFHT997avY5o9uez2fp++xp36dXGAADr3k7KvcPtCau0HFj9Ptfp+9pmhI5sV6Zsovur7ktYWAcnMNwOfm7h73deTErrrqp9zVo22tO637Vin43X68OUyUmli3T2LzMmHuX/CDe0aW2BCJWRcnlN3c0kFzSNyBzAbnX4mcUMxNh9c0dL4Sa5BhaM8V3NNvisas5gL9oRsS98bXNcdi3Y9Ad1nyPCbL2+HWsMAyzSFzMZ2bJ13ue/s2xPtsmAeeTcO5WkbAEb+fzrBqHhprCv484/TlnCyYTVQmmkGTdMyarYlgEUhbyNIe13K09dtj6fPoqnKM2030x4X2/wdH0XPLa0dgpppHzTSUIHvkkcXOc4xtJJJ7yT51MKi4/W+K0bjKGDvtykl3H1oa0zqeFvTxFzY2glsjIS1w+UFZ/dd08f7rO/9O5D7Be2nFw4iImqL+aLmpbRWSdh9e4CyxxaJpzSlA/tslaQB/rEbv8qreFzVbP46O7UFhsDyQBarS15Oh2O7JGtcO7zjqrJonGuzOvcBWY3mbBObspH9hkbSQf8AWYx/mUyiaJwK5q1Wn8Mqdb0giIvzBUXqr8GMx8zm/gKr2mvwcxXzSL+AK05mm7I4i9UYQHzwSRAnzFzSP/tVDSVyOxgacIPJZrQsgsQO6Phka0BzHA9QQf1jYjoQuhgacKY8V2JhERZoIiICIiAsPCz+rnTnzGL+FY8nlK2IqPs2pRHG3oB3ue49A1rR1c4kgBo3JJAHUqQ0Ji58JozCUbTOzswU4mSx778j+Ubt38+x6b/IscXRgz4zH4nquxOoiLnIIiICrmudGQa1w4rPkFa3C/tatrl5jE/u6jpu0jcEb9x6EEAixotmHiVYVcV0TaYHl3K1LWn8h7Qy1c4+515WvO7JR/ijf3PHd3dRuNw09FjXpzJYulmaj6t+pBerP99DZibIw/laQQqxLwg0dK4uOBrtJ67RuewfqBAX1uF7cw5p+bRN/D+locKRdy9xvRvwHF+1k+snuN6N+A4v2sn1lu+OZNw1co6locNRdy9xvRvwHF+1k+snuN6N+A4v2sn1k+OZNw1co6locNRdy9xvRvwHF+1k+svrODujWO38BQO+R73uH6i7ZPjmTcNXKOpaN7hdYS5C8yjRgkv33+9q1wHPPynrs0dR5TiAN+pXduHGgho2jNPaeyfL2+UzyM95G0e9iYe8tBJO56uJJ2A2a2xYjBY3AVzBjKFbHwk7llaJsYcfSdh1Pylb64mXe1Ksrp93RFqfWV1ahERcNBQuY0Vp/UNgWMpg8bkZwOUS2qkcjwPRu4E7KaRZU11UTembSalW9yvRnxTwn7vi+qnuV6M+KeE/d8X1VaUW7tGNxzzlbzvVb3K9GfFPCfu+L6qe5Xoz4p4T93xfVVpRO0Y3HPOS871W9yvRnxTwn7vi+qnuV6M+KeE/d8X1VaUTtGNxzzkvO9B4rQ2nMFZbZx2AxlCw3flmrVI43t379iBuN1OIi1VV1VzeqbprERFgCIiAiIgIiICIiAiIgIiICIiAiIg//9k=", + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from langgraph.graph import START, StateGraph\n", + "from IPython.display import Image, display\n", + "\n", + "# Define a new graph for the agent\n", + "builder = StateGraph(GraphState)\n", + "\n", + "# Define the two nodes we will cycle between\n", + "builder.add_node(\"assistant\", assistant)\n", + "builder.add_node(\"tools\", tool_node)\n", + "\n", + "# Set the entrypoint as `agent`\n", + "builder.add_edge(START, \"assistant\")\n", + "\n", + "# Making a conditional edge\n", + "# should_continue will determine which node is called next.\n", + "builder.add_conditional_edges(\"assistant\", should_continue, [\"tools\", END])\n", + "\n", + "# Making a normal edge from `tools` to `agent`.\n", + "# The `agent` node will be called after the `tool`.\n", + "builder.add_edge(\"tools\", \"assistant\")\n", + "\n", + "# Compile and display the graph for a visual overview\n", + "react_graph = builder.compile()\n", + "display(Image(react_graph.get_graph(xray=True).draw_mermaid_png()))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "wlNB4fI4ZQv5" + }, + "source": [ + "To test our setup, we will run the agent with a query. The agent will fetch the price of copper using the metals.dev API." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "id": "rzt0I-n2ZQv5" + }, + "outputs": [], + "source": [ + "from langchain_core.messages import HumanMessage\n", + "\n", + "messages = [HumanMessage(content=\"What is the price of copper?\")]\n", + "result = react_graph.invoke({\"messages\": messages})" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "esoHsop8ZQv5", + "outputId": "0d52f2db-f2da-4f5a-943e-e549b731f01e" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "[HumanMessage(content='What is the price of copper?', id='4122f5d4-e298-49e8-a0e0-c98adda78c6c'),\n", + " AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_DkVQBK4UMgiXrpguUS2qC4mA', 'function': {'arguments': '{\"metal_name\":\"copper\"}', 'name': 'get_metal_price'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 18, 'prompt_tokens': 116, 'total_tokens': 134, 'prompt_tokens_details': {'cached_tokens': 0, 'audio_tokens': 0}, 'completion_tokens_details': {'reasoning_tokens': 0, 'audio_tokens': 0, 'accepted_prediction_tokens': 0, 'rejected_prediction_tokens': 0}}, 'model_name': 'gpt-4o-mini-2024-07-18', 'system_fingerprint': 'fp_0ba0d124f1', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-0f77b156-e43e-4c1e-bd3a-307333eefb68-0', tool_calls=[{'name': 'get_metal_price', 'args': {'metal_name': 'copper'}, 'id': 'call_DkVQBK4UMgiXrpguUS2qC4mA', 'type': 'tool_call'}], usage_metadata={'input_tokens': 116, 'output_tokens': 18, 'total_tokens': 134}),\n", + " ToolMessage(content='0.0098', name='get_metal_price', id='422c089a-6b76-4e48-952f-8925c3700ae3', tool_call_id='call_DkVQBK4UMgiXrpguUS2qC4mA'),\n", + " AIMessage(content='The price of copper is $0.0098 per gram.', response_metadata={'token_usage': {'completion_tokens': 14, 'prompt_tokens': 148, 'total_tokens': 162, 'prompt_tokens_details': {'cached_tokens': 0, 'audio_tokens': 0}, 'completion_tokens_details': {'reasoning_tokens': 0, 'audio_tokens': 0, 'accepted_prediction_tokens': 0, 'rejected_prediction_tokens': 0}}, 'model_name': 'gpt-4o-mini-2024-07-18', 'system_fingerprint': 'fp_0ba0d124f1', 'finish_reason': 'stop', 'logprobs': None}, id='run-67cbf98b-4fa6-431e-9ce4-58697a76c36e-0', usage_metadata={'input_tokens': 148, 'output_tokens': 14, 'total_tokens': 162})]" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "result[\"messages\"]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "wsK_VEDSZQv6" + }, + "source": [ + "### Converting Messages to Ragas Evaluation Format\n", + "\n", + "In the current implementation, the GraphState stores messages exchanged between the human user, the AI (LLM's responses), and any external tools (APIs or services the AI uses) in a list. Each message is an object in LangChain's format\n", + "\n", + "```python\n", + "# Implementation of Graph State\n", + "class GraphState(TypedDict):\n", + " messages: Annotated[list[AnyMessage], add_messages]\n", + "```\n", + "\n", + "Each time a message is exchanged during agent execution, it gets added to the messages list in the GraphState. However, Ragas requires a specific message format for evaluating interactions.\n", + "\n", + "Ragas uses its own format to evaluate agent interactions. So, if you're using LangGraph, you will need to convert the LangChain message objects into Ragas message objects. This allows you to evaluate your AI agents with Ragas’ built-in evaluation tools.\n", + "\n", + "**Goal:** Convert the list of LangChain messages (e.g., HumanMessage, AIMessage, and ToolMessage) into the format expected by Ragas, so the evaluation framework can understand and process them properly." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To convert a list of LangChain messages into a format suitable for Ragas evaluation, Ragas provides the function [convert_to_ragas_messages][ragas.integrations.langgraph.convert_to_ragas_messages], which can be used to transform LangChain messages into the format expected by Ragas.\n", + "\n", + "Here's how you can use the function:" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "from ragas.integrations.langgraph import convert_to_ragas_messages\n", + "\n", + "# Assuming 'result[\"messages\"]' contains the list of LangChain messages\n", + "ragas_trace = convert_to_ragas_messages(result[\"messages\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[HumanMessage(content='What is the price of copper?', metadata=None, type='human'),\n", + " AIMessage(content='', metadata=None, type='ai', tool_calls=[ToolCall(name='get_metal_price', args={'metal_name': 'copper'})]),\n", + " ToolMessage(content='0.0098', metadata=None, type='tool'),\n", + " AIMessage(content='The price of copper is $0.0098 per gram.', metadata=None, type='ai', tool_calls=None)]" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ragas_trace # List of Ragas messages" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "n5mbTp5aZQv6" + }, + "source": [ + "## Evaluating the Agent's Performance" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "H885v5sxZQv6" + }, + "source": [ + "For this tutorial, let us evaluate the Agent with the following metrics:\n", + "\n", + "- [Tool call Accuracy](https://docs.ragas.io/en/stable/concepts/metrics/available_metrics/agents/#tool-call-accuracy):ToolCallAccuracy is a metric that can be used to evaluate the performance of the LLM in identifying and calling the required tools to complete a given task. \n", + "\n", + "- [Agent Goal accuracy](https://docs.ragas.io/en/stable/concepts/metrics/available_metrics/agents/#agent-goal-accuracy): Agent goal accuracy is a metric that can be used to evaluate the performance of the LLM in identifying and achieving the goals of the user. This is a binary metric, with 1 indicating that the AI has achieved the goal and 0 indicating that the AI has not achieved the goal.\n", + "\n", + "\n", + "First, let us actually run our Agent with a couple of queries, and make sure we have the ground truth labels for these queries." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "7kRRIyTAZQv6" + }, + "source": [ + "### Tool Call Accuracy" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "CC973Yq1ZQv6", + "outputId": "d5bf508d-f3ba-4f2e-a4c6-e6efbf229603" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "1.0" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from ragas.metrics import ToolCallAccuracy\n", + "from ragas.dataset_schema import MultiTurnSample\n", + "from ragas.integrations.langgraph import convert_to_ragas_messages\n", + "import ragas.messages as r\n", + "\n", + "\n", + "ragas_trace = convert_to_ragas_messages(\n", + " messages=result[\"messages\"]\n", + ") # List of Ragas messages converted using the Ragas function\n", + "\n", + "sample = MultiTurnSample(\n", + " user_input=ragas_trace,\n", + " reference_tool_calls=[\n", + " r.ToolCall(name=\"get_metal_price\", args={\"metal_name\": \"copper\"})\n", + " ],\n", + ")\n", + "\n", + "tool_accuracy_scorer = ToolCallAccuracy()\n", + "tool_accuracy_scorer.llm = ChatOpenAI(model=\"gpt-4o-mini\")\n", + "await tool_accuracy_scorer.multi_turn_ascore(sample)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Tool Call Accuracy: 1, because the LLM correctly identified and used the necessary tool (get_metal_price) with the correct parameters (i.e., metal name as \"copper\")." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "rGOL1CBsZQv6" + }, + "source": [ + "### Agent Goal Accuracy" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "id": "FA0kMvTfZQwB" + }, + "outputs": [], + "source": [ + "messages = [HumanMessage(content=\"What is the price of 10 grams of silver?\")]\n", + "\n", + "result = react_graph.invoke({\"messages\": messages})" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "YJr4Hxn8ZQwB", + "outputId": "9797c93b-47a2-4264-b535-f182effb396b" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "[HumanMessage(content='What is the price of 10 grams of silver?', id='51a469de-5b7c-4d01-ab71-f8db64c8da49'),\n", + " AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_rdplOo95CRwo3mZcPu4dmNxG', 'function': {'arguments': '{\"metal_name\":\"silver\"}', 'name': 'get_metal_price'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 17, 'prompt_tokens': 120, 'total_tokens': 137, 'prompt_tokens_details': {'cached_tokens': 0, 'audio_tokens': 0}, 'completion_tokens_details': {'reasoning_tokens': 0, 'audio_tokens': 0, 'accepted_prediction_tokens': 0, 'rejected_prediction_tokens': 0}}, 'model_name': 'gpt-4o-mini-2024-07-18', 'system_fingerprint': 'fp_0ba0d124f1', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-3bb60e27-1275-41f1-a46e-03f77984c9d8-0', tool_calls=[{'name': 'get_metal_price', 'args': {'metal_name': 'silver'}, 'id': 'call_rdplOo95CRwo3mZcPu4dmNxG', 'type': 'tool_call'}], usage_metadata={'input_tokens': 120, 'output_tokens': 17, 'total_tokens': 137}),\n", + " ToolMessage(content='1.0523', name='get_metal_price', id='0b5f9260-df26-4164-b042-6df2e869adfb', tool_call_id='call_rdplOo95CRwo3mZcPu4dmNxG'),\n", + " AIMessage(content='The current price of silver is approximately $1.0523 per gram. Therefore, the price of 10 grams of silver would be about $10.52.', response_metadata={'token_usage': {'completion_tokens': 34, 'prompt_tokens': 151, 'total_tokens': 185, 'prompt_tokens_details': {'cached_tokens': 0, 'audio_tokens': 0}, 'completion_tokens_details': {'reasoning_tokens': 0, 'audio_tokens': 0, 'accepted_prediction_tokens': 0, 'rejected_prediction_tokens': 0}}, 'model_name': 'gpt-4o-mini-2024-07-18', 'system_fingerprint': 'fp_0ba0d124f1', 'finish_reason': 'stop', 'logprobs': None}, id='run-93e38f71-cc9d-41d6-812a-bfad9f9231b2-0', usage_metadata={'input_tokens': 151, 'output_tokens': 34, 'total_tokens': 185})]" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "result[\"messages\"] # List of Langchain messages" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "StDNqR2vZQwB", + "outputId": "47e914a4-3e48-4932-8b20-752441b42fd4" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "[HumanMessage(content='What is the price of 10 grams of silver?', metadata=None, type='human'),\n", + " AIMessage(content='', metadata=None, type='ai', tool_calls=[ToolCall(name='get_metal_price', args={'metal_name': 'silver'})]),\n", + " ToolMessage(content='1.0523', metadata=None, type='tool'),\n", + " AIMessage(content='The current price of silver is approximately $1.0523 per gram. Therefore, the price of 10 grams of silver would be about $10.52.', metadata=None, type='ai', tool_calls=None)]" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from ragas.integrations.langgraph import convert_to_ragas_messages\n", + "\n", + "ragas_trace = convert_to_ragas_messages(\n", + " result[\"messages\"]\n", + ") # List of Ragas messages converted using the Ragas function\n", + "ragas_trace" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "c6u9-RYdZQwB", + "outputId": "ebf8fdd8-88fc-47c3-e1e2-b401956c0633" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "1.0" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from ragas.dataset_schema import MultiTurnSample\n", + "from ragas.metrics import AgentGoalAccuracyWithReference\n", + "from ragas.llms import LangchainLLMWrapper\n", + "\n", + "\n", + "sample = MultiTurnSample(\n", + " user_input=ragas_trace,\n", + " reference=\"Price of 10 grams of silver\",\n", + ")\n", + "\n", + "scorer = AgentGoalAccuracyWithReference()\n", + "\n", + "evaluator_llm = LangchainLLMWrapper(ChatOpenAI(model=\"gpt-4o-mini\"))\n", + "scorer.llm = evaluator_llm\n", + "await scorer.multi_turn_ascore(sample)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Agent Goal Accuracy: 1, because the LLM correctly achieved the user’s goal of retrieving the price of 10 grams of silver." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "18wmDI0xZQwB" + }, + "source": [ + "## What’s next\n", + "🎉 Congratulations! We have learned how to evaluate an agent using the Ragas evaluation framework." + ] + } + ], + "metadata": { + "colab": { + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "ragas", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/docs/references/integrations.md b/docs/references/integrations.md index a1f9a996c..dd8069153 100644 --- a/docs/references/integrations.md +++ b/docs/references/integrations.md @@ -16,3 +16,7 @@ ::: ragas.integrations.helicone options: show_root_heading: true + +::: ragas.integrations.langgraph + options: + show_root_heading: true diff --git a/mkdocs.yml b/mkdocs.yml index 3874ab187..f08987112 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -90,6 +90,7 @@ nav: - Integrations: - howtos/integrations/index.md - LlamaIndex: howtos/integrations/_llamaindex.md + - LangGraph: howtos/integrations/_langgraph_agent_evaluation.md - Migrations: - From v0.1 to v0.2: howtos/migrations/migrate_from_v01_to_v02.md - 📖 References: diff --git a/src/ragas/evaluation.py b/src/ragas/evaluation.py index 94faa877a..e9e5b5bb8 100644 --- a/src/ragas/evaluation.py +++ b/src/ragas/evaluation.py @@ -7,9 +7,8 @@ from langchain_core.callbacks import BaseCallbackHandler, BaseCallbackManager from langchain_core.embeddings import Embeddings as LangchainEmbeddings from langchain_core.language_models import BaseLanguageModel as LangchainLLM - -from llama_index.core.base.llms.base import BaseLLM as LlamaIndexLLM from llama_index.core.base.embeddings.base import BaseEmbedding as LlamaIndexEmbedding +from llama_index.core.base.llms.base import BaseLLM as LlamaIndexLLM from ragas._analytics import EvaluationEvent, track, track_was_completed from ragas.callbacks import ChainType, RagasTracer, new_group @@ -61,7 +60,9 @@ def evaluate( dataset: t.Union[Dataset, EvaluationDataset], metrics: t.Optional[t.Sequence[Metric]] = None, llm: t.Optional[BaseRagasLLM | LangchainLLM | LlamaIndexLLM] = None, - embeddings: t.Optional[BaseRagasEmbeddings | LangchainEmbeddings | LlamaIndexEmbedding] = None, + embeddings: t.Optional[ + BaseRagasEmbeddings | LangchainEmbeddings | LlamaIndexEmbedding + ] = None, callbacks: Callbacks = None, in_ci: bool = False, run_config: RunConfig = RunConfig(), diff --git a/src/ragas/integrations/langgraph.py b/src/ragas/integrations/langgraph.py new file mode 100644 index 000000000..9e9828f2e --- /dev/null +++ b/src/ragas/integrations/langgraph.py @@ -0,0 +1,85 @@ +import json +from typing import List, Union + +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage + +import ragas.messages as r + + +def convert_to_ragas_messages( + messages: List[Union[HumanMessage, SystemMessage, AIMessage, ToolMessage]] +) -> List[Union[r.HumanMessage, r.AIMessage, r.ToolMessage]]: + """ + Convert LangChain messages into Ragas messages for agent evaluation. + + Parameters + ---------- + messages : List[Union[HumanMessage, SystemMessage, AIMessage, ToolMessage]] + List of LangChain message objects to be converted. + + Returns + ------- + List[Union[r.HumanMessage, r.AIMessage, r.ToolMessage]] + List of corresponding Ragas message objects. + + Raises + ------ + ValueError + If an unsupported message type is encountered. + TypeError + If message content is not a string. + + Notes + ----- + SystemMessages are skipped in the conversion process. + """ + + def _validate_string_content(message, message_type: str) -> str: + if not isinstance(message.content, str): + raise TypeError( + f"{message_type} content must be a string, got {type(message.content).__name__}. " + f"Content: {message.content}" + ) + return message.content + + MESSAGE_TYPE_MAP = { + HumanMessage: lambda m: r.HumanMessage( + content=_validate_string_content(m, "HumanMessage") + ), + ToolMessage: lambda m: r.ToolMessage( + content=_validate_string_content(m, "ToolMessage") + ), + } + + def _extract_tool_calls(message: AIMessage) -> List[r.ToolCall]: + tool_calls = message.additional_kwargs.get("tool_calls", []) + return [ + r.ToolCall( + name=tool_call["function"]["name"], + args=json.loads(tool_call["function"]["arguments"]), + ) + for tool_call in tool_calls + ] + + def _convert_ai_message(message: AIMessage) -> r.AIMessage: + tool_calls = _extract_tool_calls(message) if message.additional_kwargs else None + return r.AIMessage( + content=_validate_string_content(message, "AIMessage"), + tool_calls=tool_calls, + ) + + def _convert_message(message): + if isinstance(message, SystemMessage): + return None # Skip SystemMessages + if isinstance(message, AIMessage): + return _convert_ai_message(message) + converter = MESSAGE_TYPE_MAP.get(type(message)) + if converter is None: + raise ValueError(f"Unsupported message type: {type(message).__name__}") + return converter(message) + + return [ + converted + for message in messages + if (converted := _convert_message(message)) is not None + ] diff --git a/src/ragas/metrics/_topic_adherence.py b/src/ragas/metrics/_topic_adherence.py index ae55ddffb..1737f7f5a 100644 --- a/src/ragas/metrics/_topic_adherence.py +++ b/src/ragas/metrics/_topic_adherence.py @@ -48,9 +48,7 @@ class TopicClassificationOutput(BaseModel): class TopicClassificationPrompt( PydanticPrompt[TopicClassificationInput, TopicClassificationOutput] ): - instruction = ( - "Given a set of topics classify if the topic falls into any of the given reference topics." - ) + instruction = "Given a set of topics classify if the topic falls into any of the given reference topics." input_model = TopicClassificationInput output_model = TopicClassificationOutput examples = [ @@ -149,10 +147,14 @@ class TopicAdherenceScore(MetricWithLLM, MultiTurnMetric): topic_classification_prompt: PydanticPrompt = TopicClassificationPrompt() topic_refused_prompt: PydanticPrompt = TopicRefusedPrompt() - async def _multi_turn_ascore(self, sample: MultiTurnSample, callbacks: Callbacks) -> float: + async def _multi_turn_ascore( + self, sample: MultiTurnSample, callbacks: Callbacks + ) -> float: assert self.llm is not None, "LLM must be set" assert isinstance(sample.user_input, list), "Sample user_input must be a list" - assert isinstance(sample.reference_topics, list), "Sample reference_topics must be a list" + assert isinstance( + sample.reference_topics, list + ), "Sample reference_topics must be a list" user_input = sample.pretty_repr() prompt_input = TopicExtractionInput(user_input=user_input) @@ -168,7 +170,9 @@ async def _multi_turn_ascore(self, sample: MultiTurnSample, callbacks: Callbacks data=prompt_input, llm=self.llm, callbacks=callbacks ) topic_answered_verdict.append(response.refused_to_answer) - topic_answered_verdict = np.array([not answer for answer in topic_answered_verdict]) + topic_answered_verdict = np.array( + [not answer for answer in topic_answered_verdict] + ) prompt_input = TopicClassificationInput( reference_topics=sample.reference_topics, topics=topics diff --git a/tests/unit/test_langgraph.py b/tests/unit/test_langgraph.py new file mode 100644 index 000000000..9d94080e4 --- /dev/null +++ b/tests/unit/test_langgraph.py @@ -0,0 +1,129 @@ +import json + +import pytest +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage + +import ragas.messages as r +from ragas.integrations.langgraph import convert_to_ragas_messages + + +def test_human_message_conversion(): + """Test conversion of HumanMessage with valid string content""" + messages = [ + HumanMessage(content="Hello, add 4 and 5"), + ToolMessage(content="9", tool_call_id="1"), + ] + result = convert_to_ragas_messages(messages) + + assert len(result) == 2 + assert isinstance(result[0], r.HumanMessage) + assert result[0].content == "Hello, add 4 and 5" + + +def test_human_message_invalid_content(): + """Test HumanMessage with invalid content type raises TypeError""" + messages = [HumanMessage(content=["invalid", "content"])] + + with pytest.raises(TypeError) as exc_info: + convert_to_ragas_messages(messages) + assert "HumanMessage content must be a string" in str(exc_info.value) + + +def test_ai_message_conversion(): + """Test conversion of AIMessage with valid string content""" + messages = [AIMessage(content="I'm doing well, thanks!")] + result = convert_to_ragas_messages(messages) + + assert len(result) == 1 + assert isinstance(result[0], r.AIMessage) + assert result[0].content == "I'm doing well, thanks!" + assert result[0].tool_calls is None + + +def test_ai_message_with_tool_calls(): + """Test conversion of AIMessage with tool calls""" + + tool_calls = [ + { + "function": { + "arguments": '{"metal_name": "gold"}', + "name": "get_metal_price", + } + }, + { + "function": { + "arguments": '{"metal_name": "silver"}', + "name": "get_metal_price", + } + }, + ] + + messages = [ + AIMessage( + content="Find the difference in the price of gold and silver?", + additional_kwargs={"tool_calls": tool_calls}, + ) + ] + + result = convert_to_ragas_messages(messages) + assert len(result) == 1 + assert isinstance(result[0], r.AIMessage) + assert result[0].content == "Find the difference in the price of gold and silver?" + assert len(result[0].tool_calls) == 2 + assert result[0].tool_calls[0].name == "get_metal_price" + assert result[0].tool_calls[0].args == {"metal_name": "gold"} + assert result[0].tool_calls[1].name == "get_metal_price" + assert result[0].tool_calls[1].args == {"metal_name": "silver"} + + +def test_tool_message_conversion(): + """Test conversion of ToolMessage with valid string content""" + messages = [ + HumanMessage(content="Hello, add 4 and 5"), + ToolMessage(content="9", tool_call_id="2"), + ] + result = convert_to_ragas_messages(messages) + + assert len(result) == 2 + assert isinstance(result[1], r.ToolMessage) + assert result[1].content == "9" + + +def test_system_message_skipped(): + """Test that SystemMessages are properly skipped""" + messages = [SystemMessage(content="System prompt"), HumanMessage(content="Hello")] + result = convert_to_ragas_messages(messages) + + assert len(result) == 1 + assert isinstance(result[0], r.HumanMessage) + assert result[0].content == "Hello" + + +def test_unsupported_message_type(): + """Test that unsupported message types raise ValueError""" + + class CustomMessage: + content = "test" + + messages = [CustomMessage()] + + with pytest.raises(ValueError) as exc_info: + convert_to_ragas_messages(messages) + assert "Unsupported message type: CustomMessage" in str(exc_info.value) + + +def test_empty_message_list(): + """Test conversion of empty message list""" + messages = [] + result = convert_to_ragas_messages(messages) + assert result == [] + + +def test_invalid_tool_calls_json(): + """Test handling of invalid JSON in tool calls""" + tool_calls = [{"function": {"name": "search", "arguments": "invalid json"}}] + + messages = [AIMessage(content="Test", additional_kwargs={"tool_calls": tool_calls})] + + with pytest.raises(json.JSONDecodeError): + convert_to_ragas_messages(messages)