Skip to content

Commit ffd4c09

Browse files
committed
Capture agent snapshot for async tool calls
1 parent 1acec90 commit ffd4c09

File tree

1 file changed

+21
-10
lines changed

1 file changed

+21
-10
lines changed

src/agents/realtime/session.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -218,10 +218,11 @@ async def on_event(self, event: RealtimeModelEvent) -> None:
218218
if event.type == "error":
219219
await self._put_event(RealtimeError(info=self._event_info, error=event.error))
220220
elif event.type == "function_call":
221+
agent_snapshot = self._current_agent
221222
if self._async_tool_calls:
222-
self._enqueue_tool_call_task(event)
223+
self._enqueue_tool_call_task(event, agent_snapshot)
223224
else:
224-
await self._handle_tool_call(event)
225+
await self._handle_tool_call(event, agent_snapshot=agent_snapshot)
225226
elif event.type == "audio":
226227
await self._put_event(
227228
RealtimeAudio(
@@ -389,11 +390,17 @@ async def _put_event(self, event: RealtimeSessionEvent) -> None:
389390
"""Put an event into the queue."""
390391
await self._event_queue.put(event)
391392

392-
async def _handle_tool_call(self, event: RealtimeModelToolCallEvent) -> None:
393+
async def _handle_tool_call(
394+
self,
395+
event: RealtimeModelToolCallEvent,
396+
*,
397+
agent_snapshot: RealtimeAgent | None = None,
398+
) -> None:
393399
"""Handle a tool call event."""
400+
agent = agent_snapshot or self._current_agent
394401
tools, handoffs = await asyncio.gather(
395-
self._current_agent.get_all_tools(self._context_wrapper),
396-
self._get_handoffs(self._current_agent, self._context_wrapper),
402+
agent.get_all_tools(self._context_wrapper),
403+
self._get_handoffs(agent, self._context_wrapper),
397404
)
398405
function_map = {tool.name: tool for tool in tools if isinstance(tool, FunctionTool)}
399406
handoff_map = {handoff.tool_name: handoff for handoff in handoffs}
@@ -403,7 +410,7 @@ async def _handle_tool_call(self, event: RealtimeModelToolCallEvent) -> None:
403410
RealtimeToolStart(
404411
info=self._event_info,
405412
tool=function_map[event.name],
406-
agent=self._current_agent,
413+
agent=agent,
407414
)
408415
)
409416

@@ -428,7 +435,7 @@ async def _handle_tool_call(self, event: RealtimeModelToolCallEvent) -> None:
428435
info=self._event_info,
429436
tool=func_tool,
430437
output=result,
431-
agent=self._current_agent,
438+
agent=agent,
432439
)
433440
)
434441
elif event.name in handoff_map:
@@ -449,7 +456,7 @@ async def _handle_tool_call(self, event: RealtimeModelToolCallEvent) -> None:
449456
)
450457

451458
# Store previous agent for event
452-
previous_agent = self._current_agent
459+
previous_agent = agent
453460

454461
# Update current agent
455462
self._current_agent = result
@@ -757,9 +764,13 @@ def _cleanup_guardrail_tasks(self) -> None:
757764
task.cancel()
758765
self._guardrail_tasks.clear()
759766

760-
def _enqueue_tool_call_task(self, event: RealtimeModelToolCallEvent) -> None:
767+
def _enqueue_tool_call_task(
768+
self, event: RealtimeModelToolCallEvent, agent_snapshot: RealtimeAgent
769+
) -> None:
761770
"""Run tool calls in the background to avoid blocking realtime transport."""
762-
task = asyncio.create_task(self._handle_tool_call(event))
771+
task = asyncio.create_task(
772+
self._handle_tool_call(event, agent_snapshot=agent_snapshot)
773+
)
763774
self._tool_call_tasks.add(task)
764775
task.add_done_callback(self._on_tool_call_task_done)
765776

0 commit comments

Comments
 (0)