diff --git a/docs/agents.md b/docs/agents.md index 208d92a583..5b332e0b27 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -826,7 +826,7 @@ with capture_run_messages() as messages: # (2)! result = agent.run_sync('Please get me the volume of a box with size 6.') except UnexpectedModelBehavior as e: print('An error occurred:', e) - #> An error occurred: Tool exceeded max retries count of 1 + #> An error occurred: Tool 'calc_volume' exceeded max retries count of 1 print('cause:', repr(e.__cause__)) #> cause: ModelRetry('Please try again.') print('messages:', messages) diff --git a/docs/api/ext.md b/docs/api/ext.md new file mode 100644 index 0000000000..7f01b44d45 --- /dev/null +++ b/docs/api/ext.md @@ -0,0 +1,5 @@ +# `pydantic_ai.ext` + +::: pydantic_ai.ext.langchain + +::: pydantic_ai.ext.aci diff --git a/docs/api/output.md b/docs/api/output.md index 135ff597bc..bb584608c7 100644 --- a/docs/api/output.md +++ b/docs/api/output.md @@ -10,3 +10,4 @@ - PromptedOutput - TextOutput - StructuredDict + - DeferredToolCalls diff --git a/docs/api/toolsets.md b/docs/api/toolsets.md new file mode 100644 index 0000000000..8146864076 --- /dev/null +++ b/docs/api/toolsets.md @@ -0,0 +1,14 @@ +# `pydantic_ai.toolsets` + +::: pydantic_ai.toolsets + options: + members: + - AbstractToolset + - CombinedToolset + - DeferredToolset + - FilteredToolset + - FunctionToolset + - PrefixedToolset + - RenamedToolset + - PreparedToolset + - WrapperToolset diff --git a/docs/mcp/client.md b/docs/mcp/client.md index 7f8c5fdd6a..15ef46f2e2 100644 --- a/docs/mcp/client.md +++ b/docs/mcp/client.md @@ -16,42 +16,54 @@ pip/uv-add "pydantic-ai-slim[mcp]" ## Usage -PydanticAI comes with two ways to connect to MCP servers: +PydanticAI comes with three ways to connect to MCP servers: -- [`MCPServerSSE`][pydantic_ai.mcp.MCPServerSSE] which connects to an MCP server using the [HTTP SSE](https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#http-with-sse) transport - [`MCPServerStreamableHTTP`][pydantic_ai.mcp.MCPServerStreamableHTTP] which connects to an MCP server using the [Streamable HTTP](https://modelcontextprotocol.io/introduction#streamable-http) transport +- [`MCPServerSSE`][pydantic_ai.mcp.MCPServerSSE] which connects to an MCP server using the [HTTP SSE](https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#http-with-sse) transport - [`MCPServerStdio`][pydantic_ai.mcp.MCPServerStdio] which runs the server as a subprocess and connects to it using the [stdio](https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#stdio) transport -Examples of both are shown below; [mcp-run-python](run-python.md) is used as the MCP server in both examples. +Examples of all three are shown below; [mcp-run-python](run-python.md) is used as the MCP server in all examples. -### SSE Client +Each MCP server instance is a [toolset](../toolsets.md) and can be registered with an [`Agent`][pydantic_ai.Agent] using the `toolsets` argument. -[`MCPServerSSE`][pydantic_ai.mcp.MCPServerSSE] connects over HTTP using the [HTTP + Server Sent Events transport](https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#http-with-sse) to a server. +You can use the [`async with agent`][pydantic_ai.Agent.__aenter__] context manager to open and close connections to all registered servers (and in the case of stdio servers, start and stop the subprocesses) around the context where they'll be used in agent runs. You can also use [`async with server`][pydantic_ai.mcp.MCPServer.__aenter__] to manage the connection or subprocess of a specific server, for example if you'd like to use it with multiple agents. If you don't explicitly enter one of these context managers to set up the server, this will be done automatically when it's needed (e.g. to list the available tools or call a specific tool), but it's more efficient to do so around the entire context where you expect the servers to be used. + +### Streamable HTTP Client + +[`MCPServerStreamableHTTP`][pydantic_ai.mcp.MCPServerStreamableHTTP] connects over HTTP using the +[Streamable HTTP](https://modelcontextprotocol.io/introduction#streamable-http) transport to a server. !!! note - [`MCPServerSSE`][pydantic_ai.mcp.MCPServerSSE] requires an MCP server to be running and accepting HTTP connections before calling [`agent.run_mcp_servers()`][pydantic_ai.Agent.run_mcp_servers]. Running the server is not managed by PydanticAI. + [`MCPServerStreamableHTTP`][pydantic_ai.mcp.MCPServerStreamableHTTP] requires an MCP server to be + running and accepting HTTP connections before running the agent. Running the server is not + managed by Pydantic AI. -The name "HTTP" is used since this implementation will be adapted in future to use the new -[Streamable HTTP](https://github.com/modelcontextprotocol/specification/pull/206) currently in development. +Before creating the Streamable HTTP client, we need to run a server that supports the Streamable HTTP transport. -Before creating the SSE client, we need to run the server (docs [here](run-python.md)): +```python {title="streamable_http_server.py" py="3.10" dunder_name="not_main"} +from mcp.server.fastmcp import FastMCP -```bash {title="terminal (run sse server)"} -deno run \ - -N -R=node_modules -W=node_modules --node-modules-dir=auto \ - jsr:@pydantic/mcp-run-python sse +app = FastMCP() + +@app.tool() +def add(a: int, b: int) -> int: + return a + b + +if __name__ == '__main__': + app.run(transport='streamable-http') ``` -```python {title="mcp_sse_client.py" py="3.10"} -from pydantic_ai import Agent -from pydantic_ai.mcp import MCPServerSSE +Then we can create the client: -server = MCPServerSSE(url='http://localhost:3001/sse') # (1)! -agent = Agent('openai:gpt-4o', mcp_servers=[server]) # (2)! +```python {title="mcp_streamable_http_client.py" py="3.10"} +from pydantic_ai import Agent +from pydantic_ai.mcp import MCPServerStreamableHTTP +server = MCPServerStreamableHTTP('http://localhost:8000/mcp') # (1)! +agent = Agent('openai:gpt-4o', toolsets=[server]) # (2)! async def main(): - async with agent.run_mcp_servers(): # (3)! + async with agent: # (3)! result = await agent.run('How many days between 2000-01-01 and 2025-03-18?') print(result.output) #> There are 9,208 days between January 1, 2000, and March 18, 2025. @@ -85,43 +97,34 @@ Will display as follows: ![Logfire run python code](../img/logfire-run-python-code.png) -### Streamable HTTP Client +### SSE Client -[`MCPServerStreamableHTTP`][pydantic_ai.mcp.MCPServerStreamableHTTP] connects over HTTP using the -[Streamable HTTP](https://modelcontextprotocol.io/introduction#streamable-http) transport to a server. +[`MCPServerSSE`][pydantic_ai.mcp.MCPServerSSE] connects over HTTP using the [HTTP + Server Sent Events transport](https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#http-with-sse) to a server. !!! note - [`MCPServerStreamableHTTP`][pydantic_ai.mcp.MCPServerStreamableHTTP] requires an MCP server to be - running and accepting HTTP connections before calling - [`agent.run_mcp_servers()`][pydantic_ai.Agent.run_mcp_servers]. Running the server is not - managed by PydanticAI. - -Before creating the Streamable HTTP client, we need to run a server that supports the Streamable HTTP transport. - -```python {title="streamable_http_server.py" py="3.10" dunder_name="not_main"} -from mcp.server.fastmcp import FastMCP + [`MCPServerSSE`][pydantic_ai.mcp.MCPServerSSE] requires an MCP server to be running and accepting HTTP connections before running the agent. Running the server is not managed by Pydantic AI. -app = FastMCP() +The name "HTTP" is used since this implementation will be adapted in future to use the new +[Streamable HTTP](https://github.com/modelcontextprotocol/specification/pull/206) currently in development. -@app.tool() -def add(a: int, b: int) -> int: - return a + b +Before creating the SSE client, we need to run the server (docs [here](run-python.md)): -if __name__ == '__main__': - app.run(transport='streamable-http') +```bash {title="terminal (run sse server)"} +deno run \ + -N -R=node_modules -W=node_modules --node-modules-dir=auto \ + jsr:@pydantic/mcp-run-python sse ``` -Then we can create the client: - -```python {title="mcp_streamable_http_client.py" py="3.10"} +```python {title="mcp_sse_client.py" py="3.10"} from pydantic_ai import Agent -from pydantic_ai.mcp import MCPServerStreamableHTTP +from pydantic_ai.mcp import MCPServerSSE + +server = MCPServerSSE(url='http://localhost:3001/sse') # (1)! +agent = Agent('openai:gpt-4o', toolsets=[server]) # (2)! -server = MCPServerStreamableHTTP('http://localhost:8000/mcp') # (1)! -agent = Agent('openai:gpt-4o', mcp_servers=[server]) # (2)! async def main(): - async with agent.run_mcp_servers(): # (3)! + async with agent: # (3)! result = await agent.run('How many days between 2000-01-01 and 2025-03-18?') print(result.output) #> There are 9,208 days between January 1, 2000, and March 18, 2025. @@ -137,9 +140,6 @@ _(This example is complete, it can be run "as is" with Python 3.10+ — you'll n The other transport offered by MCP is the [stdio transport](https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#stdio) where the server is run as a subprocess and communicates with the client over `stdin` and `stdout`. In this case, you'd use the [`MCPServerStdio`][pydantic_ai.mcp.MCPServerStdio] class. -!!! note - When using [`MCPServerStdio`][pydantic_ai.mcp.MCPServerStdio] servers, the [`agent.run_mcp_servers()`][pydantic_ai.Agent.run_mcp_servers] context manager is responsible for starting and stopping the server. - ```python {title="mcp_stdio_client.py" py="3.10"} from pydantic_ai import Agent from pydantic_ai.mcp import MCPServerStdio @@ -156,11 +156,11 @@ server = MCPServerStdio( # (1)! 'stdio', ] ) -agent = Agent('openai:gpt-4o', mcp_servers=[server]) +agent = Agent('openai:gpt-4o', toolsets=[server]) async def main(): - async with agent.run_mcp_servers(): + async with agent: result = await agent.run('How many days between 2000-01-01 and 2025-03-18?') print(result.output) #> There are 9,208 days between January 1, 2000, and March 18, 2025. @@ -188,23 +188,23 @@ from pydantic_ai.tools import RunContext async def process_tool_call( ctx: RunContext[int], call_tool: CallToolFunc, - tool_name: str, - args: dict[str, Any], + name: str, + tool_args: dict[str, Any], ) -> ToolResult: """A tool call processor that passes along the deps.""" - return await call_tool(tool_name, args, metadata={'deps': ctx.deps}) + return await call_tool(name, tool_args, {'deps': ctx.deps}) server = MCPServerStdio('python', ['mcp_server.py'], process_tool_call=process_tool_call) agent = Agent( model=TestModel(call_tools=['echo_deps']), deps_type=int, - mcp_servers=[server] + toolsets=[server] ) async def main(): - async with agent.run_mcp_servers(): + async with agent: result = await agent.run('Echo with deps set to 42', deps=42) print(result.output) #> {"echo_deps":{"echo":"This is an echo message","deps":42}} @@ -214,15 +214,7 @@ async def main(): When connecting to multiple MCP servers that might provide tools with the same name, you can use the `tool_prefix` parameter to avoid naming conflicts. This parameter adds a prefix to all tool names from a specific server. -### How It Works - -- If `tool_prefix` is set, all tools from that server will be prefixed with `{tool_prefix}_` -- When listing tools, the prefixed names are shown to the model -- When calling tools, the prefix is automatically removed before sending the request to the server - -This allows you to use multiple servers that might have overlapping tool names without conflicts. - -### Example with HTTP Server +This allows you to use multiple servers that might have overlapping tool names without conflicts: ```python {title="mcp_tool_prefix_http_client.py" py="3.10"} from pydantic_ai import Agent @@ -242,41 +234,9 @@ calculator_server = MCPServerSSE( # Both servers might have a tool named 'get_data', but they'll be exposed as: # - 'weather_get_data' # - 'calc_get_data' -agent = Agent('openai:gpt-4o', mcp_servers=[weather_server, calculator_server]) -``` - -### Example with Stdio Server - -```python {title="mcp_tool_prefix_stdio_client.py" py="3.10"} -from pydantic_ai import Agent -from pydantic_ai.mcp import MCPServerStdio - -python_server = MCPServerStdio( - 'deno', - args=[ - 'run', - '-N', - 'jsr:@pydantic/mcp-run-python', - 'stdio', - ], - tool_prefix='py' # Tools will be prefixed with 'py_' -) - -js_server = MCPServerStdio( - 'node', - args=[ - 'run', - 'mcp-js-server.js', - 'stdio', - ], - tool_prefix='js' # Tools will be prefixed with 'js_' -) - -agent = Agent('openai:gpt-4o', mcp_servers=[python_server, js_server]) +agent = Agent('openai:gpt-4o', toolsets=[weather_server, calculator_server]) ``` -When the model interacts with these servers, it will see the prefixed tool names, but the prefixes will be automatically handled when making tool calls. - ## MCP Sampling !!! info "What is MCP Sampling?" @@ -312,6 +272,8 @@ Pydantic AI supports sampling as both a client and server. See the [server](./se Sampling is automatically supported by Pydantic AI agents when they act as a client. +To be able to use sampling, an MCP server instance needs to have a [`sampling_model`][pydantic_ai.mcp.MCPServerStdio.sampling_model] set. This can be done either directly on the server using the constructor keyword argument or the property, or by using [`agent.set_mcp_sampling_model()`][pydantic_ai.Agent.set_mcp_sampling_model] to set the agent's model or one specified as an argument as the sampling model on all MCP servers registered with that agent. + Let's say we have an MCP server that wants to use sampling (in this case to generate an SVG as per the tool arguments). ??? example "Sampling MCP Server" @@ -359,11 +321,12 @@ from pydantic_ai import Agent from pydantic_ai.mcp import MCPServerStdio server = MCPServerStdio(command='python', args=['generate_svg.py']) -agent = Agent('openai:gpt-4o', mcp_servers=[server]) +agent = Agent('openai:gpt-4o', toolsets=[server]) async def main(): - async with agent.run_mcp_servers(): + async with agent: + agent.set_mcp_sampling_model() result = await agent.run('Create an image of a robot in a punk style.') print(result.output) #> Image file written to robot_punk.svg. diff --git a/docs/models/huggingface.md b/docs/models/huggingface.md index 61f8eef35f..4e6e2ee795 100644 --- a/docs/models/huggingface.md +++ b/docs/models/huggingface.md @@ -69,7 +69,7 @@ agent = Agent(model) ## Custom Hugging Face client [`HuggingFaceProvider`][pydantic_ai.providers.huggingface.HuggingFaceProvider] also accepts a custom -[`AsyncInferenceClient`][huggingface_hub.AsyncInferenceClient] client via the `hf_client` parameter, so you can customise +[`AsyncInferenceClient`](https://huggingface.co/docs/huggingface_hub/v0.29.3/en/package_reference/inference_client#huggingface_hub.AsyncInferenceClient) client via the `hf_client` parameter, so you can customise the `headers`, `bill_to` (billing to an HF organization you're a member of), `base_url` etc. as defined in the [Hugging Face Hub python library docs](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client). diff --git a/docs/output.md b/docs/output.md index caa0c14b0f..2391a5efde 100644 --- a/docs/output.md +++ b/docs/output.md @@ -199,8 +199,8 @@ async def hand_off_to_sql_agent(ctx: RunContext, query: str) -> list[Row]: return output except UnexpectedModelBehavior as e: # Bubble up potentially retryable errors to the router agent - if (cause := e.__cause__) and hasattr(cause, 'tool_retry'): - raise ModelRetry(f'SQL agent failed: {cause.tool_retry.content}') from e + if (cause := e.__cause__) and isinstance(cause, ModelRetry): + raise ModelRetry(f'SQL agent failed: {cause.message}') from e else: raise @@ -276,6 +276,8 @@ In the default Tool Output mode, the output JSON schema of each output type (or If you'd like to change the name of the output tool, pass a custom description to aid the model, or turn on or off strict mode, you can wrap the type(s) in the [`ToolOutput`][pydantic_ai.output.ToolOutput] marker class and provide the appropriate arguments. Note that by default, the description is taken from the docstring specified on a Pydantic model or output function, so specifying it using the marker class is typically not necessary. +To dynamically modify or filter the available output tools during an agent run, you can define an agent-wide `prepare_output_tools` function that will be called ahead of each step of a run. This function should be of type [`ToolsPrepareFunc`][pydantic_ai.tools.ToolsPrepareFunc], which takes the [`RunContext`][pydantic_ai.tools.RunContext] and a list of [`ToolDefinition`][pydantic_ai.tools.ToolDefinition], and returns a new list of tool definitions (or `None` to disable all tools for that step). This is analogous to the [`prepare_tools` function](tools.md#prepare-tools) for non-output tools. + ```python {title="tool_output.py"} from pydantic import BaseModel diff --git a/docs/testing.md b/docs/testing.md index ac7d2ea249..b40bb1dc9c 100644 --- a/docs/testing.md +++ b/docs/testing.md @@ -10,7 +10,7 @@ Unless you're really sure you know better, you'll probably want to follow roughl * If you find yourself typing out long assertions, use [inline-snapshot](https://15r10nk.github.io/inline-snapshot/latest/) * Similarly, [dirty-equals](https://dirty-equals.helpmanual.io/latest/) can be useful for comparing large data structures * Use [`TestModel`][pydantic_ai.models.test.TestModel] or [`FunctionModel`][pydantic_ai.models.function.FunctionModel] in place of your actual model to avoid the usage, latency and variability of real LLM calls -* Use [`Agent.override`][pydantic_ai.agent.Agent.override] to replace your model inside your application logic +* Use [`Agent.override`][pydantic_ai.agent.Agent.override] to replace an agent's model, dependencies, or toolsets inside your application logic * Set [`ALLOW_MODEL_REQUESTS=False`][pydantic_ai.models.ALLOW_MODEL_REQUESTS] globally to block any requests from being made to non-test models accidentally ### Unit testing with `TestModel` diff --git a/docs/tools.md b/docs/tools.md index 44133f5759..134a8f96ea 100644 --- a/docs/tools.md +++ b/docs/tools.md @@ -1,25 +1,30 @@ # Function Tools -Function tools provide a mechanism for models to retrieve extra information to help them generate a response. +Function tools provide a mechanism for models to perform actions and retrieve extra information to help them generate a response. -They're useful when you want to enable the model to take some action and use the result, when it is impractical or impossible to put all the context an agent might need into the system prompt, or when you want to make agents' behavior more deterministic or reliable by deferring some of the logic required to generate a response to another (not necessarily AI-powered) tool. +They're useful when you want to enable the model to take some action and use the result, when it is impractical or impossible to put all the context an agent might need into the instructions, or when you want to make agents' behavior more deterministic or reliable by deferring some of the logic required to generate a response to another (not necessarily AI-powered) tool. If you want a model to be able to call a function as its final action, without the result being sent back to the model, you can use an [output function](output.md#output-functions) instead. -!!! info "Function tools vs. RAG" - Function tools are basically the "R" of RAG (Retrieval-Augmented Generation) — they augment what the model can do by letting it request extra information. - - The main semantic difference between PydanticAI Tools and RAG is RAG is synonymous with vector search, while PydanticAI tools are more general-purpose. (Note: we may add support for vector search functionality in the future, particularly an API for generating embeddings. See [#58](https://github.com/pydantic/pydantic-ai/issues/58)) - There are a number of ways to register tools with an agent: * via the [`@agent.tool`][pydantic_ai.Agent.tool] decorator — for tools that need access to the agent [context][pydantic_ai.tools.RunContext] * via the [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain] decorator — for tools that do not need access to the agent [context][pydantic_ai.tools.RunContext] * via the [`tools`][pydantic_ai.Agent.__init__] keyword argument to `Agent` which can take either plain functions, or instances of [`Tool`][pydantic_ai.tools.Tool] -## Registering Function Tools via Decorator +For more advanced use cases, the [toolsets](toolsets.md) feature lets you manage collections of tools (built by you or providd by an [MCP server](mcp/client.md) or other [third party](#third-party-tools)) and register them with an agent in one go via the [`toolsets`][pydantic_ai.Agent.__init__] keyword argument to `Agent`. + +!!! info "Function tools vs. RAG" + Function tools are basically the "R" of RAG (Retrieval-Augmented Generation) — they augment what the model can do by letting it request extra information. + + The main semantic difference between PydanticAI Tools and RAG is RAG is synonymous with vector search, while PydanticAI tools are more general-purpose. (Note: we may add support for vector search functionality in the future, particularly an API for generating embeddings. See [#58](https://github.com/pydantic/pydantic-ai/issues/58)) -`@agent.tool` is considered the default decorator since in the majority of cases tools will need access to the agent context. +!!! info "Function Tools vs. Structured Outputs" + As the name suggests, function tools use the model's "tools" or "functions" API to let the model know what is available to call. Tools or functions are also used to define the schema(s) for [structured output](output.md) when using the default [tool output mode](output.md#tool-output), thus a model might have access to many tools, some of which call function tools while others end the run and produce a final output. + +## Registering via Decorator {#registering-function-tools-via-decorator} + +`@agent.tool` is considered the default decorator since in the majority of cases tools will need access to the agent [context][pydantic_ai.tools.RunContext]. Here's an example using both: @@ -58,7 +63,7 @@ print(dice_result.output) 1. This is a pretty simple task, so we can use the fast and cheap Gemini flash model. 2. We pass the user's name as the dependency, to keep things simple we use just the name as a string as the dependency. -3. This tool doesn't need any context, it just returns a random number. You could probably use a dynamic system prompt in this case. +3. This tool doesn't need any context, it just returns a random number. You could probably use dynamic instructions in this case. 4. This tool needs the player's name, so it uses `RunContext` to access dependencies which are just the player's name in this case. 5. Run the agent, passing the player's name as the dependency. @@ -176,7 +181,7 @@ sequenceDiagram Note over Agent: Game session complete ``` -## Registering Function Tools via Agent Argument +## Registering via Agent Argument {#registering-function-tools-via-agent-argument} As well as using the decorators, we can register tools via the `tools` argument to the [`Agent` constructor][pydantic_ai.Agent.__init__]. This is useful when you want to reuse tools, and can also give more fine-grained control over the tools. @@ -232,7 +237,7 @@ print(dice_result['b'].output) _(This example is complete, it can be run "as is")_ -## Function Tool Output +## Tool Output {#function-tool-output} Tools can return anything that Pydantic can serialize to JSON, as well as audio, video, image or document content depending on the types of [multi-modal input](input.md) the model supports: @@ -353,11 +358,7 @@ print(result.output) This separation allows you to provide rich context to the model while maintaining clean, structured return values for your application logic. -## Function Tools vs. Structured Outputs - -As the name suggests, function tools use the model's "tools" or "functions" API to let the model know what is available to call. Tools or functions are also used to define the schema(s) for structured responses, thus a model might have access to many tools, some of which call function tools while others end the run and produce a final output. - -## Function tools and schema +## Tool Schema {#function-tools-and-schema} Function parameters are extracted from the function signature, and all parameters except `RunContext` are used to build the schema for that tool call. @@ -469,7 +470,9 @@ print(test_model.last_model_request_parameters.function_tools) _(This example is complete, it can be run "as is")_ -If you have a function that lacks appropriate documentation (i.e. poorly named, no type information, poor docstring, use of *args or **kwargs and suchlike) then you can still turn it into a tool that can be effectively used by the agent with the `Tool.from_schema` function. With this you provide the name, description and JSON schema for the function directly: +### Custom Tool Schema + +If you have a function that lacks appropriate documentation (i.e. poorly named, no type information, poor docstring, use of *args or **kwargs and suchlike) then you can still turn it into a tool that can be effectively used by the agent with the [`Tool.from_schema`][pydantic_ai.Tool.from_schema] function. With this you provide the name, description and JSON schema for the function directly: ```python from pydantic_ai import Agent, Tool @@ -505,7 +508,7 @@ print(result.output) Please note that validation of the tool arguments will not be performed, and this will pass all arguments as keyword arguments. -## Dynamic Function tools {#tool-prepare} +## Dynamic Tools {#tool-prepare} Tools can optionally be defined with another function: `prepare`, which is called at each step of a run to customize the definition of the tool passed to the model, or omit the tool completely from that step. @@ -606,14 +609,15 @@ print(test_model.last_model_request_parameters.function_tools) _(This example is complete, it can be run "as is")_ -## Agent-wide Dynamic Tool Preparation {#prepare-tools} +### Agent-wide Dynamic Tools {#prepare-tools} In addition to per-tool `prepare` methods, you can also define an agent-wide `prepare_tools` function. This function is called at each step of a run and allows you to filter or modify the list of all tool definitions available to the agent for that step. This is especially useful if you want to enable or disable multiple tools at once, or apply global logic based on the current context. The `prepare_tools` function should be of type [`ToolsPrepareFunc`][pydantic_ai.tools.ToolsPrepareFunc], which takes the [`RunContext`][pydantic_ai.tools.RunContext] and a list of [`ToolDefinition`][pydantic_ai.tools.ToolDefinition], and returns a new list of tool definitions (or `None` to disable all tools for that step). !!! note - The list of tool definitions passed to `prepare_tools` includes both regular tools and tools from any MCP servers attached to the agent. + The list of tool definitions passed to `prepare_tools` includes both regular function tools and tools from any [toolsets](toolsets.md) registered to the agent, but not [output tools](output.md#tool-output). + To modify output tools, you can set a `prepare_output_tools` function instead. Here's an example that makes all tools strict if the model is an OpenAI model: @@ -724,11 +728,11 @@ Raising `ModelRetry` also generates a `RetryPromptPart` containing the exception ### MCP Tools {#mcp-tools} -See the [MCP Client](./mcp/client.md) documentation for how to use MCP servers with Pydantic AI. +See the [MCP Client](./mcp/client.md) documentation for how to use MCP servers with Pydantic AI as [toolsets](toolsets.md). ### LangChain Tools {#langchain-tools} -If you'd like to use a tool from LangChain's [community tool library](https://python.langchain.com/docs/integrations/tools/) with Pydantic AI, you can use the `pydancic_ai.ext.langchain.tool_from_langchain` convenience method. Note that Pydantic AI will not validate the arguments in this case -- it's up to the model to provide arguments matching the schema specified by the LangChain tool, and up to the LangChain tool to raise an error if the arguments are invalid. +If you'd like to use a tool from LangChain's [community tool library](https://python.langchain.com/docs/integrations/tools/) with Pydantic AI, you can use the [`tool_from_langchain`][pydantic_ai.ext.langchain.tool_from_langchain] convenience method. Note that Pydantic AI will not validate the arguments in this case -- it's up to the model to provide arguments matching the schema specified by the LangChain tool, and up to the LangChain tool to raise an error if the arguments are invalid. You will need to install the `langchain-community` package and any others required by the tool in question. @@ -740,6 +744,7 @@ from langchain_community.tools import DuckDuckGoSearchRun from pydantic_ai import Agent from pydantic_ai.ext.langchain import tool_from_langchain + search = DuckDuckGoSearchRun() search_tool = tool_from_langchain(search) @@ -755,9 +760,25 @@ print(result.output) 1. The release date of this game is the 30th of May 2025, which is after the knowledge cutoff for Gemini 2.0 (August 2024). +If you'd like to use multiple LangChain tools or a LangChain [toolkit](https://python.langchain.com/docs/concepts/tools/#toolkits), you can use the [`LangChainToolset`][pydantic_ai.ext.langchain.LangChainToolset] [toolset](toolsets.md) which takes a list of LangChain tools: + +```python {test="skip"} +from langchain_community.agent_toolkits import SlackToolkit + +from pydantic_ai import Agent +from pydantic_ai.ext.langchain import LangChainToolset + + +toolkit = SlackToolkit() +toolset = LangChainToolset(toolkit.get_tools()) + +agent = Agent('openai:gpt-4o', toolsets=[toolset]) +# ... +``` + ### ACI.dev Tools {#aci-tools} -If you'd like to use a tool from the [ACI.dev tool library](https://www.aci.dev/tools) with Pydantic AI, you can use the `pydancic_ai.ext.aci.tool_from_aci` convenience method. Note that Pydantic AI will not validate the arguments in this case -- it's up to the model to provide arguments matching the schema specified by the ACI tool, and up to the ACI tool to raise an error if the arguments are invalid. +If you'd like to use a tool from the [ACI.dev tool library](https://www.aci.dev/tools) with Pydantic AI, you can use the [`tool_from_aci`][pydantic_ai.ext.aci.tool_from_aci] convenience method. Note that Pydantic AI will not validate the arguments in this case -- it's up to the model to provide arguments matching the schema specified by the ACI tool, and up to the ACI tool to raise an error if the arguments are invalid. You will need to install the `aci-sdk` package, set your ACI API key in the `ACI_API_KEY` environment variable, and pass your ACI "linked account owner ID" to the function. @@ -769,14 +790,15 @@ import os from pydantic_ai import Agent from pydantic_ai.ext.aci import tool_from_aci + tavily_search = tool_from_aci( 'TAVILY__SEARCH', - linked_account_owner_id=os.getenv('LINKED_ACCOUNT_OWNER_ID') + linked_account_owner_id=os.getenv('LINKED_ACCOUNT_OWNER_ID'), ) agent = Agent( 'google-gla:gemini-2.0-flash', - tools=[tavily_search] + tools=[tavily_search], ) result = agent.run_sync('What is the release date of Elden Ring Nightreign?') # (1)! @@ -785,3 +807,23 @@ print(result.output) ``` 1. The release date of this game is the 30th of May 2025, which is after the knowledge cutoff for Gemini 2.0 (August 2024). + +If you'd like to use multiple ACI.dev tools, you can use the [`ACIToolset`][pydantic_ai.ext.aci.ACIToolset] [toolset](toolsets.md) which takes a list of ACI tool names as well as the `linked_account_owner_id`: + +```python {test="skip"} +import os + +from pydantic_ai import Agent +from pydantic_ai.ext.aci import ACIToolset + + +toolset = ACIToolset( + [ + 'OPEN_WEATHER_MAP__CURRENT_WEATHER', + 'OPEN_WEATHER_MAP__FORECAST', + ], + linked_account_owner_id=os.getenv('LINKED_ACCOUNT_OWNER_ID'), +) + +agent = Agent('openai:gpt-4o', toolsets=[toolset]) +``` diff --git a/docs/toolsets.md b/docs/toolsets.md new file mode 100644 index 0000000000..fa7073798b --- /dev/null +++ b/docs/toolsets.md @@ -0,0 +1,633 @@ + +# Toolsets + +A toolset represents a collection of [tools](tools.md) that can be registered with an agent in one go. They can be reused by different agents, swapped out at runtime or during testing, and composed in order to dynamically filter which tools are available, modify tool definitions, or change tool execution behavior. A toolset can contain locally defined functions, depend on an external service to provide them, or implement custom logic to list available tools and handle them being called. + +Toolsets are used (among many other things) to define [MCP servers](mcp/client.md) available to an agent. Pydantic AI includes many kinds of toolsets which are described below, and you can define a [custom toolset](#building-a-custom-toolset) by inheriting from the [`AbstractToolset`][pydantic_ai.toolsets.AbstractToolset] class. + +The toolsets that will be available during an agent run can be specified in three different ways: + +* at agent construction time, via the [`toolsets`][pydantic_ai.Agent.__init__] keyword argument to `Agent` +* at agent run time, via the `toolsets` keyword argument to [`agent.run()`][pydantic_ai.Agent.run], [`agent.run_sync()`][pydantic_ai.Agent.run_sync], [`agent.run_stream()`][pydantic_ai.Agent.run_stream], or [`agent.iter()`][pydantic_ai.Agent.iter]. These toolsets will be additional to those provided to the `Agent` constructor +* as a contextual override, via the `toolsets` keyword argument to the [`agent.override()`][pydantic_ai.Agent.iter] context manager. These toolsets will replace those provided at agent construction or run time during the life of the context manager + +```python {title="toolsets.py"} +from pydantic_ai import Agent +from pydantic_ai.models.test import TestModel +from pydantic_ai.toolsets import FunctionToolset + + +def agent_tool(): + return "I'm registered directly on the agent" + + +def extra_tool(): + return "I'm passed as an extra tool for a specific run" + + +def override_tool(): + return "I override all other tools" + + +agent_toolset = FunctionToolset(tools=[agent_tool]) # (1)! +extra_toolset = FunctionToolset(tools=[extra_tool]) +override_toolset = FunctionToolset(tools=[override_tool]) + +test_model = TestModel() # (2)! +agent = Agent(test_model, toolsets=[agent_toolset]) + +result = agent.run_sync('What tools are available?') +print([t.name for t in test_model.last_model_request_parameters.function_tools]) +#> ['agent_tool'] + +result = agent.run_sync('What tools are available?', toolsets=[extra_toolset]) +print([t.name for t in test_model.last_model_request_parameters.function_tools]) +#> ['agent_tool', 'extra_tool'] + +with agent.override(toolsets=[override_toolset]): + result = agent.run_sync('What tools are available?', toolsets=[extra_toolset]) # (3)! + print([t.name for t in test_model.last_model_request_parameters.function_tools]) + #> ['override_tool'] +``` + +1. The [`FunctionToolset`][pydantic_ai.toolsets.FunctionToolset] will be explained in detail in the next section. +2. We're using [`TestModel`][pydantic_ai.models.test.TestModel] here because it makes it easy to see which tools were available on each run. +3. This `extra_toolset` will be ignored because we're inside an override context. + +_(This example is complete, it can be run "as is")_ + +## Function Toolset + +As the name suggests, a [`FunctionToolset`][pydantic_ai.toolsets.FunctionToolset] makes locally defined functions available as tools. + +Functions can be added as tools in three different ways: + +* via the [`@toolset.tool`][pydantic_ai.toolsets.FunctionToolset.tool] decorator +* via the [`tools`][pydantic_ai.toolsets.FunctionToolset.__init__] keyword argument to the constructor which can take either plain functions, or instances of [`Tool`][pydantic_ai.tools.Tool] +* via the [`toolset.add_function()`][pydantic_ai.toolsets.FunctionToolset.add_function] and [`toolset.add_tool()`][pydantic_ai.toolsets.FunctionToolset.add_tool] methods which can take a plain function or an instance of [`Tool`][pydantic_ai.tools.Tool] respectively + +Functions registered in any of these ways can define an initial `ctx: RunContext` argument in order to receive the agent [context][pydantic_ai.tools.RunContext]. The `add_function()` and `add_tool()` methods can also be used from a tool function to dynamically register new tools during a run to be available in future run steps. + +```python {title="function_toolset.py"} +from datetime import datetime + +from pydantic_ai import Agent, RunContext +from pydantic_ai.models.test import TestModel +from pydantic_ai.toolsets import FunctionToolset + + +def temperature_celsius(city: str) -> float: + return 21.0 + + +def temperature_fahrenheit(city: str) -> float: + return 69.8 + + +weather_toolset = FunctionToolset(tools=[temperature_celsius, temperature_fahrenheit]) + + +@weather_toolset.tool +def conditions(ctx: RunContext, city: str) -> str: + if ctx.run_step % 2 == 0: + return "It's sunny" + else: + return "It's raining" + + +datetime_toolset = FunctionToolset() +datetime_toolset.add_function(lambda: datetime.now(), name='now') + +test_model = TestModel() # (1)! +agent = Agent(test_model) + +result = agent.run_sync('What tools are available?', toolsets=[weather_toolset]) +print([t.name for t in test_model.last_model_request_parameters.function_tools]) +#> ['temperature_celsius', 'temperature_fahrenheit', 'conditions'] + +result = agent.run_sync('What tools are available?', toolsets=[datetime_toolset]) +print([t.name for t in test_model.last_model_request_parameters.function_tools]) +#> ['now'] +``` + +1. We're using [`TestModel`][pydantic_ai.models.test.TestModel] here because it makes it easy to see which tools were available on each run. + +_(This example is complete, it can be run "as is")_ + +## Toolset Composition + +Toolsets can be composed to dynamically filter which tools are available, modify tool definitions, or change tool execution behavior. Multiple toolsets can also be combined into one. + +### Combining Toolsets + +[`CombinedToolset`][pydantic_ai.toolsets.CombinedToolset] takes a list of toolsets and lets them be used as one. + +```python {title="combined_toolset.py" requires="function_toolset.py"} +from function_toolset import weather_toolset, datetime_toolset + +from pydantic_ai import Agent +from pydantic_ai.models.test import TestModel +from pydantic_ai.toolsets import CombinedToolset + + +combined_toolset = CombinedToolset([weather_toolset, datetime_toolset]) + +test_model = TestModel() # (1)! +agent = Agent(test_model, toolsets=[combined_toolset]) +result = agent.run_sync('What tools are available?') +print([t.name for t in test_model.last_model_request_parameters.function_tools]) +#> ['temperature_celsius', 'temperature_fahrenheit', 'conditions', 'now'] +``` + +1. We're using [`TestModel`][pydantic_ai.models.test.TestModel] here because it makes it easy to see which tools were available on each run. + +_(This example is complete, it can be run "as is")_ + +### Filtering Tools + +[`FilteredToolset`][pydantic_ai.toolsets.FilteredToolset] wraps a toolset and filters available tools ahead of each step of the run based on a user-defined function that is passed the agent [run context][pydantic_ai.tools.RunContext] and each tool's [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] and returns a boolean to indicate whether or not a given tool should be available. + +To easily chain different modifications, you can also call [`filtered()`][pydantic_ai.toolsets.AbstractToolset.filtered] on any toolset instead of directly constructing a `FilteredToolset`. + +```python {title="filtered_toolset.py" requires="function_toolset.py,combined_toolset.py"} +from combined_toolset import combined_toolset + +from pydantic_ai import Agent +from pydantic_ai.models.test import TestModel + +filtered_toolset = combined_toolset.filtered(lambda ctx, tool_def: 'fahrenheit' not in tool_def.name) + +test_model = TestModel() # (1)! +agent = Agent(test_model, toolsets=[filtered_toolset]) +result = agent.run_sync('What tools are available?') +print([t.name for t in test_model.last_model_request_parameters.function_tools]) +#> ['weather_temperature_celsius', 'weather_conditions', 'datetime_now'] +``` + +1. We're using [`TestModel`][pydantic_ai.models.test.TestModel] here because it makes it easy to see which tools were available on each run. + +_(This example is complete, it can be run "as is")_ + +### Prefixing Tool Names + +[`PrefixedToolset`][pydantic_ai.toolsets.PrefixedToolset] wraps a toolset and adds a prefix to each tool name to prevent tool name conflicts between different toolsets. + +To easily chain different modifications, you can also call [`prefixed()`][pydantic_ai.toolsets.AbstractToolset.prefixed] on any toolset instead of directly constructing a `PrefixedToolset`. + +```python {title="combined_toolset.py" requires="function_toolset.py"} +from function_toolset import weather_toolset, datetime_toolset + +from pydantic_ai import Agent +from pydantic_ai.models.test import TestModel +from pydantic_ai.toolsets import CombinedToolset + + +combined_toolset = CombinedToolset( + [ + weather_toolset.prefixed('weather'), + datetime_toolset.prefixed('datetime') + ] +) + +test_model = TestModel() # (1)! +agent = Agent(test_model, toolsets=[combined_toolset]) +result = agent.run_sync('What tools are available?') +print([t.name for t in test_model.last_model_request_parameters.function_tools]) +""" +[ + 'weather_temperature_celsius', + 'weather_temperature_fahrenheit', + 'weather_conditions', + 'datetime_now', +] +""" +``` + +1. We're using [`TestModel`][pydantic_ai.models.test.TestModel] here because it makes it easy to see which tools were available on each run. + +_(This example is complete, it can be run "as is")_ + +### Renaming Tools + +[`RenamedToolset`][pydantic_ai.toolsets.RenamedToolset] wraps a toolset and lets you rename tools using a dictionary mapping new names to original names. This is useful when the names provided by a toolset are ambiguous or would conflict with tools defined by other toolsets, but [prefixing them](#prefixing-tool-names) creates a name that is unnecessarily long or could be confusing to the model. + +To easily chain different modifications, you can also call [`renamed()`][pydantic_ai.toolsets.AbstractToolset.renamed] on any toolset instead of directly constructing a `RenamedToolset`. + +```python {title="renamed_toolset.py" requires="function_toolset.py,combined_toolset.py"} +from combined_toolset import combined_toolset + +from pydantic_ai import Agent +from pydantic_ai.models.test import TestModel + + +renamed_toolset = combined_toolset.renamed( + { + 'current_time': 'datetime_now', + 'temperature_celsius': 'weather_temperature_celsius', + 'temperature_fahrenheit': 'weather_temperature_fahrenheit' + } +) + +test_model = TestModel() # (1)! +agent = Agent(test_model, toolsets=[renamed_toolset]) +result = agent.run_sync('What tools are available?') +print([t.name for t in test_model.last_model_request_parameters.function_tools]) +""" +['temperature_celsius', 'temperature_fahrenheit', 'weather_conditions', 'current_time'] +""" +``` + +1. We're using [`TestModel`][pydantic_ai.models.test.TestModel] here because it makes it easy to see which tools were available on each run. + +_(This example is complete, it can be run "as is")_ + +### Preparing Tool Definitions + +[`PreparedToolset`][pydantic_ai.toolsets.PreparedToolset] lets you modify the entire list of available tools ahead of each step of the agent run using a user-defined function that takes the agent [run context][pydantic_ai.tools.RunContext] and a list of [`ToolDefinition`s][pydantic_ai.tools.ToolDefinition] and returns a list of modified `ToolDefinition`s. + +This is the toolset-specific equivalent of the [`prepare_tools`](tools.md#prepare-tools) argument to `Agent` that prepares all tool definitions registered to an agent across toolsets. + +Note that it is not possible to add or rename tools using `PreparedToolset`. Instead, you can use [`FunctionToolset.add_function()`](#function-toolset) or [`RenamedToolset`](#renaming-tools). + +To easily chain different modifications, you can also call [`prepared()`][pydantic_ai.toolsets.AbstractToolset.prepared] on any toolset instead of directly constructing a `PreparedToolset`. + +```python {title="prepared_toolset.py" requires="function_toolset.py,combined_toolset.py,renamed_toolset.py"} +from dataclasses import replace +from typing import Union + +from renamed_toolset import renamed_toolset + +from pydantic_ai import Agent, RunContext +from pydantic_ai.models.test import TestModel +from pydantic_ai.tools import ToolDefinition + +descriptions = { + 'temperature_celsius': 'Get the temperature in degrees Celsius', + 'temperature_fahrenheit': 'Get the temperature in degrees Fahrenheit', + 'weather_conditions': 'Get the current weather conditions', + 'current_time': 'Get the current time', +} + +async def add_descriptions(ctx: RunContext, tool_defs: list[ToolDefinition]) -> Union[list[ToolDefinition], None]: + return [ + replace(tool_def, description=description) + if (description := descriptions.get(tool_def.name, None)) + else tool_def + for tool_def + in tool_defs + ] + +prepared_toolset = renamed_toolset.prepared(add_descriptions) + +test_model = TestModel() # (1)! +agent = Agent(test_model, toolsets=[prepared_toolset]) +result = agent.run_sync('What tools are available?') +print(test_model.last_model_request_parameters.function_tools) +""" +[ + ToolDefinition( + name='temperature_celsius', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'city': {'type': 'string'}}, + 'required': ['city'], + 'type': 'object', + }, + description='Get the temperature in degrees Celsius', + ), + ToolDefinition( + name='temperature_fahrenheit', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'city': {'type': 'string'}}, + 'required': ['city'], + 'type': 'object', + }, + description='Get the temperature in degrees Fahrenheit', + ), + ToolDefinition( + name='weather_conditions', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'city': {'type': 'string'}}, + 'required': ['city'], + 'type': 'object', + }, + description='Get the current weather conditions', + ), + ToolDefinition( + name='current_time', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {}, + 'type': 'object', + }, + description='Get the current time', + ), +] +""" +``` + +1. We're using [`TestModel`][pydantic_ai.models.test.TestModel] here because it makes it easy to see which tools were available on each run. + +### Wrapping a Toolset + +[`WrapperToolset`][pydantic_ai.toolsets.WrapperToolset] wraps another toolset and delegates all responsibility to it. + +To easily chain different modifications, you can also call [`wrap()`][pydantic_ai.toolsets.AbstractToolset.wrap] on any toolset instead of directly constructing an instance of (a subclass of) `WrapperToolset`. + +`WrapperToolset` is a no-op by default, but enables some useful abilities: + +#### Changing Tool Execution + +You can subclass `WrapperToolset` to change the wrapped toolset's tool execution behavior by overriding the [`call_tool()`][pydantic_ai.toolsets.AbstractToolset.call_tool] method. + +```python {title="logging_toolset.py" requires="function_toolset.py,combined_toolset.py,renamed_toolset.py,prepared_toolset.py"} +from typing_extensions import Any + +from prepared_toolset import prepared_toolset + +from pydantic_ai.agent import Agent +from pydantic_ai.models.test import TestModel +from pydantic_ai.tools import RunContext +from pydantic_ai.toolsets import WrapperToolset, ToolsetTool + +LOG = [] + +class LoggingToolset(WrapperToolset): + async def call_tool(self, name: str, tool_args: dict[str, Any], ctx: RunContext, tool: ToolsetTool) -> Any: + LOG.append(f'Calling tool {name!r} with args: {tool_args!r}') + try: + result = await super().call_tool(name, tool_args, ctx, tool) + LOG.append(f'Finished calling tool {name!r} with result: {result!r}') + except Exception as e: + LOG.append(f'Error calling tool {name!r}: {e}') + raise e + else: + return result + + +logging_toolset = prepared_toolset.wrap(LoggingToolset) + +agent = Agent(TestModel(), toolsets=[logging_toolset]) # (1)! +result = agent.run_sync('Call all the tools') +print(LOG) +""" +[ + "Calling tool 'temperature_celsius' with args: {'city': 'a'}", + "Calling tool 'temperature_fahrenheit' with args: {'city': 'a'}", + "Calling tool 'weather_conditions' with args: {'city': 'a'}", + "Calling tool 'current_time' with args: {}", + "Finished calling tool 'temperature_celsius' with result: 21.0", + "Finished calling tool 'temperature_fahrenheit' with result: 69.8", + 'Finished calling tool \'weather_conditions\' with result: "It\'s raining"', + "Finished calling tool 'current_time' with result: datetime.datetime(...)", +] +""" +``` + +1. We use [`TestModel`][pydantic_ai.models.test.TestModel] here as it will automatically call each tool. + +_(This example is complete, it can be run "as is")_ + +#### Modifying Toolsets During a Run + +You can change the `WrapperToolset`'s `wrapped` property during an agent run to swap out one toolset for another starting at the next run step. + +To add or remove available toolsets, you can wrap a [`CombinedToolset`](#combining-toolsets) and replace it during the run with one that can include fewer, more, or entirely different toolsets. + +```python {title="wrapper_toolset.py" requires="function_toolset.py"} +from function_toolset import weather_toolset, datetime_toolset + +from pydantic_ai import Agent, RunContext +from pydantic_ai.models.test import TestModel +from pydantic_ai.toolsets import WrapperToolset, FunctionToolset + +togglable_toolset = WrapperToolset(weather_toolset) + +def toggle(ctx: RunContext[WrapperToolset]): + if ctx.deps.wrapped == weather_toolset: + ctx.deps.wrapped = datetime_toolset + else: + ctx.deps.wrapped = weather_toolset + +test_model = TestModel() # (1)! +agent = Agent( + test_model, + deps_type=WrapperToolset, # (2)! + toolsets=[togglable_toolset, FunctionToolset([toggle])] +) +result = agent.run_sync('Toggle the toolset', deps=togglable_toolset) +print([t.name for t in test_model.last_model_request_parameters.function_tools]) # (3)! +#> ['now', 'toggle'] + +result = agent.run_sync('Toggle the toolset', deps=togglable_toolset) +print([t.name for t in test_model.last_model_request_parameters.function_tools]) +#> ['temperature_celsius', 'temperature_fahrenheit', 'conditions', 'toggle'] +``` + +1. We're using [`TestModel`][pydantic_ai.models.test.TestModel] here because it makes it easy to see which tools were available on each run. +2. We're using the agent's dependencies to give the `toggle` tool access to the `togglable_toolset` via the `RunContext` argument. +3. This shows the available tools _after_ the `toggle` tool was executed, as the "last model request" was the one that returned the `toggle` tool result to the model. + +## Building a Custom Toolset + +To define a fully custom toolset with its own logic to list available tools and handle them being called, you can subclass [`AbstractToolset`][pydantic_ai.toolsets.AbstractToolset] and implement the [`get_tools()`][pydantic_ai.toolsets.AbstractToolset.get_tools] and [`call_tool()`][pydantic_ai.toolsets.AbstractToolset.call_tool] methods. + +If you want to reuse a network connection or session across tool listings and calls during an agent run step, you can implement [`__aenter__()`][pydantic_ai.toolsets.AbstractToolset.__aenter__] and [`__aexit__()`][pydantic_ai.toolsets.AbstractToolset.__aexit__], which will be called when the agent that uses the toolset is itself entered using the [`async with agent`][pydantic_ai.Agent.__aenter__] context manager. + +### Deferred Toolset + +A deferred tool is one that will be executed not by Pydantic AI, but by the upstream service that called the agent, such as a web application that supports frontend-defined tools provided to Pydantic AI via a protocol like [AG-UI](https://docs.ag-ui.com/concepts/tools#frontend-defined-tools). + +!!! note + This is not typically something you need to bother with, unless you are implementing support for such a protocol between an upstream tool provider and Pydantic AI. + +When the model calls a deferred tool, the agent run ends with a [`DeferredToolCalls`][pydantic_ai.output.DeferredToolCalls] object containing the deferred tool call names and arguments, which is expected to be returned to the upstream tool provider. This upstream service is then expected to generate a response for each tool call and start a new Pydantic AI agent run with the message history and new [`ToolReturnPart`s][pydantic_ai.messages.ToolReturnPart] corresponding to each deferred call, after which the run will continue. + +To enable an agent to call deferred tools, you create a [`DeferredToolset`][pydantic_ai.toolsets.DeferredToolset], pass it a list of [`ToolDefinition`s][pydantic_ai.tools.ToolDefinition], and provide it to the agent using one of the methods described above. Additionally, you need to add `DeferredToolCalls` to the `Agent`'s [output types](output.md#structured-output) so that the agent run's output type is correctly inferred. Finally, you should handle the possible `DeferredToolCalls` result by returning it to the upstream tool provider. + +If your agent can also be used in a context where no deferred tools are available, you will not want to include `DeferredToolCalls` in the `output_type` passed to the `Agent` constructor as you'd have to deal with that type everywhere you use the agent. Instead, you can pass the `toolsets` and `output_type` keyword arguments when you run the agent using [`agent.run()`][pydantic_ai.Agent.run], [`agent.run_sync()`][pydantic_ai.Agent.run_sync], [`agent.run_stream()`][pydantic_ai.Agent.run_stream], or [`agent.iter()`][pydantic_ai.Agent.iter]. Note that while `toolsets` provided at this stage are additional to the toolsets provided to the constructor, the `output_type` overrides the one specified at construction time (for type inference reasons), so you'll need to include the original output types explicitly. + +To demonstrate, let us first define a simple agent _without_ deferred tools: + +```python {title="deferred_toolset_agent.py"} +from pydantic import BaseModel + +from pydantic_ai import Agent +from pydantic_ai.toolsets.function import FunctionToolset + +toolset = FunctionToolset() + + +@toolset.tool +def get_default_language(): + return 'en-US' + + +@toolset.tool +def get_user_name(): + return 'David' + + +class PersonalizedGreeting(BaseModel): + greeting: str + language_code: str + + +agent = Agent('openai:gpt-4o', toolsets=[toolset], output_type=PersonalizedGreeting) + +result = agent.run_sync('Greet the user in a personalized way') +print(repr(result.output)) +#> PersonalizedGreeting(greeting='Hello, David!', language_code='en-US') +``` + +Next, let's define an function for a hypothetical "run agent" API endpoint that can be called by the frontend and takes a list of messages to send to the model plus a dict of frontend tool names and descriptions. This is where `DeferredToolset` and `DeferredToolCalls` come in: + +```python {title="deferred_toolset_api.py" requires="deferred_toolset_agent.py"} +from deferred_toolset_agent import agent, PersonalizedGreeting + +from typing import Union + +from pydantic_ai.output import DeferredToolCalls +from pydantic_ai.tools import ToolDefinition +from pydantic_ai.toolsets import DeferredToolset +from pydantic_ai.messages import ModelMessage + +def run_agent( + messages: list[ModelMessage] = [], frontend_tools: list[ToolDefinition] = {} +) -> tuple[Union[PersonalizedGreeting, DeferredToolCalls], list[ModelMessage]]: + deferred_toolset = DeferredToolset(frontend_tools) + result = agent.run_sync( + toolsets=[deferred_toolset], # (1)! + output_type=[agent.output_type, DeferredToolCalls], # (2)! + message_history=messages, # (3)! + ) + return result.output, result.new_messages() +``` + +1. As mentioned above, these `toolsets` are additional to those provided to the `Agent` constructor +2. As mentioned above, this `output_type` overrides the one provided to the `Agent` constructor, so we have to make sure to not lose it +3. We don't include an `user_prompt` keyword argument as we expect the frontend to provide it via `messages` + +Now, imagine that the code below is implemented on the frontend, and `run_agent` stands in for an API call to the backend that runs the agent. This is where we actually execute the deferred tool calls and start a new run with the new result included: + +```python {title="deferred_tools.py" requires="deferred_toolset_agent.py,deferred_toolset_api.py"} +from deferred_toolset_api import run_agent + +from pydantic_ai.messages import ModelMessage, ModelRequest, RetryPromptPart, ToolReturnPart, UserPromptPart +from pydantic_ai.tools import ToolDefinition +from pydantic_ai.output import DeferredToolCalls + +frontend_tool_definitions = [ + ToolDefinition( + name='get_preferred_language', + parameters_json_schema={'type': 'object', 'properties': {'default_language': {'type': 'string'}}}, + description="Get the user's preferred language from their browser", + ) +] +def get_preferred_language(default_language: str) -> str: + return 'es-MX' # (1)! +frontend_tool_functions = {'get_preferred_language': get_preferred_language} + +messages: list[ModelMessage] = [ + ModelRequest( + parts=[ + UserPromptPart(content='Greet the user in a personalized way') + ] + ) +] + +final_output = None +while True: + output, new_messages = run_agent(messages, frontend_tool_definitions) + messages += new_messages + + if not isinstance(output, DeferredToolCalls): + final_output = output + break + + print(output.tool_calls) + """ + [ + ToolCallPart( + tool_name='get_preferred_language', + args={'default_language': 'en-US'}, + tool_call_id='pyd_ai_tool_call_id', + ) + ] + """ + for tool_call in output.tool_calls: + if function := frontend_tool_functions.get(tool_call.tool_name): + part = ToolReturnPart( + tool_name=tool_call.tool_name, + content=function(**tool_call.args_as_dict()), + tool_call_id=tool_call.tool_call_id, + ) + else: + part = RetryPromptPart( + tool_name=tool_call.tool_name, + content=f'Unknown tool {tool_call.tool_name!r}', + tool_call_id=tool_call.tool_call_id, + ) + messages.append(ModelRequest(parts=[part])) + +print(repr(final_output)) +""" +PersonalizedGreeting(greeting='Hola, David! Espero que tengas un gran día!', language_code='es-MX') +""" +``` + +1. Imagine that this returns [`navigator.language`](https://developer.mozilla.org/en-US/docs/Web/API/Navigator/language) + +_(This example is complete, it can be run "as is")_ + +## Third-Party Toolsets + +### MCP Servers + +See the [MCP Client](./mcp/client.md) documentation for how to use MCP servers with Pydantic AI. + +### LangChain Tools {#langchain-tools} + +If you'd like to use tools or a [toolkit](https://python.langchain.com/docs/concepts/tools/#toolkits) from LangChain's [community tool library](https://python.langchain.com/docs/integrations/tools/) with Pydantic AI, you can use the [`LangChainToolset`][pydantic_ai.ext.langchain.LangChainToolset] which takes a list of LangChain tools. Note that Pydantic AI will not validate the arguments in this case -- it's up to the model to provide arguments matching the schema specified by the LangChain tool, and up to the LangChain tool to raise an error if the arguments are invalid. + +You will need to install the `langchain-community` package and any others required by the tools in question. + +```python {test="skip"} +from langchain_community.agent_toolkits import SlackToolkit + +from pydantic_ai import Agent +from pydantic_ai.ext.langchain import LangChainToolset + + +toolkit = SlackToolkit() +toolset = LangChainToolset(toolkit.get_tools()) + +agent = Agent('openai:gpt-4o', toolsets=[toolset]) +# ... +``` + +### ACI.dev Tools {#aci-tools} + +If you'd like to use tools from the [ACI.dev tool library](https://www.aci.dev/tools) with Pydantic AI, you can use the [`ACIToolset`][pydantic_ai.ext.aci.ACIToolset] [toolset](toolsets.md) which takes a list of ACI tool names as well as the `linked_account_owner_id`. Note that Pydantic AI will not validate the arguments in this case -- it's up to the model to provide arguments matching the schema specified by the ACI tool, and up to the ACI tool to raise an error if the arguments are invalid. + +You will need to install the `aci-sdk` package, set your ACI API key in the `ACI_API_KEY` environment variable, and pass your ACI "linked account owner ID" to the function. + +```python {test="skip"} +import os + +from pydantic_ai import Agent +from pydantic_ai.ext.aci import ACIToolset + + +toolset = ACIToolset( + [ + 'OPEN_WEATHER_MAP__CURRENT_WEATHER', + 'OPEN_WEATHER_MAP__FORECAST', + ], + linked_account_owner_id=os.getenv('LINKED_ACCOUNT_OWNER_ID'), +) + +agent = Agent('openai:gpt-4o', toolsets=[toolset]) +``` diff --git a/mcp-run-python/README.md b/mcp-run-python/README.md index 360ca23471..edd84ddb88 100644 --- a/mcp-run-python/README.md +++ b/mcp-run-python/README.md @@ -52,11 +52,11 @@ server = MCPServerStdio('deno', 'jsr:@pydantic/mcp-run-python', 'stdio', ]) -agent = Agent('claude-3-5-haiku-latest', mcp_servers=[server]) +agent = Agent('claude-3-5-haiku-latest', toolsets=[server]) async def main(): - async with agent.run_mcp_servers(): + async with agent: result = await agent.run('How many days between 2000-01-01 and 2025-03-18?') print(result.output) #> There are 9,208 days between January 1, 2000, and March 18, 2025.w diff --git a/mkdocs.yml b/mkdocs.yml index a950d52c0c..fc6cd27999 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -28,9 +28,10 @@ nav: - models/cohere.md - models/groq.md - models/mistral.md + - models/huggingface.md - dependencies.md - tools.md - - common-tools.md + - toolsets.md - output.md - message-history.md - testing.md @@ -41,6 +42,7 @@ nav: - input.md - thinking.md - direct.md + - common-tools.md - MCP: - mcp/index.md - mcp/client.md @@ -64,6 +66,7 @@ nav: - API Reference: - api/agent.md - api/tools.md + - api/toolsets.md - api/common_tools.md - api/output.md - api/result.md @@ -75,6 +78,7 @@ nav: - api/format_as_xml.md - api/format_prompt.md - api/direct.md + - api/ext.md - api/models/base.md - api/models/openai.md - api/models/anthropic.md diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index fda19acda4..f6b4a51c3f 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -3,6 +3,7 @@ import asyncio import dataclasses import hashlib +from collections import defaultdict, deque from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence from contextlib import asynccontextmanager, contextmanager from contextvars import ContextVar @@ -13,17 +14,18 @@ from typing_extensions import TypeGuard, TypeVar, assert_never from pydantic_ai._function_schema import _takes_ctx as is_takes_ctx # type: ignore +from pydantic_ai._tool_manager import ToolManager from pydantic_ai._utils import is_async_callable, run_in_executor from pydantic_graph import BaseNode, Graph, GraphRunContext from pydantic_graph.nodes import End, NodeRunEndT from . import _output, _system_prompt, exceptions, messages as _messages, models, result, usage as _usage +from .exceptions import ToolRetryError from .output import OutputDataT, OutputSpec from .settings import ModelSettings, merge_model_settings -from .tools import RunContext, Tool, ToolDefinition, ToolsPrepareFunc +from .tools import RunContext, ToolDefinition, ToolKind if TYPE_CHECKING: - from .mcp import MCPServer from .models.instrumented import InstrumentationSettings __all__ = ( @@ -77,11 +79,13 @@ class GraphAgentState: retries: int run_step: int - def increment_retries(self, max_result_retries: int, error: Exception | None = None) -> None: + def increment_retries(self, max_result_retries: int, error: BaseException | None = None) -> None: self.retries += 1 if self.retries > max_result_retries: - message = f'Exceeded maximum retries ({max_result_retries}) for result validation' + message = f'Exceeded maximum retries ({max_result_retries}) for output validation' if error: + if isinstance(error, exceptions.UnexpectedModelBehavior) and error.__cause__ is not None: + error = error.__cause__ raise exceptions.UnexpectedModelBehavior(message) from error else: raise exceptions.UnexpectedModelBehavior(message) @@ -108,15 +112,11 @@ class GraphAgentDeps(Generic[DepsT, OutputDataT]): history_processors: Sequence[HistoryProcessor[DepsT]] - function_tools: dict[str, Tool[DepsT]] = dataclasses.field(repr=False) - mcp_servers: Sequence[MCPServer] = dataclasses.field(repr=False) - default_retries: int + tool_manager: ToolManager[DepsT] tracer: Tracer instrumentation_settings: InstrumentationSettings | None = None - prepare_tools: ToolsPrepareFunc[DepsT] | None = None - class AgentNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]]): """The base class for all agent nodes. @@ -248,59 +248,27 @@ async def _prepare_request_parameters( ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], ) -> models.ModelRequestParameters: """Build tools and create an agent model.""" - function_tool_defs_map: dict[str, ToolDefinition] = {} - run_context = build_run_context(ctx) - - async def add_tool(tool: Tool[DepsT]) -> None: - ctx = run_context.replace_with(retry=tool.current_retry, tool_name=tool.name) - if tool_def := await tool.prepare_tool_def(ctx): - # prepare_tool_def may change tool_def.name - if tool_def.name in function_tool_defs_map: - if tool_def.name != tool.name: - # Prepare tool def may have renamed the tool - raise exceptions.UserError( - f"Renaming tool '{tool.name}' to '{tool_def.name}' conflicts with existing tool." - ) - else: - raise exceptions.UserError(f'Tool name conflicts with existing tool: {tool.name!r}.') - function_tool_defs_map[tool_def.name] = tool_def - - async def add_mcp_server_tools(server: MCPServer) -> None: - if not server.is_running: - raise exceptions.UserError(f'MCP server is not running: {server}') - tool_defs = await server.list_tools() - for tool_def in tool_defs: - if tool_def.name in function_tool_defs_map: - raise exceptions.UserError( - f"MCP Server '{server}' defines a tool whose name conflicts with existing tool: {tool_def.name!r}. Consider using `tool_prefix` to avoid name conflicts." - ) - function_tool_defs_map[tool_def.name] = tool_def - - await asyncio.gather( - *map(add_tool, ctx.deps.function_tools.values()), - *map(add_mcp_server_tools, ctx.deps.mcp_servers), - ) - function_tool_defs = list(function_tool_defs_map.values()) - if ctx.deps.prepare_tools: - # Prepare the tools using the provided function - # This also acts over tool definitions pulled from MCP servers - function_tool_defs = await ctx.deps.prepare_tools(run_context, function_tool_defs) or [] + ctx.deps.tool_manager = await ctx.deps.tool_manager.for_run_step(run_context) output_schema = ctx.deps.output_schema - - output_tools = [] output_object = None - if isinstance(output_schema, _output.ToolOutputSchema): - output_tools = output_schema.tool_defs() - elif isinstance(output_schema, _output.NativeOutputSchema): + if isinstance(output_schema, _output.NativeOutputSchema): output_object = output_schema.object_def # ToolOrTextOutputSchema, NativeOutputSchema, and PromptedOutputSchema all inherit from TextOutputSchema allow_text_output = isinstance(output_schema, _output.TextOutputSchema) + function_tools: list[ToolDefinition] = [] + output_tools: list[ToolDefinition] = [] + for tool_def in ctx.deps.tool_manager.tool_defs: + if tool_def.kind == 'output': + output_tools.append(tool_def) + else: + function_tools.append(tool_def) + return models.ModelRequestParameters( - function_tools=function_tool_defs, + function_tools=function_tools, output_mode=output_schema.mode, output_tools=output_tools, output_object=output_object, @@ -341,8 +309,8 @@ async def stream( ctx.deps.output_schema, ctx.deps.output_validators, build_run_context(ctx), - _output.build_trace_context(ctx), ctx.deps.usage_limits, + ctx.deps.tool_manager, ) yield agent_stream # In case the user didn't manually consume the full stream, ensure it is fully consumed here, @@ -438,7 +406,6 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]): _next_node: ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]] | None = field( default=None, repr=False ) - _tool_responses: list[_messages.ModelRequestPart] = field(default_factory=list, repr=False) async def run( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] @@ -520,47 +487,30 @@ async def _handle_tool_calls( ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], tool_calls: list[_messages.ToolCallPart], ) -> AsyncIterator[_messages.HandleResponseEvent]: - output_schema = ctx.deps.output_schema run_context = build_run_context(ctx) - final_result: result.FinalResult[NodeRunEndT] | None = None - parts: list[_messages.ModelRequestPart] = [] - - # first, look for the output tool call - if isinstance(output_schema, _output.ToolOutputSchema): - for call, output_tool in output_schema.find_tool(tool_calls): - try: - trace_context = _output.build_trace_context(ctx) - result_data = await output_tool.process(call, run_context, trace_context) - result_data = await _validate_output(result_data, ctx, call) - except _output.ToolRetryError as e: - # TODO: Should only increment retry stuff once per node execution, not for each tool call - # Also, should increment the tool-specific retry count rather than the run retry count - ctx.state.increment_retries(ctx.deps.max_result_retries, e) - parts.append(e.tool_retry) - else: - final_result = result.FinalResult(result_data, call.tool_name, call.tool_call_id) - break + output_parts: list[_messages.ModelRequestPart] = [] + output_final_result: deque[result.FinalResult[NodeRunEndT]] = deque(maxlen=1) - # Then build the other request parts based on end strategy - tool_responses: list[_messages.ModelRequestPart] = self._tool_responses async for event in process_function_tools( - tool_calls, - final_result and final_result.tool_name, - final_result and final_result.tool_call_id, - ctx, - tool_responses, + ctx.deps.tool_manager, tool_calls, None, ctx, output_parts, output_final_result ): yield event - if final_result: - self._next_node = self._handle_final_result(ctx, final_result, tool_responses) + if output_final_result: + final_result = output_final_result[0] + self._next_node = self._handle_final_result(ctx, final_result, output_parts) + elif deferred_tool_calls := ctx.deps.tool_manager.get_deferred_tool_calls(tool_calls): + if not ctx.deps.output_schema.allows_deferred_tool_calls: + raise exceptions.UserError( + 'A deferred tool call was present, but `DeferredToolCalls` is not among output types. To resolve this, add `DeferredToolCalls` to the list of output types for this agent.' + ) + final_result = result.FinalResult(cast(NodeRunEndT, deferred_tool_calls), None, None) + self._next_node = self._handle_final_result(ctx, final_result, output_parts) else: - if tool_responses: - parts.extend(tool_responses) instructions = await ctx.deps.get_instructions(run_context) self._next_node = ModelRequestNode[DepsT, NodeRunEndT]( - _messages.ModelRequest(parts=parts, instructions=instructions) + _messages.ModelRequest(parts=output_parts, instructions=instructions) ) def _handle_final_result( @@ -586,18 +536,18 @@ async def _handle_text_response( text = '\n\n'.join(texts) try: + run_context = build_run_context(ctx) if isinstance(output_schema, _output.TextOutputSchema): - run_context = build_run_context(ctx) - trace_context = _output.build_trace_context(ctx) - result_data = await output_schema.process(text, run_context, trace_context) + result_data = await output_schema.process(text, run_context) else: m = _messages.RetryPromptPart( content='Plain text responses are not permitted, please include your response in a tool call', ) - raise _output.ToolRetryError(m) + raise ToolRetryError(m) - result_data = await _validate_output(result_data, ctx, None) - except _output.ToolRetryError as e: + for validator in ctx.deps.output_validators: + result_data = await validator.validate(result_data, run_context) + except ToolRetryError as e: ctx.state.increment_retries(ctx.deps.max_result_retries, e) return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry])) else: @@ -612,6 +562,9 @@ def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT usage=ctx.state.usage, prompt=ctx.deps.prompt, messages=ctx.state.message_history, + tracer=ctx.deps.tracer, + trace_include_content=ctx.deps.instrumentation_settings is not None + and ctx.deps.instrumentation_settings.include_content, run_step=ctx.state.run_step, ) @@ -623,269 +576,210 @@ def multi_modal_content_identifier(identifier: str | bytes) -> str: return hashlib.sha1(identifier).hexdigest()[:6] -async def process_function_tools( # noqa C901 +async def process_function_tools( # noqa: C901 + tool_manager: ToolManager[DepsT], tool_calls: list[_messages.ToolCallPart], - output_tool_name: str | None, - output_tool_call_id: str | None, + final_result: result.FinalResult[NodeRunEndT] | None, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], output_parts: list[_messages.ModelRequestPart], + output_final_result: deque[result.FinalResult[NodeRunEndT]] = deque(maxlen=1), ) -> AsyncIterator[_messages.HandleResponseEvent]: """Process function (i.e., non-result) tool calls in parallel. Also add stub return parts for any other tools that need it. - Because async iterators can't have return values, we use `output_parts` as an output argument. + Because async iterators can't have return values, we use `output_parts` and `output_final_result` as output arguments. """ - stub_function_tools = bool(output_tool_name) and ctx.deps.end_strategy == 'early' - output_schema = ctx.deps.output_schema - - # we rely on the fact that if we found a result, it's the first output tool in the last - found_used_output_tool = False - run_context = build_run_context(ctx) - - calls_to_run: list[tuple[Tool[DepsT], _messages.ToolCallPart]] = [] + tool_calls_by_kind: dict[ToolKind | Literal['unknown'], list[_messages.ToolCallPart]] = defaultdict(list) for call in tool_calls: - if ( - call.tool_name == output_tool_name - and call.tool_call_id == output_tool_call_id - and not found_used_output_tool - ): - found_used_output_tool = True - output_parts.append( - _messages.ToolReturnPart( + tool_def = tool_manager.get_tool_def(call.tool_name) + kind = tool_def.kind if tool_def else 'unknown' + tool_calls_by_kind[kind].append(call) + + # First, we handle output tool calls + for call in tool_calls_by_kind['output']: + if final_result: + if final_result.tool_call_id == call.tool_call_id: + part = _messages.ToolReturnPart( tool_name=call.tool_name, content='Final result processed.', tool_call_id=call.tool_call_id, ) - ) - elif tool := ctx.deps.function_tools.get(call.tool_name): - if stub_function_tools: - output_parts.append( - _messages.ToolReturnPart( - tool_name=call.tool_name, - content='Tool not executed - a final result was already processed.', - tool_call_id=call.tool_call_id, - ) - ) else: - event = _messages.FunctionToolCallEvent(call) - yield event - calls_to_run.append((tool, call)) - elif mcp_tool := await _tool_from_mcp_server(call.tool_name, ctx): - if stub_function_tools: - # TODO(Marcelo): We should add coverage for this part of the code. - output_parts.append( # pragma: no cover - _messages.ToolReturnPart( - tool_name=call.tool_name, - content='Tool not executed - a final result was already processed.', - tool_call_id=call.tool_call_id, - ) - ) - else: - event = _messages.FunctionToolCallEvent(call) - yield event - calls_to_run.append((mcp_tool, call)) - elif call.tool_name in output_schema.tools: - # if tool_name is in output_schema, it means we found a output tool but an error occurred in - # validation, we don't add another part here - if output_tool_name is not None: yield _messages.FunctionToolCallEvent(call) - if found_used_output_tool: - content = 'Output tool not used - a final result was already processed.' - else: - # TODO: Include information about the validation failure, and/or merge this with the ModelRetry part - content = 'Output tool not used - result failed validation.' part = _messages.ToolReturnPart( tool_name=call.tool_name, - content=content, + content='Output tool not used - a final result was already processed.', tool_call_id=call.tool_call_id, ) yield _messages.FunctionToolResultEvent(part) - output_parts.append(part) - else: - yield _messages.FunctionToolCallEvent(call) - part = _unknown_tool(call.tool_name, call.tool_call_id, ctx) - yield _messages.FunctionToolResultEvent(part) output_parts.append(part) + else: + try: + result_data = await tool_manager.handle_call(call) + except exceptions.UnexpectedModelBehavior as e: + ctx.state.increment_retries(ctx.deps.max_result_retries, e) + raise e # pragma: no cover + except ToolRetryError as e: + ctx.state.increment_retries(ctx.deps.max_result_retries, e) + yield _messages.FunctionToolCallEvent(call) + output_parts.append(e.tool_retry) + yield _messages.FunctionToolResultEvent(e.tool_retry) + else: + part = _messages.ToolReturnPart( + tool_name=call.tool_name, + content='Final result processed.', + tool_call_id=call.tool_call_id, + ) + output_parts.append(part) + final_result = result.FinalResult(result_data, call.tool_name, call.tool_call_id) - if not calls_to_run: - return - - user_parts: list[_messages.UserPromptPart] = [] + # Then, we handle function tool calls + calls_to_run: list[_messages.ToolCallPart] = [] + if final_result and ctx.deps.end_strategy == 'early': + output_parts.extend( + [ + _messages.ToolReturnPart( + tool_name=call.tool_name, + content='Tool not executed - a final result was already processed.', + tool_call_id=call.tool_call_id, + ) + for call in tool_calls_by_kind['function'] + ] + ) + else: + calls_to_run.extend(tool_calls_by_kind['function']) - include_content = ( - ctx.deps.instrumentation_settings is not None and ctx.deps.instrumentation_settings.include_content - ) + # Then, we handle unknown tool calls + if tool_calls_by_kind['unknown']: + ctx.state.increment_retries(ctx.deps.max_result_retries) + calls_to_run.extend(tool_calls_by_kind['unknown']) - # Run all tool tasks in parallel - results_by_index: dict[int, _messages.ModelRequestPart] = {} - with ctx.deps.tracer.start_as_current_span( - 'running tools', - attributes={ - 'tools': [call.tool_name for _, call in calls_to_run], - 'logfire.msg': f'running {len(calls_to_run)} tool{"" if len(calls_to_run) == 1 else "s"}', - }, - ): - tasks = [ - asyncio.create_task(tool.run(call, run_context, ctx.deps.tracer, include_content), name=call.tool_name) - for tool, call in calls_to_run - ] - - pending = tasks - while pending: - done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) - for task in done: - index = tasks.index(task) - result = task.result() - yield _messages.FunctionToolResultEvent(result) - - if isinstance(result, _messages.RetryPromptPart): - results_by_index[index] = result - elif isinstance(result, _messages.ToolReturnPart): - if isinstance(result.content, _messages.ToolReturn): - tool_return = result.content - if ( - isinstance(tool_return.return_value, _messages.MultiModalContentTypes) - or isinstance(tool_return.return_value, list) - and any( - isinstance(content, _messages.MultiModalContentTypes) - for content in tool_return.return_value # type: ignore - ) - ): - raise exceptions.UserError( - f"{result.tool_name}'s `return_value` contains invalid nested MultiModalContentTypes objects. " - f'Please use `content` instead.' - ) - result.content = tool_return.return_value # type: ignore - result.metadata = tool_return.metadata - if tool_return.content: - user_parts.append( - _messages.UserPromptPart( - content=list(tool_return.content), - timestamp=result.timestamp, - part_kind='user-prompt', - ) - ) - contents: list[Any] - single_content: bool - if isinstance(result.content, list): - contents = result.content # type: ignore - single_content = False - else: - contents = [result.content] - single_content = True - - processed_contents: list[Any] = [] - for content in contents: - if isinstance(content, _messages.ToolReturn): - raise exceptions.UserError( - f"{result.tool_name}'s return contains invalid nested ToolReturn objects. " - f'ToolReturn should be used directly.' - ) - elif isinstance(content, _messages.MultiModalContentTypes): - # Handle direct multimodal content - if isinstance(content, _messages.BinaryContent): - identifier = multi_modal_content_identifier(content.data) - else: - identifier = multi_modal_content_identifier(content.url) - - user_parts.append( - _messages.UserPromptPart( - content=[f'This is file {identifier}:', content], - timestamp=result.timestamp, - part_kind='user-prompt', - ) - ) - processed_contents.append(f'See file {identifier}') - else: - # Handle regular content - processed_contents.append(content) - - if single_content: - result.content = processed_contents[0] - else: - result.content = processed_contents + for call in calls_to_run: + yield _messages.FunctionToolCallEvent(call) - results_by_index[index] = result - else: - assert_never(result) + user_parts: list[_messages.UserPromptPart] = [] - # We append the results at the end, rather than as they are received, to retain a consistent ordering - # This is mostly just to simplify testing - for k in sorted(results_by_index): - output_parts.append(results_by_index[k]) + if calls_to_run: + # Run all tool tasks in parallel + parts_by_index: dict[int, list[_messages.ModelRequestPart]] = {} + with ctx.deps.tracer.start_as_current_span( + 'running tools', + attributes={ + 'tools': [call.tool_name for call in calls_to_run], + 'logfire.msg': f'running {len(calls_to_run)} tool{"" if len(calls_to_run) == 1 else "s"}', + }, + ): + tasks = [ + asyncio.create_task(_call_function_tool(tool_manager, call), name=call.tool_name) + for call in calls_to_run + ] + + pending = tasks + while pending: + done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) + for task in done: + index = tasks.index(task) + tool_result_part, extra_parts = task.result() + yield _messages.FunctionToolResultEvent(tool_result_part) + + parts_by_index[index] = [tool_result_part, *extra_parts] + + # We append the results at the end, rather than as they are received, to retain a consistent ordering + # This is mostly just to simplify testing + for k in sorted(parts_by_index): + output_parts.extend(parts_by_index[k]) + + # Finally, we handle deferred tool calls + for call in tool_calls_by_kind['deferred']: + if final_result: + output_parts.append( + _messages.ToolReturnPart( + tool_name=call.tool_name, + content='Tool not executed - a final result was already processed.', + tool_call_id=call.tool_call_id, + ) + ) + else: + yield _messages.FunctionToolCallEvent(call) output_parts.extend(user_parts) + if final_result: + output_final_result.append(final_result) -async def _tool_from_mcp_server( - tool_name: str, - ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], -) -> Tool[DepsT] | None: - """Call each MCP server to find the tool with the given name. - - Args: - tool_name: The name of the tool to find. - ctx: The current run context. - Returns: - The tool with the given name, or `None` if no tool with the given name is found. - """ - - async def run_tool(ctx: RunContext[DepsT], **args: Any) -> Any: - # There's no normal situation where the server will not be running at this point, we check just in case - # some weird edge case occurs. - if not server.is_running: # pragma: no cover - raise exceptions.UserError(f'MCP server is not running: {server}') - - if server.process_tool_call is not None: - result = await server.process_tool_call(ctx, server.call_tool, tool_name, args) - else: - result = await server.call_tool(tool_name, args) - - return result - - for server in ctx.deps.mcp_servers: - tools = await server.list_tools() - if tool_name in {tool.name for tool in tools}: # pragma: no branch - return Tool(name=tool_name, function=run_tool, takes_ctx=True, max_retries=ctx.deps.default_retries) - return None +async def _call_function_tool( + tool_manager: ToolManager[DepsT], + tool_call: _messages.ToolCallPart, +) -> tuple[_messages.ToolReturnPart | _messages.RetryPromptPart, list[_messages.ModelRequestPart]]: + try: + tool_result = await tool_manager.handle_call(tool_call) + except ToolRetryError as e: + return (e.tool_retry, []) + + part = _messages.ToolReturnPart( + tool_name=tool_call.tool_name, + content=tool_result, + tool_call_id=tool_call.tool_call_id, + ) + extra_parts: list[_messages.ModelRequestPart] = [] + if isinstance(tool_result, _messages.ToolReturn): + if ( + isinstance(tool_result.return_value, _messages.MultiModalContentTypes) + or isinstance(tool_result.return_value, list) + and any( + isinstance(content, _messages.MultiModalContentTypes) + for content in tool_result.return_value # type: ignore + ) + ): + raise exceptions.UserError( + f'The `return_value` of tool {tool_call.tool_name!r} contains invalid nested `MultiModalContentTypes` objects. ' + f'Please use `content` instead.' + ) -def _unknown_tool( - tool_name: str, - tool_call_id: str, - ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], -) -> _messages.RetryPromptPart: - ctx.state.increment_retries(ctx.deps.max_result_retries) - tool_names = list(ctx.deps.function_tools.keys()) + part.content = tool_result.return_value # type: ignore + part.metadata = tool_result.metadata + if tool_result.content: + extra_parts.append( + _messages.UserPromptPart( + content=list(tool_result.content), + part_kind='user-prompt', + ) + ) + else: - output_schema = ctx.deps.output_schema - if isinstance(output_schema, _output.ToolOutputSchema): - tool_names.extend(output_schema.tool_names()) + def process_content(content: Any) -> Any: + if isinstance(content, _messages.ToolReturn): + raise exceptions.UserError( + f'The return value of tool {tool_call.tool_name!r} contains invalid nested `ToolReturn` objects. ' + f'`ToolReturn` should be used directly.' + ) + elif isinstance(content, _messages.MultiModalContentTypes): + if isinstance(content, _messages.BinaryContent): + identifier = multi_modal_content_identifier(content.data) + else: + identifier = multi_modal_content_identifier(content.url) - if tool_names: - msg = f'Available tools: {", ".join(tool_names)}' - else: - msg = 'No tools available.' + extra_parts.append( + _messages.UserPromptPart( + content=[f'This is file {identifier}:', content], + part_kind='user-prompt', + ) + ) + return f'See file {identifier}' - return _messages.RetryPromptPart( - tool_name=tool_name, - tool_call_id=tool_call_id, - content=f'Unknown tool name: {tool_name!r}. {msg}', - ) + return content + if isinstance(tool_result, list): + contents = cast(list[Any], tool_result) + part.content = [process_content(content) for content in contents] + else: + part.content = process_content(tool_result) -async def _validate_output( - result_data: T, - ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]], - tool_call: _messages.ToolCallPart | None, -) -> T: - for validator in ctx.deps.output_validators: - run_context = build_run_context(ctx) - result_data = await validator.validate(result_data, tool_call, run_context) - return result_data + return (part, extra_parts) @dataclasses.dataclass diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index c3199dd95c..c8925b0678 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -1,24 +1,21 @@ from __future__ import annotations as _annotations -import dataclasses import inspect import json from abc import ABC, abstractmethod -from collections.abc import Awaitable, Iterable, Iterator, Sequence +from collections.abc import Awaitable, Sequence from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast, overload -from opentelemetry.trace import Tracer from pydantic import TypeAdapter, ValidationError -from pydantic_core import SchemaValidator -from typing_extensions import TypedDict, TypeVar, assert_never - -from pydantic_graph.nodes import GraphRunContext +from pydantic_core import SchemaValidator, to_json +from typing_extensions import Self, TypedDict, TypeVar, assert_never from . import _function_schema, _utils, messages as _messages from ._run_context import AgentDepsT, RunContext -from .exceptions import ModelRetry, UserError +from .exceptions import ModelRetry, ToolRetryError, UserError from .output import ( + DeferredToolCalls, NativeOutput, OutputDataT, OutputMode, @@ -29,12 +26,12 @@ TextOutput, TextOutputFunc, ToolOutput, + _OutputSpecItem, # type: ignore[reportPrivateUsage] ) from .tools import GenerateToolJsonSchema, ObjectJsonSchema, ToolDefinition +from .toolsets.abstract import AbstractToolset, ToolsetTool if TYPE_CHECKING: - from pydantic_ai._agent_graph import DepsT, GraphAgentDeps, GraphAgentState - from .profiles import ModelProfile T = TypeVar('T') @@ -72,77 +69,45 @@ DEFAULT_OUTPUT_TOOL_DESCRIPTION = 'The final response which ends this conversation' -@dataclass(frozen=True) -class TraceContext: - """A context for tracing output processing.""" +async def execute_output_function_with_span( + function_schema: _function_schema.FunctionSchema, + run_context: RunContext[AgentDepsT], + args: dict[str, Any] | Any, +) -> Any: + """Execute a function call within a traced span, automatically recording the response.""" + # Set up span attributes + tool_name = run_context.tool_name or getattr(function_schema.function, '__name__', 'output_function') + attributes = { + 'gen_ai.tool.name': tool_name, + 'logfire.msg': f'running output function: {tool_name}', + } + if run_context.tool_call_id: + attributes['gen_ai.tool.call.id'] = run_context.tool_call_id + if run_context.trace_include_content: + attributes['tool_arguments'] = to_json(args).decode() + attributes['logfire.json_schema'] = json.dumps( + { + 'type': 'object', + 'properties': { + 'tool_arguments': {'type': 'object'}, + 'tool_response': {'type': 'object'}, + }, + } + ) - tracer: Tracer - include_content: bool - call: _messages.ToolCallPart | None = None + with run_context.tracer.start_as_current_span('running output function', attributes=attributes) as span: + output = await function_schema.call(args, run_context) - def with_call(self, call: _messages.ToolCallPart): - return dataclasses.replace(self, call=call) + # Record response if content inclusion is enabled + if run_context.trace_include_content and span.is_recording(): + from .models.instrumented import InstrumentedModel - async def execute_function_with_span( - self, - function_schema: _function_schema.FunctionSchema, - run_context: RunContext[AgentDepsT], - args: dict[str, Any] | Any, - call: _messages.ToolCallPart, - include_tool_call_id: bool = True, - ) -> Any: - """Execute a function call within a traced span, automatically recording the response.""" - # Set up span attributes - attributes = { - 'gen_ai.tool.name': call.tool_name, - 'logfire.msg': f'running output function: {call.tool_name}', - } - if include_tool_call_id: - attributes['gen_ai.tool.call.id'] = call.tool_call_id - if self.include_content: - attributes['tool_arguments'] = call.args_as_json_str() - attributes['logfire.json_schema'] = json.dumps( - { - 'type': 'object', - 'properties': { - 'tool_arguments': {'type': 'object'}, - 'tool_response': {'type': 'object'}, - }, - } + span.set_attribute( + 'tool_response', + output if isinstance(output, str) else json.dumps(InstrumentedModel.serialize_any(output)), ) - # Execute function within span - with self.tracer.start_as_current_span('running output function', attributes=attributes) as span: - output = await function_schema.call(args, run_context) - - # Record response if content inclusion is enabled - if self.include_content and span.is_recording(): - from .models.instrumented import InstrumentedModel - - span.set_attribute( - 'tool_response', - output if isinstance(output, str) else json.dumps(InstrumentedModel.serialize_any(output)), - ) - - return output - - -def build_trace_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> TraceContext: - """Build a `TraceContext` from the current agent graph run context.""" - return TraceContext( - tracer=ctx.deps.tracer, - include_content=( - ctx.deps.instrumentation_settings is not None and ctx.deps.instrumentation_settings.include_content - ), - ) - - -class ToolRetryError(Exception): - """Exception used to signal a `ToolRetry` message should be returned to the LLM.""" - - def __init__(self, tool_retry: _messages.RetryPromptPart): - self.tool_retry = tool_retry - super().__init__() + return output @dataclass @@ -158,23 +123,21 @@ def __post_init__(self): async def validate( self, result: T, - tool_call: _messages.ToolCallPart | None, run_context: RunContext[AgentDepsT], + wrap_validation_errors: bool = True, ) -> T: """Validate a result but calling the function. Args: result: The result data after Pydantic validation the message content. - tool_call: The original tool call message, `None` if there was no tool call. run_context: The current run context. - trace_context: The trace context to use for tracing the output processing. + wrap_validation_errors: If true, wrap the validation errors in a retry message. Returns: Result of either the validated result data (ok) or a retry message (Err). """ if self._takes_ctx: - ctx = run_context.replace_with(tool_name=tool_call.tool_name if tool_call else None) - args = ctx, result + args = run_context, result else: args = (result,) @@ -186,24 +149,32 @@ async def validate( function = cast(Callable[[Any], T], self.function) result_data = await _utils.run_in_executor(function, *args) except ModelRetry as r: - m = _messages.RetryPromptPart(content=r.message) - if tool_call is not None: - m.tool_name = tool_call.tool_name - m.tool_call_id = tool_call.tool_call_id - raise ToolRetryError(m) from r + if wrap_validation_errors: + m = _messages.RetryPromptPart( + content=r.message, + tool_name=run_context.tool_name, + ) + if run_context.tool_call_id: # pragma: no cover + m.tool_call_id = run_context.tool_call_id + raise ToolRetryError(m) from r + else: + raise r else: return result_data +@dataclass class BaseOutputSchema(ABC, Generic[OutputDataT]): + allows_deferred_tool_calls: bool + @abstractmethod def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDataT]: raise NotImplementedError() @property - def tools(self) -> dict[str, OutputTool[OutputDataT]]: - """Get the tools for this output schema.""" - return {} + def toolset(self) -> OutputToolset[Any] | None: + """Get the toolset for this output schema.""" + return None @dataclass(init=False) @@ -235,7 +206,7 @@ def build( ) -> BaseOutputSchema[OutputDataT]: ... @classmethod - def build( + def build( # noqa: C901 cls, output_spec: OutputSpec[OutputDataT], *, @@ -245,117 +216,93 @@ def build( strict: bool | None = None, ) -> BaseOutputSchema[OutputDataT]: """Build an OutputSchema dataclass from an output type.""" - if output_spec is str: - return PlainTextOutputSchema() + raw_outputs = _flatten_output_spec(output_spec) + + outputs = [output for output in raw_outputs if output is not DeferredToolCalls] + allows_deferred_tool_calls = len(outputs) < len(raw_outputs) + if len(outputs) == 0 and allows_deferred_tool_calls: + raise UserError('At least one output type must be provided other than `DeferredToolCalls`.') + + if output := next((output for output in outputs if isinstance(output, NativeOutput)), None): + if len(outputs) > 1: + raise UserError('`NativeOutput` must be the only output type.') # pragma: no cover - if isinstance(output_spec, NativeOutput): return NativeOutputSchema( - cls._build_processor( - _flatten_output_spec(output_spec.outputs), - name=output_spec.name, - description=output_spec.description, - strict=output_spec.strict, - ) + processor=cls._build_processor( + _flatten_output_spec(output.outputs), + name=output.name, + description=output.description, + strict=output.strict, + ), + allows_deferred_tool_calls=allows_deferred_tool_calls, ) - elif isinstance(output_spec, PromptedOutput): + elif output := next((output for output in outputs if isinstance(output, PromptedOutput)), None): + if len(outputs) > 1: + raise UserError('`PromptedOutput` must be the only output type.') # pragma: no cover + return PromptedOutputSchema( - cls._build_processor( - _flatten_output_spec(output_spec.outputs), - name=output_spec.name, - description=output_spec.description, + processor=cls._build_processor( + _flatten_output_spec(output.outputs), + name=output.name, + description=output.description, ), - template=output_spec.template, + template=output.template, + allows_deferred_tool_calls=allows_deferred_tool_calls, ) text_outputs: Sequence[type[str] | TextOutput[OutputDataT]] = [] tool_outputs: Sequence[ToolOutput[OutputDataT]] = [] other_outputs: Sequence[OutputTypeOrFunction[OutputDataT]] = [] - for output in _flatten_output_spec(output_spec): + for output in outputs: if output is str: text_outputs.append(cast(type[str], output)) elif isinstance(output, TextOutput): text_outputs.append(output) elif isinstance(output, ToolOutput): tool_outputs.append(output) + elif isinstance(output, NativeOutput): + # We can never get here because this is checked for above. + raise UserError('`NativeOutput` must be the only output type.') # pragma: no cover + elif isinstance(output, PromptedOutput): + # We can never get here because this is checked for above. + raise UserError('`PromptedOutput` must be the only output type.') # pragma: no cover else: other_outputs.append(output) - tools = cls._build_tools(tool_outputs + other_outputs, name=name, description=description, strict=strict) + toolset = OutputToolset.build(tool_outputs + other_outputs, name=name, description=description, strict=strict) if len(text_outputs) > 0: if len(text_outputs) > 1: - raise UserError('Only one text output is allowed.') + raise UserError('Only one `str` or `TextOutput` is allowed.') text_output = text_outputs[0] text_output_schema = None if isinstance(text_output, TextOutput): text_output_schema = PlainTextOutputProcessor(text_output.output_function) - if len(tools) == 0: - return PlainTextOutputSchema(text_output_schema) + if toolset: + return ToolOrTextOutputSchema( + processor=text_output_schema, toolset=toolset, allows_deferred_tool_calls=allows_deferred_tool_calls + ) else: - return ToolOrTextOutputSchema(processor=text_output_schema, tools=tools) + return PlainTextOutputSchema( + processor=text_output_schema, allows_deferred_tool_calls=allows_deferred_tool_calls + ) if len(tool_outputs) > 0: - return ToolOutputSchema(tools) + return ToolOutputSchema(toolset=toolset, allows_deferred_tool_calls=allows_deferred_tool_calls) if len(other_outputs) > 0: schema = OutputSchemaWithoutMode( processor=cls._build_processor(other_outputs, name=name, description=description, strict=strict), - tools=tools, + toolset=toolset, + allows_deferred_tool_calls=allows_deferred_tool_calls, ) if default_mode: schema = schema.with_default_mode(default_mode) return schema - raise UserError('No output type provided.') # pragma: no cover - - @staticmethod - def _build_tools( - outputs: list[OutputTypeOrFunction[OutputDataT] | ToolOutput[OutputDataT]], - name: str | None = None, - description: str | None = None, - strict: bool | None = None, - ) -> dict[str, OutputTool[OutputDataT]]: - tools: dict[str, OutputTool[OutputDataT]] = {} - - default_name = name or DEFAULT_OUTPUT_TOOL_NAME - default_description = description - default_strict = strict - - multiple = len(outputs) > 1 - for output in outputs: - name = None - description = None - strict = None - if isinstance(output, ToolOutput): - # do we need to error on conflicts here? (DavidM): If this is internal maybe doesn't matter, if public, use overloads - name = output.name - description = output.description - strict = output.strict - - output = output.output - - description = description or default_description - if strict is None: - strict = default_strict - - processor = ObjectOutputProcessor(output=output, description=description, strict=strict) - - if name is None: - name = default_name - if multiple: - name += f'_{processor.object_def.name}' - - i = 1 - original_name = name - while name in tools: - i += 1 - name = f'{original_name}_{i}' - - tools[name] = OutputTool(name=name, processor=processor, multiple=multiple) - - return tools + raise UserError('At least one output type must be provided.') @staticmethod def _build_processor( @@ -387,32 +334,39 @@ def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDa @dataclass(init=False) class OutputSchemaWithoutMode(BaseOutputSchema[OutputDataT]): processor: ObjectOutputProcessor[OutputDataT] | UnionOutputProcessor[OutputDataT] - _tools: dict[str, OutputTool[OutputDataT]] = field(default_factory=dict) + _toolset: OutputToolset[Any] | None def __init__( self, processor: ObjectOutputProcessor[OutputDataT] | UnionOutputProcessor[OutputDataT], - tools: dict[str, OutputTool[OutputDataT]], + toolset: OutputToolset[Any] | None, + allows_deferred_tool_calls: bool, ): + super().__init__(allows_deferred_tool_calls) self.processor = processor - self._tools = tools + self._toolset = toolset def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDataT]: if mode == 'native': - return NativeOutputSchema(self.processor) + return NativeOutputSchema( + processor=self.processor, allows_deferred_tool_calls=self.allows_deferred_tool_calls + ) elif mode == 'prompted': - return PromptedOutputSchema(self.processor) + return PromptedOutputSchema( + processor=self.processor, allows_deferred_tool_calls=self.allows_deferred_tool_calls + ) elif mode == 'tool': - return ToolOutputSchema(self.tools) + return ToolOutputSchema(toolset=self.toolset, allows_deferred_tool_calls=self.allows_deferred_tool_calls) else: assert_never(mode) @property - def tools(self) -> dict[str, OutputTool[OutputDataT]]: - """Get the tools for this output schema.""" - # We return tools here as they're checked in Agent._register_tool. - # At that point we may don't know yet what output mode we're going to use if no model was provided or it was deferred until agent.run time. - return self._tools + def toolset(self) -> OutputToolset[Any] | None: + """Get the toolset for this output schema.""" + # We return a toolset here as they're checked for name conflicts with other toolsets in the Agent constructor. + # At that point we may not know yet what output mode we're going to use if no model was provided or it was deferred until agent.run time, + # but we cover ourselves just in case we end up using the tool output mode. + return self._toolset class TextOutputSchema(OutputSchema[OutputDataT], ABC): @@ -421,7 +375,6 @@ async def process( self, text: str, run_context: RunContext[AgentDepsT], - trace_context: TraceContext, allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: @@ -444,7 +397,6 @@ async def process( self, text: str, run_context: RunContext[AgentDepsT], - trace_context: TraceContext, allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: @@ -453,7 +405,6 @@ async def process( Args: text: The output text to validate. run_context: The current run context. - trace_context: The trace context to use for tracing the output processing. allow_partial: If true, allow partial validation. wrap_validation_errors: If true, wrap the validation errors in a retry message. @@ -464,7 +415,7 @@ async def process( return cast(OutputDataT, text) return await self.processor.process( - text, run_context, trace_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors + text, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors ) @@ -486,13 +437,12 @@ def mode(self) -> OutputMode: def raise_if_unsupported(self, profile: ModelProfile) -> None: """Raise an error if the mode is not supported by the model.""" if not profile.supports_json_schema_output: - raise UserError('Structured output is not supported by the model.') + raise UserError('Native structured output is not supported by the model.') async def process( self, text: str, run_context: RunContext[AgentDepsT], - trace_context: TraceContext, allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: @@ -501,7 +451,6 @@ async def process( Args: text: The output text to validate. run_context: The current run context. - trace_context: The trace context to use for tracing the output processing. allow_partial: If true, allow partial validation. wrap_validation_errors: If true, wrap the validation errors in a retry message. @@ -509,7 +458,7 @@ async def process( Either the validated output data (left) or a retry message (right). """ return await self.processor.process( - text, run_context, trace_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors + text, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors ) @@ -545,7 +494,6 @@ async def process( self, text: str, run_context: RunContext[AgentDepsT], - trace_context: TraceContext, allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: @@ -554,7 +502,6 @@ async def process( Args: text: The output text to validate. run_context: The current run context. - trace_context: The trace context to use for tracing the output processing. allow_partial: If true, allow partial validation. wrap_validation_errors: If true, wrap the validation errors in a retry message. @@ -564,16 +511,17 @@ async def process( text = _utils.strip_markdown_fences(text) return await self.processor.process( - text, run_context, trace_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors + text, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors ) @dataclass(init=False) class ToolOutputSchema(OutputSchema[OutputDataT]): - _tools: dict[str, OutputTool[OutputDataT]] = field(default_factory=dict) + _toolset: OutputToolset[Any] | None - def __init__(self, tools: dict[str, OutputTool[OutputDataT]]): - self._tools = tools + def __init__(self, toolset: OutputToolset[Any] | None, allows_deferred_tool_calls: bool): + super().__init__(allows_deferred_tool_calls) + self._toolset = toolset @property def mode(self) -> OutputMode: @@ -585,36 +533,9 @@ def raise_if_unsupported(self, profile: ModelProfile) -> None: raise UserError('Output tools are not supported by the model.') @property - def tools(self) -> dict[str, OutputTool[OutputDataT]]: - """Get the tools for this output schema.""" - return self._tools - - def tool_names(self) -> list[str]: - """Return the names of the tools.""" - return list(self.tools.keys()) - - def tool_defs(self) -> list[ToolDefinition]: - """Get tool definitions to register with the model.""" - return [t.tool_def for t in self.tools.values()] - - def find_named_tool( - self, parts: Iterable[_messages.ModelResponsePart], tool_name: str - ) -> tuple[_messages.ToolCallPart, OutputTool[OutputDataT]] | None: - """Find a tool that matches one of the calls, with a specific name.""" - for part in parts: # pragma: no branch - if isinstance(part, _messages.ToolCallPart): # pragma: no branch - if part.tool_name == tool_name: - return part, self.tools[tool_name] - - def find_tool( - self, - parts: Iterable[_messages.ModelResponsePart], - ) -> Iterator[tuple[_messages.ToolCallPart, OutputTool[OutputDataT]]]: - """Find a tool that matches one of the calls.""" - for part in parts: - if isinstance(part, _messages.ToolCallPart): # pragma: no branch - if result := self.tools.get(part.tool_name): - yield part, result + def toolset(self) -> OutputToolset[Any] | None: + """Get the toolset for this output schema.""" + return self._toolset @dataclass(init=False) @@ -622,10 +543,11 @@ class ToolOrTextOutputSchema(ToolOutputSchema[OutputDataT], PlainTextOutputSchem def __init__( self, processor: PlainTextOutputProcessor[OutputDataT] | None, - tools: dict[str, OutputTool[OutputDataT]], + toolset: OutputToolset[Any] | None, + allows_deferred_tool_calls: bool, ): + super().__init__(toolset=toolset, allows_deferred_tool_calls=allows_deferred_tool_calls) self.processor = processor - self._tools = tools @property def mode(self) -> OutputMode: @@ -647,7 +569,6 @@ async def process( self, data: str, run_context: RunContext[AgentDepsT], - trace_context: TraceContext, allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: @@ -659,7 +580,7 @@ async def process( class ObjectOutputProcessor(BaseOutputProcessor[OutputDataT]): object_def: OutputObjectDefinition outer_typed_dict_key: str | None = None - _validator: SchemaValidator + validator: SchemaValidator _function_schema: _function_schema.FunctionSchema | None = None def __init__( @@ -672,7 +593,7 @@ def __init__( ): if inspect.isfunction(output) or inspect.ismethod(output): self._function_schema = _function_schema.function_schema(output, GenerateToolJsonSchema) - self._validator = self._function_schema.validator + self.validator = self._function_schema.validator json_schema = self._function_schema.json_schema json_schema['description'] = self._function_schema.description else: @@ -688,7 +609,7 @@ def __init__( type_adapter = TypeAdapter(response_data_typed_dict) # Really a PluggableSchemaValidator, but it's API-compatible - self._validator = cast(SchemaValidator, type_adapter.validator) + self.validator = cast(SchemaValidator, type_adapter.validator) json_schema = _utils.check_object_json_schema( type_adapter.json_schema(schema_generator=GenerateToolJsonSchema) ) @@ -717,7 +638,6 @@ async def process( self, data: str | dict[str, Any] | None, run_context: RunContext[AgentDepsT], - trace_context: TraceContext, allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: @@ -726,7 +646,6 @@ async def process( Args: data: The output data to validate. run_context: The current run context. - trace_context: The trace context to use for tracing the output processing. allow_partial: If true, allow partial validation. wrap_validation_errors: If true, wrap the validation errors in a retry message. @@ -734,11 +653,7 @@ async def process( Either the validated output data (left) or a retry message (right). """ try: - pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off' - if isinstance(data, str): - output = self._validator.validate_json(data or '{}', allow_partial=pyd_allow_partial) - else: - output = self._validator.validate_python(data or {}, allow_partial=pyd_allow_partial) + output = self.validate(data, allow_partial) except ValidationError as e: if wrap_validation_errors: m = _messages.RetryPromptPart( @@ -748,30 +663,40 @@ async def process( else: raise + try: + output = await self.call(output, run_context) + except ModelRetry as r: + if wrap_validation_errors: + m = _messages.RetryPromptPart( + content=r.message, + ) + raise ToolRetryError(m) from r + else: + raise # pragma: no cover + + return output + + def validate( + self, + data: str | dict[str, Any] | None, + allow_partial: bool = False, + ) -> dict[str, Any]: + pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off' + if isinstance(data, str): + return self.validator.validate_json(data or '{}', allow_partial=pyd_allow_partial) + else: + return self.validator.validate_python(data or {}, allow_partial=pyd_allow_partial) + + async def call( + self, + output: Any, + run_context: RunContext[AgentDepsT], + ): if k := self.outer_typed_dict_key: output = output[k] if self._function_schema: - # Wraps the output function call in an OpenTelemetry span. - if trace_context.call: - call = trace_context.call - include_tool_call_id = True - else: - function_name = getattr(self._function_schema.function, '__name__', 'output_function') - call = _messages.ToolCallPart(tool_name=function_name, args=data) - include_tool_call_id = False - try: - output = await trace_context.execute_function_with_span( - self._function_schema, run_context, output, call, include_tool_call_id - ) - except ModelRetry as r: - if wrap_validation_errors: - m = _messages.RetryPromptPart( - content=r.message, - ) - raise ToolRetryError(m) from r - else: - raise + output = await execute_output_function_with_span(self._function_schema, run_context, output) return output @@ -876,12 +801,11 @@ async def process( self, data: str | dict[str, Any] | None, run_context: RunContext[AgentDepsT], - trace_context: TraceContext, allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: union_object = await self._union_processor.process( - data, run_context, trace_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors + data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors ) result = union_object.result @@ -897,7 +821,7 @@ async def process( raise return await processor.process( - data, run_context, trace_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors + data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors ) @@ -928,20 +852,12 @@ async def process( self, data: str, run_context: RunContext[AgentDepsT], - trace_context: TraceContext, allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: args = {self._str_argument_name: data} - # Wraps the output function call in an OpenTelemetry span. - # Note: PlainTextOutputProcessor is used for text responses (not tool calls), - # so we don't have tool call attributes like gen_ai.tool.name or gen_ai.tool.call.id - function_name = getattr(self._function_schema.function, '__name__', 'text_output_function') - call = _messages.ToolCallPart(tool_name=function_name, args=args) try: - output = await trace_context.execute_function_with_span( - self._function_schema, run_context, args, call, include_tool_call_id=False - ) + output = await execute_output_function_with_span(self._function_schema, run_context, args) except ModelRetry as r: if wrap_validation_errors: m = _messages.RetryPromptPart( @@ -955,91 +871,139 @@ async def process( @dataclass(init=False) -class OutputTool(Generic[OutputDataT]): - processor: ObjectOutputProcessor[OutputDataT] - tool_def: ToolDefinition +class OutputToolset(AbstractToolset[AgentDepsT]): + """A toolset that contains contains output tools for agent output types.""" - def __init__(self, *, name: str, processor: ObjectOutputProcessor[OutputDataT], multiple: bool): - self.processor = processor - object_def = processor.object_def + _tool_defs: list[ToolDefinition] + """The tool definitions for the output tools in this toolset.""" + processors: dict[str, ObjectOutputProcessor[Any]] + """The processors for the output tools in this toolset.""" + max_retries: int + output_validators: list[OutputValidator[AgentDepsT, Any]] - description = object_def.description - if not description: - description = DEFAULT_OUTPUT_TOOL_DESCRIPTION - if multiple: - description = f'{object_def.name}: {description}' + @classmethod + def build( + cls, + outputs: list[OutputTypeOrFunction[OutputDataT] | ToolOutput[OutputDataT]], + name: str | None = None, + description: str | None = None, + strict: bool | None = None, + ) -> Self | None: + if len(outputs) == 0: + return None - self.tool_def = ToolDefinition( - name=name, - description=description, - parameters_json_schema=object_def.json_schema, - strict=object_def.strict, - outer_typed_dict_key=processor.outer_typed_dict_key, - ) + processors: dict[str, ObjectOutputProcessor[Any]] = {} + tool_defs: list[ToolDefinition] = [] - async def process( - self, - tool_call: _messages.ToolCallPart, - run_context: RunContext[AgentDepsT], - trace_context: TraceContext, - allow_partial: bool = False, - wrap_validation_errors: bool = True, - ) -> OutputDataT: - """Process an output message. + default_name = name or DEFAULT_OUTPUT_TOOL_NAME + default_description = description + default_strict = strict - Args: - tool_call: The tool call from the LLM to validate. - run_context: The current run context. - trace_context: The trace context to use for tracing the output processing. - allow_partial: If true, allow partial validation. - wrap_validation_errors: If true, wrap the validation errors in a retry message. + multiple = len(outputs) > 1 + for output in outputs: + name = None + description = None + strict = None + if isinstance(output, ToolOutput): + # do we need to error on conflicts here? (DavidM): If this is internal maybe doesn't matter, if public, use overloads + name = output.name + description = output.description + strict = output.strict - Returns: - Either the validated output data (left) or a retry message (right). - """ - try: - output = await self.processor.process( - tool_call.args, - run_context, - trace_context.with_call(tool_call), - allow_partial=allow_partial, - wrap_validation_errors=False, + output = output.output + + description = description or default_description + if strict is None: + strict = default_strict + + processor = ObjectOutputProcessor(output=output, description=description, strict=strict) + object_def = processor.object_def + + if name is None: + name = default_name + if multiple: + name += f'_{object_def.name}' + + i = 1 + original_name = name + while name in processors: + i += 1 + name = f'{original_name}_{i}' + + description = object_def.description + if not description: + description = DEFAULT_OUTPUT_TOOL_DESCRIPTION + if multiple: + description = f'{object_def.name}: {description}' + + tool_def = ToolDefinition( + name=name, + description=description, + parameters_json_schema=object_def.json_schema, + strict=object_def.strict, + outer_typed_dict_key=processor.outer_typed_dict_key, + kind='output', ) - except ValidationError as e: - if wrap_validation_errors: - m = _messages.RetryPromptPart( - tool_name=tool_call.tool_name, - content=e.errors(include_url=False, include_context=False), - tool_call_id=tool_call.tool_call_id, - ) - raise ToolRetryError(m) from e - else: - raise # pragma: no cover - except ModelRetry as r: - if wrap_validation_errors: - m = _messages.RetryPromptPart( - tool_name=tool_call.tool_name, - content=r.message, - tool_call_id=tool_call.tool_call_id, - ) - raise ToolRetryError(m) from r - else: - raise # pragma: no cover - else: - return output + processors[name] = processor + tool_defs.append(tool_def) + + return cls(processors=processors, tool_defs=tool_defs) + + def __init__( + self, + tool_defs: list[ToolDefinition], + processors: dict[str, ObjectOutputProcessor[Any]], + max_retries: int = 1, + output_validators: list[OutputValidator[AgentDepsT, Any]] | None = None, + ): + self.processors = processors + self._tool_defs = tool_defs + self.max_retries = max_retries + self.output_validators = output_validators or [] + + async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]: + return { + tool_def.name: ToolsetTool( + toolset=self, + tool_def=tool_def, + max_retries=self.max_retries, + args_validator=self.processors[tool_def.name].validator, + ) + for tool_def in self._tool_defs + } + + async def call_tool( + self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT] + ) -> Any: + output = await self.processors[name].call(tool_args, ctx) + for validator in self.output_validators: + output = await validator.validate(output, ctx, wrap_validation_errors=False) + return output + + +@overload +def _flatten_output_spec( + output_spec: OutputTypeOrFunction[T] | Sequence[OutputTypeOrFunction[T]], +) -> Sequence[OutputTypeOrFunction[T]]: ... + + +@overload +def _flatten_output_spec(output_spec: OutputSpec[T]) -> Sequence[_OutputSpecItem[T]]: ... -def _flatten_output_spec(output_spec: T | Sequence[T]) -> list[T]: - outputs: Sequence[T] +def _flatten_output_spec(output_spec: OutputSpec[T]) -> Sequence[_OutputSpecItem[T]]: + outputs: Sequence[OutputSpec[T]] if isinstance(output_spec, Sequence): outputs = output_spec else: outputs = (output_spec,) - outputs_flat: list[T] = [] + outputs_flat: list[_OutputSpecItem[T]] = [] for output in outputs: - if union_types := _utils.get_union_args(output): + if isinstance(output, Sequence): + outputs_flat.extend(_flatten_output_spec(cast(OutputSpec[T], output))) + elif union_types := _utils.get_union_args(output): outputs_flat.extend(union_types) else: - outputs_flat.append(output) + outputs_flat.append(cast(_OutputSpecItem[T], output)) return outputs_flat diff --git a/pydantic_ai_slim/pydantic_ai/_run_context.py b/pydantic_ai_slim/pydantic_ai/_run_context.py index bb7f474201..afad0e60e6 100644 --- a/pydantic_ai_slim/pydantic_ai/_run_context.py +++ b/pydantic_ai_slim/pydantic_ai/_run_context.py @@ -5,6 +5,7 @@ from dataclasses import field from typing import TYPE_CHECKING, Generic +from opentelemetry.trace import NoOpTracer, Tracer from typing_extensions import TypeVar from . import _utils, messages as _messages @@ -27,10 +28,16 @@ class RunContext(Generic[AgentDepsT]): """The model used in this run.""" usage: Usage """LLM usage associated with the run.""" - prompt: str | Sequence[_messages.UserContent] | None + prompt: str | Sequence[_messages.UserContent] | None = None """The original user prompt passed to the run.""" messages: list[_messages.ModelMessage] = field(default_factory=list) """Messages exchanged in the conversation so far.""" + tracer: Tracer = field(default_factory=NoOpTracer) + """The tracer to use for tracing the run.""" + trace_include_content: bool = False + """Whether to include the content of the messages in the trace.""" + retries: dict[str, int] = field(default_factory=dict) + """Number of retries for each tool so far.""" tool_call_id: str | None = None """The ID of the tool call.""" tool_name: str | None = None @@ -40,17 +47,4 @@ class RunContext(Generic[AgentDepsT]): run_step: int = 0 """The current step in the run.""" - def replace_with( - self, - retry: int | None = None, - tool_name: str | None | _utils.Unset = _utils.UNSET, - ) -> RunContext[AgentDepsT]: - # Create a new `RunContext` a new `retry` value and `tool_name`. - kwargs = {} - if retry is not None: - kwargs['retry'] = retry - if tool_name is not _utils.UNSET: # pragma: no branch - kwargs['tool_name'] = tool_name - return dataclasses.replace(self, **kwargs) - __repr__ = _utils.dataclasses_no_defaults_repr diff --git a/pydantic_ai_slim/pydantic_ai/_tool_manager.py b/pydantic_ai_slim/pydantic_ai/_tool_manager.py new file mode 100644 index 0000000000..bea4103896 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/_tool_manager.py @@ -0,0 +1,190 @@ +from __future__ import annotations + +import json +from collections.abc import Iterable +from dataclasses import dataclass, replace +from typing import Any, Generic + +from pydantic import ValidationError +from typing_extensions import assert_never + +from pydantic_ai.output import DeferredToolCalls + +from . import messages as _messages +from ._run_context import AgentDepsT, RunContext +from .exceptions import ModelRetry, ToolRetryError, UnexpectedModelBehavior +from .messages import ToolCallPart +from .tools import ToolDefinition +from .toolsets.abstract import AbstractToolset, ToolsetTool + + +@dataclass +class ToolManager(Generic[AgentDepsT]): + """Manages tools for an agent run step. It caches the agent run's toolset's tool definitions and handles calling tools and retries.""" + + ctx: RunContext[AgentDepsT] + """The agent run context for a specific run step.""" + toolset: AbstractToolset[AgentDepsT] + """The toolset that provides the tools for this run step.""" + tools: dict[str, ToolsetTool[AgentDepsT]] + """The cached tools for this run step.""" + + @classmethod + async def build(cls, toolset: AbstractToolset[AgentDepsT], ctx: RunContext[AgentDepsT]) -> ToolManager[AgentDepsT]: + """Build a new tool manager for a specific run step.""" + return cls( + ctx=ctx, + toolset=toolset, + tools=await toolset.get_tools(ctx), + ) + + async def for_run_step(self, ctx: RunContext[AgentDepsT]) -> ToolManager[AgentDepsT]: + """Build a new tool manager for the next run step, carrying over the retries from the current run step.""" + return await self.__class__.build(self.toolset, replace(ctx, retries=self.ctx.retries)) + + @property + def tool_defs(self) -> list[ToolDefinition]: + """The tool definitions for the tools in this tool manager.""" + return [tool.tool_def for tool in self.tools.values()] + + def get_tool_def(self, name: str) -> ToolDefinition | None: + """Get the tool definition for a given tool name, or `None` if the tool is unknown.""" + try: + return self.tools[name].tool_def + except KeyError: + return None + + async def handle_call(self, call: ToolCallPart, allow_partial: bool = False) -> Any: + """Handle a tool call by validating the arguments, calling the tool, and handling retries. + + Args: + call: The tool call part to handle. + allow_partial: Whether to allow partial validation of the tool arguments. + """ + if (tool := self.tools.get(call.tool_name)) and tool.tool_def.kind == 'output': + # Output tool calls are not traced + return await self._call_tool(call, allow_partial) + else: + return await self._call_tool_traced(call, allow_partial) + + async def _call_tool(self, call: ToolCallPart, allow_partial: bool = False) -> Any: + name = call.tool_name + tool = self.tools.get(name) + try: + if tool is None: + if self.tools: + msg = f'Available tools: {", ".join(f"{name!r}" for name in self.tools.keys())}' + else: + msg = 'No tools available.' + raise ModelRetry(f'Unknown tool name: {name!r}. {msg}') + + ctx = replace( + self.ctx, + tool_name=name, + tool_call_id=call.tool_call_id, + retry=self.ctx.retries.get(name, 0), + ) + + pyd_allow_partial = 'trailing-strings' if allow_partial else 'off' + validator = tool.args_validator + if isinstance(call.args, str): + args_dict = validator.validate_json(call.args or '{}', allow_partial=pyd_allow_partial) + else: + args_dict = validator.validate_python(call.args or {}, allow_partial=pyd_allow_partial) + + output = await self.toolset.call_tool(name, args_dict, ctx, tool) + except (ValidationError, ModelRetry) as e: + max_retries = tool.max_retries if tool is not None else 1 + current_retry = self.ctx.retries.get(name, 0) + + if current_retry == max_retries: + raise UnexpectedModelBehavior(f'Tool {name!r} exceeded max retries count of {max_retries}') from e + else: + if isinstance(e, ValidationError): + m = _messages.RetryPromptPart( + tool_name=name, + content=e.errors(include_url=False, include_context=False), + tool_call_id=call.tool_call_id, + ) + e = ToolRetryError(m) + elif isinstance(e, ModelRetry): + m = _messages.RetryPromptPart( + tool_name=name, + content=e.message, + tool_call_id=call.tool_call_id, + ) + e = ToolRetryError(m) + else: + assert_never(e) + + self.ctx.retries[name] = current_retry + 1 + raise e + else: + self.ctx.retries.pop(name, None) + return output + + async def _call_tool_traced(self, call: ToolCallPart, allow_partial: bool = False) -> Any: + """See .""" + span_attributes = { + 'gen_ai.tool.name': call.tool_name, + # NOTE: this means `gen_ai.tool.call.id` will be included even if it was generated by pydantic-ai + 'gen_ai.tool.call.id': call.tool_call_id, + **({'tool_arguments': call.args_as_json_str()} if self.ctx.trace_include_content else {}), + 'logfire.msg': f'running tool: {call.tool_name}', + # add the JSON schema so these attributes are formatted nicely in Logfire + 'logfire.json_schema': json.dumps( + { + 'type': 'object', + 'properties': { + **( + { + 'tool_arguments': {'type': 'object'}, + 'tool_response': {'type': 'object'}, + } + if self.ctx.trace_include_content + else {} + ), + 'gen_ai.tool.name': {}, + 'gen_ai.tool.call.id': {}, + }, + } + ), + } + with self.ctx.tracer.start_as_current_span('running tool', attributes=span_attributes) as span: + try: + tool_result = await self._call_tool(call, allow_partial) + except ToolRetryError as e: + part = e.tool_retry + if self.ctx.trace_include_content and span.is_recording(): + span.set_attribute('tool_response', part.model_response()) + raise e + + if self.ctx.trace_include_content and span.is_recording(): + span.set_attribute( + 'tool_response', + tool_result + if isinstance(tool_result, str) + else _messages.tool_return_ta.dump_json(tool_result).decode(), + ) + + return tool_result + + def get_deferred_tool_calls(self, parts: Iterable[_messages.ModelResponsePart]) -> DeferredToolCalls | None: + """Get the deferred tool calls from the model response parts.""" + deferred_calls_and_defs = [ + (part, tool_def) + for part in parts + if isinstance(part, _messages.ToolCallPart) + and (tool_def := self.get_tool_def(part.tool_name)) + and tool_def.kind == 'deferred' + ] + if not deferred_calls_and_defs: + return None + + deferred_calls: list[_messages.ToolCallPart] = [] + deferred_tool_defs: dict[str, ToolDefinition] = {} + for part, tool_def in deferred_calls_and_defs: + deferred_calls.append(part) + deferred_tool_defs[part.tool_name] = tool_def + + return DeferredToolCalls(deferred_calls, deferred_tool_defs) diff --git a/pydantic_ai_slim/pydantic_ai/_utils.py b/pydantic_ai_slim/pydantic_ai/_utils.py index d3f42a7ee9..88bb30ebe3 100644 --- a/pydantic_ai_slim/pydantic_ai/_utils.py +++ b/pydantic_ai_slim/pydantic_ai/_utils.py @@ -4,8 +4,10 @@ import functools import inspect import re +import sys import time import uuid +import warnings from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Iterator from contextlib import asynccontextmanager, suppress from dataclasses import dataclass, fields, is_dataclass @@ -29,7 +31,7 @@ from typing_inspection import typing_objects from typing_inspection.introspection import is_union_origin -from pydantic_graph._utils import AbstractSpan +from pydantic_graph._utils import AbstractSpan, get_event_loop from . import exceptions @@ -461,3 +463,18 @@ def get_union_args(tp: Any) -> tuple[Any, ...]: return get_args(tp) else: return () + + +# The `asyncio.Lock` `loop` argument was deprecated in 3.8 and removed in 3.10, +# but 3.9 still needs it to have the intended behavior. + +if sys.version_info < (3, 10): + + def get_async_lock() -> asyncio.Lock: # pragma: lax no cover + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + return asyncio.Lock(loop=get_event_loop()) +else: + + def get_async_lock() -> asyncio.Lock: # pragma: lax no cover + return asyncio.Lock() diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 3ff881294c..2b0aeb597e 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -4,6 +4,7 @@ import inspect import json import warnings +from asyncio import Lock from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager, contextmanager from contextvars import ContextVar @@ -15,7 +16,6 @@ from pydantic.json_schema import GenerateJsonSchema from typing_extensions import Literal, Never, Self, TypeIs, TypeVar, deprecated -from pydantic_ai.profiles import ModelProfile from pydantic_graph import End, Graph, GraphRun, GraphRunContext from pydantic_graph._utils import get_event_loop @@ -31,8 +31,11 @@ usage as _usage, ) from ._agent_graph import HistoryProcessor +from ._output import OutputToolset +from ._tool_manager import ToolManager from .models.instrumented import InstrumentationSettings, InstrumentedModel, instrument_model from .output import OutputDataT, OutputSpec +from .profiles import ModelProfile from .result import FinalResult, StreamedRunResult from .settings import ModelSettings, merge_model_settings from .tools import ( @@ -48,6 +51,10 @@ ToolPrepareFunc, ToolsPrepareFunc, ) +from .toolsets import AbstractToolset +from .toolsets.combined import CombinedToolset +from .toolsets.function import FunctionToolset +from .toolsets.prepared import PreparedToolset # Re-exporting like this improves auto-import behavior in PyCharm capture_run_messages = _agent_graph.capture_run_messages @@ -153,12 +160,17 @@ class Agent(Generic[AgentDepsT, OutputDataT]): _system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field( repr=False ) + _function_toolset: FunctionToolset[AgentDepsT] = dataclasses.field(repr=False) + _output_toolset: OutputToolset[AgentDepsT] | None = dataclasses.field(repr=False) + _user_toolsets: Sequence[AbstractToolset[AgentDepsT]] = dataclasses.field(repr=False) _prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False) - _function_tools: dict[str, Tool[AgentDepsT]] = dataclasses.field(repr=False) - _mcp_servers: Sequence[MCPServer] = dataclasses.field(repr=False) - _default_retries: int = dataclasses.field(repr=False) + _prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False) _max_result_retries: int = dataclasses.field(repr=False) + _enter_lock: Lock = dataclasses.field(repr=False) + _entered_count: int = dataclasses.field(repr=False) + _exit_stack: AsyncExitStack | None = dataclasses.field(repr=False) + @overload def __init__( self, @@ -177,7 +189,8 @@ def __init__( output_retries: int | None = None, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None, - mcp_servers: Sequence[MCPServer] = (), + prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, defer_model_check: bool = False, end_strategy: EndStrategy = 'early', instrument: InstrumentationSettings | bool | None = None, @@ -186,7 +199,7 @@ def __init__( @overload @deprecated( - '`result_type`, `result_tool_name`, `result_tool_description` & `result_retries` are deprecated, use `output_type` instead. `result_retries` is deprecated, use `output_retries` instead.' + '`result_type`, `result_tool_name` & `result_tool_description` are deprecated, use `output_type` instead. `result_retries` is deprecated, use `output_retries` instead.' ) def __init__( self, @@ -207,6 +220,36 @@ def __init__( result_retries: int | None = None, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None, + prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + defer_model_check: bool = False, + end_strategy: EndStrategy = 'early', + instrument: InstrumentationSettings | bool | None = None, + history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None, + ) -> None: ... + + @overload + @deprecated('`mcp_servers` is deprecated, use `toolsets` instead.') + def __init__( + self, + model: models.Model | models.KnownModelName | str | None = None, + *, + result_type: type[OutputDataT] = str, + instructions: str + | _system_prompt.SystemPromptFunc[AgentDepsT] + | Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]] + | None = None, + system_prompt: str | Sequence[str] = (), + deps_type: type[AgentDepsT] = NoneType, + name: str | None = None, + model_settings: ModelSettings | None = None, + retries: int = 1, + result_tool_name: str = _output.DEFAULT_OUTPUT_TOOL_NAME, + result_tool_description: str | None = None, + result_retries: int | None = None, + tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), + prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None, + prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None, mcp_servers: Sequence[MCPServer] = (), defer_model_check: bool = False, end_strategy: EndStrategy = 'early', @@ -232,7 +275,8 @@ def __init__( output_retries: int | None = None, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None, - mcp_servers: Sequence[MCPServer] = (), + prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, defer_model_check: bool = False, end_strategy: EndStrategy = 'early', instrument: InstrumentationSettings | bool | None = None, @@ -258,14 +302,16 @@ def __init__( when the agent is first run. model_settings: Optional model request settings to use for this agent's runs, by default. retries: The default number of retries to allow before raising an error. - output_retries: The maximum number of retries to allow for result validation, defaults to `retries`. + output_retries: The maximum number of retries to allow for output validation, defaults to `retries`. tools: Tools to register with the agent, you can also register tools via the decorators [`@agent.tool`][pydantic_ai.Agent.tool] and [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain]. - prepare_tools: custom method to prepare the tool definition of all tools for each step. + prepare_tools: Custom function to prepare the tool definition of all tools for each step, except output tools. This is useful if you want to customize the definition of multiple tools or you want to register a subset of tools for a given step. See [`ToolsPrepareFunc`][pydantic_ai.tools.ToolsPrepareFunc] - mcp_servers: MCP servers to register with the agent. You should register a [`MCPServer`][pydantic_ai.mcp.MCPServer] - for each server you want the agent to connect to. + prepare_output_tools: Custom function to prepare the tool definition of all output tools for each step. + This is useful if you want to customize the definition of multiple output tools or you want to register + a subset of output tools for a given step. See [`ToolsPrepareFunc`][pydantic_ai.tools.ToolsPrepareFunc] + toolsets: Toolsets to register with the agent, including MCP servers. defer_model_check: by default, if you provide a [named][pydantic_ai.models.KnownModelName] model, it's evaluated to create a [`Model`][pydantic_ai.models.Model] instance immediately, which checks for the necessary environment variables. Set this to `false` @@ -329,10 +375,17 @@ def __init__( ) output_retries = result_retries + if mcp_servers := _deprecated_kwargs.pop('mcp_servers', None): + if toolsets is not None: # pragma: no cover + raise TypeError('`mcp_servers` and `toolsets` cannot be set at the same time.') + warnings.warn('`mcp_servers` is deprecated, use `toolsets` instead', DeprecationWarning) + toolsets = mcp_servers + + _utils.validate_empty_kwargs(_deprecated_kwargs) + default_output_mode = ( self.model.profile.default_structured_output_mode if isinstance(self.model, models.Model) else None ) - _utils.validate_empty_kwargs(_deprecated_kwargs) self._output_schema = _output.OutputSchema[OutputDataT].build( output_type, @@ -357,21 +410,28 @@ def __init__( self._system_prompt_functions = [] self._system_prompt_dynamic_functions = {} - self._function_tools = {} - - self._default_retries = retries self._max_result_retries = output_retries if output_retries is not None else retries - self._mcp_servers = mcp_servers self._prepare_tools = prepare_tools + self._prepare_output_tools = prepare_output_tools + + self._output_toolset = self._output_schema.toolset + if self._output_toolset: + self._output_toolset.max_retries = self._max_result_retries + + self._function_toolset = FunctionToolset(tools, max_retries=retries) + self._user_toolsets = toolsets or () + self.history_processors = history_processors or [] - for tool in tools: - if isinstance(tool, Tool): - self._register_tool(tool) - else: - self._register_tool(Tool(tool)) self._override_deps: ContextVar[_utils.Option[AgentDepsT]] = ContextVar('_override_deps', default=None) self._override_model: ContextVar[_utils.Option[models.Model]] = ContextVar('_override_model', default=None) + self._override_toolsets: ContextVar[_utils.Option[Sequence[AbstractToolset[AgentDepsT]]]] = ContextVar( + '_override_toolsets', default=None + ) + + self._enter_lock = _utils.get_async_lock() + self._entered_count = 0 + self._exit_stack = None @staticmethod def instrument_all(instrument: InstrumentationSettings | bool = True) -> None: @@ -391,6 +451,7 @@ async def run( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AgentRunResult[OutputDataT]: ... @overload @@ -406,6 +467,7 @@ async def run( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AgentRunResult[RunOutputDataT]: ... @overload @@ -422,6 +484,7 @@ async def run( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AgentRunResult[RunOutputDataT]: ... async def run( @@ -436,6 +499,7 @@ async def run( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, **_deprecated_kwargs: Never, ) -> AgentRunResult[Any]: """Run the agent with a user prompt in async mode. @@ -466,6 +530,7 @@ async def main(): usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. + toolsets: Optional additional toolsets for this run. Returns: The result of the run. @@ -490,6 +555,7 @@ async def main(): model_settings=model_settings, usage_limits=usage_limits, usage=usage, + toolsets=toolsets, ) as agent_run: async for _ in agent_run: pass @@ -510,6 +576,7 @@ def iter( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, **_deprecated_kwargs: Never, ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, OutputDataT]]: ... @@ -526,6 +593,7 @@ def iter( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, **_deprecated_kwargs: Never, ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, RunOutputDataT]]: ... @@ -543,6 +611,7 @@ def iter( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, Any]]: ... @asynccontextmanager @@ -558,6 +627,7 @@ async def iter( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, **_deprecated_kwargs: Never, ) -> AsyncIterator[AgentRun[AgentDepsT, Any]]: """A contextmanager which can be used to iterate over the agent graph's nodes as they are executed. @@ -632,6 +702,7 @@ async def main(): usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. + toolsets: Optional additional toolsets for this run. Returns: The result of the run. @@ -655,6 +726,18 @@ async def main(): output_type_ = output_type or self.output_type + # We consider it a user error if a user tries to restrict the result type while having an output validator that + # may change the result type from the restricted type to something else. Therefore, we consider the following + # typecast reasonable, even though it is possible to violate it with otherwise-type-checked code. + output_validators = cast(list[_output.OutputValidator[AgentDepsT, RunOutputDataT]], self._output_validators) + + output_toolset = self._output_toolset + if output_schema != self._output_schema or output_validators: + output_toolset = cast(OutputToolset[AgentDepsT], output_schema.toolset) + if output_toolset: + output_toolset.max_retries = self._max_result_retries + output_toolset.output_validators = output_validators + # Build the graph graph: Graph[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], FinalResult[Any]] = ( _agent_graph.build_agent_graph(self.name, self._deps_type, output_type_) @@ -669,22 +752,32 @@ async def main(): run_step=0, ) - # We consider it a user error if a user tries to restrict the result type while having an output validator that - # may change the result type from the restricted type to something else. Therefore, we consider the following - # typecast reasonable, even though it is possible to violate it with otherwise-type-checked code. - output_validators = cast(list[_output.OutputValidator[AgentDepsT, RunOutputDataT]], self._output_validators) - - # Merge model settings in order of precedence: run > agent > model - merged_settings = merge_model_settings(model_used.settings, self.model_settings) - model_settings = merge_model_settings(merged_settings, model_settings) - usage_limits = usage_limits or _usage.UsageLimits() - if isinstance(model_used, InstrumentedModel): instrumentation_settings = model_used.instrumentation_settings tracer = model_used.instrumentation_settings.tracer else: instrumentation_settings = None tracer = NoOpTracer() + + run_context = RunContext[AgentDepsT]( + deps=deps, + model=model_used, + usage=usage, + prompt=user_prompt, + messages=state.message_history, + tracer=tracer, + trace_include_content=instrumentation_settings is not None and instrumentation_settings.include_content, + run_step=state.run_step, + ) + + toolset = self._get_toolset(output_toolset=output_toolset, additional_toolsets=toolsets) + # This will raise errors for any name conflicts + run_toolset = await ToolManager[AgentDepsT].build(toolset, run_context) + + # Merge model settings in order of precedence: run > agent > model + merged_settings = merge_model_settings(model_used.settings, self.model_settings) + model_settings = merge_model_settings(merged_settings, model_settings) + usage_limits = usage_limits or _usage.UsageLimits() agent_name = self.name or 'agent' run_span = tracer.start_span( 'agent run', @@ -711,10 +804,6 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: return None return '\n\n'.join(parts).strip() - # Copy the function tools so that retry state is agent-run-specific - # Note that the retry count is reset to 0 when this happens due to the `default=0` and `init=False`. - run_function_tools = {k: dataclasses.replace(v) for k, v in self._function_tools.items()} - graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunOutputDataT]( user_deps=deps, prompt=user_prompt, @@ -727,11 +816,8 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: output_schema=output_schema, output_validators=output_validators, history_processors=self.history_processors, - function_tools=run_function_tools, - mcp_servers=self._mcp_servers, - default_retries=self._default_retries, + tool_manager=run_toolset, tracer=tracer, - prepare_tools=self._prepare_tools, get_instructions=get_instructions, instrumentation_settings=instrumentation_settings, ) @@ -801,6 +887,7 @@ def run_sync( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AgentRunResult[OutputDataT]: ... @overload @@ -816,6 +903,7 @@ def run_sync( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AgentRunResult[RunOutputDataT]: ... @overload @@ -832,6 +920,7 @@ def run_sync( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AgentRunResult[RunOutputDataT]: ... def run_sync( @@ -846,6 +935,7 @@ def run_sync( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, **_deprecated_kwargs: Never, ) -> AgentRunResult[Any]: """Synchronously run the agent with a user prompt. @@ -875,6 +965,7 @@ def run_sync( usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. + toolsets: Optional additional toolsets for this run. Returns: The result of the run. @@ -901,6 +992,7 @@ def run_sync( usage_limits=usage_limits, usage=usage, infer_name=False, + toolsets=toolsets, ) ) @@ -916,6 +1008,7 @@ def run_stream( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, OutputDataT]]: ... @overload @@ -931,6 +1024,7 @@ def run_stream( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, RunOutputDataT]]: ... @overload @@ -947,6 +1041,7 @@ def run_stream( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, RunOutputDataT]]: ... @asynccontextmanager @@ -962,6 +1057,7 @@ async def run_stream( # noqa C901 usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, **_deprecated_kwargs: Never, ) -> AsyncIterator[result.StreamedRunResult[AgentDepsT, Any]]: """Run the agent with a user prompt in async mode, returning a streamed response. @@ -989,6 +1085,7 @@ async def main(): usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. + toolsets: Optional additional toolsets for this run. Returns: The result of the run. @@ -1019,6 +1116,7 @@ async def main(): usage_limits=usage_limits, usage=usage, infer_name=False, + toolsets=toolsets, ) as agent_run: first_node = agent_run.next_node # start with the first node assert isinstance(first_node, _agent_graph.UserPromptNode) # the first node should be a user prompt node @@ -1039,15 +1137,17 @@ async def stream_to_final( output_schema, _output.TextOutputSchema ): return FinalResult(s, None, None) - elif isinstance(new_part, _messages.ToolCallPart) and isinstance( - output_schema, _output.ToolOutputSchema - ): # pragma: no branch - for call, _ in output_schema.find_tool([new_part]): - return FinalResult(s, call.tool_name, call.tool_call_id) + elif isinstance(new_part, _messages.ToolCallPart) and ( + tool_def := graph_ctx.deps.tool_manager.get_tool_def(new_part.tool_name) + ): + if tool_def.kind == 'output': + return FinalResult(s, new_part.tool_name, new_part.tool_call_id) + elif tool_def.kind == 'deferred': + return FinalResult(s, None, None) return None - final_result_details = await stream_to_final(streamed_response) - if final_result_details is not None: + final_result = await stream_to_final(streamed_response) + if final_result is not None: if yielded: raise exceptions.AgentRunError('Agent run produced final results') # pragma: no cover yielded = True @@ -1068,17 +1168,13 @@ async def on_complete() -> None: parts: list[_messages.ModelRequestPart] = [] async for _event in _agent_graph.process_function_tools( + graph_ctx.deps.tool_manager, tool_calls, - final_result_details.tool_name, - final_result_details.tool_call_id, + final_result, graph_ctx, parts, ): pass - # TODO: Should we do something here related to the retry count? - # Maybe we should move the incrementing of the retry count to where we actually make a request? - # if any(isinstance(part, _messages.RetryPromptPart) for part in parts): - # ctx.state.increment_retries(ctx.deps.max_result_retries) if parts: messages.append(_messages.ModelRequest(parts)) @@ -1089,10 +1185,10 @@ async def on_complete() -> None: streamed_response, graph_ctx.deps.output_schema, _agent_graph.build_run_context(graph_ctx), - _output.build_trace_context(graph_ctx), graph_ctx.deps.output_validators, - final_result_details.tool_name, + final_result.tool_name, on_complete, + graph_ctx.deps.tool_manager, ) break next_node = await agent_run.next(node) @@ -1111,8 +1207,9 @@ def override( *, deps: AgentDepsT | _utils.Unset = _utils.UNSET, model: models.Model | models.KnownModelName | str | _utils.Unset = _utils.UNSET, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | _utils.Unset = _utils.UNSET, ) -> Iterator[None]: - """Context manager to temporarily override agent dependencies and model. + """Context manager to temporarily override agent dependencies, model, or toolsets. This is particularly useful when testing. You can find an example of this [here](../testing.md#overriding-model-via-pytest-fixtures). @@ -1120,6 +1217,7 @@ def override( Args: deps: The dependencies to use instead of the dependencies passed to the agent run. model: The model to use instead of the model passed to the agent run. + toolsets: The toolsets to use instead of the toolsets passed to the agent constructor and agent run. """ if _utils.is_set(deps): deps_token = self._override_deps.set(_utils.Some(deps)) @@ -1131,6 +1229,11 @@ def override( else: model_token = None + if _utils.is_set(toolsets): + toolsets_token = self._override_toolsets.set(_utils.Some(toolsets)) + else: + toolsets_token = None + try: yield finally: @@ -1138,6 +1241,8 @@ def override( self._override_deps.reset(deps_token) if model_token is not None: self._override_model.reset(model_token) + if toolsets_token is not None: + self._override_toolsets.reset(toolsets_token) @overload def instructions( @@ -1423,30 +1528,13 @@ async def spam(ctx: RunContext[str], y: float) -> float: strict: Whether to enforce JSON schema compliance (only affects OpenAI). See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info. """ - if func is None: - def tool_decorator( - func_: ToolFuncContext[AgentDepsT, ToolParams], - ) -> ToolFuncContext[AgentDepsT, ToolParams]: - # noinspection PyTypeChecker - self._register_function( - func_, - True, - name, - retries, - prepare, - docstring_format, - require_parameter_descriptions, - schema_generator, - strict, - ) - return func_ - - return tool_decorator - else: + def tool_decorator( + func_: ToolFuncContext[AgentDepsT, ToolParams], + ) -> ToolFuncContext[AgentDepsT, ToolParams]: # noinspection PyTypeChecker - self._register_function( - func, + self._function_toolset.add_function( + func_, True, name, retries, @@ -1456,7 +1544,9 @@ def tool_decorator( schema_generator, strict, ) - return func + return func_ + + return tool_decorator if func is None else tool_decorator(func) @overload def tool_plain(self, func: ToolFuncPlain[ToolParams], /) -> ToolFuncPlain[ToolParams]: ... @@ -1532,27 +1622,11 @@ async def spam(ctx: RunContext[str]) -> float: strict: Whether to enforce JSON schema compliance (only affects OpenAI). See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info. """ - if func is None: - - def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams]: - # noinspection PyTypeChecker - self._register_function( - func_, - False, - name, - retries, - prepare, - docstring_format, - require_parameter_descriptions, - schema_generator, - strict, - ) - return func_ - return tool_decorator - else: - self._register_function( - func, + def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams]: + # noinspection PyTypeChecker + self._function_toolset.add_function( + func_, False, name, retries, @@ -1562,48 +1636,9 @@ def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams schema_generator, strict, ) - return func - - def _register_function( - self, - func: ToolFuncEither[AgentDepsT, ToolParams], - takes_ctx: bool, - name: str | None, - retries: int | None, - prepare: ToolPrepareFunc[AgentDepsT] | None, - docstring_format: DocstringFormat, - require_parameter_descriptions: bool, - schema_generator: type[GenerateJsonSchema], - strict: bool | None, - ) -> None: - """Private utility to register a function as a tool.""" - retries_ = retries if retries is not None else self._default_retries - tool = Tool[AgentDepsT]( - func, - takes_ctx=takes_ctx, - name=name, - max_retries=retries_, - prepare=prepare, - docstring_format=docstring_format, - require_parameter_descriptions=require_parameter_descriptions, - schema_generator=schema_generator, - strict=strict, - ) - self._register_tool(tool) - - def _register_tool(self, tool: Tool[AgentDepsT]) -> None: - """Private utility to register a tool instance.""" - if tool.max_retries is None: - # noinspection PyTypeChecker - tool = dataclasses.replace(tool, max_retries=self._default_retries) + return func_ - if tool.name in self._function_tools: - raise exceptions.UserError(f'Tool name conflicts with existing tool: {tool.name!r}') - - if tool.name in self._output_schema.tools: - raise exceptions.UserError(f'Tool name conflicts with output tool name: {tool.name!r}') - - self._function_tools[tool.name] = tool + return tool_decorator if func is None else tool_decorator(func) def _get_model(self, model: models.Model | models.KnownModelName | str | None) -> models.Model: """Create a model configured for this agent. @@ -1649,6 +1684,37 @@ def _get_deps(self: Agent[T, OutputDataT], deps: T) -> T: else: return deps + def _get_toolset( + self, + output_toolset: AbstractToolset[AgentDepsT] | None | _utils.Unset = _utils.UNSET, + additional_toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + ) -> AbstractToolset[AgentDepsT]: + """Get the complete toolset. + + Args: + output_toolset: The output toolset to use instead of the one built at agent construction time. + additional_toolsets: Additional toolsets to add. + """ + if some_user_toolsets := self._override_toolsets.get(): + user_toolsets = some_user_toolsets.value + elif additional_toolsets is not None: + user_toolsets = [*self._user_toolsets, *additional_toolsets] + else: + user_toolsets = self._user_toolsets + + all_toolsets = [self._function_toolset, *user_toolsets] + + if self._prepare_tools: + all_toolsets = [PreparedToolset(CombinedToolset(all_toolsets), self._prepare_tools)] + + output_toolset = output_toolset if _utils.is_set(output_toolset) else self._output_toolset + if output_toolset is not None: + if self._prepare_output_tools: + output_toolset = PreparedToolset(output_toolset, self._prepare_output_tools) + all_toolsets = [output_toolset, *all_toolsets] + + return CombinedToolset(all_toolsets) + def _infer_name(self, function_frame: FrameType | None) -> None: """Infer the agent name from the call frame. @@ -1734,28 +1800,68 @@ def is_end_node( """ return isinstance(node, End) + async def __aenter__(self) -> Self: + """Enter the agent context. + + This will start all [`MCPServerStdio`s][pydantic_ai.mcp.MCPServerStdio] registered as `toolsets` so they are ready to be used. + + This is a no-op if the agent has already been entered. + """ + async with self._enter_lock: + if self._entered_count == 0: + self._exit_stack = AsyncExitStack() + toolset = self._get_toolset() + await self._exit_stack.enter_async_context(toolset) + self._entered_count += 1 + return self + + async def __aexit__(self, *args: Any) -> bool | None: + async with self._enter_lock: + self._entered_count -= 1 + if self._entered_count == 0 and self._exit_stack is not None: + await self._exit_stack.aclose() + self._exit_stack = None + + def set_mcp_sampling_model(self, model: models.Model | models.KnownModelName | str | None = None) -> None: + """Set the sampling model on all MCP servers registered with the agent. + + If no sampling model is provided, the agent's model will be used. + """ + try: + sampling_model = models.infer_model(model) if model else self._get_model(None) + except exceptions.UserError as e: + raise exceptions.UserError('No sampling model provided and no model set on the agent.') from e + + from .mcp import MCPServer + + def _set_sampling_model(toolset: AbstractToolset[AgentDepsT]) -> None: + if isinstance(toolset, MCPServer): + toolset.sampling_model = sampling_model + + self._get_toolset().apply(_set_sampling_model) + @asynccontextmanager + @deprecated( + '`run_mcp_servers` is deprecated, use `async with agent:` instead. If you need to set a sampling model on all MCP servers, use `agent.set_mcp_sampling_model()`.' + ) async def run_mcp_servers( self, model: models.Model | models.KnownModelName | str | None = None ) -> AsyncIterator[None]: """Run [`MCPServerStdio`s][pydantic_ai.mcp.MCPServerStdio] so they can be used by the agent. + Deprecated: use [`async with agent`][pydantic_ai.agent.Agent.__aenter__] instead. + If you need to set a sampling model on all MCP servers, use [`agent.set_mcp_sampling_model()`][pydantic_ai.agent.Agent.set_mcp_sampling_model]. + Returns: a context manager to start and shutdown the servers. """ try: - sampling_model: models.Model | None = self._get_model(model) - except exceptions.UserError: # pragma: no cover - sampling_model = None + self.set_mcp_sampling_model(model) + except exceptions.UserError: + if model is not None: + raise - exit_stack = AsyncExitStack() - try: - for mcp_server in self._mcp_servers: - if sampling_model is not None: # pragma: no branch - mcp_server.sampling_model = sampling_model - await exit_stack.enter_async_context(mcp_server) + async with self: yield - finally: - await exit_stack.aclose() def to_a2a( self, diff --git a/pydantic_ai_slim/pydantic_ai/exceptions.py b/pydantic_ai_slim/pydantic_ai/exceptions.py index 3f57faaf8d..344ab94daf 100644 --- a/pydantic_ai_slim/pydantic_ai/exceptions.py +++ b/pydantic_ai_slim/pydantic_ai/exceptions.py @@ -2,12 +2,16 @@ import json import sys +from typing import TYPE_CHECKING if sys.version_info < (3, 11): from exceptiongroup import ExceptionGroup else: ExceptionGroup = ExceptionGroup +if TYPE_CHECKING: + from .messages import RetryPromptPart + __all__ = ( 'ModelRetry', 'UserError', @@ -113,3 +117,11 @@ def __init__(self, status_code: int, model_name: str, body: object | None = None class FallbackExceptionGroup(ExceptionGroup): """A group of exceptions that can be raised when all fallback models fail.""" + + +class ToolRetryError(Exception): + """Exception used to signal a `ToolRetry` message should be returned to the LLM.""" + + def __init__(self, tool_retry: RetryPromptPart): + self.tool_retry = tool_retry + super().__init__() diff --git a/pydantic_ai_slim/pydantic_ai/ext/aci.py b/pydantic_ai_slim/pydantic_ai/ext/aci.py index 5e5dc49366..6cd43402a1 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/aci.py +++ b/pydantic_ai_slim/pydantic_ai/ext/aci.py @@ -4,11 +4,13 @@ except ImportError as _import_error: raise ImportError('Please install `aci-sdk` to use ACI.dev tools') from _import_error +from collections.abc import Sequence from typing import Any from aci import ACI -from pydantic_ai import Tool +from pydantic_ai.tools import Tool +from pydantic_ai.toolsets.function import FunctionToolset def _clean_schema(schema): @@ -22,10 +24,10 @@ def _clean_schema(schema): def tool_from_aci(aci_function: str, linked_account_owner_id: str) -> Tool: - """Creates a Pydantic AI tool proxy from an ACI function. + """Creates a Pydantic AI tool proxy from an ACI.dev function. Args: - aci_function: The ACI function to wrao. + aci_function: The ACI.dev function to wrap. linked_account_owner_id: The ACI user ID to execute the function on behalf of. Returns: @@ -64,3 +66,10 @@ def implementation(*args: Any, **kwargs: Any) -> str: description=function_description, json_schema=json_schema, ) + + +class ACIToolset(FunctionToolset): + """A toolset that wraps ACI.dev tools.""" + + def __init__(self, aci_functions: Sequence[str], linked_account_owner_id: str): + super().__init__([tool_from_aci(aci_function, linked_account_owner_id) for aci_function in aci_functions]) diff --git a/pydantic_ai_slim/pydantic_ai/ext/langchain.py b/pydantic_ai_slim/pydantic_ai/ext/langchain.py index 60db763f9d..3fb4079386 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/langchain.py +++ b/pydantic_ai_slim/pydantic_ai/ext/langchain.py @@ -3,6 +3,7 @@ from pydantic.json_schema import JsonSchemaValue from pydantic_ai.tools import Tool +from pydantic_ai.toolsets.function import FunctionToolset class LangChainTool(Protocol): @@ -23,7 +24,7 @@ def description(self) -> str: ... def run(self, *args: Any, **kwargs: Any) -> str: ... -__all__ = ('tool_from_langchain',) +__all__ = ('tool_from_langchain', 'LangChainToolset') def tool_from_langchain(langchain_tool: LangChainTool) -> Tool: @@ -59,3 +60,10 @@ def proxy(*args: Any, **kwargs: Any) -> str: description=function_description, json_schema=schema, ) + + +class LangChainToolset(FunctionToolset): + """A toolset that wraps LangChain tools.""" + + def __init__(self, tools: list[LangChainTool]): + super().__init__([tool_from_langchain(tool) for tool in tools]) diff --git a/pydantic_ai_slim/pydantic_ai/mcp.py b/pydantic_ai_slim/pydantic_ai/mcp.py index e1fc10f29d..2ca7950b3e 100644 --- a/pydantic_ai_slim/pydantic_ai/mcp.py +++ b/pydantic_ai_slim/pydantic_ai/mcp.py @@ -3,11 +3,11 @@ import base64 import functools from abc import ABC, abstractmethod +from asyncio import Lock from collections.abc import AsyncIterator, Awaitable, Sequence from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager -from dataclasses import dataclass +from dataclasses import dataclass, field, replace from pathlib import Path -from types import TracebackType from typing import Any, Callable import anyio @@ -16,6 +16,11 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from typing_extensions import Self, assert_never, deprecated +from pydantic_ai._run_context import RunContext +from pydantic_ai.tools import ToolDefinition + +from .toolsets.abstract import AbstractToolset, ToolsetTool + try: from mcp import types as mcp_types from mcp.client.session import ClientSession, LoggingFnT @@ -32,12 +37,18 @@ ) from _import_error # after mcp imports so any import error maps to this file, not _mcp.py -from . import _mcp, exceptions, messages, models, tools +from . import _mcp, exceptions, messages, models __all__ = 'MCPServer', 'MCPServerStdio', 'MCPServerHTTP', 'MCPServerSSE', 'MCPServerStreamableHTTP' +TOOL_SCHEMA_VALIDATOR = pydantic_core.SchemaValidator( + schema=pydantic_core.core_schema.dict_schema( + pydantic_core.core_schema.str_schema(), pydantic_core.core_schema.any_schema() + ) +) + -class MCPServer(ABC): +class MCPServer(AbstractToolset[Any], ABC): """Base class for attaching agents to MCP servers. See for more information. @@ -50,15 +61,22 @@ class MCPServer(ABC): timeout: float = 5 process_tool_call: ProcessToolCallback | None = None allow_sampling: bool = True + max_retries: int = 1 + sampling_model: models.Model | None = None # } end of "abstract fields" - _running_count: int = 0 + _enter_lock: Lock = field(compare=False) + _running_count: int + _exit_stack: AsyncExitStack | None _client: ClientSession _read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] _write_stream: MemoryObjectSendStream[SessionMessage] - _exit_stack: AsyncExitStack - sampling_model: models.Model | None = None + + def __post_init__(self): + self._enter_lock = Lock() + self._running_count = 0 + self._exit_stack = None @abstractmethod @asynccontextmanager @@ -74,47 +92,36 @@ async def client_streams( raise NotImplementedError('MCP Server subclasses must implement this method.') yield - def get_prefixed_tool_name(self, tool_name: str) -> str: - """Get the tool name with prefix if `tool_prefix` is set.""" - return f'{self.tool_prefix}_{tool_name}' if self.tool_prefix else tool_name - - def get_unprefixed_tool_name(self, tool_name: str) -> str: - """Get original tool name without prefix for calling tools.""" - return tool_name.removeprefix(f'{self.tool_prefix}_') if self.tool_prefix else tool_name + @property + def name(self) -> str: + return repr(self) @property - def is_running(self) -> bool: - """Check if the MCP server is running.""" - return bool(self._running_count) + def tool_name_conflict_hint(self) -> str: + return 'Consider setting `tool_prefix` to avoid name conflicts.' - async def list_tools(self) -> list[tools.ToolDefinition]: + async def list_tools(self) -> list[mcp_types.Tool]: """Retrieve tools that are currently active on the server. Note: - We don't cache tools as they might change. - We also don't subscribe to the server to avoid complexity. """ - mcp_tools = await self._client.list_tools() - return [ - tools.ToolDefinition( - name=self.get_prefixed_tool_name(tool.name), - description=tool.description, - parameters_json_schema=tool.inputSchema, - ) - for tool in mcp_tools.tools - ] + async with self: # Ensure server is running + result = await self._client.list_tools() + return result.tools - async def call_tool( + async def direct_call_tool( self, - tool_name: str, - arguments: dict[str, Any], + name: str, + args: dict[str, Any], metadata: dict[str, Any] | None = None, ) -> ToolResult: """Call a tool on the server. Args: - tool_name: The name of the tool to call. - arguments: The arguments to pass to the tool. + name: The name of the tool to call. + args: The arguments to pass to the tool. metadata: Request-level metadata (optional) Returns: @@ -123,23 +130,23 @@ async def call_tool( Raises: ModelRetry: If the tool call fails. """ - try: - # meta param is not provided by session yet, so build and can send_request directly. - result = await self._client.send_request( - mcp_types.ClientRequest( - mcp_types.CallToolRequest( - method='tools/call', - params=mcp_types.CallToolRequestParams( - name=self.get_unprefixed_tool_name(tool_name), - arguments=arguments, - _meta=mcp_types.RequestParams.Meta(**metadata) if metadata else None, - ), - ) - ), - mcp_types.CallToolResult, - ) - except McpError as e: - raise exceptions.ModelRetry(e.error.message) + async with self: # Ensure server is running + try: + result = await self._client.send_request( + mcp_types.ClientRequest( + mcp_types.CallToolRequest( + method='tools/call', + params=mcp_types.CallToolRequestParams( + name=name, + arguments=args, + _meta=mcp_types.RequestParams.Meta(**metadata) if metadata else None, + ), + ) + ), + mcp_types.CallToolResult, + ) + except McpError as e: + raise exceptions.ModelRetry(e.error.message) content = [self._map_tool_result_part(part) for part in result.content] @@ -149,36 +156,80 @@ async def call_tool( else: return content[0] if len(content) == 1 else content - async def __aenter__(self) -> Self: - if self._running_count == 0: - self._exit_stack = AsyncExitStack() - - self._read_stream, self._write_stream = await self._exit_stack.enter_async_context(self.client_streams()) - client = ClientSession( - read_stream=self._read_stream, - write_stream=self._write_stream, - sampling_callback=self._sampling_callback if self.allow_sampling else None, - logging_callback=self.log_handler, + async def call_tool( + self, + name: str, + tool_args: dict[str, Any], + ctx: RunContext[Any], + tool: ToolsetTool[Any], + ) -> ToolResult: + if self.tool_prefix: + name = name.removeprefix(f'{self.tool_prefix}_') + ctx = replace(ctx, tool_name=name) + + if self.process_tool_call is not None: + return await self.process_tool_call(ctx, self.direct_call_tool, name, tool_args) + else: + return await self.direct_call_tool(name, tool_args) + + async def get_tools(self, ctx: RunContext[Any]) -> dict[str, ToolsetTool[Any]]: + return { + name: ToolsetTool( + toolset=self, + tool_def=ToolDefinition( + name=name, + description=mcp_tool.description, + parameters_json_schema=mcp_tool.inputSchema, + ), + max_retries=self.max_retries, + args_validator=TOOL_SCHEMA_VALIDATOR, ) - self._client = await self._exit_stack.enter_async_context(client) + for mcp_tool in await self.list_tools() + if (name := f'{self.tool_prefix}_{mcp_tool.name}' if self.tool_prefix else mcp_tool.name) + } + + async def __aenter__(self) -> Self: + """Enter the MCP server context. + + This will initialize the connection to the server. + If this server is an [`MCPServerStdio`][pydantic_ai.mcp.MCPServerStdio], the server will first be started as a subprocess. - with anyio.fail_after(self.timeout): - await self._client.initialize() + This is a no-op if the MCP server has already been entered. + """ + async with self._enter_lock: + if self._running_count == 0: + self._exit_stack = AsyncExitStack() + + self._read_stream, self._write_stream = await self._exit_stack.enter_async_context( + self.client_streams() + ) + client = ClientSession( + read_stream=self._read_stream, + write_stream=self._write_stream, + sampling_callback=self._sampling_callback if self.allow_sampling else None, + logging_callback=self.log_handler, + ) + self._client = await self._exit_stack.enter_async_context(client) + + with anyio.fail_after(self.timeout): + await self._client.initialize() - if log_level := self.log_level: - await self._client.set_logging_level(log_level) - self._running_count += 1 + if log_level := self.log_level: + await self._client.set_logging_level(log_level) + self._running_count += 1 return self - async def __aexit__( - self, - exc_type: type[BaseException] | None, - exc_value: BaseException | None, - traceback: TracebackType | None, - ) -> bool | None: - self._running_count -= 1 - if self._running_count <= 0: - await self._exit_stack.aclose() + async def __aexit__(self, *args: Any) -> bool | None: + async with self._enter_lock: + self._running_count -= 1 + if self._running_count == 0 and self._exit_stack is not None: + await self._exit_stack.aclose() + self._exit_stack = None + + @property + def is_running(self) -> bool: + """Check if the MCP server is running.""" + return bool(self._running_count) async def _sampling_callback( self, context: RequestContext[ClientSession, Any], params: mcp_types.CreateMessageRequestParams @@ -271,10 +322,10 @@ class MCPServerStdio(MCPServer): 'stdio', ] ) - agent = Agent('openai:gpt-4o', mcp_servers=[server]) + agent = Agent('openai:gpt-4o', toolsets=[server]) async def main(): - async with agent.run_mcp_servers(): # (2)! + async with agent: # (2)! ... ``` @@ -327,6 +378,12 @@ async def main(): allow_sampling: bool = True """Whether to allow MCP sampling through this client.""" + max_retries: int = 1 + """The maximum number of times to retry a tool call.""" + + sampling_model: models.Model | None = None + """The model to use for sampling.""" + @asynccontextmanager async def client_streams( self, @@ -422,6 +479,12 @@ class _MCPServerHTTP(MCPServer): allow_sampling: bool = True """Whether to allow MCP sampling through this client.""" + max_retries: int = 1 + """The maximum number of times to retry a tool call.""" + + sampling_model: models.Model | None = None + """The model to use for sampling.""" + @property @abstractmethod def _transport_client( @@ -503,10 +566,10 @@ class MCPServerSSE(_MCPServerHTTP): from pydantic_ai.mcp import MCPServerSSE server = MCPServerSSE('http://localhost:3001/sse') # (1)! - agent = Agent('openai:gpt-4o', mcp_servers=[server]) + agent = Agent('openai:gpt-4o', toolsets=[server]) async def main(): - async with agent.run_mcp_servers(): # (2)! + async with agent: # (2)! ... ``` @@ -537,10 +600,10 @@ class MCPServerHTTP(MCPServerSSE): from pydantic_ai.mcp import MCPServerHTTP server = MCPServerHTTP('http://localhost:3001/sse') # (1)! - agent = Agent('openai:gpt-4o', mcp_servers=[server]) + agent = Agent('openai:gpt-4o', toolsets=[server]) async def main(): - async with agent.run_mcp_servers(): # (2)! + async with agent: # (2)! ... ``` @@ -566,10 +629,10 @@ class MCPServerStreamableHTTP(_MCPServerHTTP): from pydantic_ai.mcp import MCPServerStreamableHTTP server = MCPServerStreamableHTTP('http://localhost:8000/mcp') # (1)! - agent = Agent('openai:gpt-4o', mcp_servers=[server]) + agent = Agent('openai:gpt-4o', toolsets=[server]) async def main(): - async with agent.run_mcp_servers(): # (2)! + async with agent: # (2)! ... ``` """ @@ -586,14 +649,14 @@ def _transport_client(self): | list[Any] | Sequence[str | messages.BinaryContent | dict[str, Any] | list[Any]] ) -"""The result type of a tool call.""" +"""The result type of an MCP tool call.""" CallToolFunc = Callable[[str, dict[str, Any], dict[str, Any] | None], Awaitable[ToolResult]] """A function type that represents a tool call.""" ProcessToolCallback = Callable[ [ - tools.RunContext[Any], + RunContext[Any], CallToolFunc, str, dict[str, Any], diff --git a/pydantic_ai_slim/pydantic_ai/output.py b/pydantic_ai_slim/pydantic_ai/output.py index 9dc7d2ef6b..9c41b535db 100644 --- a/pydantic_ai_slim/pydantic_ai/output.py +++ b/pydantic_ai_slim/pydantic_ai/output.py @@ -10,7 +10,8 @@ from typing_extensions import TypeAliasType, TypeVar from . import _utils -from .tools import RunContext +from .messages import ToolCallPart +from .tools import RunContext, ToolDefinition __all__ = ( # classes @@ -330,15 +331,17 @@ def __get_pydantic_json_schema__( return _StructuredDict +_OutputSpecItem = TypeAliasType( + '_OutputSpecItem', + Union[OutputTypeOrFunction[T_co], ToolOutput[T_co], NativeOutput[T_co], PromptedOutput[T_co], TextOutput[T_co]], + type_params=(T_co,), +) + OutputSpec = TypeAliasType( 'OutputSpec', Union[ - OutputTypeOrFunction[T_co], - ToolOutput[T_co], - NativeOutput[T_co], - PromptedOutput[T_co], - TextOutput[T_co], - Sequence[Union[OutputTypeOrFunction[T_co], ToolOutput[T_co], TextOutput[T_co]]], + _OutputSpecItem[T_co], + Sequence['OutputSpec[T_co]'], ], type_params=(T_co,), ) @@ -354,3 +357,11 @@ def __get_pydantic_json_schema__( See [output docs](../output.md) for more information. """ + + +@dataclass +class DeferredToolCalls: + """Container for calls of deferred tools. This can be used as an agent's `output_type` and will be used as the output of the agent run if the model called any deferred tools.""" + + tool_calls: list[ToolCallPart] + tool_defs: dict[str, ToolDefinition] diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index f700482662..163189ac0b 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -5,11 +5,13 @@ from copy import copy from dataclasses import dataclass, field from datetime import datetime -from typing import Generic +from typing import Generic, cast from pydantic import ValidationError from typing_extensions import TypeVar, deprecated, overload +from pydantic_ai._tool_manager import ToolManager + from . import _utils, exceptions, messages as _messages, models from ._output import ( OutputDataT_inv, @@ -19,7 +21,6 @@ PlainTextOutputSchema, TextOutputSchema, ToolOutputSchema, - TraceContext, ) from ._run_context import AgentDepsT, RunContext from .messages import AgentStreamEvent, FinalResultEvent @@ -47,8 +48,8 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]): _output_schema: OutputSchema[OutputDataT] _output_validators: list[OutputValidator[AgentDepsT, OutputDataT]] _run_ctx: RunContext[AgentDepsT] - _trace_ctx: TraceContext _usage_limits: UsageLimits | None + _toolset: ToolManager[AgentDepsT] _agent_stream_iterator: AsyncIterator[AgentStreamEvent] | None = field(default=None, init=False) _final_result_event: FinalResultEvent | None = field(default=None, init=False) @@ -97,37 +98,40 @@ async def _validate_response( self, message: _messages.ModelResponse, output_tool_name: str | None, *, allow_partial: bool = False ) -> OutputDataT: """Validate a structured result message.""" - call = None if isinstance(self._output_schema, ToolOutputSchema) and output_tool_name is not None: - match = self._output_schema.find_named_tool(message.parts, output_tool_name) - if match is None: + tool_call = next( + ( + part + for part in message.parts + if isinstance(part, _messages.ToolCallPart) and part.tool_name == output_tool_name + ), + None, + ) + if tool_call is None: raise exceptions.UnexpectedModelBehavior( # pragma: no cover - f'Invalid response, unable to find tool: {self._output_schema.tool_names()}' + f'Invalid response, unable to find tool call for {output_tool_name!r}' ) - - call, output_tool = match - result_data = await output_tool.process( - call, - self._run_ctx, - self._trace_ctx, - allow_partial=allow_partial, - wrap_validation_errors=False, - ) + return await self._toolset.handle_call(tool_call, allow_partial=allow_partial) + elif deferred_tool_calls := self._toolset.get_deferred_tool_calls(message.parts): + if not self._output_schema.allows_deferred_tool_calls: + raise exceptions.UserError( # pragma: no cover + 'A deferred tool call was present, but `DeferredToolCalls` is not among output types. To resolve this, add `DeferredToolCalls` to the list of output types for this agent.' + ) + return cast(OutputDataT, deferred_tool_calls) elif isinstance(self._output_schema, TextOutputSchema): text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart)) result_data = await self._output_schema.process( - text, self._run_ctx, self._trace_ctx, allow_partial=allow_partial, wrap_validation_errors=False + text, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False ) + for validator in self._output_validators: + result_data = await validator.validate(result_data, self._run_ctx) + return result_data else: raise exceptions.UnexpectedModelBehavior( # pragma: no cover 'Invalid response, unable to process text output' ) - for validator in self._output_validators: - result_data = await validator.validate(result_data, call, self._run_ctx) - return result_data - def __aiter__(self) -> AsyncIterator[AgentStreamEvent]: """Stream [`AgentStreamEvent`][pydantic_ai.messages.AgentStreamEvent]s. @@ -145,13 +149,19 @@ def _get_final_result_event(e: _messages.ModelResponseStreamEvent) -> _messages. """Return an appropriate FinalResultEvent if `e` corresponds to a part that will produce a final result.""" if isinstance(e, _messages.PartStartEvent): new_part = e.part - if isinstance(new_part, _messages.ToolCallPart) and isinstance(output_schema, ToolOutputSchema): - for call, _ in output_schema.find_tool([new_part]): # pragma: no branch - return _messages.FinalResultEvent(tool_name=call.tool_name, tool_call_id=call.tool_call_id) - elif isinstance(new_part, _messages.TextPart) and isinstance( + if isinstance(new_part, _messages.TextPart) and isinstance( output_schema, TextOutputSchema ): # pragma: no branch return _messages.FinalResultEvent(tool_name=None, tool_call_id=None) + elif isinstance(new_part, _messages.ToolCallPart) and ( + tool_def := self._toolset.get_tool_def(new_part.tool_name) + ): + if tool_def.kind == 'output': + return _messages.FinalResultEvent( + tool_name=new_part.tool_name, tool_call_id=new_part.tool_call_id + ) + elif tool_def.kind == 'deferred': + return _messages.FinalResultEvent(tool_name=None, tool_call_id=None) usage_checking_stream = _get_usage_checking_stream_response( self._raw_stream_response, self._usage_limits, self.usage @@ -183,10 +193,10 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]): _stream_response: models.StreamedResponse _output_schema: OutputSchema[OutputDataT] _run_ctx: RunContext[AgentDepsT] - _trace_ctx: TraceContext _output_validators: list[OutputValidator[AgentDepsT, OutputDataT]] _output_tool_name: str | None _on_complete: Callable[[], Awaitable[None]] + _toolset: ToolManager[AgentDepsT] _initial_run_ctx_usage: Usage = field(init=False) is_complete: bool = field(default=False, init=False) @@ -420,40 +430,43 @@ async def validate_structured_output( self, message: _messages.ModelResponse, *, allow_partial: bool = False ) -> OutputDataT: """Validate a structured result message.""" - call = None if isinstance(self._output_schema, ToolOutputSchema) and self._output_tool_name is not None: - match = self._output_schema.find_named_tool(message.parts, self._output_tool_name) - if match is None: + tool_call = next( + ( + part + for part in message.parts + if isinstance(part, _messages.ToolCallPart) and part.tool_name == self._output_tool_name + ), + None, + ) + if tool_call is None: raise exceptions.UnexpectedModelBehavior( # pragma: no cover - f'Invalid response, unable to find tool: {self._output_schema.tool_names()}' + f'Invalid response, unable to find tool call for {self._output_tool_name!r}' ) - - call, output_tool = match - result_data = await output_tool.process( - call, - self._run_ctx, - self._trace_ctx, - allow_partial=allow_partial, - wrap_validation_errors=False, - ) + return await self._toolset.handle_call(tool_call, allow_partial=allow_partial) + elif deferred_tool_calls := self._toolset.get_deferred_tool_calls(message.parts): + if not self._output_schema.allows_deferred_tool_calls: + raise exceptions.UserError( + 'A deferred tool call was present, but `DeferredToolCalls` is not among output types. To resolve this, add `DeferredToolCalls` to the list of output types for this agent.' + ) + return cast(OutputDataT, deferred_tool_calls) elif isinstance(self._output_schema, TextOutputSchema): text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart)) result_data = await self._output_schema.process( - text, self._run_ctx, self._trace_ctx, allow_partial=allow_partial, wrap_validation_errors=False + text, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False ) + for validator in self._output_validators: + result_data = await validator.validate(result_data, self._run_ctx) # pragma: no cover + return result_data else: raise exceptions.UnexpectedModelBehavior( # pragma: no cover 'Invalid response, unable to process text output' ) - for validator in self._output_validators: - result_data = await validator.validate(result_data, call, self._run_ctx) # pragma: no cover - return result_data - async def _validate_text_output(self, text: str) -> str: for validator in self._output_validators: - text = await validator.validate(text, None, self._run_ctx) # pragma: no cover + text = await validator.validate(text, self._run_ctx) # pragma: no cover return text async def _marked_completed(self, message: _messages.ModelResponse) -> None: diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index bbc8d83209..4243c02971 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -1,20 +1,15 @@ from __future__ import annotations as _annotations -import dataclasses -import json from collections.abc import Awaitable, Sequence from dataclasses import dataclass, field from typing import Any, Callable, Generic, Literal, Union -from opentelemetry.trace import Tracer -from pydantic import ValidationError from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue from pydantic_core import SchemaValidator, core_schema from typing_extensions import Concatenate, ParamSpec, Self, TypeAlias, TypeVar -from . import _function_schema, _utils, messages as _messages +from . import _function_schema, _utils from ._run_context import AgentDepsT, RunContext -from .exceptions import ModelRetry, UnexpectedModelBehavior __all__ = ( 'AgentDepsT', @@ -32,7 +27,6 @@ 'ToolDefinition', ) -from .messages import ToolReturnPart ToolParams = ParamSpec('ToolParams', default=...) """Retrieval function param spec.""" @@ -173,12 +167,6 @@ class Tool(Generic[AgentDepsT]): This schema may be modified by the `prepare` function or by the Model class prior to including it in an API request. """ - # TODO: Consider moving this current_retry state to live on something other than the tool. - # We've worked around this for now by copying instances of the tool when creating new runs, - # but this is a bit fragile. Moving the tool retry counts to live on the agent run state would likely clean things - # up, though is also likely a larger effort to refactor. - current_retry: int = field(default=0, init=False) - def __init__( self, function: ToolFuncEither[AgentDepsT], @@ -303,6 +291,15 @@ def from_schema( function_schema=function_schema, ) + @property + def tool_def(self): + return ToolDefinition( + name=self.name, + description=self.description, + parameters_json_schema=self.function_schema.json_schema, + strict=self.strict, + ) + async def prepare_tool_def(self, ctx: RunContext[AgentDepsT]) -> ToolDefinition | None: """Get the tool definition. @@ -312,113 +309,11 @@ async def prepare_tool_def(self, ctx: RunContext[AgentDepsT]) -> ToolDefinition Returns: return a `ToolDefinition` or `None` if the tools should not be registered for this run. """ - tool_def = ToolDefinition( - name=self.name, - description=self.description, - parameters_json_schema=self.function_schema.json_schema, - strict=self.strict, - ) + base_tool_def = self.tool_def if self.prepare is not None: - return await self.prepare(ctx, tool_def) + return await self.prepare(ctx, base_tool_def) else: - return tool_def - - async def run( - self, - message: _messages.ToolCallPart, - run_context: RunContext[AgentDepsT], - tracer: Tracer, - include_content: bool = False, - ) -> _messages.ToolReturnPart | _messages.RetryPromptPart: - """Run the tool function asynchronously. - - This method wraps `_run` in an OpenTelemetry span. - - See . - """ - span_attributes = { - 'gen_ai.tool.name': self.name, - # NOTE: this means `gen_ai.tool.call.id` will be included even if it was generated by pydantic-ai - 'gen_ai.tool.call.id': message.tool_call_id, - **({'tool_arguments': message.args_as_json_str()} if include_content else {}), - 'logfire.msg': f'running tool: {self.name}', - # add the JSON schema so these attributes are formatted nicely in Logfire - 'logfire.json_schema': json.dumps( - { - 'type': 'object', - 'properties': { - **( - { - 'tool_arguments': {'type': 'object'}, - 'tool_response': {'type': 'object'}, - } - if include_content - else {} - ), - 'gen_ai.tool.name': {}, - 'gen_ai.tool.call.id': {}, - }, - } - ), - } - with tracer.start_as_current_span('running tool', attributes=span_attributes) as span: - response = await self._run(message, run_context) - if include_content and span.is_recording(): - span.set_attribute( - 'tool_response', - response.model_response_str() - if isinstance(response, ToolReturnPart) - else response.model_response(), - ) - - return response - - async def _run( - self, message: _messages.ToolCallPart, run_context: RunContext[AgentDepsT] - ) -> _messages.ToolReturnPart | _messages.RetryPromptPart: - try: - validator = self.function_schema.validator - if isinstance(message.args, str): - args_dict = validator.validate_json(message.args or '{}') - else: - args_dict = validator.validate_python(message.args or {}) - except ValidationError as e: - return self._on_error(e, message) - - ctx = dataclasses.replace( - run_context, - retry=self.current_retry, - tool_name=message.tool_name, - tool_call_id=message.tool_call_id, - ) - try: - response_content = await self.function_schema.call(args_dict, ctx) - except ModelRetry as e: - return self._on_error(e, message) - - self.current_retry = 0 - return _messages.ToolReturnPart( - tool_name=message.tool_name, - content=response_content, - tool_call_id=message.tool_call_id, - ) - - def _on_error( - self, exc: ValidationError | ModelRetry, call_message: _messages.ToolCallPart - ) -> _messages.RetryPromptPart: - self.current_retry += 1 - if self.max_retries is None or self.current_retry > self.max_retries: - raise UnexpectedModelBehavior(f'Tool exceeded max retries count of {self.max_retries}') from exc - else: - if isinstance(exc, ValidationError): - content = exc.errors(include_url=False, include_context=False) - else: - content = exc.message - return _messages.RetryPromptPart( - tool_name=call_message.tool_name, - content=content, - tool_call_id=call_message.tool_call_id, - ) + return base_tool_def ObjectJsonSchema: TypeAlias = dict[str, Any] @@ -429,6 +324,9 @@ def _on_error( With PEP-728 this should be a TypedDict with `type: Literal['object']`, and `extra_parts=Any` """ +ToolKind: TypeAlias = Literal['function', 'output', 'deferred'] +"""Kind of tool.""" + @dataclass(repr=False) class ToolDefinition: @@ -440,7 +338,7 @@ class ToolDefinition: name: str """The name of the tool.""" - parameters_json_schema: ObjectJsonSchema + parameters_json_schema: ObjectJsonSchema = field(default_factory=lambda: {'type': 'object', 'properties': {}}) """The JSON schema for the tool's parameters.""" description: str | None = None @@ -464,4 +362,13 @@ class ToolDefinition: Note: this is currently only supported by OpenAI models. """ + kind: ToolKind = field(default='function') + """The kind of tool: + + - `'function'`: a tool that can be executed by Pydantic AI and has its result returned to the model + - `'output'`: a tool that passes through an output value that ends the run + - `'deferred'`: a tool that will be executed not by Pydantic AI, but by the upstream service that called the agent, such as a web application that supports frontend-defined tools provided to Pydantic AI via e.g. [AG-UI](https://docs.ag-ui.com/concepts/tools#frontend-defined-tools). + When the model calls a deferred tool, the agent run ends with a `DeferredToolCalls` object and a new run is expected to be started at a later point with the message history and new `ToolReturnPart`s corresponding to each deferred call. + """ + __repr__ = _utils.dataclasses_no_defaults_repr diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/__init__.py b/pydantic_ai_slim/pydantic_ai/toolsets/__init__.py new file mode 100644 index 0000000000..f3d3d362dc --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/toolsets/__init__.py @@ -0,0 +1,22 @@ +from .abstract import AbstractToolset, ToolsetTool +from .combined import CombinedToolset +from .deferred import DeferredToolset +from .filtered import FilteredToolset +from .function import FunctionToolset +from .prefixed import PrefixedToolset +from .prepared import PreparedToolset +from .renamed import RenamedToolset +from .wrapper import WrapperToolset + +__all__ = ( + 'AbstractToolset', + 'ToolsetTool', + 'CombinedToolset', + 'DeferredToolset', + 'FilteredToolset', + 'FunctionToolset', + 'PrefixedToolset', + 'RenamedToolset', + 'PreparedToolset', + 'WrapperToolset', +) diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/abstract.py b/pydantic_ai_slim/pydantic_ai/toolsets/abstract.py new file mode 100644 index 0000000000..0f19eec3bc --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/toolsets/abstract.py @@ -0,0 +1,165 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Protocol, TypeVar + +from pydantic_core import SchemaValidator +from typing_extensions import Self + +from .._run_context import AgentDepsT, RunContext +from ..tools import ToolDefinition, ToolsPrepareFunc + +if TYPE_CHECKING: + from .filtered import FilteredToolset + from .prefixed import PrefixedToolset + from .prepared import PreparedToolset + from .renamed import RenamedToolset + from .wrapper import WrapperToolset + +WrapperT = TypeVar('WrapperT', bound='WrapperToolset[Any]') + + +class SchemaValidatorProt(Protocol): + """Protocol for a Pydantic Core `SchemaValidator` or `PluggableSchemaValidator` (which is private but API-compatible).""" + + def validate_json( + self, + input: str | bytes | bytearray, + *, + allow_partial: bool | Literal['off', 'on', 'trailing-strings'] = False, + **kwargs: Any, + ) -> Any: ... + + def validate_python( + self, input: Any, *, allow_partial: bool | Literal['off', 'on', 'trailing-strings'] = False, **kwargs: Any + ) -> Any: ... + + +@dataclass +class ToolsetTool(Generic[AgentDepsT]): + """Definition of a tool available on a toolset. + + This is a wrapper around a plain tool definition that includes information about: + + - the toolset that provided it, for use in error messages + - the maximum number of retries to attempt if the tool call fails + - the validator for the tool's arguments + """ + + toolset: AbstractToolset[AgentDepsT] + """The toolset that provided this tool, for use in error messages.""" + tool_def: ToolDefinition + """The tool definition for this tool, including the name, description, and parameters.""" + max_retries: int + """The maximum number of retries to attempt if the tool call fails.""" + args_validator: SchemaValidator | SchemaValidatorProt + """The Pydantic Core validator for the tool's arguments. + + For example, a [`pydantic.TypeAdapter(...).validator`](https://docs.pydantic.dev/latest/concepts/type_adapter/) or [`pydantic_core.SchemaValidator`](https://docs.pydantic.dev/latest/api/pydantic_core/#pydantic_core.SchemaValidator). + """ + + +class AbstractToolset(ABC, Generic[AgentDepsT]): + """A toolset is a collection of tools that can be used by an agent. + + It is responsible for: + + - Listing the tools it contains + - Validating the arguments of the tools + - Calling the tools + + See [toolset docs](../toolsets.md) for more information. + """ + + @property + def name(self) -> str: + """The name of the toolset for use in error messages.""" + return self.__class__.__name__.replace('Toolset', ' toolset') + + @property + def tool_name_conflict_hint(self) -> str: + """A hint for how to avoid name conflicts with other toolsets for use in error messages.""" + return 'Rename the tool or wrap the toolset in a `PrefixedToolset` to avoid name conflicts.' + + async def __aenter__(self) -> Self: + """Enter the toolset context. + + This is where you can set up network connections in a concrete implementation. + """ + return self + + async def __aexit__(self, *args: Any) -> bool | None: + """Exit the toolset context. + + This is where you can tear down network connections in a concrete implementation. + """ + return None + + @abstractmethod + async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]: + """The tools that are available in this toolset.""" + raise NotImplementedError() + + @abstractmethod + async def call_tool( + self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT] + ) -> Any: + """Call a tool with the given arguments. + + Args: + name: The name of the tool to call. + tool_args: The arguments to pass to the tool. + ctx: The run context. + tool: The tool definition returned by [`get_tools`][pydantic_ai.toolsets.AbstractToolset.get_tools] that was called. + """ + raise NotImplementedError() + + def apply(self, visitor: Callable[[AbstractToolset[AgentDepsT]], Any]) -> Any: + """Run a visitor function on all concrete toolsets that are not wrappers (i.e. they implement their own tool listing and calling).""" + return visitor(self) + + def filtered( + self, filter_func: Callable[[RunContext[AgentDepsT], ToolDefinition], bool] + ) -> FilteredToolset[AgentDepsT]: + """Returns a new toolset that filters this toolset's tools using a filter function that takes the agent context and the tool definition. + + See [toolset docs](../toolsets.md#filtering-tools) for more information. + """ + from .filtered import FilteredToolset + + return FilteredToolset(self, filter_func) + + def prefixed(self, prefix: str) -> PrefixedToolset[AgentDepsT]: + """Returns a new toolset that prefixes the names of this toolset's tools. + + See [toolset docs](../toolsets.md#prefixing-tool-names) for more information. + """ + from .prefixed import PrefixedToolset + + return PrefixedToolset(self, prefix) + + def prepared(self, prepare_func: ToolsPrepareFunc[AgentDepsT]) -> PreparedToolset[AgentDepsT]: + """Returns a new toolset that prepares this toolset's tools using a prepare function that takes the agent context and the original tool definitions. + + See [toolset docs](../toolsets.md#preparing-tool-definitions) for more information. + """ + from .prepared import PreparedToolset + + return PreparedToolset(self, prepare_func) + + def renamed(self, name_map: dict[str, str]) -> RenamedToolset[AgentDepsT]: + """Returns a new toolset that renames this toolset's tools using a dictionary mapping new names to original names. + + See [toolset docs](../toolsets.md#renaming-tools) for more information. + """ + from .renamed import RenamedToolset + + return RenamedToolset(self, name_map) + + def wrap(self, wrapper_cls: type[WrapperT], *args: Any, **kwargs: Any) -> WrapperT: + """Returns an instance of the provided wrapper class wrapping this toolset, with all arguments passed to the wrapper class constructor. + + See [toolset docs](../toolsets.md#wrapping-a-toolset) for more information. + """ + return wrapper_cls(self, *args, **kwargs) diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/combined.py b/pydantic_ai_slim/pydantic_ai/toolsets/combined.py new file mode 100644 index 0000000000..a083477196 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/toolsets/combined.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +import asyncio +from collections.abc import Sequence +from contextlib import AsyncExitStack +from dataclasses import dataclass, field +from typing import Any, Callable + +from typing_extensions import Self + +from .._run_context import AgentDepsT, RunContext +from .._utils import get_async_lock +from ..exceptions import UserError +from .abstract import AbstractToolset, ToolsetTool + + +@dataclass +class _CombinedToolsetTool(ToolsetTool[AgentDepsT]): + """A tool definition for a combined toolset tools that keeps track of the source toolset and tool.""" + + source_toolset: AbstractToolset[AgentDepsT] + source_tool: ToolsetTool[AgentDepsT] + + +@dataclass +class CombinedToolset(AbstractToolset[AgentDepsT]): + """A toolset that combines multiple toolsets. + + See [toolset docs](../toolsets.md#combining-toolsets) for more information. + """ + + toolsets: Sequence[AbstractToolset[AgentDepsT]] + + _enter_lock: asyncio.Lock = field(compare=False, init=False) + _entered_count: int = field(init=False) + _exit_stack: AsyncExitStack | None = field(init=False) + + def __post_init__(self): + self._enter_lock = get_async_lock() + self._entered_count = 0 + self._exit_stack = None + + async def __aenter__(self) -> Self: + async with self._enter_lock: + if self._entered_count == 0: + self._exit_stack = AsyncExitStack() + for toolset in self.toolsets: + await self._exit_stack.enter_async_context(toolset) + self._entered_count += 1 + return self + + async def __aexit__(self, *args: Any) -> bool | None: + async with self._enter_lock: + self._entered_count -= 1 + if self._entered_count == 0 and self._exit_stack is not None: + await self._exit_stack.aclose() + self._exit_stack = None + + async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]: + toolsets_tools = await asyncio.gather(*(toolset.get_tools(ctx) for toolset in self.toolsets)) + all_tools: dict[str, ToolsetTool[AgentDepsT]] = {} + + for toolset, tools in zip(self.toolsets, toolsets_tools): + for name, tool in tools.items(): + if existing_tools := all_tools.get(name): + raise UserError( + f'{toolset.name} defines a tool whose name conflicts with existing tool from {existing_tools.toolset.name}: {name!r}. {toolset.tool_name_conflict_hint}' + ) + + all_tools[name] = _CombinedToolsetTool( + toolset=tool.toolset, + tool_def=tool.tool_def, + max_retries=tool.max_retries, + args_validator=tool.args_validator, + source_toolset=toolset, + source_tool=tool, + ) + return all_tools + + async def call_tool( + self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT] + ) -> Any: + assert isinstance(tool, _CombinedToolsetTool) + return await tool.source_toolset.call_tool(name, tool_args, ctx, tool.source_tool) + + def apply(self, visitor: Callable[[AbstractToolset[AgentDepsT]], Any]) -> Any: + for toolset in self.toolsets: + toolset.apply(visitor) diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/deferred.py b/pydantic_ai_slim/pydantic_ai/toolsets/deferred.py new file mode 100644 index 0000000000..29964e9333 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/toolsets/deferred.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +from dataclasses import dataclass, replace +from typing import Any + +from pydantic_core import SchemaValidator, core_schema + +from .._run_context import AgentDepsT, RunContext +from ..tools import ToolDefinition +from .abstract import AbstractToolset, ToolsetTool + +TOOL_SCHEMA_VALIDATOR = SchemaValidator(schema=core_schema.any_schema()) + + +@dataclass +class DeferredToolset(AbstractToolset[AgentDepsT]): + """A toolset that holds deferred tools that will be called by the upstream service that called the agent. + + See [toolset docs](../toolsets.md#deferred-toolset), [`ToolDefinition.kind`][pydantic_ai.tools.ToolDefinition.kind], and [`DeferredToolCalls`][pydantic_ai.output.DeferredToolCalls] for more information. + """ + + tool_defs: list[ToolDefinition] + + async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]: + return { + tool_def.name: ToolsetTool( + toolset=self, + tool_def=replace(tool_def, kind='deferred'), + max_retries=0, + args_validator=TOOL_SCHEMA_VALIDATOR, + ) + for tool_def in self.tool_defs + } + + async def call_tool( + self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT] + ) -> Any: + raise NotImplementedError('Deferred tools cannot be called') diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/filtered.py b/pydantic_ai_slim/pydantic_ai/toolsets/filtered.py new file mode 100644 index 0000000000..3ff98c8ec5 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/toolsets/filtered.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Callable + +from .._run_context import AgentDepsT, RunContext +from ..tools import ToolDefinition +from .abstract import ToolsetTool +from .wrapper import WrapperToolset + + +@dataclass +class FilteredToolset(WrapperToolset[AgentDepsT]): + """A toolset that filters the tools it contains using a filter function that takes the agent context and the tool definition. + + See [toolset docs](../toolsets.md#filtering-tools) for more information. + """ + + filter_func: Callable[[RunContext[AgentDepsT], ToolDefinition], bool] + + async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]: + return { + name: tool for name, tool in (await super().get_tools(ctx)).items() if self.filter_func(ctx, tool.tool_def) + } diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/function.py b/pydantic_ai_slim/pydantic_ai/toolsets/function.py new file mode 100644 index 0000000000..63f44a1f0c --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/toolsets/function.py @@ -0,0 +1,238 @@ +from __future__ import annotations + +from collections.abc import Awaitable, Sequence +from dataclasses import dataclass, field, replace +from typing import Any, Callable, overload + +from pydantic.json_schema import GenerateJsonSchema + +from .._run_context import AgentDepsT, RunContext +from ..exceptions import UserError +from ..tools import ( + DocstringFormat, + GenerateToolJsonSchema, + Tool, + ToolFuncEither, + ToolParams, + ToolPrepareFunc, +) +from .abstract import AbstractToolset, ToolsetTool + + +@dataclass +class _FunctionToolsetTool(ToolsetTool[AgentDepsT]): + """A tool definition for a function toolset tool that keeps track of the function to call.""" + + call_func: Callable[[dict[str, Any], RunContext[AgentDepsT]], Awaitable[Any]] + + +@dataclass(init=False) +class FunctionToolset(AbstractToolset[AgentDepsT]): + """A toolset that lets Python functions be used as tools. + + See [toolset docs](../toolsets.md#function-toolset) for more information. + """ + + max_retries: int = field(default=1) + tools: dict[str, Tool[Any]] = field(default_factory=dict) + + def __init__(self, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = [], max_retries: int = 1): + """Build a new function toolset. + + Args: + tools: The tools to add to the toolset. + max_retries: The maximum number of retries for each tool during a run. + """ + self.max_retries = max_retries + self.tools = {} + for tool in tools: + if isinstance(tool, Tool): + self.add_tool(tool) + else: + self.add_function(tool) + + @overload + def tool(self, func: ToolFuncEither[AgentDepsT, ToolParams], /) -> ToolFuncEither[AgentDepsT, ToolParams]: ... + + @overload + def tool( + self, + /, + *, + name: str | None = None, + retries: int | None = None, + prepare: ToolPrepareFunc[AgentDepsT] | None = None, + docstring_format: DocstringFormat = 'auto', + require_parameter_descriptions: bool = False, + schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema, + strict: bool | None = None, + ) -> Callable[[ToolFuncEither[AgentDepsT, ToolParams]], ToolFuncEither[AgentDepsT, ToolParams]]: ... + + def tool( + self, + func: ToolFuncEither[AgentDepsT, ToolParams] | None = None, + /, + *, + name: str | None = None, + retries: int | None = None, + prepare: ToolPrepareFunc[AgentDepsT] | None = None, + docstring_format: DocstringFormat = 'auto', + require_parameter_descriptions: bool = False, + schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema, + strict: bool | None = None, + ) -> Any: + """Decorator to register a tool function which takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument. + + Can decorate a sync or async functions. + + The docstring is inspected to extract both the tool description and description of each parameter, + [learn more](../tools.md#function-tools-and-schema). + + We can't add overloads for every possible signature of tool, since the return type is a recursive union + so the signature of functions decorated with `@toolset.tool` is obscured. + + Example: + ```python + from pydantic_ai import Agent, RunContext + from pydantic_ai.toolsets.function import FunctionToolset + + toolset = FunctionToolset() + + @toolset.tool + def foobar(ctx: RunContext[int], x: int) -> int: + return ctx.deps + x + + @toolset.tool(retries=2) + async def spam(ctx: RunContext[str], y: float) -> float: + return ctx.deps + y + + agent = Agent('test', toolsets=[toolset], deps_type=int) + result = agent.run_sync('foobar', deps=1) + print(result.output) + #> {"foobar":1,"spam":1.0} + ``` + + Args: + func: The tool function to register. + name: The name of the tool, defaults to the function name. + retries: The number of retries to allow for this tool, defaults to the agent's default retries, + which defaults to 1. + prepare: custom method to prepare the tool definition for each step, return `None` to omit this + tool from a given step. This is useful if you want to customise a tool at call time, + or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc]. + docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat]. + Defaults to `'auto'`, such that the format is inferred from the structure of the docstring. + require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False. + schema_generator: The JSON schema generator class to use for this tool. Defaults to `GenerateToolJsonSchema`. + strict: Whether to enforce JSON schema compliance (only affects OpenAI). + See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info. + """ + + def tool_decorator( + func_: ToolFuncEither[AgentDepsT, ToolParams], + ) -> ToolFuncEither[AgentDepsT, ToolParams]: + # noinspection PyTypeChecker + self.add_function( + func_, + None, + name, + retries, + prepare, + docstring_format, + require_parameter_descriptions, + schema_generator, + strict, + ) + return func_ + + return tool_decorator if func is None else tool_decorator(func) + + def add_function( + self, + func: ToolFuncEither[AgentDepsT, ToolParams], + takes_ctx: bool | None = None, + name: str | None = None, + retries: int | None = None, + prepare: ToolPrepareFunc[AgentDepsT] | None = None, + docstring_format: DocstringFormat = 'auto', + require_parameter_descriptions: bool = False, + schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema, + strict: bool | None = None, + ) -> None: + """Add a function as a tool to the toolset. + + Can take a sync or async function. + + The docstring is inspected to extract both the tool description and description of each parameter, + [learn more](../tools.md#function-tools-and-schema). + + Args: + func: The tool function to register. + takes_ctx: Whether the function takes a [`RunContext`][pydantic_ai.tools.RunContext] as its first argument. If `None`, this is inferred from the function signature. + name: The name of the tool, defaults to the function name. + retries: The number of retries to allow for this tool, defaults to the agent's default retries, + which defaults to 1. + prepare: custom method to prepare the tool definition for each step, return `None` to omit this + tool from a given step. This is useful if you want to customise a tool at call time, + or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc]. + docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat]. + Defaults to `'auto'`, such that the format is inferred from the structure of the docstring. + require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False. + schema_generator: The JSON schema generator class to use for this tool. Defaults to `GenerateToolJsonSchema`. + strict: Whether to enforce JSON schema compliance (only affects OpenAI). + See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info. + """ + tool = Tool[AgentDepsT]( + func, + takes_ctx=takes_ctx, + name=name, + max_retries=retries, + prepare=prepare, + docstring_format=docstring_format, + require_parameter_descriptions=require_parameter_descriptions, + schema_generator=schema_generator, + strict=strict, + ) + self.add_tool(tool) + + def add_tool(self, tool: Tool[AgentDepsT]) -> None: + """Add a tool to the toolset. + + Args: + tool: The tool to add. + """ + if tool.name in self.tools: + raise UserError(f'Tool name conflicts with existing tool: {tool.name!r}') + if tool.max_retries is None: + tool.max_retries = self.max_retries + self.tools[tool.name] = tool + + async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]: + tools: dict[str, ToolsetTool[AgentDepsT]] = {} + for original_name, tool in self.tools.items(): + run_context = replace(ctx, tool_name=original_name, retry=ctx.retries.get(original_name, 0)) + tool_def = await tool.prepare_tool_def(run_context) + if not tool_def: + continue + + new_name = tool_def.name + if new_name in tools: + if new_name != original_name: + raise UserError(f'Renaming tool {original_name!r} to {new_name!r} conflicts with existing tool.') + else: + raise UserError(f'Tool name conflicts with previously renamed tool: {new_name!r}.') + + tools[new_name] = _FunctionToolsetTool( + toolset=self, + tool_def=tool_def, + max_retries=tool.max_retries if tool.max_retries is not None else self.max_retries, + args_validator=tool.function_schema.validator, + call_func=tool.function_schema.call, + ) + return tools + + async def call_tool( + self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT] + ) -> Any: + assert isinstance(tool, _FunctionToolsetTool) + return await tool.call_func(tool_args, ctx) diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/prefixed.py b/pydantic_ai_slim/pydantic_ai/toolsets/prefixed.py new file mode 100644 index 0000000000..be70ed4f0f --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/toolsets/prefixed.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +from dataclasses import dataclass, replace +from typing import Any + +from .._run_context import AgentDepsT, RunContext +from .abstract import ToolsetTool +from .wrapper import WrapperToolset + + +@dataclass +class PrefixedToolset(WrapperToolset[AgentDepsT]): + """A toolset that prefixes the names of the tools it contains. + + See [toolset docs](../toolsets.md#prefixing-tool-names) for more information. + """ + + prefix: str + + async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]: + return { + new_name: replace( + tool, + toolset=self, + tool_def=replace(tool.tool_def, name=new_name), + ) + for name, tool in (await super().get_tools(ctx)).items() + if (new_name := f'{self.prefix}_{name}') + } + + async def call_tool( + self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT] + ) -> Any: + original_name = name.removeprefix(self.prefix + '_') + ctx = replace(ctx, tool_name=original_name) + tool = replace(tool, tool_def=replace(tool.tool_def, name=original_name)) + return await super().call_tool(original_name, tool_args, ctx, tool) diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/prepared.py b/pydantic_ai_slim/pydantic_ai/toolsets/prepared.py new file mode 100644 index 0000000000..af604d4328 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/toolsets/prepared.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +from dataclasses import dataclass, replace + +from .._run_context import AgentDepsT, RunContext +from ..exceptions import UserError +from ..tools import ToolsPrepareFunc +from .abstract import ToolsetTool +from .wrapper import WrapperToolset + + +@dataclass +class PreparedToolset(WrapperToolset[AgentDepsT]): + """A toolset that prepares the tools it contains using a prepare function that takes the agent context and the original tool definitions. + + See [toolset docs](../toolsets.md#preparing-tool-definitions) for more information. + """ + + prepare_func: ToolsPrepareFunc[AgentDepsT] + + async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]: + original_tools = await super().get_tools(ctx) + original_tool_defs = [tool.tool_def for tool in original_tools.values()] + prepared_tool_defs_by_name = { + tool_def.name: tool_def for tool_def in (await self.prepare_func(ctx, original_tool_defs) or []) + } + + if len(prepared_tool_defs_by_name.keys() - original_tools.keys()) > 0: + raise UserError( + 'Prepare function cannot add or rename tools. Use `FunctionToolset.add_function()` or `RenamedToolset` instead.' + ) + + return { + name: replace(original_tools[name], tool_def=tool_def) + for name, tool_def in prepared_tool_defs_by_name.items() + } diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/renamed.py b/pydantic_ai_slim/pydantic_ai/toolsets/renamed.py new file mode 100644 index 0000000000..c0d8aff7a0 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/toolsets/renamed.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +from dataclasses import dataclass, replace +from typing import Any + +from .._run_context import AgentDepsT, RunContext +from .abstract import ToolsetTool +from .wrapper import WrapperToolset + + +@dataclass +class RenamedToolset(WrapperToolset[AgentDepsT]): + """A toolset that renames the tools it contains using a dictionary mapping new names to original names. + + See [toolset docs](../toolsets.md#renaming-tools) for more information. + """ + + name_map: dict[str, str] + + async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]: + original_to_new_name_map = {v: k for k, v in self.name_map.items()} + original_tools = await super().get_tools(ctx) + tools: dict[str, ToolsetTool[AgentDepsT]] = {} + for original_name, tool in original_tools.items(): + new_name = original_to_new_name_map.get(original_name, None) + if new_name: + tools[new_name] = replace( + tool, + toolset=self, + tool_def=replace(tool.tool_def, name=new_name), + ) + else: + tools[original_name] = tool + return tools + + async def call_tool( + self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT] + ) -> Any: + original_name = self.name_map.get(name, name) + ctx = replace(ctx, tool_name=original_name) + tool = replace(tool, tool_def=replace(tool.tool_def, name=original_name)) + return await super().call_tool(original_name, tool_args, ctx, tool) diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/wrapper.py b/pydantic_ai_slim/pydantic_ai/toolsets/wrapper.py new file mode 100644 index 0000000000..1dddd96a51 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/toolsets/wrapper.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Callable + +from typing_extensions import Self + +from .._run_context import AgentDepsT, RunContext +from .abstract import AbstractToolset, ToolsetTool + + +@dataclass +class WrapperToolset(AbstractToolset[AgentDepsT]): + """A toolset that wraps another toolset and delegates to it. + + See [toolset docs](../toolsets.md#wrapping-a-toolset) for more information. + """ + + wrapped: AbstractToolset[AgentDepsT] + + async def __aenter__(self) -> Self: + await self.wrapped.__aenter__() + return self + + async def __aexit__(self, *args: Any) -> bool | None: + return await self.wrapped.__aexit__(*args) + + async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]: + return await self.wrapped.get_tools(ctx) + + async def call_tool( + self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT] + ) -> Any: + return await self.wrapped.call_tool(name, tool_args, ctx, tool) + + def apply(self, visitor: Callable[[AbstractToolset[AgentDepsT]], Any]) -> Any: + return self.wrapped.apply(visitor) diff --git a/tests/cassettes/test_mcp/test_agent_with_server_not_running.yaml b/tests/cassettes/test_mcp/test_agent_with_server_not_running.yaml new file mode 100644 index 0000000000..e33e36f96e --- /dev/null +++ b/tests/cassettes/test_mcp/test_agent_with_server_not_running.yaml @@ -0,0 +1,391 @@ +interactions: +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '2501' + content-type: + - application/json + host: + - api.openai.com + method: POST + parsed_body: + messages: + - content: What is 0 degrees Celsius in Fahrenheit? + role: user + model: gpt-4o + stream: false + tool_choice: auto + tools: + - function: + description: "Convert Celsius to Fahrenheit.\n\n Args:\n celsius: Temperature in Celsius\n\n Returns:\n + \ Temperature in Fahrenheit\n " + name: celsius_to_fahrenheit + parameters: + properties: + celsius: + type: number + required: + - celsius + type: object + type: function + - function: + description: "Get the weather forecast for a location.\n\n Args:\n location: The location to get the weather + forecast for.\n\n Returns:\n The weather forecast for the location.\n " + name: get_weather_forecast + parameters: + properties: + location: + type: string + required: + - location + type: object + type: function + - function: + description: '' + name: get_image_resource + parameters: + properties: {} + type: object + type: function + - function: + description: '' + name: get_audio_resource + parameters: + properties: {} + type: object + type: function + - function: + description: '' + name: get_product_name + parameters: + properties: {} + type: object + type: function + - function: + description: '' + name: get_image + parameters: + properties: {} + type: object + type: function + - function: + description: '' + name: get_dict + parameters: + properties: {} + type: object + type: function + - function: + description: '' + name: get_error + parameters: + properties: + value: + type: boolean + type: object + type: function + - function: + description: '' + name: get_none + parameters: + properties: {} + type: object + type: function + - function: + description: '' + name: get_multiple_items + parameters: + properties: {} + type: object + type: function + - function: + description: "Get the current log level.\n\n Returns:\n The current log level.\n " + name: get_log_level + parameters: + properties: {} + type: object + type: function + - function: + description: "Echo the run context.\n\n Args:\n ctx: Context object containing request and session information.\n\n + \ Returns:\n Dictionary with an echo message and the deps.\n " + name: echo_deps + parameters: + properties: {} + type: object + type: function + - function: + description: Use sampling callback. + name: use_sampling + parameters: + properties: + foo: + type: string + required: + - foo + type: object + type: function + uri: https://api.openai.com/v1/chat/completions + response: + headers: + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '1086' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '420' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + choices: + - finish_reason: tool_calls + index: 0 + logprobs: null + message: + annotations: [] + content: null + refusal: null + role: assistant + tool_calls: + - function: + arguments: '{"celsius":0}' + name: celsius_to_fahrenheit + id: call_hS0oexgCNI6TneJuPPuwn9jQ + type: function + created: 1751491994 + id: chatcmpl-BozMoBhgfC5D8QBjkiOwz5OxxrwQK + model: gpt-4o-2024-08-06 + object: chat.completion + service_tier: default + system_fingerprint: fp_a288987b44 + usage: + completion_tokens: 18 + completion_tokens_details: + accepted_prediction_tokens: 0 + audio_tokens: 0 + reasoning_tokens: 0 + rejected_prediction_tokens: 0 + prompt_tokens: 268 + prompt_tokens_details: + audio_tokens: 0 + cached_tokens: 0 + total_tokens: 286 + status: + code: 200 + message: OK +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '2748' + content-type: + - application/json + cookie: + - __cf_bm=JOV7WG2Y48FZrZxdh0IZvA9mCj_ljIN3DhGMuC1pw6M-1751491995-1.0.1.1-zGPrLbzYx7y3iZT28xogbHO1KAIej60kPEwQ8ZxGMxv1r.ICtqI0T8WCnlyUccKfLSXB6ZTNQT05xCma8LSvq2pk4X2eEuSkYC1sPqbuLU8; + _cfuvid=LdoyX0uKYwM98NSSSvySlZAiJHCVHz_1krUGKbWmNHg-1751491995391-0.0.1.1-604800000 + host: + - api.openai.com + method: POST + parsed_body: + messages: + - content: What is 0 degrees Celsius in Fahrenheit? + role: user + - role: assistant + tool_calls: + - function: + arguments: '{"celsius":0}' + name: celsius_to_fahrenheit + id: call_hS0oexgCNI6TneJuPPuwn9jQ + type: function + - content: '32.0' + role: tool + tool_call_id: call_hS0oexgCNI6TneJuPPuwn9jQ + model: gpt-4o + stream: false + tool_choice: auto + tools: + - function: + description: "Convert Celsius to Fahrenheit.\n\n Args:\n celsius: Temperature in Celsius\n\n Returns:\n + \ Temperature in Fahrenheit\n " + name: celsius_to_fahrenheit + parameters: + properties: + celsius: + type: number + required: + - celsius + type: object + type: function + - function: + description: "Get the weather forecast for a location.\n\n Args:\n location: The location to get the weather + forecast for.\n\n Returns:\n The weather forecast for the location.\n " + name: get_weather_forecast + parameters: + properties: + location: + type: string + required: + - location + type: object + type: function + - function: + description: '' + name: get_image_resource + parameters: + properties: {} + type: object + type: function + - function: + description: '' + name: get_audio_resource + parameters: + properties: {} + type: object + type: function + - function: + description: '' + name: get_product_name + parameters: + properties: {} + type: object + type: function + - function: + description: '' + name: get_image + parameters: + properties: {} + type: object + type: function + - function: + description: '' + name: get_dict + parameters: + properties: {} + type: object + type: function + - function: + description: '' + name: get_error + parameters: + properties: + value: + type: boolean + type: object + type: function + - function: + description: '' + name: get_none + parameters: + properties: {} + type: object + type: function + - function: + description: '' + name: get_multiple_items + parameters: + properties: {} + type: object + type: function + - function: + description: "Get the current log level.\n\n Returns:\n The current log level.\n " + name: get_log_level + parameters: + properties: {} + type: object + type: function + - function: + description: "Echo the run context.\n\n Args:\n ctx: Context object containing request and session information.\n\n + \ Returns:\n Dictionary with an echo message and the deps.\n " + name: echo_deps + parameters: + properties: {} + type: object + type: function + - function: + description: Use sampling callback. + name: use_sampling + parameters: + properties: + foo: + type: string + required: + - foo + type: object + type: function + uri: https://api.openai.com/v1/chat/completions + response: + headers: + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '849' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '520' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + choices: + - finish_reason: stop + index: 0 + logprobs: null + message: + annotations: [] + content: 0 degrees Celsius is 32.0 degrees Fahrenheit. + refusal: null + role: assistant + created: 1751491998 + id: chatcmpl-BozMsevK8quJblNOyNCaDQpdtDwI5 + model: gpt-4o-2024-08-06 + object: chat.completion + service_tier: default + system_fingerprint: fp_a288987b44 + usage: + completion_tokens: 12 + completion_tokens_details: + accepted_prediction_tokens: 0 + audio_tokens: 0 + reasoning_tokens: 0 + rejected_prediction_tokens: 0 + prompt_tokens: 300 + prompt_tokens_details: + audio_tokens: 0 + cached_tokens: 0 + total_tokens: 312 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/ext/test_langchain.py b/tests/ext/test_langchain.py index 73e7cc0504..926a228194 100644 --- a/tests/ext/test_langchain.py +++ b/tests/ext/test_langchain.py @@ -6,7 +6,7 @@ from pydantic.json_schema import JsonSchemaValue from pydantic_ai import Agent -from pydantic_ai.ext.langchain import tool_from_langchain +from pydantic_ai.ext.langchain import LangChainToolset, tool_from_langchain @dataclass @@ -49,24 +49,26 @@ def get_input_jsonschema(self) -> JsonSchemaValue: } -def test_langchain_tool_conversion(): - langchain_tool = SimulatedLangChainTool( - name='file_search', - description='Recursively search for files in a subdirectory that match the regex pattern', - args={ - 'dir_path': { - 'default': '.', - 'description': 'Subdirectory to search in.', - 'title': 'Dir Path', - 'type': 'string', - }, - 'pattern': { - 'description': 'Unix shell regex, where * matches everything.', - 'title': 'Pattern', - 'type': 'string', - }, +langchain_tool = SimulatedLangChainTool( + name='file_search', + description='Recursively search for files in a subdirectory that match the regex pattern', + args={ + 'dir_path': { + 'default': '.', + 'description': 'Subdirectory to search in.', + 'title': 'Dir Path', + 'type': 'string', }, - ) + 'pattern': { + 'description': 'Unix shell regex, where * matches everything.', + 'title': 'Pattern', + 'type': 'string', + }, + }, +) + + +def test_langchain_tool_conversion(): pydantic_tool = tool_from_langchain(langchain_tool) agent = Agent('test', tools=[pydantic_tool], retries=7) @@ -74,6 +76,13 @@ def test_langchain_tool_conversion(): assert result.output == snapshot("{\"file_search\":\"I was called with {'dir_path': '.', 'pattern': 'a'}\"}") +def test_langchain_toolset(): + toolset = LangChainToolset([langchain_tool]) + agent = Agent('test', toolsets=[toolset], retries=7) + result = agent.run_sync('foobar') + assert result.output == snapshot("{\"file_search\":\"I was called with {'dir_path': '.', 'pattern': 'a'}\"}") + + def test_langchain_tool_no_additional_properties(): langchain_tool = SimulatedLangChainTool( name='file_search', diff --git a/tests/models/test_anthropic.py b/tests/models/test_anthropic.py index 3891c5108c..77857e8821 100644 --- a/tests/models/test_anthropic.py +++ b/tests/models/test_anthropic.py @@ -1700,7 +1700,7 @@ class CityLocation(BaseModel): agent = Agent(m, output_type=NativeOutput(CityLocation)) - with pytest.raises(UserError, match='Structured output is not supported by the model.'): + with pytest.raises(UserError, match='Native structured output is not supported by the model.'): await agent.run('What is the largest city in the user country?') diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index a84c49d869..0b95bedeb5 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -974,12 +974,47 @@ async def test_stream_text_heterogeneous(get_gemini_client: GetGeminiClient): @agent.tool_plain() def get_location(loc_name: str) -> str: - return f'Location for {loc_name}' + return f'Location for {loc_name}' # pragma: no cover async with agent.run_stream('Hello') as result: data = await result.get_output() assert data == 'Hello foo' + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='Hello', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + TextPart(content='Hello foo'), + ToolCallPart( + tool_name='get_location', + args={'loc_name': 'San Fransisco'}, + tool_call_id=IsStr(), + ), + ], + usage=Usage(request_tokens=1, response_tokens=2, total_tokens=3, details={}), + model_name='gemini-1.5-flash', + timestamp=IsDatetime(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='get_location', + content='Tool not executed - a final result was already processed.', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ] + ), + ] + ) async def test_empty_text_ignored(): diff --git a/tests/models/test_model_test.py b/tests/models/test_model_test.py index 31635c080d..02aafd259f 100644 --- a/tests/models/test_model_test.py +++ b/tests/models/test_model_test.py @@ -4,6 +4,7 @@ import asyncio import dataclasses +import re from datetime import timezone from typing import Annotated, Any, Literal @@ -157,7 +158,7 @@ def validate_output(ctx: RunContext[None], output: OutputModel) -> OutputModel: call_count += 1 raise ModelRetry('Fail') - with pytest.raises(UnexpectedModelBehavior, match='Exceeded maximum retries'): + with pytest.raises(UnexpectedModelBehavior, match=re.escape('Exceeded maximum retries (2) for output validation')): agent.run_sync('Hello', model=TestModel()) assert call_count == 3 @@ -200,7 +201,7 @@ class ResultModel(BaseModel): agent = Agent('test', output_type=ResultModel, retries=2) - with pytest.raises(UnexpectedModelBehavior, match='Exceeded maximum retries'): + with pytest.raises(UnexpectedModelBehavior, match=r'Exceeded maximum retries \(2\) for output validation'): agent.run_sync('Hello', model=TestModel(custom_output_args={'foo': 'a', 'bar': 1})) diff --git a/tests/test_a2a.py b/tests/test_a2a.py index c3d56728f4..535e3b1e91 100644 --- a/tests/test_a2a.py +++ b/tests/test_a2a.py @@ -800,8 +800,15 @@ async def test_a2a_multiple_messages(): } ) - await anyio.sleep(0.1) - task = await a2a_client.get_task(task_id) + task = None + tries = 0 + while tries < 10: # pragma: no branch + await anyio.sleep(0.1) + task = await a2a_client.get_task(task_id) + tries += 1 + if 'result' in task and task['result']['status']['state'] == 'completed': # pragma: no branch + break + assert task == snapshot( { 'jsonrpc': '2.0', diff --git a/tests/test_agent.py b/tests/test_agent.py index 47a9c01a12..c1893fdcb4 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -1,6 +1,7 @@ import json import re import sys +from dataclasses import dataclass from datetime import timezone from typing import Any, Callable, Union @@ -45,6 +46,9 @@ from pydantic_ai.profiles import ModelProfile from pydantic_ai.result import Usage from pydantic_ai.tools import ToolDefinition +from pydantic_ai.toolsets.combined import CombinedToolset +from pydantic_ai.toolsets.function import FunctionToolset +from pydantic_ai.toolsets.prefixed import PrefixedToolset from .conftest import IsDatetime, IsNow, IsStr, TestEnv @@ -396,6 +400,7 @@ def test_response_tuple(): 'type': 'object', }, outer_typed_dict_key='response', + kind='output', ) ] ) @@ -469,6 +474,7 @@ def validate_output(ctx: RunContext[None], o: Any) -> Any: 'title': 'Foo', 'type': 'object', }, + kind='output', ) ] ) @@ -548,6 +554,7 @@ def validate_output(ctx: RunContext[None], o: Any) -> Any: 'title': 'Foo', 'type': 'object', }, + kind='output', ), ToolDefinition( name='final_result_Bar', @@ -558,6 +565,7 @@ def validate_output(ctx: RunContext[None], o: Any) -> Any: 'title': 'Bar', 'type': 'object', }, + kind='output', ), ] ) @@ -589,6 +597,7 @@ class MyOutput(BaseModel): 'title': 'MyOutput', 'type': 'object', }, + kind='output', ) ] ) @@ -635,6 +644,7 @@ class Bar(BaseModel): }, outer_typed_dict_key='response', strict=False, + kind='output', ) ] ) @@ -673,6 +683,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'required': ['city'], 'type': 'object', }, + kind='output', ) ] ) @@ -712,6 +723,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'required': ['city'], 'type': 'object', }, + kind='output', ) ] ) @@ -752,6 +764,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'required': ['city'], 'type': 'object', }, + kind='output', ) ] ) @@ -793,6 +806,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'required': ['city'], 'type': 'object', }, + kind='output', ) ] ) @@ -943,7 +957,7 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: [[str, str], [str, TextOutput(upcase)], [TextOutput(upcase), TextOutput(upcase)]], ) def test_output_type_multiple_text_output(output_type: OutputSpec[str]): - with pytest.raises(UserError, match='Only one text output is allowed.'): + with pytest.raises(UserError, match='Only one `str` or `TextOutput` is allowed.'): Agent('test', output_type=output_type) @@ -989,6 +1003,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'required': ['city'], 'type': 'object', }, + kind='output', ) ] ) @@ -1027,6 +1042,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'required': ['city'], 'type': 'object', }, + kind='output', ) ] ) @@ -1065,6 +1081,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'required': ['city'], 'type': 'object', }, + kind='output', ), ToolDefinition( name='final_result_Weather', @@ -1075,6 +1092,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'title': 'Weather', 'type': 'object', }, + kind='output', ), ] ) @@ -1251,6 +1269,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'required': ['city'], 'type': 'object', }, + kind='output', ), ToolDefinition( name='return_weather', @@ -1261,6 +1280,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'title': 'Weather', 'type': 'object', }, + kind='output', ), ] ) @@ -1322,6 +1342,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'type': 'object', }, description='A person', + kind='output', ), ToolDefinition( name='final_result_Animal', @@ -1332,6 +1353,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'type': 'object', }, description='An animal', + kind='output', ), ] ) @@ -1998,7 +2020,7 @@ def empty(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse: agent = Agent(FunctionModel(empty)) with capture_run_messages() as messages: - with pytest.raises(UnexpectedModelBehavior, match=r'Exceeded maximum retries \(1\) for result validation'): + with pytest.raises(UnexpectedModelBehavior, match=r'Exceeded maximum retries \(1\) for output validation'): agent.run_sync('Hello') assert messages == snapshot( [ @@ -2350,12 +2372,6 @@ def another_tool(y: int) -> int: tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), ), - RetryPromptPart( - tool_name='unknown_tool', - content="Unknown tool name: 'unknown_tool'. Available tools: regular_tool, another_tool, final_result", - timestamp=IsNow(tz=timezone.utc), - tool_call_id=IsStr(), - ), ToolReturnPart( tool_name='regular_tool', content=42, @@ -2365,6 +2381,12 @@ def another_tool(y: int) -> int: ToolReturnPart( tool_name='another_tool', content=2, tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc) ), + RetryPromptPart( + content="Unknown tool name: 'unknown_tool'. Available tools: 'final_result', 'regular_tool', 'another_tool'", + tool_name='unknown_tool', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ), ] ), ] @@ -2428,16 +2450,16 @@ def another_tool(y: int) -> int: # pragma: no cover ModelRequest( parts=[ ToolReturnPart( - tool_name='regular_tool', - content='Tool not executed - a final result was already processed.', + tool_name='final_result', + content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), ), ToolReturnPart( - tool_name='final_result', - content='Final result processed.', + tool_name='regular_tool', + content='Tool not executed - a final result was already processed.', tool_call_id=IsStr(), - timestamp=IsNow(tz=timezone.utc), + timestamp=IsDatetime(), ), ToolReturnPart( tool_name='another_tool', @@ -2447,7 +2469,7 @@ def another_tool(y: int) -> int: # pragma: no cover ), RetryPromptPart( tool_name='unknown_tool', - content="Unknown tool name: 'unknown_tool'. Available tools: regular_tool, another_tool, final_result", + content="Unknown tool name: 'unknown_tool'. Available tools: 'final_result', 'regular_tool', 'another_tool'", timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr(), ), @@ -2494,11 +2516,13 @@ def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: # Verify we got appropriate tool returns assert result.new_messages()[-1].parts == snapshot( [ - ToolReturnPart( + RetryPromptPart( + content=[ + {'type': 'missing', 'loc': ('value',), 'msg': 'Field required', 'input': {'bad_value': 'first'}} + ], tool_name='final_result', tool_call_id='first', - content='Output tool not used - result failed validation.', - timestamp=IsNow(tz=timezone.utc), + timestamp=IsDatetime(), ), ToolReturnPart( tool_name='final_result', @@ -3247,7 +3271,7 @@ def hello(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse: agent = Agent(model, output_type=NativeOutput(Foo)) - with pytest.raises(UserError, match='Structured output is not supported by the model.'): + with pytest.raises(UserError, match='Native structured output is not supported by the model.'): agent.run_sync('Hello') agent = Agent(model, output_type=ToolOutput(Foo)) @@ -3435,7 +3459,7 @@ def analyze_data() -> list[Any]: with pytest.raises( UserError, - match="analyze_data's return contains invalid nested ToolReturn objects. ToolReturn should be used directly.", + match="The return value of tool 'analyze_data' contains invalid nested `ToolReturn` objects. `ToolReturn` should be used directly.", ): agent.run_sync('Please analyze the data') @@ -3469,7 +3493,7 @@ def analyze_data() -> ToolReturn: with pytest.raises( UserError, - match="analyze_data's `return_value` contains invalid nested MultiModalContentTypes objects. Please use `content` instead.", + match="The `return_value` of tool 'analyze_data' contains invalid nested `MultiModalContentTypes` objects. Please use `content` instead.", ): agent.run_sync('Please analyze the data') @@ -3534,6 +3558,19 @@ def test_deprecated_kwargs_still_work(): assert issubclass(w[0].category, DeprecationWarning) assert '`result_retries` is deprecated' in str(w[0].message) + try: + from pydantic_ai.mcp import MCPServerStdio + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + + agent = Agent('test', mcp_servers=[MCPServerStdio('python', ['-m', 'tests.mcp_server'])]) # type: ignore[call-arg] + assert len(w) == 1 + assert issubclass(w[0].category, DeprecationWarning) + assert '`mcp_servers` is deprecated' in str(w[0].message) + except ImportError: + pass + def test_deprecated_kwargs_mixed_valid_invalid(): """Test that mix of valid deprecated and invalid kwargs raises error for invalid ones.""" @@ -3548,3 +3585,272 @@ def test_deprecated_kwargs_mixed_valid_invalid(): with warnings.catch_warnings(): warnings.simplefilter('ignore', DeprecationWarning) # Ignore the deprecation warning for result_tool_name Agent('test', result_tool_name='test', foo='value1', bar='value2') # type: ignore[call-arg] + + +def test_override_toolsets(): + foo_toolset = FunctionToolset() + + @foo_toolset.tool + def foo() -> str: + return 'Hello from foo' + + available_tools: list[list[str]] = [] + + async def prepare_tools(ctx: RunContext[None], tool_defs: list[ToolDefinition]) -> list[ToolDefinition]: + nonlocal available_tools + available_tools.append([tool_def.name for tool_def in tool_defs]) + return tool_defs + + agent = Agent('test', toolsets=[foo_toolset], prepare_tools=prepare_tools) + + @agent.tool_plain + def baz() -> str: + return 'Hello from baz' + + result = agent.run_sync('Hello') + assert available_tools[-1] == snapshot(['baz', 'foo']) + assert result.output == snapshot('{"baz":"Hello from baz","foo":"Hello from foo"}') + + bar_toolset = FunctionToolset() + + @bar_toolset.tool + def bar() -> str: + return 'Hello from bar' + + with agent.override(toolsets=[bar_toolset]): + result = agent.run_sync('Hello') + assert available_tools[-1] == snapshot(['baz', 'bar']) + assert result.output == snapshot('{"baz":"Hello from baz","bar":"Hello from bar"}') + + with agent.override(toolsets=[]): + result = agent.run_sync('Hello') + assert available_tools[-1] == snapshot(['baz']) + assert result.output == snapshot('{"baz":"Hello from baz"}') + + result = agent.run_sync('Hello', toolsets=[bar_toolset]) + assert available_tools[-1] == snapshot(['baz', 'foo', 'bar']) + assert result.output == snapshot('{"baz":"Hello from baz","foo":"Hello from foo","bar":"Hello from bar"}') + + with agent.override(toolsets=[]): + result = agent.run_sync('Hello', toolsets=[bar_toolset]) + assert available_tools[-1] == snapshot(['baz']) + assert result.output == snapshot('{"baz":"Hello from baz"}') + + +def test_adding_tools_during_run(): + toolset = FunctionToolset() + + def foo() -> str: + return 'Hello from foo' + + @toolset.tool + def add_foo_tool() -> str: + toolset.add_function(foo) + return 'foo tool added' + + def respond(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + if len(messages) == 1: + return ModelResponse(parts=[ToolCallPart('add_foo_tool')]) + elif len(messages) == 3: + return ModelResponse(parts=[ToolCallPart('foo')]) + else: + return ModelResponse(parts=[TextPart('Done')]) + + agent = Agent(FunctionModel(respond), toolsets=[toolset]) + result = agent.run_sync('Add the foo tool and run it') + assert result.output == snapshot('Done') + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='Add the foo tool and run it', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ToolCallPart(tool_name='add_foo_tool', tool_call_id=IsStr())], + usage=Usage(requests=1, request_tokens=57, response_tokens=2, total_tokens=59), + model_name='function:respond:', + timestamp=IsDatetime(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='add_foo_tool', + content='foo tool added', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ToolCallPart(tool_name='foo', tool_call_id=IsStr())], + usage=Usage(requests=1, request_tokens=60, response_tokens=4, total_tokens=64), + model_name='function:respond:', + timestamp=IsDatetime(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='foo', + content='Hello from foo', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[TextPart(content='Done')], + usage=Usage(requests=1, request_tokens=63, response_tokens=5, total_tokens=68), + model_name='function:respond:', + timestamp=IsDatetime(), + ), + ] + ) + + +def test_prepare_output_tools(): + @dataclass + class AgentDeps: + plan_presented: bool = False + + async def present_plan(ctx: RunContext[AgentDeps], plan: str) -> str: + """ + Present the plan to the user. + """ + ctx.deps.plan_presented = True + return plan + + async def run_sql(ctx: RunContext[AgentDeps], purpose: str, query: str) -> str: + """ + Run an SQL query. + """ + return 'SQL query executed successfully' + + async def only_if_plan_presented( + ctx: RunContext[AgentDeps], tool_defs: list[ToolDefinition] + ) -> list[ToolDefinition]: + return tool_defs if ctx.deps.plan_presented else [] + + agent = Agent( + model='test', + deps_type=AgentDeps, + tools=[present_plan], + output_type=[ToolOutput(run_sql, name='run_sql')], + prepare_output_tools=only_if_plan_presented, + ) + + result = agent.run_sync('Hello', deps=AgentDeps()) + assert result.output == snapshot('SQL query executed successfully') + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='Hello', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + ToolCallPart( + tool_name='present_plan', + args={'plan': 'a'}, + tool_call_id=IsStr(), + ) + ], + usage=Usage(requests=1, request_tokens=51, response_tokens=5, total_tokens=56), + model_name='test', + timestamp=IsDatetime(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='present_plan', + content='a', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + ToolCallPart( + tool_name='run_sql', + args={'purpose': 'a', 'query': 'a'}, + tool_call_id=IsStr(), + ) + ], + usage=Usage(requests=1, request_tokens=52, response_tokens=12, total_tokens=64), + model_name='test', + timestamp=IsDatetime(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='run_sql', + content='Final result processed.', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ] + ), + ] + ) + + +async def test_context_manager(): + try: + from pydantic_ai.mcp import MCPServerStdio + except ImportError: # pragma: lax no cover + pytest.skip('mcp is not installed') + + server1 = MCPServerStdio('python', ['-m', 'tests.mcp_server']) + server2 = MCPServerStdio('python', ['-m', 'tests.mcp_server']) + toolset = CombinedToolset([server1, PrefixedToolset(server2, 'prefix')]) + agent = Agent('test', toolsets=[toolset]) + + async with agent: + assert server1.is_running + assert server2.is_running + + async with agent: + assert server1.is_running + assert server2.is_running + + +def test_set_mcp_sampling_model(): + try: + from pydantic_ai.mcp import MCPServerStdio + except ImportError: # pragma: lax no cover + pytest.skip('mcp is not installed') + + test_model = TestModel() + server1 = MCPServerStdio('python', ['-m', 'tests.mcp_server']) + server2 = MCPServerStdio('python', ['-m', 'tests.mcp_server'], sampling_model=test_model) + toolset = CombinedToolset([server1, PrefixedToolset(server2, 'prefix')]) + agent = Agent(None, toolsets=[toolset]) + + with pytest.raises(UserError, match='No sampling model provided and no model set on the agent.'): + agent.set_mcp_sampling_model() + assert server1.sampling_model is None + assert server2.sampling_model is test_model + + agent.model = test_model + agent.set_mcp_sampling_model() + assert server1.sampling_model is test_model + assert server2.sampling_model is test_model + + function_model = FunctionModel(lambda messages, info: ModelResponse(parts=[TextPart('Hello')])) + with agent.override(model=function_model): + agent.set_mcp_sampling_model() + assert server1.sampling_model is function_model + assert server2.sampling_model is function_model + + function_model2 = FunctionModel(lambda messages, info: ModelResponse(parts=[TextPart('Goodbye')])) + agent.set_mcp_sampling_model(function_model2) + assert server1.sampling_model is function_model2 + assert server2.sampling_model is function_model2 diff --git a/tests/test_examples.py b/tests/test_examples.py index 5735a11ffa..0fbe64bc3a 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -21,6 +21,7 @@ from rich.console import Console from pydantic_ai import ModelHTTPError +from pydantic_ai._run_context import RunContext from pydantic_ai._utils import group_by_temporal from pydantic_ai.exceptions import UnexpectedModelBehavior from pydantic_ai.messages import ( @@ -36,6 +37,8 @@ from pydantic_ai.models.fallback import FallbackModel from pydantic_ai.models.function import AgentInfo, DeltaToolCall, DeltaToolCalls, FunctionModel from pydantic_ai.models.test import TestModel +from pydantic_ai.toolsets import AbstractToolset +from pydantic_ai.toolsets.abstract import ToolsetTool from .conftest import ClientWithHandler, TestEnv, try_import @@ -259,18 +262,20 @@ def rich_prompt_ask(prompt: str, *_args: Any, **_kwargs: Any) -> str: raise ValueError(f'Unexpected prompt: {prompt}') -class MockMCPServer: - is_running = True - +class MockMCPServer(AbstractToolset[Any]): async def __aenter__(self) -> MockMCPServer: return self async def __aexit__(self, *args: Any) -> None: pass - @staticmethod - async def list_tools() -> list[None]: - return [] + async def get_tools(self, ctx: RunContext[Any]) -> dict[str, ToolsetTool[Any]]: + return {} + + async def call_tool( + self, name: str, tool_args: dict[str, Any], ctx: RunContext[Any], tool: ToolsetTool[Any] + ) -> Any: + return None # pragma: lax no cover text_responses: dict[str, str | ToolCallPart] = { @@ -553,6 +558,21 @@ async def model_logic( # noqa: C901 ) ] ) + elif m.content == 'Greet the user in a personalized way': + if any(t.name == 'get_preferred_language' for t in info.function_tools): + part = ToolCallPart( + tool_name='get_preferred_language', + args={'default_language': 'en-US'}, + tool_call_id='pyd_ai_tool_call_id', + ) + else: + part = ToolCallPart( + tool_name='final_result', + args={'greeting': 'Hello, David!', 'language_code': 'en-US'}, + tool_call_id='pyd_ai_tool_call_id', + ) + + return ModelResponse(parts=[part]) elif response := text_responses.get(m.content): if isinstance(response, str): return ModelResponse(parts=[TextPart(response)]) @@ -697,6 +717,16 @@ async def model_logic( # noqa: C901 ) elif isinstance(m, ToolReturnPart) and m.tool_name == 'image_generator': return ModelResponse(parts=[TextPart('Image file written to robot_punk.svg.')]) + elif isinstance(m, ToolReturnPart) and m.tool_name == 'get_preferred_language': + return ModelResponse( + parts=[ + ToolCallPart( + tool_name='final_result', + args={'greeting': 'Hola, David! Espero que tengas un gran día!', 'language_code': 'es-MX'}, + tool_call_id='pyd_ai_tool_call_id', + ) + ] + ) else: sys.stdout.write(str(debug.format(messages, info))) raise RuntimeError(f'Unexpected message: {m}') diff --git a/tests/test_logfire.py b/tests/test_logfire.py index 97ba871cc9..799724179a 100644 --- a/tests/test_logfire.py +++ b/tests/test_logfire.py @@ -11,6 +11,7 @@ from pydantic_ai import Agent from pydantic_ai._utils import get_traceparent +from pydantic_ai.exceptions import ModelRetry, UnexpectedModelBehavior from pydantic_ai.messages import ModelMessage, ModelResponse, TextPart, ToolCallPart from pydantic_ai.models.function import AgentInfo, FunctionModel from pydantic_ai.models.instrumented import InstrumentationSettings, InstrumentedModel @@ -294,6 +295,7 @@ async def my_ret(x: int) -> str: }, 'outer_typed_dict_key': None, 'strict': None, + 'kind': 'function', } ], 'output_mode': 'text', @@ -641,10 +643,11 @@ async def test_feedback(capfire: CaptureLogfire) -> None: @pytest.mark.skipif(not logfire_installed, reason='logfire not installed') -@pytest.mark.parametrize('include_content', [True, False]) +@pytest.mark.parametrize('include_content,tool_error', [(True, False), (True, True), (False, False), (False, True)]) def test_include_tool_args_span_attributes( get_logfire_summary: Callable[[], LogfireSummary], include_content: bool, + tool_error: bool, ) -> None: """Test that tool arguments are included/excluded in span attributes based on instrumentation settings.""" @@ -655,61 +658,119 @@ def test_include_tool_args_span_attributes( @my_agent.tool_plain async def add_numbers(x: int, y: int) -> int: """Add two numbers together.""" + if tool_error: + raise ModelRetry('Tool error') return x + y - result = my_agent.run_sync('Add 42 and 42') - assert result.output == snapshot('{"add_numbers":84}') + try: + result = my_agent.run_sync('Add 42 and 42') + assert result.output == snapshot('{"add_numbers":84}') + except UnexpectedModelBehavior: + if not tool_error: + raise # pragma: no cover summary = get_logfire_summary() - [tool_attributes] = [ + tool_attributes = next( attributes for attributes in summary.attributes.values() if attributes.get('gen_ai.tool.name') == 'add_numbers' - ] + ) if include_content: - assert tool_attributes == snapshot( - { - 'gen_ai.tool.name': 'add_numbers', - 'gen_ai.tool.call.id': IsStr(), - 'tool_arguments': '{"x":42,"y":42}', - 'tool_response': '84', - 'logfire.msg': 'running tool: add_numbers', - 'logfire.json_schema': IsJson( - snapshot( - { - 'type': 'object', - 'properties': { - 'tool_arguments': {'type': 'object'}, - 'tool_response': {'type': 'object'}, - 'gen_ai.tool.name': {}, - 'gen_ai.tool.call.id': {}, - }, - } - ) - ), - 'logfire.span_type': 'span', - } - ) + if tool_error: + assert tool_attributes == snapshot( + { + 'gen_ai.tool.name': 'add_numbers', + 'gen_ai.tool.call.id': IsStr(), + 'tool_arguments': '{"x":42,"y":42}', + 'logfire.msg': 'running tool: add_numbers', + 'logfire.json_schema': IsJson( + snapshot( + { + 'type': 'object', + 'properties': { + 'tool_arguments': {'type': 'object'}, + 'tool_response': {'type': 'object'}, + 'gen_ai.tool.name': {}, + 'gen_ai.tool.call.id': {}, + }, + } + ) + ), + 'logfire.span_type': 'span', + 'tool_response': """\ +Tool error + +Fix the errors and try again.\ +""", + 'logfire.level_num': 17, + } + ) + else: + assert tool_attributes == snapshot( + { + 'gen_ai.tool.name': 'add_numbers', + 'gen_ai.tool.call.id': IsStr(), + 'tool_arguments': '{"x":42,"y":42}', + 'tool_response': '84', + 'logfire.msg': 'running tool: add_numbers', + 'logfire.json_schema': IsJson( + snapshot( + { + 'type': 'object', + 'properties': { + 'tool_arguments': {'type': 'object'}, + 'tool_response': {'type': 'object'}, + 'gen_ai.tool.name': {}, + 'gen_ai.tool.call.id': {}, + }, + } + ) + ), + 'logfire.span_type': 'span', + } + ) else: - assert tool_attributes == snapshot( - { - 'gen_ai.tool.name': 'add_numbers', - 'gen_ai.tool.call.id': IsStr(), - 'logfire.msg': 'running tool: add_numbers', - 'logfire.json_schema': IsJson( - snapshot( - { - 'type': 'object', - 'properties': { - 'gen_ai.tool.name': {}, - 'gen_ai.tool.call.id': {}, - }, - } - ) - ), - 'logfire.span_type': 'span', - } - ) + if tool_error: + assert tool_attributes == snapshot( + { + 'gen_ai.tool.name': 'add_numbers', + 'gen_ai.tool.call.id': IsStr(), + 'logfire.msg': 'running tool: add_numbers', + 'logfire.json_schema': IsJson( + snapshot( + { + 'type': 'object', + 'properties': { + 'gen_ai.tool.name': {}, + 'gen_ai.tool.call.id': {}, + }, + } + ) + ), + 'logfire.span_type': 'span', + 'logfire.level_num': 17, + } + ) + else: + assert tool_attributes == snapshot( + { + 'gen_ai.tool.name': 'add_numbers', + 'gen_ai.tool.call.id': IsStr(), + 'logfire.msg': 'running tool: add_numbers', + 'logfire.json_schema': IsJson( + snapshot( + { + 'type': 'object', + 'properties': { + 'gen_ai.tool.name': {}, + 'gen_ai.tool.call.id': {}, + }, + } + ) + ), + 'logfire.span_type': 'span', + } + ) class WeatherInfo(BaseModel): @@ -750,7 +811,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: { 'gen_ai.tool.name': 'final_result', 'gen_ai.tool.call.id': IsStr(), - 'tool_arguments': '{"city": "Mexico City"}', + 'tool_arguments': '{"city":"Mexico City"}', 'logfire.msg': 'running output function: final_result', 'logfire.json_schema': IsJson( snapshot( @@ -811,7 +872,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: { 'gen_ai.tool.name': 'final_result', 'gen_ai.tool.call.id': IsStr(), - 'tool_arguments': '{"city": "Mexico City"}', + 'tool_arguments': '{"city":"Mexico City"}', 'logfire.msg': 'running output function: final_result', 'logfire.json_schema': IsJson( snapshot( @@ -881,7 +942,7 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'gen_ai.tool.name': 'final_result', 'logfire.msg': 'running output function: final_result', 'gen_ai.tool.call.id': IsStr(), - 'tool_arguments': '{"city": "New York City"}', + 'tool_arguments': '{"city":"New York City"}', 'logfire.json_schema': IsJson( snapshot( { @@ -900,7 +961,7 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'gen_ai.tool.name': 'final_result', 'logfire.msg': 'running output function: final_result', 'gen_ai.tool.call.id': IsStr(), - 'tool_arguments': '{"city": "Mexico City"}', + 'tool_arguments': '{"city":"Mexico City"}', 'logfire.json_schema': IsJson( snapshot( { @@ -968,7 +1029,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: { 'gen_ai.tool.name': 'get_weather', 'gen_ai.tool.call.id': IsStr(), - 'tool_arguments': '{"city": "Mexico City"}', + 'tool_arguments': '{"city":"Mexico City"}', 'logfire.msg': 'running output function: get_weather', 'logfire.json_schema': IsJson( snapshot( @@ -1034,7 +1095,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: { 'gen_ai.tool.name': 'final_result', 'gen_ai.tool.call.id': IsStr(), - 'tool_arguments': '{"city": "Mexico City"}', + 'tool_arguments': '{"city":"Mexico City"}', 'logfire.msg': 'running output function: final_result', 'logfire.json_schema': IsJson( snapshot( @@ -1101,7 +1162,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: { 'gen_ai.tool.name': 'final_result', 'gen_ai.tool.call.id': IsStr(), - 'tool_arguments': '{"city": "Mexico City"}', + 'tool_arguments': '{"city":"Mexico City"}', 'logfire.msg': 'running output function: final_result', 'logfire.json_schema': IsJson( snapshot( @@ -1163,7 +1224,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: { 'gen_ai.tool.name': 'final_result', 'gen_ai.tool.call.id': IsStr(), - 'tool_arguments': '{"city": "Mexico City"}', + 'tool_arguments': '{"city":"Mexico City"}', 'logfire.msg': 'running output function: final_result', 'logfire.json_schema': IsJson( snapshot( @@ -1299,7 +1360,7 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: assert output_function_attributes == snapshot( { 'gen_ai.tool.name': 'upcase_text', - 'tool_arguments': '{"text": "hello world"}', + 'tool_arguments': '{"text":"hello world"}', 'logfire.msg': 'running output function: upcase_text', 'logfire.json_schema': IsJson( snapshot( diff --git a/tests/test_mcp.py b/tests/test_mcp.py index fe092d9dd7..e2c5bd0989 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -1,5 +1,7 @@ """Tests for the MCP (Model Context Protocol) server implementation.""" +from __future__ import annotations + import base64 import re from datetime import timezone @@ -23,6 +25,7 @@ ToolReturnPart, UserPromptPart, ) +from pydantic_ai.models import Model from pydantic_ai.models.test import TestModel from pydantic_ai.tools import RunContext from pydantic_ai.usage import Usage @@ -48,23 +51,36 @@ @pytest.fixture -def agent(openai_api_key: str): - server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) - model = OpenAIModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) - return Agent(model, mcp_servers=[server]) +def mcp_server() -> MCPServerStdio: + return MCPServerStdio('python', ['-m', 'tests.mcp_server']) + + +@pytest.fixture +def model(openai_api_key: str) -> Model: + return OpenAIModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) + +@pytest.fixture +def agent(model: Model, mcp_server: MCPServerStdio) -> Agent: + return Agent(model, toolsets=[mcp_server]) -async def test_stdio_server(): + +@pytest.fixture +def run_context(model: Model) -> RunContext[int]: + return RunContext(deps=0, model=model, usage=Usage()) + + +async def test_stdio_server(run_context: RunContext[int]): server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) async with server: - tools = await server.list_tools() + tools = [tool.tool_def for tool in (await server.get_tools(run_context)).values()] assert len(tools) == snapshot(13) assert tools[0].name == 'celsius_to_fahrenheit' assert isinstance(tools[0].description, str) assert tools[0].description.startswith('Convert Celsius to Fahrenheit.') # Test calling the temperature conversion tool - result = await server.call_tool('celsius_to_fahrenheit', {'celsius': 0}) + result = await server.direct_call_tool('celsius_to_fahrenheit', {'celsius': 0}) assert result == snapshot('32.0') @@ -75,38 +91,43 @@ async def test_reentrant_context_manager(): pass -async def test_stdio_server_with_tool_prefix(): +async def test_stdio_server_with_tool_prefix(run_context: RunContext[int]): server = MCPServerStdio('python', ['-m', 'tests.mcp_server'], tool_prefix='foo') async with server: - tools = await server.list_tools() - assert all(tool.name.startswith('foo_') for tool in tools) + tools = await server.get_tools(run_context) + assert all(name.startswith('foo_') for name in tools.keys()) + + result = await server.call_tool( + 'foo_celsius_to_fahrenheit', {'celsius': 0}, run_context, tools['foo_celsius_to_fahrenheit'] + ) + assert result == snapshot('32.0') -async def test_stdio_server_with_cwd(): +async def test_stdio_server_with_cwd(run_context: RunContext[int]): test_dir = Path(__file__).parent server = MCPServerStdio('python', ['mcp_server.py'], cwd=test_dir) async with server: - tools = await server.list_tools() + tools = await server.get_tools(run_context) assert len(tools) == snapshot(13) -async def test_process_tool_call() -> None: +async def test_process_tool_call(run_context: RunContext[int]) -> int: called: bool = False async def process_tool_call( ctx: RunContext[int], call_tool: CallToolFunc, - tool_name: str, - args: dict[str, Any], + name: str, + tool_args: dict[str, Any], ) -> ToolResult: """A process_tool_call that sets a flag and sends deps as metadata.""" nonlocal called called = True - return await call_tool(tool_name, args, {'deps': ctx.deps}) + return await call_tool(name, tool_args, {'deps': ctx.deps}) server = MCPServerStdio('python', ['-m', 'tests.mcp_server'], process_tool_call=process_tool_call) async with server: - agent = Agent(deps_type=int, model=TestModel(call_tools=['echo_deps']), mcp_servers=[server]) + agent = Agent(deps_type=int, model=TestModel(call_tools=['echo_deps']), toolsets=[server]) result = await agent.run('Echo with deps set to 42', deps=42) assert result.output == snapshot('{"echo_deps":{"echo":"This is an echo message","deps":42}}') assert called, 'process_tool_call should have been called' @@ -135,7 +156,7 @@ def test_sse_server_with_header_and_timeout(): @pytest.mark.vcr() async def test_agent_with_stdio_server(allow_model_requests: None, agent: Agent): - async with agent.run_mcp_servers(): + async with agent: result = await agent.run('What is 0 degrees Celsius in Fahrenheit?') assert result.output == snapshot('0 degrees Celsius is equal to 32 degrees Fahrenheit.') assert result.all_messages() == snapshot( @@ -212,11 +233,11 @@ def get_none() -> None: # pragma: no cover """Return nothing""" return None - async with agent.run_mcp_servers(): + async with agent: with pytest.raises( UserError, match=re.escape( - "MCP Server 'MCPServerStdio(command='python', args=['-m', 'tests.mcp_server'], tool_prefix=None)' defines a tool whose name conflicts with existing tool: 'get_none'. Consider using `tool_prefix` to avoid name conflicts." + "MCPServerStdio(command='python', args=['-m', 'tests.mcp_server'], tool_prefix=None) defines a tool whose name conflicts with existing tool from Function toolset: 'get_none'. Consider setting `tool_prefix` to avoid name conflicts." ), ): await agent.run('Get me a conflict') @@ -227,7 +248,7 @@ async def test_agent_with_prefix_tool_name(openai_api_key: str): model = OpenAIModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) agent = Agent( model, - mcp_servers=[server], + toolsets=[server], ) @agent.tool_plain @@ -235,43 +256,41 @@ def get_none() -> None: # pragma: no cover """Return nothing""" return None - async with agent.run_mcp_servers(): + async with agent: # This means that we passed the _prepare_request_parameters check and there is no conflict in the tool name with pytest.raises(RuntimeError, match='Model requests are not allowed, since ALLOW_MODEL_REQUESTS is False'): await agent.run('No conflict') -async def test_agent_with_server_not_running(openai_api_key: str): - server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) - model = OpenAIModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) - agent = Agent(model, mcp_servers=[server]) - with pytest.raises(UserError, match='MCP server is not running'): - await agent.run('What is 0 degrees Celsius in Fahrenheit?') +@pytest.mark.vcr() +async def test_agent_with_server_not_running(agent: Agent, allow_model_requests: None): + result = await agent.run('What is 0 degrees Celsius in Fahrenheit?') + assert result.output == snapshot('0 degrees Celsius is 32.0 degrees Fahrenheit.') -async def test_log_level_unset(): +async def test_log_level_unset(run_context: RunContext[int]): server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) assert server.log_level is None async with server: - tools = await server.list_tools() + tools = [tool.tool_def for tool in (await server.get_tools(run_context)).values()] assert len(tools) == snapshot(13) assert tools[10].name == 'get_log_level' - result = await server.call_tool('get_log_level', {}) + result = await server.direct_call_tool('get_log_level', {}) assert result == snapshot('unset') -async def test_log_level_set(): +async def test_log_level_set(run_context: RunContext[int]): server = MCPServerStdio('python', ['-m', 'tests.mcp_server'], log_level='info') assert server.log_level == 'info' async with server: - result = await server.call_tool('get_log_level', {}) + result = await server.direct_call_tool('get_log_level', {}) assert result == snapshot('info') @pytest.mark.vcr() async def test_tool_returning_str(allow_model_requests: None, agent: Agent): - async with agent.run_mcp_servers(): + async with agent: result = await agent.run('What is the weather in Mexico City?') assert result.output == snapshot( 'The weather in Mexico City is currently sunny with a temperature of 26 degrees Celsius.' @@ -350,7 +369,7 @@ async def test_tool_returning_str(allow_model_requests: None, agent: Agent): @pytest.mark.vcr() async def test_tool_returning_text_resource(allow_model_requests: None, agent: Agent): - async with agent.run_mcp_servers(): + async with agent: result = await agent.run('Get me the product name') assert result.output == snapshot('The product name is "PydanticAI".') assert result.all_messages() == snapshot( @@ -423,7 +442,7 @@ async def test_tool_returning_text_resource(allow_model_requests: None, agent: A @pytest.mark.vcr() async def test_tool_returning_image_resource(allow_model_requests: None, agent: Agent, image_content: BinaryContent): - async with agent.run_mcp_servers(): + async with agent: result = await agent.run('Get me the image resource') assert result.output == snapshot( 'This is an image of a sliced kiwi with a vibrant green interior and black seeds.' @@ -506,7 +525,7 @@ async def test_tool_returning_audio_resource( allow_model_requests: None, agent: Agent, audio_content: BinaryContent, gemini_api_key: str ): model = GoogleModel('gemini-2.5-pro-preview-03-25', provider=GoogleProvider(api_key=gemini_api_key)) - async with agent.run_mcp_servers(): + async with agent: result = await agent.run("What's the content of the audio resource?", model=model) assert result.output == snapshot('The audio resource contains a voice saying "Hello, my name is Marcelo."') assert result.all_messages() == snapshot( @@ -557,7 +576,7 @@ async def test_tool_returning_audio_resource( @pytest.mark.vcr() async def test_tool_returning_image(allow_model_requests: None, agent: Agent, image_content: BinaryContent): - async with agent.run_mcp_servers(): + async with agent: result = await agent.run('Get me an image') assert result.output == snapshot('Here is an image of a sliced kiwi on a white background.') assert result.all_messages() == snapshot( @@ -637,7 +656,7 @@ async def test_tool_returning_image(allow_model_requests: None, agent: Agent, im @pytest.mark.vcr() async def test_tool_returning_dict(allow_model_requests: None, agent: Agent): - async with agent.run_mcp_servers(): + async with agent: result = await agent.run('Get me a dict, respond on one line') assert result.output == snapshot('{"foo":"bar","baz":123}') assert result.all_messages() == snapshot( @@ -704,7 +723,7 @@ async def test_tool_returning_dict(allow_model_requests: None, agent: Agent): @pytest.mark.vcr() async def test_tool_returning_error(allow_model_requests: None, agent: Agent): - async with agent.run_mcp_servers(): + async with agent: result = await agent.run('Get me an error, pass False as a value, unless the tool tells you otherwise') assert result.output == snapshot( 'I called the tool with the correct parameter, and it returned: "This is not an error."' @@ -818,7 +837,7 @@ async def test_tool_returning_error(allow_model_requests: None, agent: Agent): @pytest.mark.vcr() async def test_tool_returning_none(allow_model_requests: None, agent: Agent): - async with agent.run_mcp_servers(): + async with agent: result = await agent.run('Call the none tool and say Hello') assert result.output == snapshot('Hello! How can I assist you today?') assert result.all_messages() == snapshot( @@ -885,7 +904,7 @@ async def test_tool_returning_none(allow_model_requests: None, agent: Agent): @pytest.mark.vcr() async def test_tool_returning_multiple_items(allow_model_requests: None, agent: Agent, image_content: BinaryContent): - async with agent.run_mcp_servers(): + async with agent: result = await agent.run('Get me multiple items and summarize in one sentence') assert result.output == snapshot( 'The data includes two strings, a dictionary with a key-value pair, and an image of a sliced kiwi.' @@ -974,11 +993,11 @@ async def test_tool_returning_multiple_items(allow_model_requests: None, agent: ) -async def test_client_sampling(): +async def test_client_sampling(run_context: RunContext[int]): server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) server.sampling_model = TestModel(custom_output_text='sampling model response') async with server: - result = await server.call_tool('use_sampling', {'foo': 'bar'}) + result = await server.direct_call_tool('use_sampling', {'foo': 'bar'}) assert result == snapshot( { 'meta': None, @@ -990,27 +1009,27 @@ async def test_client_sampling(): ) -async def test_client_sampling_disabled(): +async def test_client_sampling_disabled(run_context: RunContext[int]): server = MCPServerStdio('python', ['-m', 'tests.mcp_server'], allow_sampling=False) server.sampling_model = TestModel(custom_output_text='sampling model response') async with server: with pytest.raises(ModelRetry, match='Error executing tool use_sampling: Sampling not supported'): - await server.call_tool('use_sampling', {'foo': 'bar'}) - + await server.direct_call_tool('use_sampling', {'foo': 'bar'}) -async def test_mcp_server_raises_mcp_error(allow_model_requests: None, agent: Agent) -> None: - server = agent._mcp_servers[0] # pyright: ignore[reportPrivateUsage] +async def test_mcp_server_raises_mcp_error( + allow_model_requests: None, mcp_server: MCPServerStdio, agent: Agent, run_context: RunContext[int] +) -> None: mcp_error = McpError(error=ErrorData(code=400, message='Test MCP error conversion')) - async with agent.run_mcp_servers(): + async with agent: with patch.object( - server._client, # pyright: ignore[reportPrivateUsage] + mcp_server._client, # pyright: ignore[reportPrivateUsage] 'send_request', new=AsyncMock(side_effect=mcp_error), ): with pytest.raises(ModelRetry, match='Test MCP error conversion'): - await server.call_tool('test_tool', {}) + await mcp_server.direct_call_tool('test_tool', {}) def test_map_from_mcp_params_model_request(): diff --git a/tests/test_streaming.py b/tests/test_streaming.py index c7753f1fad..dbdcd71f32 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -5,6 +5,7 @@ import re from collections.abc import AsyncIterator from copy import deepcopy +from dataclasses import replace from datetime import timezone from typing import Any, Union @@ -12,14 +13,16 @@ from inline_snapshot import snapshot from pydantic import BaseModel -from pydantic_ai import Agent, UnexpectedModelBehavior, UserError, capture_run_messages +from pydantic_ai import Agent, RunContext, UnexpectedModelBehavior, UserError, capture_run_messages from pydantic_ai.agent import AgentRun from pydantic_ai.messages import ( + FinalResultEvent, FunctionToolCallEvent, FunctionToolResultEvent, ModelMessage, ModelRequest, ModelResponse, + PartStartEvent, RetryPromptPart, TextPart, ToolCallPart, @@ -28,8 +31,9 @@ ) from pydantic_ai.models.function import AgentInfo, DeltaToolCall, DeltaToolCalls, FunctionModel from pydantic_ai.models.test import TestModel -from pydantic_ai.output import PromptedOutput, TextOutput +from pydantic_ai.output import DeferredToolCalls, PromptedOutput, TextOutput from pydantic_ai.result import AgentStream, FinalResult, Usage +from pydantic_ai.tools import ToolDefinition from pydantic_graph import End from .conftest import IsInt, IsNow, IsStr @@ -272,7 +276,7 @@ async def text_stream(_messages: list[ModelMessage], _: AgentInfo) -> AsyncItera agent = Agent(FunctionModel(stream_function=text_stream), output_type=tuple[str, str]) - with pytest.raises(UnexpectedModelBehavior, match=r'Exceeded maximum retries \(1\) for result validation'): + with pytest.raises(UnexpectedModelBehavior, match=r'Exceeded maximum retries \(1\) for output validation'): async with agent.run_stream(''): pass @@ -407,7 +411,7 @@ async def ret_a(x: str) -> str: # pragma: no cover return x with capture_run_messages() as messages: - with pytest.raises(UnexpectedModelBehavior, match=r'Exceeded maximum retries \(0\) for result validation'): + with pytest.raises(UnexpectedModelBehavior, match=r'Exceeded maximum retries \(0\) for output validation'): async with agent.run_stream('hello'): pass @@ -613,18 +617,18 @@ def another_tool(y: int) -> int: timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr(), ), - RetryPromptPart( - tool_name='unknown_tool', - content="Unknown tool name: 'unknown_tool'. Available tools: regular_tool, another_tool, final_result", - timestamp=IsNow(tz=timezone.utc), - tool_call_id=IsStr(), - ), ToolReturnPart( tool_name='regular_tool', content=42, timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr() ), ToolReturnPart( tool_name='another_tool', content=2, timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr() ), + RetryPromptPart( + content="Unknown tool name: 'unknown_tool'. Available tools: 'final_result', 'regular_tool', 'another_tool'", + tool_name='unknown_tool', + tool_call_id=IsStr(), + timestamp=IsNow(tz=timezone.utc), + ), ] ), ] @@ -712,15 +716,15 @@ def another_tool(y: int) -> int: # pragma: no cover ModelRequest( parts=[ ToolReturnPart( - tool_name='regular_tool', - content='Tool not executed - a final result was already processed.', + tool_name='final_result', + content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=datetime.timezone.utc), part_kind='tool-return', ), ToolReturnPart( - tool_name='final_result', - content='Final result processed.', + tool_name='regular_tool', + content='Tool not executed - a final result was already processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=datetime.timezone.utc), part_kind='tool-return', @@ -733,10 +737,7 @@ def another_tool(y: int) -> int: # pragma: no cover part_kind='tool-return', ), RetryPromptPart( - content='Unknown tool name: ' - "'unknown_tool'. Available tools: " - 'regular_tool, another_tool, ' - 'final_result', + content="Unknown tool name: 'unknown_tool'. Available tools: 'final_result', 'regular_tool', 'another_tool'", tool_name='unknown_tool', tool_call_id=IsStr(), timestamp=IsNow(tz=datetime.timezone.utc), @@ -975,6 +976,13 @@ def known_tool(x: int) -> int: assert event_parts == snapshot( [ + FunctionToolCallEvent( + part=ToolCallPart( + tool_name='known_tool', + args={'x': 5}, + tool_call_id=IsStr(), + ) + ), FunctionToolCallEvent( part=ToolCallPart( tool_name='unknown_tool', @@ -984,14 +992,11 @@ def known_tool(x: int) -> int: ), FunctionToolResultEvent( result=RetryPromptPart( - content="Unknown tool name: 'unknown_tool'. Available tools: known_tool", + content="Unknown tool name: 'unknown_tool'. Available tools: 'known_tool'", tool_name='unknown_tool', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), - ) - ), - FunctionToolCallEvent( - part=ToolCallPart(tool_name='known_tool', args={'x': 5}, tool_call_id=IsStr()), + ), ), FunctionToolResultEvent( result=ToolReturnPart( @@ -999,13 +1004,6 @@ def known_tool(x: int) -> int: content=10, tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), - ) - ), - FunctionToolCallEvent( - part=ToolCallPart( - tool_name='unknown_tool', - args={'arg': 'value'}, - tool_call_id=IsStr(), ), ), ] @@ -1027,15 +1025,15 @@ def call_final_result_with_bad_data(messages: list[ModelMessage], info: AgentInf agent = Agent(FunctionModel(call_final_result_with_bad_data), output_type=OutputType) - event_parts: list[Any] = [] + events: list[Any] = [] async with agent.iter('test') as agent_run: async for node in agent_run: if Agent.is_call_tools_node(node): async with node.stream(agent_run.ctx) as event_stream: async for event in event_stream: - event_parts.append(event) + events.append(event) - assert event_parts == snapshot( + assert events == snapshot( [ FunctionToolCallEvent( part=ToolCallPart( @@ -1045,9 +1043,16 @@ def call_final_result_with_bad_data(messages: list[ModelMessage], info: AgentInf ), ), FunctionToolResultEvent( - result=ToolReturnPart( + result=RetryPromptPart( + content=[ + { + 'type': 'missing', + 'loc': ('value',), + 'msg': 'Field required', + 'input': {'bad_value': 'invalid'}, + } + ], tool_name='final_result', - content='Output tool not used - result failed validation.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), ) @@ -1118,3 +1123,93 @@ def test_function_tool_event_tool_call_id_properties(): # The event should expose the same `tool_call_id` as the result part assert result_event.tool_call_id == return_part.tool_call_id == 'return_id_456' + + +async def test_deferred_tool(): + agent = Agent(TestModel(), output_type=[str, DeferredToolCalls]) + + async def prepare_tool(ctx: RunContext[None], tool_def: ToolDefinition) -> ToolDefinition: + return replace(tool_def, kind='deferred') + + @agent.tool_plain(prepare=prepare_tool) + def my_tool(x: int) -> int: + return x + 1 # pragma: no cover + + async with agent.run_stream('Hello') as result: + assert not result.is_complete + output = await result.get_output() + assert output == snapshot( + DeferredToolCalls( + tool_calls=[ToolCallPart(tool_name='my_tool', args={'x': 0}, tool_call_id=IsStr())], + tool_defs={ + 'my_tool': ToolDefinition( + name='my_tool', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'x': {'type': 'integer'}}, + 'required': ['x'], + 'type': 'object', + }, + kind='deferred', + ) + }, + ) + ) + assert result.is_complete + + +async def test_deferred_tool_iter(): + agent = Agent(TestModel(), output_type=[str, DeferredToolCalls]) + + async def prepare_tool(ctx: RunContext[None], tool_def: ToolDefinition) -> ToolDefinition: + return replace(tool_def, kind='deferred') + + @agent.tool_plain(prepare=prepare_tool) + def my_tool(x: int) -> int: + return x + 1 # pragma: no cover + + outputs: list[str | DeferredToolCalls] = [] + events: list[Any] = [] + + async with agent.iter('test') as run: + async for node in run: + if agent.is_model_request_node(node): + async with node.stream(run.ctx) as stream: + async for event in stream: + events.append(event) + async for output in stream.stream_output(debounce_by=None): + outputs.append(output) + if agent.is_call_tools_node(node): + async with node.stream(run.ctx) as stream: + async for event in stream: + events.append(event) + + assert outputs == snapshot( + [ + DeferredToolCalls( + tool_calls=[ToolCallPart(tool_name='my_tool', args={'x': 0}, tool_call_id=IsStr())], + tool_defs={ + 'my_tool': ToolDefinition( + name='my_tool', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'x': {'type': 'integer'}}, + 'required': ['x'], + 'type': 'object', + }, + kind='deferred', + ) + }, + ) + ] + ) + assert events == snapshot( + [ + PartStartEvent( + index=0, + part=ToolCallPart(tool_name='my_tool', args={'x': 0}, tool_call_id=IsStr()), + ), + FinalResultEvent(tool_name=None, tool_call_id=None), + FunctionToolCallEvent(part=ToolCallPart(tool_name='my_tool', args={'x': 0}, tool_call_id=IsStr())), + ] + ) diff --git a/tests/test_tools.py b/tests/test_tools.py index 00f3f5bcc0..c72cd1e086 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -12,11 +12,17 @@ from typing_extensions import TypedDict from pydantic_ai import Agent, RunContext, Tool, UserError +from pydantic_ai.exceptions import ModelRetry, UnexpectedModelBehavior from pydantic_ai.messages import ModelMessage, ModelRequest, ModelResponse, TextPart, ToolCallPart, ToolReturnPart from pydantic_ai.models.function import AgentInfo, FunctionModel from pydantic_ai.models.test import TestModel -from pydantic_ai.output import ToolOutput +from pydantic_ai.output import DeferredToolCalls, ToolOutput from pydantic_ai.tools import ToolDefinition +from pydantic_ai.toolsets.deferred import DeferredToolset +from pydantic_ai.toolsets.function import FunctionToolset +from pydantic_ai.toolsets.prefixed import PrefixedToolset + +from .conftest import IsStr def test_tool_no_ctx(): @@ -105,6 +111,7 @@ def test_docstring_google(docstring_format: Literal['google', 'auto']): }, 'outer_typed_dict_key': None, 'strict': None, + 'kind': 'function', } ) @@ -136,6 +143,7 @@ def test_docstring_sphinx(docstring_format: Literal['sphinx', 'auto']): }, 'outer_typed_dict_key': None, 'strict': None, + 'kind': 'function', } ) @@ -175,6 +183,7 @@ def test_docstring_numpy(docstring_format: Literal['numpy', 'auto']): }, 'outer_typed_dict_key': None, 'strict': None, + 'kind': 'function', } ) @@ -214,6 +223,7 @@ def my_tool(x: int) -> str: # pragma: no cover }, 'outer_typed_dict_key': None, 'strict': None, + 'kind': 'function', } ) @@ -251,6 +261,7 @@ def my_tool(x: int) -> str: # pragma: no cover }, 'outer_typed_dict_key': None, 'strict': None, + 'kind': 'function', } ) @@ -294,6 +305,7 @@ def my_tool(x: int) -> str: # pragma: no cover }, 'outer_typed_dict_key': None, 'strict': None, + 'kind': 'function', } ) @@ -325,6 +337,7 @@ def test_only_returns_type(): 'parameters_json_schema': {'additionalProperties': False, 'properties': {}, 'type': 'object'}, 'outer_typed_dict_key': None, 'strict': None, + 'kind': 'function', } ) @@ -347,6 +360,7 @@ def test_docstring_unknown(): 'parameters_json_schema': {'properties': {}, 'type': 'object'}, 'outer_typed_dict_key': None, 'strict': None, + 'kind': 'function', } ) @@ -387,6 +401,7 @@ def test_docstring_google_no_body(docstring_format: Literal['google', 'auto']): }, 'outer_typed_dict_key': None, 'strict': None, + 'kind': 'function', } ) @@ -420,6 +435,7 @@ def takes_just_model(model: Foo) -> str: }, 'outer_typed_dict_key': None, 'strict': None, + 'kind': 'function', } ) @@ -462,6 +478,7 @@ def takes_just_model(model: Foo, z: int) -> str: }, 'outer_typed_dict_key': None, 'strict': None, + 'kind': 'function', } ) @@ -481,15 +498,15 @@ def plain_tool(x: int) -> int: result = agent.run_sync('foobar') assert result.output == snapshot('{"plain_tool":1}') assert call_args == snapshot([0]) - assert agent._function_tools['plain_tool'].takes_ctx is False - assert agent._function_tools['plain_tool'].max_retries == 7 + assert agent._function_toolset.tools['plain_tool'].takes_ctx is False + assert agent._function_toolset.tools['plain_tool'].max_retries == 7 agent_infer = Agent('test', tools=[plain_tool], retries=7) result = agent_infer.run_sync('foobar') assert result.output == snapshot('{"plain_tool":1}') assert call_args == snapshot([0, 0]) - assert agent_infer._function_tools['plain_tool'].takes_ctx is False - assert agent_infer._function_tools['plain_tool'].max_retries == 7 + assert agent_infer._function_toolset.tools['plain_tool'].takes_ctx is False + assert agent_infer._function_toolset.tools['plain_tool'].max_retries == 7 def ctx_tool(ctx: RunContext[int], x: int) -> int: @@ -501,13 +518,13 @@ def test_init_tool_ctx(): agent = Agent('test', tools=[Tool(ctx_tool, takes_ctx=True, max_retries=3)], deps_type=int, retries=7) result = agent.run_sync('foobar', deps=5) assert result.output == snapshot('{"ctx_tool":5}') - assert agent._function_tools['ctx_tool'].takes_ctx is True - assert agent._function_tools['ctx_tool'].max_retries == 3 + assert agent._function_toolset.tools['ctx_tool'].takes_ctx is True + assert agent._function_toolset.tools['ctx_tool'].max_retries == 3 agent_infer = Agent('test', tools=[ctx_tool], deps_type=int) result = agent_infer.run_sync('foobar', deps=6) assert result.output == snapshot('{"ctx_tool":6}') - assert agent_infer._function_tools['ctx_tool'].takes_ctx is True + assert agent_infer._function_toolset.tools['ctx_tool'].takes_ctx is True def test_repeat_tool_by_rename(): @@ -557,18 +574,40 @@ def foo(x: int, y: str) -> str: # pragma: no cover def bar(x: int, y: str) -> str: # pragma: no cover return f'{x} {y}' - with pytest.raises(UserError, match=r"Tool name conflicts with existing tool: 'bar'."): + with pytest.raises(UserError, match="Tool name conflicts with previously renamed tool: 'bar'."): agent.run_sync('') def test_tool_return_conflict(): # this is okay - Agent('test', tools=[ctx_tool], deps_type=int) + Agent('test', tools=[ctx_tool], deps_type=int).run_sync('', deps=0) # this is also okay - Agent('test', tools=[ctx_tool], deps_type=int, output_type=int) + Agent('test', tools=[ctx_tool], deps_type=int, output_type=int).run_sync('', deps=0) # this raises an error - with pytest.raises(UserError, match="Tool name conflicts with output tool name: 'ctx_tool'"): - Agent('test', tools=[ctx_tool], deps_type=int, output_type=ToolOutput(int, name='ctx_tool')) + with pytest.raises( + UserError, + match="Function toolset defines a tool whose name conflicts with existing tool from Output toolset: 'ctx_tool'. Rename the tool or wrap the toolset in a `PrefixedToolset` to avoid name conflicts.", + ): + Agent('test', tools=[ctx_tool], deps_type=int, output_type=ToolOutput(int, name='ctx_tool')).run_sync( + '', deps=0 + ) + + +def test_tool_name_conflict_hint(): + with pytest.raises( + UserError, + match="Prefixed toolset defines a tool whose name conflicts with existing tool from Function toolset: 'foo_tool'. Rename the tool or wrap the toolset in a `PrefixedToolset` to avoid name conflicts.", + ): + + def tool(x: int) -> int: + return x + 1 # pragma: no cover + + def foo_tool(x: str) -> str: + return x + 'foo' # pragma: no cover + + function_toolset = FunctionToolset([tool]) + prefixed_toolset = PrefixedToolset(function_toolset, 'foo') + Agent('test', tools=[foo_tool], toolsets=[prefixed_toolset]).run_sync('') def test_init_ctx_tool_invalid(): @@ -798,6 +837,7 @@ def test_suppress_griffe_logging(caplog: LogCaptureFixture): 'outer_typed_dict_key': None, 'parameters_json_schema': {'additionalProperties': False, 'properties': {}, 'type': 'object'}, 'strict': None, + 'kind': 'function', } ) @@ -867,6 +907,7 @@ def my_tool_plain(*, a: int = 1, b: int) -> int: 'type': 'object', }, 'strict': None, + 'kind': 'function', }, { 'description': None, @@ -879,6 +920,7 @@ def my_tool_plain(*, a: int = 1, b: int) -> int: 'type': 'object', }, 'strict': None, + 'kind': 'function', }, ] ) @@ -963,6 +1005,7 @@ def my_tool(x: Annotated[Union[str, None], WithJsonSchema({'type': 'string'})] = 'type': 'object', }, 'strict': None, + 'kind': 'function', }, { 'description': None, @@ -973,6 +1016,7 @@ def my_tool(x: Annotated[Union[str, None], WithJsonSchema({'type': 'string'})] = 'type': 'object', }, 'strict': None, + 'kind': 'function', }, ] ) @@ -1008,6 +1052,7 @@ def get_score(data: Data) -> int: ... # pragma: no branch }, 'outer_typed_dict_key': None, 'strict': None, + 'kind': 'function', } ) @@ -1039,7 +1084,7 @@ def foobar(ctx: RunContext[int], x: int, y: str) -> str: with agent.override(model=FunctionModel(get_json_schema)): result = agent.run_sync('', deps=21) json_schema = json.loads(result.output) - assert agent._function_tools['foobar'].strict is None + assert agent._function_toolset.tools['foobar'].strict is None assert json_schema['strict'] is True result = agent.run_sync('', deps=1) @@ -1066,8 +1111,8 @@ def function(*args: Any, **kwargs: Any) -> str: agent = Agent('test', tools=[pydantic_tool], retries=0) result = agent.run_sync('foobar') assert result.output == snapshot('{"foobar":"I like being called like this"}') - assert agent._function_tools['foobar'].takes_ctx is False - assert agent._function_tools['foobar'].max_retries == 0 + assert agent._function_toolset.tools['foobar'].takes_ctx is False + assert agent._function_toolset.tools['foobar'].max_retries == 0 def test_function_tool_inconsistent_with_schema(): @@ -1113,5 +1158,146 @@ async def function(*args: Any, **kwargs: Any) -> str: agent = Agent('test', tools=[pydantic_tool], retries=0) result = agent.run_sync('foobar') assert result.output == snapshot('{"foobar":"I like being called like this"}') - assert agent._function_tools['foobar'].takes_ctx is False - assert agent._function_tools['foobar'].max_retries == 0 + assert agent._function_toolset.tools['foobar'].takes_ctx is False + assert agent._function_toolset.tools['foobar'].max_retries == 0 + + +def test_tool_retries(): + prepare_tools_retries: list[int] = [] + prepare_retries: list[int] = [] + call_retries: list[int] = [] + + async def prepare_tool_defs( + ctx: RunContext[None], tool_defs: list[ToolDefinition] + ) -> Union[list[ToolDefinition], None]: + nonlocal prepare_tools_retries + retry = ctx.retries.get('infinite_retry_tool', 0) + prepare_tools_retries.append(retry) + return tool_defs + + agent = Agent(TestModel(), retries=3, prepare_tools=prepare_tool_defs) + + async def prepare_tool_def(ctx: RunContext[None], tool_def: ToolDefinition) -> Union[ToolDefinition, None]: + nonlocal prepare_retries + prepare_retries.append(ctx.retry) + return tool_def + + @agent.tool(retries=5, prepare=prepare_tool_def) + def infinite_retry_tool(ctx: RunContext[None]) -> int: + nonlocal call_retries + call_retries.append(ctx.retry) + raise ModelRetry('Please try again.') + + with pytest.raises(UnexpectedModelBehavior, match="Tool 'infinite_retry_tool' exceeded max retries count of 5"): + agent.run_sync('Begin infinite retry loop!') + + # There are extra 0s here because the toolset is prepared once ahead of the graph run, before the user prompt part is added in. + assert prepare_tools_retries == [0, 0, 1, 2, 3, 4, 5] + assert prepare_retries == [0, 0, 1, 2, 3, 4, 5] + assert call_retries == [0, 1, 2, 3, 4, 5] + + +def test_deferred_tool(): + deferred_toolset = DeferredToolset( + [ + ToolDefinition( + name='my_tool', + description='', + parameters_json_schema={'type': 'object', 'properties': {'x': {'type': 'integer'}}, 'required': ['x']}, + ), + ] + ) + agent = Agent(TestModel(), output_type=[str, DeferredToolCalls], toolsets=[deferred_toolset]) + + result = agent.run_sync('Hello') + assert result.output == snapshot( + DeferredToolCalls( + tool_calls=[ToolCallPart(tool_name='my_tool', args={'x': 0}, tool_call_id=IsStr())], + tool_defs={ + 'my_tool': ToolDefinition( + name='my_tool', + description='', + parameters_json_schema={ + 'type': 'object', + 'properties': {'x': {'type': 'integer'}}, + 'required': ['x'], + }, + kind='deferred', + ) + }, + ) + ) + + +def test_deferred_tool_with_output_type(): + class MyModel(BaseModel): + foo: str + + deferred_toolset = DeferredToolset( + [ + ToolDefinition( + name='my_tool', + description='', + parameters_json_schema={'type': 'object', 'properties': {'x': {'type': 'integer'}}, 'required': ['x']}, + ), + ] + ) + agent = Agent(TestModel(call_tools=[]), output_type=[MyModel, DeferredToolCalls], toolsets=[deferred_toolset]) + + result = agent.run_sync('Hello') + assert result.output == snapshot(MyModel(foo='a')) + + +def test_deferred_tool_with_tool_output_type(): + class MyModel(BaseModel): + foo: str + + deferred_toolset = DeferredToolset( + [ + ToolDefinition( + name='my_tool', + description='', + parameters_json_schema={'type': 'object', 'properties': {'x': {'type': 'integer'}}, 'required': ['x']}, + ), + ] + ) + agent = Agent( + TestModel(call_tools=[]), + output_type=[[ToolOutput(MyModel), ToolOutput(MyModel)], DeferredToolCalls], + toolsets=[deferred_toolset], + ) + + result = agent.run_sync('Hello') + assert result.output == snapshot(MyModel(foo='a')) + + +async def test_deferred_tool_without_output_type(): + deferred_toolset = DeferredToolset( + [ + ToolDefinition( + name='my_tool', + description='', + parameters_json_schema={'type': 'object', 'properties': {'x': {'type': 'integer'}}, 'required': ['x']}, + ), + ] + ) + agent = Agent(TestModel(), toolsets=[deferred_toolset]) + + msg = 'A deferred tool call was present, but `DeferredToolCalls` is not among output types. To resolve this, add `DeferredToolCalls` to the list of output types for this agent.' + + with pytest.raises(UserError, match=msg): + await agent.run('Hello') + + with pytest.raises(UserError, match=msg): + async with agent.run_stream('Hello') as result: + await result.get_output() + + +def test_output_type_deferred_tool_calls_by_itself(): + with pytest.raises(UserError, match='At least one output type must be provided other than `DeferredToolCalls`.'): + Agent(TestModel(), output_type=DeferredToolCalls) + + +def test_output_type_empty(): + with pytest.raises(UserError, match='At least one output type must be provided.'): + Agent(TestModel(), output_type=[]) diff --git a/tests/test_toolsets.py b/tests/test_toolsets.py new file mode 100644 index 0000000000..eac0dc78a7 --- /dev/null +++ b/tests/test_toolsets.py @@ -0,0 +1,471 @@ +from __future__ import annotations + +import re +from dataclasses import dataclass, replace +from typing import TypeVar + +import pytest +from inline_snapshot import snapshot + +from pydantic_ai._run_context import RunContext +from pydantic_ai._tool_manager import ToolManager +from pydantic_ai.exceptions import UserError +from pydantic_ai.messages import ToolCallPart +from pydantic_ai.models.test import TestModel +from pydantic_ai.tools import ToolDefinition +from pydantic_ai.toolsets.combined import CombinedToolset +from pydantic_ai.toolsets.filtered import FilteredToolset +from pydantic_ai.toolsets.function import FunctionToolset +from pydantic_ai.toolsets.prefixed import PrefixedToolset +from pydantic_ai.toolsets.prepared import PreparedToolset +from pydantic_ai.usage import Usage + +pytestmark = pytest.mark.anyio + +T = TypeVar('T') + + +def build_run_context(deps: T) -> RunContext[T]: + return RunContext( + deps=deps, + model=TestModel(), + usage=Usage(), + prompt=None, + messages=[], + run_step=0, + ) + + +async def test_function_toolset(): + @dataclass + class PrefixDeps: + prefix: str | None = None + + toolset = FunctionToolset[PrefixDeps]() + + async def prepare_add_prefix(ctx: RunContext[PrefixDeps], tool_def: ToolDefinition) -> ToolDefinition | None: + if ctx.deps.prefix is None: + return tool_def + + return replace(tool_def, name=f'{ctx.deps.prefix}_{tool_def.name}') + + @toolset.tool(prepare=prepare_add_prefix) + def add(a: int, b: int) -> int: + """Add two numbers""" + return a + b + + no_prefix_context = build_run_context(PrefixDeps()) + no_prefix_toolset = await ToolManager[PrefixDeps].build(toolset, no_prefix_context) + assert no_prefix_toolset.tool_defs == snapshot( + [ + ToolDefinition( + name='add', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'a': {'type': 'integer'}, 'b': {'type': 'integer'}}, + 'required': ['a', 'b'], + 'type': 'object', + }, + description='Add two numbers', + ) + ] + ) + assert await no_prefix_toolset.handle_call(ToolCallPart(tool_name='add', args={'a': 1, 'b': 2})) == 3 + + foo_context = build_run_context(PrefixDeps(prefix='foo')) + foo_toolset = await ToolManager[PrefixDeps].build(toolset, foo_context) + assert foo_toolset.tool_defs == snapshot( + [ + ToolDefinition( + name='foo_add', + description='Add two numbers', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'a': {'type': 'integer'}, 'b': {'type': 'integer'}}, + 'required': ['a', 'b'], + 'type': 'object', + }, + ) + ] + ) + assert await foo_toolset.handle_call(ToolCallPart(tool_name='foo_add', args={'a': 1, 'b': 2})) == 3 + + @toolset.tool + def subtract(a: int, b: int) -> int: + """Subtract two numbers""" + return a - b # pragma: lax no cover + + bar_context = build_run_context(PrefixDeps(prefix='bar')) + bar_toolset = await ToolManager[PrefixDeps].build(toolset, bar_context) + assert bar_toolset.tool_defs == snapshot( + [ + ToolDefinition( + name='bar_add', + description='Add two numbers', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'a': {'type': 'integer'}, 'b': {'type': 'integer'}}, + 'required': ['a', 'b'], + 'type': 'object', + }, + ), + ToolDefinition( + name='subtract', + description='Subtract two numbers', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'a': {'type': 'integer'}, 'b': {'type': 'integer'}}, + 'required': ['a', 'b'], + 'type': 'object', + }, + ), + ] + ) + assert await bar_toolset.handle_call(ToolCallPart(tool_name='bar_add', args={'a': 1, 'b': 2})) == 3 + + +async def test_prepared_toolset_user_error_add_new_tools(): + """Test that PreparedToolset raises UserError when prepare function tries to add new tools.""" + context = build_run_context(None) + base_toolset = FunctionToolset[None]() + + @base_toolset.tool + def add(a: int, b: int) -> int: + """Add two numbers""" + return a + b # pragma: no cover + + async def prepare_add_new_tool(ctx: RunContext[None], tool_defs: list[ToolDefinition]) -> list[ToolDefinition]: + # Try to add a new tool that wasn't in the original set + new_tool = ToolDefinition( + name='new_tool', + description='A new tool', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'x': {'type': 'integer'}}, + 'required': ['x'], + 'type': 'object', + }, + ) + return tool_defs + [new_tool] + + prepared_toolset = PreparedToolset(base_toolset, prepare_add_new_tool) + + with pytest.raises( + UserError, + match=re.escape( + 'Prepare function cannot add or rename tools. Use `FunctionToolset.add_function()` or `RenamedToolset` instead.' + ), + ): + await ToolManager[None].build(prepared_toolset, context) + + +async def test_prepared_toolset_user_error_change_tool_names(): + """Test that PreparedToolset raises UserError when prepare function tries to change tool names.""" + context = build_run_context(None) + base_toolset = FunctionToolset[None]() + + @base_toolset.tool + def add(a: int, b: int) -> int: + """Add two numbers""" + return a + b # pragma: no cover + + @base_toolset.tool + def subtract(a: int, b: int) -> int: + """Subtract two numbers""" + return a - b # pragma: no cover + + async def prepare_change_names(ctx: RunContext[None], tool_defs: list[ToolDefinition]) -> list[ToolDefinition]: + # Try to change the name of an existing tool + modified_tool_defs: list[ToolDefinition] = [] + for tool_def in tool_defs: + if tool_def.name == 'add': + modified_tool_defs.append(replace(tool_def, name='modified_add')) + else: + modified_tool_defs.append(tool_def) + return modified_tool_defs + + prepared_toolset = PreparedToolset(base_toolset, prepare_change_names) + + with pytest.raises( + UserError, + match=re.escape( + 'Prepare function cannot add or rename tools. Use `FunctionToolset.add_function()` or `RenamedToolset` instead.' + ), + ): + await ToolManager[None].build(prepared_toolset, context) + + +async def test_comprehensive_toolset_composition(): + """Test that all toolsets can be composed together and work correctly.""" + + @dataclass + class TestDeps: + user_role: str = 'user' + enable_advanced: bool = True + + # Create first FunctionToolset with basic math operations + math_toolset = FunctionToolset[TestDeps]() + + @math_toolset.tool + def add(a: int, b: int) -> int: + """Add two numbers""" + return a + b + + @math_toolset.tool + def subtract(a: int, b: int) -> int: + """Subtract two numbers""" + return a - b # pragma: no cover + + @math_toolset.tool + def multiply(a: int, b: int) -> int: + """Multiply two numbers""" + return a * b # pragma: no cover + + # Create second FunctionToolset with string operations + string_toolset = FunctionToolset[TestDeps]() + + @string_toolset.tool + def concat(s1: str, s2: str) -> str: + """Concatenate two strings""" + return s1 + s2 + + @string_toolset.tool + def uppercase(text: str) -> str: + """Convert text to uppercase""" + return text.upper() # pragma: no cover + + @string_toolset.tool + def reverse(text: str) -> str: + """Reverse a string""" + return text[::-1] # pragma: no cover + + # Create third FunctionToolset with advanced operations + advanced_toolset = FunctionToolset[TestDeps]() + + @advanced_toolset.tool + def power(base: int, exponent: int) -> int: + """Calculate base raised to the power of exponent""" + return base**exponent # pragma: no cover + + # Step 1: Prefix each FunctionToolset individually + prefixed_math = PrefixedToolset(math_toolset, 'math') + prefixed_string = PrefixedToolset(string_toolset, 'str') + prefixed_advanced = PrefixedToolset(advanced_toolset, 'adv') + + # Step 2: Combine the prefixed toolsets + combined_prefixed_toolset = CombinedToolset([prefixed_math, prefixed_string, prefixed_advanced]) + + # Step 3: Filter tools based on user role and advanced flag, now using prefixed names + def filter_tools(ctx: RunContext[TestDeps], tool_def: ToolDefinition) -> bool: + # Only allow advanced tools if enable_advanced is True + if tool_def.name.startswith('adv_') and not ctx.deps.enable_advanced: + return False + # Only allow string operations for admin users (simulating role-based access) + if tool_def.name.startswith('str_') and ctx.deps.user_role != 'admin': + return False + return True + + filtered_toolset = FilteredToolset[TestDeps](combined_prefixed_toolset, filter_tools) + + # Step 4: Apply prepared toolset to modify descriptions (add user role annotation) + async def prepare_add_context(ctx: RunContext[TestDeps], tool_defs: list[ToolDefinition]) -> list[ToolDefinition]: + # Annotate each tool description with the user role + role = ctx.deps.user_role + return [replace(td, description=f'{td.description} (role: {role})') for td in tool_defs] + + prepared_toolset = PreparedToolset(filtered_toolset, prepare_add_context) + + # Step 5: Test the fully composed toolset + # Test with regular user context + regular_deps = TestDeps(user_role='user', enable_advanced=True) + regular_context = build_run_context(regular_deps) + final_toolset = await ToolManager[TestDeps].build(prepared_toolset, regular_context) + # Tool definitions should have role annotation + assert final_toolset.tool_defs == snapshot( + [ + ToolDefinition( + name='math_add', + description='Add two numbers (role: user)', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'a': {'type': 'integer'}, 'b': {'type': 'integer'}}, + 'required': ['a', 'b'], + 'type': 'object', + }, + ), + ToolDefinition( + name='math_subtract', + description='Subtract two numbers (role: user)', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'a': {'type': 'integer'}, 'b': {'type': 'integer'}}, + 'required': ['a', 'b'], + 'type': 'object', + }, + ), + ToolDefinition( + name='math_multiply', + description='Multiply two numbers (role: user)', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'a': {'type': 'integer'}, 'b': {'type': 'integer'}}, + 'required': ['a', 'b'], + 'type': 'object', + }, + ), + ToolDefinition( + name='adv_power', + description='Calculate base raised to the power of exponent (role: user)', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'base': {'type': 'integer'}, 'exponent': {'type': 'integer'}}, + 'required': ['base', 'exponent'], + 'type': 'object', + }, + ), + ] + ) + # Call a tool and check result + result = await final_toolset.handle_call(ToolCallPart(tool_name='math_add', args={'a': 5, 'b': 3})) + assert result == 8 + + # Test with admin user context (should have string tools) + admin_deps = TestDeps(user_role='admin', enable_advanced=True) + admin_context = build_run_context(admin_deps) + admin_final_toolset = await ToolManager[TestDeps].build(prepared_toolset, admin_context) + assert admin_final_toolset.tool_defs == snapshot( + [ + ToolDefinition( + name='math_add', + description='Add two numbers (role: admin)', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'a': {'type': 'integer'}, 'b': {'type': 'integer'}}, + 'required': ['a', 'b'], + 'type': 'object', + }, + ), + ToolDefinition( + name='math_subtract', + description='Subtract two numbers (role: admin)', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'a': {'type': 'integer'}, 'b': {'type': 'integer'}}, + 'required': ['a', 'b'], + 'type': 'object', + }, + ), + ToolDefinition( + name='math_multiply', + description='Multiply two numbers (role: admin)', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'a': {'type': 'integer'}, 'b': {'type': 'integer'}}, + 'required': ['a', 'b'], + 'type': 'object', + }, + ), + ToolDefinition( + name='str_concat', + description='Concatenate two strings (role: admin)', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'s1': {'type': 'string'}, 's2': {'type': 'string'}}, + 'required': ['s1', 's2'], + 'type': 'object', + }, + ), + ToolDefinition( + name='str_uppercase', + description='Convert text to uppercase (role: admin)', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'text': {'type': 'string'}}, + 'required': ['text'], + 'type': 'object', + }, + ), + ToolDefinition( + name='str_reverse', + description='Reverse a string (role: admin)', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'text': {'type': 'string'}}, + 'required': ['text'], + 'type': 'object', + }, + ), + ToolDefinition( + name='adv_power', + description='Calculate base raised to the power of exponent (role: admin)', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'base': {'type': 'integer'}, 'exponent': {'type': 'integer'}}, + 'required': ['base', 'exponent'], + 'type': 'object', + }, + ), + ] + ) + result = await admin_final_toolset.handle_call( + ToolCallPart(tool_name='str_concat', args={'s1': 'Hello', 's2': 'World'}) + ) + assert result == 'HelloWorld' + + # Test with advanced features disabled + basic_deps = TestDeps(user_role='user', enable_advanced=False) + basic_context = build_run_context(basic_deps) + basic_final_toolset = await ToolManager[TestDeps].build(prepared_toolset, basic_context) + assert basic_final_toolset.tool_defs == snapshot( + [ + ToolDefinition( + name='math_add', + description='Add two numbers (role: user)', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'a': {'type': 'integer'}, 'b': {'type': 'integer'}}, + 'required': ['a', 'b'], + 'type': 'object', + }, + ), + ToolDefinition( + name='math_subtract', + description='Subtract two numbers (role: user)', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'a': {'type': 'integer'}, 'b': {'type': 'integer'}}, + 'required': ['a', 'b'], + 'type': 'object', + }, + ), + ToolDefinition( + name='math_multiply', + description='Multiply two numbers (role: user)', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'a': {'type': 'integer'}, 'b': {'type': 'integer'}}, + 'required': ['a', 'b'], + 'type': 'object', + }, + ), + ] + ) + + +async def test_context_manager(): + try: + from pydantic_ai.mcp import MCPServerStdio + except ImportError: # pragma: lax no cover + pytest.skip('mcp is not installed') + + server1 = MCPServerStdio('python', ['-m', 'tests.mcp_server']) + server2 = MCPServerStdio('python', ['-m', 'tests.mcp_server']) + toolset = CombinedToolset([server1, PrefixedToolset(server2, 'prefix')]) + + async with toolset: + assert server1.is_running + assert server2.is_running + + async with toolset: + assert server1.is_running + assert server2.is_running diff --git a/tests/typed_agent.py b/tests/typed_agent.py index 6ea4c4c223..3e1171c076 100644 --- a/tests/typed_agent.py +++ b/tests/typed_agent.py @@ -10,7 +10,7 @@ from pydantic_ai import Agent, ModelRetry, RunContext, Tool from pydantic_ai.agent import AgentRunResult -from pydantic_ai.output import StructuredDict, TextOutput, ToolOutput +from pydantic_ai.output import DeferredToolCalls, StructuredDict, TextOutput, ToolOutput from pydantic_ai.tools import ToolDefinition # Define here so we can check `if MYPY` below. This will not be executed, MYPY will always set it to True @@ -222,6 +222,14 @@ def my_method(self) -> bool: assert_type( complex_output_agent, Agent[None, Foo | Bar | Decimal | int | bool | tuple[str, int] | str | re.Pattern[str]] ) + + complex_deferred_output_agent = Agent[ + None, Foo | Bar | Decimal | int | bool | tuple[str, int] | str | re.Pattern[str] | DeferredToolCalls + ](output_type=[complex_output_agent.output_type, DeferredToolCalls]) + assert_type( + complex_deferred_output_agent, + Agent[None, Foo | Bar | Decimal | int | bool | tuple[str, int] | str | re.Pattern[str] | DeferredToolCalls], + ) else: # pyright is able to correctly infer the type here async_int_function_agent = Agent(output_type=foobar_plain) @@ -241,6 +249,12 @@ def my_method(self) -> bool: complex_output_agent, Agent[None, Foo | Bar | Decimal | int | bool | tuple[str, int] | str | re.Pattern[str]] ) + complex_deferred_output_agent = Agent(output_type=[complex_output_agent.output_type, DeferredToolCalls]) + assert_type( + complex_deferred_output_agent, + Agent[None, Foo | Bar | Decimal | int | bool | tuple[str, int] | str | re.Pattern[str] | DeferredToolCalls], + ) + Tool(foobar_ctx, takes_ctx=True) Tool(foobar_ctx)