diff --git a/python/packages/core/agent_framework/_agents.py b/python/packages/core/agent_framework/_agents.py index 249ee9ecfb..3c40004362 100644 --- a/python/packages/core/agent_framework/_agents.py +++ b/python/packages/core/agent_framework/_agents.py @@ -878,6 +878,9 @@ async def run( user=user, additional_properties=merged_additional_options, # type: ignore[arg-type] ) + + # Ensure thread is forwarded in kwargs for tool invocation + kwargs["thread"] = thread # Filter chat_options from kwargs to prevent duplicate keyword argument filtered_kwargs = {k: v for k, v in kwargs.items() if k != "chat_options"} response = await self.chat_client.get_response( @@ -895,7 +898,12 @@ async def run( # Only notify the thread of new messages if the chatResponse was successful # to avoid inconsistent messages state in the thread. - await self._notify_thread_of_new_messages(thread, input_messages, response.messages) + await self._notify_thread_of_new_messages( + thread, + input_messages, + response.messages, + **{k: v for k, v in kwargs.items() if k != "thread"}, + ) return AgentRunResponse( messages=response.messages, response_id=response.response_id, @@ -1017,6 +1025,8 @@ async def run_stream( additional_properties=merged_additional_options, # type: ignore[arg-type] ) + # Ensure thread is forwarded in kwargs for tool invocation + kwargs["thread"] = thread # Filter chat_options from kwargs to prevent duplicate keyword argument filtered_kwargs = {k: v for k, v in kwargs.items() if k != "chat_options"} response_updates: list[ChatResponseUpdate] = [] @@ -1043,7 +1053,13 @@ async def run_stream( response = ChatResponse.from_chat_response_updates(response_updates, output_format_type=co.response_format) await self._update_thread_with_type_and_conversation_id(thread, response.conversation_id) - await self._notify_thread_of_new_messages(thread, input_messages, response.messages, **kwargs) + + await self._notify_thread_of_new_messages( + thread, + input_messages, + response.messages, + **{k: v for k, v in kwargs.items() if k != "thread"}, + ) @override def get_new_thread( diff --git a/python/packages/core/agent_framework/_tools.py b/python/packages/core/agent_framework/_tools.py index bbbf07ab7a..bc16d9edb9 100644 --- a/python/packages/core/agent_framework/_tools.py +++ b/python/packages/core/agent_framework/_tools.py @@ -627,6 +627,12 @@ def __init__( self._invocation_duration_histogram = _default_histogram() self.type: Literal["ai_function"] = "ai_function" self._forward_runtime_kwargs: bool = False + if self.func: + sig = inspect.signature(self.func) + for param in sig.parameters.values(): + if param.kind == inspect.Parameter.VAR_KEYWORD: + self._forward_runtime_kwargs = True + break @property def declaration_only(self) -> bool: @@ -915,6 +921,7 @@ def _create_input_model_from_func(func: Callable[..., Any], name: str) -> type[B ) for pname, param in sig.parameters.items() if pname not in {"self", "cls"} + and param.kind not in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD} } return create_model(f"{name}_input", **fields) # type: ignore[call-overload, no-any-return] @@ -1744,7 +1751,9 @@ async def function_invocation_wrapper( break _replace_approval_contents_with_results(prepped_messages, fcc_todo, approved_function_results) - response = await func(self, messages=prepped_messages, **kwargs) + # Filter out internal framework kwargs before passing to clients. + filtered_kwargs = {k: v for k, v in kwargs.items() if k != "thread"} + response = await func(self, messages=prepped_messages, **filtered_kwargs) # if there are function calls, we will handle them first function_results = { it.call_id for it in response.messages[0].contents if isinstance(it, FunctionResultContent) @@ -1833,7 +1842,10 @@ async def function_invocation_wrapper( # Failsafe: give up on tools, ask model for plain answer kwargs["tool_choice"] = "none" - response = await func(self, messages=prepped_messages, **kwargs) + + # Filter out internal framework kwargs before passing to clients. + filtered_kwargs = {k: v for k, v in kwargs.items() if k != "thread"} + response = await func(self, messages=prepped_messages, **filtered_kwargs) if fcc_messages: for msg in reversed(fcc_messages): response.messages.insert(0, msg) @@ -1920,7 +1932,9 @@ async def streaming_function_invocation_wrapper( _replace_approval_contents_with_results(prepped_messages, fcc_todo, approved_function_results) all_updates: list["ChatResponseUpdate"] = [] - async for update in func(self, messages=prepped_messages, **kwargs): + # Filter out internal framework kwargs before passing to clients. + filtered_kwargs = {k: v for k, v in kwargs.items() if k != "thread"} + async for update in func(self, messages=prepped_messages, **filtered_kwargs): all_updates.append(update) yield update @@ -2031,7 +2045,9 @@ async def streaming_function_invocation_wrapper( # Failsafe: give up on tools, ask model for plain answer kwargs["tool_choice"] = "none" - async for update in func(self, messages=prepped_messages, **kwargs): + # Filter out internal framework kwargs before passing to clients. + filtered_kwargs = {k: v for k, v in kwargs.items() if k != "thread"} + async for update in func(self, messages=prepped_messages, **filtered_kwargs): yield update return streaming_function_invocation_wrapper diff --git a/python/packages/core/tests/core/test_agents.py b/python/packages/core/tests/core/test_agents.py index 77d5911865..a6df07cbbe 100644 --- a/python/packages/core/tests/core/test_agents.py +++ b/python/packages/core/tests/core/test_agents.py @@ -21,9 +21,11 @@ ChatResponse, Context, ContextProvider, + FunctionCallContent, HostedCodeInterpreterTool, Role, TextContent, + ai_function, ) from agent_framework._mcp import MCPTool from agent_framework.exceptions import AgentExecutionException @@ -595,3 +597,38 @@ async def test_chat_agent_with_local_mcp_tools(chat_client: ChatClientProtocol) # Test async context manager with MCP tools async with agent: pass + + +async def test_agent_tool_receives_thread_in_kwargs(chat_client_base: Any) -> None: + """Verify tool execution receives 'thread' inside **kwargs when function is called by client.""" + + captured: dict[str, Any] = {} + + @ai_function(name="echo_thread_info") + def echo_thread_info(text: str, **kwargs: Any) -> str: # type: ignore[reportUnknownParameterType] + thread = kwargs.get("thread") + captured["has_thread"] = thread is not None + captured["has_message_store"] = thread.message_store is not None if isinstance(thread, AgentThread) else False + return f"echo: {text}" + + # Make the base client emit a function call for our tool + chat_client_base.run_responses = [ + ChatResponse( + messages=ChatMessage( + role="assistant", + contents=[FunctionCallContent(call_id="1", name="echo_thread_info", arguments='{"text": "hello"}')], + ) + ), + ChatResponse(messages=ChatMessage(role="assistant", text="done")), + ] + + agent = ChatAgent( + chat_client=chat_client_base, tools=[echo_thread_info], chat_message_store_factory=ChatMessageStore + ) + thread = agent.get_new_thread() + + result = await agent.run("hello", thread=thread) + + assert result.text == "done" + assert captured.get("has_thread") is True + assert captured.get("has_message_store") is True diff --git a/python/packages/core/tests/core/test_tools.py b/python/packages/core/tests/core/test_tools.py index c1cc0f119b..4beee1fb7d 100644 --- a/python/packages/core/tests/core/test_tools.py +++ b/python/packages/core/tests/core/test_tools.py @@ -1334,3 +1334,37 @@ async def mock_get_streaming_response(self, messages, **kwargs): assert updates[2].role == Role.ASSISTANT assert len(updates[2].contents) == 2 assert all(isinstance(c, FunctionApprovalRequestContent) for c in updates[2].contents) + + +async def test_ai_function_with_kwargs_injection(): + """Test that ai_function correctly handles kwargs injection and hides them from schema.""" + + @ai_function + def tool_with_kwargs(x: int, **kwargs: Any) -> str: + """A tool that accepts kwargs.""" + user_id = kwargs.get("user_id", "unknown") + return f"x={x}, user={user_id}" + + # Verify schema does not include kwargs + assert tool_with_kwargs.parameters() == { + "properties": {"x": {"title": "X", "type": "integer"}}, + "required": ["x"], + "title": "tool_with_kwargs_input", + "type": "object", + } + + # Verify direct invocation works + assert tool_with_kwargs(1, user_id="user1") == "x=1, user=user1" + + # Verify invoke works with injected args + result = await tool_with_kwargs.invoke( + arguments=tool_with_kwargs.input_model(x=5), + user_id="user2", + ) + assert result == "x=5, user=user2" + + # Verify invoke works without injected args (uses default) + result_default = await tool_with_kwargs.invoke( + arguments=tool_with_kwargs.input_model(x=10), + ) + assert result_default == "x=10, user=unknown" diff --git a/python/samples/getting_started/tools/README.md b/python/samples/getting_started/tools/README.md index 66ca227da6..7daf2c6b49 100644 --- a/python/samples/getting_started/tools/README.md +++ b/python/samples/getting_started/tools/README.md @@ -11,6 +11,8 @@ This folder contains examples demonstrating how to use AI functions (tools) with | [`ai_function_recover_from_failures.py`](ai_function_recover_from_failures.py) | Demonstrates graceful error handling when tools raise exceptions. Shows how agents receive error information and can recover from failures, deciding whether to retry or respond differently based on the exception. | | [`ai_function_with_approval.py`](ai_function_with_approval.py) | Shows how to implement user approval workflows for function calls without using threads. Demonstrates both streaming and non-streaming approval patterns where users can approve or reject function executions before they run. | | [`ai_function_with_approval_and_threads.py`](ai_function_with_approval_and_threads.py) | Demonstrates tool approval workflows using threads for automatic conversation history management. Shows how threads simplify approval workflows by automatically storing and retrieving conversation context. Includes both approval and rejection examples. | +| [`ai_function_with_kwargs.py`](ai_function_with_kwargs.py) | Demonstrates how to inject custom arguments (context) into an AI function from the agent's run method. Useful for passing runtime information like access tokens or user IDs that the tool needs but the model shouldn't see. | +| [`ai_function_with_thread_injection.py`](ai_function_with_thread_injection.py) | Shows how to access the current `thread` object inside an AI function via `**kwargs`. | | [`ai_function_with_max_exceptions.py`](ai_function_with_max_exceptions.py) | Shows how to limit the number of times a tool can fail with exceptions using `max_invocation_exceptions`. Useful for preventing expensive tools from being called repeatedly when they keep failing. | | [`ai_function_with_max_invocations.py`](ai_function_with_max_invocations.py) | Demonstrates limiting the total number of times a tool can be invoked using `max_invocations`. Useful for rate-limiting expensive operations or ensuring tools are only called a specific number of times per conversation. | | [`ai_functions_in_class.py`](ai_functions_in_class.py) | Shows how to use `ai_function` decorator with class methods to create stateful tools. Demonstrates how class state can control tool behavior dynamically, allowing you to adjust tool functionality at runtime by modifying class properties. | diff --git a/python/samples/getting_started/tools/ai_function_with_kwargs.py b/python/samples/getting_started/tools/ai_function_with_kwargs.py new file mode 100644 index 0000000000..ff75aa7aa4 --- /dev/null +++ b/python/samples/getting_started/tools/ai_function_with_kwargs.py @@ -0,0 +1,53 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +from typing import Annotated, Any + +from agent_framework import ai_function +from agent_framework.openai import OpenAIResponsesClient +from pydantic import Field + +""" +AI Function with kwargs Example + +This example demonstrates how to inject custom keyword arguments (kwargs) into an AI function +from the agent's run method, without exposing them to the AI model. + +This is useful for passing runtime information like access tokens, user IDs, or +request-specific context that the tool needs but the model shouldn't know about +or provide. +""" + + +# Define the function tool with **kwargs to accept injected arguments +@ai_function +def get_weather( + location: Annotated[str, Field(description="The location to get the weather for.")], + **kwargs: Any, +) -> str: + """Get the weather for a given location.""" + # Extract the injected argument from kwargs + user_id = kwargs.get("user_id", "unknown") + + # Simulate using the user_id for logging or personalization + print(f"Getting weather for user: {user_id}") + + return f"The weather in {location} is cloudy with a high of 15°C." + + +async def main() -> None: + agent = OpenAIResponsesClient().create_agent( + name="WeatherAgent", + instructions="You are a helpful weather assistant.", + tools=[get_weather], + ) + + # Pass the injected argument when running the agent + # The 'user_id' kwarg will be passed down to the tool execution via **kwargs + response = await agent.run("What is the weather like in Amsterdam?", user_id="user_123") + + print(f"Agent: {response.text}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/getting_started/tools/ai_function_with_thread_injection.py b/python/samples/getting_started/tools/ai_function_with_thread_injection.py new file mode 100644 index 0000000000..d3e5d0f808 --- /dev/null +++ b/python/samples/getting_started/tools/ai_function_with_thread_injection.py @@ -0,0 +1,52 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +from typing import Annotated, Any + +from agent_framework import AgentThread, ai_function +from agent_framework.openai import OpenAIChatClient +from pydantic import Field + +""" +AI Function with Thread Injection Example + +This example demonstrates the behavior when passing 'thread' to agent.run() +and accessing that thread in AI function. +""" + + +# Define the function tool with **kwargs +@ai_function +async def get_weather( + location: Annotated[str, Field(description="The location to get the weather for.")], + **kwargs: Any, +) -> str: + """Get the weather for a given location.""" + # Get thread object from kwargs + thread = kwargs.get("thread") + if thread and isinstance(thread, AgentThread): + if thread.message_store: + messages = await thread.message_store.list_messages() + print(f"Thread contains {len(messages)} messages.") + elif thread.service_thread_id: + print(f"Thread ID: {thread.service_thread_id}.") + + return f"The weather in {location} is cloudy." + + +async def main() -> None: + agent = OpenAIChatClient().create_agent( + name="WeatherAgent", instructions="You are a helpful weather assistant.", tools=[get_weather] + ) + + # Create a thread + thread = agent.get_new_thread() + + # Run the agent with the thread + print(f"Agent: {await agent.run('What is the weather in London?', thread=thread)}") + print(f"Agent: {await agent.run('What is the weather in Amsterdam?', thread=thread)}") + print(f"Agent: {await agent.run('What cities did I ask about?', thread=thread)}") + + +if __name__ == "__main__": + asyncio.run(main())