Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/llama_stack_client/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ def _build_request(
# so that passing a `TypedDict` doesn't cause an error.
# https://github.com/microsoft/pyright/issues/3526#event-6715453066
params=self.qs.stringify(cast(Mapping[str, Any], params)) if params else None,
json=json_data,
json=json_data if is_given(json_data) else None,
files=files,
**kwargs,
)
Expand Down
16 changes: 8 additions & 8 deletions src/llama_stack_client/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,16 +126,16 @@ def __init__(
) -> None:
"""Construct a new synchronous llama-stack-client client instance.

This automatically infers the `api_key` argument from the `LLAMA_STACK_CLIENT_API_KEY` environment variable if it is not provided.
This automatically infers the `api_key` argument from the `LLAMA_STACK_API_KEY` environment variable if it is not provided.
"""
if api_key is None:
api_key = os.environ.get("LLAMA_STACK_CLIENT_API_KEY")
api_key = os.environ.get("LLAMA_STACK_API_KEY")
self.api_key = api_key

if base_url is None:
base_url = os.environ.get("LLAMA_STACK_CLIENT_BASE_URL")
base_url = os.environ.get("LLAMA_STACK_BASE_URL")
if base_url is None:
base_url = f"http://any-hosted-llama-stack.com"
base_url = "http://any-hosted-llama-stack.com"

custom_headers = default_headers or {}
custom_headers["X-LlamaStack-Client-Version"] = __version__
Expand Down Expand Up @@ -342,16 +342,16 @@ def __init__(
) -> None:
"""Construct a new async llama-stack-client client instance.

This automatically infers the `api_key` argument from the `LLAMA_STACK_CLIENT_API_KEY` environment variable if it is not provided.
This automatically infers the `api_key` argument from the `LLAMA_STACK_API_KEY` environment variable if it is not provided.
"""
if api_key is None:
api_key = os.environ.get("LLAMA_STACK_CLIENT_API_KEY")
api_key = os.environ.get("LLAMA_STACK_API_KEY")
self.api_key = api_key

if base_url is None:
base_url = os.environ.get("LLAMA_STACK_CLIENT_BASE_URL")
base_url = os.environ.get("LLAMA_STACK_BASE_URL")
if base_url is None:
base_url = f"http://any-hosted-llama-stack.com"
base_url = "http://any-hosted-llama-stack.com"

custom_headers = default_headers or {}
custom_headers["X-LlamaStack-Client-Version"] = __version__
Expand Down
4 changes: 2 additions & 2 deletions src/llama_stack_client/_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def _transform_file(file: FileTypes) -> HttpxFileTypes:
if is_tuple_t(file):
return (file[0], _read_file_content(file[1]), *file[2:])

raise TypeError(f"Expected file types input to be a FileContent type or to be a tuple")
raise TypeError("Expected file types input to be a FileContent type or to be a tuple")


def _read_file_content(file: FileContent) -> HttpxFileContent:
Expand Down Expand Up @@ -113,7 +113,7 @@ async def _async_transform_file(file: FileTypes) -> HttpxFileTypes:
if is_tuple_t(file):
return (file[0], await _async_read_file_content(file[1]), *file[2:])

raise TypeError(f"Expected file types input to be a FileContent type or to be a tuple")
raise TypeError("Expected file types input to be a FileContent type or to be a tuple")


async def _async_read_file_content(file: FileContent) -> HttpxFileContent:
Expand Down
8 changes: 4 additions & 4 deletions src/llama_stack_client/_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def _parse(self, *, to: type[_T] | None = None) -> R | _T:
# the response class ourselves but that is something that should be supported directly in httpx
# as it would be easy to incorrectly construct the Response object due to the multitude of arguments.
if cast_to != httpx.Response:
raise ValueError(f"Subclasses of httpx.Response cannot be passed to `cast_to`")
raise ValueError("Subclasses of httpx.Response cannot be passed to `cast_to`")
return cast(R, response)

if (
Expand All @@ -245,9 +245,9 @@ def _parse(self, *, to: type[_T] | None = None) -> R | _T:

if (
cast_to is not object
and not origin is list
and not origin is dict
and not origin is Union
and origin is not list
and origin is not dict
and origin is not Union
and not issubclass(origin, BaseModel)
):
raise RuntimeError(
Expand Down
2 changes: 1 addition & 1 deletion src/llama_stack_client/_utils/_logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def _basic_config() -> None:


def setup_logging() -> None:
env = os.environ.get("LLAMA_STACK_CLIENT_LOG")
env = os.environ.get("LLAMA_STACK_LOG")
if env == "debug":
_basic_config()
logger.setLevel(logging.DEBUG)
Expand Down
119 changes: 44 additions & 75 deletions src/llama_stack_client/lib/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import uuid
from datetime import datetime
from typing import Iterator, List, Optional, Tuple, Union

from llama_stack_client import LlamaStackClient
Expand All @@ -16,13 +14,7 @@
from llama_stack_client.types.agents.turn_create_response import (
AgentTurnResponseStreamChunk,
)
from llama_stack_client.types.agents.turn_response_event import TurnResponseEvent
from llama_stack_client.types.agents.turn_response_event_payload import (
AgentTurnResponseStepCompletePayload,
)
from llama_stack_client.types.shared.tool_call import ToolCall
from llama_stack_client.types.tool_execution_step import ToolExecutionStep
from llama_stack_client.types.tool_response import ToolResponse

from .client_tool import ClientTool
from .tool_parser import ToolParser
Expand Down Expand Up @@ -66,7 +58,7 @@ def create_session(self, session_name: str) -> str:
return self.session_id

def _get_tool_calls(self, chunk: AgentTurnResponseStreamChunk) -> List[ToolCall]:
if chunk.event.payload.event_type != "turn_complete":
if chunk.event.payload.event_type not in {"turn_complete", "turn_awaiting_input"}:
return []

message = chunk.event.payload.turn.output_message
Expand All @@ -78,6 +70,12 @@ def _get_tool_calls(self, chunk: AgentTurnResponseStreamChunk) -> List[ToolCall]

return message.tool_calls

def _get_turn_id(self, chunk: AgentTurnResponseStreamChunk) -> Optional[str]:
if chunk.event.payload.event_type not in ["turn_complete", "turn_awaiting_input"]:
return None

return chunk.event.payload.turn.turn_id

def _run_tool(self, tool_calls: List[ToolCall]) -> ToolResponseMessage:
assert len(tool_calls) == 1, "Only one tool call is supported"
tool_call = tool_calls[0]
Expand Down Expand Up @@ -132,27 +130,10 @@ def create_turn(
if stream:
return self._create_turn_streaming(messages, session_id, toolgroups, documents)
else:
chunks = []
for chunk in self._create_turn_streaming(messages, session_id, toolgroups, documents):
if chunk.event.payload.event_type == "turn_complete":
chunks.append(chunk)
pass
chunks = [x for x in self._create_turn_streaming(messages, session_id, toolgroups, documents)]
if not chunks:
raise Exception("Turn did not complete")

# merge chunks
return Turn(
input_messages=chunks[0].event.payload.turn.input_messages,
output_message=chunks[-1].event.payload.turn.output_message,
session_id=chunks[0].event.payload.turn.session_id,
steps=[step for chunk in chunks for step in chunk.event.payload.turn.steps],
turn_id=chunks[0].event.payload.turn.turn_id,
started_at=chunks[0].event.payload.turn.started_at,
completed_at=chunks[-1].event.payload.turn.completed_at,
output_attachments=[
attachment for chunk in chunks for attachment in chunk.event.payload.turn.output_attachments
],
)
return chunks[-1].event.payload.turn

def _create_turn_streaming(
self,
Expand All @@ -161,62 +142,50 @@ def _create_turn_streaming(
toolgroups: Optional[List[Toolgroup]] = None,
documents: Optional[List[Document]] = None,
) -> Iterator[AgentTurnResponseStreamChunk]:
stop = False
n_iter = 0
max_iter = self.agent_config.get("max_infer_iters", DEFAULT_MAX_ITER)
while not stop and n_iter < max_iter:
response = self.client.agents.turn.create(
agent_id=self.agent_id,
# use specified session_id or last session created
session_id=session_id or self.session_id[-1],
messages=messages,
stream=True,
documents=documents,
toolgroups=toolgroups,
)
# by default, we stop after the first turn
stop = True
for chunk in response:

# 1. create an agent turn
turn_response = self.client.agents.turn.create(
agent_id=self.agent_id,
# use specified session_id or last session created
session_id=session_id or self.session_id[-1],
messages=messages,
stream=True,
documents=documents,
toolgroups=toolgroups,
allow_turn_resume=True,
)

# 2. process turn and resume if there's a tool call
is_turn_complete = False
while not is_turn_complete:
is_turn_complete = True
for chunk in turn_response:
tool_calls = self._get_tool_calls(chunk)
if hasattr(chunk, "error"):
yield chunk
return
elif not tool_calls:
yield chunk
else:
tool_execution_start_time = datetime.now()
is_turn_complete = False
turn_id = self._get_turn_id(chunk)
if n_iter == 0:
yield chunk

# run the tools
tool_response_message = self._run_tool(tool_calls)
tool_execution_step = ToolExecutionStep(
step_type="tool_execution",
step_id=str(uuid.uuid4()),
tool_calls=tool_calls,
tool_responses=[
ToolResponse(
tool_name=tool_response_message.tool_name,
content=tool_response_message.content,
call_id=tool_response_message.call_id,
)
],
turn_id=chunk.event.payload.turn.turn_id,
completed_at=datetime.now(),
started_at=tool_execution_start_time,
)
yield AgentTurnResponseStreamChunk(
event=TurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
event_type="step_complete",
step_id=tool_execution_step.step_id,
step_type="tool_execution",
step_details=tool_execution_step,
)
)
# pass it to next iteration
turn_response = self.client.agents.turn.resume(
agent_id=self.agent_id,
session_id=session_id or self.session_id[-1],
turn_id=turn_id,
tool_responses=[tool_response_message],
stream=True,
)

# HACK: append the tool execution step to the turn
chunk.event.payload.turn.steps.append(tool_execution_step)
yield chunk

# continue the turn when there's a tool call
stop = False
messages = [tool_response_message]
n_iter += 1
break

if n_iter >= max_iter:
raise Exception(f"Turn did not complete in {max_iter} iterations")
6 changes: 4 additions & 2 deletions src/llama_stack_client/lib/agents/event_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def _yield_printable_events(
event = chunk.event
event_type = event.payload.event_type

if event_type in {"turn_start", "turn_complete"}:
if event_type in {"turn_start", "turn_complete", "turn_awaiting_input"}:
# Currently not logging any turn realted info
yield TurnStreamPrintableEvent(role=None, content="", end="", color="grey")
return
Expand Down Expand Up @@ -149,7 +149,9 @@ def _get_event_type_step_type(self, chunk: Any) -> Tuple[Optional[str], Optional
if hasattr(chunk, "event"):
previous_event_type = chunk.event.payload.event_type if hasattr(chunk, "event") else None
previous_step_type = (
chunk.event.payload.step_type if previous_event_type not in {"turn_start", "turn_complete"} else None
chunk.event.payload.step_type
if previous_event_type not in {"turn_start", "turn_complete", "turn_awaiting_input"}
else None
)
return previous_event_type, previous_step_type
return None, None
Expand Down
Loading