From b92d2f11198ee5ca994e07773161502f7397c919 Mon Sep 17 00:00:00 2001 From: Eric Huang Date: Thu, 13 Mar 2025 11:34:51 -0700 Subject: [PATCH] feat(agent): support multiple tool calls Summary: Test Plan: --- src/llama_stack_client/lib/agents/agent.py | 32 ++++++++++++---------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index 09c353b1..3c43bc0e 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -200,10 +200,13 @@ def create_session(self, session_name: str) -> str: self.sessions.append(self.session_id) return self.session_id - def _run_tool(self, tool_calls: List[ToolCall]) -> ToolResponseParam: - assert len(tool_calls) == 1, "Only one tool call is supported" - tool_call = tool_calls[0] + def _run_tool_calls(self, tool_calls: List[ToolCall]) -> List[ToolResponseParam]: + responses = [] + for tool_call in tool_calls: + responses.append(self._run_single_tool(tool_call)) + return responses + def _run_single_tool(self, tool_call: ToolCall) -> ToolResponseParam: # custom client tools if tool_call.tool_name in self.client_tools: tool = self.client_tools[tool_call.tool_name] @@ -227,12 +230,11 @@ def _run_tool(self, tool_calls: List[ToolCall]) -> ToolResponseParam: tool_name=tool_call.tool_name, kwargs={**tool_call.arguments, **self.builtin_tools[tool_call.tool_name]}, ) - tool_response = ToolResponseParam( + return ToolResponseParam( call_id=tool_call.call_id, tool_name=tool_call.tool_name, content=tool_result.content, ) - return tool_response # cannot find tools return ToolResponseParam( @@ -302,14 +304,14 @@ def _create_turn_streaming( yield chunk # run the tools - tool_response = self._run_tool(tool_calls) + tool_responses = self._run_tool_calls(tool_calls) # pass it to next iteration turn_response = self.client.agents.turn.resume( agent_id=self.agent_id, session_id=session_id or self.session_id[-1], turn_id=turn_id, - tool_responses=[tool_response], + tool_responses=tool_responses, stream=True, ) n_iter += 1 @@ -439,10 +441,13 @@ async def create_turn( raise Exception("Turn did not complete") return chunks[-1].event.payload.turn - async def _run_tool(self, tool_calls: List[ToolCall]) -> ToolResponseParam: - assert len(tool_calls) == 1, "Only one tool call is supported" - tool_call = tool_calls[0] + async def _run_tool_calls(self, tool_calls: List[ToolCall]) -> List[ToolResponseParam]: + responses = [] + for tool_call in tool_calls: + responses.append(await self._run_single_tool(tool_call)) + return responses + async def _run_single_tool(self, tool_call: ToolCall) -> ToolResponseParam: # custom client tools if tool_call.tool_name in self.client_tools: tool = self.client_tools[tool_call.tool_name] @@ -464,12 +469,11 @@ async def _run_tool(self, tool_calls: List[ToolCall]) -> ToolResponseParam: tool_name=tool_call.tool_name, kwargs={**tool_call.arguments, **self.builtin_tools[tool_call.tool_name]}, ) - tool_response = ToolResponseParam( + return ToolResponseParam( call_id=tool_call.call_id, tool_name=tool_call.tool_name, content=tool_result.content, ) - return tool_response # cannot find tools return ToolResponseParam( @@ -522,14 +526,14 @@ async def _create_turn_streaming( yield chunk # run the tools - tool_response = await self._run_tool(tool_calls) + tool_responses = await self._run_tool_calls(tool_calls) # pass it to next iteration turn_response = await self.client.agents.turn.resume( agent_id=self.agent_id, session_id=session_id or self.session_id[-1], turn_id=turn_id, - tool_responses=[tool_response], + tool_responses=tool_responses, stream=True, ) n_iter += 1