diff --git a/src/app/endpoints/rlsapi_v1.py b/src/app/endpoints/rlsapi_v1.py index a10e4770..ecbf5bff 100644 --- a/src/app/endpoints/rlsapi_v1.py +++ b/src/app/endpoints/rlsapi_v1.py @@ -31,6 +31,7 @@ ) from models.rlsapi.requests import RlsapiV1InferRequest, RlsapiV1SystemInfo from models.rlsapi.responses import RlsapiV1InferData, RlsapiV1InferResponse +from app.endpoints.query_v2 import get_mcp_tools from observability import InferenceEventData, build_inference_event, send_splunk_event from utils.responses import extract_text_from_response_output_item from utils.suid import get_suid @@ -142,7 +143,9 @@ def _get_default_model_id() -> str: ) -async def retrieve_simple_response(question: str, instructions: str) -> str: +async def retrieve_simple_response( + question: str, instructions: str, tools: list | None = None +) -> str: """Retrieve a simple response from the LLM for a stateless query. Uses the Responses API for simple stateless inference, consistent with @@ -151,6 +154,7 @@ async def retrieve_simple_response(question: str, instructions: str) -> str: Args: question: The combined user input (question + context). instructions: System instructions for the LLM. + tools: Optional list of MCP tool definitions for the LLM. Returns: The LLM-generated response text. @@ -168,6 +172,7 @@ async def retrieve_simple_response(question: str, instructions: str) -> str: input=question, model=model_id, instructions=instructions, + tools=tools or [], stream=False, store=False, ) @@ -255,13 +260,16 @@ async def infer_endpoint( input_source = infer_request.get_input_source() instructions = _build_instructions(infer_request.context.systeminfo) + mcp_tools = get_mcp_tools(configuration.mcp_servers) logger.debug( "Request %s: Combined input source length: %d", request_id, len(input_source) ) start_time = time.monotonic() try: - response_text = await retrieve_simple_response(input_source, instructions) + response_text = await retrieve_simple_response( + input_source, instructions, tools=mcp_tools + ) inference_time = time.monotonic() - start_time except APIConnectionError as e: inference_time = time.monotonic() - start_time diff --git a/tests/integration/endpoints/test_rlsapi_v1_integration.py b/tests/integration/endpoints/test_rlsapi_v1_integration.py index 746b0299..8f7586d3 100644 --- a/tests/integration/endpoints/test_rlsapi_v1_integration.py +++ b/tests/integration/endpoints/test_rlsapi_v1_integration.py @@ -349,6 +349,104 @@ async def test_rlsapi_v1_infer_input_source_combination( assert expected in input_content +# ========================================== +# MCP Tools Passthrough Tests +# ========================================== + + +@pytest.mark.asyncio +async def test_rlsapi_v1_infer_no_mcp_servers_passes_empty_tools( + rlsapi_config: AppConfig, + mock_authorization: None, + test_auth: AuthTuple, + mocker: MockerFixture, +) -> None: + """Regression: no MCP servers configured passes empty tools list. + + When mcp_servers is empty (the default), get_mcp_tools returns [], + and responses.create should receive tools=[]. + """ + _ = rlsapi_config + + mock_response = mocker.Mock() + mock_response.output = [_create_mock_response_output(mocker, "response text")] + + mock_responses = mocker.Mock() + mock_responses.create = mocker.AsyncMock(return_value=mock_response) + + mock_client = mocker.Mock() + mock_client.responses = mock_responses + + mock_holder_class = mocker.patch( + "app.endpoints.rlsapi_v1.AsyncLlamaStackClientHolder" + ) + mock_holder_class.return_value.get_client.return_value = mock_client + + mocker.patch( + "app.endpoints.rlsapi_v1.get_mcp_tools", + return_value=[], + ) + + await infer_endpoint( + infer_request=RlsapiV1InferRequest(question="How do I list files?"), + request=_create_mock_request(mocker), + background_tasks=_create_mock_background_tasks(mocker), + auth=test_auth, + ) + + call_kwargs = mock_responses.create.call_args.kwargs + assert call_kwargs["tools"] == [] + + +@pytest.mark.asyncio +async def test_rlsapi_v1_infer_mcp_tools_passed_to_llm( + rlsapi_config: AppConfig, + mock_authorization: None, + test_auth: AuthTuple, + mocker: MockerFixture, +) -> None: + """Test that MCP tool definitions are forwarded to responses.create().""" + _ = rlsapi_config + + mock_response = mocker.Mock() + mock_response.output = [_create_mock_response_output(mocker, "enriched response")] + + mock_responses = mocker.Mock() + mock_responses.create = mocker.AsyncMock(return_value=mock_response) + + mock_client = mocker.Mock() + mock_client.responses = mock_responses + + mock_holder_class = mocker.patch( + "app.endpoints.rlsapi_v1.AsyncLlamaStackClientHolder" + ) + mock_holder_class.return_value.get_client.return_value = mock_client + + mcp_tools = [ + { + "type": "mcp", + "server_label": "rag-knowledge-base", + "server_url": "http://rag-server:8080/sse", + "require_approval": "never", + } + ] + mocker.patch( + "app.endpoints.rlsapi_v1.get_mcp_tools", + return_value=mcp_tools, + ) + + response = await infer_endpoint( + infer_request=RlsapiV1InferRequest(question="How do I configure SELinux?"), + request=_create_mock_request(mocker), + background_tasks=_create_mock_background_tasks(mocker), + auth=test_auth, + ) + + call_kwargs = mock_responses.create.call_args.kwargs + assert call_kwargs["tools"] == mcp_tools + assert response.data.text == "enriched response" + + # ========================================== # Skip RAG Tests # ========================================== diff --git a/tests/unit/app/endpoints/test_rlsapi_v1.py b/tests/unit/app/endpoints/test_rlsapi_v1.py index 54aa95c9..1b00b774 100644 --- a/tests/unit/app/endpoints/test_rlsapi_v1.py +++ b/tests/unit/app/endpoints/test_rlsapi_v1.py @@ -518,3 +518,92 @@ def test_infer_request_question_is_stripped() -> None: """Test that question whitespace is stripped during validation.""" request = RlsapiV1InferRequest(question=" How do I list files? ") assert request.question == "How do I list files?" + + +# --- Test MCP tools passthrough --- + + +def _setup_responses_mock_with_capture( + mocker: MockerFixture, response_text: str = "Test response." +) -> Any: + """Set up responses.create mock and return the create mock for assertion. + + Unlike _setup_responses_mock, this returns the mock_create object so + callers can inspect call_args to verify tools were passed correctly. + + Args: + mocker: The pytest mocker fixture. + response_text: Text for the mock LLM response. + + Returns: + The mock create coroutine, whose call_args can be inspected. + """ + mock_response = mocker.Mock() + mock_response.output = [_create_mock_response_output(mocker, response_text)] + + mock_create = mocker.AsyncMock(return_value=mock_response) + _setup_responses_mock(mocker, mock_create) + return mock_create + + +@pytest.mark.asyncio +async def test_retrieve_simple_response_passes_tools( + mocker: MockerFixture, mock_configuration: AppConfig +) -> None: + """Test that retrieve_simple_response forwards tools to responses.create().""" + mock_create = _setup_responses_mock_with_capture(mocker) + tools = [ + { + "type": "mcp", + "server_label": "test-mcp", + "server_url": "http://localhost:9000/sse", + "require_approval": "never", + } + ] + + await retrieve_simple_response("Test question", "Instructions", tools=tools) + + mock_create.assert_called_once() + call_kwargs = mock_create.call_args.kwargs + assert call_kwargs["tools"] == tools + + +@pytest.mark.asyncio +async def test_retrieve_simple_response_defaults_to_empty_tools( + mocker: MockerFixture, mock_configuration: AppConfig +) -> None: + """Test that retrieve_simple_response passes empty list when tools is None.""" + mock_create = _setup_responses_mock_with_capture(mocker) + + await retrieve_simple_response("Test question", "Instructions") + + mock_create.assert_called_once() + call_kwargs = mock_create.call_args.kwargs + assert call_kwargs["tools"] == [] + + +@pytest.mark.asyncio +async def test_infer_endpoint_calls_get_mcp_tools( + mocker: MockerFixture, + mock_configuration: AppConfig, + mock_llm_response: None, + mock_auth_resolvers: None, +) -> None: + """Test that infer_endpoint calls get_mcp_tools with configuration.mcp_servers.""" + mock_get_mcp_tools = mocker.patch( + "app.endpoints.rlsapi_v1.get_mcp_tools", + return_value=[{"type": "mcp", "server_label": "test"}], + ) + + infer_request = RlsapiV1InferRequest(question="How do I list files?") + mock_request = _create_mock_request(mocker) + mock_background_tasks = _create_mock_background_tasks(mocker) + + await infer_endpoint( + infer_request=infer_request, + request=mock_request, + background_tasks=mock_background_tasks, + auth=MOCK_AUTH, + ) + + mock_get_mcp_tools.assert_called_once_with(mock_configuration.mcp_servers)