From 4eec3e337f61889514f4bbef74b8aac5503aa0e1 Mon Sep 17 00:00:00 2001 From: Shoumi Date: Fri, 14 Nov 2025 14:41:51 +0530 Subject: [PATCH 01/34] Add correlation ID system for unified request tracking Signed-off-by: Shoumi --- .env.example | 11 + mcpgateway/auth.py | 67 +++- mcpgateway/config.py | 6 + mcpgateway/main.py | 8 + mcpgateway/middleware/correlation_id.py | 121 +++++++ mcpgateway/middleware/http_auth_middleware.py | 12 +- .../middleware/request_logging_middleware.py | 5 + mcpgateway/observability.py | 18 + mcpgateway/services/a2a_service.py | 1 + mcpgateway/services/logging_service.py | 54 ++- mcpgateway/services/tool_service.py | 6 +- mcpgateway/utils/correlation_id.py | 175 ++++++++++ .../middleware/test_correlation_id.py | 230 +++++++++++++ .../test_correlation_id_json_formatter.py | 307 ++++++++++++++++++ .../mcpgateway/utils/test_correlation_id.py | 216 ++++++++++++ 15 files changed, 1226 insertions(+), 11 deletions(-) create mode 100644 mcpgateway/middleware/correlation_id.py create mode 100644 mcpgateway/utils/correlation_id.py create mode 100644 tests/unit/mcpgateway/middleware/test_correlation_id.py create mode 100644 tests/unit/mcpgateway/services/test_correlation_id_json_formatter.py create mode 100644 tests/unit/mcpgateway/utils/test_correlation_id.py diff --git a/.env.example b/.env.example index 693715cf2..c4f3293c5 100644 --- a/.env.example +++ b/.env.example @@ -659,6 +659,17 @@ LOG_MAX_SIZE_MB=1 LOG_BACKUP_COUNT=5 LOG_BUFFER_SIZE_MB=1.0 +# Correlation ID / Request Tracking +# Enable automatic correlation ID tracking for unified request tracing +# Options: true (default), false +CORRELATION_ID_ENABLED=true +# HTTP header name for correlation ID (default: X-Correlation-ID) +CORRELATION_ID_HEADER=X-Correlation-ID +# Preserve incoming correlation IDs from clients (default: true) +CORRELATION_ID_PRESERVE=true +# Include correlation ID in HTTP response headers (default: true) +CORRELATION_ID_RESPONSE_HEADER=true + # Transport Protocol Configuration # Options: all (default), sse, streamablehttp, http # - all: Enable all transport protocols diff --git a/mcpgateway/auth.py b/mcpgateway/auth.py index ea633ee5d..629f4df46 100644 --- a/mcpgateway/auth.py +++ b/mcpgateway/auth.py @@ -26,11 +26,63 @@ from mcpgateway.config import settings from mcpgateway.db import EmailUser, SessionLocal from mcpgateway.plugins.framework import get_plugin_manager, GlobalContext, HttpAuthResolveUserPayload, HttpHeaderPayload, HttpHookType, PluginViolationError -from mcpgateway.services.team_management_service import TeamManagementService +from mcpgateway.services.team_management_service import TeamManagementService # pylint: disable=import-outside-toplevel +from mcpgateway.utils.correlation_id import get_correlation_id from mcpgateway.utils.verify_credentials import verify_jwt_token # Security scheme -bearer_scheme = HTTPBearer(auto_error=False) +security = HTTPBearer(auto_error=False) + + +def _log_auth_event( + logger: logging.Logger, + message: str, + level: int = logging.INFO, + user_id: Optional[str] = None, + auth_method: Optional[str] = None, + auth_success: bool = False, + security_event: Optional[str] = None, + security_severity: str = "low", + **extra_context +) -> None: + """Log authentication event with structured context and request_id. + + This helper creates structured log records that include request_id from the + correlation ID context, enabling end-to-end tracing of authentication flows. + + Args: + logger: Logger instance to use + message: Log message + level: Log level (default: INFO) + user_id: User identifier + auth_method: Authentication method used (jwt, api_token, etc.) + auth_success: Whether authentication succeeded + security_event: Type of security event (authentication, authorization, etc.) + security_severity: Severity level (low, medium, high, critical) + **extra_context: Additional context fields + """ + # Get request_id from correlation ID context + request_id = get_correlation_id() + + # Build structured log record + extra = { + 'request_id': request_id, + 'entity_type': 'auth', + 'auth_success': auth_success, + 'security_event': security_event or 'authentication', + 'security_severity': security_severity, + } + + if user_id: + extra['user_id'] = user_id + if auth_method: + extra['auth_method'] = auth_method + + # Add any additional context + extra.update(extra_context) + + # Log with structured context + logger.log(level, message, extra=extra) def get_db() -> Generator[Session, Never, None]: @@ -169,10 +221,15 @@ async def get_current_user( if request and hasattr(request, "headers"): headers = dict(request.headers) - # Get request ID from request state (set by middleware) or generate new one - request_id = getattr(request.state, "request_id", None) if request else None + # Get request ID from correlation ID context (set by CorrelationIDMiddleware) + request_id = get_correlation_id() if not request_id: - request_id = uuid.uuid4().hex + # Fallback chain for safety + if request and hasattr(request, "state") and hasattr(request.state, "request_id"): + request_id = request.state.request_id + else: + request_id = uuid.uuid4().hex + logger.debug(f"Generated fallback request ID in get_current_user: {request_id}") # Get plugin contexts from request state if available global_context = getattr(request.state, "plugin_global_context", None) if request else None diff --git a/mcpgateway/config.py b/mcpgateway/config.py index 9d3017876..8b541f096 100644 --- a/mcpgateway/config.py +++ b/mcpgateway/config.py @@ -776,6 +776,12 @@ def _parse_allowed_origins(cls, v: Any) -> Set[str]: # Enable span events observability_events_enabled: bool = Field(default=True, description="Enable event logging within spans") + # Correlation ID Settings + correlation_id_enabled: bool = Field(default=True, description="Enable automatic correlation ID tracking for requests") + correlation_id_header: str = Field(default="X-Correlation-ID", description="HTTP header name for correlation ID") + correlation_id_preserve: bool = Field(default=True, description="Preserve correlation IDs from incoming requests") + correlation_id_response_header: bool = Field(default=True, description="Include correlation ID in response headers") + @field_validator("log_level", mode="before") @classmethod def validate_log_level(cls, v: str) -> str: diff --git a/mcpgateway/main.py b/mcpgateway/main.py index 6eb6e0f2d..2df30e6ec 100644 --- a/mcpgateway/main.py +++ b/mcpgateway/main.py @@ -70,6 +70,7 @@ from mcpgateway.db import refresh_slugs_on_startup, SessionLocal from mcpgateway.db import Tool as DbTool from mcpgateway.handlers.sampling import SamplingHandler +from mcpgateway.middleware.correlation_id import CorrelationIDMiddleware from mcpgateway.middleware.http_auth_middleware import HttpAuthMiddleware from mcpgateway.middleware.protocol_version import MCPProtocolVersionMiddleware from mcpgateway.middleware.rbac import get_current_user_with_permissions, require_permission @@ -1169,6 +1170,13 @@ async def _call_streamable_http(self, scope, receive, send): # Add HTTP authentication hook middleware for plugins (before auth dependencies) if plugin_manager: app.add_middleware(HttpAuthMiddleware, plugin_manager=plugin_manager) + logger.info("🔌 HTTP authentication hooks enabled for plugins") + +# Add correlation ID middleware if enabled +# Note: Registered AFTER HttpAuthMiddleware so it executes FIRST (middleware runs in LIFO order) +if settings.correlation_id_enabled: + app.add_middleware(CorrelationIDMiddleware) + logger.info(f"✅ Correlation ID tracking enabled (header: {settings.correlation_id_header})") # Add custom DocsAuthMiddleware app.add_middleware(DocsAuthMiddleware) diff --git a/mcpgateway/middleware/correlation_id.py b/mcpgateway/middleware/correlation_id.py new file mode 100644 index 000000000..bcb033a3d --- /dev/null +++ b/mcpgateway/middleware/correlation_id.py @@ -0,0 +1,121 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/middleware/correlation_id.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: MCP Gateway Contributors + +Correlation ID (Request ID) Middleware. + +This middleware handles X-Correlation-ID HTTP headers and maps them to the internal +request_id used throughout the system for unified request tracing. + +Key concept: HTTP X-Correlation-ID header → Internal request_id field (single ID for entire request flow) + +The middleware automatically extracts or generates request IDs for every HTTP request, +stores them in context variables for async-safe propagation across services, and +injects them back into response headers for client-side correlation. + +This enables end-to-end tracing: HTTP → Middleware → Services → Plugins → Logs (all with same request_id) +""" + +# Standard +import logging +from typing import Callable + +# Third-Party +from fastapi import Request, Response +from starlette.middleware.base import BaseHTTPMiddleware + +# First-Party +from mcpgateway.config import settings +from mcpgateway.utils.correlation_id import ( + clear_correlation_id, + extract_correlation_id_from_headers, + generate_correlation_id, + set_correlation_id, +) + +logger = logging.getLogger(__name__) + + +class CorrelationIDMiddleware(BaseHTTPMiddleware): + """Middleware for automatic request ID (correlation ID) handling. + + This middleware: + 1. Extracts request ID from X-Correlation-ID header in incoming requests + 2. Generates a new UUID if no correlation ID is present + 3. Stores the ID in context variables for the request lifecycle (used as request_id throughout system) + 4. Injects the request ID into X-Correlation-ID response header + 5. Cleans up context after request completion + + The request ID extracted/generated here becomes the unified request_id used in: + - All log entries (request_id field) + - GlobalContext.request_id (when plugins execute) + - Service method calls for tracing + - Database queries for request tracking + + Configuration is controlled via settings: + - correlation_id_enabled: Enable/disable the middleware + - correlation_id_header: Header name to use (default: X-Correlation-ID) + - correlation_id_preserve: Whether to preserve incoming IDs (default: True) + - correlation_id_response_header: Whether to add ID to responses (default: True) + """ + + def __init__(self, app): + """Initialize the correlation ID (request ID) middleware. + + Args: + app: The FastAPI application instance + """ + super().__init__(app) + self.header_name = getattr(settings, 'correlation_id_header', 'X-Correlation-ID') + self.preserve_incoming = getattr(settings, 'correlation_id_preserve', True) + self.add_to_response = getattr(settings, 'correlation_id_response_header', True) + + async def dispatch(self, request: Request, call_next: Callable) -> Response: + """Process the request and manage request ID (correlation ID) lifecycle. + + Extracts or generates a request ID, stores it in context variables for use throughout + the request lifecycle (becomes request_id in logs, services, plugins), and injects + it back into the X-Correlation-ID response header. + + Args: + request: The incoming HTTP request + call_next: The next middleware or route handler + + Returns: + Response: The HTTP response with correlation ID header added + """ + # Extract correlation ID from incoming request headers + correlation_id = None + if self.preserve_incoming: + correlation_id = extract_correlation_id_from_headers( + dict(request.headers), + self.header_name + ) + + # Generate new correlation ID if none was provided + if not correlation_id: + correlation_id = generate_correlation_id() + logger.debug(f"Generated new correlation ID: {correlation_id}") + else: + logger.debug(f"Using client-provided correlation ID: {correlation_id}") + + # Store correlation ID in context variable for this request + # This makes it available to all downstream code (auth, services, plugins, logs) + set_correlation_id(correlation_id) + + try: + # Process the request + response = await call_next(request) + + # Add correlation ID to response headers if enabled + if self.add_to_response: + response.headers[self.header_name] = correlation_id + + return response + + finally: + # Clean up context after request completes + # Note: ContextVar automatically cleans up, but explicit cleanup is good practice + clear_correlation_id() diff --git a/mcpgateway/middleware/http_auth_middleware.py b/mcpgateway/middleware/http_auth_middleware.py index 84058641f..36f987e3c 100644 --- a/mcpgateway/middleware/http_auth_middleware.py +++ b/mcpgateway/middleware/http_auth_middleware.py @@ -17,6 +17,7 @@ # First-Party from mcpgateway.plugins.framework import GlobalContext, HttpHeaderPayload, HttpHookType, HttpPostRequestPayload, HttpPreRequestPayload, PluginManager +from mcpgateway.utils.correlation_id import generate_correlation_id, get_correlation_id logger = logging.getLogger(__name__) @@ -60,9 +61,14 @@ async def dispatch(self, request: Request, call_next): if not self.plugin_manager: return await call_next(request) - # Generate request ID for tracing and store in request state - # This ensures all hooks and downstream code see the same request ID - request_id = uuid.uuid4().hex + # Use correlation ID from CorrelationIDMiddleware if available + # This ensures all hooks and downstream code see the same unified request ID + request_id = get_correlation_id() + if not request_id: + # Fallback if correlation ID middleware is disabled + request_id = generate_correlation_id() + logger.debug(f"Correlation ID not found, generated fallback: {request_id}") + request.state.request_id = request_id # Create global context for hooks diff --git a/mcpgateway/middleware/request_logging_middleware.py b/mcpgateway/middleware/request_logging_middleware.py index db286b20f..f0dc10805 100644 --- a/mcpgateway/middleware/request_logging_middleware.py +++ b/mcpgateway/middleware/request_logging_middleware.py @@ -23,6 +23,7 @@ # First-Party from mcpgateway.services.logging_service import LoggingService +from mcpgateway.utils.correlation_id import get_correlation_id # Initialize logging service first logging_service = LoggingService() @@ -171,12 +172,16 @@ async def dispatch(self, request: Request, call_next: Callable): # Mask sensitive headers masked_headers = mask_sensitive_headers(dict(request.headers)) + # Get correlation ID for request tracking + request_id = get_correlation_id() + logger.log( log_level, f"📩 Incoming request: {request.method} {request.url.path}\n" f"Query params: {dict(request.query_params)}\n" f"Headers: {masked_headers}\n" f"Body: {payload_str}{'... [truncated]' if truncated else ''}", + extra={"request_id": request_id}, ) except Exception as e: diff --git a/mcpgateway/observability.py b/mcpgateway/observability.py index 6714dd392..31ab246b0 100644 --- a/mcpgateway/observability.py +++ b/mcpgateway/observability.py @@ -440,6 +440,24 @@ def create_span(name: str, attributes: Optional[Dict[str, Any]] = None) -> Any: # Return a no-op context manager if tracing is not configured or available return nullcontext() + # Auto-inject correlation ID into all spans for request tracing + try: + # Import here to avoid circular dependency + from mcpgateway.utils.correlation_id import get_correlation_id + + correlation_id = get_correlation_id() + if correlation_id: + if attributes is None: + attributes = {} + # Add correlation ID if not already present + if "correlation_id" not in attributes: + attributes["correlation_id"] = correlation_id + if "request_id" not in attributes: + attributes["request_id"] = correlation_id # Alias for compatibility + except ImportError: + # Correlation ID module not available, continue without it + pass + # Start span and return the context manager span_context = _TRACER.start_as_current_span(name) diff --git a/mcpgateway/services/a2a_service.py b/mcpgateway/services/a2a_service.py index 33f3468d0..a410f18b7 100644 --- a/mcpgateway/services/a2a_service.py +++ b/mcpgateway/services/a2a_service.py @@ -28,6 +28,7 @@ from mcpgateway.services.logging_service import LoggingService from mcpgateway.services.team_management_service import TeamManagementService from mcpgateway.services.tool_service import ToolService +from mcpgateway.utils.correlation_id import get_correlation_id from mcpgateway.utils.create_slug import slugify from mcpgateway.utils.services_auth import encode_auth # ,decode_auth diff --git a/mcpgateway/services/logging_service.py b/mcpgateway/services/logging_service.py index 4f21111c0..0cad12837 100644 --- a/mcpgateway/services/logging_service.py +++ b/mcpgateway/services/logging_service.py @@ -25,6 +25,7 @@ from mcpgateway.common.models import LogLevel from mcpgateway.config import settings from mcpgateway.services.log_storage_service import LogStorageService +from mcpgateway.utils.correlation_id import get_correlation_id AnyioClosedResourceError: Optional[type] # pylint: disable=invalid-name try: @@ -38,8 +39,57 @@ # Create a text formatter text_formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") -# Create a JSON formatter -json_formatter = jsonlogger.JsonFormatter("%(asctime)s %(name)s %(levelname)s %(message)s") + +class CorrelationIdJsonFormatter(jsonlogger.JsonFormatter): + """JSON formatter that includes correlation ID and OpenTelemetry trace context.""" + + def add_fields(self, log_record: dict, record: logging.LogRecord, message_dict: dict) -> None: + """Add custom fields to the log record. + + Args: + log_record: The dictionary that will be logged as JSON + record: The original LogRecord + message_dict: Additional message fields + + """ + super().add_fields(log_record, record, message_dict) + + # Add timestamp in ISO 8601 format with 'Z' suffix for UTC + import os + import socket + from datetime import datetime, timezone + + dt = datetime.fromtimestamp(record.created, tz=timezone.utc) + log_record["@timestamp"] = dt.isoformat().replace("+00:00", "Z") + + # Add hostname and process ID for log aggregation + log_record["hostname"] = socket.gethostname() + log_record["process_id"] = os.getpid() + + # Add correlation ID from context + correlation_id = get_correlation_id() + if correlation_id: + log_record["request_id"] = correlation_id + + # Add OpenTelemetry trace context if available + try: + from opentelemetry import trace + + span = trace.get_current_span() + if span and span.is_recording(): + span_context = span.get_span_context() + if span_context.is_valid: + # Format trace_id and span_id as hex strings + log_record["trace_id"] = format(span_context.trace_id, "032x") + log_record["span_id"] = format(span_context.span_id, "016x") + log_record["trace_flags"] = format(span_context.trace_flags, "02x") + except (ImportError, Exception): + # OpenTelemetry not available or error accessing span + pass + + +# Create a JSON formatter with correlation ID support +json_formatter = CorrelationIdJsonFormatter("%(asctime)s %(name)s %(levelname)s %(message)s") # Note: Don't use basicConfig here as it conflicts with our custom dual logging setup # The LoggingService.initialize() method will properly configure all handlers diff --git a/mcpgateway/services/tool_service.py b/mcpgateway/services/tool_service.py index 731b67e0a..88f23ddc0 100644 --- a/mcpgateway/services/tool_service.py +++ b/mcpgateway/services/tool_service.py @@ -26,6 +26,9 @@ from urllib.parse import parse_qs, urlparse import uuid +# First-Party (early import for correlation_id) +from mcpgateway.utils.correlation_id import get_correlation_id + # Third-Party import httpx import jq @@ -1224,7 +1227,8 @@ async def invoke_tool( global_context.server_id = gateway_id else: # Create new context (fallback when middleware didn't run) - request_id = uuid.uuid4().hex + # Use correlation ID from context if available, otherwise generate new one + request_id = get_correlation_id() or uuid.uuid4().hex gateway_id = getattr(tool, "gateway_id", "unknown") server_id = gateway_id if isinstance(gateway_id, str) else "unknown" global_context = GlobalContext(request_id=request_id, server_id=server_id, tenant_id=None, user=app_user_email) diff --git a/mcpgateway/utils/correlation_id.py b/mcpgateway/utils/correlation_id.py new file mode 100644 index 000000000..2ff58f2bc --- /dev/null +++ b/mcpgateway/utils/correlation_id.py @@ -0,0 +1,175 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/utils/correlation_id.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: MCP Gateway Contributors + +Correlation ID (Request ID) Utilities. + +This module provides async-safe utilities for managing correlation IDs (also known as +request IDs) throughout the request lifecycle using Python's contextvars. + +The correlation ID is a unique identifier that tracks a single request as it flows +through all components of the system (HTTP → Middleware → Services → Plugins → Logs). + +Key concepts: +- ContextVar provides per-request isolation in async environments +- Correlation IDs can be client-provided (X-Correlation-ID header) or auto-generated +- The same ID is used as request_id throughout logs, services, and plugin contexts +- Thread-safe and async-safe (no cross-contamination between concurrent requests) +""" + +# Standard +from contextvars import ContextVar +import logging +from typing import Dict, Optional +import uuid + +logger = logging.getLogger(__name__) + +# Context variable for storing correlation ID (request ID) per-request +# This is async-safe and provides automatic isolation between concurrent requests +_correlation_id_context: ContextVar[Optional[str]] = ContextVar('correlation_id', default=None) + + +def get_correlation_id() -> Optional[str]: + """Get the current correlation ID (request ID) from context. + + Returns the correlation ID for the current async task/request. Each request + has its own isolated context, so concurrent requests won't interfere. + + Returns: + Optional[str]: The correlation ID if set, None otherwise + """ + return _correlation_id_context.get() + + +def set_correlation_id(correlation_id: str) -> None: + """Set the correlation ID (request ID) for the current context. + + Stores the correlation ID in a context variable that's automatically isolated + per async task. This ID will be used as request_id throughout the system. + + Args: + correlation_id: The correlation ID to set (typically a UUID or client-provided ID) + """ + _correlation_id_context.set(correlation_id) + + +def clear_correlation_id() -> None: + """Clear the correlation ID (request ID) from the current context. + + Should be called at the end of request processing to clean up context. + In practice, FastAPI middleware automatically handles context cleanup. + + Note: This is optional as ContextVar automatically cleans up when the + async task completes. + """ + _correlation_id_context.set(None) + + +def generate_correlation_id() -> str: + """Generate a new correlation ID (UUID4 hex format). + + Creates a new random UUID suitable for use as a correlation ID. + Uses UUID4 which provides 122 bits of randomness. + + Returns: + str: A new UUID in hex format (32 characters, no hyphens) + """ + return uuid.uuid4().hex + + +def extract_correlation_id_from_headers(headers: Dict[str, str], header_name: str = "X-Correlation-ID") -> Optional[str]: + """Extract correlation ID from HTTP headers. + + Searches for the correlation ID header (case-insensitive) and returns its value. + Validates that the value is non-empty after stripping whitespace. + + Args: + headers: Dictionary of HTTP headers + header_name: Name of the correlation ID header (default: X-Correlation-ID) + + Returns: + Optional[str]: The correlation ID if found and valid, None otherwise + + Example: + >>> headers = {"X-Correlation-ID": "abc-123"} + >>> extract_correlation_id_from_headers(headers) + 'abc-123' + + >>> headers = {"x-correlation-id": "def-456"} # Case insensitive + >>> extract_correlation_id_from_headers(headers) + 'def-456' + """ + # Headers can be accessed case-insensitively in FastAPI/Starlette + for key, value in headers.items(): + if key.lower() == header_name.lower(): + correlation_id = value.strip() + if correlation_id: + return correlation_id + return None + + +def get_or_generate_correlation_id() -> str: + """Get the current correlation ID or generate a new one if not set. + + This is a convenience function that ensures you always have a correlation ID. + If the current context doesn't have a correlation ID, it generates and sets + a new one. + + Returns: + str: The correlation ID (either existing or newly generated) + + Example: + >>> # First call generates new ID + >>> id1 = get_or_generate_correlation_id() + >>> # Second call returns same ID + >>> id2 = get_or_generate_correlation_id() + >>> assert id1 == id2 + """ + correlation_id = get_correlation_id() + if not correlation_id: + correlation_id = generate_correlation_id() + set_correlation_id(correlation_id) + return correlation_id + + +def validate_correlation_id(correlation_id: Optional[str], max_length: int = 255) -> bool: + """Validate a correlation ID for safety and length. + + Checks that the correlation ID is: + - Non-empty after stripping whitespace + - Within the maximum length limit + - Contains only safe characters (alphanumeric, hyphens, underscores) + + Args: + correlation_id: The correlation ID to validate + max_length: Maximum allowed length (default: 255) + + Returns: + bool: True if valid, False otherwise + + Example: + >>> validate_correlation_id("abc-123") + True + >>> validate_correlation_id("abc 123") # Spaces not allowed + False + >>> validate_correlation_id("a" * 300) # Too long + False + """ + if not correlation_id or not correlation_id.strip(): + return False + + correlation_id = correlation_id.strip() + + if len(correlation_id) > max_length: + logger.warning(f"Correlation ID too long: {len(correlation_id)} > {max_length}") + return False + + # Allow alphanumeric, hyphens, and underscores only + if not all(c.isalnum() or c in ('-', '_') for c in correlation_id): + logger.warning(f"Correlation ID contains invalid characters: {correlation_id}") + return False + + return True diff --git a/tests/unit/mcpgateway/middleware/test_correlation_id.py b/tests/unit/mcpgateway/middleware/test_correlation_id.py new file mode 100644 index 000000000..029d482fc --- /dev/null +++ b/tests/unit/mcpgateway/middleware/test_correlation_id.py @@ -0,0 +1,230 @@ +# -*- coding: utf-8 -*- +"""Tests for correlation ID middleware.""" + +import pytest +from unittest.mock import Mock, patch +from fastapi import FastAPI, Request +from fastapi.testclient import TestClient + +from mcpgateway.middleware.correlation_id import CorrelationIDMiddleware +from mcpgateway.utils.correlation_id import get_correlation_id + + +@pytest.fixture +def app(): + """Create a test FastAPI app with correlation ID middleware.""" + test_app = FastAPI() + + # Add the correlation ID middleware + test_app.add_middleware(CorrelationIDMiddleware) + + @test_app.get("/test") + async def test_endpoint(request: Request): + # Get correlation ID from context + correlation_id = get_correlation_id() + return {"correlation_id": correlation_id} + + return test_app + + +@pytest.fixture +def client(app): + """Create a test client.""" + return TestClient(app) + + +def test_middleware_generates_correlation_id_when_not_provided(client): + """Test that middleware generates a correlation ID when not provided by client.""" + response = client.get("/test") + + assert response.status_code == 200 + data = response.json() + + # Should have a correlation ID in response body + assert "correlation_id" in data + assert data["correlation_id"] is not None + assert len(data["correlation_id"]) == 32 # UUID hex format + + # Should have correlation ID in response headers + assert "X-Correlation-ID" in response.headers + assert response.headers["X-Correlation-ID"] == data["correlation_id"] + + +def test_middleware_preserves_client_correlation_id(client): + """Test that middleware preserves correlation ID from client.""" + client_id = "client-provided-id-123" + + response = client.get("/test", headers={"X-Correlation-ID": client_id}) + + assert response.status_code == 200 + data = response.json() + + # Should use the client-provided ID + assert data["correlation_id"] == client_id + + # Should echo it back in response headers + assert response.headers["X-Correlation-ID"] == client_id + + +def test_middleware_case_insensitive_header(client): + """Test that middleware handles case-insensitive headers.""" + client_id = "lowercase-header-id" + + response = client.get("/test", headers={"x-correlation-id": client_id}) + + assert response.status_code == 200 + data = response.json() + + # Should use the client-provided ID regardless of case + assert data["correlation_id"] == client_id + + +def test_middleware_strips_whitespace_from_header(client): + """Test that middleware strips whitespace from correlation ID header.""" + client_id = " whitespace-id " + + response = client.get("/test", headers={"X-Correlation-ID": client_id}) + + assert response.status_code == 200 + data = response.json() + + # Should strip whitespace + assert data["correlation_id"] == "whitespace-id" + + +def test_middleware_clears_correlation_id_after_request(app): + """Test that middleware clears correlation ID after request completes.""" + client = TestClient(app) + + # Make a request + response = client.get("/test") + assert response.status_code == 200 + + # After request completes, correlation ID should be cleared + # (Note: This happens in a different context, so we can't directly test it here, + # but we verify that multiple requests get different IDs) + response2 = client.get("/test") + assert response2.status_code == 200 + + # Two requests without client-provided IDs should have different correlation IDs + assert response.json()["correlation_id"] != response2.json()["correlation_id"] + + +def test_middleware_handles_empty_header(client): + """Test that middleware generates new ID when header is empty.""" + response = client.get("/test", headers={"X-Correlation-ID": ""}) + + assert response.status_code == 200 + data = response.json() + + # Should generate a new ID when header is empty + assert data["correlation_id"] is not None + assert len(data["correlation_id"]) == 32 + + +def test_middleware_with_custom_settings(monkeypatch): + """Test middleware with custom configuration settings.""" + # Create a mock settings object + mock_settings = Mock() + mock_settings.correlation_id_header = "X-Request-ID" + mock_settings.correlation_id_preserve = False + mock_settings.correlation_id_response_header = False + + # Create app with custom settings + app = FastAPI() + + # Patch settings at module level + with patch("mcpgateway.middleware.correlation_id.settings", mock_settings): + app.add_middleware(CorrelationIDMiddleware) + + @app.get("/test") + async def test_endpoint(): + return {"correlation_id": get_correlation_id()} + + client = TestClient(app) + + # Test with custom header name + response = client.get("/test", headers={"X-Request-ID": "custom-id"}) + + assert response.status_code == 200 + + # When preserve=False, should always generate new ID (not use client's) + # When response_header=False, should not include in response headers + assert "X-Request-ID" not in response.headers + + +def test_middleware_integration_with_multiple_requests(client): + """Test middleware properly isolates correlation IDs across multiple requests.""" + ids = [] + + for i in range(5): + response = client.get("/test", headers={"X-Correlation-ID": f"request-{i}"}) + assert response.status_code == 200 + ids.append(response.json()["correlation_id"]) + + # Each request should have its unique correlation ID + assert len(ids) == 5 + assert len(set(ids)) == 5 # All unique + for i, correlation_id in enumerate(ids): + assert correlation_id == f"request-{i}" + + +def test_middleware_context_isolation(): + """Test that correlation ID is properly isolated per request context.""" + app = FastAPI() + app.add_middleware(CorrelationIDMiddleware) + + correlation_ids_seen = [] + + @app.get("/capture") + async def capture_endpoint(): + # Capture the correlation ID during request handling + correlation_id = get_correlation_id() + correlation_ids_seen.append(correlation_id) + return {"captured": correlation_id} + + client = TestClient(app) + + # Make multiple concurrent-like requests + for i in range(3): + response = client.get("/capture", headers={"X-Correlation-ID": f"id-{i}"}) + assert response.status_code == 200 + + # Each request should have captured its own unique ID + assert len(correlation_ids_seen) == 3 + assert correlation_ids_seen[0] == "id-0" + assert correlation_ids_seen[1] == "id-1" + assert correlation_ids_seen[2] == "id-2" + + +def test_middleware_preserves_correlation_id_through_request_lifecycle(): + """Test that correlation ID remains consistent throughout entire request.""" + captured_ids = [] + + app = FastAPI() + + @app.middleware("http") + async def capture_middleware(request: Request, call_next): + # Capture ID at middleware level (after CorrelationIDMiddleware sets it) + captured_ids.append(("middleware", get_correlation_id())) + response = await call_next(request) + return response + + # Add CorrelationIDMiddleware last so it executes first (LIFO) + app.add_middleware(CorrelationIDMiddleware) + + @app.get("/test") + async def test_endpoint(): + # Capture ID at endpoint level + captured_ids.append(("endpoint", get_correlation_id())) + return {"ok": True} + + client = TestClient(app) + response = client.get("/test", headers={"X-Correlation-ID": "consistent-id"}) + + assert response.status_code == 200 + + # Both captures should have the same correlation ID + assert len(captured_ids) == 2 + assert captured_ids[0][1] == "consistent-id" # Middleware capture + assert captured_ids[1][1] == "consistent-id" # Endpoint capture diff --git a/tests/unit/mcpgateway/services/test_correlation_id_json_formatter.py b/tests/unit/mcpgateway/services/test_correlation_id_json_formatter.py new file mode 100644 index 000000000..55a7e6137 --- /dev/null +++ b/tests/unit/mcpgateway/services/test_correlation_id_json_formatter.py @@ -0,0 +1,307 @@ +# -*- coding: utf-8 -*- +"""Tests for correlation ID JSON formatter.""" + +import json +import logging +from datetime import datetime, timezone +from io import StringIO +from unittest.mock import Mock, patch + +import pytest + +from mcpgateway.services.logging_service import CorrelationIdJsonFormatter +from mcpgateway.utils.correlation_id import set_correlation_id, clear_correlation_id + + +@pytest.fixture +def formatter(): + """Create a test JSON formatter.""" + return CorrelationIdJsonFormatter() + + +@pytest.fixture +def logger_with_formatter(formatter): + """Create a test logger with JSON formatter.""" + logger = logging.getLogger("test_correlation_logger") + logger.setLevel(logging.DEBUG) + logger.handlers.clear() + + # Add string stream handler + stream = StringIO() + handler = logging.StreamHandler(stream) + handler.setFormatter(formatter) + logger.addHandler(handler) + + return logger, stream + + +def test_formatter_includes_correlation_id(logger_with_formatter): + """Test that formatter includes correlation ID in log records.""" + logger, stream = logger_with_formatter + + # Set correlation ID + test_id = "test-correlation-123" + set_correlation_id(test_id) + + # Log a message + logger.info("Test message") + + # Get the logged output + output = stream.getvalue() + log_record = json.loads(output.strip()) + + # Should include correlation ID + assert "request_id" in log_record + assert log_record["request_id"] == test_id + + clear_correlation_id() + + +def test_formatter_without_correlation_id(logger_with_formatter): + """Test formatter when correlation ID is not set.""" + logger, stream = logger_with_formatter + + # Clear any existing correlation ID + clear_correlation_id() + + # Log a message + logger.info("Test message without correlation ID") + + # Get the logged output + output = stream.getvalue() + log_record = json.loads(output.strip()) + + # request_id should not be present + assert "request_id" not in log_record or log_record.get("request_id") is None + + +def test_formatter_includes_standard_fields(logger_with_formatter): + """Test that formatter includes standard log fields.""" + logger, stream = logger_with_formatter + + # Log a message + logger.info("Standard fields test") + + # Get the logged output + output = stream.getvalue() + log_record = json.loads(output.strip()) + + # Check for standard fields + assert "message" in log_record + assert log_record["message"] == "Standard fields test" + assert "@timestamp" in log_record + assert "hostname" in log_record + assert "process_id" in log_record + # Note: levelname is included by the JsonFormatter format string if specified + + +def test_formatter_includes_opentelemetry_trace_context(logger_with_formatter): + """Test that formatter includes OpenTelemetry trace context when available.""" + logger, stream = logger_with_formatter + + # Mock OpenTelemetry span + mock_span_context = Mock() + mock_span_context.trace_id = 0x1234567890abcdef1234567890abcdef + mock_span_context.span_id = 0x1234567890abcdef + mock_span_context.trace_flags = 0x01 + mock_span_context.is_valid = True + + mock_span = Mock() + mock_span.is_recording.return_value = True + mock_span.get_span_context.return_value = mock_span_context + + with patch("opentelemetry.trace.get_current_span") as mock_get_span: + mock_get_span.return_value = mock_span + + # Log a message + logger.info("Test with trace context") + + # Get the logged output + output = stream.getvalue() + log_record = json.loads(output.strip()) + + # Should include trace context + assert "trace_id" in log_record + assert "span_id" in log_record + assert "trace_flags" in log_record + + # Verify hex formatting + assert log_record["trace_id"] == "1234567890abcdef1234567890abcdef" + assert log_record["span_id"] == "1234567890abcdef" + assert log_record["trace_flags"] == "01" + + +def test_formatter_handles_missing_opentelemetry(logger_with_formatter): + """Test that formatter gracefully handles missing OpenTelemetry.""" + logger, stream = logger_with_formatter + + # Simulate ImportError for opentelemetry + import sys + with patch.dict(sys.modules, {"opentelemetry.trace": None}): + # Log a message + logger.info("Test without OpenTelemetry") + + # Get the logged output + output = stream.getvalue() + log_record = json.loads(output.strip()) + + # Should not fail, just exclude trace fields + assert "trace_id" not in log_record + assert "span_id" not in log_record + assert "message" in log_record + + +def test_formatter_timestamp_format(logger_with_formatter): + """Test that timestamp is in ISO 8601 format with 'Z' suffix.""" + logger, stream = logger_with_formatter + + # Log a message + logger.info("Timestamp test") + + # Get the logged output + output = stream.getvalue() + log_record = json.loads(output.strip()) + + # Check timestamp format + assert "@timestamp" in log_record + timestamp = log_record["@timestamp"] + + # Should end with 'Z' (Zulu/UTC time) + assert timestamp.endswith("Z") + + # Should be parseable as ISO 8601 + # Remove 'Z' and parse + datetime.fromisoformat(timestamp.replace("Z", "+00:00")) + + +def test_formatter_with_extra_fields(logger_with_formatter): + """Test that formatter includes extra fields from log record.""" + logger, stream = logger_with_formatter + + # Log with extra fields + logger.info("Extra fields test", extra={"user_id": "user-123", "action": "login"}) + + # Get the logged output + output = stream.getvalue() + log_record = json.loads(output.strip()) + + # Should include extra fields + assert log_record.get("user_id") == "user-123" + assert log_record.get("action") == "login" + + +def test_formatter_correlation_id_with_trace_context(logger_with_formatter): + """Test that both correlation ID and trace context coexist.""" + logger, stream = logger_with_formatter + + # Set correlation ID + set_correlation_id("both-test-id") + + # Mock OpenTelemetry span + mock_span_context = Mock() + mock_span_context.trace_id = 0xabcdef + mock_span_context.span_id = 0x123456 + mock_span_context.trace_flags = 0x01 + mock_span_context.is_valid = True + + mock_span = Mock() + mock_span.is_recording.return_value = True + mock_span.get_span_context.return_value = mock_span_context + + with patch("opentelemetry.trace.get_current_span") as mock_get_span: + mock_get_span.return_value = mock_span + + # Log a message + logger.info("Test with both IDs") + + # Get the logged output + output = stream.getvalue() + log_record = json.loads(output.strip()) + + # Should include both correlation ID and trace context + assert log_record.get("request_id") == "both-test-id" + assert "trace_id" in log_record + assert "span_id" in log_record + + clear_correlation_id() + + +def test_formatter_multiple_log_entries(logger_with_formatter): + """Test that formatter handles multiple log entries correctly.""" + logger, stream = logger_with_formatter + + # Log multiple messages with different correlation IDs + set_correlation_id("first-id") + logger.info("First message") + + set_correlation_id("second-id") + logger.info("Second message") + + clear_correlation_id() + logger.info("Third message") + + # Get all logged output + output = stream.getvalue() + log_lines = output.strip().split("\n") + + assert len(log_lines) == 3 + + # Parse each line + first_record = json.loads(log_lines[0]) + second_record = json.loads(log_lines[1]) + third_record = json.loads(log_lines[2]) + + # Verify correlation IDs + assert first_record.get("request_id") == "first-id" + assert second_record.get("request_id") == "second-id" + assert "request_id" not in third_record or third_record.get("request_id") is None + + +def test_formatter_process_id_and_hostname(logger_with_formatter): + """Test that formatter includes process ID and hostname.""" + logger, stream = logger_with_formatter + + # Log a message + logger.info("Process info test") + + # Get the logged output + output = stream.getvalue() + log_record = json.loads(output.strip()) + + # Check process_id and hostname + assert "process_id" in log_record + assert isinstance(log_record["process_id"], int) + assert log_record["process_id"] > 0 + + assert "hostname" in log_record + assert isinstance(log_record["hostname"], str) + assert len(log_record["hostname"]) > 0 + + +def test_formatter_handles_invalid_span_context(logger_with_formatter): + """Test that formatter handles invalid span context gracefully.""" + logger, stream = logger_with_formatter + + # Mock span with invalid context + mock_span_context = Mock() + mock_span_context.is_valid = False + + mock_span = Mock() + mock_span.is_recording.return_value = True + mock_span.get_span_context.return_value = mock_span_context + + with patch("opentelemetry.trace.get_current_span") as mock_get_span: + mock_get_span.return_value = mock_span + + # Log a message + logger.info("Test with invalid span") + + # Get the logged output + output = stream.getvalue() + log_record = json.loads(output.strip()) + + # Should not include trace context when invalid + assert "trace_id" not in log_record + assert "span_id" not in log_record + # But message should still be logged + assert log_record["message"] == "Test with invalid span" diff --git a/tests/unit/mcpgateway/utils/test_correlation_id.py b/tests/unit/mcpgateway/utils/test_correlation_id.py new file mode 100644 index 000000000..6b80ae163 --- /dev/null +++ b/tests/unit/mcpgateway/utils/test_correlation_id.py @@ -0,0 +1,216 @@ +# -*- coding: utf-8 -*- +"""Tests for correlation ID utilities.""" + +import asyncio +import pytest +from mcpgateway.utils.correlation_id import ( + clear_correlation_id, + extract_correlation_id_from_headers, + generate_correlation_id, + get_correlation_id, + get_or_generate_correlation_id, + set_correlation_id, + validate_correlation_id, +) + + +def test_generate_correlation_id(): + """Test correlation ID generation.""" + id1 = generate_correlation_id() + id2 = generate_correlation_id() + + assert id1 is not None + assert id2 is not None + assert id1 != id2 + assert len(id1) == 32 # UUID4 hex is 32 characters + assert len(id2) == 32 + + +def test_set_and_get_correlation_id(): + """Test setting and getting correlation ID.""" + test_id = "test-correlation-123" + + set_correlation_id(test_id) + retrieved_id = get_correlation_id() + + assert retrieved_id == test_id + + clear_correlation_id() + + +def test_clear_correlation_id(): + """Test clearing correlation ID.""" + test_id = "test-correlation-456" + + set_correlation_id(test_id) + assert get_correlation_id() == test_id + + clear_correlation_id() + assert get_correlation_id() is None + + +def test_get_correlation_id_returns_none_when_not_set(): + """Test getting correlation ID when not set.""" + clear_correlation_id() + assert get_correlation_id() is None + + +def test_extract_correlation_id_from_headers(): + """Test extracting correlation ID from headers.""" + headers = {"X-Correlation-ID": "header-correlation-789"} + + correlation_id = extract_correlation_id_from_headers(headers) + assert correlation_id == "header-correlation-789" + + +def test_extract_correlation_id_from_headers_case_insensitive(): + """Test case-insensitive header extraction.""" + headers = {"x-correlation-id": "lowercase-id"} + + correlation_id = extract_correlation_id_from_headers(headers) + assert correlation_id == "lowercase-id" + + +def test_extract_correlation_id_from_headers_custom_header(): + """Test extracting from custom header name.""" + headers = {"X-Request-ID": "custom-request-id"} + + correlation_id = extract_correlation_id_from_headers(headers, "X-Request-ID") + assert correlation_id == "custom-request-id" + + +def test_extract_correlation_id_from_headers_not_found(): + """Test when correlation ID header is not present.""" + headers = {"Content-Type": "application/json"} + + correlation_id = extract_correlation_id_from_headers(headers) + assert correlation_id is None + + +def test_extract_correlation_id_from_headers_empty_value(): + """Test when correlation ID header has empty value.""" + headers = {"X-Correlation-ID": " "} + + correlation_id = extract_correlation_id_from_headers(headers) + assert correlation_id is None + + +def test_get_or_generate_correlation_id_when_not_set(): + """Test get_or_generate when ID is not set.""" + clear_correlation_id() + + correlation_id = get_or_generate_correlation_id() + + assert correlation_id is not None + assert len(correlation_id) == 32 + assert get_correlation_id() == correlation_id # Should be stored + + clear_correlation_id() + + +def test_get_or_generate_correlation_id_when_already_set(): + """Test get_or_generate when ID is already set.""" + test_id = "existing-correlation-id" + set_correlation_id(test_id) + + correlation_id = get_or_generate_correlation_id() + + assert correlation_id == test_id + + clear_correlation_id() + + +def test_validate_correlation_id_valid(): + """Test validation of valid correlation IDs.""" + assert validate_correlation_id("abc-123") is True + assert validate_correlation_id("test_id_456") is True + assert validate_correlation_id("UPPER-lower-123_mix") is True + + +def test_validate_correlation_id_invalid(): + """Test validation of invalid correlation IDs.""" + assert validate_correlation_id(None) is False + assert validate_correlation_id("") is False + assert validate_correlation_id(" ") is False + assert validate_correlation_id("id with spaces") is False + assert validate_correlation_id("id@special!chars") is False + + +def test_validate_correlation_id_too_long(): + """Test validation rejects overly long IDs.""" + long_id = "a" * 256 # Default max is 255 + + assert validate_correlation_id(long_id) is False + assert validate_correlation_id(long_id, max_length=300) is True + + +@pytest.mark.asyncio +async def test_correlation_id_isolation_between_async_tasks(): + """Test that correlation IDs are isolated between concurrent async tasks.""" + results = [] + + async def task_with_id(task_id: str): + set_correlation_id(task_id) + await asyncio.sleep(0.01) # Simulate async work + retrieved_id = get_correlation_id() + results.append((task_id, retrieved_id)) + clear_correlation_id() + + # Run multiple tasks concurrently + await asyncio.gather( + task_with_id("task-1"), + task_with_id("task-2"), + task_with_id("task-3"), + ) + + # Each task should have retrieved its own ID + assert len(results) == 3 + for task_id, retrieved_id in results: + assert task_id == retrieved_id + + +@pytest.mark.asyncio +async def test_correlation_id_inheritance_in_nested_tasks(): + """Test that correlation ID is inherited by child async tasks.""" + + async def parent_task(): + set_correlation_id("parent-id") + parent_id = get_correlation_id() + + async def child_task(): + return get_correlation_id() + + child_id = await child_task() + + clear_correlation_id() + return parent_id, child_id + + parent_id, child_id = await parent_task() + + # Child should inherit parent's correlation ID + assert parent_id == "parent-id" + assert child_id == "parent-id" + + +def test_correlation_id_context_isolation(): + """Test that correlation ID is properly isolated per context.""" + clear_correlation_id() + + # Set ID in one context + set_correlation_id("context-1") + assert get_correlation_id() == "context-1" + + # Overwrite with new ID + set_correlation_id("context-2") + assert get_correlation_id() == "context-2" + + clear_correlation_id() + assert get_correlation_id() is None + + +def test_extract_correlation_id_strips_whitespace(): + """Test that extracted correlation ID has whitespace stripped.""" + headers = {"X-Correlation-ID": " trimmed-id "} + + correlation_id = extract_correlation_id_from_headers(headers) + assert correlation_id == "trimmed-id" From 29c103aa1e490fcd5baf11f54766aeb7973d5ce4 Mon Sep 17 00:00:00 2001 From: Shoumi Date: Fri, 14 Nov 2025 16:29:00 +0530 Subject: [PATCH 02/34] replace undefined bearer_scheme with security Signed-off-by: Shoumi --- mcpgateway/auth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mcpgateway/auth.py b/mcpgateway/auth.py index 629f4df46..edca68967 100644 --- a/mcpgateway/auth.py +++ b/mcpgateway/auth.py @@ -171,7 +171,7 @@ async def get_team_from_token(payload: Dict[str, Any], db: Session) -> Optional[ async def get_current_user( - credentials: Optional[HTTPAuthorizationCredentials] = Depends(bearer_scheme), + credentials: Optional[HTTPAuthorizationCredentials] = Depends(security), db: Session = Depends(get_db), request: Optional[object] = None, ) -> EmailUser: From 6bea3600e2a25d668cdbcceb0908f996308a44b1 Mon Sep 17 00:00:00 2001 From: Shoumi Date: Fri, 14 Nov 2025 17:30:01 +0530 Subject: [PATCH 03/34] lint & test fixes Signed-off-by: Shoumi --- mcpgateway/middleware/http_auth_middleware.py | 1 - mcpgateway/services/a2a_service.py | 1 - mcpgateway/services/logging_service.py | 34 +++++++++++-------- .../test_correlation_id_json_formatter.py | 8 ++--- 4 files changed, 24 insertions(+), 20 deletions(-) diff --git a/mcpgateway/middleware/http_auth_middleware.py b/mcpgateway/middleware/http_auth_middleware.py index 36f987e3c..8b73ffacd 100644 --- a/mcpgateway/middleware/http_auth_middleware.py +++ b/mcpgateway/middleware/http_auth_middleware.py @@ -8,7 +8,6 @@ # Standard import logging -import uuid # Third-Party from fastapi import Request diff --git a/mcpgateway/services/a2a_service.py b/mcpgateway/services/a2a_service.py index a410f18b7..33f3468d0 100644 --- a/mcpgateway/services/a2a_service.py +++ b/mcpgateway/services/a2a_service.py @@ -28,7 +28,6 @@ from mcpgateway.services.logging_service import LoggingService from mcpgateway.services.team_management_service import TeamManagementService from mcpgateway.services.tool_service import ToolService -from mcpgateway.utils.correlation_id import get_correlation_id from mcpgateway.utils.create_slug import slugify from mcpgateway.utils.services_auth import encode_auth # ,decode_auth diff --git a/mcpgateway/services/logging_service.py b/mcpgateway/services/logging_service.py index 0cad12837..0140da796 100644 --- a/mcpgateway/services/logging_service.py +++ b/mcpgateway/services/logging_service.py @@ -27,6 +27,13 @@ from mcpgateway.services.log_storage_service import LogStorageService from mcpgateway.utils.correlation_id import get_correlation_id +# Optional OpenTelemetry support +try: + # Third-Party + from opentelemetry import trace # type: ignore[import-untyped] +except ImportError: + trace = None # type: ignore[assignment] + AnyioClosedResourceError: Optional[type] # pylint: disable=invalid-name try: # Optional import; only used for filtering a known benign upstream error @@ -72,20 +79,19 @@ def add_fields(self, log_record: dict, record: logging.LogRecord, message_dict: log_record["request_id"] = correlation_id # Add OpenTelemetry trace context if available - try: - from opentelemetry import trace - - span = trace.get_current_span() - if span and span.is_recording(): - span_context = span.get_span_context() - if span_context.is_valid: - # Format trace_id and span_id as hex strings - log_record["trace_id"] = format(span_context.trace_id, "032x") - log_record["span_id"] = format(span_context.span_id, "016x") - log_record["trace_flags"] = format(span_context.trace_flags, "02x") - except (ImportError, Exception): - # OpenTelemetry not available or error accessing span - pass + if trace is not None: + try: + span = trace.get_current_span() + if span and span.is_recording(): + span_context = span.get_span_context() + if span_context.is_valid: + # Format trace_id and span_id as hex strings + log_record["trace_id"] = format(span_context.trace_id, "032x") + log_record["span_id"] = format(span_context.span_id, "016x") + log_record["trace_flags"] = format(span_context.trace_flags, "02x") + except Exception: + # Error accessing span context + pass # Create a JSON formatter with correlation ID support diff --git a/tests/unit/mcpgateway/services/test_correlation_id_json_formatter.py b/tests/unit/mcpgateway/services/test_correlation_id_json_formatter.py index 55a7e6137..de6ce9220 100644 --- a/tests/unit/mcpgateway/services/test_correlation_id_json_formatter.py +++ b/tests/unit/mcpgateway/services/test_correlation_id_json_formatter.py @@ -208,8 +208,8 @@ def test_formatter_correlation_id_with_trace_context(logger_with_formatter): mock_span.is_recording.return_value = True mock_span.get_span_context.return_value = mock_span_context - with patch("opentelemetry.trace.get_current_span") as mock_get_span: - mock_get_span.return_value = mock_span + with patch("mcpgateway.services.logging_service.trace") as mock_trace: + mock_trace.get_current_span.return_value = mock_span # Log a message logger.info("Test with both IDs") @@ -290,8 +290,8 @@ def test_formatter_handles_invalid_span_context(logger_with_formatter): mock_span.is_recording.return_value = True mock_span.get_span_context.return_value = mock_span_context - with patch("opentelemetry.trace.get_current_span") as mock_get_span: - mock_get_span.return_value = mock_span + with patch("mcpgateway.services.logging_service.trace") as mock_trace: + mock_trace.get_current_span.return_value = mock_span # Log a message logger.info("Test with invalid span") From 9caa57f3f025d5c10f430a2b11064b30ed61ce9e Mon Sep 17 00:00:00 2001 From: Shoumi Date: Fri, 14 Nov 2025 18:00:03 +0530 Subject: [PATCH 04/34] fixes for lint Signed-off-by: Shoumi --- mcpgateway/observability.py | 13 ++++++------- mcpgateway/services/logging_service.py | 11 +++-------- mcpgateway/services/tool_service.py | 4 +--- 3 files changed, 10 insertions(+), 18 deletions(-) diff --git a/mcpgateway/observability.py b/mcpgateway/observability.py index 31ab246b0..82016b4f3 100644 --- a/mcpgateway/observability.py +++ b/mcpgateway/observability.py @@ -15,10 +15,9 @@ import os from typing import Any, Callable, cast, Dict, Optional -# Try to import OpenTelemetry core components - make them truly optional +# Third-Party - Try to import OpenTelemetry core components - make them truly optional OTEL_AVAILABLE = False try: - # Third-Party from opentelemetry import trace from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace import TracerProvider @@ -93,6 +92,9 @@ class _ConsoleSpanExporterStub: # pragma: no cover - test patch replaces this # Shimming is a non-critical, best-effort step for tests; log and continue. logging.getLogger(__name__).debug("Skipping OpenTelemetry shim setup: %s", exc) +# First-Party +from mcpgateway.utils.correlation_id import get_correlation_id # noqa: E402 + # Try to import optional exporters try: OTLP_SPAN_EXPORTER = getattr(_im("opentelemetry.exporter.otlp.proto.grpc.trace_exporter"), "OTLPSpanExporter") @@ -442,9 +444,6 @@ def create_span(name: str, attributes: Optional[Dict[str, Any]] = None) -> Any: # Auto-inject correlation ID into all spans for request tracing try: - # Import here to avoid circular dependency - from mcpgateway.utils.correlation_id import get_correlation_id - correlation_id = get_correlation_id() if correlation_id: if attributes is None: @@ -454,8 +453,8 @@ def create_span(name: str, attributes: Optional[Dict[str, Any]] = None) -> Any: attributes["correlation_id"] = correlation_id if "request_id" not in attributes: attributes["request_id"] = correlation_id # Alias for compatibility - except ImportError: - # Correlation ID module not available, continue without it + except Exception: + # Correlation ID not available or error getting it, continue without it pass # Start span and return the context manager diff --git a/mcpgateway/services/logging_service.py b/mcpgateway/services/logging_service.py index 0140da796..58f95224d 100644 --- a/mcpgateway/services/logging_service.py +++ b/mcpgateway/services/logging_service.py @@ -16,6 +16,7 @@ import logging from logging.handlers import RotatingFileHandler import os +import socket from typing import Any, AsyncGenerator, Dict, List, NotRequired, Optional, TextIO, TypedDict # Third-Party @@ -27,17 +28,15 @@ from mcpgateway.services.log_storage_service import LogStorageService from mcpgateway.utils.correlation_id import get_correlation_id -# Optional OpenTelemetry support +# Optional OpenTelemetry support (Third-Party) try: - # Third-Party from opentelemetry import trace # type: ignore[import-untyped] except ImportError: trace = None # type: ignore[assignment] AnyioClosedResourceError: Optional[type] # pylint: disable=invalid-name try: - # Optional import; only used for filtering a known benign upstream error - # Third-Party + # Optional import; only used for filtering a known benign upstream error (Third-Party) from anyio import ClosedResourceError as AnyioClosedResourceError # pylint: disable=invalid-name except Exception: # pragma: no cover - environment without anyio AnyioClosedResourceError = None # pylint: disable=invalid-name @@ -62,10 +61,6 @@ def add_fields(self, log_record: dict, record: logging.LogRecord, message_dict: super().add_fields(log_record, record, message_dict) # Add timestamp in ISO 8601 format with 'Z' suffix for UTC - import os - import socket - from datetime import datetime, timezone - dt = datetime.fromtimestamp(record.created, tz=timezone.utc) log_record["@timestamp"] = dt.isoformat().replace("+00:00", "Z") diff --git a/mcpgateway/services/tool_service.py b/mcpgateway/services/tool_service.py index 88f23ddc0..6fb8a9454 100644 --- a/mcpgateway/services/tool_service.py +++ b/mcpgateway/services/tool_service.py @@ -26,9 +26,6 @@ from urllib.parse import parse_qs, urlparse import uuid -# First-Party (early import for correlation_id) -from mcpgateway.utils.correlation_id import get_correlation_id - # Third-Party import httpx import jq @@ -46,6 +43,7 @@ from mcpgateway.common.models import Tool as PydanticTool from mcpgateway.common.models import ToolResult from mcpgateway.config import settings +from mcpgateway.utils.correlation_id import get_correlation_id from mcpgateway.db import A2AAgent as DbA2AAgent from mcpgateway.db import EmailTeam from mcpgateway.db import Gateway as DbGateway From 9af6c9f9ae902ff0f7b07ae30021a7f4181ee530 Mon Sep 17 00:00:00 2001 From: Shoumi Date: Fri, 14 Nov 2025 18:13:42 +0530 Subject: [PATCH 05/34] pylint fixes Signed-off-by: Shoumi --- mcpgateway/observability.py | 2 +- mcpgateway/services/logging_service.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mcpgateway/observability.py b/mcpgateway/observability.py index 82016b4f3..7694ab398 100644 --- a/mcpgateway/observability.py +++ b/mcpgateway/observability.py @@ -93,7 +93,7 @@ class _ConsoleSpanExporterStub: # pragma: no cover - test patch replaces this logging.getLogger(__name__).debug("Skipping OpenTelemetry shim setup: %s", exc) # First-Party -from mcpgateway.utils.correlation_id import get_correlation_id # noqa: E402 +from mcpgateway.utils.correlation_id import get_correlation_id # noqa: E402 # pylint: disable=wrong-import-position # Try to import optional exporters try: diff --git a/mcpgateway/services/logging_service.py b/mcpgateway/services/logging_service.py index 58f95224d..e81b3929b 100644 --- a/mcpgateway/services/logging_service.py +++ b/mcpgateway/services/logging_service.py @@ -49,7 +49,7 @@ class CorrelationIdJsonFormatter(jsonlogger.JsonFormatter): """JSON formatter that includes correlation ID and OpenTelemetry trace context.""" - def add_fields(self, log_record: dict, record: logging.LogRecord, message_dict: dict) -> None: + def add_fields(self, log_record: dict, record: logging.LogRecord, message_dict: dict) -> None: # pylint: disable=arguments-renamed """Add custom fields to the log record. Args: From e7f543f5c5de7d0652ae8d9d54a2a4575ecc2fd5 Mon Sep 17 00:00:00 2001 From: Shoumi Date: Fri, 14 Nov 2025 18:40:51 +0530 Subject: [PATCH 06/34] test fixes Signed-off-by: Shoumi --- .../mcpgateway/services/test_correlation_id_json_formatter.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/mcpgateway/services/test_correlation_id_json_formatter.py b/tests/unit/mcpgateway/services/test_correlation_id_json_formatter.py index de6ce9220..337e23f27 100644 --- a/tests/unit/mcpgateway/services/test_correlation_id_json_formatter.py +++ b/tests/unit/mcpgateway/services/test_correlation_id_json_formatter.py @@ -110,8 +110,8 @@ def test_formatter_includes_opentelemetry_trace_context(logger_with_formatter): mock_span.is_recording.return_value = True mock_span.get_span_context.return_value = mock_span_context - with patch("opentelemetry.trace.get_current_span") as mock_get_span: - mock_get_span.return_value = mock_span + with patch("mcpgateway.services.logging_service.trace") as mock_trace: + mock_trace.get_current_span.return_value = mock_span # Log a message logger.info("Test with trace context") From abba8d56581af0d6af4fce3031979eaedda51b54 Mon Sep 17 00:00:00 2001 From: Shoumi Date: Fri, 14 Nov 2025 18:57:54 +0530 Subject: [PATCH 07/34] Bandit fixes Signed-off-by: Shoumi --- mcpgateway/observability.py | 4 ++-- mcpgateway/services/logging_service.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mcpgateway/observability.py b/mcpgateway/observability.py index 7694ab398..459e57f18 100644 --- a/mcpgateway/observability.py +++ b/mcpgateway/observability.py @@ -453,9 +453,9 @@ def create_span(name: str, attributes: Optional[Dict[str, Any]] = None) -> Any: attributes["correlation_id"] = correlation_id if "request_id" not in attributes: attributes["request_id"] = correlation_id # Alias for compatibility - except Exception: + except Exception as exc: # Correlation ID not available or error getting it, continue without it - pass + logger.debug("Failed to add correlation_id to span: %s", exc) # Start span and return the context manager span_context = _TRACER.start_as_current_span(name) diff --git a/mcpgateway/services/logging_service.py b/mcpgateway/services/logging_service.py index e81b3929b..ef39abfde 100644 --- a/mcpgateway/services/logging_service.py +++ b/mcpgateway/services/logging_service.py @@ -84,8 +84,8 @@ def add_fields(self, log_record: dict, record: logging.LogRecord, message_dict: log_record["trace_id"] = format(span_context.trace_id, "032x") log_record["span_id"] = format(span_context.span_id, "016x") log_record["trace_flags"] = format(span_context.trace_flags, "02x") - except Exception: - # Error accessing span context + except Exception: # nosec B110 - intentionally catching all exceptions for optional tracing + # Error accessing span context, continue without trace fields pass From 220dfe875f7e0aeec8082d3c6461bff96a6f6b58 Mon Sep 17 00:00:00 2001 From: Shoumi Date: Thu, 20 Nov 2025 15:16:23 +0530 Subject: [PATCH 08/34] fix for test Signed-off-by: Shoumi --- .../middleware/request_logging_middleware.py | 27 +++++++++++++------ 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/mcpgateway/middleware/request_logging_middleware.py b/mcpgateway/middleware/request_logging_middleware.py index f0dc10805..5a6ebe3a4 100644 --- a/mcpgateway/middleware/request_logging_middleware.py +++ b/mcpgateway/middleware/request_logging_middleware.py @@ -175,14 +175,25 @@ async def dispatch(self, request: Request, call_next: Callable): # Get correlation ID for request tracking request_id = get_correlation_id() - logger.log( - log_level, - f"📩 Incoming request: {request.method} {request.url.path}\n" - f"Query params: {dict(request.query_params)}\n" - f"Headers: {masked_headers}\n" - f"Body: {payload_str}{'... [truncated]' if truncated else ''}", - extra={"request_id": request_id}, - ) + # Try to log with extra parameter, fall back to without if not supported + try: + logger.log( + log_level, + f"📩 Incoming request: {request.method} {request.url.path}\n" + f"Query params: {dict(request.query_params)}\n" + f"Headers: {masked_headers}\n" + f"Body: {payload_str}{'... [truncated]' if truncated else ''}", + extra={"request_id": request_id}, + ) + except TypeError: + # Fall back for test loggers that don't accept extra parameter + logger.log( + log_level, + f"📩 Incoming request: {request.method} {request.url.path}\n" + f"Query params: {dict(request.query_params)}\n" + f"Headers: {masked_headers}\n" + f"Body: {payload_str}{'... [truncated]' if truncated else ''}", + ) except Exception as e: logger.warning(f"Failed to log request body: {e}") From 1f8c6733b2e19aa62c6aee82510bcfdff23f821b Mon Sep 17 00:00:00 2001 From: Shoumi Date: Mon, 24 Nov 2025 12:40:09 +0530 Subject: [PATCH 09/34] addiitonal changes for UI & middleware Signed-off-by: Shoumi --- mcpgateway/admin.py | 117 ++ ...6f7g8h9i0_add_structured_logging_tables.py | 204 ++++ mcpgateway/config.py | 37 + mcpgateway/db.py | 246 +++++ mcpgateway/main.py | 31 +- mcpgateway/middleware/auth_middleware.py | 25 + .../middleware/request_logging_middleware.py | 160 ++- mcpgateway/routers/log_search.py | 605 +++++++++++ mcpgateway/services/a2a_service.py | 83 ++ mcpgateway/services/audit_trail_service.py | 425 ++++++++ mcpgateway/services/log_aggregator.py | 399 +++++++ mcpgateway/services/performance_tracker.py | 324 ++++++ mcpgateway/services/security_logger.py | 640 +++++++++++ mcpgateway/services/server_service.py | 240 ++++- mcpgateway/services/structured_logger.py | 408 +++++++ mcpgateway/services/tool_service.py | 194 +++- mcpgateway/static/admin.js | 996 ++++++++++++++++++ mcpgateway/templates/admin.html | 271 ++--- 18 files changed, 5193 insertions(+), 212 deletions(-) create mode 100644 mcpgateway/alembic/versions/k5e6f7g8h9i0_add_structured_logging_tables.py create mode 100644 mcpgateway/routers/log_search.py create mode 100644 mcpgateway/services/audit_trail_service.py create mode 100644 mcpgateway/services/log_aggregator.py create mode 100644 mcpgateway/services/performance_tracker.py create mode 100644 mcpgateway/services/security_logger.py create mode 100644 mcpgateway/services/structured_logger.py diff --git a/mcpgateway/admin.py b/mcpgateway/admin.py index d5f981875..22e5f8013 100644 --- a/mcpgateway/admin.py +++ b/mcpgateway/admin.py @@ -118,6 +118,8 @@ from mcpgateway.services.plugin_service import get_plugin_service from mcpgateway.services.prompt_service import PromptNameConflictError, PromptNotFoundError, PromptService from mcpgateway.services.resource_service import ResourceNotFoundError, ResourceService, ResourceURIConflictError +from mcpgateway.services.structured_logger import get_structured_logger +from mcpgateway.services.audit_trail_service import get_audit_trail_service from mcpgateway.services.root_service import RootService from mcpgateway.services.server_service import ServerError, ServerNameConflictError, ServerNotFoundError, ServerService from mcpgateway.services.tag_service import TagService @@ -12536,6 +12538,7 @@ async def list_plugins( HTTPException: If there's an error retrieving plugins """ LOGGER.debug(f"User {get_user_email(user)} requested plugin list") + structured_logger = get_structured_logger() try: # Get plugin service @@ -12556,10 +12559,41 @@ async def list_plugins( enabled_count = sum(1 for p in plugins if p["status"] == "enabled") disabled_count = sum(1 for p in plugins if p["status"] == "disabled") + # Log plugin marketplace browsing activity + structured_logger.info( + f"User browsed plugin marketplace", + user_id=str(user.id), + user_email=get_user_email(user), + component="plugin_marketplace", + category="business_logic", + resource_type="plugin_list", + resource_action="browse", + custom_fields={ + "search_query": search, + "filter_mode": mode, + "filter_hook": hook, + "filter_tag": tag, + "results_count": len(plugins), + "enabled_count": enabled_count, + "disabled_count": disabled_count, + "has_filters": any([search, mode, hook, tag]) + }, + db=db + ) + return PluginListResponse(plugins=plugins, total=len(plugins), enabled_count=enabled_count, disabled_count=disabled_count) except Exception as e: LOGGER.error(f"Error listing plugins: {e}") + structured_logger.error( + f"Failed to list plugins in marketplace", + user_id=str(user.id), + user_email=get_user_email(user), + error=e, + component="plugin_marketplace", + category="business_logic", + db=db + ) raise HTTPException(status_code=500, detail=str(e)) @@ -12579,6 +12613,7 @@ async def get_plugin_stats(request: Request, db: Session = Depends(get_db), user HTTPException: If there's an error getting plugin statistics """ LOGGER.debug(f"User {get_user_email(user)} requested plugin statistics") + structured_logger = get_structured_logger() try: # Get plugin service @@ -12592,10 +12627,39 @@ async def get_plugin_stats(request: Request, db: Session = Depends(get_db), user # Get statistics stats = plugin_service.get_plugin_statistics() + # Log marketplace analytics access + structured_logger.info( + f"User accessed plugin marketplace statistics", + user_id=str(user.id), + user_email=get_user_email(user), + component="plugin_marketplace", + category="business_logic", + resource_type="plugin_stats", + resource_action="view", + custom_fields={ + "total_plugins": stats.get("total_plugins", 0), + "enabled_plugins": stats.get("enabled_plugins", 0), + "disabled_plugins": stats.get("disabled_plugins", 0), + "hooks_count": len(stats.get("plugins_by_hook", {})), + "tags_count": len(stats.get("plugins_by_tag", {})), + "authors_count": len(stats.get("plugins_by_author", {})) + }, + db=db + ) + return PluginStatsResponse(**stats) except Exception as e: LOGGER.error(f"Error getting plugin statistics: {e}") + structured_logger.error( + f"Failed to get plugin marketplace statistics", + user_id=str(user.id), + user_email=get_user_email(user), + error=e, + component="plugin_marketplace", + category="business_logic", + db=db + ) raise HTTPException(status_code=500, detail=str(e)) @@ -12616,6 +12680,8 @@ async def get_plugin_details(name: str, request: Request, db: Session = Depends( HTTPException: If plugin not found """ LOGGER.debug(f"User {get_user_email(user)} requested details for plugin {name}") + structured_logger = get_structured_logger() + audit_service = get_audit_trail_service() try: # Get plugin service @@ -12630,14 +12696,65 @@ async def get_plugin_details(name: str, request: Request, db: Session = Depends( plugin = plugin_service.get_plugin_by_name(name) if not plugin: + structured_logger.warning( + f"Plugin '{name}' not found in marketplace", + user_id=str(user.id), + user_email=get_user_email(user), + component="plugin_marketplace", + category="business_logic", + custom_fields={"plugin_name": name, "action": "view_details"}, + db=db + ) raise HTTPException(status_code=404, detail=f"Plugin '{name}' not found") + # Log plugin view activity + structured_logger.info( + f"User viewed plugin details: '{name}'", + user_id=str(user.id), + user_email=get_user_email(user), + component="plugin_marketplace", + category="business_logic", + resource_type="plugin", + resource_id=name, + resource_action="view_details", + custom_fields={ + "plugin_name": name, + "plugin_version": plugin.get("version"), + "plugin_author": plugin.get("author"), + "plugin_status": plugin.get("status"), + "plugin_mode": plugin.get("mode"), + "plugin_hooks": plugin.get("hooks", []), + "plugin_tags": plugin.get("tags", []) + }, + db=db + ) + + # Create audit trail for plugin access + audit_service.log_audit( + user_id=str(user.id), + user_email=get_user_email(user), + resource_type="plugin", + resource_id=name, + action="view", + description=f"Viewed plugin '{name}' details in marketplace", + db=db + ) + return PluginDetail(**plugin) except HTTPException: raise except Exception as e: LOGGER.error(f"Error getting plugin details: {e}") + structured_logger.error( + f"Failed to get plugin details: '{name}'", + user_id=str(user.id), + user_email=get_user_email(user), + error=e, + component="plugin_marketplace", + category="business_logic", + db=db + ) raise HTTPException(status_code=500, detail=str(e)) diff --git a/mcpgateway/alembic/versions/k5e6f7g8h9i0_add_structured_logging_tables.py b/mcpgateway/alembic/versions/k5e6f7g8h9i0_add_structured_logging_tables.py new file mode 100644 index 000000000..dae7b6266 --- /dev/null +++ b/mcpgateway/alembic/versions/k5e6f7g8h9i0_add_structured_logging_tables.py @@ -0,0 +1,204 @@ +"""Add structured logging tables + +Revision ID: k5e6f7g8h9i0 +Revises: f3a3a3d901b8 +Create Date: 2025-01-15 12:00:00.000000 + +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = 'k5e6f7g8h9i0' +down_revision = 'f3a3a3d901b8' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + """Add structured logging tables.""" + # Create structured_log_entries table + op.create_table( + 'structured_log_entries', + sa.Column('id', sa.String(36), nullable=False), + sa.Column('timestamp', sa.DateTime(timezone=True), nullable=False), + sa.Column('correlation_id', sa.String(64), nullable=True), + sa.Column('request_id', sa.String(64), nullable=True), + sa.Column('level', sa.String(20), nullable=False), + sa.Column('component', sa.String(100), nullable=False), + sa.Column('message', sa.Text(), nullable=False), + sa.Column('logger', sa.String(255), nullable=True), + sa.Column('user_id', sa.String(255), nullable=True), + sa.Column('user_email', sa.String(255), nullable=True), + sa.Column('client_ip', sa.String(45), nullable=True), + sa.Column('user_agent', sa.Text(), nullable=True), + sa.Column('request_path', sa.String(500), nullable=True), + sa.Column('request_method', sa.String(10), nullable=True), + sa.Column('duration_ms', sa.Float(), nullable=True), + sa.Column('operation_type', sa.String(100), nullable=True), + sa.Column('is_security_event', sa.Boolean(), nullable=False, server_default='0'), + sa.Column('security_severity', sa.String(20), nullable=True), + sa.Column('threat_indicators', sa.JSON(), nullable=True), + sa.Column('context', sa.JSON(), nullable=True), + sa.Column('error_details', sa.JSON(), nullable=True), + sa.Column('performance_metrics', sa.JSON(), nullable=True), + sa.Column('hostname', sa.String(255), nullable=False), + sa.Column('process_id', sa.Integer(), nullable=False), + sa.Column('thread_id', sa.Integer(), nullable=True), + sa.Column('version', sa.String(50), nullable=False), + sa.Column('environment', sa.String(50), nullable=False, server_default='production'), + sa.Column('trace_id', sa.String(32), nullable=True), + sa.Column('span_id', sa.String(16), nullable=True), + sa.PrimaryKeyConstraint('id') + ) + + # Create indexes for structured_log_entries + op.create_index('ix_structured_log_entries_timestamp', 'structured_log_entries', ['timestamp'], unique=False) + op.create_index('ix_structured_log_entries_level', 'structured_log_entries', ['level'], unique=False) + op.create_index('ix_structured_log_entries_component', 'structured_log_entries', ['component'], unique=False) + op.create_index('ix_structured_log_entries_correlation_id', 'structured_log_entries', ['correlation_id'], unique=False) + op.create_index('ix_structured_log_entries_request_id', 'structured_log_entries', ['request_id'], unique=False) + op.create_index('ix_structured_log_entries_user_id', 'structured_log_entries', ['user_id'], unique=False) + op.create_index('ix_structured_log_entries_user_email', 'structured_log_entries', ['user_email'], unique=False) + op.create_index('ix_structured_log_entries_operation_type', 'structured_log_entries', ['operation_type'], unique=False) + op.create_index('ix_structured_log_entries_is_security_event', 'structured_log_entries', ['is_security_event'], unique=False) + op.create_index('ix_structured_log_entries_security_severity', 'structured_log_entries', ['security_severity'], unique=False) + op.create_index('ix_structured_log_entries_trace_id', 'structured_log_entries', ['trace_id'], unique=False) + + # Composite indexes matching db.py + op.create_index('idx_log_correlation_time', 'structured_log_entries', ['correlation_id', 'timestamp'], unique=False) + op.create_index('idx_log_user_time', 'structured_log_entries', ['user_id', 'timestamp'], unique=False) + op.create_index('idx_log_level_time', 'structured_log_entries', ['level', 'timestamp'], unique=False) + op.create_index('idx_log_component_time', 'structured_log_entries', ['component', 'timestamp'], unique=False) + op.create_index('idx_log_security', 'structured_log_entries', ['is_security_event', 'security_severity', 'timestamp'], unique=False) + op.create_index('idx_log_operation', 'structured_log_entries', ['operation_type', 'timestamp'], unique=False) + op.create_index('idx_log_trace', 'structured_log_entries', ['trace_id', 'timestamp'], unique=False) + + # Create performance_metrics table + op.create_table( + 'performance_metrics', + sa.Column('id', sa.String(36), nullable=False), + sa.Column('timestamp', sa.DateTime(timezone=True), nullable=False), + sa.Column('operation_type', sa.String(100), nullable=False), + sa.Column('component', sa.String(100), nullable=False), + sa.Column('request_count', sa.Integer(), nullable=False, server_default='0'), + sa.Column('error_count', sa.Integer(), nullable=False, server_default='0'), + sa.Column('error_rate', sa.Float(), nullable=False, server_default='0.0'), + sa.Column('avg_duration_ms', sa.Float(), nullable=False), + sa.Column('min_duration_ms', sa.Float(), nullable=False), + sa.Column('max_duration_ms', sa.Float(), nullable=False), + sa.Column('p50_duration_ms', sa.Float(), nullable=False), + sa.Column('p95_duration_ms', sa.Float(), nullable=False), + sa.Column('p99_duration_ms', sa.Float(), nullable=False), + sa.Column('window_start', sa.DateTime(timezone=True), nullable=False), + sa.Column('window_end', sa.DateTime(timezone=True), nullable=False), + sa.Column('window_duration_seconds', sa.Integer(), nullable=False), + sa.Column('metric_metadata', sa.JSON(), nullable=True), + sa.PrimaryKeyConstraint('id') + ) + + # Create indexes for performance_metrics + op.create_index('ix_performance_metrics_timestamp', 'performance_metrics', ['timestamp'], unique=False) + op.create_index('ix_performance_metrics_component', 'performance_metrics', ['component'], unique=False) + op.create_index('ix_performance_metrics_operation_type', 'performance_metrics', ['operation_type'], unique=False) + op.create_index('ix_performance_metrics_window_start', 'performance_metrics', ['window_start'], unique=False) + op.create_index('idx_perf_operation_time', 'performance_metrics', ['operation_type', 'window_start'], unique=False) + op.create_index('idx_perf_component_time', 'performance_metrics', ['component', 'window_start'], unique=False) + op.create_index('idx_perf_window', 'performance_metrics', ['window_start', 'window_end'], unique=False) + + # Create security_events table + op.create_table( + 'security_events', + sa.Column('id', sa.String(36), nullable=False), + sa.Column('timestamp', sa.DateTime(timezone=True), nullable=False), + sa.Column('detected_at', sa.DateTime(timezone=True), nullable=False), + sa.Column('event_type', sa.String(100), nullable=False), + sa.Column('severity', sa.String(20), nullable=False), + sa.Column('category', sa.String(100), nullable=False), + sa.Column('user_id', sa.String(255), nullable=True), + sa.Column('user_email', sa.String(255), nullable=True), + sa.Column('client_ip', sa.String(45), nullable=False), + sa.Column('user_agent', sa.Text(), nullable=True), + sa.Column('description', sa.Text(), nullable=False), + sa.Column('action_taken', sa.String(100), nullable=True), + sa.Column('threat_score', sa.Float(), nullable=False, server_default='0.0'), + sa.Column('threat_indicators', sa.JSON(), nullable=True), + sa.Column('failed_attempts_count', sa.Integer(), nullable=False, server_default='0'), + sa.Column('context', sa.JSON(), nullable=True), + sa.Column('correlation_id', sa.String(255), nullable=True), + sa.Column('resolved', sa.Boolean(), nullable=False, server_default='false'), + sa.Column('resolved_at', sa.DateTime(timezone=True), nullable=True), + sa.Column('resolved_by', sa.String(255), nullable=True), + sa.Column('resolution_notes', sa.Text(), nullable=True), + sa.Column('alert_sent', sa.Boolean(), nullable=False, server_default='false'), + sa.Column('alert_sent_at', sa.DateTime(timezone=True), nullable=True), + sa.PrimaryKeyConstraint('id') + ) + + # Create indexes for security_events + op.create_index('ix_security_events_timestamp', 'security_events', ['timestamp'], unique=False) + op.create_index('ix_security_events_detected_at', 'security_events', ['detected_at'], unique=False) + op.create_index('ix_security_events_correlation_id', 'security_events', ['correlation_id'], unique=False) + op.create_index('ix_security_events_event_type', 'security_events', ['event_type'], unique=False) + op.create_index('ix_security_events_severity', 'security_events', ['severity'], unique=False) + op.create_index('ix_security_events_category', 'security_events', ['category'], unique=False) + op.create_index('ix_security_events_user_id', 'security_events', ['user_id'], unique=False) + op.create_index('ix_security_events_user_email', 'security_events', ['user_email'], unique=False) + op.create_index('ix_security_events_client_ip', 'security_events', ['client_ip'], unique=False) + op.create_index('idx_security_event_time', 'security_events', ['event_type', 'timestamp'], unique=False) + op.create_index('idx_security_severity_time', 'security_events', ['severity', 'timestamp'], unique=False) + op.create_index('idx_security_user_time', 'security_events', ['user_id', 'timestamp'], unique=False) + + # Create audit_trails table + op.create_table( + 'audit_trails', + sa.Column('id', sa.String(36), nullable=False), + sa.Column('timestamp', sa.DateTime(timezone=True), nullable=False), + sa.Column('correlation_id', sa.String(64), nullable=True), + sa.Column('action', sa.String(50), nullable=False), + sa.Column('resource_type', sa.String(100), nullable=False), + sa.Column('resource_id', sa.String(255), nullable=False), + sa.Column('resource_name', sa.String(500), nullable=True), + sa.Column('user_id', sa.String(255), nullable=False), + sa.Column('user_email', sa.String(255), nullable=True), + sa.Column('team_id', sa.String(36), nullable=True), + sa.Column('client_ip', sa.String(45), nullable=True), + sa.Column('user_agent', sa.Text(), nullable=True), + sa.Column('request_path', sa.String(500), nullable=True), + sa.Column('request_method', sa.String(10), nullable=True), + sa.Column('old_values', sa.JSON(), nullable=True), + sa.Column('new_values', sa.JSON(), nullable=True), + sa.Column('changes', sa.JSON(), nullable=True), + sa.Column('data_classification', sa.String(50), nullable=True), + sa.Column('requires_review', sa.Boolean(), nullable=False, server_default='0'), + sa.Column('success', sa.Boolean(), nullable=False), + sa.Column('error_message', sa.Text(), nullable=True), + sa.Column('context', sa.JSON(), nullable=True), + sa.PrimaryKeyConstraint('id') + ) + + # Create indexes for audit_trails + op.create_index('ix_audit_trails_timestamp', 'audit_trails', ['timestamp'], unique=False) + op.create_index('ix_audit_trails_correlation_id', 'audit_trails', ['correlation_id'], unique=False) + op.create_index('ix_audit_trails_action', 'audit_trails', ['action'], unique=False) + op.create_index('ix_audit_trails_resource_type', 'audit_trails', ['resource_type'], unique=False) + op.create_index('ix_audit_trails_resource_id', 'audit_trails', ['resource_id'], unique=False) + op.create_index('ix_audit_trails_user_id', 'audit_trails', ['user_id'], unique=False) + op.create_index('ix_audit_trails_user_email', 'audit_trails', ['user_email'], unique=False) + op.create_index('ix_audit_trails_team_id', 'audit_trails', ['team_id'], unique=False) + op.create_index('ix_audit_trails_data_classification', 'audit_trails', ['data_classification'], unique=False) + op.create_index('ix_audit_trails_requires_review', 'audit_trails', ['requires_review'], unique=False) + op.create_index('ix_audit_trails_success', 'audit_trails', ['success'], unique=False) + op.create_index('idx_audit_action_time', 'audit_trails', ['action', 'timestamp'], unique=False) + op.create_index('idx_audit_resource_time', 'audit_trails', ['resource_type', 'resource_id', 'timestamp'], unique=False) + op.create_index('idx_audit_user_time', 'audit_trails', ['user_id', 'timestamp'], unique=False) + op.create_index('idx_audit_classification', 'audit_trails', ['data_classification', 'timestamp'], unique=False) + op.create_index('idx_audit_review', 'audit_trails', ['requires_review', 'timestamp'], unique=False) + + +def downgrade() -> None: + """Remove structured logging tables.""" + op.drop_table('audit_trails') + op.drop_table('security_events') + op.drop_table('performance_metrics') + op.drop_table('structured_log_entries') diff --git a/mcpgateway/config.py b/mcpgateway/config.py index 8b541f096..1d521dcb6 100644 --- a/mcpgateway/config.py +++ b/mcpgateway/config.py @@ -782,6 +782,43 @@ def _parse_allowed_origins(cls, v: Any) -> Set[str]: correlation_id_preserve: bool = Field(default=True, description="Preserve correlation IDs from incoming requests") correlation_id_response_header: bool = Field(default=True, description="Include correlation ID in response headers") + # Structured Logging Configuration + structured_logging_enabled: bool = Field(default=True, description="Enable structured JSON logging with database persistence") + structured_logging_database_enabled: bool = Field(default=True, description="Persist structured logs to database") + structured_logging_external_enabled: bool = Field(default=False, description="Send logs to external systems") + + # Performance Tracking Configuration + performance_tracking_enabled: bool = Field(default=True, description="Enable performance tracking and metrics") + performance_threshold_database_query_ms: float = Field(default=100.0, description="Alert threshold for database queries (ms)") + performance_threshold_tool_invocation_ms: float = Field(default=2000.0, description="Alert threshold for tool invocations (ms)") + performance_threshold_resource_read_ms: float = Field(default=1000.0, description="Alert threshold for resource reads (ms)") + performance_threshold_http_request_ms: float = Field(default=500.0, description="Alert threshold for HTTP requests (ms)") + performance_degradation_multiplier: float = Field(default=1.5, description="Alert if performance degrades by this multiplier vs baseline") + + # Security Logging Configuration + security_logging_enabled: bool = Field(default=True, description="Enable security event logging") + security_failed_auth_threshold: int = Field(default=5, description="Failed auth attempts before high severity alert") + security_threat_score_alert: float = Field(default=0.7, description="Threat score threshold for alerts (0.0-1.0)") + security_rate_limit_window_minutes: int = Field(default=5, description="Time window for rate limit checks (minutes)") + + # Metrics Aggregation Configuration + metrics_aggregation_enabled: bool = Field(default=True, description="Enable automatic log aggregation into performance metrics") + metrics_aggregation_window_minutes: int = Field(default=5, description="Time window for metrics aggregation (minutes)") + + # Log Search Configuration + log_search_max_results: int = Field(default=1000, description="Maximum results per log search query") + log_retention_days: int = Field(default=30, description="Number of days to retain logs in database") + + # External Log Integration Configuration + elasticsearch_enabled: bool = Field(default=False, description="Send logs to Elasticsearch") + elasticsearch_url: Optional[str] = Field(default=None, description="Elasticsearch cluster URL") + elasticsearch_index_prefix: str = Field(default="mcpgateway-logs", description="Elasticsearch index prefix") + syslog_enabled: bool = Field(default=False, description="Send logs to syslog") + syslog_host: Optional[str] = Field(default=None, description="Syslog server host") + syslog_port: int = Field(default=514, description="Syslog server port") + webhook_logging_enabled: bool = Field(default=False, description="Send logs to webhook endpoints") + webhook_logging_urls: List[str] = Field(default_factory=list, description="Webhook URLs for log delivery") + @field_validator("log_level", mode="before") @classmethod def validate_log_level(cls, v: str) -> str: diff --git a/mcpgateway/db.py b/mcpgateway/db.py index f5289951d..ee5303350 100644 --- a/mcpgateway/db.py +++ b/mcpgateway/db.py @@ -3797,6 +3797,252 @@ def init_db(): raise Exception(f"Failed to initialize database: {str(e)}") +# ============================================================================ +# Structured Logging Models +# ============================================================================ + + +class StructuredLogEntry(Base): + """Structured log entry for comprehensive logging and analysis. + + Stores all log entries with correlation IDs, performance metrics, + and security context for advanced search and analytics. + """ + + __tablename__ = "structured_log_entries" + + # Primary key + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: uuid.uuid4().hex) + + # Timestamps + timestamp: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, index=True, default=utc_now) + + # Correlation and request tracking + correlation_id: Mapped[Optional[str]] = mapped_column(String(64), index=True, nullable=True) + request_id: Mapped[Optional[str]] = mapped_column(String(64), index=True, nullable=True) + + # Log metadata + level: Mapped[str] = mapped_column(String(20), nullable=False, index=True) # DEBUG, INFO, WARNING, ERROR, CRITICAL + component: Mapped[str] = mapped_column(String(100), nullable=False, index=True) + message: Mapped[str] = mapped_column(Text, nullable=False) + logger: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + + # User and request context + user_id: Mapped[Optional[str]] = mapped_column(String(255), index=True, nullable=True) + user_email: Mapped[Optional[str]] = mapped_column(String(255), index=True, nullable=True) + client_ip: Mapped[Optional[str]] = mapped_column(String(45), nullable=True) # IPv6 max length + user_agent: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + request_path: Mapped[Optional[str]] = mapped_column(String(500), nullable=True) + request_method: Mapped[Optional[str]] = mapped_column(String(10), nullable=True) + + # Performance data + duration_ms: Mapped[Optional[float]] = mapped_column(Float, nullable=True) + operation_type: Mapped[Optional[str]] = mapped_column(String(100), index=True, nullable=True) + + # Security context + is_security_event: Mapped[bool] = mapped_column(Boolean, default=False, index=True, nullable=False) + security_severity: Mapped[Optional[str]] = mapped_column(String(20), index=True, nullable=True) # LOW, MEDIUM, HIGH, CRITICAL + threat_indicators: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True) + + # Structured context data + context: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True) + error_details: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True) + performance_metrics: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True) + + # System information + hostname: Mapped[str] = mapped_column(String(255), nullable=False) + process_id: Mapped[int] = mapped_column(Integer, nullable=False) + thread_id: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) + version: Mapped[str] = mapped_column(String(50), nullable=False) + environment: Mapped[str] = mapped_column(String(50), nullable=False, default="production") + + # OpenTelemetry trace context + trace_id: Mapped[Optional[str]] = mapped_column(String(32), index=True, nullable=True) + span_id: Mapped[Optional[str]] = mapped_column(String(16), nullable=True) + + # Indexes for performance + __table_args__ = ( + Index('idx_log_correlation_time', 'correlation_id', 'timestamp'), + Index('idx_log_user_time', 'user_id', 'timestamp'), + Index('idx_log_level_time', 'level', 'timestamp'), + Index('idx_log_component_time', 'component', 'timestamp'), + Index('idx_log_security', 'is_security_event', 'security_severity', 'timestamp'), + Index('idx_log_operation', 'operation_type', 'timestamp'), + Index('idx_log_trace', 'trace_id', 'timestamp'), + ) + + +class PerformanceMetric(Base): + """Aggregated performance metrics from log analysis. + + Stores time-windowed aggregations of operation performance + for analytics and trend analysis. + """ + + __tablename__ = "performance_metrics" + + # Primary key + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: uuid.uuid4().hex) + + # Timestamp + timestamp: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, index=True, default=utc_now) + + # Metric identification + operation_type: Mapped[str] = mapped_column(String(100), nullable=False, index=True) + component: Mapped[str] = mapped_column(String(100), nullable=False, index=True) + + # Aggregated metrics + request_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + error_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + error_rate: Mapped[float] = mapped_column(Float, nullable=False, default=0.0) + + # Duration metrics (in milliseconds) + avg_duration_ms: Mapped[float] = mapped_column(Float, nullable=False) + min_duration_ms: Mapped[float] = mapped_column(Float, nullable=False) + max_duration_ms: Mapped[float] = mapped_column(Float, nullable=False) + p50_duration_ms: Mapped[float] = mapped_column(Float, nullable=False) + p95_duration_ms: Mapped[float] = mapped_column(Float, nullable=False) + p99_duration_ms: Mapped[float] = mapped_column(Float, nullable=False) + + # Time window + window_start: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, index=True) + window_end: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) + window_duration_seconds: Mapped[int] = mapped_column(Integer, nullable=False) + + # Additional context + metric_metadata: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True) + + __table_args__ = ( + Index('idx_perf_operation_time', 'operation_type', 'window_start'), + Index('idx_perf_component_time', 'component', 'window_start'), + Index('idx_perf_window', 'window_start', 'window_end'), + ) + + +class SecurityEvent(Base): + """Security event logging for threat detection and audit trails. + + Specialized table for security events with enhanced context + and threat analysis capabilities. + """ + + __tablename__ = "security_events" + + # Primary key + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: uuid.uuid4().hex) + + # Timestamps + timestamp: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, index=True, default=utc_now) + detected_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, default=utc_now) + + # Correlation tracking + correlation_id: Mapped[Optional[str]] = mapped_column(String(64), index=True, nullable=True) + log_entry_id: Mapped[Optional[str]] = mapped_column(String(36), ForeignKey("structured_log_entries.id"), index=True, nullable=True) + + # Event classification + event_type: Mapped[str] = mapped_column(String(100), nullable=False, index=True) # auth_failure, suspicious_activity, rate_limit, etc. + severity: Mapped[str] = mapped_column(String(20), nullable=False, index=True) # LOW, MEDIUM, HIGH, CRITICAL + category: Mapped[str] = mapped_column(String(50), nullable=False, index=True) # authentication, authorization, data_access, etc. + + # User and request context + user_id: Mapped[Optional[str]] = mapped_column(String(255), index=True, nullable=True) + user_email: Mapped[Optional[str]] = mapped_column(String(255), index=True, nullable=True) + client_ip: Mapped[str] = mapped_column(String(45), nullable=False, index=True) + user_agent: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + + # Event details + description: Mapped[str] = mapped_column(Text, nullable=False) + action_taken: Mapped[Optional[str]] = mapped_column(String(100), nullable=True) # blocked, allowed, flagged, etc. + + # Threat analysis + threat_score: Mapped[float] = mapped_column(Float, nullable=False, default=0.0) # 0.0-1.0 + threat_indicators: Mapped[Dict[str, Any]] = mapped_column(JSON, nullable=False, default=dict) + failed_attempts_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + + # Resolution tracking + resolved: Mapped[bool] = mapped_column(Boolean, default=False, index=True, nullable=False) + resolved_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True) + resolved_by: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + resolution_notes: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + + # Alert tracking + alert_sent: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) + alert_sent_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True) + alert_recipients: Mapped[Optional[List[str]]] = mapped_column(JSON, nullable=True) + + # Additional context + context: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True) + + __table_args__ = ( + Index('idx_security_type_time', 'event_type', 'timestamp'), + Index('idx_security_severity_time', 'severity', 'timestamp'), + Index('idx_security_user_time', 'user_id', 'timestamp'), + Index('idx_security_ip_time', 'client_ip', 'timestamp'), + Index('idx_security_unresolved', 'resolved', 'severity', 'timestamp'), + ) + + +class AuditTrail(Base): + """Comprehensive audit trail for data access and changes. + + Tracks all significant system changes and data access for + compliance and security auditing. + """ + + __tablename__ = "audit_trails" + + # Primary key + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: uuid.uuid4().hex) + + # Timestamps + timestamp: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, index=True, default=utc_now) + + # Correlation tracking + correlation_id: Mapped[Optional[str]] = mapped_column(String(64), index=True, nullable=True) + request_id: Mapped[Optional[str]] = mapped_column(String(64), index=True, nullable=True) + + # Action details + action: Mapped[str] = mapped_column(String(100), nullable=False, index=True) # create, read, update, delete, execute, etc. + resource_type: Mapped[str] = mapped_column(String(100), nullable=False, index=True) # tool, resource, prompt, user, etc. + resource_id: Mapped[Optional[str]] = mapped_column(String(255), index=True, nullable=True) + resource_name: Mapped[Optional[str]] = mapped_column(String(500), nullable=True) + + # User context + user_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True) + user_email: Mapped[Optional[str]] = mapped_column(String(255), index=True, nullable=True) + team_id: Mapped[Optional[str]] = mapped_column(String(36), index=True, nullable=True) + + # Request context + client_ip: Mapped[Optional[str]] = mapped_column(String(45), nullable=True) + user_agent: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + request_path: Mapped[Optional[str]] = mapped_column(String(500), nullable=True) + request_method: Mapped[Optional[str]] = mapped_column(String(10), nullable=True) + + # Change tracking + old_values: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True) + new_values: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True) + changes: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True) + + # Data classification + data_classification: Mapped[Optional[str]] = mapped_column(String(50), index=True, nullable=True) # public, internal, confidential, restricted + requires_review: Mapped[bool] = mapped_column(Boolean, default=False, index=True, nullable=False) + + # Result + success: Mapped[bool] = mapped_column(Boolean, nullable=False, index=True) + error_message: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + + # Additional context + context: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True) + + __table_args__ = ( + Index('idx_audit_action_time', 'action', 'timestamp'), + Index('idx_audit_resource_time', 'resource_type', 'resource_id', 'timestamp'), + Index('idx_audit_user_time', 'user_id', 'timestamp'), + Index('idx_audit_classification', 'data_classification', 'timestamp'), + Index('idx_audit_review', 'requires_review', 'timestamp'), + ) + + if __name__ == "__main__": # Wait for database to be ready before initializing wait_for_db_ready(max_tries=int(settings.db_max_retries), interval=int(settings.db_retry_interval_ms) / 1000, sync=True) # Converting ms to s diff --git a/mcpgateway/main.py b/mcpgateway/main.py index 2df30e6ec..c14a20bd7 100644 --- a/mcpgateway/main.py +++ b/mcpgateway/main.py @@ -1172,11 +1172,11 @@ async def _call_streamable_http(self, scope, receive, send): app.add_middleware(HttpAuthMiddleware, plugin_manager=plugin_manager) logger.info("🔌 HTTP authentication hooks enabled for plugins") -# Add correlation ID middleware if enabled -# Note: Registered AFTER HttpAuthMiddleware so it executes FIRST (middleware runs in LIFO order) -if settings.correlation_id_enabled: - app.add_middleware(CorrelationIDMiddleware) - logger.info(f"✅ Correlation ID tracking enabled (header: {settings.correlation_id_header})") +# Add request logging middleware FIRST (always enabled for gateway boundary logging) +# IMPORTANT: Must be registered BEFORE CorrelationIDMiddleware so it executes AFTER correlation ID is set +# Gateway boundary logging (request_started/completed) runs regardless of log_requests setting +# Detailed payload logging only runs if log_detailed_requests=True +app.add_middleware(RequestLoggingMiddleware, enable_gateway_logging=True, log_detailed_requests=settings.log_requests, log_level=settings.log_level, max_body_size=settings.log_max_size_mb * 1024 * 1024) # Convert MB to bytes # Add custom DocsAuthMiddleware app.add_middleware(DocsAuthMiddleware) @@ -1184,9 +1184,11 @@ async def _call_streamable_http(self, scope, receive, send): # Trust all proxies (or lock down with a list of host patterns) app.add_middleware(ProxyHeadersMiddleware, trusted_hosts="*") -# Add request logging middleware if enabled -if settings.log_requests: - app.add_middleware(RequestLoggingMiddleware, log_requests=settings.log_requests, log_level=settings.log_level, max_body_size=settings.log_max_size_mb * 1024 * 1024) # Convert MB to bytes +# Add correlation ID middleware if enabled +# Note: Registered AFTER RequestLoggingMiddleware so correlation ID is available when RequestLoggingMiddleware executes +if settings.correlation_id_enabled: + app.add_middleware(CorrelationIDMiddleware) + logger.info(f"✅ Correlation ID tracking enabled (header: {settings.correlation_id_header})") # Add observability middleware if enabled # Note: Middleware runs in REVERSE order (last added runs first) @@ -4988,6 +4990,19 @@ async def cleanup_import_statuses(max_age_hours: int = 24, user=Depends(get_curr app.include_router(tag_router) app.include_router(export_import_router) +# Include log search router if structured logging is enabled +if getattr(settings, "structured_logging_enabled", True): + try: + # First-Party + from mcpgateway.routers.log_search import router as log_search_router + + app.include_router(log_search_router) + logger.info("Log search router included - structured logging enabled") + except ImportError as e: + logger.warning(f"Failed to import log search router: {e}") +else: + logger.info("Log search router not included - structured logging disabled") + # Conditionally include observability router if enabled if settings.observability_enabled: # First-Party diff --git a/mcpgateway/middleware/auth_middleware.py b/mcpgateway/middleware/auth_middleware.py index a8868ccbe..3a35de675 100644 --- a/mcpgateway/middleware/auth_middleware.py +++ b/mcpgateway/middleware/auth_middleware.py @@ -28,8 +28,10 @@ # First-Party from mcpgateway.auth import get_current_user from mcpgateway.db import SessionLocal +from mcpgateway.services.security_logger import get_security_logger logger = logging.getLogger(__name__) +security_logger = get_security_logger() class AuthContextMiddleware(BaseHTTPMiddleware): @@ -88,10 +90,33 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response: # Store user in request state for downstream use request.state.user = user logger.info(f"✓ Authenticated user for observability: {user.email}") + + # Log successful authentication + security_logger.log_authentication_attempt( + user_id=str(user.id), + user_email=user.email, + auth_method="bearer_token", + success=True, + client_ip=request.client.host if request.client else "unknown", + user_agent=request.headers.get("user-agent"), + db=db + ) except Exception as e: # Silently fail - let route handlers enforce auth if needed logger.info(f"✗ Auth context extraction failed (continuing as anonymous): {e}") + + # Log failed authentication attempt + security_logger.log_authentication_attempt( + user_id="unknown", + user_email=None, + auth_method="bearer_token", + success=False, + client_ip=request.client.host if request.client else "unknown", + user_agent=request.headers.get("user-agent"), + failure_reason=str(e), + db=db if db else None + ) finally: # Always close database session diff --git a/mcpgateway/middleware/request_logging_middleware.py b/mcpgateway/middleware/request_logging_middleware.py index 5a6ebe3a4..f9a78903a 100644 --- a/mcpgateway/middleware/request_logging_middleware.py +++ b/mcpgateway/middleware/request_logging_middleware.py @@ -15,6 +15,7 @@ # Standard import json import logging +import time from typing import Callable # Third-Party @@ -23,12 +24,16 @@ # First-Party from mcpgateway.services.logging_service import LoggingService +from mcpgateway.services.structured_logger import get_structured_logger from mcpgateway.utils.correlation_id import get_correlation_id # Initialize logging service first logging_service = LoggingService() logger = logging_service.get_logger(__name__) +# Initialize structured logger for gateway boundary logging +structured_logger = get_structured_logger("gateway") + SENSITIVE_KEYS = {"password", "secret", "token", "apikey", "access_token", "refresh_token", "client_secret", "authorization", "jwt_token"} @@ -107,17 +112,19 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware): masking sensitive information like passwords, tokens, and authorization headers. """ - def __init__(self, app, log_requests: bool = True, log_level: str = "DEBUG", max_body_size: int = 4096): + def __init__(self, app, enable_gateway_logging: bool = True, log_detailed_requests: bool = False, log_level: str = "DEBUG", max_body_size: int = 4096): """Initialize the request logging middleware. Args: app: The FastAPI application instance - log_requests: Whether to enable request logging + enable_gateway_logging: Whether to enable gateway boundary logging (request_started/completed) + log_detailed_requests: Whether to enable detailed request/response payload logging log_level: The log level for requests (not used, logs at INFO) max_body_size: Maximum request body size to log in bytes """ super().__init__(app) - self.log_requests = log_requests + self.enable_gateway_logging = enable_gateway_logging + self.log_detailed_requests = log_detailed_requests self.log_level = log_level.upper() self.max_body_size = max_body_size # Expected to be in bytes @@ -131,9 +138,71 @@ async def dispatch(self, request: Request, call_next: Callable): Returns: Response: The HTTP response from downstream handlers """ - # Skip logging if disabled - if not self.log_requests: - return await call_next(request) + # Track start time for total duration + start_time = time.time() + + # Get correlation ID and request metadata for boundary logging + correlation_id = get_correlation_id() + path = request.url.path + method = request.method + user_agent = request.headers.get("user-agent", "unknown") + client_ip = request.client.host if request.client else "unknown" + + # Skip boundary logging for health checks and static assets + skip_paths = ["/health", "/healthz", "/static", "/favicon.ico"] + should_log_boundary = self.enable_gateway_logging and not any(path.startswith(skip_path) for skip_path in skip_paths) + + # Log gateway request started + if should_log_boundary: + try: + structured_logger.log( + level="INFO", + message=f"Request started: {method} {path}", + component="gateway", + correlation_id=correlation_id, + operation_type="http_request", + request_method=method, + request_path=path, + user_agent=user_agent, + client_ip=client_ip, + metadata={ + "event": "request_started", + "query_params": str(request.query_params) if request.query_params else None + } + ) + except Exception as e: + logger.warning(f"Failed to log request start: {e}") + + # Skip detailed logging if disabled + if not self.log_detailed_requests: + response = await call_next(request) + + # Still log request completed even if detailed logging is disabled + if should_log_boundary: + duration_ms = (time.time() - start_time) * 1000 + try: + log_level = "ERROR" if response.status_code >= 500 else "WARNING" if response.status_code >= 400 else "INFO" + structured_logger.log( + level=log_level, + message=f"Request completed: {method} {path} - {response.status_code}", + component="gateway", + correlation_id=correlation_id, + operation_type="http_request", + request_method=method, + request_path=path, + response_status_code=response.status_code, + user_agent=user_agent, + client_ip=client_ip, + duration_ms=duration_ms, + metadata={ + "event": "request_completed", + "response_time_category": "fast" if duration_ms < 100 else "normal" if duration_ms < 1000 else "slow" + } + ) + except Exception as e: + logger.warning(f"Failed to log request completion: {e}") + + return response # Always log at INFO level for request payloads to ensure visibility log_level = logging.INFO @@ -211,5 +280,82 @@ async def receive(): new_scope = request.scope.copy() new_request = Request(new_scope, receive=receive) - response: Response = await call_next(new_request) + # Process request + try: + response: Response = await call_next(new_request) + status_code = response.status_code + except Exception as e: + duration_ms = (time.time() - start_time) * 1000 + + # Log request failed + if should_log_boundary: + try: + structured_logger.log( + level="ERROR", + message=f"Request failed: {method} {path}", + component="gateway", + correlation_id=correlation_id, + operation_type="http_request", + request_method=method, + request_path=path, + user_agent=user_agent, + client_ip=client_ip, + duration_ms=duration_ms, + error=e, + metadata={ + "event": "request_failed" + } + ) + except Exception as log_error: + logger.warning(f"Failed to log request failure: {log_error}") + + raise + + # Calculate total duration + duration_ms = (time.time() - start_time) * 1000 + + # Log gateway request completed + if should_log_boundary: + try: + log_level = "ERROR" if status_code >= 500 else "WARNING" if status_code >= 400 else "INFO" + + structured_logger.log( + level=log_level, + message=f"Request completed: {method} {path} - {status_code}", + component="gateway", + correlation_id=correlation_id, + operation_type="http_request", + request_method=method, + request_path=path, + response_status_code=status_code, + user_agent=user_agent, + client_ip=client_ip, + duration_ms=duration_ms, + metadata={ + "event": "request_completed", + "response_time_category": self._categorize_response_time(duration_ms) + } + ) + except Exception as e: + logger.warning(f"Failed to log request completion: {e}") + return response + + @staticmethod + def _categorize_response_time(duration_ms: float) -> str: + """Categorize response time for analytics. + + Args: + duration_ms: Response time in milliseconds + + Returns: + Category string + """ + if duration_ms < 100: + return "fast" + elif duration_ms < 500: + return "normal" + elif duration_ms < 2000: + return "slow" + else: + return "very_slow" diff --git a/mcpgateway/routers/log_search.py b/mcpgateway/routers/log_search.py new file mode 100644 index 000000000..3cc2174ad --- /dev/null +++ b/mcpgateway/routers/log_search.py @@ -0,0 +1,605 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/routers/log_search.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 + +Log Search API Router. + +This module provides REST API endpoints for searching and analyzing structured logs, +security events, audit trails, and performance metrics. +""" + +# Standard +from datetime import datetime, timedelta, timezone +import logging +from typing import Any, Dict, List, Optional + +# Third-Party +from fastapi import APIRouter, Depends, HTTPException, Query +from pydantic import BaseModel, Field +from sqlalchemy import and_, or_, desc, select +from sqlalchemy.orm import Session +from sqlalchemy.sql import func as sa_func + +# First-Party +from mcpgateway.db import ( + AuditTrail, + PerformanceMetric, + SecurityEvent, + StructuredLogEntry, + get_db, +) +from mcpgateway.middleware.rbac import require_permission, get_current_user_with_permissions + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/logs", tags=["logs"]) + + +# Request/Response Models +class LogSearchRequest(BaseModel): + """Log search request parameters.""" + search_text: Optional[str] = Field(None, description="Text search query") + level: Optional[List[str]] = Field(None, description="Log levels to filter") + component: Optional[List[str]] = Field(None, description="Components to filter") + category: Optional[List[str]] = Field(None, description="Categories to filter") + correlation_id: Optional[str] = Field(None, description="Correlation ID to filter") + user_id: Optional[str] = Field(None, description="User ID to filter") + start_time: Optional[datetime] = Field(None, description="Start timestamp") + end_time: Optional[datetime] = Field(None, description="End timestamp") + min_duration_ms: Optional[float] = Field(None, description="Minimum duration") + max_duration_ms: Optional[float] = Field(None, description="Maximum duration") + has_error: Optional[bool] = Field(None, description="Filter for errors") + limit: int = Field(100, ge=1, le=1000, description="Maximum results") + offset: int = Field(0, ge=0, description="Result offset") + sort_by: str = Field("timestamp", description="Field to sort by") + sort_order: str = Field("desc", description="Sort order (asc/desc)") + + +class LogEntry(BaseModel): + """Log entry response model.""" + id: str + timestamp: datetime + level: str + component: str + message: str + correlation_id: Optional[str] = None + user_id: Optional[str] = None + user_email: Optional[str] = None + duration_ms: Optional[float] = None + operation_type: Optional[str] = None + request_path: Optional[str] = None + request_method: Optional[str] = None + is_security_event: bool = False + error_details: Optional[Dict[str, Any]] = None + + class Config: + from_attributes = True + + +class LogSearchResponse(BaseModel): + """Log search response.""" + total: int + results: List[LogEntry] + + +class CorrelationTraceRequest(BaseModel): + """Correlation trace request.""" + correlation_id: str + + +class CorrelationTraceResponse(BaseModel): + """Correlation trace response with all related logs.""" + correlation_id: str + total_duration_ms: Optional[float] + log_count: int + error_count: int + logs: List[LogEntry] + security_events: List[Dict[str, Any]] + audit_trails: List[Dict[str, Any]] + performance_metrics: Optional[Dict[str, Any]] + + +class SecurityEventResponse(BaseModel): + """Security event response model.""" + id: str + timestamp: datetime + event_type: str + severity: str + category: str + user_id: Optional[str] + client_ip: str + description: str + threat_score: float + action_taken: Optional[str] + resolved: bool + + class Config: + from_attributes = True + + +class AuditTrailResponse(BaseModel): + """Audit trail response model.""" + id: str + timestamp: datetime + action: str + resource_type: str + resource_id: Optional[str] + user_id: str + success: bool + requires_review: bool + data_classification: Optional[str] + + class Config: + from_attributes = True + + +class PerformanceMetricResponse(BaseModel): + """Performance metric response model.""" + id: str + timestamp: datetime + component: str + operation_type: str + window_start: datetime + window_end: datetime + request_count: int + error_count: int + error_rate: float + avg_duration_ms: float + min_duration_ms: float + max_duration_ms: float + p50_duration_ms: float + p95_duration_ms: float + p99_duration_ms: float + + class Config: + from_attributes = True + + +# API Endpoints +@router.post("/search", response_model=LogSearchResponse) +@require_permission("logs:read") +async def search_logs( + request: LogSearchRequest, + user=Depends(get_current_user_with_permissions), + db: Session = Depends(get_db) +) -> LogSearchResponse: + """Search structured logs with filters and pagination. + + Args: + request: Search parameters + db: Database session + _: Permission check dependency + + Returns: + Search results with pagination + """ + try: + # Build base query + stmt = select(StructuredLogEntry) + + # Apply filters + conditions = [] + + if request.search_text: + conditions.append( + or_( + StructuredLogEntry.message.ilike(f"%{request.search_text}%"), + StructuredLogEntry.component.ilike(f"%{request.search_text}%") + ) + ) + + if request.level: + conditions.append(StructuredLogEntry.level.in_(request.level)) + + if request.component: + conditions.append(StructuredLogEntry.component.in_(request.component)) + + # Note: category field doesn't exist in StructuredLogEntry + # if request.category: + # conditions.append(StructuredLogEntry.category.in_(request.category)) + + if request.correlation_id: + conditions.append(StructuredLogEntry.correlation_id == request.correlation_id) + + if request.user_id: + conditions.append(StructuredLogEntry.user_id == request.user_id) + + if request.start_time: + conditions.append(StructuredLogEntry.timestamp >= request.start_time) + + if request.end_time: + conditions.append(StructuredLogEntry.timestamp <= request.end_time) + + if request.min_duration_ms is not None: + conditions.append(StructuredLogEntry.duration_ms >= request.min_duration_ms) + + if request.max_duration_ms is not None: + conditions.append(StructuredLogEntry.duration_ms <= request.max_duration_ms) + + if request.has_error is not None: + if request.has_error: + conditions.append(StructuredLogEntry.error_details.isnot(None)) + else: + conditions.append(StructuredLogEntry.error_details.is_(None)) + + if conditions: + stmt = stmt.where(and_(*conditions)) + + # Get total count + count_stmt = select(sa_func.count()).select_from(stmt.subquery()) + total = db.execute(count_stmt).scalar() or 0 + + # Apply sorting + sort_column = getattr(StructuredLogEntry, request.sort_by, StructuredLogEntry.timestamp) + if request.sort_order == "desc": + stmt = stmt.order_by(desc(sort_column)) + else: + stmt = stmt.order_by(sort_column) + + # Apply pagination + stmt = stmt.limit(request.limit).offset(request.offset) + + # Execute query + results = db.execute(stmt).scalars().all() + + # Convert to response models + log_entries = [ + LogEntry( + id=str(log.id), + timestamp=log.timestamp, + level=log.level, + component=log.component, + message=log.message, + correlation_id=log.correlation_id, + user_id=log.user_id, + user_email=log.user_email, + duration_ms=log.duration_ms, + operation_type=log.operation_type, + request_path=log.request_path, + request_method=log.request_method, + is_security_event=log.is_security_event, + error_details=log.error_details, + ) + for log in results + ] + + return LogSearchResponse( + total=total, + results=log_entries + ) + + except Exception as e: + logger.error(f"Log search failed: {e}") + raise HTTPException(status_code=500, detail="Log search failed") + + +@router.get("/trace/{correlation_id}", response_model=CorrelationTraceResponse) +@require_permission("logs:read") +async def trace_correlation_id( + correlation_id: str, + user=Depends(get_current_user_with_permissions), + db: Session = Depends(get_db) +) -> CorrelationTraceResponse: + """Get all logs and events for a correlation ID. + + Args: + correlation_id: Correlation ID to trace + db: Database session + _: Permission check dependency + + Returns: + Complete trace of all related logs and events + """ + try: + # Get structured logs + log_stmt = select(StructuredLogEntry).where( + StructuredLogEntry.correlation_id == correlation_id + ).order_by(StructuredLogEntry.timestamp) + + logs = db.execute(log_stmt).scalars().all() + + # Get security events + security_stmt = select(SecurityEvent).where( + SecurityEvent.correlation_id == correlation_id + ).order_by(SecurityEvent.timestamp) + + security_events = db.execute(security_stmt).scalars().all() + + # Get audit trails + audit_stmt = select(AuditTrail).where( + AuditTrail.correlation_id == correlation_id + ).order_by(AuditTrail.timestamp) + + audit_trails = db.execute(audit_stmt).scalars().all() + + # Calculate metrics + durations = [log.duration_ms for log in logs if log.duration_ms is not None] + total_duration = sum(durations) if durations else None + error_count = sum(1 for log in logs if log.error_details) + + # Get performance metrics (if any aggregations exist) + perf_metrics = None + if logs: + component = logs[0].component + operation = logs[0].operation_type + if component and operation: + perf_stmt = select(PerformanceMetric).where( + and_( + PerformanceMetric.component == component, + PerformanceMetric.operation_type == operation + ) + ).order_by(desc(PerformanceMetric.window_start)).limit(1) + + perf = db.execute(perf_stmt).scalar_one_or_none() + if perf: + perf_metrics = { + "avg_duration_ms": perf.avg_duration_ms, + "p95_duration_ms": perf.p95_duration_ms, + "p99_duration_ms": perf.p99_duration_ms, + "error_rate": perf.error_rate, + } + + return CorrelationTraceResponse( + correlation_id=correlation_id, + total_duration_ms=total_duration, + log_count=len(logs), + error_count=error_count, + logs=[ + LogEntry( + id=str(log.id), + timestamp=log.timestamp, + level=log.level, + component=log.component, + message=log.message, + correlation_id=log.correlation_id, + user_id=log.user_id, + user_email=log.user_email, + duration_ms=log.duration_ms, + operation_type=log.operation_type, + request_path=log.request_path, + request_method=log.request_method, + is_security_event=log.is_security_event, + error_details=log.error_details, + ) + for log in logs + ], + security_events=[ + { + "id": str(event.id), + "timestamp": event.timestamp.isoformat(), + "event_type": event.event_type, + "severity": event.severity, + "description": event.description, + "threat_score": event.threat_score, + } + for event in security_events + ], + audit_trails=[ + { + "id": str(audit.id), + "timestamp": audit.timestamp.isoformat(), + "action": audit.action, + "resource_type": audit.resource_type, + "resource_id": audit.resource_id, + "success": audit.success, + } + for audit in audit_trails + ], + performance_metrics=perf_metrics, + ) + + except Exception as e: + logger.error(f"Correlation trace failed: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"Correlation trace failed: {str(e)}") + + +@router.get("/security-events", response_model=List[SecurityEventResponse]) +@require_permission("security:read") +async def get_security_events( + severity: Optional[List[str]] = Query(None), + event_type: Optional[List[str]] = Query(None), + resolved: Optional[bool] = Query(None), + start_time: Optional[datetime] = Query(None), + end_time: Optional[datetime] = Query(None), + limit: int = Query(100, ge=1, le=1000), + offset: int = Query(0, ge=0), + user=Depends(get_current_user_with_permissions), + db: Session = Depends(get_db) +) -> List[SecurityEventResponse]: + """Get security events with filters. + + Args: + severity: Filter by severity levels + event_type: Filter by event types + resolved: Filter by resolution status + start_time: Start timestamp + end_time: End timestamp + limit: Maximum results + offset: Result offset + db: Database session + _: Permission check dependency + + Returns: + List of security events + """ + try: + stmt = select(SecurityEvent) + + conditions = [] + if severity: + conditions.append(SecurityEvent.severity.in_(severity)) + if event_type: + conditions.append(SecurityEvent.event_type.in_(event_type)) + if resolved is not None: + conditions.append(SecurityEvent.resolved == resolved) + if start_time: + conditions.append(SecurityEvent.timestamp >= start_time) + if end_time: + conditions.append(SecurityEvent.timestamp <= end_time) + + if conditions: + stmt = stmt.where(and_(*conditions)) + + stmt = stmt.order_by(desc(SecurityEvent.timestamp)).limit(limit).offset(offset) + + events = db.execute(stmt).scalars().all() + + return [ + SecurityEventResponse( + id=str(event.id), + timestamp=event.timestamp, + event_type=event.event_type, + severity=event.severity, + category=event.category, + user_id=event.user_id, + client_ip=event.client_ip, + description=event.description, + threat_score=event.threat_score, + action_taken=event.action_taken, + resolved=event.resolved, + ) + for event in events + ] + + except Exception as e: + logger.error(f"Security events query failed: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"Security events query failed: {str(e)}") + + +@router.get("/audit-trails", response_model=List[AuditTrailResponse]) +@require_permission("audit:read") +async def get_audit_trails( + action: Optional[List[str]] = Query(None), + resource_type: Optional[List[str]] = Query(None), + user_id: Optional[str] = Query(None), + requires_review: Optional[bool] = Query(None), + start_time: Optional[datetime] = Query(None), + end_time: Optional[datetime] = Query(None), + limit: int = Query(100, ge=1, le=1000), + offset: int = Query(0, ge=0), + user=Depends(get_current_user_with_permissions), + db: Session = Depends(get_db) +) -> List[AuditTrailResponse]: + """Get audit trails with filters. + + Args: + action: Filter by actions + resource_type: Filter by resource types + user_id: Filter by user ID + requires_review: Filter by review requirement + start_time: Start timestamp + end_time: End timestamp + limit: Maximum results + offset: Result offset + db: Database session + _: Permission check dependency + + Returns: + List of audit trail entries + """ + try: + stmt = select(AuditTrail) + + conditions = [] + if action: + conditions.append(AuditTrail.action.in_(action)) + if resource_type: + conditions.append(AuditTrail.resource_type.in_(resource_type)) + if user_id: + conditions.append(AuditTrail.user_id == user_id) + if requires_review is not None: + conditions.append(AuditTrail.requires_review == requires_review) + if start_time: + conditions.append(AuditTrail.timestamp >= start_time) + if end_time: + conditions.append(AuditTrail.timestamp <= end_time) + + if conditions: + stmt = stmt.where(and_(*conditions)) + + stmt = stmt.order_by(desc(AuditTrail.timestamp)).limit(limit).offset(offset) + + trails = db.execute(stmt).scalars().all() + + return [ + AuditTrailResponse( + id=str(trail.id), + timestamp=trail.timestamp, + action=trail.action, + resource_type=trail.resource_type, + resource_id=trail.resource_id, + user_id=trail.user_id, + success=trail.success, + requires_review=trail.requires_review, + data_classification=trail.data_classification, + ) + for trail in trails + ] + + except Exception as e: + logger.error(f"Audit trails query failed: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"Audit trails query failed: {str(e)}") + + +@router.get("/performance-metrics", response_model=List[PerformanceMetricResponse]) +@require_permission("metrics:read") +async def get_performance_metrics( + component: Optional[str] = Query(None), + operation: Optional[str] = Query(None), + hours: int = Query(24, ge=1, le=1000), + user=Depends(get_current_user_with_permissions), + db: Session = Depends(get_db) +) -> List[PerformanceMetricResponse]: + """Get performance metrics. + + Args: + component: Filter by component + operation: Filter by operation + hours: Hours of history + db: Database session + _: Permission check dependency + + Returns: + List of performance metrics + """ + try: + since = datetime.now(timezone.utc) - timedelta(hours=hours) + + stmt = select(PerformanceMetric).where( + PerformanceMetric.window_start >= since + ) + + if component: + stmt = stmt.where(PerformanceMetric.component == component) + if operation: + stmt = stmt.where(PerformanceMetric.operation_type == operation) + + stmt = stmt.order_by(desc(PerformanceMetric.window_start)) + + metrics = db.execute(stmt).scalars().all() + + return [ + PerformanceMetricResponse( + id=str(metric.id), + timestamp=metric.timestamp, + component=metric.component, + operation_type=metric.operation_type, + window_start=metric.window_start, + window_end=metric.window_end, + request_count=metric.request_count, + error_count=metric.error_count, + error_rate=metric.error_rate, + avg_duration_ms=metric.avg_duration_ms, + min_duration_ms=metric.min_duration_ms, + max_duration_ms=metric.max_duration_ms, + p50_duration_ms=metric.p50_duration_ms, + p95_duration_ms=metric.p95_duration_ms, + p99_duration_ms=metric.p99_duration_ms, + ) + for metric in metrics + ] + + except Exception as e: + logger.error(f"Performance metrics query failed: {e}") + raise HTTPException(status_code=500, detail="Performance metrics query failed") diff --git a/mcpgateway/services/a2a_service.py b/mcpgateway/services/a2a_service.py index 33f3468d0..9d0cb704e 100644 --- a/mcpgateway/services/a2a_service.py +++ b/mcpgateway/services/a2a_service.py @@ -26,6 +26,7 @@ from mcpgateway.db import A2AAgentMetric, EmailTeam from mcpgateway.schemas import A2AAgentCreate, A2AAgentMetrics, A2AAgentRead, A2AAgentUpdate from mcpgateway.services.logging_service import LoggingService +from mcpgateway.services.structured_logger import get_structured_logger from mcpgateway.services.team_management_service import TeamManagementService from mcpgateway.services.tool_service import ToolService from mcpgateway.utils.create_slug import slugify @@ -35,6 +36,9 @@ logging_service = LoggingService() logger = logging_service.get_logger(__name__) +# Initialize structured logger for A2A lifecycle tracking +structured_logger = get_structured_logger("a2a_service") + class A2AAgentError(Exception): """Base class for A2A agent-related errors. @@ -279,6 +283,25 @@ async def register_agent( ) logger.info(f"Registered new A2A agent: {new_agent.name} (ID: {new_agent.id})") + + # Log A2A agent registration for lifecycle tracking + structured_logger.info( + f"A2A agent '{new_agent.name}' registered successfully", + user_id=created_by, + user_email=owner_email, + team_id=team_id, + resource_type="a2a_agent", + resource_id=str(new_agent.id), + resource_action="create", + custom_fields={ + "agent_name": new_agent.name, + "agent_type": new_agent.agent_type, + "protocol_version": new_agent.protocol_version, + "visibility": visibility, + "endpoint_url": new_agent.endpoint_url + } + ) + return self._db_to_schema(db=db, db_agent=new_agent) except A2AAgentNameConflictError as ie: @@ -802,14 +825,74 @@ async def invoke_agent(self, db: Session, agent_name: str, parameters: Dict[str, token_value = getattr(db_row, "auth_value", None) if db_row else None if token_value: headers["Authorization"] = f"Bearer {token_value}" + + # Add correlation ID to outbound headers for distributed tracing + from mcpgateway.utils.correlation_id import get_correlation_id + correlation_id = get_correlation_id() + if correlation_id: + headers["X-Correlation-ID"] = correlation_id + + # Log A2A external call start + call_start_time = datetime.now(timezone.utc) + structured_logger.log( + level="INFO", + message=f"A2A external call started: {agent_name}", + component="a2a_service", + correlation_id=correlation_id, + metadata={ + "event": "a2a_call_started", + "agent_name": agent_name, + "agent_id": agent.id, + "endpoint_url": agent.endpoint_url, + "interaction_type": interaction_type, + "protocol_version": agent.protocol_version + } + ) http_response = await client.post(agent.endpoint_url, json=request_data, headers=headers) + call_duration_ms = (datetime.now(timezone.utc) - call_start_time).total_seconds() * 1000 if http_response.status_code == 200: response = http_response.json() success = True + + # Log successful A2A call + structured_logger.log( + level="INFO", + message=f"A2A external call completed: {agent_name}", + component="a2a_service", + correlation_id=correlation_id, + duration_ms=call_duration_ms, + metadata={ + "event": "a2a_call_completed", + "agent_name": agent_name, + "agent_id": agent.id, + "status_code": http_response.status_code, + "success": True + } + ) else: error_message = f"HTTP {http_response.status_code}: {http_response.text}" + + # Log failed A2A call + structured_logger.log( + level="ERROR", + message=f"A2A external call failed: {agent_name}", + component="a2a_service", + correlation_id=correlation_id, + duration_ms=call_duration_ms, + error_details={ + "error_type": "A2AHTTPError", + "error_message": error_message + }, + metadata={ + "event": "a2a_call_failed", + "agent_name": agent_name, + "agent_id": agent.id, + "status_code": http_response.status_code + } + ) + raise A2AAgentError(error_message) except Exception as e: diff --git a/mcpgateway/services/audit_trail_service.py b/mcpgateway/services/audit_trail_service.py new file mode 100644 index 000000000..d086ee11c --- /dev/null +++ b/mcpgateway/services/audit_trail_service.py @@ -0,0 +1,425 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/services/audit_trail_service.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 + +Audit Trail Service. + +This module provides audit trail management for CRUD operations, +data access tracking, and compliance logging. +""" + +# Standard +from datetime import datetime, timezone +from enum import Enum +import logging +from typing import Any, Dict, Optional + +# Third-Party +from sqlalchemy import select +from sqlalchemy.orm import Session + +# First-Party +from mcpgateway.db import AuditTrail, SessionLocal +from mcpgateway.utils.correlation_id import get_correlation_id + +logger = logging.getLogger(__name__) + + +class AuditAction(str, Enum): + """Audit trail action types.""" + CREATE = "CREATE" + READ = "READ" + UPDATE = "UPDATE" + DELETE = "DELETE" + EXECUTE = "EXECUTE" + ACCESS = "ACCESS" + EXPORT = "EXPORT" + IMPORT = "IMPORT" + + +class DataClassification(str, Enum): + """Data classification levels.""" + PUBLIC = "public" + INTERNAL = "internal" + CONFIDENTIAL = "confidential" + RESTRICTED = "restricted" + + +class AuditTrailService: + """Service for managing audit trails and compliance logging. + + Provides comprehensive audit trail management with data classification, + change tracking, and compliance reporting capabilities. + """ + + def __init__(self): + """Initialize audit trail service.""" + pass + + def log_action( + self, + action: str, + resource_type: str, + resource_id: str, + user_id: str, + user_email: Optional[str] = None, + team_id: Optional[str] = None, + resource_name: Optional[str] = None, + client_ip: Optional[str] = None, + user_agent: Optional[str] = None, + request_path: Optional[str] = None, + request_method: Optional[str] = None, + old_values: Optional[Dict[str, Any]] = None, + new_values: Optional[Dict[str, Any]] = None, + changes: Optional[Dict[str, Any]] = None, + data_classification: Optional[str] = None, + requires_review: bool = False, + success: bool = True, + error_message: Optional[str] = None, + context: Optional[Dict[str, Any]] = None, + db: Optional[Session] = None + ) -> Optional[AuditTrail]: + """Log an audit trail entry. + + Args: + action: Action performed (CREATE, READ, UPDATE, DELETE, etc.) + resource_type: Type of resource (tool, server, prompt, etc.) + resource_id: ID of the resource + user_id: User who performed the action + user_email: User's email address + team_id: Team ID if applicable + resource_name: Name of the resource + client_ip: Client IP address + user_agent: Client user agent + request_path: HTTP request path + request_method: HTTP request method + old_values: Previous values before change + new_values: New values after change + changes: Specific changes made + data_classification: Data classification level + requires_review: Whether this action requires review + success: Whether the action succeeded + error_message: Error message if failed + context: Additional context + db: Optional database session + + Returns: + Created AuditTrail entry or None if logging disabled + """ + correlation_id = get_correlation_id() + + # Use provided session or create new one + close_db = False + if db is None: + db = SessionLocal() + close_db = True + + try: + # Create audit trail entry + audit_entry = AuditTrail( + timestamp=datetime.now(timezone.utc), + correlation_id=correlation_id, + action=action, + resource_type=resource_type, + resource_id=resource_id, + resource_name=resource_name, + user_id=user_id, + user_email=user_email, + team_id=team_id, + client_ip=client_ip, + user_agent=user_agent, + request_path=request_path, + request_method=request_method, + old_values=old_values, + new_values=new_values, + changes=changes, + data_classification=data_classification, + requires_review=requires_review, + success=success, + error_message=error_message, + context=context + ) + + db.add(audit_entry) + db.commit() + db.refresh(audit_entry) + + logger.debug( + f"Audit trail logged: {action} {resource_type}/{resource_id} by {user_id}", + extra={ + "correlation_id": correlation_id, + "action": action, + "resource_type": resource_type, + "resource_id": resource_id, + "user_id": user_id, + "success": success + } + ) + + return audit_entry + + except Exception as e: + logger.error( + f"Failed to log audit trail: {e}", + exc_info=True, + extra={ + "correlation_id": correlation_id, + "action": action, + "resource_type": resource_type, + "resource_id": resource_id + } + ) + if close_db: + db.rollback() + return None + + finally: + if close_db: + db.close() + + def log_crud_operation( + self, + operation: str, + resource_type: str, + resource_id: str, + user_id: str, + user_email: Optional[str] = None, + team_id: Optional[str] = None, + resource_name: Optional[str] = None, + old_values: Optional[Dict[str, Any]] = None, + new_values: Optional[Dict[str, Any]] = None, + success: bool = True, + error_message: Optional[str] = None, + db: Optional[Session] = None, + **kwargs + ) -> Optional[AuditTrail]: + """Log a CRUD operation with change tracking. + + Args: + operation: CRUD operation (CREATE, READ, UPDATE, DELETE) + resource_type: Type of resource + resource_id: ID of the resource + user_id: User who performed the operation + user_email: User's email + team_id: Team ID if applicable + resource_name: Name of the resource + old_values: Previous values (for UPDATE/DELETE) + new_values: New values (for CREATE/UPDATE) + success: Whether the operation succeeded + error_message: Error message if failed + db: Optional database session + **kwargs: Additional arguments passed to log_action + + Returns: + Created AuditTrail entry + """ + # Calculate changes for UPDATE operations + changes = None + if operation == "UPDATE" and old_values and new_values: + changes = {} + for key in set(old_values.keys()) | set(new_values.keys()): + old_val = old_values.get(key) + new_val = new_values.get(key) + if old_val != new_val: + changes[key] = {"old": old_val, "new": new_val} + + # Determine data classification based on resource type + data_classification = None + if resource_type in ["user", "team", "token", "credential"]: + data_classification = DataClassification.CONFIDENTIAL.value + elif resource_type in ["tool", "server", "prompt", "resource"]: + data_classification = DataClassification.INTERNAL.value + + # Determine if review is required + requires_review = False + if data_classification == DataClassification.CONFIDENTIAL.value: + requires_review = True + if operation == "DELETE" and resource_type in ["tool", "server", "gateway"]: + requires_review = True + + return self.log_action( + action=operation, + resource_type=resource_type, + resource_id=resource_id, + user_id=user_id, + user_email=user_email, + team_id=team_id, + resource_name=resource_name, + old_values=old_values, + new_values=new_values, + changes=changes, + data_classification=data_classification, + requires_review=requires_review, + success=success, + error_message=error_message, + db=db, + **kwargs + ) + + def log_data_access( + self, + resource_type: str, + resource_id: str, + user_id: str, + access_type: str = "READ", + user_email: Optional[str] = None, + team_id: Optional[str] = None, + resource_name: Optional[str] = None, + data_classification: Optional[str] = None, + db: Optional[Session] = None, + **kwargs + ) -> Optional[AuditTrail]: + """Log data access for compliance tracking. + + Args: + resource_type: Type of resource accessed + resource_id: ID of the resource + user_id: User who accessed the data + access_type: Type of access (READ, EXPORT, etc.) + user_email: User's email + team_id: Team ID if applicable + resource_name: Name of the resource + data_classification: Data classification level + db: Optional database session + **kwargs: Additional arguments passed to log_action + + Returns: + Created AuditTrail entry + """ + requires_review = data_classification in [ + DataClassification.CONFIDENTIAL.value, + DataClassification.RESTRICTED.value + ] + + return self.log_action( + action=access_type, + resource_type=resource_type, + resource_id=resource_id, + user_id=user_id, + user_email=user_email, + team_id=team_id, + resource_name=resource_name, + data_classification=data_classification, + requires_review=requires_review, + success=True, + db=db, + **kwargs + ) + + def log_audit( + self, + user_id: str, + resource_type: str, + resource_id: str, + action: str, + user_email: Optional[str] = None, + description: Optional[str] = None, + db: Optional[Session] = None, + **kwargs + ) -> Optional[AuditTrail]: + """Convenience method for simple audit logging. + + Args: + user_id: User who performed the action + resource_type: Type of resource + resource_id: ID of the resource + action: Action performed + user_email: User's email + description: Description of the action + db: Optional database session + **kwargs: Additional arguments passed to log_action + + Returns: + Created AuditTrail entry + """ + # Build context if description provided + context = kwargs.pop("context", {}) + if description: + context["description"] = description + + return self.log_action( + action=action, + resource_type=resource_type, + resource_id=resource_id, + user_id=user_id, + user_email=user_email, + context=context if context else None, + db=db, + **kwargs + ) + + def get_audit_trail( + self, + resource_type: Optional[str] = None, + resource_id: Optional[str] = None, + user_id: Optional[str] = None, + action: Optional[str] = None, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, + limit: int = 100, + offset: int = 0, + db: Optional[Session] = None + ) -> list[AuditTrail]: + """Query audit trail entries. + + Args: + resource_type: Filter by resource type + resource_id: Filter by resource ID + user_id: Filter by user ID + action: Filter by action + start_time: Filter by start time + end_time: Filter by end time + limit: Maximum number of results + offset: Offset for pagination + db: Optional database session + + Returns: + List of AuditTrail entries + """ + close_db = False + if db is None: + db = SessionLocal() + close_db = True + + try: + query = select(AuditTrail) + + if resource_type: + query = query.where(AuditTrail.resource_type == resource_type) + if resource_id: + query = query.where(AuditTrail.resource_id == resource_id) + if user_id: + query = query.where(AuditTrail.user_id == user_id) + if action: + query = query.where(AuditTrail.action == action) + if start_time: + query = query.where(AuditTrail.timestamp >= start_time) + if end_time: + query = query.where(AuditTrail.timestamp <= end_time) + + query = query.order_by(AuditTrail.timestamp.desc()) + query = query.limit(limit).offset(offset) + + result = db.execute(query) + return list(result.scalars().all()) + + finally: + if close_db: + db.close() + + +# Singleton instance +_audit_trail_service: Optional[AuditTrailService] = None + + +def get_audit_trail_service() -> AuditTrailService: + """Get or create the singleton audit trail service instance. + + Returns: + AuditTrailService instance + """ + global _audit_trail_service + if _audit_trail_service is None: + _audit_trail_service = AuditTrailService() + return _audit_trail_service diff --git a/mcpgateway/services/log_aggregator.py b/mcpgateway/services/log_aggregator.py new file mode 100644 index 000000000..6301ebb7f --- /dev/null +++ b/mcpgateway/services/log_aggregator.py @@ -0,0 +1,399 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/services/log_aggregator.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 + +Log Aggregation Service. + +This module provides aggregation of performance metrics from structured logs +into time-windowed statistics for analysis and monitoring. +""" + +# Standard +from datetime import datetime, timedelta, timezone +import logging +import statistics +from typing import Dict, List, Optional, Tuple + +# Third-Party +from sqlalchemy import and_, func, select +from sqlalchemy.orm import Session + +# First-Party +from mcpgateway.db import PerformanceMetric, StructuredLogEntry, SessionLocal +from mcpgateway.config import settings + +logger = logging.getLogger(__name__) + + +class LogAggregator: + """Aggregates structured logs into performance metrics.""" + + def __init__(self): + """Initialize log aggregator.""" + self.aggregation_window_minutes = getattr(settings, "metrics_aggregation_window_minutes", 5) + self.enabled = getattr(settings, "metrics_aggregation_enabled", True) + + def aggregate_performance_metrics( + self, + component: str, + operation: str, + window_start: Optional[datetime] = None, + window_end: Optional[datetime] = None, + db: Optional[Session] = None + ) -> Optional[PerformanceMetric]: + """Aggregate performance metrics for a component and operation. + + Args: + component: Component name + operation: Operation name + window_start: Start of aggregation window (defaults to N minutes ago) + window_end: End of aggregation window (defaults to now) + db: Optional database session + + Returns: + Created PerformanceMetric or None if no data + """ + if not self.enabled: + return None + + # Default time window + if window_end is None: + window_end = datetime.now(timezone.utc) + if window_start is None: + window_start = window_end - timedelta(minutes=self.aggregation_window_minutes) + + should_close = False + if db is None: + db = SessionLocal() + should_close = True + + try: + # Query structured logs for this component/operation in time window + stmt = select(StructuredLogEntry).where( + and_( + StructuredLogEntry.component == component, + StructuredLogEntry.category == "performance", + StructuredLogEntry.resource_action == operation, + StructuredLogEntry.timestamp >= window_start, + StructuredLogEntry.timestamp <= window_end, + StructuredLogEntry.duration_ms.isnot(None) + ) + ) + + results = db.execute(stmt).scalars().all() + + if not results: + return None + + # Extract durations + durations = [r.duration_ms for r in results if r.duration_ms is not None] + + if not durations: + return None + + # Calculate statistics + count = len(durations) + total_duration = sum(durations) + avg_duration = statistics.mean(durations) + min_duration = min(durations) + max_duration = max(durations) + + # Calculate percentiles + sorted_durations = sorted(durations) + p50 = self._percentile(sorted_durations, 0.50) + p95 = self._percentile(sorted_durations, 0.95) + p99 = self._percentile(sorted_durations, 0.99) + + # Count errors + error_count = sum(1 for r in results if r.error_message is not None) + error_rate = error_count / count if count > 0 else 0.0 + + # Aggregate database metrics + db_queries = [r.database_query_count for r in results if r.database_query_count is not None] + total_db_queries = sum(db_queries) if db_queries else 0 + avg_db_queries = statistics.mean(db_queries) if db_queries else 0.0 + + db_durations = [r.database_query_duration_ms for r in results if r.database_query_duration_ms is not None] + total_db_duration = sum(db_durations) if db_durations else 0.0 + avg_db_duration = statistics.mean(db_durations) if db_durations else 0.0 + + # Aggregate cache metrics + cache_hits = sum(r.cache_hits for r in results if r.cache_hits is not None) + cache_misses = sum(r.cache_misses for r in results if r.cache_misses is not None) + cache_total = cache_hits + cache_misses + cache_hit_rate = cache_hits / cache_total if cache_total > 0 else 0.0 + + # Create performance metric + metric = PerformanceMetric( + component=component, + operation=operation, + window_start=window_start, + window_end=window_end, + request_count=count, + error_count=error_count, + error_rate=error_rate, + total_duration_ms=total_duration, + avg_duration_ms=avg_duration, + min_duration_ms=min_duration, + max_duration_ms=max_duration, + p50_duration_ms=p50, + p95_duration_ms=p95, + p99_duration_ms=p99, + total_database_queries=total_db_queries, + avg_database_queries=avg_db_queries, + total_database_duration_ms=total_db_duration, + avg_database_duration_ms=avg_db_duration, + cache_hits=cache_hits, + cache_misses=cache_misses, + cache_hit_rate=cache_hit_rate, + ) + + db.add(metric) + db.commit() + db.refresh(metric) + + logger.info( + f"Aggregated performance metrics for {component}.{operation}: " + f"{count} requests, {avg_duration:.2f}ms avg, {error_rate:.2%} error rate" + ) + + return metric + + except Exception as e: + logger.error(f"Failed to aggregate performance metrics: {e}") + if db: + db.rollback() + return None + + finally: + if should_close: + db.close() + + def aggregate_all_components( + self, + window_start: Optional[datetime] = None, + window_end: Optional[datetime] = None, + db: Optional[Session] = None + ) -> List[PerformanceMetric]: + """Aggregate metrics for all components and operations. + + Args: + window_start: Start of aggregation window + window_end: End of aggregation window + db: Optional database session + + Returns: + List of created PerformanceMetric records + """ + if not self.enabled: + return [] + + should_close = False + if db is None: + db = SessionLocal() + should_close = True + + try: + # Get unique component/operation pairs + if window_end is None: + window_end = datetime.now(timezone.utc) + if window_start is None: + window_start = window_end - timedelta(minutes=self.aggregation_window_minutes) + + stmt = select( + StructuredLogEntry.component, + StructuredLogEntry.resource_action + ).where( + and_( + StructuredLogEntry.category == "performance", + StructuredLogEntry.timestamp >= window_start, + StructuredLogEntry.timestamp <= window_end, + StructuredLogEntry.duration_ms.isnot(None) + ) + ).distinct() + + pairs = db.execute(stmt).all() + + metrics = [] + for component, operation in pairs: + if component and operation: + metric = self.aggregate_performance_metrics( + component=component, + operation=operation, + window_start=window_start, + window_end=window_end, + db=db + ) + if metric: + metrics.append(metric) + + return metrics + + finally: + if should_close: + db.close() + + def get_recent_metrics( + self, + component: Optional[str] = None, + operation: Optional[str] = None, + hours: int = 24, + db: Optional[Session] = None + ) -> List[PerformanceMetric]: + """Get recent performance metrics. + + Args: + component: Optional component filter + operation: Optional operation filter + hours: Hours of history to retrieve + db: Optional database session + + Returns: + List of PerformanceMetric records + """ + should_close = False + if db is None: + db = SessionLocal() + should_close = True + + try: + since = datetime.now(timezone.utc) - timedelta(hours=hours) + + stmt = select(PerformanceMetric).where( + PerformanceMetric.window_start >= since + ) + + if component: + stmt = stmt.where(PerformanceMetric.component == component) + if operation: + stmt = stmt.where(PerformanceMetric.operation == operation) + + stmt = stmt.order_by(PerformanceMetric.window_start.desc()) + + return db.execute(stmt).scalars().all() + + finally: + if should_close: + db.close() + + def get_degradation_alerts( + self, + threshold_multiplier: float = 1.5, + hours: int = 24, + db: Optional[Session] = None + ) -> List[Dict[str, any]]: + """Identify performance degradations by comparing recent vs baseline. + + Args: + threshold_multiplier: Alert if recent is X times slower than baseline + hours: Hours of recent data to check + db: Optional database session + + Returns: + List of degradation alerts with details + """ + should_close = False + if db is None: + db = SessionLocal() + should_close = True + + try: + recent_cutoff = datetime.now(timezone.utc) - timedelta(hours=hours) + baseline_cutoff = recent_cutoff - timedelta(hours=hours * 2) + + # Get unique component/operation pairs + stmt = select( + PerformanceMetric.component, + PerformanceMetric.operation + ).distinct() + + pairs = db.execute(stmt).all() + + alerts = [] + for component, operation in pairs: + # Get recent metrics + recent_stmt = select(PerformanceMetric).where( + and_( + PerformanceMetric.component == component, + PerformanceMetric.operation == operation, + PerformanceMetric.window_start >= recent_cutoff + ) + ) + recent_metrics = db.execute(recent_stmt).scalars().all() + + # Get baseline metrics + baseline_stmt = select(PerformanceMetric).where( + and_( + PerformanceMetric.component == component, + PerformanceMetric.operation == operation, + PerformanceMetric.window_start >= baseline_cutoff, + PerformanceMetric.window_start < recent_cutoff + ) + ) + baseline_metrics = db.execute(baseline_stmt).scalars().all() + + if not recent_metrics or not baseline_metrics: + continue + + recent_avg = statistics.mean([m.avg_duration_ms for m in recent_metrics]) + baseline_avg = statistics.mean([m.avg_duration_ms for m in baseline_metrics]) + + if recent_avg > baseline_avg * threshold_multiplier: + alerts.append({ + "component": component, + "operation": operation, + "recent_avg_ms": recent_avg, + "baseline_avg_ms": baseline_avg, + "degradation_ratio": recent_avg / baseline_avg, + "recent_error_rate": statistics.mean([m.error_rate for m in recent_metrics]), + "baseline_error_rate": statistics.mean([m.error_rate for m in baseline_metrics]), + }) + + return alerts + + finally: + if should_close: + db.close() + + @staticmethod + def _percentile(sorted_values: List[float], percentile: float) -> float: + """Calculate percentile from sorted values. + + Args: + sorted_values: Sorted list of values + percentile: Percentile to calculate (0.0 to 1.0) + + Returns: + Percentile value + """ + if not sorted_values: + return 0.0 + + k = (len(sorted_values) - 1) * percentile + f = int(k) + c = f + 1 + + if c >= len(sorted_values): + return sorted_values[-1] + + d0 = sorted_values[f] * (c - k) + d1 = sorted_values[c] * (k - f) + + return d0 + d1 + + +# Global log aggregator instance +_log_aggregator: Optional[LogAggregator] = None + + +def get_log_aggregator() -> LogAggregator: + """Get or create the global log aggregator instance. + + Returns: + Global LogAggregator instance + """ + global _log_aggregator + if _log_aggregator is None: + _log_aggregator = LogAggregator() + return _log_aggregator diff --git a/mcpgateway/services/performance_tracker.py b/mcpgateway/services/performance_tracker.py new file mode 100644 index 000000000..db78c6382 --- /dev/null +++ b/mcpgateway/services/performance_tracker.py @@ -0,0 +1,324 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/services/performance_tracker.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 + +Performance Tracking Service. + +This module provides performance tracking and analytics for all operations +across the MCP Gateway, enabling identification of bottlenecks and +optimization opportunities. +""" + +# Standard +from collections import defaultdict +from contextlib import contextmanager +from datetime import datetime, timezone +import logging +import statistics +import time +from typing import Any, Dict, Generator, List, Optional + +# First-Party +from mcpgateway.config import settings +from mcpgateway.utils.correlation_id import get_correlation_id + +logger = logging.getLogger(__name__) + + +class PerformanceTracker: + """Tracks and analyzes performance metrics across requests. + + Provides context managers for tracking operation timing, + aggregation of metrics, and threshold-based alerting. + """ + + def __init__(self): + """Initialize performance tracker.""" + self.operation_timings: Dict[str, List[float]] = defaultdict(list) + + # Performance thresholds (seconds) from settings or defaults + self.performance_thresholds = { + "database_query": getattr(settings, "perf_threshold_database_query", 0.1), + "tool_invocation": getattr(settings, "perf_threshold_tool_invocation", 2.0), + "authentication": getattr(settings, "perf_threshold_authentication", 0.5), + "cache_operation": getattr(settings, "perf_threshold_cache_operation", 0.01), + "a2a_task": getattr(settings, "perf_threshold_a2a_task", 5.0), + "request_total": getattr(settings, "perf_threshold_request_total", 10.0), + "resource_fetch": getattr(settings, "perf_threshold_resource_fetch", 1.0), + "prompt_processing": getattr(settings, "perf_threshold_prompt_processing", 0.5), + } + + # Max buffer size per operation type + self.max_samples = getattr(settings, "perf_max_samples_per_operation", 1000) + + @contextmanager + def track_operation( + self, + operation_name: str, + component: Optional[str] = None, + log_slow: bool = True, + extra_context: Optional[Dict[str, Any]] = None + ) -> Generator[None, None, None]: + """Context manager to track operation performance. + + Args: + operation_name: Name of the operation being tracked + component: Component/module name for context + log_slow: Whether to log operations exceeding thresholds + extra_context: Additional context to include in logs + + Yields: + None + + Example: + >>> tracker = PerformanceTracker() + >>> with tracker.track_operation("database_query", component="tool_service"): + ... # Perform database operation + ... pass + """ + start_time = time.time() + correlation_id = get_correlation_id() + error_occurred = False + + try: + yield + except Exception: + error_occurred = True + raise + finally: + duration = time.time() - start_time + + # Record timing + self.operation_timings[operation_name].append(duration) + + # Limit buffer size + if len(self.operation_timings[operation_name]) > self.max_samples: + self.operation_timings[operation_name].pop(0) + + # Check threshold and log if needed + threshold = self.performance_thresholds.get(operation_name, float('inf')) + threshold_exceeded = duration > threshold + + if log_slow and threshold_exceeded: + context = { + "operation": operation_name, + "duration_ms": duration * 1000, + "threshold_ms": threshold * 1000, + "exceeded_by_ms": (duration - threshold) * 1000, + "component": component, + "correlation_id": correlation_id, + "error_occurred": error_occurred, + } + if extra_context: + context.update(extra_context) + + logger.warning( + f"Slow operation detected: {operation_name} took {duration*1000:.2f}ms " + f"(threshold: {threshold*1000:.2f}ms)", + extra=context + ) + + def record_timing( + self, + operation_name: str, + duration: float, + component: Optional[str] = None, + extra_context: Optional[Dict[str, Any]] = None + ) -> None: + """Manually record a timing measurement. + + Args: + operation_name: Name of the operation + duration: Duration in seconds + component: Component/module name + extra_context: Additional context + """ + self.operation_timings[operation_name].append(duration) + + # Limit buffer size + if len(self.operation_timings[operation_name]) > self.max_samples: + self.operation_timings[operation_name].pop(0) + + # Check threshold + threshold = self.performance_thresholds.get(operation_name, float('inf')) + if duration > threshold: + context = { + "operation": operation_name, + "duration_ms": duration * 1000, + "threshold_ms": threshold * 1000, + "component": component, + "correlation_id": get_correlation_id(), + } + if extra_context: + context.update(extra_context) + + logger.warning( + f"Slow operation: {operation_name} took {duration*1000:.2f}ms", + extra=context + ) + + def get_performance_summary( + self, + operation_name: Optional[str] = None, + min_samples: int = 1 + ) -> Dict[str, Any]: + """Get performance summary for analytics. + + Args: + operation_name: Specific operation to summarize (None for all) + min_samples: Minimum samples required to include in summary + + Returns: + Dictionary containing performance statistics + + Example: + >>> tracker = PerformanceTracker() + >>> summary = tracker.get_performance_summary() + >>> isinstance(summary, dict) + True + """ + summary = {} + + operations = ( + {operation_name: self.operation_timings[operation_name]} + if operation_name and operation_name in self.operation_timings + else self.operation_timings + ) + + for op_name, timings in operations.items(): + if len(timings) < min_samples: + continue + + # Calculate percentiles + sorted_timings = sorted(timings) + count = len(sorted_timings) + + def percentile(p: float) -> float: + """Calculate percentile value.""" + k = (count - 1) * p + f = int(k) + c = k - f + if f + 1 < count: + return sorted_timings[f] * (1 - c) + sorted_timings[f + 1] * c + return sorted_timings[f] + + summary[op_name] = { + "count": count, + "avg_duration_ms": statistics.mean(timings) * 1000, + "min_duration_ms": min(timings) * 1000, + "max_duration_ms": max(timings) * 1000, + "p50_duration_ms": percentile(0.5) * 1000, + "p95_duration_ms": percentile(0.95) * 1000, + "p99_duration_ms": percentile(0.99) * 1000, + "threshold_ms": self.performance_thresholds.get(op_name, float('inf')) * 1000, + "threshold_violations": sum(1 for t in timings if t > self.performance_thresholds.get(op_name, float('inf'))), + "violation_rate": sum(1 for t in timings if t > self.performance_thresholds.get(op_name, float('inf'))) / count, + } + + return summary + + def get_operation_stats(self, operation_name: str) -> Optional[Dict[str, Any]]: + """Get statistics for a specific operation. + + Args: + operation_name: Name of the operation + + Returns: + Statistics dictionary or None if no data + """ + if operation_name not in self.operation_timings: + return None + + timings = self.operation_timings[operation_name] + if not timings: + return None + + return { + "operation": operation_name, + "sample_count": len(timings), + "avg_duration_ms": statistics.mean(timings) * 1000, + "min_duration_ms": min(timings) * 1000, + "max_duration_ms": max(timings) * 1000, + "total_time_ms": sum(timings) * 1000, + "threshold_ms": self.performance_thresholds.get(operation_name, float('inf')) * 1000, + } + + def clear_stats(self, operation_name: Optional[str] = None) -> None: + """Clear performance statistics. + + Args: + operation_name: Specific operation to clear (None for all) + """ + if operation_name: + if operation_name in self.operation_timings: + self.operation_timings[operation_name].clear() + else: + self.operation_timings.clear() + + def set_threshold(self, operation_name: str, threshold_seconds: float) -> None: + """Set or update performance threshold for an operation. + + Args: + operation_name: Name of the operation + threshold_seconds: Threshold in seconds + """ + self.performance_thresholds[operation_name] = threshold_seconds + + def check_performance_degradation( + self, + operation_name: str, + baseline_multiplier: float = 2.0 + ) -> Dict[str, Any]: + """Check if performance has degraded compared to baseline. + + Args: + operation_name: Name of the operation to check + baseline_multiplier: Multiplier for degradation detection + + Returns: + Dictionary with degradation analysis + """ + if operation_name not in self.operation_timings: + return {"degraded": False, "reason": "no_data"} + + timings = self.operation_timings[operation_name] + if len(timings) < 10: + return {"degraded": False, "reason": "insufficient_samples"} + + # Compare recent timings to overall average + recent_count = min(10, len(timings)) + recent_timings = timings[-recent_count:] + historical_timings = timings[:-recent_count] if len(timings) > recent_count else timings + + if not historical_timings: + return {"degraded": False, "reason": "insufficient_historical_data"} + + recent_avg = statistics.mean(recent_timings) + historical_avg = statistics.mean(historical_timings) + + degraded = recent_avg > (historical_avg * baseline_multiplier) + + return { + "degraded": degraded, + "recent_avg_ms": recent_avg * 1000, + "historical_avg_ms": historical_avg * 1000, + "multiplier": recent_avg / historical_avg if historical_avg > 0 else 0, + "threshold_multiplier": baseline_multiplier, + } + + +# Global performance tracker instance +_performance_tracker: Optional[PerformanceTracker] = None + + +def get_performance_tracker() -> PerformanceTracker: + """Get or create the global performance tracker instance. + + Returns: + Global PerformanceTracker instance + """ + global _performance_tracker + if _performance_tracker is None: + _performance_tracker = PerformanceTracker() + return _performance_tracker diff --git a/mcpgateway/services/security_logger.py b/mcpgateway/services/security_logger.py new file mode 100644 index 000000000..2f846b44a --- /dev/null +++ b/mcpgateway/services/security_logger.py @@ -0,0 +1,640 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/services/security_logger.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 + +Security Logger Service. + +This module provides specialized logging for security events, threat detection, +and audit trail management with automated threat analysis and alerting. +""" + +# Standard +from datetime import datetime, timedelta, timezone +from enum import Enum +import logging +from typing import Any, Dict, Optional + +# Third-Party +from sqlalchemy import func, select +from sqlalchemy.orm import Session + +# First-Party +from mcpgateway.db import SecurityEvent, AuditTrail, SessionLocal +from mcpgateway.utils.correlation_id import get_correlation_id + +logger = logging.getLogger(__name__) + + +class SecuritySeverity(str, Enum): + """Security event severity levels.""" + LOW = "LOW" + MEDIUM = "MEDIUM" + HIGH = "HIGH" + CRITICAL = "CRITICAL" + + +class SecurityEventType(str, Enum): + """Types of security events.""" + AUTHENTICATION_FAILURE = "authentication_failure" + AUTHENTICATION_SUCCESS = "authentication_success" + AUTHORIZATION_FAILURE = "authorization_failure" + SUSPICIOUS_ACTIVITY = "suspicious_activity" + RATE_LIMIT_EXCEEDED = "rate_limit_exceeded" + BRUTE_FORCE_ATTEMPT = "brute_force_attempt" + TOKEN_MANIPULATION = "token_manipulation" + DATA_EXFILTRATION = "data_exfiltration" + PRIVILEGE_ESCALATION = "privilege_escalation" + INJECTION_ATTEMPT = "injection_attempt" + ANOMALOUS_BEHAVIOR = "anomalous_behavior" + + +class SecurityLogger: + """Specialized logger for security events and audit trails. + + Provides threat detection, security event logging, and audit trail + management with automated analysis and alerting capabilities. + """ + + def __init__(self): + """Initialize security logger.""" + self.failed_auth_threshold = getattr(settings, "security_failed_auth_threshold", 5) + self.threat_score_alert_threshold = getattr(settings, "security_threat_score_alert", 0.7) + self.rate_limit_window_minutes = getattr(settings, "security_rate_limit_window", 5) + + def log_authentication_attempt( + self, + user_id: str, + user_email: Optional[str], + auth_method: str, + success: bool, + client_ip: str, + user_agent: Optional[str] = None, + failure_reason: Optional[str] = None, + additional_context: Optional[Dict[str, Any]] = None, + db: Optional[Session] = None + ) -> Optional[SecurityEvent]: + """Log authentication attempts with security analysis. + + Args: + user_id: User identifier + user_email: User email address + auth_method: Authentication method used + success: Whether authentication succeeded + client_ip: Client IP address + user_agent: Client user agent + failure_reason: Reason for failure if applicable + additional_context: Additional event context + db: Optional database session + + Returns: + Created SecurityEvent or None if logging disabled + """ + correlation_id = get_correlation_id() + + # Count recent failed attempts + failed_attempts = self._count_recent_failures( + user_id=user_id, + client_ip=client_ip, + db=db + ) + + # Calculate threat score + threat_score = self._calculate_auth_threat_score( + success=success, + failed_attempts=failed_attempts, + auth_method=auth_method + ) + + # Determine severity + if not success: + if failed_attempts >= self.failed_auth_threshold: + severity = SecuritySeverity.HIGH + elif failed_attempts >= 3: + severity = SecuritySeverity.MEDIUM + else: + severity = SecuritySeverity.LOW + else: + severity = SecuritySeverity.LOW + + # Build event description + description = f"Authentication {'successful' if success else 'failed'} for user {user_id}" + if not success and failure_reason: + description += f": {failure_reason}" + + # Build context + context = { + "auth_method": auth_method, + "failed_attempts_recent": failed_attempts, + "user_agent": user_agent, + **(additional_context or {}) + } + + # Create security event + event = self._create_security_event( + event_type=SecurityEventType.AUTHENTICATION_SUCCESS if success else SecurityEventType.AUTHENTICATION_FAILURE, + severity=severity, + category="authentication", + user_id=user_id, + user_email=user_email, + client_ip=client_ip, + user_agent=user_agent, + description=description, + threat_score=threat_score, + failed_attempts_count=failed_attempts, + context=context, + action_taken="allowed" if success else "denied", + correlation_id=correlation_id, + db=db + ) + + # Log to standard logger as well + log_level = logging.WARNING if not success else logging.INFO + logger.log( + log_level, + f"Authentication attempt: {description}", + extra={ + "security_event": True, + "event_type": event.event_type if event else None, + "severity": severity.value, + "threat_score": threat_score, + "correlation_id": correlation_id, + } + ) + + return event + + def log_data_access( + self, + action: str, + resource_type: str, + resource_id: str, + resource_name: Optional[str], + user_id: str, + user_email: Optional[str], + team_id: Optional[str], + client_ip: Optional[str], + user_agent: Optional[str], + success: bool, + data_classification: Optional[str] = None, + old_values: Optional[Dict[str, Any]] = None, + new_values: Optional[Dict[str, Any]] = None, + error_message: Optional[str] = None, + additional_context: Optional[Dict[str, Any]] = None, + db: Optional[Session] = None + ) -> Optional[AuditTrail]: + """Log data access for audit trails. + + Args: + action: Action performed (create, read, update, delete, execute) + resource_type: Type of resource accessed + resource_id: Resource identifier + resource_name: Resource name + user_id: User performing the action + user_email: User email + team_id: Team context + client_ip: Client IP address + user_agent: Client user agent + success: Whether action succeeded + data_classification: Data sensitivity classification + old_values: Previous values (for updates) + new_values: New values (for updates/creates) + error_message: Error message if failed + additional_context: Additional context + db: Optional database session + + Returns: + Created AuditTrail entry or None + """ + correlation_id = get_correlation_id() + + # Determine if audit requires review + requires_review = self._requires_audit_review( + action=action, + resource_type=resource_type, + data_classification=data_classification, + success=success + ) + + # Calculate changes + changes = None + if old_values and new_values: + changes = { + k: {"old": old_values.get(k), "new": new_values.get(k)} + for k in set(old_values.keys()) | set(new_values.keys()) + if old_values.get(k) != new_values.get(k) + } + + # Create audit trail + audit = self._create_audit_trail( + action=action, + resource_type=resource_type, + resource_id=resource_id, + resource_name=resource_name, + user_id=user_id, + user_email=user_email, + team_id=team_id, + client_ip=client_ip, + user_agent=user_agent, + success=success, + old_values=old_values, + new_values=new_values, + changes=changes, + data_classification=data_classification, + requires_review=requires_review, + error_message=error_message, + context=additional_context, + correlation_id=correlation_id, + db=db + ) + + # Log sensitive data access as security event + if data_classification in ["confidential", "restricted", "sensitive"]: + self._create_security_event( + event_type="data_access", + severity=SecuritySeverity.MEDIUM if success else SecuritySeverity.HIGH, + category="data_access", + user_id=user_id, + user_email=user_email, + client_ip=client_ip or "unknown", + user_agent=user_agent, + description=f"Access to {data_classification} {resource_type}: {resource_name or resource_id}", + threat_score=0.3 if success else 0.6, + context={ + "action": action, + "resource_type": resource_type, + "resource_id": resource_id, + "data_classification": data_classification, + }, + correlation_id=correlation_id, + db=db + ) + + return audit + + def log_suspicious_activity( + self, + activity_type: str, + description: str, + user_id: Optional[str], + user_email: Optional[str], + client_ip: str, + user_agent: Optional[str], + threat_score: float, + severity: SecuritySeverity, + threat_indicators: Dict[str, Any], + action_taken: str, + additional_context: Optional[Dict[str, Any]] = None, + db: Optional[Session] = None + ) -> Optional[SecurityEvent]: + """Log suspicious activity with threat analysis. + + Args: + activity_type: Type of suspicious activity + description: Event description + user_id: User identifier (if known) + user_email: User email (if known) + client_ip: Client IP address + user_agent: Client user agent + threat_score: Calculated threat score (0.0-1.0) + severity: Event severity + threat_indicators: Dictionary of threat indicators + action_taken: Action taken in response + additional_context: Additional context + db: Optional database session + + Returns: + Created SecurityEvent or None + """ + correlation_id = get_correlation_id() + + event = self._create_security_event( + event_type=SecurityEventType.SUSPICIOUS_ACTIVITY, + severity=severity, + category="suspicious_activity", + user_id=user_id, + user_email=user_email, + client_ip=client_ip, + user_agent=user_agent, + description=description, + threat_score=threat_score, + threat_indicators=threat_indicators, + action_taken=action_taken, + context=additional_context, + correlation_id=correlation_id, + db=db + ) + + logger.warning( + f"Suspicious activity detected: {description}", + extra={ + "security_event": True, + "activity_type": activity_type, + "severity": severity.value, + "threat_score": threat_score, + "action_taken": action_taken, + "correlation_id": correlation_id, + } + ) + + return event + + def _count_recent_failures( + self, + user_id: Optional[str] = None, + client_ip: Optional[str] = None, + minutes: Optional[int] = None, + db: Optional[Session] = None + ) -> int: + """Count recent authentication failures. + + Args: + user_id: User identifier + client_ip: Client IP address + minutes: Time window in minutes + db: Optional database session + + Returns: + Count of recent failures + """ + if not user_id and not client_ip: + return 0 + + window_minutes = minutes or self.rate_limit_window_minutes + since = datetime.now(timezone.utc) - timedelta(minutes=window_minutes) + + should_close = False + if db is None: + db = SessionLocal() + should_close = True + + try: + stmt = select(func.count(SecurityEvent.id)).where( + SecurityEvent.event_type == SecurityEventType.AUTHENTICATION_FAILURE, + SecurityEvent.timestamp >= since + ) + + if user_id: + stmt = stmt.where(SecurityEvent.user_id == user_id) + if client_ip: + stmt = stmt.where(SecurityEvent.client_ip == client_ip) + + result = db.execute(stmt).scalar() + return result or 0 + + finally: + if should_close: + db.close() + + def _calculate_auth_threat_score( + self, + success: bool, + failed_attempts: int, + auth_method: str + ) -> float: + """Calculate threat score for authentication attempt. + + Args: + success: Whether authentication succeeded + failed_attempts: Count of recent failures + auth_method: Authentication method used + + Returns: + Threat score from 0.0 to 1.0 + """ + if success: + return 0.0 + + # Base score for failure + score = 0.3 + + # Increase based on failed attempts + if failed_attempts >= 10: + score += 0.5 + elif failed_attempts >= 5: + score += 0.3 + elif failed_attempts >= 3: + score += 0.2 + + # Cap at 1.0 + return min(score, 1.0) + + def _requires_audit_review( + self, + action: str, + resource_type: str, + data_classification: Optional[str], + success: bool + ) -> bool: + """Determine if audit entry requires manual review. + + Args: + action: Action performed + resource_type: Resource type + data_classification: Data classification + success: Whether action succeeded + + Returns: + True if review required + """ + # Failed actions on sensitive data require review + if not success and data_classification in ["confidential", "restricted"]: + return True + + # Deletions of sensitive data require review + if action == "delete" and data_classification in ["confidential", "restricted"]: + return True + + # Privilege modifications require review + if resource_type in ["role", "permission", "team_member"]: + return True + + return False + + def _create_security_event( + self, + event_type: str, + severity: SecuritySeverity, + category: str, + client_ip: str, + description: str, + threat_score: float, + user_id: Optional[str] = None, + user_email: Optional[str] = None, + user_agent: Optional[str] = None, + action_taken: Optional[str] = None, + failed_attempts_count: int = 0, + threat_indicators: Optional[Dict[str, Any]] = None, + context: Optional[Dict[str, Any]] = None, + correlation_id: Optional[str] = None, + db: Optional[Session] = None + ) -> Optional[SecurityEvent]: + """Create a security event record. + + Args: + event_type: Type of security event + severity: Event severity + category: Event category + client_ip: Client IP address + description: Event description + threat_score: Threat score (0.0-1.0) + user_id: User identifier + user_email: User email + user_agent: User agent string + action_taken: Action taken + failed_attempts_count: Failed attempts count + threat_indicators: Threat indicators + context: Additional context + correlation_id: Correlation ID + db: Optional database session + + Returns: + Created SecurityEvent or None + """ + should_close = False + if db is None: + db = SessionLocal() + should_close = True + + try: + event = SecurityEvent( + event_type=event_type, + severity=severity.value, + category=category, + user_id=user_id, + user_email=user_email, + client_ip=client_ip, + user_agent=user_agent, + description=description, + action_taken=action_taken, + threat_score=threat_score, + threat_indicators=threat_indicators or {}, + failed_attempts_count=failed_attempts_count, + context=context, + correlation_id=correlation_id, + ) + + db.add(event) + db.commit() + db.refresh(event) + + return event + + except Exception as e: + logger.error(f"Failed to create security event: {e}") + db.rollback() + return None + + finally: + if should_close: + db.close() + + def _create_audit_trail( + self, + action: str, + resource_type: str, + user_id: str, + success: bool, + resource_id: Optional[str] = None, + resource_name: Optional[str] = None, + user_email: Optional[str] = None, + team_id: Optional[str] = None, + client_ip: Optional[str] = None, + user_agent: Optional[str] = None, + old_values: Optional[Dict[str, Any]] = None, + new_values: Optional[Dict[str, Any]] = None, + changes: Optional[Dict[str, Any]] = None, + data_classification: Optional[str] = None, + requires_review: bool = False, + error_message: Optional[str] = None, + context: Optional[Dict[str, Any]] = None, + correlation_id: Optional[str] = None, + db: Optional[Session] = None + ) -> Optional[AuditTrail]: + """Create an audit trail record. + + Args: + action: Action performed + resource_type: Resource type + user_id: User performing action + success: Whether action succeeded + resource_id: Resource identifier + resource_name: Resource name + user_email: User email + team_id: Team context + client_ip: Client IP + user_agent: User agent + old_values: Previous values + new_values: New values + changes: Calculated changes + data_classification: Data classification + requires_review: Whether manual review needed + error_message: Error message if failed + context: Additional context + correlation_id: Correlation ID + db: Optional database session + + Returns: + Created AuditTrail or None + """ + should_close = False + if db is None: + db = SessionLocal() + should_close = True + + try: + audit = AuditTrail( + action=action, + resource_type=resource_type, + resource_id=resource_id, + resource_name=resource_name, + user_id=user_id, + user_email=user_email, + team_id=team_id, + client_ip=client_ip, + user_agent=user_agent, + old_values=old_values, + new_values=new_values, + changes=changes, + data_classification=data_classification, + requires_review=requires_review, + success=success, + error_message=error_message, + context=context, + correlation_id=correlation_id, + ) + + db.add(audit) + db.commit() + db.refresh(audit) + + return audit + + except Exception as e: + logger.error(f"Failed to create audit trail: {e}") + db.rollback() + return None + + finally: + if should_close: + db.close() + + +# Global security logger instance +_security_logger: Optional[SecurityLogger] = None + + +def get_security_logger() -> SecurityLogger: + """Get or create the global security logger instance. + + Returns: + Global SecurityLogger instance + """ + global _security_logger + if _security_logger is None: + _security_logger = SecurityLogger() + return _security_logger + + +# Import settings here to avoid circular imports +from mcpgateway.config import settings diff --git a/mcpgateway/services/server_service.py b/mcpgateway/services/server_service.py index e7f8aae4d..4f321524a 100644 --- a/mcpgateway/services/server_service.py +++ b/mcpgateway/services/server_service.py @@ -34,6 +34,9 @@ from mcpgateway.db import Tool as DbTool from mcpgateway.schemas import ServerCreate, ServerMetrics, ServerRead, ServerUpdate, TopPerformer from mcpgateway.services.logging_service import LoggingService +from mcpgateway.services.structured_logger import get_structured_logger +from mcpgateway.services.audit_trail_service import get_audit_trail_service +from mcpgateway.services.performance_tracker import get_performance_tracker from mcpgateway.services.team_management_service import TeamManagementService from mcpgateway.utils.metrics_common import build_top_performers from mcpgateway.utils.sqlalchemy_modifier import json_contains_expr @@ -130,6 +133,9 @@ def __init__(self) -> None: """ self._event_subscribers: List[asyncio.Queue] = [] self._http_client = httpx.AsyncClient(timeout=settings.federation_timeout, verify=not settings.skip_ssl_verify) + self._structured_logger = get_structured_logger() + self._audit_trail = get_audit_trail_service() + self._performance_tracker = get_performance_tracker() async def initialize(self) -> None: """Initialize the server service.""" @@ -549,17 +555,87 @@ async def register_server( logger.debug(f"Server Data: {server_data}") await self._notify_server_added(db_server) logger.info(f"Registered server: {server_in.name}") + + # Structured logging: Audit trail for server creation + await self._audit_trail.log_action( + user_id=created_by or "system", + action="create_server", + resource_type="server", + resource_id=db_server.id, + details={ + "server_name": db_server.name, + "visibility": visibility, + "team_id": team_id, + "associated_tools_count": len(db_server.tools), + "associated_resources_count": len(db_server.resources), + "associated_prompts_count": len(db_server.prompts), + "associated_a2a_agents_count": len(db_server.a2a_agents), + }, + metadata={ + "created_from_ip": created_from_ip, + "created_via": created_via, + "created_user_agent": created_user_agent, + }, + ) + + # Structured logging: Log successful server creation + await self._structured_logger.log( + level="info", + message="Server created successfully", + event_type="server_created", + component="server_service", + server_id=db_server.id, + server_name=db_server.name, + visibility=visibility, + created_by=created_by, + ) + db_server.team = self._get_team_name(db, db_server.team_id) return self._convert_server_to_read(db_server) except IntegrityError as ie: db.rollback() logger.error(f"IntegrityErrors in group: {ie}") + + # Structured logging: Log database integrity error + await self._structured_logger.log( + level="error", + message="Server creation failed due to database integrity error", + event_type="server_creation_failed", + component="server_service", + server_name=server_in.name, + error_type="IntegrityError", + error_message=str(ie), + created_by=created_by, + ) raise ie except ServerNameConflictError as se: db.rollback() + + # Structured logging: Log name conflict error + await self._structured_logger.log( + level="warning", + message="Server creation failed due to name conflict", + event_type="server_name_conflict", + component="server_service", + server_name=server_in.name, + visibility=visibility, + created_by=created_by, + ) raise se except Exception as ex: db.rollback() + + # Structured logging: Log generic server creation failure + await self._structured_logger.log( + level="error", + message="Server creation failed", + event_type="server_creation_failed", + component="server_service", + server_name=server_in.name, + error_type=type(ex).__name__, + error_message=str(ex), + created_by=created_by, + ) raise ServerError(f"Failed to register server: {str(ex)}") async def list_servers(self, db: Session, include_inactive: bool = False, tags: Optional[List[str]] = None) -> List[ServerRead]: @@ -927,6 +1003,43 @@ async def update_server( await self._notify_server_updated(server) logger.info(f"Updated server: {server.name}") + # Structured logging: Audit trail for server update + changes = [] + if server_update.name: + changes.append(f"name: {server_update.name}") + if server_update.visibility: + changes.append(f"visibility: {server_update.visibility}") + if server_update.team_id: + changes.append(f"team_id: {server_update.team_id}") + + await self._audit_trail.log_action( + user_id=user_email or "system", + action="update_server", + resource_type="server", + resource_id=server.id, + details={ + "server_name": server.name, + "changes": ", ".join(changes) if changes else "metadata only", + "version": server.version, + }, + metadata={ + "modified_from_ip": modified_from_ip, + "modified_via": modified_via, + "modified_user_agent": modified_user_agent, + }, + ) + + # Structured logging: Log successful server update + await self._structured_logger.log( + level="info", + message="Server updated successfully", + event_type="server_updated", + component="server_service", + server_id=server.id, + server_name=server.name, + modified_by=user_email, + ) + # Build a dictionary with associated IDs server_data = { "id": server.id, @@ -946,13 +1059,47 @@ async def update_server( except IntegrityError as ie: db.rollback() logger.error(f"IntegrityErrors in group: {ie}") + + # Structured logging: Log database integrity error + await self._structured_logger.log( + level="error", + message="Server update failed due to database integrity error", + event_type="server_update_failed", + component="server_service", + server_id=server_id, + error_type="IntegrityError", + error_message=str(ie), + modified_by=user_email, + ) raise ie except ServerNameConflictError as snce: db.rollback() logger.error(f"Server name conflict: {snce}") + + # Structured logging: Log name conflict error + await self._structured_logger.log( + level="warning", + message="Server update failed due to name conflict", + event_type="server_name_conflict", + component="server_service", + server_id=server_id, + modified_by=user_email, + ) raise snce except Exception as e: db.rollback() + + # Structured logging: Log generic server update failure + await self._structured_logger.log( + level="error", + message="Server update failed", + event_type="server_update_failed", + component="server_service", + server_id=server_id, + error_type=type(e).__name__, + error_message=str(e), + modified_by=user_email, + ) raise ServerError(f"Failed to update server: {str(e)}") async def toggle_server_status(self, db: Session, server_id: str, activate: bool, user_email: Optional[str] = None) -> ServerRead: @@ -1013,6 +1160,30 @@ async def toggle_server_status(self, db: Session, server_id: str, activate: bool else: await self._notify_server_deactivated(server) logger.info(f"Server {server.name} {'activated' if activate else 'deactivated'}") + + # Structured logging: Audit trail for server status toggle + await self._audit_trail.log_action( + user_id=user_email or "system", + action="activate_server" if activate else "deactivate_server", + resource_type="server", + resource_id=server.id, + details={ + "server_name": server.name, + "new_status": "active" if activate else "inactive", + }, + ) + + # Structured logging: Log server status change + await self._structured_logger.log( + level="info", + message=f"Server {'activated' if activate else 'deactivated'}", + event_type="server_status_changed", + component="server_service", + server_id=server.id, + server_name=server.name, + new_status="active" if activate else "inactive", + changed_by=user_email, + ) server_data = { "id": server.id, @@ -1030,9 +1201,30 @@ async def toggle_server_status(self, db: Session, server_id: str, activate: bool logger.info(f"Server Data: {server_data}") return self._convert_server_to_read(server) except PermissionError as e: + # Structured logging: Log permission error + await self._structured_logger.log( + level="warning", + message="Server status toggle failed due to insufficient permissions", + event_type="server_status_toggle_permission_denied", + component="server_service", + server_id=server_id, + user_email=user_email, + ) raise e except Exception as e: db.rollback() + + # Structured logging: Log generic server status toggle failure + await self._structured_logger.log( + level="error", + message="Server status toggle failed", + event_type="server_status_toggle_failed", + component="server_service", + server_id=server_id, + error_type=type(e).__name__, + error_message=str(e), + user_email=user_email, + ) raise ServerError(f"Failed to toggle server status: {str(e)}") async def delete_server(self, db: Session, server_id: str, user_email: Optional[str] = None) -> None: @@ -1081,11 +1273,55 @@ async def delete_server(self, db: Session, server_id: str, user_email: Optional[ await self._notify_server_deleted(server_info) logger.info(f"Deleted server: {server_info['name']}") - except PermissionError: + + # Structured logging: Audit trail for server deletion + await self._audit_trail.log_action( + user_id=user_email or "system", + action="delete_server", + resource_type="server", + resource_id=server_info["id"], + details={ + "server_name": server_info["name"], + }, + ) + + # Structured logging: Log successful server deletion + await self._structured_logger.log( + level="info", + message="Server deleted successfully", + event_type="server_deleted", + component="server_service", + server_id=server_info["id"], + server_name=server_info["name"], + deleted_by=user_email, + ) + except PermissionError as pe: db.rollback() - raise + + # Structured logging: Log permission error + await self._structured_logger.log( + level="warning", + message="Server deletion failed due to insufficient permissions", + event_type="server_deletion_permission_denied", + component="server_service", + server_id=server_id, + user_email=user_email, + ) + raise pe except Exception as e: db.rollback() + + # Structured logging: Log generic server deletion failure + await self._structured_logger.log( + level="error", + message="Server deletion failed", + event_type="server_deletion_failed", + component="server_service", + server_id=server_id, + error_type=type(e).__name__, + error_message=str(e), + user_email=user_email, + ) raise ServerError(f"Failed to delete server: {str(e)}") async def _publish_event(self, event: Dict[str, Any]) -> None: diff --git a/mcpgateway/services/structured_logger.py b/mcpgateway/services/structured_logger.py new file mode 100644 index 000000000..b43b98958 --- /dev/null +++ b/mcpgateway/services/structured_logger.py @@ -0,0 +1,408 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/services/structured_logger.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 + +Structured Logger Service. + +This module provides comprehensive structured logging with component-based loggers, +automatic enrichment, intelligent routing, and database persistence. +""" + +# Standard +from datetime import datetime, timezone +from enum import Enum +import logging +import os +import socket +import sys +import traceback +from typing import Any, Dict, List, Optional, Union + +# Third-Party +from sqlalchemy.orm import Session + +# First-Party +from mcpgateway.config import settings +from mcpgateway.db import StructuredLogEntry, SessionLocal +from mcpgateway.services.performance_tracker import get_performance_tracker +from mcpgateway.utils.correlation_id import get_correlation_id + +logger = logging.getLogger(__name__) + + +class LogLevel(str, Enum): + """Log levels matching Python logging.""" + DEBUG = "DEBUG" + INFO = "INFO" + WARNING = "WARNING" + ERROR = "ERROR" + CRITICAL = "CRITICAL" + + +class LogCategory(str, Enum): + """Log categories for classification.""" + APPLICATION = "application" + REQUEST = "request" + SECURITY = "security" + PERFORMANCE = "performance" + DATABASE = "database" + AUTHENTICATION = "authentication" + AUTHORIZATION = "authorization" + EXTERNAL_SERVICE = "external_service" + BUSINESS_LOGIC = "business_logic" + SYSTEM = "system" + + +class LogEnricher: + """Enriches log entries with contextual information.""" + + @staticmethod + def enrich(entry: Dict[str, Any]) -> Dict[str, Any]: + """Enrich log entry with system and context information. + + Args: + entry: Base log entry + + Returns: + Enriched log entry + """ + # Get correlation ID + correlation_id = get_correlation_id() + if correlation_id: + entry["correlation_id"] = correlation_id + + # Add hostname and process info + entry.setdefault("hostname", socket.gethostname()) + entry.setdefault("process_id", os.getpid()) + + # Add timestamp if not present + if "timestamp" not in entry: + entry["timestamp"] = datetime.now(timezone.utc) + + # Add performance metrics if available + perf_tracker = get_performance_tracker() + if correlation_id and perf_tracker: + current_ops = perf_tracker.get_current_operations(correlation_id) + if current_ops: + entry["active_operations"] = len(current_ops) + + # Add OpenTelemetry trace context if available + try: + from opentelemetry import trace + span = trace.get_current_span() + if span and span.get_span_context().is_valid: + ctx = span.get_span_context() + entry["trace_id"] = format(ctx.trace_id, "032x") + entry["span_id"] = format(ctx.span_id, "016x") + except (ImportError, Exception): + pass + + return entry + + +class LogRouter: + """Routes log entries to appropriate destinations.""" + + def __init__(self): + """Initialize log router.""" + self.database_enabled = getattr(settings, "structured_logging_database_enabled", True) + self.external_enabled = getattr(settings, "structured_logging_external_enabled", False) + + def route(self, entry: Dict[str, Any], db: Optional[Session] = None) -> None: + """Route log entry to configured destinations. + + Args: + entry: Log entry to route + db: Optional database session + """ + # Always log to standard Python logger + self._log_to_python_logger(entry) + + # Persist to database if enabled + if self.database_enabled: + self._persist_to_database(entry, db) + + # Send to external systems if enabled + if self.external_enabled: + self._send_to_external(entry) + + def _log_to_python_logger(self, entry: Dict[str, Any]) -> None: + """Log to standard Python logger. + + Args: + entry: Log entry + """ + level_str = entry.get("level", "INFO") + level = getattr(logging, level_str, logging.INFO) + + message = entry.get("message", "") + component = entry.get("component", "") + + log_message = f"[{component}] {message}" if component else message + + # Build extra dict for structured logging + extra = { + k: v for k, v in entry.items() + if k not in ["message", "level"] + } + + logger.log(level, log_message, extra=extra) + + def _persist_to_database(self, entry: Dict[str, Any], db: Optional[Session] = None) -> None: + """Persist log entry to database. + + Args: + entry: Log entry + db: Optional database session + """ + should_close = False + if db is None: + db = SessionLocal() + should_close = True + + try: + # Build error_details JSON from error-related fields + error_details = None + if any([entry.get("error_type"), entry.get("error_message"), entry.get("error_stack_trace"), entry.get("error_context")]): + error_details = { + "error_type": entry.get("error_type"), + "error_message": entry.get("error_message"), + "error_stack_trace": entry.get("error_stack_trace"), + "error_context": entry.get("error_context"), + } + + # Build performance_metrics JSON from performance-related fields + performance_metrics = None + perf_fields = { + "database_query_count": entry.get("database_query_count"), + "database_query_duration_ms": entry.get("database_query_duration_ms"), + "cache_hits": entry.get("cache_hits"), + "cache_misses": entry.get("cache_misses"), + "external_api_calls": entry.get("external_api_calls"), + "external_api_duration_ms": entry.get("external_api_duration_ms"), + "memory_usage_mb": entry.get("memory_usage_mb"), + "cpu_usage_percent": entry.get("cpu_usage_percent"), + } + if any(v is not None for v in perf_fields.values()): + performance_metrics = {k: v for k, v in perf_fields.items() if v is not None} + + # Build threat_indicators JSON from security-related fields + threat_indicators = None + security_fields = { + "security_event_type": entry.get("security_event_type"), + "security_threat_score": entry.get("security_threat_score"), + "security_action_taken": entry.get("security_action_taken"), + } + if any(v is not None for v in security_fields.values()): + threat_indicators = {k: v for k, v in security_fields.items() if v is not None} + + # Build context JSON from remaining fields + context_fields = { + "team_id": entry.get("team_id"), + "request_query": entry.get("request_query"), + "request_headers": entry.get("request_headers"), + "request_body_size": entry.get("request_body_size"), + "response_status_code": entry.get("response_status_code"), + "response_body_size": entry.get("response_body_size"), + "response_headers": entry.get("response_headers"), + "business_event_type": entry.get("business_event_type"), + "business_entity_type": entry.get("business_entity_type"), + "business_entity_id": entry.get("business_entity_id"), + "resource_type": entry.get("resource_type"), + "resource_id": entry.get("resource_id"), + "resource_action": entry.get("resource_action"), + "category": entry.get("category"), + "custom_fields": entry.get("custom_fields"), + "tags": entry.get("tags"), + "metadata": entry.get("metadata"), + } + context = {k: v for k, v in context_fields.items() if v is not None} + + # Determine if this is a security event + is_security_event = entry.get("is_security_event", False) or bool(threat_indicators) + security_severity = entry.get("security_severity") + + log_entry = StructuredLogEntry( + timestamp=entry.get("timestamp", datetime.now(timezone.utc)), + level=entry.get("level", "INFO"), + component=entry.get("component"), + message=entry.get("message", ""), + correlation_id=entry.get("correlation_id"), + request_id=entry.get("request_id"), + trace_id=entry.get("trace_id"), + span_id=entry.get("span_id"), + user_id=entry.get("user_id"), + user_email=entry.get("user_email"), + client_ip=entry.get("client_ip"), + user_agent=entry.get("user_agent"), + request_method=entry.get("request_method"), + request_path=entry.get("request_path"), + duration_ms=entry.get("duration_ms"), + operation_type=entry.get("operation_type"), + is_security_event=is_security_event, + security_severity=security_severity, + threat_indicators=threat_indicators, + context=context if context else None, + error_details=error_details, + performance_metrics=performance_metrics, + hostname=entry.get("hostname"), + process_id=entry.get("process_id"), + thread_id=entry.get("thread_id"), + environment=entry.get("environment", getattr(settings, "environment", "development")), + version=entry.get("version", getattr(settings, "version", "unknown")), + ) + + db.add(log_entry) + db.commit() + + except Exception as e: + logger.error(f"Failed to persist log entry to database: {e}") + if db: + db.rollback() + + finally: + if should_close: + db.close() + + def _send_to_external(self, entry: Dict[str, Any]) -> None: + """Send log entry to external systems. + + Args: + entry: Log entry + """ + # Placeholder for external logging integration + # Will be implemented in log exporters + pass + + +class StructuredLogger: + """Main structured logger with enrichment and routing.""" + + def __init__(self, component: str): + """Initialize structured logger. + + Args: + component: Component name for log entries + """ + self.component = component + self.enricher = LogEnricher() + self.router = LogRouter() + + def log( + self, + level: Union[LogLevel, str], + message: str, + category: Optional[Union[LogCategory, str]] = None, + user_id: Optional[str] = None, + user_email: Optional[str] = None, + team_id: Optional[str] = None, + error: Optional[Exception] = None, + duration_ms: Optional[float] = None, + custom_fields: Optional[Dict[str, Any]] = None, + tags: Optional[List[str]] = None, + db: Optional[Session] = None, + **kwargs: Any + ) -> None: + """Log a structured message. + + Args: + level: Log level + message: Log message + category: Log category + user_id: User identifier + user_email: User email + team_id: Team identifier + error: Exception object + duration_ms: Operation duration + custom_fields: Additional custom fields + tags: Log tags + db: Optional database session + **kwargs: Additional fields to include + """ + # Build base entry + entry: Dict[str, Any] = { + "level": level.value if isinstance(level, LogLevel) else level, + "component": self.component, + "message": message, + "category": category.value if isinstance(category, LogCategory) and category else category if category else None, + "user_id": user_id, + "user_email": user_email, + "team_id": team_id, + "duration_ms": duration_ms, + "custom_fields": custom_fields, + "tags": tags, + } + + # Add error information if present + if error: + entry["error_type"] = type(error).__name__ + entry["error_message"] = str(error) + entry["error_stack_trace"] = "".join(traceback.format_exception(type(error), error, error.__traceback__)) + + # Add any additional kwargs + entry.update(kwargs) + + # Enrich entry with context + entry = self.enricher.enrich(entry) + + # Route to destinations + self.router.route(entry, db) + + def debug(self, message: str, **kwargs: Any) -> None: + """Log debug message.""" + self.log(LogLevel.DEBUG, message, **kwargs) + + def info(self, message: str, **kwargs: Any) -> None: + """Log info message.""" + self.log(LogLevel.INFO, message, **kwargs) + + def warning(self, message: str, **kwargs: Any) -> None: + """Log warning message.""" + self.log(LogLevel.WARNING, message, **kwargs) + + def error(self, message: str, error: Optional[Exception] = None, **kwargs: Any) -> None: + """Log error message.""" + self.log(LogLevel.ERROR, message, error=error, **kwargs) + + def critical(self, message: str, error: Optional[Exception] = None, **kwargs: Any) -> None: + """Log critical message.""" + self.log(LogLevel.CRITICAL, message, error=error, **kwargs) + + +class ComponentLogger: + """Logger factory for component-specific loggers.""" + + _loggers: Dict[str, StructuredLogger] = {} + + @classmethod + def get_logger(cls, component: str) -> StructuredLogger: + """Get or create a logger for a specific component. + + Args: + component: Component name + + Returns: + StructuredLogger instance for the component + """ + if component not in cls._loggers: + cls._loggers[component] = StructuredLogger(component) + return cls._loggers[component] + + @classmethod + def clear_loggers(cls) -> None: + """Clear all cached loggers (useful for testing).""" + cls._loggers.clear() + + +# Global structured logger instance for backward compatibility +def get_structured_logger(component: str = "mcpgateway") -> StructuredLogger: + """Get a structured logger instance. + + Args: + component: Component name + + Returns: + StructuredLogger instance + """ + return ComponentLogger.get_logger(component) diff --git a/mcpgateway/services/tool_service.py b/mcpgateway/services/tool_service.py index 6fb8a9454..358e05f17 100644 --- a/mcpgateway/services/tool_service.py +++ b/mcpgateway/services/tool_service.py @@ -67,6 +67,8 @@ from mcpgateway.services.event_service import EventService from mcpgateway.services.logging_service import LoggingService from mcpgateway.services.oauth_manager import OAuthManager +from mcpgateway.services.performance_tracker import get_performance_tracker +from mcpgateway.services.structured_logger import get_structured_logger, LogCategory from mcpgateway.services.team_management_service import TeamManagementService from mcpgateway.utils.create_slug import slugify from mcpgateway.utils.display_name import generate_display_name @@ -82,6 +84,10 @@ logging_service = LoggingService() logger = logging_service.get_logger(__name__) +# Initialize performance tracker and structured logger for tool operations +perf_tracker = get_performance_tracker() +structured_logger = get_structured_logger("tool_service") + def extract_using_jq(data, jq_filter=""): """ @@ -1448,11 +1454,74 @@ async def connect_to_sse_server(server_url: str, headers: dict = headers): Returns: ToolResult: Result of tool call """ - async with sse_client(url=server_url, headers=headers, httpx_client_factory=get_httpx_client_factory) as streams: - async with ClientSession(*streams) as session: - await session.initialize() - tool_call_result = await session.call_tool(tool.original_name, arguments) - return tool_call_result + # Get correlation ID for distributed tracing + correlation_id = get_correlation_id() + + # Add correlation ID to headers + if correlation_id and headers: + headers["X-Correlation-ID"] = correlation_id + + # Log MCP call start + mcp_start_time = time.time() + structured_logger.log( + level="INFO", + message=f"MCP tool call started: {tool.original_name}", + component="tool_service", + correlation_id=correlation_id, + metadata={ + "event": "mcp_call_started", + "tool_name": tool.original_name, + "tool_id": tool.id, + "server_url": server_url, + "transport": "sse" + } + ) + + try: + async with sse_client(url=server_url, headers=headers, httpx_client_factory=get_httpx_client_factory) as streams: + async with ClientSession(*streams) as session: + await session.initialize() + tool_call_result = await session.call_tool(tool.original_name, arguments) + + # Log successful MCP call + mcp_duration_ms = (time.time() - mcp_start_time) * 1000 + structured_logger.log( + level="INFO", + message=f"MCP tool call completed: {tool.original_name}", + component="tool_service", + correlation_id=correlation_id, + duration_ms=mcp_duration_ms, + metadata={ + "event": "mcp_call_completed", + "tool_name": tool.original_name, + "tool_id": tool.id, + "transport": "sse", + "success": True + } + ) + + return tool_call_result + except Exception as e: + # Log failed MCP call + mcp_duration_ms = (time.time() - mcp_start_time) * 1000 + structured_logger.log( + level="ERROR", + message=f"MCP tool call failed: {tool.original_name}", + component="tool_service", + correlation_id=correlation_id, + duration_ms=mcp_duration_ms, + error_details={ + "error_type": type(e).__name__, + "error_message": str(e) + }, + metadata={ + "event": "mcp_call_failed", + "tool_name": tool.original_name, + "tool_id": tool.id, + "transport": "sse" + } + ) + raise async def connect_to_streamablehttp_server(server_url: str, headers: dict = headers): """Connect to an MCP server running with Streamable HTTP transport. @@ -1464,11 +1533,74 @@ async def connect_to_streamablehttp_server(server_url: str, headers: dict = head Returns: ToolResult: Result of tool call """ - async with streamablehttp_client(url=server_url, headers=headers, httpx_client_factory=get_httpx_client_factory) as (read_stream, write_stream, _get_session_id): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() - tool_call_result = await session.call_tool(tool.original_name, arguments) - return tool_call_result + # Get correlation ID for distributed tracing + correlation_id = get_correlation_id() + + # Add correlation ID to headers + if correlation_id and headers: + headers["X-Correlation-ID"] = correlation_id + + # Log MCP call start + mcp_start_time = time.time() + structured_logger.log( + level="INFO", + message=f"MCP tool call started: {tool.original_name}", + component="tool_service", + correlation_id=correlation_id, + metadata={ + "event": "mcp_call_started", + "tool_name": tool.original_name, + "tool_id": tool.id, + "server_url": server_url, + "transport": "streamablehttp" + } + ) + + try: + async with streamablehttp_client(url=server_url, headers=headers, httpx_client_factory=get_httpx_client_factory) as (read_stream, write_stream, _get_session_id): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + tool_call_result = await session.call_tool(tool.original_name, arguments) + + # Log successful MCP call + mcp_duration_ms = (time.time() - mcp_start_time) * 1000 + structured_logger.log( + level="INFO", + message=f"MCP tool call completed: {tool.original_name}", + component="tool_service", + correlation_id=correlation_id, + duration_ms=mcp_duration_ms, + metadata={ + "event": "mcp_call_completed", + "tool_name": tool.original_name, + "tool_id": tool.id, + "transport": "streamablehttp", + "success": True + } + ) + + return tool_call_result + except Exception as e: + # Log failed MCP call + mcp_duration_ms = (time.time() - mcp_start_time) * 1000 + structured_logger.log( + level="ERROR", + message=f"MCP tool call failed: {tool.original_name}", + component="tool_service", + correlation_id=correlation_id, + duration_ms=mcp_duration_ms, + error_details={ + "error_type": type(e).__name__, + "error_message": str(e) + }, + metadata={ + "event": "mcp_call_failed", + "tool_name": tool.original_name, + "tool_id": tool.id, + "transport": "streamablehttp" + } + ) + raise tool_gateway_id = tool.gateway_id tool_gateway = db.execute(select(DbGateway).where(DbGateway.id == tool_gateway_id).where(DbGateway.enabled)).scalar_one_or_none() @@ -1548,11 +1680,51 @@ async def connect_to_streamablehttp_server(server_url: str, headers: dict = head span.set_attribute("error.message", str(e)) raise ToolInvocationError(f"Tool invocation failed: {error_message}") finally: + # Calculate duration + duration_ms = (time.monotonic() - start_time) * 1000 + # Add final span attributes if span: span.set_attribute("success", success) - span.set_attribute("duration.ms", (time.monotonic() - start_time) * 1000) + span.set_attribute("duration.ms", duration_ms) + + # Record tool metric await self._record_tool_metric(db, tool, start_time, success, error_message) + + # Log structured message with performance tracking + if success: + structured_logger.info( + f"Tool '{name}' invoked successfully", + user_id=app_user_email, + resource_type="tool", + resource_id=str(tool.id), + resource_action="invoke", + duration_ms=duration_ms, + custom_fields={ + "tool_name": name, + "integration_type": tool.integration_type, + "arguments_count": len(arguments) if arguments else 0 + } + ) + else: + structured_logger.error( + f"Tool '{name}' invocation failed", + error=Exception(error_message) if error_message else None, + user_id=app_user_email, + resource_type="tool", + resource_id=str(tool.id), + resource_action="invoke", + duration_ms=duration_ms, + custom_fields={ + "tool_name": name, + "integration_type": tool.integration_type, + "error_message": error_message + } + ) + + # Track performance with threshold checking + with perf_tracker.track_operation("tool_invocation", name): + pass # Duration already captured above async def update_tool( self, diff --git a/mcpgateway/static/admin.js b/mcpgateway/static/admin.js index ef418b617..914db4d77 100644 --- a/mcpgateway/static/admin.js +++ b/mcpgateway/static/admin.js @@ -6573,6 +6573,14 @@ function showTab(tabName) { initializeLLMChat(); } + if (tabName === "logs") { + // Load structured logs when tab is first opened + const logsTbody = safeGetElement("logs-tbody"); + if (logsTbody && logsTbody.children.length === 0) { + searchStructuredLogs(); + } + } + if (tabName === "teams") { // Load Teams list if not already loaded const teamsList = safeGetElement("teams-list"); @@ -23977,6 +23985,275 @@ function updateEntityStatus(type, data) { const isEnabled = data.enabled !== undefined ? data.enabled : data.isActive; updateEntityActionButtons(actionCell, type, data.id, isEnabled); +// ============================================================================ +// Structured Logging UI Functions +// ============================================================================ + +// Current log search state +let currentLogPage = 0; +let currentLogLimit = 50; +let currentLogFilters = {}; + +/** + * Search structured logs with filters + */ +async function searchStructuredLogs() { + const levelFilter = document.getElementById('log-level-filter')?.value; + const componentFilter = document.getElementById('log-component-filter')?.value; + const searchQuery = document.getElementById('log-search')?.value; + + // Restore default log table headers (in case we're coming from performance metrics view) + restoreLogTableHeaders(); + + // Build search request + const searchRequest = { + limit: currentLogLimit, + offset: currentLogPage * currentLogLimit, + sort_by: 'timestamp', + sort_order: 'desc' + }; + + // Only add filters if they have actual values (not empty strings) + if (searchQuery && searchQuery.trim() !== '') { + const trimmedSearch = searchQuery.trim(); + // Check if search is a correlation ID (32 hex chars or UUID format) or text search + const correlationIdPattern = /^([0-9a-f]{32}|[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12})$/i; + if (correlationIdPattern.test(trimmedSearch)) { + searchRequest.correlation_id = trimmedSearch; + } else { + searchRequest.search_text = trimmedSearch; + } + } + if (levelFilter && levelFilter !== '') { + searchRequest.level = [levelFilter.toUpperCase()]; + } + if (componentFilter && componentFilter !== '') { + searchRequest.component = [componentFilter]; + } + + // Store filters for pagination + currentLogFilters = searchRequest; + + try { + const response = await fetch(`${getRootPath()}/api/logs/search`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'Authorization': `Bearer ${getAuthToken()}` + }, + body: JSON.stringify(searchRequest) + }); + + if (!response.ok) { + const errorText = await response.text(); + console.error('API Error Response:', errorText); + throw new Error(`Failed to search logs: ${response.statusText} - ${errorText}`); + } + + const data = await response.json(); + displayLogResults(data); + } catch (error) { + console.error('Error searching logs:', error); + showToast('Failed to search logs: ' + error.message, 'error'); + document.getElementById('logs-tbody').innerHTML = ` + + ❌ Error: ${escapeHtml(error.message)} + + `; + } +} + +/** + * Display log search results + */ +function displayLogResults(data) { + const tbody = document.getElementById('logs-tbody'); + const logCount = document.getElementById('log-count'); + const logStats = document.getElementById('log-stats'); + const prevButton = document.getElementById('prev-page'); + const nextButton = document.getElementById('next-page'); + + // Ensure default headers are shown for log view + restoreLogTableHeaders(); + + if (!data.results || data.results.length === 0) { + tbody.innerHTML = ` + + 📭 No logs found matching your criteria + + `; + logCount.textContent = '0 logs'; + logStats.innerHTML = 'No results'; + return; + } + + // Update stats + logCount.textContent = `${data.total.toLocaleString()} logs`; + const start = currentLogPage * currentLogLimit + 1; + const end = Math.min(start + data.results.length - 1, data.total); + logStats.innerHTML = ` + + Showing ${start}-${end} of ${data.total.toLocaleString()} logs + + `; + + // Update pagination buttons + prevButton.disabled = currentLogPage === 0; + nextButton.disabled = end >= data.total; + + // Render log entries + tbody.innerHTML = data.results.map(log => { + const levelClass = getLogLevelClass(log.level); + const durationDisplay = log.duration_ms ? `${log.duration_ms.toFixed(2)}ms` : '-'; + const correlationId = log.correlation_id || '-'; + const userDisplay = log.user_email || log.user_id || '-'; + + return ` + + + ${formatTimestamp(log.timestamp)} + + + + ${log.level} + + + + ${escapeHtml(log.component || '-')} + + + ${escapeHtml(truncateText(log.message, 80))} + ${log.error_details ? '⚠️' : ''} + + + ${escapeHtml(userDisplay)} + + + ${durationDisplay} + + + ${correlationId !== '-' ? ` + + ` : '-'} + + + `; + }).join(''); +} + +/** + * Get CSS class for log level badge + */ +function getLogLevelClass(level) { + const classes = { + 'DEBUG': 'bg-gray-200 text-gray-800 dark:bg-gray-600 dark:text-gray-200', + 'INFO': 'bg-blue-200 text-blue-800 dark:bg-blue-800 dark:text-blue-200', + 'WARNING': 'bg-yellow-200 text-yellow-800 dark:bg-yellow-800 dark:text-yellow-200', + 'ERROR': 'bg-red-200 text-red-800 dark:bg-red-800 dark:text-red-200', + 'CRITICAL': 'bg-purple-200 text-purple-800 dark:bg-purple-800 dark:text-purple-200' + }; + return classes[level] || classes['INFO']; +} + +/** + * Format timestamp for display + */ +function formatTimestamp(timestamp) { + const date = new Date(timestamp); + return date.toLocaleString('en-US', { + month: 'short', + day: 'numeric', + hour: '2-digit', + minute: '2-digit', + second: '2-digit' + }); +} + +/** + * Truncate text with ellipsis + */ +function truncateText(text, maxLength) { + if (!text) return ''; + return text.length > maxLength ? text.substring(0, maxLength) + '...' : text; +} + +/** + * Show detailed log entry (future enhancement - modal) + */ +function showLogDetails(logId, correlationId) { + if (correlationId) { + showCorrelationTrace(correlationId); + } else { + console.log('Log details:', logId); + showToast('Full log details view coming soon', 'info'); + } +} + +/** + * Restore default log table headers + */ +function restoreLogTableHeaders() { + const thead = document.getElementById('logs-thead'); + if (thead) { + thead.innerHTML = ` + + + Time + + + Level + + + Component + + + Message + + + User + + + Duration + + + Correlation ID + + + `; + } +} + +/** + * Trace all logs for a correlation ID + */ +async function showCorrelationTrace(correlationId) { + if (!correlationId) { + const searchInput = document.getElementById('log-search'); + correlationId = prompt('Enter Correlation ID to trace:', searchInput?.value || ''); + if (!correlationId) return; + } + + try { + const response = await fetch(`${getRootPath()}/api/logs/trace/${encodeURIComponent(correlationId)}`, { + method: 'GET', + headers: { + 'Authorization': `Bearer ${getAuthToken()}` + } + }); + + if (!response.ok) { + throw new Error(`Failed to fetch trace: ${response.statusText}`); + } + + const trace = await response.json(); + displayCorrelationTrace(trace); + } catch (error) { + console.error('Error fetching correlation trace:', error); + showToast('Failed to fetch correlation trace: ' + error.message, 'error'); } } @@ -24018,6 +24295,36 @@ function generateStatusBadgeHtml(enabled, reachable, typeLabel) { `; + * Restore default log table headers + */ +function restoreLogTableHeaders() { + const thead = document.getElementById('logs-thead'); + if (thead) { + thead.innerHTML = ` + + + Time + + + Level + + + Component + + + Message + + + User + + + Duration + + + Correlation ID + + + `; } } @@ -24137,3 +24444,692 @@ console.log("🔧 MCP SERVERS SEARCH DEBUG FUNCTIONS LOADED!"); console.log("💡 Use: window.emergencyFixMCPSearch() to fix search"); console.log("💡 Use: window.testMCPSearchManually('github') to test search"); console.log("💡 Use: window.debugMCPSearchState() to check current state"); + +/** + * Display correlation trace results + */ +function displayCorrelationTrace(trace) { + const tbody = document.getElementById('logs-tbody'); + const thead = document.getElementById('logs-thead'); + const logCount = document.getElementById('log-count'); + const logStats = document.getElementById('log-stats'); + + // Calculate total events + const totalEvents = (trace.logs?.length || 0) + + (trace.security_events?.length || 0) + + (trace.audit_trails?.length || 0); + + // Update table headers for trace view + if (thead) { + thead.innerHTML = ` + + + Time + + + Event Type + + + Component + + + Message/Description + + + User + + + Duration + + + Status/Severity + + + `; + } + + // Update stats + logCount.textContent = `${totalEvents} events`; + logStats.innerHTML = ` +
+
+ Correlation ID:
+ ${escapeHtml(trace.correlation_id)} +
+
+ Logs: ${trace.log_count || 0} +
+
+ Security: ${trace.security_events?.length || 0} +
+
+ Audit: ${trace.audit_trails?.length || 0} +
+
+ Duration: ${trace.total_duration_ms ? trace.total_duration_ms.toFixed(2) + 'ms' : 'N/A'} +
+
+ `; + + if (totalEvents === 0) { + tbody.innerHTML = ` + + 📭 No events found for this correlation ID + + `; + return; + } + + // Combine all events into a unified timeline + const allEvents = []; + + // Add logs + (trace.logs || []).forEach(log => { + const levelClass = getLogLevelClass(log.level); + allEvents.push({ + timestamp: new Date(log.timestamp), + html: ` + + + ${formatTimestamp(log.timestamp)} + + + + 📝 Log + + + + ${escapeHtml(log.component || '-')} + + + ${escapeHtml(log.message)} + ${log.error_details ? `
⚠️ ${escapeHtml(log.error_details.error_message || JSON.stringify(log.error_details))}` : ''} + + + ${escapeHtml(log.user_email || log.user_id || '-')} + + + ${log.duration_ms ? log.duration_ms.toFixed(2) + 'ms' : '-'} + + + + ${log.level} + + + + ` + }); + }); + + // Add security events + (trace.security_events || []).forEach(event => { + const severityClass = getSeverityClass(event.severity); + const threatScore = event.threat_score ? (event.threat_score * 100).toFixed(0) : 0; + allEvents.push({ + timestamp: new Date(event.timestamp), + html: ` + + + ${formatTimestamp(event.timestamp)} + + + + 🛡️ Security + + + + ${escapeHtml(event.event_type || '-')} + + + ${escapeHtml(event.description || '-')} + + + ${escapeHtml(event.user_email || event.user_id || '-')} + + + - + + +
+ + ${event.severity} + +
+ Threat: +
+
+
+ ${threatScore}% +
+
+ + + ` + }); + }); + + // Add audit trails + (trace.audit_trails || []).forEach(audit => { + const actionBadgeColors = { + 'create': 'bg-green-200 text-green-800', + 'update': 'bg-blue-200 text-blue-800', + 'delete': 'bg-red-200 text-red-800', + 'read': 'bg-gray-200 text-gray-800' + }; + const actionBadge = actionBadgeColors[audit.action?.toLowerCase()] || 'bg-purple-200 text-purple-800'; + const statusIcon = audit.success ? '✓' : '✗'; + const statusClass = audit.success ? 'text-green-600' : 'text-red-600'; + const statusBg = audit.success ? 'bg-green-100 dark:bg-green-900' : 'bg-red-100 dark:bg-red-900'; + + allEvents.push({ + timestamp: new Date(audit.timestamp), + html: ` + + + ${formatTimestamp(audit.timestamp)} + + + + 📋 ${audit.action?.toUpperCase()} + + + + ${escapeHtml(audit.resource_type || '-')} + + + ${audit.action}: ${audit.resource_type} + ${escapeHtml(audit.resource_id || '-')} + + + ${escapeHtml(audit.user_email || audit.user_id || '-')} + + + - + + + + ${statusIcon} ${audit.success ? 'Success' : 'Failed'} + + + + ` + }); + }); + + // Sort all events chronologically + allEvents.sort((a, b) => a.timestamp - b.timestamp); + + // Render sorted events + tbody.innerHTML = allEvents.map(event => event.html).join(''); +} + +/** + * Show security events + */ +async function showSecurityEvents() { + try { + const response = await fetch(`${getRootPath()}/api/logs/security-events?limit=50&resolved=false`, { + method: 'GET', + headers: { + 'Authorization': `Bearer ${getAuthToken()}` + } + }); + + if (!response.ok) { + throw new Error(`Failed to fetch security events: ${response.statusText}`); + } + + const events = await response.json(); + displaySecurityEvents(events); + } catch (error) { + console.error('Error fetching security events:', error); + showToast('Failed to fetch security events: ' + error.message, 'error'); + } +} + +/** + * Display security events + */ +function displaySecurityEvents(events) { + const tbody = document.getElementById('logs-tbody'); + const thead = document.getElementById('logs-thead'); + const logCount = document.getElementById('log-count'); + const logStats = document.getElementById('log-stats'); + + // Update table headers for security events + if (thead) { + thead.innerHTML = ` + + + Time + + + Severity + + + Event Type + + + Description + + + User/Source + + + Threat Score + + + Correlation ID + + + `; + } + + logCount.textContent = `${events.length} security events`; + logStats.innerHTML = ` + + 🛡️ Unresolved Security Events + + `; + + if (events.length === 0) { + tbody.innerHTML = ` + + ✅ No unresolved security events + + `; + return; + } + + tbody.innerHTML = events.map(event => { + const severityClass = getSeverityClass(event.severity); + const threatScore = (event.threat_score * 100).toFixed(0); + + return ` + + + ${formatTimestamp(event.timestamp)} + + + + ${event.severity} + + + + ${escapeHtml(event.event_type)} + + + ${escapeHtml(event.description)} + + + ${escapeHtml(event.user_email || event.user_id || '-')} + + +
+
+
+
+ ${threatScore}% +
+ + + ${event.correlation_id ? ` + + ` : '-'} + + + `; + }).join(''); +} + +/** + * Get CSS class for severity badge + */ +function getSeverityClass(severity) { + const classes = { + 'LOW': 'bg-blue-200 text-blue-800 dark:bg-blue-800 dark:text-blue-200', + 'MEDIUM': 'bg-yellow-200 text-yellow-800 dark:bg-yellow-800 dark:text-yellow-200', + 'HIGH': 'bg-orange-200 text-orange-800 dark:bg-orange-800 dark:text-orange-200', + 'CRITICAL': 'bg-red-200 text-red-800 dark:bg-red-800 dark:text-red-200' + }; + return classes[severity] || classes['MEDIUM']; +} + +/** + * Show audit trail + */ +async function showAuditTrail() { + try { + const response = await fetch(`${getRootPath()}/api/logs/audit-trails?limit=50&requires_review=true`, { + method: 'GET', + headers: { + 'Authorization': `Bearer ${getAuthToken()}` + } + }); + + if (!response.ok) { + throw new Error(`Failed to fetch audit trails: ${response.statusText}`); + } + + const trails = await response.json(); + displayAuditTrail(trails); + } catch (error) { + console.error('Error fetching audit trails:', error); + showToast('Failed to fetch audit trails: ' + error.message, 'error'); + } +} + +/** + * Display audit trail entries + */ +function displayAuditTrail(trails) { + const tbody = document.getElementById('logs-tbody'); + const thead = document.getElementById('logs-thead'); + const logCount = document.getElementById('log-count'); + const logStats = document.getElementById('log-stats'); + + // Update table headers for audit trail + if (thead) { + thead.innerHTML = ` + + + Time + + + Action + + + Resource Type + + + Resource + + + User + + + Status + + + Correlation ID + + + `; + } + + logCount.textContent = `${trails.length} audit entries`; + logStats.innerHTML = ` + + 📝 Audit Trail Entries Requiring Review + + `; + + if (trails.length === 0) { + tbody.innerHTML = ` + + ✅ No audit entries require review + + `; + return; + } + + tbody.innerHTML = trails.map(trail => { + const actionClass = trail.success ? 'text-green-600' : 'text-red-600'; + const actionIcon = trail.success ? '✓' : '✗'; + + // Determine action badge color + const actionBadgeColors = { + 'create': 'bg-green-200 text-green-800 dark:bg-green-800 dark:text-green-200', + 'update': 'bg-blue-200 text-blue-800 dark:bg-blue-800 dark:text-blue-200', + 'delete': 'bg-red-200 text-red-800 dark:bg-red-800 dark:text-red-200', + 'read': 'bg-gray-200 text-gray-800 dark:bg-gray-600 dark:text-gray-200', + 'activate': 'bg-teal-200 text-teal-800 dark:bg-teal-800 dark:text-teal-200', + 'deactivate': 'bg-orange-200 text-orange-800 dark:bg-orange-800 dark:text-orange-200' + }; + const actionBadge = actionBadgeColors[trail.action.toLowerCase()] || 'bg-purple-200 text-purple-800 dark:bg-purple-800 dark:text-purple-200'; + + // Format resource name with ID + const resourceName = trail.resource_name || trail.resource_id || '-'; + const resourceDisplay = ` +
${escapeHtml(resourceName)}
+ ${trail.resource_id && trail.resource_name ? `
ID: ${escapeHtml(trail.resource_id)}
` : ''} + ${trail.data_classification ? `
🔒 ${escapeHtml(trail.data_classification)}
` : ''} + `; + + return ` + + + ${formatTimestamp(trail.timestamp)} + + + + ${trail.action.toUpperCase()} + + + + ${escapeHtml(trail.resource_type || '-')} + + + ${resourceDisplay} + + + ${escapeHtml(trail.user_email || trail.user_id || '-')} + + + ${actionIcon} ${trail.success ? 'Success' : 'Failed'} + + + ${trail.correlation_id ? ` + + ` : '-'} + + + `; + }).join(''); +} + +/** + * Show performance metrics + */ +async function showPerformanceMetrics() { + try { + const response = await fetch(`${getRootPath()}/api/logs/performance-metrics?hours=24`, { + method: 'GET', + headers: { + 'Authorization': `Bearer ${getAuthToken()}` + } + }); + + if (!response.ok) { + throw new Error(`Failed to fetch performance metrics: ${response.statusText}`); + } + + const metrics = await response.json(); + displayPerformanceMetrics(metrics); + } catch (error) { + console.error('Error fetching performance metrics:', error); + showToast('Failed to fetch performance metrics: ' + error.message, 'error'); + } +} + +/** + * Display performance metrics + */ +function displayPerformanceMetrics(metrics) { + const tbody = document.getElementById('logs-tbody'); + const thead = document.getElementById('logs-thead'); + const logCount = document.getElementById('log-count'); + const logStats = document.getElementById('log-stats'); + + // Update table headers for performance metrics + if (thead) { + thead.innerHTML = ` + + + Time + + + Component + + + Operation + + + Avg Duration + + + Requests + + + Error Rate + + + P99 Duration + + + `; + } + + logCount.textContent = `${metrics.length} metrics`; + logStats.innerHTML = ` + + ⚡ Performance Metrics (Last 24 Hours) + + `; + + if (metrics.length === 0) { + tbody.innerHTML = ` + + 📊 No performance metrics available + + `; + return; + } + + tbody.innerHTML = metrics.map(metric => { + const errorRatePercent = (metric.error_rate * 100).toFixed(2); + const errorClass = metric.error_rate > 0.1 ? 'text-red-600' : 'text-green-600'; + + return ` + + + ${formatTimestamp(metric.window_start)} + + + ${escapeHtml(metric.component || '-')} + + + ${escapeHtml(metric.operation_type || '-')} + + +
+
Avg: ${metric.avg_duration_ms.toFixed(2)}ms
+
P95: ${metric.p95_duration_ms.toFixed(2)}ms
+
+ + + ${metric.request_count.toLocaleString()} requests + + + ${errorRatePercent}% + ${metric.error_rate > 0.1 ? '⚠️' : ''} + + +
+ P99: ${metric.p99_duration_ms.toFixed(2)}ms +
+ + + `; + }).join(''); +} + +/** + * Navigate to previous log page + */ +function previousLogPage() { + if (currentLogPage > 0) { + currentLogPage--; + searchStructuredLogs(); + } +} + +/** + * Navigate to next log page + */ +function nextLogPage() { + currentLogPage++; + searchStructuredLogs(); +} + +/** + * Get auth token from session + */ +function getAuthToken() { + // Check cookie first (matches HTMX authentication) + const jwtToken = getCookie('jwt_token') || getCookie('access_token') || getCookie('token'); + if (jwtToken) { + return jwtToken; + } + + // Fallback: check localStorage + const localToken = localStorage.getItem('auth_token'); + if (localToken) { + return localToken; + } + + // Last resort: check input field + const tokenInput = document.querySelector('input[name="auth_token"]'); + if (tokenInput && tokenInput.value) { + return tokenInput.value; + } + + // No token found - log warning for debugging + console.warn('No authentication token found for API request'); + return ''; +} + +/** + * Get root path for API calls + */ +function getRootPath() { + return window.ROOT_PATH || ''; +} + +/** + * Escape HTML to prevent XSS + */ +function escapeHtml(text) { + if (!text) return ''; + const div = document.createElement('div'); + div.textContent = text; + return div.innerHTML; +} + +/** + * Show toast notification + */ +function showToast(message, type = 'info') { + // Check if showMessage function exists (from existing admin.js) + if (typeof showMessage === 'function') { + showMessage(message, type === 'error' ? 'danger' : type); + } else { + console.log(`[${type.toUpperCase()}] ${message}`); + } +} + +// Make functions globally available for HTML onclick handlers +window.searchStructuredLogs = searchStructuredLogs; +window.showCorrelationTrace = showCorrelationTrace; +window.showSecurityEvents = showSecurityEvents; +window.showAuditTrail = showAuditTrail; +window.showPerformanceMetrics = showPerformanceMetrics; +window.previousLogPage = previousLogPage; +window.nextLogPage = nextLogPage; +window.showLogDetails = showLogDetails; diff --git a/mcpgateway/templates/admin.html b/mcpgateway/templates/admin.html index 97ba325bd..94cba2c52 100644 --- a/mcpgateway/templates/admin.html +++ b/mcpgateway/templates/admin.html @@ -370,7 +370,7 @@ class="text-gray-700 dark:text-gray-300 block px-4 py-2 text-sm hover:bg-gray-100 dark:hover:bg-gray-700" role="menuitem" onclick="showTab('logs')" - >📋 System Logs📋 Structured Logs - System Logs + 📋 Structured Logs & Analytics + +
+ + + + + +
+
@@ -616,77 +650,40 @@
Component
Search / Correlation ID
- -
- - - - - -
-
- Loading stats... + Loading...
@@ -694,28 +691,43 @@ - + + + + Breakdown by Type: let currentLogPage = 0; const logsPerPage = 100; - async function refreshLogs() { - const level = document.getElementById("log-level-filter").value; - const entityType = document.getElementById("log-entity-filter").value; - const search = document.getElementById("log-search").value; - - const params = new URLSearchParams({ - limit: logsPerPage, - offset: currentLogPage * logsPerPage, - order: "desc", - }); - - if (level) params.append("level", level); - if (entityType) params.append("entity_type", entityType); - if (search) params.append("search", search); - - try { - const headers = {}; - const token = localStorage.getItem("token"); - if (token) { - headers["Authorization"] = `Bearer ${token}`; - } - - const response = await fetch( - `${window.ROOT_PATH || ""}/admin/logs?${params}`, - { - headers: headers, - credentials: "same-origin", - }, - ); - - if (!response.ok) throw new Error(`HTTP ${response.status}`); - - const data = await response.json(); - displayLogs(data.logs); - updateLogStats(data.stats); - updateLogCount(data.total); - } catch (error) { - console.error("Error fetching logs:", error); - showErrorMessage("Failed to fetch logs"); - } - } - - function displayLogs(logs) { - const tbody = document.getElementById("logs-tbody"); - tbody.innerHTML = ""; - - logs.forEach((log) => { - const row = document.createElement("tr"); - row.className = "hover:bg-gray-50 dark:hover:bg-gray-700"; - - const timestamp = new Date(log.timestamp).toLocaleString(); - const levelClass = getLevelClass(log.level); - const entity = log.entity_name || log.entity_id || "-"; - - row.innerHTML = ` - - - - - `; - - tbody.appendChild(row); - }); - } - - function getLevelClass(level) { - switch (level) { - case "debug": - return "bg-gray-100 text-gray-800"; - case "info": - return "bg-blue-100 text-blue-800"; - case "warning": - return "bg-yellow-100 text-yellow-800"; - case "error": - return "bg-red-100 text-red-800"; - case "critical": - return "bg-red-600 text-white"; - default: - return "bg-gray-100 text-gray-800"; - } - } - - function updateLogStats(stats) { - if (!stats) return; - - const statsDiv = document.getElementById("log-stats"); - const levelDist = stats.level_distribution || {}; - const entityDist = stats.entity_distribution || {}; - - let html = ` -
- Buffer: ${stats.usage_percent || 0}% (${stats.buffer_size_mb || 0}/${stats.max_size_mb || 0} MB) - Total: ${stats.total_logs || 0} logs - `; - - if (Object.keys(levelDist).length > 0) { - html += "Levels: "; - for (const [level, count] of Object.entries(levelDist)) { - html += `${level}(${count}) `; - } - html += ""; - } - - html += "
"; - statsDiv.innerHTML = html; - } - - function updateLogCount(total) { - document.getElementById("log-count").textContent = `${total} logs`; - - // Update pagination buttons - document.getElementById("prev-page").disabled = currentLogPage === 0; - document.getElementById("next-page").disabled = - (currentLogPage + 1) * logsPerPage >= total; - } + // Main search function for structured logs + // Note: Structured logging functions are defined in admin.js which loads below: + // - searchStructuredLogs() - Search logs with filters + // - displayLogResults() - Display log table + // - showCorrelationTrace() - Show correlation trace modal + // - displayCorrelationTrace() - Display trace results + // - Helper functions: getLevelClass(), formatDuration(), getDurationClass() + // + // Keeping all structured logging UI logic centralized in admin.js to avoid + // duplication and maintenance issues. admin.js loads last and provides the + // definitive implementations. + + // Note: showSecurityEvents, showAuditTrail, showPerformanceMetrics, + // updateLogStats, and updateLogCount are also defined in admin.js which + // is loaded below and overrides any inline definitions. + // Keeping functions centralized in admin.js to avoid duplication and maintenance issues. function toggleLogStream() { const button = document.getElementById("stream-toggle"); @@ -11303,7 +11206,7 @@

Breakdown by Type:

try { // Use the same auth approach as other admin endpoints const headers = {}; - const token = localStorage.getItem("token"); + const token = getAuthToken(); if (token) { headers["Authorization"] = `Bearer ${token}`; } @@ -11354,7 +11257,7 @@

Breakdown by Type:

async function showLogFiles() { try { const headers = {}; - const token = localStorage.getItem("token"); + const token = getAuthToken(); if (token) { headers["Authorization"] = `Bearer ${token}`; } @@ -11422,7 +11325,7 @@

Available Log Files

async function downloadLogFile(filename) { try { const headers = {}; - const token = localStorage.getItem("token"); + const token = getAuthToken(); if (token) { headers["Authorization"] = `Bearer ${token}`; } @@ -11483,7 +11386,7 @@

Available Log Files

document.addEventListener("DOMContentLoaded", () => { const logFilters = [ "log-level-filter", - "log-entity-filter", + "log-component-filter", "log-search", ]; logFilters.forEach((id) => { From fca544964312cf54acf20a18ba41c12a0ed2585e Mon Sep 17 00:00:00 2001 From: Shoumi Date: Mon, 24 Nov 2025 17:25:58 +0530 Subject: [PATCH 10/34] fix bug Signed-off-by: Shoumi --- mcpgateway/services/structured_logger.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/mcpgateway/services/structured_logger.py b/mcpgateway/services/structured_logger.py index b43b98958..92814c253 100644 --- a/mcpgateway/services/structured_logger.py +++ b/mcpgateway/services/structured_logger.py @@ -81,11 +81,15 @@ def enrich(entry: Dict[str, Any]) -> Dict[str, Any]: entry["timestamp"] = datetime.now(timezone.utc) # Add performance metrics if available - perf_tracker = get_performance_tracker() - if correlation_id and perf_tracker: - current_ops = perf_tracker.get_current_operations(correlation_id) - if current_ops: - entry["active_operations"] = len(current_ops) + try: + perf_tracker = get_performance_tracker() + if correlation_id and perf_tracker and hasattr(perf_tracker, 'get_current_operations'): + current_ops = perf_tracker.get_current_operations(correlation_id) + if current_ops: + entry["active_operations"] = len(current_ops) + except Exception: + # Silently skip if performance tracker is unavailable or method doesn't exist + pass # Add OpenTelemetry trace context if available try: @@ -257,7 +261,11 @@ def _persist_to_database(self, entry: Dict[str, Any], db: Optional[Session] = No db.commit() except Exception as e: - logger.error(f"Failed to persist log entry to database: {e}") + logger.error(f"Failed to persist log entry to database: {e}", exc_info=True) + # Also print to console for immediate visibility + import traceback + print(f"ERROR persisting log to database: {e}") + traceback.print_exc() if db: db.rollback() From 1c5fa705ebccd030ee75f62152b71639d11c0d5e Mon Sep 17 00:00:00 2001 From: Shoumi Date: Mon, 24 Nov 2025 19:27:32 +0530 Subject: [PATCH 11/34] dropdown mismatch fix Signed-off-by: Shoumi --- mcpgateway/templates/admin.html | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mcpgateway/templates/admin.html b/mcpgateway/templates/admin.html index 94cba2c52..660a33d34 100644 --- a/mcpgateway/templates/admin.html +++ b/mcpgateway/templates/admin.html @@ -657,10 +657,10 @@ class="mt-1 px-1.5 block w-full border-gray-300 rounded-md shadow-sm dark:bg-gray-700 dark:border-gray-600 text-gray-700 dark:text-gray-300" > + - From 88749964968ccbd341f399e335be0ef56339685f Mon Sep 17 00:00:00 2001 From: Shoumi Date: Tue, 25 Nov 2025 11:25:16 +0530 Subject: [PATCH 12/34] fixes for UI Signed-off-by: Shoumi --- mcpgateway/config.py | 1 + mcpgateway/main.py | 55 +++++- mcpgateway/routers/log_search.py | 10 + mcpgateway/services/log_aggregator.py | 271 ++++++++++++++++++-------- mcpgateway/static/admin.js | 2 +- mcpgateway/templates/admin.html | 12 +- 6 files changed, 258 insertions(+), 93 deletions(-) diff --git a/mcpgateway/config.py b/mcpgateway/config.py index 1d521dcb6..4956a5ff2 100644 --- a/mcpgateway/config.py +++ b/mcpgateway/config.py @@ -803,6 +803,7 @@ def _parse_allowed_origins(cls, v: Any) -> Set[str]: # Metrics Aggregation Configuration metrics_aggregation_enabled: bool = Field(default=True, description="Enable automatic log aggregation into performance metrics") + metrics_aggregation_backfill_hours: int = Field(default=6, ge=0, le=168, description="Hours of structured logs to backfill into performance metrics on startup") metrics_aggregation_window_minutes: int = Field(default=5, description="Time window for metrics aggregation (minutes)") # Log Search Configuration diff --git a/mcpgateway/main.py b/mcpgateway/main.py index c14a20bd7..41eabbd32 100644 --- a/mcpgateway/main.py +++ b/mcpgateway/main.py @@ -27,7 +27,7 @@ # Standard import asyncio -from contextlib import asynccontextmanager +from contextlib import asynccontextmanager, suppress from datetime import datetime import json import os as _os # local alias to avoid collisions @@ -114,6 +114,7 @@ from mcpgateway.services.import_service import ImportError as ImportServiceError from mcpgateway.services.import_service import ImportService, ImportValidationError from mcpgateway.services.logging_service import LoggingService +from mcpgateway.services.log_aggregator import get_log_aggregator from mcpgateway.services.metrics import setup_metrics from mcpgateway.services.prompt_service import PromptError, PromptNameConflictError, PromptNotFoundError, PromptService from mcpgateway.services.resource_service import ResourceError, ResourceNotFoundError, ResourceService, ResourceURIConflictError @@ -407,6 +408,10 @@ async def lifespan(_app: FastAPI) -> AsyncIterator[None]: Exception: Any unhandled error that occurs during service initialisation or shutdown is re-raised to the caller. """ + aggregation_stop_event: Optional[asyncio.Event] = None + aggregation_loop_task: Optional[asyncio.Task] = None + aggregation_backfill_task: Optional[asyncio.Task] = None + # Initialize logging service FIRST to ensure all logging goes to dual output await logging_service.initialize() logger.info("Starting MCP Gateway services") @@ -462,6 +467,46 @@ async def lifespan(_app: FastAPI) -> AsyncIterator[None]: # Reconfigure uvicorn loggers after startup to capture access logs in dual output logging_service.configure_uvicorn_after_startup() + if settings.metrics_aggregation_enabled: + aggregation_stop_event = asyncio.Event() + log_aggregator = get_log_aggregator() + + async def run_log_backfill() -> None: + hours = getattr(settings, "metrics_aggregation_backfill_hours", 0) + if hours <= 0: + return + try: + await asyncio.to_thread(log_aggregator.backfill, hours) + logger.info("Log aggregation backfill completed for last %s hour(s)", hours) + except Exception as backfill_error: # pragma: no cover - defensive logging + logger.warning("Log aggregation backfill failed: %s", backfill_error) + + async def run_log_aggregation_loop() -> None: + interval_seconds = max(1, int(settings.metrics_aggregation_window_minutes)) * 60 + logger.info( + "Starting log aggregation loop (window=%s min)", + log_aggregator.aggregation_window_minutes, + ) + try: + while not aggregation_stop_event.is_set(): + try: + await asyncio.to_thread(log_aggregator.aggregate_all_components) + except Exception as agg_error: # pragma: no cover - defensive logging + logger.warning("Log aggregation loop iteration failed: %s", agg_error) + + try: + await asyncio.wait_for(aggregation_stop_event.wait(), timeout=interval_seconds) + except asyncio.TimeoutError: + continue + except asyncio.CancelledError: + logger.debug("Log aggregation loop cancelled") + raise + finally: + logger.info("Log aggregation loop stopped") + + aggregation_backfill_task = asyncio.create_task(run_log_backfill()) + aggregation_loop_task = asyncio.create_task(run_log_aggregation_loop()) + yield except Exception as e: logger.error(f"Error during startup: {str(e)}") @@ -475,6 +520,14 @@ async def lifespan(_app: FastAPI) -> AsyncIterator[None]: raise SystemExit(1) raise finally: + if aggregation_stop_event is not None: + aggregation_stop_event.set() + for task in (aggregation_backfill_task, aggregation_loop_task): + if task: + task.cancel() + with suppress(asyncio.CancelledError): + await task + # Shutdown plugin manager if plugin_manager: try: diff --git a/mcpgateway/routers/log_search.py b/mcpgateway/routers/log_search.py index 3cc2174ad..8c5e5cb52 100644 --- a/mcpgateway/routers/log_search.py +++ b/mcpgateway/routers/log_search.py @@ -22,6 +22,7 @@ from sqlalchemy.sql import func as sa_func # First-Party +from mcpgateway.config import settings from mcpgateway.db import ( AuditTrail, PerformanceMetric, @@ -30,6 +31,7 @@ get_db, ) from mcpgateway.middleware.rbac import require_permission, get_current_user_with_permissions +from mcpgateway.services.log_aggregator import get_log_aggregator logger = logging.getLogger(__name__) @@ -578,6 +580,14 @@ async def get_performance_metrics( stmt = stmt.order_by(desc(PerformanceMetric.window_start)) metrics = db.execute(stmt).scalars().all() + + if not metrics and settings.metrics_aggregation_enabled: + try: + aggregator = get_log_aggregator() + aggregator.backfill(hours=hours, db=db) + metrics = db.execute(stmt).scalars().all() + except Exception as agg_error: # pragma: no cover - defensive logging + logger.warning("On-demand metrics aggregation failed: %s", agg_error) return [ PerformanceMetricResponse( diff --git a/mcpgateway/services/log_aggregator.py b/mcpgateway/services/log_aggregator.py index 6301ebb7f..987a93cfc 100644 --- a/mcpgateway/services/log_aggregator.py +++ b/mcpgateway/services/log_aggregator.py @@ -12,11 +12,12 @@ # Standard from datetime import datetime, timedelta, timezone import logging +import math import statistics -from typing import Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple # Third-Party -from sqlalchemy import and_, func, select +from sqlalchemy import and_, select from sqlalchemy.orm import Session # First-Party @@ -36,8 +37,8 @@ def __init__(self): def aggregate_performance_metrics( self, - component: str, - operation: str, + component: Optional[str], + operation_type: Optional[str], window_start: Optional[datetime] = None, window_end: Optional[datetime] = None, db: Optional[Session] = None @@ -56,12 +57,10 @@ def aggregate_performance_metrics( """ if not self.enabled: return None + if not component or not operation_type: + return None - # Default time window - if window_end is None: - window_end = datetime.now(timezone.utc) - if window_start is None: - window_start = window_end - timedelta(minutes=self.aggregation_window_minutes) + window_start, window_end = self._resolve_window_bounds(window_start, window_end) should_close = False if db is None: @@ -73,10 +72,9 @@ def aggregate_performance_metrics( stmt = select(StructuredLogEntry).where( and_( StructuredLogEntry.component == component, - StructuredLogEntry.category == "performance", - StructuredLogEntry.resource_action == operation, + StructuredLogEntry.operation_type == operation_type, StructuredLogEntry.timestamp >= window_start, - StructuredLogEntry.timestamp <= window_end, + StructuredLogEntry.timestamp < window_end, StructuredLogEntry.duration_ms.isnot(None) ) ) @@ -87,74 +85,49 @@ def aggregate_performance_metrics( return None # Extract durations - durations = [r.duration_ms for r in results if r.duration_ms is not None] + durations = sorted(r.duration_ms for r in results if r.duration_ms is not None) if not durations: return None # Calculate statistics count = len(durations) - total_duration = sum(durations) - avg_duration = statistics.mean(durations) - min_duration = min(durations) - max_duration = max(durations) + avg_duration = statistics.fmean(durations) if hasattr(statistics, "fmean") else statistics.mean(durations) + min_duration = durations[0] + max_duration = durations[-1] # Calculate percentiles - sorted_durations = sorted(durations) - p50 = self._percentile(sorted_durations, 0.50) - p95 = self._percentile(sorted_durations, 0.95) - p99 = self._percentile(sorted_durations, 0.99) + p50 = self._percentile(durations, 0.50) + p95 = self._percentile(durations, 0.95) + p99 = self._percentile(durations, 0.99) # Count errors - error_count = sum(1 for r in results if r.error_message is not None) + error_count = self._calculate_error_count(results) error_rate = error_count / count if count > 0 else 0.0 - # Aggregate database metrics - db_queries = [r.database_query_count for r in results if r.database_query_count is not None] - total_db_queries = sum(db_queries) if db_queries else 0 - avg_db_queries = statistics.mean(db_queries) if db_queries else 0.0 - - db_durations = [r.database_query_duration_ms for r in results if r.database_query_duration_ms is not None] - total_db_duration = sum(db_durations) if db_durations else 0.0 - avg_db_duration = statistics.mean(db_durations) if db_durations else 0.0 - - # Aggregate cache metrics - cache_hits = sum(r.cache_hits for r in results if r.cache_hits is not None) - cache_misses = sum(r.cache_misses for r in results if r.cache_misses is not None) - cache_total = cache_hits + cache_misses - cache_hit_rate = cache_hits / cache_total if cache_total > 0 else 0.0 - - # Create performance metric - metric = PerformanceMetric( + metric = self._upsert_metric( component=component, - operation=operation, + operation_type=operation_type, window_start=window_start, window_end=window_end, request_count=count, error_count=error_count, error_rate=error_rate, - total_duration_ms=total_duration, avg_duration_ms=avg_duration, min_duration_ms=min_duration, max_duration_ms=max_duration, p50_duration_ms=p50, p95_duration_ms=p95, p99_duration_ms=p99, - total_database_queries=total_db_queries, - avg_database_queries=avg_db_queries, - total_database_duration_ms=total_db_duration, - avg_database_duration_ms=avg_db_duration, - cache_hits=cache_hits, - cache_misses=cache_misses, - cache_hit_rate=cache_hit_rate, + metric_metadata={ + "sample_size": count, + "generated_at": datetime.now(timezone.utc).isoformat(), + }, + db=db, ) - db.add(metric) - db.commit() - db.refresh(metric) - logger.info( - f"Aggregated performance metrics for {component}.{operation}: " + f"Aggregated performance metrics for {component}.{operation_type}: " f"{count} requests, {avg_duration:.2f}ms avg, {error_rate:.2%} error rate" ) @@ -195,21 +168,17 @@ def aggregate_all_components( should_close = True try: - # Get unique component/operation pairs - if window_end is None: - window_end = datetime.now(timezone.utc) - if window_start is None: - window_start = window_end - timedelta(minutes=self.aggregation_window_minutes) - + window_start, window_end = self._resolve_window_bounds(window_start, window_end) + stmt = select( StructuredLogEntry.component, - StructuredLogEntry.resource_action + StructuredLogEntry.operation_type ).where( and_( - StructuredLogEntry.category == "performance", StructuredLogEntry.timestamp >= window_start, - StructuredLogEntry.timestamp <= window_end, - StructuredLogEntry.duration_ms.isnot(None) + StructuredLogEntry.timestamp < window_end, + StructuredLogEntry.duration_ms.isnot(None), + StructuredLogEntry.operation_type.isnot(None) ) ).distinct() @@ -220,7 +189,7 @@ def aggregate_all_components( if component and operation: metric = self.aggregate_performance_metrics( component=component, - operation=operation, + operation_type=operation, window_start=window_start, window_end=window_end, db=db @@ -267,7 +236,7 @@ def get_recent_metrics( if component: stmt = stmt.where(PerformanceMetric.component == component) if operation: - stmt = stmt.where(PerformanceMetric.operation == operation) + stmt = stmt.where(PerformanceMetric.operation_type == operation) stmt = stmt.order_by(PerformanceMetric.window_start.desc()) @@ -282,7 +251,7 @@ def get_degradation_alerts( threshold_multiplier: float = 1.5, hours: int = 24, db: Optional[Session] = None - ) -> List[Dict[str, any]]: + ) -> List[Dict[str, Any]]: """Identify performance degradations by comparing recent vs baseline. Args: @@ -305,7 +274,7 @@ def get_degradation_alerts( # Get unique component/operation pairs stmt = select( PerformanceMetric.component, - PerformanceMetric.operation + PerformanceMetric.operation_type ).distinct() pairs = db.execute(stmt).all() @@ -316,7 +285,7 @@ def get_degradation_alerts( recent_stmt = select(PerformanceMetric).where( and_( PerformanceMetric.component == component, - PerformanceMetric.operation == operation, + PerformanceMetric.operation_type == operation, PerformanceMetric.window_start >= recent_cutoff ) ) @@ -326,7 +295,7 @@ def get_degradation_alerts( baseline_stmt = select(PerformanceMetric).where( and_( PerformanceMetric.component == component, - PerformanceMetric.operation == operation, + PerformanceMetric.operation_type == operation, PerformanceMetric.window_start >= baseline_cutoff, PerformanceMetric.window_start < recent_cutoff ) @@ -356,31 +325,161 @@ def get_degradation_alerts( if should_close: db.close() - @staticmethod - def _percentile(sorted_values: List[float], percentile: float) -> float: - """Calculate percentile from sorted values. - + def backfill(self, hours: int, db: Optional[Session] = None) -> int: + """Backfill metrics for a historical time range. + Args: - sorted_values: Sorted list of values - percentile: Percentile to calculate (0.0 to 1.0) - + hours: Number of hours of history to aggregate + db: Optional shared database session + Returns: - Percentile value + Count of performance metric windows processed """ + if not self.enabled or hours <= 0: + return 0 + + window_minutes = self.aggregation_window_minutes + window_delta = timedelta(minutes=window_minutes) + total_windows = max(1, math.ceil((hours * 60) / window_minutes)) + + should_close = False + if db is None: + db = SessionLocal() + should_close = True + + try: + _, latest_end = self._resolve_window_bounds(None, None) + current_start = latest_end - (window_delta * total_windows) + processed = 0 + + while current_start < latest_end: + current_end = current_start + window_delta + created = self.aggregate_all_components( + window_start=current_start, + window_end=current_end, + db=db, + ) + if created: + processed += 1 + current_start = current_end + + return processed + + finally: + if should_close: + db.close() + + @staticmethod + def _percentile(sorted_values: List[float], percentile: float) -> float: + """Calculate percentile from sorted values.""" if not sorted_values: return 0.0 - + + if len(sorted_values) == 1: + return float(sorted_values[0]) + k = (len(sorted_values) - 1) * percentile - f = int(k) - c = f + 1 - - if c >= len(sorted_values): - return sorted_values[-1] - + f = math.floor(k) + c = math.ceil(k) + + if f == c: + return float(sorted_values[int(k)]) + d0 = sorted_values[f] * (c - k) d1 = sorted_values[c] * (k - f) - - return d0 + d1 + return float(d0 + d1) + + @staticmethod + def _calculate_error_count(entries: List[StructuredLogEntry]) -> int: + """Calculate error occurrences for a batch of log entries.""" + error_levels = {"ERROR", "CRITICAL"} + return sum(1 for entry in entries if (entry.level and entry.level.upper() in error_levels) or entry.error_details) + + def _resolve_window_bounds( + self, + window_start: Optional[datetime], + window_end: Optional[datetime], + ) -> Tuple[datetime, datetime]: + """Resolve and normalize aggregation window bounds.""" + window_delta = timedelta(minutes=self.aggregation_window_minutes) + + if window_end is None: + reference = datetime.now(timezone.utc) + else: + reference = window_end.astimezone(timezone.utc) + + reference = reference.replace(second=0, microsecond=0) + minutes_offset = reference.minute % self.aggregation_window_minutes + if window_end is None and minutes_offset: + reference = reference - timedelta(minutes=minutes_offset) + + resolved_end = reference if window_end is None else reference + + if window_start is None: + resolved_start = resolved_end - window_delta + else: + resolved_start = window_start.astimezone(timezone.utc) + + if resolved_end <= resolved_start: + resolved_start = resolved_end - window_delta + + return resolved_start, resolved_end + + def _upsert_metric( + self, + component: str, + operation_type: str, + window_start: datetime, + window_end: datetime, + request_count: int, + error_count: int, + error_rate: float, + avg_duration_ms: float, + min_duration_ms: float, + max_duration_ms: float, + p50_duration_ms: float, + p95_duration_ms: float, + p99_duration_ms: float, + metric_metadata: Optional[Dict[str, Any]], + db: Session, + ) -> PerformanceMetric: + """Create or update a performance metric window.""" + + existing_stmt = select(PerformanceMetric).where( + and_( + PerformanceMetric.component == component, + PerformanceMetric.operation_type == operation_type, + PerformanceMetric.window_start == window_start, + PerformanceMetric.window_end == window_end, + ) + ) + + metric = db.execute(existing_stmt).scalar_one_or_none() + + if metric is None: + metric = PerformanceMetric( + component=component, + operation_type=operation_type, + window_start=window_start, + window_end=window_end, + window_duration_seconds=int((window_end - window_start).total_seconds()), + ) + db.add(metric) + + metric.request_count = request_count + metric.error_count = error_count + metric.error_rate = error_rate + metric.avg_duration_ms = avg_duration_ms + metric.min_duration_ms = min_duration_ms + metric.max_duration_ms = max_duration_ms + metric.p50_duration_ms = p50_duration_ms + metric.p95_duration_ms = p95_duration_ms + metric.p99_duration_ms = p99_duration_ms + metric.metric_metadata = metric_metadata + + db.commit() + db.refresh(metric) + return metric # Global log aggregator instance diff --git a/mcpgateway/static/admin.js b/mcpgateway/static/admin.js index 914db4d77..c8dce059c 100644 --- a/mcpgateway/static/admin.js +++ b/mcpgateway/static/admin.js @@ -24025,7 +24025,7 @@ async function searchStructuredLogs() { } } if (levelFilter && levelFilter !== '') { - searchRequest.level = [levelFilter.toUpperCase()]; + searchRequest.level = [levelFilter]; } if (componentFilter && componentFilter !== '') { searchRequest.component = [componentFilter]; diff --git a/mcpgateway/templates/admin.html b/mcpgateway/templates/admin.html index 660a33d34..c569a0b0a 100644 --- a/mcpgateway/templates/admin.html +++ b/mcpgateway/templates/admin.html @@ -637,13 +637,14 @@ @@ -655,6 +656,7 @@ `; return; @@ -25130,6 +25205,7 @@ window.showCorrelationTrace = showCorrelationTrace; window.showSecurityEvents = showSecurityEvents; window.showAuditTrail = showAuditTrail; window.showPerformanceMetrics = showPerformanceMetrics; +window.handlePerformanceAggregationChange = handlePerformanceAggregationChange; window.previousLogPage = previousLogPage; window.nextLogPage = nextLogPage; window.showLogDetails = showLogDetails; diff --git a/mcpgateway/templates/admin.html b/mcpgateway/templates/admin.html index c569a0b0a..1442a98f8 100644 --- a/mcpgateway/templates/admin.html +++ b/mcpgateway/templates/admin.html @@ -627,8 +627,28 @@ + + + -
+
From 4683a981748622fd036203731198a9a9486049ba Mon Sep 17 00:00:00 2001 From: Shoumi Date: Wed, 26 Nov 2025 18:12:19 +0530 Subject: [PATCH 15/34] flake8 fixes Signed-off-by: Shoumi --- mcpgateway/admin.py | 50 +-- ...6f7g8h9i0_add_structured_logging_tables.py | 337 +++++++++--------- mcpgateway/auth.py | 16 +- mcpgateway/db.py | 138 +++---- mcpgateway/main.py | 6 +- mcpgateway/middleware/auth_middleware.py | 12 +- mcpgateway/middleware/correlation_id.py | 11 +- .../middleware/request_logging_middleware.py | 51 ++- mcpgateway/routers/log_search.py | 198 +++++----- mcpgateway/services/a2a_service.py | 43 +-- mcpgateway/services/audit_trail_service.py | 133 +++---- mcpgateway/services/gateway_service.py | 46 +-- mcpgateway/services/log_aggregator.py | 206 +++++------ mcpgateway/services/performance_tracker.py | 154 ++++---- mcpgateway/services/prompt_service.py | 46 +-- mcpgateway/services/resource_service.py | 44 +-- mcpgateway/services/security_logger.py | 219 +++++------- mcpgateway/services/server_service.py | 36 +- mcpgateway/services/structured_logger.py | 119 +++---- mcpgateway/services/tool_service.py | 148 +++----- mcpgateway/utils/correlation_id.py | 4 +- 21 files changed, 875 insertions(+), 1142 deletions(-) diff --git a/mcpgateway/admin.py b/mcpgateway/admin.py index 55f7009bc..0befe2593 100644 --- a/mcpgateway/admin.py +++ b/mcpgateway/admin.py @@ -12573,7 +12573,7 @@ async def list_plugins( # Log plugin marketplace browsing activity structured_logger.info( - f"User browsed plugin marketplace", + "User browsed plugin marketplace", user_id=str(user.id), user_email=get_user_email(user), component="plugin_marketplace", @@ -12588,9 +12588,9 @@ async def list_plugins( "results_count": len(plugins), "enabled_count": enabled_count, "disabled_count": disabled_count, - "has_filters": any([search, mode, hook, tag]) + "has_filters": any([search, mode, hook, tag]), }, - db=db + db=db, ) return PluginListResponse(plugins=plugins, total=len(plugins), enabled_count=enabled_count, disabled_count=disabled_count) @@ -12598,13 +12598,7 @@ async def list_plugins( except Exception as e: LOGGER.error(f"Error listing plugins: {e}") structured_logger.error( - f"Failed to list plugins in marketplace", - user_id=str(user.id), - user_email=get_user_email(user), - error=e, - component="plugin_marketplace", - category="business_logic", - db=db + "Failed to list plugins in marketplace", user_id=str(user.id), user_email=get_user_email(user), error=e, component="plugin_marketplace", category="business_logic", db=db ) raise HTTPException(status_code=500, detail=str(e)) @@ -12641,7 +12635,7 @@ async def get_plugin_stats(request: Request, db: Session = Depends(get_db), user # Log marketplace analytics access structured_logger.info( - f"User accessed plugin marketplace statistics", + "User accessed plugin marketplace statistics", user_id=str(user.id), user_email=get_user_email(user), component="plugin_marketplace", @@ -12654,9 +12648,9 @@ async def get_plugin_stats(request: Request, db: Session = Depends(get_db), user "disabled_plugins": stats.get("disabled_plugins", 0), "hooks_count": len(stats.get("plugins_by_hook", {})), "tags_count": len(stats.get("plugins_by_tag", {})), - "authors_count": len(stats.get("plugins_by_author", {})) + "authors_count": len(stats.get("plugins_by_author", {})), }, - db=db + db=db, ) return PluginStatsResponse(**stats) @@ -12664,13 +12658,7 @@ async def get_plugin_stats(request: Request, db: Session = Depends(get_db), user except Exception as e: LOGGER.error(f"Error getting plugin statistics: {e}") structured_logger.error( - f"Failed to get plugin marketplace statistics", - user_id=str(user.id), - user_email=get_user_email(user), - error=e, - component="plugin_marketplace", - category="business_logic", - db=db + "Failed to get plugin marketplace statistics", user_id=str(user.id), user_email=get_user_email(user), error=e, component="plugin_marketplace", category="business_logic", db=db ) raise HTTPException(status_code=500, detail=str(e)) @@ -12715,7 +12703,7 @@ async def get_plugin_details(name: str, request: Request, db: Session = Depends( component="plugin_marketplace", category="business_logic", custom_fields={"plugin_name": name, "action": "view_details"}, - db=db + db=db, ) raise HTTPException(status_code=404, detail=f"Plugin '{name}' not found") @@ -12736,20 +12724,14 @@ async def get_plugin_details(name: str, request: Request, db: Session = Depends( "plugin_status": plugin.get("status"), "plugin_mode": plugin.get("mode"), "plugin_hooks": plugin.get("hooks", []), - "plugin_tags": plugin.get("tags", []) + "plugin_tags": plugin.get("tags", []), }, - db=db + db=db, ) # Create audit trail for plugin access audit_service.log_audit( - user_id=str(user.id), - user_email=get_user_email(user), - resource_type="plugin", - resource_id=name, - action="view", - description=f"Viewed plugin '{name}' details in marketplace", - db=db + user_id=str(user.id), user_email=get_user_email(user), resource_type="plugin", resource_id=name, action="view", description=f"Viewed plugin '{name}' details in marketplace", db=db ) return PluginDetail(**plugin) @@ -12759,13 +12741,7 @@ async def get_plugin_details(name: str, request: Request, db: Session = Depends( except Exception as e: LOGGER.error(f"Error getting plugin details: {e}") structured_logger.error( - f"Failed to get plugin details: '{name}'", - user_id=str(user.id), - user_email=get_user_email(user), - error=e, - component="plugin_marketplace", - category="business_logic", - db=db + f"Failed to get plugin details: '{name}'", user_id=str(user.id), user_email=get_user_email(user), error=e, component="plugin_marketplace", category="business_logic", db=db ) raise HTTPException(status_code=500, detail=str(e)) diff --git a/mcpgateway/alembic/versions/k5e6f7g8h9i0_add_structured_logging_tables.py b/mcpgateway/alembic/versions/k5e6f7g8h9i0_add_structured_logging_tables.py index ec2840255..951840ba2 100644 --- a/mcpgateway/alembic/versions/k5e6f7g8h9i0_add_structured_logging_tables.py +++ b/mcpgateway/alembic/versions/k5e6f7g8h9i0_add_structured_logging_tables.py @@ -5,12 +5,13 @@ Create Date: 2025-01-15 12:00:00.000000 """ + from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. -revision = 'k5e6f7g8h9i0' -down_revision = ('f3a3a3d901b8', '191a2def08d7') +revision = "k5e6f7g8h9i0" +down_revision = "z1a2b3c4d5e6" branch_labels = None depends_on = None @@ -19,188 +20,188 @@ def upgrade() -> None: """Add structured logging tables.""" # Create structured_log_entries table op.create_table( - 'structured_log_entries', - sa.Column('id', sa.String(36), nullable=False), - sa.Column('timestamp', sa.DateTime(timezone=True), nullable=False), - sa.Column('correlation_id', sa.String(64), nullable=True), - sa.Column('request_id', sa.String(64), nullable=True), - sa.Column('level', sa.String(20), nullable=False), - sa.Column('component', sa.String(100), nullable=False), - sa.Column('message', sa.Text(), nullable=False), - sa.Column('logger', sa.String(255), nullable=True), - sa.Column('user_id', sa.String(255), nullable=True), - sa.Column('user_email', sa.String(255), nullable=True), - sa.Column('client_ip', sa.String(45), nullable=True), - sa.Column('user_agent', sa.Text(), nullable=True), - sa.Column('request_path', sa.String(500), nullable=True), - sa.Column('request_method', sa.String(10), nullable=True), - sa.Column('duration_ms', sa.Float(), nullable=True), - sa.Column('operation_type', sa.String(100), nullable=True), - sa.Column('is_security_event', sa.Boolean(), nullable=False, server_default='0'), - sa.Column('security_severity', sa.String(20), nullable=True), - sa.Column('threat_indicators', sa.JSON(), nullable=True), - sa.Column('context', sa.JSON(), nullable=True), - sa.Column('error_details', sa.JSON(), nullable=True), - sa.Column('performance_metrics', sa.JSON(), nullable=True), - sa.Column('hostname', sa.String(255), nullable=False), - sa.Column('process_id', sa.Integer(), nullable=False), - sa.Column('thread_id', sa.Integer(), nullable=True), - sa.Column('version', sa.String(50), nullable=False), - sa.Column('environment', sa.String(50), nullable=False, server_default='production'), - sa.Column('trace_id', sa.String(32), nullable=True), - sa.Column('span_id', sa.String(16), nullable=True), - sa.PrimaryKeyConstraint('id') + "structured_log_entries", + sa.Column("id", sa.String(36), nullable=False), + sa.Column("timestamp", sa.DateTime(timezone=True), nullable=False), + sa.Column("correlation_id", sa.String(64), nullable=True), + sa.Column("request_id", sa.String(64), nullable=True), + sa.Column("level", sa.String(20), nullable=False), + sa.Column("component", sa.String(100), nullable=False), + sa.Column("message", sa.Text(), nullable=False), + sa.Column("logger", sa.String(255), nullable=True), + sa.Column("user_id", sa.String(255), nullable=True), + sa.Column("user_email", sa.String(255), nullable=True), + sa.Column("client_ip", sa.String(45), nullable=True), + sa.Column("user_agent", sa.Text(), nullable=True), + sa.Column("request_path", sa.String(500), nullable=True), + sa.Column("request_method", sa.String(10), nullable=True), + sa.Column("duration_ms", sa.Float(), nullable=True), + sa.Column("operation_type", sa.String(100), nullable=True), + sa.Column("is_security_event", sa.Boolean(), nullable=False, server_default="0"), + sa.Column("security_severity", sa.String(20), nullable=True), + sa.Column("threat_indicators", sa.JSON(), nullable=True), + sa.Column("context", sa.JSON(), nullable=True), + sa.Column("error_details", sa.JSON(), nullable=True), + sa.Column("performance_metrics", sa.JSON(), nullable=True), + sa.Column("hostname", sa.String(255), nullable=False), + sa.Column("process_id", sa.Integer(), nullable=False), + sa.Column("thread_id", sa.Integer(), nullable=True), + sa.Column("version", sa.String(50), nullable=False), + sa.Column("environment", sa.String(50), nullable=False, server_default="production"), + sa.Column("trace_id", sa.String(32), nullable=True), + sa.Column("span_id", sa.String(16), nullable=True), + sa.PrimaryKeyConstraint("id"), ) - + # Create indexes for structured_log_entries - op.create_index('ix_structured_log_entries_timestamp', 'structured_log_entries', ['timestamp'], unique=False) - op.create_index('ix_structured_log_entries_level', 'structured_log_entries', ['level'], unique=False) - op.create_index('ix_structured_log_entries_component', 'structured_log_entries', ['component'], unique=False) - op.create_index('ix_structured_log_entries_correlation_id', 'structured_log_entries', ['correlation_id'], unique=False) - op.create_index('ix_structured_log_entries_request_id', 'structured_log_entries', ['request_id'], unique=False) - op.create_index('ix_structured_log_entries_user_id', 'structured_log_entries', ['user_id'], unique=False) - op.create_index('ix_structured_log_entries_user_email', 'structured_log_entries', ['user_email'], unique=False) - op.create_index('ix_structured_log_entries_operation_type', 'structured_log_entries', ['operation_type'], unique=False) - op.create_index('ix_structured_log_entries_is_security_event', 'structured_log_entries', ['is_security_event'], unique=False) - op.create_index('ix_structured_log_entries_security_severity', 'structured_log_entries', ['security_severity'], unique=False) - op.create_index('ix_structured_log_entries_trace_id', 'structured_log_entries', ['trace_id'], unique=False) - + op.create_index("ix_structured_log_entries_timestamp", "structured_log_entries", ["timestamp"], unique=False) + op.create_index("ix_structured_log_entries_level", "structured_log_entries", ["level"], unique=False) + op.create_index("ix_structured_log_entries_component", "structured_log_entries", ["component"], unique=False) + op.create_index("ix_structured_log_entries_correlation_id", "structured_log_entries", ["correlation_id"], unique=False) + op.create_index("ix_structured_log_entries_request_id", "structured_log_entries", ["request_id"], unique=False) + op.create_index("ix_structured_log_entries_user_id", "structured_log_entries", ["user_id"], unique=False) + op.create_index("ix_structured_log_entries_user_email", "structured_log_entries", ["user_email"], unique=False) + op.create_index("ix_structured_log_entries_operation_type", "structured_log_entries", ["operation_type"], unique=False) + op.create_index("ix_structured_log_entries_is_security_event", "structured_log_entries", ["is_security_event"], unique=False) + op.create_index("ix_structured_log_entries_security_severity", "structured_log_entries", ["security_severity"], unique=False) + op.create_index("ix_structured_log_entries_trace_id", "structured_log_entries", ["trace_id"], unique=False) + # Composite indexes matching db.py - op.create_index('idx_log_correlation_time', 'structured_log_entries', ['correlation_id', 'timestamp'], unique=False) - op.create_index('idx_log_user_time', 'structured_log_entries', ['user_id', 'timestamp'], unique=False) - op.create_index('idx_log_level_time', 'structured_log_entries', ['level', 'timestamp'], unique=False) - op.create_index('idx_log_component_time', 'structured_log_entries', ['component', 'timestamp'], unique=False) - op.create_index('idx_log_security', 'structured_log_entries', ['is_security_event', 'security_severity', 'timestamp'], unique=False) - op.create_index('idx_log_operation', 'structured_log_entries', ['operation_type', 'timestamp'], unique=False) - op.create_index('idx_log_trace', 'structured_log_entries', ['trace_id', 'timestamp'], unique=False) - + op.create_index("idx_log_correlation_time", "structured_log_entries", ["correlation_id", "timestamp"], unique=False) + op.create_index("idx_log_user_time", "structured_log_entries", ["user_id", "timestamp"], unique=False) + op.create_index("idx_log_level_time", "structured_log_entries", ["level", "timestamp"], unique=False) + op.create_index("idx_log_component_time", "structured_log_entries", ["component", "timestamp"], unique=False) + op.create_index("idx_log_security", "structured_log_entries", ["is_security_event", "security_severity", "timestamp"], unique=False) + op.create_index("idx_log_operation", "structured_log_entries", ["operation_type", "timestamp"], unique=False) + op.create_index("idx_log_trace", "structured_log_entries", ["trace_id", "timestamp"], unique=False) + # Create performance_metrics table op.create_table( - 'performance_metrics', - sa.Column('id', sa.String(36), nullable=False), - sa.Column('timestamp', sa.DateTime(timezone=True), nullable=False), - sa.Column('operation_type', sa.String(100), nullable=False), - sa.Column('component', sa.String(100), nullable=False), - sa.Column('request_count', sa.Integer(), nullable=False, server_default='0'), - sa.Column('error_count', sa.Integer(), nullable=False, server_default='0'), - sa.Column('error_rate', sa.Float(), nullable=False, server_default='0.0'), - sa.Column('avg_duration_ms', sa.Float(), nullable=False), - sa.Column('min_duration_ms', sa.Float(), nullable=False), - sa.Column('max_duration_ms', sa.Float(), nullable=False), - sa.Column('p50_duration_ms', sa.Float(), nullable=False), - sa.Column('p95_duration_ms', sa.Float(), nullable=False), - sa.Column('p99_duration_ms', sa.Float(), nullable=False), - sa.Column('window_start', sa.DateTime(timezone=True), nullable=False), - sa.Column('window_end', sa.DateTime(timezone=True), nullable=False), - sa.Column('window_duration_seconds', sa.Integer(), nullable=False), - sa.Column('metric_metadata', sa.JSON(), nullable=True), - sa.PrimaryKeyConstraint('id') + "performance_metrics", + sa.Column("id", sa.String(36), nullable=False), + sa.Column("timestamp", sa.DateTime(timezone=True), nullable=False), + sa.Column("operation_type", sa.String(100), nullable=False), + sa.Column("component", sa.String(100), nullable=False), + sa.Column("request_count", sa.Integer(), nullable=False, server_default="0"), + sa.Column("error_count", sa.Integer(), nullable=False, server_default="0"), + sa.Column("error_rate", sa.Float(), nullable=False, server_default="0.0"), + sa.Column("avg_duration_ms", sa.Float(), nullable=False), + sa.Column("min_duration_ms", sa.Float(), nullable=False), + sa.Column("max_duration_ms", sa.Float(), nullable=False), + sa.Column("p50_duration_ms", sa.Float(), nullable=False), + sa.Column("p95_duration_ms", sa.Float(), nullable=False), + sa.Column("p99_duration_ms", sa.Float(), nullable=False), + sa.Column("window_start", sa.DateTime(timezone=True), nullable=False), + sa.Column("window_end", sa.DateTime(timezone=True), nullable=False), + sa.Column("window_duration_seconds", sa.Integer(), nullable=False), + sa.Column("metric_metadata", sa.JSON(), nullable=True), + sa.PrimaryKeyConstraint("id"), ) - + # Create indexes for performance_metrics - op.create_index('ix_performance_metrics_timestamp', 'performance_metrics', ['timestamp'], unique=False) - op.create_index('ix_performance_metrics_component', 'performance_metrics', ['component'], unique=False) - op.create_index('ix_performance_metrics_operation_type', 'performance_metrics', ['operation_type'], unique=False) - op.create_index('ix_performance_metrics_window_start', 'performance_metrics', ['window_start'], unique=False) - op.create_index('idx_perf_operation_time', 'performance_metrics', ['operation_type', 'window_start'], unique=False) - op.create_index('idx_perf_component_time', 'performance_metrics', ['component', 'window_start'], unique=False) - op.create_index('idx_perf_window', 'performance_metrics', ['window_start', 'window_end'], unique=False) - + op.create_index("ix_performance_metrics_timestamp", "performance_metrics", ["timestamp"], unique=False) + op.create_index("ix_performance_metrics_component", "performance_metrics", ["component"], unique=False) + op.create_index("ix_performance_metrics_operation_type", "performance_metrics", ["operation_type"], unique=False) + op.create_index("ix_performance_metrics_window_start", "performance_metrics", ["window_start"], unique=False) + op.create_index("idx_perf_operation_time", "performance_metrics", ["operation_type", "window_start"], unique=False) + op.create_index("idx_perf_component_time", "performance_metrics", ["component", "window_start"], unique=False) + op.create_index("idx_perf_window", "performance_metrics", ["window_start", "window_end"], unique=False) + # Create security_events table op.create_table( - 'security_events', - sa.Column('id', sa.String(36), nullable=False), - sa.Column('timestamp', sa.DateTime(timezone=True), nullable=False), - sa.Column('detected_at', sa.DateTime(timezone=True), nullable=False), - sa.Column('event_type', sa.String(100), nullable=False), - sa.Column('severity', sa.String(20), nullable=False), - sa.Column('category', sa.String(100), nullable=False), - sa.Column('user_id', sa.String(255), nullable=True), - sa.Column('user_email', sa.String(255), nullable=True), - sa.Column('client_ip', sa.String(45), nullable=False), - sa.Column('user_agent', sa.Text(), nullable=True), - sa.Column('description', sa.Text(), nullable=False), - sa.Column('action_taken', sa.String(100), nullable=True), - sa.Column('threat_score', sa.Float(), nullable=False, server_default='0.0'), - sa.Column('threat_indicators', sa.JSON(), nullable=True), - sa.Column('failed_attempts_count', sa.Integer(), nullable=False, server_default='0'), - sa.Column('context', sa.JSON(), nullable=True), - sa.Column('correlation_id', sa.String(255), nullable=True), - sa.Column('resolved', sa.Boolean(), nullable=False, server_default='false'), - sa.Column('resolved_at', sa.DateTime(timezone=True), nullable=True), - sa.Column('resolved_by', sa.String(255), nullable=True), - sa.Column('resolution_notes', sa.Text(), nullable=True), - sa.Column('alert_sent', sa.Boolean(), nullable=False, server_default='false'), - sa.Column('alert_sent_at', sa.DateTime(timezone=True), nullable=True), - sa.PrimaryKeyConstraint('id') + "security_events", + sa.Column("id", sa.String(36), nullable=False), + sa.Column("timestamp", sa.DateTime(timezone=True), nullable=False), + sa.Column("detected_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("event_type", sa.String(100), nullable=False), + sa.Column("severity", sa.String(20), nullable=False), + sa.Column("category", sa.String(100), nullable=False), + sa.Column("user_id", sa.String(255), nullable=True), + sa.Column("user_email", sa.String(255), nullable=True), + sa.Column("client_ip", sa.String(45), nullable=False), + sa.Column("user_agent", sa.Text(), nullable=True), + sa.Column("description", sa.Text(), nullable=False), + sa.Column("action_taken", sa.String(100), nullable=True), + sa.Column("threat_score", sa.Float(), nullable=False, server_default="0.0"), + sa.Column("threat_indicators", sa.JSON(), nullable=True), + sa.Column("failed_attempts_count", sa.Integer(), nullable=False, server_default="0"), + sa.Column("context", sa.JSON(), nullable=True), + sa.Column("correlation_id", sa.String(255), nullable=True), + sa.Column("resolved", sa.Boolean(), nullable=False, server_default="false"), + sa.Column("resolved_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("resolved_by", sa.String(255), nullable=True), + sa.Column("resolution_notes", sa.Text(), nullable=True), + sa.Column("alert_sent", sa.Boolean(), nullable=False, server_default="false"), + sa.Column("alert_sent_at", sa.DateTime(timezone=True), nullable=True), + sa.PrimaryKeyConstraint("id"), ) - + # Create indexes for security_events - op.create_index('ix_security_events_timestamp', 'security_events', ['timestamp'], unique=False) - op.create_index('ix_security_events_detected_at', 'security_events', ['detected_at'], unique=False) - op.create_index('ix_security_events_correlation_id', 'security_events', ['correlation_id'], unique=False) - op.create_index('ix_security_events_event_type', 'security_events', ['event_type'], unique=False) - op.create_index('ix_security_events_severity', 'security_events', ['severity'], unique=False) - op.create_index('ix_security_events_category', 'security_events', ['category'], unique=False) - op.create_index('ix_security_events_user_id', 'security_events', ['user_id'], unique=False) - op.create_index('ix_security_events_user_email', 'security_events', ['user_email'], unique=False) - op.create_index('ix_security_events_client_ip', 'security_events', ['client_ip'], unique=False) - op.create_index('idx_security_event_time', 'security_events', ['event_type', 'timestamp'], unique=False) - op.create_index('idx_security_severity_time', 'security_events', ['severity', 'timestamp'], unique=False) - op.create_index('idx_security_user_time', 'security_events', ['user_id', 'timestamp'], unique=False) - + op.create_index("ix_security_events_timestamp", "security_events", ["timestamp"], unique=False) + op.create_index("ix_security_events_detected_at", "security_events", ["detected_at"], unique=False) + op.create_index("ix_security_events_correlation_id", "security_events", ["correlation_id"], unique=False) + op.create_index("ix_security_events_event_type", "security_events", ["event_type"], unique=False) + op.create_index("ix_security_events_severity", "security_events", ["severity"], unique=False) + op.create_index("ix_security_events_category", "security_events", ["category"], unique=False) + op.create_index("ix_security_events_user_id", "security_events", ["user_id"], unique=False) + op.create_index("ix_security_events_user_email", "security_events", ["user_email"], unique=False) + op.create_index("ix_security_events_client_ip", "security_events", ["client_ip"], unique=False) + op.create_index("idx_security_event_time", "security_events", ["event_type", "timestamp"], unique=False) + op.create_index("idx_security_severity_time", "security_events", ["severity", "timestamp"], unique=False) + op.create_index("idx_security_user_time", "security_events", ["user_id", "timestamp"], unique=False) + # Create audit_trails table op.create_table( - 'audit_trails', - sa.Column('id', sa.String(36), nullable=False), - sa.Column('timestamp', sa.DateTime(timezone=True), nullable=False), - sa.Column('correlation_id', sa.String(64), nullable=True), - sa.Column('request_id', sa.String(64), nullable=True), - sa.Column('action', sa.String(100), nullable=False), - sa.Column('resource_type', sa.String(100), nullable=False), - sa.Column('resource_id', sa.String(255), nullable=False), - sa.Column('resource_name', sa.String(500), nullable=True), - sa.Column('user_id', sa.String(255), nullable=False), - sa.Column('user_email', sa.String(255), nullable=True), - sa.Column('team_id', sa.String(36), nullable=True), - sa.Column('client_ip', sa.String(45), nullable=True), - sa.Column('user_agent', sa.Text(), nullable=True), - sa.Column('request_path', sa.String(500), nullable=True), - sa.Column('request_method', sa.String(10), nullable=True), - sa.Column('old_values', sa.JSON(), nullable=True), - sa.Column('new_values', sa.JSON(), nullable=True), - sa.Column('changes', sa.JSON(), nullable=True), - sa.Column('data_classification', sa.String(50), nullable=True), - sa.Column('requires_review', sa.Boolean(), nullable=False, server_default='0'), - sa.Column('success', sa.Boolean(), nullable=False), - sa.Column('error_message', sa.Text(), nullable=True), - sa.Column('context', sa.JSON(), nullable=True), - sa.PrimaryKeyConstraint('id') + "audit_trails", + sa.Column("id", sa.String(36), nullable=False), + sa.Column("timestamp", sa.DateTime(timezone=True), nullable=False), + sa.Column("correlation_id", sa.String(64), nullable=True), + sa.Column("request_id", sa.String(64), nullable=True), + sa.Column("action", sa.String(100), nullable=False), + sa.Column("resource_type", sa.String(100), nullable=False), + sa.Column("resource_id", sa.String(255), nullable=False), + sa.Column("resource_name", sa.String(500), nullable=True), + sa.Column("user_id", sa.String(255), nullable=False), + sa.Column("user_email", sa.String(255), nullable=True), + sa.Column("team_id", sa.String(36), nullable=True), + sa.Column("client_ip", sa.String(45), nullable=True), + sa.Column("user_agent", sa.Text(), nullable=True), + sa.Column("request_path", sa.String(500), nullable=True), + sa.Column("request_method", sa.String(10), nullable=True), + sa.Column("old_values", sa.JSON(), nullable=True), + sa.Column("new_values", sa.JSON(), nullable=True), + sa.Column("changes", sa.JSON(), nullable=True), + sa.Column("data_classification", sa.String(50), nullable=True), + sa.Column("requires_review", sa.Boolean(), nullable=False, server_default="0"), + sa.Column("success", sa.Boolean(), nullable=False), + sa.Column("error_message", sa.Text(), nullable=True), + sa.Column("context", sa.JSON(), nullable=True), + sa.PrimaryKeyConstraint("id"), ) - + # Create indexes for audit_trails - op.create_index('ix_audit_trails_timestamp', 'audit_trails', ['timestamp'], unique=False) - op.create_index('ix_audit_trails_correlation_id', 'audit_trails', ['correlation_id'], unique=False) - op.create_index('ix_audit_trails_request_id', 'audit_trails', ['request_id'], unique=False) - op.create_index('ix_audit_trails_action', 'audit_trails', ['action'], unique=False) - op.create_index('ix_audit_trails_resource_type', 'audit_trails', ['resource_type'], unique=False) - op.create_index('ix_audit_trails_resource_id', 'audit_trails', ['resource_id'], unique=False) - op.create_index('ix_audit_trails_user_id', 'audit_trails', ['user_id'], unique=False) - op.create_index('ix_audit_trails_user_email', 'audit_trails', ['user_email'], unique=False) - op.create_index('ix_audit_trails_team_id', 'audit_trails', ['team_id'], unique=False) - op.create_index('ix_audit_trails_data_classification', 'audit_trails', ['data_classification'], unique=False) - op.create_index('ix_audit_trails_requires_review', 'audit_trails', ['requires_review'], unique=False) - op.create_index('ix_audit_trails_success', 'audit_trails', ['success'], unique=False) - op.create_index('idx_audit_action_time', 'audit_trails', ['action', 'timestamp'], unique=False) - op.create_index('idx_audit_resource_time', 'audit_trails', ['resource_type', 'resource_id', 'timestamp'], unique=False) - op.create_index('idx_audit_user_time', 'audit_trails', ['user_id', 'timestamp'], unique=False) - op.create_index('idx_audit_classification', 'audit_trails', ['data_classification', 'timestamp'], unique=False) - op.create_index('idx_audit_review', 'audit_trails', ['requires_review', 'timestamp'], unique=False) + op.create_index("ix_audit_trails_timestamp", "audit_trails", ["timestamp"], unique=False) + op.create_index("ix_audit_trails_correlation_id", "audit_trails", ["correlation_id"], unique=False) + op.create_index("ix_audit_trails_request_id", "audit_trails", ["request_id"], unique=False) + op.create_index("ix_audit_trails_action", "audit_trails", ["action"], unique=False) + op.create_index("ix_audit_trails_resource_type", "audit_trails", ["resource_type"], unique=False) + op.create_index("ix_audit_trails_resource_id", "audit_trails", ["resource_id"], unique=False) + op.create_index("ix_audit_trails_user_id", "audit_trails", ["user_id"], unique=False) + op.create_index("ix_audit_trails_user_email", "audit_trails", ["user_email"], unique=False) + op.create_index("ix_audit_trails_team_id", "audit_trails", ["team_id"], unique=False) + op.create_index("ix_audit_trails_data_classification", "audit_trails", ["data_classification"], unique=False) + op.create_index("ix_audit_trails_requires_review", "audit_trails", ["requires_review"], unique=False) + op.create_index("ix_audit_trails_success", "audit_trails", ["success"], unique=False) + op.create_index("idx_audit_action_time", "audit_trails", ["action", "timestamp"], unique=False) + op.create_index("idx_audit_resource_time", "audit_trails", ["resource_type", "resource_id", "timestamp"], unique=False) + op.create_index("idx_audit_user_time", "audit_trails", ["user_id", "timestamp"], unique=False) + op.create_index("idx_audit_classification", "audit_trails", ["data_classification", "timestamp"], unique=False) + op.create_index("idx_audit_review", "audit_trails", ["requires_review", "timestamp"], unique=False) def downgrade() -> None: """Remove structured logging tables.""" - op.drop_table('audit_trails') - op.drop_table('security_events') - op.drop_table('performance_metrics') - op.drop_table('structured_log_entries') + op.drop_table("audit_trails") + op.drop_table("security_events") + op.drop_table("performance_metrics") + op.drop_table("structured_log_entries") diff --git a/mcpgateway/auth.py b/mcpgateway/auth.py index edca68967..c0300a124 100644 --- a/mcpgateway/auth.py +++ b/mcpgateway/auth.py @@ -43,7 +43,7 @@ def _log_auth_event( auth_success: bool = False, security_event: Optional[str] = None, security_severity: str = "low", - **extra_context + **extra_context, ) -> None: """Log authentication event with structured context and request_id. @@ -66,17 +66,17 @@ def _log_auth_event( # Build structured log record extra = { - 'request_id': request_id, - 'entity_type': 'auth', - 'auth_success': auth_success, - 'security_event': security_event or 'authentication', - 'security_severity': security_severity, + "request_id": request_id, + "entity_type": "auth", + "auth_success": auth_success, + "security_event": security_event or "authentication", + "security_severity": security_severity, } if user_id: - extra['user_id'] = user_id + extra["user_id"] = user_id if auth_method: - extra['auth_method'] = auth_method + extra["auth_method"] = auth_method # Add any additional context extra.update(extra_context) diff --git a/mcpgateway/db.py b/mcpgateway/db.py index ee5303350..f548c8af4 100644 --- a/mcpgateway/db.py +++ b/mcpgateway/db.py @@ -3804,29 +3804,29 @@ def init_db(): class StructuredLogEntry(Base): """Structured log entry for comprehensive logging and analysis. - + Stores all log entries with correlation IDs, performance metrics, and security context for advanced search and analytics. """ - + __tablename__ = "structured_log_entries" - + # Primary key id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: uuid.uuid4().hex) - + # Timestamps timestamp: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, index=True, default=utc_now) - + # Correlation and request tracking correlation_id: Mapped[Optional[str]] = mapped_column(String(64), index=True, nullable=True) request_id: Mapped[Optional[str]] = mapped_column(String(64), index=True, nullable=True) - + # Log metadata level: Mapped[str] = mapped_column(String(20), nullable=False, index=True) # DEBUG, INFO, WARNING, ERROR, CRITICAL component: Mapped[str] = mapped_column(String(100), nullable=False, index=True) message: Mapped[str] = mapped_column(Text, nullable=False) logger: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) - + # User and request context user_id: Mapped[Optional[str]] = mapped_column(String(255), index=True, nullable=True) user_email: Mapped[Optional[str]] = mapped_column(String(255), index=True, nullable=True) @@ -3834,68 +3834,68 @@ class StructuredLogEntry(Base): user_agent: Mapped[Optional[str]] = mapped_column(Text, nullable=True) request_path: Mapped[Optional[str]] = mapped_column(String(500), nullable=True) request_method: Mapped[Optional[str]] = mapped_column(String(10), nullable=True) - + # Performance data duration_ms: Mapped[Optional[float]] = mapped_column(Float, nullable=True) operation_type: Mapped[Optional[str]] = mapped_column(String(100), index=True, nullable=True) - + # Security context is_security_event: Mapped[bool] = mapped_column(Boolean, default=False, index=True, nullable=False) security_severity: Mapped[Optional[str]] = mapped_column(String(20), index=True, nullable=True) # LOW, MEDIUM, HIGH, CRITICAL threat_indicators: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True) - + # Structured context data context: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True) error_details: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True) performance_metrics: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True) - + # System information hostname: Mapped[str] = mapped_column(String(255), nullable=False) process_id: Mapped[int] = mapped_column(Integer, nullable=False) thread_id: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) version: Mapped[str] = mapped_column(String(50), nullable=False) environment: Mapped[str] = mapped_column(String(50), nullable=False, default="production") - + # OpenTelemetry trace context trace_id: Mapped[Optional[str]] = mapped_column(String(32), index=True, nullable=True) span_id: Mapped[Optional[str]] = mapped_column(String(16), nullable=True) - + # Indexes for performance __table_args__ = ( - Index('idx_log_correlation_time', 'correlation_id', 'timestamp'), - Index('idx_log_user_time', 'user_id', 'timestamp'), - Index('idx_log_level_time', 'level', 'timestamp'), - Index('idx_log_component_time', 'component', 'timestamp'), - Index('idx_log_security', 'is_security_event', 'security_severity', 'timestamp'), - Index('idx_log_operation', 'operation_type', 'timestamp'), - Index('idx_log_trace', 'trace_id', 'timestamp'), + Index("idx_log_correlation_time", "correlation_id", "timestamp"), + Index("idx_log_user_time", "user_id", "timestamp"), + Index("idx_log_level_time", "level", "timestamp"), + Index("idx_log_component_time", "component", "timestamp"), + Index("idx_log_security", "is_security_event", "security_severity", "timestamp"), + Index("idx_log_operation", "operation_type", "timestamp"), + Index("idx_log_trace", "trace_id", "timestamp"), ) class PerformanceMetric(Base): """Aggregated performance metrics from log analysis. - + Stores time-windowed aggregations of operation performance for analytics and trend analysis. """ - + __tablename__ = "performance_metrics" - + # Primary key id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: uuid.uuid4().hex) - + # Timestamp timestamp: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, index=True, default=utc_now) - + # Metric identification operation_type: Mapped[str] = mapped_column(String(100), nullable=False, index=True) component: Mapped[str] = mapped_column(String(100), nullable=False, index=True) - + # Aggregated metrics request_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0) error_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0) error_rate: Mapped[float] = mapped_column(Float, nullable=False, default=0.0) - + # Duration metrics (in milliseconds) avg_duration_ms: Mapped[float] = mapped_column(Float, nullable=False) min_duration_ms: Mapped[float] = mapped_column(Float, nullable=False) @@ -3903,143 +3903,143 @@ class PerformanceMetric(Base): p50_duration_ms: Mapped[float] = mapped_column(Float, nullable=False) p95_duration_ms: Mapped[float] = mapped_column(Float, nullable=False) p99_duration_ms: Mapped[float] = mapped_column(Float, nullable=False) - + # Time window window_start: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, index=True) window_end: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) window_duration_seconds: Mapped[int] = mapped_column(Integer, nullable=False) - + # Additional context metric_metadata: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True) - + __table_args__ = ( - Index('idx_perf_operation_time', 'operation_type', 'window_start'), - Index('idx_perf_component_time', 'component', 'window_start'), - Index('idx_perf_window', 'window_start', 'window_end'), + Index("idx_perf_operation_time", "operation_type", "window_start"), + Index("idx_perf_component_time", "component", "window_start"), + Index("idx_perf_window", "window_start", "window_end"), ) class SecurityEvent(Base): """Security event logging for threat detection and audit trails. - + Specialized table for security events with enhanced context and threat analysis capabilities. """ - + __tablename__ = "security_events" - + # Primary key id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: uuid.uuid4().hex) - + # Timestamps timestamp: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, index=True, default=utc_now) detected_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, default=utc_now) - + # Correlation tracking correlation_id: Mapped[Optional[str]] = mapped_column(String(64), index=True, nullable=True) log_entry_id: Mapped[Optional[str]] = mapped_column(String(36), ForeignKey("structured_log_entries.id"), index=True, nullable=True) - + # Event classification event_type: Mapped[str] = mapped_column(String(100), nullable=False, index=True) # auth_failure, suspicious_activity, rate_limit, etc. severity: Mapped[str] = mapped_column(String(20), nullable=False, index=True) # LOW, MEDIUM, HIGH, CRITICAL category: Mapped[str] = mapped_column(String(50), nullable=False, index=True) # authentication, authorization, data_access, etc. - + # User and request context user_id: Mapped[Optional[str]] = mapped_column(String(255), index=True, nullable=True) user_email: Mapped[Optional[str]] = mapped_column(String(255), index=True, nullable=True) client_ip: Mapped[str] = mapped_column(String(45), nullable=False, index=True) user_agent: Mapped[Optional[str]] = mapped_column(Text, nullable=True) - + # Event details description: Mapped[str] = mapped_column(Text, nullable=False) action_taken: Mapped[Optional[str]] = mapped_column(String(100), nullable=True) # blocked, allowed, flagged, etc. - + # Threat analysis threat_score: Mapped[float] = mapped_column(Float, nullable=False, default=0.0) # 0.0-1.0 threat_indicators: Mapped[Dict[str, Any]] = mapped_column(JSON, nullable=False, default=dict) failed_attempts_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0) - + # Resolution tracking resolved: Mapped[bool] = mapped_column(Boolean, default=False, index=True, nullable=False) resolved_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True) resolved_by: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) resolution_notes: Mapped[Optional[str]] = mapped_column(Text, nullable=True) - + # Alert tracking alert_sent: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) alert_sent_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True) alert_recipients: Mapped[Optional[List[str]]] = mapped_column(JSON, nullable=True) - + # Additional context context: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True) - + __table_args__ = ( - Index('idx_security_type_time', 'event_type', 'timestamp'), - Index('idx_security_severity_time', 'severity', 'timestamp'), - Index('idx_security_user_time', 'user_id', 'timestamp'), - Index('idx_security_ip_time', 'client_ip', 'timestamp'), - Index('idx_security_unresolved', 'resolved', 'severity', 'timestamp'), + Index("idx_security_type_time", "event_type", "timestamp"), + Index("idx_security_severity_time", "severity", "timestamp"), + Index("idx_security_user_time", "user_id", "timestamp"), + Index("idx_security_ip_time", "client_ip", "timestamp"), + Index("idx_security_unresolved", "resolved", "severity", "timestamp"), ) class AuditTrail(Base): """Comprehensive audit trail for data access and changes. - + Tracks all significant system changes and data access for compliance and security auditing. """ - + __tablename__ = "audit_trails" - + # Primary key id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: uuid.uuid4().hex) - + # Timestamps timestamp: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, index=True, default=utc_now) - + # Correlation tracking correlation_id: Mapped[Optional[str]] = mapped_column(String(64), index=True, nullable=True) request_id: Mapped[Optional[str]] = mapped_column(String(64), index=True, nullable=True) - + # Action details action: Mapped[str] = mapped_column(String(100), nullable=False, index=True) # create, read, update, delete, execute, etc. resource_type: Mapped[str] = mapped_column(String(100), nullable=False, index=True) # tool, resource, prompt, user, etc. resource_id: Mapped[Optional[str]] = mapped_column(String(255), index=True, nullable=True) resource_name: Mapped[Optional[str]] = mapped_column(String(500), nullable=True) - + # User context user_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True) user_email: Mapped[Optional[str]] = mapped_column(String(255), index=True, nullable=True) team_id: Mapped[Optional[str]] = mapped_column(String(36), index=True, nullable=True) - + # Request context client_ip: Mapped[Optional[str]] = mapped_column(String(45), nullable=True) user_agent: Mapped[Optional[str]] = mapped_column(Text, nullable=True) request_path: Mapped[Optional[str]] = mapped_column(String(500), nullable=True) request_method: Mapped[Optional[str]] = mapped_column(String(10), nullable=True) - + # Change tracking old_values: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True) new_values: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True) changes: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True) - + # Data classification data_classification: Mapped[Optional[str]] = mapped_column(String(50), index=True, nullable=True) # public, internal, confidential, restricted requires_review: Mapped[bool] = mapped_column(Boolean, default=False, index=True, nullable=False) - + # Result success: Mapped[bool] = mapped_column(Boolean, nullable=False, index=True) error_message: Mapped[Optional[str]] = mapped_column(Text, nullable=True) - + # Additional context context: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True) - + __table_args__ = ( - Index('idx_audit_action_time', 'action', 'timestamp'), - Index('idx_audit_resource_time', 'resource_type', 'resource_id', 'timestamp'), - Index('idx_audit_user_time', 'user_id', 'timestamp'), - Index('idx_audit_classification', 'data_classification', 'timestamp'), - Index('idx_audit_review', 'requires_review', 'timestamp'), + Index("idx_audit_action_time", "action", "timestamp"), + Index("idx_audit_resource_time", "resource_type", "resource_id", "timestamp"), + Index("idx_audit_user_time", "user_id", "timestamp"), + Index("idx_audit_classification", "data_classification", "timestamp"), + Index("idx_audit_review", "requires_review", "timestamp"), ) diff --git a/mcpgateway/main.py b/mcpgateway/main.py index 0a9277f65..4474afbd4 100644 --- a/mcpgateway/main.py +++ b/mcpgateway/main.py @@ -1231,7 +1231,9 @@ async def _call_streamable_http(self, scope, receive, send): # IMPORTANT: Must be registered BEFORE CorrelationIDMiddleware so it executes AFTER correlation ID is set # Gateway boundary logging (request_started/completed) runs regardless of log_requests setting # Detailed payload logging only runs if log_detailed_requests=True -app.add_middleware(RequestLoggingMiddleware, enable_gateway_logging=True, log_detailed_requests=settings.log_requests, log_level=settings.log_level, max_body_size=settings.log_max_size_mb * 1024 * 1024) # Convert MB to bytes +app.add_middleware( + RequestLoggingMiddleware, enable_gateway_logging=True, log_detailed_requests=settings.log_requests, log_level=settings.log_level, max_body_size=settings.log_max_size_mb * 1024 * 1024 +) # Convert MB to bytes # Add custom DocsAuthMiddleware app.add_middleware(DocsAuthMiddleware) @@ -5068,7 +5070,7 @@ async def cleanup_import_statuses(max_age_hours: int = 24, user=Depends(get_curr try: # First-Party from mcpgateway.routers.log_search import router as log_search_router - + app.include_router(log_search_router) logger.info("Log search router included - structured logging enabled") except ImportError as e: diff --git a/mcpgateway/middleware/auth_middleware.py b/mcpgateway/middleware/auth_middleware.py index 5fa0d160e..1c2dc7a6c 100644 --- a/mcpgateway/middleware/auth_middleware.py +++ b/mcpgateway/middleware/auth_middleware.py @@ -92,15 +92,15 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response: # Note: EmailUser uses 'email' as primary key, not 'id' user_email = user.email user_id = user_email # For EmailUser, email IS the ID - + # Expunge the user from the session so it can be used after session closes # This makes the object detached but with all attributes already loaded db.expunge(user) - + # Store user in request state for downstream use request.state.user = user logger.info(f"✓ Authenticated user: {user_email if user_email else user_id}") - + # Log successful authentication security_logger.log_authentication_attempt( user_id=user_id, @@ -109,13 +109,13 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response: success=True, client_ip=request.client.host if request.client else "unknown", user_agent=request.headers.get("user-agent"), - db=db + db=db, ) except Exception as e: # Silently fail - let route handlers enforce auth if needed logger.info(f"✗ Auth context extraction failed (continuing as anonymous): {e}") - + # Log failed authentication attempt security_logger.log_authentication_attempt( user_id="unknown", @@ -125,7 +125,7 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response: client_ip=request.client.host if request.client else "unknown", user_agent=request.headers.get("user-agent"), failure_reason=str(e), - db=db if db else None + db=db if db else None, ) finally: diff --git a/mcpgateway/middleware/correlation_id.py b/mcpgateway/middleware/correlation_id.py index bcb033a3d..7d9a31193 100644 --- a/mcpgateway/middleware/correlation_id.py +++ b/mcpgateway/middleware/correlation_id.py @@ -68,9 +68,9 @@ def __init__(self, app): app: The FastAPI application instance """ super().__init__(app) - self.header_name = getattr(settings, 'correlation_id_header', 'X-Correlation-ID') - self.preserve_incoming = getattr(settings, 'correlation_id_preserve', True) - self.add_to_response = getattr(settings, 'correlation_id_response_header', True) + self.header_name = getattr(settings, "correlation_id_header", "X-Correlation-ID") + self.preserve_incoming = getattr(settings, "correlation_id_preserve", True) + self.add_to_response = getattr(settings, "correlation_id_response_header", True) async def dispatch(self, request: Request, call_next: Callable) -> Response: """Process the request and manage request ID (correlation ID) lifecycle. @@ -89,10 +89,7 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response: # Extract correlation ID from incoming request headers correlation_id = None if self.preserve_incoming: - correlation_id = extract_correlation_id_from_headers( - dict(request.headers), - self.header_name - ) + correlation_id = extract_correlation_id_from_headers(dict(request.headers), self.header_name) # Generate new correlation ID if none was provided if not correlation_id: diff --git a/mcpgateway/middleware/request_logging_middleware.py b/mcpgateway/middleware/request_logging_middleware.py index 4b9318dec..d16e276a8 100644 --- a/mcpgateway/middleware/request_logging_middleware.py +++ b/mcpgateway/middleware/request_logging_middleware.py @@ -131,7 +131,7 @@ def __init__(self, app, enable_gateway_logging: bool = True, log_detailed_reques self.log_detailed_requests = log_detailed_requests self.log_level = log_level.upper() self.max_body_size = max_body_size # Expected to be in bytes - + async def _resolve_user_identity(self, request: Request): """Best-effort extraction of user identity for request logs.""" # Prefer context injected by upstream middleware @@ -182,7 +182,7 @@ async def dispatch(self, request: Request, call_next: Callable): """ # Track start time for total duration start_time = time.time() - + # Get correlation ID and request metadata for boundary logging correlation_id = get_correlation_id() path = request.url.path @@ -190,11 +190,11 @@ async def dispatch(self, request: Request, call_next: Callable): user_agent = request.headers.get("user-agent", "unknown") client_ip = request.client.host if request.client else "unknown" user_id, user_email = await self._resolve_user_identity(request) - + # Skip boundary logging for health checks and static assets skip_paths = ["/health", "/healthz", "/static", "/favicon.ico"] should_log_boundary = self.enable_gateway_logging and not any(path.startswith(skip_path) for skip_path in skip_paths) - + # Log gateway request started if should_log_boundary: try: @@ -210,18 +210,15 @@ async def dispatch(self, request: Request, call_next: Callable): request_path=path, user_agent=user_agent, client_ip=client_ip, - metadata={ - "event": "request_started", - "query_params": str(request.query_params) if request.query_params else None - } + metadata={"event": "request_started", "query_params": str(request.query_params) if request.query_params else None}, ) except Exception as e: logger.warning(f"Failed to log request start: {e}") - + # Skip detailed logging if disabled if not self.log_detailed_requests: response = await call_next(request) - + # Still log request completed even if detailed logging is disabled if should_log_boundary: duration_ms = (time.time() - start_time) * 1000 @@ -241,14 +238,11 @@ async def dispatch(self, request: Request, call_next: Callable): user_agent=user_agent, client_ip=client_ip, duration_ms=duration_ms, - metadata={ - "event": "request_completed", - "response_time_category": "fast" if duration_ms < 100 else "normal" if duration_ms < 1000 else "slow" - } + metadata={"event": "request_completed", "response_time_category": "fast" if duration_ms < 100 else "normal" if duration_ms < 1000 else "slow"}, ) except Exception as e: logger.warning(f"Failed to log request completion: {e}") - + return response # Always log at INFO level for request payloads to ensure visibility @@ -333,7 +327,7 @@ async def receive(): status_code = response.status_code except Exception as e: duration_ms = (time.time() - start_time) * 1000 - + # Log request failed if should_log_boundary: try: @@ -351,23 +345,21 @@ async def receive(): client_ip=client_ip, duration_ms=duration_ms, error=e, - metadata={ - "event": "request_failed" - } + metadata={"event": "request_failed"}, ) except Exception as log_error: logger.warning(f"Failed to log request failure: {log_error}") - + raise - + # Calculate total duration duration_ms = (time.time() - start_time) * 1000 - + # Log gateway request completed if should_log_boundary: try: log_level = "ERROR" if status_code >= 500 else "WARNING" if status_code >= 400 else "INFO" - + structured_logger.log( level=log_level, message=f"Request completed: {method} {path} - {status_code}", @@ -382,23 +374,20 @@ async def receive(): user_agent=user_agent, client_ip=client_ip, duration_ms=duration_ms, - metadata={ - "event": "request_completed", - "response_time_category": self._categorize_response_time(duration_ms) - } + metadata={"event": "request_completed", "response_time_category": self._categorize_response_time(duration_ms)}, ) except Exception as e: logger.warning(f"Failed to log request completion: {e}") - + return response - + @staticmethod def _categorize_response_time(duration_ms: float) -> str: """Categorize response time for analytics. - + Args: duration_ms: Response time in milliseconds - + Returns: Category string """ diff --git a/mcpgateway/routers/log_search.py b/mcpgateway/routers/log_search.py index 5fb14696f..1fe74a15c 100644 --- a/mcpgateway/routers/log_search.py +++ b/mcpgateway/routers/log_search.py @@ -100,21 +100,13 @@ def _aggregate_custom_windows( needs_rebuild = True if needs_rebuild: - db.execute( - delete(PerformanceMetric).where( - PerformanceMetric.window_duration_seconds == window_duration_seconds - ) - ) + db.execute(delete(PerformanceMetric).where(PerformanceMetric.window_duration_seconds == window_duration_seconds)) db.commit() sample_row = None max_existing = None if not needs_rebuild: - max_existing = db.execute( - select(sa_func.max(PerformanceMetric.window_start)).where( - PerformanceMetric.window_duration_seconds == window_duration_seconds - ) - ).scalar() + max_existing = db.execute(select(sa_func.max(PerformanceMetric.window_start)).where(PerformanceMetric.window_duration_seconds == window_duration_seconds)).scalar() if max_existing: current_start = max_existing if max_existing.tzinfo else max_existing.replace(tzinfo=timezone.utc) @@ -142,6 +134,7 @@ def _aggregate_custom_windows( # Request/Response Models class LogSearchRequest(BaseModel): """Log search request parameters.""" + search_text: Optional[str] = Field(None, description="Text search query") level: Optional[List[str]] = Field(None, description="Log levels to filter") component: Optional[List[str]] = Field(None, description="Components to filter") @@ -161,6 +154,7 @@ class LogSearchRequest(BaseModel): class LogEntry(BaseModel): """Log entry response model.""" + id: str timestamp: datetime level: str @@ -175,24 +169,27 @@ class LogEntry(BaseModel): request_method: Optional[str] = None is_security_event: bool = False error_details: Optional[Dict[str, Any]] = None - + class Config: from_attributes = True class LogSearchResponse(BaseModel): """Log search response.""" + total: int results: List[LogEntry] class CorrelationTraceRequest(BaseModel): """Correlation trace request.""" + correlation_id: str class CorrelationTraceResponse(BaseModel): """Correlation trace response with all related logs.""" + correlation_id: str total_duration_ms: Optional[float] log_count: int @@ -205,6 +202,7 @@ class CorrelationTraceResponse(BaseModel): class SecurityEventResponse(BaseModel): """Security event response model.""" + id: str timestamp: datetime event_type: str @@ -216,13 +214,14 @@ class SecurityEventResponse(BaseModel): threat_score: float action_taken: Optional[str] resolved: bool - + class Config: from_attributes = True class AuditTrailResponse(BaseModel): """Audit trail response model.""" + id: str timestamp: datetime correlation_id: Optional[str] = None @@ -235,13 +234,14 @@ class AuditTrailResponse(BaseModel): success: bool requires_review: bool data_classification: Optional[str] - + class Config: from_attributes = True class PerformanceMetricResponse(BaseModel): """Performance metric response model.""" + id: str timestamp: datetime component: str @@ -257,7 +257,7 @@ class PerformanceMetricResponse(BaseModel): p50_duration_ms: float p95_duration_ms: float p99_duration_ms: float - + class Config: from_attributes = True @@ -265,90 +265,81 @@ class Config: # API Endpoints @router.post("/search", response_model=LogSearchResponse) @require_permission("logs:read") -async def search_logs( - request: LogSearchRequest, - user=Depends(get_current_user_with_permissions), - db: Session = Depends(get_db) -) -> LogSearchResponse: +async def search_logs(request: LogSearchRequest, user=Depends(get_current_user_with_permissions), db: Session = Depends(get_db)) -> LogSearchResponse: """Search structured logs with filters and pagination. - + Args: request: Search parameters db: Database session _: Permission check dependency - + Returns: Search results with pagination """ try: # Build base query stmt = select(StructuredLogEntry) - + # Apply filters conditions = [] - + if request.search_text: - conditions.append( - or_( - StructuredLogEntry.message.ilike(f"%{request.search_text}%"), - StructuredLogEntry.component.ilike(f"%{request.search_text}%") - ) - ) - + conditions.append(or_(StructuredLogEntry.message.ilike(f"%{request.search_text}%"), StructuredLogEntry.component.ilike(f"%{request.search_text}%"))) + if request.level: conditions.append(StructuredLogEntry.level.in_(request.level)) - + if request.component: conditions.append(StructuredLogEntry.component.in_(request.component)) - + # Note: category field doesn't exist in StructuredLogEntry # if request.category: # conditions.append(StructuredLogEntry.category.in_(request.category)) - + if request.correlation_id: conditions.append(StructuredLogEntry.correlation_id == request.correlation_id) - + if request.user_id: conditions.append(StructuredLogEntry.user_id == request.user_id) - + if request.start_time: conditions.append(StructuredLogEntry.timestamp >= request.start_time) - + if request.end_time: conditions.append(StructuredLogEntry.timestamp <= request.end_time) - + if request.min_duration_ms is not None: conditions.append(StructuredLogEntry.duration_ms >= request.min_duration_ms) - + if request.max_duration_ms is not None: conditions.append(StructuredLogEntry.duration_ms <= request.max_duration_ms) - + if request.has_error is not None: if request.has_error: conditions.append(StructuredLogEntry.error_details.isnot(None)) else: conditions.append(StructuredLogEntry.error_details.is_(None)) - + if conditions: stmt = stmt.where(and_(*conditions)) - + # Get total count count_stmt = select(sa_func.count()).select_from(stmt.subquery()) total = db.execute(count_stmt).scalar() or 0 - + # Apply sorting sort_column = getattr(StructuredLogEntry, request.sort_by, StructuredLogEntry.timestamp) if request.sort_order == "desc": stmt = stmt.order_by(desc(sort_column)) else: stmt = stmt.order_by(sort_column) - + # Apply pagination stmt = stmt.limit(request.limit).offset(request.offset) - + # Execute query results = db.execute(stmt).scalars().all() - + # Convert to response models log_entries = [ LogEntry( @@ -369,12 +360,9 @@ async def search_logs( ) for log in results ] - - return LogSearchResponse( - total=total, - results=log_entries - ) - + + return LogSearchResponse(total=total, results=log_entries) + except Exception as e: logger.error(f"Log search failed: {e}") raise HTTPException(status_code=500, detail="Log search failed") @@ -382,61 +370,51 @@ async def search_logs( @router.get("/trace/{correlation_id}", response_model=CorrelationTraceResponse) @require_permission("logs:read") -async def trace_correlation_id( - correlation_id: str, - user=Depends(get_current_user_with_permissions), - db: Session = Depends(get_db) -) -> CorrelationTraceResponse: +async def trace_correlation_id(correlation_id: str, user=Depends(get_current_user_with_permissions), db: Session = Depends(get_db)) -> CorrelationTraceResponse: """Get all logs and events for a correlation ID. - + Args: correlation_id: Correlation ID to trace db: Database session _: Permission check dependency - + Returns: Complete trace of all related logs and events """ try: # Get structured logs - log_stmt = select(StructuredLogEntry).where( - StructuredLogEntry.correlation_id == correlation_id - ).order_by(StructuredLogEntry.timestamp) - + log_stmt = select(StructuredLogEntry).where(StructuredLogEntry.correlation_id == correlation_id).order_by(StructuredLogEntry.timestamp) + logs = db.execute(log_stmt).scalars().all() - + # Get security events - security_stmt = select(SecurityEvent).where( - SecurityEvent.correlation_id == correlation_id - ).order_by(SecurityEvent.timestamp) - + security_stmt = select(SecurityEvent).where(SecurityEvent.correlation_id == correlation_id).order_by(SecurityEvent.timestamp) + security_events = db.execute(security_stmt).scalars().all() - + # Get audit trails - audit_stmt = select(AuditTrail).where( - AuditTrail.correlation_id == correlation_id - ).order_by(AuditTrail.timestamp) - + audit_stmt = select(AuditTrail).where(AuditTrail.correlation_id == correlation_id).order_by(AuditTrail.timestamp) + audit_trails = db.execute(audit_stmt).scalars().all() - + # Calculate metrics durations = [log.duration_ms for log in logs if log.duration_ms is not None] total_duration = sum(durations) if durations else None error_count = sum(1 for log in logs if log.error_details) - + # Get performance metrics (if any aggregations exist) perf_metrics = None if logs: component = logs[0].component operation = logs[0].operation_type if component and operation: - perf_stmt = select(PerformanceMetric).where( - and_( - PerformanceMetric.component == component, - PerformanceMetric.operation_type == operation - ) - ).order_by(desc(PerformanceMetric.window_start)).limit(1) - + perf_stmt = ( + select(PerformanceMetric) + .where(and_(PerformanceMetric.component == component, PerformanceMetric.operation_type == operation)) + .order_by(desc(PerformanceMetric.window_start)) + .limit(1) + ) + perf = db.execute(perf_stmt).scalar_one_or_none() if perf: perf_metrics = { @@ -445,7 +423,7 @@ async def trace_correlation_id( "p99_duration_ms": perf.p99_duration_ms, "error_rate": perf.error_rate, } - + return CorrelationTraceResponse( correlation_id=correlation_id, total_duration_ms=total_duration, @@ -494,7 +472,7 @@ async def trace_correlation_id( ], performance_metrics=perf_metrics, ) - + except Exception as e: logger.error(f"Correlation trace failed: {e}", exc_info=True) raise HTTPException(status_code=500, detail=f"Correlation trace failed: {str(e)}") @@ -511,10 +489,10 @@ async def get_security_events( limit: int = Query(100, ge=1, le=1000), offset: int = Query(0, ge=0), user=Depends(get_current_user_with_permissions), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ) -> List[SecurityEventResponse]: """Get security events with filters. - + Args: severity: Filter by severity levels event_type: Filter by event types @@ -525,13 +503,13 @@ async def get_security_events( offset: Result offset db: Database session _: Permission check dependency - + Returns: List of security events """ try: stmt = select(SecurityEvent) - + conditions = [] if severity: conditions.append(SecurityEvent.severity.in_(severity)) @@ -543,14 +521,14 @@ async def get_security_events( conditions.append(SecurityEvent.timestamp >= start_time) if end_time: conditions.append(SecurityEvent.timestamp <= end_time) - + if conditions: stmt = stmt.where(and_(*conditions)) - + stmt = stmt.order_by(desc(SecurityEvent.timestamp)).limit(limit).offset(offset) - + events = db.execute(stmt).scalars().all() - + return [ SecurityEventResponse( id=str(event.id), @@ -567,7 +545,7 @@ async def get_security_events( ) for event in events ] - + except Exception as e: logger.error(f"Security events query failed: {e}", exc_info=True) raise HTTPException(status_code=500, detail=f"Security events query failed: {str(e)}") @@ -585,10 +563,10 @@ async def get_audit_trails( limit: int = Query(100, ge=1, le=1000), offset: int = Query(0, ge=0), user=Depends(get_current_user_with_permissions), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ) -> List[AuditTrailResponse]: """Get audit trails with filters. - + Args: action: Filter by actions resource_type: Filter by resource types @@ -600,13 +578,13 @@ async def get_audit_trails( offset: Result offset db: Database session _: Permission check dependency - + Returns: List of audit trail entries """ try: stmt = select(AuditTrail) - + conditions = [] if action: conditions.append(AuditTrail.action.in_(action)) @@ -620,14 +598,14 @@ async def get_audit_trails( conditions.append(AuditTrail.timestamp >= start_time) if end_time: conditions.append(AuditTrail.timestamp <= end_time) - + if conditions: stmt = stmt.where(and_(*conditions)) - + stmt = stmt.order_by(desc(AuditTrail.timestamp)).limit(limit).offset(offset) - + trails = db.execute(stmt).scalars().all() - + return [ AuditTrailResponse( id=str(trail.id), @@ -645,7 +623,7 @@ async def get_audit_trails( ) for trail in trails ] - + except Exception as e: logger.error(f"Audit trails query failed: {e}", exc_info=True) raise HTTPException(status_code=500, detail=f"Audit trails query failed: {str(e)}") @@ -659,17 +637,17 @@ async def get_performance_metrics( hours: float = Query(24.0, ge=MIN_PERFORMANCE_RANGE_HOURS, le=1000.0, description="Historical window to display"), aggregation: str = Query(_DEFAULT_AGGREGATION_KEY, regex="^(5m|24h)$", description="Aggregation level for metrics"), user=Depends(get_current_user_with_permissions), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ) -> List[PerformanceMetricResponse]: """Get performance metrics. - + Args: component: Filter by component operation: Filter by operation hours: Hours of history db: Database session _: Permission check dependency - + Returns: List of performance metrics """ @@ -692,21 +670,19 @@ async def get_performance_metrics( except Exception as agg_error: # pragma: no cover - defensive logging logger.warning("On-demand metrics aggregation failed: %s", agg_error) - stmt = select(PerformanceMetric).where( - PerformanceMetric.window_duration_seconds == window_duration_seconds - ) - + stmt = select(PerformanceMetric).where(PerformanceMetric.window_duration_seconds == window_duration_seconds) + if component: stmt = stmt.where(PerformanceMetric.component == component) if operation: stmt = stmt.where(PerformanceMetric.operation_type == operation) stmt = stmt.order_by(desc(PerformanceMetric.window_start), desc(PerformanceMetric.timestamp)) - + metrics = db.execute(stmt).scalars().all() metrics = _deduplicate_metrics(metrics) - + return [ PerformanceMetricResponse( id=str(metric.id), @@ -727,7 +703,7 @@ async def get_performance_metrics( ) for metric in metrics ] - + except Exception as e: logger.error(f"Performance metrics query failed: {e}") raise HTTPException(status_code=500, detail="Performance metrics query failed") diff --git a/mcpgateway/services/a2a_service.py b/mcpgateway/services/a2a_service.py index c4aa7f137..526c478b1 100644 --- a/mcpgateway/services/a2a_service.py +++ b/mcpgateway/services/a2a_service.py @@ -283,7 +283,7 @@ async def register_agent( ) logger.info(f"Registered new A2A agent: {new_agent.name} (ID: {new_agent.id})") - + # Log A2A agent registration for lifecycle tracking structured_logger.info( f"A2A agent '{new_agent.name}' registered successfully", @@ -298,10 +298,10 @@ async def register_agent( "agent_type": new_agent.agent_type, "protocol_version": new_agent.protocol_version, "visibility": visibility, - "endpoint_url": new_agent.endpoint_url - } + "endpoint_url": new_agent.endpoint_url, + }, ) - + return self._db_to_schema(db=db, db_agent=new_agent) except A2AAgentNameConflictError as ie: @@ -862,13 +862,14 @@ async def invoke_agent( token_value = getattr(db_row, "auth_value", None) if db_row else None if token_value: headers["Authorization"] = f"Bearer {token_value}" - + # Add correlation ID to outbound headers for distributed tracing from mcpgateway.utils.correlation_id import get_correlation_id + correlation_id = get_correlation_id() if correlation_id: headers["X-Correlation-ID"] = correlation_id - + # Log A2A external call start call_start_time = datetime.now(timezone.utc) structured_logger.log( @@ -884,8 +885,8 @@ async def invoke_agent( "agent_id": agent.id, "endpoint_url": agent.endpoint_url, "interaction_type": interaction_type, - "protocol_version": agent.protocol_version - } + "protocol_version": agent.protocol_version, + }, ) http_response = await client.post(agent.endpoint_url, json=request_data, headers=headers) @@ -894,7 +895,7 @@ async def invoke_agent( if http_response.status_code == 200: response = http_response.json() success = True - + # Log successful A2A call structured_logger.log( level="INFO", @@ -904,17 +905,11 @@ async def invoke_agent( user_email=user_email, correlation_id=correlation_id, duration_ms=call_duration_ms, - metadata={ - "event": "a2a_call_completed", - "agent_name": agent_name, - "agent_id": agent.id, - "status_code": http_response.status_code, - "success": True - } + metadata={"event": "a2a_call_completed", "agent_name": agent_name, "agent_id": agent.id, "status_code": http_response.status_code, "success": True}, ) else: error_message = f"HTTP {http_response.status_code}: {http_response.text}" - + # Log failed A2A call structured_logger.log( level="ERROR", @@ -924,18 +919,10 @@ async def invoke_agent( user_email=user_email, correlation_id=correlation_id, duration_ms=call_duration_ms, - error_details={ - "error_type": "A2AHTTPError", - "error_message": error_message - }, - metadata={ - "event": "a2a_call_failed", - "agent_name": agent_name, - "agent_id": agent.id, - "status_code": http_response.status_code - } + error_details={"error_type": "A2AHTTPError", "error_message": error_message}, + metadata={"event": "a2a_call_failed", "agent_name": agent_name, "agent_id": agent.id, "status_code": http_response.status_code}, ) - + raise A2AAgentError(error_message) except Exception as e: diff --git a/mcpgateway/services/audit_trail_service.py b/mcpgateway/services/audit_trail_service.py index 8dda99c6d..1e2b83ee5 100644 --- a/mcpgateway/services/audit_trail_service.py +++ b/mcpgateway/services/audit_trail_service.py @@ -28,6 +28,7 @@ class AuditAction(str, Enum): """Audit trail action types.""" + CREATE = "CREATE" READ = "READ" UPDATE = "UPDATE" @@ -40,6 +41,7 @@ class AuditAction(str, Enum): class DataClassification(str, Enum): """Data classification levels.""" + PUBLIC = "public" INTERNAL = "internal" CONFIDENTIAL = "confidential" @@ -58,15 +60,14 @@ class DataClassification(str, Enum): class AuditTrailService: """Service for managing audit trails and compliance logging. - + Provides comprehensive audit trail management with data classification, change tracking, and compliance reporting capabilities. """ - + def __init__(self): """Initialize audit trail service.""" - pass - + def log_action( self, action: str, @@ -90,10 +91,10 @@ def log_action( context: Optional[Dict[str, Any]] = None, details: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None, - db: Optional[Session] = None + db: Optional[Session] = None, ) -> Optional[AuditTrail]: """Log an audit trail entry. - + Args: action: Action performed (CREATE, READ, UPDATE, DELETE, etc.) resource_type: Type of resource (tool, server, prompt, etc.) @@ -117,18 +118,18 @@ def log_action( details: Extra key/value payload (stored under context.details) metadata: Extra metadata payload (stored under context.metadata) db: Optional database session - + Returns: Created AuditTrail entry or None if logging disabled """ correlation_id = get_or_generate_correlation_id() - + # Use provided session or create new one close_db = False if db is None: db = SessionLocal() close_db = True - + try: context_payload: Dict[str, Any] = dict(context) if context else {} if details: @@ -165,46 +166,30 @@ def log_action( requires_review=requires_review_flag, success=success, error_message=error_message, - context=context_value + context=context_value, ) - + db.add(audit_entry) db.commit() db.refresh(audit_entry) - + logger.debug( f"Audit trail logged: {action} {resource_type}/{resource_id} by {user_id}", - extra={ - "correlation_id": correlation_id, - "action": action, - "resource_type": resource_type, - "resource_id": resource_id, - "user_id": user_id, - "success": success - } + extra={"correlation_id": correlation_id, "action": action, "resource_type": resource_type, "resource_id": resource_id, "user_id": user_id, "success": success}, ) - + return audit_entry - + except Exception as e: - logger.error( - f"Failed to log audit trail: {e}", - exc_info=True, - extra={ - "correlation_id": correlation_id, - "action": action, - "resource_type": resource_type, - "resource_id": resource_id - } - ) + logger.error(f"Failed to log audit trail: {e}", exc_info=True, extra={"correlation_id": correlation_id, "action": action, "resource_type": resource_type, "resource_id": resource_id}) if close_db: db.rollback() return None - + finally: if close_db: db.close() - + def _determine_requires_review( self, action: Optional[str], @@ -238,10 +223,10 @@ def log_crud_operation( success: bool = True, error_message: Optional[str] = None, db: Optional[Session] = None, - **kwargs + **kwargs, ) -> Optional[AuditTrail]: """Log a CRUD operation with change tracking. - + Args: operation: CRUD operation (CREATE, READ, UPDATE, DELETE) resource_type: Type of resource @@ -256,7 +241,7 @@ def log_crud_operation( error_message: Error message if failed db: Optional database session **kwargs: Additional arguments passed to log_action - + Returns: Created AuditTrail entry """ @@ -269,21 +254,21 @@ def log_crud_operation( new_val = new_values.get(key) if old_val != new_val: changes[key] = {"old": old_val, "new": new_val} - + # Determine data classification based on resource type data_classification = None if resource_type in ["user", "team", "token", "credential"]: data_classification = DataClassification.CONFIDENTIAL.value elif resource_type in ["tool", "server", "prompt", "resource"]: data_classification = DataClassification.INTERNAL.value - + # Determine if review is required requires_review = False if data_classification == DataClassification.CONFIDENTIAL.value: requires_review = True if operation == "DELETE" and resource_type in ["tool", "server", "gateway"]: requires_review = True - + return self.log_action( action=operation, resource_type=resource_type, @@ -300,9 +285,9 @@ def log_crud_operation( success=success, error_message=error_message, db=db, - **kwargs + **kwargs, ) - + def log_data_access( self, resource_type: str, @@ -314,10 +299,10 @@ def log_data_access( resource_name: Optional[str] = None, data_classification: Optional[str] = None, db: Optional[Session] = None, - **kwargs + **kwargs, ) -> Optional[AuditTrail]: """Log data access for compliance tracking. - + Args: resource_type: Type of resource accessed resource_id: ID of the resource @@ -329,15 +314,12 @@ def log_data_access( data_classification: Data classification level db: Optional database session **kwargs: Additional arguments passed to log_action - + Returns: Created AuditTrail entry """ - requires_review = data_classification in [ - DataClassification.CONFIDENTIAL.value, - DataClassification.RESTRICTED.value - ] - + requires_review = data_classification in [DataClassification.CONFIDENTIAL.value, DataClassification.RESTRICTED.value] + return self.log_action( action=access_type, resource_type=resource_type, @@ -350,22 +332,14 @@ def log_data_access( requires_review=requires_review, success=True, db=db, - **kwargs + **kwargs, ) - + def log_audit( - self, - user_id: str, - resource_type: str, - resource_id: str, - action: str, - user_email: Optional[str] = None, - description: Optional[str] = None, - db: Optional[Session] = None, - **kwargs + self, user_id: str, resource_type: str, resource_id: str, action: str, user_email: Optional[str] = None, description: Optional[str] = None, db: Optional[Session] = None, **kwargs ) -> Optional[AuditTrail]: """Convenience method for simple audit logging. - + Args: user_id: User who performed the action resource_type: Type of resource @@ -375,7 +349,7 @@ def log_audit( description: Description of the action db: Optional database session **kwargs: Additional arguments passed to log_action - + Returns: Created AuditTrail entry """ @@ -383,18 +357,9 @@ def log_audit( context = kwargs.pop("context", {}) if description: context["description"] = description - - return self.log_action( - action=action, - resource_type=resource_type, - resource_id=resource_id, - user_id=user_id, - user_email=user_email, - context=context if context else None, - db=db, - **kwargs - ) - + + return self.log_action(action=action, resource_type=resource_type, resource_id=resource_id, user_id=user_id, user_email=user_email, context=context if context else None, db=db, **kwargs) + def get_audit_trail( self, resource_type: Optional[str] = None, @@ -405,10 +370,10 @@ def get_audit_trail( end_time: Optional[datetime] = None, limit: int = 100, offset: int = 0, - db: Optional[Session] = None + db: Optional[Session] = None, ) -> list[AuditTrail]: """Query audit trail entries. - + Args: resource_type: Filter by resource type resource_id: Filter by resource ID @@ -419,7 +384,7 @@ def get_audit_trail( limit: Maximum number of results offset: Offset for pagination db: Optional database session - + Returns: List of AuditTrail entries """ @@ -427,10 +392,10 @@ def get_audit_trail( if db is None: db = SessionLocal() close_db = True - + try: query = select(AuditTrail) - + if resource_type: query = query.where(AuditTrail.resource_type == resource_type) if resource_id: @@ -443,13 +408,13 @@ def get_audit_trail( query = query.where(AuditTrail.timestamp >= start_time) if end_time: query = query.where(AuditTrail.timestamp <= end_time) - + query = query.order_by(AuditTrail.timestamp.desc()) query = query.limit(limit).offset(offset) - + result = db.execute(query) return list(result.scalars().all()) - + finally: if close_db: db.close() @@ -461,7 +426,7 @@ def get_audit_trail( def get_audit_trail_service() -> AuditTrailService: """Get or create the singleton audit trail service instance. - + Returns: AuditTrailService instance """ diff --git a/mcpgateway/services/gateway_service.py b/mcpgateway/services/gateway_service.py index 537f773de..38ad4770a 100644 --- a/mcpgateway/services/gateway_service.py +++ b/mcpgateway/services/gateway_service.py @@ -816,7 +816,7 @@ async def register_gateway( await self._notify_gateway_added(db_gateway) logger.info(f"Registered gateway: {gateway.name}") - + # Structured logging: Audit trail for gateway creation audit_trail.log_action( user_id=created_by or "system", @@ -842,7 +842,7 @@ async def register_gateway( }, db=db, ) - + # Structured logging: Log successful gateway creation structured_logger.log( level="INFO", @@ -870,7 +870,7 @@ async def register_gateway( if TYPE_CHECKING: ge: ExceptionGroup[GatewayConnectionError] logger.error(f"GatewayConnectionError in group: {ge.exceptions}") - + structured_logger.log( level="ERROR", message="Gateway creation failed due to connection error", @@ -887,7 +887,7 @@ async def register_gateway( if TYPE_CHECKING: gnce: ExceptionGroup[GatewayNameConflictError] logger.error(f"GatewayNameConflictError in group: {gnce.exceptions}") - + structured_logger.log( level="WARNING", message="Gateway creation failed due to name conflict", @@ -903,7 +903,7 @@ async def register_gateway( if TYPE_CHECKING: guce: ExceptionGroup[GatewayDuplicateConflictError] logger.error(f"GatewayDuplicateConflictError in group: {guce.exceptions}") - + structured_logger.log( level="WARNING", message="Gateway creation failed due to duplicate", @@ -919,7 +919,7 @@ async def register_gateway( if TYPE_CHECKING: ve: ExceptionGroup[ValueError] logger.error(f"ValueErrors in group: {ve.exceptions}") - + structured_logger.log( level="ERROR", message="Gateway creation failed due to validation error", @@ -936,7 +936,7 @@ async def register_gateway( if TYPE_CHECKING: re: ExceptionGroup[RuntimeError] logger.error(f"RuntimeErrors in group: {re.exceptions}") - + structured_logger.log( level="ERROR", message="Gateway creation failed due to runtime error", @@ -953,7 +953,7 @@ async def register_gateway( if TYPE_CHECKING: ie: ExceptionGroup[IntegrityError] logger.error(f"IntegrityErrors in group: {ie.exceptions}") - + structured_logger.log( level="ERROR", message="Gateway creation failed due to database integrity error", @@ -1585,7 +1585,7 @@ async def update_gateway( await self._notify_gateway_updated(gateway) logger.info(f"Updated gateway: {gateway.name}") - + # Structured logging: Audit trail for gateway update audit_trail.log_action( user_id=user_email or modified_by or "system", @@ -1607,7 +1607,7 @@ async def update_gateway( }, db=db, ) - + # Structured logging: Log successful gateway update structured_logger.log( level="INFO", @@ -1625,7 +1625,7 @@ async def update_gateway( }, db=db, ) - + gateway.team = self._get_team_name(db, getattr(gateway, "team_id", None)) return GatewayRead.model_validate(self._prepare_gateway_for_read(gateway)) @@ -1633,7 +1633,7 @@ async def update_gateway( return None except GatewayNameConflictError as ge: logger.error(f"GatewayNameConflictError in group: {ge}") - + structured_logger.log( level="WARNING", message="Gateway update failed due to name conflict", @@ -1648,7 +1648,7 @@ async def update_gateway( raise ge except GatewayNotFoundError as gnfe: logger.error(f"GatewayNotFoundError: {gnfe}") - + structured_logger.log( level="ERROR", message="Gateway update failed - gateway not found", @@ -1663,7 +1663,7 @@ async def update_gateway( raise gnfe except IntegrityError as ie: logger.error(f"IntegrityErrors in group: {ie}") - + structured_logger.log( level="ERROR", message="Gateway update failed due to database integrity error", @@ -1678,7 +1678,7 @@ async def update_gateway( raise ie except PermissionError as pe: db.rollback() - + structured_logger.log( level="WARNING", message="Gateway update failed due to permission error", @@ -1693,7 +1693,7 @@ async def update_gateway( raise except Exception as e: db.rollback() - + structured_logger.log( level="ERROR", message="Gateway update failed", @@ -1913,7 +1913,7 @@ async def toggle_gateway_status(self, db: Session, gateway_id: str, activate: bo await self.tool_service.toggle_tool_status(db, tool.id, activate, reachable) logger.info(f"Gateway status: {gateway.name} - {'enabled' if activate else 'disabled'} and {'accessible' if reachable else 'inaccessible'}") - + # Structured logging: Audit trail for gateway status toggle audit_trail.log_action( user_id=user_email or "system", @@ -1933,7 +1933,7 @@ async def toggle_gateway_status(self, db: Session, gateway_id: str, activate: bo }, db=db, ) - + # Structured logging: Log successful gateway status toggle structured_logger.log( level="INFO", @@ -1971,7 +1971,7 @@ async def toggle_gateway_status(self, db: Session, gateway_id: str, activate: bo raise e except Exception as e: db.rollback() - + # Structured logging: Log generic gateway status toggle failure structured_logger.log( level="ERROR", @@ -2067,7 +2067,7 @@ async def delete_gateway(self, db: Session, gateway_id: str, user_email: Optiona await self._notify_gateway_deleted(gateway_info) logger.info(f"Permanently deleted gateway: {gateway.name}") - + # Structured logging: Audit trail for gateway deletion audit_trail.log_action( user_id=user_email or "system", @@ -2083,7 +2083,7 @@ async def delete_gateway(self, db: Session, gateway_id: str, user_email: Optiona }, db=db, ) - + # Structured logging: Log successful gateway deletion structured_logger.log( level="INFO", @@ -2103,7 +2103,7 @@ async def delete_gateway(self, db: Session, gateway_id: str, user_email: Optiona except PermissionError as pe: db.rollback() - + # Structured logging: Log permission error structured_logger.log( level="WARNING", @@ -2119,7 +2119,7 @@ async def delete_gateway(self, db: Session, gateway_id: str, user_email: Optiona raise except Exception as e: db.rollback() - + # Structured logging: Log generic gateway deletion failure structured_logger.log( level="ERROR", diff --git a/mcpgateway/services/log_aggregator.py b/mcpgateway/services/log_aggregator.py index 69f4e279a..9b70317dc 100644 --- a/mcpgateway/services/log_aggregator.py +++ b/mcpgateway/services/log_aggregator.py @@ -29,29 +29,24 @@ class LogAggregator: """Aggregates structured logs into performance metrics.""" - + def __init__(self): """Initialize log aggregator.""" self.aggregation_window_minutes = getattr(settings, "metrics_aggregation_window_minutes", 5) self.enabled = getattr(settings, "metrics_aggregation_enabled", True) - + def aggregate_performance_metrics( - self, - component: Optional[str], - operation_type: Optional[str], - window_start: Optional[datetime] = None, - window_end: Optional[datetime] = None, - db: Optional[Session] = None + self, component: Optional[str], operation_type: Optional[str], window_start: Optional[datetime] = None, window_end: Optional[datetime] = None, db: Optional[Session] = None ) -> Optional[PerformanceMetric]: """Aggregate performance metrics for a component and operation. - + Args: component: Component name operation: Operation name window_start: Start of aggregation window (defaults to N minutes ago) window_end: End of aggregation window (defaults to now) db: Optional database session - + Returns: Created PerformanceMetric or None if no data """ @@ -59,14 +54,14 @@ def aggregate_performance_metrics( return None if not component or not operation_type: return None - + window_start, window_end = self._resolve_window_bounds(window_start, window_end) - + should_close = False if db is None: db = SessionLocal() should_close = True - + try: # Query structured logs for this component/operation in time window stmt = select(StructuredLogEntry).where( @@ -75,36 +70,36 @@ def aggregate_performance_metrics( StructuredLogEntry.operation_type == operation_type, StructuredLogEntry.timestamp >= window_start, StructuredLogEntry.timestamp < window_end, - StructuredLogEntry.duration_ms.isnot(None) + StructuredLogEntry.duration_ms.isnot(None), ) ) - + results = db.execute(stmt).scalars().all() - + if not results: return None - + # Extract durations durations = sorted(r.duration_ms for r in results if r.duration_ms is not None) - + if not durations: return None - + # Calculate statistics count = len(durations) avg_duration = statistics.fmean(durations) if hasattr(statistics, "fmean") else statistics.mean(durations) min_duration = durations[0] max_duration = durations[-1] - + # Calculate percentiles p50 = self._percentile(durations, 0.50) p95 = self._percentile(durations, 0.95) p99 = self._percentile(durations, 0.99) - + # Count errors error_count = self._calculate_error_count(results) error_rate = error_count / count if count > 0 else 0.0 - + metric = self._upsert_metric( component=component, operation_type=operation_type, @@ -125,99 +120,80 @@ def aggregate_performance_metrics( }, db=db, ) - - logger.info( - f"Aggregated performance metrics for {component}.{operation_type}: " - f"{count} requests, {avg_duration:.2f}ms avg, {error_rate:.2%} error rate" - ) - + + logger.info(f"Aggregated performance metrics for {component}.{operation_type}: " f"{count} requests, {avg_duration:.2f}ms avg, {error_rate:.2%} error rate") + return metric - + except Exception as e: logger.error(f"Failed to aggregate performance metrics: {e}") if db: db.rollback() return None - + finally: if should_close: db.close() - - def aggregate_all_components( - self, - window_start: Optional[datetime] = None, - window_end: Optional[datetime] = None, - db: Optional[Session] = None - ) -> List[PerformanceMetric]: + + def aggregate_all_components(self, window_start: Optional[datetime] = None, window_end: Optional[datetime] = None, db: Optional[Session] = None) -> List[PerformanceMetric]: """Aggregate metrics for all components and operations. - + Args: window_start: Start of aggregation window window_end: End of aggregation window db: Optional database session - + Returns: List of created PerformanceMetric records """ if not self.enabled: return [] - + should_close = False if db is None: db = SessionLocal() should_close = True - + try: window_start, window_end = self._resolve_window_bounds(window_start, window_end) - stmt = select( - StructuredLogEntry.component, - StructuredLogEntry.operation_type - ).where( - and_( - StructuredLogEntry.timestamp >= window_start, - StructuredLogEntry.timestamp < window_end, - StructuredLogEntry.duration_ms.isnot(None), - StructuredLogEntry.operation_type.isnot(None) + stmt = ( + select(StructuredLogEntry.component, StructuredLogEntry.operation_type) + .where( + and_( + StructuredLogEntry.timestamp >= window_start, + StructuredLogEntry.timestamp < window_end, + StructuredLogEntry.duration_ms.isnot(None), + StructuredLogEntry.operation_type.isnot(None), + ) ) - ).distinct() - + .distinct() + ) + pairs = db.execute(stmt).all() - + metrics = [] for component, operation in pairs: if component and operation: - metric = self.aggregate_performance_metrics( - component=component, - operation_type=operation, - window_start=window_start, - window_end=window_end, - db=db - ) + metric = self.aggregate_performance_metrics(component=component, operation_type=operation, window_start=window_start, window_end=window_end, db=db) if metric: metrics.append(metric) - + return metrics - + finally: if should_close: db.close() - - def get_recent_metrics( - self, - component: Optional[str] = None, - operation: Optional[str] = None, - hours: int = 24, - db: Optional[Session] = None - ) -> List[PerformanceMetric]: + + def get_recent_metrics(self, component: Optional[str] = None, operation: Optional[str] = None, hours: int = 24, db: Optional[Session] = None) -> List[PerformanceMetric]: """Get recent performance metrics. - + Args: component: Optional component filter operation: Optional operation filter hours: Hours of history to retrieve db: Optional database session - + Returns: List of PerformanceMetric records """ @@ -225,40 +201,33 @@ def get_recent_metrics( if db is None: db = SessionLocal() should_close = True - + try: since = datetime.now(timezone.utc) - timedelta(hours=hours) - - stmt = select(PerformanceMetric).where( - PerformanceMetric.window_start >= since - ) - + + stmt = select(PerformanceMetric).where(PerformanceMetric.window_start >= since) + if component: stmt = stmt.where(PerformanceMetric.component == component) if operation: stmt = stmt.where(PerformanceMetric.operation_type == operation) - + stmt = stmt.order_by(PerformanceMetric.window_start.desc()) - + return db.execute(stmt).scalars().all() - + finally: if should_close: db.close() - - def get_degradation_alerts( - self, - threshold_multiplier: float = 1.5, - hours: int = 24, - db: Optional[Session] = None - ) -> List[Dict[str, Any]]: + + def get_degradation_alerts(self, threshold_multiplier: float = 1.5, hours: int = 24, db: Optional[Session] = None) -> List[Dict[str, Any]]: """Identify performance degradations by comparing recent vs baseline. - + Args: threshold_multiplier: Alert if recent is X times slower than baseline hours: Hours of recent data to check db: Optional database session - + Returns: List of degradation alerts with details """ @@ -266,65 +235,60 @@ def get_degradation_alerts( if db is None: db = SessionLocal() should_close = True - + try: recent_cutoff = datetime.now(timezone.utc) - timedelta(hours=hours) baseline_cutoff = recent_cutoff - timedelta(hours=hours * 2) - + # Get unique component/operation pairs - stmt = select( - PerformanceMetric.component, - PerformanceMetric.operation_type - ).distinct() - + stmt = select(PerformanceMetric.component, PerformanceMetric.operation_type).distinct() + pairs = db.execute(stmt).all() - + alerts = [] for component, operation in pairs: # Get recent metrics recent_stmt = select(PerformanceMetric).where( - and_( - PerformanceMetric.component == component, - PerformanceMetric.operation_type == operation, - PerformanceMetric.window_start >= recent_cutoff - ) + and_(PerformanceMetric.component == component, PerformanceMetric.operation_type == operation, PerformanceMetric.window_start >= recent_cutoff) ) recent_metrics = db.execute(recent_stmt).scalars().all() - + # Get baseline metrics baseline_stmt = select(PerformanceMetric).where( and_( PerformanceMetric.component == component, PerformanceMetric.operation_type == operation, PerformanceMetric.window_start >= baseline_cutoff, - PerformanceMetric.window_start < recent_cutoff + PerformanceMetric.window_start < recent_cutoff, ) ) baseline_metrics = db.execute(baseline_stmt).scalars().all() - + if not recent_metrics or not baseline_metrics: continue - + recent_avg = statistics.mean([m.avg_duration_ms for m in recent_metrics]) baseline_avg = statistics.mean([m.avg_duration_ms for m in baseline_metrics]) - + if recent_avg > baseline_avg * threshold_multiplier: - alerts.append({ - "component": component, - "operation": operation, - "recent_avg_ms": recent_avg, - "baseline_avg_ms": baseline_avg, - "degradation_ratio": recent_avg / baseline_avg, - "recent_error_rate": statistics.mean([m.error_rate for m in recent_metrics]), - "baseline_error_rate": statistics.mean([m.error_rate for m in baseline_metrics]), - }) - + alerts.append( + { + "component": component, + "operation": operation, + "recent_avg_ms": recent_avg, + "baseline_avg_ms": baseline_avg, + "degradation_ratio": recent_avg / baseline_avg, + "recent_error_rate": statistics.mean([m.error_rate for m in recent_metrics]), + "baseline_error_rate": statistics.mean([m.error_rate for m in baseline_metrics]), + } + ) + return alerts - + finally: if should_close: db.close() - + def backfill(self, hours: float, db: Optional[Session] = None) -> int: """Backfill metrics for a historical time range. @@ -508,7 +472,7 @@ def _upsert_metric( def get_log_aggregator() -> LogAggregator: """Get or create the global log aggregator instance. - + Returns: Global LogAggregator instance """ diff --git a/mcpgateway/services/performance_tracker.py b/mcpgateway/services/performance_tracker.py index db78c6382..cfa30bcdf 100644 --- a/mcpgateway/services/performance_tracker.py +++ b/mcpgateway/services/performance_tracker.py @@ -13,7 +13,6 @@ # Standard from collections import defaultdict from contextlib import contextmanager -from datetime import datetime, timezone import logging import statistics import time @@ -28,15 +27,15 @@ class PerformanceTracker: """Tracks and analyzes performance metrics across requests. - + Provides context managers for tracking operation timing, aggregation of metrics, and threshold-based alerting. """ - + def __init__(self): """Initialize performance tracker.""" self.operation_timings: Dict[str, List[float]] = defaultdict(list) - + # Performance thresholds (seconds) from settings or defaults self.performance_thresholds = { "database_query": getattr(settings, "perf_threshold_database_query", 0.1), @@ -48,29 +47,23 @@ def __init__(self): "resource_fetch": getattr(settings, "perf_threshold_resource_fetch", 1.0), "prompt_processing": getattr(settings, "perf_threshold_prompt_processing", 0.5), } - + # Max buffer size per operation type self.max_samples = getattr(settings, "perf_max_samples_per_operation", 1000) - + @contextmanager - def track_operation( - self, - operation_name: str, - component: Optional[str] = None, - log_slow: bool = True, - extra_context: Optional[Dict[str, Any]] = None - ) -> Generator[None, None, None]: + def track_operation(self, operation_name: str, component: Optional[str] = None, log_slow: bool = True, extra_context: Optional[Dict[str, Any]] = None) -> Generator[None, None, None]: """Context manager to track operation performance. - + Args: operation_name: Name of the operation being tracked component: Component/module name for context log_slow: Whether to log operations exceeding thresholds extra_context: Additional context to include in logs - + Yields: None - + Example: >>> tracker = PerformanceTracker() >>> with tracker.track_operation("database_query", component="tool_service"): @@ -80,7 +73,7 @@ def track_operation( start_time = time.time() correlation_id = get_correlation_id() error_occurred = False - + try: yield except Exception: @@ -88,18 +81,18 @@ def track_operation( raise finally: duration = time.time() - start_time - + # Record timing self.operation_timings[operation_name].append(duration) - + # Limit buffer size if len(self.operation_timings[operation_name]) > self.max_samples: self.operation_timings[operation_name].pop(0) - + # Check threshold and log if needed - threshold = self.performance_thresholds.get(operation_name, float('inf')) + threshold = self.performance_thresholds.get(operation_name, float("inf")) threshold_exceeded = duration > threshold - + if log_slow and threshold_exceeded: context = { "operation": operation_name, @@ -112,22 +105,12 @@ def track_operation( } if extra_context: context.update(extra_context) - - logger.warning( - f"Slow operation detected: {operation_name} took {duration*1000:.2f}ms " - f"(threshold: {threshold*1000:.2f}ms)", - extra=context - ) - - def record_timing( - self, - operation_name: str, - duration: float, - component: Optional[str] = None, - extra_context: Optional[Dict[str, Any]] = None - ) -> None: + + logger.warning(f"Slow operation detected: {operation_name} took {duration*1000:.2f}ms " f"(threshold: {threshold*1000:.2f}ms)", extra=context) + + def record_timing(self, operation_name: str, duration: float, component: Optional[str] = None, extra_context: Optional[Dict[str, Any]] = None) -> None: """Manually record a timing measurement. - + Args: operation_name: Name of the operation duration: Duration in seconds @@ -135,13 +118,13 @@ def record_timing( extra_context: Additional context """ self.operation_timings[operation_name].append(duration) - + # Limit buffer size if len(self.operation_timings[operation_name]) > self.max_samples: self.operation_timings[operation_name].pop(0) - + # Check threshold - threshold = self.performance_thresholds.get(operation_name, float('inf')) + threshold = self.performance_thresholds.get(operation_name, float("inf")) if duration > threshold: context = { "operation": operation_name, @@ -152,26 +135,19 @@ def record_timing( } if extra_context: context.update(extra_context) - - logger.warning( - f"Slow operation: {operation_name} took {duration*1000:.2f}ms", - extra=context - ) - - def get_performance_summary( - self, - operation_name: Optional[str] = None, - min_samples: int = 1 - ) -> Dict[str, Any]: + + logger.warning(f"Slow operation: {operation_name} took {duration*1000:.2f}ms", extra=context) + + def get_performance_summary(self, operation_name: Optional[str] = None, min_samples: int = 1) -> Dict[str, Any]: """Get performance summary for analytics. - + Args: operation_name: Specific operation to summarize (None for all) min_samples: Minimum samples required to include in summary - + Returns: Dictionary containing performance statistics - + Example: >>> tracker = PerformanceTracker() >>> summary = tracker.get_performance_summary() @@ -179,21 +155,17 @@ def get_performance_summary( True """ summary = {} - - operations = ( - {operation_name: self.operation_timings[operation_name]} - if operation_name and operation_name in self.operation_timings - else self.operation_timings - ) - + + operations = {operation_name: self.operation_timings[operation_name]} if operation_name and operation_name in self.operation_timings else self.operation_timings + for op_name, timings in operations.items(): if len(timings) < min_samples: continue - + # Calculate percentiles sorted_timings = sorted(timings) count = len(sorted_timings) - + def percentile(p: float) -> float: """Calculate percentile value.""" k = (count - 1) * p @@ -202,7 +174,7 @@ def percentile(p: float) -> float: if f + 1 < count: return sorted_timings[f] * (1 - c) + sorted_timings[f + 1] * c return sorted_timings[f] - + summary[op_name] = { "count": count, "avg_duration_ms": statistics.mean(timings) * 1000, @@ -211,29 +183,29 @@ def percentile(p: float) -> float: "p50_duration_ms": percentile(0.5) * 1000, "p95_duration_ms": percentile(0.95) * 1000, "p99_duration_ms": percentile(0.99) * 1000, - "threshold_ms": self.performance_thresholds.get(op_name, float('inf')) * 1000, - "threshold_violations": sum(1 for t in timings if t > self.performance_thresholds.get(op_name, float('inf'))), - "violation_rate": sum(1 for t in timings if t > self.performance_thresholds.get(op_name, float('inf'))) / count, + "threshold_ms": self.performance_thresholds.get(op_name, float("inf")) * 1000, + "threshold_violations": sum(1 for t in timings if t > self.performance_thresholds.get(op_name, float("inf"))), + "violation_rate": sum(1 for t in timings if t > self.performance_thresholds.get(op_name, float("inf"))) / count, } - + return summary - + def get_operation_stats(self, operation_name: str) -> Optional[Dict[str, Any]]: """Get statistics for a specific operation. - + Args: operation_name: Name of the operation - + Returns: Statistics dictionary or None if no data """ if operation_name not in self.operation_timings: return None - + timings = self.operation_timings[operation_name] if not timings: return None - + return { "operation": operation_name, "sample_count": len(timings), @@ -241,12 +213,12 @@ def get_operation_stats(self, operation_name: str) -> Optional[Dict[str, Any]]: "min_duration_ms": min(timings) * 1000, "max_duration_ms": max(timings) * 1000, "total_time_ms": sum(timings) * 1000, - "threshold_ms": self.performance_thresholds.get(operation_name, float('inf')) * 1000, + "threshold_ms": self.performance_thresholds.get(operation_name, float("inf")) * 1000, } - + def clear_stats(self, operation_name: Optional[str] = None) -> None: """Clear performance statistics. - + Args: operation_name: Specific operation to clear (None for all) """ @@ -255,50 +227,46 @@ def clear_stats(self, operation_name: Optional[str] = None) -> None: self.operation_timings[operation_name].clear() else: self.operation_timings.clear() - + def set_threshold(self, operation_name: str, threshold_seconds: float) -> None: """Set or update performance threshold for an operation. - + Args: operation_name: Name of the operation threshold_seconds: Threshold in seconds """ self.performance_thresholds[operation_name] = threshold_seconds - - def check_performance_degradation( - self, - operation_name: str, - baseline_multiplier: float = 2.0 - ) -> Dict[str, Any]: + + def check_performance_degradation(self, operation_name: str, baseline_multiplier: float = 2.0) -> Dict[str, Any]: """Check if performance has degraded compared to baseline. - + Args: operation_name: Name of the operation to check baseline_multiplier: Multiplier for degradation detection - + Returns: Dictionary with degradation analysis """ if operation_name not in self.operation_timings: return {"degraded": False, "reason": "no_data"} - + timings = self.operation_timings[operation_name] if len(timings) < 10: return {"degraded": False, "reason": "insufficient_samples"} - + # Compare recent timings to overall average recent_count = min(10, len(timings)) recent_timings = timings[-recent_count:] historical_timings = timings[:-recent_count] if len(timings) > recent_count else timings - + if not historical_timings: return {"degraded": False, "reason": "insufficient_historical_data"} - + recent_avg = statistics.mean(recent_timings) historical_avg = statistics.mean(historical_timings) - + degraded = recent_avg > (historical_avg * baseline_multiplier) - + return { "degraded": degraded, "recent_avg_ms": recent_avg * 1000, @@ -314,7 +282,7 @@ def check_performance_degradation( def get_performance_tracker() -> PerformanceTracker: """Get or create the global performance tracker instance. - + Returns: Global PerformanceTracker instance """ diff --git a/mcpgateway/services/prompt_service.py b/mcpgateway/services/prompt_service.py index 4f4267b35..de3b1a573 100644 --- a/mcpgateway/services/prompt_service.py +++ b/mcpgateway/services/prompt_service.py @@ -407,7 +407,7 @@ async def register_prompt( await self._notify_prompt_added(db_prompt) logger.info(f"Registered prompt: {prompt.name}") - + # Structured logging: Audit trail for prompt creation audit_trail.log_action( user_id=created_by or "system", @@ -430,7 +430,7 @@ async def register_prompt( }, db=db, ) - + # Structured logging: Log successful prompt creation structured_logger.log( level="INFO", @@ -448,14 +448,14 @@ async def register_prompt( }, db=db, ) - + db_prompt.team = self._get_team_name(db, db_prompt.team_id) prompt_dict = self._convert_db_prompt(db_prompt) return PromptRead.model_validate(prompt_dict) except IntegrityError as ie: logger.error(f"IntegrityErrors in group: {ie}") - + structured_logger.log( level="ERROR", message="Prompt creation failed due to database integrity error", @@ -470,7 +470,7 @@ async def register_prompt( raise ie except PromptNameConflictError as se: db.rollback() - + structured_logger.log( level="WARNING", message="Prompt creation failed due to name conflict", @@ -484,7 +484,7 @@ async def register_prompt( raise se except Exception as e: db.rollback() - + structured_logger.log( level="ERROR", message="Prompt creation failed", @@ -1110,7 +1110,7 @@ async def update_prompt( db.refresh(prompt) await self._notify_prompt_updated(prompt) - + # Structured logging: Audit trail for prompt update audit_trail.log_action( user_id=user_email or modified_by or "system", @@ -1126,7 +1126,7 @@ async def update_prompt( context={"modified_via": modified_via}, db=db, ) - + structured_logger.log( level="INFO", message="Prompt updated successfully", @@ -1140,13 +1140,13 @@ async def update_prompt( custom_fields={"prompt_name": prompt.name, "version": prompt.version}, db=db, ) - + prompt.team = self._get_team_name(db, prompt.team_id) return PromptRead.model_validate(self._convert_db_prompt(prompt)) except PermissionError as pe: db.rollback() - + structured_logger.log( level="WARNING", message="Prompt update failed due to permission error", @@ -1162,7 +1162,7 @@ async def update_prompt( except IntegrityError as ie: db.rollback() logger.error(f"IntegrityErrors in group: {ie}") - + structured_logger.log( level="ERROR", message="Prompt update failed due to database integrity error", @@ -1178,7 +1178,7 @@ async def update_prompt( except PromptNotFoundError as e: db.rollback() logger.error(f"Prompt not found: {e}") - + structured_logger.log( level="ERROR", message="Prompt update failed - prompt not found", @@ -1194,7 +1194,7 @@ async def update_prompt( except PromptNameConflictError as pnce: db.rollback() logger.error(f"Prompt name conflict: {pnce}") - + structured_logger.log( level="WARNING", message="Prompt update failed due to name conflict", @@ -1209,7 +1209,7 @@ async def update_prompt( raise pnce except Exception as e: db.rollback() - + structured_logger.log( level="ERROR", message="Prompt update failed", @@ -1282,7 +1282,7 @@ async def toggle_prompt_status(self, db: Session, prompt_id: int, activate: bool else: await self._notify_prompt_deactivated(prompt) logger.info(f"Prompt {prompt.name} {'activated' if activate else 'deactivated'}") - + # Structured logging: Audit trail for prompt status toggle audit_trail.log_action( user_id=user_email or "system", @@ -1296,7 +1296,7 @@ async def toggle_prompt_status(self, db: Session, prompt_id: int, activate: bool context={"action": "activate" if activate else "deactivate"}, db=db, ) - + structured_logger.log( level="INFO", message=f"Prompt {'activated' if activate else 'deactivated'} successfully", @@ -1309,7 +1309,7 @@ async def toggle_prompt_status(self, db: Session, prompt_id: int, activate: bool custom_fields={"prompt_name": prompt.name, "is_active": prompt.is_active}, db=db, ) - + prompt.team = self._get_team_name(db, prompt.team_id) return PromptRead.model_validate(self._convert_db_prompt(prompt)) except PermissionError as e: @@ -1327,7 +1327,7 @@ async def toggle_prompt_status(self, db: Session, prompt_id: int, activate: bool raise e except Exception as e: db.rollback() - + structured_logger.log( level="ERROR", message="Prompt status toggle failed", @@ -1453,12 +1453,12 @@ async def delete_prompt(self, db: Session, prompt_id: Union[int, str], user_emai prompt_info = {"id": prompt.id, "name": prompt.name} prompt_name = prompt.name prompt_team_id = prompt.team_id - + db.delete(prompt) db.commit() await self._notify_prompt_deleted(prompt_info) logger.info(f"Deleted prompt: {prompt_info['name']}") - + # Structured logging: Audit trail for prompt deletion audit_trail.log_action( user_id=user_email or "system", @@ -1471,7 +1471,7 @@ async def delete_prompt(self, db: Session, prompt_id: Union[int, str], user_emai old_values={"name": prompt_name}, db=db, ) - + # Structured logging: Log successful prompt deletion structured_logger.log( level="INFO", @@ -1487,7 +1487,7 @@ async def delete_prompt(self, db: Session, prompt_id: Union[int, str], user_emai ) except PermissionError as pe: db.rollback() - + # Structured logging: Log permission error structured_logger.log( level="WARNING", @@ -1517,7 +1517,7 @@ async def delete_prompt(self, db: Session, prompt_id: Union[int, str], user_emai db=db, ) raise e - + # Structured logging: Log generic prompt deletion failure structured_logger.log( level="ERROR", diff --git a/mcpgateway/services/resource_service.py b/mcpgateway/services/resource_service.py index 2fb295d16..c17da434f 100644 --- a/mcpgateway/services/resource_service.py +++ b/mcpgateway/services/resource_service.py @@ -414,7 +414,7 @@ async def register_resource( await self._notify_resource_added(db_resource) logger.info(f"Registered resource: {resource.uri}") - + # Structured logging: Audit trail for resource creation audit_trail.log_action( user_id=created_by or "system", @@ -439,7 +439,7 @@ async def register_resource( }, db=db, ) - + # Structured logging: Log successful resource creation structured_logger.log( level="INFO", @@ -458,12 +458,12 @@ async def register_resource( }, db=db, ) - + db_resource.team = self._get_team_name(db, db_resource.team_id) return self._convert_resource_to_read(db_resource) except IntegrityError as ie: logger.error(f"IntegrityErrors in group: {ie}") - + # Structured logging: Log database integrity error structured_logger.log( level="ERROR", @@ -481,7 +481,7 @@ async def register_resource( raise ie except ResourceURIConflictError as rce: logger.error(f"ResourceURIConflictError in group: {resource.uri}") - + # Structured logging: Log URI conflict error structured_logger.log( level="WARNING", @@ -499,7 +499,7 @@ async def register_resource( raise rce except Exception as e: db.rollback() - + # Structured logging: Log generic resource creation failure structured_logger.log( level="ERROR", @@ -1569,7 +1569,7 @@ async def toggle_resource_status(self, db: Session, resource_id: int, activate: await self._notify_resource_deactivated(resource) logger.info(f"Resource {resource.uri} {'activated' if activate else 'deactivated'}") - + # Structured logging: Audit trail for resource status toggle audit_trail.log_action( user_id=user_email or "system", @@ -1587,7 +1587,7 @@ async def toggle_resource_status(self, db: Session, resource_id: int, activate: }, db=db, ) - + # Structured logging: Log successful resource status toggle structured_logger.log( level="INFO", @@ -1623,7 +1623,7 @@ async def toggle_resource_status(self, db: Session, resource_id: int, activate: raise e except Exception as e: db.rollback() - + # Structured logging: Log generic resource status toggle failure structured_logger.log( level="ERROR", @@ -1851,7 +1851,7 @@ async def update_resource( await self._notify_resource_updated(resource) logger.info(f"Updated resource: {resource.uri}") - + # Structured logging: Audit trail for resource update changes = [] if resource_update.uri: @@ -1860,7 +1860,7 @@ async def update_resource( changes.append(f"visibility: {resource_update.visibility}") if resource_update.description: changes.append("description updated") - + audit_trail.log_action( user_id=user_email or modified_by or "system", action="update_resource", @@ -1882,7 +1882,7 @@ async def update_resource( }, db=db, ) - + # Structured logging: Log successful resource update structured_logger.log( level="INFO", @@ -1900,11 +1900,11 @@ async def update_resource( }, db=db, ) - + return self._convert_resource_to_read(resource) except PermissionError as pe: db.rollback() - + # Structured logging: Log permission error structured_logger.log( level="WARNING", @@ -1921,7 +1921,7 @@ async def update_resource( except IntegrityError as ie: db.rollback() logger.error(f"IntegrityErrors in group: {ie}") - + # Structured logging: Log database integrity error structured_logger.log( level="ERROR", @@ -1938,7 +1938,7 @@ async def update_resource( raise ie except ResourceURIConflictError as pe: logger.error(f"Resource URI conflict: {pe}") - + # Structured logging: Log URI conflict error structured_logger.log( level="WARNING", @@ -1969,7 +1969,7 @@ async def update_resource( db=db, ) raise e - + # Structured logging: Log generic resource update failure structured_logger.log( level="ERROR", @@ -2044,7 +2044,7 @@ async def delete_resource(self, db: Session, resource_id: Union[int, str], user_ resource_uri = resource.uri resource_name = resource.name resource_team_id = resource.team_id - + db.delete(resource) db.commit() @@ -2052,7 +2052,7 @@ async def delete_resource(self, db: Session, resource_id: Union[int, str], user_ await self._notify_resource_deleted(resource_info) logger.info(f"Permanently deleted resource: {resource.uri}") - + # Structured logging: Audit trail for resource deletion audit_trail.log_action( user_id=user_email or "system", @@ -2068,7 +2068,7 @@ async def delete_resource(self, db: Session, resource_id: Union[int, str], user_ }, db=db, ) - + # Structured logging: Log successful resource deletion structured_logger.log( level="INFO", @@ -2087,7 +2087,7 @@ async def delete_resource(self, db: Session, resource_id: Union[int, str], user_ except PermissionError as pe: db.rollback() - + # Structured logging: Log permission error structured_logger.log( level="WARNING", @@ -2118,7 +2118,7 @@ async def delete_resource(self, db: Session, resource_id: Union[int, str], user_ raise except Exception as e: db.rollback() - + # Structured logging: Log generic resource deletion failure structured_logger.log( level="ERROR", diff --git a/mcpgateway/services/security_logger.py b/mcpgateway/services/security_logger.py index 2f846b44a..e4965fc20 100644 --- a/mcpgateway/services/security_logger.py +++ b/mcpgateway/services/security_logger.py @@ -22,12 +22,14 @@ # First-Party from mcpgateway.db import SecurityEvent, AuditTrail, SessionLocal from mcpgateway.utils.correlation_id import get_correlation_id +from mcpgateway.config import settings logger = logging.getLogger(__name__) class SecuritySeverity(str, Enum): """Security event severity levels.""" + LOW = "LOW" MEDIUM = "MEDIUM" HIGH = "HIGH" @@ -36,6 +38,7 @@ class SecuritySeverity(str, Enum): class SecurityEventType(str, Enum): """Types of security events.""" + AUTHENTICATION_FAILURE = "authentication_failure" AUTHENTICATION_SUCCESS = "authentication_success" AUTHORIZATION_FAILURE = "authorization_failure" @@ -51,17 +54,17 @@ class SecurityEventType(str, Enum): class SecurityLogger: """Specialized logger for security events and audit trails. - + Provides threat detection, security event logging, and audit trail management with automated analysis and alerting capabilities. """ - + def __init__(self): """Initialize security logger.""" self.failed_auth_threshold = getattr(settings, "security_failed_auth_threshold", 5) self.threat_score_alert_threshold = getattr(settings, "security_threat_score_alert", 0.7) self.rate_limit_window_minutes = getattr(settings, "security_rate_limit_window", 5) - + def log_authentication_attempt( self, user_id: str, @@ -72,10 +75,10 @@ def log_authentication_attempt( user_agent: Optional[str] = None, failure_reason: Optional[str] = None, additional_context: Optional[Dict[str, Any]] = None, - db: Optional[Session] = None + db: Optional[Session] = None, ) -> Optional[SecurityEvent]: """Log authentication attempts with security analysis. - + Args: user_id: User identifier user_email: User email address @@ -86,26 +89,18 @@ def log_authentication_attempt( failure_reason: Reason for failure if applicable additional_context: Additional event context db: Optional database session - + Returns: Created SecurityEvent or None if logging disabled """ correlation_id = get_correlation_id() - + # Count recent failed attempts - failed_attempts = self._count_recent_failures( - user_id=user_id, - client_ip=client_ip, - db=db - ) - + failed_attempts = self._count_recent_failures(user_id=user_id, client_ip=client_ip, db=db) + # Calculate threat score - threat_score = self._calculate_auth_threat_score( - success=success, - failed_attempts=failed_attempts, - auth_method=auth_method - ) - + threat_score = self._calculate_auth_threat_score(success=success, failed_attempts=failed_attempts, auth_method=auth_method) + # Determine severity if not success: if failed_attempts >= self.failed_auth_threshold: @@ -116,20 +111,15 @@ def log_authentication_attempt( severity = SecuritySeverity.LOW else: severity = SecuritySeverity.LOW - + # Build event description description = f"Authentication {'successful' if success else 'failed'} for user {user_id}" if not success and failure_reason: description += f": {failure_reason}" - + # Build context - context = { - "auth_method": auth_method, - "failed_attempts_recent": failed_attempts, - "user_agent": user_agent, - **(additional_context or {}) - } - + context = {"auth_method": auth_method, "failed_attempts_recent": failed_attempts, "user_agent": user_agent, **(additional_context or {})} + # Create security event event = self._create_security_event( event_type=SecurityEventType.AUTHENTICATION_SUCCESS if success else SecurityEventType.AUTHENTICATION_FAILURE, @@ -145,9 +135,9 @@ def log_authentication_attempt( context=context, action_taken="allowed" if success else "denied", correlation_id=correlation_id, - db=db + db=db, ) - + # Log to standard logger as well log_level = logging.WARNING if not success else logging.INFO logger.log( @@ -159,11 +149,11 @@ def log_authentication_attempt( "severity": severity.value, "threat_score": threat_score, "correlation_id": correlation_id, - } + }, ) - + return event - + def log_data_access( self, action: str, @@ -181,10 +171,10 @@ def log_data_access( new_values: Optional[Dict[str, Any]] = None, error_message: Optional[str] = None, additional_context: Optional[Dict[str, Any]] = None, - db: Optional[Session] = None + db: Optional[Session] = None, ) -> Optional[AuditTrail]: """Log data access for audit trails. - + Args: action: Action performed (create, read, update, delete, execute) resource_type: Type of resource accessed @@ -202,29 +192,20 @@ def log_data_access( error_message: Error message if failed additional_context: Additional context db: Optional database session - + Returns: Created AuditTrail entry or None """ correlation_id = get_correlation_id() - + # Determine if audit requires review - requires_review = self._requires_audit_review( - action=action, - resource_type=resource_type, - data_classification=data_classification, - success=success - ) - + requires_review = self._requires_audit_review(action=action, resource_type=resource_type, data_classification=data_classification, success=success) + # Calculate changes changes = None if old_values and new_values: - changes = { - k: {"old": old_values.get(k), "new": new_values.get(k)} - for k in set(old_values.keys()) | set(new_values.keys()) - if old_values.get(k) != new_values.get(k) - } - + changes = {k: {"old": old_values.get(k), "new": new_values.get(k)} for k in set(old_values.keys()) | set(new_values.keys()) if old_values.get(k) != new_values.get(k)} + # Create audit trail audit = self._create_audit_trail( action=action, @@ -245,9 +226,9 @@ def log_data_access( error_message=error_message, context=additional_context, correlation_id=correlation_id, - db=db + db=db, ) - + # Log sensitive data access as security event if data_classification in ["confidential", "restricted", "sensitive"]: self._create_security_event( @@ -267,11 +248,11 @@ def log_data_access( "data_classification": data_classification, }, correlation_id=correlation_id, - db=db + db=db, ) - + return audit - + def log_suspicious_activity( self, activity_type: str, @@ -285,10 +266,10 @@ def log_suspicious_activity( threat_indicators: Dict[str, Any], action_taken: str, additional_context: Optional[Dict[str, Any]] = None, - db: Optional[Session] = None + db: Optional[Session] = None, ) -> Optional[SecurityEvent]: """Log suspicious activity with threat analysis. - + Args: activity_type: Type of suspicious activity description: Event description @@ -302,12 +283,12 @@ def log_suspicious_activity( action_taken: Action taken in response additional_context: Additional context db: Optional database session - + Returns: Created SecurityEvent or None """ correlation_id = get_correlation_id() - + event = self._create_security_event( event_type=SecurityEventType.SUSPICIOUS_ACTIVITY, severity=severity, @@ -322,9 +303,9 @@ def log_suspicious_activity( action_taken=action_taken, context=additional_context, correlation_id=correlation_id, - db=db + db=db, ) - + logger.warning( f"Suspicious activity detected: {description}", extra={ @@ -334,80 +315,66 @@ def log_suspicious_activity( "threat_score": threat_score, "action_taken": action_taken, "correlation_id": correlation_id, - } + }, ) - + return event - - def _count_recent_failures( - self, - user_id: Optional[str] = None, - client_ip: Optional[str] = None, - minutes: Optional[int] = None, - db: Optional[Session] = None - ) -> int: + + def _count_recent_failures(self, user_id: Optional[str] = None, client_ip: Optional[str] = None, minutes: Optional[int] = None, db: Optional[Session] = None) -> int: """Count recent authentication failures. - + Args: user_id: User identifier client_ip: Client IP address minutes: Time window in minutes db: Optional database session - + Returns: Count of recent failures """ if not user_id and not client_ip: return 0 - + window_minutes = minutes or self.rate_limit_window_minutes since = datetime.now(timezone.utc) - timedelta(minutes=window_minutes) - + should_close = False if db is None: db = SessionLocal() should_close = True - + try: - stmt = select(func.count(SecurityEvent.id)).where( - SecurityEvent.event_type == SecurityEventType.AUTHENTICATION_FAILURE, - SecurityEvent.timestamp >= since - ) - + stmt = select(func.count(SecurityEvent.id)).where(SecurityEvent.event_type == SecurityEventType.AUTHENTICATION_FAILURE, SecurityEvent.timestamp >= since) + if user_id: stmt = stmt.where(SecurityEvent.user_id == user_id) if client_ip: stmt = stmt.where(SecurityEvent.client_ip == client_ip) - + result = db.execute(stmt).scalar() return result or 0 - + finally: if should_close: db.close() - - def _calculate_auth_threat_score( - self, - success: bool, - failed_attempts: int, - auth_method: str - ) -> float: + + def _calculate_auth_threat_score(self, success: bool, failed_attempts: int, auth_method: str) -> float: """Calculate threat score for authentication attempt. - + Args: success: Whether authentication succeeded failed_attempts: Count of recent failures auth_method: Authentication method used - + Returns: Threat score from 0.0 to 1.0 """ if success: return 0.0 - + # Base score for failure score = 0.3 - + # Increase based on failed attempts if failed_attempts >= 10: score += 0.5 @@ -415,42 +382,36 @@ def _calculate_auth_threat_score( score += 0.3 elif failed_attempts >= 3: score += 0.2 - + # Cap at 1.0 return min(score, 1.0) - - def _requires_audit_review( - self, - action: str, - resource_type: str, - data_classification: Optional[str], - success: bool - ) -> bool: + + def _requires_audit_review(self, action: str, resource_type: str, data_classification: Optional[str], success: bool) -> bool: """Determine if audit entry requires manual review. - + Args: action: Action performed resource_type: Resource type data_classification: Data classification success: Whether action succeeded - + Returns: True if review required """ # Failed actions on sensitive data require review if not success and data_classification in ["confidential", "restricted"]: return True - + # Deletions of sensitive data require review if action == "delete" and data_classification in ["confidential", "restricted"]: return True - + # Privilege modifications require review if resource_type in ["role", "permission", "team_member"]: return True - + return False - + def _create_security_event( self, event_type: str, @@ -467,10 +428,10 @@ def _create_security_event( threat_indicators: Optional[Dict[str, Any]] = None, context: Optional[Dict[str, Any]] = None, correlation_id: Optional[str] = None, - db: Optional[Session] = None + db: Optional[Session] = None, ) -> Optional[SecurityEvent]: """Create a security event record. - + Args: event_type: Type of security event severity: Event severity @@ -487,7 +448,7 @@ def _create_security_event( context: Additional context correlation_id: Correlation ID db: Optional database session - + Returns: Created SecurityEvent or None """ @@ -495,7 +456,7 @@ def _create_security_event( if db is None: db = SessionLocal() should_close = True - + try: event = SecurityEvent( event_type=event_type, @@ -513,22 +474,22 @@ def _create_security_event( context=context, correlation_id=correlation_id, ) - + db.add(event) db.commit() db.refresh(event) - + return event - + except Exception as e: logger.error(f"Failed to create security event: {e}") db.rollback() return None - + finally: if should_close: db.close() - + def _create_audit_trail( self, action: str, @@ -549,10 +510,10 @@ def _create_audit_trail( error_message: Optional[str] = None, context: Optional[Dict[str, Any]] = None, correlation_id: Optional[str] = None, - db: Optional[Session] = None + db: Optional[Session] = None, ) -> Optional[AuditTrail]: """Create an audit trail record. - + Args: action: Action performed resource_type: Resource type @@ -573,7 +534,7 @@ def _create_audit_trail( context: Additional context correlation_id: Correlation ID db: Optional database session - + Returns: Created AuditTrail or None """ @@ -581,7 +542,7 @@ def _create_audit_trail( if db is None: db = SessionLocal() should_close = True - + try: audit = AuditTrail( action=action, @@ -603,18 +564,18 @@ def _create_audit_trail( context=context, correlation_id=correlation_id, ) - + db.add(audit) db.commit() db.refresh(audit) - + return audit - + except Exception as e: logger.error(f"Failed to create audit trail: {e}") db.rollback() return None - + finally: if should_close: db.close() @@ -626,7 +587,7 @@ def _create_audit_trail( def get_security_logger() -> SecurityLogger: """Get or create the global security logger instance. - + Returns: Global SecurityLogger instance """ @@ -634,7 +595,3 @@ def get_security_logger() -> SecurityLogger: if _security_logger is None: _security_logger = SecurityLogger() return _security_logger - - -# Import settings here to avoid circular imports -from mcpgateway.config import settings diff --git a/mcpgateway/services/server_service.py b/mcpgateway/services/server_service.py index 8ec5147ef..b10621725 100644 --- a/mcpgateway/services/server_service.py +++ b/mcpgateway/services/server_service.py @@ -555,7 +555,7 @@ async def register_server( logger.debug(f"Server Data: {server_data}") await self._notify_server_added(db_server) logger.info(f"Registered server: {server_in.name}") - + # Structured logging: Audit trail for server creation self._audit_trail.log_action( user_id=created_by or "system", @@ -577,7 +577,7 @@ async def register_server( "created_user_agent": created_user_agent, }, ) - + # Structured logging: Log successful server creation self._structured_logger.log( level="INFO", @@ -590,13 +590,13 @@ async def register_server( created_by=created_by, user_email=created_by, ) - + db_server.team = self._get_team_name(db, db_server.team_id) return self._convert_server_to_read(db_server) except IntegrityError as ie: db.rollback() logger.error(f"IntegrityErrors in group: {ie}") - + # Structured logging: Log database integrity error self._structured_logger.log( level="ERROR", @@ -612,7 +612,7 @@ async def register_server( raise ie except ServerNameConflictError as se: db.rollback() - + # Structured logging: Log name conflict error self._structured_logger.log( level="WARNING", @@ -627,7 +627,7 @@ async def register_server( raise se except Exception as ex: db.rollback() - + # Structured logging: Log generic server creation failure self._structured_logger.log( level="ERROR", @@ -1047,7 +1047,7 @@ async def update_server( changes.append(f"visibility: {server_update.visibility}") if server_update.team_id: changes.append(f"team_id: {server_update.team_id}") - + self._audit_trail.log_action( user_id=user_email or "system", action="update_server", @@ -1064,7 +1064,7 @@ async def update_server( "modified_user_agent": modified_user_agent, }, ) - + # Structured logging: Log successful server update self._structured_logger.log( level="INFO", @@ -1096,7 +1096,7 @@ async def update_server( except IntegrityError as ie: db.rollback() logger.error(f"IntegrityErrors in group: {ie}") - + # Structured logging: Log database integrity error self._structured_logger.log( level="ERROR", @@ -1113,7 +1113,7 @@ async def update_server( except ServerNameConflictError as snce: db.rollback() logger.error(f"Server name conflict: {snce}") - + # Structured logging: Log name conflict error self._structured_logger.log( level="WARNING", @@ -1127,7 +1127,7 @@ async def update_server( raise snce except Exception as e: db.rollback() - + # Structured logging: Log generic server update failure self._structured_logger.log( level="ERROR", @@ -1200,7 +1200,7 @@ async def toggle_server_status(self, db: Session, server_id: str, activate: bool else: await self._notify_server_deactivated(server) logger.info(f"Server {server.name} {'activated' if activate else 'deactivated'}") - + # Structured logging: Audit trail for server status toggle self._audit_trail.log_action( user_id=user_email or "system", @@ -1212,7 +1212,7 @@ async def toggle_server_status(self, db: Session, server_id: str, activate: bool "new_status": "active" if activate else "inactive", }, ) - + # Structured logging: Log server status change self._structured_logger.log( level="INFO", @@ -1254,7 +1254,7 @@ async def toggle_server_status(self, db: Session, server_id: str, activate: bool raise e except Exception as e: db.rollback() - + # Structured logging: Log generic server status toggle failure self._structured_logger.log( level="ERROR", @@ -1314,7 +1314,7 @@ async def delete_server(self, db: Session, server_id: str, user_email: Optional[ await self._notify_server_deleted(server_info) logger.info(f"Deleted server: {server_info['name']}") - + # Structured logging: Audit trail for server deletion self._audit_trail.log_action( user_id=user_email or "system", @@ -1325,7 +1325,7 @@ async def delete_server(self, db: Session, server_id: str, user_email: Optional[ "server_name": server_info["name"], }, ) - + # Structured logging: Log successful server deletion self._structured_logger.log( level="INFO", @@ -1339,7 +1339,7 @@ async def delete_server(self, db: Session, server_id: str, user_email: Optional[ ) except PermissionError as pe: db.rollback() - + # Structured logging: Log permission error self._structured_logger.log( level="WARNING", @@ -1352,7 +1352,7 @@ async def delete_server(self, db: Session, server_id: str, user_email: Optional[ raise pe except Exception as e: db.rollback() - + # Structured logging: Log generic server deletion failure self._structured_logger.log( level="ERROR", diff --git a/mcpgateway/services/structured_logger.py b/mcpgateway/services/structured_logger.py index 92814c253..99d42a967 100644 --- a/mcpgateway/services/structured_logger.py +++ b/mcpgateway/services/structured_logger.py @@ -15,7 +15,6 @@ import logging import os import socket -import sys import traceback from typing import Any, Dict, List, Optional, Union @@ -33,6 +32,7 @@ class LogLevel(str, Enum): """Log levels matching Python logging.""" + DEBUG = "DEBUG" INFO = "INFO" WARNING = "WARNING" @@ -42,6 +42,7 @@ class LogLevel(str, Enum): class LogCategory(str, Enum): """Log categories for classification.""" + APPLICATION = "application" REQUEST = "request" SECURITY = "security" @@ -56,14 +57,14 @@ class LogCategory(str, Enum): class LogEnricher: """Enriches log entries with contextual information.""" - + @staticmethod def enrich(entry: Dict[str, Any]) -> Dict[str, Any]: """Enrich log entry with system and context information. - + Args: entry: Base log entry - + Returns: Enriched log entry """ @@ -71,29 +72,30 @@ def enrich(entry: Dict[str, Any]) -> Dict[str, Any]: correlation_id = get_correlation_id() if correlation_id: entry["correlation_id"] = correlation_id - + # Add hostname and process info entry.setdefault("hostname", socket.gethostname()) entry.setdefault("process_id", os.getpid()) - + # Add timestamp if not present if "timestamp" not in entry: entry["timestamp"] = datetime.now(timezone.utc) - + # Add performance metrics if available try: perf_tracker = get_performance_tracker() - if correlation_id and perf_tracker and hasattr(perf_tracker, 'get_current_operations'): + if correlation_id and perf_tracker and hasattr(perf_tracker, "get_current_operations"): current_ops = perf_tracker.get_current_operations(correlation_id) if current_ops: entry["active_operations"] = len(current_ops) except Exception: # Silently skip if performance tracker is unavailable or method doesn't exist pass - + # Add OpenTelemetry trace context if available try: from opentelemetry import trace + span = trace.get_current_span() if span and span.get_span_context().is_valid: ctx = span.get_span_context() @@ -101,61 +103,58 @@ def enrich(entry: Dict[str, Any]) -> Dict[str, Any]: entry["span_id"] = format(ctx.span_id, "016x") except (ImportError, Exception): pass - + return entry class LogRouter: """Routes log entries to appropriate destinations.""" - + def __init__(self): """Initialize log router.""" self.database_enabled = getattr(settings, "structured_logging_database_enabled", True) self.external_enabled = getattr(settings, "structured_logging_external_enabled", False) - + def route(self, entry: Dict[str, Any], db: Optional[Session] = None) -> None: """Route log entry to configured destinations. - + Args: entry: Log entry to route db: Optional database session """ # Always log to standard Python logger self._log_to_python_logger(entry) - + # Persist to database if enabled if self.database_enabled: self._persist_to_database(entry, db) - + # Send to external systems if enabled if self.external_enabled: self._send_to_external(entry) - + def _log_to_python_logger(self, entry: Dict[str, Any]) -> None: """Log to standard Python logger. - + Args: entry: Log entry """ level_str = entry.get("level", "INFO") level = getattr(logging, level_str, logging.INFO) - + message = entry.get("message", "") component = entry.get("component", "") - + log_message = f"[{component}] {message}" if component else message - + # Build extra dict for structured logging - extra = { - k: v for k, v in entry.items() - if k not in ["message", "level"] - } - + extra = {k: v for k, v in entry.items() if k not in ["message", "level"]} + logger.log(level, log_message, extra=extra) - + def _persist_to_database(self, entry: Dict[str, Any], db: Optional[Session] = None) -> None: """Persist log entry to database. - + Args: entry: Log entry db: Optional database session @@ -164,7 +163,7 @@ def _persist_to_database(self, entry: Dict[str, Any], db: Optional[Session] = No if db is None: db = SessionLocal() should_close = True - + try: # Build error_details JSON from error-related fields error_details = None @@ -175,7 +174,7 @@ def _persist_to_database(self, entry: Dict[str, Any], db: Optional[Session] = No "error_stack_trace": entry.get("error_stack_trace"), "error_context": entry.get("error_context"), } - + # Build performance_metrics JSON from performance-related fields performance_metrics = None perf_fields = { @@ -190,7 +189,7 @@ def _persist_to_database(self, entry: Dict[str, Any], db: Optional[Session] = No } if any(v is not None for v in perf_fields.values()): performance_metrics = {k: v for k, v in perf_fields.items() if v is not None} - + # Build threat_indicators JSON from security-related fields threat_indicators = None security_fields = { @@ -200,7 +199,7 @@ def _persist_to_database(self, entry: Dict[str, Any], db: Optional[Session] = No } if any(v is not None for v in security_fields.values()): threat_indicators = {k: v for k, v in security_fields.items() if v is not None} - + # Build context JSON from remaining fields context_fields = { "team_id": entry.get("team_id"), @@ -222,11 +221,11 @@ def _persist_to_database(self, entry: Dict[str, Any], db: Optional[Session] = No "metadata": entry.get("metadata"), } context = {k: v for k, v in context_fields.items() if v is not None} - + # Determine if this is a security event is_security_event = entry.get("is_security_event", False) or bool(threat_indicators) security_severity = entry.get("security_severity") - + log_entry = StructuredLogEntry( timestamp=entry.get("timestamp", datetime.now(timezone.utc)), level=entry.get("level", "INFO"), @@ -256,47 +255,47 @@ def _persist_to_database(self, entry: Dict[str, Any], db: Optional[Session] = No environment=entry.get("environment", getattr(settings, "environment", "development")), version=entry.get("version", getattr(settings, "version", "unknown")), ) - + db.add(log_entry) db.commit() - + except Exception as e: logger.error(f"Failed to persist log entry to database: {e}", exc_info=True) # Also print to console for immediate visibility import traceback + print(f"ERROR persisting log to database: {e}") traceback.print_exc() if db: db.rollback() - + finally: if should_close: db.close() - + def _send_to_external(self, entry: Dict[str, Any]) -> None: """Send log entry to external systems. - + Args: entry: Log entry """ # Placeholder for external logging integration # Will be implemented in log exporters - pass class StructuredLogger: """Main structured logger with enrichment and routing.""" - + def __init__(self, component: str): """Initialize structured logger. - + Args: component: Component name for log entries """ self.component = component self.enricher = LogEnricher() self.router = LogRouter() - + def log( self, level: Union[LogLevel, str], @@ -310,10 +309,10 @@ def log( custom_fields: Optional[Dict[str, Any]] = None, tags: Optional[List[str]] = None, db: Optional[Session] = None, - **kwargs: Any + **kwargs: Any, ) -> None: """Log a structured message. - + Args: level: Log level message: Log message @@ -341,38 +340,38 @@ def log( "custom_fields": custom_fields, "tags": tags, } - + # Add error information if present if error: entry["error_type"] = type(error).__name__ entry["error_message"] = str(error) entry["error_stack_trace"] = "".join(traceback.format_exception(type(error), error, error.__traceback__)) - + # Add any additional kwargs entry.update(kwargs) - + # Enrich entry with context entry = self.enricher.enrich(entry) - + # Route to destinations self.router.route(entry, db) - + def debug(self, message: str, **kwargs: Any) -> None: """Log debug message.""" self.log(LogLevel.DEBUG, message, **kwargs) - + def info(self, message: str, **kwargs: Any) -> None: """Log info message.""" self.log(LogLevel.INFO, message, **kwargs) - + def warning(self, message: str, **kwargs: Any) -> None: """Log warning message.""" self.log(LogLevel.WARNING, message, **kwargs) - + def error(self, message: str, error: Optional[Exception] = None, **kwargs: Any) -> None: """Log error message.""" self.log(LogLevel.ERROR, message, error=error, **kwargs) - + def critical(self, message: str, error: Optional[Exception] = None, **kwargs: Any) -> None: """Log critical message.""" self.log(LogLevel.CRITICAL, message, error=error, **kwargs) @@ -380,23 +379,23 @@ def critical(self, message: str, error: Optional[Exception] = None, **kwargs: An class ComponentLogger: """Logger factory for component-specific loggers.""" - + _loggers: Dict[str, StructuredLogger] = {} - + @classmethod def get_logger(cls, component: str) -> StructuredLogger: """Get or create a logger for a specific component. - + Args: component: Component name - + Returns: StructuredLogger instance for the component """ if component not in cls._loggers: cls._loggers[component] = StructuredLogger(component) return cls._loggers[component] - + @classmethod def clear_loggers(cls) -> None: """Clear all cached loggers (useful for testing).""" @@ -406,10 +405,10 @@ def clear_loggers(cls) -> None: # Global structured logger instance for backward compatibility def get_structured_logger(component: str = "mcpgateway") -> StructuredLogger: """Get a structured logger instance. - + Args: component: Component name - + Returns: StructuredLogger instance """ diff --git a/mcpgateway/services/tool_service.py b/mcpgateway/services/tool_service.py index b63ff4b60..b186b0bb3 100644 --- a/mcpgateway/services/tool_service.py +++ b/mcpgateway/services/tool_service.py @@ -69,7 +69,7 @@ from mcpgateway.services.logging_service import LoggingService from mcpgateway.services.oauth_manager import OAuthManager from mcpgateway.services.performance_tracker import get_performance_tracker -from mcpgateway.services.structured_logger import get_structured_logger, LogCategory +from mcpgateway.services.structured_logger import get_structured_logger from mcpgateway.services.team_management_service import TeamManagementService from mcpgateway.utils.create_slug import slugify from mcpgateway.utils.display_name import generate_display_name @@ -719,7 +719,7 @@ async def register_tool( db.commit() db.refresh(db_tool) await self._notify_tool_added(db_tool) - + # Structured logging: Audit trail for tool creation audit_trail.log_action( user_id=created_by or "system", @@ -744,7 +744,7 @@ async def register_tool( }, db=db, ) - + # Structured logging: Log successful tool creation structured_logger.log( level="INFO", @@ -763,12 +763,12 @@ async def register_tool( }, db=db, ) - + return self._convert_tool_to_read(db_tool) except IntegrityError as ie: db.rollback() logger.error(f"IntegrityError during tool registration: {ie}") - + # Structured logging: Log database integrity error structured_logger.log( level="ERROR", @@ -787,7 +787,7 @@ async def register_tool( except ToolNameConflictError as tnce: db.rollback() logger.error(f"ToolNameConflictError during tool registration: {tnce}") - + # Structured logging: Log name conflict error structured_logger.log( level="WARNING", @@ -805,7 +805,7 @@ async def register_tool( raise tnce except Exception as e: db.rollback() - + # Structured logging: Log generic tool creation failure structured_logger.log( level="ERROR", @@ -1172,12 +1172,12 @@ async def delete_tool(self, db: Session, tool_id: str, user_email: Optional[str] tool_info = {"id": tool.id, "name": tool.name} tool_name = tool.name tool_team_id = tool.team_id - + db.delete(tool) db.commit() await self._notify_tool_deleted(tool_info) logger.info(f"Permanently deleted tool: {tool_info['name']}") - + # Structured logging: Audit trail for tool deletion audit_trail.log_action( user_id=user_email or "system", @@ -1192,7 +1192,7 @@ async def delete_tool(self, db: Session, tool_id: str, user_email: Optional[str] }, db=db, ) - + # Structured logging: Log successful tool deletion structured_logger.log( level="INFO", @@ -1210,7 +1210,7 @@ async def delete_tool(self, db: Session, tool_id: str, user_email: Optional[str] ) except PermissionError as pe: db.rollback() - + # Structured logging: Log permission error structured_logger.log( level="WARNING", @@ -1226,7 +1226,7 @@ async def delete_tool(self, db: Session, tool_id: str, user_email: Optional[str] raise except Exception as e: db.rollback() - + # Structured logging: Log generic tool deletion failure structured_logger.log( level="ERROR", @@ -1317,7 +1317,7 @@ async def toggle_tool_status(self, db: Session, tool_id: str, activate: bool, re await self._notify_tool_activated(tool) logger.info(f"Tool: {tool.name} is {'enabled' if activate else 'disabled'}{' and accessible' if reachable else ' but inaccessible'}") - + # Structured logging: Audit trail for tool status toggle audit_trail.log_action( user_id=user_email or "system", @@ -1336,7 +1336,7 @@ async def toggle_tool_status(self, db: Session, tool_id: str, activate: bool, re }, db=db, ) - + # Structured logging: Log successful tool status toggle structured_logger.log( level="INFO", @@ -1354,7 +1354,7 @@ async def toggle_tool_status(self, db: Session, tool_id: str, activate: bool, re }, db=db, ) - + return self._convert_tool_to_read(tool) except PermissionError as e: # Structured logging: Log permission error @@ -1372,7 +1372,7 @@ async def toggle_tool_status(self, db: Session, tool_id: str, activate: bool, re raise e except Exception as e: db.rollback() - + # Structured logging: Log generic tool status toggle failure structured_logger.log( level="ERROR", @@ -1689,11 +1689,11 @@ async def connect_to_sse_server(server_url: str, headers: dict = headers): """ # Get correlation ID for distributed tracing correlation_id = get_correlation_id() - + # Add correlation ID to headers if correlation_id and headers: headers["X-Correlation-ID"] = correlation_id - + # Log MCP call start mcp_start_time = time.time() structured_logger.log( @@ -1701,21 +1701,15 @@ async def connect_to_sse_server(server_url: str, headers: dict = headers): message=f"MCP tool call started: {tool.original_name}", component="tool_service", correlation_id=correlation_id, - metadata={ - "event": "mcp_call_started", - "tool_name": tool.original_name, - "tool_id": tool.id, - "server_url": server_url, - "transport": "sse" - } + metadata={"event": "mcp_call_started", "tool_name": tool.original_name, "tool_id": tool.id, "server_url": server_url, "transport": "sse"}, ) - + try: async with sse_client(url=server_url, headers=headers, httpx_client_factory=get_httpx_client_factory) as streams: async with ClientSession(*streams) as session: await session.initialize() tool_call_result = await session.call_tool(tool.original_name, arguments) - + # Log successful MCP call mcp_duration_ms = (time.time() - mcp_start_time) * 1000 structured_logger.log( @@ -1724,15 +1718,9 @@ async def connect_to_sse_server(server_url: str, headers: dict = headers): component="tool_service", correlation_id=correlation_id, duration_ms=mcp_duration_ms, - metadata={ - "event": "mcp_call_completed", - "tool_name": tool.original_name, - "tool_id": tool.id, - "transport": "sse", - "success": True - } + metadata={"event": "mcp_call_completed", "tool_name": tool.original_name, "tool_id": tool.id, "transport": "sse", "success": True}, ) - + return tool_call_result except Exception as e: # Log failed MCP call @@ -1743,16 +1731,8 @@ async def connect_to_sse_server(server_url: str, headers: dict = headers): component="tool_service", correlation_id=correlation_id, duration_ms=mcp_duration_ms, - error_details={ - "error_type": type(e).__name__, - "error_message": str(e) - }, - metadata={ - "event": "mcp_call_failed", - "tool_name": tool.original_name, - "tool_id": tool.id, - "transport": "sse" - } + error_details={"error_type": type(e).__name__, "error_message": str(e)}, + metadata={"event": "mcp_call_failed", "tool_name": tool.original_name, "tool_id": tool.id, "transport": "sse"}, ) raise @@ -1768,11 +1748,11 @@ async def connect_to_streamablehttp_server(server_url: str, headers: dict = head """ # Get correlation ID for distributed tracing correlation_id = get_correlation_id() - + # Add correlation ID to headers if correlation_id and headers: headers["X-Correlation-ID"] = correlation_id - + # Log MCP call start mcp_start_time = time.time() structured_logger.log( @@ -1780,21 +1760,15 @@ async def connect_to_streamablehttp_server(server_url: str, headers: dict = head message=f"MCP tool call started: {tool.original_name}", component="tool_service", correlation_id=correlation_id, - metadata={ - "event": "mcp_call_started", - "tool_name": tool.original_name, - "tool_id": tool.id, - "server_url": server_url, - "transport": "streamablehttp" - } + metadata={"event": "mcp_call_started", "tool_name": tool.original_name, "tool_id": tool.id, "server_url": server_url, "transport": "streamablehttp"}, ) - + try: async with streamablehttp_client(url=server_url, headers=headers, httpx_client_factory=get_httpx_client_factory) as (read_stream, write_stream, _get_session_id): async with ClientSession(read_stream, write_stream) as session: await session.initialize() tool_call_result = await session.call_tool(tool.original_name, arguments) - + # Log successful MCP call mcp_duration_ms = (time.time() - mcp_start_time) * 1000 structured_logger.log( @@ -1803,15 +1777,9 @@ async def connect_to_streamablehttp_server(server_url: str, headers: dict = head component="tool_service", correlation_id=correlation_id, duration_ms=mcp_duration_ms, - metadata={ - "event": "mcp_call_completed", - "tool_name": tool.original_name, - "tool_id": tool.id, - "transport": "streamablehttp", - "success": True - } + metadata={"event": "mcp_call_completed", "tool_name": tool.original_name, "tool_id": tool.id, "transport": "streamablehttp", "success": True}, ) - + return tool_call_result except Exception as e: # Log failed MCP call @@ -1822,16 +1790,8 @@ async def connect_to_streamablehttp_server(server_url: str, headers: dict = head component="tool_service", correlation_id=correlation_id, duration_ms=mcp_duration_ms, - error_details={ - "error_type": type(e).__name__, - "error_message": str(e) - }, - metadata={ - "event": "mcp_call_failed", - "tool_name": tool.original_name, - "tool_id": tool.id, - "transport": "streamablehttp" - } + error_details={"error_type": type(e).__name__, "error_message": str(e)}, + metadata={"event": "mcp_call_failed", "tool_name": tool.original_name, "tool_id": tool.id, "transport": "streamablehttp"}, ) raise @@ -1915,15 +1875,15 @@ async def connect_to_streamablehttp_server(server_url: str, headers: dict = head finally: # Calculate duration duration_ms = (time.monotonic() - start_time) * 1000 - + # Add final span attributes if span: span.set_attribute("success", success) span.set_attribute("duration.ms", duration_ms) - + # Record tool metric await self._record_tool_metric(db, tool, start_time, success, error_message) - + # Log structured message with performance tracking if success: structured_logger.info( @@ -1933,11 +1893,7 @@ async def connect_to_streamablehttp_server(server_url: str, headers: dict = head resource_id=str(tool.id), resource_action="invoke", duration_ms=duration_ms, - custom_fields={ - "tool_name": name, - "integration_type": tool.integration_type, - "arguments_count": len(arguments) if arguments else 0 - } + custom_fields={"tool_name": name, "integration_type": tool.integration_type, "arguments_count": len(arguments) if arguments else 0}, ) else: structured_logger.error( @@ -1948,13 +1904,9 @@ async def connect_to_streamablehttp_server(server_url: str, headers: dict = head resource_id=str(tool.id), resource_action="invoke", duration_ms=duration_ms, - custom_fields={ - "tool_name": name, - "integration_type": tool.integration_type, - "error_message": error_message - } + custom_fields={"tool_name": name, "integration_type": tool.integration_type, "error_message": error_message}, ) - + # Track performance with threshold checking with perf_tracker.track_operation("tool_invocation", name): pass # Duration already captured above @@ -2103,7 +2055,7 @@ async def update_tool( db.refresh(tool) await self._notify_tool_updated(tool) logger.info(f"Updated tool: {tool.name}") - + # Structured logging: Audit trail for tool update changes = [] if tool_update.name: @@ -2111,8 +2063,8 @@ async def update_tool( if tool_update.visibility: changes.append(f"visibility: {tool_update.visibility}") if tool_update.description: - changes.append(f"description updated") - + changes.append("description updated") + audit_trail.log_action( user_id=user_email or modified_by or "system", action="update_tool", @@ -2134,7 +2086,7 @@ async def update_tool( }, db=db, ) - + # Structured logging: Log successful tool update structured_logger.log( level="INFO", @@ -2152,11 +2104,11 @@ async def update_tool( }, db=db, ) - + return self._convert_tool_to_read(tool) except PermissionError as pe: db.rollback() - + # Structured logging: Log permission error structured_logger.log( level="WARNING", @@ -2173,7 +2125,7 @@ async def update_tool( except IntegrityError as ie: db.rollback() logger.error(f"IntegrityError during tool update: {ie}") - + # Structured logging: Log database integrity error structured_logger.log( level="ERROR", @@ -2191,7 +2143,7 @@ async def update_tool( except ToolNotFoundError as tnfe: db.rollback() logger.error(f"Tool not found during update: {tnfe}") - + # Structured logging: Log not found error structured_logger.log( level="ERROR", @@ -2208,7 +2160,7 @@ async def update_tool( except ToolNameConflictError as tnce: db.rollback() logger.error(f"Tool name conflict during update: {tnce}") - + # Structured logging: Log name conflict error structured_logger.log( level="WARNING", @@ -2225,7 +2177,7 @@ async def update_tool( raise tnce except Exception as ex: db.rollback() - + # Structured logging: Log generic tool update failure structured_logger.log( level="ERROR", diff --git a/mcpgateway/utils/correlation_id.py b/mcpgateway/utils/correlation_id.py index 2ff58f2bc..6701405e3 100644 --- a/mcpgateway/utils/correlation_id.py +++ b/mcpgateway/utils/correlation_id.py @@ -29,7 +29,7 @@ # Context variable for storing correlation ID (request ID) per-request # This is async-safe and provides automatic isolation between concurrent requests -_correlation_id_context: ContextVar[Optional[str]] = ContextVar('correlation_id', default=None) +_correlation_id_context: ContextVar[Optional[str]] = ContextVar("correlation_id", default=None) def get_correlation_id() -> Optional[str]: @@ -168,7 +168,7 @@ def validate_correlation_id(correlation_id: Optional[str], max_length: int = 255 return False # Allow alphanumeric, hyphens, and underscores only - if not all(c.isalnum() or c in ('-', '_') for c in correlation_id): + if not all(c.isalnum() or c in ("-", "_") for c in correlation_id): logger.warning(f"Correlation ID contains invalid characters: {correlation_id}") return False From 627610e3efc1cc5fbeb61f51ad55cf00cd5895e1 Mon Sep 17 00:00:00 2001 From: Shoumi Date: Wed, 26 Nov 2025 18:33:09 +0530 Subject: [PATCH 16/34] test fixes Signed-off-by: Shoumi --- .../middleware/request_logging_middleware.py | 7 ++--- .../framework/external/mcp/tls_utils.py | 2 +- mcpgateway/routers/log_search.py | 8 +++++ mcpgateway/services/a2a_service.py | 3 +- mcpgateway/services/audit_trail_service.py | 4 +-- mcpgateway/services/log_aggregator.py | 2 +- mcpgateway/services/performance_tracker.py | 12 +++---- mcpgateway/services/security_logger.py | 10 +++--- mcpgateway/services/structured_logger.py | 6 ++-- mcpgateway/static/admin.js | 31 ------------------- mcpgateway/translate_grpc.py | 2 +- mcpgateway/utils/retry_manager.py | 2 +- 12 files changed, 31 insertions(+), 58 deletions(-) diff --git a/mcpgateway/middleware/request_logging_middleware.py b/mcpgateway/middleware/request_logging_middleware.py index d16e276a8..499effb74 100644 --- a/mcpgateway/middleware/request_logging_middleware.py +++ b/mcpgateway/middleware/request_logging_middleware.py @@ -393,9 +393,8 @@ def _categorize_response_time(duration_ms: float) -> str: """ if duration_ms < 100: return "fast" - elif duration_ms < 500: + if duration_ms < 500: return "normal" - elif duration_ms < 2000: + if duration_ms < 2000: return "slow" - else: - return "very_slow" + return "very_slow" diff --git a/mcpgateway/plugins/framework/external/mcp/tls_utils.py b/mcpgateway/plugins/framework/external/mcp/tls_utils.py index 91b04cfb0..befac4e51 100644 --- a/mcpgateway/plugins/framework/external/mcp/tls_utils.py +++ b/mcpgateway/plugins/framework/external/mcp/tls_utils.py @@ -86,7 +86,7 @@ def create_ssl_context(tls_config: MCPClientTLSConfig, plugin_name: str) -> ssl. # Disable certificate verification (not recommended for production) logger.warning(f"Certificate verification disabled for plugin '{plugin_name}'. This is not recommended for production use.") ssl_context.check_hostname = False - ssl_context.verify_mode = ssl.CERT_NONE # noqa: DUO122 + ssl_context.verify_mode = ssl.CERT_NONE # nosec else: # Enable strict certificate verification (production mode) # Load CA certificate bundle for server certificate validation diff --git a/mcpgateway/routers/log_search.py b/mcpgateway/routers/log_search.py index 1fe74a15c..b2ffc611e 100644 --- a/mcpgateway/routers/log_search.py +++ b/mcpgateway/routers/log_search.py @@ -171,6 +171,8 @@ class LogEntry(BaseModel): error_details: Optional[Dict[str, Any]] = None class Config: + """Pydantic configuration.""" + from_attributes = True @@ -216,6 +218,8 @@ class SecurityEventResponse(BaseModel): resolved: bool class Config: + """Pydantic configuration.""" + from_attributes = True @@ -236,6 +240,8 @@ class AuditTrailResponse(BaseModel): data_classification: Optional[str] class Config: + """Pydantic configuration.""" + from_attributes = True @@ -259,6 +265,8 @@ class PerformanceMetricResponse(BaseModel): p99_duration_ms: float class Config: + """Pydantic configuration.""" + from_attributes = True diff --git a/mcpgateway/services/a2a_service.py b/mcpgateway/services/a2a_service.py index 526c478b1..6fa2b5774 100644 --- a/mcpgateway/services/a2a_service.py +++ b/mcpgateway/services/a2a_service.py @@ -29,6 +29,7 @@ from mcpgateway.services.structured_logger import get_structured_logger from mcpgateway.services.team_management_service import TeamManagementService from mcpgateway.services.tool_service import ToolService +from mcpgateway.utils.correlation_id import get_correlation_id from mcpgateway.utils.create_slug import slugify from mcpgateway.utils.services_auth import encode_auth # ,decode_auth @@ -864,8 +865,6 @@ async def invoke_agent( headers["Authorization"] = f"Bearer {token_value}" # Add correlation ID to outbound headers for distributed tracing - from mcpgateway.utils.correlation_id import get_correlation_id - correlation_id = get_correlation_id() if correlation_id: headers["X-Correlation-ID"] = correlation_id diff --git a/mcpgateway/services/audit_trail_service.py b/mcpgateway/services/audit_trail_service.py index 1e2b83ee5..c4bcf83a9 100644 --- a/mcpgateway/services/audit_trail_service.py +++ b/mcpgateway/services/audit_trail_service.py @@ -68,7 +68,7 @@ class AuditTrailService: def __init__(self): """Initialize audit trail service.""" - def log_action( + def log_action( # pylint: disable=too-many-positional-arguments self, action: str, resource_type: str, @@ -430,7 +430,7 @@ def get_audit_trail_service() -> AuditTrailService: Returns: AuditTrailService instance """ - global _audit_trail_service + global _audit_trail_service # pylint: disable=global-statement if _audit_trail_service is None: _audit_trail_service = AuditTrailService() return _audit_trail_service diff --git a/mcpgateway/services/log_aggregator.py b/mcpgateway/services/log_aggregator.py index 9b70317dc..9a8bd5bf2 100644 --- a/mcpgateway/services/log_aggregator.py +++ b/mcpgateway/services/log_aggregator.py @@ -476,7 +476,7 @@ def get_log_aggregator() -> LogAggregator: Returns: Global LogAggregator instance """ - global _log_aggregator + global _log_aggregator # pylint: disable=global-statement if _log_aggregator is None: _log_aggregator = LogAggregator() return _log_aggregator diff --git a/mcpgateway/services/performance_tracker.py b/mcpgateway/services/performance_tracker.py index cfa30bcdf..ef0996c0b 100644 --- a/mcpgateway/services/performance_tracker.py +++ b/mcpgateway/services/performance_tracker.py @@ -166,14 +166,14 @@ def get_performance_summary(self, operation_name: Optional[str] = None, min_samp sorted_timings = sorted(timings) count = len(sorted_timings) - def percentile(p: float) -> float: + def percentile(p: float, *, sorted_vals=sorted_timings, n=count) -> float: """Calculate percentile value.""" - k = (count - 1) * p + k = (n - 1) * p f = int(k) c = k - f - if f + 1 < count: - return sorted_timings[f] * (1 - c) + sorted_timings[f + 1] * c - return sorted_timings[f] + if f + 1 < n: + return sorted_vals[f] * (1 - c) + sorted_vals[f + 1] * c + return sorted_vals[f] summary[op_name] = { "count": count, @@ -286,7 +286,7 @@ def get_performance_tracker() -> PerformanceTracker: Returns: Global PerformanceTracker instance """ - global _performance_tracker + global _performance_tracker # pylint: disable=global-statement if _performance_tracker is None: _performance_tracker = PerformanceTracker() return _performance_tracker diff --git a/mcpgateway/services/security_logger.py b/mcpgateway/services/security_logger.py index e4965fc20..0d93f092d 100644 --- a/mcpgateway/services/security_logger.py +++ b/mcpgateway/services/security_logger.py @@ -154,7 +154,7 @@ def log_authentication_attempt( return event - def log_data_access( + def log_data_access( # pylint: disable=too-many-positional-arguments self, action: str, resource_type: str, @@ -344,7 +344,7 @@ def _count_recent_failures(self, user_id: Optional[str] = None, client_ip: Optio should_close = True try: - stmt = select(func.count(SecurityEvent.id)).where(SecurityEvent.event_type == SecurityEventType.AUTHENTICATION_FAILURE, SecurityEvent.timestamp >= since) + stmt = select(func.count(SecurityEvent.id)).where(SecurityEvent.event_type == SecurityEventType.AUTHENTICATION_FAILURE, SecurityEvent.timestamp >= since) # pylint: disable=not-callable if user_id: stmt = stmt.where(SecurityEvent.user_id == user_id) @@ -358,7 +358,7 @@ def _count_recent_failures(self, user_id: Optional[str] = None, client_ip: Optio if should_close: db.close() - def _calculate_auth_threat_score(self, success: bool, failed_attempts: int, auth_method: str) -> float: + def _calculate_auth_threat_score(self, success: bool, failed_attempts: int, auth_method: str) -> float: # pylint: disable=unused-argument """Calculate threat score for authentication attempt. Args: @@ -490,7 +490,7 @@ def _create_security_event( if should_close: db.close() - def _create_audit_trail( + def _create_audit_trail( # pylint: disable=too-many-positional-arguments self, action: str, resource_type: str, @@ -591,7 +591,7 @@ def get_security_logger() -> SecurityLogger: Returns: Global SecurityLogger instance """ - global _security_logger + global _security_logger # pylint: disable=global-statement if _security_logger is None: _security_logger = SecurityLogger() return _security_logger diff --git a/mcpgateway/services/structured_logger.py b/mcpgateway/services/structured_logger.py index 99d42a967..4bff2094f 100644 --- a/mcpgateway/services/structured_logger.py +++ b/mcpgateway/services/structured_logger.py @@ -85,7 +85,7 @@ def enrich(entry: Dict[str, Any]) -> Dict[str, Any]: try: perf_tracker = get_performance_tracker() if correlation_id and perf_tracker and hasattr(perf_tracker, "get_current_operations"): - current_ops = perf_tracker.get_current_operations(correlation_id) + current_ops = perf_tracker.get_current_operations(correlation_id) # pylint: disable=no-member if current_ops: entry["active_operations"] = len(current_ops) except Exception: @@ -94,7 +94,7 @@ def enrich(entry: Dict[str, Any]) -> Dict[str, Any]: # Add OpenTelemetry trace context if available try: - from opentelemetry import trace + from opentelemetry import trace # pylint: disable=import-outside-toplevel span = trace.get_current_span() if span and span.get_span_context().is_valid: @@ -262,8 +262,6 @@ def _persist_to_database(self, entry: Dict[str, Any], db: Optional[Session] = No except Exception as e: logger.error(f"Failed to persist log entry to database: {e}", exc_info=True) # Also print to console for immediate visibility - import traceback - print(f"ERROR persisting log to database: {e}") traceback.print_exc() if db: diff --git a/mcpgateway/static/admin.js b/mcpgateway/static/admin.js index 4eb0975ad..dc5926dff 100644 --- a/mcpgateway/static/admin.js +++ b/mcpgateway/static/admin.js @@ -24351,43 +24351,12 @@ function generateStatusBadgeHtml(enabled, reachable, typeLabel) {
`; - * Restore default log table headers - */ -function restoreLogTableHeaders() { - const thead = document.getElementById('logs-thead'); - if (thead) { - thead.innerHTML = ` - - - - - - - - - - `; } } /** * Dynamically updates the action buttons (Activate/Deactivate) inside the table cell */ - function updateEntityActionButtons(cell, type, id, isEnabled) { // We look for the form that toggles activation inside the cell const form = cell.querySelector('form[action*="/toggle"]'); diff --git a/mcpgateway/translate_grpc.py b/mcpgateway/translate_grpc.py index f1aa9a4ba..58b80ab8f 100644 --- a/mcpgateway/translate_grpc.py +++ b/mcpgateway/translate_grpc.py @@ -173,7 +173,7 @@ async def _discover_service_details(self, stub, service_name: str) -> None: # Add to pool (ignore if already exists) try: self._pool.Add(file_desc_proto) - except Exception as e: # noqa: B110 + except Exception as e: # pylint: disable=broad-except # Descriptor already in pool, safe to skip logger.debug(f"Descriptor already in pool: {e}") diff --git a/mcpgateway/utils/retry_manager.py b/mcpgateway/utils/retry_manager.py index c8cb8283f..613d5736d 100644 --- a/mcpgateway/utils/retry_manager.py +++ b/mcpgateway/utils/retry_manager.py @@ -301,7 +301,7 @@ async def _sleep_with_jitter(self, base: float, jitter_range: float): True """ # random.uniform() is safe here as jitter is only used for retry timing, not security - delay = base + random.uniform(0, jitter_range) # noqa: DUO102 # nosec B311 + delay = base + random.uniform(0, jitter_range) # nosec B311 # Ensure delay doesn't exceed the max allowed delay = min(delay, self.max_delay) await asyncio.sleep(delay) From 6a59bc32e1a511a53a712ba53fb5a590c73caaba Mon Sep 17 00:00:00 2001 From: Shoumi Date: Wed, 26 Nov 2025 19:03:09 +0530 Subject: [PATCH 17/34] lint fixes Signed-off-by: Shoumi --- mcpgateway/main.py | 2 + mcpgateway/static/admin.js | 649 ++++++++++++++++++++----------------- 2 files changed, 357 insertions(+), 294 deletions(-) diff --git a/mcpgateway/main.py b/mcpgateway/main.py index 4474afbd4..96e5a6261 100644 --- a/mcpgateway/main.py +++ b/mcpgateway/main.py @@ -472,6 +472,7 @@ async def lifespan(_app: FastAPI) -> AsyncIterator[None]: log_aggregator = get_log_aggregator() async def run_log_backfill() -> None: + """Backfill log aggregation metrics for configured hours.""" hours = getattr(settings, "metrics_aggregation_backfill_hours", 0) if hours <= 0: return @@ -482,6 +483,7 @@ async def run_log_backfill() -> None: logger.warning("Log aggregation backfill failed: %s", backfill_error) async def run_log_aggregation_loop() -> None: + """Run continuous log aggregation at configured intervals.""" interval_seconds = max(1, int(settings.metrics_aggregation_window_minutes)) * 60 logger.info( "Starting log aggregation loop (window=%s min)", diff --git a/mcpgateway/static/admin.js b/mcpgateway/static/admin.js index dc5926dff..bdfd2877b 100644 --- a/mcpgateway/static/admin.js +++ b/mcpgateway/static/admin.js @@ -23991,51 +23991,67 @@ function updateEntityStatus(type, data) { // Current log search state let currentLogPage = 0; -let currentLogLimit = 50; +const currentLogLimit = 50; +// eslint-disable-next-line no-unused-vars let currentLogFilters = {}; const PERFORMANCE_HISTORY_HOURS = 24; const PERFORMANCE_AGGREGATION_OPTIONS = { - '5m': { label: '5-minute aggregation', query: '5m' }, - '24h': { label: '24-hour aggregation', query: '24h' } + "5m": { label: "5-minute aggregation", query: "5m" }, + "24h": { label: "24-hour aggregation", query: "24h" }, }; -let currentPerformanceAggregationKey = '5m'; +let currentPerformanceAggregationKey = "5m"; -function getPerformanceAggregationConfig(rangeKey = currentPerformanceAggregationKey) { - return PERFORMANCE_AGGREGATION_OPTIONS[rangeKey] || PERFORMANCE_AGGREGATION_OPTIONS['5m']; +function getPerformanceAggregationConfig( + rangeKey = currentPerformanceAggregationKey, +) { + return ( + PERFORMANCE_AGGREGATION_OPTIONS[rangeKey] || + PERFORMANCE_AGGREGATION_OPTIONS["5m"] + ); } -function getPerformanceAggregationLabel(rangeKey = currentPerformanceAggregationKey) { +function getPerformanceAggregationLabel( + rangeKey = currentPerformanceAggregationKey, +) { return getPerformanceAggregationConfig(rangeKey).label; } -function getPerformanceAggregationQuery(rangeKey = currentPerformanceAggregationKey) { +function getPerformanceAggregationQuery( + rangeKey = currentPerformanceAggregationKey, +) { return getPerformanceAggregationConfig(rangeKey).query; } function syncPerformanceAggregationSelect() { - const select = document.getElementById('performance-aggregation-select'); + const select = document.getElementById("performance-aggregation-select"); if (select && select.value !== currentPerformanceAggregationKey) { select.value = currentPerformanceAggregationKey; } } function setPerformanceAggregationVisibility(shouldShow) { - const controls = document.getElementById('performance-aggregation-controls'); - if (!controls) return; + const controls = document.getElementById( + "performance-aggregation-controls", + ); + if (!controls) { + return; + } if (shouldShow) { - controls.classList.remove('hidden'); + controls.classList.remove("hidden"); } else { - controls.classList.add('hidden'); + controls.classList.add("hidden"); } } function setLogFiltersVisibility(shouldShow) { - const filters = document.getElementById('log-filters'); - if (!filters) return; + const filters = document.getElementById("log-filters"); + if (!filters) { + return; + } if (shouldShow) { - filters.classList.remove('hidden'); + filters.classList.remove("hidden"); } else { - filters.classList.add('hidden'); + filters.classList.add("hidden"); } } @@ -24052,64 +24068,69 @@ function handlePerformanceAggregationChange(event) { async function searchStructuredLogs() { setPerformanceAggregationVisibility(false); setLogFiltersVisibility(true); - const levelFilter = document.getElementById('log-level-filter')?.value; - const componentFilter = document.getElementById('log-component-filter')?.value; - const searchQuery = document.getElementById('log-search')?.value; - + const levelFilter = document.getElementById("log-level-filter")?.value; + const componentFilter = document.getElementById( + "log-component-filter", + )?.value; + const searchQuery = document.getElementById("log-search")?.value; + // Restore default log table headers (in case we're coming from performance metrics view) restoreLogTableHeaders(); - + // Build search request const searchRequest = { limit: currentLogLimit, offset: currentLogPage * currentLogLimit, - sort_by: 'timestamp', - sort_order: 'desc' + sort_by: "timestamp", + sort_order: "desc", }; - + // Only add filters if they have actual values (not empty strings) - if (searchQuery && searchQuery.trim() !== '') { + if (searchQuery && searchQuery.trim() !== "") { const trimmedSearch = searchQuery.trim(); // Check if search is a correlation ID (32 hex chars or UUID format) or text search - const correlationIdPattern = /^([0-9a-f]{32}|[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12})$/i; + const correlationIdPattern = + /^([0-9a-f]{32}|[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12})$/i; if (correlationIdPattern.test(trimmedSearch)) { searchRequest.correlation_id = trimmedSearch; } else { searchRequest.search_text = trimmedSearch; } } - if (levelFilter && levelFilter !== '') { + if (levelFilter && levelFilter !== "") { searchRequest.level = [levelFilter]; } - if (componentFilter && componentFilter !== '') { + if (componentFilter && componentFilter !== "") { searchRequest.component = [componentFilter]; } - + // Store filters for pagination currentLogFilters = searchRequest; - + try { const response = await fetch(`${getRootPath()}/api/logs/search`, { - method: 'POST', + method: "POST", headers: { - 'Content-Type': 'application/json', - 'Authorization': `Bearer ${getAuthToken()}` + "Content-Type": "application/json", + Authorization: `Bearer ${getAuthToken()}`, }, - body: JSON.stringify(searchRequest) + body: JSON.stringify(searchRequest), }); - + if (!response.ok) { const errorText = await response.text(); - console.error('API Error Response:', errorText); - throw new Error(`Failed to search logs: ${response.statusText} - ${errorText}`); + console.error("API Error Response:", errorText); + throw new Error( + `Failed to search logs: ${response.statusText} - ${errorText}`, + ); } - + const data = await response.json(); displayLogResults(data); } catch (error) { - console.error('Error searching logs:', error); - showToast('Failed to search logs: ' + error.message, 'error'); - document.getElementById('logs-tbody').innerHTML = ` + console.error("Error searching logs:", error); + showToast("Failed to search logs: " + error.message, "error"); + document.getElementById("logs-tbody").innerHTML = ` @@ -24121,26 +24142,26 @@ async function searchStructuredLogs() { * Display log search results */ function displayLogResults(data) { - const tbody = document.getElementById('logs-tbody'); - const logCount = document.getElementById('log-count'); - const logStats = document.getElementById('log-stats'); - const prevButton = document.getElementById('prev-page'); - const nextButton = document.getElementById('next-page'); - + const tbody = document.getElementById("logs-tbody"); + const logCount = document.getElementById("log-count"); + const logStats = document.getElementById("log-stats"); + const prevButton = document.getElementById("prev-page"); + const nextButton = document.getElementById("next-page"); + // Ensure default headers are shown for log view restoreLogTableHeaders(); - + if (!data.results || data.results.length === 0) { tbody.innerHTML = ` `; - logCount.textContent = '0 logs'; + logCount.textContent = "0 logs"; logStats.innerHTML = 'No results'; return; } - + // Update stats logCount.textContent = `${data.total.toLocaleString()} logs`; const start = currentLogPage * currentLogLimit + 1; @@ -24150,21 +24171,24 @@ function displayLogResults(data) { Showing ${start}-${end} of ${data.total.toLocaleString()} logs `; - + // Update pagination buttons prevButton.disabled = currentLogPage === 0; nextButton.disabled = end >= data.total; - + // Render log entries - tbody.innerHTML = data.results.map(log => { - const levelClass = getLogLevelClass(log.level); - const durationDisplay = log.duration_ms ? `${log.duration_ms.toFixed(2)}ms` : '-'; - const correlationId = log.correlation_id || '-'; - const userDisplay = log.user_email || log.user_id || '-'; - - return ` + tbody.innerHTML = data.results + .map((log) => { + const levelClass = getLogLevelClass(log.level); + const durationDisplay = log.duration_ms + ? `${log.duration_ms.toFixed(2)}ms` + : "-"; + const correlationId = log.correlation_id || "-"; + const userDisplay = log.user_email || log.user_id || "-"; + + return ` + onclick="showLogDetails('${log.id}', '${escapeHtml(log.correlation_id || "")}')"> @@ -24174,11 +24198,11 @@ function displayLogResults(data) { `; - }).join(''); + }) + .join(""); } /** @@ -24204,13 +24233,15 @@ function displayLogResults(data) { */ function getLogLevelClass(level) { const classes = { - 'DEBUG': 'bg-gray-200 text-gray-800 dark:bg-gray-600 dark:text-gray-200', - 'INFO': 'bg-blue-200 text-blue-800 dark:bg-blue-800 dark:text-blue-200', - 'WARNING': 'bg-yellow-200 text-yellow-800 dark:bg-yellow-800 dark:text-yellow-200', - 'ERROR': 'bg-red-200 text-red-800 dark:bg-red-800 dark:text-red-200', - 'CRITICAL': 'bg-purple-200 text-purple-800 dark:bg-purple-800 dark:text-purple-200' + DEBUG: "bg-gray-200 text-gray-800 dark:bg-gray-600 dark:text-gray-200", + INFO: "bg-blue-200 text-blue-800 dark:bg-blue-800 dark:text-blue-200", + WARNING: + "bg-yellow-200 text-yellow-800 dark:bg-yellow-800 dark:text-yellow-200", + ERROR: "bg-red-200 text-red-800 dark:bg-red-800 dark:text-red-200", + CRITICAL: + "bg-purple-200 text-purple-800 dark:bg-purple-800 dark:text-purple-200", }; - return classes[level] || classes['INFO']; + return classes[level] || classes.INFO; } /** @@ -24218,12 +24249,12 @@ function getLogLevelClass(level) { */ function formatTimestamp(timestamp) { const date = new Date(timestamp); - return date.toLocaleString('en-US', { - month: 'short', - day: 'numeric', - hour: '2-digit', - minute: '2-digit', - second: '2-digit' + return date.toLocaleString("en-US", { + month: "short", + day: "numeric", + hour: "2-digit", + minute: "2-digit", + second: "2-digit", }); } @@ -24231,8 +24262,12 @@ function formatTimestamp(timestamp) { * Truncate text with ellipsis */ function truncateText(text, maxLength) { - if (!text) return ''; - return text.length > maxLength ? text.substring(0, maxLength) + '...' : text; + if (!text) { + return ""; + } + return text.length > maxLength + ? text.substring(0, maxLength) + "..." + : text; } /** @@ -24242,8 +24277,8 @@ function showLogDetails(logId, correlationId) { if (correlationId) { showCorrelationTrace(correlationId); } else { - console.log('Log details:', logId); - showToast('Full log details view coming soon', 'info'); + console.log("Log details:", logId); + showToast("Full log details view coming soon", "info"); } } @@ -24251,7 +24286,7 @@ function showLogDetails(logId, correlationId) { * Restore default log table headers */ function restoreLogTableHeaders() { - const thead = document.getElementById('logs-thead'); + const thead = document.getElementById("logs-thead"); if (thead) { thead.innerHTML = ` @@ -24288,28 +24323,39 @@ async function showCorrelationTrace(correlationId) { setPerformanceAggregationVisibility(false); setLogFiltersVisibility(true); if (!correlationId) { - const searchInput = document.getElementById('log-search'); - correlationId = prompt('Enter Correlation ID to trace:', searchInput?.value || ''); - if (!correlationId) return; + const searchInput = document.getElementById("log-search"); + correlationId = prompt( + "Enter Correlation ID to trace:", + searchInput?.value || "", + ); + if (!correlationId) { + return; + } } - + try { - const response = await fetch(`${getRootPath()}/api/logs/trace/${encodeURIComponent(correlationId)}`, { - method: 'GET', - headers: { - 'Authorization': `Bearer ${getAuthToken()}` - } - }); - + const response = await fetch( + `${getRootPath()}/api/logs/trace/${encodeURIComponent(correlationId)}`, + { + method: "GET", + headers: { + Authorization: `Bearer ${getAuthToken()}`, + }, + }, + ); + if (!response.ok) { throw new Error(`Failed to fetch trace: ${response.statusText}`); } - + const trace = await response.json(); displayCorrelationTrace(trace); } catch (error) { - console.error('Error fetching correlation trace:', error); - showToast('Failed to fetch correlation trace: ' + error.message, 'error'); + console.error("Error fetching correlation trace:", error); + showToast( + "Failed to fetch correlation trace: " + error.message, + "error", + ); } } @@ -24474,16 +24520,17 @@ console.log("💡 Use: window.debugMCPSearchState() to check current state"); * Display correlation trace results */ function displayCorrelationTrace(trace) { - const tbody = document.getElementById('logs-tbody'); - const thead = document.getElementById('logs-thead'); - const logCount = document.getElementById('log-count'); - const logStats = document.getElementById('log-stats'); - + const tbody = document.getElementById("logs-tbody"); + const thead = document.getElementById("logs-thead"); + const logCount = document.getElementById("log-count"); + const logStats = document.getElementById("log-stats"); + // Calculate total events - const totalEvents = (trace.logs?.length || 0) + - (trace.security_events?.length || 0) + - (trace.audit_trails?.length || 0); - + const totalEvents = + (trace.logs?.length || 0) + + (trace.security_events?.length || 0) + + (trace.audit_trails?.length || 0); + // Update table headers for trace view if (thead) { thead.innerHTML = ` @@ -24512,7 +24559,7 @@ function displayCorrelationTrace(trace) { `; } - + // Update stats logCount.textContent = `${totalEvents} events`; logStats.innerHTML = ` @@ -24531,11 +24578,11 @@ function displayCorrelationTrace(trace) { Audit: ${trace.audit_trails?.length || 0}
- Duration: ${trace.total_duration_ms ? trace.total_duration_ms.toFixed(2) + 'ms' : 'N/A'} + Duration: ${trace.total_duration_ms ? trace.total_duration_ms.toFixed(2) + "ms" : "N/A"}
`; - + if (totalEvents === 0) { tbody.innerHTML = `
- ` + `, }); }); - + // Add security events - (trace.security_events || []).forEach(event => { + (trace.security_events || []).forEach((event) => { const severityClass = getSeverityClass(event.severity); - const threatScore = event.threat_score ? (event.threat_score * 100).toFixed(0) : 0; + const threatScore = event.threat_score + ? (event.threat_score * 100).toFixed(0) + : 0; allEvents.push({ timestamp: new Date(event.timestamp), html: ` @@ -24603,13 +24652,13 @@ function displayCorrelationTrace(trace) { - ` + `, }); }); - + // Add audit trails - (trace.audit_trails || []).forEach(audit => { + (trace.audit_trails || []).forEach((audit) => { const actionBadgeColors = { - 'create': 'bg-green-200 text-green-800', - 'update': 'bg-blue-200 text-blue-800', - 'delete': 'bg-red-200 text-red-800', - 'read': 'bg-gray-200 text-gray-800' + create: "bg-green-200 text-green-800", + update: "bg-blue-200 text-blue-800", + delete: "bg-red-200 text-red-800", + read: "bg-gray-200 text-gray-800", }; - const actionBadge = actionBadgeColors[audit.action?.toLowerCase()] || 'bg-purple-200 text-purple-800'; - const statusIcon = audit.success ? '✓' : '✗'; - const statusClass = audit.success ? 'text-green-600' : 'text-red-600'; - const statusBg = audit.success ? 'bg-green-100 dark:bg-green-900' : 'bg-red-100 dark:bg-red-900'; - + const actionBadge = + actionBadgeColors[audit.action?.toLowerCase()] || + "bg-purple-200 text-purple-800"; + const statusIcon = audit.success ? "✓" : "✗"; + const statusClass = audit.success ? "text-green-600" : "text-red-600"; + const statusBg = audit.success + ? "bg-green-100 dark:bg-green-900" + : "bg-red-100 dark:bg-red-900"; + allEvents.push({ timestamp: new Date(audit.timestamp), html: ` @@ -24659,33 +24712,33 @@ function displayCorrelationTrace(trace) { - ` + `, }); }); - + // Sort all events chronologically allEvents.sort((a, b) => a.timestamp - b.timestamp); - + // Render sorted events - tbody.innerHTML = allEvents.map(event => event.html).join(''); + tbody.innerHTML = allEvents.map((event) => event.html).join(""); } /** @@ -24695,22 +24748,27 @@ async function showSecurityEvents() { setPerformanceAggregationVisibility(false); setLogFiltersVisibility(false); try { - const response = await fetch(`${getRootPath()}/api/logs/security-events?limit=50&resolved=false`, { - method: 'GET', - headers: { - 'Authorization': `Bearer ${getAuthToken()}` - } - }); - + const response = await fetch( + `${getRootPath()}/api/logs/security-events?limit=50&resolved=false`, + { + method: "GET", + headers: { + Authorization: `Bearer ${getAuthToken()}`, + }, + }, + ); + if (!response.ok) { - throw new Error(`Failed to fetch security events: ${response.statusText}`); + throw new Error( + `Failed to fetch security events: ${response.statusText}`, + ); } - + const events = await response.json(); displaySecurityEvents(events); } catch (error) { - console.error('Error fetching security events:', error); - showToast('Failed to fetch security events: ' + error.message, 'error'); + console.error("Error fetching security events:", error); + showToast("Failed to fetch security events: " + error.message, "error"); } } @@ -24718,11 +24776,11 @@ async function showSecurityEvents() { * Display security events */ function displaySecurityEvents(events) { - const tbody = document.getElementById('logs-tbody'); - const thead = document.getElementById('logs-thead'); - const logCount = document.getElementById('log-count'); - const logStats = document.getElementById('log-stats'); - + const tbody = document.getElementById("logs-tbody"); + const thead = document.getElementById("logs-thead"); + const logCount = document.getElementById("log-count"); + const logStats = document.getElementById("log-stats"); + // Update table headers for security events if (thead) { thead.innerHTML = ` @@ -24751,14 +24809,14 @@ function displaySecurityEvents(events) { `; } - + logCount.textContent = `${events.length} security events`; logStats.innerHTML = ` 🛡️ Unresolved Security Events `; - + if (events.length === 0) { tbody.innerHTML = ` `; - }).join(''); + }) + .join(""); } /** @@ -24817,12 +24881,12 @@ function displaySecurityEvents(events) { */ function getSeverityClass(severity) { const classes = { - 'LOW': 'bg-blue-200 text-blue-800 dark:bg-blue-800 dark:text-blue-200', - 'MEDIUM': 'bg-yellow-200 text-yellow-800 dark:bg-yellow-800 dark:text-yellow-200', - 'HIGH': 'bg-orange-200 text-orange-800 dark:bg-orange-800 dark:text-orange-200', - 'CRITICAL': 'bg-red-200 text-red-800 dark:bg-red-800 dark:text-red-200' + LOW: "bg-blue-200 text-blue-800 dark:bg-blue-800 dark:text-blue-200", + MEDIUM: "bg-yellow-200 text-yellow-800 dark:bg-yellow-800 dark:text-yellow-200", + HIGH: "bg-orange-200 text-orange-800 dark:bg-orange-800 dark:text-orange-200", + CRITICAL: "bg-red-200 text-red-800 dark:bg-red-800 dark:text-red-200", }; - return classes[severity] || classes['MEDIUM']; + return classes[severity] || classes.MEDIUM; } /** @@ -24832,22 +24896,27 @@ async function showAuditTrail() { setPerformanceAggregationVisibility(false); setLogFiltersVisibility(false); try { - const response = await fetch(`${getRootPath()}/api/logs/audit-trails?limit=50&requires_review=true`, { - method: 'GET', - headers: { - 'Authorization': `Bearer ${getAuthToken()}` - } - }); - + const response = await fetch( + `${getRootPath()}/api/logs/audit-trails?limit=50&requires_review=true`, + { + method: "GET", + headers: { + Authorization: `Bearer ${getAuthToken()}`, + }, + }, + ); + if (!response.ok) { - throw new Error(`Failed to fetch audit trails: ${response.statusText}`); + throw new Error( + `Failed to fetch audit trails: ${response.statusText}`, + ); } - + const trails = await response.json(); displayAuditTrail(trails); } catch (error) { - console.error('Error fetching audit trails:', error); - showToast('Failed to fetch audit trails: ' + error.message, 'error'); + console.error("Error fetching audit trails:", error); + showToast("Failed to fetch audit trails: " + error.message, "error"); } } @@ -24855,11 +24924,11 @@ async function showAuditTrail() { * Display audit trail entries */ function displayAuditTrail(trails) { - const tbody = document.getElementById('logs-tbody'); - const thead = document.getElementById('logs-thead'); - const logCount = document.getElementById('log-count'); - const logStats = document.getElementById('log-stats'); - + const tbody = document.getElementById("logs-tbody"); + const thead = document.getElementById("logs-thead"); + const logCount = document.getElementById("log-count"); + const logStats = document.getElementById("log-stats"); + // Update table headers for audit trail if (thead) { thead.innerHTML = ` @@ -24888,14 +24957,14 @@ function displayAuditTrail(trails) { `; } - + logCount.textContent = `${trails.length} audit entries`; logStats.innerHTML = ` 📝 Audit Trail Entries Requiring Review `; - + if (trails.length === 0) { tbody.innerHTML = ` `; - }).join(''); + }) + .join(""); } /** @@ -24970,7 +25052,9 @@ async function showPerformanceMetrics(rangeKey) { if (rangeKey && PERFORMANCE_AGGREGATION_OPTIONS[rangeKey]) { currentPerformanceAggregationKey = rangeKey; } else { - const select = document.getElementById('performance-aggregation-select'); + const select = document.getElementById( + "performance-aggregation-select", + ); if (select?.value && PERFORMANCE_AGGREGATION_OPTIONS[select.value]) { currentPerformanceAggregationKey = select.value; } @@ -24980,25 +25064,35 @@ async function showPerformanceMetrics(rangeKey) { setPerformanceAggregationVisibility(true); setLogFiltersVisibility(false); const hoursParam = encodeURIComponent(PERFORMANCE_HISTORY_HOURS.toString()); - const aggregationParam = encodeURIComponent(getPerformanceAggregationQuery()); + const aggregationParam = encodeURIComponent( + getPerformanceAggregationQuery(), + ); try { - const response = await fetch(`${getRootPath()}/api/logs/performance-metrics?hours=${hoursParam}&aggregation=${aggregationParam}`, { - method: 'GET', - headers: { - 'Authorization': `Bearer ${getAuthToken()}` - } - }); - + const response = await fetch( + `${getRootPath()}/api/logs/performance-metrics?hours=${hoursParam}&aggregation=${aggregationParam}`, + { + method: "GET", + headers: { + Authorization: `Bearer ${getAuthToken()}`, + }, + }, + ); + if (!response.ok) { - throw new Error(`Failed to fetch performance metrics: ${response.statusText}`); + throw new Error( + `Failed to fetch performance metrics: ${response.statusText}`, + ); } - + const metrics = await response.json(); displayPerformanceMetrics(metrics); } catch (error) { - console.error('Error fetching performance metrics:', error); - showToast('Failed to fetch performance metrics: ' + error.message, 'error'); + console.error("Error fetching performance metrics:", error); + showToast( + "Failed to fetch performance metrics: " + error.message, + "error", + ); } } @@ -25006,12 +25100,12 @@ async function showPerformanceMetrics(rangeKey) { * Display performance metrics */ function displayPerformanceMetrics(metrics) { - const tbody = document.getElementById('logs-tbody'); - const thead = document.getElementById('logs-thead'); - const logCount = document.getElementById('log-count'); - const logStats = document.getElementById('log-stats'); + const tbody = document.getElementById("logs-tbody"); + const thead = document.getElementById("logs-thead"); + const logCount = document.getElementById("log-count"); + const logStats = document.getElementById("log-stats"); const aggregationLabel = getPerformanceAggregationLabel(); - + // Update table headers for performance metrics if (thead) { thead.innerHTML = ` @@ -25040,14 +25134,14 @@ function displayPerformanceMetrics(metrics) { `; } - + logCount.textContent = `${metrics.length} metrics`; logStats.innerHTML = ` ⚡ Performance Metrics (${aggregationLabel}) `; - + if (metrics.length === 0) { tbody.innerHTML = ` `; - }).join(''); + }) + .join(""); } /** @@ -25113,57 +25210,21 @@ function nextLogPage() { searchStructuredLogs(); } -/** - * Get auth token from session - */ -function getAuthToken() { - // Check cookie first (matches HTMX authentication) - const jwtToken = getCookie('jwt_token') || getCookie('access_token') || getCookie('token'); - if (jwtToken) { - return jwtToken; - } - - // Fallback: check localStorage - const localToken = localStorage.getItem('auth_token'); - if (localToken) { - return localToken; - } - - // Last resort: check input field - const tokenInput = document.querySelector('input[name="auth_token"]'); - if (tokenInput && tokenInput.value) { - return tokenInput.value; - } - - // No token found - log warning for debugging - console.warn('No authentication token found for API request'); - return ''; -} - /** * Get root path for API calls */ function getRootPath() { - return window.ROOT_PATH || ''; -} - -/** - * Escape HTML to prevent XSS - */ -function escapeHtml(text) { - if (!text) return ''; - const div = document.createElement('div'); - div.textContent = text; - return div.innerHTML; + return window.ROOT_PATH || ""; } /** * Show toast notification */ -function showToast(message, type = 'info') { +function showToast(message, type = "info") { // Check if showMessage function exists (from existing admin.js) - if (typeof showMessage === 'function') { - showMessage(message, type === 'error' ? 'danger' : type); + if (typeof showMessage === "function") { + // eslint-disable-next-line no-undef + showMessage(message, type === "error" ? "danger" : type); } else { console.log(`[${type.toUpperCase()}] ${message}`); } From 99253a03fbe3680a1d803aa4399d9e1ec1629fdc Mon Sep 17 00:00:00 2001 From: Shoumi Date: Thu, 27 Nov 2025 11:12:48 +0530 Subject: [PATCH 18/34] fix for doctest Signed-off-by: Shoumi --- mcpgateway/services/tool_service.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/mcpgateway/services/tool_service.py b/mcpgateway/services/tool_service.py index b186b0bb3..1cdfafb0c 100644 --- a/mcpgateway/services/tool_service.py +++ b/mcpgateway/services/tool_service.py @@ -1422,15 +1422,17 @@ async def invoke_tool( Examples: >>> from mcpgateway.services.tool_service import ToolService - >>> from unittest.mock import MagicMock + >>> from unittest.mock import MagicMock, patch >>> service = ToolService() >>> db = MagicMock() >>> tool = MagicMock() >>> db.execute.return_value.scalar_one_or_none.side_effect = [tool, None] >>> tool.reachable = True >>> import asyncio - >>> result = asyncio.run(service.invoke_tool(db, 'tool_name', {})) - >>> isinstance(result, object) + >>> # Mock structured_logger to prevent database writes during doctest + >>> with patch('mcpgateway.services.tool_service.structured_logger'): + ... result = asyncio.run(service.invoke_tool(db, 'tool_name', {})) + ... isinstance(result, object) True """ # pylint: disable=comparison-with-callable From f9d37e906880fac1807005767092006d5bee6f84 Mon Sep 17 00:00:00 2001 From: Shoumi Date: Thu, 27 Nov 2025 13:23:30 +0530 Subject: [PATCH 19/34] auth issue fixes Signed-off-by: Shoumi --- mcpgateway/services/server_service.py | 16 ++++++-- mcpgateway/static/admin.js | 58 +++++++++++++++------------ 2 files changed, 45 insertions(+), 29 deletions(-) diff --git a/mcpgateway/services/server_service.py b/mcpgateway/services/server_service.py index b10621725..2ed3bcaa1 100644 --- a/mcpgateway/services/server_service.py +++ b/mcpgateway/services/server_service.py @@ -400,7 +400,7 @@ async def register_server( Examples: >>> from mcpgateway.services.server_service import ServerService - >>> from unittest.mock import MagicMock, AsyncMock + >>> from unittest.mock import MagicMock, AsyncMock, patch >>> from mcpgateway.schemas import ServerRead >>> service = ServerService() >>> db = MagicMock() @@ -412,6 +412,8 @@ async def register_server( >>> db.refresh = MagicMock() >>> service._notify_server_added = AsyncMock() >>> service._convert_server_to_read = MagicMock(return_value='server_read') + >>> service._structured_logger = MagicMock() # Mock structured logger to prevent database writes + >>> service._audit_trail = MagicMock() # Mock audit trail to prevent database writes >>> ServerRead.model_validate = MagicMock(return_value='server_read') >>> import asyncio >>> asyncio.run(service.register_server(db, server_in)) @@ -881,7 +883,7 @@ async def update_server( Examples: >>> from mcpgateway.services.server_service import ServerService - >>> from unittest.mock import MagicMock, AsyncMock + >>> from unittest.mock import MagicMock, AsyncMock, patch >>> from mcpgateway.schemas import ServerRead >>> service = ServerService() >>> db = MagicMock() @@ -895,6 +897,8 @@ async def update_server( >>> db.refresh = MagicMock() >>> db.execute.return_value.scalar_one_or_none.return_value = None >>> service._convert_server_to_read = MagicMock(return_value='server_read') + >>> service._structured_logger = MagicMock() # Mock structured logger to prevent database writes + >>> service._audit_trail = MagicMock() # Mock audit trail to prevent database writes >>> ServerRead.model_validate = MagicMock(return_value='server_read') >>> server_update = MagicMock() >>> server_update.id = None # No UUID change @@ -1161,7 +1165,7 @@ async def toggle_server_status(self, db: Session, server_id: str, activate: bool Examples: >>> from mcpgateway.services.server_service import ServerService - >>> from unittest.mock import MagicMock, AsyncMock + >>> from unittest.mock import MagicMock, AsyncMock, patch >>> from mcpgateway.schemas import ServerRead >>> service = ServerService() >>> db = MagicMock() @@ -1172,6 +1176,8 @@ async def toggle_server_status(self, db: Session, server_id: str, activate: bool >>> service._notify_server_activated = AsyncMock() >>> service._notify_server_deactivated = AsyncMock() >>> service._convert_server_to_read = MagicMock(return_value='server_read') + >>> service._structured_logger = MagicMock() # Mock structured logger to prevent database writes + >>> service._audit_trail = MagicMock() # Mock audit trail to prevent database writes >>> ServerRead.model_validate = MagicMock(return_value='server_read') >>> import asyncio >>> asyncio.run(service.toggle_server_status(db, 'server_id', True)) @@ -1283,7 +1289,7 @@ async def delete_server(self, db: Session, server_id: str, user_email: Optional[ Examples: >>> from mcpgateway.services.server_service import ServerService - >>> from unittest.mock import MagicMock, AsyncMock + >>> from unittest.mock import MagicMock, AsyncMock, patch >>> service = ServerService() >>> db = MagicMock() >>> server = MagicMock() @@ -1291,6 +1297,8 @@ async def delete_server(self, db: Session, server_id: str, user_email: Optional[ >>> db.delete = MagicMock() >>> db.commit = MagicMock() >>> service._notify_server_deleted = AsyncMock() + >>> service._structured_logger = MagicMock() # Mock structured logger to prevent database writes + >>> service._audit_trail = MagicMock() # Mock audit trail to prevent database writes >>> import asyncio >>> asyncio.run(service.delete_server(db, 'server_id', 'user@example.com')) """ diff --git a/mcpgateway/static/admin.js b/mcpgateway/static/admin.js index bdfd2877b..a62a85671 100644 --- a/mcpgateway/static/admin.js +++ b/mcpgateway/static/admin.js @@ -18514,11 +18514,29 @@ async function getAuthToken() { if (!token) { token = localStorage.getItem("auth_token"); } - console.log("MY TOKEN GENERATED:", token); - return token || ""; } +/** + * Fetch helper that always includes auth context. + * Ensures HTTP-only cookies are sent even when JS cannot read them. + */ +async function fetchWithAuth(url, options = {}) { + const opts = { ...options }; + // Always send same-origin cookies unless caller overrides explicitly + opts.credentials = options.credentials || "same-origin"; + + // Clone headers to avoid mutating caller-provided object + const headers = new Headers(options.headers || {}); + const token = await getAuthToken(); + if (token) { + headers.set("Authorization", `Bearer ${token}`); + } + opts.headers = headers; + + return fetch(url, opts); +} + // Expose token management functions to global scope window.loadTokensList = loadTokensList; window.setupCreateTokenForm = setupCreateTokenForm; @@ -24108,14 +24126,16 @@ async function searchStructuredLogs() { currentLogFilters = searchRequest; try { - const response = await fetch(`${getRootPath()}/api/logs/search`, { - method: "POST", - headers: { - "Content-Type": "application/json", - Authorization: `Bearer ${getAuthToken()}`, + const response = await fetchWithAuth( + `${getRootPath()}/api/logs/search`, + { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify(searchRequest), }, - body: JSON.stringify(searchRequest), - }); + ); if (!response.ok) { const errorText = await response.text(); @@ -24334,13 +24354,10 @@ async function showCorrelationTrace(correlationId) { } try { - const response = await fetch( + const response = await fetchWithAuth( `${getRootPath()}/api/logs/trace/${encodeURIComponent(correlationId)}`, { method: "GET", - headers: { - Authorization: `Bearer ${getAuthToken()}`, - }, }, ); @@ -24748,13 +24765,10 @@ async function showSecurityEvents() { setPerformanceAggregationVisibility(false); setLogFiltersVisibility(false); try { - const response = await fetch( + const response = await fetchWithAuth( `${getRootPath()}/api/logs/security-events?limit=50&resolved=false`, { method: "GET", - headers: { - Authorization: `Bearer ${getAuthToken()}`, - }, }, ); @@ -24896,13 +24910,10 @@ async function showAuditTrail() { setPerformanceAggregationVisibility(false); setLogFiltersVisibility(false); try { - const response = await fetch( + const response = await fetchWithAuth( `${getRootPath()}/api/logs/audit-trails?limit=50&requires_review=true`, { method: "GET", - headers: { - Authorization: `Bearer ${getAuthToken()}`, - }, }, ); @@ -25069,13 +25080,10 @@ async function showPerformanceMetrics(rangeKey) { ); try { - const response = await fetch( + const response = await fetchWithAuth( `${getRootPath()}/api/logs/performance-metrics?hours=${hoursParam}&aggregation=${aggregationParam}`, { method: "GET", - headers: { - Authorization: `Bearer ${getAuthToken()}`, - }, }, ); From f8d3efbda5a21927e745c9a0b25cfad7fc64d928 Mon Sep 17 00:00:00 2001 From: Shoumi Date: Thu, 27 Nov 2025 16:35:08 +0530 Subject: [PATCH 20/34] update for failing tests Signed-off-by: Shoumi --- tests/conftest.py | 24 ++++++-- tests/e2e/test_main_apis.py | 9 +++ tests/fuzz/conftest.py | 55 +++++++++++++++++++ tests/fuzz/test_security_fuzz.py | 4 +- .../security/test_rpc_endpoint_validation.py | 22 ++++++-- .../middleware/test_auth_middleware.py | 10 +++- .../test_request_logging_middleware.py | 36 +++++++----- .../mcpgateway/services/test_a2a_service.py | 15 +++++ .../services/test_gateway_service.py | 10 ++++ .../services/test_prompt_service.py | 10 ++++ .../services/test_resource_ownership.py | 20 +++++++ .../services/test_resource_service.py | 10 ++++ .../mcpgateway/services/test_tool_service.py | 10 ++++ tests/unit/mcpgateway/test_main.py | 10 +++- tests/unit/mcpgateway/test_main_extended.py | 10 +++- .../utils/test_verify_credentials.py | 13 ++++- 16 files changed, 238 insertions(+), 30 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 5c813749f..a1cce45df 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -121,8 +121,16 @@ def app(): import mcpgateway.main as main_mod mp.setattr(main_mod, "SessionLocal", TestSessionLocal, raising=False) - # (patch engine too if your code references it) - mp.setattr(main_mod, "engine", engine, raising=False) + + # Also patch security_logger and auth_middleware's SessionLocal + # First-Party + import mcpgateway.middleware.auth_middleware as auth_middleware_mod + import mcpgateway.services.security_logger as sec_logger_mod + import mcpgateway.services.structured_logger as struct_logger_mod + + mp.setattr(auth_middleware_mod, "SessionLocal", TestSessionLocal, raising=False) + mp.setattr(sec_logger_mod, "SessionLocal", TestSessionLocal, raising=False) + mp.setattr(struct_logger_mod, "SessionLocal", TestSessionLocal, raising=False) # 4) create schema db_mod.Base.metadata.create_all(bind=engine) @@ -186,8 +194,16 @@ def app_with_temp_db(): import mcpgateway.main as main_mod mp.setattr(main_mod, "SessionLocal", TestSessionLocal, raising=False) - # (patch engine too if your code references it) - mp.setattr(main_mod, "engine", engine, raising=False) + + # Also patch security_logger and auth_middleware's SessionLocal + # First-Party + import mcpgateway.middleware.auth_middleware as auth_middleware_mod + import mcpgateway.services.security_logger as sec_logger_mod + import mcpgateway.services.structured_logger as struct_logger_mod + + mp.setattr(auth_middleware_mod, "SessionLocal", TestSessionLocal, raising=False) + mp.setattr(sec_logger_mod, "SessionLocal", TestSessionLocal, raising=False) + mp.setattr(struct_logger_mod, "SessionLocal", TestSessionLocal, raising=False) # 4) create schema db_mod.Base.metadata.create_all(bind=engine) diff --git a/tests/e2e/test_main_apis.py b/tests/e2e/test_main_apis.py index bb8c6e29c..ab22b1126 100644 --- a/tests/e2e/test_main_apis.py +++ b/tests/e2e/test_main_apis.py @@ -218,9 +218,18 @@ def mock_get_permission_service(*args, **kwargs): app.dependency_overrides[get_permission_service] = mock_get_permission_service app.dependency_overrides[get_db] = override_get_db + # Mock security_logger to prevent database access issues + mock_sec_logger = MagicMock() + mock_sec_logger.log_authentication_attempt = MagicMock(return_value=None) + mock_sec_logger.log_security_event = MagicMock(return_value=None) + # Patch at the middleware level where security_logger is used + sec_patcher = patch("mcpgateway.middleware.auth_middleware.security_logger", mock_sec_logger) + sec_patcher.start() + yield engine # Cleanup + sec_patcher.stop() app.dependency_overrides.clear() os.close(db_fd) os.unlink(db_path) diff --git a/tests/fuzz/conftest.py b/tests/fuzz/conftest.py index 6b9326b4b..a92cd87ca 100644 --- a/tests/fuzz/conftest.py +++ b/tests/fuzz/conftest.py @@ -7,13 +7,68 @@ Fuzzing test configuration. """ +# Standard +import os +import tempfile + # Third-Party +from _pytest.monkeypatch import MonkeyPatch from hypothesis import HealthCheck, settings, Verbosity import pytest +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker +from sqlalchemy.pool import StaticPool # Mark all tests in this directory as fuzz tests pytestmark = pytest.mark.fuzz + +@pytest.fixture(autouse=True) +def mock_logging_services(monkeypatch): + """Mock logging services to prevent database access during fuzz tests. + + This fixture patches SessionLocal in the db module and all modules that + import it, ensuring they use a test database with all tables created. + """ + # Create a temp database for the fuzz tests + fd, path = tempfile.mkstemp(suffix=".db") + url = f"sqlite:///{path}" + + engine = create_engine(url, connect_args={"check_same_thread": False}, poolclass=StaticPool) + TestSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + + # First-Party + import mcpgateway.db as db_mod + from mcpgateway.db import Base + import mcpgateway.main as main_mod + import mcpgateway.middleware.auth_middleware as auth_middleware_mod + import mcpgateway.services.security_logger as sec_logger_mod + import mcpgateway.services.structured_logger as struct_logger_mod + + # Patch the core db module + monkeypatch.setattr(db_mod, "engine", engine) + monkeypatch.setattr(db_mod, "SessionLocal", TestSessionLocal) + + # Patch main module's SessionLocal (it imports SessionLocal from db) + monkeypatch.setattr(main_mod, "SessionLocal", TestSessionLocal) + + # Patch auth_middleware's SessionLocal + monkeypatch.setattr(auth_middleware_mod, "SessionLocal", TestSessionLocal) + + # Patch security_logger and structured_logger SessionLocal + monkeypatch.setattr(sec_logger_mod, "SessionLocal", TestSessionLocal) + monkeypatch.setattr(struct_logger_mod, "SessionLocal", TestSessionLocal) + + # Create all tables + Base.metadata.create_all(bind=engine) + + yield + + # Cleanup + engine.dispose() + os.close(fd) + os.unlink(path) + # Configure Hypothesis profiles for different environments settings.register_profile("dev", max_examples=100, verbosity=Verbosity.normal, suppress_health_check=[HealthCheck.too_slow]) diff --git a/tests/fuzz/test_security_fuzz.py b/tests/fuzz/test_security_fuzz.py index 7da56e4c5..b4494d5d9 100644 --- a/tests/fuzz/test_security_fuzz.py +++ b/tests/fuzz/test_security_fuzz.py @@ -99,7 +99,7 @@ def test_integer_overflow_handling(self, large_int): response = client.post("/admin/tools", json=payload, headers={"Authorization": "Basic YWRtaW46Y2hhbmdlbWU="}) - assert response.status_code in [200, 201, 400, 422] + assert response.status_code in [200, 201, 400, 401, 422] def test_path_traversal_resistance(self): """Test resistance to path traversal attacks.""" @@ -330,4 +330,4 @@ def test_rate_limiting_behavior(self): # Should either accept all or start rate limiting # Rate limiting typically returns 429 for status in responses: - assert status in [200, 201, 400, 422, 429, 409] + assert status in [200, 201, 400, 401, 422, 429, 409] diff --git a/tests/security/test_rpc_endpoint_validation.py b/tests/security/test_rpc_endpoint_validation.py index 2ec390eee..71af40285 100644 --- a/tests/security/test_rpc_endpoint_validation.py +++ b/tests/security/test_rpc_endpoint_validation.py @@ -14,6 +14,7 @@ # Standard import logging +from unittest.mock import MagicMock, patch # Third-Party from fastapi.testclient import TestClient @@ -37,9 +38,14 @@ class TestRPCEndpointValidation: """ @pytest.fixture - def client(self): - """Create a test client for the FastAPI app.""" - return TestClient(app) + def client(self, app): + """Create a test client for the FastAPI app with mocked security_logger.""" + # Mock security_logger to prevent database access + mock_sec_logger = MagicMock() + mock_sec_logger.log_authentication_attempt = MagicMock(return_value=None) + mock_sec_logger.log_security_event = MagicMock(return_value=None) + with patch("mcpgateway.middleware.auth_middleware.security_logger", mock_sec_logger): + yield TestClient(app) @pytest.fixture def auth_headers(self): @@ -269,8 +275,14 @@ class TestRPCValidationBypass: """Test various techniques to bypass RPC validation.""" @pytest.fixture - def client(self): - return TestClient(app) + def client(self, app): + """Create a test client for the FastAPI app with mocked security_logger.""" + # Mock security_logger to prevent database access + mock_sec_logger = MagicMock() + mock_sec_logger.log_authentication_attempt = MagicMock(return_value=None) + mock_sec_logger.log_security_event = MagicMock(return_value=None) + with patch("mcpgateway.middleware.auth_middleware.security_logger", mock_sec_logger): + yield TestClient(app) def test_bypass_techniques(self, client): """Test various bypass techniques.""" diff --git a/tests/unit/mcpgateway/middleware/test_auth_middleware.py b/tests/unit/mcpgateway/middleware/test_auth_middleware.py index cf8b85aa3..5882e20af 100644 --- a/tests/unit/mcpgateway/middleware/test_auth_middleware.py +++ b/tests/unit/mcpgateway/middleware/test_auth_middleware.py @@ -103,10 +103,18 @@ async def test_authentication_failure(monkeypatch): request.url.path = "/api/data" request.cookies = {"jwt_token": "bad_token"} request.headers = {} + # Mock request.client for security_logger + request.client = MagicMock() + request.client.host = "127.0.0.1" + + # Mock security_logger to prevent database operations + mock_security_logger = MagicMock() + mock_security_logger.log_authentication_attempt = MagicMock(return_value=None) with patch("mcpgateway.middleware.auth_middleware.SessionLocal", return_value=MagicMock()) as mock_session, \ patch("mcpgateway.middleware.auth_middleware.get_current_user", AsyncMock(side_effect=Exception("Invalid token"))), \ - patch("mcpgateway.middleware.auth_middleware.logger") as mock_logger: + patch("mcpgateway.middleware.auth_middleware.logger") as mock_logger, \ + patch("mcpgateway.middleware.auth_middleware.security_logger", mock_security_logger): response = await middleware.dispatch(request, call_next) call_next.assert_awaited_once_with(request) diff --git a/tests/unit/mcpgateway/middleware/test_request_logging_middleware.py b/tests/unit/mcpgateway/middleware/test_request_logging_middleware.py index 30a2a3c26..e905d9716 100644 --- a/tests/unit/mcpgateway/middleware/test_request_logging_middleware.py +++ b/tests/unit/mcpgateway/middleware/test_request_logging_middleware.py @@ -7,6 +7,7 @@ """ import json import pytest +from unittest.mock import MagicMock from fastapi import Request, Response from starlette.datastructures import Headers from starlette.types import Scope @@ -28,7 +29,7 @@ def __init__(self): def isEnabledFor(self, level): return self.enabled - def log(self, level, msg): + def log(self, level, msg, extra=None): self.logged.append((level, msg)) def warning(self, msg): @@ -40,6 +41,15 @@ def dummy_logger(monkeypatch): monkeypatch.setattr("mcpgateway.middleware.request_logging_middleware.logger", logger) return logger + +@pytest.fixture +def mock_structured_logger(monkeypatch): + """Mock the structured_logger to prevent database writes.""" + mock_logger = MagicMock() + mock_logger.log = MagicMock() + monkeypatch.setattr("mcpgateway.middleware.request_logging_middleware.structured_logger", mock_logger) + return mock_logger + @pytest.fixture def dummy_call_next(): async def _call_next(request): @@ -112,8 +122,8 @@ def test_mask_sensitive_headers_non_sensitive(): # --- RequestLoggingMiddleware tests --- @pytest.mark.asyncio -async def test_dispatch_logs_json_body(dummy_logger, dummy_call_next): - middleware = RequestLoggingMiddleware(app=None) +async def test_dispatch_logs_json_body(dummy_logger, mock_structured_logger, dummy_call_next): + middleware = RequestLoggingMiddleware(app=None, enable_gateway_logging=False, log_detailed_requests=True) body = json.dumps({"password": "123", "data": "ok"}).encode() request = make_request(body=body, headers={"Authorization": "Bearer abc"}) response = await middleware.dispatch(request, dummy_call_next) @@ -122,8 +132,8 @@ async def test_dispatch_logs_json_body(dummy_logger, dummy_call_next): assert "******" in dummy_logger.logged[0][1] @pytest.mark.asyncio -async def test_dispatch_logs_non_json_body(dummy_logger, dummy_call_next): - middleware = RequestLoggingMiddleware(app=None) +async def test_dispatch_logs_non_json_body(dummy_logger, mock_structured_logger, dummy_call_next): + middleware = RequestLoggingMiddleware(app=None, enable_gateway_logging=False, log_detailed_requests=True) body = b"token=abc" request = make_request(body=body) response = await middleware.dispatch(request, dummy_call_next) @@ -131,8 +141,8 @@ async def test_dispatch_logs_non_json_body(dummy_logger, dummy_call_next): assert any("" in msg for _, msg in dummy_logger.logged) @pytest.mark.asyncio -async def test_dispatch_large_body_truncated(dummy_logger, dummy_call_next): - middleware = RequestLoggingMiddleware(app=None, max_body_size=10) +async def test_dispatch_large_body_truncated(dummy_logger, mock_structured_logger, dummy_call_next): + middleware = RequestLoggingMiddleware(app=None, enable_gateway_logging=False, log_detailed_requests=True, max_body_size=10) body = b"{" + b"a" * 100 + b"}" request = make_request(body=body) response = await middleware.dispatch(request, dummy_call_next) @@ -140,8 +150,8 @@ async def test_dispatch_large_body_truncated(dummy_logger, dummy_call_next): assert any("[truncated]" in msg for _, msg in dummy_logger.logged) @pytest.mark.asyncio -async def test_dispatch_logging_disabled(dummy_logger, dummy_call_next): - middleware = RequestLoggingMiddleware(app=None, log_requests=False) +async def test_dispatch_logging_disabled(dummy_logger, mock_structured_logger, dummy_call_next): + middleware = RequestLoggingMiddleware(app=None, enable_gateway_logging=False, log_detailed_requests=False) body = b"{}" request = make_request(body=body) response = await middleware.dispatch(request, dummy_call_next) @@ -149,9 +159,9 @@ async def test_dispatch_logging_disabled(dummy_logger, dummy_call_next): assert dummy_logger.logged == [] @pytest.mark.asyncio -async def test_dispatch_logger_disabled(dummy_logger, dummy_call_next): +async def test_dispatch_logger_disabled(dummy_logger, mock_structured_logger, dummy_call_next): dummy_logger.enabled = False - middleware = RequestLoggingMiddleware(app=None) + middleware = RequestLoggingMiddleware(app=None, enable_gateway_logging=False, log_detailed_requests=True) body = b"{}" request = make_request(body=body) response = await middleware.dispatch(request, dummy_call_next) @@ -159,12 +169,12 @@ async def test_dispatch_logger_disabled(dummy_logger, dummy_call_next): assert dummy_logger.logged == [] @pytest.mark.asyncio -async def test_dispatch_exception_handling(dummy_logger, dummy_call_next, monkeypatch): +async def test_dispatch_exception_handling(dummy_logger, mock_structured_logger, dummy_call_next, monkeypatch): async def bad_body(): raise ValueError("fail") request = make_request() monkeypatch.setattr(request, "body", bad_body) - middleware = RequestLoggingMiddleware(app=None) + middleware = RequestLoggingMiddleware(app=None, enable_gateway_logging=False, log_detailed_requests=True) response = await middleware.dispatch(request, dummy_call_next) assert response.status_code == 200 assert any("Failed to log request body" in msg for msg in dummy_logger.warnings) diff --git a/tests/unit/mcpgateway/services/test_a2a_service.py b/tests/unit/mcpgateway/services/test_a2a_service.py index 34a2e34b2..0b45fe87c 100644 --- a/tests/unit/mcpgateway/services/test_a2a_service.py +++ b/tests/unit/mcpgateway/services/test_a2a_service.py @@ -21,6 +21,21 @@ from mcpgateway.schemas import A2AAgentCreate, A2AAgentUpdate from mcpgateway.services.a2a_service import A2AAgentError, A2AAgentNameConflictError, A2AAgentNotFoundError, A2AAgentService + +@pytest.fixture(autouse=True) +def mock_logging_services(): + """Mock structured_logger and audit_trail to prevent database writes during tests.""" + with patch("mcpgateway.services.a2a_service.structured_logger") as mock_a2a_logger, \ + patch("mcpgateway.services.tool_service.structured_logger") as mock_tool_logger, \ + patch("mcpgateway.services.tool_service.audit_trail") as mock_tool_audit: + mock_a2a_logger.log = MagicMock(return_value=None) + mock_a2a_logger.info = MagicMock(return_value=None) + mock_tool_logger.log = MagicMock(return_value=None) + mock_tool_logger.info = MagicMock(return_value=None) + mock_tool_audit.log_action = MagicMock(return_value=None) + yield {"structured_logger": mock_a2a_logger, "tool_logger": mock_tool_logger, "tool_audit": mock_tool_audit} + + class TestA2AAgentService: """Test suite for A2A Agent Service.""" diff --git a/tests/unit/mcpgateway/services/test_gateway_service.py b/tests/unit/mcpgateway/services/test_gateway_service.py index 0a6d4e06f..7e443279b 100644 --- a/tests/unit/mcpgateway/services/test_gateway_service.py +++ b/tests/unit/mcpgateway/services/test_gateway_service.py @@ -67,6 +67,16 @@ def _make_execute_result(*, scalar: _R | None = None, scalars_list: list[_R] | N return result +@pytest.fixture(autouse=True) +def mock_logging_services(): + """Mock audit_trail and structured_logger to prevent database writes during tests.""" + with patch("mcpgateway.services.gateway_service.audit_trail") as mock_audit, \ + patch("mcpgateway.services.gateway_service.structured_logger") as mock_logger: + mock_audit.log_action = MagicMock(return_value=None) + mock_logger.log = MagicMock(return_value=None) + yield {"audit_trail": mock_audit, "structured_logger": mock_logger} + + @pytest.fixture(autouse=True) def _bypass_gatewayread_validation(monkeypatch): """ diff --git a/tests/unit/mcpgateway/services/test_prompt_service.py b/tests/unit/mcpgateway/services/test_prompt_service.py index a26a90240..dbbcb806a 100644 --- a/tests/unit/mcpgateway/services/test_prompt_service.py +++ b/tests/unit/mcpgateway/services/test_prompt_service.py @@ -44,6 +44,16 @@ # --------------------------------------------------------------------------- +@pytest.fixture(autouse=True) +def mock_logging_services(): + """Mock audit_trail and structured_logger to prevent database writes during tests.""" + with patch("mcpgateway.services.prompt_service.audit_trail") as mock_audit, \ + patch("mcpgateway.services.prompt_service.structured_logger") as mock_logger: + mock_audit.log_action = MagicMock(return_value=None) + mock_logger.log = MagicMock(return_value=None) + yield {"audit_trail": mock_audit, "structured_logger": mock_logger} + + @pytest.fixture def mock_prompt(): """Create a mock prompt model.""" diff --git a/tests/unit/mcpgateway/services/test_resource_ownership.py b/tests/unit/mcpgateway/services/test_resource_ownership.py index 6c70cb399..c3e4f82d6 100644 --- a/tests/unit/mcpgateway/services/test_resource_ownership.py +++ b/tests/unit/mcpgateway/services/test_resource_ownership.py @@ -26,6 +26,26 @@ from mcpgateway.services.a2a_service import A2AAgentService +@pytest.fixture(autouse=True) +def mock_logging_services(): + """Mock audit_trail and structured_logger to prevent database writes during tests.""" + with patch("mcpgateway.services.gateway_service.audit_trail") as mock_gw_audit, \ + patch("mcpgateway.services.gateway_service.structured_logger") as mock_gw_logger, \ + patch("mcpgateway.services.tool_service.audit_trail") as mock_tool_audit, \ + patch("mcpgateway.services.tool_service.structured_logger") as mock_tool_logger, \ + patch("mcpgateway.services.resource_service.audit_trail") as mock_res_audit, \ + patch("mcpgateway.services.resource_service.structured_logger") as mock_res_logger, \ + patch("mcpgateway.services.prompt_service.audit_trail") as mock_prompt_audit, \ + patch("mcpgateway.services.prompt_service.structured_logger") as mock_prompt_logger, \ + patch("mcpgateway.services.a2a_service.structured_logger") as mock_a2a_logger: + for mock in [mock_gw_audit, mock_tool_audit, mock_res_audit, mock_prompt_audit]: + mock.log_action = MagicMock(return_value=None) + for mock in [mock_gw_logger, mock_tool_logger, mock_res_logger, mock_prompt_logger, mock_a2a_logger]: + mock.log = MagicMock(return_value=None) + mock.info = MagicMock(return_value=None) + yield + + @pytest.fixture def mock_db_session(): """Create a mock database session.""" diff --git a/tests/unit/mcpgateway/services/test_resource_service.py b/tests/unit/mcpgateway/services/test_resource_service.py index 23fecfc64..2c9cb1517 100644 --- a/tests/unit/mcpgateway/services/test_resource_service.py +++ b/tests/unit/mcpgateway/services/test_resource_service.py @@ -37,6 +37,16 @@ # --------------------------------------------------------------------------- # +@pytest.fixture(autouse=True) +def mock_logging_services(): + """Mock audit_trail and structured_logger to prevent database writes during tests.""" + with patch("mcpgateway.services.resource_service.audit_trail") as mock_audit, \ + patch("mcpgateway.services.resource_service.structured_logger") as mock_logger: + mock_audit.log_action = MagicMock(return_value=None) + mock_logger.log = MagicMock(return_value=None) + yield {"audit_trail": mock_audit, "structured_logger": mock_logger} + + @pytest.fixture def resource_service(monkeypatch): """Create a ResourceService instance.""" diff --git a/tests/unit/mcpgateway/services/test_tool_service.py b/tests/unit/mcpgateway/services/test_tool_service.py index 5beeeab27..46c520906 100644 --- a/tests/unit/mcpgateway/services/test_tool_service.py +++ b/tests/unit/mcpgateway/services/test_tool_service.py @@ -36,6 +36,16 @@ from mcpgateway.utils.services_auth import encode_auth +@pytest.fixture(autouse=True) +def mock_logging_services(): + """Mock audit_trail and structured_logger to prevent database writes during tests.""" + with patch("mcpgateway.services.tool_service.audit_trail") as mock_audit, \ + patch("mcpgateway.services.tool_service.structured_logger") as mock_logger: + mock_audit.log_action = MagicMock(return_value=None) + mock_logger.log = MagicMock(return_value=None) + yield {"audit_trail": mock_audit, "structured_logger": mock_logger} + + @pytest.fixture def tool_service(): """Create a tool service instance.""" diff --git a/tests/unit/mcpgateway/test_main.py b/tests/unit/mcpgateway/test_main.py index 94e0beb80..fabf401bc 100644 --- a/tests/unit/mcpgateway/test_main.py +++ b/tests/unit/mcpgateway/test_main.py @@ -193,13 +193,20 @@ def test_client(app): # Patch the auth function used by DocsAuthMiddleware # Standard - from unittest.mock import patch + from unittest.mock import MagicMock, patch # Third-Party from fastapi import HTTPException, status # First-Party + # Mock security_logger to prevent database access + mock_sec_logger = MagicMock() + mock_sec_logger.log_authentication_attempt = MagicMock(return_value=None) + mock_sec_logger.log_security_event = MagicMock(return_value=None) + sec_patcher = patch("mcpgateway.middleware.auth_middleware.security_logger", mock_sec_logger) + sec_patcher.start() + # Create a mock that validates JWT tokens properly async def mock_require_auth_override(auth_header=None, jwt_token=None): # Third-Party @@ -270,6 +277,7 @@ async def mock_check_permission(self, user_email: str, permission: str, resource app.dependency_overrides.pop(get_current_user, None) app.dependency_overrides.pop(get_current_user_with_permissions, None) patcher.stop() # Stop the require_auth_override patch + sec_patcher.stop() # Stop the security_logger patch if hasattr(PermissionService, "_original_check_permission"): PermissionService.check_permission = PermissionService._original_check_permission diff --git a/tests/unit/mcpgateway/test_main_extended.py b/tests/unit/mcpgateway/test_main_extended.py index ceb40f763..2079b692d 100644 --- a/tests/unit/mcpgateway/test_main_extended.py +++ b/tests/unit/mcpgateway/test_main_extended.py @@ -324,7 +324,7 @@ def test_server_toggle_edge_cases(self, test_client, auth_headers): def test_client(app): """Test client with auth override for testing protected endpoints.""" # Standard - from unittest.mock import patch + from unittest.mock import MagicMock, patch # First-Party from mcpgateway.auth import get_current_user @@ -341,6 +341,13 @@ def test_client(app): auth_provider="test", ) + # Mock security_logger to prevent database access + mock_sec_logger = MagicMock() + mock_sec_logger.log_authentication_attempt = MagicMock(return_value=None) + mock_sec_logger.log_security_event = MagicMock(return_value=None) + sec_patcher = patch("mcpgateway.middleware.auth_middleware.security_logger", mock_sec_logger) + sec_patcher.start() + # Mock require_auth_override function def mock_require_auth_override(user: str) -> str: return user @@ -390,6 +397,7 @@ async def mock_check_permission( app.dependency_overrides.pop(get_current_user, None) app.dependency_overrides.pop(get_current_user_with_permissions, None) patcher.stop() # Stop the require_auth_override patch + sec_patcher.stop() # Stop the security_logger patch if hasattr(PermissionService, "_original_check_permission"): PermissionService.check_permission = PermissionService._original_check_permission diff --git a/tests/unit/mcpgateway/utils/test_verify_credentials.py b/tests/unit/mcpgateway/utils/test_verify_credentials.py index dabf49f63..ada73ecc4 100644 --- a/tests/unit/mcpgateway/utils/test_verify_credentials.py +++ b/tests/unit/mcpgateway/utils/test_verify_credentials.py @@ -281,9 +281,16 @@ async def test_require_auth_override_basic_auth_disabled(monkeypatch): @pytest.fixture -def test_client(): - if app is None: - pytest.skip("FastAPI app not importable") +def test_client(app, monkeypatch): + """Create a test client with the properly configured app fixture from conftest.""" + from unittest.mock import MagicMock + + # Patch security_logger at the middleware level where it's imported and called + mock_sec_logger = MagicMock() + mock_sec_logger.log_authentication_attempt = MagicMock(return_value=None) + mock_sec_logger.log_security_event = MagicMock(return_value=None) + monkeypatch.setattr("mcpgateway.middleware.auth_middleware.security_logger", mock_sec_logger) + return TestClient(app) From 8c0771214c3e75162a9785a0c6200cc42feea6a8 Mon Sep 17 00:00:00 2001 From: Shoumi Date: Thu, 27 Nov 2025 17:05:36 +0530 Subject: [PATCH 21/34] flake8 fixes Signed-off-by: Shoumi --- mcpgateway/main.py | 6 ++- .../middleware/request_logging_middleware.py | 12 ++++- .../framework/external/mcp/tls_utils.py | 2 +- mcpgateway/routers/log_search.py | 53 +++++++++++++++--- mcpgateway/services/audit_trail_service.py | 11 +++- mcpgateway/services/log_aggregator.py | 54 +++++++++++++++++-- mcpgateway/services/performance_tracker.py | 14 ++++- mcpgateway/services/structured_logger.py | 37 +++++++++++-- mcpgateway/services/tool_service.py | 6 +++ mcpgateway/utils/retry_manager.py | 2 +- 10 files changed, 173 insertions(+), 24 deletions(-) diff --git a/mcpgateway/main.py b/mcpgateway/main.py index 96e5a6261..09b0dff20 100644 --- a/mcpgateway/main.py +++ b/mcpgateway/main.py @@ -483,7 +483,11 @@ async def run_log_backfill() -> None: logger.warning("Log aggregation backfill failed: %s", backfill_error) async def run_log_aggregation_loop() -> None: - """Run continuous log aggregation at configured intervals.""" + """Run continuous log aggregation at configured intervals. + + Raises: + asyncio.CancelledError: When aggregation is stopped + """ interval_seconds = max(1, int(settings.metrics_aggregation_window_minutes)) * 60 logger.info( "Starting log aggregation loop (window=%s min)", diff --git a/mcpgateway/middleware/request_logging_middleware.py b/mcpgateway/middleware/request_logging_middleware.py index 499effb74..cdb75e413 100644 --- a/mcpgateway/middleware/request_logging_middleware.py +++ b/mcpgateway/middleware/request_logging_middleware.py @@ -133,7 +133,14 @@ def __init__(self, app, enable_gateway_logging: bool = True, log_detailed_reques self.max_body_size = max_body_size # Expected to be in bytes async def _resolve_user_identity(self, request: Request): - """Best-effort extraction of user identity for request logs.""" + """Best-effort extraction of user identity for request logs. + + Args: + request: The incoming HTTP request + + Returns: + Tuple[Optional[str], Optional[str]]: User ID and email + """ # Prefer context injected by upstream middleware if hasattr(request.state, "user") and request.state.user is not None: raw_user_id = getattr(request.state.user, "id", None) @@ -179,6 +186,9 @@ async def dispatch(self, request: Request, call_next: Callable): Returns: Response: The HTTP response from downstream handlers + + Raises: + Exception: Any exception from downstream handlers is re-raised """ # Track start time for total duration start_time = time.time() diff --git a/mcpgateway/plugins/framework/external/mcp/tls_utils.py b/mcpgateway/plugins/framework/external/mcp/tls_utils.py index befac4e51..137bf0950 100644 --- a/mcpgateway/plugins/framework/external/mcp/tls_utils.py +++ b/mcpgateway/plugins/framework/external/mcp/tls_utils.py @@ -86,7 +86,7 @@ def create_ssl_context(tls_config: MCPClientTLSConfig, plugin_name: str) -> ssl. # Disable certificate verification (not recommended for production) logger.warning(f"Certificate verification disabled for plugin '{plugin_name}'. This is not recommended for production use.") ssl_context.check_hostname = False - ssl_context.verify_mode = ssl.CERT_NONE # nosec + ssl_context.verify_mode = ssl.CERT_NONE # nosec B502 else: # Enable strict certificate verification (production mode) # Load CA certificate bundle for server certificate validation diff --git a/mcpgateway/routers/log_search.py b/mcpgateway/routers/log_search.py index b2ffc611e..ab8cdc919 100644 --- a/mcpgateway/routers/log_search.py +++ b/mcpgateway/routers/log_search.py @@ -46,7 +46,15 @@ def _align_to_window(dt: datetime, window_minutes: int) -> datetime: - """Align a datetime down to the nearest aggregation window boundary.""" + """Align a datetime down to the nearest aggregation window boundary. + + Args: + dt: Datetime to align + window_minutes: Aggregation window size in minutes + + Returns: + datetime: Aligned datetime at window boundary + """ timestamp = dt.astimezone(timezone.utc) total_minutes = int(timestamp.timestamp() // 60) aligned_minutes = (total_minutes // window_minutes) * window_minutes @@ -54,7 +62,14 @@ def _align_to_window(dt: datetime, window_minutes: int) -> datetime: def _deduplicate_metrics(metrics: List[PerformanceMetric]) -> List[PerformanceMetric]: - """Ensure a single metric per component/operation/window.""" + """Ensure a single metric per component/operation/window. + + Args: + metrics: List of performance metrics to deduplicate + + Returns: + List[PerformanceMetric]: Deduplicated metrics sorted by window_start + """ if not metrics: return [] @@ -75,7 +90,13 @@ def _aggregate_custom_windows( window_minutes: int, db: Session, ) -> None: - """Aggregate metrics using custom window duration.""" + """Aggregate metrics using custom window duration. + + Args: + aggregator: Log aggregator instance + window_minutes: Window size in minutes + db: Database session + """ window_delta = timedelta(minutes=window_minutes) window_duration_seconds = window_minutes * 60 @@ -278,11 +299,14 @@ async def search_logs(request: LogSearchRequest, user=Depends(get_current_user_w Args: request: Search parameters + user: Current authenticated user db: Database session - _: Permission check dependency Returns: Search results with pagination + + Raises: + HTTPException: On database or validation errors """ try: # Build base query @@ -383,11 +407,14 @@ async def trace_correlation_id(correlation_id: str, user=Depends(get_current_use Args: correlation_id: Correlation ID to trace + user: Current authenticated user db: Database session - _: Permission check dependency Returns: Complete trace of all related logs and events + + Raises: + HTTPException: On database or validation errors """ try: # Get structured logs @@ -509,11 +536,14 @@ async def get_security_events( end_time: End timestamp limit: Maximum results offset: Result offset + user: Current authenticated user db: Database session - _: Permission check dependency Returns: List of security events + + Raises: + HTTPException: On database or validation errors """ try: stmt = select(SecurityEvent) @@ -584,11 +614,14 @@ async def get_audit_trails( end_time: End timestamp limit: Maximum results offset: Result offset + user: Current authenticated user db: Database session - _: Permission check dependency Returns: List of audit trail entries + + Raises: + HTTPException: On database or validation errors """ try: stmt = select(AuditTrail) @@ -652,12 +685,16 @@ async def get_performance_metrics( Args: component: Filter by component operation: Filter by operation + aggregation: Aggregation level (5m, 1h, 1d, 7d) hours: Hours of history + user: Current authenticated user db: Database session - _: Permission check dependency Returns: List of performance metrics + + Raises: + HTTPException: On database or validation errors """ try: aggregation_config = _AGGREGATION_LEVELS.get(aggregation, _AGGREGATION_LEVELS[_DEFAULT_AGGREGATION_KEY]) diff --git a/mcpgateway/services/audit_trail_service.py b/mcpgateway/services/audit_trail_service.py index c4bcf83a9..3d9023bfe 100644 --- a/mcpgateway/services/audit_trail_service.py +++ b/mcpgateway/services/audit_trail_service.py @@ -196,7 +196,16 @@ def _determine_requires_review( data_classification: Optional[str], requires_review_param: Optional[bool], ) -> bool: - """Resolve whether an audit entry should require review.""" + """Resolve whether an audit entry should require review. + + Args: + action: Action being performed + data_classification: Data classification level + requires_review_param: Explicit review requirement + + Returns: + bool: Whether the audit entry requires review + """ if requires_review_param is not None: return requires_review_param diff --git a/mcpgateway/services/log_aggregator.py b/mcpgateway/services/log_aggregator.py index 9a8bd5bf2..7e4443680 100644 --- a/mcpgateway/services/log_aggregator.py +++ b/mcpgateway/services/log_aggregator.py @@ -42,7 +42,7 @@ def aggregate_performance_metrics( Args: component: Component name - operation: Operation name + operation_type: Operation name window_start: Start of aggregation window (defaults to N minutes ago) window_end: End of aggregation window (defaults to now) db: Optional database session @@ -335,7 +335,15 @@ def backfill(self, hours: float, db: Optional[Session] = None) -> int: @staticmethod def _percentile(sorted_values: List[float], percentile: float) -> float: - """Calculate percentile from sorted values.""" + """Calculate percentile from sorted values. + + Args: + sorted_values: Sorted list of values + percentile: Percentile to calculate (0.0 to 1.0) + + Returns: + float: Calculated percentile value + """ if not sorted_values: return 0.0 @@ -355,7 +363,14 @@ def _percentile(sorted_values: List[float], percentile: float) -> float: @staticmethod def _calculate_error_count(entries: List[StructuredLogEntry]) -> int: - """Calculate error occurrences for a batch of log entries.""" + """Calculate error occurrences for a batch of log entries. + + Args: + entries: List of log entries to analyze + + Returns: + int: Count of error entries + """ error_levels = {"ERROR", "CRITICAL"} return sum(1 for entry in entries if (entry.level and entry.level.upper() in error_levels) or entry.error_details) @@ -364,7 +379,15 @@ def _resolve_window_bounds( window_start: Optional[datetime], window_end: Optional[datetime], ) -> Tuple[datetime, datetime]: - """Resolve and normalize aggregation window bounds.""" + """Resolve and normalize aggregation window bounds. + + Args: + window_start: Start of window or None to calculate + window_end: End of window or None for current time + + Returns: + Tuple[datetime, datetime]: Resolved window start and end + """ window_delta = timedelta(minutes=self.aggregation_window_minutes) if window_start is not None and window_end is not None: @@ -414,7 +437,28 @@ def _upsert_metric( metric_metadata: Optional[Dict[str, Any]], db: Session, ) -> PerformanceMetric: - """Create or update a performance metric window.""" + """Create or update a performance metric window. + + Args: + component: Component name + operation_type: Operation type + window_start: Window start time + window_end: Window end time + request_count: Total request count + error_count: Total error count + error_rate: Error rate (0.0-1.0) + avg_duration_ms: Average duration in milliseconds + min_duration_ms: Minimum duration in milliseconds + max_duration_ms: Maximum duration in milliseconds + p50_duration_ms: 50th percentile duration + p95_duration_ms: 95th percentile duration + p99_duration_ms: 99th percentile duration + metric_metadata: Additional metadata + db: Database session + + Returns: + PerformanceMetric: Created or updated metric + """ existing_stmt = select(PerformanceMetric).where( and_( diff --git a/mcpgateway/services/performance_tracker.py b/mcpgateway/services/performance_tracker.py index ef0996c0b..dcf813979 100644 --- a/mcpgateway/services/performance_tracker.py +++ b/mcpgateway/services/performance_tracker.py @@ -64,6 +64,9 @@ def track_operation(self, operation_name: str, component: Optional[str] = None, Yields: None + Raises: + Exception: Any exception from the tracked operation is re-raised + Example: >>> tracker = PerformanceTracker() >>> with tracker.track_operation("database_query", component="tool_service"): @@ -167,7 +170,16 @@ def get_performance_summary(self, operation_name: Optional[str] = None, min_samp count = len(sorted_timings) def percentile(p: float, *, sorted_vals=sorted_timings, n=count) -> float: - """Calculate percentile value.""" + """Calculate percentile value. + + Args: + p: Percentile to calculate (0.0 to 1.0) + sorted_vals: Sorted list of values + n: Number of values + + Returns: + float: Calculated percentile value + """ k = (n - 1) * p f = int(k) c = k - f diff --git a/mcpgateway/services/structured_logger.py b/mcpgateway/services/structured_logger.py index 4bff2094f..324d708c4 100644 --- a/mcpgateway/services/structured_logger.py +++ b/mcpgateway/services/structured_logger.py @@ -355,23 +355,50 @@ def log( self.router.route(entry, db) def debug(self, message: str, **kwargs: Any) -> None: - """Log debug message.""" + """Log debug message. + + Args: + message: Log message + **kwargs: Additional context fields + """ self.log(LogLevel.DEBUG, message, **kwargs) def info(self, message: str, **kwargs: Any) -> None: - """Log info message.""" + """Log info message. + + Args: + message: Log message + **kwargs: Additional context fields + """ self.log(LogLevel.INFO, message, **kwargs) def warning(self, message: str, **kwargs: Any) -> None: - """Log warning message.""" + """Log warning message. + + Args: + message: Log message + **kwargs: Additional context fields + """ self.log(LogLevel.WARNING, message, **kwargs) def error(self, message: str, error: Optional[Exception] = None, **kwargs: Any) -> None: - """Log error message.""" + """Log error message. + + Args: + message: Log message + error: Exception object if available + **kwargs: Additional context fields + """ self.log(LogLevel.ERROR, message, error=error, **kwargs) def critical(self, message: str, error: Optional[Exception] = None, **kwargs: Any) -> None: - """Log critical message.""" + """Log critical message. + + Args: + message: Log message + error: Exception object if available + **kwargs: Additional context fields + """ self.log(LogLevel.CRITICAL, message, error=error, **kwargs) diff --git a/mcpgateway/services/tool_service.py b/mcpgateway/services/tool_service.py index 1cdfafb0c..6c0778015 100644 --- a/mcpgateway/services/tool_service.py +++ b/mcpgateway/services/tool_service.py @@ -1688,6 +1688,9 @@ async def connect_to_sse_server(server_url: str, headers: dict = headers): Returns: ToolResult: Result of tool call + + Raises: + Exception: On connection or communication errors """ # Get correlation ID for distributed tracing correlation_id = get_correlation_id() @@ -1747,6 +1750,9 @@ async def connect_to_streamablehttp_server(server_url: str, headers: dict = head Returns: ToolResult: Result of tool call + + Raises: + Exception: On connection or communication errors """ # Get correlation ID for distributed tracing correlation_id = get_correlation_id() diff --git a/mcpgateway/utils/retry_manager.py b/mcpgateway/utils/retry_manager.py index 613d5736d..3e721167e 100644 --- a/mcpgateway/utils/retry_manager.py +++ b/mcpgateway/utils/retry_manager.py @@ -301,7 +301,7 @@ async def _sleep_with_jitter(self, base: float, jitter_range: float): True """ # random.uniform() is safe here as jitter is only used for retry timing, not security - delay = base + random.uniform(0, jitter_range) # nosec B311 + delay = base + random.uniform(0, jitter_range) # nosec B311 # noqa: DUO102 # Ensure delay doesn't exceed the max allowed delay = min(delay, self.max_delay) await asyncio.sleep(delay) From 8f3e2196b1f301b97430c953080bb89ef970c4d8 Mon Sep 17 00:00:00 2001 From: Shoumi Date: Thu, 27 Nov 2025 17:42:22 +0530 Subject: [PATCH 22/34] flake8 issue Signed-off-by: Shoumi --- mcpgateway/plugins/framework/external/mcp/tls_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mcpgateway/plugins/framework/external/mcp/tls_utils.py b/mcpgateway/plugins/framework/external/mcp/tls_utils.py index 137bf0950..11d4d0acf 100644 --- a/mcpgateway/plugins/framework/external/mcp/tls_utils.py +++ b/mcpgateway/plugins/framework/external/mcp/tls_utils.py @@ -86,7 +86,7 @@ def create_ssl_context(tls_config: MCPClientTLSConfig, plugin_name: str) -> ssl. # Disable certificate verification (not recommended for production) logger.warning(f"Certificate verification disabled for plugin '{plugin_name}'. This is not recommended for production use.") ssl_context.check_hostname = False - ssl_context.verify_mode = ssl.CERT_NONE # nosec B502 + ssl_context.verify_mode = ssl.CERT_NONE # nosec B502 # noqa: DUO122 else: # Enable strict certificate verification (production mode) # Load CA certificate bundle for server certificate validation From 2bad146f9a87e99ebd7d0ef74ab996fcbbf04460 Mon Sep 17 00:00:00 2001 From: Shoumi Date: Thu, 27 Nov 2025 18:39:30 +0530 Subject: [PATCH 23/34] prevent SQLite rollback error on validation failures Signed-off-by: Shoumi --- mcpgateway/admin.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/mcpgateway/admin.py b/mcpgateway/admin.py index 0befe2593..abc0ae5de 100644 --- a/mcpgateway/admin.py +++ b/mcpgateway/admin.py @@ -47,7 +47,7 @@ from pydantic import SecretStr, ValidationError from pydantic_core import ValidationError as CoreValidationError from sqlalchemy import and_, case, cast, desc, func, or_, select, String -from sqlalchemy.exc import IntegrityError +from sqlalchemy.exc import IntegrityError, InvalidRequestError, OperationalError from sqlalchemy.orm import joinedload, Session from sqlalchemy.sql.functions import coalesce from starlette.datastructures import UploadFile as StarletteUploadFile @@ -8371,6 +8371,17 @@ async def admin_add_resource(request: Request, db: Session = Depends(get_db), us status_code=200, ) except Exception as ex: + # Roll back only when a transaction is active to avoid sqlite3 "no transaction" errors. + try: + active_transaction = db.get_transaction() if hasattr(db, "get_transaction") else None + if db.is_active and active_transaction is not None: + db.rollback() + except (InvalidRequestError, OperationalError) as rollback_error: + LOGGER.warning( + "Rollback failed (ignoring for SQLite compatibility): %s", + rollback_error, + ) + if isinstance(ex, ValidationError): LOGGER.error(f"ValidationError in admin_add_resource: {ErrorFormatter.format_validation_error(ex)}") return JSONResponse(content=ErrorFormatter.format_validation_error(ex), status_code=422) From f598f22c6f4f3f8fba6103f7c1d52bd6d9d17d42 Mon Sep 17 00:00:00 2001 From: Shoumi Date: Thu, 27 Nov 2025 19:09:19 +0530 Subject: [PATCH 24/34] false positive issues Signed-off-by: Shoumi --- mcpgateway/middleware/request_logging_middleware.py | 2 +- mcpgateway/services/security_logger.py | 2 +- mcpgateway/services/structured_logger.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mcpgateway/middleware/request_logging_middleware.py b/mcpgateway/middleware/request_logging_middleware.py index cdb75e413..8506ed2b5 100644 --- a/mcpgateway/middleware/request_logging_middleware.py +++ b/mcpgateway/middleware/request_logging_middleware.py @@ -174,7 +174,7 @@ async def _resolve_user_identity(self, request: Request): if db: try: db.close() - except Exception: + except Exception: # nosec B110 - Silently handle db.close() failures during cleanup pass async def dispatch(self, request: Request, call_next: Callable): diff --git a/mcpgateway/services/security_logger.py b/mcpgateway/services/security_logger.py index 0d93f092d..0a2f1f7b6 100644 --- a/mcpgateway/services/security_logger.py +++ b/mcpgateway/services/security_logger.py @@ -45,7 +45,7 @@ class SecurityEventType(str, Enum): SUSPICIOUS_ACTIVITY = "suspicious_activity" RATE_LIMIT_EXCEEDED = "rate_limit_exceeded" BRUTE_FORCE_ATTEMPT = "brute_force_attempt" - TOKEN_MANIPULATION = "token_manipulation" + TOKEN_MANIPULATION = "token_manipulation" # nosec B105 - Not a password, security event type constant DATA_EXFILTRATION = "data_exfiltration" PRIVILEGE_ESCALATION = "privilege_escalation" INJECTION_ATTEMPT = "injection_attempt" diff --git a/mcpgateway/services/structured_logger.py b/mcpgateway/services/structured_logger.py index 324d708c4..6f7d3afd7 100644 --- a/mcpgateway/services/structured_logger.py +++ b/mcpgateway/services/structured_logger.py @@ -88,7 +88,7 @@ def enrich(entry: Dict[str, Any]) -> Dict[str, Any]: current_ops = perf_tracker.get_current_operations(correlation_id) # pylint: disable=no-member if current_ops: entry["active_operations"] = len(current_ops) - except Exception: + except Exception: # nosec B110 - Graceful degradation if performance tracker unavailable # Silently skip if performance tracker is unavailable or method doesn't exist pass From 293bbf13a607f941e5dbceee12b2645ffbfceba2 Mon Sep 17 00:00:00 2001 From: Shoumi Date: Fri, 28 Nov 2025 14:49:50 +0530 Subject: [PATCH 25/34] fix lint issue Signed-off-by: Shoumi --- mcpgateway/static/admin.js | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mcpgateway/static/admin.js b/mcpgateway/static/admin.js index a62a85671..61bd603d1 100644 --- a/mcpgateway/static/admin.js +++ b/mcpgateway/static/admin.js @@ -24003,6 +24003,8 @@ function updateEntityStatus(type, data) { const isEnabled = data.enabled !== undefined ? data.enabled : data.isActive; updateEntityActionButtons(actionCell, type, data.id, isEnabled); + } +} // ============================================================================ // Structured Logging UI Functions // ============================================================================ From 547c1f435a760b89656e7cf55883c418096cdbc4 Mon Sep 17 00:00:00 2001 From: Shoumi Date: Mon, 1 Dec 2025 18:03:26 +0530 Subject: [PATCH 26/34] update alembic file Signed-off-by: Shoumi --- .../versions/k5e6f7g8h9i0_add_structured_logging_tables.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mcpgateway/alembic/versions/k5e6f7g8h9i0_add_structured_logging_tables.py b/mcpgateway/alembic/versions/k5e6f7g8h9i0_add_structured_logging_tables.py index 951840ba2..98b45eafe 100644 --- a/mcpgateway/alembic/versions/k5e6f7g8h9i0_add_structured_logging_tables.py +++ b/mcpgateway/alembic/versions/k5e6f7g8h9i0_add_structured_logging_tables.py @@ -11,7 +11,7 @@ # revision identifiers, used by Alembic. revision = "k5e6f7g8h9i0" -down_revision = "z1a2b3c4d5e6" +down_revision = "add_toolops_test_cases_table" branch_labels = None depends_on = None From 0a38f73ce6be1ede0b06e631533fea1aca927af7 Mon Sep 17 00:00:00 2001 From: Shoumi Date: Mon, 1 Dec 2025 19:25:50 +0530 Subject: [PATCH 27/34] updated alembic revision Signed-off-by: Shoumi --- .../versions/k5e6f7g8h9i0_add_structured_logging_tables.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mcpgateway/alembic/versions/k5e6f7g8h9i0_add_structured_logging_tables.py b/mcpgateway/alembic/versions/k5e6f7g8h9i0_add_structured_logging_tables.py index 98b45eafe..8b41788f9 100644 --- a/mcpgateway/alembic/versions/k5e6f7g8h9i0_add_structured_logging_tables.py +++ b/mcpgateway/alembic/versions/k5e6f7g8h9i0_add_structured_logging_tables.py @@ -11,7 +11,7 @@ # revision identifiers, used by Alembic. revision = "k5e6f7g8h9i0" -down_revision = "add_toolops_test_cases_table" +down_revision = "9e028ecf59c4" branch_labels = None depends_on = None From 2284ba886c8cceae1a6d0e519268e7d5f8ff9acb Mon Sep 17 00:00:00 2001 From: Shoumi Date: Tue, 2 Dec 2025 11:02:48 +0530 Subject: [PATCH 28/34] changes in table schema Signed-off-by: Shoumi --- ...5e6f7g8h9i0_add_structured_logging_tables.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/mcpgateway/alembic/versions/k5e6f7g8h9i0_add_structured_logging_tables.py b/mcpgateway/alembic/versions/k5e6f7g8h9i0_add_structured_logging_tables.py index 8b41788f9..c9eaf8a8b 100644 --- a/mcpgateway/alembic/versions/k5e6f7g8h9i0_add_structured_logging_tables.py +++ b/mcpgateway/alembic/versions/k5e6f7g8h9i0_add_structured_logging_tables.py @@ -113,9 +113,11 @@ def upgrade() -> None: sa.Column("id", sa.String(36), nullable=False), sa.Column("timestamp", sa.DateTime(timezone=True), nullable=False), sa.Column("detected_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("correlation_id", sa.String(64), nullable=True), + sa.Column("log_entry_id", sa.String(36), nullable=True), sa.Column("event_type", sa.String(100), nullable=False), sa.Column("severity", sa.String(20), nullable=False), - sa.Column("category", sa.String(100), nullable=False), + sa.Column("category", sa.String(50), nullable=False), sa.Column("user_id", sa.String(255), nullable=True), sa.Column("user_email", sa.String(255), nullable=True), sa.Column("client_ip", sa.String(45), nullable=False), @@ -123,17 +125,18 @@ def upgrade() -> None: sa.Column("description", sa.Text(), nullable=False), sa.Column("action_taken", sa.String(100), nullable=True), sa.Column("threat_score", sa.Float(), nullable=False, server_default="0.0"), - sa.Column("threat_indicators", sa.JSON(), nullable=True), + sa.Column("threat_indicators", sa.JSON(), nullable=False, server_default="{}"), sa.Column("failed_attempts_count", sa.Integer(), nullable=False, server_default="0"), - sa.Column("context", sa.JSON(), nullable=True), - sa.Column("correlation_id", sa.String(255), nullable=True), sa.Column("resolved", sa.Boolean(), nullable=False, server_default="false"), sa.Column("resolved_at", sa.DateTime(timezone=True), nullable=True), sa.Column("resolved_by", sa.String(255), nullable=True), sa.Column("resolution_notes", sa.Text(), nullable=True), sa.Column("alert_sent", sa.Boolean(), nullable=False, server_default="false"), sa.Column("alert_sent_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("alert_recipients", sa.JSON(), nullable=True), + sa.Column("context", sa.JSON(), nullable=True), sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["log_entry_id"], ["structured_log_entries.id"]), ) # Create indexes for security_events @@ -146,9 +149,13 @@ def upgrade() -> None: op.create_index("ix_security_events_user_id", "security_events", ["user_id"], unique=False) op.create_index("ix_security_events_user_email", "security_events", ["user_email"], unique=False) op.create_index("ix_security_events_client_ip", "security_events", ["client_ip"], unique=False) - op.create_index("idx_security_event_time", "security_events", ["event_type", "timestamp"], unique=False) + op.create_index("ix_security_events_log_entry_id", "security_events", ["log_entry_id"], unique=False) + op.create_index("ix_security_events_resolved", "security_events", ["resolved"], unique=False) + op.create_index("idx_security_type_time", "security_events", ["event_type", "timestamp"], unique=False) op.create_index("idx_security_severity_time", "security_events", ["severity", "timestamp"], unique=False) op.create_index("idx_security_user_time", "security_events", ["user_id", "timestamp"], unique=False) + op.create_index("idx_security_ip_time", "security_events", ["client_ip", "timestamp"], unique=False) + op.create_index("idx_security_unresolved", "security_events", ["resolved", "severity", "timestamp"], unique=False) # Create audit_trails table op.create_table( From 9a8525ea4b3778d29ad7eb6f42bac1272b02cd14 Mon Sep 17 00:00:00 2001 From: Shoumi Date: Wed, 3 Dec 2025 15:42:20 +0530 Subject: [PATCH 29/34] gateway service fixes Signed-off-by: Shoumi --- mcpgateway/admin.py | 49 ++++++++++++++++++- .../middleware/request_logging_middleware.py | 8 +-- mcpgateway/services/gateway_service.py | 20 +++++++- mcpgateway/templates/admin.html | 3 +- 4 files changed, 72 insertions(+), 8 deletions(-) diff --git a/mcpgateway/admin.py b/mcpgateway/admin.py index abc0ae5de..10fb9b0a0 100644 --- a/mcpgateway/admin.py +++ b/mcpgateway/admin.py @@ -105,6 +105,7 @@ ) from mcpgateway.services.a2a_service import A2AAgentError, A2AAgentNameConflictError, A2AAgentNotFoundError, A2AAgentService from mcpgateway.services.argon2_service import Argon2PasswordService +from mcpgateway.services.audit_trail_service import get_audit_trail_service from mcpgateway.services.catalog_service import catalog_service from mcpgateway.services.email_auth_service import AuthenticationError, EmailAuthService, PasswordValidationError from mcpgateway.services.encryption_service import get_encryption_service @@ -118,10 +119,9 @@ from mcpgateway.services.plugin_service import get_plugin_service from mcpgateway.services.prompt_service import PromptNameConflictError, PromptNotFoundError, PromptService from mcpgateway.services.resource_service import ResourceNotFoundError, ResourceService, ResourceURIConflictError -from mcpgateway.services.structured_logger import get_structured_logger -from mcpgateway.services.audit_trail_service import get_audit_trail_service from mcpgateway.services.root_service import RootService from mcpgateway.services.server_service import ServerError, ServerNameConflictError, ServerNotFoundError, ServerService +from mcpgateway.services.structured_logger import get_structured_logger from mcpgateway.services.tag_service import TagService from mcpgateway.services.team_management_service import TeamManagementService from mcpgateway.services.tool_service import ToolError, ToolNameConflictError, ToolNotFoundError, ToolService @@ -9863,11 +9863,56 @@ async def admin_test_gateway(request: GatewayTestRequest, team_id: Optional[str] except json.JSONDecodeError: response_body = {"details": response.text} + # Structured logging: Log successful gateway test + structured_logger = get_structured_logger("gateway_service") + structured_logger.log( + level="INFO", + message=f"Gateway test completed: {request.base_url}", + event_type="gateway_tested", + component="gateway_service", + user_email=get_user_email(user), + team_id=team_id, + resource_type="gateway", + resource_id=gateway.id if gateway else None, + custom_fields={ + "gateway_name": gateway.name if gateway else None, + "gateway_url": str(request.base_url), + "test_method": request.method, + "test_path": request.path, + "status_code": response.status_code, + "latency_ms": latency_ms, + }, + db=db, + ) + return GatewayTestResponse(status_code=response.status_code, latency_ms=latency_ms, body=response_body) except httpx.RequestError as e: LOGGER.warning(f"Gateway test failed: {e}") latency_ms = int((time.monotonic() - start_time) * 1000) + + # Structured logging: Log failed gateway test + structured_logger = get_structured_logger("gateway_service") + structured_logger.log( + level="ERROR", + message=f"Gateway test failed: {request.base_url}", + event_type="gateway_test_failed", + component="gateway_service", + user_email=get_user_email(user), + team_id=team_id, + resource_type="gateway", + resource_id=gateway.id if gateway else None, + error=e, + custom_fields={ + "gateway_name": gateway.name if gateway else None, + "gateway_url": str(request.base_url), + "test_method": request.method, + "test_path": request.path, + "latency_ms": latency_ms, + }, + db=db, + ) + return GatewayTestResponse(status_code=502, latency_ms=latency_ms, body={"error": "Request failed", "details": str(e)}) diff --git a/mcpgateway/middleware/request_logging_middleware.py b/mcpgateway/middleware/request_logging_middleware.py index 8506ed2b5..f241197ab 100644 --- a/mcpgateway/middleware/request_logging_middleware.py +++ b/mcpgateway/middleware/request_logging_middleware.py @@ -20,9 +20,9 @@ # Third-Party from fastapi.security import HTTPAuthorizationCredentials +from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from starlette.responses import Response -from starlette.middleware.base import BaseHTTPMiddleware # First-Party from mcpgateway.auth import get_current_user @@ -36,7 +36,7 @@ logger = logging_service.get_logger(__name__) # Initialize structured logger for gateway boundary logging -structured_logger = get_structured_logger("gateway") +structured_logger = get_structured_logger("http_gateway") SENSITIVE_KEYS = {"password", "secret", "token", "apikey", "access_token", "refresh_token", "client_secret", "authorization", "jwt_token"} @@ -211,7 +211,7 @@ async def dispatch(self, request: Request, call_next: Callable): structured_logger.log( level="INFO", message=f"Request started: {method} {path}", - component="gateway", + component="http_gateway", correlation_id=correlation_id, user_email=user_email, user_id=user_id, @@ -237,7 +237,7 @@ async def dispatch(self, request: Request, call_next: Callable): structured_logger.log( level=log_level, message=f"Request completed: {method} {path} - {response.status_code}", - component="gateway", + component="http_gateway", correlation_id=correlation_id, user_email=user_email, user_id=user_id, diff --git a/mcpgateway/services/gateway_service.py b/mcpgateway/services/gateway_service.py index 38ad4770a..f2cbe81da 100644 --- a/mcpgateway/services/gateway_service.py +++ b/mcpgateway/services/gateway_service.py @@ -80,10 +80,10 @@ from mcpgateway.db import Tool as DbTool from mcpgateway.observability import create_span from mcpgateway.schemas import GatewayCreate, GatewayRead, GatewayUpdate, PromptCreate, ResourceCreate, ToolCreate -from mcpgateway.services.event_service import EventService # logging.getLogger("httpx").setLevel(logging.WARNING) # Disables httpx logs for regular health checks from mcpgateway.services.audit_trail_service import get_audit_trail_service +from mcpgateway.services.event_service import EventService from mcpgateway.services.logging_service import LoggingService from mcpgateway.services.oauth_manager import OAuthManager from mcpgateway.services.structured_logger import get_structured_logger @@ -1767,6 +1767,24 @@ async def get_gateway(self, db: Session, gateway_id: str, include_inactive: bool if gateway.enabled or include_inactive: gateway.team = self._get_team_name(db, getattr(gateway, "team_id", None)) + + # Structured logging: Log gateway view + structured_logger.log( + level="INFO", + message="Gateway retrieved successfully", + event_type="gateway_viewed", + component="gateway_service", + team_id=getattr(gateway, "team_id", None), + resource_type="gateway", + resource_id=str(gateway.id), + custom_fields={ + "gateway_name": gateway.name, + "gateway_url": gateway.url, + "include_inactive": include_inactive, + }, + db=db, + ) + return GatewayRead.model_validate(self._prepare_gateway_for_read(gateway)).masked() raise GatewayNotFoundError(f"Gateway not found: {gateway_id}") diff --git a/mcpgateway/templates/admin.html b/mcpgateway/templates/admin.html index bba814c0a..c983421c5 100644 --- a/mcpgateway/templates/admin.html +++ b/mcpgateway/templates/admin.html @@ -679,7 +679,8 @@ onchange="searchStructuredLogs()" > - + + From 275685633a09b0f3563b58c1c10e72e31cbe699a Mon Sep 17 00:00:00 2001 From: Shoumi Date: Wed, 3 Dec 2025 16:12:16 +0530 Subject: [PATCH 30/34] updated tests Signed-off-by: Shoumi --- mcpgateway/admin.py | 4 ++-- mcpgateway/services/gateway_service.py | 4 ++-- tests/unit/mcpgateway/test_admin.py | 12 ++++++++---- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/mcpgateway/admin.py b/mcpgateway/admin.py index 10fb9b0a0..c945be5e4 100644 --- a/mcpgateway/admin.py +++ b/mcpgateway/admin.py @@ -9890,7 +9890,7 @@ async def admin_test_gateway(request: GatewayTestRequest, team_id: Optional[str] except httpx.RequestError as e: LOGGER.warning(f"Gateway test failed: {e}") latency_ms = int((time.monotonic() - start_time) * 1000) - + # Structured logging: Log failed gateway test structured_logger = get_structured_logger("gateway_service") structured_logger.log( @@ -9912,7 +9912,7 @@ async def admin_test_gateway(request: GatewayTestRequest, team_id: Optional[str] }, db=db, ) - + return GatewayTestResponse(status_code=502, latency_ms=latency_ms, body={"error": "Request failed", "details": str(e)}) diff --git a/mcpgateway/services/gateway_service.py b/mcpgateway/services/gateway_service.py index f2cbe81da..9d47fea1e 100644 --- a/mcpgateway/services/gateway_service.py +++ b/mcpgateway/services/gateway_service.py @@ -1767,7 +1767,7 @@ async def get_gateway(self, db: Session, gateway_id: str, include_inactive: bool if gateway.enabled or include_inactive: gateway.team = self._get_team_name(db, getattr(gateway, "team_id", None)) - + # Structured logging: Log gateway view structured_logger.log( level="INFO", @@ -1784,7 +1784,7 @@ async def get_gateway(self, db: Session, gateway_id: str, include_inactive: bool }, db=db, ) - + return GatewayRead.model_validate(self._prepare_gateway_for_read(gateway)).masked() raise GatewayNotFoundError(f"Gateway not found: {gateway_id}") diff --git a/tests/unit/mcpgateway/test_admin.py b/tests/unit/mcpgateway/test_admin.py index 3079044bb..5eb7d7b13 100644 --- a/tests/unit/mcpgateway/test_admin.py +++ b/tests/unit/mcpgateway/test_admin.py @@ -1359,7 +1359,8 @@ async def test_admin_test_gateway_various_methods(self): mock_client_class.return_value = mock_client - result = await admin_test_gateway(request, "test-user") + mock_db = MagicMock() + result = await admin_test_gateway(request, None, "test-user", mock_db) assert result.status_code == 200 mock_client.request.assert_called_once() @@ -1398,7 +1399,8 @@ async def test_admin_test_gateway_url_construction(self): mock_client_class.return_value = mock_client - await admin_test_gateway(request, "test-user") + mock_db = MagicMock() + await admin_test_gateway(request, None, "test-user", mock_db) call_args = mock_client.request.call_args assert call_args[1]["url"] == expected_url @@ -1424,7 +1426,8 @@ async def test_admin_test_gateway_timeout_handling(self): mock_client_class.return_value = mock_client - result = await admin_test_gateway(request, "test-user") + mock_db = MagicMock() + result = await admin_test_gateway(request, None, "test-user", mock_db) assert result.status_code == 502 assert "Request timed out" in str(result.body) @@ -1461,7 +1464,8 @@ async def test_admin_test_gateway_non_json_response(self): mock_client_class.return_value = mock_client - result = await admin_test_gateway(request, "test-user") + mock_db = MagicMock() + result = await admin_test_gateway(request, None, "test-user", mock_db) assert result.status_code == 200 assert result.body["details"] == response_text From 673f4892c6d96e472426789971eb1645118e20c5 Mon Sep 17 00:00:00 2001 From: Shoumi Date: Wed, 3 Dec 2025 16:34:19 +0530 Subject: [PATCH 31/34] fix doctest coverage Signed-off-by: Shoumi --- mcpgateway/admin.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mcpgateway/admin.py b/mcpgateway/admin.py index c945be5e4..123bee450 100644 --- a/mcpgateway/admin.py +++ b/mcpgateway/admin.py @@ -9693,7 +9693,7 @@ async def admin_test_gateway(request: GatewayTestRequest, team_id: Optional[str] >>> async def test_admin_test_gateway(): ... with patch('mcpgateway.admin.ResilientHttpClient') as mock_client_class: ... mock_client_class.return_value = MockClient() - ... response = await admin_test_gateway(mock_request, mock_user) + ... response = await admin_test_gateway(mock_request, None, mock_user, mock_db) ... return isinstance(response, GatewayTestResponse) and response.status_code == 200 >>> >>> result = asyncio.run(test_admin_test_gateway()) @@ -9719,7 +9719,7 @@ async def admin_test_gateway(request: GatewayTestRequest, team_id: Optional[str] >>> async def test_admin_test_gateway_text_response(): ... with patch('mcpgateway.admin.ResilientHttpClient') as mock_client_class: ... mock_client_class.return_value = MockClientTextOnly() - ... response = await admin_test_gateway(mock_request, mock_user) + ... response = await admin_test_gateway(mock_request, None, mock_user, mock_db) ... return isinstance(response, GatewayTestResponse) and response.body.get("details") == "plain text response" >>> >>> asyncio.run(test_admin_test_gateway_text_response()) @@ -9737,7 +9737,7 @@ async def admin_test_gateway(request: GatewayTestRequest, team_id: Optional[str] >>> async def test_admin_test_gateway_network_error(): ... with patch('mcpgateway.admin.ResilientHttpClient') as mock_client_class: ... mock_client_class.return_value = MockClientError() - ... response = await admin_test_gateway(mock_request, mock_user) + ... response = await admin_test_gateway(mock_request, None, mock_user, mock_db) ... return response.status_code == 502 and "Network error" in str(response.body) >>> >>> asyncio.run(test_admin_test_gateway_network_error()) @@ -9755,7 +9755,7 @@ async def admin_test_gateway(request: GatewayTestRequest, team_id: Optional[str] >>> async def test_admin_test_gateway_post(): ... with patch('mcpgateway.admin.ResilientHttpClient') as mock_client_class: ... mock_client_class.return_value = MockClient() - ... response = await admin_test_gateway(mock_request_post, mock_user) + ... response = await admin_test_gateway(mock_request_post, None, mock_user, mock_db) ... return isinstance(response, GatewayTestResponse) and response.status_code == 200 >>> >>> asyncio.run(test_admin_test_gateway_post()) @@ -9773,7 +9773,7 @@ async def admin_test_gateway(request: GatewayTestRequest, team_id: Optional[str] >>> async def test_admin_test_gateway_trailing_slash(): ... with patch('mcpgateway.admin.ResilientHttpClient') as mock_client_class: ... mock_client_class.return_value = MockClient() - ... response = await admin_test_gateway(mock_request_trailing, mock_user) + ... response = await admin_test_gateway(mock_request_trailing, None, mock_user, mock_db) ... return isinstance(response, GatewayTestResponse) and response.status_code == 200 >>> >>> asyncio.run(test_admin_test_gateway_trailing_slash()) From 0232a625f8988d5d1f3678a4f0653a249cb67f0e Mon Sep 17 00:00:00 2001 From: Mihai Criveti Date: Fri, 12 Dec 2025 09:35:59 +0000 Subject: [PATCH 32/34] fix: resolve rebase conflicts and fix test issues for correlation ID PR - Fix Alembic migration to chain after main branch head (356a2d4eed6f) - Fix is_active/enabled attribute access in services (server, prompt, resource, export) - Update export_service to use getattr with fallback for backwards compatibility - Add db.refresh before return in tool_service.register_tool to handle session expiry after audit/logging commits - Add SessionLocal patches in conftest.py for audit_trail_service and log_aggregator - Update test assertions for expected db.refresh call count - Apply isort import ordering fixes across service files Signed-off-by: Mihai Criveti --- .../k5e6f7g8h9i0_add_structured_logging_tables.py | 5 +++-- mcpgateway/main.py | 2 +- mcpgateway/observability.py | 1 + mcpgateway/routers/log_search.py | 6 +++--- mcpgateway/services/export_service.py | 4 ++-- mcpgateway/services/log_aggregator.py | 2 +- mcpgateway/services/logging_service.py | 2 ++ mcpgateway/services/prompt_service.py | 6 +++--- mcpgateway/services/resource_service.py | 6 +++--- mcpgateway/services/security_logger.py | 4 ++-- mcpgateway/services/server_service.py | 8 ++++---- mcpgateway/services/structured_logger.py | 3 ++- mcpgateway/services/tool_service.py | 6 ++++-- tests/conftest.py | 8 ++++++++ tests/unit/mcpgateway/services/test_export_service.py | 4 ++-- tests/unit/mcpgateway/services/test_tool_service.py | 3 ++- 16 files changed, 43 insertions(+), 27 deletions(-) diff --git a/mcpgateway/alembic/versions/k5e6f7g8h9i0_add_structured_logging_tables.py b/mcpgateway/alembic/versions/k5e6f7g8h9i0_add_structured_logging_tables.py index c9eaf8a8b..7750d14c4 100644 --- a/mcpgateway/alembic/versions/k5e6f7g8h9i0_add_structured_logging_tables.py +++ b/mcpgateway/alembic/versions/k5e6f7g8h9i0_add_structured_logging_tables.py @@ -1,17 +1,18 @@ """Add structured logging tables Revision ID: k5e6f7g8h9i0 -Revises: f3a3a3d901b8 +Revises: 356a2d4eed6f Create Date: 2025-01-15 12:00:00.000000 """ +# Third-Party from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "k5e6f7g8h9i0" -down_revision = "9e028ecf59c4" +down_revision = "356a2d4eed6f" branch_labels = None depends_on = None diff --git a/mcpgateway/main.py b/mcpgateway/main.py index 09b0dff20..1b0e39eae 100644 --- a/mcpgateway/main.py +++ b/mcpgateway/main.py @@ -113,8 +113,8 @@ from mcpgateway.services.import_service import ConflictStrategy, ImportConflictError from mcpgateway.services.import_service import ImportError as ImportServiceError from mcpgateway.services.import_service import ImportService, ImportValidationError -from mcpgateway.services.logging_service import LoggingService from mcpgateway.services.log_aggregator import get_log_aggregator +from mcpgateway.services.logging_service import LoggingService from mcpgateway.services.metrics import setup_metrics from mcpgateway.services.prompt_service import PromptError, PromptNameConflictError, PromptNotFoundError, PromptService from mcpgateway.services.resource_service import ResourceError, ResourceNotFoundError, ResourceService, ResourceURIConflictError diff --git a/mcpgateway/observability.py b/mcpgateway/observability.py index 459e57f18..25b28ef5c 100644 --- a/mcpgateway/observability.py +++ b/mcpgateway/observability.py @@ -18,6 +18,7 @@ # Third-Party - Try to import OpenTelemetry core components - make them truly optional OTEL_AVAILABLE = False try: + # Third-Party from opentelemetry import trace from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace import TracerProvider diff --git a/mcpgateway/routers/log_search.py b/mcpgateway/routers/log_search.py index ab8cdc919..5023ea614 100644 --- a/mcpgateway/routers/log_search.py +++ b/mcpgateway/routers/log_search.py @@ -17,7 +17,7 @@ # Third-Party from fastapi import APIRouter, Depends, HTTPException, Query from pydantic import BaseModel, Field -from sqlalchemy import and_, or_, desc, select, delete +from sqlalchemy import and_, delete, desc, or_, select from sqlalchemy.orm import Session from sqlalchemy.sql import func as sa_func @@ -25,12 +25,12 @@ from mcpgateway.config import settings from mcpgateway.db import ( AuditTrail, + get_db, PerformanceMetric, SecurityEvent, StructuredLogEntry, - get_db, ) -from mcpgateway.middleware.rbac import require_permission, get_current_user_with_permissions +from mcpgateway.middleware.rbac import get_current_user_with_permissions, require_permission from mcpgateway.services.log_aggregator import get_log_aggregator logger = logging.getLogger(__name__) diff --git a/mcpgateway/services/export_service.py b/mcpgateway/services/export_service.py index d5806dd59..78a5a5763 100644 --- a/mcpgateway/services/export_service.py +++ b/mcpgateway/services/export_service.py @@ -399,7 +399,7 @@ async def _export_servers(self, db: Session, tags: Optional[List[str]], include_ "websocket_endpoint": f"{root_path}/servers/{server.id}/ws", "jsonrpc_endpoint": f"{root_path}/servers/{server.id}/jsonrpc", "capabilities": {"tools": {"list_changed": True}, "prompts": {"list_changed": True}}, - "is_active": server.is_active, + "is_active": getattr(server, "enabled", getattr(server, "is_active", False)), "tags": server.tags or [], } @@ -469,7 +469,7 @@ async def _export_resources(self, db: Session, tags: Optional[List[str]], includ "description": resource.description, "mime_type": resource.mime_type, "tags": resource.tags or [], - "is_active": resource.is_active, + "is_active": getattr(resource, "enabled", getattr(resource, "is_active", False)), "last_modified": resource.updated_at.isoformat() if resource.updated_at else None, } diff --git a/mcpgateway/services/log_aggregator.py b/mcpgateway/services/log_aggregator.py index 7e4443680..2d7f0f293 100644 --- a/mcpgateway/services/log_aggregator.py +++ b/mcpgateway/services/log_aggregator.py @@ -21,8 +21,8 @@ from sqlalchemy.orm import Session # First-Party -from mcpgateway.db import PerformanceMetric, StructuredLogEntry, SessionLocal from mcpgateway.config import settings +from mcpgateway.db import PerformanceMetric, SessionLocal, StructuredLogEntry logger = logging.getLogger(__name__) diff --git a/mcpgateway/services/logging_service.py b/mcpgateway/services/logging_service.py index ef39abfde..f18f826f9 100644 --- a/mcpgateway/services/logging_service.py +++ b/mcpgateway/services/logging_service.py @@ -30,6 +30,7 @@ # Optional OpenTelemetry support (Third-Party) try: + # Third-Party from opentelemetry import trace # type: ignore[import-untyped] except ImportError: trace = None # type: ignore[assignment] @@ -37,6 +38,7 @@ AnyioClosedResourceError: Optional[type] # pylint: disable=invalid-name try: # Optional import; only used for filtering a known benign upstream error (Third-Party) + # Third-Party from anyio import ClosedResourceError as AnyioClosedResourceError # pylint: disable=invalid-name except Exception: # pragma: no cover - environment without anyio AnyioClosedResourceError = None # pylint: disable=invalid-name diff --git a/mcpgateway/services/prompt_service.py b/mcpgateway/services/prompt_service.py index de3b1a573..cd3841563 100644 --- a/mcpgateway/services/prompt_service.py +++ b/mcpgateway/services/prompt_service.py @@ -37,8 +37,8 @@ from mcpgateway.observability import create_span from mcpgateway.plugins.framework import GlobalContext, PluginContextTable, PluginManager, PromptHookType, PromptPosthookPayload, PromptPrehookPayload from mcpgateway.schemas import PromptCreate, PromptRead, PromptUpdate, TopPerformer -from mcpgateway.services.event_service import EventService from mcpgateway.services.audit_trail_service import get_audit_trail_service +from mcpgateway.services.event_service import EventService from mcpgateway.services.logging_service import LoggingService from mcpgateway.services.observability_service import current_trace_id, ObservabilityService from mcpgateway.services.structured_logger import get_structured_logger @@ -1292,7 +1292,7 @@ async def toggle_prompt_status(self, db: Session, prompt_id: int, activate: bool resource_name=prompt.name, user_email=user_email, team_id=prompt.team_id, - new_values={"is_active": prompt.is_active}, + new_values={"enabled": prompt.enabled}, context={"action": "activate" if activate else "deactivate"}, db=db, ) @@ -1306,7 +1306,7 @@ async def toggle_prompt_status(self, db: Session, prompt_id: int, activate: bool team_id=prompt.team_id, resource_type="prompt", resource_id=str(prompt.id), - custom_fields={"prompt_name": prompt.name, "is_active": prompt.is_active}, + custom_fields={"prompt_name": prompt.name, "enabled": prompt.enabled}, db=db, ) diff --git a/mcpgateway/services/resource_service.py b/mcpgateway/services/resource_service.py index c17da434f..1b8136a51 100644 --- a/mcpgateway/services/resource_service.py +++ b/mcpgateway/services/resource_service.py @@ -51,8 +51,8 @@ from mcpgateway.db import server_resource_association from mcpgateway.observability import create_span from mcpgateway.schemas import ResourceCreate, ResourceMetrics, ResourceRead, ResourceSubscription, ResourceUpdate, TopPerformer -from mcpgateway.services.event_service import EventService from mcpgateway.services.audit_trail_service import get_audit_trail_service +from mcpgateway.services.event_service import EventService from mcpgateway.services.logging_service import LoggingService from mcpgateway.services.oauth_manager import OAuthManager from mcpgateway.services.observability_service import current_trace_id, ObservabilityService @@ -1580,7 +1580,7 @@ async def toggle_resource_status(self, db: Session, resource_id: int, activate: user_email=user_email, team_id=resource.team_id, new_values={ - "is_active": resource.is_active, + "enabled": resource.enabled, }, context={ "action": "activate" if activate else "deactivate", @@ -1600,7 +1600,7 @@ async def toggle_resource_status(self, db: Session, resource_id: int, activate: resource_id=str(resource.id), custom_fields={ "resource_uri": resource.uri, - "is_active": resource.is_active, + "enabled": resource.enabled, }, db=db, ) diff --git a/mcpgateway/services/security_logger.py b/mcpgateway/services/security_logger.py index 0a2f1f7b6..1b2470691 100644 --- a/mcpgateway/services/security_logger.py +++ b/mcpgateway/services/security_logger.py @@ -20,9 +20,9 @@ from sqlalchemy.orm import Session # First-Party -from mcpgateway.db import SecurityEvent, AuditTrail, SessionLocal -from mcpgateway.utils.correlation_id import get_correlation_id from mcpgateway.config import settings +from mcpgateway.db import AuditTrail, SecurityEvent, SessionLocal +from mcpgateway.utils.correlation_id import get_correlation_id logger = logging.getLogger(__name__) diff --git a/mcpgateway/services/server_service.py b/mcpgateway/services/server_service.py index 2ed3bcaa1..01f20b304 100644 --- a/mcpgateway/services/server_service.py +++ b/mcpgateway/services/server_service.py @@ -33,10 +33,10 @@ from mcpgateway.db import ServerMetric from mcpgateway.db import Tool as DbTool from mcpgateway.schemas import ServerCreate, ServerMetrics, ServerRead, ServerUpdate, TopPerformer -from mcpgateway.services.logging_service import LoggingService -from mcpgateway.services.structured_logger import get_structured_logger from mcpgateway.services.audit_trail_service import get_audit_trail_service +from mcpgateway.services.logging_service import LoggingService from mcpgateway.services.performance_tracker import get_performance_tracker +from mcpgateway.services.structured_logger import get_structured_logger from mcpgateway.services.team_management_service import TeamManagementService from mcpgateway.utils.metrics_common import build_top_performers from mcpgateway.utils.sqlalchemy_modifier import json_contains_expr @@ -826,7 +826,7 @@ async def get_server(self, db: Session, server_id: str) -> ServerRead: resource_type="server", resource_id=server.id, custom_fields={ - "is_active": server.is_active, + "enabled": server.enabled, "tool_count": len(getattr(server, "tools", []) or []), "resource_count": len(getattr(server, "resources", []) or []), "prompt_count": len(getattr(server, "prompts", []) or []), @@ -841,7 +841,7 @@ async def get_server(self, db: Session, server_id: str) -> ServerRead: resource_name=server.name, user_id="system", team_id=getattr(server, "team_id", None), - context={"is_active": server.is_active}, + context={"enabled": server.enabled}, db=db, ) diff --git a/mcpgateway/services/structured_logger.py b/mcpgateway/services/structured_logger.py index 6f7d3afd7..0d8a4a599 100644 --- a/mcpgateway/services/structured_logger.py +++ b/mcpgateway/services/structured_logger.py @@ -23,7 +23,7 @@ # First-Party from mcpgateway.config import settings -from mcpgateway.db import StructuredLogEntry, SessionLocal +from mcpgateway.db import SessionLocal, StructuredLogEntry from mcpgateway.services.performance_tracker import get_performance_tracker from mcpgateway.utils.correlation_id import get_correlation_id @@ -94,6 +94,7 @@ def enrich(entry: Dict[str, Any]) -> Dict[str, Any]: # Add OpenTelemetry trace context if available try: + # Third-Party from opentelemetry import trace # pylint: disable=import-outside-toplevel span = trace.get_current_span() diff --git a/mcpgateway/services/tool_service.py b/mcpgateway/services/tool_service.py index 6c0778015..5616e0ff4 100644 --- a/mcpgateway/services/tool_service.py +++ b/mcpgateway/services/tool_service.py @@ -43,7 +43,6 @@ from mcpgateway.common.models import Tool as PydanticTool from mcpgateway.common.models import ToolResult from mcpgateway.config import settings -from mcpgateway.utils.correlation_id import get_correlation_id from mcpgateway.db import A2AAgent as DbA2AAgent from mcpgateway.db import EmailTeam from mcpgateway.db import Gateway as DbGateway @@ -64,13 +63,14 @@ ) from mcpgateway.plugins.framework.constants import GATEWAY_METADATA, TOOL_METADATA from mcpgateway.schemas import ToolCreate, ToolRead, ToolUpdate, TopPerformer -from mcpgateway.services.event_service import EventService from mcpgateway.services.audit_trail_service import get_audit_trail_service +from mcpgateway.services.event_service import EventService from mcpgateway.services.logging_service import LoggingService from mcpgateway.services.oauth_manager import OAuthManager from mcpgateway.services.performance_tracker import get_performance_tracker from mcpgateway.services.structured_logger import get_structured_logger from mcpgateway.services.team_management_service import TeamManagementService +from mcpgateway.utils.correlation_id import get_correlation_id from mcpgateway.utils.create_slug import slugify from mcpgateway.utils.display_name import generate_display_name from mcpgateway.utils.metrics_common import build_top_performers @@ -764,6 +764,8 @@ async def register_tool( db=db, ) + # Refresh db_tool after logging commits (they expire the session objects) + db.refresh(db_tool) return self._convert_tool_to_read(db_tool) except IntegrityError as ie: db.rollback() diff --git a/tests/conftest.py b/tests/conftest.py index a1cce45df..69b3a0e31 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -127,10 +127,14 @@ def app(): import mcpgateway.middleware.auth_middleware as auth_middleware_mod import mcpgateway.services.security_logger as sec_logger_mod import mcpgateway.services.structured_logger as struct_logger_mod + import mcpgateway.services.audit_trail_service as audit_trail_mod + import mcpgateway.services.log_aggregator as log_aggregator_mod mp.setattr(auth_middleware_mod, "SessionLocal", TestSessionLocal, raising=False) mp.setattr(sec_logger_mod, "SessionLocal", TestSessionLocal, raising=False) mp.setattr(struct_logger_mod, "SessionLocal", TestSessionLocal, raising=False) + mp.setattr(audit_trail_mod, "SessionLocal", TestSessionLocal, raising=False) + mp.setattr(log_aggregator_mod, "SessionLocal", TestSessionLocal, raising=False) # 4) create schema db_mod.Base.metadata.create_all(bind=engine) @@ -200,10 +204,14 @@ def app_with_temp_db(): import mcpgateway.middleware.auth_middleware as auth_middleware_mod import mcpgateway.services.security_logger as sec_logger_mod import mcpgateway.services.structured_logger as struct_logger_mod + import mcpgateway.services.audit_trail_service as audit_trail_mod + import mcpgateway.services.log_aggregator as log_aggregator_mod mp.setattr(auth_middleware_mod, "SessionLocal", TestSessionLocal, raising=False) mp.setattr(sec_logger_mod, "SessionLocal", TestSessionLocal, raising=False) mp.setattr(struct_logger_mod, "SessionLocal", TestSessionLocal, raising=False) + mp.setattr(audit_trail_mod, "SessionLocal", TestSessionLocal, raising=False) + mp.setattr(log_aggregator_mod, "SessionLocal", TestSessionLocal, raising=False) # 4) create schema db_mod.Base.metadata.create_all(bind=engine) diff --git a/tests/unit/mcpgateway/services/test_export_service.py b/tests/unit/mcpgateway/services/test_export_service.py index 0c60f803b..4f921a140 100644 --- a/tests/unit/mcpgateway/services/test_export_service.py +++ b/tests/unit/mcpgateway/services/test_export_service.py @@ -726,7 +726,7 @@ async def test_export_servers_with_data(export_service, mock_db): mock_server.name = "test_server" mock_server.description = "Test server" mock_server.associated_tools = ["tool1", "tool2"] - mock_server.is_active = True + mock_server.enabled = True mock_server.tags = ["test", "api"] export_service.server_service.list_servers.return_value = [mock_server] @@ -803,7 +803,7 @@ async def test_export_resources_with_data(export_service, mock_db): mock_resource.uri = "file:///workspace/test.txt" mock_resource.description = "Test resource file" mock_resource.mime_type = "text/plain" - mock_resource.is_active = True + mock_resource.enabled = True mock_resource.tags = ["file", "text"] mock_resource.updated_at = datetime.now(timezone.utc) diff --git a/tests/unit/mcpgateway/services/test_tool_service.py b/tests/unit/mcpgateway/services/test_tool_service.py index 46c520906..f315560c4 100644 --- a/tests/unit/mcpgateway/services/test_tool_service.py +++ b/tests/unit/mcpgateway/services/test_tool_service.py @@ -300,7 +300,8 @@ async def test_register_tool(self, tool_service, mock_tool, test_db): # Verify DB operations test_db.add.assert_called_once() test_db.commit.assert_called_once() - test_db.refresh.assert_called_once() + # refresh is called twice: once after commit and once after logging commits + assert test_db.refresh.call_count == 2 # Verify result assert result.name == "test-gateway-test-tool" From 4821f0a54653eb473619dbef79e7c9e035bf4f6d Mon Sep 17 00:00:00 2001 From: Mihai Criveti Date: Fri, 12 Dec 2025 09:45:41 +0000 Subject: [PATCH 33/34] Linting Signed-off-by: Mihai Criveti --- .env.example | 2 +- README.md | 2 +- docs/docs/deployment/container.md | 4 +- gunicorn.config.py | 19 ++-- ...6f_uuid_change_for_prompt_and_resources.py | 1 + ...4_tag_records_changes_list_str_to_list_.py | 1 + ...6f7g8h9i0_add_structured_logging_tables.py | 1 + mcpgateway/static/admin.js | 14 +-- mcpgateway/templates/admin.html | 46 +++++----- mcpgateway/toolops/README.md | 4 +- plugins/vault/README.md | 1 - .../services/test_gateway_service_extended.py | 2 +- .../services/test_resource_service_plugins.py | 6 +- .../services/test_server_service.py | 12 +-- .../mcpgateway/services/test_tool_service.py | 6 +- tests/unit/mcpgateway/test_admin.py | 2 +- .../mcpgateway/utils/test_ssl_key_manager.py | 92 +++++++++---------- .../utils/test_verify_credentials.py | 4 +- 18 files changed, 110 insertions(+), 109 deletions(-) diff --git a/.env.example b/.env.example index 87e5b81d4..8afdf1313 100644 --- a/.env.example +++ b/.env.example @@ -537,7 +537,7 @@ SECURITY_HEADERS_ENABLED=true # null or none: Completely removes iframe restrictions (no headers sent) # ALLOW-FROM uri: Allows specific domain (deprecated, use CSP instead) # ALLOW-ALL uri: Allows all (*, http, https) -# +# # Both X-Frame-Options header and CSP frame-ancestors directive are automatically synced. # Modern browsers prioritize CSP frame-ancestors over X-Frame-Options. X_FRAME_OPTIONS=DENY diff --git a/README.md b/README.md index fae6a2b60..ff0390d3b 100644 --- a/README.md +++ b/README.md @@ -1619,7 +1619,7 @@ ContextForge implements **OAuth 2.0 Dynamic Client Registration (RFC 7591)** and > > **iframe Embedding**: The gateway controls iframe embedding through both `X-Frame-Options` header and CSP `frame-ancestors` directive (both are automatically synced). Options: > - `X_FRAME_OPTIONS=DENY` (default): Blocks all iframe embedding -> - `X_FRAME_OPTIONS=SAMEORIGIN`: Allows embedding from same domain only +> - `X_FRAME_OPTIONS=SAMEORIGIN`: Allows embedding from same domain only > - `X_FRAME_OPTIONS="ALLOW-ALL"`: Allows embedding from all sources (sets `frame-ancestors * file: http: https:`) > - `X_FRAME_OPTIONS=null` or `none`: Completely removes iframe restrictions (no headers sent) > diff --git a/docs/docs/deployment/container.md b/docs/docs/deployment/container.md index 775aeb430..8342e4680 100644 --- a/docs/docs/deployment/container.md +++ b/docs/docs/deployment/container.md @@ -31,12 +31,12 @@ docker logs mcpgateway You can now access the UI at [http://localhost:4444/admin](http://localhost:4444/admin) ### Multi-architecture containers -Note: the container build process creates container images for 'amd64', 'arm64' and 's390x' architectures. The version `ghcr.io/ibm/mcp-context-forge:VERSION` +Note: the container build process creates container images for 'amd64', 'arm64' and 's390x' architectures. The version `ghcr.io/ibm/mcp-context-forge:VERSION` not points to a manifest so that if all commands will pull the correct image for the architecture being used (whether that be locally or on Kubernetes or OpenShift). If the specific image is needed for one architecture on a different architecture use the appropriate arguments for your given container execution tool: -With docker run: +With docker run: ``` docker run [... all your options...] --platform linux/arm64 ghcr.io/ibm/mcp-context-forge:VERSION ``` diff --git a/gunicorn.config.py b/gunicorn.config.py index f6158672f..df888da42 100644 --- a/gunicorn.config.py +++ b/gunicorn.config.py @@ -65,37 +65,37 @@ def on_starting(server): """Called just before the master process is initialized. - + This is where we handle passphrase-protected SSL keys by decrypting them to a temporary file before Gunicorn workers start. """ global _prepared_key_file - + # Check if SSL is enabled via environment variable (set by run-gunicorn.sh) # and a passphrase is provided ssl_enabled = os.environ.get("SSL", "false").lower() == "true" ssl_key_password = os.environ.get("SSL_KEY_PASSWORD") - + if ssl_enabled and ssl_key_password: try: from mcpgateway.utils.ssl_key_manager import prepare_ssl_key - + # Get the key file path from environment (set by run-gunicorn.sh) key_file = os.environ.get("KEY_FILE", "certs/key.pem") - + server.log.info(f"Preparing passphrase-protected SSL key: {key_file}") - + # Decrypt the key and get the temporary file path _prepared_key_file = prepare_ssl_key(key_file, ssl_key_password) - + server.log.info(f"SSL key prepared successfully: {_prepared_key_file}") - + # Update the keyfile setting to use the decrypted temporary file # This is a bit of a hack, but Gunicorn doesn't provide a better way # to modify the keyfile after it's been set via command line if hasattr(server, 'cfg'): server.cfg.set('keyfile', _prepared_key_file) - + except Exception as e: server.log.error(f"Failed to prepare SSL key: {e}") raise @@ -127,4 +127,3 @@ def worker_exit(server, worker): def child_exit(server, worker): server.log.info("Worker child exit (pid: %s)", worker.pid) - diff --git a/mcpgateway/alembic/versions/356a2d4eed6f_uuid_change_for_prompt_and_resources.py b/mcpgateway/alembic/versions/356a2d4eed6f_uuid_change_for_prompt_and_resources.py index b1e49a6f0..b616a0892 100644 --- a/mcpgateway/alembic/versions/356a2d4eed6f_uuid_change_for_prompt_and_resources.py +++ b/mcpgateway/alembic/versions/356a2d4eed6f_uuid_change_for_prompt_and_resources.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- """UUID Change for Prompt and Resources Revision ID: 356a2d4eed6f diff --git a/mcpgateway/alembic/versions/9e028ecf59c4_tag_records_changes_list_str_to_list_.py b/mcpgateway/alembic/versions/9e028ecf59c4_tag_records_changes_list_str_to_list_.py index 61ba1ed7c..481f303f5 100644 --- a/mcpgateway/alembic/versions/9e028ecf59c4_tag_records_changes_list_str_to_list_.py +++ b/mcpgateway/alembic/versions/9e028ecf59c4_tag_records_changes_list_str_to_list_.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- """tag records changes list[str] to list[Dict[str,str]] Revision ID: 9e028ecf59c4 diff --git a/mcpgateway/alembic/versions/k5e6f7g8h9i0_add_structured_logging_tables.py b/mcpgateway/alembic/versions/k5e6f7g8h9i0_add_structured_logging_tables.py index 7750d14c4..b54a2a2a2 100644 --- a/mcpgateway/alembic/versions/k5e6f7g8h9i0_add_structured_logging_tables.py +++ b/mcpgateway/alembic/versions/k5e6f7g8h9i0_add_structured_logging_tables.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- """Add structured logging tables Revision ID: k5e6f7g8h9i0 diff --git a/mcpgateway/static/admin.js b/mcpgateway/static/admin.js index 61bd603d1..f3e73d21a 100644 --- a/mcpgateway/static/admin.js +++ b/mcpgateway/static/admin.js @@ -3655,7 +3655,7 @@ function openResourceTestModal(resource) { // 2️⃣ If no template → show a simple message fieldsContainer.innerHTML = `
- This resource has no URI template. + This resource has no URI template. Click "Invoke Resource" to test directly.
`; @@ -9591,7 +9591,7 @@ async function loadTools() { console.log("Loading tools..."); try { if (toolBody !== null) { - toolBody.innerHTML = ` + toolBody.innerHTML = `
@@ -24209,7 +24209,7 @@ function displayLogResults(data) { const userDisplay = log.user_email || log.user_id || "-"; return ` -
Time Level - Entity + Component Message + User + + Duration + + Correlation ID +
- ${timestamp} - - - ${log.level} - - - ${log.entity_type ? `${log.entity_type}: ${entity}` : entity} - - ${escapeHtml(log.message)} -
- 📊 No performance metrics available + 📊 No performance metrics available for ${aggregationLabel.toLowerCase()}
- Time - - Level - - Component - - Message - - User - - Duration - - Correlation ID -
❌ Error: ${escapeHtml(error.message)}
📭 No logs found matching your criteria
${formatTimestamp(log.timestamp)} - ${escapeHtml(log.component || '-')} + ${escapeHtml(log.component || "-")} ${escapeHtml(truncateText(log.message, 80))} - ${log.error_details ? '⚠️' : ''} + ${log.error_details ? '⚠️' : ""} ${escapeHtml(userDisplay)} @@ -24187,16 +24211,21 @@ function displayLogResults(data) { ${durationDisplay} - ${correlationId !== '-' ? ` + ${ + correlationId !== "-" + ? ` - ` : '-'} + ` + : "-" + }
@@ -24544,12 +24591,12 @@ function displayCorrelationTrace(trace) { `; return; } - + // Combine all events into a unified timeline const allEvents = []; - + // Add logs - (trace.logs || []).forEach(log => { + (trace.logs || []).forEach((log) => { const levelClass = getLogLevelClass(log.level); allEvents.push({ timestamp: new Date(log.timestamp), @@ -24564,17 +24611,17 @@ function displayCorrelationTrace(trace) { - ${escapeHtml(log.component || '-')} + ${escapeHtml(log.component || "-")} ${escapeHtml(log.message)} - ${log.error_details ? `
⚠️ ${escapeHtml(log.error_details.error_message || JSON.stringify(log.error_details))}` : ''} + ${log.error_details ? `
⚠️ ${escapeHtml(log.error_details.error_message || JSON.stringify(log.error_details))}` : ""}
- ${escapeHtml(log.user_email || log.user_id || '-')} + ${escapeHtml(log.user_email || log.user_id || "-")} - ${log.duration_ms ? log.duration_ms.toFixed(2) + 'ms' : '-'} + ${log.duration_ms ? log.duration_ms.toFixed(2) + "ms" : "-"} @@ -24582,14 +24629,16 @@ function displayCorrelationTrace(trace) {
- ${escapeHtml(event.event_type || '-')} + ${escapeHtml(event.event_type || "-")} - ${escapeHtml(event.description || '-')} + ${escapeHtml(event.description || "-")} - ${escapeHtml(event.user_email || event.user_id || '-')} + ${escapeHtml(event.user_email || event.user_id || "-")} - @@ -24629,23 +24678,27 @@ function displayCorrelationTrace(trace) {
- ${escapeHtml(audit.resource_type || '-')} + ${escapeHtml(audit.resource_type || "-")} ${audit.action}: ${audit.resource_type} - ${escapeHtml(audit.resource_id || '-')} + ${escapeHtml(audit.resource_id || "-")} - ${escapeHtml(audit.user_email || audit.user_id || '-')} + ${escapeHtml(audit.user_email || audit.user_id || "-")} - - ${statusIcon} ${audit.success ? 'Success' : 'Failed'} + ${statusIcon} ${audit.success ? "Success" : "Failed"}
@@ -24767,12 +24825,13 @@ function displaySecurityEvents(events) { `; return; } - - tbody.innerHTML = events.map(event => { - const severityClass = getSeverityClass(event.severity); - const threatScore = (event.threat_score * 100).toFixed(0); - - return ` + + tbody.innerHTML = events + .map((event) => { + const severityClass = getSeverityClass(event.severity); + const threatScore = (event.threat_score * 100).toFixed(0); + + return `
${formatTimestamp(event.timestamp)} @@ -24789,7 +24848,7 @@ function displaySecurityEvents(events) { ${escapeHtml(event.description)} - ${escapeHtml(event.user_email || event.user_id || '-')} + ${escapeHtml(event.user_email || event.user_id || "-")}
@@ -24800,16 +24859,21 @@ function displaySecurityEvents(events) {
- ${event.correlation_id ? ` + ${ + event.correlation_id + ? ` - ` : '-'} + ` + : "-" + }
@@ -24904,31 +24973,39 @@ function displayAuditTrail(trails) { `; return; } - - tbody.innerHTML = trails.map(trail => { - const actionClass = trail.success ? 'text-green-600' : 'text-red-600'; - const actionIcon = trail.success ? '✓' : '✗'; - - // Determine action badge color - const actionBadgeColors = { - 'create': 'bg-green-200 text-green-800 dark:bg-green-800 dark:text-green-200', - 'update': 'bg-blue-200 text-blue-800 dark:bg-blue-800 dark:text-blue-200', - 'delete': 'bg-red-200 text-red-800 dark:bg-red-800 dark:text-red-200', - 'read': 'bg-gray-200 text-gray-800 dark:bg-gray-600 dark:text-gray-200', - 'activate': 'bg-teal-200 text-teal-800 dark:bg-teal-800 dark:text-teal-200', - 'deactivate': 'bg-orange-200 text-orange-800 dark:bg-orange-800 dark:text-orange-200' - }; - const actionBadge = actionBadgeColors[trail.action.toLowerCase()] || 'bg-purple-200 text-purple-800 dark:bg-purple-800 dark:text-purple-200'; - - // Format resource name with ID - const resourceName = trail.resource_name || trail.resource_id || '-'; - const resourceDisplay = ` + + tbody.innerHTML = trails + .map((trail) => { + const actionClass = trail.success + ? "text-green-600" + : "text-red-600"; + const actionIcon = trail.success ? "✓" : "✗"; + + // Determine action badge color + const actionBadgeColors = { + create: "bg-green-200 text-green-800 dark:bg-green-800 dark:text-green-200", + update: "bg-blue-200 text-blue-800 dark:bg-blue-800 dark:text-blue-200", + delete: "bg-red-200 text-red-800 dark:bg-red-800 dark:text-red-200", + read: "bg-gray-200 text-gray-800 dark:bg-gray-600 dark:text-gray-200", + activate: + "bg-teal-200 text-teal-800 dark:bg-teal-800 dark:text-teal-200", + deactivate: + "bg-orange-200 text-orange-800 dark:bg-orange-800 dark:text-orange-200", + }; + const actionBadge = + actionBadgeColors[trail.action.toLowerCase()] || + "bg-purple-200 text-purple-800 dark:bg-purple-800 dark:text-purple-200"; + + // Format resource name with ID + const resourceName = + trail.resource_name || trail.resource_id || "-"; + const resourceDisplay = `
${escapeHtml(resourceName)}
- ${trail.resource_id && trail.resource_name ? `
UUID: ${escapeHtml(trail.resource_id)}
` : ''} - ${trail.data_classification ? `
🔒 ${escapeHtml(trail.data_classification)}
` : ''} + ${trail.resource_id && trail.resource_name ? `
UUID: ${escapeHtml(trail.resource_id)}
` : ""} + ${trail.data_classification ? `
🔒 ${escapeHtml(trail.data_classification)}
` : ""} `; - - return ` + + return `
${formatTimestamp(trail.timestamp)} @@ -24939,28 +25016,33 @@ function displayAuditTrail(trails) { - ${escapeHtml(trail.resource_type || '-')} + ${escapeHtml(trail.resource_type || "-")} ${resourceDisplay} - ${escapeHtml(trail.user_email || trail.user_id || '-')} + ${escapeHtml(trail.user_email || trail.user_id || "-")} - ${actionIcon} ${trail.success ? 'Success' : 'Failed'} + ${actionIcon} ${trail.success ? "Success" : "Failed"} - ${trail.correlation_id ? ` + ${ + trail.correlation_id + ? ` - ` : '-'} + ` + : "-" + }
@@ -25056,21 +25150,23 @@ function displayPerformanceMetrics(metrics) { `; return; } - - tbody.innerHTML = metrics.map(metric => { - const errorRatePercent = (metric.error_rate * 100).toFixed(2); - const errorClass = metric.error_rate > 0.1 ? 'text-red-600' : 'text-green-600'; - - return ` + + tbody.innerHTML = metrics + .map((metric) => { + const errorRatePercent = (metric.error_rate * 100).toFixed(2); + const errorClass = + metric.error_rate > 0.1 ? "text-red-600" : "text-green-600"; + + return `
${formatTimestamp(metric.window_start)} - ${escapeHtml(metric.component || '-')} + ${escapeHtml(metric.component || "-")} - ${escapeHtml(metric.operation_type || '-')} + ${escapeHtml(metric.operation_type || "-")}
@@ -25083,7 +25179,7 @@ function displayPerformanceMetrics(metrics) {
${errorRatePercent}% - ${metric.error_rate > 0.1 ? '⚠️' : ''} + ${metric.error_rate > 0.1 ? "⚠️" : ""}
@@ -25092,7 +25188,8 @@ function displayPerformanceMetrics(metrics) {
Loading tools...
${formatTimestamp(log.timestamp)} @@ -24236,7 +24236,7 @@ function displayLogResults(data) { ${ correlationId !== "-" ? ` - @@ -24734,7 +24734,7 @@ function displayCorrelationTrace(trace) { ${escapeHtml(audit.resource_type || "-")} - ${audit.action}: ${audit.resource_type} + ${audit.action}: ${audit.resource_type} ${escapeHtml(audit.resource_id || "-")} @@ -24878,7 +24878,7 @@ function displaySecurityEvents(events) { ${ event.correlation_id ? ` - @@ -25044,7 +25044,7 @@ function displayAuditTrail(trails) { ${ trail.correlation_id ? ` - diff --git a/mcpgateway/templates/admin.html b/mcpgateway/templates/admin.html index c983421c5..b54474f16 100644 --- a/mcpgateway/templates/admin.html +++ b/mcpgateway/templates/admin.html @@ -1756,7 +1756,7 @@

Virtual MCP Servers

- +
@@ -2615,7 +2615,7 @@

MCP Tools

Clear
- +
@@ -3613,7 +3613,7 @@

MCP Resources

Clear
- + @@ -3833,7 +3833,7 @@

MCP Prompts

Clear - + @@ -4084,7 +4084,7 @@

Clear - + @@ -5501,7 +5501,7 @@

Clear - + @@ -7653,7 +7653,7 @@

class="inline-block align-bottom bg-white dark:bg-gray-900 rounded-lg text-left overflow-hidden shadow-xl transform transition-all sm:my-8 sm:align-middle sm:max-w-4xl sm:w-full" >
- +

- -