Skip to content

Commit fe4b0fc

Browse files
committed
Updated streaming query endpoint
1 parent 2e73753 commit fe4b0fc

File tree

1 file changed

+58
-40
lines changed

1 file changed

+58
-40
lines changed

src/app/endpoints/streaming_query.py

Lines changed: 58 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,20 @@
44
import logging
55
from typing import Any, AsyncIterator
66

7+
from llama_stack_client import APIConnectionError
78
from llama_stack_client.lib.agents.agent import AsyncAgent # type: ignore
89
from llama_stack_client import AsyncLlamaStackClient # type: ignore
910
from llama_stack_client.types import UserMessage # type: ignore
1011

11-
from fastapi import APIRouter, Request, Depends
12+
from fastapi import APIRouter, HTTPException, Request, Depends, status
1213
from fastapi.responses import StreamingResponse
1314

1415
from client import get_async_llama_stack_client
1516
from configuration import configuration
1617
from models.requests import QueryRequest
1718
import constants
1819
from utils.auth import auth_dependency
20+
from utils.endpoints import check_configuration_loaded
1921
from utils.common import retrieve_user_id
2022

2123

@@ -128,47 +130,63 @@ async def streaming_query_endpoint_handler(
128130
auth: Any = Depends(auth_dependency),
129131
) -> StreamingResponse:
130132
"""Handle request to the /streaming_query endpoint."""
133+
check_configuration_loaded(configuration)
134+
131135
llama_stack_config = configuration.llama_stack_configuration
132136
logger.info("LLama stack config: %s", llama_stack_config)
133-
client = await get_async_llama_stack_client(llama_stack_config)
134-
model_id = select_model_id(await client.models.list(), query_request)
135-
conversation_id = retrieve_conversation_id(query_request)
136-
response = await retrieve_response(client, model_id, query_request)
137-
138-
async def response_generator(turn_response: Any) -> AsyncIterator[str]:
139-
"""Generate SSE formatted streaming response."""
140-
chunk_id = 0
141-
complete_response = ""
142-
143-
# Send start event
144-
yield stream_start_event(conversation_id)
145-
146-
async for chunk in turn_response:
147-
if event := stream_build_event(chunk, chunk_id):
148-
complete_response += json.loads(event.replace("data: ", ""))["data"][
149-
"token"
150-
]
151-
chunk_id += 1
152-
yield event
153-
154-
yield stream_end_event()
155-
156-
if not is_transcripts_enabled():
157-
logger.debug("Transcript collection is disabled in the configuration")
158-
else:
159-
store_transcript(
160-
user_id=retrieve_user_id(auth),
161-
conversation_id=conversation_id,
162-
query_is_valid=True, # TODO(lucasagomes): implement as part of query validation
163-
query=query_request.query,
164-
query_request=query_request,
165-
response=complete_response,
166-
rag_chunks=[], # TODO(lucasagomes): implement rag_chunks
167-
truncated=False, # TODO(lucasagomes): implement truncation as part of quota work
168-
attachments=query_request.attachments or [],
169-
)
170-
171-
return StreamingResponse(response_generator(response))
137+
138+
try:
139+
# try to get Llama Stack client
140+
client = await get_async_llama_stack_client(llama_stack_config)
141+
model_id = select_model_id(await client.models.list(), query_request)
142+
conversation_id = retrieve_conversation_id(query_request)
143+
response = await retrieve_response(client, model_id, query_request)
144+
145+
async def response_generator(turn_response: Any) -> AsyncIterator[str]:
146+
"""Generate SSE formatted streaming response."""
147+
chunk_id = 0
148+
complete_response = ""
149+
150+
# Send start event
151+
yield stream_start_event(conversation_id)
152+
153+
async for chunk in turn_response:
154+
if event := stream_build_event(chunk, chunk_id):
155+
complete_response += json.loads(event.replace("data: ", ""))[
156+
"data"
157+
]["token"]
158+
chunk_id += 1
159+
yield event
160+
161+
yield stream_end_event()
162+
163+
if not is_transcripts_enabled():
164+
logger.debug("Transcript collection is disabled in the configuration")
165+
else:
166+
store_transcript(
167+
user_id=retrieve_user_id(auth),
168+
conversation_id=conversation_id,
169+
query_is_valid=True, # TODO(lucasagomes): implement as part of query validation
170+
query=query_request.query,
171+
query_request=query_request,
172+
response=complete_response,
173+
rag_chunks=[], # TODO(lucasagomes): implement rag_chunks
174+
truncated=False, # TODO(lucasagomes): implement truncation as part
175+
# of quota work
176+
attachments=query_request.attachments or [],
177+
)
178+
179+
return StreamingResponse(response_generator(response))
180+
# connection to Llama Stack server
181+
except APIConnectionError as e:
182+
logger.error("Unable to connect to Llama Stack: %s", e)
183+
raise HTTPException(
184+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
185+
detail={
186+
"response": "Unable to connect to Llama Stack",
187+
"cause": str(e),
188+
},
189+
) from e
172190

173191

174192
async def retrieve_response(

0 commit comments

Comments
 (0)