diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 83389a50..2271f5ee 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -585,6 +585,8 @@ def tool( *, retries: int | None = None, prepare: ToolPrepareFunc[AgentDeps] | None = None, + docstring_format: Literal["google", "numpy", "sphinx", "auto"] = "auto", + require_parameter_descriptions: bool = False ) -> Callable[[ToolFuncContext[AgentDeps, ToolParams]], ToolFuncContext[AgentDeps, ToolParams]]: ... def tool( @@ -594,6 +596,8 @@ def tool( *, retries: int | None = None, prepare: ToolPrepareFunc[AgentDeps] | None = None, + docstring_format: Literal["google", "numpy", "sphinx", "auto"] = "auto", + require_parameter_descriptions: bool = False ) -> Any: """Decorator to register a tool function which takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument. @@ -638,13 +642,13 @@ def tool_decorator( func_: ToolFuncContext[AgentDeps, ToolParams], ) -> ToolFuncContext[AgentDeps, ToolParams]: # noinspection PyTypeChecker - self._register_function(func_, True, retries, prepare) + self._register_function(func_, True, retries, prepare, docstring_format, require_parameter_descriptions) return func_ return tool_decorator else: # noinspection PyTypeChecker - self._register_function(func, True, retries, prepare) + self._register_function(func, True, retries, prepare, docstring_format, require_parameter_descriptions) return func @overload @@ -722,10 +726,13 @@ def _register_function( takes_ctx: bool, retries: int | None, prepare: ToolPrepareFunc[AgentDeps] | None, + docstring_format: Literal["google", "numpy", "sphinx", "auto"] = "auto", + require_parameter_descriptions: bool = False ) -> None: """Private utility to register a function as a tool.""" retries_ = retries if retries is not None else self._default_retries - tool = Tool(func, takes_ctx=takes_ctx, max_retries=retries_, prepare=prepare) + tool = Tool(func, takes_ctx=takes_ctx, max_retries=retries_, prepare=prepare, docstring_format=docstring_format, + require_parameter_descriptions=require_parameter_descriptions) self._register_tool(tool) def _register_tool(self, tool: Tool[AgentDeps]) -> None: diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index 9f9b4fe8..92aac7ff 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -3,7 +3,7 @@ import inspect from collections.abc import Awaitable from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, cast +from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, cast, Literal from pydantic import ValidationError from pydantic_core import SchemaValidator @@ -144,6 +144,8 @@ class Tool(Generic[AgentDeps]): _validator: SchemaValidator = field(init=False, repr=False) _parameters_json_schema: ObjectJsonSchema = field(init=False) current_retry: int = field(default=0, init=False) + docstring_format: Literal["google", "numpy", "sphinx", "auto"] = "auto" + require_parameter_descriptions: bool = False def __init__( self, @@ -154,6 +156,8 @@ def __init__( name: str | None = None, description: str | None = None, prepare: ToolPrepareFunc[AgentDeps] | None = None, + docstring_format: Literal["google", "numpy", "sphinx", "auto"] = "auto", + require_parameter_descriptions: bool = False ): """Create a new tool instance. @@ -211,6 +215,8 @@ async def prep_my_tool( self.name = name or function.__name__ self.description = description or f['description'] self.prepare = prepare + self.docstring_format = docstring_format + self.require_parameter_descriptions = require_parameter_descriptions self._is_async = inspect.iscoroutinefunction(self.function) self._single_arg_name = f['single_arg_name'] self._positional_fields = f['positional_fields']