From cd36a77e205013dc326bc066b7ac77f10777290c Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 20 Feb 2025 17:38:21 -0800 Subject: [PATCH 01/14] 3/n --- llama_stack/apis/agents/agents.py | 9 +++ .../agents/meta_reference/agent_instance.py | 72 ++++++++++++------- .../inline/agents/meta_reference/agents.py | 14 +++- 3 files changed, 69 insertions(+), 26 deletions(-) diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index a83538b359..adf4313d7e 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -297,6 +297,15 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn): tool_config: Optional[ToolConfig] = None +@json_schema_type +class AgentTurnContinueRequest(BaseModel): + agent_id: str + session_id: str + turn_id: str + tool_responses: List[ToolResponseMessage] + stream: Optional[bool] = False + + @json_schema_type class AgentTurnResponseStreamChunk(BaseModel): """streamed agent turn completion response.""" diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 8da3f3a141..2ae71ded6b 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -23,6 +23,7 @@ AgentConfig, AgentToolGroup, AgentToolGroupWithArgs, + AgentTurnContinueRequest, AgentTurnCreateRequest, AgentTurnResponseEvent, AgentTurnResponseEventType, @@ -30,7 +31,6 @@ AgentTurnResponseStepProgressPayload, AgentTurnResponseStepStartPayload, AgentTurnResponseStreamChunk, - AgentTurnResponseTurnAwaitingInputPayload, AgentTurnResponseTurnCompletePayload, AgentTurnResponseTurnStartPayload, Attachment, @@ -227,25 +227,51 @@ async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> Asyn ) await self.storage.add_turn_to_session(request.session_id, turn) - if output_message.tool_calls: - chunk = AgentTurnResponseStreamChunk( - event=AgentTurnResponseEvent( - payload=AgentTurnResponseTurnAwaitingInputPayload( - turn=turn, - ) - ) - ) - else: - chunk = AgentTurnResponseStreamChunk( - event=AgentTurnResponseEvent( - payload=AgentTurnResponseTurnCompletePayload( - turn=turn, - ) + chunk = AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseTurnCompletePayload( + turn=turn, ) ) - + ) yield chunk + async def continue_turn(self, request: AgentTurnContinueRequest) -> AsyncGenerator: + with tracing.span("continue_turn") as span: + span.set_attribute("agent_id", self.agent_id) + span.set_attribute("session_id", request.session_id) + span.set_attribute("turn_id", request.turn_id) + span.set_attribute("request", request.model_dump_json()) + assert request.stream is True, "Non-streaming not supported" + + session_info = await self.storage.get_session_info(request.session_id) + if session_info is None: + raise ValueError(f"Session {request.session_id} not found") + + turns = await self.storage.get_session_turns(request.session_id) + + messages = [] + if self.agent_config.instructions != "": + messages.append(SystemMessage(content=self.agent_config.instructions)) + + for i, turn in enumerate(turns): + messages.extend(self.turn_to_messages(turn)) + + messages.extend(request.messages) + + # steps = [] + # output_message = None + # async for chunk in self.run( + # session_id=request.session_id, + # turn_id=request.turn_id, + # input_messages=messages, + # sampling_params=self.agent_config.sampling_params, + # stream=request.stream, + # documents=request.documents, + # toolgroups_for_turn=request.toolgroups, + # ): + # if isinstance(chunk, CompletionMessage): + async def run( self, session_id: str, @@ -626,7 +652,11 @@ async def _run( input_messages = input_messages + [message] else: log.info(f"{str(message)}") - # 1. Start the tool execution step and progress + tool_call = message.tool_calls[0] + if tool_call.tool_name in client_tools: + yield message + return + step_id = str(uuid.uuid4()) yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( @@ -636,8 +666,6 @@ async def _run( ) ) ) - - tool_call = message.tool_calls[0] yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepProgressPayload( @@ -652,12 +680,6 @@ async def _run( ) ) - # If tool is a client tool, yield CompletionMessage and return - if tool_call.tool_name in client_tools: - yield message - return - - # If tool is a builtin server tool, execute it tool_name = tool_call.tool_name if isinstance(tool_name, BuiltinTool): tool_name = tool_name.value diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index dfbc41262c..bdde890162 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -20,6 +20,7 @@ AgentSessionCreateResponse, AgentStepResponse, AgentToolGroup, + AgentTurnContinueRequest, AgentTurnCreateRequest, Document, Session, @@ -177,7 +178,18 @@ async def continue_agent_turn( tool_responses: List[ToolResponseMessage], stream: Optional[bool] = False, ) -> AsyncGenerator: - pass + if stream: + return self._continue_agent_turn_streaming(request) + else: + raise NotImplementedError("Non-streaming agent turns not yet implemented") + + async def _continue_agent_turn_streaming( + self, + request: AgentTurnContinueRequest, + ) -> AsyncGenerator: + agent = await self.get_agent(request.agent_id) + async for event in agent.continue_turn(request): + yield event async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn: turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}") From ee3c174bb3687150af577665b49172504a8a57a8 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 20 Feb 2025 17:40:39 -0800 Subject: [PATCH 02/14] add back 2/n --- .../agents/meta_reference/agent_instance.py | 21 ++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 2ae71ded6b..a6fba6f040 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -31,6 +31,7 @@ AgentTurnResponseStepProgressPayload, AgentTurnResponseStepStartPayload, AgentTurnResponseStreamChunk, + AgentTurnResponseTurnAwaitingInputPayload, AgentTurnResponseTurnCompletePayload, AgentTurnResponseTurnStartPayload, Attachment, @@ -227,13 +228,23 @@ async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> Asyn ) await self.storage.add_turn_to_session(request.session_id, turn) - chunk = AgentTurnResponseStreamChunk( - event=AgentTurnResponseEvent( - payload=AgentTurnResponseTurnCompletePayload( - turn=turn, + if output_message.tool_calls: + chunk = AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseTurnAwaitingInputPayload( + turn=turn, + ) ) ) - ) + else: + chunk = AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseTurnCompletePayload( + turn=turn, + ) + ) + ) + yield chunk async def continue_turn(self, request: AgentTurnContinueRequest) -> AsyncGenerator: From 157cf320d9c93675503003eb918e84b0d9c3ad90 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 20 Feb 2025 17:52:01 -0800 Subject: [PATCH 03/14] add back 2/n --- .../inline/agents/meta_reference/agent_instance.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index a6fba6f040..6a581cf11f 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -663,11 +663,7 @@ async def _run( input_messages = input_messages + [message] else: log.info(f"{str(message)}") - tool_call = message.tool_calls[0] - if tool_call.tool_name in client_tools: - yield message - return - + # 1. Start the tool execution step and progress step_id = str(uuid.uuid4()) yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( @@ -691,6 +687,13 @@ async def _run( ) ) + # If tool is a client tool, yield CompletionMessage and return + tool_call = message.tool_calls[0] + if tool_call.tool_name in client_tools: + yield message + return + + # If tool is a builtin server tool, execute it tool_name = tool_call.tool_name if isinstance(tool_name, BuiltinTool): tool_name = tool_name.value From 22355e3b1f2e19b4c387c18d6dcda21eec237939 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 20 Feb 2025 17:53:29 -0800 Subject: [PATCH 04/14] add back 2/n --- .../providers/inline/agents/meta_reference/agent_instance.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 6a581cf11f..bd12222a87 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -268,7 +268,7 @@ async def continue_turn(self, request: AgentTurnContinueRequest) -> AsyncGenerat for i, turn in enumerate(turns): messages.extend(self.turn_to_messages(turn)) - messages.extend(request.messages) + messages.extend(request.tool_responses) # steps = [] # output_message = None @@ -673,6 +673,7 @@ async def _run( ) ) ) + tool_call = message.tool_calls[0] yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepProgressPayload( @@ -688,7 +689,6 @@ async def _run( ) # If tool is a client tool, yield CompletionMessage and return - tool_call = message.tool_calls[0] if tool_call.tool_name in client_tools: yield message return From 4923270122d05b5ed5d91c7bfc215c8485955698 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 20 Feb 2025 18:00:57 -0800 Subject: [PATCH 05/14] continue turn --- .../agents/meta_reference/agent_instance.py | 73 ++++++++++++++++--- 1 file changed, 61 insertions(+), 12 deletions(-) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index bd12222a87..f7960185ed 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -270,18 +270,67 @@ async def continue_turn(self, request: AgentTurnContinueRequest) -> AsyncGenerat messages.extend(request.tool_responses) - # steps = [] - # output_message = None - # async for chunk in self.run( - # session_id=request.session_id, - # turn_id=request.turn_id, - # input_messages=messages, - # sampling_params=self.agent_config.sampling_params, - # stream=request.stream, - # documents=request.documents, - # toolgroups_for_turn=request.toolgroups, - # ): - # if isinstance(chunk, CompletionMessage): + # get the steps from the turn id + steps = [] + if len(turns) > 0: + steps = turns[-1].steps + + output_message = None + async for chunk in self.run( + session_id=request.session_id, + turn_id=request.turn_id, + input_messages=messages, + sampling_params=self.agent_config.sampling_params, + stream=request.stream, + ): + if isinstance(chunk, CompletionMessage): + output_message = chunk + continue + + assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}" + event = chunk.event + if event.payload.event_type == AgentTurnResponseEventType.step_complete.value: + steps.append(event.payload.step_details) + + yield chunk + + assert output_message is not None + + last_turn_messages = [] + last_turn_start_time = datetime.now() + if len(turns) > 0: + last_turn_start_time = turns[-1].started_at + last_turn_messages = self.turn_to_messages(turns[-1]) + + turn = Turn( + turn_id=request.turn_id, + session_id=request.session_id, + input_messages=last_turn_messages, + output_message=output_message, + started_at=last_turn_start_time, + completed_at=datetime.now(), + steps=steps, + ) + await self.storage.add_turn_to_session(request.session_id, turn) + + if output_message.tool_calls: + chunk = AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseTurnAwaitingInputPayload( + turn=turn, + ) + ) + ) + else: + chunk = AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseTurnCompletePayload( + turn=turn, + ) + ) + ) + + yield chunk async def run( self, From 5e00e9f260f2a75c1bf294068b776d40f1ab2928 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 20 Feb 2025 19:33:21 -0800 Subject: [PATCH 06/14] persist pending tool execution --- .../agents/meta_reference/agent_instance.py | 49 +++++++++++++++++++ .../inline/agents/meta_reference/agents.py | 7 +++ .../agents/meta_reference/persistence.py | 14 +++++- 3 files changed, 69 insertions(+), 1 deletion(-) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index f7960185ed..b0d8225112 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -275,6 +275,36 @@ async def continue_turn(self, request: AgentTurnContinueRequest) -> AsyncGenerat if len(turns) > 0: steps = turns[-1].steps + # mark tool execution step as complete + in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step( + request.session_id, request.turn_id + ) + tool_execution_step = ToolExecutionStep( + step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())), + turn_id=request.turn_id, + tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []), + tool_responses=[ + ToolResponse( + call_id=x.call_id, + tool_name=x.tool_name, + content=x.content, + ) + for x in in_progress_tool_call_step.tool_responses + ], + completed_at=datetime.now(), + started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else datetime.now()), + ) + steps.append(tool_execution_step) + yield AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseStepCompletePayload( + step_type=StepType.tool_execution.value, + step_id=tool_execution_step.step_id, + step_details=tool_execution_step, + ) + ) + ) + output_message = None async for chunk in self.run( session_id=request.session_id, @@ -302,6 +332,14 @@ async def continue_turn(self, request: AgentTurnContinueRequest) -> AsyncGenerat last_turn_start_time = turns[-1].started_at last_turn_messages = self.turn_to_messages(turns[-1]) + # add tool responses to the last turn messages + last_turn_messages.extend(request.tool_responses) + # filter out non User / Tool messages + # TODO: should we just keep all message types in Turn.input_messages? + last_turn_messages = [ + m for m in last_turn_messages if isinstance(m, UserMessage) or isinstance(m, ToolResponseMessage) + ] + turn = Turn( turn_id=request.turn_id, session_id=request.session_id, @@ -739,6 +777,17 @@ async def _run( # If tool is a client tool, yield CompletionMessage and return if tool_call.tool_name in client_tools: + await self.storage.set_in_progress_tool_call_step( + session_id, + turn_id, + ToolExecutionStep( + step_id=step_id, + turn_id=turn_id, + tool_calls=[tool_call], + tool_responses=[], + started_at=datetime.now(), + ), + ) yield message return diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index bdde890162..35038b3399 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -178,6 +178,13 @@ async def continue_agent_turn( tool_responses: List[ToolResponseMessage], stream: Optional[bool] = False, ) -> AsyncGenerator: + request = AgentTurnContinueRequest( + agent_id=agent_id, + session_id=session_id, + turn_id=turn_id, + tool_responses=tool_responses, + stream=stream, + ) if stream: return self._continue_agent_turn_streaming(request) else: diff --git a/llama_stack/providers/inline/agents/meta_reference/persistence.py b/llama_stack/providers/inline/agents/meta_reference/persistence.py index 4b8ad6d4ad..3c3866873b 100644 --- a/llama_stack/providers/inline/agents/meta_reference/persistence.py +++ b/llama_stack/providers/inline/agents/meta_reference/persistence.py @@ -12,7 +12,7 @@ from pydantic import BaseModel -from llama_stack.apis.agents import Turn +from llama_stack.apis.agents import ToolExecutionStep, Turn from llama_stack.providers.utils.kvstore import KVStore log = logging.getLogger(__name__) @@ -84,3 +84,15 @@ async def get_session_turns(self, session_id: str) -> List[Turn]: continue turns.sort(key=lambda x: (x.completed_at or datetime.min)) return turns + + async def set_in_progress_tool_call_step(self, session_id: str, turn_id: str, step: ToolExecutionStep): + await self.kvstore.set( + key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}", + value=step.model_dump_json(), + ) + + async def get_in_progress_tool_call_step(self, session_id: str, turn_id: str) -> Optional[ToolExecutionStep]: + value = await self.kvstore.get( + key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}", + ) + return ToolExecutionStep(**json.loads(value)) if value else None From 9a07e709ee3e45bed2203a00e3df74ced43b2b2e Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 20 Feb 2025 19:48:54 -0800 Subject: [PATCH 07/14] rename --- .../inline/agents/meta_reference/agent_instance.py | 4 ++-- .../providers/inline/agents/meta_reference/agents.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index b0d8225112..d0347e8cd9 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -23,7 +23,6 @@ AgentConfig, AgentToolGroup, AgentToolGroupWithArgs, - AgentTurnContinueRequest, AgentTurnCreateRequest, AgentTurnResponseEvent, AgentTurnResponseEventType, @@ -34,6 +33,7 @@ AgentTurnResponseTurnAwaitingInputPayload, AgentTurnResponseTurnCompletePayload, AgentTurnResponseTurnStartPayload, + AgentTurnResumeRequest, Attachment, Document, InferenceStep, @@ -247,7 +247,7 @@ async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> Asyn yield chunk - async def continue_turn(self, request: AgentTurnContinueRequest) -> AsyncGenerator: + async def continue_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator: with tracing.span("continue_turn") as span: span.set_attribute("agent_id", self.agent_id) span.set_attribute("session_id", request.session_id) diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index 2b87b4bd60..ea438bab8e 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -20,8 +20,8 @@ AgentSessionCreateResponse, AgentStepResponse, AgentToolGroup, - AgentTurnContinueRequest, AgentTurnCreateRequest, + AgentTurnResumeRequest, Document, Session, Turn, @@ -178,7 +178,7 @@ async def resume_agent_turn( tool_responses: List[ToolResponseMessage], stream: Optional[bool] = False, ) -> AsyncGenerator: - request = AgentTurnContinueRequest( + request = AgentTurnResumeRequest( agent_id=agent_id, session_id=session_id, turn_id=turn_id, @@ -192,7 +192,7 @@ async def resume_agent_turn( async def _continue_agent_turn_streaming( self, - request: AgentTurnContinueRequest, + request: AgentTurnResumeRequest, ) -> AsyncGenerator: agent = await self.get_agent(request.agent_id) async for event in agent.continue_turn(request): From 97f9580b1ab8fd814a9ded2a4c4c84442b5c562e Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 20 Feb 2025 19:49:50 -0800 Subject: [PATCH 08/14] rename --- .../providers/inline/agents/meta_reference/agent_instance.py | 4 ++-- llama_stack/providers/inline/agents/meta_reference/agents.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index d0347e8cd9..2344d5a17b 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -247,8 +247,8 @@ async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> Asyn yield chunk - async def continue_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator: - with tracing.span("continue_turn") as span: + async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator: + with tracing.span("resume_turn") as span: span.set_attribute("agent_id", self.agent_id) span.set_attribute("session_id", request.session_id) span.set_attribute("turn_id", request.turn_id) diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index ea438bab8e..19b4c0925d 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -195,7 +195,7 @@ async def _continue_agent_turn_streaming( request: AgentTurnResumeRequest, ) -> AsyncGenerator: agent = await self.get_agent(request.agent_id) - async for event in agent.continue_turn(request): + async for event in agent.resume_turn(request): yield event async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn: From 9c40529e93dd6198fc97fe38ca55f1a0cf242f02 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 20 Feb 2025 21:36:50 -0800 Subject: [PATCH 09/14] fix tool execution step from tool response --- .../providers/inline/agents/meta_reference/agent_instance.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 2344d5a17b..1d731fd8f2 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -289,7 +289,7 @@ async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator: tool_name=x.tool_name, content=x.content, ) - for x in in_progress_tool_call_step.tool_responses + for x in request.tool_responses ], completed_at=datetime.now(), started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else datetime.now()), From 99bc54b0337dc0b5984f299cf11f2cdf35ce48a4 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 20 Feb 2025 22:16:37 -0800 Subject: [PATCH 10/14] fix duplicate tool msg --- .../inline/agents/meta_reference/agent_instance.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index e3f4b2173a..e064e400f5 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -269,6 +269,9 @@ async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator: messages.extend(self.turn_to_messages(turn)) messages.extend(request.tool_responses) + last_turn_messages = [ + x for x in messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage) + ] # get the steps from the turn id steps = [] @@ -326,19 +329,9 @@ async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator: assert output_message is not None - last_turn_messages = [] last_turn_start_time = datetime.now() if len(turns) > 0: last_turn_start_time = turns[-1].started_at - last_turn_messages = self.turn_to_messages(turns[-1]) - - # add tool responses to the last turn messages - last_turn_messages.extend(request.tool_responses) - # filter out non User / Tool messages - # TODO: should we just keep all message types in Turn.input_messages? - last_turn_messages = [ - m for m in last_turn_messages if isinstance(m, UserMessage) or isinstance(m, ToolResponseMessage) - ] turn = Turn( turn_id=request.turn_id, From 2c06704d63ca0f323018b5cb140d534449d2bb02 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 20 Feb 2025 22:40:51 -0800 Subject: [PATCH 11/14] refactor --- .../agents/meta_reference/agent_instance.py | 28 ++++++++----------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index e064e400f5..66a12e5906 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -157,6 +157,15 @@ def turn_to_messages(self, turn: Turn) -> List[Message]: async def create_session(self, name: str) -> str: return await self.storage.create_session(name) + async def get_messages_from_turns(self, turns: List[Turn]) -> List[Message]: + messages = [] + if self.agent_config.instructions != "": + messages.append(SystemMessage(content=self.agent_config.instructions)) + + for i, turn in enumerate(turns): + messages.extend(self.turn_to_messages(turn)) + return messages + async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator: with tracing.span("create_and_execute_turn") as span: span.set_attribute("session_id", request.session_id) @@ -169,14 +178,7 @@ async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> Asyn raise ValueError(f"Session {request.session_id} not found") turns = await self.storage.get_session_turns(request.session_id) - - messages = [] - if self.agent_config.instructions != "": - messages.append(SystemMessage(content=self.agent_config.instructions)) - - for i, turn in enumerate(turns): - messages.extend(self.turn_to_messages(turn)) - + messages = await self.get_messages_from_turns(turns) messages.extend(request.messages) turn_id = str(uuid.uuid4()) @@ -260,15 +262,9 @@ async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator: raise ValueError(f"Session {request.session_id} not found") turns = await self.storage.get_session_turns(request.session_id) - - messages = [] - if self.agent_config.instructions != "": - messages.append(SystemMessage(content=self.agent_config.instructions)) - - for i, turn in enumerate(turns): - messages.extend(self.turn_to_messages(turn)) - + messages = await self.get_messages_from_turns(turns) messages.extend(request.tool_responses) + last_turn_messages = [ x for x in messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage) ] From fa4a56cf6c858a30bcd834f99c17d700dc8d9741 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 20 Feb 2025 22:41:23 -0800 Subject: [PATCH 12/14] refactor --- .../providers/inline/agents/meta_reference/agent_instance.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 66a12e5906..0b2b917f62 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -162,7 +162,7 @@ async def get_messages_from_turns(self, turns: List[Turn]) -> List[Message]: if self.agent_config.instructions != "": messages.append(SystemMessage(content=self.agent_config.instructions)) - for i, turn in enumerate(turns): + for turn in turns: messages.extend(self.turn_to_messages(turn)) return messages From b1b45ed3208a97d0eb859047dec516093df540ae Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 20 Feb 2025 22:46:17 -0800 Subject: [PATCH 13/14] add comment --- .../providers/inline/agents/meta_reference/agent_instance.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 0b2b917f62..1e41e5a01c 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -275,6 +275,8 @@ async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator: steps = turns[-1].steps # mark tool execution step as complete + # if there's no tool execution in progress step (due to storage, or tool call parsing on client), + # we'll create a new tool execution step with current time in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step( request.session_id, request.turn_id ) From ea050f7fa891a4da0fffc2d442a75c6b5e13fcf4 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Fri, 21 Feb 2025 11:42:18 -0800 Subject: [PATCH 14/14] datetime nit --- .../providers/inline/agents/meta_reference/agent_instance.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 1e41e5a01c..edd253356f 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -280,6 +280,7 @@ async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator: in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step( request.session_id, request.turn_id ) + now = datetime.now() tool_execution_step = ToolExecutionStep( step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())), turn_id=request.turn_id, @@ -292,8 +293,8 @@ async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator: ) for x in request.tool_responses ], - completed_at=datetime.now(), - started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else datetime.now()), + completed_at=now, + started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now), ) steps.append(tool_execution_step) yield AgentTurnResponseStreamChunk(