Skip to content

Commit 9645cee

Browse files
jawoszekcopybara-github
authored andcommitted
chore: add test for parallel agent to verify correct handling of exceptions
PiperOrigin-RevId: 797825924
1 parent 70f50db commit 9645cee

File tree

1 file changed

+70
-19
lines changed

1 file changed

+70
-19
lines changed

tests/unittests/agents/test_parallel_agent.py

Lines changed: 70 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,8 @@ class _TestingAgent(BaseAgent):
3333
delay: float = 0
3434
"""The delay before the agent generates an event."""
3535

36-
@override
37-
async def _run_async_impl(
38-
self, ctx: InvocationContext
39-
) -> AsyncGenerator[Event, None]:
40-
await asyncio.sleep(self.delay)
41-
yield Event(
36+
def event(self, ctx: InvocationContext):
37+
return Event(
4238
author=self.name,
4339
branch=ctx.branch,
4440
invocation_id=ctx.invocation_id,
@@ -47,6 +43,13 @@ async def _run_async_impl(
4743
),
4844
)
4945

46+
@override
47+
async def _run_async_impl(
48+
self, ctx: InvocationContext
49+
) -> AsyncGenerator[Event, None]:
50+
await asyncio.sleep(self.delay)
51+
yield self.event(ctx)
52+
5053

5154
async def _create_parent_invocation_context(
5255
test_name: str, agent: BaseAgent
@@ -137,26 +140,19 @@ async def test_run_async_branches(request: pytest.FixtureRequest):
137140
assert events[2].branch != events[0].branch
138141

139142

140-
class _TestingAgentWithMultipleEvents(BaseAgent):
143+
class _TestingAgentWithMultipleEvents(_TestingAgent):
141144
"""Mock agent for testing."""
142145

143146
@override
144147
async def _run_async_impl(
145148
self, ctx: InvocationContext
146149
) -> AsyncGenerator[Event, None]:
147150
for _ in range(0, 3):
148-
event = Event(
149-
author=self.name,
150-
branch=ctx.branch,
151-
invocation_id=ctx.invocation_id,
152-
content=types.Content(
153-
parts=[types.Part(text=f'Hello, async {self.name}!')]
154-
),
155-
)
156-
yield event
157-
# Check that the event was processed by the consumer.
158-
assert event.custom_metadata is not None
159-
assert event.custom_metadata['processed']
151+
event = self.event(ctx)
152+
yield event
153+
# Check that the event was processed by the consumer.
154+
assert event.custom_metadata is not None
155+
assert event.custom_metadata['processed']
160156

161157

162158
@pytest.mark.asyncio
@@ -186,3 +182,58 @@ async def test_generating_one_event_per_agent_at_once(
186182
async for event in agen:
187183
event.custom_metadata = {'processed': True}
188184
# Asserts on event are done in _TestingAgentWithMultipleEvents.
185+
186+
187+
class _TestingAgentWithException(_TestingAgent):
188+
"""Mock agent for testing."""
189+
190+
@override
191+
async def _run_async_impl(
192+
self, ctx: InvocationContext
193+
) -> AsyncGenerator[Event, None]:
194+
yield self.event(ctx)
195+
raise Exception()
196+
197+
198+
class _TestingAgentInfiniteEvents(_TestingAgent):
199+
"""Mock agent for testing."""
200+
201+
@override
202+
async def _run_async_impl(
203+
self, ctx: InvocationContext
204+
) -> AsyncGenerator[Event, None]:
205+
while True:
206+
yield self.event(ctx)
207+
208+
209+
@pytest.mark.asyncio
210+
async def test_stop_agent_if_sub_agent_fails(
211+
request: pytest.FixtureRequest,
212+
):
213+
# This test is to verify that the parallel agent and subagents will all stop
214+
# processing and throw exception to top level runner in case of exception.
215+
agent1 = _TestingAgentWithException(
216+
name=f'{request.function.__name__}_test_agent_1'
217+
)
218+
agent2 = _TestingAgentInfiniteEvents(
219+
name=f'{request.function.__name__}_test_agent_2'
220+
)
221+
parallel_agent = ParallelAgent(
222+
name=f'{request.function.__name__}_test_parallel_agent',
223+
sub_agents=[
224+
agent1,
225+
agent2,
226+
],
227+
)
228+
parent_ctx = await _create_parent_invocation_context(
229+
request.function.__name__, parallel_agent
230+
)
231+
232+
agen = parallel_agent.run_async(parent_ctx)
233+
# We expect to receive an exception from one of subagents.
234+
# The exception should be propagated to root agent and other subagents.
235+
# Otherwise we'll have an infinite loop.
236+
with pytest.raises(Exception):
237+
async for _ in agen:
238+
# The infinite agent could iterate a few times depending on scheduling.
239+
pass

0 commit comments

Comments
 (0)