Skip to content

Commit d754a32

Browse files
authored
Merge pull request #585 from bsatapat-jpg/stream_query
LCORE-693: Added rag_chunks to streaming_query
2 parents 87b749c + 54a5858 commit d754a32

File tree

5 files changed

+395
-31
lines changed

5 files changed

+395
-31
lines changed

src/app/endpoints/query.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from llama_stack_client.types.model_list_response import ModelListResponse
2323
from llama_stack_client.types.shared.interleaved_content_item import TextContentItem
2424
from llama_stack_client.types.tool_execution_step import ToolExecutionStep
25-
from pydantic import AnyUrl
2625

2726
import constants
2827
import metrics
@@ -363,22 +362,7 @@ async def query_endpoint_handler( # pylint: disable=R0914
363362
for tc in summary.tool_calls
364363
]
365364

366-
logger.info("Extracting referenced documents...")
367-
referenced_docs = []
368-
doc_sources = set()
369-
for chunk in summary.rag_chunks:
370-
if chunk.source and chunk.source not in doc_sources:
371-
doc_sources.add(chunk.source)
372-
referenced_docs.append(
373-
ReferencedDocument(
374-
doc_url=(
375-
AnyUrl(chunk.source)
376-
if chunk.source.startswith("http")
377-
else None
378-
),
379-
doc_title=chunk.source,
380-
)
381-
)
365+
logger.info("Using referenced documents from response...")
382366

383367
logger.info("Building final response...")
384368
response = QueryResponse(

src/app/endpoints/streaming_query.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@
4848
from models.responses import ForbiddenResponse, UnauthorizedResponse
4949
from utils.endpoints import (
5050
check_configuration_loaded,
51+
create_referenced_documents_with_metadata,
52+
create_rag_chunks_dict,
5153
get_agent,
5254
get_system_prompt,
5355
store_conversation_into_cache,
@@ -142,7 +144,7 @@ def stream_start_event(conversation_id: str) -> str:
142144
)
143145

144146

145-
def stream_end_event(metadata_map: dict) -> str:
147+
def stream_end_event(metadata_map: dict, summary: TurnSummary) -> str:
146148
"""
147149
Yield the end of the data stream.
148150
@@ -158,20 +160,27 @@ def stream_end_event(metadata_map: dict) -> str:
158160
str: A Server-Sent Events (SSE) formatted string
159161
representing the end of the data stream.
160162
"""
163+
# Process RAG chunks using utility function
164+
rag_chunks = create_rag_chunks_dict(summary)
165+
166+
# Extract referenced documents using utility function
167+
referenced_docs = create_referenced_documents_with_metadata(summary, metadata_map)
168+
169+
# Convert ReferencedDocument objects to dictionaries for JSON serialization
170+
referenced_docs_dict = [
171+
{
172+
"doc_url": str(doc.doc_url) if doc.doc_url else None,
173+
"doc_title": doc.doc_title,
174+
}
175+
for doc in referenced_docs
176+
]
177+
161178
return format_stream_data(
162179
{
163180
"event": "end",
164181
"data": {
165-
"referenced_documents": [
166-
{
167-
"doc_url": v["docs_url"],
168-
"doc_title": v["title"],
169-
}
170-
for v in filter(
171-
lambda v: ("docs_url" in v) and ("title" in v),
172-
metadata_map.values(),
173-
)
174-
],
182+
"rag_chunks": rag_chunks,
183+
"referenced_documents": referenced_docs_dict,
175184
"truncated": None, # TODO(jboos): implement truncated
176185
"input_tokens": 0, # TODO(jboos): implement input tokens
177186
"output_tokens": 0, # TODO(jboos): implement output tokens
@@ -668,6 +677,8 @@ async def response_generator(
668677
yield stream_start_event(conversation_id)
669678

670679
async for chunk in turn_response:
680+
if chunk.event is None:
681+
continue
671682
p = chunk.event.payload
672683
if p.event_type == "turn_complete":
673684
summary.llm_response = interleaved_content_as_str(
@@ -688,7 +699,7 @@ async def response_generator(
688699
chunk_id += 1
689700
yield event
690701

691-
yield stream_end_event(metadata_map)
702+
yield stream_end_event(metadata_map, summary)
692703

693704
if not is_transcripts_enabled():
694705
logger.debug("Transcript collection is disabled in the configuration")
@@ -702,7 +713,7 @@ async def response_generator(
702713
query=query_request.query,
703714
query_request=query_request,
704715
summary=summary,
705-
rag_chunks=[], # TODO(lucasagomes): implement rag_chunks
716+
rag_chunks=create_rag_chunks_dict(summary),
706717
truncated=False, # TODO(lucasagomes): implement truncation as part
707718
# of quota work
708719
attachments=query_request.attachments or [],

src/utils/endpoints.py

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,22 @@
11
"""Utility functions for endpoint handlers."""
22

33
from contextlib import suppress
4+
from typing import Any
45
from fastapi import HTTPException, status
56
from llama_stack_client._client import AsyncLlamaStackClient
67
from llama_stack_client.lib.agents.agent import AsyncAgent
8+
from pydantic import AnyUrl, ValidationError
79

810
import constants
911
from models.cache_entry import CacheEntry
1012
from models.requests import QueryRequest
13+
from models.responses import ReferencedDocument
1114
from models.database.conversations import UserConversation
1215
from models.config import Action
1316
from app.database import get_session
1417
from configuration import AppConfig
1518
from utils.suid import get_suid
19+
from utils.types import TurnSummary
1620
from utils.types import GraniteToolParser
1721

1822

@@ -344,3 +348,216 @@ async def get_temp_agent(
344348
session_id = await agent.create_session(get_suid())
345349

346350
return agent, session_id, conversation_id
351+
352+
353+
def create_rag_chunks_dict(summary: TurnSummary) -> list[dict[str, Any]]:
354+
"""
355+
Create dictionary representation of RAG chunks for streaming response.
356+
357+
Args:
358+
summary: TurnSummary containing RAG chunks
359+
360+
Returns:
361+
List of dictionaries with content, source, and score
362+
"""
363+
return [
364+
{"content": chunk.content, "source": chunk.source, "score": chunk.score}
365+
for chunk in summary.rag_chunks
366+
]
367+
368+
369+
def _process_http_source(
370+
src: str, doc_urls: set[str]
371+
) -> tuple[AnyUrl | None, str] | None:
372+
"""Process HTTP source and return (doc_url, doc_title) tuple."""
373+
if src not in doc_urls:
374+
doc_urls.add(src)
375+
try:
376+
validated_url = AnyUrl(src)
377+
except ValidationError:
378+
logger.warning("Invalid URL in chunk source: %s", src)
379+
validated_url = None
380+
381+
doc_title = src.rsplit("/", 1)[-1] or src
382+
return (validated_url, doc_title)
383+
return None
384+
385+
386+
def _process_document_id(
387+
src: str,
388+
doc_ids: set[str],
389+
doc_urls: set[str],
390+
metas_by_id: dict[str, dict[str, Any]],
391+
metadata_map: dict[str, Any] | None,
392+
) -> tuple[AnyUrl | None, str] | None:
393+
"""Process document ID and return (doc_url, doc_title) tuple."""
394+
if src in doc_ids:
395+
return None
396+
doc_ids.add(src)
397+
398+
meta = metas_by_id.get(src, {}) if metadata_map else {}
399+
doc_url = meta.get("docs_url")
400+
title = meta.get("title")
401+
# Type check to ensure we have the right types
402+
if not isinstance(doc_url, (str, type(None))):
403+
doc_url = None
404+
if not isinstance(title, (str, type(None))):
405+
title = None
406+
407+
if doc_url:
408+
if doc_url in doc_urls:
409+
return None
410+
doc_urls.add(doc_url)
411+
412+
try:
413+
validated_doc_url = None
414+
if doc_url and doc_url.startswith("http"):
415+
validated_doc_url = AnyUrl(doc_url)
416+
except ValidationError:
417+
logger.warning("Invalid URL in metadata: %s", doc_url)
418+
validated_doc_url = None
419+
420+
doc_title = title or (doc_url.rsplit("/", 1)[-1] if doc_url else src)
421+
return (validated_doc_url, doc_title)
422+
423+
424+
def _add_additional_metadata_docs(
425+
doc_urls: set[str],
426+
metas_by_id: dict[str, dict[str, Any]],
427+
) -> list[tuple[AnyUrl | None, str]]:
428+
"""Add additional referenced documents from metadata_map."""
429+
additional_entries: list[tuple[AnyUrl | None, str]] = []
430+
for meta in metas_by_id.values():
431+
doc_url = meta.get("docs_url")
432+
title = meta.get("title") # Note: must be "title", not "Title"
433+
# Type check to ensure we have the right types
434+
if not isinstance(doc_url, (str, type(None))):
435+
doc_url = None
436+
if not isinstance(title, (str, type(None))):
437+
title = None
438+
if doc_url and doc_url not in doc_urls and title is not None:
439+
doc_urls.add(doc_url)
440+
try:
441+
validated_url = None
442+
if doc_url.startswith("http"):
443+
validated_url = AnyUrl(doc_url)
444+
except ValidationError:
445+
logger.warning("Invalid URL in metadata_map: %s", doc_url)
446+
validated_url = None
447+
448+
additional_entries.append((validated_url, title))
449+
return additional_entries
450+
451+
452+
def _process_rag_chunks_for_documents(
453+
rag_chunks: list,
454+
metadata_map: dict[str, Any] | None = None,
455+
) -> list[tuple[AnyUrl | None, str]]:
456+
"""
457+
Process RAG chunks and return a list of (doc_url, doc_title) tuples.
458+
459+
This is the core logic shared between both return formats.
460+
"""
461+
doc_urls: set[str] = set()
462+
doc_ids: set[str] = set()
463+
464+
# Process metadata_map if provided
465+
metas_by_id: dict[str, dict[str, Any]] = {}
466+
if metadata_map:
467+
metas_by_id = {k: v for k, v in metadata_map.items() if isinstance(v, dict)}
468+
469+
document_entries: list[tuple[AnyUrl | None, str]] = []
470+
471+
for chunk in rag_chunks:
472+
src = chunk.source
473+
if not src or src == constants.DEFAULT_RAG_TOOL:
474+
continue
475+
476+
if src.startswith("http"):
477+
entry = _process_http_source(src, doc_urls)
478+
if entry:
479+
document_entries.append(entry)
480+
else:
481+
entry = _process_document_id(
482+
src, doc_ids, doc_urls, metas_by_id, metadata_map
483+
)
484+
if entry:
485+
document_entries.append(entry)
486+
487+
# Add any additional referenced documents from metadata_map not already present
488+
if metadata_map:
489+
additional_entries = _add_additional_metadata_docs(doc_urls, metas_by_id)
490+
document_entries.extend(additional_entries)
491+
492+
return document_entries
493+
494+
495+
def create_referenced_documents(
496+
rag_chunks: list,
497+
metadata_map: dict[str, Any] | None = None,
498+
return_dict_format: bool = False,
499+
) -> list[ReferencedDocument] | list[dict[str, str | None]]:
500+
"""
501+
Create referenced documents from RAG chunks with optional metadata enrichment.
502+
503+
This unified function processes RAG chunks and creates referenced documents with
504+
optional metadata enrichment, deduplication, and proper URL handling. It can return
505+
either ReferencedDocument objects (for query endpoint) or dictionaries (for streaming).
506+
507+
Args:
508+
rag_chunks: List of RAG chunks with source information
509+
metadata_map: Optional mapping containing metadata about referenced documents
510+
return_dict_format: If True, returns list of dicts; if False, returns list of
511+
ReferencedDocument objects
512+
513+
Returns:
514+
List of ReferencedDocument objects or dictionaries with doc_url and doc_title
515+
"""
516+
document_entries = _process_rag_chunks_for_documents(rag_chunks, metadata_map)
517+
518+
if return_dict_format:
519+
return [
520+
{
521+
"doc_url": str(doc_url) if doc_url else None,
522+
"doc_title": doc_title,
523+
}
524+
for doc_url, doc_title in document_entries
525+
]
526+
return [
527+
ReferencedDocument(doc_url=doc_url, doc_title=doc_title)
528+
for doc_url, doc_title in document_entries
529+
]
530+
531+
532+
# Backward compatibility functions
533+
def create_referenced_documents_with_metadata(
534+
summary: TurnSummary, metadata_map: dict[str, Any]
535+
) -> list[ReferencedDocument]:
536+
"""
537+
Create referenced documents from RAG chunks with metadata enrichment for streaming.
538+
539+
This function now returns ReferencedDocument objects for consistency with the query endpoint.
540+
"""
541+
document_entries = _process_rag_chunks_for_documents(
542+
summary.rag_chunks, metadata_map
543+
)
544+
return [
545+
ReferencedDocument(doc_url=doc_url, doc_title=doc_title)
546+
for doc_url, doc_title in document_entries
547+
]
548+
549+
550+
def create_referenced_documents_from_chunks(
551+
rag_chunks: list,
552+
) -> list[ReferencedDocument]:
553+
"""
554+
Create referenced documents from RAG chunks for query endpoint.
555+
556+
This is a backward compatibility wrapper around the unified
557+
create_referenced_documents function.
558+
"""
559+
document_entries = _process_rag_chunks_for_documents(rag_chunks)
560+
return [
561+
ReferencedDocument(doc_url=doc_url, doc_title=doc_title)
562+
for doc_url, doc_title in document_entries
563+
]

0 commit comments

Comments
 (0)