Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 91 additions & 2 deletions src/llama_stack_client/lib/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,118 @@
from typing import Iterator, List, Optional, Tuple, Union

from llama_stack_client import LlamaStackClient
import logging

from llama_stack_client.types import ToolResponseMessage, ToolResponseParam, UserMessage
from llama_stack_client.types.agent_create_params import AgentConfig
from llama_stack_client.types.agents.turn import CompletionMessage, Turn
from llama_stack_client.types.agents.turn_create_params import Document, Toolgroup
from llama_stack_client.types.agents.turn_create_response import AgentTurnResponseStreamChunk
from llama_stack_client.types.shared.tool_call import ToolCall
from llama_stack_client.types.shared_params.response_format import ResponseFormat
from llama_stack_client.types.shared_params.sampling_params import SamplingParams
from llama_stack_client.types.shared_params.agent_config import ToolConfig

from .client_tool import ClientTool
from .tool_parser import ToolParser

DEFAULT_MAX_ITER = 10

logger = logging.getLogger(__name__)


class Agent:
def __init__(
self,
client: LlamaStackClient,
agent_config: AgentConfig,
client_tools: Tuple[ClientTool] = (),
# begin deprecated
agent_config: Optional[AgentConfig] = None,
client_tools: Tuple[ClientTool, ...] = (),
# end deprecated
tool_parser: Optional[ToolParser] = None,
model: Optional[str] = None,
instructions: Optional[str] = None,
tools: Optional[List[Union[Toolgroup, ClientTool]]] = None,
tool_config: Optional[ToolConfig] = None,
sampling_params: Optional[SamplingParams] = None,
max_infer_iters: Optional[int] = None,
input_shields: Optional[List[str]] = None,
output_shields: Optional[List[str]] = None,
response_format: Optional[ResponseFormat] = None,
enable_session_persistence: Optional[bool] = None,
):
"""Construct an Agent with the given parameters.

:param client: The LlamaStackClient instance.
:param agent_config: The AgentConfig instance.
::deprecated: use other parameters instead
:param client_tools: A tuple of ClientTool instances.
::deprecated: use tools instead
:param tool_parser: Custom logic that parses tool calls from a message.
:param model: The model to use for the agent.
:param instructions: The instructions for the agent.
:param tools: A list of tools for the agent. Values can be one of the following:
- dict representing a toolgroup/tool with arguments: e.g. {"name": "builtin::rag/knowledge_search", "args": {"vector_db_ids": [123]}}
- a python function decorated with @client_tool
- str representing a tool within a toolgroup: e.g. "builtin::rag/knowledge_search"
- str representing a toolgroup_id: e.g. "builtin::rag", "builtin::code_interpreter", where all tools in the toolgroup will be added to the agent
- an instance of ClientTool: A client tool object.
:param tool_config: The tool configuration for the agent.
:param sampling_params: The sampling parameters for the agent.
:param max_infer_iters: The maximum number of inference iterations.
:param input_shields: The input shields for the agent.
:param output_shields: The output shields for the agent.
:param response_format: The response format for the agent.
:param enable_session_persistence: Whether to enable session persistence.
"""
self.client = client

if agent_config is not None:
logger.warning("`agent_config` is deprecated. Use inlined parameters instead.")
if client_tools != ():
logger.warning("`client_tools` is deprecated. Use `tools` instead.")

# Construct agent_config from parameters if not provided
if agent_config is None:
# Create a minimal valid AgentConfig with required fields
if model is None or instructions is None:
raise ValueError("Both 'model' and 'instructions' are required when agent_config is not provided")

agent_config = {
"model": model,
"instructions": instructions,
}

# Add optional parameters if provided
if enable_session_persistence is not None:
agent_config["enable_session_persistence"] = enable_session_persistence
if input_shields is not None:
agent_config["input_shields"] = input_shields
if max_infer_iters is not None:
agent_config["max_infer_iters"] = max_infer_iters
if output_shields is not None:
agent_config["output_shields"] = output_shields
if response_format is not None:
agent_config["response_format"] = response_format
if sampling_params is not None:
agent_config["sampling_params"] = sampling_params
if tool_config is not None:
agent_config["tool_config"] = tool_config
if tools is not None:
toolgroups: List[Toolgroup] = []
client_tools: List[ClientTool] = []

for tool in tools:
if isinstance(tool, str) or isinstance(tool, dict):
toolgroups.append(tool)
else:
client_tools.append(tool)

agent_config["toolgroups"] = toolgroups
agent_config["client_tools"] = [tool.get_tool_definition() for tool in client_tools]

agent_config = AgentConfig(**agent_config)

self.agent_config = agent_config
self.agent_id = self._create_agent(agent_config)
self.client_tools = {t.get_name(): t for t in client_tools}
Expand Down