Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion src/agents/guardrail.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,13 @@ class InputGuardrail(Generic[TContext]):
function's name.
"""

block_downstream_calls: bool = True
"""Whether this guardrail should block downstream calls until it completes.
If any input guardrail has this set to True, the initial model call and any
subsequent tool execution will be delayed until all blocking guardrails finish.
Defaults to True for backwards compatibility and safety.
"""

def get_name(self) -> str:
if self.name:
return self.name
Expand Down Expand Up @@ -209,6 +216,7 @@ def input_guardrail(
def input_guardrail(
*,
name: str | None = None,
block_downstream_calls: bool = True,
) -> Callable[
[_InputGuardrailFuncSync[TContext_co] | _InputGuardrailFuncAsync[TContext_co]],
InputGuardrail[TContext_co],
Expand All @@ -221,6 +229,7 @@ def input_guardrail(
| None = None,
*,
name: str | None = None,
block_downstream_calls: bool = True,
) -> (
InputGuardrail[TContext_co]
| Callable[
Expand All @@ -235,7 +244,7 @@ def input_guardrail(
@input_guardrail
def my_sync_guardrail(...): ...

@input_guardrail(name="guardrail_name")
@input_guardrail(name="guardrail_name", block_downstream_calls=False)
async def my_async_guardrail(...): ...
"""

Expand All @@ -246,6 +255,7 @@ def decorator(
guardrail_function=f,
# If not set, guardrail name uses the function’s name by default.
name=name if name else f.__name__,
block_downstream_calls=block_downstream_calls,
)

if func is not None:
Expand Down
6 changes: 5 additions & 1 deletion src/agents/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,11 @@ async def stream_events(self) -> AsyncIterator[StreamEvent]:
break

try:
item = await self._event_queue.get()
# Avoid blocking forever if the background task errors before enqueuing.
item = await asyncio.wait_for(self._event_queue.get(), timeout=0.1)
except asyncio.TimeoutError:
# No item yet; re-check for stored exceptions and completion conditions.
continue
except asyncio.CancelledError:
break

Expand Down
266 changes: 250 additions & 16 deletions src/agents/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,15 +478,45 @@ async def run(
)

if current_turn == 1:
input_guardrail_results, turn_result = await asyncio.gather(
self._run_input_guardrails(
# Separate blocking and non-blocking guardrails for first turn
all_guardrails = starting_agent.input_guardrails + (
run_config.input_guardrails or []
)
blocking_guardrails, non_blocking_guardrails = (
self._separate_blocking_guardrails(all_guardrails)
)

if blocking_guardrails:
# Gate the initial model call on blocking guardrails completing.
# If a tripwire is triggered, this will raise BEFORE any LLM call.
blocking_results = await self._run_input_guardrails(
starting_agent,
starting_agent.input_guardrails
+ (run_config.input_guardrails or []),
blocking_guardrails,
_copy_str_or_list(prepared_input),
context_wrapper,
),
self._run_single_turn(
)
input_guardrail_results.extend(blocking_results)

# Run non-blocking guardrails in parallel with the model call and tools.
non_blocking_task = (
asyncio.create_task(
self._run_input_guardrails(
starting_agent,
non_blocking_guardrails,
_copy_str_or_list(prepared_input),
context_wrapper,
)
)
if non_blocking_guardrails
else None
)

# Now that blocking guardrails completed, get the model response.
(
model_response,
output_schema,
handoffs,
) = await self._get_model_response_only(
agent=current_agent,
all_tools=all_tools,
original_input=original_input,
Expand All @@ -498,8 +528,71 @@ async def run(
tool_use_tracker=tool_use_tracker,
previous_response_id=previous_response_id,
conversation_id=conversation_id,
),
)
)

# Execute tools after receiving the model response.
turn_result = await self._execute_tools_from_model_response(
agent=current_agent,
all_tools=all_tools,
original_input=original_input,
generated_items=generated_items,
new_response=model_response,
output_schema=output_schema,
handoffs=handoffs,
hooks=hooks,
context_wrapper=context_wrapper,
run_config=run_config,
tool_use_tracker=tool_use_tracker,
)

# Collect non-blocking guardrail results if any.
if non_blocking_task is not None:
input_guardrail_results.extend(await non_blocking_task)
else:
# No blocking guardrails - run all guardrails in parallel
# with model/tools.
all_guardrails_task = asyncio.create_task(
self._run_input_guardrails(
starting_agent,
all_guardrails,
_copy_str_or_list(prepared_input),
context_wrapper,
)
)

(
model_response,
output_schema,
handoffs,
) = await self._get_model_response_only(
agent=current_agent,
all_tools=all_tools,
original_input=original_input,
generated_items=generated_items,
hooks=hooks,
context_wrapper=context_wrapper,
run_config=run_config,
should_run_agent_start_hooks=should_run_agent_start_hooks,
tool_use_tracker=tool_use_tracker,
previous_response_id=previous_response_id,
conversation_id=conversation_id,
)

turn_result = await self._execute_tools_from_model_response(
agent=current_agent,
all_tools=all_tools,
original_input=original_input,
generated_items=generated_items,
new_response=model_response,
output_schema=output_schema,
handoffs=handoffs,
hooks=hooks,
context_wrapper=context_wrapper,
run_config=run_config,
tool_use_tracker=tool_use_tracker,
)

input_guardrail_results.extend(await all_guardrails_task)
else:
turn_result = await self._run_single_turn(
agent=current_agent,
Expand Down Expand Up @@ -750,7 +843,10 @@ async def _run_input_guardrails_with_queue(
t.cancel()
raise

streamed_result.input_guardrail_results = guardrail_results
# Append to any existing guardrail results to avoid overwriting
streamed_result.input_guardrail_results = (
streamed_result.input_guardrail_results + guardrail_results
)

@classmethod
async def _start_streaming(
Expand Down Expand Up @@ -825,17 +921,52 @@ async def _start_streaming(
break

if current_turn == 1:
# Run the input guardrails in the background and put the results on the queue
streamed_result._input_guardrails_task = asyncio.create_task(
cls._run_input_guardrails_with_queue(
# On the first turn, separate blocking and non-blocking guardrails.
all_guardrails = starting_agent.input_guardrails + (
run_config.input_guardrails or []
)
(
blocking_guardrails,
non_blocking_guardrails,
) = cls._separate_blocking_guardrails(all_guardrails)

if blocking_guardrails:
# Gate the model streaming by running blocking guardrails to
# completion first. If a tripwire is triggered, this will raise
# BEFORE any LLM call.
blocking_results = await cls._run_input_guardrails(
starting_agent,
starting_agent.input_guardrails + (run_config.input_guardrails or []),
blocking_guardrails,
ItemHelpers.input_to_new_input_list(prepared_input),
context_wrapper,
streamed_result,
current_span,
)
)
# Push blocking results to the guardrail queue for consumers.
for r in blocking_results:
streamed_result._input_guardrail_queue.put_nowait(r)
# Start any non-blocking guardrails in the background.
if non_blocking_guardrails:
streamed_result._input_guardrails_task = asyncio.create_task(
cls._run_input_guardrails_with_queue(
starting_agent,
non_blocking_guardrails,
ItemHelpers.input_to_new_input_list(prepared_input),
context_wrapper,
streamed_result,
current_span,
)
)
else:
# No blocking guardrails - run all guardrails in the background.
streamed_result._input_guardrails_task = asyncio.create_task(
cls._run_input_guardrails_with_queue(
starting_agent,
all_guardrails,
ItemHelpers.input_to_new_input_list(prepared_input),
context_wrapper,
streamed_result,
current_span,
)
)
try:
turn_result = await cls._run_single_turn_streamed(
streamed_result,
Expand Down Expand Up @@ -1260,6 +1391,109 @@ async def _get_single_step_result_from_streamed_response(

return single_step_result

@classmethod
def _separate_blocking_guardrails(
cls,
guardrails: list[InputGuardrail[TContext]],
) -> tuple[list[InputGuardrail[TContext]], list[InputGuardrail[TContext]]]:
"""Separate guardrails into blocking and non-blocking lists."""
blocking: list[InputGuardrail[TContext]] = []
non_blocking: list[InputGuardrail[TContext]] = []

for guardrail in guardrails:
if guardrail.block_downstream_calls:
blocking.append(guardrail)
else:
non_blocking.append(guardrail)

return blocking, non_blocking

@classmethod
async def _get_model_response_only(
cls,
*,
agent: Agent[TContext],
all_tools: list[Tool],
original_input: str | list[TResponseInputItem],
generated_items: list[RunItem],
hooks: RunHooks[TContext],
context_wrapper: RunContextWrapper[TContext],
run_config: RunConfig,
should_run_agent_start_hooks: bool,
tool_use_tracker: AgentToolUseTracker,
previous_response_id: str | None,
conversation_id: str | None,
) -> tuple[ModelResponse, AgentOutputSchemaBase | None, list[Handoff]]:
"""Get model response and metadata without executing tools."""
# Ensure we run the hooks before anything else
if should_run_agent_start_hooks:
await asyncio.gather(
hooks.on_agent_start(context_wrapper, agent),
(
agent.hooks.on_start(context_wrapper, agent)
if agent.hooks
else _coro.noop_coroutine()
),
)

system_prompt, prompt_config = await asyncio.gather(
agent.get_system_prompt(context_wrapper),
agent.get_prompt(context_wrapper),
)

output_schema = cls._get_output_schema(agent)
handoffs = await cls._get_handoffs(agent, context_wrapper)
input_items = ItemHelpers.input_to_new_input_list(original_input)
input_items.extend([gi.to_input_item() for gi in generated_items])

new_response = await cls._get_new_response(
agent,
system_prompt,
input_items,
output_schema,
all_tools,
handoffs,
context_wrapper,
run_config,
tool_use_tracker,
previous_response_id,
conversation_id,
prompt_config,
)

return new_response, output_schema, handoffs

@classmethod
async def _execute_tools_from_model_response(
cls,
*,
agent: Agent[TContext],
all_tools: list[Tool],
original_input: str | list[TResponseInputItem],
generated_items: list[RunItem],
new_response: ModelResponse,
output_schema: AgentOutputSchemaBase | None,
handoffs: list[Handoff],
hooks: RunHooks[TContext],
context_wrapper: RunContextWrapper[TContext],
run_config: RunConfig,
tool_use_tracker: AgentToolUseTracker,
) -> SingleStepResult:
"""Execute tools and side effects from a model response."""
return await cls._get_single_step_result_from_response(
agent=agent,
original_input=original_input,
pre_step_items=generated_items,
new_response=new_response,
output_schema=output_schema,
all_tools=all_tools,
handoffs=handoffs,
hooks=hooks,
context_wrapper=context_wrapper,
run_config=run_config,
tool_use_tracker=tool_use_tracker,
)

@classmethod
async def _run_input_guardrails(
cls,
Expand Down
Loading
Loading