Skip to content

Commit 627265f

Browse files
committed
Streaming improvements
1 parent 8c26807 commit 627265f

File tree

2 files changed

+334
-82
lines changed

2 files changed

+334
-82
lines changed

src/app/endpoints/streaming_query.py

Lines changed: 239 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,19 @@
33
import json
44
import logging
55
import re
6-
from typing import Any, AsyncIterator
6+
from typing import Any, AsyncIterator, Iterator
77

88
from cachetools import TTLCache # type: ignore
99

1010
from llama_stack_client import APIConnectionError
1111
from llama_stack_client.lib.agents.agent import AsyncAgent # type: ignore
1212
from llama_stack_client import AsyncLlamaStackClient # type: ignore
13-
from llama_stack_client.types.shared.interleaved_content_item import TextContentItem
1413
from llama_stack_client.types import UserMessage # type: ignore
1514

15+
from llama_stack_client.lib.agents.event_logger import interleaved_content_as_str
16+
from llama_stack_client.types.shared import ToolCall
17+
from llama_stack_client.types.shared.interleaved_content_item import TextContentItem
18+
1619
from fastapi import APIRouter, HTTPException, Request, Depends, status
1720
from fastapi.responses import StreamingResponse
1821

@@ -122,7 +125,9 @@ def stream_end_event(metadata_map: dict) -> str:
122125
)
123126

124127

125-
def stream_build_event(chunk: Any, chunk_id: int, metadata_map: dict) -> str | None:
128+
# pylint: disable=R1702
129+
# pylint: disable=R0912
130+
def stream_build_event(chunk: Any, chunk_id: int, metadata_map: dict) -> Iterator[str]:
126131
"""Build a streaming event from a chunk response.
127132
128133
This function processes chunks from the LLama Stack streaming response and formats
@@ -137,52 +142,254 @@ def stream_build_event(chunk: Any, chunk_id: int, metadata_map: dict) -> str | N
137142
chunk_id: The current chunk ID counter (gets incremented for each token)
138143
139144
Returns:
140-
str | None: A formatted SSE data string with event information, or None if
141-
the chunk doesn't contain processable event data
145+
Iterator[str]: An iterable list of formatted SSE data strings with event information
142146
"""
143-
# pylint: disable=R1702
144-
if hasattr(chunk.event, "payload"):
145-
if chunk.event.payload.event_type == "step_progress":
146-
if hasattr(chunk.event.payload.delta, "text"):
147-
text = chunk.event.payload.delta.text
148-
return format_stream_data(
147+
# -----------------------------------
148+
# Error handling
149+
# -----------------------------------
150+
if hasattr(chunk, "error"):
151+
yield format_stream_data(
152+
{
153+
"event": "error",
154+
"data": {
155+
"id": chunk_id,
156+
"token": chunk.error["message"],
157+
},
158+
}
159+
)
160+
return
161+
162+
# -----------------------------------
163+
# Turn handling
164+
# -----------------------------------
165+
if chunk.event.payload.event_type in {"turn_start", "turn_awaiting_input"}:
166+
yield format_stream_data(
167+
{
168+
"event": "token",
169+
"data": {
170+
"id": chunk_id,
171+
"token": "",
172+
},
173+
}
174+
)
175+
return
176+
177+
if chunk.event.payload.event_type == "turn_complete":
178+
yield format_stream_data(
179+
{
180+
"event": "turn_complete",
181+
"data": {
182+
"id": chunk_id,
183+
"token": chunk.event.payload.turn.output_message.content,
184+
},
185+
}
186+
)
187+
return
188+
189+
# -----------------------------------
190+
# Shield handling
191+
# -----------------------------------
192+
if chunk.event.payload.step_type == "shield_call":
193+
if chunk.event.payload.event_type == "step_complete":
194+
violation = chunk.event.payload.step_details.violation
195+
if not violation:
196+
yield format_stream_data(
149197
{
150198
"event": "token",
151199
"data": {
152200
"id": chunk_id,
153201
"role": chunk.event.payload.step_type,
154-
"token": text,
202+
"token": "No Violation",
155203
},
156204
}
157205
)
158-
if (
159-
chunk.event.payload.event_type == "step_complete"
160-
and chunk.event.payload.step_details.step_type == "tool_execution"
161-
):
162-
for r in chunk.event.payload.step_details.tool_responses:
163-
if r.tool_name == "knowledge_search" and r.content:
164-
for text_content_item in r.content:
165-
if isinstance(text_content_item, TextContentItem):
166-
for match in METADATA_PATTERN.findall(
167-
text_content_item.text
168-
):
169-
meta = json.loads(match.replace("'", '"'))
170-
metadata_map[meta["document_id"]] = meta
171-
if chunk.event.payload.step_details.tool_calls:
172-
tool_name = str(
173-
chunk.event.payload.step_details.tool_calls[0].tool_name
206+
else:
207+
yield format_stream_data(
208+
{
209+
"event": "token",
210+
"data": {
211+
"id": chunk_id,
212+
"role": chunk.event.payload.step_type,
213+
"token": f"{violation.metadata} {violation.user_message}",
214+
},
215+
}
174216
)
175-
return format_stream_data(
217+
return
218+
219+
# -----------------------------------
220+
# Inference handling
221+
# -----------------------------------
222+
if chunk.event.payload.step_type == "inference":
223+
if chunk.event.payload.event_type == "step_start":
224+
yield format_stream_data(
225+
{
226+
"event": "token",
227+
"data": {
228+
"id": chunk_id,
229+
"role": chunk.event.payload.step_type,
230+
"token": "",
231+
},
232+
}
233+
)
234+
235+
elif chunk.event.payload.event_type == "step_progress":
236+
if chunk.event.payload.delta.type == "tool_call":
237+
if isinstance(chunk.event.payload.delta.tool_call, str):
238+
yield format_stream_data(
239+
{
240+
"event": "tool_call",
241+
"data": {
242+
"id": chunk_id,
243+
"role": chunk.event.payload.step_type,
244+
"token": chunk.event.payload.delta.tool_call,
245+
},
246+
}
247+
)
248+
elif isinstance(chunk.event.payload.delta.tool_call, ToolCall):
249+
yield format_stream_data(
250+
{
251+
"event": "tool_call",
252+
"data": {
253+
"id": chunk_id,
254+
"role": chunk.event.payload.step_type,
255+
"token": chunk.event.payload.delta.tool_call.tool_name,
256+
},
257+
}
258+
)
259+
260+
elif chunk.event.payload.delta.type == "text":
261+
yield format_stream_data(
176262
{
177263
"event": "token",
178264
"data": {
179265
"id": chunk_id,
180266
"role": chunk.event.payload.step_type,
181-
"token": tool_name,
267+
"token": chunk.event.payload.delta.text,
182268
},
183269
}
184270
)
185-
return None
271+
272+
elif chunk.event.payload.event_type == "step_complete":
273+
yield format_stream_data(
274+
{
275+
"event": "step_complete",
276+
"data": {
277+
"id": chunk_id,
278+
"token": "",
279+
},
280+
}
281+
)
282+
return
283+
284+
# -----------------------------------
285+
# Tool Execution handling
286+
# -----------------------------------
287+
if chunk.event.payload.step_type == "tool_execution":
288+
if chunk.event.payload.event_type == "step_start":
289+
yield format_stream_data(
290+
{
291+
"event": "tool_call",
292+
"data": {
293+
"id": chunk_id,
294+
# PatternFly Chat UI expects 'role=inference' to render correctly
295+
"role": "inference", # chunk.event.payload.step_type,
296+
"token": "",
297+
},
298+
}
299+
)
300+
301+
elif chunk.event.payload.event_type == "step_complete":
302+
for t in chunk.event.payload.step_details.tool_calls:
303+
yield format_stream_data(
304+
{
305+
"event": "tool_call",
306+
"data": {
307+
"id": chunk_id,
308+
# PatternFly Chat UI expects 'role=inference' to render correctly
309+
"role": "inference", # chunk.event.payload.step_type,
310+
"token": f"Tool:{t.tool_name} arguments:{t.arguments}",
311+
},
312+
}
313+
)
314+
315+
for r in chunk.event.payload.step_details.tool_responses:
316+
if r.tool_name == "query_from_memory":
317+
inserted_context = interleaved_content_as_str(r.content)
318+
yield format_stream_data(
319+
{
320+
"event": "tool_call",
321+
"data": {
322+
"id": chunk_id,
323+
# PatternFly Chat UI expects 'role=inference' to render correctly
324+
"role": "inference", # chunk.event.payload.step_type,
325+
"token": f"Fetched {len(inserted_context)} bytes from memory",
326+
},
327+
}
328+
)
329+
330+
elif r.tool_name == "knowledge_search" and r.content:
331+
summary = ""
332+
for i, text_content_item in enumerate(r.content):
333+
if isinstance(text_content_item, TextContentItem):
334+
if i == 0:
335+
summary = text_content_item.text
336+
summary = summary[: summary.find("\n")]
337+
for match in METADATA_PATTERN.findall(
338+
text_content_item.text
339+
):
340+
meta = json.loads(match.replace("'", '"'))
341+
metadata_map[meta["document_id"]] = meta
342+
yield format_stream_data(
343+
{
344+
"event": "tool_call",
345+
"data": {
346+
"id": chunk_id,
347+
# PatternFly Chat UI expects 'role=inference' to render correctly
348+
"role": "inference", # chunk.event.payload.step_type,
349+
"token": f"Tool:{r.tool_name} summary:{summary}",
350+
},
351+
}
352+
)
353+
354+
else:
355+
yield format_stream_data(
356+
{
357+
"event": "tool_call",
358+
"data": {
359+
"id": chunk_id,
360+
# PatternFly Chat UI expects 'role=inference' to render correctly
361+
"role": "inference", # chunk.event.payload.step_type,
362+
"token": f"Tool:{r.tool_name} response:{r.content}",
363+
},
364+
}
365+
)
366+
367+
# We swallow the 'step_complete' event and re-emit 'token' events with the tool details
368+
# Ensure we send a 'step_complete' event so the UI knows the 'tool_execution' completed.
369+
yield format_stream_data(
370+
{
371+
"event": "step_complete",
372+
"data": {
373+
"id": chunk_id,
374+
"token": "",
375+
},
376+
}
377+
)
378+
379+
return
380+
381+
# -----------------------------------
382+
# Catch-all for everything else
383+
# -----------------------------------
384+
yield format_stream_data(
385+
{
386+
"event": "heartbeat",
387+
"data": {
388+
"id": chunk_id,
389+
"token": "heartbeat",
390+
},
391+
}
392+
)
186393

187394

188395
@router.post("/streaming_query")
@@ -222,7 +429,7 @@ async def response_generator(turn_response: Any) -> AsyncIterator[str]:
222429
yield stream_start_event(conversation_id)
223430

224431
async for chunk in turn_response:
225-
if event := stream_build_event(chunk, chunk_id, metadata_map):
432+
for event in stream_build_event(chunk, chunk_id, metadata_map):
226433
complete_response += json.loads(event.replace("data: ", ""))[
227434
"data"
228435
]["token"]

0 commit comments

Comments
 (0)