Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 63 additions & 12 deletions src/llama_stack_client/lib/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,21 @@
from llama_stack_client.types.agent_create_params import AgentConfig
from llama_stack_client.types.agents.turn import Turn
from llama_stack_client.types.agents.turn_create_params import Document, Toolgroup
from llama_stack_client.types.agents.turn_create_response import AgentTurnResponseStreamChunk
from llama_stack_client.types.agents.turn_create_response import (
AgentTurnResponseStreamChunk,
)
from llama_stack_client.types.agents.turn_response_event import TurnResponseEvent
from llama_stack_client.types.agents.turn_response_event_payload import (
AgentTurnResponseStepCompletePayload,
)
from llama_stack_client.types.shared.tool_call import ToolCall
from llama_stack_client.types.agents.turn import CompletionMessage
from .client_tool import ClientTool
from .tool_parser import ToolParser
from datetime import datetime
import uuid
from llama_stack_client.types.tool_execution_step import ToolExecutionStep
from llama_stack_client.types.tool_response import ToolResponse

DEFAULT_MAX_ITER = 10

Expand Down Expand Up @@ -119,24 +129,36 @@ def create_turn(
stream: bool = True,
) -> Iterator[AgentTurnResponseStreamChunk] | Turn:
if stream:
return self._create_turn_streaming(messages, session_id, toolgroups, documents, stream)
return self._create_turn_streaming(messages, session_id, toolgroups, documents)
else:
chunk = None
for chunk in self._create_turn_streaming(messages, session_id, toolgroups, documents, stream):
chunks = []
for chunk in self._create_turn_streaming(messages, session_id, toolgroups, documents):
if chunk.event.payload.event_type == "turn_complete":
chunks.append(chunk)
pass
if not chunk:
raise Exception("No chunk returned")
if chunk.event.payload.event_type != "turn_complete":
if not chunks:
raise Exception("Turn did not complete")
return chunk.event.payload.turn

# merge chunks
return Turn(
input_messages=chunks[0].event.payload.turn.input_messages,
output_message=chunks[-1].event.payload.turn.output_message,
session_id=chunks[0].event.payload.turn.session_id,
steps=[step for chunk in chunks for step in chunk.event.payload.turn.steps],
turn_id=chunks[0].event.payload.turn.turn_id,
started_at=chunks[0].event.payload.turn.started_at,
completed_at=chunks[-1].event.payload.turn.completed_at,
output_attachments=[
attachment for chunk in chunks for attachment in chunk.event.payload.turn.output_attachments
],
)

def _create_turn_streaming(
self,
messages: List[Union[UserMessage, ToolResponseMessage]],
session_id: Optional[str] = None,
toolgroups: Optional[List[Toolgroup]] = None,
documents: Optional[List[Document]] = None,
stream: bool = True,
) -> Iterator[AgentTurnResponseStreamChunk]:
stop = False
n_iter = 0
Expand All @@ -161,10 +183,39 @@ def _create_turn_streaming(
elif not tool_calls:
yield chunk
else:
next_message = self._run_tool(tool_calls)
yield next_message
tool_execution_start_time = datetime.now()
tool_response_message = self._run_tool(tool_calls)
tool_execution_step = ToolExecutionStep(
step_type="tool_execution",
step_id=str(uuid.uuid4()),
tool_calls=tool_calls,
tool_responses=[
ToolResponse(
tool_name=tool_response_message.tool_name,
content=tool_response_message.content,
call_id=tool_response_message.call_id,
)
],
turn_id=chunk.event.payload.turn.turn_id,
completed_at=datetime.now(),
started_at=tool_execution_start_time,
)
yield AgentTurnResponseStreamChunk(
event=TurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
event_type="step_complete",
step_id=tool_execution_step.step_id,
step_type="tool_execution",
step_details=tool_execution_step,
)
)
)

# HACK: append the tool execution step to the turn
chunk.event.payload.turn.steps.append(tool_execution_step)
yield chunk

# continue the turn when there's a tool call
stop = False
messages = [next_message]
messages = [tool_response_message]
n_iter += 1
10 changes: 1 addition & 9 deletions src/llama_stack_client/lib/agents/event_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from termcolor import cprint

from llama_stack_client.types import InterleavedContent, ToolResponseMessage
from llama_stack_client.types import InterleavedContent


def interleaved_content_as_str(content: InterleavedContent, sep: str = " ") -> str:
Expand Down Expand Up @@ -70,14 +70,6 @@ def _yield_printable_events(self, chunk, previous_event_type=None, previous_step
yield TurnStreamPrintableEvent(role=None, content=chunk.error["message"], color="red")
return

if not hasattr(chunk, "event"):
# Need to check for custom tool first
# since it does not produce event but instead
# a Message
if isinstance(chunk, ToolResponseMessage):
yield TurnStreamPrintableEvent(role="CustomTool", content=chunk.content, color="green")
return

event = chunk.event
event_type = event.payload.event_type

Expand Down