Skip to content

Make the reset behavior on tool use configurable #335

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 2 additions & 7 deletions docs/agents.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,6 @@ Supplying a list of tools doesn't always mean the LLM will use a tool. You can f

!!! note

To prevent infinite loops, the framework automatically resets `tool_choice` to "auto" after a tool call in the following scenarios:

1. When `tool_choice` is set to a specific function name (any string that's not "auto", "required", or "none")
2. When `tool_choice` is set to "required" AND there is only one tool available

This targeted reset mechanism allows the model to decide whether to make additional tool calls in subsequent turns while avoiding infinite loops in these specific cases.

To prevent infinite loops, the framework automatically resets `tool_choice` to "auto" after a tool call. This behavior is configurable via [`agent.reset_tool_choice`][agents.agent.Agent.reset_tool_choice]. The infinite loop is because tool results are sent to the LLM, which then generates another tool call because of `tool_choice`, ad infinitum.

If you want the Agent to completely stop after a tool call (rather than continuing with auto mode), you can set [`Agent.tool_use_behavior="stop_on_first_tool"`] which will directly use the tool output as the final response without further LLM processing.
73 changes: 34 additions & 39 deletions src/agents/_run_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import dataclasses
import inspect
from collections.abc import Awaitable
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, cast

from openai.types.responses import (
Expand Down Expand Up @@ -77,6 +77,23 @@ class QueueCompleteSentinel:
_NOT_FINAL_OUTPUT = ToolsToFinalOutputResult(is_final_output=False, final_output=None)


@dataclass
class AgentToolUseTracker:
agent_to_tools: list[tuple[Agent, list[str]]] = field(default_factory=list)
"""Tuple of (agent, list of tools used). Can't use a dict because agents aren't hashable."""

def add_tool_use(self, agent: Agent[Any], tool_names: list[str]) -> None:
existing_data = next((item for item in self.agent_to_tools if item[0] == agent), None)
if existing_data:
existing_data[1].extend(tool_names)
else:
self.agent_to_tools.append((agent, tool_names))

def has_used_tools(self, agent: Agent[Any]) -> bool:
existing_data = next((item for item in self.agent_to_tools if item[0] == agent), None)
return existing_data is not None and len(existing_data[1]) > 0


@dataclass
class ToolRunHandoff:
handoff: Handoff
Expand All @@ -101,6 +118,7 @@ class ProcessedResponse:
handoffs: list[ToolRunHandoff]
functions: list[ToolRunFunction]
computer_actions: list[ToolRunComputerAction]
tools_used: list[str] # Names of all tools used, including hosted tools

def has_tools_to_run(self) -> bool:
# Handoffs, functions and computer actions need local processing
Expand Down Expand Up @@ -208,29 +226,6 @@ async def execute_tools_and_side_effects(
new_step_items.extend([result.run_item for result in function_results])
new_step_items.extend(computer_results)

# Reset tool_choice to "auto" after tool execution to prevent infinite loops
if processed_response.functions or processed_response.computer_actions:
tools = agent.tools

if (
run_config.model_settings and
cls._should_reset_tool_choice(run_config.model_settings, tools)
):
# update the run_config model settings with a copy
new_run_config_settings = dataclasses.replace(
run_config.model_settings,
tool_choice="auto"
)
run_config = dataclasses.replace(run_config, model_settings=new_run_config_settings)

if cls._should_reset_tool_choice(agent.model_settings, tools):
# Create a modified copy instead of modifying the original agent
new_model_settings = dataclasses.replace(
agent.model_settings,
tool_choice="auto"
)
agent = dataclasses.replace(agent, model_settings=new_model_settings)

# Second, check if there are any handoffs
if run_handoffs := processed_response.handoffs:
return await cls.execute_handoffs(
Expand Down Expand Up @@ -322,22 +317,16 @@ async def execute_tools_and_side_effects(
)

@classmethod
def _should_reset_tool_choice(cls, model_settings: ModelSettings, tools: list[Tool]) -> bool:
if model_settings is None or model_settings.tool_choice is None:
return False
def maybe_reset_tool_choice(
cls, agent: Agent[Any], tool_use_tracker: AgentToolUseTracker, model_settings: ModelSettings
) -> ModelSettings:
"""Resets tool choice to None if the agent has used tools and the agent's reset_tool_choice
flag is True."""

# for specific tool choices
if (
isinstance(model_settings.tool_choice, str) and
model_settings.tool_choice not in ["auto", "required", "none"]
):
return True
if agent.reset_tool_choice is True and tool_use_tracker.has_used_tools(agent):
return dataclasses.replace(model_settings, tool_choice=None)

# for one tool and required tool choice
if model_settings.tool_choice == "required":
return len(tools) == 1

return False
return model_settings

@classmethod
def process_model_response(
Expand All @@ -354,7 +343,7 @@ def process_model_response(
run_handoffs = []
functions = []
computer_actions = []

tools_used: list[str] = []
handoff_map = {handoff.tool_name: handoff for handoff in handoffs}
function_map = {tool.name: tool for tool in all_tools if isinstance(tool, FunctionTool)}
computer_tool = next((tool for tool in all_tools if isinstance(tool, ComputerTool)), None)
Expand All @@ -364,12 +353,15 @@ def process_model_response(
items.append(MessageOutputItem(raw_item=output, agent=agent))
elif isinstance(output, ResponseFileSearchToolCall):
items.append(ToolCallItem(raw_item=output, agent=agent))
tools_used.append("file_search")
elif isinstance(output, ResponseFunctionWebSearch):
items.append(ToolCallItem(raw_item=output, agent=agent))
tools_used.append("web_search")
elif isinstance(output, ResponseReasoningItem):
items.append(ReasoningItem(raw_item=output, agent=agent))
elif isinstance(output, ResponseComputerToolCall):
items.append(ToolCallItem(raw_item=output, agent=agent))
tools_used.append("computer_use")
if not computer_tool:
_error_tracing.attach_error_to_current_span(
SpanError(
Expand All @@ -391,6 +383,8 @@ def process_model_response(
if not isinstance(output, ResponseFunctionToolCall):
continue

tools_used.append(output.name)

# Handoffs
if output.name in handoff_map:
items.append(HandoffCallItem(raw_item=output, agent=agent))
Expand Down Expand Up @@ -422,6 +416,7 @@ def process_model_response(
handoffs=run_handoffs,
functions=functions,
computer_actions=computer_actions,
tools_used=tools_used,
)

@classmethod
Expand Down
4 changes: 4 additions & 0 deletions src/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,10 @@ class Agent(Generic[TContext]):
web search, etc are always processed by the LLM.
"""

reset_tool_choice: bool = True
"""Whether to reset the tool choice to the default value after a tool has been called. Defaults
to True. This ensures that the agent doesn't enter an infinite loop of tool usage."""

def clone(self, **kwargs: Any) -> Agent[TContext]:
"""Make a copy of the agent, with the given arguments changed. For example, you could do:
```
Expand Down
6 changes: 4 additions & 2 deletions src/agents/models/openai_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,10 @@ async def _fetch_response(
list_input = ItemHelpers.input_to_new_input_list(input)

parallel_tool_calls = (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this would be easier to read as an if/elif/else block rather than this expression

True if model_settings.parallel_tool_calls and tools and len(tools) > 0
else False if model_settings.parallel_tool_calls is False
True
if model_settings.parallel_tool_calls and tools and len(tools) > 0
else False
if model_settings.parallel_tool_calls is False
else NOT_GIVEN
)

Expand Down
21 changes: 21 additions & 0 deletions src/agents/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from agents.tool import Tool

from ._run_impl import (
AgentToolUseTracker,
NextStepFinalOutput,
NextStepHandoff,
NextStepRunAgain,
Expand Down Expand Up @@ -151,6 +152,8 @@ async def run(
if run_config is None:
run_config = RunConfig()

tool_use_tracker = AgentToolUseTracker()

with TraceCtxManager(
workflow_name=run_config.workflow_name,
trace_id=run_config.trace_id,
Expand Down Expand Up @@ -227,6 +230,7 @@ async def run(
context_wrapper=context_wrapper,
run_config=run_config,
should_run_agent_start_hooks=should_run_agent_start_hooks,
tool_use_tracker=tool_use_tracker,
),
)
else:
Expand All @@ -239,6 +243,7 @@ async def run(
context_wrapper=context_wrapper,
run_config=run_config,
should_run_agent_start_hooks=should_run_agent_start_hooks,
tool_use_tracker=tool_use_tracker,
)
should_run_agent_start_hooks = False

Expand Down Expand Up @@ -486,6 +491,7 @@ async def _run_streamed_impl(
current_agent = starting_agent
current_turn = 0
should_run_agent_start_hooks = True
tool_use_tracker = AgentToolUseTracker()

streamed_result._event_queue.put_nowait(AgentUpdatedStreamEvent(new_agent=current_agent))

Expand Down Expand Up @@ -546,6 +552,7 @@ async def _run_streamed_impl(
context_wrapper,
run_config,
should_run_agent_start_hooks,
tool_use_tracker,
)
should_run_agent_start_hooks = False

Expand Down Expand Up @@ -613,6 +620,7 @@ async def _run_single_turn_streamed(
context_wrapper: RunContextWrapper[TContext],
run_config: RunConfig,
should_run_agent_start_hooks: bool,
tool_use_tracker: AgentToolUseTracker,
) -> SingleStepResult:
if should_run_agent_start_hooks:
await asyncio.gather(
Expand All @@ -635,6 +643,8 @@ async def _run_single_turn_streamed(
all_tools = await cls._get_all_tools(agent)
model = cls._get_model(agent, run_config)
model_settings = agent.model_settings.resolve(run_config.model_settings)
model_settings = RunImpl.maybe_reset_tool_choice(agent, tool_use_tracker, model_settings)

final_response: ModelResponse | None = None

input = ItemHelpers.input_to_new_input_list(streamed_result.input)
Expand Down Expand Up @@ -687,6 +697,7 @@ async def _run_single_turn_streamed(
hooks=hooks,
context_wrapper=context_wrapper,
run_config=run_config,
tool_use_tracker=tool_use_tracker,
)

RunImpl.stream_step_result_to_queue(single_step_result, streamed_result._event_queue)
Expand All @@ -704,6 +715,7 @@ async def _run_single_turn(
context_wrapper: RunContextWrapper[TContext],
run_config: RunConfig,
should_run_agent_start_hooks: bool,
tool_use_tracker: AgentToolUseTracker,
) -> SingleStepResult:
# Ensure we run the hooks before anything else
if should_run_agent_start_hooks:
Expand Down Expand Up @@ -732,6 +744,7 @@ async def _run_single_turn(
handoffs,
context_wrapper,
run_config,
tool_use_tracker,
)

return await cls._get_single_step_result_from_response(
Expand All @@ -745,6 +758,7 @@ async def _run_single_turn(
hooks=hooks,
context_wrapper=context_wrapper,
run_config=run_config,
tool_use_tracker=tool_use_tracker,
)

@classmethod
Expand All @@ -761,6 +775,7 @@ async def _get_single_step_result_from_response(
hooks: RunHooks[TContext],
context_wrapper: RunContextWrapper[TContext],
run_config: RunConfig,
tool_use_tracker: AgentToolUseTracker,
) -> SingleStepResult:
processed_response = RunImpl.process_model_response(
agent=agent,
Expand All @@ -769,6 +784,9 @@ async def _get_single_step_result_from_response(
output_schema=output_schema,
handoffs=handoffs,
)

tool_use_tracker.add_tool_use(agent, processed_response.tools_used)

return await RunImpl.execute_tools_and_side_effects(
agent=agent,
original_input=original_input,
Expand Down Expand Up @@ -868,9 +886,12 @@ async def _get_new_response(
handoffs: list[Handoff],
context_wrapper: RunContextWrapper[TContext],
run_config: RunConfig,
tool_use_tracker: AgentToolUseTracker,
) -> ModelResponse:
model = cls._get_model(agent, run_config)
model_settings = agent.model_settings.resolve(run_config.model_settings)
model_settings = RunImpl.maybe_reset_tool_choice(agent, tool_use_tracker, model_settings)

new_response = await model.get_response(
system_instructions=system_prompt,
input=input,
Expand Down
10 changes: 10 additions & 0 deletions tests/fake_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from collections.abc import AsyncIterator
from typing import Any

from openai.types.responses import Response, ResponseCompletedEvent

Expand Down Expand Up @@ -31,6 +32,7 @@ def __init__(
[initial_output] if initial_output else []
)
self.tracing_enabled = tracing_enabled
self.last_turn_args: dict[str, Any] = {}

def set_next_output(self, output: list[TResponseOutputItem] | Exception):
self.turn_outputs.append(output)
Expand All @@ -53,6 +55,14 @@ async def get_response(
handoffs: list[Handoff],
tracing: ModelTracing,
) -> ModelResponse:
self.last_turn_args = {
"system_instructions": system_instructions,
"input": input,
"model_settings": model_settings,
"tools": tools,
"output_schema": output_schema,
}

with generation_span(disabled=not self.tracing_enabled) as span:
output = self.get_next_output()

Expand Down
Loading