diff --git a/mcpgateway/services/gateway_service.py b/mcpgateway/services/gateway_service.py index 2888ba67a..9df0ae8c2 100644 --- a/mcpgateway/services/gateway_service.py +++ b/mcpgateway/services/gateway_service.py @@ -2043,6 +2043,7 @@ async def _handle_gateway_failure(self, gateway: DbGateway) -> None: None Examples: + >>> from mcpgateway.services.gateway_service import GatewayService >>> service = GatewayService() >>> gateway = type('Gateway', (), { ... 'id': 'gw1', 'name': 'test_gw', 'enabled': True, 'reachable': True @@ -3463,76 +3464,76 @@ def get_httpx_client_factory( if tools: logger.info(f"Fetched {len(tools)} tools from gateway") - # Fetch resources if supported - resources = [] - logger.debug(f"Checking for resources support: {capabilities.get('resources')}") - if capabilities.get("resources"): - try: - response = await session.list_resources() - raw_resources = response.resources - for resource in raw_resources: - resource_data = resource.model_dump(by_alias=True, exclude_none=True) - # Convert AnyUrl to string if present - if "uri" in resource_data and hasattr(resource_data["uri"], "unicode_string"): - resource_data["uri"] = str(resource_data["uri"]) - # Add default content if not present - if "content" not in resource_data: - resource_data["content"] = "" - try: - resources.append(ResourceCreate.model_validate(resource_data)) - except Exception: - # If validation fails, create minimal resource - resources.append( - ResourceCreate( - uri=str(resource_data.get("uri", "")), - name=resource_data.get("name", ""), - description=resource_data.get("description"), - mime_type=resource_data.get("mimeType"), - uri_template=resource_data.get("uriTemplate") or None, - content="", - ) + # Fetch resources if supported + resources = [] + logger.debug(f"Checking for resources support: {capabilities.get('resources')}") + if capabilities.get("resources"): + try: + response = await session.list_resources() + raw_resources = response.resources + for resource in raw_resources: + resource_data = resource.model_dump(by_alias=True, exclude_none=True) + # Convert AnyUrl to string if present + if "uri" in resource_data and hasattr(resource_data["uri"], "unicode_string"): + resource_data["uri"] = str(resource_data["uri"]) + # Add default content if not present + if "content" not in resource_data: + resource_data["content"] = "" + try: + resources.append(ResourceCreate.model_validate(resource_data)) + except Exception: + # If validation fails, create minimal resource + resources.append( + ResourceCreate( + uri=str(resource_data.get("uri", "")), + name=resource_data.get("name", ""), + description=resource_data.get("description"), + mime_type=resource_data.get("mimeType"), + uri_template=resource_data.get("uriTemplate") or None, + content="", ) - logger.info(f"Fetched {len(resources)} resources from gateway") - except Exception as e: - logger.warning(f"Failed to fetch resources: {e}") - - # resource template URI - try: - response_templates = await session.list_resource_templates() - raw_resources_templates = response_templates.resourceTemplates - resource_templates = [] - for resource_template in raw_resources_templates: - resource_template_data = resource_template.model_dump(by_alias=True, exclude_none=True) - - if "uriTemplate" in resource_template_data: # and hasattr(resource_template_data["uriTemplate"], "unicode_string"): - resource_template_data["uri_template"] = str(resource_template_data["uriTemplate"]) - resource_template_data["uri"] = str(resource_template_data["uriTemplate"]) - - if "content" not in resource_template_data: - resource_template_data["content"] = "" + ) + logger.info(f"Fetched {len(resources)} resources from gateway") + except Exception as e: + logger.warning(f"Failed to fetch resources: {e}") - resources.append(ResourceCreate.model_validate(resource_template_data)) - resource_templates.append(ResourceCreate.model_validate(resource_template_data)) - logger.info(f"Fetched {len(resource_templates)} resource templates from gateway") - except Exception as e: - logger.warning(f"Failed to fetch resource templates: {e}") + # resource template URI + try: + response_templates = await session.list_resource_templates() + raw_resources_templates = response_templates.resourceTemplates + resource_templates = [] + for resource_template in raw_resources_templates: + resource_template_data = resource_template.model_dump(by_alias=True, exclude_none=True) + + if "uriTemplate" in resource_template_data: # and hasattr(resource_template_data["uriTemplate"], "unicode_string"): + resource_template_data["uri_template"] = str(resource_template_data["uriTemplate"]) + resource_template_data["uri"] = str(resource_template_data["uriTemplate"]) + + if "content" not in resource_template_data: + resource_template_data["content"] = "" + + resources.append(ResourceCreate.model_validate(resource_template_data)) + resource_templates.append(ResourceCreate.model_validate(resource_template_data)) + logger.info(f"Fetched {len(resource_templates)} resource templates from gateway") + except Exception as e: + logger.warning(f"Failed to fetch resource templates: {e}") - # Fetch prompts if supported - prompts = [] - logger.debug(f"Checking for prompts support: {capabilities.get('prompts')}") - if capabilities.get("prompts"): - try: - response = await session.list_prompts() - raw_prompts = response.prompts - for prompt in raw_prompts: - prompt_data = prompt.model_dump(by_alias=True, exclude_none=True) - # Add default template if not present - if "template" not in prompt_data: - prompt_data["template"] = "" - prompts.append(PromptCreate.model_validate(prompt_data)) - logger.info(f"Fetched {len(prompts)} prompts from gateway") - except Exception as e: - logger.warning(f"Failed to fetch prompts: {e}") + # Fetch prompts if supported + prompts = [] + logger.debug(f"Checking for prompts support: {capabilities.get('prompts')}") + if capabilities.get("prompts"): + try: + response = await session.list_prompts() + raw_prompts = response.prompts + for prompt in raw_prompts: + prompt_data = prompt.model_dump(by_alias=True, exclude_none=True) + # Add default template if not present + if "template" not in prompt_data: + prompt_data["template"] = "" + prompts.append(PromptCreate.model_validate(prompt_data)) + logger.info(f"Fetched {len(prompts)} prompts from gateway") + except Exception as e: + logger.warning(f"Failed to fetch prompts: {e}") return capabilities, tools, resources, prompts raise GatewayConnectionError(f"Failed to initialize gateway at{server_url}") diff --git a/mcpgateway/services/resource_service.py b/mcpgateway/services/resource_service.py index 591b00d38..264af6583 100644 --- a/mcpgateway/services/resource_service.py +++ b/mcpgateway/services/resource_service.py @@ -25,11 +25,16 @@ import mimetypes import os import re +import ssl import time from typing import Any, AsyncGenerator, Dict, List, Optional, Union import uuid # Third-Party +import httpx +from mcp import ClientSession +from mcp.client.sse import sse_client +from mcp.client.streamable_http import streamablehttp_client import parse from sqlalchemy import and_, case, delete, desc, Float, func, not_, or_, select from sqlalchemy.exc import IntegrityError @@ -39,6 +44,7 @@ from mcpgateway.common.models import ResourceContent, ResourceTemplate, TextContent from mcpgateway.config import settings from mcpgateway.db import EmailTeam +from mcpgateway.db import Gateway as DbGateway from mcpgateway.db import Resource as DbResource from mcpgateway.db import ResourceMetric from mcpgateway.db import ResourceSubscription as DbSubscription @@ -47,10 +53,13 @@ from mcpgateway.schemas import ResourceCreate, ResourceMetrics, ResourceRead, ResourceSubscription, ResourceUpdate, TopPerformer from mcpgateway.services.event_service import EventService from mcpgateway.services.logging_service import LoggingService +from mcpgateway.services.oauth_manager import OAuthManager from mcpgateway.services.observability_service import current_trace_id, ObservabilityService from mcpgateway.utils.metrics_common import build_top_performers from mcpgateway.utils.pagination import decode_cursor, encode_cursor +from mcpgateway.utils.services_auth import decode_auth from mcpgateway.utils.sqlalchemy_modifier import json_contains_expr +from mcpgateway.utils.validate_signature import validate_signature # Plugin support imports (conditional) try: @@ -115,6 +124,7 @@ def __init__(self) -> None: """Initialize the resource service.""" self._event_service = EventService(channel_name="mcpgateway:resource_events") self._template_cache: Dict[str, ResourceTemplate] = {} + self.oauth_manager = OAuthManager(request_timeout=int(os.getenv("OAUTH_REQUEST_TIMEOUT", "30")), max_retries=int(os.getenv("OAUTH_MAX_RETRIES", "3"))) # Initialize plugin manager if plugins are enabled in settings self._plugin_manager = None @@ -672,6 +682,423 @@ async def _record_resource_metric(self, db: Session, resource: DbResource, start db.add(metric) db.commit() + async def _record_invoke_resource_metric(self, db: Session, resource_id: str, start_time: float, success: bool, error_message: Optional[str]) -> None: + """ + Records a metric for invoking resource. + + Args: + db: Database Session + resource_id: unique identifier to access & invoke resource + start_time: Monotonic start time of the access + success: True if successful, False otherwise + error_message: Error message if failed, None otherwise + """ + end_time = time.monotonic() + response_time = end_time - start_time + + metric = ResourceMetric( + resource_id=resource_id, + response_time=response_time, + is_success=success, + error_message=error_message, + ) + db.add(metric) + db.commit() + + def create_ssl_context(self, ca_certificate: str) -> ssl.SSLContext: + """Create an SSL context with the provided CA certificate. + + Args: + ca_certificate: CA certificate in PEM format + + Returns: + ssl.SSLContext: Configured SSL context + """ + ctx = ssl.create_default_context() + ctx.load_verify_locations(cadata=ca_certificate) + return ctx + + async def invoke_resource(self, db: Session, resource_id: str, resource_uri: str, resource_template_uri: Optional[str] = None) -> Any: + """ + Invoke a resource via its configured gateway using SSE or StreamableHTTP transport. + + This method determines the correct URI to invoke, loads the associated resource + and gateway from the database, validates certificates if applicable, prepares + authentication headers (OAuth, header-based, or none), and then connects to + the gateway to read the resource using the appropriate transport. + + The function supports: + - CA certificate validation / SSL context creation + - OAuth client-credentials and authorization-code flow + - Header-based auth + - SSE transport gateways + - StreamableHTTP transport gateways + + Args: + db (Session): + SQLAlchemy session for retrieving resource and gateway information. + resource_id (str): + ID of the resource to invoke. + resource_uri (str): + Direct resource URI configured for the resource. + resource_template_uri (Optional[str]): + URI from the template. Overrides `resource_uri` when provided. + + Returns: + Any: The text content returned by the remote resource, or ``None`` if the + gateway could not be contacted or an error occurred. + + Raises: + Exception: Any unhandled internal errors (e.g., DB issues). + + --- + Doctest Examples + ---------------- + + >>> class FakeDB: + ... "Simple DB stub returning fake resource and gateway rows." + ... def execute(self, query): + ... class Result: + ... def scalar_one_or_none(self): + ... # Return fake objects with the needed attributes + ... class FakeResource: + ... id = "res123" + ... name = "Demo Resource" + ... gateway_id = "gw1" + ... return FakeResource() + ... return Result() + + >>> class FakeGateway: + ... id = "gw1" + ... name = "Fake Gateway" + ... url = "https://fake.gateway" + ... ca_certificate = None + ... ca_certificate_sig = None + ... transport = "sse" + ... auth_type = None + ... auth_value = {} + + >>> # Monkeypatch the DB lookup for gateway + >>> def fake_execute_gateway(self, query): + ... class Result: + ... def scalar_one_or_none(self_inner): + ... return FakeGateway() + ... return Result() + + >>> FakeDB.execute_gateway = fake_execute_gateway + + >>> class FakeService: + ... "Service stub replacing network calls with predictable outputs." + ... async def invoke_resource(self, db, resource_id, resource_uri, resource_template_uri=None): + ... # Represent the behavior of a successful SSE response. + ... return "hello from gateway" + + >>> svc = FakeService() + >>> import asyncio + >>> asyncio.run(svc.invoke_resource(FakeDB(), "res123", "/test")) + 'hello from gateway' + + --- + Example: Template URI overrides resource URI + -------------------------------------------- + + >>> class FakeService2(FakeService): + ... async def invoke_resource(self, db, resource_id, resource_uri, resource_template_uri=None): + ... if resource_template_uri: + ... return f"using template: {resource_template_uri}" + ... return f"using direct: {resource_uri}" + + >>> svc2 = FakeService2() + >>> asyncio.run(svc2.invoke_resource(FakeDB(), "res123", "/direct", "/template")) + 'using template: /template' + + """ + + uri = None + if resource_uri and resource_template_uri: + uri = resource_template_uri + elif resource_uri: + uri = resource_uri + + logger.info(f"Invoking the resource: {uri}") + gateway_id = None + resource_info = None + resource_info = db.execute(select(DbResource).where(DbResource.id == resource_id)).scalar_one_or_none() + user_email = settings.platform_admin_email + + if resource_info: + gateway_id = getattr(resource_info, "gateway_id", None) + resource_name = getattr(resource_info, "name", None) + if gateway_id: + gateway = db.execute(select(DbGateway).where(DbGateway.id == gateway_id)).scalar_one_or_none() + + start_time = time.monotonic() + success = False + error_message = None + + # Create database span for observability dashboard + trace_id = current_trace_id.get() + db_span_id = None + db_span_ended = False + observability_service = ObservabilityService() if trace_id else None + + if trace_id and observability_service: + try: + db_span_id = observability_service.start_span( + db=db, + trace_id=trace_id, + name="invoke.resource", + attributes={ + "resource.name": resource_name if resource_name else "unknown", + "resource.id": str(resource_id) if resource_id else "unknown", + "resource.uri": str(uri) or "unknown", + "gateway.transport": getattr(gateway, "transport") or "uknown", + "gateway.url": getattr(gateway, "url") or "unknown", + }, + ) + logger.debug(f"✓ Created resource.read span: {db_span_id} for resource: {resource_id} & {uri}") + except Exception as e: + logger.warning(f"Failed to start the observability span for invoking resource: {e}") + db_span_id = None + + with create_span( + "invoke.resource", + { + "resource.name": resource_name if resource_name else "unknown", + "resource.id": str(resource_id) if resource_id else "unknown", + "resource.uri": str(uri) or "unknown", + "gateway.transport": getattr(gateway, "transport") or "uknown", + "gateway.url": getattr(gateway, "url") or "unknown", + }, + ) as span: + valid = False + if gateway.ca_certificate: + if settings.enable_ed25519_signing: + public_key_pem = settings.ed25519_public_key + valid = validate_signature(gateway.ca_certificate.encode(), gateway.ca_certificate_sig, public_key_pem) + else: + valid = True + + if valid: + ssl_context = self.create_ssl_context(gateway.ca_certificate) + else: + ssl_context = None + + def _get_httpx_client_factory( + headers: dict[str, str] | None = None, + timeout: httpx.Timeout | None = None, + auth: httpx.Auth | None = None, + ) -> httpx.AsyncClient: + """Factory function to create httpx.AsyncClient with optional CA certificate. + + Args: + headers: Optional headers for the client + timeout: Optional timeout for the client + auth: Optional auth for the client + + Returns: + httpx.AsyncClient: Configured HTTPX async client + """ + return httpx.AsyncClient( + verify=ssl_context if ssl_context else True, # pylint: disable=cell-var-from-loop + follow_redirects=True, + headers=headers, + timeout=timeout or httpx.Timeout(30.0), + auth=auth, + ) + + try: + # Handle different authentication types + headers = {} + if gateway and gateway.auth_type == "oauth" and gateway.oauth_config: + grant_type = gateway.oauth_config.get("grant_type", "client_credentials") + + if grant_type == "authorization_code": + # For Authorization Code flow, try to get stored tokens + try: + # First-Party + from mcpgateway.services.token_storage_service import TokenStorageService # pylint: disable=import-outside-toplevel + + token_storage = TokenStorageService(db) + # Get user-specific OAuth token + # if not user_email: + # if span: + # span.set_attribute("health.status", "unhealthy") + # span.set_attribute("error.message", "User email required for OAuth token") + # await self._handle_gateway_failure(gateway) + + access_token: str = await token_storage.get_user_token(gateway.id, user_email) + + if access_token: + headers["Authorization"] = f"Bearer {access_token}" + else: + if span: + span.set_attribute("health.status", "unhealthy") + span.set_attribute("error.message", "No valid OAuth token for user") + # await self._handle_gateway_failure(gateway) + + except Exception as e: + logger.error(f"Failed to obtain stored OAuth token for gateway {gateway.name}: {e}") + if span: + span.set_attribute("health.status", "unhealthy") + span.set_attribute("error.message", "Failed to obtain stored OAuth token") + # await self._handle_gateway_failure(gateway) + else: + # For Client Credentials flow, get token directly + try: + access_token: str = await self.oauth_manager.get_access_token(gateway.oauth_config) + headers["Authorization"] = f"Bearer {access_token}" + except Exception as e: + if span: + span.set_attribute("health.status", "unhealthy") + span.set_attribute("error.message", str(e)) + # await self._handle_gateway_failure(gateway) + else: + # Handle non-OAuth authentication (existing logic) + auth_data = gateway.auth_value or {} + if isinstance(auth_data, str): + headers = decode_auth(auth_data) + elif isinstance(auth_data, dict): + headers = {str(k): str(v) for k, v in auth_data.items()} + else: + headers = {} + + async def connect_to_sse_session(server_url: str, uri: str, authentication: Optional[Dict[str, str]] = None) -> str | None: + """ + Connect to an SSE-based gateway and retrieve the text content of a resource. + + This helper establishes an SSE (Server-Sent Events) session with the remote + gateway, initializes a `ClientSession`, invokes `read_resource()` for the + given URI, and returns the textual content from the first item in the + response's `contents` list. + + If any error occurs (network failure, unexpected response format, session + initialization failure, etc.), the method logs the exception and returns + ``None`` instead of raising. + + Args: + server_url (str): + The base URL of the SSE gateway to connect to. + uri (str): + The resource URI that should be requested from the gateway. + authentication (Optional[Dict[str, str]]): + Optional dictionary of headers (e.g., OAuth Bearer tokens) to + include in the SSE connection request. Defaults to an empty + dictionary when not provided. + + Returns: + str | None: + The text content returned by the remote resource, or ``None`` if the + SSE connection fails or the response is invalid. + + Notes: + - This function assumes the SSE client context manager yields: + ``(read_stream, write_stream, get_session_id)``. + - The expected response object from `session.read_resource()` must have a + `contents` attribute containing a list, where the first element has a + `text` attribute. + """ + if authentication is None: + authentication = {} + try: + async with sse_client(url=server_url, headers=authentication, timeout=settings.health_check_timeout, httpx_client_factory=_get_httpx_client_factory) as ( + read_stream, + write_stream, + _get_session_id, + ): + async with ClientSession(read_stream, write_stream) as session: + _ = await session.initialize() + resource_response = await session.read_resource(uri=uri) + return getattr(getattr(resource_response, "contents")[0], "text") + except Exception as e: + logger.debug(f"Exception while connecting to sse gateway: {e}") + return None + + async def connect_to_streamablehttp_server(server_url: str, uri: str, authentication: Optional[Dict[str, str]] = None) -> str | None: + """ + Connect to a StreamableHTTP gateway and retrieve the text content of a resource. + + This helper establishes a StreamableHTTP client session with the specified + gateway, initializes a `ClientSession`, invokes `read_resource()` for the + given URI, and returns the textual content from the first element in the + response's `contents` list. + + If any exception occurs during connection, session initialization, or + resource reading, the function logs the error and returns ``None`` instead + of propagating the exception. + + Args: + server_url (str): + The endpoint URL of the StreamableHTTP gateway. + uri (str): + The resource URI to request from the gateway. + authentication (Optional[Dict[str, str]]): + Optional dictionary of authentication headers (e.g., API keys or + Bearer tokens). Defaults to an empty dictionary when not provided. + + Returns: + str | None: + The text content returned by the StreamableHTTP resource, or ``None`` + if the connection fails or the response format is invalid. + + Notes: + - The `streamablehttp_client` context manager must yield a tuple: + ``(read_stream, write_stream, get_session_id)``. + - The expected `resource_response` returned by ``session.read_resource()`` + must contain a `contents` list, whose first element exposes a `text` + attribute. + """ + if authentication is None: + authentication = {} + try: + async with streamablehttp_client(url=server_url, headers=authentication, timeout=settings.health_check_timeout, httpx_client_factory=_get_httpx_client_factory) as ( + read_stream, + write_stream, + _get_session_id, + ): + async with ClientSession(read_stream, write_stream) as session: + _ = await session.initialize() + resource_response = await session.read_resource(uri=uri) + return getattr(getattr(resource_response, "contents")[0], "text") + except Exception as e: + logger.debug(f"Exception while connecting to streamablehttp gateway: {e}") + return None + + if span: + span.set_attribute("success", True) + span.set_attribute("duration.ms", (time.monotonic() - start_time) * 1000) + + resource_text = "" + if (gateway.transport).lower() == "sse": + resource_text = await connect_to_sse_session(server_url=gateway.url, authentication=headers, uri=uri) + else: + resource_text = await connect_to_streamablehttp_server(server_url=gateway.url, authentication=headers, uri=uri) + return resource_text + except Exception as e: + success = False + error_message = str(e) + raise + finally: + if resource_text: + try: + await self._record_invoke_resource_metric(db, resource_id, start_time, success, error_message) + except Exception as metrics_error: + logger.warning(f"Failed to invoke resource metric: {metrics_error}") + + # End Invoke resource span for Observability dashboard + if db_span_id and observability_service and not db_span_ended: + try: + observability_service.end_span( + db=db, + span_id=db_span_id, + status="ok" if success else "error", + status_message=error_message if error_message else None, + ) + db_span_ended = True + logger.debug(f"✓ Ended invoke.resource span: {db_span_id}") + except Exception as e: + logger.warning(f"Failed to end observability span for invoking resource: {e}") + async def read_resource( self, db: Session, @@ -919,9 +1346,24 @@ async def read_resource( # If content is already a Pydantic content model, return as-is if isinstance(content, (ResourceContent, TextContent)): + resource_response = await self.invoke_resource( + db=db, resource_id=getattr(content, "id"), resource_uri=getattr(content, "uri") or None, resource_template_uri=getattr(content, "text") or None + ) + if resource_response: + setattr(content, "text", resource_response) return content # If content is any object that quacks like content (e.g., MagicMock with .text/.blob), return as-is if hasattr(content, "text") or hasattr(content, "blob"): + if hasattr(content, "blob"): + resource_response = await self.invoke_resource( + db=db, resource_id=getattr(content, "id"), resource_uri=getattr(content, "uri") or None, resource_template_uri=getattr(content, "blob") or None + ) + setattr(content, "blob", resource_response) + elif hasattr(content, "text"): + resource_response = await self.invoke_resource( + db=db, resource_id=getattr(content, "id"), resource_uri=getattr(content, "uri") or None, resource_template_uri=getattr(content, "text") or None + ) + setattr(content, "text", resource_response) return content # Normalize primitive types to ResourceContent if isinstance(content, bytes): diff --git a/mcpgateway/static/admin.js b/mcpgateway/static/admin.js index ce80ba66f..5e128fe71 100644 --- a/mcpgateway/static/admin.js +++ b/mcpgateway/static/admin.js @@ -12629,9 +12629,10 @@ async function handleResourceFormSubmit(e) { // Check if URI contains '{' and '}' if (uri && uri.includes("{") && uri.includes("}")) { template = uri; + // append uri_template only when uri is a templatized resource + formData.append("uri_template", template); } - formData.append("uri_template", template); const nameValidation = validateInputName(name, "resource"); const uriValidation = validateInputName(uri, "resource URI"); diff --git a/tests/unit/mcpgateway/services/test_resource_service.py b/tests/unit/mcpgateway/services/test_resource_service.py index 43df3913b..11b1a8fea 100644 --- a/tests/unit/mcpgateway/services/test_resource_service.py +++ b/tests/unit/mcpgateway/services/test_resource_service.py @@ -433,21 +433,38 @@ async def test_list_server_resources(self, resource_service, mock_db, mock_resou # --------------------------------------------------------------------------- # # Resource reading tests # # --------------------------------------------------------------------------- # - +from unittest.mock import patch class TestResourceReading: """Test resource reading functionality.""" @pytest.mark.asyncio - async def test_read_resource_success(self, mock_db, mock_resource): - """Test successful resource reading.""" - from mcpgateway.services.resource_service import ResourceService + @patch("mcpgateway.services.resource_service.ssl.create_default_context") + async def test_read_resource_success(self, mock_ssl, mock_db, mock_resource): + mock_ctx = MagicMock() + mock_ssl.return_value = mock_ctx + mock_scalar = MagicMock() + mock_resource.gateway.ca_certificate = "-----BEGIN CERTIFICATE-----\nABC\n-----END CERTIFICATE-----" mock_scalar.scalar_one_or_none.return_value = mock_resource mock_db.execute.return_value = mock_scalar - resource_service_instance = ResourceService() - result = await resource_service_instance.read_resource(mock_db, resource_id=mock_resource.id) + + from mcpgateway.services.resource_service import ResourceService + service = ResourceService() + + result = await service.read_resource(mock_db, resource_id=mock_resource.id) assert result is not None + # @pytest.mark.asyncio + # async def test_read_resource_success(self, mock_db, mock_resource): + # """Test successful resource reading.""" + # from mcpgateway.services.resource_service import ResourceService + # mock_scalar = MagicMock() + # mock_scalar.scalar_one_or_none.return_value = mock_resource + # mock_db.execute.return_value = mock_scalar + # resource_service_instance = ResourceService() + # result = await resource_service_instance.read_resource(mock_db, resource_id=mock_resource.id) + # assert result is not None + @pytest.mark.asyncio async def test_read_resource_not_found(self, resource_service, mock_db): """Test reading non-existent resource.""" diff --git a/tests/unit/mcpgateway/services/test_resource_service_plugins.py b/tests/unit/mcpgateway/services/test_resource_service_plugins.py index 28bfa9b23..33856187f 100644 --- a/tests/unit/mcpgateway/services/test_resource_service_plugins.py +++ b/tests/unit/mcpgateway/services/test_resource_service_plugins.py @@ -157,7 +157,8 @@ async def test_admin_add_resource_with_valid_mime_type(self, mock_register_resou assert resource_create.uri_template == "greetme://morning/{name}" @pytest.mark.asyncio - async def test_read_resource_with_pre_fetch_hook(self,resource_service_with_plugins): + @patch("mcpgateway.services.resource_service.ssl.create_default_context") + async def test_read_resource_with_pre_fetch_hook(self, mock_ssl, resource_service_with_plugins): """Test read_resource executes pre-fetch hook and passes correct context.""" service = resource_service_with_plugins @@ -174,6 +175,9 @@ async def test_read_resource_with_pre_fetch_hook(self,resource_service_with_plug text="Test content", ) + mock_ctx = MagicMock() + mock_ssl.return_value = mock_ctx + # Mock DB row returned by scalar_one_or_none mock_db_row = MagicMock() mock_db_row.content = fake_resource_content @@ -264,7 +268,8 @@ async def test_read_resource_blocked_by_plugin(self, resource_service_with_plugi mock_manager.invoke_hook.assert_called() @pytest.mark.asyncio - async def test_read_resource_uri_modified_by_plugin(self, mock_db, resource_service_with_plugins): + @patch("mcpgateway.services.resource_service.ssl.create_default_context") + async def test_read_resource_uri_modified_by_plugin(self, mock_ssl, mock_db, resource_service_with_plugins): """Test read_resource with plugin modifying URI and a mocked SQLAlchemy Session.""" service = resource_service_with_plugins @@ -283,6 +288,9 @@ async def test_read_resource_uri_modified_by_plugin(self, mock_db, resource_serv mock_db_row.content = fake_resource_content mock_db_row.uri = fake_resource_content.uri mock_db_row.uri_template = None + + mock_ctx = MagicMock() + mock_ssl.return_value = mock_ctx # Configure scalar_one_or_none to return the mocked row mock_db.execute.return_value.scalar_one_or_none.return_value = mock_db_row @@ -290,6 +298,7 @@ async def test_read_resource_uri_modified_by_plugin(self, mock_db, resource_serv # Plugin modifies the URI (can return same or a different URI) modified_payload = MagicMock() modified_payload.uri = "cached://test://resource" + modified_payload.gateway.ca_certificate = "-----BEGIN CERTIFICATE-----\nABC\n-----END CERTIFICATE-----" async def plugin_side_effect(hook_type, payload, global_context, local_contexts=None, **kwargs): if hook_type == ResourceHookType.RESOURCE_PRE_FETCH: @@ -317,7 +326,8 @@ async def plugin_side_effect(hook_type, payload, global_context, local_contexts= @pytest.mark.asyncio - async def test_read_resource_content_filtered_by_plugin(self, resource_service_with_plugins, mock_db): + @patch("mcpgateway.services.resource_service.ssl.create_default_context") + async def test_read_resource_content_filtered_by_plugin(self, mock_ssl, resource_service_with_plugins, mock_db): """Test read_resource with content filtering by post-fetch hook.""" # First-Party from mcpgateway.plugins.framework.models import PluginResult @@ -328,6 +338,9 @@ async def test_read_resource_content_filtered_by_plugin(self, resource_service_w service = resource_service_with_plugins mock_manager = service._plugin_manager + mock_ctx = MagicMock() + mock_ssl.return_value = mock_ctx + # Setup mock resource with sensitive data mock_resource = MagicMock() original_content = ResourceContent( @@ -464,13 +477,17 @@ def invoke_hook_side_effect(hook_type, payload, global_context, local_contexts=N assert mock_manager.invoke_hook.call_count == 2 @pytest.mark.asyncio - async def test_read_resource_with_template(self, resource_service_with_plugins, mock_db): + @patch("mcpgateway.services.resource_service.ssl.create_default_context") + async def test_read_resource_with_template(self, mock_ssl, resource_service_with_plugins, mock_db): """Test read_resource with template resource and plugins.""" import mcpgateway.services.resource_service as resource_service_mod resource_service_mod.PLUGINS_AVAILABLE = True service = resource_service_with_plugins mock_manager = service._plugin_manager + mock_ctx = MagicMock() + mock_ssl.return_value = mock_ctx + # Setup mock resource mock_resource = MagicMock() mock_template_content = ResourceContent( @@ -479,6 +496,7 @@ async def test_read_resource_with_template(self, resource_service_with_plugins, uri="test://123/data", text="Template content for id=123", ) + mock_resource.gateway.ca_certificate = "-----BEGIN CERTIFICATE-----\nABC\n-----END CERTIFICATE-----" mock_resource.content = mock_template_content mock_resource.uri = "test://123/data" # Ensure uri is set at the top level mock_db.execute.return_value.scalar_one_or_none.return_value = mock_resource @@ -495,7 +513,8 @@ async def test_read_resource_with_template(self, resource_service_with_plugins, assert mock_manager.invoke_hook.call_count >= 2 # Pre and post fetch @pytest.mark.asyncio - async def test_read_resource_context_propagation(self, resource_service_with_plugins, mock_db): + @patch("mcpgateway.services.resource_service.ssl.create_default_context") + async def test_read_resource_context_propagation(self,mock_ssl, resource_service_with_plugins, mock_db): """Test context propagation from pre-fetch to post-fetch.""" # First-Party from mcpgateway.plugins.framework.models import PluginResult @@ -506,6 +525,9 @@ async def test_read_resource_context_propagation(self, resource_service_with_plu service = resource_service_with_plugins mock_manager = service._plugin_manager + mock_ctx = MagicMock() + mock_ssl.return_value = mock_ctx + # Setup mock resource mock_resource = MagicMock() mock_resource.content = ResourceContent( @@ -514,6 +536,7 @@ async def test_read_resource_context_propagation(self, resource_service_with_plu uri="test://resource", text="Test content", ) + mock_resource.gateway.ca_certificate = "-----BEGIN CERTIFICATE-----\nABC\n-----END CERTIFICATE-----" mock_resource.uri = "test://resource" # Ensure uri is set at the top level mock_db.execute.return_value.scalar_one_or_none.return_value = mock_resource mock_db.get.return_value = mock_resource # Ensure resource_db is not None @@ -583,17 +606,22 @@ async def test_plugin_manager_initialization_failure(self): assert service._plugin_manager is None # Should fail gracefully @pytest.mark.asyncio - async def test_read_resource_no_request_id(self, resource_service_with_plugins, mock_db): + @patch("mcpgateway.services.resource_service.ssl.create_default_context") + async def test_read_resource_no_request_id(self, mock_ssl,resource_service_with_plugins, mock_db): """Test read_resource generates request_id if not provided.""" import mcpgateway.services.resource_service as resource_service_mod resource_service_mod.PLUGINS_AVAILABLE = True service = resource_service_with_plugins mock_manager = service._plugin_manager + mock_ctx = MagicMock() + mock_ssl.return_value = mock_ctx + # Setup mock resource mock_resource = MagicMock() mock_resource.content = ResourceContent(type="resource", id="test://resource", uri="test://resource", text="Test") mock_resource.uri = "test://resource" # Ensure uri is set at the top level + mock_resource.gateway.ca_certificate = "-----BEGIN CERTIFICATE-----\nABC\n-----END CERTIFICATE-----" mock_db.execute.return_value.scalar_one_or_none.return_value = mock_resource mock_db.get.return_value = mock_resource # Ensure resource_db is not None