From 706a1e61d7a0ba0e10ad4a4392bb60c6ceb02b14 Mon Sep 17 00:00:00 2001 From: Eric Huang Date: Thu, 13 Feb 2025 10:40:26 -0800 Subject: [PATCH] feat: include complete turn response in Agent.create_turn MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: In https://github.com/meta-llama/llama-stack-client-python/pull/102, we made a turn's behavior more complete by automatically passing back the tool response and create another turn when client tool is used. However, this creates a problem with the non-streaming API where the response object only contains information since the last tool call. This PR is a hacky attemp to address this, by combining the Turn responses into one. I think ideally we should move all the loop logic to only be on the server side, where a turn would pause and the client SDK would pass tool reponses back to resume a turn. I also changed it to not yield ToolResponseMessage but instead yield a proper ToolExecutionStep event so that it can be treated the same as server side tool execution in terms of logging. I.e. it now outputs: "tool_execution> Tool:load_url Response:{"content": "\nToday Google announced that they have released the source code to PebbleOS. This is massive for Rebble, and will accelerate our" instead of "CustomTool> {"content": "\nToday Google announced that they have released the source code to PebbleOS. This is massive for Rebble, and will accelerate our efforts to " Test Plan: Added test in https://github.com/meta-llama/llama-stack/pull/1078 Run a simple script with Agent and client tool. Observe the returned response has steps from both created turns. Turn( │ input_messages=[ │ │ UserMessage( │ │ │ content='load https://llama-stack.readthedocs.io/en/latest/introduction/index.html and summarize it', │ │ │ role='user', │ │ │ context=None │ │ ) │ ], │ output_message=CompletionMessage( │ │ content="The document from the given URL is about Google releasing the source code to PebbleOS, which is a significant development for Rebble. This allows Rebble to accelerate its efforts to produce new hardware. Rebble had been working on its own replacement firmware, RebbleOS, but the release of PebbleOS's source code will help Rebble to build a production-ready real-time OS for the Pebble.", │ │ role='assistant', │ │ stop_reason='end_of_turn', │ │ tool_calls=[] │ ), │ session_id='dec1c6c0-ed9b-42c1-97d7-906871acd5ba', │ started_at=datetime.datetime(2025, 2, 12, 16, 38, 14, 643186), │ steps=[ │ │ InferenceStep( │ │ │ api_model_response=CompletionMessage( │ │ │ │ content='', │ │ │ │ role='assistant', │ │ │ │ stop_reason='end_of_turn', │ │ │ │ tool_calls=[ │ │ │ │ │ ToolCall( │ │ │ │ │ │ arguments={'url': 'https://llama-stack.readthedocs.io/en/latest/introduction/index.html'}, │ │ │ │ │ │ call_id='5d09151b-8a53-4292-be8d-f21e134d5142', │ │ │ │ │ │ tool_name='load_url' │ │ │ │ │ ) │ │ │ │ ] │ │ │ ), │ │ │ step_id='d724a238-d02b-4d77-a4bc-a978a54979c6', │ │ │ step_type='inference', │ │ │ turn_id='0496c654-cd02-48bb-a2ab-d1a0a5e91aba', │ │ │ completed_at=datetime.datetime(2025, 2, 12, 16, 38, 15, 523310), │ │ │ started_at=datetime.datetime(2025, 2, 12, 16, 38, 14, 654535) │ │ ), │ │ ToolExecutionStep( │ │ │ step_id='49f19a5e-6a1e-4b1c-9232-fbafb82f2f89', │ │ │ step_type='tool_execution', │ │ │ tool_calls=[ │ │ │ │ ToolCall( │ │ │ │ │ arguments={'url': 'https://llama-stack.readthedocs.io/en/latest/introduction/index.html'}, │ │ │ │ │ call_id='5d09151b-8a53-4292-be8d-f21e134d5142', │ │ │ │ │ tool_name='load_url' │ │ │ │ ) │ │ │ ], │ │ │ tool_responses=[ │ │ │ │ ToolResponse( │ │ │ │ │ call_id='5d09151b-8a53-4292-be8d-f21e134d5142', │ │ │ │ │ content='{"content": "\nToday Google announced that they have released the source code to PebbleOS. This is massive for Rebble, and will accelerate our efforts to produce new hardware.\n\nPreviously, we have been working on our own replacement firmware: RebbleOS. As you can see by the commit history though, progress was slow. Building a production-ready realtime OS for the Pebble is no small feat, and although we were confident we’d get there given enough time, it was never our ideal path. Thanks to the hard work of many people both within Google and not, we finally have our hands on the original source code for PebbleOS. You can read Google’s blog post on this for even more information.\n\nThis does not mean we instantly have the ability to start developing updates for PebbleOS though, we first will need to spend some concentrated time getting it to build. But before we talk about that, let’s talk about Rebble itself.\n"}', │ │ │ │ │ tool_name='load_url' │ │ │ │ ) │ │ │ ], │ │ │ turn_id='0496c654-cd02-48bb-a2ab-d1a0a5e91aba', │ │ │ completed_at=datetime.datetime(2025, 2, 12, 16, 38, 15, 534830), │ │ │ started_at=datetime.datetime(2025, 2, 12, 16, 38, 15, 534756) │ │ ), │ │ InferenceStep( │ │ │ api_model_response=CompletionMessage( │ │ │ │ content="The document from the given URL is about Google releasing the source code to PebbleOS, which is a significant development for Rebble. This allows Rebble to accelerate its efforts to produce new hardware. Rebble had been working on its own replacement firmware, RebbleOS, but the release of PebbleOS's source code will help Rebble to build a production-ready real-time OS for the Pebble.", │ │ │ │ role='assistant', │ │ │ │ stop_reason='end_of_turn', │ │ │ │ tool_calls=[] │ │ │ ), │ │ │ step_id='5e6daa91-e689-4d7a-a7f9-d7c3da2eca5a', │ │ │ step_type='inference', │ │ │ turn_id='8f65d88d-7643-4dd7-acc7-48cd9e8aa449', │ │ │ completed_at=datetime.datetime(2025, 2, 12, 16, 38, 16, 179107), │ │ │ started_at=datetime.datetime(2025, 2, 12, 16, 38, 15, 561449) │ │ ) │ ], │ turn_id='0496c654-cd02-48bb-a2ab-d1a0a5e91aba', │ completed_at=datetime.datetime(2025, 2, 12, 16, 38, 16, 191199), │ output_attachments=[] ) ``` --- src/llama_stack_client/lib/agents/agent.py | 75 ++++++++++++++++--- .../lib/agents/event_logger.py | 10 +-- 2 files changed, 64 insertions(+), 21 deletions(-) diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index c40ef4c8..0a8ab226 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -10,11 +10,21 @@ from llama_stack_client.types.agent_create_params import AgentConfig from llama_stack_client.types.agents.turn import Turn from llama_stack_client.types.agents.turn_create_params import Document, Toolgroup -from llama_stack_client.types.agents.turn_create_response import AgentTurnResponseStreamChunk +from llama_stack_client.types.agents.turn_create_response import ( + AgentTurnResponseStreamChunk, +) +from llama_stack_client.types.agents.turn_response_event import TurnResponseEvent +from llama_stack_client.types.agents.turn_response_event_payload import ( + AgentTurnResponseStepCompletePayload, +) from llama_stack_client.types.shared.tool_call import ToolCall from llama_stack_client.types.agents.turn import CompletionMessage from .client_tool import ClientTool from .tool_parser import ToolParser +from datetime import datetime +import uuid +from llama_stack_client.types.tool_execution_step import ToolExecutionStep +from llama_stack_client.types.tool_response import ToolResponse DEFAULT_MAX_ITER = 10 @@ -119,16 +129,29 @@ def create_turn( stream: bool = True, ) -> Iterator[AgentTurnResponseStreamChunk] | Turn: if stream: - return self._create_turn_streaming(messages, session_id, toolgroups, documents, stream) + return self._create_turn_streaming(messages, session_id, toolgroups, documents) else: - chunk = None - for chunk in self._create_turn_streaming(messages, session_id, toolgroups, documents, stream): + chunks = [] + for chunk in self._create_turn_streaming(messages, session_id, toolgroups, documents): + if chunk.event.payload.event_type == "turn_complete": + chunks.append(chunk) pass - if not chunk: - raise Exception("No chunk returned") - if chunk.event.payload.event_type != "turn_complete": + if not chunks: raise Exception("Turn did not complete") - return chunk.event.payload.turn + + # merge chunks + return Turn( + input_messages=chunks[0].event.payload.turn.input_messages, + output_message=chunks[-1].event.payload.turn.output_message, + session_id=chunks[0].event.payload.turn.session_id, + steps=[step for chunk in chunks for step in chunk.event.payload.turn.steps], + turn_id=chunks[0].event.payload.turn.turn_id, + started_at=chunks[0].event.payload.turn.started_at, + completed_at=chunks[-1].event.payload.turn.completed_at, + output_attachments=[ + attachment for chunk in chunks for attachment in chunk.event.payload.turn.output_attachments + ], + ) def _create_turn_streaming( self, @@ -136,7 +159,6 @@ def _create_turn_streaming( session_id: Optional[str] = None, toolgroups: Optional[List[Toolgroup]] = None, documents: Optional[List[Document]] = None, - stream: bool = True, ) -> Iterator[AgentTurnResponseStreamChunk]: stop = False n_iter = 0 @@ -161,10 +183,39 @@ def _create_turn_streaming( elif not tool_calls: yield chunk else: - next_message = self._run_tool(tool_calls) - yield next_message + tool_execution_start_time = datetime.now() + tool_response_message = self._run_tool(tool_calls) + tool_execution_step = ToolExecutionStep( + step_type="tool_execution", + step_id=str(uuid.uuid4()), + tool_calls=tool_calls, + tool_responses=[ + ToolResponse( + tool_name=tool_response_message.tool_name, + content=tool_response_message.content, + call_id=tool_response_message.call_id, + ) + ], + turn_id=chunk.event.payload.turn.turn_id, + completed_at=datetime.now(), + started_at=tool_execution_start_time, + ) + yield AgentTurnResponseStreamChunk( + event=TurnResponseEvent( + payload=AgentTurnResponseStepCompletePayload( + event_type="step_complete", + step_id=tool_execution_step.step_id, + step_type="tool_execution", + step_details=tool_execution_step, + ) + ) + ) + + # HACK: append the tool execution step to the turn + chunk.event.payload.turn.steps.append(tool_execution_step) + yield chunk # continue the turn when there's a tool call stop = False - messages = [next_message] + messages = [tool_response_message] n_iter += 1 diff --git a/src/llama_stack_client/lib/agents/event_logger.py b/src/llama_stack_client/lib/agents/event_logger.py index dff81994..fbf627f2 100644 --- a/src/llama_stack_client/lib/agents/event_logger.py +++ b/src/llama_stack_client/lib/agents/event_logger.py @@ -8,7 +8,7 @@ from termcolor import cprint -from llama_stack_client.types import InterleavedContent, ToolResponseMessage +from llama_stack_client.types import InterleavedContent def interleaved_content_as_str(content: InterleavedContent, sep: str = " ") -> str: @@ -70,14 +70,6 @@ def _yield_printable_events(self, chunk, previous_event_type=None, previous_step yield TurnStreamPrintableEvent(role=None, content=chunk.error["message"], color="red") return - if not hasattr(chunk, "event"): - # Need to check for custom tool first - # since it does not produce event but instead - # a Message - if isinstance(chunk, ToolResponseMessage): - yield TurnStreamPrintableEvent(role="CustomTool", content=chunk.content, color="green") - return - event = chunk.event event_type = event.payload.event_type