diff --git a/src/agents/guardrail.py b/src/agents/guardrail.py index 99e287675..bdda7b9f6 100644 --- a/src/agents/guardrail.py +++ b/src/agents/guardrail.py @@ -97,6 +97,13 @@ class InputGuardrail(Generic[TContext]): function's name. """ + block_downstream_calls: bool = True + """Whether this guardrail should block downstream calls until it completes. + If any input guardrail has this set to True, the initial model call and any + subsequent tool execution will be delayed until all blocking guardrails finish. + Defaults to True for backwards compatibility and safety. + """ + def get_name(self) -> str: if self.name: return self.name @@ -209,6 +216,7 @@ def input_guardrail( def input_guardrail( *, name: str | None = None, + block_downstream_calls: bool = True, ) -> Callable[ [_InputGuardrailFuncSync[TContext_co] | _InputGuardrailFuncAsync[TContext_co]], InputGuardrail[TContext_co], @@ -221,6 +229,7 @@ def input_guardrail( | None = None, *, name: str | None = None, + block_downstream_calls: bool = True, ) -> ( InputGuardrail[TContext_co] | Callable[ @@ -235,7 +244,7 @@ def input_guardrail( @input_guardrail def my_sync_guardrail(...): ... - @input_guardrail(name="guardrail_name") + @input_guardrail(name="guardrail_name", block_downstream_calls=False) async def my_async_guardrail(...): ... """ @@ -246,6 +255,7 @@ def decorator( guardrail_function=f, # If not set, guardrail name uses the function’s name by default. name=name if name else f.__name__, + block_downstream_calls=block_downstream_calls, ) if func is not None: diff --git a/src/agents/result.py b/src/agents/result.py index 5cf0e74c8..8d35e0d56 100644 --- a/src/agents/result.py +++ b/src/agents/result.py @@ -196,7 +196,11 @@ async def stream_events(self) -> AsyncIterator[StreamEvent]: break try: - item = await self._event_queue.get() + # Avoid blocking forever if the background task errors before enqueuing. + item = await asyncio.wait_for(self._event_queue.get(), timeout=0.1) + except asyncio.TimeoutError: + # No item yet; re-check for stored exceptions and completion conditions. + continue except asyncio.CancelledError: break diff --git a/src/agents/run.py b/src/agents/run.py index 742917b87..6168f0d7b 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -478,15 +478,45 @@ async def run( ) if current_turn == 1: - input_guardrail_results, turn_result = await asyncio.gather( - self._run_input_guardrails( + # Separate blocking and non-blocking guardrails for first turn + all_guardrails = starting_agent.input_guardrails + ( + run_config.input_guardrails or [] + ) + blocking_guardrails, non_blocking_guardrails = ( + self._separate_blocking_guardrails(all_guardrails) + ) + + if blocking_guardrails: + # Gate the initial model call on blocking guardrails completing. + # If a tripwire is triggered, this will raise BEFORE any LLM call. + blocking_results = await self._run_input_guardrails( starting_agent, - starting_agent.input_guardrails - + (run_config.input_guardrails or []), + blocking_guardrails, _copy_str_or_list(prepared_input), context_wrapper, - ), - self._run_single_turn( + ) + input_guardrail_results.extend(blocking_results) + + # Run non-blocking guardrails in parallel with the model call and tools. + non_blocking_task = ( + asyncio.create_task( + self._run_input_guardrails( + starting_agent, + non_blocking_guardrails, + _copy_str_or_list(prepared_input), + context_wrapper, + ) + ) + if non_blocking_guardrails + else None + ) + + # Now that blocking guardrails completed, get the model response. + ( + model_response, + output_schema, + handoffs, + ) = await self._get_model_response_only( agent=current_agent, all_tools=all_tools, original_input=original_input, @@ -498,8 +528,71 @@ async def run( tool_use_tracker=tool_use_tracker, previous_response_id=previous_response_id, conversation_id=conversation_id, - ), - ) + ) + + # Execute tools after receiving the model response. + turn_result = await self._execute_tools_from_model_response( + agent=current_agent, + all_tools=all_tools, + original_input=original_input, + generated_items=generated_items, + new_response=model_response, + output_schema=output_schema, + handoffs=handoffs, + hooks=hooks, + context_wrapper=context_wrapper, + run_config=run_config, + tool_use_tracker=tool_use_tracker, + ) + + # Collect non-blocking guardrail results if any. + if non_blocking_task is not None: + input_guardrail_results.extend(await non_blocking_task) + else: + # No blocking guardrails - run all guardrails in parallel + # with model/tools. + all_guardrails_task = asyncio.create_task( + self._run_input_guardrails( + starting_agent, + all_guardrails, + _copy_str_or_list(prepared_input), + context_wrapper, + ) + ) + + ( + model_response, + output_schema, + handoffs, + ) = await self._get_model_response_only( + agent=current_agent, + all_tools=all_tools, + original_input=original_input, + generated_items=generated_items, + hooks=hooks, + context_wrapper=context_wrapper, + run_config=run_config, + should_run_agent_start_hooks=should_run_agent_start_hooks, + tool_use_tracker=tool_use_tracker, + previous_response_id=previous_response_id, + conversation_id=conversation_id, + ) + + turn_result = await self._execute_tools_from_model_response( + agent=current_agent, + all_tools=all_tools, + original_input=original_input, + generated_items=generated_items, + new_response=model_response, + output_schema=output_schema, + handoffs=handoffs, + hooks=hooks, + context_wrapper=context_wrapper, + run_config=run_config, + tool_use_tracker=tool_use_tracker, + ) + + input_guardrail_results.extend(await all_guardrails_task) else: turn_result = await self._run_single_turn( agent=current_agent, @@ -750,7 +843,10 @@ async def _run_input_guardrails_with_queue( t.cancel() raise - streamed_result.input_guardrail_results = guardrail_results + # Append to any existing guardrail results to avoid overwriting + streamed_result.input_guardrail_results = ( + streamed_result.input_guardrail_results + guardrail_results + ) @classmethod async def _start_streaming( @@ -825,17 +921,52 @@ async def _start_streaming( break if current_turn == 1: - # Run the input guardrails in the background and put the results on the queue - streamed_result._input_guardrails_task = asyncio.create_task( - cls._run_input_guardrails_with_queue( + # On the first turn, separate blocking and non-blocking guardrails. + all_guardrails = starting_agent.input_guardrails + ( + run_config.input_guardrails or [] + ) + ( + blocking_guardrails, + non_blocking_guardrails, + ) = cls._separate_blocking_guardrails(all_guardrails) + + if blocking_guardrails: + # Gate the model streaming by running blocking guardrails to + # completion first. If a tripwire is triggered, this will raise + # BEFORE any LLM call. + blocking_results = await cls._run_input_guardrails( starting_agent, - starting_agent.input_guardrails + (run_config.input_guardrails or []), + blocking_guardrails, ItemHelpers.input_to_new_input_list(prepared_input), context_wrapper, - streamed_result, - current_span, ) - ) + # Push blocking results to the guardrail queue for consumers. + for r in blocking_results: + streamed_result._input_guardrail_queue.put_nowait(r) + # Start any non-blocking guardrails in the background. + if non_blocking_guardrails: + streamed_result._input_guardrails_task = asyncio.create_task( + cls._run_input_guardrails_with_queue( + starting_agent, + non_blocking_guardrails, + ItemHelpers.input_to_new_input_list(prepared_input), + context_wrapper, + streamed_result, + current_span, + ) + ) + else: + # No blocking guardrails - run all guardrails in the background. + streamed_result._input_guardrails_task = asyncio.create_task( + cls._run_input_guardrails_with_queue( + starting_agent, + all_guardrails, + ItemHelpers.input_to_new_input_list(prepared_input), + context_wrapper, + streamed_result, + current_span, + ) + ) try: turn_result = await cls._run_single_turn_streamed( streamed_result, @@ -1260,6 +1391,109 @@ async def _get_single_step_result_from_streamed_response( return single_step_result + @classmethod + def _separate_blocking_guardrails( + cls, + guardrails: list[InputGuardrail[TContext]], + ) -> tuple[list[InputGuardrail[TContext]], list[InputGuardrail[TContext]]]: + """Separate guardrails into blocking and non-blocking lists.""" + blocking: list[InputGuardrail[TContext]] = [] + non_blocking: list[InputGuardrail[TContext]] = [] + + for guardrail in guardrails: + if guardrail.block_downstream_calls: + blocking.append(guardrail) + else: + non_blocking.append(guardrail) + + return blocking, non_blocking + + @classmethod + async def _get_model_response_only( + cls, + *, + agent: Agent[TContext], + all_tools: list[Tool], + original_input: str | list[TResponseInputItem], + generated_items: list[RunItem], + hooks: RunHooks[TContext], + context_wrapper: RunContextWrapper[TContext], + run_config: RunConfig, + should_run_agent_start_hooks: bool, + tool_use_tracker: AgentToolUseTracker, + previous_response_id: str | None, + conversation_id: str | None, + ) -> tuple[ModelResponse, AgentOutputSchemaBase | None, list[Handoff]]: + """Get model response and metadata without executing tools.""" + # Ensure we run the hooks before anything else + if should_run_agent_start_hooks: + await asyncio.gather( + hooks.on_agent_start(context_wrapper, agent), + ( + agent.hooks.on_start(context_wrapper, agent) + if agent.hooks + else _coro.noop_coroutine() + ), + ) + + system_prompt, prompt_config = await asyncio.gather( + agent.get_system_prompt(context_wrapper), + agent.get_prompt(context_wrapper), + ) + + output_schema = cls._get_output_schema(agent) + handoffs = await cls._get_handoffs(agent, context_wrapper) + input_items = ItemHelpers.input_to_new_input_list(original_input) + input_items.extend([gi.to_input_item() for gi in generated_items]) + + new_response = await cls._get_new_response( + agent, + system_prompt, + input_items, + output_schema, + all_tools, + handoffs, + context_wrapper, + run_config, + tool_use_tracker, + previous_response_id, + conversation_id, + prompt_config, + ) + + return new_response, output_schema, handoffs + + @classmethod + async def _execute_tools_from_model_response( + cls, + *, + agent: Agent[TContext], + all_tools: list[Tool], + original_input: str | list[TResponseInputItem], + generated_items: list[RunItem], + new_response: ModelResponse, + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + hooks: RunHooks[TContext], + context_wrapper: RunContextWrapper[TContext], + run_config: RunConfig, + tool_use_tracker: AgentToolUseTracker, + ) -> SingleStepResult: + """Execute tools and side effects from a model response.""" + return await cls._get_single_step_result_from_response( + agent=agent, + original_input=original_input, + pre_step_items=generated_items, + new_response=new_response, + output_schema=output_schema, + all_tools=all_tools, + handoffs=handoffs, + hooks=hooks, + context_wrapper=context_wrapper, + run_config=run_config, + tool_use_tracker=tool_use_tracker, + ) + @classmethod async def _run_input_guardrails( cls, diff --git a/tests/test_agent_runner.py b/tests/test_agent_runner.py index c8ae5b5f2..c3e808432 100644 --- a/tests/test_agent_runner.py +++ b/tests/test_agent_runner.py @@ -526,11 +526,13 @@ def guardrail_function( tripwire_triggered=True, ) - agent = Agent( - name="test", input_guardrails=[InputGuardrail(guardrail_function=guardrail_function)] - ) model = FakeModel() model.set_next_output([get_text_message("user_message")]) + agent = Agent( + name="test", + input_guardrails=[InputGuardrail(guardrail_function=guardrail_function)], + model=model, + ) with pytest.raises(InputGuardrailTripwireTriggered): await Runner.run(agent, input="user_message") @@ -780,3 +782,213 @@ async def add_tool() -> str: assert executed["called"] is True assert result.final_output == "done" + + +@pytest.mark.asyncio +async def test_blocking_input_guardrail_tripwire_prevents_model_call(): + """Blocking input guardrails should prevent the initial model call if tripped.""" + def guardrail_function( + context: RunContextWrapper[Any], agent: Agent[Any], input: Any + ) -> GuardrailFunctionOutput: + return GuardrailFunctionOutput(output_info=None, tripwire_triggered=True) + + model = FakeModel() + model.set_next_output([get_text_message("should_not_be_called")]) + agent = Agent( + name="test", + input_guardrails=[InputGuardrail(guardrail_function=guardrail_function)], + model=model, + ) + + with pytest.raises(InputGuardrailTripwireTriggered): + await Runner.run(agent, input="user_message") + + # Ensure model was never invoked + assert model.last_turn_args == {} + + +@pytest.mark.asyncio +async def test_blocking_input_guardrail_delays_tool_execution(): + """Test that blocking input guardrails delay tool execution until they complete.""" + import asyncio + + execution_order = [] + + # Create a slow blocking guardrail + async def slow_blocking_guardrail( + context: RunContextWrapper[Any], agent: Agent[Any], input: Any + ) -> GuardrailFunctionOutput: + execution_order.append("guardrail_start") + await asyncio.sleep(0.1) # Simulate slow guardrail + execution_order.append("guardrail_end") + return GuardrailFunctionOutput( + output_info="blocking_completed", + tripwire_triggered=False, + ) + + # Create a tool that tracks when it's called + @function_tool + async def test_tool() -> str: + execution_order.append("tool_executed") + return "tool_result" + + # Create agent with blocking guardrail and tool + blocking_guardrail = InputGuardrail( + guardrail_function=slow_blocking_guardrail, block_downstream_calls=True + ) + + model = FakeModel() + agent = Agent( + name="test", + input_guardrails=[blocking_guardrail], + tools=[test_tool], + model=model, + ) + + # Model output that calls the tool + model.add_multiple_turn_outputs( + [[get_function_tool_call("test_tool", "{}")], [get_text_message("completed")]] + ) + + result = await Runner.run(agent, input="test input") + + # Verify execution order: guardrail must complete before tool is executed + assert execution_order == ["guardrail_start", "guardrail_end", "tool_executed"] + assert result.final_output == "completed" + assert len(result.input_guardrail_results) == 1 + assert result.input_guardrail_results[0].output.output_info == "blocking_completed" + + +@pytest.mark.asyncio +async def test_non_blocking_input_guardrail_allows_parallel_tool_execution(): + """Test that non-blocking input guardrails allow tool execution in parallel.""" + import asyncio + + execution_order = [] + tool_started = asyncio.Event() + + # Create a slow non-blocking guardrail + async def slow_non_blocking_guardrail( + context: RunContextWrapper[Any], agent: Agent[Any], input: Any + ) -> GuardrailFunctionOutput: + execution_order.append("guardrail_start") + # Wait for tool to start before finishing + await tool_started.wait() + execution_order.append("guardrail_end") + return GuardrailFunctionOutput( + output_info="non_blocking_completed", + tripwire_triggered=False, + ) + + # Create a tool that signals when it starts + @function_tool + async def test_tool() -> str: + execution_order.append("tool_start") + tool_started.set() + await asyncio.sleep(0.05) # Small delay + execution_order.append("tool_end") + return "tool_result" + + # Create agent with non-blocking guardrail and tool + non_blocking_guardrail = InputGuardrail( + guardrail_function=slow_non_blocking_guardrail, block_downstream_calls=False + ) + + model = FakeModel() + agent = Agent( + name="test", + input_guardrails=[non_blocking_guardrail], + tools=[test_tool], + model=model, + ) + + # Model output that calls the tool + model.add_multiple_turn_outputs( + [[get_function_tool_call("test_tool", "{}")], [get_text_message("completed")]] + ) + + result = await Runner.run(agent, input="test input") + + # Verify execution order: tool should start before guardrail finishes + assert execution_order == ["guardrail_start", "tool_start", "guardrail_end", "tool_end"] + assert result.final_output == "completed" + assert len(result.input_guardrail_results) == 1 + assert result.input_guardrail_results[0].output.output_info == "non_blocking_completed" + + +@pytest.mark.asyncio +async def test_mixed_blocking_and_non_blocking_guardrails(): + """Test behavior when both blocking and non-blocking guardrails are present.""" + import asyncio + + execution_order = [] + + # Create blocking and non-blocking guardrails + async def blocking_guardrail( + context: RunContextWrapper[Any], agent: Agent[Any], input: Any + ) -> GuardrailFunctionOutput: + execution_order.append("blocking_start") + await asyncio.sleep(0.1) + execution_order.append("blocking_end") + return GuardrailFunctionOutput( + output_info="blocking_done", + tripwire_triggered=False, + ) + + async def non_blocking_guardrail( + context: RunContextWrapper[Any], agent: Agent[Any], input: Any + ) -> GuardrailFunctionOutput: + execution_order.append("non_blocking_start") + await asyncio.sleep(0.15) # Takes longer than blocking + execution_order.append("non_blocking_end") + return GuardrailFunctionOutput( + output_info="non_blocking_done", + tripwire_triggered=False, + ) + + @function_tool + async def test_tool() -> str: + execution_order.append("tool_executed") + return "tool_result" + + # Create agent with both types of guardrails + blocking_gr = InputGuardrail( + guardrail_function=blocking_guardrail, block_downstream_calls=True + ) + non_blocking_gr = InputGuardrail( + guardrail_function=non_blocking_guardrail, block_downstream_calls=False + ) + + model = FakeModel() + agent = Agent( + name="test", + input_guardrails=[blocking_gr, non_blocking_gr], + tools=[test_tool], + model=model, + ) + + model.add_multiple_turn_outputs( + [[get_function_tool_call("test_tool", "{}")], [get_text_message("completed")]] + ) + + result = await Runner.run(agent, input="test input") + + # Expected execution order + # blocking guardrail run first because you need to wait for it to run + # subsequent runs + # then non-blocking guardrail runs + # then tool runs + # then non-blocking guardrail finishes + expected_execution_order = [ + "blocking_start", + "blocking_end", + "non_blocking_start", + "tool_executed", + "non_blocking_end", + ] + + # Check that the execution order is as expected + assert execution_order == expected_execution_order + + assert result.final_output == "completed" + assert len(result.input_guardrail_results) == 2 diff --git a/tests/test_agent_runner_streamed.py b/tests/test_agent_runner_streamed.py index d4afbd2e0..c794c2b44 100644 --- a/tests/test_agent_runner_streamed.py +++ b/tests/test_agent_runner_streamed.py @@ -521,6 +521,32 @@ def guardrail_function( pass +@pytest.mark.asyncio +async def test_blocking_input_guardrail_tripwire_prevents_model_stream_start(): + """Blocking input guardrails should prevent starting model streaming if tripped.""" + def guardrail_function( + context: RunContextWrapper[Any], agent: Agent[Any], input: Any + ) -> GuardrailFunctionOutput: + return GuardrailFunctionOutput(output_info=None, tripwire_triggered=True) + + model = FakeModel() + model.set_next_output([get_text_message("should_not_stream")]) + agent = Agent( + name="test", + input_guardrails=[InputGuardrail(guardrail_function=guardrail_function)], + model=model, + ) + + result = Runner.run_streamed(agent, input="user_message") + + with pytest.raises(InputGuardrailTripwireTriggered): + async for _ in result.stream_events(): + pass + + # Ensure model streaming was never invoked + assert model.last_turn_args == {} + + @pytest.mark.asyncio async def test_output_guardrail_tripwire_triggered_causes_exception_streamed(): def guardrail_function( diff --git a/tests/test_guardrails.py b/tests/test_guardrails.py index c9f318c32..8a8dce812 100644 --- a/tests/test_guardrails.py +++ b/tests/test_guardrails.py @@ -260,3 +260,64 @@ async def test_output_guardrail_decorators(): assert not result.output.tripwire_triggered assert result.output.output_info == "test_4" assert guardrail.get_name() == "Custom name" + + +@input_guardrail(block_downstream_calls=False) +def non_blocking_input_guardrail( + context: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem] +) -> GuardrailFunctionOutput: + return GuardrailFunctionOutput( + output_info="non_blocking", + tripwire_triggered=False, + ) + + +@input_guardrail(block_downstream_calls=True) +def blocking_input_guardrail( + context: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem] +) -> GuardrailFunctionOutput: + return GuardrailFunctionOutput( + output_info="blocking", + tripwire_triggered=False, + ) + + +@pytest.mark.asyncio +async def test_input_guardrail_block_downstream_calls_parameter(): + """Test that the block_downstream_calls parameter is properly set.""" + # Test decorator with block_downstream_calls=False + guardrail = non_blocking_input_guardrail + assert guardrail.block_downstream_calls is False + + # Test decorator with block_downstream_calls=True (explicit) + guardrail = blocking_input_guardrail + assert guardrail.block_downstream_calls is True + + # Test default behavior (should be True) + @input_guardrail + def default_guardrail( + context: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem] + ) -> GuardrailFunctionOutput: + return GuardrailFunctionOutput(output_info="default", tripwire_triggered=False) + + assert default_guardrail.block_downstream_calls is True + + +@pytest.mark.asyncio +async def test_input_guardrail_manual_creation_with_block_downstream_calls(): + """Test creating InputGuardrail manually with block_downstream_calls parameter.""" + + def test_func(context, agent, input): + return GuardrailFunctionOutput(output_info="test", tripwire_triggered=False) + + # Test explicit True + guardrail = InputGuardrail(guardrail_function=test_func, block_downstream_calls=True) + assert guardrail.block_downstream_calls is True + + # Test explicit False + guardrail = InputGuardrail(guardrail_function=test_func, block_downstream_calls=False) + assert guardrail.block_downstream_calls is False + + # Test default (should be True) + guardrail = InputGuardrail(guardrail_function=test_func) + assert guardrail.block_downstream_calls is True diff --git a/tests/test_tracing_errors.py b/tests/test_tracing_errors.py index 72bd39eda..db52250bb 100644 --- a/tests/test_tracing_errors.py +++ b/tests/test_tracing_errors.py @@ -506,11 +506,13 @@ def guardrail_function( @pytest.mark.asyncio async def test_guardrail_error(): - agent = Agent( - name="test", input_guardrails=[InputGuardrail(guardrail_function=guardrail_function)] - ) model = FakeModel() model.set_next_output([get_text_message("some_message")]) + agent = Agent( + name="test", + input_guardrails=[InputGuardrail(guardrail_function=guardrail_function)], + model=model, + ) with pytest.raises(InputGuardrailTripwireTriggered): await Runner.run(agent, input="user_message") diff --git a/tests/test_tracing_errors_streamed.py b/tests/test_tracing_errors_streamed.py index 40efef3fa..ec79dae63 100644 --- a/tests/test_tracing_errors_streamed.py +++ b/tests/test_tracing_errors_streamed.py @@ -543,10 +543,7 @@ async def test_input_guardrail_error(): "type": "agent", "error": { "message": "Guardrail tripwire triggered", - "data": { - "guardrail": "input_guardrail_function", - "type": "input_guardrail", - }, + "data": {"guardrail": "input_guardrail_function"}, }, "data": {"name": "test", "handoffs": [], "tools": [], "output_type": "str"}, "children": [