From 724f6ab43824cf254b3b2450a5420105f6537b24 Mon Sep 17 00:00:00 2001 From: Mohamed Amri Date: Tue, 25 Mar 2025 13:28:22 +0400 Subject: [PATCH 01/23] Adding fact checking decorator --- src/agents/guardrail.py | 152 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 152 insertions(+) diff --git a/src/agents/guardrail.py b/src/agents/guardrail.py index a96f0f7d..44f207a1 100644 --- a/src/agents/guardrail.py +++ b/src/agents/guardrail.py @@ -67,6 +67,32 @@ class OutputGuardrailResult: output: GuardrailFunctionOutput """The output of the guardrail function.""" +@dataclass +class FactCheckingGuardrailResult: + """The result of a guardrail run.""" + + guardrail: FactCheckingGuardrail[Any] + """ + The guardrail that was run. + """ + + agent_input: Any + """ + The input of the agent that was checked by the guardrail. + """ + + agent_output: Any + """ + The output of the agent that was checked by the guardrail. + """ + + agent: Agent[Any] + """ + The agent that was checked by the guardrail. + """ + + output: GuardrailFunctionOutput + """The output of the guardrail function.""" @dataclass class InputGuardrail(Generic[TContext]): @@ -179,6 +205,63 @@ async def run( ) +@dataclass +class FactCheckingGuardrail(Generic[TContext]): + """Fact checking guardrails are checks that run on the final output and the input of an agent. + They can be used to do check if the output passes certain validation criteria + + You can use the `@fact_checking_guardrail()` decorator to turn a function into an `OutputGuardrail`, + or create an `OutputGuardrail` manually. + + Guardrails return a `GuardrailResult`. If `result.tripwire_triggered` is `True`, a + `OutputGuardrailTripwireTriggered` exception will be raised. + """ + + guardrail_function: Callable[ + [RunContextWrapper[TContext], Agent[Any], Any, Any], + MaybeAwaitable[GuardrailFunctionOutput], + ] + """A function that receives the final agent, its output, and the context, and returns a + `GuardrailResult`. The result marks whether the tripwire was triggered, and can optionally + include information about the guardrail's output. + """ + + name: str | None = None + """The name of the guardrail, used for tracing. If not provided, we'll use the guardrail + function's name. + """ + + def get_name(self) -> str: + if self.name: + return self.name + + return self.guardrail_function.__name__ + + async def run( + self, context: RunContextWrapper[TContext], agent: Agent[Any], agent_output: Any, agent_input: Any + ) -> FactCheckingGuardrailResult: + if not callable(self.guardrail_function): + raise UserError(f"Guardrail function must be callable, got {self.guardrail_function}") + + output = self.guardrail_function(context, agent, agent_output, agent_input) + if inspect.isawaitable(output): + return FactCheckingGuardrailResult( + guardrail=self, + agent=agent, + agent_input=agent_input, + agent_output=agent_output, + output=await output, + ) + + return FactCheckingGuardrailResult( + guardrail=self, + agent=agent, + agent_input=agent_input, + agent_output=agent_output, + output=output, + ) + + TContext_co = TypeVar("TContext_co", bound=Any, covariant=True) # For InputGuardrail @@ -318,3 +401,72 @@ def decorator( # Decorator used with keyword arguments return decorator + + +_FactCheckingGuardrailFuncSync = Callable[ + [RunContextWrapper[TContext_co], "Agent[Any]", Any, Any], + GuardrailFunctionOutput, +] +_FactCheckingGuardrailAsync = Callable[ + [RunContextWrapper[TContext_co], "Agent[Any]", Any, Any], + Awaitable[GuardrailFunctionOutput], +] + + +@overload +def fact_checking_guardrail( + func: _FactCheckingGuardrailFuncSync[TContext_co], +) -> OutputGuardrail[TContext_co]: ... + + +@overload +def fact_checking_guardrail( + func: _FactCheckingGuardrailAsync[TContext_co], +) -> OutputGuardrail[TContext_co]: ... + + +@overload +def fact_checking_guardrail( + *, + name: str | None = None, +) -> Callable[ + [_FactCheckingGuardrailFuncSync[TContext_co] | _FactCheckingGuardrailAsync[TContext_co]], + OutputGuardrail[TContext_co], +]: ... + + +def fact_checking_guardrail( + func: _FactCheckingGuardrailFuncSync[TContext_co] + | _FactCheckingGuardrailAsync[TContext_co] + | None = None, + *, + name: str | None = None, +) -> ( + OutputGuardrail[TContext_co] + | Callable[ + [_FactCheckingGuardrailFuncSync[TContext_co] | _FactCheckingGuardrailAsync[TContext_co]], + OutputGuardrail[TContext_co], + ] +): + """ + Decorator that transforms a sync or async function into an `OutputGuardrail`. + It can be used directly (no parentheses) or with keyword args, e.g.: + + @output_guardrail + def my_sync_guardrail(...): ... + + @output_guardrail(name="guardrail_name") + async def my_async_guardrail(...): ... + """ + + def decorator( + f: _FactCheckingGuardrailFuncSync[TContext_co] | _FactCheckingGuardrailAsync[TContext_co], + ) -> OutputGuardrail[TContext_co]: + return OutputGuardrail(guardrail_function=f, name=name) + + if func is not None: + # Decorator was used without parentheses + return decorator(func) + + # Decorator used with keyword arguments + return decorator From 3a09da000f4b3930ec6976c1e01b09a5618c5ff8 Mon Sep 17 00:00:00 2001 From: Mohamed Amri Date: Tue, 25 Mar 2025 13:42:14 +0400 Subject: [PATCH 02/23] Adding fact checking decorator --- src/agents/guardrail.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/agents/guardrail.py b/src/agents/guardrail.py index 44f207a1..69c2f197 100644 --- a/src/agents/guardrail.py +++ b/src/agents/guardrail.py @@ -442,27 +442,27 @@ def fact_checking_guardrail( *, name: str | None = None, ) -> ( - OutputGuardrail[TContext_co] + FactCheckingGuardrail[TContext_co] | Callable[ [_FactCheckingGuardrailFuncSync[TContext_co] | _FactCheckingGuardrailAsync[TContext_co]], - OutputGuardrail[TContext_co], + FactCheckingGuardrail[TContext_co], ] ): """ Decorator that transforms a sync or async function into an `OutputGuardrail`. It can be used directly (no parentheses) or with keyword args, e.g.: - @output_guardrail + @fact_checking_guardrail def my_sync_guardrail(...): ... - @output_guardrail(name="guardrail_name") + @fact_checking_guardrail(name="guardrail_name") async def my_async_guardrail(...): ... """ def decorator( f: _FactCheckingGuardrailFuncSync[TContext_co] | _FactCheckingGuardrailAsync[TContext_co], - ) -> OutputGuardrail[TContext_co]: - return OutputGuardrail(guardrail_function=f, name=name) + ) -> FactCheckingGuardrail[TContext_co]: + return FactCheckingGuardrail(guardrail_function=f, name=name) if func is not None: # Decorator was used without parentheses From b8e43f3d6e9a72fdc1779276c6ba216d15bb4eb1 Mon Sep 17 00:00:00 2001 From: Mohamed Amri Date: Tue, 25 Mar 2025 15:04:12 +0400 Subject: [PATCH 03/23] Updating result.py --- src/agents/_run_impl.py | 16 ++++++++++++- src/agents/exceptions.py | 14 ++++++++++- src/agents/guardrail.py | 2 +- src/agents/result.py | 16 ++++++++++--- src/agents/run.py | 52 +++++++++++++++++++++++++++++++++++++++- 5 files changed, 93 insertions(+), 7 deletions(-) diff --git a/src/agents/_run_impl.py b/src/agents/_run_impl.py index 2849538d..9f0c8148 100644 --- a/src/agents/_run_impl.py +++ b/src/agents/_run_impl.py @@ -31,7 +31,7 @@ from .agent_output import AgentOutputSchema from .computer import AsyncComputer, Computer from .exceptions import AgentsException, ModelBehaviorError, UserError -from .guardrail import InputGuardrail, InputGuardrailResult, OutputGuardrail, OutputGuardrailResult +from .guardrail import InputGuardrail, InputGuardrailResult, OutputGuardrail, OutputGuardrailResult, FactCheckingGuardrail, FactCheckingGuardrailResult from .handoffs import Handoff, HandoffInputData from .items import ( HandoffCallItem, @@ -658,6 +658,20 @@ async def run_single_output_guardrail( span_guardrail.span_data.triggered = result.output.tripwire_triggered return result + @classmethod + async def run_single_fact_checking_guardrail( + cls, + guardrail: FactCheckingGuardrail[TContext], + agent: Agent[Any], + agent_output: Any, + context: RunContextWrapper[TContext], + agent_input: Any, + ) -> FactCheckingGuardrailResult: + with guardrail_span(guardrail.get_name()) as span_guardrail: + result = await guardrail.run(agent=agent, agent_output=agent_output, context=context, agent_input=agent_input) + span_guardrail.span_data.triggered = result.output.tripwire_triggered + return result + @classmethod def stream_step_result_to_queue( cls, diff --git a/src/agents/exceptions.py b/src/agents/exceptions.py index 78898f01..e3f82cea 100644 --- a/src/agents/exceptions.py +++ b/src/agents/exceptions.py @@ -1,7 +1,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from .guardrail import InputGuardrailResult, OutputGuardrailResult + from .guardrail import InputGuardrailResult, OutputGuardrailResult, FactCheckingGuardrailResult class AgentsException(Exception): @@ -61,3 +61,15 @@ def __init__(self, guardrail_result: "OutputGuardrailResult"): super().__init__( f"Guardrail {guardrail_result.guardrail.__class__.__name__} triggered tripwire" ) + +class FactCheckingGuardrailTripwireTriggered(AgentsException): + """Exception raised when a guardrail tripwire is triggered.""" + + guardrail_result: "FactCheckingGuardrailResult" + """The result data of the guardrail that was triggered.""" + + def __init__(self, guardrail_result: "FactCheckingGuardrailResult"): + self.guardrail_result = guardrail_result + super().__init__( + f"Guardrail {guardrail_result.guardrail.__class__.__name__} triggered tripwire" + ) diff --git a/src/agents/guardrail.py b/src/agents/guardrail.py index 69c2f197..463b9cbc 100644 --- a/src/agents/guardrail.py +++ b/src/agents/guardrail.py @@ -449,7 +449,7 @@ def fact_checking_guardrail( ] ): """ - Decorator that transforms a sync or async function into an `OutputGuardrail`. + Decorator that transforms a sync or async function into an `FactCheckingGuardrail`. It can be used directly (no parentheses) or with keyword args, e.g.: @fact_checking_guardrail diff --git a/src/agents/result.py b/src/agents/result.py index 40a64806..9daf615b 100644 --- a/src/agents/result.py +++ b/src/agents/result.py @@ -8,11 +8,9 @@ from typing_extensions import TypeVar -from ._run_impl import QueueCompleteSentinel -from .agent import Agent from .agent_output import AgentOutputSchema from .exceptions import InputGuardrailTripwireTriggered, MaxTurnsExceeded -from .guardrail import InputGuardrailResult, OutputGuardrailResult +from .guardrail import InputGuardrailResult, OutputGuardrailResult, FactCheckingGuardrailResult from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem from .logger import logger from .stream_events import StreamEvent @@ -50,6 +48,9 @@ class RunResultBase(abc.ABC): output_guardrail_results: list[OutputGuardrailResult] """Guardrail results for the final output of the agent.""" + fact_checking_guardrail_results: list[FactCheckingGuardrailResult] + """Guardrail results for the original input and the final output of the agent.""" + @property @abc.abstractmethod def last_agent(self) -> Agent[Any]: @@ -135,6 +136,7 @@ class RunResultStreaming(RunResultBase): _run_impl_task: asyncio.Task[Any] | None = field(default=None, repr=False) _input_guardrails_task: asyncio.Task[Any] | None = field(default=None, repr=False) _output_guardrails_task: asyncio.Task[Any] | None = field(default=None, repr=False) + _fact_checking_guardrails_task: asyncio.Task[Any] | None = field(default=None, repr=False) _stored_exception: Exception | None = field(default=None, repr=False) @property @@ -211,6 +213,11 @@ def _check_errors(self): if exc and isinstance(exc, Exception): self._stored_exception = exc + if self._fact_checking_guardrails_task and self._fact_checking_guardrails_task.done(): + exc = self._fact_checking_guardrails_task.exception() + if exc and isinstance(exc, Exception): + self._stored_exception = exc + def _cleanup_tasks(self): if self._run_impl_task and not self._run_impl_task.done(): self._run_impl_task.cancel() @@ -221,5 +228,8 @@ def _cleanup_tasks(self): if self._output_guardrails_task and not self._output_guardrails_task.done(): self._output_guardrails_task.cancel() + if self._fact_checking_guardrails_task and not self._fact_checking_guardrails_task.done(): + self._fact_checking_guardrails_task.cancel() + def __str__(self) -> str: return pretty_print_run_result_streaming(self) diff --git a/src/agents/run.py b/src/agents/run.py index 934400fe..df22f92b 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -25,8 +25,16 @@ MaxTurnsExceeded, ModelBehaviorError, OutputGuardrailTripwireTriggered, + FactCheckingGuardrailTripwireTriggered, +) +from .guardrail import ( + InputGuardrail, + InputGuardrailResult, + OutputGuardrail, + OutputGuardrailResult, + FactCheckingGuardrail, + FactCheckingGuardrailResult ) -from .guardrail import InputGuardrail, InputGuardrailResult, OutputGuardrail, OutputGuardrailResult from .handoffs import Handoff, HandoffInputFilter, handoff from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem from .lifecycle import RunHooks @@ -74,6 +82,9 @@ class RunConfig: output_guardrails: list[OutputGuardrail[Any]] | None = None """A list of output guardrails to run on the final output of the run.""" + fact_checking_guardrails: list[FactCheckingGuardrail[Any]] | None = None + """A list of fact checking guardrails to run on the original input and the final output of the run.""" + tracing_disabled: bool = False """Whether tracing is disabled for the agent run. If disabled, we will not trace the agent run. """ @@ -846,6 +857,45 @@ async def _run_output_guardrails( return guardrail_results + @classmethod + async def _run_fact_checking_guardrails( + cls, + guardrails: list[FactCheckingGuardrail[TContext]], + agent: Agent[TContext], + agent_output: Any, + agent_input: Any, + context: RunContextWrapper[TContext], + ) -> list[FactCheckingGuardrailResult]: + if not guardrails: + return [] + + guardrail_tasks = [ + asyncio.create_task( + RunImpl.run_single_fact_checking_guardrail(guardrail, agent, agent_output, context, agent_input) + ) + for guardrail in guardrails + ] + + guardrail_results = [] + + for done in asyncio.as_completed(guardrail_tasks): + result = await done + if result.output.tripwire_triggered: + # Cancel all guardrail tasks if a tripwire is triggered. + for t in guardrail_tasks: + t.cancel() + _error_tracing.attach_error_to_current_span( + SpanError( + message="Guardrail tripwire triggered", + data={"guardrail": result.guardrail.get_name()}, + ) + ) + raise FactCheckingGuardrailTripwireTriggered(result) + else: + guardrail_results.append(result) + + return guardrail_results + @classmethod async def _get_new_response( cls, From b5c3b9132bf9b39d4a482e87cf9b3386e10edbea Mon Sep 17 00:00:00 2001 From: Mohamed Amri Date: Tue, 25 Mar 2025 15:18:09 +0400 Subject: [PATCH 04/23] Updatin run --- src/agents/agent.py | 7 ++++++- src/agents/run.py | 10 +++++++++- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/src/agents/agent.py b/src/agents/agent.py index 2723e678..a71bcedb 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -8,7 +8,7 @@ from typing_extensions import TypeAlias, TypedDict -from .guardrail import InputGuardrail, OutputGuardrail +from .guardrail import InputGuardrail, OutputGuardrail, FactCheckingGuardrail from .handoffs import Handoff from .items import ItemHelpers from .logger import logger @@ -117,6 +117,11 @@ class Agent(Generic[TContext]): Runs only if the agent produces a final output. """ + fact_checking_guardrails: list[FactCheckingGuardrail[TContext]] = field(default_factory=list) + """A list of checks that run on the original input and the final output of the agent, after generating a response. + Runs only if the agent produces a final output. + """ + output_type: type[Any] | None = None """The type of the output object. If not provided, the output will be `str`.""" diff --git a/src/agents/run.py b/src/agents/run.py index df22f92b..bd97fdcf 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -259,6 +259,13 @@ async def run( turn_result.next_step.output, context_wrapper, ) + fact_checking_guardrail_results = await cls._run_fact_checking_guardrails( + current_agent.fact_checking_guardrails + (run_config.fact_checking_guardrails or []), + current_agent, + turn_result.next_step.output, + context_wrapper, + original_input + ) return RunResult( input=original_input, new_items=generated_items, @@ -267,6 +274,7 @@ async def run( _last_agent=current_agent, input_guardrail_results=input_guardrail_results, output_guardrail_results=output_guardrail_results, + fact_checking_guardrail_results=fact_checking_guardrail_results, ) elif isinstance(turn_result.next_step, NextStepHandoff): current_agent = cast(Agent[TContext], turn_result.next_step.new_agent) @@ -863,8 +871,8 @@ async def _run_fact_checking_guardrails( guardrails: list[FactCheckingGuardrail[TContext]], agent: Agent[TContext], agent_output: Any, - agent_input: Any, context: RunContextWrapper[TContext], + agent_input: Any, ) -> list[FactCheckingGuardrailResult]: if not guardrails: return [] From d7380b51e03b6f2f4d9643230658b865d170d773 Mon Sep 17 00:00:00 2001 From: Mohamed Amri Date: Tue, 25 Mar 2025 15:23:11 +0400 Subject: [PATCH 05/23] Updating run --- src/agents/run.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/agents/run.py b/src/agents/run.py index bd97fdcf..c2ef25cf 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -424,6 +424,7 @@ def run_streamed( max_turns=max_turns, input_guardrail_results=[], output_guardrail_results=[], + fact_checking_guardrail_results=[], _current_agent_output_schema=output_schema, _trace=new_trace, ) From 44495cedd2330a3675be79cfb132296cec476837 Mon Sep 17 00:00:00 2001 From: Mohamed Amri Date: Tue, 25 Mar 2025 15:53:21 +0400 Subject: [PATCH 06/23] Updating run --- src/agents/_run_impl.py | 9 ++++++++- src/agents/run.py | 18 ++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/src/agents/_run_impl.py b/src/agents/_run_impl.py index 9f0c8148..897265ee 100644 --- a/src/agents/_run_impl.py +++ b/src/agents/_run_impl.py @@ -31,7 +31,14 @@ from .agent_output import AgentOutputSchema from .computer import AsyncComputer, Computer from .exceptions import AgentsException, ModelBehaviorError, UserError -from .guardrail import InputGuardrail, InputGuardrailResult, OutputGuardrail, OutputGuardrailResult, FactCheckingGuardrail, FactCheckingGuardrailResult +from .guardrail import ( + InputGuardrail, + InputGuardrailResult, + OutputGuardrail, + OutputGuardrailResult, + FactCheckingGuardrail, + FactCheckingGuardrailResult +) from .handoffs import Handoff, HandoffInputData from .items import ( HandoffCallItem, diff --git a/src/agents/run.py b/src/agents/run.py index c2ef25cf..116a9885 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -589,13 +589,31 @@ async def _run_streamed_impl( ) ) + streamed_result._fact_checking_guardrails_task = asyncio.create_task( + cls._run_fact_checking_guardrails( + current_agent.fact_checking_guardrails + + (run_config.fact_checking_guardrails or []), + current_agent, + turn_result.next_step.output, + context_wrapper, + copy.deepcopy(ItemHelpers.input_to_new_input_list(starting_input)), + ) + ) + try: output_guardrail_results = await streamed_result._output_guardrails_task except Exception: # Exceptions will be checked in the stream_events loop output_guardrail_results = [] + try: + fact_checking_guardrails_results = await streamed_result._fact_checking_guardrails_task + except Exception: + # Exceptions will be checked in the stream_events loop + fact_checking_guardrails_results = [] + streamed_result.output_guardrail_results = output_guardrail_results + streamed_result.fact_checking_guardrails = fact_checking_guardrails_results streamed_result.final_output = turn_result.next_step.output streamed_result.is_complete = True streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) From c68d2e4fce9770a5757403f1211a0acfa9dead41 Mon Sep 17 00:00:00 2001 From: Mohamed Amri Date: Tue, 25 Mar 2025 16:09:43 +0400 Subject: [PATCH 07/23] Updating __init__ --- src/agents/__init__.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/agents/__init__.py b/src/agents/__init__.py index 47bb2649..4004f862 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -15,6 +15,7 @@ ModelBehaviorError, OutputGuardrailTripwireTriggered, UserError, + FactCheckingGuardrailTripwireTriggered, ) from .guardrail import ( GuardrailFunctionOutput, @@ -22,8 +23,11 @@ InputGuardrailResult, OutputGuardrail, OutputGuardrailResult, + FactCheckingGuardrail, + FactCheckingGuardrailResult, input_guardrail, output_guardrail, + fact_checking_guardrail, ) from .handoffs import Handoff, HandoffInputData, HandoffInputFilter, handoff from .items import ( @@ -162,6 +166,7 @@ def enable_verbose_stdout_logging(): "AgentsException", "InputGuardrailTripwireTriggered", "OutputGuardrailTripwireTriggered", + "FactCheckingGuardrailTripwireTriggered", "MaxTurnsExceeded", "ModelBehaviorError", "UserError", @@ -169,9 +174,12 @@ def enable_verbose_stdout_logging(): "InputGuardrailResult", "OutputGuardrail", "OutputGuardrailResult", + "FactCheckingGuardrail", + "FactCheckingGuardrailResult", "GuardrailFunctionOutput", "input_guardrail", "output_guardrail", + "fact_checking_guardrail", "handoff", "Handoff", "HandoffInputData", From 2d3023ad3e7333d201665187083f997dadae7466 Mon Sep 17 00:00:00 2001 From: Mohamed Amri Date: Tue, 25 Mar 2025 17:24:44 +0400 Subject: [PATCH 08/23] Updating unit tests --- src/agents/guardrail.py | 1 + tests/test_guardrails.py | 41 ++++++++++++++++++++++++++++++++++++++-- 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/src/agents/guardrail.py b/src/agents/guardrail.py index 463b9cbc..85842fd5 100644 --- a/src/agents/guardrail.py +++ b/src/agents/guardrail.py @@ -189,6 +189,7 @@ async def run( raise UserError(f"Guardrail function must be callable, got {self.guardrail_function}") output = self.guardrail_function(context, agent, agent_output) + if inspect.isawaitable(output): return OutputGuardrailResult( guardrail=self, diff --git a/tests/test_guardrails.py b/tests/test_guardrails.py index c9f318c3..0bac8404 100644 --- a/tests/test_guardrails.py +++ b/tests/test_guardrails.py @@ -4,16 +4,17 @@ import pytest -from agents import ( +from src.agents import ( Agent, GuardrailFunctionOutput, InputGuardrail, OutputGuardrail, + FactCheckingGuardrail, RunContextWrapper, TResponseInputItem, UserError, ) -from agents.guardrail import input_guardrail, output_guardrail +from src.agents.guardrail import input_guardrail, output_guardrail, fact_checking_guardrail def get_sync_guardrail(triggers: bool, output_info: Any | None = None): @@ -260,3 +261,39 @@ async def test_output_guardrail_decorators(): assert not result.output.tripwire_triggered assert result.output.output_info == "test_4" assert guardrail.get_name() == "Custom name" + + +def get_sync_fact_checking_guardrail(triggers: bool, output_info: Any | None = None): + def sync_guardrail(context: RunContextWrapper[Any], agent: Agent[Any], agent_output: Any, agent_input: Any): + return GuardrailFunctionOutput( + output_info=output_info, + tripwire_triggered=triggers, + ) + + return sync_guardrail + + +@pytest.mark.asyncio +async def test_sync_fact_guardrail(): + guardrail = FactCheckingGuardrail(guardrail_function=get_sync_fact_checking_guardrail(triggers=False)) + result = await guardrail.run( + agent=Agent(name="test"), agent_input="test", agent_output="test", context=RunContextWrapper(context=None) + ) + assert not result.output.tripwire_triggered + assert result.output.output_info is None + + guardrail = FactCheckingGuardrail(guardrail_function=get_sync_fact_checking_guardrail(triggers=True)) + result = await guardrail.run( + agent=Agent(name="test"), agent_input="test", agent_output="test", context=RunContextWrapper(context=None) + ) + assert result.output.tripwire_triggered + assert result.output.output_info is None + + guardrail = FactCheckingGuardrail( + guardrail_function=get_sync_fact_checking_guardrail(triggers=True, output_info="test") + ) + result = await guardrail.run( + agent=Agent(name="test"), agent_input="test", agent_output="test", context=RunContextWrapper(context=None) + ) + assert result.output.tripwire_triggered + assert result.output.output_info == "test" \ No newline at end of file From e62ebac239556ce25d8c00dd51c64328bcb5bb7d Mon Sep 17 00:00:00 2001 From: Mohamed Amri Date: Tue, 25 Mar 2025 17:52:35 +0400 Subject: [PATCH 09/23] Updated unit tests --- src/agents/guardrail.py | 6 +-- tests/test_guardrails.py | 89 +++++++++++++++++++++++++++++++++++++++- 2 files changed, 90 insertions(+), 5 deletions(-) diff --git a/src/agents/guardrail.py b/src/agents/guardrail.py index 85842fd5..28937bfd 100644 --- a/src/agents/guardrail.py +++ b/src/agents/guardrail.py @@ -417,13 +417,13 @@ def decorator( @overload def fact_checking_guardrail( func: _FactCheckingGuardrailFuncSync[TContext_co], -) -> OutputGuardrail[TContext_co]: ... +) -> FactCheckingGuardrail[TContext_co]: ... @overload def fact_checking_guardrail( func: _FactCheckingGuardrailAsync[TContext_co], -) -> OutputGuardrail[TContext_co]: ... +) -> FactCheckingGuardrail[TContext_co]: ... @overload @@ -432,7 +432,7 @@ def fact_checking_guardrail( name: str | None = None, ) -> Callable[ [_FactCheckingGuardrailFuncSync[TContext_co] | _FactCheckingGuardrailAsync[TContext_co]], - OutputGuardrail[TContext_co], + FactCheckingGuardrail[TContext_co], ]: ... diff --git a/tests/test_guardrails.py b/tests/test_guardrails.py index 0bac8404..1fdc7802 100644 --- a/tests/test_guardrails.py +++ b/tests/test_guardrails.py @@ -274,7 +274,7 @@ def sync_guardrail(context: RunContextWrapper[Any], agent: Agent[Any], agent_out @pytest.mark.asyncio -async def test_sync_fact_guardrail(): +async def test_sync_fact_checking_guardrail(): guardrail = FactCheckingGuardrail(guardrail_function=get_sync_fact_checking_guardrail(triggers=False)) result = await guardrail.run( agent=Agent(name="test"), agent_input="test", agent_output="test", context=RunContextWrapper(context=None) @@ -296,4 +296,89 @@ async def test_sync_fact_guardrail(): agent=Agent(name="test"), agent_input="test", agent_output="test", context=RunContextWrapper(context=None) ) assert result.output.tripwire_triggered - assert result.output.output_info == "test" \ No newline at end of file + assert result.output.output_info == "test" + + +def get_async_fact_checking_guardrail(triggers: bool, output_info: Any | None = None): + async def async_guardrail( + context: RunContextWrapper[Any], agent: Agent[Any], agent_output: Any, agent_input: Any + ): + return GuardrailFunctionOutput( + output_info=output_info, + tripwire_triggered=triggers, + ) + + return async_guardrail + + +@pytest.mark.asyncio +async def test_async_fact_checking_guardrail(): + guardrail = FactCheckingGuardrail(guardrail_function=get_async_fact_checking_guardrail(triggers=False)) + result = await guardrail.run( + agent=Agent(name="test"), agent_input="test", agent_output="test", context=RunContextWrapper(context=None) + ) + assert not result.output.tripwire_triggered + assert result.output.output_info is None + + guardrail = FactCheckingGuardrail(guardrail_function=get_async_fact_checking_guardrail(triggers=True)) + result = await guardrail.run( + agent=Agent(name="test"), agent_input="test", agent_output="test", context=RunContextWrapper(context=None) + ) + assert result.output.tripwire_triggered + assert result.output.output_info is None + + guardrail = FactCheckingGuardrail( + guardrail_function=get_async_fact_checking_guardrail(triggers=True, output_info="test") + ) + result = await guardrail.run( + agent=Agent(name="test"), agent_input="test", agent_output="test", context=RunContextWrapper(context=None) + ) + assert result.output.tripwire_triggered + assert result.output.output_info == "test" + + +@pytest.mark.asyncio +async def test_invalid_fact_checking_guardrail_raises_user_error(): + with pytest.raises(UserError): + # Purposely ignoring type error + guardrail = FactCheckingGuardrail(guardrail_function="foo") # type: ignore + await guardrail.run( + agent=Agent(name="test"), agent_input="test", agent_output="test", context=RunContextWrapper(context=None) + ) + + +@fact_checking_guardrail +def decorated_fact_checking_guardrail( + context: RunContextWrapper[Any], agent: Agent[Any], agent_output: Any, agent_input: Any +) -> GuardrailFunctionOutput: + return GuardrailFunctionOutput( + output_info="test_5", + tripwire_triggered=False, + ) + + +@fact_checking_guardrail(name="Custom name") +def decorated_named_fact_checking_guardrail( + context: RunContextWrapper[Any], agent: Agent[Any], agent_output: Any, agent_input: Any +) -> GuardrailFunctionOutput: + return GuardrailFunctionOutput( + output_info="test_6", + tripwire_triggered=False, + ) + +@pytest.mark.asyncio +async def test_fact_checking_guardrail_decorators(): + guardrail = decorated_fact_checking_guardrail + result = await guardrail.run( + agent=Agent(name="test"), agent_input="test", agent_output="test", context=RunContextWrapper(context=None) + ) + assert not result.output.tripwire_triggered + assert result.output.output_info == "test_5" + + guardrail = decorated_named_fact_checking_guardrail + result = await guardrail.run( + agent=Agent(name="test"), agent_input="test", agent_output="test", context=RunContextWrapper(context=None) + ) + assert not result.output.tripwire_triggered + assert result.output.output_info == "test_6" + assert guardrail.get_name() == "Custom name" \ No newline at end of file From bf5a3d9f44a401215bc01b7991f06693686ba612 Mon Sep 17 00:00:00 2001 From: Mohamed Amri Date: Tue, 25 Mar 2025 18:16:43 +0400 Subject: [PATCH 10/23] Updated unit tests --- tests/test_guardrails.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/test_guardrails.py b/tests/test_guardrails.py index 1fdc7802..df75def9 100644 --- a/tests/test_guardrails.py +++ b/tests/test_guardrails.py @@ -4,7 +4,7 @@ import pytest -from src.agents import ( +from agents import ( Agent, GuardrailFunctionOutput, InputGuardrail, @@ -14,7 +14,7 @@ TResponseInputItem, UserError, ) -from src.agents.guardrail import input_guardrail, output_guardrail, fact_checking_guardrail +from agents.guardrail import input_guardrail, output_guardrail, fact_checking_guardrail def get_sync_guardrail(triggers: bool, output_info: Any | None = None): @@ -366,6 +366,7 @@ def decorated_named_fact_checking_guardrail( tripwire_triggered=False, ) + @pytest.mark.asyncio async def test_fact_checking_guardrail_decorators(): guardrail = decorated_fact_checking_guardrail @@ -381,4 +382,4 @@ async def test_fact_checking_guardrail_decorators(): ) assert not result.output.tripwire_triggered assert result.output.output_info == "test_6" - assert guardrail.get_name() == "Custom name" \ No newline at end of file + assert guardrail.get_name() == "Custom name" From 26057cb01307bf7f75849ffdc3f672abb098f1b1 Mon Sep 17 00:00:00 2001 From: Mohamed Amri Date: Tue, 25 Mar 2025 22:10:56 +0400 Subject: [PATCH 11/23] Adding an example --- .../agent_patterns/fact_checking_guardrail.py | 102 ++++++++++++++++++ 1 file changed, 102 insertions(+) create mode 100644 examples/agent_patterns/fact_checking_guardrail.py diff --git a/examples/agent_patterns/fact_checking_guardrail.py b/examples/agent_patterns/fact_checking_guardrail.py new file mode 100644 index 00000000..29b2eb1f --- /dev/null +++ b/examples/agent_patterns/fact_checking_guardrail.py @@ -0,0 +1,102 @@ +from __future__ import annotations + +import asyncio +import json + +from pydantic import BaseModel, Field + +from src.agents import ( + Agent, + GuardrailFunctionOutput, + FactCheckingGuardrailTripwireTriggered, + RunContextWrapper, + Runner, + fact_checking_guardrail, +) + + +""" +This example shows how to use fact checking guardrails. + +Fact checking guardrails are checks that run on both the original input and the final output of an agent. +Their primary purpose is to ensure the consistency and accuracy of the agent’s response by verifying that +the output aligns with known facts or the provided input data. They can be used to: +- Validate that the agent's output correctly reflects the information given in the input. +- Ensure that any factual details in the response match expected values. +- Detect discrepancies or potential misinformation. + +In this example, we'll use a contrived scenario where we verify if the agent's response contains data that matches the input. +""" + + +class MessageOutput(BaseModel): + reasoning: str = Field(description="Thoughts on how to respond to the user's message") + response: str = Field(description="The response to the user's message") + age: int | None = Field(description="Age of the person") + + +class ProductFactCheckingOutput(BaseModel): + reasoning: str + is_fact_correct: bool + + +guardrail_agent = Agent( + name="Guardrail Check", + instructions=( + "You are given a task to determine if the hypothesis is grounded in the provided evidence. " + "Rely solely on the contents of the evidence without using external knowledge." + ), + output_type=ProductFactCheckingOutput, +) + + +@fact_checking_guardrail +async def self_check_facts(context: RunContextWrapper, agent: Agent, output: MessageOutput, evidence: str) \ + -> GuardrailFunctionOutput: + """This is a facts checking guardrail function, which happens to call an agent to check if the output + is coherent with the input. + """ + message = ( + f"Evidence: {evidence}\n" + f"Hypothesis: {output.age}" + ) + + print(f"message: {message}") + + # Run the fact-checking agent using the constructed message. + result = await Runner.run(guardrail_agent, message, context=context.context) + final_output = result.final_output_as(ProductFactCheckingOutput) + + return GuardrailFunctionOutput( + output_info=final_output, + tripwire_triggered=final_output.is_fact_correct, + ) + + +async def main(): + agent = Agent( + name="Entities Extraction Agent", + instructions=""" + Extract the age of the person. + """, + fact_checking_guardrails=[self_check_facts], + output_type=MessageOutput, + ) + + await Runner.run(agent, "My name is Alex and I'm 28 years old.") + print("First message passed") + + # This should trip the guardrail + try: + result = await Runner.run( + agent, "My name is Alex." + ) + print( + f"Guardrail didn't trip - this is unexpected. Output: {json.dumps(result.final_output.model_dump(), indent=2)}" + ) + + except FactCheckingGuardrailTripwireTriggered as e: + print(f"Guardrail tripped. Info: {e.guardrail_result.output.output_info}") + +if __name__ == "__main__": + asyncio.run(main()) From 0472ea6af0b98878edd99e72eb12eff835e340ca Mon Sep 17 00:00:00 2001 From: Mohamed Amri Date: Tue, 25 Mar 2025 23:00:59 +0400 Subject: [PATCH 12/23] Fix example --- ..._guardrail.py => fact_checking_guardrails.py} | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) rename examples/agent_patterns/{fact_checking_guardrail.py => fact_checking_guardrails.py} (89%) diff --git a/examples/agent_patterns/fact_checking_guardrail.py b/examples/agent_patterns/fact_checking_guardrails.py similarity index 89% rename from examples/agent_patterns/fact_checking_guardrail.py rename to examples/agent_patterns/fact_checking_guardrails.py index 29b2eb1f..1d8f02c0 100644 --- a/examples/agent_patterns/fact_checking_guardrail.py +++ b/examples/agent_patterns/fact_checking_guardrails.py @@ -5,7 +5,7 @@ from pydantic import BaseModel, Field -from src.agents import ( +from agents import ( Agent, GuardrailFunctionOutput, FactCheckingGuardrailTripwireTriggered, @@ -35,9 +35,9 @@ class MessageOutput(BaseModel): age: int | None = Field(description="Age of the person") -class ProductFactCheckingOutput(BaseModel): +class FactCheckingOutput(BaseModel): reasoning: str - is_fact_correct: bool + is_fact_wrong: bool guardrail_agent = Agent( @@ -46,7 +46,7 @@ class ProductFactCheckingOutput(BaseModel): "You are given a task to determine if the hypothesis is grounded in the provided evidence. " "Rely solely on the contents of the evidence without using external knowledge." ), - output_type=ProductFactCheckingOutput, + output_type=FactCheckingOutput, ) @@ -57,19 +57,19 @@ async def self_check_facts(context: RunContextWrapper, agent: Agent, output: Mes is coherent with the input. """ message = ( - f"Evidence: {evidence}\n" - f"Hypothesis: {output.age}" + f"Input: {evidence}\n" + f"Age: {output.age}" ) print(f"message: {message}") # Run the fact-checking agent using the constructed message. result = await Runner.run(guardrail_agent, message, context=context.context) - final_output = result.final_output_as(ProductFactCheckingOutput) + final_output = result.final_output_as(FactCheckingOutput) return GuardrailFunctionOutput( output_info=final_output, - tripwire_triggered=final_output.is_fact_correct, + tripwire_triggered=final_output.is_fact_wrong, ) From 17f70b8a1e64294d0ea5b6403ef1515237ac7e94 Mon Sep 17 00:00:00 2001 From: Mohamed Amri Date: Tue, 25 Mar 2025 23:03:27 +0400 Subject: [PATCH 13/23] Fix example --- examples/agent_patterns/fact_checking_guardrails.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/agent_patterns/fact_checking_guardrails.py b/examples/agent_patterns/fact_checking_guardrails.py index 1d8f02c0..6bb9fa36 100644 --- a/examples/agent_patterns/fact_checking_guardrails.py +++ b/examples/agent_patterns/fact_checking_guardrails.py @@ -5,7 +5,7 @@ from pydantic import BaseModel, Field -from agents import ( +from src.agents import ( Agent, GuardrailFunctionOutput, FactCheckingGuardrailTripwireTriggered, @@ -77,7 +77,7 @@ async def main(): agent = Agent( name="Entities Extraction Agent", instructions=""" - Extract the age of the person. + Always respond age = 28. """, fact_checking_guardrails=[self_check_facts], output_type=MessageOutput, @@ -89,7 +89,7 @@ async def main(): # This should trip the guardrail try: result = await Runner.run( - agent, "My name is Alex." + agent, "My name is Alex and I'm 38." ) print( f"Guardrail didn't trip - this is unexpected. Output: {json.dumps(result.final_output.model_dump(), indent=2)}" From 9a3635d8397f7088b59bbcbd5428c82afab46150 Mon Sep 17 00:00:00 2001 From: Mohamed Amri Date: Wed, 26 Mar 2025 03:56:01 +0400 Subject: [PATCH 14/23] Adding documentation --- docs/guardrails.md | 132 +++++++++++++++++- .../fact_checking_guardrails.py | 6 +- 2 files changed, 133 insertions(+), 5 deletions(-) diff --git a/docs/guardrails.md b/docs/guardrails.md index 2f0be0f2..83c5df90 100644 --- a/docs/guardrails.md +++ b/docs/guardrails.md @@ -2,10 +2,11 @@ Guardrails run _in parallel_ to your agents, enabling you to do checks and validations of user input. For example, imagine you have an agent that uses a very smart (and hence slow/expensive) model to help with customer requests. You wouldn't want malicious users to ask the model to help them with their math homework. So, you can run a guardrail with a fast/cheap model. If the guardrail detects malicious usage, it can immediately raise an error, which stops the expensive model from running and saves you time/money. -There are two kinds of guardrails: +There are two three kinds of guardrails: 1. Input guardrails run on the initial user input 2. Output guardrails run on the final agent output +3. Fact checking guardrails run on the initial user input and the final user output ## Input guardrails @@ -23,7 +24,7 @@ Input guardrails run in 3 steps: Output guardrails run in 3 steps: -1. First, the guardrail receives the same input passed to the agent. +1. First, the guardrail receives the output of the last agent. 2. Next, the guardrail function runs to produce a [`GuardrailFunctionOutput`][agents.guardrail.GuardrailFunctionOutput], which is then wrapped in an [`OutputGuardrailResult`][agents.guardrail.OutputGuardrailResult] 3. Finally, we check if [`.tripwire_triggered`][agents.guardrail.GuardrailFunctionOutput.tripwire_triggered] is true. If true, an [`OutputGuardrailTripwireTriggered`][agents.exceptions.OutputGuardrailTripwireTriggered] exception is raised, so you can appropriately respond to the user or handle the exception. @@ -31,6 +32,19 @@ Output guardrails run in 3 steps: Output guardrails are intended to run on the final agent output, so an agent's guardrails only run if the agent is the *last* agent. Similar to the input guardrails, we do this because guardrails tend to be related to the actual Agent - you'd run different guardrails for different agents, so colocating the code is useful for readability. +## Fact Checking guardrails + +Fact Checking guardrails run in 3 steps: + +1. First, the guardrail receives the same input passed to the first agent and the output of the last agent. +2. Next, the guardrail function runs to produce a [`GuardrailFunctionOutput`][agents.guardrail.GuardrailFunctionOutput], which is then wrapped in an [`FactCheckingGuardrailResult`][agents.guardrail.FactCheckingGuardrailResult] +3. Finally, we check if [`.tripwire_triggered`][agents.guardrail.GuardrailFunctionOutput.tripwire_triggered] is true. If true, an [`FactCheckingGuardrailTripwireTriggered`][agents.exceptions.FactCheckingGuardrailTripwireTriggered] exception is raised, so you can appropriately respond to the user or handle the exception. + +!!! Note + + Fact checking guardrails are intended to run on user input and the final agent output, so an agent's guardrails only run if the agent is the *last* agent. Similar to the output guardrails, we do this because guardrails tend to be related to the actual Agent - you'd run different guardrails for different agents, so colocating the code is useful for readability. + + ## Tripwires If the input or output fails the guardrail, the Guardrail can signal this with a tripwire. As soon as we see a guardrail that has triggered the tripwires, we immediately raise a `{Input,Output}GuardrailTripwireTriggered` exception and halt the Agent execution. @@ -152,3 +166,117 @@ async def main(): 2. This is the guardrail's output type. 3. This is the guardrail function that receives the agent's output, and returns the result. 4. This is the actual agent that defines the workflow. + +Fact checking guardrails are similar. + +```python +import asyncio +import json + +from pydantic import BaseModel, Field + +from agents import ( + Agent, + GuardrailFunctionOutput, + FactCheckingGuardrailTripwireTriggered, + RunContextWrapper, + Runner, + fact_checking_guardrail, +) + + +""" +This example shows how to use fact checking guardrails. + +Fact checking guardrails are checks that run on both the original input and the final output of an agent. +Their primary purpose is to ensure the consistency and accuracy of the agent’s response by verifying that +the output aligns with known facts or the provided input data. They can be used to: +- Validate that the agent's output correctly reflects the information given in the input. +- Ensure that any factual details in the response match expected values. +- Detect discrepancies or potential misinformation. + +In this example, we'll use a contrived scenario where we verify if the agent's response contains data that matches the input. +""" + + +class MessageOutput(BaseModel): # (1)! + reasoning: str = Field(description="Thoughts on how to respond to the user's message") + response: str = Field(description="The response to the user's message") + age: int | None = Field(description="Age of the person") + + +class FactCheckingOutput(BaseModel): # (2)! + reasoning: str + is_age_correct: bool + + +guardrail_agent = Agent( + name="Guardrail Check", + instructions=( + "You are given a task to determine if the hypothesis is grounded in the provided evidence. " + "Rely solely on the contents of the evidence without using external knowledge." + ), + output_type=FactCheckingOutput, +) + + +@fact_checking_guardrail +async def self_check_facts( # (3)! + context: RunContextWrapper, + agent: Agent, + output: MessageOutput, + evidence: str) \ + -> GuardrailFunctionOutput: + """This is a facts checking guardrail function, which happens to call an agent to check if the output + is coherent with the input. + """ + message = ( + f"Input: {evidence}\n" + f"Age: {output.age}" + ) + + print(f"message: {message}") + + # Run the fact-checking agent using the constructed message. + result = await Runner.run(guardrail_agent, message, context=context.context) + final_output = result.final_output_as(FactCheckingOutput) + + return GuardrailFunctionOutput( + output_info=final_output, + tripwire_triggered=not final_output.is_age_correct, + ) + + +async def main(): + agent = Agent( # (3)! + name="Entities Extraction Agent", + instructions=""" + Always respond age = 28. + """, + fact_checking_guardrails=[self_check_facts], + output_type=MessageOutput, + ) + + await Runner.run(agent, "My name is Alex and I'm 28 years old.") + print("First message passed") + + # This should trip the guardrail + try: + result = await Runner.run( + agent, "My name is Alex and I'm 38." + ) + print( + f"Guardrail didn't trip - this is unexpected. Output: {json.dumps(result.final_output.model_dump(), indent=2)}" + ) + + except FactCheckingGuardrailTripwireTriggered as e: + print(f"Guardrail tripped. Info: {e.guardrail_result.output.output_info}") + +if __name__ == "__main__": + asyncio.run(main()) +``` + +1. This is the actual agent's output type. +2. This is the guardrail's output type. +3. This is the guardrail function that receives the user input and agent's output, and returns the result. +4. This is the actual agent that defines the workflow. diff --git a/examples/agent_patterns/fact_checking_guardrails.py b/examples/agent_patterns/fact_checking_guardrails.py index 6bb9fa36..24b45cea 100644 --- a/examples/agent_patterns/fact_checking_guardrails.py +++ b/examples/agent_patterns/fact_checking_guardrails.py @@ -5,7 +5,7 @@ from pydantic import BaseModel, Field -from src.agents import ( +from agents import ( Agent, GuardrailFunctionOutput, FactCheckingGuardrailTripwireTriggered, @@ -37,7 +37,7 @@ class MessageOutput(BaseModel): class FactCheckingOutput(BaseModel): reasoning: str - is_fact_wrong: bool + is_age_correct: bool guardrail_agent = Agent( @@ -69,7 +69,7 @@ async def self_check_facts(context: RunContextWrapper, agent: Agent, output: Mes return GuardrailFunctionOutput( output_info=final_output, - tripwire_triggered=final_output.is_fact_wrong, + tripwire_triggered=not final_output.is_age_correct, ) From 0070cbd03230dc0dd06ba0ae5694746205755366 Mon Sep 17 00:00:00 2001 From: Mohamed Amri Date: Wed, 26 Mar 2025 04:00:41 +0400 Subject: [PATCH 15/23] Adding documentation --- docs/guardrails.md | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/docs/guardrails.md b/docs/guardrails.md index 83c5df90..69f8d9cd 100644 --- a/docs/guardrails.md +++ b/docs/guardrails.md @@ -2,7 +2,7 @@ Guardrails run _in parallel_ to your agents, enabling you to do checks and validations of user input. For example, imagine you have an agent that uses a very smart (and hence slow/expensive) model to help with customer requests. You wouldn't want malicious users to ask the model to help them with their math homework. So, you can run a guardrail with a fast/cheap model. If the guardrail detects malicious usage, it can immediately raise an error, which stops the expensive model from running and saves you time/money. -There are two three kinds of guardrails: +There are three kinds of guardrails: 1. Input guardrails run on the initial user input 2. Output guardrails run on the final agent output @@ -47,7 +47,7 @@ Fact Checking guardrails run in 3 steps: ## Tripwires -If the input or output fails the guardrail, the Guardrail can signal this with a tripwire. As soon as we see a guardrail that has triggered the tripwires, we immediately raise a `{Input,Output}GuardrailTripwireTriggered` exception and halt the Agent execution. +If the input or output fails the guardrail, the Guardrail can signal this with a tripwire. As soon as we see a guardrail that has triggered the tripwires, we immediately raise a `{Input,Output,FactChecking}GuardrailTripwireTriggered` exception and halt the Agent execution. ## Implementing a guardrail @@ -170,7 +170,6 @@ async def main(): Fact checking guardrails are similar. ```python -import asyncio import json from pydantic import BaseModel, Field @@ -248,7 +247,7 @@ async def self_check_facts( # (3)! async def main(): - agent = Agent( # (3)! + agent = Agent( # (4)! name="Entities Extraction Agent", instructions=""" Always respond age = 28. @@ -271,9 +270,6 @@ async def main(): except FactCheckingGuardrailTripwireTriggered as e: print(f"Guardrail tripped. Info: {e.guardrail_result.output.output_info}") - -if __name__ == "__main__": - asyncio.run(main()) ``` 1. This is the actual agent's output type. From 2602769a6b62a4e89bcbc2f6bfe32a72277ee733 Mon Sep 17 00:00:00 2001 From: Mohamed Amri Date: Wed, 26 Mar 2025 09:20:23 +0400 Subject: [PATCH 16/23] Fix formatting --- "eval \"$(ssh-agent -s)\"" | 7 ++++ "eval \"$(ssh-agent -s)\".pub" | 1 + src/agents/__init__.py | 12 +++---- src/agents/_run_impl.py | 10 ++++-- src/agents/agent.py | 5 +-- src/agents/exceptions.py | 2 +- src/agents/guardrail.py | 11 ++++-- src/agents/result.py | 2 +- src/agents/run.py | 25 ++++++++----- tests/test_guardrails.py | 66 +++++++++++++++++++++++++--------- 10 files changed, 101 insertions(+), 40 deletions(-) create mode 100644 "eval \"$(ssh-agent -s)\"" create mode 100644 "eval \"$(ssh-agent -s)\".pub" diff --git "a/eval \"$(ssh-agent -s)\"" "b/eval \"$(ssh-agent -s)\"" new file mode 100644 index 00000000..08890ba6 --- /dev/null +++ "b/eval \"$(ssh-agent -s)\"" @@ -0,0 +1,7 @@ +-----BEGIN OPENSSH PRIVATE KEY----- +b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW +QyNTUxOQAAACDgli+WytIJ+zP8X8dx6jhvXKKZugu6yS49WJWLxVbpLwAAAJiPmjxZj5o8 +WQAAAAtzc2gtZWQyNTUxOQAAACDgli+WytIJ+zP8X8dx6jhvXKKZugu6yS49WJWLxVbpLw +AAAECpWJ6QKQub+iCVzQ60Di2V6A7B8doiXSocBOtHi2ODluCWL5bK0gn7M/xfx3HqOG9c +opm6C7rJLj1YlYvFVukvAAAAEWFtcmkzNjlAZ21haWwuY29tAQIDBA== +-----END OPENSSH PRIVATE KEY----- diff --git "a/eval \"$(ssh-agent -s)\".pub" "b/eval \"$(ssh-agent -s)\".pub" new file mode 100644 index 00000000..cb370cf8 --- /dev/null +++ "b/eval \"$(ssh-agent -s)\".pub" @@ -0,0 +1 @@ +ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIOCWL5bK0gn7M/xfx3HqOG9copm6C7rJLj1YlYvFVukv amri369@gmail.com diff --git a/src/agents/__init__.py b/src/agents/__init__.py index 537da0fb..eaf689a8 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -10,24 +10,24 @@ from .computer import AsyncComputer, Button, Computer, Environment from .exceptions import ( AgentsException, + FactCheckingGuardrailTripwireTriggered, InputGuardrailTripwireTriggered, MaxTurnsExceeded, ModelBehaviorError, OutputGuardrailTripwireTriggered, UserError, - FactCheckingGuardrailTripwireTriggered, ) from .guardrail import ( + fact_checking_guardrail, + FactCheckingGuardrail, + FactCheckingGuardrailResult, GuardrailFunctionOutput, + input_guardrail, InputGuardrail, InputGuardrailResult, + output_guardrail, OutputGuardrail, OutputGuardrailResult, - FactCheckingGuardrail, - FactCheckingGuardrailResult, - input_guardrail, - output_guardrail, - fact_checking_guardrail, ) from .handoffs import Handoff, HandoffInputData, HandoffInputFilter, handoff from .items import ( diff --git a/src/agents/_run_impl.py b/src/agents/_run_impl.py index 93b4e356..3ec5e8c5 100644 --- a/src/agents/_run_impl.py +++ b/src/agents/_run_impl.py @@ -33,12 +33,12 @@ from .computer import AsyncComputer, Computer from .exceptions import AgentsException, ModelBehaviorError, UserError from .guardrail import ( + FactCheckingGuardrail, + FactCheckingGuardrailResult, InputGuardrail, InputGuardrailResult, OutputGuardrail, OutputGuardrailResult, - FactCheckingGuardrail, - FactCheckingGuardrailResult ) from .handoffs import Handoff, HandoffInputData from .items import ( @@ -714,7 +714,11 @@ async def run_single_fact_checking_guardrail( agent_input: Any, ) -> FactCheckingGuardrailResult: with guardrail_span(guardrail.get_name()) as span_guardrail: - result = await guardrail.run(agent=agent, agent_output=agent_output, context=context, agent_input=agent_input) + result = await guardrail.run( + agent=agent, + agent_output=agent_output, + context=context, + agent_input=agent_input) span_guardrail.span_data.triggered = result.output.tripwire_triggered return result diff --git a/src/agents/agent.py b/src/agents/agent.py index dbf67929..824b1275 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -8,7 +8,7 @@ from typing_extensions import TypeAlias, TypedDict -from .guardrail import InputGuardrail, OutputGuardrail, FactCheckingGuardrail +from .guardrail import FactCheckingGuardrail, InputGuardrail, OutputGuardrail from .handoffs import Handoff from .items import ItemHelpers from .logger import logger @@ -130,7 +130,8 @@ class Agent(Generic[TContext]): """ fact_checking_guardrails: list[FactCheckingGuardrail[TContext]] = field(default_factory=list) - """A list of checks that run on the original input and the final output of the agent, after generating a response. + """A list of checks that run on the original input + and the final output of the agent, after generating a response. Runs only if the agent produces a final output. """ diff --git a/src/agents/exceptions.py b/src/agents/exceptions.py index e3f82cea..055b1fb8 100644 --- a/src/agents/exceptions.py +++ b/src/agents/exceptions.py @@ -1,7 +1,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from .guardrail import InputGuardrailResult, OutputGuardrailResult, FactCheckingGuardrailResult + from .guardrail import FactCheckingGuardrailResult, InputGuardrailResult, OutputGuardrailResult class AgentsException(Exception): diff --git a/src/agents/guardrail.py b/src/agents/guardrail.py index 28937bfd..cc4fd4d7 100644 --- a/src/agents/guardrail.py +++ b/src/agents/guardrail.py @@ -211,11 +211,12 @@ class FactCheckingGuardrail(Generic[TContext]): """Fact checking guardrails are checks that run on the final output and the input of an agent. They can be used to do check if the output passes certain validation criteria - You can use the `@fact_checking_guardrail()` decorator to turn a function into an `OutputGuardrail`, + You can use the `@fact_checking_guardrail()` + decorator to turn a function into an `FactCheckingGuardrail`, or create an `OutputGuardrail` manually. Guardrails return a `GuardrailResult`. If `result.tripwire_triggered` is `True`, a - `OutputGuardrailTripwireTriggered` exception will be raised. + `FactCheckingGuardrailTripwireTriggered` exception will be raised. """ guardrail_function: Callable[ @@ -239,7 +240,11 @@ def get_name(self) -> str: return self.guardrail_function.__name__ async def run( - self, context: RunContextWrapper[TContext], agent: Agent[Any], agent_output: Any, agent_input: Any + self, + context: RunContextWrapper[TContext], + agent: Agent[Any], + agent_output: Any, + agent_input: Any, ) -> FactCheckingGuardrailResult: if not callable(self.guardrail_function): raise UserError(f"Guardrail function must be callable, got {self.guardrail_function}") diff --git a/src/agents/result.py b/src/agents/result.py index 9daf615b..c1a8a550 100644 --- a/src/agents/result.py +++ b/src/agents/result.py @@ -10,7 +10,7 @@ from .agent_output import AgentOutputSchema from .exceptions import InputGuardrailTripwireTriggered, MaxTurnsExceeded -from .guardrail import InputGuardrailResult, OutputGuardrailResult, FactCheckingGuardrailResult +from .guardrail import FactCheckingGuardrailResult, InputGuardrailResult, OutputGuardrailResult from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem from .logger import logger from .stream_events import StreamEvent diff --git a/src/agents/run.py b/src/agents/run.py index a7fca980..5bb1b120 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -22,19 +22,19 @@ from .agent_output import AgentOutputSchema from .exceptions import ( AgentsException, + FactCheckingGuardrailTripwireTriggered, InputGuardrailTripwireTriggered, MaxTurnsExceeded, ModelBehaviorError, OutputGuardrailTripwireTriggered, - FactCheckingGuardrailTripwireTriggered, ) from .guardrail import ( + FactCheckingGuardrail, + FactCheckingGuardrailResult, InputGuardrail, InputGuardrailResult, OutputGuardrail, OutputGuardrailResult, - FactCheckingGuardrail, - FactCheckingGuardrailResult ) from .handoffs import Handoff, HandoffInputFilter, handoff from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem @@ -85,7 +85,8 @@ class RunConfig: """A list of output guardrails to run on the final output of the run.""" fact_checking_guardrails: list[FactCheckingGuardrail[Any]] | None = None - """A list of fact checking guardrails to run on the original input and the final output of the run.""" + """A list of fact checking guardrails to run on the original + input and the final output of the run.""" tracing_disabled: bool = False """Whether tracing is disabled for the agent run. If disabled, we will not trace the agent run. @@ -269,7 +270,8 @@ async def run( context_wrapper, ) fact_checking_guardrail_results = await cls._run_fact_checking_guardrails( - current_agent.fact_checking_guardrails + (run_config.fact_checking_guardrails or []), + current_agent.fact_checking_guardrails + + (run_config.fact_checking_guardrails or []), current_agent, turn_result.next_step.output, context_wrapper, @@ -614,13 +616,15 @@ async def _run_streamed_impl( ) try: - output_guardrail_results = await streamed_result._output_guardrails_task + output_guardrail_results = \ + await streamed_result._output_guardrails_task except Exception: # Exceptions will be checked in the stream_events loop output_guardrail_results = [] try: - fact_checking_guardrails_results = await streamed_result._fact_checking_guardrails_task + fact_checking_guardrails_results = \ + await streamed_result._fact_checking_guardrails_task except Exception: # Exceptions will be checked in the stream_events loop fact_checking_guardrails_results = [] @@ -928,7 +932,12 @@ async def _run_fact_checking_guardrails( guardrail_tasks = [ asyncio.create_task( - RunImpl.run_single_fact_checking_guardrail(guardrail, agent, agent_output, context, agent_input) + RunImpl.run_single_fact_checking_guardrail( + guardrail, + agent, + agent_output, + context, + agent_input) ) for guardrail in guardrails ] diff --git a/tests/test_guardrails.py b/tests/test_guardrails.py index df75def9..4de0077b 100644 --- a/tests/test_guardrails.py +++ b/tests/test_guardrails.py @@ -6,15 +6,15 @@ from agents import ( Agent, + FactCheckingGuardrail, GuardrailFunctionOutput, InputGuardrail, OutputGuardrail, - FactCheckingGuardrail, RunContextWrapper, TResponseInputItem, UserError, ) -from agents.guardrail import input_guardrail, output_guardrail, fact_checking_guardrail +from agents.guardrail import fact_checking_guardrail, input_guardrail, output_guardrail def get_sync_guardrail(triggers: bool, output_info: Any | None = None): @@ -264,7 +264,10 @@ async def test_output_guardrail_decorators(): def get_sync_fact_checking_guardrail(triggers: bool, output_info: Any | None = None): - def sync_guardrail(context: RunContextWrapper[Any], agent: Agent[Any], agent_output: Any, agent_input: Any): + def sync_guardrail(context: RunContextWrapper[Any], + agent: Agent[Any], + agent_output: Any, + agent_input: Any): return GuardrailFunctionOutput( output_info=output_info, tripwire_triggered=triggers, @@ -275,16 +278,24 @@ def sync_guardrail(context: RunContextWrapper[Any], agent: Agent[Any], agent_out @pytest.mark.asyncio async def test_sync_fact_checking_guardrail(): - guardrail = FactCheckingGuardrail(guardrail_function=get_sync_fact_checking_guardrail(triggers=False)) + guardrail = FactCheckingGuardrail( + guardrail_function=get_sync_fact_checking_guardrail(triggers=False)) result = await guardrail.run( - agent=Agent(name="test"), agent_input="test", agent_output="test", context=RunContextWrapper(context=None) + agent=Agent(name="test"), + agent_input="test", + agent_output="test", + context=RunContextWrapper(context=None) ) assert not result.output.tripwire_triggered assert result.output.output_info is None - guardrail = FactCheckingGuardrail(guardrail_function=get_sync_fact_checking_guardrail(triggers=True)) + guardrail = FactCheckingGuardrail( + guardrail_function=get_sync_fact_checking_guardrail(triggers=True)) result = await guardrail.run( - agent=Agent(name="test"), agent_input="test", agent_output="test", context=RunContextWrapper(context=None) + agent=Agent(name="test"), + agent_input="test", + agent_output="test", + context=RunContextWrapper(context=None) ) assert result.output.tripwire_triggered assert result.output.output_info is None @@ -293,7 +304,10 @@ async def test_sync_fact_checking_guardrail(): guardrail_function=get_sync_fact_checking_guardrail(triggers=True, output_info="test") ) result = await guardrail.run( - agent=Agent(name="test"), agent_input="test", agent_output="test", context=RunContextWrapper(context=None) + agent=Agent(name="test"), + agent_input="test", + agent_output="test", + context=RunContextWrapper(context=None) ) assert result.output.tripwire_triggered assert result.output.output_info == "test" @@ -313,16 +327,24 @@ async def async_guardrail( @pytest.mark.asyncio async def test_async_fact_checking_guardrail(): - guardrail = FactCheckingGuardrail(guardrail_function=get_async_fact_checking_guardrail(triggers=False)) + guardrail = FactCheckingGuardrail( + guardrail_function=get_async_fact_checking_guardrail(triggers=False)) result = await guardrail.run( - agent=Agent(name="test"), agent_input="test", agent_output="test", context=RunContextWrapper(context=None) + agent=Agent(name="test"), + agent_input="test", + agent_output="test", + context=RunContextWrapper(context=None) ) assert not result.output.tripwire_triggered assert result.output.output_info is None - guardrail = FactCheckingGuardrail(guardrail_function=get_async_fact_checking_guardrail(triggers=True)) + guardrail = FactCheckingGuardrail( + guardrail_function=get_async_fact_checking_guardrail(triggers=True)) result = await guardrail.run( - agent=Agent(name="test"), agent_input="test", agent_output="test", context=RunContextWrapper(context=None) + agent=Agent(name="test"), + agent_input="test", + agent_output="test", + context=RunContextWrapper(context=None) ) assert result.output.tripwire_triggered assert result.output.output_info is None @@ -331,7 +353,10 @@ async def test_async_fact_checking_guardrail(): guardrail_function=get_async_fact_checking_guardrail(triggers=True, output_info="test") ) result = await guardrail.run( - agent=Agent(name="test"), agent_input="test", agent_output="test", context=RunContextWrapper(context=None) + agent=Agent(name="test"), + agent_input="test", + agent_output="test", + context=RunContextWrapper(context=None) ) assert result.output.tripwire_triggered assert result.output.output_info == "test" @@ -343,7 +368,10 @@ async def test_invalid_fact_checking_guardrail_raises_user_error(): # Purposely ignoring type error guardrail = FactCheckingGuardrail(guardrail_function="foo") # type: ignore await guardrail.run( - agent=Agent(name="test"), agent_input="test", agent_output="test", context=RunContextWrapper(context=None) + agent=Agent(name="test"), + agent_input="test", + agent_output="test", + context=RunContextWrapper(context=None) ) @@ -371,14 +399,20 @@ def decorated_named_fact_checking_guardrail( async def test_fact_checking_guardrail_decorators(): guardrail = decorated_fact_checking_guardrail result = await guardrail.run( - agent=Agent(name="test"), agent_input="test", agent_output="test", context=RunContextWrapper(context=None) + agent=Agent(name="test"), + agent_input="test", + agent_output="test", + context=RunContextWrapper(context=None) ) assert not result.output.tripwire_triggered assert result.output.output_info == "test_5" guardrail = decorated_named_fact_checking_guardrail result = await guardrail.run( - agent=Agent(name="test"), agent_input="test", agent_output="test", context=RunContextWrapper(context=None) + agent=Agent(name="test"), + agent_input="test", + agent_output="test", + context=RunContextWrapper(context=None) ) assert not result.output.tripwire_triggered assert result.output.output_info == "test_6" From 5a1fa3cd14902089edaff3c2b8446d24ec90472e Mon Sep 17 00:00:00 2001 From: Mohamed Amri Date: Wed, 26 Mar 2025 09:45:53 +0400 Subject: [PATCH 17/23] Fix formatting --- .../fact_checking_guardrails.py | 22 +++++------ examples/mcp/git_example/main.py | 9 +---- src/agents/__init__.py | 6 +-- src/agents/_run_impl.py | 18 ++++----- src/agents/exceptions.py | 1 + src/agents/guardrail.py | 12 +++--- src/agents/run.py | 31 +++++++--------- tests/test_guardrails.py | 37 ++++++++++--------- 8 files changed, 63 insertions(+), 73 deletions(-) diff --git a/examples/agent_patterns/fact_checking_guardrails.py b/examples/agent_patterns/fact_checking_guardrails.py index 24b45cea..084a71ad 100644 --- a/examples/agent_patterns/fact_checking_guardrails.py +++ b/examples/agent_patterns/fact_checking_guardrails.py @@ -7,14 +7,13 @@ from agents import ( Agent, - GuardrailFunctionOutput, FactCheckingGuardrailTripwireTriggered, + GuardrailFunctionOutput, RunContextWrapper, Runner, fact_checking_guardrail, ) - """ This example shows how to use fact checking guardrails. @@ -51,15 +50,13 @@ class FactCheckingOutput(BaseModel): @fact_checking_guardrail -async def self_check_facts(context: RunContextWrapper, agent: Agent, output: MessageOutput, evidence: str) \ - -> GuardrailFunctionOutput: +async def self_check_facts( + context: RunContextWrapper, agent: Agent, output: MessageOutput, evidence: str +) -> GuardrailFunctionOutput: """This is a facts checking guardrail function, which happens to call an agent to check if the output - is coherent with the input. - """ - message = ( - f"Input: {evidence}\n" - f"Age: {output.age}" - ) + is coherent with the input. + """ + message = f"Input: {evidence}\nAge: {output.age}" print(f"message: {message}") @@ -88,9 +85,7 @@ async def main(): # This should trip the guardrail try: - result = await Runner.run( - agent, "My name is Alex and I'm 38." - ) + result = await Runner.run(agent, "My name is Alex and I'm 38.") print( f"Guardrail didn't trip - this is unexpected. Output: {json.dumps(result.final_output.model_dump(), indent=2)}" ) @@ -98,5 +93,6 @@ async def main(): except FactCheckingGuardrailTripwireTriggered as e: print(f"Guardrail tripped. Info: {e.guardrail_result.output.output_info}") + if __name__ == "__main__": asyncio.run(main()) diff --git a/examples/mcp/git_example/main.py b/examples/mcp/git_example/main.py index cfc15108..7d882481 100644 --- a/examples/mcp/git_example/main.py +++ b/examples/mcp/git_example/main.py @@ -29,14 +29,7 @@ async def main(): # Ask the user for the directory path directory_path = input("Please enter the path to the git repository: ") - async with MCPServerStdio( - params={ - "command": "uvx", - "args": [ - "mcp-server-git" - ] - } - ) as server: + async with MCPServerStdio(params={"command": "uvx", "args": ["mcp-server-git"]}) as server: with trace(workflow_name="MCP Git Example"): await run(server, directory_path) diff --git a/src/agents/__init__.py b/src/agents/__init__.py index eaf689a8..5771c7cb 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -18,16 +18,16 @@ UserError, ) from .guardrail import ( - fact_checking_guardrail, FactCheckingGuardrail, FactCheckingGuardrailResult, GuardrailFunctionOutput, - input_guardrail, InputGuardrail, InputGuardrailResult, - output_guardrail, OutputGuardrail, OutputGuardrailResult, + fact_checking_guardrail, + input_guardrail, + output_guardrail, ) from .handoffs import Handoff, HandoffInputData, HandoffInputFilter, handoff from .items import ( diff --git a/src/agents/_run_impl.py b/src/agents/_run_impl.py index 3ec5e8c5..84dffb1e 100644 --- a/src/agents/_run_impl.py +++ b/src/agents/_run_impl.py @@ -706,19 +706,17 @@ async def run_single_output_guardrail( @classmethod async def run_single_fact_checking_guardrail( - cls, - guardrail: FactCheckingGuardrail[TContext], - agent: Agent[Any], - agent_output: Any, - context: RunContextWrapper[TContext], - agent_input: Any, + cls, + guardrail: FactCheckingGuardrail[TContext], + agent: Agent[Any], + agent_output: Any, + context: RunContextWrapper[TContext], + agent_input: Any, ) -> FactCheckingGuardrailResult: with guardrail_span(guardrail.get_name()) as span_guardrail: result = await guardrail.run( - agent=agent, - agent_output=agent_output, - context=context, - agent_input=agent_input) + agent=agent, agent_output=agent_output, context=context, agent_input=agent_input + ) span_guardrail.span_data.triggered = result.output.tripwire_triggered return result diff --git a/src/agents/exceptions.py b/src/agents/exceptions.py index 055b1fb8..e74b02ad 100644 --- a/src/agents/exceptions.py +++ b/src/agents/exceptions.py @@ -62,6 +62,7 @@ def __init__(self, guardrail_result: "OutputGuardrailResult"): f"Guardrail {guardrail_result.guardrail.__class__.__name__} triggered tripwire" ) + class FactCheckingGuardrailTripwireTriggered(AgentsException): """Exception raised when a guardrail tripwire is triggered.""" diff --git a/src/agents/guardrail.py b/src/agents/guardrail.py index cc4fd4d7..ea17c3d9 100644 --- a/src/agents/guardrail.py +++ b/src/agents/guardrail.py @@ -67,6 +67,7 @@ class OutputGuardrailResult: output: GuardrailFunctionOutput """The output of the guardrail function.""" + @dataclass class FactCheckingGuardrailResult: """The result of a guardrail run.""" @@ -94,6 +95,7 @@ class FactCheckingGuardrailResult: output: GuardrailFunctionOutput """The output of the guardrail function.""" + @dataclass class InputGuardrail(Generic[TContext]): """Input guardrails are checks that run in parallel to the agent's execution. @@ -240,11 +242,11 @@ def get_name(self) -> str: return self.guardrail_function.__name__ async def run( - self, - context: RunContextWrapper[TContext], - agent: Agent[Any], - agent_output: Any, - agent_input: Any, + self, + context: RunContextWrapper[TContext], + agent: Agent[Any], + agent_output: Any, + agent_input: Any, ) -> FactCheckingGuardrailResult: if not callable(self.guardrail_function): raise UserError(f"Guardrail function must be callable, got {self.guardrail_function}") diff --git a/src/agents/run.py b/src/agents/run.py index 5bb1b120..ee4028df 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -270,12 +270,12 @@ async def run( context_wrapper, ) fact_checking_guardrail_results = await cls._run_fact_checking_guardrails( - current_agent.fact_checking_guardrails + - (run_config.fact_checking_guardrails or []), + current_agent.fact_checking_guardrails + + (run_config.fact_checking_guardrails or []), current_agent, turn_result.next_step.output, context_wrapper, - original_input + original_input, ) return RunResult( input=original_input, @@ -616,15 +616,15 @@ async def _run_streamed_impl( ) try: - output_guardrail_results = \ - await streamed_result._output_guardrails_task + output_guardrail_results = await streamed_result._output_guardrails_task except Exception: # Exceptions will be checked in the stream_events loop output_guardrail_results = [] try: - fact_checking_guardrails_results = \ + fact_checking_guardrails_results = ( await streamed_result._fact_checking_guardrails_task + ) except Exception: # Exceptions will be checked in the stream_events loop fact_checking_guardrails_results = [] @@ -920,12 +920,12 @@ async def _run_output_guardrails( @classmethod async def _run_fact_checking_guardrails( - cls, - guardrails: list[FactCheckingGuardrail[TContext]], - agent: Agent[TContext], - agent_output: Any, - context: RunContextWrapper[TContext], - agent_input: Any, + cls, + guardrails: list[FactCheckingGuardrail[TContext]], + agent: Agent[TContext], + agent_output: Any, + context: RunContextWrapper[TContext], + agent_input: Any, ) -> list[FactCheckingGuardrailResult]: if not guardrails: return [] @@ -933,11 +933,8 @@ async def _run_fact_checking_guardrails( guardrail_tasks = [ asyncio.create_task( RunImpl.run_single_fact_checking_guardrail( - guardrail, - agent, - agent_output, - context, - agent_input) + guardrail, agent, agent_output, context, agent_input + ) ) for guardrail in guardrails ] diff --git a/tests/test_guardrails.py b/tests/test_guardrails.py index 4de0077b..08510279 100644 --- a/tests/test_guardrails.py +++ b/tests/test_guardrails.py @@ -264,10 +264,9 @@ async def test_output_guardrail_decorators(): def get_sync_fact_checking_guardrail(triggers: bool, output_info: Any | None = None): - def sync_guardrail(context: RunContextWrapper[Any], - agent: Agent[Any], - agent_output: Any, - agent_input: Any): + def sync_guardrail( + context: RunContextWrapper[Any], agent: Agent[Any], agent_output: Any, agent_input: Any + ): return GuardrailFunctionOutput( output_info=output_info, tripwire_triggered=triggers, @@ -279,23 +278,25 @@ def sync_guardrail(context: RunContextWrapper[Any], @pytest.mark.asyncio async def test_sync_fact_checking_guardrail(): guardrail = FactCheckingGuardrail( - guardrail_function=get_sync_fact_checking_guardrail(triggers=False)) + guardrail_function=get_sync_fact_checking_guardrail(triggers=False) + ) result = await guardrail.run( agent=Agent(name="test"), agent_input="test", agent_output="test", - context=RunContextWrapper(context=None) + context=RunContextWrapper(context=None), ) assert not result.output.tripwire_triggered assert result.output.output_info is None guardrail = FactCheckingGuardrail( - guardrail_function=get_sync_fact_checking_guardrail(triggers=True)) + guardrail_function=get_sync_fact_checking_guardrail(triggers=True) + ) result = await guardrail.run( agent=Agent(name="test"), agent_input="test", agent_output="test", - context=RunContextWrapper(context=None) + context=RunContextWrapper(context=None), ) assert result.output.tripwire_triggered assert result.output.output_info is None @@ -307,7 +308,7 @@ async def test_sync_fact_checking_guardrail(): agent=Agent(name="test"), agent_input="test", agent_output="test", - context=RunContextWrapper(context=None) + context=RunContextWrapper(context=None), ) assert result.output.tripwire_triggered assert result.output.output_info == "test" @@ -328,23 +329,25 @@ async def async_guardrail( @pytest.mark.asyncio async def test_async_fact_checking_guardrail(): guardrail = FactCheckingGuardrail( - guardrail_function=get_async_fact_checking_guardrail(triggers=False)) + guardrail_function=get_async_fact_checking_guardrail(triggers=False) + ) result = await guardrail.run( agent=Agent(name="test"), agent_input="test", agent_output="test", - context=RunContextWrapper(context=None) + context=RunContextWrapper(context=None), ) assert not result.output.tripwire_triggered assert result.output.output_info is None guardrail = FactCheckingGuardrail( - guardrail_function=get_async_fact_checking_guardrail(triggers=True)) + guardrail_function=get_async_fact_checking_guardrail(triggers=True) + ) result = await guardrail.run( agent=Agent(name="test"), agent_input="test", agent_output="test", - context=RunContextWrapper(context=None) + context=RunContextWrapper(context=None), ) assert result.output.tripwire_triggered assert result.output.output_info is None @@ -356,7 +359,7 @@ async def test_async_fact_checking_guardrail(): agent=Agent(name="test"), agent_input="test", agent_output="test", - context=RunContextWrapper(context=None) + context=RunContextWrapper(context=None), ) assert result.output.tripwire_triggered assert result.output.output_info == "test" @@ -371,7 +374,7 @@ async def test_invalid_fact_checking_guardrail_raises_user_error(): agent=Agent(name="test"), agent_input="test", agent_output="test", - context=RunContextWrapper(context=None) + context=RunContextWrapper(context=None), ) @@ -402,7 +405,7 @@ async def test_fact_checking_guardrail_decorators(): agent=Agent(name="test"), agent_input="test", agent_output="test", - context=RunContextWrapper(context=None) + context=RunContextWrapper(context=None), ) assert not result.output.tripwire_triggered assert result.output.output_info == "test_5" @@ -412,7 +415,7 @@ async def test_fact_checking_guardrail_decorators(): agent=Agent(name="test"), agent_input="test", agent_output="test", - context=RunContextWrapper(context=None) + context=RunContextWrapper(context=None), ) assert not result.output.tripwire_triggered assert result.output.output_info == "test_6" From ae22d0e52a062d99fbf5b6e8ef8dc4097d8726d0 Mon Sep 17 00:00:00 2001 From: Mohamed Amri Date: Wed, 26 Mar 2025 10:49:39 +0400 Subject: [PATCH 18/23] Refactoring --- "eval \"$(ssh-agent -s)\"" | 7 ------- "eval \"$(ssh-agent -s)\".pub" | 1 - 2 files changed, 8 deletions(-) delete mode 100644 "eval \"$(ssh-agent -s)\"" delete mode 100644 "eval \"$(ssh-agent -s)\".pub" diff --git "a/eval \"$(ssh-agent -s)\"" "b/eval \"$(ssh-agent -s)\"" deleted file mode 100644 index 08890ba6..00000000 --- "a/eval \"$(ssh-agent -s)\"" +++ /dev/null @@ -1,7 +0,0 @@ ------BEGIN OPENSSH PRIVATE KEY----- -b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW -QyNTUxOQAAACDgli+WytIJ+zP8X8dx6jhvXKKZugu6yS49WJWLxVbpLwAAAJiPmjxZj5o8 -WQAAAAtzc2gtZWQyNTUxOQAAACDgli+WytIJ+zP8X8dx6jhvXKKZugu6yS49WJWLxVbpLw -AAAECpWJ6QKQub+iCVzQ60Di2V6A7B8doiXSocBOtHi2ODluCWL5bK0gn7M/xfx3HqOG9c -opm6C7rJLj1YlYvFVukvAAAAEWFtcmkzNjlAZ21haWwuY29tAQIDBA== ------END OPENSSH PRIVATE KEY----- diff --git "a/eval \"$(ssh-agent -s)\".pub" "b/eval \"$(ssh-agent -s)\".pub" deleted file mode 100644 index cb370cf8..00000000 --- "a/eval \"$(ssh-agent -s)\".pub" +++ /dev/null @@ -1 +0,0 @@ -ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIOCWL5bK0gn7M/xfx3HqOG9copm6C7rJLj1YlYvFVukv amri369@gmail.com From 9e3731678f9ec9bf22f41253ab08c7e36dee4fa3 Mon Sep 17 00:00:00 2001 From: Mohamed Amri Date: Thu, 27 Mar 2025 21:06:24 +0400 Subject: [PATCH 19/23] Update Readme --- examples/agent_patterns/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/agent_patterns/README.md b/examples/agent_patterns/README.md index 96b48920..b39d5f6f 100644 --- a/examples/agent_patterns/README.md +++ b/examples/agent_patterns/README.md @@ -51,4 +51,4 @@ You can definitely do this without any special Agents SDK features by using para This is really useful for latency: for example, you might have a very fast model that runs the guardrail and a slow model that runs the actual agent. You wouldn't want to wait for the slow model to finish, so guardrails let you quickly reject invalid inputs. -See the [`input_guardrails.py`](./input_guardrails.py) and [`output_guardrails.py`](./output_guardrails.py) files for examples. +See the [`input_guardrails.py`](./input_guardrails.py), [`output_guardrails.py`](./output_guardrails.py) and [`fact_checking_guardrails.py`](./fact_checking_guardrails.py) files for examples. From 6e4b656638fbc917137dac6828e28931b430fa93 Mon Sep 17 00:00:00 2001 From: Mohamed Amri Date: Thu, 27 Mar 2025 21:26:54 +0400 Subject: [PATCH 20/23] Fixing typecheck tests --- src/agents/run.py | 2 +- tests/test_result_cast.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/agents/run.py b/src/agents/run.py index ee4028df..2b713094 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -630,7 +630,7 @@ async def _run_streamed_impl( fact_checking_guardrails_results = [] streamed_result.output_guardrail_results = output_guardrail_results - streamed_result.fact_checking_guardrails = fact_checking_guardrails_results + streamed_result.fact_checking_guardrail_results = fact_checking_guardrails_results streamed_result.final_output = turn_result.next_step.output streamed_result.is_complete = True streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) diff --git a/tests/test_result_cast.py b/tests/test_result_cast.py index ec17e327..33155fe9 100644 --- a/tests/test_result_cast.py +++ b/tests/test_result_cast.py @@ -14,6 +14,7 @@ def create_run_result(final_output: Any) -> RunResult: final_output=final_output, input_guardrail_results=[], output_guardrail_results=[], + fact_checking_guardrail_results=[], _last_agent=Agent(name="test"), ) From df732ea95ec38117eca98a0534f0c89fa8bdb784 Mon Sep 17 00:00:00 2001 From: Mohamed Amri Date: Thu, 27 Mar 2025 21:28:56 +0400 Subject: [PATCH 21/23] Fix lint format --- src/agents/run.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/agents/run.py b/src/agents/run.py index 2b713094..fb91f6a9 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -630,7 +630,8 @@ async def _run_streamed_impl( fact_checking_guardrails_results = [] streamed_result.output_guardrail_results = output_guardrail_results - streamed_result.fact_checking_guardrail_results = fact_checking_guardrails_results + streamed_result.fact_checking_guardrail_results = ( + fact_checking_guardrails_results) streamed_result.final_output = turn_result.next_step.output streamed_result.is_complete = True streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) From 1ea33f0f747b3bd81cb8752e98d552c797325010 Mon Sep 17 00:00:00 2001 From: Mohamed Amri Date: Thu, 27 Mar 2025 22:05:05 +0400 Subject: [PATCH 22/23] Fix unit tests --- src/agents/result.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/agents/result.py b/src/agents/result.py index c1a8a550..1b3849a1 100644 --- a/src/agents/result.py +++ b/src/agents/result.py @@ -16,9 +16,9 @@ from .stream_events import StreamEvent from .tracing import Trace from .util._pretty_print import pretty_print_result, pretty_print_run_result_streaming +from ._run_impl import QueueCompleteSentinel if TYPE_CHECKING: - from ._run_impl import QueueCompleteSentinel from .agent import Agent T = TypeVar("T") From d7639964244450e4ca8902a32ab7dc3888ba4e9a Mon Sep 17 00:00:00 2001 From: Mohamed Amri Date: Thu, 27 Mar 2025 22:13:19 +0400 Subject: [PATCH 23/23] Fix lint --- src/agents/result.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/agents/result.py b/src/agents/result.py index 1b3849a1..71f467d4 100644 --- a/src/agents/result.py +++ b/src/agents/result.py @@ -8,6 +8,7 @@ from typing_extensions import TypeVar +from ._run_impl import QueueCompleteSentinel from .agent_output import AgentOutputSchema from .exceptions import InputGuardrailTripwireTriggered, MaxTurnsExceeded from .guardrail import FactCheckingGuardrailResult, InputGuardrailResult, OutputGuardrailResult @@ -16,7 +17,6 @@ from .stream_events import StreamEvent from .tracing import Trace from .util._pretty_print import pretty_print_result, pretty_print_run_result_streaming -from ._run_impl import QueueCompleteSentinel if TYPE_CHECKING: from .agent import Agent