Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions src/llama_stack_client/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,14 @@ def __init__(
) -> None:
"""Construct a new synchronous llama-stack-client client instance.

This automatically infers the `api_key` argument from the `LLAMA_STACK_CLIENT_API_KEY` environment variable if it is not provided.
This automatically infers the `api_key` argument from the `LLAMA_STACK_API_KEY` environment variable if it is not provided.
"""
if api_key is None:
api_key = os.environ.get("LLAMA_STACK_CLIENT_API_KEY")
api_key = os.environ.get("LLAMA_STACK_API_KEY")
self.api_key = api_key

if base_url is None:
base_url = os.environ.get("LLAMA_STACK_CLIENT_BASE_URL")
base_url = os.environ.get("LLAMA_STACK_BASE_URL")
if base_url is None:
base_url = f"http://any-hosted-llama-stack.com"

Expand Down Expand Up @@ -342,14 +342,14 @@ def __init__(
) -> None:
"""Construct a new async llama-stack-client client instance.

This automatically infers the `api_key` argument from the `LLAMA_STACK_CLIENT_API_KEY` environment variable if it is not provided.
This automatically infers the `api_key` argument from the `LLAMA_STACK_API_KEY` environment variable if it is not provided.
"""
if api_key is None:
api_key = os.environ.get("LLAMA_STACK_CLIENT_API_KEY")
api_key = os.environ.get("LLAMA_STACK_API_KEY")
self.api_key = api_key

if base_url is None:
base_url = os.environ.get("LLAMA_STACK_CLIENT_BASE_URL")
base_url = os.environ.get("LLAMA_STACK_BASE_URL")
if base_url is None:
base_url = f"http://any-hosted-llama-stack.com"

Expand Down
2 changes: 1 addition & 1 deletion src/llama_stack_client/_utils/_logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def _basic_config() -> None:


def setup_logging() -> None:
env = os.environ.get("LLAMA_STACK_CLIENT_LOG")
env = os.environ.get("LLAMA_STACK_LOG")
if env == "debug":
_basic_config()
logger.setLevel(logging.DEBUG)
Expand Down
100 changes: 47 additions & 53 deletions src/llama_stack_client/lib/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,10 @@
from llama_stack_client.types.agents.turn_create_response import (
AgentTurnResponseStreamChunk,
)
from llama_stack_client.types.agents.turn_response_event import TurnResponseEvent
from llama_stack_client.types.agents.turn_response_event_payload import (
AgentTurnResponseStepCompletePayload,
)
from llama_stack_client.types.shared.tool_call import ToolCall
from llama_stack_client.types.agents.turn import CompletionMessage
from .client_tool import ClientTool
from .tool_parser import ToolParser
from datetime import datetime
import uuid
from llama_stack_client.types.tool_execution_step import ToolExecutionStep
from llama_stack_client.types.tool_response import ToolResponse

DEFAULT_MAX_ITER = 10

Expand Down Expand Up @@ -65,7 +57,7 @@ def create_session(self, session_name: str) -> int:
return self.session_id

def _get_tool_calls(self, chunk: AgentTurnResponseStreamChunk) -> List[ToolCall]:
if chunk.event.payload.event_type != "turn_complete":
if chunk.event.payload.event_type not in ["turn_complete", "turn_awaiting_input"]:
return []

message = chunk.event.payload.turn.output_message
Expand All @@ -77,6 +69,12 @@ def _get_tool_calls(self, chunk: AgentTurnResponseStreamChunk) -> List[ToolCall]

return message.tool_calls

def _get_turn_id(self, chunk: AgentTurnResponseStreamChunk) -> Optional[str]:
if chunk.event.payload.event_type not in ["turn_complete", "turn_awaiting_input"]:
return None

return chunk.event.payload.turn.turn_id

def _run_tool(self, tool_calls: List[ToolCall]) -> ToolResponseMessage:
assert len(tool_calls) == 1, "Only one tool call is supported"
tool_call = tool_calls[0]
Expand Down Expand Up @@ -163,59 +161,55 @@ def _create_turn_streaming(
stop = False
n_iter = 0
max_iter = self.agent_config.get("max_infer_iters", DEFAULT_MAX_ITER)
while not stop and n_iter < max_iter:
response = self.client.agents.turn.create(

# 1. create an agent turn
turn_response = self.client.agents.turn.create(
agent_id=self.agent_id,
# use specified session_id or last session created
session_id=session_id or self.session_id[-1],
messages=messages,
stream=True,
documents=documents,
toolgroups=toolgroups,
)
is_turn_complete = True
turn_id = None
for chunk in turn_response:
tool_calls = self._get_tool_calls(chunk)
if hasattr(chunk, "error"):
yield chunk
return
elif not tool_calls:
yield chunk
else:
is_turn_complete = False
turn_id = self._get_turn_id(chunk)
yield chunk
break

# 2. while the turn is not complete, continue the turn
while not is_turn_complete and n_iter < max_iter:
is_turn_complete = True
assert turn_id is not None, "turn_id is None"

# run the tools
tool_response_message = self._run_tool(tool_calls)

continue_response = self.client.agents.turn.continue_(
agent_id=self.agent_id,
# use specified session_id or last session created
session_id=session_id or self.session_id[-1],
messages=messages,
turn_id=turn_id,
tool_responses=[tool_response_message],
stream=True,
documents=documents,
toolgroups=toolgroups,
)
# by default, we stop after the first turn
stop = True
for chunk in response:
for chunk in continue_response:
tool_calls = self._get_tool_calls(chunk)
if hasattr(chunk, "error"):
yield chunk
return
elif not tool_calls:
yield chunk
else:
tool_execution_start_time = datetime.now()
tool_response_message = self._run_tool(tool_calls)
tool_execution_step = ToolExecutionStep(
step_type="tool_execution",
step_id=str(uuid.uuid4()),
tool_calls=tool_calls,
tool_responses=[
ToolResponse(
tool_name=tool_response_message.tool_name,
content=tool_response_message.content,
call_id=tool_response_message.call_id,
)
],
turn_id=chunk.event.payload.turn.turn_id,
completed_at=datetime.now(),
started_at=tool_execution_start_time,
)
yield AgentTurnResponseStreamChunk(
event=TurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
event_type="step_complete",
step_id=tool_execution_step.step_id,
step_type="tool_execution",
step_details=tool_execution_step,
)
)
)

# HACK: append the tool execution step to the turn
chunk.event.payload.turn.steps.append(tool_execution_step)
yield chunk

# continue the turn when there's a tool call
stop = False
messages = [tool_response_message]
is_turn_complete = False
turn_id = self._get_turn_id(chunk)
n_iter += 1
6 changes: 4 additions & 2 deletions src/llama_stack_client/lib/agents/event_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def _yield_printable_events(
event = chunk.event
event_type = event.payload.event_type

if event_type in {"turn_start", "turn_complete"}:
if event_type in {"turn_start", "turn_complete", "turn_awaiting_input"}:
# Currently not logging any turn realted info
yield TurnStreamPrintableEvent(role=None, content="", end="", color="grey")
return
Expand Down Expand Up @@ -149,7 +149,9 @@ def _get_event_type_step_type(self, chunk: Any) -> Tuple[Optional[str], Optional
if hasattr(chunk, "event"):
previous_event_type = chunk.event.payload.event_type if hasattr(chunk, "event") else None
previous_step_type = (
chunk.event.payload.step_type if previous_event_type not in {"turn_start", "turn_complete"} else None
chunk.event.payload.step_type
if previous_event_type not in {"turn_start", "turn_complete", "turn_awaiting_input"}
else None
)
return previous_event_type, previous_step_type
return None, None
Expand Down
Loading