Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions docs/_static/llama-stack-spec.html
Original file line number Diff line number Diff line change
Expand Up @@ -9228,11 +9228,21 @@
"type": "object",
"properties": {
"tool_responses": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ToolResponseMessage"
},
"description": "The tool call responses to resume the turn with."
"oneOf": [
{
"type": "array",
"items": {
"$ref": "#/components/schemas/ToolResponse"
}
},
{
"type": "array",
"items": {
"$ref": "#/components/schemas/ToolResponseMessage"
}
}
],
"description": "The tool call responses to resume the turn with. NOTE: ToolResponseMessage will be deprecated. Use ToolResponse."
},
"stream": {
"type": "boolean",
Expand Down
13 changes: 9 additions & 4 deletions docs/_static/llama-stack-spec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6153,11 +6153,16 @@ components:
type: object
properties:
tool_responses:
type: array
items:
$ref: '#/components/schemas/ToolResponseMessage'
oneOf:
- type: array
items:
$ref: '#/components/schemas/ToolResponse'
- type: array
items:
$ref: '#/components/schemas/ToolResponseMessage'
description: >-
The tool call responses to resume the turn with.
The tool call responses to resume the turn with. NOTE: ToolResponseMessage
will be deprecated. Use ToolResponse.
stream:
type: boolean
description: Whether to stream the response.
Expand Down
5 changes: 3 additions & 2 deletions llama_stack/apis/agents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ class AgentTurnResumeRequest(BaseModel):
agent_id: str
session_id: str
turn_id: str
tool_responses: List[ToolResponseMessage]
tool_responses: Union[List[ToolResponse], List[ToolResponseMessage]]
stream: Optional[bool] = False


Expand Down Expand Up @@ -363,7 +363,7 @@ async def resume_agent_turn(
agent_id: str,
session_id: str,
turn_id: str,
tool_responses: List[ToolResponseMessage],
tool_responses: Union[List[ToolResponse], List[ToolResponseMessage]],
stream: Optional[bool] = False,
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]:
"""Resume an agent turn with executed tool call responses.
Expand All @@ -374,6 +374,7 @@ async def resume_agent_turn(
:param session_id: The ID of the session to resume.
:param turn_id: The ID of the turn to resume.
:param tool_responses: The tool call responses to resume the turn with.
NOTE: ToolResponseMessage will be deprecated. Use ToolResponse.
:param stream: Whether to stream the response.
:returns: A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk objects.
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,13 +216,25 @@ async def _run_turn(
steps = []
messages = await self.get_messages_from_turns(turns)
if is_resume:
messages.extend(request.tool_responses)
if isinstance(request.tool_responses[0], ToolResponseMessage):
tool_response_messages = request.tool_responses
tool_responses = [
ToolResponse(call_id=x.call_id, tool_name=x.tool_name, content=x.content)
for x in request.tool_responses
]
else:
tool_response_messages = [
ToolResponseMessage(call_id=x.call_id, tool_name=x.tool_name, content=x.content)
for x in request.tool_responses
]
tool_responses = request.tool_responses
messages.extend(tool_response_messages)
last_turn = turns[-1]
last_turn_messages = self.turn_to_messages(last_turn)
last_turn_messages = [
x for x in last_turn_messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage)
]
last_turn_messages.extend(request.tool_responses)
last_turn_messages.extend(tool_response_messages)

# get steps from the turn
steps = last_turn.steps
Expand All @@ -238,14 +250,7 @@ async def _run_turn(
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
turn_id=request.turn_id,
tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []),
tool_responses=[
ToolResponse(
call_id=x.call_id,
tool_name=x.tool_name,
content=x.content,
)
for x in request.tool_responses
],
tool_responses=tool_responses,
completed_at=now,
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from llama_stack.apis.inference import (
Inference,
ToolConfig,
ToolResponse,
ToolResponseMessage,
UserMessage,
)
Expand Down Expand Up @@ -168,7 +169,7 @@ async def resume_agent_turn(
agent_id: str,
session_id: str,
turn_id: str,
tool_responses: List[ToolResponseMessage],
tool_responses: Union[List[ToolResponse], List[ToolResponseMessage]],
stream: Optional[bool] = False,
) -> AsyncGenerator:
request = AgentTurnResumeRequest(
Expand Down
29 changes: 26 additions & 3 deletions tests/integration/agents/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from typing import Any, Dict
from uuid import uuid4

import pytest
Expand Down Expand Up @@ -40,6 +41,25 @@ def get_boiling_point(liquid_name: str, celcius: bool = True) -> int:
return -1


@client_tool
def get_boiling_point_with_metadata(liquid_name: str, celcius: bool = True) -> Dict[str, Any]:
"""
Returns the boiling point of a liquid in Celcius or Fahrenheit

:param liquid_name: The name of the liquid
:param celcius: Whether to return the boiling point in Celcius
:return: The boiling point of the liquid in Celcius or Fahrenheit
"""
if liquid_name.lower() == "polyjuice":
if celcius:
temp = -100
else:
temp = -212
else:
temp = -1
return {"content": temp, "metadata": {"source": "https://www.google.com"}}


@pytest.fixture(scope="session")
def agent_config(llama_stack_client_with_mocked_inference, text_model_id):
available_shields = [shield.identifier for shield in llama_stack_client_with_mocked_inference.shields.list()]
Expand Down Expand Up @@ -551,8 +571,9 @@ def test_rag_and_code_agent(llama_stack_client_with_mocked_inference, agent_conf
assert expected_kw in response.output_message.content.lower()


def test_create_turn_response(llama_stack_client_with_mocked_inference, agent_config):
client_tool = get_boiling_point
@pytest.mark.parametrize("client_tools", [(get_boiling_point, False), (get_boiling_point_with_metadata, True)])
def test_create_turn_response(llama_stack_client_with_mocked_inference, agent_config, client_tools):
client_tool, expectes_metadata = client_tools
agent_config = {
**agent_config,
"input_shields": [],
Expand All @@ -577,7 +598,9 @@ def test_create_turn_response(llama_stack_client_with_mocked_inference, agent_co
assert len(steps) == 3
assert steps[0].step_type == "inference"
assert steps[1].step_type == "tool_execution"
assert steps[1].tool_calls[0].tool_name == "get_boiling_point"
assert steps[1].tool_calls[0].tool_name.startswith("get_boiling_point")
if expectes_metadata:
assert steps[1].tool_responses[0].metadata["source"] == "https://www.google.com"
assert steps[2].step_type == "inference"

last_step_completed_at = None
Expand Down
Loading