diff --git a/src/app/endpoints/rlsapi_v1.py b/src/app/endpoints/rlsapi_v1.py index 8e4c12f5..ffb7227a 100644 --- a/src/app/endpoints/rlsapi_v1.py +++ b/src/app/endpoints/rlsapi_v1.py @@ -5,22 +5,31 @@ """ import logging -from typing import Annotated, Any +from typing import Annotated, Any, cast -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, HTTPException +from llama_stack_client import APIConnectionError # type: ignore +from llama_stack_client.types import UserMessage # type: ignore +from llama_stack_client.types.alpha.agents.turn import Turn +import constants from authentication import get_auth_dependency from authentication.interface import AuthTuple from authorization.middleware import authorize +from client import AsyncLlamaStackClientHolder +from configuration import configuration from models.config import Action from models.responses import ( ForbiddenResponse, + ServiceUnavailableResponse, UnauthorizedResponse, UnprocessableEntityResponse, ) from models.rlsapi.requests import RlsapiV1InferRequest from models.rlsapi.responses import RlsapiV1InferData, RlsapiV1InferResponse +from utils.endpoints import get_temp_agent from utils.suid import get_suid +from utils.types import content_to_str logger = logging.getLogger(__name__) router = APIRouter(tags=["rlsapi-v1"]) @@ -33,9 +42,84 @@ ), 403: ForbiddenResponse.openapi_response(examples=["endpoint"]), 422: UnprocessableEntityResponse.openapi_response(), + 503: ServiceUnavailableResponse.openapi_response(), } +def _get_default_model_id() -> str: + """Get the default model ID from configuration. + + Returns the model identifier in Llama Stack format (provider/model). + + Returns: + The model identifier string. + + Raises: + HTTPException: If no model can be determined from configuration. + """ + if configuration.inference is None: + msg = "No inference configuration available" + logger.error(msg) + raise HTTPException( + status_code=503, + detail={"response": "Service configuration error", "cause": msg}, + ) + + model_id = configuration.inference.default_model + provider_id = configuration.inference.default_provider + + if model_id and provider_id: + return f"{provider_id}/{model_id}" + + msg = "No default model configured for rlsapi v1 inference" + logger.error(msg) + raise HTTPException( + status_code=503, + detail={"response": "Service configuration error", "cause": msg}, + ) + + +async def retrieve_simple_response(question: str) -> str: + """Retrieve a simple response from the LLM for a stateless query. + + Creates a temporary agent, sends a single turn with the user's question, + and returns the LLM response text. No conversation persistence or tools. + + Args: + question: The combined user input (question + context). + + Returns: + The LLM-generated response text. + + Raises: + APIConnectionError: If the Llama Stack service is unreachable. + HTTPException: 503 if no model is configured. + """ + client = AsyncLlamaStackClientHolder().get_client() + model_id = _get_default_model_id() + + logger.debug("Using model %s for rlsapi v1 inference", model_id) + + agent, session_id, _ = await get_temp_agent( + client, model_id, constants.DEFAULT_SYSTEM_PROMPT + ) + + response = await agent.create_turn( + messages=[UserMessage(role="user", content=question).model_dump()], + session_id=session_id, + stream=False, + ) + response = cast(Turn, response) + + if getattr(response, "output_message", None) is None: + return "" + + if getattr(response.output_message, "content", None) is None: + return "" + + return content_to_str(response.output_message.content) + + @router.post("/infer", responses=infer_responses) @authorize(Action.RLSAPI_V1_INFER) async def infer_endpoint( @@ -55,6 +139,9 @@ async def infer_endpoint( Returns: RlsapiV1InferResponse containing the generated response text and request ID. + + Raises: + HTTPException: 503 if the LLM service is unavailable. """ # Authentication enforced by get_auth_dependency(), authorization by @authorize decorator. _ = auth @@ -66,14 +153,28 @@ async def infer_endpoint( # Combine all input sources (question, stdin, attachments, terminal) input_source = infer_request.get_input_source() - logger.debug("Combined input source length: %d", len(input_source)) - - # NOTE(major): Placeholder until we wire up the LLM integration. - response_text = ( - "Inference endpoint is functional. " - "LLM integration will be added in a subsequent update." + logger.debug( + "Request %s: Combined input source length: %d", request_id, len(input_source) ) + try: + response_text = await retrieve_simple_response(input_source) + except APIConnectionError as e: + logger.error( + "Unable to connect to Llama Stack for request %s: %s", request_id, e + ) + response = ServiceUnavailableResponse( + backend_name="Llama Stack", + cause=str(e), + ) + raise HTTPException(**response.model_dump()) from e + + if not response_text: + logger.warning("Empty response from LLM for request %s", request_id) + response_text = constants.UNABLE_TO_PROCESS_RESPONSE + + logger.info("Completed rlsapi v1 /infer request %s", request_id) + return RlsapiV1InferResponse( data=RlsapiV1InferData( text=response_text, diff --git a/tests/unit/app/endpoints/test_rlsapi_v1.py b/tests/unit/app/endpoints/test_rlsapi_v1.py index 596754e5..39e66c45 100644 --- a/tests/unit/app/endpoints/test_rlsapi_v1.py +++ b/tests/unit/app/endpoints/test_rlsapi_v1.py @@ -1,11 +1,22 @@ """Unit tests for the rlsapi v1 /infer REST API endpoint.""" +# pylint: disable=protected-access +# pylint: disable=unused-argument + import pytest +from fastapi import HTTPException, status +from llama_stack_client import APIConnectionError from pydantic import ValidationError from pytest_mock import MockerFixture -from app.endpoints.rlsapi_v1 import infer_endpoint +import constants +from app.endpoints.rlsapi_v1 import ( + _get_default_model_id, + infer_endpoint, + retrieve_simple_response, +) from authentication.interface import AuthTuple +from configuration import AppConfig from models.rlsapi.requests import ( RlsapiV1Attachment, RlsapiV1Context, @@ -14,28 +25,200 @@ RlsapiV1Terminal, ) from models.rlsapi.responses import RlsapiV1InferResponse +from tests.unit.conftest import AgentFixtures from tests.unit.utils.auth_helpers import mock_authorization_resolvers from utils.suid import check_suid MOCK_AUTH: AuthTuple = ("test_user_id", "test_user", True, "test_token") +@pytest.fixture(name="mock_configuration") +def mock_configuration_fixture( + mocker: MockerFixture, minimal_config: AppConfig +) -> AppConfig: + """Extend minimal_config with inference defaults and patch it.""" + minimal_config.inference.default_model = "gpt-4-turbo" + minimal_config.inference.default_provider = "openai" + mocker.patch("app.endpoints.rlsapi_v1.configuration", minimal_config) + return minimal_config + + +@pytest.fixture(name="mock_llm_response") +def mock_llm_response_fixture( + mocker: MockerFixture, prepare_agent_mocks: AgentFixtures +) -> None: + """Mock the LLM integration for successful responses.""" + mock_client, mock_agent = prepare_agent_mocks + + # Create mock output message with content + mock_output_message = mocker.Mock() + mock_output_message.content = "This is a test LLM response." + + # Create mock turn response + mock_turn = mocker.Mock() + mock_turn.output_message = mock_output_message + mock_turn.steps = [] + + # Use AsyncMock for async method + mock_agent.create_turn = mocker.AsyncMock(return_value=mock_turn) + + # Mock get_temp_agent to return our mock agent + mocker.patch( + "app.endpoints.rlsapi_v1.get_temp_agent", + return_value=(mock_agent, "test_session_id", None), + ) + + # Mock the client holder + mock_client_holder = mocker.Mock() + mock_client_holder.get_client.return_value = mock_client + mocker.patch( + "app.endpoints.rlsapi_v1.AsyncLlamaStackClientHolder", + return_value=mock_client_holder, + ) + + +@pytest.fixture(name="mock_empty_llm_response") +def mock_empty_llm_response_fixture( + mocker: MockerFixture, prepare_agent_mocks: AgentFixtures +) -> None: + """Mock the LLM integration for empty responses (output_message=None).""" + mock_client, mock_agent = prepare_agent_mocks + + # Create mock turn response with no output + mock_turn = mocker.Mock() + mock_turn.output_message = None + mock_turn.steps = [] + + # Use AsyncMock for async method + mock_agent.create_turn = mocker.AsyncMock(return_value=mock_turn) + + # Mock get_temp_agent to return our mock agent + mocker.patch( + "app.endpoints.rlsapi_v1.get_temp_agent", + return_value=(mock_agent, "test_session_id", None), + ) + + # Mock the client holder + mock_client_holder = mocker.Mock() + mock_client_holder.get_client.return_value = mock_client + mocker.patch( + "app.endpoints.rlsapi_v1.AsyncLlamaStackClientHolder", + return_value=mock_client_holder, + ) + + +@pytest.fixture(name="mock_api_connection_error") +def mock_api_connection_error_fixture(mocker: MockerFixture) -> None: + """Mock AsyncLlamaStackClientHolder to raise APIConnectionError.""" + mock_client_holder = mocker.Mock() + mock_client_holder.get_client.side_effect = APIConnectionError( + request=mocker.Mock() + ) + mocker.patch( + "app.endpoints.rlsapi_v1.AsyncLlamaStackClientHolder", + return_value=mock_client_holder, + ) + + +# --- Test _get_default_model_id --- + + +def test_get_default_model_id_success(mock_configuration: AppConfig) -> None: + """Test _get_default_model_id returns properly formatted model ID.""" + model_id = _get_default_model_id() + assert model_id == "openai/gpt-4-turbo" + + +@pytest.mark.parametrize( + ("config_setup", "expected_message"), + [ + pytest.param( + "missing_model", + "No default model configured", + id="missing_model_config", + ), + pytest.param( + "none_inference", + "No inference configuration available", + id="none_inference_config", + ), + ], +) +def test_get_default_model_id_errors( + mocker: MockerFixture, + minimal_config: AppConfig, + config_setup: str, + expected_message: str, +) -> None: + """Test _get_default_model_id raises HTTPException for invalid configs.""" + if config_setup == "missing_model": + # Config exists but no model/provider defaults + mocker.patch("app.endpoints.rlsapi_v1.configuration", minimal_config) + else: + # inference is None + mock_config = mocker.Mock() + mock_config.inference = None + mocker.patch("app.endpoints.rlsapi_v1.configuration", mock_config) + + with pytest.raises(HTTPException) as exc_info: + _get_default_model_id() + + assert exc_info.value.status_code == 503 + assert expected_message in str(exc_info.value.detail) + + +# --- Test retrieve_simple_response --- + + @pytest.mark.asyncio -async def test_infer_minimal_request(mocker: MockerFixture) -> None: - """Test /infer endpoint returns valid response with UUID request_id.""" +async def test_retrieve_simple_response_success( + mock_configuration: AppConfig, mock_llm_response: None +) -> None: + """Test retrieve_simple_response returns LLM response text.""" + response = await retrieve_simple_response("How do I list files?") + assert response == "This is a test LLM response." + + +@pytest.mark.asyncio +async def test_retrieve_simple_response_empty_output( + mock_configuration: AppConfig, mock_empty_llm_response: None +) -> None: + """Test retrieve_simple_response handles empty LLM output.""" + response = await retrieve_simple_response("Test question") + assert response == "" + + +@pytest.mark.asyncio +async def test_retrieve_simple_response_api_connection_error( + mock_configuration: AppConfig, mock_api_connection_error: None +) -> None: + """Test retrieve_simple_response propagates APIConnectionError.""" + with pytest.raises(APIConnectionError): + await retrieve_simple_response("Test question") + + +# --- Test infer_endpoint --- + + +@pytest.mark.asyncio +async def test_infer_minimal_request( + mocker: MockerFixture, mock_configuration: AppConfig, mock_llm_response: None +) -> None: + """Test /infer endpoint returns valid response with LLM text.""" mock_authorization_resolvers(mocker) request = RlsapiV1InferRequest(question="How do I list files?") response = await infer_endpoint(infer_request=request, auth=MOCK_AUTH) assert isinstance(response, RlsapiV1InferResponse) - assert response.data.text - # Verify request_id is valid SUID + assert response.data.text == "This is a test LLM response." assert check_suid(response.data.request_id) @pytest.mark.asyncio -async def test_infer_full_context_request(mocker: MockerFixture) -> None: +async def test_infer_full_context_request( + mocker: MockerFixture, mock_configuration: AppConfig, mock_llm_response: None +) -> None: """Test /infer endpoint handles full context (stdin, attachments, terminal).""" mock_authorization_resolvers(mocker) request = RlsapiV1InferRequest( @@ -56,7 +239,9 @@ async def test_infer_full_context_request(mocker: MockerFixture) -> None: @pytest.mark.asyncio -async def test_infer_generates_unique_request_ids(mocker: MockerFixture) -> None: +async def test_infer_generates_unique_request_ids( + mocker: MockerFixture, mock_configuration: AppConfig, mock_llm_response: None +) -> None: """Test that each /infer call generates a unique request_id.""" mock_authorization_resolvers(mocker) request = RlsapiV1InferRequest(question="How do I list files?") @@ -67,8 +252,48 @@ async def test_infer_generates_unique_request_ids(mocker: MockerFixture) -> None assert response1.data.request_id != response2.data.request_id +@pytest.mark.asyncio +async def test_infer_api_connection_error_returns_503( + mocker: MockerFixture, + mock_configuration: AppConfig, + mock_api_connection_error: None, +) -> None: + """Test /infer endpoint returns 503 when LLM service is unavailable.""" + mock_authorization_resolvers(mocker) + request = RlsapiV1InferRequest(question="Test question") + + with pytest.raises(HTTPException) as exc_info: + await infer_endpoint(infer_request=request, auth=MOCK_AUTH) + + assert exc_info.value.status_code == status.HTTP_503_SERVICE_UNAVAILABLE + + +@pytest.mark.asyncio +async def test_infer_empty_llm_response_returns_fallback( + mocker: MockerFixture, + mock_configuration: AppConfig, + mock_empty_llm_response: None, +) -> None: + """Test /infer endpoint returns fallback text when LLM returns empty response.""" + mock_authorization_resolvers(mocker) + request = RlsapiV1InferRequest(question="Test question") + + response = await infer_endpoint(infer_request=request, auth=MOCK_AUTH) + + assert response.data.text == constants.UNABLE_TO_PROCESS_RESPONSE + + +# --- Test request validation --- + + @pytest.mark.parametrize("invalid_question", ["", " ", "\t\n"]) def test_infer_rejects_invalid_question(invalid_question: str) -> None: """Test that empty or whitespace-only questions are rejected.""" with pytest.raises(ValidationError): RlsapiV1InferRequest(question=invalid_question) + + +def test_infer_request_question_is_stripped() -> None: + """Test that question whitespace is stripped during validation.""" + request = RlsapiV1InferRequest(question=" How do I list files? ") + assert request.question == "How do I list files?"