|
4 | 4 | import logging |
5 | 5 | from typing import Any, AsyncIterator |
6 | 6 |
|
| 7 | +from llama_stack_client import APIConnectionError |
7 | 8 | from llama_stack_client.lib.agents.agent import AsyncAgent # type: ignore |
8 | 9 | from llama_stack_client import AsyncLlamaStackClient # type: ignore |
9 | 10 | from llama_stack_client.types import UserMessage # type: ignore |
10 | 11 |
|
11 | | -from fastapi import APIRouter, Request, Depends |
| 12 | +from fastapi import APIRouter, HTTPException, Request, Depends, status |
12 | 13 | from fastapi.responses import StreamingResponse |
13 | 14 |
|
14 | 15 | from client import get_async_llama_stack_client |
15 | 16 | from configuration import configuration |
16 | 17 | from models.requests import QueryRequest |
17 | 18 | import constants |
18 | 19 | from utils.auth import auth_dependency |
| 20 | +from utils.endpoints import check_configuration_loaded |
19 | 21 | from utils.common import retrieve_user_id |
20 | 22 |
|
21 | 23 |
|
@@ -128,47 +130,63 @@ async def streaming_query_endpoint_handler( |
128 | 130 | auth: Any = Depends(auth_dependency), |
129 | 131 | ) -> StreamingResponse: |
130 | 132 | """Handle request to the /streaming_query endpoint.""" |
| 133 | + check_configuration_loaded(configuration) |
| 134 | + |
131 | 135 | llama_stack_config = configuration.llama_stack_configuration |
132 | 136 | logger.info("LLama stack config: %s", llama_stack_config) |
133 | | - client = await get_async_llama_stack_client(llama_stack_config) |
134 | | - model_id = select_model_id(await client.models.list(), query_request) |
135 | | - conversation_id = retrieve_conversation_id(query_request) |
136 | | - response = await retrieve_response(client, model_id, query_request) |
137 | | - |
138 | | - async def response_generator(turn_response: Any) -> AsyncIterator[str]: |
139 | | - """Generate SSE formatted streaming response.""" |
140 | | - chunk_id = 0 |
141 | | - complete_response = "" |
142 | | - |
143 | | - # Send start event |
144 | | - yield stream_start_event(conversation_id) |
145 | | - |
146 | | - async for chunk in turn_response: |
147 | | - if event := stream_build_event(chunk, chunk_id): |
148 | | - complete_response += json.loads(event.replace("data: ", ""))["data"][ |
149 | | - "token" |
150 | | - ] |
151 | | - chunk_id += 1 |
152 | | - yield event |
153 | | - |
154 | | - yield stream_end_event() |
155 | | - |
156 | | - if not is_transcripts_enabled(): |
157 | | - logger.debug("Transcript collection is disabled in the configuration") |
158 | | - else: |
159 | | - store_transcript( |
160 | | - user_id=retrieve_user_id(auth), |
161 | | - conversation_id=conversation_id, |
162 | | - query_is_valid=True, # TODO(lucasagomes): implement as part of query validation |
163 | | - query=query_request.query, |
164 | | - query_request=query_request, |
165 | | - response=complete_response, |
166 | | - rag_chunks=[], # TODO(lucasagomes): implement rag_chunks |
167 | | - truncated=False, # TODO(lucasagomes): implement truncation as part of quota work |
168 | | - attachments=query_request.attachments or [], |
169 | | - ) |
170 | | - |
171 | | - return StreamingResponse(response_generator(response)) |
| 137 | + |
| 138 | + try: |
| 139 | + # try to get Llama Stack client |
| 140 | + client = await get_async_llama_stack_client(llama_stack_config) |
| 141 | + model_id = select_model_id(await client.models.list(), query_request) |
| 142 | + conversation_id = retrieve_conversation_id(query_request) |
| 143 | + response = await retrieve_response(client, model_id, query_request) |
| 144 | + |
| 145 | + async def response_generator(turn_response: Any) -> AsyncIterator[str]: |
| 146 | + """Generate SSE formatted streaming response.""" |
| 147 | + chunk_id = 0 |
| 148 | + complete_response = "" |
| 149 | + |
| 150 | + # Send start event |
| 151 | + yield stream_start_event(conversation_id) |
| 152 | + |
| 153 | + async for chunk in turn_response: |
| 154 | + if event := stream_build_event(chunk, chunk_id): |
| 155 | + complete_response += json.loads(event.replace("data: ", ""))[ |
| 156 | + "data" |
| 157 | + ]["token"] |
| 158 | + chunk_id += 1 |
| 159 | + yield event |
| 160 | + |
| 161 | + yield stream_end_event() |
| 162 | + |
| 163 | + if not is_transcripts_enabled(): |
| 164 | + logger.debug("Transcript collection is disabled in the configuration") |
| 165 | + else: |
| 166 | + store_transcript( |
| 167 | + user_id=retrieve_user_id(auth), |
| 168 | + conversation_id=conversation_id, |
| 169 | + query_is_valid=True, # TODO(lucasagomes): implement as part of query validation |
| 170 | + query=query_request.query, |
| 171 | + query_request=query_request, |
| 172 | + response=complete_response, |
| 173 | + rag_chunks=[], # TODO(lucasagomes): implement rag_chunks |
| 174 | + truncated=False, # TODO(lucasagomes): implement truncation as part |
| 175 | + # of quota work |
| 176 | + attachments=query_request.attachments or [], |
| 177 | + ) |
| 178 | + |
| 179 | + return StreamingResponse(response_generator(response)) |
| 180 | + # connection to Llama Stack server |
| 181 | + except APIConnectionError as e: |
| 182 | + logger.error("Unable to connect to Llama Stack: %s", e) |
| 183 | + raise HTTPException( |
| 184 | + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
| 185 | + detail={ |
| 186 | + "response": "Unable to connect to Llama Stack", |
| 187 | + "cause": str(e), |
| 188 | + }, |
| 189 | + ) from e |
172 | 190 |
|
173 | 191 |
|
174 | 192 | async def retrieve_response( |
|
0 commit comments