From f3f5e4f4da3217dd40317c393b812bb9949f95fe Mon Sep 17 00:00:00 2001 From: "Eric Huang (AI Platform)" Date: Mon, 27 Jan 2025 15:21:26 -0800 Subject: [PATCH 1/2] client tool fix --- src/llama_stack_client/lib/agents/agent.py | 44 +++++++++++++--------- 1 file changed, 26 insertions(+), 18 deletions(-) diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index 5af8ba14..091df938 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -75,21 +75,29 @@ def create_turn( toolgroups: Optional[List[Toolgroup]] = None, documents: Optional[List[Document]] = None, ): - response = self.client.agents.turn.create( - agent_id=self.agent_id, - # use specified session_id or last session created - session_id=session_id or self.session_id[-1], - messages=messages, - stream=True, - documents=documents, - toolgroups=toolgroups, - ) - for chunk in response: - if hasattr(chunk, "error"): - yield chunk - return - elif not self._has_tool_call(chunk): - yield chunk - else: - next_message = self._run_tool(chunk) - yield next_message + stop = False + while not stop: + response = self.client.agents.turn.create( + agent_id=self.agent_id, + # use specified session_id or last session created + session_id=session_id or self.session_id[-1], + messages=messages, + stream=True, + documents=documents, + toolgroups=toolgroups, + ) + # by default, we stop after the first turn + stop = True + for chunk in response: + if hasattr(chunk, "error"): + yield chunk + return + elif not self._has_tool_call(chunk): + yield chunk + else: + next_message = self._run_tool(chunk) + yield next_message + + # continue the turn when there's a tool call + stop = False + messages = [next_message] From 364086b0303c290b298cab13b6b2b389fbc9361d Mon Sep 17 00:00:00 2001 From: "Eric Huang (AI Platform)" Date: Mon, 27 Jan 2025 15:21:26 -0800 Subject: [PATCH 2/2] add max iter limit --- src/llama_stack_client/lib/agents/agent.py | 48 ++++++++++++++-------- 1 file changed, 30 insertions(+), 18 deletions(-) diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index 5af8ba14..3951b281 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -13,6 +13,7 @@ from .client_tool import ClientTool +DEFAULT_MAX_ITER = 10 class Agent: def __init__( @@ -75,21 +76,32 @@ def create_turn( toolgroups: Optional[List[Toolgroup]] = None, documents: Optional[List[Document]] = None, ): - response = self.client.agents.turn.create( - agent_id=self.agent_id, - # use specified session_id or last session created - session_id=session_id or self.session_id[-1], - messages=messages, - stream=True, - documents=documents, - toolgroups=toolgroups, - ) - for chunk in response: - if hasattr(chunk, "error"): - yield chunk - return - elif not self._has_tool_call(chunk): - yield chunk - else: - next_message = self._run_tool(chunk) - yield next_message + stop = False + n_iter = 0 + max_iter = self.agent_config.get('max_infer_iters', DEFAULT_MAX_ITER) + while not stop and n_iter < max_iter: + response = self.client.agents.turn.create( + agent_id=self.agent_id, + # use specified session_id or last session created + session_id=session_id or self.session_id[-1], + messages=messages, + stream=True, + documents=documents, + toolgroups=toolgroups, + ) + # by default, we stop after the first turn + stop = True + for chunk in response: + if hasattr(chunk, "error"): + yield chunk + return + elif not self._has_tool_call(chunk): + yield chunk + else: + next_message = self._run_tool(chunk) + yield next_message + + # continue the turn when there's a tool call + stop = False + messages = [next_message] + n_iter += 1 \ No newline at end of file