|
12 | 12 | import string |
13 | 13 | import uuid |
14 | 14 | from datetime import datetime |
15 | | -from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple |
| 15 | +from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union |
16 | 16 | from urllib.parse import urlparse |
17 | 17 |
|
18 | 18 | import httpx |
|
31 | 31 | AgentTurnResponseStreamChunk, |
32 | 32 | AgentTurnResponseTurnAwaitingInputPayload, |
33 | 33 | AgentTurnResponseTurnCompletePayload, |
34 | | - AgentTurnResponseTurnStartPayload, |
35 | 34 | AgentTurnResumeRequest, |
36 | 35 | Attachment, |
37 | 36 | Document, |
@@ -184,115 +183,49 @@ async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> Asyn |
184 | 183 | span.set_attribute("session_id", request.session_id) |
185 | 184 | span.set_attribute("agent_id", self.agent_id) |
186 | 185 | span.set_attribute("request", request.model_dump_json()) |
187 | | - assert request.stream is True, "Non-streaming not supported" |
188 | | - |
189 | | - session_info = await self.storage.get_session_info(request.session_id) |
190 | | - if session_info is None: |
191 | | - raise ValueError(f"Session {request.session_id} not found") |
192 | | - |
193 | | - turns = await self.storage.get_session_turns(request.session_id) |
194 | | - messages = await self.get_messages_from_turns(turns) |
195 | | - messages.extend(request.messages) |
196 | | - |
197 | 186 | turn_id = str(uuid.uuid4()) |
198 | 187 | span.set_attribute("turn_id", turn_id) |
199 | | - start_time = datetime.now().astimezone().isoformat() |
200 | | - yield AgentTurnResponseStreamChunk( |
201 | | - event=AgentTurnResponseEvent( |
202 | | - payload=AgentTurnResponseTurnStartPayload( |
203 | | - turn_id=turn_id, |
204 | | - ) |
205 | | - ) |
206 | | - ) |
207 | | - |
208 | | - steps = [] |
209 | | - output_message = None |
210 | | - async for chunk in self.run( |
211 | | - session_id=request.session_id, |
212 | | - turn_id=turn_id, |
213 | | - input_messages=messages, |
214 | | - sampling_params=self.agent_config.sampling_params, |
215 | | - stream=request.stream, |
216 | | - documents=request.documents, |
217 | | - toolgroups_for_turn=request.toolgroups, |
218 | | - ): |
219 | | - if isinstance(chunk, CompletionMessage): |
220 | | - logcat.info( |
221 | | - "agents", |
222 | | - f"returning result from the agent turn: {chunk}", |
223 | | - ) |
224 | | - output_message = chunk |
225 | | - continue |
226 | | - |
227 | | - assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}" |
228 | | - event = chunk.event |
229 | | - if event.payload.event_type == AgentTurnResponseEventType.step_complete.value: |
230 | | - steps.append(event.payload.step_details) |
231 | | - |
| 188 | + async for chunk in self._run_turn(request, turn_id): |
232 | 189 | yield chunk |
233 | 190 |
|
234 | | - assert output_message is not None |
235 | | - |
236 | | - turn = Turn( |
237 | | - turn_id=turn_id, |
238 | | - session_id=request.session_id, |
239 | | - input_messages=request.messages, |
240 | | - output_message=output_message, |
241 | | - started_at=start_time, |
242 | | - completed_at=datetime.now().astimezone().isoformat(), |
243 | | - steps=steps, |
244 | | - ) |
245 | | - await self.storage.add_turn_to_session(request.session_id, turn) |
246 | | - if output_message.tool_calls: |
247 | | - chunk = AgentTurnResponseStreamChunk( |
248 | | - event=AgentTurnResponseEvent( |
249 | | - payload=AgentTurnResponseTurnAwaitingInputPayload( |
250 | | - turn=turn, |
251 | | - ) |
252 | | - ) |
253 | | - ) |
254 | | - else: |
255 | | - chunk = AgentTurnResponseStreamChunk( |
256 | | - event=AgentTurnResponseEvent( |
257 | | - payload=AgentTurnResponseTurnCompletePayload( |
258 | | - turn=turn, |
259 | | - ) |
260 | | - ) |
261 | | - ) |
262 | | - |
263 | | - yield chunk |
264 | | - |
265 | 191 | async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator: |
266 | 192 | with tracing.span("resume_turn") as span: |
267 | 193 | span.set_attribute("agent_id", self.agent_id) |
268 | 194 | span.set_attribute("session_id", request.session_id) |
269 | 195 | span.set_attribute("turn_id", request.turn_id) |
270 | 196 | span.set_attribute("request", request.model_dump_json()) |
271 | | - assert request.stream is True, "Non-streaming not supported" |
| 197 | + async for chunk in self._run_turn(request): |
| 198 | + yield chunk |
272 | 199 |
|
273 | | - session_info = await self.storage.get_session_info(request.session_id) |
274 | | - if session_info is None: |
275 | | - raise ValueError(f"Session {request.session_id} not found") |
| 200 | + async def _run_turn( |
| 201 | + self, |
| 202 | + request: Union[AgentTurnCreateRequest, AgentTurnResumeRequest], |
| 203 | + turn_id: Optional[str] = None, |
| 204 | + ) -> AsyncGenerator: |
| 205 | + assert request.stream is True, "Non-streaming not supported" |
276 | 206 |
|
277 | | - turns = await self.storage.get_session_turns(request.session_id) |
278 | | - if len(turns) == 0: |
279 | | - raise ValueError("No turns found for session") |
| 207 | + is_resume = isinstance(request, AgentTurnResumeRequest) |
| 208 | + session_info = await self.storage.get_session_info(request.session_id) |
| 209 | + if session_info is None: |
| 210 | + raise ValueError(f"Session {request.session_id} not found") |
280 | 211 |
|
281 | | - messages = await self.get_messages_from_turns(turns) |
282 | | - messages.extend(request.tool_responses) |
| 212 | + turns = await self.storage.get_session_turns(request.session_id) |
| 213 | + if is_resume and len(turns) == 0: |
| 214 | + raise ValueError("No turns found for session") |
283 | 215 |
|
| 216 | + steps = [] |
| 217 | + messages = await self.get_messages_from_turns(turns) |
| 218 | + if is_resume: |
| 219 | + messages.extend(request.tool_responses) |
284 | 220 | last_turn = turns[-1] |
285 | 221 | last_turn_messages = self.turn_to_messages(last_turn) |
286 | 222 | last_turn_messages = [ |
287 | 223 | x for x in last_turn_messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage) |
288 | 224 | ] |
289 | | - |
290 | | - # TODO: figure out whether we should add the tool responses to the last turn messages |
291 | 225 | last_turn_messages.extend(request.tool_responses) |
292 | 226 |
|
293 | | - # get the steps from the turn id |
294 | | - steps = [] |
295 | | - steps = turns[-1].steps |
| 227 | + # get steps from the turn |
| 228 | + steps = last_turn.steps |
296 | 229 |
|
297 | 230 | # mark tool execution step as complete |
298 | 231 | # if there's no tool execution in progress step (due to storage, or tool call parsing on client), |
@@ -326,61 +259,66 @@ async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator: |
326 | 259 | ) |
327 | 260 | ) |
328 | 261 | ) |
| 262 | + input_messages = last_turn_messages |
329 | 263 |
|
330 | | - output_message = None |
331 | | - async for chunk in self.run( |
332 | | - session_id=request.session_id, |
333 | | - turn_id=request.turn_id, |
334 | | - input_messages=messages, |
335 | | - sampling_params=self.agent_config.sampling_params, |
336 | | - stream=request.stream, |
337 | | - ): |
338 | | - if isinstance(chunk, CompletionMessage): |
339 | | - output_message = chunk |
340 | | - continue |
341 | | - |
342 | | - assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}" |
343 | | - event = chunk.event |
344 | | - if event.payload.event_type == AgentTurnResponseEventType.step_complete.value: |
345 | | - steps.append(event.payload.step_details) |
346 | | - |
347 | | - yield chunk |
| 264 | + turn_id = request.turn_id |
| 265 | + start_time = last_turn.started_at |
| 266 | + else: |
| 267 | + messages.extend(request.messages) |
| 268 | + start_time = datetime.now().astimezone().isoformat() |
| 269 | + input_messages = request.messages |
| 270 | + |
| 271 | + output_message = None |
| 272 | + async for chunk in self.run( |
| 273 | + session_id=request.session_id, |
| 274 | + turn_id=turn_id, |
| 275 | + input_messages=messages, |
| 276 | + sampling_params=self.agent_config.sampling_params, |
| 277 | + stream=request.stream, |
| 278 | + documents=request.documents if not is_resume else None, |
| 279 | + toolgroups_for_turn=request.toolgroups if not is_resume else None, |
| 280 | + ): |
| 281 | + if isinstance(chunk, CompletionMessage): |
| 282 | + output_message = chunk |
| 283 | + continue |
348 | 284 |
|
349 | | - assert output_message is not None |
| 285 | + assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}" |
| 286 | + event = chunk.event |
| 287 | + if event.payload.event_type == AgentTurnResponseEventType.step_complete.value: |
| 288 | + steps.append(event.payload.step_details) |
350 | 289 |
|
351 | | - last_turn_start_time = datetime.now().astimezone().isoformat() |
352 | | - if len(turns) > 0: |
353 | | - last_turn_start_time = turns[-1].started_at |
| 290 | + yield chunk |
354 | 291 |
|
355 | | - turn = Turn( |
356 | | - turn_id=request.turn_id, |
357 | | - session_id=request.session_id, |
358 | | - input_messages=last_turn_messages, |
359 | | - output_message=output_message, |
360 | | - started_at=last_turn_start_time, |
361 | | - completed_at=datetime.now().astimezone().isoformat(), |
362 | | - steps=steps, |
363 | | - ) |
364 | | - await self.storage.add_turn_to_session(request.session_id, turn) |
| 292 | + assert output_message is not None |
365 | 293 |
|
366 | | - if output_message.tool_calls: |
367 | | - chunk = AgentTurnResponseStreamChunk( |
368 | | - event=AgentTurnResponseEvent( |
369 | | - payload=AgentTurnResponseTurnAwaitingInputPayload( |
370 | | - turn=turn, |
371 | | - ) |
| 294 | + turn = Turn( |
| 295 | + turn_id=turn_id, |
| 296 | + session_id=request.session_id, |
| 297 | + input_messages=input_messages, |
| 298 | + output_message=output_message, |
| 299 | + started_at=start_time, |
| 300 | + completed_at=datetime.now().astimezone().isoformat(), |
| 301 | + steps=steps, |
| 302 | + ) |
| 303 | + await self.storage.add_turn_to_session(request.session_id, turn) |
| 304 | + if output_message.tool_calls: |
| 305 | + chunk = AgentTurnResponseStreamChunk( |
| 306 | + event=AgentTurnResponseEvent( |
| 307 | + payload=AgentTurnResponseTurnAwaitingInputPayload( |
| 308 | + turn=turn, |
372 | 309 | ) |
373 | 310 | ) |
374 | | - else: |
375 | | - chunk = AgentTurnResponseStreamChunk( |
376 | | - event=AgentTurnResponseEvent( |
377 | | - payload=AgentTurnResponseTurnCompletePayload( |
378 | | - turn=turn, |
379 | | - ) |
| 311 | + ) |
| 312 | + else: |
| 313 | + chunk = AgentTurnResponseStreamChunk( |
| 314 | + event=AgentTurnResponseEvent( |
| 315 | + payload=AgentTurnResponseTurnCompletePayload( |
| 316 | + turn=turn, |
380 | 317 | ) |
381 | 318 | ) |
| 319 | + ) |
382 | 320 |
|
383 | | - yield chunk |
| 321 | + yield chunk |
384 | 322 |
|
385 | 323 | async def run( |
386 | 324 | self, |
|
0 commit comments