Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions src/llama_stack_client/lib/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ async def create_turn(
raise Exception("Turn did not complete")
return chunks[-1].event.payload.turn

async def _run_tool(self, tool_calls: List[ToolCall]) -> ToolResponseMessage:
async def _run_tool(self, tool_calls: List[ToolCall]) -> ToolResponseParam:
assert len(tool_calls) == 1, "Only one tool call is supported"
tool_call = tool_calls[0]

Expand All @@ -464,20 +464,18 @@ async def _run_tool(self, tool_calls: List[ToolCall]) -> ToolResponseMessage:
tool_name=tool_call.tool_name,
kwargs=tool_call.arguments,
)
tool_response_message = ToolResponseMessage(
tool_response = ToolResponseParam(
call_id=tool_call.call_id,
tool_name=tool_call.tool_name,
content=tool_result.content,
role="tool",
)
return tool_response_message
return tool_response

# cannot find tools
return ToolResponseMessage(
return ToolResponseParam(
call_id=tool_call.call_id,
tool_name=tool_call.tool_name,
content=f"Unknown tool `{tool_call.tool_name}` was called.",
role="tool",
)

async def _create_turn_streaming(
Expand Down Expand Up @@ -524,14 +522,14 @@ async def _create_turn_streaming(
yield chunk

# run the tools
tool_response_message = await self._run_tool(tool_calls)
tool_response = await self._run_tool(tool_calls)

# pass it to next iteration
turn_response = await self.client.agents.turn.resume(
agent_id=self.agent_id,
session_id=session_id or self.session_id[-1],
turn_id=turn_id,
tool_responses=[tool_response_message],
tool_responses=[tool_response],
stream=True,
)
n_iter += 1
Expand Down