diff --git a/src/app/endpoints/rlsapi_v1.py b/src/app/endpoints/rlsapi_v1.py index 72fdf382..a7d8a4f1 100644 --- a/src/app/endpoints/rlsapi_v1.py +++ b/src/app/endpoints/rlsapi_v1.py @@ -9,9 +9,10 @@ from fastapi import APIRouter, Depends, HTTPException from llama_stack.apis.agents.openai_responses import OpenAIResponseObject -from llama_stack_client import APIConnectionError +from llama_stack_client import APIConnectionError, APIStatusError, RateLimitError import constants +import metrics from authentication import get_auth_dependency from authentication.interface import AuthTuple from authorization.middleware import authorize @@ -20,11 +21,13 @@ from models.config import Action from models.responses import ( ForbiddenResponse, + InternalServerErrorResponse, + QuotaExceededResponse, ServiceUnavailableResponse, UnauthorizedResponse, UnprocessableEntityResponse, ) -from models.rlsapi.requests import RlsapiV1InferRequest +from models.rlsapi.requests import RlsapiV1InferRequest, RlsapiV1SystemInfo from models.rlsapi.responses import RlsapiV1InferData, RlsapiV1InferResponse from utils.responses import extract_text_from_response_output_item from utils.suid import get_suid @@ -40,10 +43,41 @@ ), 403: ForbiddenResponse.openapi_response(examples=["endpoint"]), 422: UnprocessableEntityResponse.openapi_response(), + 429: QuotaExceededResponse.openapi_response(), + 500: InternalServerErrorResponse.openapi_response(examples=["generic"]), 503: ServiceUnavailableResponse.openapi_response(), } +def _build_instructions(systeminfo: RlsapiV1SystemInfo) -> str: + """Build LLM instructions incorporating system context when available. + + Enhances the default system prompt with RHEL system information to provide + the LLM with relevant context about the user's environment. + + Args: + systeminfo: System information from the client (OS, version, arch). + + Returns: + Instructions string for the LLM, with system context if available. + """ + base_prompt = constants.DEFAULT_SYSTEM_PROMPT + + context_parts = [] + if systeminfo.os: + context_parts.append(f"OS: {systeminfo.os}") + if systeminfo.version: + context_parts.append(f"Version: {systeminfo.version}") + if systeminfo.arch: + context_parts.append(f"Architecture: {systeminfo.arch}") + + if not context_parts: + return base_prompt + + system_context = ", ".join(context_parts) + return f"{base_prompt}\n\nUser's system: {system_context}" + + def _get_default_model_id() -> str: """Get the default model ID from configuration. @@ -77,7 +111,7 @@ def _get_default_model_id() -> str: ) -async def retrieve_simple_response(question: str) -> str: +async def retrieve_simple_response(question: str, instructions: str) -> str: """Retrieve a simple response from the LLM for a stateless query. Uses the Responses API for simple stateless inference, consistent with @@ -85,6 +119,7 @@ async def retrieve_simple_response(question: str) -> str: Args: question: The combined user input (question + context). + instructions: System instructions for the LLM. Returns: The LLM-generated response text. @@ -101,7 +136,7 @@ async def retrieve_simple_response(question: str) -> str: response = await client.responses.create( input=question, model=model_id, - instructions=constants.DEFAULT_SYSTEM_PROMPT, + instructions=instructions, stream=False, store=False, ) @@ -144,15 +179,16 @@ async def infer_endpoint( logger.info("Processing rlsapi v1 /infer request %s", request_id) - # Combine all input sources (question, stdin, attachments, terminal) input_source = infer_request.get_input_source() + instructions = _build_instructions(infer_request.context.systeminfo) logger.debug( "Request %s: Combined input source length: %d", request_id, len(input_source) ) try: - response_text = await retrieve_simple_response(input_source) + response_text = await retrieve_simple_response(input_source, instructions) except APIConnectionError as e: + metrics.llm_calls_failures_total.inc() logger.error( "Unable to connect to Llama Stack for request %s: %s", request_id, e ) @@ -161,6 +197,18 @@ async def infer_endpoint( cause=str(e), ) raise HTTPException(**response.model_dump()) from e + except RateLimitError as e: + metrics.llm_calls_failures_total.inc() + logger.error("Rate limit exceeded for request %s: %s", request_id, e) + response = QuotaExceededResponse( + response="The quota has been exceeded", cause=str(e) + ) + raise HTTPException(**response.model_dump()) from e + except APIStatusError as e: + metrics.llm_calls_failures_total.inc() + logger.exception("API error for request %s: %s", request_id, e) + response = InternalServerErrorResponse.generic() + raise HTTPException(**response.model_dump()) from e if not response_text: logger.warning("Empty response from LLM for request %s", request_id) diff --git a/src/models/rlsapi/requests.py b/src/models/rlsapi/requests.py index 5bb05b78..fc0f7724 100644 --- a/src/models/rlsapi/requests.py +++ b/src/models/rlsapi/requests.py @@ -126,7 +126,7 @@ class RlsapiV1InferRequest(ConfigurationBase): Attributes: question: User question string. context: Context with system info, terminal output, etc. (defaults provided). - skip_rag: Whether to skip RAG retrieval (default False). + skip_rag: Reserved for future use. RAG retrieval is not yet implemented. Example: ```python @@ -152,7 +152,7 @@ class RlsapiV1InferRequest(ConfigurationBase): ) skip_rag: bool = Field( default=False, - description="Whether to skip RAG retrieval", + description="Reserved for future use. RAG retrieval is not yet implemented.", examples=[False, True], ) diff --git a/tests/unit/app/endpoints/test_rlsapi_v1.py b/tests/unit/app/endpoints/test_rlsapi_v1.py index 87351398..1bfae922 100644 --- a/tests/unit/app/endpoints/test_rlsapi_v1.py +++ b/tests/unit/app/endpoints/test_rlsapi_v1.py @@ -13,6 +13,7 @@ import constants from app.endpoints.rlsapi_v1 import ( + _build_instructions, _get_default_model_id, infer_endpoint, retrieve_simple_response, @@ -30,7 +31,7 @@ from tests.unit.utils.auth_helpers import mock_authorization_resolvers from utils.suid import check_suid -MOCK_AUTH: AuthTuple = ("test_user_id", "test_user", True, "test_token") +MOCK_AUTH: AuthTuple = ("mock_user_id", "mock_username", False, "mock_token") def _setup_responses_mock(mocker: MockerFixture, create_behavior: Any) -> None: @@ -87,6 +88,12 @@ def mock_empty_llm_response_fixture(mocker: MockerFixture) -> None: _setup_responses_mock(mocker, mocker.AsyncMock(return_value=mock_response)) +@pytest.fixture(name="mock_auth_resolvers") +def mock_auth_resolvers_fixture(mocker: MockerFixture) -> None: + """Mock authorization resolvers for endpoint tests.""" + mock_authorization_resolvers(mocker) + + @pytest.fixture(name="mock_api_connection_error") def mock_api_connection_error_fixture(mocker: MockerFixture) -> None: """Mock responses.create() to raise APIConnectionError.""" @@ -96,6 +103,47 @@ def mock_api_connection_error_fixture(mocker: MockerFixture) -> None: ) +# --- Test _build_instructions --- + + +@pytest.mark.parametrize( + ("systeminfo_kwargs", "expected_contains", "expected_not_contains"), + [ + pytest.param( + {"os": "RHEL", "version": "9.3", "arch": "x86_64"}, + ["OS: RHEL", "Version: 9.3", "Architecture: x86_64"], + [], + id="full_systeminfo", + ), + pytest.param( + {"os": "RHEL", "version": "", "arch": ""}, + ["OS: RHEL"], + ["Version:", "Architecture:"], + id="partial_systeminfo", + ), + pytest.param( + {}, + [constants.DEFAULT_SYSTEM_PROMPT], + ["OS:", "Version:", "Architecture:"], + id="empty_systeminfo", + ), + ], +) +def test_build_instructions( + systeminfo_kwargs: dict[str, str], + expected_contains: list[str], + expected_not_contains: list[str], +) -> None: + """Test _build_instructions with various system info combinations.""" + systeminfo = RlsapiV1SystemInfo(**systeminfo_kwargs) + result = _build_instructions(systeminfo) + + for expected in expected_contains: + assert expected in result + for not_expected in expected_not_contains: + assert not_expected not in result + + # --- Test _get_default_model_id --- @@ -151,7 +199,9 @@ async def test_retrieve_simple_response_success( mock_configuration: AppConfig, mock_llm_response: None ) -> None: """Test retrieve_simple_response returns LLM response text.""" - response = await retrieve_simple_response("How do I list files?") + response = await retrieve_simple_response( + "How do I list files?", constants.DEFAULT_SYSTEM_PROMPT + ) assert response == "This is a test LLM response." @@ -160,7 +210,9 @@ async def test_retrieve_simple_response_empty_output( mock_configuration: AppConfig, mock_empty_llm_response: None ) -> None: """Test retrieve_simple_response handles empty LLM output.""" - response = await retrieve_simple_response("Test question") + response = await retrieve_simple_response( + "Test question", constants.DEFAULT_SYSTEM_PROMPT + ) assert response == "" @@ -170,7 +222,7 @@ async def test_retrieve_simple_response_api_connection_error( ) -> None: """Test retrieve_simple_response propagates APIConnectionError.""" with pytest.raises(APIConnectionError): - await retrieve_simple_response("Test question") + await retrieve_simple_response("Test question", constants.DEFAULT_SYSTEM_PROMPT) # --- Test infer_endpoint --- @@ -178,10 +230,11 @@ async def test_retrieve_simple_response_api_connection_error( @pytest.mark.asyncio async def test_infer_minimal_request( - mocker: MockerFixture, mock_configuration: AppConfig, mock_llm_response: None + mock_configuration: AppConfig, + mock_llm_response: None, + mock_auth_resolvers: None, ) -> None: """Test /infer endpoint returns valid response with LLM text.""" - mock_authorization_resolvers(mocker) request = RlsapiV1InferRequest(question="How do I list files?") response = await infer_endpoint(infer_request=request, auth=MOCK_AUTH) @@ -194,10 +247,11 @@ async def test_infer_minimal_request( @pytest.mark.asyncio async def test_infer_full_context_request( - mocker: MockerFixture, mock_configuration: AppConfig, mock_llm_response: None + mock_configuration: AppConfig, + mock_llm_response: None, + mock_auth_resolvers: None, ) -> None: """Test /infer endpoint handles full context (stdin, attachments, terminal).""" - mock_authorization_resolvers(mocker) request = RlsapiV1InferRequest( question="Why did this command fail?", context=RlsapiV1Context( @@ -217,10 +271,11 @@ async def test_infer_full_context_request( @pytest.mark.asyncio async def test_infer_generates_unique_request_ids( - mocker: MockerFixture, mock_configuration: AppConfig, mock_llm_response: None + mock_configuration: AppConfig, + mock_llm_response: None, + mock_auth_resolvers: None, ) -> None: """Test that each /infer call generates a unique request_id.""" - mock_authorization_resolvers(mocker) request = RlsapiV1InferRequest(question="How do I list files?") response1 = await infer_endpoint(infer_request=request, auth=MOCK_AUTH) @@ -231,12 +286,11 @@ async def test_infer_generates_unique_request_ids( @pytest.mark.asyncio async def test_infer_api_connection_error_returns_503( - mocker: MockerFixture, mock_configuration: AppConfig, mock_api_connection_error: None, + mock_auth_resolvers: None, ) -> None: """Test /infer endpoint returns 503 when LLM service is unavailable.""" - mock_authorization_resolvers(mocker) request = RlsapiV1InferRequest(question="Test question") with pytest.raises(HTTPException) as exc_info: @@ -247,12 +301,11 @@ async def test_infer_api_connection_error_returns_503( @pytest.mark.asyncio async def test_infer_empty_llm_response_returns_fallback( - mocker: MockerFixture, mock_configuration: AppConfig, mock_empty_llm_response: None, + mock_auth_resolvers: None, ) -> None: """Test /infer endpoint returns fallback text when LLM returns empty response.""" - mock_authorization_resolvers(mocker) request = RlsapiV1InferRequest(question="Test question") response = await infer_endpoint(infer_request=request, auth=MOCK_AUTH)