Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 18 additions & 14 deletions src/llama_stack_client/lib/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,10 +200,13 @@ def create_session(self, session_name: str) -> str:
self.sessions.append(self.session_id)
return self.session_id

def _run_tool(self, tool_calls: List[ToolCall]) -> ToolResponseParam:
assert len(tool_calls) == 1, "Only one tool call is supported"
tool_call = tool_calls[0]
def _run_tool_calls(self, tool_calls: List[ToolCall]) -> List[ToolResponseParam]:
responses = []
for tool_call in tool_calls:
responses.append(self._run_single_tool(tool_call))
return responses

def _run_single_tool(self, tool_call: ToolCall) -> ToolResponseParam:
# custom client tools
if tool_call.tool_name in self.client_tools:
tool = self.client_tools[tool_call.tool_name]
Expand All @@ -227,12 +230,11 @@ def _run_tool(self, tool_calls: List[ToolCall]) -> ToolResponseParam:
tool_name=tool_call.tool_name,
kwargs={**tool_call.arguments, **self.builtin_tools[tool_call.tool_name]},
)
tool_response = ToolResponseParam(
return ToolResponseParam(
call_id=tool_call.call_id,
tool_name=tool_call.tool_name,
content=tool_result.content,
)
return tool_response

# cannot find tools
return ToolResponseParam(
Expand Down Expand Up @@ -302,14 +304,14 @@ def _create_turn_streaming(
yield chunk

# run the tools
tool_response = self._run_tool(tool_calls)
tool_responses = self._run_tool_calls(tool_calls)

# pass it to next iteration
turn_response = self.client.agents.turn.resume(
agent_id=self.agent_id,
session_id=session_id or self.session_id[-1],
turn_id=turn_id,
tool_responses=[tool_response],
tool_responses=tool_responses,
stream=True,
)
n_iter += 1
Expand Down Expand Up @@ -439,10 +441,13 @@ async def create_turn(
raise Exception("Turn did not complete")
return chunks[-1].event.payload.turn

async def _run_tool(self, tool_calls: List[ToolCall]) -> ToolResponseParam:
assert len(tool_calls) == 1, "Only one tool call is supported"
tool_call = tool_calls[0]
async def _run_tool_calls(self, tool_calls: List[ToolCall]) -> List[ToolResponseParam]:
responses = []
for tool_call in tool_calls:
responses.append(await self._run_single_tool(tool_call))
return responses

async def _run_single_tool(self, tool_call: ToolCall) -> ToolResponseParam:
# custom client tools
if tool_call.tool_name in self.client_tools:
tool = self.client_tools[tool_call.tool_name]
Expand All @@ -464,12 +469,11 @@ async def _run_tool(self, tool_calls: List[ToolCall]) -> ToolResponseParam:
tool_name=tool_call.tool_name,
kwargs={**tool_call.arguments, **self.builtin_tools[tool_call.tool_name]},
)
tool_response = ToolResponseParam(
return ToolResponseParam(
call_id=tool_call.call_id,
tool_name=tool_call.tool_name,
content=tool_result.content,
)
return tool_response

# cannot find tools
return ToolResponseParam(
Expand Down Expand Up @@ -522,14 +526,14 @@ async def _create_turn_streaming(
yield chunk

# run the tools
tool_response = await self._run_tool(tool_calls)
tool_responses = await self._run_tool_calls(tool_calls)

# pass it to next iteration
turn_response = await self.client.agents.turn.resume(
agent_id=self.agent_id,
session_id=session_id or self.session_id[-1],
turn_id=turn_id,
tool_responses=[tool_response],
tool_responses=tool_responses,
stream=True,
)
n_iter += 1
Expand Down