Skip to content

Commit 548f2de

Browse files
yanxi0830Xi Yan
andauthored
feat: unify max infer iters with server/client tools (#173)
# What does this PR do? - See https://github.com/meta-llama/llama-stack/pull/1309/files [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan - See https://github.com/meta-llama/llama-stack/pull/1309/files [//]: # (## Documentation) [//]: # (- [ ] Added a Changelog entry if the change is significant) --------- Co-authored-by: Xi Yan <xiyan@Mac.attlocal.net>
1 parent ee5dd2b commit 548f2de

File tree

3 files changed

+10
-19
lines changed

3 files changed

+10
-19
lines changed

src/llama_stack_client/lib/agents/agent.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,7 @@
1111
from llama_stack_client.types.agent_create_params import AgentConfig
1212
from llama_stack_client.types.agents.turn import CompletionMessage, Turn
1313
from llama_stack_client.types.agents.turn_create_params import Document, Toolgroup
14-
from llama_stack_client.types.agents.turn_create_response import (
15-
AgentTurnResponseStreamChunk,
16-
)
14+
from llama_stack_client.types.agents.turn_create_response import AgentTurnResponseStreamChunk
1715
from llama_stack_client.types.shared.tool_call import ToolCall
1816

1917
from .client_tool import ClientTool
@@ -143,7 +141,6 @@ def _create_turn_streaming(
143141
documents: Optional[List[Document]] = None,
144142
) -> Iterator[AgentTurnResponseStreamChunk]:
145143
n_iter = 0
146-
max_iter = self.agent_config.get("max_infer_iters", DEFAULT_MAX_ITER)
147144

148145
# 1. create an agent turn
149146
turn_response = self.client.agents.turn.create(
@@ -170,12 +167,18 @@ def _create_turn_streaming(
170167
yield chunk
171168
else:
172169
is_turn_complete = False
170+
# End of turn is reached, do not resume even if there's a tool call
171+
if chunk.event.payload.turn.output_message.stop_reason in {"end_of_turn"}:
172+
yield chunk
173+
break
174+
173175
turn_id = self._get_turn_id(chunk)
174176
if n_iter == 0:
175177
yield chunk
176178

177179
# run the tools
178180
tool_response_message = self._run_tool(tool_calls)
181+
179182
# pass it to next iteration
180183
turn_response = self.client.agents.turn.resume(
181184
agent_id=self.agent_id,
@@ -185,7 +188,3 @@ def _create_turn_streaming(
185188
stream=True,
186189
)
187190
n_iter += 1
188-
break
189-
190-
if n_iter >= max_iter:
191-
raise Exception(f"Turn did not complete in {max_iter} iterations")

src/llama_stack_client/lib/agents/client_tool.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,7 @@
77
import inspect
88
import json
99
from abc import abstractmethod
10-
from typing import (
11-
Callable,
12-
Dict,
13-
get_args,
14-
get_origin,
15-
get_type_hints,
16-
List,
17-
TypeVar,
18-
Union,
19-
)
10+
from typing import Callable, Dict, get_args, get_origin, get_type_hints, List, TypeVar, Union
2011

2112
from llama_stack_client.types import Message, ToolResponseMessage
2213
from llama_stack_client.types.tool_def_param import Parameter, ToolDefParam

uv.lock

Lines changed: 2 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)