Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions python/packages/core/agent_framework/_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,6 +878,9 @@ async def run(
user=user,
additional_properties=merged_additional_options, # type: ignore[arg-type]
)

# Ensure thread is forwarded in kwargs for tool invocation
kwargs["thread"] = thread
# Filter chat_options from kwargs to prevent duplicate keyword argument
filtered_kwargs = {k: v for k, v in kwargs.items() if k != "chat_options"}
response = await self.chat_client.get_response(
Expand All @@ -895,7 +898,12 @@ async def run(

# Only notify the thread of new messages if the chatResponse was successful
# to avoid inconsistent messages state in the thread.
await self._notify_thread_of_new_messages(thread, input_messages, response.messages)
await self._notify_thread_of_new_messages(
thread,
input_messages,
response.messages,
**{k: v for k, v in kwargs.items() if k != "thread"},
)
return AgentRunResponse(
messages=response.messages,
response_id=response.response_id,
Expand Down Expand Up @@ -1017,6 +1025,8 @@ async def run_stream(
additional_properties=merged_additional_options, # type: ignore[arg-type]
)

# Ensure thread is forwarded in kwargs for tool invocation
kwargs["thread"] = thread
# Filter chat_options from kwargs to prevent duplicate keyword argument
filtered_kwargs = {k: v for k, v in kwargs.items() if k != "chat_options"}
response_updates: list[ChatResponseUpdate] = []
Expand All @@ -1043,7 +1053,13 @@ async def run_stream(

response = ChatResponse.from_chat_response_updates(response_updates, output_format_type=co.response_format)
await self._update_thread_with_type_and_conversation_id(thread, response.conversation_id)
await self._notify_thread_of_new_messages(thread, input_messages, response.messages, **kwargs)

await self._notify_thread_of_new_messages(
thread,
input_messages,
response.messages,
**{k: v for k, v in kwargs.items() if k != "thread"},
)

@override
def get_new_thread(
Expand Down
24 changes: 20 additions & 4 deletions python/packages/core/agent_framework/_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,12 @@ def __init__(
self._invocation_duration_histogram = _default_histogram()
self.type: Literal["ai_function"] = "ai_function"
self._forward_runtime_kwargs: bool = False
if self.func:
sig = inspect.signature(self.func)
for param in sig.parameters.values():
if param.kind == inspect.Parameter.VAR_KEYWORD:
self._forward_runtime_kwargs = True
break

@property
def declaration_only(self) -> bool:
Expand Down Expand Up @@ -915,6 +921,7 @@ def _create_input_model_from_func(func: Callable[..., Any], name: str) -> type[B
)
for pname, param in sig.parameters.items()
if pname not in {"self", "cls"}
and param.kind not in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD}
}
return create_model(f"{name}_input", **fields) # type: ignore[call-overload, no-any-return]

Expand Down Expand Up @@ -1744,7 +1751,9 @@ async def function_invocation_wrapper(
break
_replace_approval_contents_with_results(prepped_messages, fcc_todo, approved_function_results)

response = await func(self, messages=prepped_messages, **kwargs)
# Filter out internal framework kwargs before passing to clients.
filtered_kwargs = {k: v for k, v in kwargs.items() if k != "thread"}
response = await func(self, messages=prepped_messages, **filtered_kwargs)
# if there are function calls, we will handle them first
function_results = {
it.call_id for it in response.messages[0].contents if isinstance(it, FunctionResultContent)
Expand Down Expand Up @@ -1833,7 +1842,10 @@ async def function_invocation_wrapper(

# Failsafe: give up on tools, ask model for plain answer
kwargs["tool_choice"] = "none"
response = await func(self, messages=prepped_messages, **kwargs)

# Filter out internal framework kwargs before passing to clients.
filtered_kwargs = {k: v for k, v in kwargs.items() if k != "thread"}
response = await func(self, messages=prepped_messages, **filtered_kwargs)
if fcc_messages:
for msg in reversed(fcc_messages):
response.messages.insert(0, msg)
Expand Down Expand Up @@ -1920,7 +1932,9 @@ async def streaming_function_invocation_wrapper(
_replace_approval_contents_with_results(prepped_messages, fcc_todo, approved_function_results)

all_updates: list["ChatResponseUpdate"] = []
async for update in func(self, messages=prepped_messages, **kwargs):
# Filter out internal framework kwargs before passing to clients.
filtered_kwargs = {k: v for k, v in kwargs.items() if k != "thread"}
async for update in func(self, messages=prepped_messages, **filtered_kwargs):
all_updates.append(update)
yield update

Expand Down Expand Up @@ -2031,7 +2045,9 @@ async def streaming_function_invocation_wrapper(

# Failsafe: give up on tools, ask model for plain answer
kwargs["tool_choice"] = "none"
async for update in func(self, messages=prepped_messages, **kwargs):
# Filter out internal framework kwargs before passing to clients.
filtered_kwargs = {k: v for k, v in kwargs.items() if k != "thread"}
async for update in func(self, messages=prepped_messages, **filtered_kwargs):
yield update

return streaming_function_invocation_wrapper
Expand Down
37 changes: 37 additions & 0 deletions python/packages/core/tests/core/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
ChatResponse,
Context,
ContextProvider,
FunctionCallContent,
HostedCodeInterpreterTool,
Role,
TextContent,
ai_function,
)
from agent_framework._mcp import MCPTool
from agent_framework.exceptions import AgentExecutionException
Expand Down Expand Up @@ -595,3 +597,38 @@ async def test_chat_agent_with_local_mcp_tools(chat_client: ChatClientProtocol)
# Test async context manager with MCP tools
async with agent:
pass


async def test_agent_tool_receives_thread_in_kwargs(chat_client_base: Any) -> None:
"""Verify tool execution receives 'thread' inside **kwargs when function is called by client."""

captured: dict[str, Any] = {}

@ai_function(name="echo_thread_info")
def echo_thread_info(text: str, **kwargs: Any) -> str: # type: ignore[reportUnknownParameterType]
thread = kwargs.get("thread")
captured["has_thread"] = thread is not None
captured["has_message_store"] = thread.message_store is not None if isinstance(thread, AgentThread) else False
return f"echo: {text}"

# Make the base client emit a function call for our tool
chat_client_base.run_responses = [
ChatResponse(
messages=ChatMessage(
role="assistant",
contents=[FunctionCallContent(call_id="1", name="echo_thread_info", arguments='{"text": "hello"}')],
)
),
ChatResponse(messages=ChatMessage(role="assistant", text="done")),
]

agent = ChatAgent(
chat_client=chat_client_base, tools=[echo_thread_info], chat_message_store_factory=ChatMessageStore
)
thread = agent.get_new_thread()

result = await agent.run("hello", thread=thread)

assert result.text == "done"
assert captured.get("has_thread") is True
assert captured.get("has_message_store") is True
34 changes: 34 additions & 0 deletions python/packages/core/tests/core/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -1334,3 +1334,37 @@ async def mock_get_streaming_response(self, messages, **kwargs):
assert updates[2].role == Role.ASSISTANT
assert len(updates[2].contents) == 2
assert all(isinstance(c, FunctionApprovalRequestContent) for c in updates[2].contents)


async def test_ai_function_with_kwargs_injection():
"""Test that ai_function correctly handles kwargs injection and hides them from schema."""

@ai_function
def tool_with_kwargs(x: int, **kwargs: Any) -> str:
"""A tool that accepts kwargs."""
user_id = kwargs.get("user_id", "unknown")
return f"x={x}, user={user_id}"

# Verify schema does not include kwargs
assert tool_with_kwargs.parameters() == {
"properties": {"x": {"title": "X", "type": "integer"}},
"required": ["x"],
"title": "tool_with_kwargs_input",
"type": "object",
}

# Verify direct invocation works
assert tool_with_kwargs(1, user_id="user1") == "x=1, user=user1"

# Verify invoke works with injected args
result = await tool_with_kwargs.invoke(
arguments=tool_with_kwargs.input_model(x=5),
user_id="user2",
)
assert result == "x=5, user=user2"

# Verify invoke works without injected args (uses default)
result_default = await tool_with_kwargs.invoke(
arguments=tool_with_kwargs.input_model(x=10),
)
assert result_default == "x=10, user=unknown"
2 changes: 2 additions & 0 deletions python/samples/getting_started/tools/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ This folder contains examples demonstrating how to use AI functions (tools) with
| [`ai_function_recover_from_failures.py`](ai_function_recover_from_failures.py) | Demonstrates graceful error handling when tools raise exceptions. Shows how agents receive error information and can recover from failures, deciding whether to retry or respond differently based on the exception. |
| [`ai_function_with_approval.py`](ai_function_with_approval.py) | Shows how to implement user approval workflows for function calls without using threads. Demonstrates both streaming and non-streaming approval patterns where users can approve or reject function executions before they run. |
| [`ai_function_with_approval_and_threads.py`](ai_function_with_approval_and_threads.py) | Demonstrates tool approval workflows using threads for automatic conversation history management. Shows how threads simplify approval workflows by automatically storing and retrieving conversation context. Includes both approval and rejection examples. |
| [`ai_function_with_kwargs.py`](ai_function_with_kwargs.py) | Demonstrates how to inject custom arguments (context) into an AI function from the agent's run method. Useful for passing runtime information like access tokens or user IDs that the tool needs but the model shouldn't see. |
| [`ai_function_with_thread_injection.py`](ai_function_with_thread_injection.py) | Shows how to access the current `thread` object inside an AI function via `**kwargs`. |
| [`ai_function_with_max_exceptions.py`](ai_function_with_max_exceptions.py) | Shows how to limit the number of times a tool can fail with exceptions using `max_invocation_exceptions`. Useful for preventing expensive tools from being called repeatedly when they keep failing. |
| [`ai_function_with_max_invocations.py`](ai_function_with_max_invocations.py) | Demonstrates limiting the total number of times a tool can be invoked using `max_invocations`. Useful for rate-limiting expensive operations or ensuring tools are only called a specific number of times per conversation. |
| [`ai_functions_in_class.py`](ai_functions_in_class.py) | Shows how to use `ai_function` decorator with class methods to create stateful tools. Demonstrates how class state can control tool behavior dynamically, allowing you to adjust tool functionality at runtime by modifying class properties. |
Expand Down
53 changes: 53 additions & 0 deletions python/samples/getting_started/tools/ai_function_with_kwargs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) Microsoft. All rights reserved.

import asyncio
from typing import Annotated, Any

from agent_framework import ai_function
from agent_framework.openai import OpenAIResponsesClient
from pydantic import Field

"""
AI Function with kwargs Example

This example demonstrates how to inject custom keyword arguments (kwargs) into an AI function
from the agent's run method, without exposing them to the AI model.

This is useful for passing runtime information like access tokens, user IDs, or
request-specific context that the tool needs but the model shouldn't know about
or provide.
"""


# Define the function tool with **kwargs to accept injected arguments
@ai_function
def get_weather(
location: Annotated[str, Field(description="The location to get the weather for.")],
**kwargs: Any,
) -> str:
"""Get the weather for a given location."""
# Extract the injected argument from kwargs
user_id = kwargs.get("user_id", "unknown")

# Simulate using the user_id for logging or personalization
print(f"Getting weather for user: {user_id}")

return f"The weather in {location} is cloudy with a high of 15°C."


async def main() -> None:
agent = OpenAIResponsesClient().create_agent(
name="WeatherAgent",
instructions="You are a helpful weather assistant.",
tools=[get_weather],
)

# Pass the injected argument when running the agent
# The 'user_id' kwarg will be passed down to the tool execution via **kwargs
response = await agent.run("What is the weather like in Amsterdam?", user_id="user_123")

print(f"Agent: {response.text}")


if __name__ == "__main__":
asyncio.run(main())
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (c) Microsoft. All rights reserved.

import asyncio
from typing import Annotated, Any

from agent_framework import AgentThread, ai_function
from agent_framework.openai import OpenAIChatClient
from pydantic import Field

"""
AI Function with Thread Injection Example

This example demonstrates the behavior when passing 'thread' to agent.run()
and accessing that thread in AI function.
"""


# Define the function tool with **kwargs
@ai_function
async def get_weather(
location: Annotated[str, Field(description="The location to get the weather for.")],
**kwargs: Any,
) -> str:
"""Get the weather for a given location."""
# Get thread object from kwargs
thread = kwargs.get("thread")
if thread and isinstance(thread, AgentThread):
if thread.message_store:
messages = await thread.message_store.list_messages()
print(f"Thread contains {len(messages)} messages.")
elif thread.service_thread_id:
print(f"Thread ID: {thread.service_thread_id}.")

return f"The weather in {location} is cloudy."


async def main() -> None:
agent = OpenAIChatClient().create_agent(
name="WeatherAgent", instructions="You are a helpful weather assistant.", tools=[get_weather]
)

# Create a thread
thread = agent.get_new_thread()

# Run the agent with the thread
print(f"Agent: {await agent.run('What is the weather in London?', thread=thread)}")
print(f"Agent: {await agent.run('What is the weather in Amsterdam?', thread=thread)}")
print(f"Agent: {await agent.run('What cities did I ask about?', thread=thread)}")


if __name__ == "__main__":
asyncio.run(main())
Loading