Skip to content

Commit aaff961

Browse files
authored
fix: react agent with custom tool parser n_iters (#184)
# What does this PR do? - Custom tool_parser on client side is not working correctly with latest change unifying max_infer_iters. - This is b/c we output an `end_of_message` - Temporary is to still keep track of n_iters when we have custom tool_parser on client, this will not be needed when we move ReAct to server side. [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan ``` python -m examples.agents.react_agent localhost 8321 ``` [//]: # (## Documentation) [//]: # (- [ ] Added a Changelog entry if the change is significant)
1 parent c2f73b1 commit aaff961

File tree

2 files changed

+15
-9
lines changed

2 files changed

+15
-9
lines changed

src/llama_stack_client/lib/agents/agent.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,20 @@
33
#
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
6+
import logging
67
from typing import Iterator, List, Optional, Tuple, Union
78

89
from llama_stack_client import LlamaStackClient
9-
import logging
1010

1111
from llama_stack_client.types import ToolResponseMessage, ToolResponseParam, UserMessage
1212
from llama_stack_client.types.agent_create_params import AgentConfig
1313
from llama_stack_client.types.agents.turn import CompletionMessage, Turn
1414
from llama_stack_client.types.agents.turn_create_params import Document, Toolgroup
1515
from llama_stack_client.types.agents.turn_create_response import AgentTurnResponseStreamChunk
1616
from llama_stack_client.types.shared.tool_call import ToolCall
17+
from llama_stack_client.types.shared_params.agent_config import ToolConfig
1718
from llama_stack_client.types.shared_params.response_format import ResponseFormat
1819
from llama_stack_client.types.shared_params.sampling_params import SamplingParams
19-
from llama_stack_client.types.shared_params.agent_config import ToolConfig
2020

2121
from .client_tool import ClientTool
2222
from .tool_parser import ToolParser
@@ -91,10 +91,10 @@ def __init__(
9191
# Add optional parameters if provided
9292
if enable_session_persistence is not None:
9393
agent_config["enable_session_persistence"] = enable_session_persistence
94-
if input_shields is not None:
95-
agent_config["input_shields"] = input_shields
9694
if max_infer_iters is not None:
9795
agent_config["max_infer_iters"] = max_infer_iters
96+
if input_shields is not None:
97+
agent_config["input_shields"] = input_shields
9898
if output_shields is not None:
9999
agent_config["output_shields"] = output_shields
100100
if response_format is not None:
@@ -254,7 +254,9 @@ def _create_turn_streaming(
254254
else:
255255
is_turn_complete = False
256256
# End of turn is reached, do not resume even if there's a tool call
257-
if chunk.event.payload.turn.output_message.stop_reason in {"end_of_turn"}:
257+
# We only check for this if tool_parser is not set, because otherwise
258+
# tool call will be parsed on client side, and server will always return "end_of_turn"
259+
if not self.tool_parser and chunk.event.payload.turn.output_message.stop_reason in {"end_of_turn"}:
258260
yield chunk
259261
break
260262

@@ -274,3 +276,6 @@ def _create_turn_streaming(
274276
stream=True,
275277
)
276278
n_iter += 1
279+
280+
if self.tool_parser and n_iter > self.agent_config.get("max_infer_iters", DEFAULT_MAX_ITER):
281+
raise Exception("Max inference iterations reached")

src/llama_stack_client/lib/agents/react/tool_parser.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
66

7+
import uuid
8+
from typing import List, Optional, Union
9+
710
from pydantic import BaseModel, ValidationError
8-
from typing import Optional, List, Union
9-
from ..tool_parser import ToolParser
11+
1012
from llama_stack_client.types.shared.completion_message import CompletionMessage
1113
from llama_stack_client.types.shared.tool_call import ToolCall
12-
13-
import uuid
14+
from ..tool_parser import ToolParser
1415

1516

1617
class Param(BaseModel):

0 commit comments

Comments
 (0)