Skip to content

Commit eb447fb

Browse files
committed
Fix guardrail task cleanup to properly await cancelled tasks
Problem: The _cleanup_guardrail_tasks() method in RealtimeSession was only calling task.cancel() on pending guardrail tasks but not awaiting them. This could lead to: 1. Unhandled task exception warnings 2. Potential memory leaks from abandoned tasks 3. Improper resource cleanup Evidence: - Test code in tests/realtime/test_session.py:1199 shows the correct pattern: await asyncio.gather(*session._guardrail_tasks, return_exceptions=True) - Similar pattern used in openai_realtime.py:519-523 for WebSocket task cleanup Solution: 1. Made _cleanup_guardrail_tasks() async 2. Added await asyncio.gather() for real asyncio.Task objects to properly collect exceptions (with isinstance check to support mock objects in tests) 3. Updated _cleanup() to await the cleanup method Testing: - Created comprehensive test suite in tests/realtime/test_guardrail_cleanup.py with 3 test cases: 1. Verify cancelled tasks are properly awaited 2. Verify exceptions during cleanup are handled 3. Verify multiple concurrent tasks are cleaned up - All new tests pass - All existing tests pass (838 passed, 3 skipped) - Note: test_issue_889_guardrail_tool_execution has 1 pre-existing failure unrelated to this PR (also fails on main)
1 parent 16169e1 commit eb447fb

File tree

2 files changed

+264
-2
lines changed

2 files changed

+264
-2
lines changed

src/agents/realtime/session.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -746,16 +746,32 @@ def _on_guardrail_task_done(self, task: asyncio.Task[Any]) -> None:
746746
)
747747
)
748748

749-
def _cleanup_guardrail_tasks(self) -> None:
749+
async def _cleanup_guardrail_tasks(self) -> None:
750+
"""Cancel all pending guardrail tasks and wait for them to complete.
751+
752+
This ensures that any exceptions raised by the tasks are properly handled
753+
and prevents warnings about unhandled task exceptions.
754+
"""
755+
# Collect real asyncio.Task objects that need to be awaited
756+
real_tasks = []
757+
750758
for task in self._guardrail_tasks:
751759
if not task.done():
752760
task.cancel()
761+
# Only await real asyncio.Task objects (not mocks in tests)
762+
if isinstance(task, asyncio.Task):
763+
real_tasks.append(task)
764+
765+
# Wait for all real tasks to complete and collect any exceptions
766+
if real_tasks:
767+
await asyncio.gather(*real_tasks, return_exceptions=True)
768+
753769
self._guardrail_tasks.clear()
754770

755771
async def _cleanup(self) -> None:
756772
"""Clean up all resources and mark session as closed."""
757773
# Cancel and cleanup guardrail tasks
758-
self._cleanup_guardrail_tasks()
774+
await self._cleanup_guardrail_tasks()
759775

760776
# Remove ourselves as a listener
761777
self._model.remove_listener(self)
Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
"""Test guardrail task cleanup to ensure proper exception handling.
2+
3+
This test verifies the fix for the bug where _cleanup_guardrail_tasks() was not
4+
properly awaiting cancelled tasks, which could lead to unhandled task exceptions
5+
and potential memory leaks.
6+
"""
7+
8+
import asyncio
9+
from unittest.mock import AsyncMock, Mock, PropertyMock
10+
11+
import pytest
12+
13+
from agents.guardrail import GuardrailFunctionOutput, OutputGuardrail
14+
from agents.realtime import RealtimeSession
15+
from agents.realtime.agent import RealtimeAgent
16+
from agents.realtime.config import RealtimeRunConfig
17+
from agents.realtime.model import RealtimeModel
18+
from agents.realtime.model_events import RealtimeModelTranscriptDeltaEvent
19+
20+
21+
class MockRealtimeModel(RealtimeModel):
22+
"""Mock realtime model for testing."""
23+
24+
def __init__(self):
25+
super().__init__()
26+
self.listeners = []
27+
self.connect_called = False
28+
self.close_called = False
29+
self.sent_events = []
30+
self.sent_messages = []
31+
self.sent_audio = []
32+
self.sent_tool_outputs = []
33+
self.interrupts_called = 0
34+
35+
async def connect(self, options=None):
36+
self.connect_called = True
37+
38+
def add_listener(self, listener):
39+
self.listeners.append(listener)
40+
41+
def remove_listener(self, listener):
42+
if listener in self.listeners:
43+
self.listeners.remove(listener)
44+
45+
async def send_event(self, event):
46+
from agents.realtime.model_inputs import (
47+
RealtimeModelSendAudio,
48+
RealtimeModelSendInterrupt,
49+
RealtimeModelSendToolOutput,
50+
RealtimeModelSendUserInput,
51+
)
52+
53+
self.sent_events.append(event)
54+
55+
# Update legacy tracking for compatibility
56+
if isinstance(event, RealtimeModelSendUserInput):
57+
self.sent_messages.append(event.user_input)
58+
elif isinstance(event, RealtimeModelSendAudio):
59+
self.sent_audio.append((event.audio, event.commit))
60+
elif isinstance(event, RealtimeModelSendToolOutput):
61+
self.sent_tool_outputs.append((event.tool_call, event.output, event.start_response))
62+
elif isinstance(event, RealtimeModelSendInterrupt):
63+
self.interrupts_called += 1
64+
65+
async def close(self):
66+
self.close_called = True
67+
68+
69+
@pytest.fixture
70+
def mock_model():
71+
return MockRealtimeModel()
72+
73+
74+
@pytest.fixture
75+
def mock_agent():
76+
agent = Mock(spec=RealtimeAgent)
77+
agent.name = "test_agent"
78+
agent.get_all_tools = AsyncMock(return_value=[])
79+
type(agent).handoffs = PropertyMock(return_value=[])
80+
type(agent).output_guardrails = PropertyMock(return_value=[])
81+
return agent
82+
83+
84+
@pytest.mark.asyncio
85+
async def test_guardrail_task_cleanup_awaits_cancelled_tasks(mock_model, mock_agent):
86+
"""Test that cleanup properly awaits cancelled guardrail tasks.
87+
88+
This test verifies that when guardrail tasks are cancelled during cleanup,
89+
the cleanup method properly awaits them to completion using asyncio.gather()
90+
with return_exceptions=True. This ensures:
91+
1. No warnings about unhandled task exceptions
92+
2. Proper resource cleanup
93+
3. No memory leaks from abandoned tasks
94+
"""
95+
96+
# Create a guardrail that runs a long async operation
97+
task_started = asyncio.Event()
98+
task_cancelled = asyncio.Event()
99+
100+
async def slow_guardrail_func(context, agent, output):
101+
"""A guardrail that takes time to execute."""
102+
task_started.set()
103+
try:
104+
# Simulate a long-running operation
105+
await asyncio.sleep(10)
106+
return GuardrailFunctionOutput(output_info={}, tripwire_triggered=False)
107+
except asyncio.CancelledError:
108+
task_cancelled.set()
109+
raise
110+
111+
guardrail = OutputGuardrail(guardrail_function=slow_guardrail_func, name="slow_guardrail")
112+
113+
run_config: RealtimeRunConfig = {
114+
"output_guardrails": [guardrail],
115+
"guardrails_settings": {"debounce_text_length": 5},
116+
}
117+
118+
session = RealtimeSession(mock_model, mock_agent, None, run_config=run_config)
119+
120+
# Trigger a guardrail by sending a transcript delta
121+
transcript_event = RealtimeModelTranscriptDeltaEvent(
122+
item_id="item_1", delta="hello world", response_id="resp_1"
123+
)
124+
125+
await session.on_event(transcript_event)
126+
127+
# Wait for the guardrail task to start
128+
await asyncio.wait_for(task_started.wait(), timeout=1.0)
129+
130+
# Verify a guardrail task was created
131+
assert len(session._guardrail_tasks) == 1
132+
task = list(session._guardrail_tasks)[0]
133+
assert not task.done()
134+
135+
# Now cleanup the session - this should cancel and await the task
136+
await session._cleanup_guardrail_tasks()
137+
138+
# Verify the task was cancelled and properly awaited
139+
assert task_cancelled.is_set(), "Task should have received CancelledError"
140+
assert len(session._guardrail_tasks) == 0, "Tasks list should be cleared"
141+
142+
# No warnings should be raised about unhandled task exceptions
143+
144+
145+
@pytest.mark.asyncio
146+
async def test_guardrail_task_cleanup_with_exception(mock_model, mock_agent):
147+
"""Test that cleanup handles guardrail tasks that raise exceptions.
148+
149+
This test verifies that if a guardrail task raises an exception (not just
150+
CancelledError), the cleanup method still completes successfully and doesn't
151+
propagate the exception, thanks to return_exceptions=True.
152+
"""
153+
154+
task_started = asyncio.Event()
155+
exception_raised = asyncio.Event()
156+
157+
async def failing_guardrail_func(context, agent, output):
158+
"""A guardrail that raises an exception."""
159+
task_started.set()
160+
try:
161+
await asyncio.sleep(10)
162+
return GuardrailFunctionOutput(output_info={}, tripwire_triggered=False)
163+
except asyncio.CancelledError as e:
164+
exception_raised.set()
165+
# Simulate an error during cleanup
166+
raise RuntimeError("Cleanup error") from e
167+
168+
guardrail = OutputGuardrail(
169+
guardrail_function=failing_guardrail_func, name="failing_guardrail"
170+
)
171+
172+
run_config: RealtimeRunConfig = {
173+
"output_guardrails": [guardrail],
174+
"guardrails_settings": {"debounce_text_length": 5},
175+
}
176+
177+
session = RealtimeSession(mock_model, mock_agent, None, run_config=run_config)
178+
179+
# Trigger a guardrail
180+
transcript_event = RealtimeModelTranscriptDeltaEvent(
181+
item_id="item_1", delta="hello world", response_id="resp_1"
182+
)
183+
184+
await session.on_event(transcript_event)
185+
186+
# Wait for the guardrail task to start
187+
await asyncio.wait_for(task_started.wait(), timeout=1.0)
188+
189+
# Cleanup should not raise the RuntimeError due to return_exceptions=True
190+
await session._cleanup_guardrail_tasks()
191+
192+
# Verify cleanup completed successfully
193+
assert exception_raised.is_set()
194+
assert len(session._guardrail_tasks) == 0
195+
196+
197+
@pytest.mark.asyncio
198+
async def test_guardrail_task_cleanup_with_multiple_tasks(mock_model, mock_agent):
199+
"""Test cleanup with multiple pending guardrail tasks.
200+
201+
This test verifies that cleanup properly handles multiple concurrent guardrail
202+
tasks by triggering guardrails multiple times, then cancelling and awaiting all of them.
203+
"""
204+
205+
tasks_started = asyncio.Event()
206+
tasks_cancelled = 0
207+
208+
async def slow_guardrail_func(context, agent, output):
209+
nonlocal tasks_cancelled
210+
tasks_started.set()
211+
try:
212+
await asyncio.sleep(10)
213+
return GuardrailFunctionOutput(output_info={}, tripwire_triggered=False)
214+
except asyncio.CancelledError:
215+
tasks_cancelled += 1
216+
raise
217+
218+
guardrail = OutputGuardrail(guardrail_function=slow_guardrail_func, name="slow_guardrail")
219+
220+
run_config: RealtimeRunConfig = {
221+
"output_guardrails": [guardrail],
222+
"guardrails_settings": {"debounce_text_length": 5},
223+
}
224+
225+
session = RealtimeSession(mock_model, mock_agent, None, run_config=run_config)
226+
227+
# Trigger guardrails multiple times to create multiple tasks
228+
for i in range(3):
229+
transcript_event = RealtimeModelTranscriptDeltaEvent(
230+
item_id=f"item_{i}", delta="hello world", response_id=f"resp_{i}"
231+
)
232+
await session.on_event(transcript_event)
233+
234+
# Wait for at least one task to start
235+
await asyncio.wait_for(tasks_started.wait(), timeout=1.0)
236+
237+
# Should have at least one guardrail task
238+
initial_task_count = len(session._guardrail_tasks)
239+
assert initial_task_count >= 1, "At least one guardrail task should exist"
240+
241+
# Cleanup should cancel and await all tasks
242+
await session._cleanup_guardrail_tasks()
243+
244+
# Verify all tasks were cancelled and cleared
245+
assert tasks_cancelled >= 1, "At least one task should have been cancelled"
246+
assert len(session._guardrail_tasks) == 0

0 commit comments

Comments
 (0)