Skip to content

Commit 59a8b0f

Browse files
authored
Fix streaming trace end before guardrails complete (openai#1921)
1 parent 04eec50 commit 59a8b0f

File tree

2 files changed

+239
-0
lines changed

2 files changed

+239
-0
lines changed

src/agents/run.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1138,6 +1138,15 @@ async def _start_streaming(
11381138

11391139
streamed_result.is_complete = True
11401140
finally:
1141+
if streamed_result._input_guardrails_task:
1142+
try:
1143+
await AgentRunner._input_guardrail_tripwire_triggered_for_stream(
1144+
streamed_result
1145+
)
1146+
except Exception as e:
1147+
logger.debug(
1148+
f"Error in streamed_result finalize for agent {current_agent.name} - {e}"
1149+
)
11411150
if current_span:
11421151
current_span.finish(reset_current=True)
11431152
if streamed_result.trace:
Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
from datetime import datetime
5+
from typing import Any
6+
7+
import pytest
8+
from openai.types.responses import ResponseCompletedEvent
9+
10+
from agents import Agent, GuardrailFunctionOutput, InputGuardrail, RunContextWrapper, Runner
11+
from agents.exceptions import InputGuardrailTripwireTriggered
12+
from agents.items import TResponseInputItem
13+
from tests.fake_model import FakeModel
14+
from tests.test_responses import get_text_message
15+
from tests.testing_processor import fetch_events, fetch_ordered_spans
16+
17+
18+
def make_input_guardrail(delay_seconds: float, *, trip: bool) -> InputGuardrail[Any]:
19+
async def guardrail(
20+
ctx: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem]
21+
) -> GuardrailFunctionOutput:
22+
# Simulate variable guardrail completion timing.
23+
if delay_seconds > 0:
24+
await asyncio.sleep(delay_seconds)
25+
return GuardrailFunctionOutput(
26+
output_info={"delay": delay_seconds}, tripwire_triggered=trip
27+
)
28+
29+
name = "tripping_input_guardrail" if trip else "delayed_input_guardrail"
30+
return InputGuardrail(guardrail_function=guardrail, name=name)
31+
32+
33+
@pytest.mark.asyncio
34+
@pytest.mark.parametrize("guardrail_delay", [0.0, 0.2])
35+
async def test_run_streamed_input_guardrail_timing_is_consistent(guardrail_delay: float):
36+
"""Ensure streaming behavior matches when input guardrail finishes before and after LLM stream.
37+
38+
We verify that:
39+
- The sequence of streamed event types is identical.
40+
- Final output matches.
41+
- Exactly one input guardrail result is recorded and does not trigger.
42+
"""
43+
44+
# Arrange: Agent with a single text output and a delayed input guardrail
45+
model = FakeModel()
46+
model.set_next_output([get_text_message("Final response")])
47+
48+
agent = Agent(
49+
name="TimingAgent",
50+
model=model,
51+
input_guardrails=[make_input_guardrail(guardrail_delay, trip=False)],
52+
)
53+
54+
# Act: Run streamed and collect event types
55+
result = Runner.run_streamed(agent, input="Hello")
56+
event_types: list[str] = []
57+
58+
async for event in result.stream_events():
59+
event_types.append(event.type)
60+
61+
# Assert: Guardrail results populated and identical behavioral outcome
62+
assert len(result.input_guardrail_results) == 1, "Expected exactly one input guardrail result"
63+
assert result.input_guardrail_results[0].guardrail.get_name() == "delayed_input_guardrail", (
64+
"Guardrail name mismatch"
65+
)
66+
assert result.input_guardrail_results[0].output.tripwire_triggered is False, (
67+
"Guardrail should not trigger in this test"
68+
)
69+
70+
# Final output should be the text from the model's single message
71+
assert result.final_output == "Final response"
72+
73+
# Minimal invariants on event sequence to ensure stability across timing
74+
# Must start with agent update and include raw response events
75+
assert len(event_types) >= 3, f"Unexpectedly few events: {event_types}"
76+
assert event_types[0] == "agent_updated_stream_event"
77+
# Ensure we observed raw response events in the stream irrespective of guardrail timing
78+
assert any(t == "raw_response_event" for t in event_types)
79+
80+
81+
@pytest.mark.asyncio
82+
async def test_run_streamed_input_guardrail_sequences_match_between_fast_and_slow():
83+
"""Run twice with fast vs slow input guardrail and compare event sequences exactly."""
84+
85+
async def run_once(delay: float) -> list[str]:
86+
model = FakeModel()
87+
model.set_next_output([get_text_message("Final response")])
88+
agent = Agent(
89+
name="TimingAgent",
90+
model=model,
91+
input_guardrails=[make_input_guardrail(delay, trip=False)],
92+
)
93+
result = Runner.run_streamed(agent, input="Hello")
94+
events: list[str] = []
95+
async for ev in result.stream_events():
96+
events.append(ev.type)
97+
return events
98+
99+
events_fast = await run_once(0.0)
100+
events_slow = await run_once(0.2)
101+
102+
assert events_fast == events_slow, (
103+
f"Event sequences differ between guardrail timings:\nfast={events_fast}\nslow={events_slow}"
104+
)
105+
106+
107+
@pytest.mark.asyncio
108+
@pytest.mark.parametrize("guardrail_delay", [0.0, 0.2])
109+
async def test_run_streamed_input_guardrail_tripwire_raises(guardrail_delay: float):
110+
"""Guardrail tripwire must raise from stream_events regardless of timing."""
111+
112+
model = FakeModel()
113+
model.set_next_output([get_text_message("Final response")])
114+
115+
agent = Agent(
116+
name="TimingAgentTrip",
117+
model=model,
118+
input_guardrails=[make_input_guardrail(guardrail_delay, trip=True)],
119+
)
120+
121+
result = Runner.run_streamed(agent, input="Hello")
122+
123+
with pytest.raises(InputGuardrailTripwireTriggered) as excinfo:
124+
async for _ in result.stream_events():
125+
pass
126+
127+
# Exception contains the guardrail result and run data
128+
exc = excinfo.value
129+
assert exc.guardrail_result.output.tripwire_triggered is True
130+
assert exc.run_data is not None
131+
assert len(exc.run_data.input_guardrail_results) == 1
132+
assert (
133+
exc.run_data.input_guardrail_results[0].guardrail.get_name() == "tripping_input_guardrail"
134+
)
135+
136+
137+
class SlowCompleteFakeModel(FakeModel):
138+
"""A FakeModel that delays just before emitting ResponseCompletedEvent in streaming."""
139+
140+
def __init__(self, delay_seconds: float, tracing_enabled: bool = True):
141+
super().__init__(tracing_enabled=tracing_enabled)
142+
self._delay_seconds = delay_seconds
143+
144+
async def stream_response(self, *args, **kwargs):
145+
async for ev in super().stream_response(*args, **kwargs):
146+
if isinstance(ev, ResponseCompletedEvent) and self._delay_seconds > 0:
147+
await asyncio.sleep(self._delay_seconds)
148+
yield ev
149+
150+
151+
def _get_span_by_type(spans, span_type: str):
152+
for s in spans:
153+
exported = s.export()
154+
if not exported:
155+
continue
156+
if exported.get("span_data", {}).get("type") == span_type:
157+
return s
158+
return None
159+
160+
161+
def _iso(s: str | None) -> datetime:
162+
assert s is not None
163+
return datetime.fromisoformat(s)
164+
165+
166+
@pytest.mark.asyncio
167+
async def test_parent_span_and_trace_finish_after_slow_input_guardrail():
168+
"""Agent span and trace finish after guardrail when guardrail completes last."""
169+
170+
model = FakeModel(tracing_enabled=True)
171+
model.set_next_output([get_text_message("Final response")])
172+
agent = Agent(
173+
name="TimingAgentTrace",
174+
model=model,
175+
input_guardrails=[make_input_guardrail(0.2, trip=False)], # guardrail slower than model
176+
)
177+
178+
result = Runner.run_streamed(agent, input="Hello")
179+
async for _ in result.stream_events():
180+
pass
181+
182+
spans = fetch_ordered_spans()
183+
agent_span = _get_span_by_type(spans, "agent")
184+
guardrail_span = _get_span_by_type(spans, "guardrail")
185+
generation_span = _get_span_by_type(spans, "generation")
186+
187+
assert agent_span and guardrail_span and generation_span, (
188+
"Expected agent, guardrail, generation spans"
189+
)
190+
191+
# Agent span must finish last
192+
assert _iso(agent_span.ended_at) >= _iso(guardrail_span.ended_at)
193+
assert _iso(agent_span.ended_at) >= _iso(generation_span.ended_at)
194+
195+
# Trace should end after all spans end
196+
events = fetch_events()
197+
assert events[-1] == "trace_end"
198+
199+
200+
@pytest.mark.asyncio
201+
async def test_parent_span_and_trace_finish_after_slow_model():
202+
"""Agent span and trace finish after model when model completes last."""
203+
204+
model = SlowCompleteFakeModel(delay_seconds=0.2, tracing_enabled=True)
205+
model.set_next_output([get_text_message("Final response")])
206+
agent = Agent(
207+
name="TimingAgentTrace",
208+
model=model,
209+
input_guardrails=[make_input_guardrail(0.0, trip=False)], # guardrail faster than model
210+
)
211+
212+
result = Runner.run_streamed(agent, input="Hello")
213+
async for _ in result.stream_events():
214+
pass
215+
216+
spans = fetch_ordered_spans()
217+
agent_span = _get_span_by_type(spans, "agent")
218+
guardrail_span = _get_span_by_type(spans, "guardrail")
219+
generation_span = _get_span_by_type(spans, "generation")
220+
221+
assert agent_span and guardrail_span and generation_span, (
222+
"Expected agent, guardrail, generation spans"
223+
)
224+
225+
# Agent span must finish last
226+
assert _iso(agent_span.ended_at) >= _iso(guardrail_span.ended_at)
227+
assert _iso(agent_span.ended_at) >= _iso(generation_span.ended_at)
228+
229+
events = fetch_events()
230+
assert events[-1] == "trace_end"

0 commit comments

Comments
 (0)