Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions src/llama_stack_client/lib/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,16 @@
from llama_stack_client import LlamaStackClient
from llama_stack_client.types import Attachment, ToolResponseMessage, UserMessage
from llama_stack_client.types.agent_create_params import AgentConfig

from .custom_tool import CustomTool

class Agent:
def __init__(self, client: LlamaStackClient, agent_config: AgentConfig, custom_tools: Tuple[CustomTool] = ()):
def __init__(
self,
client: LlamaStackClient,
agent_config: AgentConfig,
custom_tools: Tuple[CustomTool] = (),
):
self.client = client
self.agent_config = agent_config
self.agent_id = self._create_agent(agent_config)
Expand Down Expand Up @@ -72,7 +78,10 @@ def create_turn(
stream=True,
)
for chunk in response:
if not self._has_tool_call(chunk):
if hasattr(chunk, "error"):
yield chunk
return
elif not self._has_tool_call(chunk):
yield chunk
else:
next_message = self._run_tool(chunk)
Expand Down
18 changes: 12 additions & 6 deletions src/llama_stack_client/lib/agents/event_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@

from typing import List, Optional, Union

from termcolor import cprint

from llama_stack_client.types import ToolResponseMessage

from termcolor import cprint


def interleaved_text_media_as_str(content: Union[str, List[str]], sep: str = " ") -> str:
def interleaved_text_media_as_str(
content: Union[str, List[str]], sep: str = " "
) -> str:
def _process(c) -> str:
if isinstance(c, str):
return c
Expand Down Expand Up @@ -49,14 +51,18 @@ def print(self, flush=True):

class EventLogger:
def _get_log_event(self, chunk, previous_event_type=None, previous_step_type=None):
if hasattr(chunk, "error"):
yield LogEvent(role=None, content=chunk.error["message"], color="red")
return
if not hasattr(chunk, "event"):
# Need to check for custom tool first
# since it does not produce event but instead
# a Message
if isinstance(chunk, ToolResponseMessage):
yield LogEvent(role="CustomTool", content=chunk.content, color="green")
yield LogEvent(
role="CustomTool", content=chunk.content, color="green"
)
return

event = chunk.event
event_type = event.payload.event_type

Expand Down Expand Up @@ -153,4 +159,4 @@ def log(self, event_generator):
for chunk in event_generator:
for log_event in self._get_log_event(chunk, previous_event_type, previous_step_type):
yield log_event
previous_event_type, previous_step_type = self._get_event_type_step_type(chunk)
previous_event_type, previous_step_type = self._get_event_type_step_type(chunk)