Skip to content

Commit 78962be

Browse files
authored
chore: refactor create_and_execute_turn and resume_turn (#1399)
# What does this PR do? - Closes #1212 [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan ``` LLAMA_STACK_BASE_URL=http://localhost:8321 pytest -v tests/integration/agents/test_agents.py --inference-model "meta-llama/Llama-3.3-70B-Instruct" ``` <img width="1203" alt="image" src="https://github.com/user-attachments/assets/35b60017-b3f2-4e98-87f2-2868730261bd" /> ``` LLAMA_STACK_CONFIG=fireworks pytest -v tests/integration/agents/test_agents.py::test_rag_and_code_agent --inference-model "meta-llama/Llama-3.3-70B-Instruct" ``` [//]: # (## Documentation)
1 parent abfbaf3 commit 78962be

File tree

5 files changed

+2891
-931
lines changed

5 files changed

+2891
-931
lines changed

llama_stack/providers/inline/agents/meta_reference/agent_instance.py

Lines changed: 73 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import string
1313
import uuid
1414
from datetime import datetime
15-
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
15+
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
1616
from urllib.parse import urlparse
1717

1818
import httpx
@@ -31,7 +31,6 @@
3131
AgentTurnResponseStreamChunk,
3232
AgentTurnResponseTurnAwaitingInputPayload,
3333
AgentTurnResponseTurnCompletePayload,
34-
AgentTurnResponseTurnStartPayload,
3534
AgentTurnResumeRequest,
3635
Attachment,
3736
Document,
@@ -184,115 +183,49 @@ async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> Asyn
184183
span.set_attribute("session_id", request.session_id)
185184
span.set_attribute("agent_id", self.agent_id)
186185
span.set_attribute("request", request.model_dump_json())
187-
assert request.stream is True, "Non-streaming not supported"
188-
189-
session_info = await self.storage.get_session_info(request.session_id)
190-
if session_info is None:
191-
raise ValueError(f"Session {request.session_id} not found")
192-
193-
turns = await self.storage.get_session_turns(request.session_id)
194-
messages = await self.get_messages_from_turns(turns)
195-
messages.extend(request.messages)
196-
197186
turn_id = str(uuid.uuid4())
198187
span.set_attribute("turn_id", turn_id)
199-
start_time = datetime.now().astimezone().isoformat()
200-
yield AgentTurnResponseStreamChunk(
201-
event=AgentTurnResponseEvent(
202-
payload=AgentTurnResponseTurnStartPayload(
203-
turn_id=turn_id,
204-
)
205-
)
206-
)
207-
208-
steps = []
209-
output_message = None
210-
async for chunk in self.run(
211-
session_id=request.session_id,
212-
turn_id=turn_id,
213-
input_messages=messages,
214-
sampling_params=self.agent_config.sampling_params,
215-
stream=request.stream,
216-
documents=request.documents,
217-
toolgroups_for_turn=request.toolgroups,
218-
):
219-
if isinstance(chunk, CompletionMessage):
220-
logcat.info(
221-
"agents",
222-
f"returning result from the agent turn: {chunk}",
223-
)
224-
output_message = chunk
225-
continue
226-
227-
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
228-
event = chunk.event
229-
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
230-
steps.append(event.payload.step_details)
231-
188+
async for chunk in self._run_turn(request, turn_id):
232189
yield chunk
233190

234-
assert output_message is not None
235-
236-
turn = Turn(
237-
turn_id=turn_id,
238-
session_id=request.session_id,
239-
input_messages=request.messages,
240-
output_message=output_message,
241-
started_at=start_time,
242-
completed_at=datetime.now().astimezone().isoformat(),
243-
steps=steps,
244-
)
245-
await self.storage.add_turn_to_session(request.session_id, turn)
246-
if output_message.tool_calls:
247-
chunk = AgentTurnResponseStreamChunk(
248-
event=AgentTurnResponseEvent(
249-
payload=AgentTurnResponseTurnAwaitingInputPayload(
250-
turn=turn,
251-
)
252-
)
253-
)
254-
else:
255-
chunk = AgentTurnResponseStreamChunk(
256-
event=AgentTurnResponseEvent(
257-
payload=AgentTurnResponseTurnCompletePayload(
258-
turn=turn,
259-
)
260-
)
261-
)
262-
263-
yield chunk
264-
265191
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
266192
with tracing.span("resume_turn") as span:
267193
span.set_attribute("agent_id", self.agent_id)
268194
span.set_attribute("session_id", request.session_id)
269195
span.set_attribute("turn_id", request.turn_id)
270196
span.set_attribute("request", request.model_dump_json())
271-
assert request.stream is True, "Non-streaming not supported"
197+
async for chunk in self._run_turn(request):
198+
yield chunk
272199

273-
session_info = await self.storage.get_session_info(request.session_id)
274-
if session_info is None:
275-
raise ValueError(f"Session {request.session_id} not found")
200+
async def _run_turn(
201+
self,
202+
request: Union[AgentTurnCreateRequest, AgentTurnResumeRequest],
203+
turn_id: Optional[str] = None,
204+
) -> AsyncGenerator:
205+
assert request.stream is True, "Non-streaming not supported"
276206

277-
turns = await self.storage.get_session_turns(request.session_id)
278-
if len(turns) == 0:
279-
raise ValueError("No turns found for session")
207+
is_resume = isinstance(request, AgentTurnResumeRequest)
208+
session_info = await self.storage.get_session_info(request.session_id)
209+
if session_info is None:
210+
raise ValueError(f"Session {request.session_id} not found")
280211

281-
messages = await self.get_messages_from_turns(turns)
282-
messages.extend(request.tool_responses)
212+
turns = await self.storage.get_session_turns(request.session_id)
213+
if is_resume and len(turns) == 0:
214+
raise ValueError("No turns found for session")
283215

216+
steps = []
217+
messages = await self.get_messages_from_turns(turns)
218+
if is_resume:
219+
messages.extend(request.tool_responses)
284220
last_turn = turns[-1]
285221
last_turn_messages = self.turn_to_messages(last_turn)
286222
last_turn_messages = [
287223
x for x in last_turn_messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage)
288224
]
289-
290-
# TODO: figure out whether we should add the tool responses to the last turn messages
291225
last_turn_messages.extend(request.tool_responses)
292226

293-
# get the steps from the turn id
294-
steps = []
295-
steps = turns[-1].steps
227+
# get steps from the turn
228+
steps = last_turn.steps
296229

297230
# mark tool execution step as complete
298231
# if there's no tool execution in progress step (due to storage, or tool call parsing on client),
@@ -326,61 +259,66 @@ async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
326259
)
327260
)
328261
)
262+
input_messages = last_turn_messages
329263

330-
output_message = None
331-
async for chunk in self.run(
332-
session_id=request.session_id,
333-
turn_id=request.turn_id,
334-
input_messages=messages,
335-
sampling_params=self.agent_config.sampling_params,
336-
stream=request.stream,
337-
):
338-
if isinstance(chunk, CompletionMessage):
339-
output_message = chunk
340-
continue
341-
342-
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
343-
event = chunk.event
344-
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
345-
steps.append(event.payload.step_details)
346-
347-
yield chunk
264+
turn_id = request.turn_id
265+
start_time = last_turn.started_at
266+
else:
267+
messages.extend(request.messages)
268+
start_time = datetime.now().astimezone().isoformat()
269+
input_messages = request.messages
270+
271+
output_message = None
272+
async for chunk in self.run(
273+
session_id=request.session_id,
274+
turn_id=turn_id,
275+
input_messages=messages,
276+
sampling_params=self.agent_config.sampling_params,
277+
stream=request.stream,
278+
documents=request.documents if not is_resume else None,
279+
toolgroups_for_turn=request.toolgroups if not is_resume else None,
280+
):
281+
if isinstance(chunk, CompletionMessage):
282+
output_message = chunk
283+
continue
348284

349-
assert output_message is not None
285+
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
286+
event = chunk.event
287+
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
288+
steps.append(event.payload.step_details)
350289

351-
last_turn_start_time = datetime.now().astimezone().isoformat()
352-
if len(turns) > 0:
353-
last_turn_start_time = turns[-1].started_at
290+
yield chunk
354291

355-
turn = Turn(
356-
turn_id=request.turn_id,
357-
session_id=request.session_id,
358-
input_messages=last_turn_messages,
359-
output_message=output_message,
360-
started_at=last_turn_start_time,
361-
completed_at=datetime.now().astimezone().isoformat(),
362-
steps=steps,
363-
)
364-
await self.storage.add_turn_to_session(request.session_id, turn)
292+
assert output_message is not None
365293

366-
if output_message.tool_calls:
367-
chunk = AgentTurnResponseStreamChunk(
368-
event=AgentTurnResponseEvent(
369-
payload=AgentTurnResponseTurnAwaitingInputPayload(
370-
turn=turn,
371-
)
294+
turn = Turn(
295+
turn_id=turn_id,
296+
session_id=request.session_id,
297+
input_messages=input_messages,
298+
output_message=output_message,
299+
started_at=start_time,
300+
completed_at=datetime.now().astimezone().isoformat(),
301+
steps=steps,
302+
)
303+
await self.storage.add_turn_to_session(request.session_id, turn)
304+
if output_message.tool_calls:
305+
chunk = AgentTurnResponseStreamChunk(
306+
event=AgentTurnResponseEvent(
307+
payload=AgentTurnResponseTurnAwaitingInputPayload(
308+
turn=turn,
372309
)
373310
)
374-
else:
375-
chunk = AgentTurnResponseStreamChunk(
376-
event=AgentTurnResponseEvent(
377-
payload=AgentTurnResponseTurnCompletePayload(
378-
turn=turn,
379-
)
311+
)
312+
else:
313+
chunk = AgentTurnResponseStreamChunk(
314+
event=AgentTurnResponseEvent(
315+
payload=AgentTurnResponseTurnCompletePayload(
316+
turn=turn,
380317
)
381318
)
319+
)
382320

383-
yield chunk
321+
yield chunk
384322

385323
async def run(
386324
self,

0 commit comments

Comments
 (0)