Skip to content

Commit 43ea2f6

Browse files
authored
feat(agent): support multiple tool calls (#192)
Summary: Supported by llamastack/llama-stack#1556 Test Plan: Tested in llamastack/llama-stack#1556
1 parent fa73d7d commit 43ea2f6

File tree

1 file changed

+18
-14
lines changed
  • src/llama_stack_client/lib/agents

1 file changed

+18
-14
lines changed

src/llama_stack_client/lib/agents/agent.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -200,10 +200,13 @@ def create_session(self, session_name: str) -> str:
200200
self.sessions.append(self.session_id)
201201
return self.session_id
202202

203-
def _run_tool(self, tool_calls: List[ToolCall]) -> ToolResponseParam:
204-
assert len(tool_calls) == 1, "Only one tool call is supported"
205-
tool_call = tool_calls[0]
203+
def _run_tool_calls(self, tool_calls: List[ToolCall]) -> List[ToolResponseParam]:
204+
responses = []
205+
for tool_call in tool_calls:
206+
responses.append(self._run_single_tool(tool_call))
207+
return responses
206208

209+
def _run_single_tool(self, tool_call: ToolCall) -> ToolResponseParam:
207210
# custom client tools
208211
if tool_call.tool_name in self.client_tools:
209212
tool = self.client_tools[tool_call.tool_name]
@@ -227,12 +230,11 @@ def _run_tool(self, tool_calls: List[ToolCall]) -> ToolResponseParam:
227230
tool_name=tool_call.tool_name,
228231
kwargs={**tool_call.arguments, **self.builtin_tools[tool_call.tool_name]},
229232
)
230-
tool_response = ToolResponseParam(
233+
return ToolResponseParam(
231234
call_id=tool_call.call_id,
232235
tool_name=tool_call.tool_name,
233236
content=tool_result.content,
234237
)
235-
return tool_response
236238

237239
# cannot find tools
238240
return ToolResponseParam(
@@ -302,14 +304,14 @@ def _create_turn_streaming(
302304
yield chunk
303305

304306
# run the tools
305-
tool_response = self._run_tool(tool_calls)
307+
tool_responses = self._run_tool_calls(tool_calls)
306308

307309
# pass it to next iteration
308310
turn_response = self.client.agents.turn.resume(
309311
agent_id=self.agent_id,
310312
session_id=session_id or self.session_id[-1],
311313
turn_id=turn_id,
312-
tool_responses=[tool_response],
314+
tool_responses=tool_responses,
313315
stream=True,
314316
)
315317
n_iter += 1
@@ -439,10 +441,13 @@ async def create_turn(
439441
raise Exception("Turn did not complete")
440442
return chunks[-1].event.payload.turn
441443

442-
async def _run_tool(self, tool_calls: List[ToolCall]) -> ToolResponseParam:
443-
assert len(tool_calls) == 1, "Only one tool call is supported"
444-
tool_call = tool_calls[0]
444+
async def _run_tool_calls(self, tool_calls: List[ToolCall]) -> List[ToolResponseParam]:
445+
responses = []
446+
for tool_call in tool_calls:
447+
responses.append(await self._run_single_tool(tool_call))
448+
return responses
445449

450+
async def _run_single_tool(self, tool_call: ToolCall) -> ToolResponseParam:
446451
# custom client tools
447452
if tool_call.tool_name in self.client_tools:
448453
tool = self.client_tools[tool_call.tool_name]
@@ -464,12 +469,11 @@ async def _run_tool(self, tool_calls: List[ToolCall]) -> ToolResponseParam:
464469
tool_name=tool_call.tool_name,
465470
kwargs={**tool_call.arguments, **self.builtin_tools[tool_call.tool_name]},
466471
)
467-
tool_response = ToolResponseParam(
472+
return ToolResponseParam(
468473
call_id=tool_call.call_id,
469474
tool_name=tool_call.tool_name,
470475
content=tool_result.content,
471476
)
472-
return tool_response
473477

474478
# cannot find tools
475479
return ToolResponseParam(
@@ -522,14 +526,14 @@ async def _create_turn_streaming(
522526
yield chunk
523527

524528
# run the tools
525-
tool_response = await self._run_tool(tool_calls)
529+
tool_responses = await self._run_tool_calls(tool_calls)
526530

527531
# pass it to next iteration
528532
turn_response = await self.client.agents.turn.resume(
529533
agent_id=self.agent_id,
530534
session_id=session_id or self.session_id[-1],
531535
turn_id=turn_id,
532-
tool_responses=[tool_response],
536+
tool_responses=tool_responses,
533537
stream=True,
534538
)
535539
n_iter += 1

0 commit comments

Comments
 (0)