Skip to content

Commit 70f50db

Browse files
jawoszekcopybara-github
authored andcommitted
chore: add test for parallel agent to verify correct ordering of agents
PiperOrigin-RevId: 797817063
1 parent 81a53b5 commit 70f50db

File tree

1 file changed

+51
-0
lines changed

1 file changed

+51
-0
lines changed

tests/unittests/agents/test_parallel_agent.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,3 +135,54 @@ async def test_run_async_branches(request: pytest.FixtureRequest):
135135
# Sub-agents should have different branches.
136136
assert events[2].branch != events[1].branch
137137
assert events[2].branch != events[0].branch
138+
139+
140+
class _TestingAgentWithMultipleEvents(BaseAgent):
141+
"""Mock agent for testing."""
142+
143+
@override
144+
async def _run_async_impl(
145+
self, ctx: InvocationContext
146+
) -> AsyncGenerator[Event, None]:
147+
for _ in range(0, 3):
148+
event = Event(
149+
author=self.name,
150+
branch=ctx.branch,
151+
invocation_id=ctx.invocation_id,
152+
content=types.Content(
153+
parts=[types.Part(text=f'Hello, async {self.name}!')]
154+
),
155+
)
156+
yield event
157+
# Check that the event was processed by the consumer.
158+
assert event.custom_metadata is not None
159+
assert event.custom_metadata['processed']
160+
161+
162+
@pytest.mark.asyncio
163+
async def test_generating_one_event_per_agent_at_once(
164+
request: pytest.FixtureRequest,
165+
):
166+
# This test is to verify that the parallel agent won't generate more than one
167+
# event per agent at a time.
168+
agent1 = _TestingAgentWithMultipleEvents(
169+
name=f'{request.function.__name__}_test_agent_1'
170+
)
171+
agent2 = _TestingAgentWithMultipleEvents(
172+
name=f'{request.function.__name__}_test_agent_2'
173+
)
174+
parallel_agent = ParallelAgent(
175+
name=f'{request.function.__name__}_test_parallel_agent',
176+
sub_agents=[
177+
agent1,
178+
agent2,
179+
],
180+
)
181+
parent_ctx = await _create_parent_invocation_context(
182+
request.function.__name__, parallel_agent
183+
)
184+
185+
agen = parallel_agent.run_async(parent_ctx)
186+
async for event in agen:
187+
event.custom_metadata = {'processed': True}
188+
# Asserts on event are done in _TestingAgentWithMultipleEvents.

0 commit comments

Comments
 (0)