Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 54 additions & 6 deletions src/app/endpoints/rlsapi_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@

from fastapi import APIRouter, Depends, HTTPException
from llama_stack.apis.agents.openai_responses import OpenAIResponseObject
from llama_stack_client import APIConnectionError
from llama_stack_client import APIConnectionError, APIStatusError, RateLimitError

import constants
import metrics
from authentication import get_auth_dependency
from authentication.interface import AuthTuple
from authorization.middleware import authorize
Expand All @@ -20,11 +21,13 @@
from models.config import Action
from models.responses import (
ForbiddenResponse,
InternalServerErrorResponse,
QuotaExceededResponse,
ServiceUnavailableResponse,
UnauthorizedResponse,
UnprocessableEntityResponse,
)
from models.rlsapi.requests import RlsapiV1InferRequest
from models.rlsapi.requests import RlsapiV1InferRequest, RlsapiV1SystemInfo
from models.rlsapi.responses import RlsapiV1InferData, RlsapiV1InferResponse
from utils.responses import extract_text_from_response_output_item
from utils.suid import get_suid
Expand All @@ -40,10 +43,41 @@
),
403: ForbiddenResponse.openapi_response(examples=["endpoint"]),
422: UnprocessableEntityResponse.openapi_response(),
429: QuotaExceededResponse.openapi_response(),
500: InternalServerErrorResponse.openapi_response(examples=["generic"]),
503: ServiceUnavailableResponse.openapi_response(),
}


def _build_instructions(systeminfo: RlsapiV1SystemInfo) -> str:
"""Build LLM instructions incorporating system context when available.

Enhances the default system prompt with RHEL system information to provide
the LLM with relevant context about the user's environment.

Args:
systeminfo: System information from the client (OS, version, arch).

Returns:
Instructions string for the LLM, with system context if available.
"""
base_prompt = constants.DEFAULT_SYSTEM_PROMPT

context_parts = []
if systeminfo.os:
context_parts.append(f"OS: {systeminfo.os}")
if systeminfo.version:
context_parts.append(f"Version: {systeminfo.version}")
if systeminfo.arch:
context_parts.append(f"Architecture: {systeminfo.arch}")

if not context_parts:
return base_prompt

system_context = ", ".join(context_parts)
return f"{base_prompt}\n\nUser's system: {system_context}"


def _get_default_model_id() -> str:
"""Get the default model ID from configuration.

Expand Down Expand Up @@ -77,14 +111,15 @@ def _get_default_model_id() -> str:
)


async def retrieve_simple_response(question: str) -> str:
async def retrieve_simple_response(question: str, instructions: str) -> str:
"""Retrieve a simple response from the LLM for a stateless query.

Uses the Responses API for simple stateless inference, consistent with
other endpoints (query_v2, streaming_query_v2).

Args:
question: The combined user input (question + context).
instructions: System instructions for the LLM.

Returns:
The LLM-generated response text.
Expand All @@ -101,7 +136,7 @@ async def retrieve_simple_response(question: str) -> str:
response = await client.responses.create(
input=question,
model=model_id,
instructions=constants.DEFAULT_SYSTEM_PROMPT,
instructions=instructions,
stream=False,
store=False,
)
Expand Down Expand Up @@ -144,15 +179,16 @@ async def infer_endpoint(

logger.info("Processing rlsapi v1 /infer request %s", request_id)

# Combine all input sources (question, stdin, attachments, terminal)
input_source = infer_request.get_input_source()
instructions = _build_instructions(infer_request.context.systeminfo)
logger.debug(
"Request %s: Combined input source length: %d", request_id, len(input_source)
)

try:
response_text = await retrieve_simple_response(input_source)
response_text = await retrieve_simple_response(input_source, instructions)
except APIConnectionError as e:
metrics.llm_calls_failures_total.inc()
logger.error(
"Unable to connect to Llama Stack for request %s: %s", request_id, e
)
Expand All @@ -161,6 +197,18 @@ async def infer_endpoint(
cause=str(e),
)
raise HTTPException(**response.model_dump()) from e
except RateLimitError as e:
metrics.llm_calls_failures_total.inc()
logger.error("Rate limit exceeded for request %s: %s", request_id, e)
response = QuotaExceededResponse(
response="The quota has been exceeded", cause=str(e)
)
raise HTTPException(**response.model_dump()) from e
except APIStatusError as e:
metrics.llm_calls_failures_total.inc()
logger.exception("API error for request %s: %s", request_id, e)
response = InternalServerErrorResponse.generic()
raise HTTPException(**response.model_dump()) from e

if not response_text:
logger.warning("Empty response from LLM for request %s", request_id)
Expand Down
4 changes: 2 additions & 2 deletions src/models/rlsapi/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ class RlsapiV1InferRequest(ConfigurationBase):
Attributes:
question: User question string.
context: Context with system info, terminal output, etc. (defaults provided).
skip_rag: Whether to skip RAG retrieval (default False).
skip_rag: Reserved for future use. RAG retrieval is not yet implemented.

Example:
```python
Expand All @@ -152,7 +152,7 @@ class RlsapiV1InferRequest(ConfigurationBase):
)
skip_rag: bool = Field(
default=False,
description="Whether to skip RAG retrieval",
description="Reserved for future use. RAG retrieval is not yet implemented.",
examples=[False, True],
)

Expand Down
81 changes: 67 additions & 14 deletions tests/unit/app/endpoints/test_rlsapi_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import constants
from app.endpoints.rlsapi_v1 import (
_build_instructions,
_get_default_model_id,
infer_endpoint,
retrieve_simple_response,
Expand All @@ -30,7 +31,7 @@
from tests.unit.utils.auth_helpers import mock_authorization_resolvers
from utils.suid import check_suid

MOCK_AUTH: AuthTuple = ("test_user_id", "test_user", True, "test_token")
MOCK_AUTH: AuthTuple = ("mock_user_id", "mock_username", False, "mock_token")


def _setup_responses_mock(mocker: MockerFixture, create_behavior: Any) -> None:
Expand Down Expand Up @@ -87,6 +88,12 @@ def mock_empty_llm_response_fixture(mocker: MockerFixture) -> None:
_setup_responses_mock(mocker, mocker.AsyncMock(return_value=mock_response))


@pytest.fixture(name="mock_auth_resolvers")
def mock_auth_resolvers_fixture(mocker: MockerFixture) -> None:
"""Mock authorization resolvers for endpoint tests."""
mock_authorization_resolvers(mocker)


@pytest.fixture(name="mock_api_connection_error")
def mock_api_connection_error_fixture(mocker: MockerFixture) -> None:
"""Mock responses.create() to raise APIConnectionError."""
Expand All @@ -96,6 +103,47 @@ def mock_api_connection_error_fixture(mocker: MockerFixture) -> None:
)


# --- Test _build_instructions ---


@pytest.mark.parametrize(
("systeminfo_kwargs", "expected_contains", "expected_not_contains"),
[
pytest.param(
{"os": "RHEL", "version": "9.3", "arch": "x86_64"},
["OS: RHEL", "Version: 9.3", "Architecture: x86_64"],
[],
id="full_systeminfo",
),
pytest.param(
{"os": "RHEL", "version": "", "arch": ""},
["OS: RHEL"],
["Version:", "Architecture:"],
id="partial_systeminfo",
),
pytest.param(
{},
[constants.DEFAULT_SYSTEM_PROMPT],
["OS:", "Version:", "Architecture:"],
id="empty_systeminfo",
),
],
)
def test_build_instructions(
systeminfo_kwargs: dict[str, str],
expected_contains: list[str],
expected_not_contains: list[str],
) -> None:
"""Test _build_instructions with various system info combinations."""
systeminfo = RlsapiV1SystemInfo(**systeminfo_kwargs)
result = _build_instructions(systeminfo)

for expected in expected_contains:
assert expected in result
for not_expected in expected_not_contains:
assert not_expected not in result


# --- Test _get_default_model_id ---


Expand Down Expand Up @@ -151,7 +199,9 @@ async def test_retrieve_simple_response_success(
mock_configuration: AppConfig, mock_llm_response: None
) -> None:
"""Test retrieve_simple_response returns LLM response text."""
response = await retrieve_simple_response("How do I list files?")
response = await retrieve_simple_response(
"How do I list files?", constants.DEFAULT_SYSTEM_PROMPT
)
assert response == "This is a test LLM response."


Expand All @@ -160,7 +210,9 @@ async def test_retrieve_simple_response_empty_output(
mock_configuration: AppConfig, mock_empty_llm_response: None
) -> None:
"""Test retrieve_simple_response handles empty LLM output."""
response = await retrieve_simple_response("Test question")
response = await retrieve_simple_response(
"Test question", constants.DEFAULT_SYSTEM_PROMPT
)
assert response == ""


Expand All @@ -170,18 +222,19 @@ async def test_retrieve_simple_response_api_connection_error(
) -> None:
"""Test retrieve_simple_response propagates APIConnectionError."""
with pytest.raises(APIConnectionError):
await retrieve_simple_response("Test question")
await retrieve_simple_response("Test question", constants.DEFAULT_SYSTEM_PROMPT)


# --- Test infer_endpoint ---


@pytest.mark.asyncio
async def test_infer_minimal_request(
mocker: MockerFixture, mock_configuration: AppConfig, mock_llm_response: None
mock_configuration: AppConfig,
mock_llm_response: None,
mock_auth_resolvers: None,
) -> None:
"""Test /infer endpoint returns valid response with LLM text."""
mock_authorization_resolvers(mocker)
request = RlsapiV1InferRequest(question="How do I list files?")

response = await infer_endpoint(infer_request=request, auth=MOCK_AUTH)
Expand All @@ -194,10 +247,11 @@ async def test_infer_minimal_request(

@pytest.mark.asyncio
async def test_infer_full_context_request(
mocker: MockerFixture, mock_configuration: AppConfig, mock_llm_response: None
mock_configuration: AppConfig,
mock_llm_response: None,
mock_auth_resolvers: None,
) -> None:
"""Test /infer endpoint handles full context (stdin, attachments, terminal)."""
mock_authorization_resolvers(mocker)
request = RlsapiV1InferRequest(
question="Why did this command fail?",
context=RlsapiV1Context(
Expand All @@ -217,10 +271,11 @@ async def test_infer_full_context_request(

@pytest.mark.asyncio
async def test_infer_generates_unique_request_ids(
mocker: MockerFixture, mock_configuration: AppConfig, mock_llm_response: None
mock_configuration: AppConfig,
mock_llm_response: None,
mock_auth_resolvers: None,
) -> None:
"""Test that each /infer call generates a unique request_id."""
mock_authorization_resolvers(mocker)
request = RlsapiV1InferRequest(question="How do I list files?")

response1 = await infer_endpoint(infer_request=request, auth=MOCK_AUTH)
Expand All @@ -231,12 +286,11 @@ async def test_infer_generates_unique_request_ids(

@pytest.mark.asyncio
async def test_infer_api_connection_error_returns_503(
mocker: MockerFixture,
mock_configuration: AppConfig,
mock_api_connection_error: None,
mock_auth_resolvers: None,
) -> None:
"""Test /infer endpoint returns 503 when LLM service is unavailable."""
mock_authorization_resolvers(mocker)
request = RlsapiV1InferRequest(question="Test question")

with pytest.raises(HTTPException) as exc_info:
Expand All @@ -247,12 +301,11 @@ async def test_infer_api_connection_error_returns_503(

@pytest.mark.asyncio
async def test_infer_empty_llm_response_returns_fallback(
mocker: MockerFixture,
mock_configuration: AppConfig,
mock_empty_llm_response: None,
mock_auth_resolvers: None,
) -> None:
"""Test /infer endpoint returns fallback text when LLM returns empty response."""
mock_authorization_resolvers(mocker)
request = RlsapiV1InferRequest(question="Test question")

response = await infer_endpoint(infer_request=request, auth=MOCK_AUTH)
Expand Down
Loading