|
3 | 3 | import logging |
4 | 4 | from typing import Any |
5 | 5 |
|
| 6 | +from llama_stack_client import LlamaStackClient |
| 7 | + |
6 | 8 | from fastapi import APIRouter, Request |
7 | 9 |
|
| 10 | +from configuration import configuration |
8 | 11 | from models.responses import QueryResponse |
9 | 12 |
|
10 | 13 | logger = logging.getLogger(__name__) |
|
19 | 22 | } |
20 | 23 |
|
21 | 24 |
|
22 | | -@router.get("/query", responses=query_response) |
23 | | -def info_endpoint_handler(request: Request) -> QueryResponse: |
24 | | - return QueryResponse(query="foo", response="bar") |
| 25 | +@router.post("/query", responses=query_response) |
| 26 | +def info_endpoint_handler(request: Request, query: str) -> QueryResponse: |
| 27 | + llama_stack_config = configuration.llama_stack_configuration |
| 28 | + logger.info("LLama stack config: %s", llama_stack_config) |
| 29 | + client = LlamaStackClient( |
| 30 | + base_url=llama_stack_config.url, api_key=llama_stack_config.api_key |
| 31 | + ) |
| 32 | + |
| 33 | + # retrieve list of available models |
| 34 | + models = client.models.list() |
| 35 | + |
| 36 | + # select the first LLM |
| 37 | + llm = next(m for m in models if m.model_type == "llm") |
| 38 | + model_id = llm.identifier |
| 39 | + |
| 40 | + logger.info("Model: %s", model_id) |
| 41 | + |
| 42 | + response = client.inference.chat_completion( |
| 43 | + model_id=model_id, |
| 44 | + messages=[ |
| 45 | + {"role": "system", "content": "You are a helpful assistant."}, |
| 46 | + {"role": "user", "content": query}, |
| 47 | + ], |
| 48 | + ) |
| 49 | + return QueryResponse(query=query, response=str(response.completion_message.content)) |
0 commit comments