|  | 
|  | 1 | +"""Test guardrail task cleanup to ensure proper exception handling. | 
|  | 2 | +
 | 
|  | 3 | +This test verifies the fix for the bug where _cleanup_guardrail_tasks() was not | 
|  | 4 | +properly awaiting cancelled tasks, which could lead to unhandled task exceptions | 
|  | 5 | +and potential memory leaks. | 
|  | 6 | +""" | 
|  | 7 | + | 
|  | 8 | +import asyncio | 
|  | 9 | +from unittest.mock import AsyncMock, Mock, PropertyMock | 
|  | 10 | + | 
|  | 11 | +import pytest | 
|  | 12 | + | 
|  | 13 | +from agents.guardrail import GuardrailFunctionOutput, OutputGuardrail | 
|  | 14 | +from agents.realtime import RealtimeSession | 
|  | 15 | +from agents.realtime.agent import RealtimeAgent | 
|  | 16 | +from agents.realtime.config import RealtimeRunConfig | 
|  | 17 | +from agents.realtime.model import RealtimeModel | 
|  | 18 | +from agents.realtime.model_events import RealtimeModelTranscriptDeltaEvent | 
|  | 19 | + | 
|  | 20 | + | 
|  | 21 | +class MockRealtimeModel(RealtimeModel): | 
|  | 22 | +    """Mock realtime model for testing.""" | 
|  | 23 | + | 
|  | 24 | +    def __init__(self): | 
|  | 25 | +        super().__init__() | 
|  | 26 | +        self.listeners = [] | 
|  | 27 | +        self.connect_called = False | 
|  | 28 | +        self.close_called = False | 
|  | 29 | +        self.sent_events = [] | 
|  | 30 | +        self.sent_messages = [] | 
|  | 31 | +        self.sent_audio = [] | 
|  | 32 | +        self.sent_tool_outputs = [] | 
|  | 33 | +        self.interrupts_called = 0 | 
|  | 34 | + | 
|  | 35 | +    async def connect(self, options=None): | 
|  | 36 | +        self.connect_called = True | 
|  | 37 | + | 
|  | 38 | +    def add_listener(self, listener): | 
|  | 39 | +        self.listeners.append(listener) | 
|  | 40 | + | 
|  | 41 | +    def remove_listener(self, listener): | 
|  | 42 | +        if listener in self.listeners: | 
|  | 43 | +            self.listeners.remove(listener) | 
|  | 44 | + | 
|  | 45 | +    async def send_event(self, event): | 
|  | 46 | +        from agents.realtime.model_inputs import ( | 
|  | 47 | +            RealtimeModelSendAudio, | 
|  | 48 | +            RealtimeModelSendInterrupt, | 
|  | 49 | +            RealtimeModelSendToolOutput, | 
|  | 50 | +            RealtimeModelSendUserInput, | 
|  | 51 | +        ) | 
|  | 52 | + | 
|  | 53 | +        self.sent_events.append(event) | 
|  | 54 | + | 
|  | 55 | +        # Update legacy tracking for compatibility | 
|  | 56 | +        if isinstance(event, RealtimeModelSendUserInput): | 
|  | 57 | +            self.sent_messages.append(event.user_input) | 
|  | 58 | +        elif isinstance(event, RealtimeModelSendAudio): | 
|  | 59 | +            self.sent_audio.append((event.audio, event.commit)) | 
|  | 60 | +        elif isinstance(event, RealtimeModelSendToolOutput): | 
|  | 61 | +            self.sent_tool_outputs.append((event.tool_call, event.output, event.start_response)) | 
|  | 62 | +        elif isinstance(event, RealtimeModelSendInterrupt): | 
|  | 63 | +            self.interrupts_called += 1 | 
|  | 64 | + | 
|  | 65 | +    async def close(self): | 
|  | 66 | +        self.close_called = True | 
|  | 67 | + | 
|  | 68 | + | 
|  | 69 | +@pytest.fixture | 
|  | 70 | +def mock_model(): | 
|  | 71 | +    return MockRealtimeModel() | 
|  | 72 | + | 
|  | 73 | + | 
|  | 74 | +@pytest.fixture | 
|  | 75 | +def mock_agent(): | 
|  | 76 | +    agent = Mock(spec=RealtimeAgent) | 
|  | 77 | +    agent.name = "test_agent" | 
|  | 78 | +    agent.get_all_tools = AsyncMock(return_value=[]) | 
|  | 79 | +    type(agent).handoffs = PropertyMock(return_value=[]) | 
|  | 80 | +    type(agent).output_guardrails = PropertyMock(return_value=[]) | 
|  | 81 | +    return agent | 
|  | 82 | + | 
|  | 83 | + | 
|  | 84 | +@pytest.mark.asyncio | 
|  | 85 | +async def test_guardrail_task_cleanup_awaits_cancelled_tasks(mock_model, mock_agent): | 
|  | 86 | +    """Test that cleanup properly awaits cancelled guardrail tasks. | 
|  | 87 | +
 | 
|  | 88 | +    This test verifies that when guardrail tasks are cancelled during cleanup, | 
|  | 89 | +    the cleanup method properly awaits them to completion using asyncio.gather() | 
|  | 90 | +    with return_exceptions=True. This ensures: | 
|  | 91 | +    1. No warnings about unhandled task exceptions | 
|  | 92 | +    2. Proper resource cleanup | 
|  | 93 | +    3. No memory leaks from abandoned tasks | 
|  | 94 | +    """ | 
|  | 95 | + | 
|  | 96 | +    # Create a guardrail that runs a long async operation | 
|  | 97 | +    task_started = asyncio.Event() | 
|  | 98 | +    task_cancelled = asyncio.Event() | 
|  | 99 | + | 
|  | 100 | +    async def slow_guardrail_func(context, agent, output): | 
|  | 101 | +        """A guardrail that takes time to execute.""" | 
|  | 102 | +        task_started.set() | 
|  | 103 | +        try: | 
|  | 104 | +            # Simulate a long-running operation | 
|  | 105 | +            await asyncio.sleep(10) | 
|  | 106 | +            return GuardrailFunctionOutput(output_info={}, tripwire_triggered=False) | 
|  | 107 | +        except asyncio.CancelledError: | 
|  | 108 | +            task_cancelled.set() | 
|  | 109 | +            raise | 
|  | 110 | + | 
|  | 111 | +    guardrail = OutputGuardrail(guardrail_function=slow_guardrail_func, name="slow_guardrail") | 
|  | 112 | + | 
|  | 113 | +    run_config: RealtimeRunConfig = { | 
|  | 114 | +        "output_guardrails": [guardrail], | 
|  | 115 | +        "guardrails_settings": {"debounce_text_length": 5}, | 
|  | 116 | +    } | 
|  | 117 | + | 
|  | 118 | +    session = RealtimeSession(mock_model, mock_agent, None, run_config=run_config) | 
|  | 119 | + | 
|  | 120 | +    # Trigger a guardrail by sending a transcript delta | 
|  | 121 | +    transcript_event = RealtimeModelTranscriptDeltaEvent( | 
|  | 122 | +        item_id="item_1", delta="hello world", response_id="resp_1" | 
|  | 123 | +    ) | 
|  | 124 | + | 
|  | 125 | +    await session.on_event(transcript_event) | 
|  | 126 | + | 
|  | 127 | +    # Wait for the guardrail task to start | 
|  | 128 | +    await asyncio.wait_for(task_started.wait(), timeout=1.0) | 
|  | 129 | + | 
|  | 130 | +    # Verify a guardrail task was created | 
|  | 131 | +    assert len(session._guardrail_tasks) == 1 | 
|  | 132 | +    task = list(session._guardrail_tasks)[0] | 
|  | 133 | +    assert not task.done() | 
|  | 134 | + | 
|  | 135 | +    # Now cleanup the session - this should cancel and await the task | 
|  | 136 | +    await session._cleanup_guardrail_tasks() | 
|  | 137 | + | 
|  | 138 | +    # Verify the task was cancelled and properly awaited | 
|  | 139 | +    assert task_cancelled.is_set(), "Task should have received CancelledError" | 
|  | 140 | +    assert len(session._guardrail_tasks) == 0, "Tasks list should be cleared" | 
|  | 141 | + | 
|  | 142 | +    # No warnings should be raised about unhandled task exceptions | 
|  | 143 | + | 
|  | 144 | + | 
|  | 145 | +@pytest.mark.asyncio | 
|  | 146 | +async def test_guardrail_task_cleanup_with_exception(mock_model, mock_agent): | 
|  | 147 | +    """Test that cleanup handles guardrail tasks that raise exceptions. | 
|  | 148 | +
 | 
|  | 149 | +    This test verifies that if a guardrail task raises an exception (not just | 
|  | 150 | +    CancelledError), the cleanup method still completes successfully and doesn't | 
|  | 151 | +    propagate the exception, thanks to return_exceptions=True. | 
|  | 152 | +    """ | 
|  | 153 | + | 
|  | 154 | +    task_started = asyncio.Event() | 
|  | 155 | +    exception_raised = asyncio.Event() | 
|  | 156 | + | 
|  | 157 | +    async def failing_guardrail_func(context, agent, output): | 
|  | 158 | +        """A guardrail that raises an exception.""" | 
|  | 159 | +        task_started.set() | 
|  | 160 | +        try: | 
|  | 161 | +            await asyncio.sleep(10) | 
|  | 162 | +            return GuardrailFunctionOutput(output_info={}, tripwire_triggered=False) | 
|  | 163 | +        except asyncio.CancelledError as e: | 
|  | 164 | +            exception_raised.set() | 
|  | 165 | +            # Simulate an error during cleanup | 
|  | 166 | +            raise RuntimeError("Cleanup error") from e | 
|  | 167 | + | 
|  | 168 | +    guardrail = OutputGuardrail( | 
|  | 169 | +        guardrail_function=failing_guardrail_func, name="failing_guardrail" | 
|  | 170 | +    ) | 
|  | 171 | + | 
|  | 172 | +    run_config: RealtimeRunConfig = { | 
|  | 173 | +        "output_guardrails": [guardrail], | 
|  | 174 | +        "guardrails_settings": {"debounce_text_length": 5}, | 
|  | 175 | +    } | 
|  | 176 | + | 
|  | 177 | +    session = RealtimeSession(mock_model, mock_agent, None, run_config=run_config) | 
|  | 178 | + | 
|  | 179 | +    # Trigger a guardrail | 
|  | 180 | +    transcript_event = RealtimeModelTranscriptDeltaEvent( | 
|  | 181 | +        item_id="item_1", delta="hello world", response_id="resp_1" | 
|  | 182 | +    ) | 
|  | 183 | + | 
|  | 184 | +    await session.on_event(transcript_event) | 
|  | 185 | + | 
|  | 186 | +    # Wait for the guardrail task to start | 
|  | 187 | +    await asyncio.wait_for(task_started.wait(), timeout=1.0) | 
|  | 188 | + | 
|  | 189 | +    # Cleanup should not raise the RuntimeError due to return_exceptions=True | 
|  | 190 | +    await session._cleanup_guardrail_tasks() | 
|  | 191 | + | 
|  | 192 | +    # Verify cleanup completed successfully | 
|  | 193 | +    assert exception_raised.is_set() | 
|  | 194 | +    assert len(session._guardrail_tasks) == 0 | 
|  | 195 | + | 
|  | 196 | + | 
|  | 197 | +@pytest.mark.asyncio | 
|  | 198 | +async def test_guardrail_task_cleanup_with_multiple_tasks(mock_model, mock_agent): | 
|  | 199 | +    """Test cleanup with multiple pending guardrail tasks. | 
|  | 200 | +
 | 
|  | 201 | +    This test verifies that cleanup properly handles multiple concurrent guardrail | 
|  | 202 | +    tasks by triggering guardrails multiple times, then cancelling and awaiting all of them. | 
|  | 203 | +    """ | 
|  | 204 | + | 
|  | 205 | +    tasks_started = asyncio.Event() | 
|  | 206 | +    tasks_cancelled = 0 | 
|  | 207 | + | 
|  | 208 | +    async def slow_guardrail_func(context, agent, output): | 
|  | 209 | +        nonlocal tasks_cancelled | 
|  | 210 | +        tasks_started.set() | 
|  | 211 | +        try: | 
|  | 212 | +            await asyncio.sleep(10) | 
|  | 213 | +            return GuardrailFunctionOutput(output_info={}, tripwire_triggered=False) | 
|  | 214 | +        except asyncio.CancelledError: | 
|  | 215 | +            tasks_cancelled += 1 | 
|  | 216 | +            raise | 
|  | 217 | + | 
|  | 218 | +    guardrail = OutputGuardrail(guardrail_function=slow_guardrail_func, name="slow_guardrail") | 
|  | 219 | + | 
|  | 220 | +    run_config: RealtimeRunConfig = { | 
|  | 221 | +        "output_guardrails": [guardrail], | 
|  | 222 | +        "guardrails_settings": {"debounce_text_length": 5}, | 
|  | 223 | +    } | 
|  | 224 | + | 
|  | 225 | +    session = RealtimeSession(mock_model, mock_agent, None, run_config=run_config) | 
|  | 226 | + | 
|  | 227 | +    # Trigger guardrails multiple times to create multiple tasks | 
|  | 228 | +    for i in range(3): | 
|  | 229 | +        transcript_event = RealtimeModelTranscriptDeltaEvent( | 
|  | 230 | +            item_id=f"item_{i}", delta="hello world", response_id=f"resp_{i}" | 
|  | 231 | +        ) | 
|  | 232 | +        await session.on_event(transcript_event) | 
|  | 233 | + | 
|  | 234 | +    # Wait for at least one task to start | 
|  | 235 | +    await asyncio.wait_for(tasks_started.wait(), timeout=1.0) | 
|  | 236 | + | 
|  | 237 | +    # Should have at least one guardrail task | 
|  | 238 | +    initial_task_count = len(session._guardrail_tasks) | 
|  | 239 | +    assert initial_task_count >= 1, "At least one guardrail task should exist" | 
|  | 240 | + | 
|  | 241 | +    # Cleanup should cancel and await all tasks | 
|  | 242 | +    await session._cleanup_guardrail_tasks() | 
|  | 243 | + | 
|  | 244 | +    # Verify all tasks were cancelled and cleared | 
|  | 245 | +    assert tasks_cancelled >= 1, "At least one task should have been cancelled" | 
|  | 246 | +    assert len(session._guardrail_tasks) == 0 | 
0 commit comments