@@ -218,10 +218,11 @@ async def on_event(self, event: RealtimeModelEvent) -> None:
218218 if event .type == "error" :
219219 await self ._put_event (RealtimeError (info = self ._event_info , error = event .error ))
220220 elif event .type == "function_call" :
221+ agent_snapshot = self ._current_agent
221222 if self ._async_tool_calls :
222- self ._enqueue_tool_call_task (event )
223+ self ._enqueue_tool_call_task (event , agent_snapshot )
223224 else :
224- await self ._handle_tool_call (event )
225+ await self ._handle_tool_call (event , agent_snapshot = agent_snapshot )
225226 elif event .type == "audio" :
226227 await self ._put_event (
227228 RealtimeAudio (
@@ -389,11 +390,17 @@ async def _put_event(self, event: RealtimeSessionEvent) -> None:
389390 """Put an event into the queue."""
390391 await self ._event_queue .put (event )
391392
392- async def _handle_tool_call (self , event : RealtimeModelToolCallEvent ) -> None :
393+ async def _handle_tool_call (
394+ self ,
395+ event : RealtimeModelToolCallEvent ,
396+ * ,
397+ agent_snapshot : RealtimeAgent | None = None ,
398+ ) -> None :
393399 """Handle a tool call event."""
400+ agent = agent_snapshot or self ._current_agent
394401 tools , handoffs = await asyncio .gather (
395- self . _current_agent .get_all_tools (self ._context_wrapper ),
396- self ._get_handoffs (self . _current_agent , self ._context_wrapper ),
402+ agent .get_all_tools (self ._context_wrapper ),
403+ self ._get_handoffs (agent , self ._context_wrapper ),
397404 )
398405 function_map = {tool .name : tool for tool in tools if isinstance (tool , FunctionTool )}
399406 handoff_map = {handoff .tool_name : handoff for handoff in handoffs }
@@ -403,7 +410,7 @@ async def _handle_tool_call(self, event: RealtimeModelToolCallEvent) -> None:
403410 RealtimeToolStart (
404411 info = self ._event_info ,
405412 tool = function_map [event .name ],
406- agent = self . _current_agent ,
413+ agent = agent ,
407414 )
408415 )
409416
@@ -428,7 +435,7 @@ async def _handle_tool_call(self, event: RealtimeModelToolCallEvent) -> None:
428435 info = self ._event_info ,
429436 tool = func_tool ,
430437 output = result ,
431- agent = self . _current_agent ,
438+ agent = agent ,
432439 )
433440 )
434441 elif event .name in handoff_map :
@@ -449,7 +456,7 @@ async def _handle_tool_call(self, event: RealtimeModelToolCallEvent) -> None:
449456 )
450457
451458 # Store previous agent for event
452- previous_agent = self . _current_agent
459+ previous_agent = agent
453460
454461 # Update current agent
455462 self ._current_agent = result
@@ -757,9 +764,13 @@ def _cleanup_guardrail_tasks(self) -> None:
757764 task .cancel ()
758765 self ._guardrail_tasks .clear ()
759766
760- def _enqueue_tool_call_task (self , event : RealtimeModelToolCallEvent ) -> None :
767+ def _enqueue_tool_call_task (
768+ self , event : RealtimeModelToolCallEvent , agent_snapshot : RealtimeAgent
769+ ) -> None :
761770 """Run tool calls in the background to avoid blocking realtime transport."""
762- task = asyncio .create_task (self ._handle_tool_call (event ))
771+ task = asyncio .create_task (
772+ self ._handle_tool_call (event , agent_snapshot = agent_snapshot )
773+ )
763774 self ._tool_call_tasks .add (task )
764775 task .add_done_callback (self ._on_tool_call_task_done )
765776
0 commit comments