Skip to content

Commit 76fa45b

Browse files
authored
Merge pull request #59 from umago/query-endpoint
Address /query endpoint compatibility
2 parents 983172b + 2c46aa9 commit 76fa45b

File tree

10 files changed

+616
-26
lines changed

10 files changed

+616
-26
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ coverage.xml
5050
.hypothesis/
5151
.pytest_cache/
5252
cover/
53+
tests/test_results/
5354

5455
# Translations
5556
*.mo

pdm.lock

Lines changed: 15 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ dev = [
2525
"black>=25.1.0",
2626
"pytest>=8.3.2",
2727
"pytest-cov>=5.0.0",
28+
"pytest-mock>=3.14.0",
2829
]
2930

3031
[tool.pdm.scripts]

src/app/endpoints/query.py

Lines changed: 101 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,65 +7,144 @@
77
from llama_stack_client import LlamaStackClient # type: ignore
88
from llama_stack_client.types import UserMessage # type: ignore
99

10-
from fastapi import APIRouter, Request
10+
from fastapi import APIRouter, Request, HTTPException, status
1111

1212
from client import get_llama_stack_client
1313
from configuration import configuration
1414
from models.responses import QueryResponse
15+
from models.requests import QueryRequest, Attachment
16+
import constants
1517

1618
logger = logging.getLogger("app.endpoints.handlers")
17-
router = APIRouter(tags=["models"])
19+
router = APIRouter(tags=["query"])
1820

1921

2022
query_response: dict[int | str, dict[str, Any]] = {
2123
200: {
22-
"query": "User query",
23-
"answer": "LLM ansert",
24+
"conversation_id": "123e4567-e89b-12d3-a456-426614174000",
25+
"response": "LLM ansert",
2426
},
2527
}
2628

2729

2830
@router.post("/query", responses=query_response)
29-
def query_endpoint_handler(request: Request, query: str) -> QueryResponse:
31+
def query_endpoint_handler(
32+
request: Request, query_request: QueryRequest
33+
) -> QueryResponse:
3034
llama_stack_config = configuration.llama_stack_configuration
3135
logger.info("LLama stack config: %s", llama_stack_config)
32-
3336
client = get_llama_stack_client(llama_stack_config)
34-
35-
# retrieve list of available models
36-
models = client.models.list()
37-
38-
# select the first LLM
39-
llm = next(m for m in models if m.model_type == "llm")
40-
model_id = llm.identifier
41-
42-
logger.info("Model: %s", model_id)
43-
44-
response = retrieve_response(client, model_id, query)
45-
46-
return QueryResponse(query=query, response=response)
37+
model_id = select_model_id(client, query_request)
38+
response = retrieve_response(client, model_id, query_request)
39+
return QueryResponse(
40+
conversation_id=query_request.conversation_id, response=response
41+
)
4742

4843

49-
def retrieve_response(client: LlamaStackClient, model_id: str, prompt: str) -> str:
44+
def select_model_id(client: LlamaStackClient, query_request: QueryRequest) -> str:
45+
"""Select the model ID based on the request or available models."""
46+
models = client.models.list()
47+
model_id = query_request.model
48+
provider_id = query_request.provider
49+
50+
# TODO(lucasagomes): support default model selection via configuration
51+
if not model_id:
52+
logger.info("No model specified in request, using the first available LLM")
53+
try:
54+
return next(m for m in models if m.model_type == "llm").identifier
55+
except (StopIteration, AttributeError):
56+
message = "No LLM model found in available models"
57+
logger.error(message)
58+
raise HTTPException(
59+
status_code=status.HTTP_400_BAD_REQUEST,
60+
detail={
61+
"response": constants.UNABLE_TO_PROCESS_RESPONSE,
62+
"cause": message,
63+
},
64+
)
65+
66+
logger.info(f"Searching for model: {model_id}, provider: {provider_id}")
67+
if not any(
68+
m.identifier == model_id and m.provider_id == provider_id for m in models
69+
):
70+
message = f"Model {model_id} from provider {provider_id} not found in available models"
71+
logger.error(message)
72+
raise HTTPException(
73+
status_code=status.HTTP_400_BAD_REQUEST,
74+
detail={
75+
"response": constants.UNABLE_TO_PROCESS_RESPONSE,
76+
"cause": message,
77+
},
78+
)
79+
80+
return model_id
81+
82+
83+
def retrieve_response(
84+
client: LlamaStackClient, model_id: str, query_request: QueryRequest
85+
) -> str:
5086

5187
available_shields = [shield.identifier for shield in client.shields.list()]
5288
if not available_shields:
5389
logger.info("No available shields. Disabling safety")
5490
else:
5591
logger.info(f"Available shields found: {available_shields}")
5692

93+
# use system prompt from request or default one
94+
system_prompt = (
95+
query_request.system_prompt
96+
if query_request.system_prompt
97+
else constants.DEFAULT_SYSTEM_PROMPT
98+
)
99+
logger.debug(f"Using system prompt: {system_prompt}")
100+
101+
# TODO(lucasagomes): redact attachments content before sending to LLM
102+
# if attachments are provided, validate them
103+
if query_request.attachments:
104+
validate_attachments_metadata(query_request.attachments)
105+
57106
agent = Agent(
58107
client,
59108
model=model_id,
60-
instructions="You are a helpful assistant",
109+
instructions=system_prompt,
61110
input_shields=available_shields if available_shields else [],
62111
tools=[],
63112
)
64113
session_id = agent.create_session("chat_session")
65114
response = agent.create_turn(
66-
messages=[UserMessage(role="user", content=prompt)],
115+
messages=[UserMessage(role="user", content=query_request.query)],
67116
session_id=session_id,
117+
documents=query_request.get_documents(),
68118
stream=False,
69119
)
70120

71121
return str(response.output_message.content)
122+
123+
124+
def validate_attachments_metadata(attachments: list[Attachment]) -> None:
125+
"""Validate the attachments metadata provided in the request.
126+
Raises HTTPException if any attachment has an improper type or content type.
127+
"""
128+
for attachment in attachments:
129+
if attachment.attachment_type not in constants.ATTACHMENT_TYPES:
130+
message = (
131+
f"Attachment with improper type {attachment.attachment_type} detected"
132+
)
133+
logger.error(message)
134+
raise HTTPException(
135+
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
136+
detail={
137+
"response": constants.UNABLE_TO_PROCESS_RESPONSE,
138+
"cause": message,
139+
},
140+
)
141+
if attachment.content_type not in constants.ATTACHMENT_CONTENT_TYPES:
142+
message = f"Attachment with improper content type {attachment.content_type} detected"
143+
logger.error(message)
144+
raise HTTPException(
145+
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
146+
detail={
147+
"response": constants.UNABLE_TO_PROCESS_RESPONSE,
148+
"cause": message,
149+
},
150+
)

src/constants.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
UNABLE_TO_PROCESS_RESPONSE = "Unable to process this request"
2+
3+
# Supported attachment types
4+
ATTACHMENT_TYPES = frozenset(
5+
{
6+
"alert",
7+
"api object",
8+
"configuration",
9+
"error message",
10+
"event",
11+
"log",
12+
"stack trace",
13+
}
14+
)
15+
16+
# Supported attachment content types
17+
ATTACHMENT_CONTENT_TYPES = frozenset(
18+
{"text/plain", "application/json", "application/yaml", "application/xml"}
19+
)
20+
21+
DEFAULT_SYSTEM_PROMPT = "You are a helpful assistant"

src/models/requests.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
from pydantic import BaseModel, model_validator
2+
from llama_stack_client.types.agents.turn_create_params import Document
3+
from typing import Optional, Self
4+
5+
6+
class Attachment(BaseModel):
7+
"""Model representing an attachment that can be send from UI as part of query.
8+
9+
List of attachments can be optional part of 'query' request.
10+
11+
Attributes:
12+
attachment_type: The attachment type, like "log", "configuration" etc.
13+
content_type: The content type as defined in MIME standard
14+
content: The actual attachment content
15+
16+
YAML attachments with **kind** and **metadata/name** attributes will
17+
be handled as resources with specified name:
18+
```
19+
kind: Pod
20+
metadata:
21+
name: private-reg
22+
```
23+
"""
24+
25+
attachment_type: str
26+
content_type: str
27+
content: str
28+
29+
# provides examples for /docs endpoint
30+
model_config = {
31+
"json_schema_extra": {
32+
"examples": [
33+
{
34+
"attachment_type": "log",
35+
"content_type": "text/plain",
36+
"content": "this is attachment",
37+
},
38+
{
39+
"attachment_type": "configuration",
40+
"content_type": "application/yaml",
41+
"content": "kind: Pod\n metadata:\n name: private-reg",
42+
},
43+
{
44+
"attachment_type": "configuration",
45+
"content_type": "application/yaml",
46+
"content": "foo: bar",
47+
},
48+
]
49+
}
50+
}
51+
52+
53+
# TODO(lucasagomes): add media_type when needed, current implementation
54+
# does not support streaming response, so this is not used
55+
class QueryRequest(BaseModel):
56+
"""Model representing a request for the LLM (Language Model).
57+
58+
Attributes:
59+
query: The query string.
60+
conversation_id: The optional conversation ID (UUID).
61+
provider: The optional provider.
62+
model: The optional model.
63+
attachments: The optional attachments.
64+
65+
Example:
66+
```python
67+
query_request = QueryRequest(query="Tell me about Kubernetes")
68+
```
69+
"""
70+
71+
query: str
72+
conversation_id: Optional[str] = None
73+
provider: Optional[str] = None
74+
model: Optional[str] = None
75+
system_prompt: Optional[str] = None
76+
attachments: Optional[list[Attachment]] = None
77+
78+
# provides examples for /docs endpoint
79+
model_config = {
80+
"extra": "forbid",
81+
"json_schema_extra": {
82+
"examples": [
83+
{
84+
"query": "write a deployment yaml for the mongodb image",
85+
"conversation_id": "123e4567-e89b-12d3-a456-426614174000",
86+
"provider": "openai",
87+
"model": "model-name",
88+
"system_prompt": "You are a helpful assistant",
89+
"attachments": [
90+
{
91+
"attachment_type": "log",
92+
"content_type": "text/plain",
93+
"content": "this is attachment",
94+
},
95+
{
96+
"attachment_type": "configuration",
97+
"content_type": "application/yaml",
98+
"content": "kind: Pod\n metadata:\n name: private-reg",
99+
},
100+
{
101+
"attachment_type": "configuration",
102+
"content_type": "application/yaml",
103+
"content": "foo: bar",
104+
},
105+
],
106+
}
107+
]
108+
},
109+
}
110+
111+
def get_documents(self) -> list[Document]:
112+
"""Returns the list of documents from the attachments."""
113+
if not self.attachments:
114+
return []
115+
return [
116+
Document(content=att.content, mime_type=att.content_type)
117+
for att in self.attachments
118+
]
119+
120+
@model_validator(mode="after")
121+
def validate_provider_and_model(self) -> Self:
122+
"""Perform validation on the provider and model."""
123+
if self.model and not self.provider:
124+
raise ValueError("Provider must be specified if model is specified")
125+
if self.provider and not self.model:
126+
raise ValueError("Model must be specified if provider is specified")
127+
return self

0 commit comments

Comments
 (0)