@@ -135,3 +135,54 @@ async def test_run_async_branches(request: pytest.FixtureRequest):
135135 # Sub-agents should have different branches.
136136 assert events [2 ].branch != events [1 ].branch
137137 assert events [2 ].branch != events [0 ].branch
138+
139+
140+ class _TestingAgentWithMultipleEvents (BaseAgent ):
141+ """Mock agent for testing."""
142+
143+ @override
144+ async def _run_async_impl (
145+ self , ctx : InvocationContext
146+ ) -> AsyncGenerator [Event , None ]:
147+ for _ in range (0 , 3 ):
148+ event = Event (
149+ author = self .name ,
150+ branch = ctx .branch ,
151+ invocation_id = ctx .invocation_id ,
152+ content = types .Content (
153+ parts = [types .Part (text = f'Hello, async { self .name } !' )]
154+ ),
155+ )
156+ yield event
157+ # Check that the event was processed by the consumer.
158+ assert event .custom_metadata is not None
159+ assert event .custom_metadata ['processed' ]
160+
161+
162+ @pytest .mark .asyncio
163+ async def test_generating_one_event_per_agent_at_once (
164+ request : pytest .FixtureRequest ,
165+ ):
166+ # This test is to verify that the parallel agent won't generate more than one
167+ # event per agent at a time.
168+ agent1 = _TestingAgentWithMultipleEvents (
169+ name = f'{ request .function .__name__ } _test_agent_1'
170+ )
171+ agent2 = _TestingAgentWithMultipleEvents (
172+ name = f'{ request .function .__name__ } _test_agent_2'
173+ )
174+ parallel_agent = ParallelAgent (
175+ name = f'{ request .function .__name__ } _test_parallel_agent' ,
176+ sub_agents = [
177+ agent1 ,
178+ agent2 ,
179+ ],
180+ )
181+ parent_ctx = await _create_parent_invocation_context (
182+ request .function .__name__ , parallel_agent
183+ )
184+
185+ agen = parallel_agent .run_async (parent_ctx )
186+ async for event in agen :
187+ event .custom_metadata = {'processed' : True }
188+ # Asserts on event are done in _TestingAgentWithMultipleEvents.
0 commit comments