Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 39 additions & 9 deletions src/llama_stack_client/lib/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,18 @@
from llama_stack_client import LlamaStackClient
from llama_stack_client.types import ToolResponseMessage, ToolResponseParam, UserMessage
from llama_stack_client.types.agent_create_params import AgentConfig
from llama_stack_client.types.agents.agent_turn_response_stream_chunk import (
AgentTurnResponseStreamChunk,
)
from llama_stack_client.types.agents.turn import CompletionMessage, Turn
from llama_stack_client.types.agents.turn_create_params import Document, Toolgroup
from llama_stack_client.types.agents.agent_turn_response_stream_chunk import AgentTurnResponseStreamChunk
from llama_stack_client.types.shared.tool_call import ToolCall
from llama_stack_client.types.shared_params.agent_config import ToolConfig
from llama_stack_client.types.shared_params.response_format import ResponseFormat
from llama_stack_client.types.shared_params.sampling_params import SamplingParams

from .client_tool import client_tool, ClientTool
from ..._types import Headers
from .client_tool import ClientTool, client_tool
from .tool_parser import ToolParser

DEFAULT_MAX_ITER = 10
Expand All @@ -27,7 +30,9 @@

class AgentUtils:
@staticmethod
def get_client_tools(tools: Optional[List[Union[Toolgroup, ClientTool, Callable[..., Any]]]]) -> List[ClientTool]:
def get_client_tools(
tools: Optional[List[Union[Toolgroup, ClientTool, Callable[..., Any]]]],
) -> List[ClientTool]:
if not tools:
return []

Expand All @@ -37,7 +42,10 @@ def get_client_tools(tools: Optional[List[Union[Toolgroup, ClientTool, Callable[

@staticmethod
def get_tool_calls(chunk: AgentTurnResponseStreamChunk, tool_parser: Optional[ToolParser] = None) -> List[ToolCall]:
if chunk.event.payload.event_type not in {"turn_complete", "turn_awaiting_input"}:
if chunk.event.payload.event_type not in {
"turn_complete",
"turn_awaiting_input",
}:
return []

message = chunk.event.payload.turn.output_message
Expand All @@ -51,7 +59,10 @@ def get_tool_calls(chunk: AgentTurnResponseStreamChunk, tool_parser: Optional[To

@staticmethod
def get_turn_id(chunk: AgentTurnResponseStreamChunk) -> Optional[str]:
if chunk.event.payload.event_type not in ["turn_complete", "turn_awaiting_input"]:
if chunk.event.payload.event_type not in [
"turn_complete",
"turn_awaiting_input",
]:
return None

return chunk.event.payload.turn.turn_id
Expand Down Expand Up @@ -228,7 +239,10 @@ def _run_single_tool(self, tool_call: ToolCall) -> ToolResponseParam:
if tool_call.tool_name in self.builtin_tools:
tool_result = self.client.tool_runtime.invoke_tool(
tool_name=tool_call.tool_name,
kwargs={**tool_call.arguments, **self.builtin_tools[tool_call.tool_name]},
kwargs={
**tool_call.arguments,
**self.builtin_tools[tool_call.tool_name],
},
)
return ToolResponseParam(
call_id=tool_call.call_id,
Expand All @@ -250,11 +264,21 @@ def create_turn(
toolgroups: Optional[List[Toolgroup]] = None,
documents: Optional[List[Document]] = None,
stream: bool = True,
extra_headers: Headers | None = None,
) -> Iterator[AgentTurnResponseStreamChunk] | Turn:
if stream:
return self._create_turn_streaming(messages, session_id, toolgroups, documents)
return self._create_turn_streaming(messages, session_id, toolgroups, documents, extra_headers=extra_headers)
else:
chunks = [x for x in self._create_turn_streaming(messages, session_id, toolgroups, documents)]
chunks = [
x
for x in self._create_turn_streaming(
messages,
session_id,
toolgroups,
documents,
extra_headers=extra_headers,
)
]
if not chunks:
raise Exception("Turn did not complete")

Expand All @@ -276,6 +300,7 @@ def _create_turn_streaming(
session_id: Optional[str] = None,
toolgroups: Optional[List[Toolgroup]] = None,
documents: Optional[List[Document]] = None,
extra_headers: Headers | None = None,
) -> Iterator[AgentTurnResponseStreamChunk]:
n_iter = 0

Expand All @@ -288,6 +313,7 @@ def _create_turn_streaming(
stream=True,
documents=documents,
toolgroups=toolgroups,
extra_headers=extra_headers,
)

# 2. process turn and resume if there's a tool call
Expand Down Expand Up @@ -324,6 +350,7 @@ def _create_turn_streaming(
turn_id=turn_id,
tool_responses=tool_responses,
stream=True,
extra_headers=extra_headers,
)
n_iter += 1

Expand Down Expand Up @@ -478,7 +505,10 @@ async def _run_single_tool(self, tool_call: ToolCall) -> ToolResponseParam:
if tool_call.tool_name in self.builtin_tools:
tool_result = await self.client.tool_runtime.invoke_tool(
tool_name=tool_call.tool_name,
kwargs={**tool_call.arguments, **self.builtin_tools[tool_call.tool_name]},
kwargs={
**tool_call.arguments,
**self.builtin_tools[tool_call.tool_name],
},
)
return ToolResponseParam(
call_id=tool_call.call_id,
Expand Down
Loading