Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 73 additions & 135 deletions llama_stack/providers/inline/agents/meta_reference/agent_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import string
import uuid
from datetime import datetime
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
from urllib.parse import urlparse

import httpx
Expand All @@ -31,7 +31,6 @@
AgentTurnResponseStreamChunk,
AgentTurnResponseTurnAwaitingInputPayload,
AgentTurnResponseTurnCompletePayload,
AgentTurnResponseTurnStartPayload,
AgentTurnResumeRequest,
Attachment,
Document,
Expand Down Expand Up @@ -184,115 +183,49 @@ async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> Asyn
span.set_attribute("session_id", request.session_id)
span.set_attribute("agent_id", self.agent_id)
span.set_attribute("request", request.model_dump_json())
assert request.stream is True, "Non-streaming not supported"

session_info = await self.storage.get_session_info(request.session_id)
if session_info is None:
raise ValueError(f"Session {request.session_id} not found")

turns = await self.storage.get_session_turns(request.session_id)
messages = await self.get_messages_from_turns(turns)
messages.extend(request.messages)

turn_id = str(uuid.uuid4())
span.set_attribute("turn_id", turn_id)
start_time = datetime.now().astimezone().isoformat()
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnStartPayload(
turn_id=turn_id,
)
)
)

steps = []
output_message = None
async for chunk in self.run(
session_id=request.session_id,
turn_id=turn_id,
input_messages=messages,
sampling_params=self.agent_config.sampling_params,
stream=request.stream,
documents=request.documents,
toolgroups_for_turn=request.toolgroups,
):
if isinstance(chunk, CompletionMessage):
logcat.info(
"agents",
f"returning result from the agent turn: {chunk}",
)
output_message = chunk
continue

assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
event = chunk.event
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
steps.append(event.payload.step_details)

async for chunk in self._run_turn(request, turn_id):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

love it

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

beauty

yield chunk

assert output_message is not None

turn = Turn(
turn_id=turn_id,
session_id=request.session_id,
input_messages=request.messages,
output_message=output_message,
started_at=start_time,
completed_at=datetime.now().astimezone().isoformat(),
steps=steps,
)
await self.storage.add_turn_to_session(request.session_id, turn)
if output_message.tool_calls:
chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnAwaitingInputPayload(
turn=turn,
)
)
)
else:
chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnCompletePayload(
turn=turn,
)
)
)

yield chunk

async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
with tracing.span("resume_turn") as span:
span.set_attribute("agent_id", self.agent_id)
span.set_attribute("session_id", request.session_id)
span.set_attribute("turn_id", request.turn_id)
span.set_attribute("request", request.model_dump_json())
assert request.stream is True, "Non-streaming not supported"
async for chunk in self._run_turn(request):
yield chunk

session_info = await self.storage.get_session_info(request.session_id)
if session_info is None:
raise ValueError(f"Session {request.session_id} not found")
async def _run_turn(
self,
request: Union[AgentTurnCreateRequest, AgentTurnResumeRequest],
turn_id: Optional[str] = None,
) -> AsyncGenerator:
assert request.stream is True, "Non-streaming not supported"

turns = await self.storage.get_session_turns(request.session_id)
if len(turns) == 0:
raise ValueError("No turns found for session")
is_resume = isinstance(request, AgentTurnResumeRequest)
session_info = await self.storage.get_session_info(request.session_id)
if session_info is None:
raise ValueError(f"Session {request.session_id} not found")

messages = await self.get_messages_from_turns(turns)
messages.extend(request.tool_responses)
turns = await self.storage.get_session_turns(request.session_id)
if is_resume and len(turns) == 0:
raise ValueError("No turns found for session")

steps = []
messages = await self.get_messages_from_turns(turns)
if is_resume:
messages.extend(request.tool_responses)
last_turn = turns[-1]
last_turn_messages = self.turn_to_messages(last_turn)
last_turn_messages = [
x for x in last_turn_messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage)
]

# TODO: figure out whether we should add the tool responses to the last turn messages
last_turn_messages.extend(request.tool_responses)

# get the steps from the turn id
steps = []
steps = turns[-1].steps
# get steps from the turn
steps = last_turn.steps

# mark tool execution step as complete
# if there's no tool execution in progress step (due to storage, or tool call parsing on client),
Expand Down Expand Up @@ -326,61 +259,66 @@ async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
)
)
)
input_messages = last_turn_messages

output_message = None
async for chunk in self.run(
session_id=request.session_id,
turn_id=request.turn_id,
input_messages=messages,
sampling_params=self.agent_config.sampling_params,
stream=request.stream,
):
if isinstance(chunk, CompletionMessage):
output_message = chunk
continue

assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
event = chunk.event
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
steps.append(event.payload.step_details)

yield chunk
turn_id = request.turn_id
start_time = last_turn.started_at
else:
messages.extend(request.messages)
start_time = datetime.now().astimezone().isoformat()
input_messages = request.messages

output_message = None
async for chunk in self.run(
session_id=request.session_id,
turn_id=turn_id,
input_messages=messages,
sampling_params=self.agent_config.sampling_params,
stream=request.stream,
documents=request.documents if not is_resume else None,
toolgroups_for_turn=request.toolgroups if not is_resume else None,
):
if isinstance(chunk, CompletionMessage):
output_message = chunk
continue

assert output_message is not None
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
event = chunk.event
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
steps.append(event.payload.step_details)

last_turn_start_time = datetime.now().astimezone().isoformat()
if len(turns) > 0:
last_turn_start_time = turns[-1].started_at
yield chunk

turn = Turn(
turn_id=request.turn_id,
session_id=request.session_id,
input_messages=last_turn_messages,
output_message=output_message,
started_at=last_turn_start_time,
completed_at=datetime.now().astimezone().isoformat(),
steps=steps,
)
await self.storage.add_turn_to_session(request.session_id, turn)
assert output_message is not None

if output_message.tool_calls:
chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnAwaitingInputPayload(
turn=turn,
)
turn = Turn(
turn_id=turn_id,
session_id=request.session_id,
input_messages=input_messages,
output_message=output_message,
started_at=start_time,
completed_at=datetime.now().astimezone().isoformat(),
steps=steps,
)
await self.storage.add_turn_to_session(request.session_id, turn)
if output_message.tool_calls:
chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnAwaitingInputPayload(
turn=turn,
)
)
else:
chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnCompletePayload(
turn=turn,
)
)
else:
chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnCompletePayload(
turn=turn,
)
)
)

yield chunk
yield chunk

async def run(
self,
Expand Down
Loading