diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index dcf38426..c40ef4c8 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -11,10 +11,10 @@ from llama_stack_client.types.agents.turn import Turn from llama_stack_client.types.agents.turn_create_params import Document, Toolgroup from llama_stack_client.types.agents.turn_create_response import AgentTurnResponseStreamChunk - - +from llama_stack_client.types.shared.tool_call import ToolCall +from llama_stack_client.types.agents.turn import CompletionMessage from .client_tool import ClientTool -from .output_parser import OutputParser +from .tool_parser import ToolParser DEFAULT_MAX_ITER = 10 @@ -25,14 +25,14 @@ def __init__( client: LlamaStackClient, agent_config: AgentConfig, client_tools: Tuple[ClientTool] = (), - output_parser: Optional[OutputParser] = None, + tool_parser: Optional[ToolParser] = None, ): self.client = client self.agent_config = agent_config self.agent_id = self._create_agent(agent_config) self.client_tools = {t.get_name(): t for t in client_tools} self.sessions = [] - self.output_parser = output_parser + self.tool_parser = tool_parser self.builtin_tools = {} for tg in agent_config["toolgroups"]: for tool in self.client.tools.list(toolgroup_id=tg): @@ -54,33 +54,38 @@ def create_session(self, session_name: str) -> int: self.sessions.append(self.session_id) return self.session_id - def _process_chunk(self, chunk: AgentTurnResponseStreamChunk) -> None: + def _get_tool_calls(self, chunk: AgentTurnResponseStreamChunk) -> List[ToolCall]: if chunk.event.payload.event_type != "turn_complete": - return - message = chunk.event.payload.turn.output_message - - if self.output_parser: - self.output_parser.parse(message) + return [] - def _has_tool_call(self, chunk: AgentTurnResponseStreamChunk) -> bool: - if chunk.event.payload.event_type != "turn_complete": - return False message = chunk.event.payload.turn.output_message if message.stop_reason == "out_of_tokens": - return False + return [] - return len(message.tool_calls) > 0 + if self.tool_parser: + return self.tool_parser.get_tool_calls(message) - def _run_tool(self, chunk: AgentTurnResponseStreamChunk) -> ToolResponseMessage: - message = chunk.event.payload.turn.output_message - tool_call = message.tool_calls[0] + return message.tool_calls + + def _run_tool(self, tool_calls: List[ToolCall]) -> ToolResponseMessage: + assert len(tool_calls) == 1, "Only one tool call is supported" + tool_call = tool_calls[0] # custom client tools if tool_call.tool_name in self.client_tools: tool = self.client_tools[tool_call.tool_name] # NOTE: tool.run() expects a list of messages, we only pass in last message here # but we could pass in the entire message history - result_message = tool.run([message]) + result_message = tool.run( + [ + CompletionMessage( + role="assistant", + content=tool_call.tool_name, + tool_calls=[tool_call], + stop_reason="end_of_turn", + ) + ] + ) return result_message # builtin tools executed by tool_runtime @@ -149,14 +154,14 @@ def _create_turn_streaming( # by default, we stop after the first turn stop = True for chunk in response: - self._process_chunk(chunk) + tool_calls = self._get_tool_calls(chunk) if hasattr(chunk, "error"): yield chunk return - elif not self._has_tool_call(chunk): + elif not tool_calls: yield chunk else: - next_message = self._run_tool(chunk) + next_message = self._run_tool(tool_calls) yield next_message # continue the turn when there's a tool call diff --git a/src/llama_stack_client/lib/agents/output_parser.py b/src/llama_stack_client/lib/agents/output_parser.py deleted file mode 100644 index 20c8468e..00000000 --- a/src/llama_stack_client/lib/agents/output_parser.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from abc import abstractmethod - -from llama_stack_client.types.agents.turn import CompletionMessage - - -class OutputParser: - """ - Abstract base class for parsing agent responses. Implement this class to customize how - agent outputs are processed and transformed. - - This class allows developers to define custom parsing logic for agent responses, - which can be useful for: - - Extracting specific information from the response - - Formatting or structuring the output in a specific way - - Validating or sanitizing the agent's response - - To use this class: - 1. Create a subclass of OutputParser - 2. Implement the `parse` method - 3. Pass your parser instance to the Agent's constructor - - Example: - class MyCustomParser(OutputParser): - def parse(self, output_message: CompletionMessage) -> CompletionMessage: - # Add your custom parsing logic here - return processed_message - - Methods: - parse(output_message: CompletionMessage) -> CompletionMessage: - Abstract method that must be implemented by subclasses to process - the agent's response. - - Args: - output_message (CompletionMessage): The response message from agent turn - - Returns: None - Modifies the output_message in place - """ - - @abstractmethod - def parse(self, output_message: CompletionMessage) -> None: - raise NotImplementedError diff --git a/src/llama_stack_client/lib/agents/react/agent.py b/src/llama_stack_client/lib/agents/react/agent.py index 3d40a08b..622d4420 100644 --- a/src/llama_stack_client/lib/agents/react/agent.py +++ b/src/llama_stack_client/lib/agents/react/agent.py @@ -6,8 +6,8 @@ from pydantic import BaseModel from typing import Dict, Any from ..agent import Agent -from .output_parser import ReActOutputParser -from ..output_parser import OutputParser +from .tool_parser import ReActToolParser +from ..tool_parser import ToolParser from .prompts import DEFAULT_REACT_AGENT_SYSTEM_PROMPT_TEMPLATE from typing import Tuple, Optional @@ -39,7 +39,7 @@ def __init__( model: str, builtin_toolgroups: Tuple[str] = (), client_tools: Tuple[ClientTool] = (), - output_parser: OutputParser = ReActOutputParser(), + tool_parser: ToolParser = ReActToolParser(), json_response_format: bool = False, custom_agent_config: Optional[AgentConfig] = None, ): @@ -101,5 +101,5 @@ def get_tool_defs(): client=client, agent_config=agent_config, client_tools=client_tools, - output_parser=output_parser, + tool_parser=tool_parser, ) diff --git a/src/llama_stack_client/lib/agents/react/output_parser.py b/src/llama_stack_client/lib/agents/react/tool_parser.py similarity index 72% rename from src/llama_stack_client/lib/agents/react/output_parser.py rename to src/llama_stack_client/lib/agents/react/tool_parser.py index 71177a6f..e668d28d 100644 --- a/src/llama_stack_client/lib/agents/react/output_parser.py +++ b/src/llama_stack_client/lib/agents/react/tool_parser.py @@ -5,8 +5,8 @@ # the root directory of this source tree. from pydantic import BaseModel, ValidationError -from typing import Dict, Any, Optional -from ..output_parser import OutputParser +from typing import Dict, Any, Optional, List +from ..tool_parser import ToolParser from llama_stack_client.types.shared.completion_message import CompletionMessage from llama_stack_client.types.shared.tool_call import ToolCall @@ -24,23 +24,24 @@ class ReActOutput(BaseModel): answer: Optional[str] = None -class ReActOutputParser(OutputParser): - def parse(self, output_message: CompletionMessage) -> None: +class ReActToolParser(ToolParser): + def get_tool_calls(self, output_message: CompletionMessage) -> List[ToolCall]: + tool_calls = [] response_text = str(output_message.content) try: react_output = ReActOutput.model_validate_json(response_text) except ValidationError as e: print(f"Error parsing action: {e}") - return + return tool_calls if react_output.answer: - return + return tool_calls if react_output.action: tool_name = react_output.action.tool_name tool_params = react_output.action.tool_params if tool_name and tool_params: call_id = str(uuid.uuid4()) - output_message.tool_calls = [ToolCall(call_id=call_id, tool_name=tool_name, arguments=tool_params)] + tool_calls = [ToolCall(call_id=call_id, tool_name=tool_name, arguments=tool_params)] - return + return tool_calls diff --git a/src/llama_stack_client/lib/agents/tool_parser.py b/src/llama_stack_client/lib/agents/tool_parser.py new file mode 100644 index 00000000..dc0c5ba4 --- /dev/null +++ b/src/llama_stack_client/lib/agents/tool_parser.py @@ -0,0 +1,44 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from abc import abstractmethod +from typing import List + +from llama_stack_client.types.agents.turn import CompletionMessage +from llama_stack_client.types.shared.tool_call import ToolCall + + +class ToolParser: + """ + Abstract base class for parsing agent responses into tool calls. Implement this class to customize how + agent outputs are processed and transformed into executable tool calls. + + To use this class: + 1. Create a subclass of ToolParser + 2. Implement the `get_tool_calls` method + 3. Pass your parser instance to the Agent's constructor + + Example: + class MyCustomParser(ToolParser): + def get_tool_calls(self, output_message: CompletionMessage) -> List[ToolCall]: + # Add your custom parsing logic here + return extracted_tool_calls + + Methods: + get_tool_calls(output_message: CompletionMessage) -> List[ToolCall]: + Abstract method that must be implemented by subclasses to process + the agent's response and extract tool calls. + + Args: + output_message (CompletionMessage): The response message from agent turn + + Returns: + Optional[List[ToolCall]]: A list of parsed tool calls, or None if no tools should be called + """ + + @abstractmethod + def get_tool_calls(self, output_message: CompletionMessage) -> List[ToolCall]: + raise NotImplementedError