@@ -33,12 +33,8 @@ class _TestingAgent(BaseAgent):
3333 delay : float = 0
3434 """The delay before the agent generates an event."""
3535
36- @override
37- async def _run_async_impl (
38- self , ctx : InvocationContext
39- ) -> AsyncGenerator [Event , None ]:
40- await asyncio .sleep (self .delay )
41- yield Event (
36+ def event (self , ctx : InvocationContext ):
37+ return Event (
4238 author = self .name ,
4339 branch = ctx .branch ,
4440 invocation_id = ctx .invocation_id ,
@@ -47,6 +43,13 @@ async def _run_async_impl(
4743 ),
4844 )
4945
46+ @override
47+ async def _run_async_impl (
48+ self , ctx : InvocationContext
49+ ) -> AsyncGenerator [Event , None ]:
50+ await asyncio .sleep (self .delay )
51+ yield self .event (ctx )
52+
5053
5154async def _create_parent_invocation_context (
5255 test_name : str , agent : BaseAgent
@@ -137,26 +140,19 @@ async def test_run_async_branches(request: pytest.FixtureRequest):
137140 assert events [2 ].branch != events [0 ].branch
138141
139142
140- class _TestingAgentWithMultipleEvents (BaseAgent ):
143+ class _TestingAgentWithMultipleEvents (_TestingAgent ):
141144 """Mock agent for testing."""
142145
143146 @override
144147 async def _run_async_impl (
145148 self , ctx : InvocationContext
146149 ) -> AsyncGenerator [Event , None ]:
147150 for _ in range (0 , 3 ):
148- event = Event (
149- author = self .name ,
150- branch = ctx .branch ,
151- invocation_id = ctx .invocation_id ,
152- content = types .Content (
153- parts = [types .Part (text = f'Hello, async { self .name } !' )]
154- ),
155- )
156- yield event
157- # Check that the event was processed by the consumer.
158- assert event .custom_metadata is not None
159- assert event .custom_metadata ['processed' ]
151+ event = self .event (ctx )
152+ yield event
153+ # Check that the event was processed by the consumer.
154+ assert event .custom_metadata is not None
155+ assert event .custom_metadata ['processed' ]
160156
161157
162158@pytest .mark .asyncio
@@ -186,3 +182,58 @@ async def test_generating_one_event_per_agent_at_once(
186182 async for event in agen :
187183 event .custom_metadata = {'processed' : True }
188184 # Asserts on event are done in _TestingAgentWithMultipleEvents.
185+
186+
187+ class _TestingAgentWithException (_TestingAgent ):
188+ """Mock agent for testing."""
189+
190+ @override
191+ async def _run_async_impl (
192+ self , ctx : InvocationContext
193+ ) -> AsyncGenerator [Event , None ]:
194+ yield self .event (ctx )
195+ raise Exception ()
196+
197+
198+ class _TestingAgentInfiniteEvents (_TestingAgent ):
199+ """Mock agent for testing."""
200+
201+ @override
202+ async def _run_async_impl (
203+ self , ctx : InvocationContext
204+ ) -> AsyncGenerator [Event , None ]:
205+ while True :
206+ yield self .event (ctx )
207+
208+
209+ @pytest .mark .asyncio
210+ async def test_stop_agent_if_sub_agent_fails (
211+ request : pytest .FixtureRequest ,
212+ ):
213+ # This test is to verify that the parallel agent and subagents will all stop
214+ # processing and throw exception to top level runner in case of exception.
215+ agent1 = _TestingAgentWithException (
216+ name = f'{ request .function .__name__ } _test_agent_1'
217+ )
218+ agent2 = _TestingAgentInfiniteEvents (
219+ name = f'{ request .function .__name__ } _test_agent_2'
220+ )
221+ parallel_agent = ParallelAgent (
222+ name = f'{ request .function .__name__ } _test_parallel_agent' ,
223+ sub_agents = [
224+ agent1 ,
225+ agent2 ,
226+ ],
227+ )
228+ parent_ctx = await _create_parent_invocation_context (
229+ request .function .__name__ , parallel_agent
230+ )
231+
232+ agen = parallel_agent .run_async (parent_ctx )
233+ # We expect to receive an exception from one of subagents.
234+ # The exception should be propagated to root agent and other subagents.
235+ # Otherwise we'll have an infinite loop.
236+ with pytest .raises (Exception ):
237+ async for _ in agen :
238+ # The infinite agent could iterate a few times depending on scheduling.
239+ pass
0 commit comments