1717from __future__ import annotations
1818
1919import asyncio
20- from typing import Any
20+ import sys
2121from typing import AsyncGenerator
2222from typing import ClassVar
23- from typing import Dict
24- from typing import Type
2523
2624from typing_extensions import override
2725
@@ -49,6 +47,70 @@ def _create_branch_ctx_for_sub_agent(
4947 return invocation_context
5048
5149
50+ # TODO - remove once Python <3.11 is no longer supported.
51+ async def _merge_agent_run_pre_3_11 (
52+ agent_runs : list [AsyncGenerator [Event , None ]],
53+ ) -> AsyncGenerator [Event , None ]:
54+ """Merges the agent run event generator.
55+ This version works in Python 3.9 and 3.10 and uses custom replacement for
56+ asyncio.TaskGroup for tasks cancellation and exception handling.
57+
58+ This implementation guarantees for each agent, it won't move on until the
59+ generated event is processed by upstream runner.
60+
61+ Args:
62+ agent_runs: A list of async generators that yield events from each agent.
63+
64+ Yields:
65+ Event: The next event from the merged generator.
66+ """
67+ sentinel = object ()
68+ queue = asyncio .Queue ()
69+
70+ def propagate_exceptions (tasks ):
71+ # Propagate exceptions and errors from tasks.
72+ for task in tasks :
73+ if task .done ():
74+ # Ignore the result (None) of correctly finished tasks and re-raise
75+ # exceptions and errors.
76+ task .result ()
77+
78+ # Agents are processed in parallel.
79+ # Events for each agent are put on queue sequentially.
80+ async def process_an_agent (events_for_one_agent ):
81+ try :
82+ async for event in events_for_one_agent :
83+ resume_signal = asyncio .Event ()
84+ await queue .put ((event , resume_signal ))
85+ # Wait for upstream to consume event before generating new events.
86+ await resume_signal .wait ()
87+ finally :
88+ # Mark agent as finished.
89+ await queue .put ((sentinel , None ))
90+
91+ tasks = []
92+ try :
93+ for events_for_one_agent in agent_runs :
94+ tasks .append (asyncio .create_task (process_an_agent (events_for_one_agent )))
95+
96+ sentinel_count = 0
97+ # Run until all agents finished processing.
98+ while sentinel_count < len (agent_runs ):
99+ propagate_exceptions (tasks )
100+ event , resume_signal = await queue .get ()
101+ # Agent finished processing.
102+ if event is sentinel :
103+ sentinel_count += 1
104+ else :
105+ yield event
106+ # Signal to agent that event has been processed by runner and it can
107+ # continue now.
108+ resume_signal .set ()
109+ finally :
110+ for task in tasks :
111+ task .cancel ()
112+
113+
52114async def _merge_agent_run (
53115 agent_runs : list [AsyncGenerator [Event , None ]],
54116) -> AsyncGenerator [Event , None ]:
@@ -63,30 +125,37 @@ async def _merge_agent_run(
63125 Yields:
64126 Event: The next event from the merged generator.
65127 """
66- tasks = [
67- asyncio .create_task (events_for_one_agent .__anext__ ())
68- for events_for_one_agent in agent_runs
69- ]
70- pending_tasks = set (tasks )
71-
72- while pending_tasks :
73- done , pending_tasks = await asyncio .wait (
74- pending_tasks , return_when = asyncio .FIRST_COMPLETED
75- )
76- for task in done :
77- try :
78- yield task .result ()
79-
80- # Find the generator that produced this event and move it on.
81- for i , original_task in enumerate (tasks ):
82- if task == original_task :
83- new_task = asyncio .create_task (agent_runs [i ].__anext__ ())
84- tasks [i ] = new_task
85- pending_tasks .add (new_task )
86- break # stop iterating once found
87-
88- except StopAsyncIteration :
89- continue
128+ sentinel = object ()
129+ queue = asyncio .Queue ()
130+
131+ # Agents are processed in parallel.
132+ # Events for each agent are put on queue sequentially.
133+ async def process_an_agent (events_for_one_agent ):
134+ try :
135+ async for event in events_for_one_agent :
136+ resume_signal = asyncio .Event ()
137+ await queue .put ((event , resume_signal ))
138+ # Wait for upstream to consume event before generating new events.
139+ await resume_signal .wait ()
140+ finally :
141+ # Mark agent as finished.
142+ await queue .put ((sentinel , None ))
143+
144+ async with asyncio .TaskGroup () as tg :
145+ for events_for_one_agent in agent_runs :
146+ tg .create_task (process_an_agent (events_for_one_agent ))
147+
148+ sentinel_count = 0
149+ # Run until all agents finished processing.
150+ while sentinel_count < len (agent_runs ):
151+ event , resume_signal = await queue .get ()
152+ # Agent finished processing.
153+ if event is sentinel :
154+ sentinel_count += 1
155+ else :
156+ yield event
157+ # Signal to agent that it should generate next event.
158+ resume_signal .set ()
90159
91160
92161class ParallelAgent (BaseAgent ):
@@ -112,10 +181,19 @@ async def _run_async_impl(
112181 )
113182 for sub_agent in self .sub_agents
114183 ]
115-
116- async with Aclosing (_merge_agent_run (agent_runs )) as agen :
117- async for event in agen :
118- yield event
184+ try :
185+ # TODO remove if once Python <3.11 is no longer supported.
186+ if sys .version_info >= (3 , 11 ):
187+ async with Aclosing (_merge_agent_run (agent_runs )) as agen :
188+ async for event in agen :
189+ yield event
190+ else :
191+ async with Aclosing (_merge_agent_run_pre_3_11 (agent_runs )) as agen :
192+ async for event in agen :
193+ yield event
194+ finally :
195+ for sub_agent_run in agent_runs :
196+ await sub_agent_run .aclose ()
119197
120198 @override
121199 async def _run_live_impl (
0 commit comments