@@ -200,10 +200,13 @@ def create_session(self, session_name: str) -> str:
200
200
self .sessions .append (self .session_id )
201
201
return self .session_id
202
202
203
- def _run_tool (self , tool_calls : List [ToolCall ]) -> ToolResponseParam :
204
- assert len (tool_calls ) == 1 , "Only one tool call is supported"
205
- tool_call = tool_calls [0 ]
203
+ def _run_tool_calls (self , tool_calls : List [ToolCall ]) -> List [ToolResponseParam ]:
204
+ responses = []
205
+ for tool_call in tool_calls :
206
+ responses .append (self ._run_single_tool (tool_call ))
207
+ return responses
206
208
209
+ def _run_single_tool (self , tool_call : ToolCall ) -> ToolResponseParam :
207
210
# custom client tools
208
211
if tool_call .tool_name in self .client_tools :
209
212
tool = self .client_tools [tool_call .tool_name ]
@@ -227,12 +230,11 @@ def _run_tool(self, tool_calls: List[ToolCall]) -> ToolResponseParam:
227
230
tool_name = tool_call .tool_name ,
228
231
kwargs = {** tool_call .arguments , ** self .builtin_tools [tool_call .tool_name ]},
229
232
)
230
- tool_response = ToolResponseParam (
233
+ return ToolResponseParam (
231
234
call_id = tool_call .call_id ,
232
235
tool_name = tool_call .tool_name ,
233
236
content = tool_result .content ,
234
237
)
235
- return tool_response
236
238
237
239
# cannot find tools
238
240
return ToolResponseParam (
@@ -302,14 +304,14 @@ def _create_turn_streaming(
302
304
yield chunk
303
305
304
306
# run the tools
305
- tool_response = self ._run_tool (tool_calls )
307
+ tool_responses = self ._run_tool_calls (tool_calls )
306
308
307
309
# pass it to next iteration
308
310
turn_response = self .client .agents .turn .resume (
309
311
agent_id = self .agent_id ,
310
312
session_id = session_id or self .session_id [- 1 ],
311
313
turn_id = turn_id ,
312
- tool_responses = [ tool_response ] ,
314
+ tool_responses = tool_responses ,
313
315
stream = True ,
314
316
)
315
317
n_iter += 1
@@ -439,10 +441,13 @@ async def create_turn(
439
441
raise Exception ("Turn did not complete" )
440
442
return chunks [- 1 ].event .payload .turn
441
443
442
- async def _run_tool (self , tool_calls : List [ToolCall ]) -> ToolResponseParam :
443
- assert len (tool_calls ) == 1 , "Only one tool call is supported"
444
- tool_call = tool_calls [0 ]
444
+ async def _run_tool_calls (self , tool_calls : List [ToolCall ]) -> List [ToolResponseParam ]:
445
+ responses = []
446
+ for tool_call in tool_calls :
447
+ responses .append (await self ._run_single_tool (tool_call ))
448
+ return responses
445
449
450
+ async def _run_single_tool (self , tool_call : ToolCall ) -> ToolResponseParam :
446
451
# custom client tools
447
452
if tool_call .tool_name in self .client_tools :
448
453
tool = self .client_tools [tool_call .tool_name ]
@@ -464,12 +469,11 @@ async def _run_tool(self, tool_calls: List[ToolCall]) -> ToolResponseParam:
464
469
tool_name = tool_call .tool_name ,
465
470
kwargs = {** tool_call .arguments , ** self .builtin_tools [tool_call .tool_name ]},
466
471
)
467
- tool_response = ToolResponseParam (
472
+ return ToolResponseParam (
468
473
call_id = tool_call .call_id ,
469
474
tool_name = tool_call .tool_name ,
470
475
content = tool_result .content ,
471
476
)
472
- return tool_response
473
477
474
478
# cannot find tools
475
479
return ToolResponseParam (
@@ -522,14 +526,14 @@ async def _create_turn_streaming(
522
526
yield chunk
523
527
524
528
# run the tools
525
- tool_response = await self ._run_tool (tool_calls )
529
+ tool_responses = await self ._run_tool_calls (tool_calls )
526
530
527
531
# pass it to next iteration
528
532
turn_response = await self .client .agents .turn .resume (
529
533
agent_id = self .agent_id ,
530
534
session_id = session_id or self .session_id [- 1 ],
531
535
turn_id = turn_id ,
532
- tool_responses = [ tool_response ] ,
536
+ tool_responses = tool_responses ,
533
537
stream = True ,
534
538
)
535
539
n_iter += 1
0 commit comments