Skip to content

Commit aba9f5f

Browse files
authored
OutputParser -> ToolParser refactor (#130)
# What does this PR do? - See discussion in #121 (comment) ## Test Plan test with meta-llama/llama-stack-apps#166 ``` LLAMA_STACK_BASE_URL=http://localhost:8321 pytest -v tests/client-sdk/agents/test_agents.py::test_override_system_message_behavior --inference-model "meta-llama/Llama-3.3-70B-Instruct" ``` <img width="1697" alt="image" src="https://github.com/user-attachments/assets/c036cbf6-9fc1-4064-82af-fa1984300653" /> ## Sources Please link relevant resources if necessary. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests.
1 parent 7584610 commit aba9f5f

File tree

5 files changed

+85
-83
lines changed

5 files changed

+85
-83
lines changed

src/llama_stack_client/lib/agents/agent.py

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111
from llama_stack_client.types.agents.turn import Turn
1212
from llama_stack_client.types.agents.turn_create_params import Document, Toolgroup
1313
from llama_stack_client.types.agents.turn_create_response import AgentTurnResponseStreamChunk
14-
15-
14+
from llama_stack_client.types.shared.tool_call import ToolCall
15+
from llama_stack_client.types.agents.turn import CompletionMessage
1616
from .client_tool import ClientTool
17-
from .output_parser import OutputParser
17+
from .tool_parser import ToolParser
1818

1919
DEFAULT_MAX_ITER = 10
2020

@@ -25,14 +25,14 @@ def __init__(
2525
client: LlamaStackClient,
2626
agent_config: AgentConfig,
2727
client_tools: Tuple[ClientTool] = (),
28-
output_parser: Optional[OutputParser] = None,
28+
tool_parser: Optional[ToolParser] = None,
2929
):
3030
self.client = client
3131
self.agent_config = agent_config
3232
self.agent_id = self._create_agent(agent_config)
3333
self.client_tools = {t.get_name(): t for t in client_tools}
3434
self.sessions = []
35-
self.output_parser = output_parser
35+
self.tool_parser = tool_parser
3636
self.builtin_tools = {}
3737
for tg in agent_config["toolgroups"]:
3838
for tool in self.client.tools.list(toolgroup_id=tg):
@@ -54,33 +54,38 @@ def create_session(self, session_name: str) -> int:
5454
self.sessions.append(self.session_id)
5555
return self.session_id
5656

57-
def _process_chunk(self, chunk: AgentTurnResponseStreamChunk) -> None:
57+
def _get_tool_calls(self, chunk: AgentTurnResponseStreamChunk) -> List[ToolCall]:
5858
if chunk.event.payload.event_type != "turn_complete":
59-
return
60-
message = chunk.event.payload.turn.output_message
61-
62-
if self.output_parser:
63-
self.output_parser.parse(message)
59+
return []
6460

65-
def _has_tool_call(self, chunk: AgentTurnResponseStreamChunk) -> bool:
66-
if chunk.event.payload.event_type != "turn_complete":
67-
return False
6861
message = chunk.event.payload.turn.output_message
6962
if message.stop_reason == "out_of_tokens":
70-
return False
63+
return []
7164

72-
return len(message.tool_calls) > 0
65+
if self.tool_parser:
66+
return self.tool_parser.get_tool_calls(message)
7367

74-
def _run_tool(self, chunk: AgentTurnResponseStreamChunk) -> ToolResponseMessage:
75-
message = chunk.event.payload.turn.output_message
76-
tool_call = message.tool_calls[0]
68+
return message.tool_calls
69+
70+
def _run_tool(self, tool_calls: List[ToolCall]) -> ToolResponseMessage:
71+
assert len(tool_calls) == 1, "Only one tool call is supported"
72+
tool_call = tool_calls[0]
7773

7874
# custom client tools
7975
if tool_call.tool_name in self.client_tools:
8076
tool = self.client_tools[tool_call.tool_name]
8177
# NOTE: tool.run() expects a list of messages, we only pass in last message here
8278
# but we could pass in the entire message history
83-
result_message = tool.run([message])
79+
result_message = tool.run(
80+
[
81+
CompletionMessage(
82+
role="assistant",
83+
content=tool_call.tool_name,
84+
tool_calls=[tool_call],
85+
stop_reason="end_of_turn",
86+
)
87+
]
88+
)
8489
return result_message
8590

8691
# builtin tools executed by tool_runtime
@@ -149,14 +154,14 @@ def _create_turn_streaming(
149154
# by default, we stop after the first turn
150155
stop = True
151156
for chunk in response:
152-
self._process_chunk(chunk)
157+
tool_calls = self._get_tool_calls(chunk)
153158
if hasattr(chunk, "error"):
154159
yield chunk
155160
return
156-
elif not self._has_tool_call(chunk):
161+
elif not tool_calls:
157162
yield chunk
158163
else:
159-
next_message = self._run_tool(chunk)
164+
next_message = self._run_tool(tool_calls)
160165
yield next_message
161166

162167
# continue the turn when there's a tool call

src/llama_stack_client/lib/agents/output_parser.py

Lines changed: 0 additions & 48 deletions
This file was deleted.

src/llama_stack_client/lib/agents/react/agent.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
from pydantic import BaseModel
77
from typing import Dict, Any
88
from ..agent import Agent
9-
from .output_parser import ReActOutputParser
10-
from ..output_parser import OutputParser
9+
from .tool_parser import ReActToolParser
10+
from ..tool_parser import ToolParser
1111
from .prompts import DEFAULT_REACT_AGENT_SYSTEM_PROMPT_TEMPLATE
1212

1313
from typing import Tuple, Optional
@@ -39,7 +39,7 @@ def __init__(
3939
model: str,
4040
builtin_toolgroups: Tuple[str] = (),
4141
client_tools: Tuple[ClientTool] = (),
42-
output_parser: OutputParser = ReActOutputParser(),
42+
tool_parser: ToolParser = ReActToolParser(),
4343
json_response_format: bool = False,
4444
custom_agent_config: Optional[AgentConfig] = None,
4545
):
@@ -101,5 +101,5 @@ def get_tool_defs():
101101
client=client,
102102
agent_config=agent_config,
103103
client_tools=client_tools,
104-
output_parser=output_parser,
104+
tool_parser=tool_parser,
105105
)

src/llama_stack_client/lib/agents/react/output_parser.py renamed to src/llama_stack_client/lib/agents/react/tool_parser.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
# the root directory of this source tree.
66

77
from pydantic import BaseModel, ValidationError
8-
from typing import Dict, Any, Optional
9-
from ..output_parser import OutputParser
8+
from typing import Dict, Any, Optional, List
9+
from ..tool_parser import ToolParser
1010
from llama_stack_client.types.shared.completion_message import CompletionMessage
1111
from llama_stack_client.types.shared.tool_call import ToolCall
1212

@@ -24,23 +24,24 @@ class ReActOutput(BaseModel):
2424
answer: Optional[str] = None
2525

2626

27-
class ReActOutputParser(OutputParser):
28-
def parse(self, output_message: CompletionMessage) -> None:
27+
class ReActToolParser(ToolParser):
28+
def get_tool_calls(self, output_message: CompletionMessage) -> List[ToolCall]:
29+
tool_calls = []
2930
response_text = str(output_message.content)
3031
try:
3132
react_output = ReActOutput.model_validate_json(response_text)
3233
except ValidationError as e:
3334
print(f"Error parsing action: {e}")
34-
return
35+
return tool_calls
3536

3637
if react_output.answer:
37-
return
38+
return tool_calls
3839

3940
if react_output.action:
4041
tool_name = react_output.action.tool_name
4142
tool_params = react_output.action.tool_params
4243
if tool_name and tool_params:
4344
call_id = str(uuid.uuid4())
44-
output_message.tool_calls = [ToolCall(call_id=call_id, tool_name=tool_name, arguments=tool_params)]
45+
tool_calls = [ToolCall(call_id=call_id, tool_name=tool_name, arguments=tool_params)]
4546

46-
return
47+
return tool_calls
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the terms described in the LICENSE file in
5+
# the root directory of this source tree.
6+
7+
from abc import abstractmethod
8+
from typing import List
9+
10+
from llama_stack_client.types.agents.turn import CompletionMessage
11+
from llama_stack_client.types.shared.tool_call import ToolCall
12+
13+
14+
class ToolParser:
15+
"""
16+
Abstract base class for parsing agent responses into tool calls. Implement this class to customize how
17+
agent outputs are processed and transformed into executable tool calls.
18+
19+
To use this class:
20+
1. Create a subclass of ToolParser
21+
2. Implement the `get_tool_calls` method
22+
3. Pass your parser instance to the Agent's constructor
23+
24+
Example:
25+
class MyCustomParser(ToolParser):
26+
def get_tool_calls(self, output_message: CompletionMessage) -> List[ToolCall]:
27+
# Add your custom parsing logic here
28+
return extracted_tool_calls
29+
30+
Methods:
31+
get_tool_calls(output_message: CompletionMessage) -> List[ToolCall]:
32+
Abstract method that must be implemented by subclasses to process
33+
the agent's response and extract tool calls.
34+
35+
Args:
36+
output_message (CompletionMessage): The response message from agent turn
37+
38+
Returns:
39+
Optional[List[ToolCall]]: A list of parsed tool calls, or None if no tools should be called
40+
"""
41+
42+
@abstractmethod
43+
def get_tool_calls(self, output_message: CompletionMessage) -> List[ToolCall]:
44+
raise NotImplementedError

0 commit comments

Comments
 (0)