diff --git a/src/agents/tool.py b/src/agents/tool.py index 8c8d3e988..79309bc22 100644 --- a/src/agents/tool.py +++ b/src/agents/tool.py @@ -228,6 +228,10 @@ class FunctionTool: and returns whether the tool is enabled. You can use this to dynamically enable/disable a tool based on your context/state.""" + _func: ToolFunction[...] | None = field(default=None, repr=False) + """The function that implements the tool. Ensures that a reference to the + original function exists when @function_tool is used.""" + # Tool-specific guardrails tool_input_guardrails: list[ToolInputGuardrail[Any]] | None = None """Optional list of input guardrails to run before invoking this tool.""" @@ -239,6 +243,19 @@ def __post_init__(self): if self.strict_json_schema: self.params_json_schema = ensure_strict_json_schema(self.params_json_schema) + # Dress the FunctionTool object with the name and docstring of the wrapped function + if self._func: + self.__name__ = self._func.__name__ + self.__doc__ = self._func.__doc__ + + def __call__(self, *args, **kwargs): + if not self._func: + raise AttributeError("""FunctionTool has no attribute `_func` and is not callable. + Likely because it was created directly without the + @function_tool decorator.""") + + return self._func(*args, **kwargs) + @dataclass class FileSearchTool: @@ -845,6 +862,7 @@ async def _on_invoke_tool(ctx: ToolContext[Any], input: str) -> Any: on_invoke_tool=_on_invoke_tool, strict_json_schema=strict_mode, is_enabled=is_enabled, + _func=func, ) # If func is actually a callable, we were used as @function_tool with no parentheses diff --git a/tests/test_function_tool.py b/tests/test_function_tool.py index 18107773d..a65fc103f 100644 --- a/tests/test_function_tool.py +++ b/tests/test_function_tool.py @@ -1,4 +1,6 @@ +import inspect import json +from dataclasses import asdict from typing import Any import pytest @@ -81,6 +83,44 @@ async def test_simple_function(): ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments=""), "" ) + # Direct call + result = tool(2, 2) + assert result == 4 + + +async def async_function(a: int, b: int = 5): + return a + b + + +@pytest.mark.asyncio +async def test_async_function(): + tool = function_tool(async_function, failure_error_function=None) + assert tool.name == "async_function" + + result = await tool.on_invoke_tool( + ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments='{"a": 1}'), + '{"a": 1}', + ) + assert result == 6 + + result = await tool.on_invoke_tool( + ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments='{"a": 1, "b": 2}'), + '{"a": 1, "b": 2}', + ) + assert result == 3 + + # Missing required argument should raise an error + with pytest.raises(ModelBehaviorError): + await tool.on_invoke_tool( + ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments=""), "" + ) + + # Direct call + result = await tool(2, 2) + assert result == 4 + + assert not inspect.iscoroutinefunction(tool.__call__), "tool.__call__ should sync." + class Foo(BaseModel): a: int @@ -148,6 +188,22 @@ async def test_complex_args_function(): ) +def test_func_tool_name_doc_inheritance(): + tool = function_tool(simple_function) + assert tool.__name__ == simple_function.__name__ + assert tool.__doc__ == simple_function.__doc__ + + +def test_absent_func_tool(): + tool = function_tool(simple_function) + kwargs = asdict(tool) + kwargs.pop("_func") + manually_defined_tool = FunctionTool(**kwargs) + + with pytest.raises(AttributeError, match="not callable"): + manually_defined_tool(1, 1) + + def test_function_config_overrides(): tool = function_tool(simple_function, name_override="custom_name") assert tool.name == "custom_name"