Skip to content

Commit a417236

Browse files
committed
Added rag_chunks to streaming_query
1 parent 6529c37 commit a417236

File tree

1 file changed

+47
-14
lines changed

1 file changed

+47
-14
lines changed

src/app/endpoints/streaming_query.py

Lines changed: 47 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from metrics.utils import update_llm_token_count_from_turn
3030
from models.config import Action
3131
from models.requests import QueryRequest
32-
from models.responses import UnauthorizedResponse, ForbiddenResponse
32+
from models.responses import UnauthorizedResponse, ForbiddenResponse, RAGChunk, ReferencedDocument
3333
from models.database.conversations import UserConversation
3434
from utils.endpoints import check_configuration_loaded, get_agent, get_system_prompt
3535
from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups
@@ -135,7 +135,7 @@ def stream_start_event(conversation_id: str) -> str:
135135
)
136136

137137

138-
def stream_end_event(metadata_map: dict) -> str:
138+
def stream_end_event(metadata_map: dict, summary: TurnSummary) -> str:
139139
"""
140140
Yield the end of the data stream.
141141
@@ -151,20 +151,44 @@ def stream_end_event(metadata_map: dict) -> str:
151151
str: A Server-Sent Events (SSE) formatted string
152152
representing the end of the data stream.
153153
"""
154+
# Process RAG chunks
155+
rag_chunks = [
156+
{
157+
"content": chunk.content,
158+
"source": chunk.source,
159+
"score": chunk.score
160+
}
161+
for chunk in summary.rag_chunks
162+
]
163+
164+
# Extract referenced documents from RAG chunks
165+
referenced_docs = []
166+
doc_sources = set()
167+
for chunk in summary.rag_chunks:
168+
if chunk.source and chunk.source not in doc_sources:
169+
doc_sources.add(chunk.source)
170+
referenced_docs.append({
171+
"doc_url": chunk.source if chunk.source.startswith("http") else None,
172+
"doc_title": chunk.source.split("/")[-1] if chunk.source else None,
173+
})
174+
175+
# Add any additional referenced documents from metadata_map
176+
for v in filter(
177+
lambda v: ("docs_url" in v) and ("title" in v),
178+
metadata_map.values(),
179+
):
180+
if v["docs_url"] not in doc_sources:
181+
referenced_docs.append({
182+
"doc_url": v["docs_url"],
183+
"doc_title": v["title"],
184+
})
185+
154186
return format_stream_data(
155187
{
156188
"event": "end",
157189
"data": {
158-
"referenced_documents": [
159-
{
160-
"doc_url": v["docs_url"],
161-
"doc_title": v["title"],
162-
}
163-
for v in filter(
164-
lambda v: ("docs_url" in v) and ("title" in v),
165-
metadata_map.values(),
166-
)
167-
],
190+
"rag_chunks": rag_chunks,
191+
"referenced_documents": referenced_docs,
168192
"truncated": None, # TODO(jboos): implement truncated
169193
"input_tokens": 0, # TODO(jboos): implement input tokens
170194
"output_tokens": 0, # TODO(jboos): implement output tokens
@@ -680,11 +704,20 @@ async def response_generator(
680704
chunk_id += 1
681705
yield event
682706

683-
yield stream_end_event(metadata_map)
707+
yield stream_end_event(metadata_map, summary)
684708

685709
if not is_transcripts_enabled():
686710
logger.debug("Transcript collection is disabled in the configuration")
687711
else:
712+
# Convert RAG chunks to serializable format for store_transcript
713+
rag_chunks_for_transcript = [
714+
{
715+
"content": chunk.content,
716+
"source": chunk.source,
717+
"score": chunk.score
718+
}
719+
for chunk in summary.rag_chunks
720+
]
688721
store_transcript(
689722
user_id=user_id,
690723
conversation_id=conversation_id,
@@ -694,7 +727,7 @@ async def response_generator(
694727
query=query_request.query,
695728
query_request=query_request,
696729
summary=summary,
697-
rag_chunks=[], # TODO(lucasagomes): implement rag_chunks
730+
rag_chunks=rag_chunks_for_transcript,
698731
truncated=False, # TODO(lucasagomes): implement truncation as part
699732
# of quota work
700733
attachments=query_request.attachments or [],

0 commit comments

Comments
 (0)