|
7 | 7 | from llama_stack_client import LlamaStackClient # type: ignore |
8 | 8 | from llama_stack_client.types import UserMessage # type: ignore |
9 | 9 |
|
10 | | -from fastapi import APIRouter, Request |
| 10 | +from fastapi import APIRouter, Request, HTTPException, status |
11 | 11 |
|
12 | 12 | from client import get_llama_stack_client |
13 | 13 | from configuration import configuration |
14 | 14 | from models.responses import QueryResponse |
| 15 | +from models.requests import QueryRequest, Attachment |
| 16 | +import constants |
15 | 17 |
|
16 | 18 | logger = logging.getLogger("app.endpoints.handlers") |
17 | | -router = APIRouter(tags=["models"]) |
| 19 | +router = APIRouter(tags=["query"]) |
18 | 20 |
|
19 | 21 |
|
20 | 22 | query_response: dict[int | str, dict[str, Any]] = { |
21 | 23 | 200: { |
22 | | - "query": "User query", |
23 | | - "answer": "LLM ansert", |
| 24 | + "conversation_id": "123e4567-e89b-12d3-a456-426614174000", |
| 25 | + "response": "LLM ansert", |
24 | 26 | }, |
25 | 27 | } |
26 | 28 |
|
27 | 29 |
|
28 | 30 | @router.post("/query", responses=query_response) |
29 | | -def query_endpoint_handler(request: Request, query: str) -> QueryResponse: |
| 31 | +def query_endpoint_handler( |
| 32 | + request: Request, query_request: QueryRequest |
| 33 | +) -> QueryResponse: |
30 | 34 | llama_stack_config = configuration.llama_stack_configuration |
31 | 35 | logger.info("LLama stack config: %s", llama_stack_config) |
32 | | - |
33 | 36 | client = get_llama_stack_client(llama_stack_config) |
34 | | - |
35 | | - # retrieve list of available models |
36 | | - models = client.models.list() |
37 | | - |
38 | | - # select the first LLM |
39 | | - llm = next(m for m in models if m.model_type == "llm") |
40 | | - model_id = llm.identifier |
41 | | - |
42 | | - logger.info("Model: %s", model_id) |
43 | | - |
44 | | - response = retrieve_response(client, model_id, query) |
45 | | - |
46 | | - return QueryResponse(query=query, response=response) |
| 37 | + model_id = select_model_id(client, query_request) |
| 38 | + response = retrieve_response(client, model_id, query_request) |
| 39 | + return QueryResponse( |
| 40 | + conversation_id=query_request.conversation_id, response=response |
| 41 | + ) |
47 | 42 |
|
48 | 43 |
|
49 | | -def retrieve_response(client: LlamaStackClient, model_id: str, prompt: str) -> str: |
| 44 | +def select_model_id(client: LlamaStackClient, query_request: QueryRequest) -> str: |
| 45 | + """Select the model ID based on the request or available models.""" |
| 46 | + models = client.models.list() |
| 47 | + model_id = query_request.model |
| 48 | + provider_id = query_request.provider |
| 49 | + |
| 50 | + # TODO(lucasagomes): support default model selection via configuration |
| 51 | + if not model_id: |
| 52 | + logger.info("No model specified in request, using the first available LLM") |
| 53 | + try: |
| 54 | + return next(m for m in models if m.model_type == "llm").identifier |
| 55 | + except (StopIteration, AttributeError): |
| 56 | + message = "No LLM model found in available models" |
| 57 | + logger.error(message) |
| 58 | + raise HTTPException( |
| 59 | + status_code=status.HTTP_400_BAD_REQUEST, |
| 60 | + detail={ |
| 61 | + "response": constants.UNABLE_TO_PROCESS_RESPONSE, |
| 62 | + "cause": message, |
| 63 | + }, |
| 64 | + ) |
| 65 | + |
| 66 | + logger.info(f"Searching for model: {model_id}, provider: {provider_id}") |
| 67 | + if not any( |
| 68 | + m.identifier == model_id and m.provider_id == provider_id for m in models |
| 69 | + ): |
| 70 | + message = f"Model {model_id} from provider {provider_id} not found in available models" |
| 71 | + logger.error(message) |
| 72 | + raise HTTPException( |
| 73 | + status_code=status.HTTP_400_BAD_REQUEST, |
| 74 | + detail={ |
| 75 | + "response": constants.UNABLE_TO_PROCESS_RESPONSE, |
| 76 | + "cause": message, |
| 77 | + }, |
| 78 | + ) |
| 79 | + |
| 80 | + return model_id |
| 81 | + |
| 82 | + |
| 83 | +def retrieve_response( |
| 84 | + client: LlamaStackClient, model_id: str, query_request: QueryRequest |
| 85 | +) -> str: |
50 | 86 |
|
51 | 87 | available_shields = [shield.identifier for shield in client.shields.list()] |
52 | 88 | if not available_shields: |
53 | 89 | logger.info("No available shields. Disabling safety") |
54 | 90 | else: |
55 | 91 | logger.info(f"Available shields found: {available_shields}") |
56 | 92 |
|
| 93 | + # use system prompt from request or default one |
| 94 | + system_prompt = ( |
| 95 | + query_request.system_prompt |
| 96 | + if query_request.system_prompt |
| 97 | + else constants.DEFAULT_SYSTEM_PROMPT |
| 98 | + ) |
| 99 | + logger.debug(f"Using system prompt: {system_prompt}") |
| 100 | + |
| 101 | + # TODO(lucasagomes): redact attachments content before sending to LLM |
| 102 | + # if attachments are provided, validate them |
| 103 | + if query_request.attachments: |
| 104 | + validate_attachments_metadata(query_request.attachments) |
| 105 | + |
57 | 106 | agent = Agent( |
58 | 107 | client, |
59 | 108 | model=model_id, |
60 | | - instructions="You are a helpful assistant", |
| 109 | + instructions=system_prompt, |
61 | 110 | input_shields=available_shields if available_shields else [], |
62 | 111 | tools=[], |
63 | 112 | ) |
64 | 113 | session_id = agent.create_session("chat_session") |
65 | 114 | response = agent.create_turn( |
66 | | - messages=[UserMessage(role="user", content=prompt)], |
| 115 | + messages=[UserMessage(role="user", content=query_request.query)], |
67 | 116 | session_id=session_id, |
| 117 | + documents=query_request.get_documents(), |
68 | 118 | stream=False, |
69 | 119 | ) |
70 | 120 |
|
71 | 121 | return str(response.output_message.content) |
| 122 | + |
| 123 | + |
| 124 | +def validate_attachments_metadata(attachments: list[Attachment]) -> None: |
| 125 | + """Validate the attachments metadata provided in the request. |
| 126 | + Raises HTTPException if any attachment has an improper type or content type. |
| 127 | + """ |
| 128 | + for attachment in attachments: |
| 129 | + if attachment.attachment_type not in constants.ATTACHMENT_TYPES: |
| 130 | + message = ( |
| 131 | + f"Attachment with improper type {attachment.attachment_type} detected" |
| 132 | + ) |
| 133 | + logger.error(message) |
| 134 | + raise HTTPException( |
| 135 | + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, |
| 136 | + detail={ |
| 137 | + "response": constants.UNABLE_TO_PROCESS_RESPONSE, |
| 138 | + "cause": message, |
| 139 | + }, |
| 140 | + ) |
| 141 | + if attachment.content_type not in constants.ATTACHMENT_CONTENT_TYPES: |
| 142 | + message = f"Attachment with improper content type {attachment.content_type} detected" |
| 143 | + logger.error(message) |
| 144 | + raise HTTPException( |
| 145 | + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, |
| 146 | + detail={ |
| 147 | + "response": constants.UNABLE_TO_PROCESS_RESPONSE, |
| 148 | + "cause": message, |
| 149 | + }, |
| 150 | + ) |
0 commit comments