diff --git a/.env.example b/.env.example index 693715cf2..8afdf1313 100644 --- a/.env.example +++ b/.env.example @@ -537,7 +537,7 @@ SECURITY_HEADERS_ENABLED=true # null or none: Completely removes iframe restrictions (no headers sent) # ALLOW-FROM uri: Allows specific domain (deprecated, use CSP instead) # ALLOW-ALL uri: Allows all (*, http, https) -# +# # Both X-Frame-Options header and CSP frame-ancestors directive are automatically synced. # Modern browsers prioritize CSP frame-ancestors over X-Frame-Options. X_FRAME_OPTIONS=DENY @@ -659,6 +659,17 @@ LOG_MAX_SIZE_MB=1 LOG_BACKUP_COUNT=5 LOG_BUFFER_SIZE_MB=1.0 +# Correlation ID / Request Tracking +# Enable automatic correlation ID tracking for unified request tracing +# Options: true (default), false +CORRELATION_ID_ENABLED=true +# HTTP header name for correlation ID (default: X-Correlation-ID) +CORRELATION_ID_HEADER=X-Correlation-ID +# Preserve incoming correlation IDs from clients (default: true) +CORRELATION_ID_PRESERVE=true +# Include correlation ID in HTTP response headers (default: true) +CORRELATION_ID_RESPONSE_HEADER=true + # Transport Protocol Configuration # Options: all (default), sse, streamablehttp, http # - all: Enable all transport protocols @@ -1193,6 +1204,16 @@ PAGINATION_INCLUDE_LINKS=true # Enable TLS for gRPC connections by default # MCPGATEWAY_GRPC_TLS_ENABLED=false +##################################### +# Security Event Logging +##################################### + +# Enable security event logging (authentication attempts, authorization failures, etc.) +# Options: true (default), false +# When enabled, the AuthContextMiddleware will log all authentication attempts to the database +# This is INDEPENDENT of observability settings - security logging is critical for audit trails +# SECURITY_LOGGING_ENABLED=true + ##################################### # Observability Settings ##################################### diff --git a/README.md b/README.md index fae6a2b60..ff0390d3b 100644 --- a/README.md +++ b/README.md @@ -1619,7 +1619,7 @@ ContextForge implements **OAuth 2.0 Dynamic Client Registration (RFC 7591)** and > > **iframe Embedding**: The gateway controls iframe embedding through both `X-Frame-Options` header and CSP `frame-ancestors` directive (both are automatically synced). Options: > - `X_FRAME_OPTIONS=DENY` (default): Blocks all iframe embedding -> - `X_FRAME_OPTIONS=SAMEORIGIN`: Allows embedding from same domain only +> - `X_FRAME_OPTIONS=SAMEORIGIN`: Allows embedding from same domain only > - `X_FRAME_OPTIONS="ALLOW-ALL"`: Allows embedding from all sources (sets `frame-ancestors * file: http: https:`) > - `X_FRAME_OPTIONS=null` or `none`: Completely removes iframe restrictions (no headers sent) > diff --git a/docs/docs/deployment/container.md b/docs/docs/deployment/container.md index 775aeb430..8342e4680 100644 --- a/docs/docs/deployment/container.md +++ b/docs/docs/deployment/container.md @@ -31,12 +31,12 @@ docker logs mcpgateway You can now access the UI at [http://localhost:4444/admin](http://localhost:4444/admin) ### Multi-architecture containers -Note: the container build process creates container images for 'amd64', 'arm64' and 's390x' architectures. The version `ghcr.io/ibm/mcp-context-forge:VERSION` +Note: the container build process creates container images for 'amd64', 'arm64' and 's390x' architectures. The version `ghcr.io/ibm/mcp-context-forge:VERSION` not points to a manifest so that if all commands will pull the correct image for the architecture being used (whether that be locally or on Kubernetes or OpenShift). If the specific image is needed for one architecture on a different architecture use the appropriate arguments for your given container execution tool: -With docker run: +With docker run: ``` docker run [... all your options...] --platform linux/arm64 ghcr.io/ibm/mcp-context-forge:VERSION ``` diff --git a/gunicorn.config.py b/gunicorn.config.py index f6158672f..df888da42 100644 --- a/gunicorn.config.py +++ b/gunicorn.config.py @@ -65,37 +65,37 @@ def on_starting(server): """Called just before the master process is initialized. - + This is where we handle passphrase-protected SSL keys by decrypting them to a temporary file before Gunicorn workers start. """ global _prepared_key_file - + # Check if SSL is enabled via environment variable (set by run-gunicorn.sh) # and a passphrase is provided ssl_enabled = os.environ.get("SSL", "false").lower() == "true" ssl_key_password = os.environ.get("SSL_KEY_PASSWORD") - + if ssl_enabled and ssl_key_password: try: from mcpgateway.utils.ssl_key_manager import prepare_ssl_key - + # Get the key file path from environment (set by run-gunicorn.sh) key_file = os.environ.get("KEY_FILE", "certs/key.pem") - + server.log.info(f"Preparing passphrase-protected SSL key: {key_file}") - + # Decrypt the key and get the temporary file path _prepared_key_file = prepare_ssl_key(key_file, ssl_key_password) - + server.log.info(f"SSL key prepared successfully: {_prepared_key_file}") - + # Update the keyfile setting to use the decrypted temporary file # This is a bit of a hack, but Gunicorn doesn't provide a better way # to modify the keyfile after it's been set via command line if hasattr(server, 'cfg'): server.cfg.set('keyfile', _prepared_key_file) - + except Exception as e: server.log.error(f"Failed to prepare SSL key: {e}") raise @@ -127,4 +127,3 @@ def worker_exit(server, worker): def child_exit(server, worker): server.log.info("Worker child exit (pid: %s)", worker.pid) - diff --git a/mcpgateway/admin.py b/mcpgateway/admin.py index d5f981875..123bee450 100644 --- a/mcpgateway/admin.py +++ b/mcpgateway/admin.py @@ -47,7 +47,7 @@ from pydantic import SecretStr, ValidationError from pydantic_core import ValidationError as CoreValidationError from sqlalchemy import and_, case, cast, desc, func, or_, select, String -from sqlalchemy.exc import IntegrityError +from sqlalchemy.exc import IntegrityError, InvalidRequestError, OperationalError from sqlalchemy.orm import joinedload, Session from sqlalchemy.sql.functions import coalesce from starlette.datastructures import UploadFile as StarletteUploadFile @@ -105,6 +105,7 @@ ) from mcpgateway.services.a2a_service import A2AAgentError, A2AAgentNameConflictError, A2AAgentNotFoundError, A2AAgentService from mcpgateway.services.argon2_service import Argon2PasswordService +from mcpgateway.services.audit_trail_service import get_audit_trail_service from mcpgateway.services.catalog_service import catalog_service from mcpgateway.services.email_auth_service import AuthenticationError, EmailAuthService, PasswordValidationError from mcpgateway.services.encryption_service import get_encryption_service @@ -120,6 +121,7 @@ from mcpgateway.services.resource_service import ResourceNotFoundError, ResourceService, ResourceURIConflictError from mcpgateway.services.root_service import RootService from mcpgateway.services.server_service import ServerError, ServerNameConflictError, ServerNotFoundError, ServerService +from mcpgateway.services.structured_logger import get_structured_logger from mcpgateway.services.tag_service import TagService from mcpgateway.services.team_management_service import TeamManagementService from mcpgateway.services.tool_service import ToolError, ToolNameConflictError, ToolNotFoundError, ToolService @@ -8369,6 +8371,17 @@ async def admin_add_resource(request: Request, db: Session = Depends(get_db), us status_code=200, ) except Exception as ex: + # Roll back only when a transaction is active to avoid sqlite3 "no transaction" errors. + try: + active_transaction = db.get_transaction() if hasattr(db, "get_transaction") else None + if db.is_active and active_transaction is not None: + db.rollback() + except (InvalidRequestError, OperationalError) as rollback_error: + LOGGER.warning( + "Rollback failed (ignoring for SQLite compatibility): %s", + rollback_error, + ) + if isinstance(ex, ValidationError): LOGGER.error(f"ValidationError in admin_add_resource: {ErrorFormatter.format_validation_error(ex)}") return JSONResponse(content=ErrorFormatter.format_validation_error(ex), status_code=422) @@ -8587,7 +8600,11 @@ async def admin_delete_resource(resource_id: str, request: Request, db: Session LOGGER.debug(f"User {get_user_email(user)} is deleting resource ID {resource_id}") error_message = None try: - await resource_service.delete_resource(user["db"] if isinstance(user, dict) else db, resource_id) + await resource_service.delete_resource( + user["db"] if isinstance(user, dict) else db, + resource_id, + user_email=user_email, + ) except PermissionError as e: LOGGER.warning(f"Permission denied for user {user_email} deleting resource {resource_id}: {e}") error_message = str(e) @@ -9676,7 +9693,7 @@ async def admin_test_gateway(request: GatewayTestRequest, team_id: Optional[str] >>> async def test_admin_test_gateway(): ... with patch('mcpgateway.admin.ResilientHttpClient') as mock_client_class: ... mock_client_class.return_value = MockClient() - ... response = await admin_test_gateway(mock_request, mock_user) + ... response = await admin_test_gateway(mock_request, None, mock_user, mock_db) ... return isinstance(response, GatewayTestResponse) and response.status_code == 200 >>> >>> result = asyncio.run(test_admin_test_gateway()) @@ -9702,7 +9719,7 @@ async def admin_test_gateway(request: GatewayTestRequest, team_id: Optional[str] >>> async def test_admin_test_gateway_text_response(): ... with patch('mcpgateway.admin.ResilientHttpClient') as mock_client_class: ... mock_client_class.return_value = MockClientTextOnly() - ... response = await admin_test_gateway(mock_request, mock_user) + ... response = await admin_test_gateway(mock_request, None, mock_user, mock_db) ... return isinstance(response, GatewayTestResponse) and response.body.get("details") == "plain text response" >>> >>> asyncio.run(test_admin_test_gateway_text_response()) @@ -9720,7 +9737,7 @@ async def admin_test_gateway(request: GatewayTestRequest, team_id: Optional[str] >>> async def test_admin_test_gateway_network_error(): ... with patch('mcpgateway.admin.ResilientHttpClient') as mock_client_class: ... mock_client_class.return_value = MockClientError() - ... response = await admin_test_gateway(mock_request, mock_user) + ... response = await admin_test_gateway(mock_request, None, mock_user, mock_db) ... return response.status_code == 502 and "Network error" in str(response.body) >>> >>> asyncio.run(test_admin_test_gateway_network_error()) @@ -9738,7 +9755,7 @@ async def admin_test_gateway(request: GatewayTestRequest, team_id: Optional[str] >>> async def test_admin_test_gateway_post(): ... with patch('mcpgateway.admin.ResilientHttpClient') as mock_client_class: ... mock_client_class.return_value = MockClient() - ... response = await admin_test_gateway(mock_request_post, mock_user) + ... response = await admin_test_gateway(mock_request_post, None, mock_user, mock_db) ... return isinstance(response, GatewayTestResponse) and response.status_code == 200 >>> >>> asyncio.run(test_admin_test_gateway_post()) @@ -9756,7 +9773,7 @@ async def admin_test_gateway(request: GatewayTestRequest, team_id: Optional[str] >>> async def test_admin_test_gateway_trailing_slash(): ... with patch('mcpgateway.admin.ResilientHttpClient') as mock_client_class: ... mock_client_class.return_value = MockClient() - ... response = await admin_test_gateway(mock_request_trailing, mock_user) + ... response = await admin_test_gateway(mock_request_trailing, None, mock_user, mock_db) ... return isinstance(response, GatewayTestResponse) and response.status_code == 200 >>> >>> asyncio.run(test_admin_test_gateway_trailing_slash()) @@ -9846,11 +9863,56 @@ async def admin_test_gateway(request: GatewayTestRequest, team_id: Optional[str] except json.JSONDecodeError: response_body = {"details": response.text} + # Structured logging: Log successful gateway test + structured_logger = get_structured_logger("gateway_service") + structured_logger.log( + level="INFO", + message=f"Gateway test completed: {request.base_url}", + event_type="gateway_tested", + component="gateway_service", + user_email=get_user_email(user), + team_id=team_id, + resource_type="gateway", + resource_id=gateway.id if gateway else None, + custom_fields={ + "gateway_name": gateway.name if gateway else None, + "gateway_url": str(request.base_url), + "test_method": request.method, + "test_path": request.path, + "status_code": response.status_code, + "latency_ms": latency_ms, + }, + db=db, + ) + return GatewayTestResponse(status_code=response.status_code, latency_ms=latency_ms, body=response_body) except httpx.RequestError as e: LOGGER.warning(f"Gateway test failed: {e}") latency_ms = int((time.monotonic() - start_time) * 1000) + + # Structured logging: Log failed gateway test + structured_logger = get_structured_logger("gateway_service") + structured_logger.log( + level="ERROR", + message=f"Gateway test failed: {request.base_url}", + event_type="gateway_test_failed", + component="gateway_service", + user_email=get_user_email(user), + team_id=team_id, + resource_type="gateway", + resource_id=gateway.id if gateway else None, + error=e, + custom_fields={ + "gateway_name": gateway.name if gateway else None, + "gateway_url": str(request.base_url), + "test_method": request.method, + "test_path": request.path, + "latency_ms": latency_ms, + }, + db=db, + ) + return GatewayTestResponse(status_code=502, latency_ms=latency_ms, body={"error": "Request failed", "details": str(e)}) @@ -11896,6 +11958,7 @@ async def admin_test_a2a_agent( return JSONResponse(content={"success": False, "error": "A2A features are disabled"}, status_code=403) try: + user_email = get_user_email(user) # Get the agent by ID agent = await a2a_service.get_agent(db, agent_id) @@ -11911,7 +11974,14 @@ async def admin_test_a2a_agent( test_params = {"message": "Hello from MCP Gateway Admin UI test!", "test": True, "timestamp": int(time.time())} # Invoke the agent - result = await a2a_service.invoke_agent(db, agent.name, test_params, "admin_test") + result = await a2a_service.invoke_agent( + db, + agent.name, + test_params, + "admin_test", + user_email=user_email, + user_id=user_email, + ) return JSONResponse(content={"success": True, "result": result, "agent_name": agent.name, "test_timestamp": time.time()}) @@ -12536,6 +12606,7 @@ async def list_plugins( HTTPException: If there's an error retrieving plugins """ LOGGER.debug(f"User {get_user_email(user)} requested plugin list") + structured_logger = get_structured_logger() try: # Get plugin service @@ -12556,10 +12627,35 @@ async def list_plugins( enabled_count = sum(1 for p in plugins if p["status"] == "enabled") disabled_count = sum(1 for p in plugins if p["status"] == "disabled") + # Log plugin marketplace browsing activity + structured_logger.info( + "User browsed plugin marketplace", + user_id=str(user.id), + user_email=get_user_email(user), + component="plugin_marketplace", + category="business_logic", + resource_type="plugin_list", + resource_action="browse", + custom_fields={ + "search_query": search, + "filter_mode": mode, + "filter_hook": hook, + "filter_tag": tag, + "results_count": len(plugins), + "enabled_count": enabled_count, + "disabled_count": disabled_count, + "has_filters": any([search, mode, hook, tag]), + }, + db=db, + ) + return PluginListResponse(plugins=plugins, total=len(plugins), enabled_count=enabled_count, disabled_count=disabled_count) except Exception as e: LOGGER.error(f"Error listing plugins: {e}") + structured_logger.error( + "Failed to list plugins in marketplace", user_id=str(user.id), user_email=get_user_email(user), error=e, component="plugin_marketplace", category="business_logic", db=db + ) raise HTTPException(status_code=500, detail=str(e)) @@ -12579,6 +12675,7 @@ async def get_plugin_stats(request: Request, db: Session = Depends(get_db), user HTTPException: If there's an error getting plugin statistics """ LOGGER.debug(f"User {get_user_email(user)} requested plugin statistics") + structured_logger = get_structured_logger() try: # Get plugin service @@ -12592,10 +12689,33 @@ async def get_plugin_stats(request: Request, db: Session = Depends(get_db), user # Get statistics stats = plugin_service.get_plugin_statistics() + # Log marketplace analytics access + structured_logger.info( + "User accessed plugin marketplace statistics", + user_id=str(user.id), + user_email=get_user_email(user), + component="plugin_marketplace", + category="business_logic", + resource_type="plugin_stats", + resource_action="view", + custom_fields={ + "total_plugins": stats.get("total_plugins", 0), + "enabled_plugins": stats.get("enabled_plugins", 0), + "disabled_plugins": stats.get("disabled_plugins", 0), + "hooks_count": len(stats.get("plugins_by_hook", {})), + "tags_count": len(stats.get("plugins_by_tag", {})), + "authors_count": len(stats.get("plugins_by_author", {})), + }, + db=db, + ) + return PluginStatsResponse(**stats) except Exception as e: LOGGER.error(f"Error getting plugin statistics: {e}") + structured_logger.error( + "Failed to get plugin marketplace statistics", user_id=str(user.id), user_email=get_user_email(user), error=e, component="plugin_marketplace", category="business_logic", db=db + ) raise HTTPException(status_code=500, detail=str(e)) @@ -12616,6 +12736,8 @@ async def get_plugin_details(name: str, request: Request, db: Session = Depends( HTTPException: If plugin not found """ LOGGER.debug(f"User {get_user_email(user)} requested details for plugin {name}") + structured_logger = get_structured_logger() + audit_service = get_audit_trail_service() try: # Get plugin service @@ -12630,14 +12752,53 @@ async def get_plugin_details(name: str, request: Request, db: Session = Depends( plugin = plugin_service.get_plugin_by_name(name) if not plugin: + structured_logger.warning( + f"Plugin '{name}' not found in marketplace", + user_id=str(user.id), + user_email=get_user_email(user), + component="plugin_marketplace", + category="business_logic", + custom_fields={"plugin_name": name, "action": "view_details"}, + db=db, + ) raise HTTPException(status_code=404, detail=f"Plugin '{name}' not found") + # Log plugin view activity + structured_logger.info( + f"User viewed plugin details: '{name}'", + user_id=str(user.id), + user_email=get_user_email(user), + component="plugin_marketplace", + category="business_logic", + resource_type="plugin", + resource_id=name, + resource_action="view_details", + custom_fields={ + "plugin_name": name, + "plugin_version": plugin.get("version"), + "plugin_author": plugin.get("author"), + "plugin_status": plugin.get("status"), + "plugin_mode": plugin.get("mode"), + "plugin_hooks": plugin.get("hooks", []), + "plugin_tags": plugin.get("tags", []), + }, + db=db, + ) + + # Create audit trail for plugin access + audit_service.log_audit( + user_id=str(user.id), user_email=get_user_email(user), resource_type="plugin", resource_id=name, action="view", description=f"Viewed plugin '{name}' details in marketplace", db=db + ) + return PluginDetail(**plugin) except HTTPException: raise except Exception as e: LOGGER.error(f"Error getting plugin details: {e}") + structured_logger.error( + f"Failed to get plugin details: '{name}'", user_id=str(user.id), user_email=get_user_email(user), error=e, component="plugin_marketplace", category="business_logic", db=db + ) raise HTTPException(status_code=500, detail=str(e)) diff --git a/mcpgateway/alembic/versions/356a2d4eed6f_uuid_change_for_prompt_and_resources.py b/mcpgateway/alembic/versions/356a2d4eed6f_uuid_change_for_prompt_and_resources.py index b1e49a6f0..b616a0892 100644 --- a/mcpgateway/alembic/versions/356a2d4eed6f_uuid_change_for_prompt_and_resources.py +++ b/mcpgateway/alembic/versions/356a2d4eed6f_uuid_change_for_prompt_and_resources.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- """UUID Change for Prompt and Resources Revision ID: 356a2d4eed6f diff --git a/mcpgateway/alembic/versions/9e028ecf59c4_tag_records_changes_list_str_to_list_.py b/mcpgateway/alembic/versions/9e028ecf59c4_tag_records_changes_list_str_to_list_.py index 61ba1ed7c..481f303f5 100644 --- a/mcpgateway/alembic/versions/9e028ecf59c4_tag_records_changes_list_str_to_list_.py +++ b/mcpgateway/alembic/versions/9e028ecf59c4_tag_records_changes_list_str_to_list_.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- """tag records changes list[str] to list[Dict[str,str]] Revision ID: 9e028ecf59c4 diff --git a/mcpgateway/alembic/versions/k5e6f7g8h9i0_add_structured_logging_tables.py b/mcpgateway/alembic/versions/k5e6f7g8h9i0_add_structured_logging_tables.py new file mode 100644 index 000000000..a83afbd28 --- /dev/null +++ b/mcpgateway/alembic/versions/k5e6f7g8h9i0_add_structured_logging_tables.py @@ -0,0 +1,216 @@ +# -*- coding: utf-8 -*- +"""Add structured logging tables + +Revision ID: k5e6f7g8h9i0 +Revises: 356a2d4eed6f +Create Date: 2025-01-15 12:00:00.000000 + +""" + +# Third-Party +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = "k5e6f7g8h9i0" +down_revision = "356a2d4eed6f" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + """Add structured logging tables.""" + # Create structured_log_entries table + op.create_table( + "structured_log_entries", + sa.Column("id", sa.String(36), nullable=False), + sa.Column("timestamp", sa.DateTime(timezone=True), nullable=False), + sa.Column("correlation_id", sa.String(64), nullable=True), + sa.Column("request_id", sa.String(64), nullable=True), + sa.Column("level", sa.String(20), nullable=False), + sa.Column("component", sa.String(100), nullable=False), + sa.Column("message", sa.Text(), nullable=False), + sa.Column("logger", sa.String(255), nullable=True), + sa.Column("user_id", sa.String(255), nullable=True), + sa.Column("user_email", sa.String(255), nullable=True), + sa.Column("client_ip", sa.String(45), nullable=True), + sa.Column("user_agent", sa.Text(), nullable=True), + sa.Column("request_path", sa.String(500), nullable=True), + sa.Column("request_method", sa.String(10), nullable=True), + sa.Column("duration_ms", sa.Float(), nullable=True), + sa.Column("operation_type", sa.String(100), nullable=True), + sa.Column("is_security_event", sa.Boolean(), nullable=False, server_default=sa.false()), + sa.Column("security_severity", sa.String(20), nullable=True), + sa.Column("threat_indicators", sa.JSON(), nullable=True), + sa.Column("context", sa.JSON(), nullable=True), + sa.Column("error_details", sa.JSON(), nullable=True), + sa.Column("performance_metrics", sa.JSON(), nullable=True), + sa.Column("hostname", sa.String(255), nullable=False), + sa.Column("process_id", sa.Integer(), nullable=False), + sa.Column("thread_id", sa.Integer(), nullable=True), + sa.Column("version", sa.String(50), nullable=False), + sa.Column("environment", sa.String(50), nullable=False, server_default="production"), + sa.Column("trace_id", sa.String(32), nullable=True), + sa.Column("span_id", sa.String(16), nullable=True), + sa.PrimaryKeyConstraint("id"), + ) + + # Create indexes for structured_log_entries + op.create_index("ix_structured_log_entries_timestamp", "structured_log_entries", ["timestamp"], unique=False) + op.create_index("ix_structured_log_entries_level", "structured_log_entries", ["level"], unique=False) + op.create_index("ix_structured_log_entries_component", "structured_log_entries", ["component"], unique=False) + op.create_index("ix_structured_log_entries_correlation_id", "structured_log_entries", ["correlation_id"], unique=False) + op.create_index("ix_structured_log_entries_request_id", "structured_log_entries", ["request_id"], unique=False) + op.create_index("ix_structured_log_entries_user_id", "structured_log_entries", ["user_id"], unique=False) + op.create_index("ix_structured_log_entries_user_email", "structured_log_entries", ["user_email"], unique=False) + op.create_index("ix_structured_log_entries_operation_type", "structured_log_entries", ["operation_type"], unique=False) + op.create_index("ix_structured_log_entries_is_security_event", "structured_log_entries", ["is_security_event"], unique=False) + op.create_index("ix_structured_log_entries_security_severity", "structured_log_entries", ["security_severity"], unique=False) + op.create_index("ix_structured_log_entries_trace_id", "structured_log_entries", ["trace_id"], unique=False) + + # Composite indexes matching db.py + op.create_index("idx_log_correlation_time", "structured_log_entries", ["correlation_id", "timestamp"], unique=False) + op.create_index("idx_log_user_time", "structured_log_entries", ["user_id", "timestamp"], unique=False) + op.create_index("idx_log_level_time", "structured_log_entries", ["level", "timestamp"], unique=False) + op.create_index("idx_log_component_time", "structured_log_entries", ["component", "timestamp"], unique=False) + op.create_index("idx_log_security", "structured_log_entries", ["is_security_event", "security_severity", "timestamp"], unique=False) + op.create_index("idx_log_operation", "structured_log_entries", ["operation_type", "timestamp"], unique=False) + op.create_index("idx_log_trace", "structured_log_entries", ["trace_id", "timestamp"], unique=False) + + # Create performance_metrics table + op.create_table( + "performance_metrics", + sa.Column("id", sa.String(36), nullable=False), + sa.Column("timestamp", sa.DateTime(timezone=True), nullable=False), + sa.Column("operation_type", sa.String(100), nullable=False), + sa.Column("component", sa.String(100), nullable=False), + sa.Column("request_count", sa.Integer(), nullable=False, server_default="0"), + sa.Column("error_count", sa.Integer(), nullable=False, server_default="0"), + sa.Column("error_rate", sa.Float(), nullable=False, server_default="0.0"), + sa.Column("avg_duration_ms", sa.Float(), nullable=False), + sa.Column("min_duration_ms", sa.Float(), nullable=False), + sa.Column("max_duration_ms", sa.Float(), nullable=False), + sa.Column("p50_duration_ms", sa.Float(), nullable=False), + sa.Column("p95_duration_ms", sa.Float(), nullable=False), + sa.Column("p99_duration_ms", sa.Float(), nullable=False), + sa.Column("window_start", sa.DateTime(timezone=True), nullable=False), + sa.Column("window_end", sa.DateTime(timezone=True), nullable=False), + sa.Column("window_duration_seconds", sa.Integer(), nullable=False), + sa.Column("metric_metadata", sa.JSON(), nullable=True), + sa.PrimaryKeyConstraint("id"), + ) + + # Create indexes for performance_metrics + op.create_index("ix_performance_metrics_timestamp", "performance_metrics", ["timestamp"], unique=False) + op.create_index("ix_performance_metrics_component", "performance_metrics", ["component"], unique=False) + op.create_index("ix_performance_metrics_operation_type", "performance_metrics", ["operation_type"], unique=False) + op.create_index("ix_performance_metrics_window_start", "performance_metrics", ["window_start"], unique=False) + op.create_index("idx_perf_operation_time", "performance_metrics", ["operation_type", "window_start"], unique=False) + op.create_index("idx_perf_component_time", "performance_metrics", ["component", "window_start"], unique=False) + op.create_index("idx_perf_window", "performance_metrics", ["window_start", "window_end"], unique=False) + + # Create security_events table + op.create_table( + "security_events", + sa.Column("id", sa.String(36), nullable=False), + sa.Column("timestamp", sa.DateTime(timezone=True), nullable=False), + sa.Column("detected_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("correlation_id", sa.String(64), nullable=True), + sa.Column("log_entry_id", sa.String(36), nullable=True), + sa.Column("event_type", sa.String(100), nullable=False), + sa.Column("severity", sa.String(20), nullable=False), + sa.Column("category", sa.String(50), nullable=False), + sa.Column("user_id", sa.String(255), nullable=True), + sa.Column("user_email", sa.String(255), nullable=True), + sa.Column("client_ip", sa.String(45), nullable=False), + sa.Column("user_agent", sa.Text(), nullable=True), + sa.Column("description", sa.Text(), nullable=False), + sa.Column("action_taken", sa.String(100), nullable=True), + sa.Column("threat_score", sa.Float(), nullable=False, server_default="0.0"), + sa.Column("threat_indicators", sa.JSON(), nullable=False, server_default=sa.text("'{}'")), + sa.Column("failed_attempts_count", sa.Integer(), nullable=False, server_default="0"), + sa.Column("resolved", sa.Boolean(), nullable=False, server_default=sa.false()), + sa.Column("resolved_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("resolved_by", sa.String(255), nullable=True), + sa.Column("resolution_notes", sa.Text(), nullable=True), + sa.Column("alert_sent", sa.Boolean(), nullable=False, server_default=sa.false()), + sa.Column("alert_sent_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("alert_recipients", sa.JSON(), nullable=True), + sa.Column("context", sa.JSON(), nullable=True), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["log_entry_id"], ["structured_log_entries.id"]), + ) + + # Create indexes for security_events + op.create_index("ix_security_events_timestamp", "security_events", ["timestamp"], unique=False) + op.create_index("ix_security_events_detected_at", "security_events", ["detected_at"], unique=False) + op.create_index("ix_security_events_correlation_id", "security_events", ["correlation_id"], unique=False) + op.create_index("ix_security_events_event_type", "security_events", ["event_type"], unique=False) + op.create_index("ix_security_events_severity", "security_events", ["severity"], unique=False) + op.create_index("ix_security_events_category", "security_events", ["category"], unique=False) + op.create_index("ix_security_events_user_id", "security_events", ["user_id"], unique=False) + op.create_index("ix_security_events_user_email", "security_events", ["user_email"], unique=False) + op.create_index("ix_security_events_client_ip", "security_events", ["client_ip"], unique=False) + op.create_index("ix_security_events_log_entry_id", "security_events", ["log_entry_id"], unique=False) + op.create_index("ix_security_events_resolved", "security_events", ["resolved"], unique=False) + op.create_index("idx_security_type_time", "security_events", ["event_type", "timestamp"], unique=False) + op.create_index("idx_security_severity_time", "security_events", ["severity", "timestamp"], unique=False) + op.create_index("idx_security_user_time", "security_events", ["user_id", "timestamp"], unique=False) + op.create_index("idx_security_ip_time", "security_events", ["client_ip", "timestamp"], unique=False) + op.create_index("idx_security_unresolved", "security_events", ["resolved", "severity", "timestamp"], unique=False) + + # Create audit_trails table + op.create_table( + "audit_trails", + sa.Column("id", sa.String(36), nullable=False), + sa.Column("timestamp", sa.DateTime(timezone=True), nullable=False), + sa.Column("correlation_id", sa.String(64), nullable=True), + sa.Column("request_id", sa.String(64), nullable=True), + sa.Column("action", sa.String(100), nullable=False), + sa.Column("resource_type", sa.String(100), nullable=False), + sa.Column("resource_id", sa.String(255), nullable=False), + sa.Column("resource_name", sa.String(500), nullable=True), + sa.Column("user_id", sa.String(255), nullable=False), + sa.Column("user_email", sa.String(255), nullable=True), + sa.Column("team_id", sa.String(36), nullable=True), + sa.Column("client_ip", sa.String(45), nullable=True), + sa.Column("user_agent", sa.Text(), nullable=True), + sa.Column("request_path", sa.String(500), nullable=True), + sa.Column("request_method", sa.String(10), nullable=True), + sa.Column("old_values", sa.JSON(), nullable=True), + sa.Column("new_values", sa.JSON(), nullable=True), + sa.Column("changes", sa.JSON(), nullable=True), + sa.Column("data_classification", sa.String(50), nullable=True), + sa.Column("requires_review", sa.Boolean(), nullable=False, server_default=sa.false()), + sa.Column("success", sa.Boolean(), nullable=False), + sa.Column("error_message", sa.Text(), nullable=True), + sa.Column("context", sa.JSON(), nullable=True), + sa.PrimaryKeyConstraint("id"), + ) + + # Create indexes for audit_trails + op.create_index("ix_audit_trails_timestamp", "audit_trails", ["timestamp"], unique=False) + op.create_index("ix_audit_trails_correlation_id", "audit_trails", ["correlation_id"], unique=False) + op.create_index("ix_audit_trails_request_id", "audit_trails", ["request_id"], unique=False) + op.create_index("ix_audit_trails_action", "audit_trails", ["action"], unique=False) + op.create_index("ix_audit_trails_resource_type", "audit_trails", ["resource_type"], unique=False) + op.create_index("ix_audit_trails_resource_id", "audit_trails", ["resource_id"], unique=False) + op.create_index("ix_audit_trails_user_id", "audit_trails", ["user_id"], unique=False) + op.create_index("ix_audit_trails_user_email", "audit_trails", ["user_email"], unique=False) + op.create_index("ix_audit_trails_team_id", "audit_trails", ["team_id"], unique=False) + op.create_index("ix_audit_trails_data_classification", "audit_trails", ["data_classification"], unique=False) + op.create_index("ix_audit_trails_requires_review", "audit_trails", ["requires_review"], unique=False) + op.create_index("ix_audit_trails_success", "audit_trails", ["success"], unique=False) + op.create_index("idx_audit_action_time", "audit_trails", ["action", "timestamp"], unique=False) + op.create_index("idx_audit_resource_time", "audit_trails", ["resource_type", "resource_id", "timestamp"], unique=False) + op.create_index("idx_audit_user_time", "audit_trails", ["user_id", "timestamp"], unique=False) + op.create_index("idx_audit_classification", "audit_trails", ["data_classification", "timestamp"], unique=False) + op.create_index("idx_audit_review", "audit_trails", ["requires_review", "timestamp"], unique=False) + + +def downgrade() -> None: + """Remove structured logging tables.""" + op.drop_table("audit_trails") + op.drop_table("security_events") + op.drop_table("performance_metrics") + op.drop_table("structured_log_entries") diff --git a/mcpgateway/auth.py b/mcpgateway/auth.py index ea633ee5d..c0300a124 100644 --- a/mcpgateway/auth.py +++ b/mcpgateway/auth.py @@ -26,11 +26,63 @@ from mcpgateway.config import settings from mcpgateway.db import EmailUser, SessionLocal from mcpgateway.plugins.framework import get_plugin_manager, GlobalContext, HttpAuthResolveUserPayload, HttpHeaderPayload, HttpHookType, PluginViolationError -from mcpgateway.services.team_management_service import TeamManagementService +from mcpgateway.services.team_management_service import TeamManagementService # pylint: disable=import-outside-toplevel +from mcpgateway.utils.correlation_id import get_correlation_id from mcpgateway.utils.verify_credentials import verify_jwt_token # Security scheme -bearer_scheme = HTTPBearer(auto_error=False) +security = HTTPBearer(auto_error=False) + + +def _log_auth_event( + logger: logging.Logger, + message: str, + level: int = logging.INFO, + user_id: Optional[str] = None, + auth_method: Optional[str] = None, + auth_success: bool = False, + security_event: Optional[str] = None, + security_severity: str = "low", + **extra_context, +) -> None: + """Log authentication event with structured context and request_id. + + This helper creates structured log records that include request_id from the + correlation ID context, enabling end-to-end tracing of authentication flows. + + Args: + logger: Logger instance to use + message: Log message + level: Log level (default: INFO) + user_id: User identifier + auth_method: Authentication method used (jwt, api_token, etc.) + auth_success: Whether authentication succeeded + security_event: Type of security event (authentication, authorization, etc.) + security_severity: Severity level (low, medium, high, critical) + **extra_context: Additional context fields + """ + # Get request_id from correlation ID context + request_id = get_correlation_id() + + # Build structured log record + extra = { + "request_id": request_id, + "entity_type": "auth", + "auth_success": auth_success, + "security_event": security_event or "authentication", + "security_severity": security_severity, + } + + if user_id: + extra["user_id"] = user_id + if auth_method: + extra["auth_method"] = auth_method + + # Add any additional context + extra.update(extra_context) + + # Log with structured context + logger.log(level, message, extra=extra) def get_db() -> Generator[Session, Never, None]: @@ -119,7 +171,7 @@ async def get_team_from_token(payload: Dict[str, Any], db: Session) -> Optional[ async def get_current_user( - credentials: Optional[HTTPAuthorizationCredentials] = Depends(bearer_scheme), + credentials: Optional[HTTPAuthorizationCredentials] = Depends(security), db: Session = Depends(get_db), request: Optional[object] = None, ) -> EmailUser: @@ -169,10 +221,15 @@ async def get_current_user( if request and hasattr(request, "headers"): headers = dict(request.headers) - # Get request ID from request state (set by middleware) or generate new one - request_id = getattr(request.state, "request_id", None) if request else None + # Get request ID from correlation ID context (set by CorrelationIDMiddleware) + request_id = get_correlation_id() if not request_id: - request_id = uuid.uuid4().hex + # Fallback chain for safety + if request and hasattr(request, "state") and hasattr(request.state, "request_id"): + request_id = request.state.request_id + else: + request_id = uuid.uuid4().hex + logger.debug(f"Generated fallback request ID in get_current_user: {request_id}") # Get plugin contexts from request state if available global_context = getattr(request.state, "plugin_global_context", None) if request else None diff --git a/mcpgateway/config.py b/mcpgateway/config.py index 9d3017876..49b367484 100644 --- a/mcpgateway/config.py +++ b/mcpgateway/config.py @@ -776,6 +776,51 @@ def _parse_allowed_origins(cls, v: Any) -> Set[str]: # Enable span events observability_events_enabled: bool = Field(default=True, description="Enable event logging within spans") + # Correlation ID Settings + correlation_id_enabled: bool = Field(default=True, description="Enable automatic correlation ID tracking for requests") + correlation_id_header: str = Field(default="X-Correlation-ID", description="HTTP header name for correlation ID") + correlation_id_preserve: bool = Field(default=True, description="Preserve correlation IDs from incoming requests") + correlation_id_response_header: bool = Field(default=True, description="Include correlation ID in response headers") + + # Structured Logging Configuration + structured_logging_enabled: bool = Field(default=True, description="Enable structured JSON logging with database persistence") + structured_logging_database_enabled: bool = Field(default=True, description="Persist structured logs to database") + structured_logging_external_enabled: bool = Field(default=False, description="Send logs to external systems") + + # Performance Tracking Configuration + performance_tracking_enabled: bool = Field(default=True, description="Enable performance tracking and metrics") + performance_threshold_database_query_ms: float = Field(default=100.0, description="Alert threshold for database queries (ms)") + performance_threshold_tool_invocation_ms: float = Field(default=2000.0, description="Alert threshold for tool invocations (ms)") + performance_threshold_resource_read_ms: float = Field(default=1000.0, description="Alert threshold for resource reads (ms)") + performance_threshold_http_request_ms: float = Field(default=500.0, description="Alert threshold for HTTP requests (ms)") + performance_degradation_multiplier: float = Field(default=1.5, description="Alert if performance degrades by this multiplier vs baseline") + + # Security Logging Configuration + security_logging_enabled: bool = Field(default=True, description="Enable security event logging") + security_failed_auth_threshold: int = Field(default=5, description="Failed auth attempts before high severity alert") + security_threat_score_alert: float = Field(default=0.7, description="Threat score threshold for alerts (0.0-1.0)") + security_rate_limit_window_minutes: int = Field(default=5, description="Time window for rate limit checks (minutes)") + + # Metrics Aggregation Configuration + metrics_aggregation_enabled: bool = Field(default=True, description="Enable automatic log aggregation into performance metrics") + metrics_aggregation_backfill_hours: int = Field(default=6, ge=0, le=168, description="Hours of structured logs to backfill into performance metrics on startup") + metrics_aggregation_window_minutes: int = Field(default=5, description="Time window for metrics aggregation (minutes)") + metrics_aggregation_auto_start: bool = Field(default=False, description="Automatically run the log aggregation loop on application startup") + + # Log Search Configuration + log_search_max_results: int = Field(default=1000, description="Maximum results per log search query") + log_retention_days: int = Field(default=30, description="Number of days to retain logs in database") + + # External Log Integration Configuration + elasticsearch_enabled: bool = Field(default=False, description="Send logs to Elasticsearch") + elasticsearch_url: Optional[str] = Field(default=None, description="Elasticsearch cluster URL") + elasticsearch_index_prefix: str = Field(default="mcpgateway-logs", description="Elasticsearch index prefix") + syslog_enabled: bool = Field(default=False, description="Send logs to syslog") + syslog_host: Optional[str] = Field(default=None, description="Syslog server host") + syslog_port: int = Field(default=514, description="Syslog server port") + webhook_logging_enabled: bool = Field(default=False, description="Send logs to webhook endpoints") + webhook_logging_urls: List[str] = Field(default_factory=list, description="Webhook URLs for log delivery") + @field_validator("log_level", mode="before") @classmethod def validate_log_level(cls, v: str) -> str: diff --git a/mcpgateway/db.py b/mcpgateway/db.py index f5289951d..f548c8af4 100644 --- a/mcpgateway/db.py +++ b/mcpgateway/db.py @@ -3797,6 +3797,252 @@ def init_db(): raise Exception(f"Failed to initialize database: {str(e)}") +# ============================================================================ +# Structured Logging Models +# ============================================================================ + + +class StructuredLogEntry(Base): + """Structured log entry for comprehensive logging and analysis. + + Stores all log entries with correlation IDs, performance metrics, + and security context for advanced search and analytics. + """ + + __tablename__ = "structured_log_entries" + + # Primary key + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: uuid.uuid4().hex) + + # Timestamps + timestamp: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, index=True, default=utc_now) + + # Correlation and request tracking + correlation_id: Mapped[Optional[str]] = mapped_column(String(64), index=True, nullable=True) + request_id: Mapped[Optional[str]] = mapped_column(String(64), index=True, nullable=True) + + # Log metadata + level: Mapped[str] = mapped_column(String(20), nullable=False, index=True) # DEBUG, INFO, WARNING, ERROR, CRITICAL + component: Mapped[str] = mapped_column(String(100), nullable=False, index=True) + message: Mapped[str] = mapped_column(Text, nullable=False) + logger: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + + # User and request context + user_id: Mapped[Optional[str]] = mapped_column(String(255), index=True, nullable=True) + user_email: Mapped[Optional[str]] = mapped_column(String(255), index=True, nullable=True) + client_ip: Mapped[Optional[str]] = mapped_column(String(45), nullable=True) # IPv6 max length + user_agent: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + request_path: Mapped[Optional[str]] = mapped_column(String(500), nullable=True) + request_method: Mapped[Optional[str]] = mapped_column(String(10), nullable=True) + + # Performance data + duration_ms: Mapped[Optional[float]] = mapped_column(Float, nullable=True) + operation_type: Mapped[Optional[str]] = mapped_column(String(100), index=True, nullable=True) + + # Security context + is_security_event: Mapped[bool] = mapped_column(Boolean, default=False, index=True, nullable=False) + security_severity: Mapped[Optional[str]] = mapped_column(String(20), index=True, nullable=True) # LOW, MEDIUM, HIGH, CRITICAL + threat_indicators: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True) + + # Structured context data + context: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True) + error_details: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True) + performance_metrics: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True) + + # System information + hostname: Mapped[str] = mapped_column(String(255), nullable=False) + process_id: Mapped[int] = mapped_column(Integer, nullable=False) + thread_id: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) + version: Mapped[str] = mapped_column(String(50), nullable=False) + environment: Mapped[str] = mapped_column(String(50), nullable=False, default="production") + + # OpenTelemetry trace context + trace_id: Mapped[Optional[str]] = mapped_column(String(32), index=True, nullable=True) + span_id: Mapped[Optional[str]] = mapped_column(String(16), nullable=True) + + # Indexes for performance + __table_args__ = ( + Index("idx_log_correlation_time", "correlation_id", "timestamp"), + Index("idx_log_user_time", "user_id", "timestamp"), + Index("idx_log_level_time", "level", "timestamp"), + Index("idx_log_component_time", "component", "timestamp"), + Index("idx_log_security", "is_security_event", "security_severity", "timestamp"), + Index("idx_log_operation", "operation_type", "timestamp"), + Index("idx_log_trace", "trace_id", "timestamp"), + ) + + +class PerformanceMetric(Base): + """Aggregated performance metrics from log analysis. + + Stores time-windowed aggregations of operation performance + for analytics and trend analysis. + """ + + __tablename__ = "performance_metrics" + + # Primary key + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: uuid.uuid4().hex) + + # Timestamp + timestamp: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, index=True, default=utc_now) + + # Metric identification + operation_type: Mapped[str] = mapped_column(String(100), nullable=False, index=True) + component: Mapped[str] = mapped_column(String(100), nullable=False, index=True) + + # Aggregated metrics + request_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + error_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + error_rate: Mapped[float] = mapped_column(Float, nullable=False, default=0.0) + + # Duration metrics (in milliseconds) + avg_duration_ms: Mapped[float] = mapped_column(Float, nullable=False) + min_duration_ms: Mapped[float] = mapped_column(Float, nullable=False) + max_duration_ms: Mapped[float] = mapped_column(Float, nullable=False) + p50_duration_ms: Mapped[float] = mapped_column(Float, nullable=False) + p95_duration_ms: Mapped[float] = mapped_column(Float, nullable=False) + p99_duration_ms: Mapped[float] = mapped_column(Float, nullable=False) + + # Time window + window_start: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, index=True) + window_end: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) + window_duration_seconds: Mapped[int] = mapped_column(Integer, nullable=False) + + # Additional context + metric_metadata: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True) + + __table_args__ = ( + Index("idx_perf_operation_time", "operation_type", "window_start"), + Index("idx_perf_component_time", "component", "window_start"), + Index("idx_perf_window", "window_start", "window_end"), + ) + + +class SecurityEvent(Base): + """Security event logging for threat detection and audit trails. + + Specialized table for security events with enhanced context + and threat analysis capabilities. + """ + + __tablename__ = "security_events" + + # Primary key + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: uuid.uuid4().hex) + + # Timestamps + timestamp: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, index=True, default=utc_now) + detected_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, default=utc_now) + + # Correlation tracking + correlation_id: Mapped[Optional[str]] = mapped_column(String(64), index=True, nullable=True) + log_entry_id: Mapped[Optional[str]] = mapped_column(String(36), ForeignKey("structured_log_entries.id"), index=True, nullable=True) + + # Event classification + event_type: Mapped[str] = mapped_column(String(100), nullable=False, index=True) # auth_failure, suspicious_activity, rate_limit, etc. + severity: Mapped[str] = mapped_column(String(20), nullable=False, index=True) # LOW, MEDIUM, HIGH, CRITICAL + category: Mapped[str] = mapped_column(String(50), nullable=False, index=True) # authentication, authorization, data_access, etc. + + # User and request context + user_id: Mapped[Optional[str]] = mapped_column(String(255), index=True, nullable=True) + user_email: Mapped[Optional[str]] = mapped_column(String(255), index=True, nullable=True) + client_ip: Mapped[str] = mapped_column(String(45), nullable=False, index=True) + user_agent: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + + # Event details + description: Mapped[str] = mapped_column(Text, nullable=False) + action_taken: Mapped[Optional[str]] = mapped_column(String(100), nullable=True) # blocked, allowed, flagged, etc. + + # Threat analysis + threat_score: Mapped[float] = mapped_column(Float, nullable=False, default=0.0) # 0.0-1.0 + threat_indicators: Mapped[Dict[str, Any]] = mapped_column(JSON, nullable=False, default=dict) + failed_attempts_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + + # Resolution tracking + resolved: Mapped[bool] = mapped_column(Boolean, default=False, index=True, nullable=False) + resolved_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True) + resolved_by: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + resolution_notes: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + + # Alert tracking + alert_sent: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) + alert_sent_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True) + alert_recipients: Mapped[Optional[List[str]]] = mapped_column(JSON, nullable=True) + + # Additional context + context: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True) + + __table_args__ = ( + Index("idx_security_type_time", "event_type", "timestamp"), + Index("idx_security_severity_time", "severity", "timestamp"), + Index("idx_security_user_time", "user_id", "timestamp"), + Index("idx_security_ip_time", "client_ip", "timestamp"), + Index("idx_security_unresolved", "resolved", "severity", "timestamp"), + ) + + +class AuditTrail(Base): + """Comprehensive audit trail for data access and changes. + + Tracks all significant system changes and data access for + compliance and security auditing. + """ + + __tablename__ = "audit_trails" + + # Primary key + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: uuid.uuid4().hex) + + # Timestamps + timestamp: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, index=True, default=utc_now) + + # Correlation tracking + correlation_id: Mapped[Optional[str]] = mapped_column(String(64), index=True, nullable=True) + request_id: Mapped[Optional[str]] = mapped_column(String(64), index=True, nullable=True) + + # Action details + action: Mapped[str] = mapped_column(String(100), nullable=False, index=True) # create, read, update, delete, execute, etc. + resource_type: Mapped[str] = mapped_column(String(100), nullable=False, index=True) # tool, resource, prompt, user, etc. + resource_id: Mapped[Optional[str]] = mapped_column(String(255), index=True, nullable=True) + resource_name: Mapped[Optional[str]] = mapped_column(String(500), nullable=True) + + # User context + user_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True) + user_email: Mapped[Optional[str]] = mapped_column(String(255), index=True, nullable=True) + team_id: Mapped[Optional[str]] = mapped_column(String(36), index=True, nullable=True) + + # Request context + client_ip: Mapped[Optional[str]] = mapped_column(String(45), nullable=True) + user_agent: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + request_path: Mapped[Optional[str]] = mapped_column(String(500), nullable=True) + request_method: Mapped[Optional[str]] = mapped_column(String(10), nullable=True) + + # Change tracking + old_values: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True) + new_values: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True) + changes: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True) + + # Data classification + data_classification: Mapped[Optional[str]] = mapped_column(String(50), index=True, nullable=True) # public, internal, confidential, restricted + requires_review: Mapped[bool] = mapped_column(Boolean, default=False, index=True, nullable=False) + + # Result + success: Mapped[bool] = mapped_column(Boolean, nullable=False, index=True) + error_message: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + + # Additional context + context: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True) + + __table_args__ = ( + Index("idx_audit_action_time", "action", "timestamp"), + Index("idx_audit_resource_time", "resource_type", "resource_id", "timestamp"), + Index("idx_audit_user_time", "user_id", "timestamp"), + Index("idx_audit_classification", "data_classification", "timestamp"), + Index("idx_audit_review", "requires_review", "timestamp"), + ) + + if __name__ == "__main__": # Wait for database to be ready before initializing wait_for_db_ready(max_tries=int(settings.db_max_retries), interval=int(settings.db_retry_interval_ms) / 1000, sync=True) # Converting ms to s diff --git a/mcpgateway/main.py b/mcpgateway/main.py index 6eb6e0f2d..1b0e39eae 100644 --- a/mcpgateway/main.py +++ b/mcpgateway/main.py @@ -27,7 +27,7 @@ # Standard import asyncio -from contextlib import asynccontextmanager +from contextlib import asynccontextmanager, suppress from datetime import datetime import json import os as _os # local alias to avoid collisions @@ -70,6 +70,7 @@ from mcpgateway.db import refresh_slugs_on_startup, SessionLocal from mcpgateway.db import Tool as DbTool from mcpgateway.handlers.sampling import SamplingHandler +from mcpgateway.middleware.correlation_id import CorrelationIDMiddleware from mcpgateway.middleware.http_auth_middleware import HttpAuthMiddleware from mcpgateway.middleware.protocol_version import MCPProtocolVersionMiddleware from mcpgateway.middleware.rbac import get_current_user_with_permissions, require_permission @@ -112,6 +113,7 @@ from mcpgateway.services.import_service import ConflictStrategy, ImportConflictError from mcpgateway.services.import_service import ImportError as ImportServiceError from mcpgateway.services.import_service import ImportService, ImportValidationError +from mcpgateway.services.log_aggregator import get_log_aggregator from mcpgateway.services.logging_service import LoggingService from mcpgateway.services.metrics import setup_metrics from mcpgateway.services.prompt_service import PromptError, PromptNameConflictError, PromptNotFoundError, PromptService @@ -406,6 +408,10 @@ async def lifespan(_app: FastAPI) -> AsyncIterator[None]: Exception: Any unhandled error that occurs during service initialisation or shutdown is re-raised to the caller. """ + aggregation_stop_event: Optional[asyncio.Event] = None + aggregation_loop_task: Optional[asyncio.Task] = None + aggregation_backfill_task: Optional[asyncio.Task] = None + # Initialize logging service FIRST to ensure all logging goes to dual output await logging_service.initialize() logger.info("Starting MCP Gateway services") @@ -461,6 +467,54 @@ async def lifespan(_app: FastAPI) -> AsyncIterator[None]: # Reconfigure uvicorn loggers after startup to capture access logs in dual output logging_service.configure_uvicorn_after_startup() + if settings.metrics_aggregation_enabled and settings.metrics_aggregation_auto_start: + aggregation_stop_event = asyncio.Event() + log_aggregator = get_log_aggregator() + + async def run_log_backfill() -> None: + """Backfill log aggregation metrics for configured hours.""" + hours = getattr(settings, "metrics_aggregation_backfill_hours", 0) + if hours <= 0: + return + try: + await asyncio.to_thread(log_aggregator.backfill, hours) + logger.info("Log aggregation backfill completed for last %s hour(s)", hours) + except Exception as backfill_error: # pragma: no cover - defensive logging + logger.warning("Log aggregation backfill failed: %s", backfill_error) + + async def run_log_aggregation_loop() -> None: + """Run continuous log aggregation at configured intervals. + + Raises: + asyncio.CancelledError: When aggregation is stopped + """ + interval_seconds = max(1, int(settings.metrics_aggregation_window_minutes)) * 60 + logger.info( + "Starting log aggregation loop (window=%s min)", + log_aggregator.aggregation_window_minutes, + ) + try: + while not aggregation_stop_event.is_set(): + try: + await asyncio.to_thread(log_aggregator.aggregate_all_components) + except Exception as agg_error: # pragma: no cover - defensive logging + logger.warning("Log aggregation loop iteration failed: %s", agg_error) + + try: + await asyncio.wait_for(aggregation_stop_event.wait(), timeout=interval_seconds) + except asyncio.TimeoutError: + continue + except asyncio.CancelledError: + logger.debug("Log aggregation loop cancelled") + raise + finally: + logger.info("Log aggregation loop stopped") + + aggregation_backfill_task = asyncio.create_task(run_log_backfill()) + aggregation_loop_task = asyncio.create_task(run_log_aggregation_loop()) + elif settings.metrics_aggregation_enabled: + logger.info("Metrics aggregation auto-start disabled; performance metrics will be generated on-demand when requested.") + yield except Exception as e: logger.error(f"Error during startup: {str(e)}") @@ -474,6 +528,14 @@ async def lifespan(_app: FastAPI) -> AsyncIterator[None]: raise SystemExit(1) raise finally: + if aggregation_stop_event is not None: + aggregation_stop_event.set() + for task in (aggregation_backfill_task, aggregation_loop_task): + if task: + task.cancel() + with suppress(asyncio.CancelledError): + await task + # Shutdown plugin manager if plugin_manager: try: @@ -1169,6 +1231,15 @@ async def _call_streamable_http(self, scope, receive, send): # Add HTTP authentication hook middleware for plugins (before auth dependencies) if plugin_manager: app.add_middleware(HttpAuthMiddleware, plugin_manager=plugin_manager) + logger.info("🔌 HTTP authentication hooks enabled for plugins") + +# Add request logging middleware FIRST (always enabled for gateway boundary logging) +# IMPORTANT: Must be registered BEFORE CorrelationIDMiddleware so it executes AFTER correlation ID is set +# Gateway boundary logging (request_started/completed) runs regardless of log_requests setting +# Detailed payload logging only runs if log_detailed_requests=True +app.add_middleware( + RequestLoggingMiddleware, enable_gateway_logging=True, log_detailed_requests=settings.log_requests, log_level=settings.log_level, max_body_size=settings.log_max_size_mb * 1024 * 1024 +) # Convert MB to bytes # Add custom DocsAuthMiddleware app.add_middleware(DocsAuthMiddleware) @@ -1176,13 +1247,27 @@ async def _call_streamable_http(self, scope, receive, send): # Trust all proxies (or lock down with a list of host patterns) app.add_middleware(ProxyHeadersMiddleware, trusted_hosts="*") -# Add request logging middleware if enabled -if settings.log_requests: - app.add_middleware(RequestLoggingMiddleware, log_requests=settings.log_requests, log_level=settings.log_level, max_body_size=settings.log_max_size_mb * 1024 * 1024) # Convert MB to bytes +# Add correlation ID middleware if enabled +# Note: Registered AFTER RequestLoggingMiddleware so correlation ID is available when RequestLoggingMiddleware executes +if settings.correlation_id_enabled: + app.add_middleware(CorrelationIDMiddleware) + logger.info(f"✅ Correlation ID tracking enabled (header: {settings.correlation_id_header})") + +# Add authentication context middleware if security logging is enabled +# This middleware extracts user context and logs security events (authentication attempts) +# Note: This is independent of observability - security logging is always important +if settings.security_logging_enabled: + # First-Party + from mcpgateway.middleware.auth_middleware import AuthContextMiddleware + + app.add_middleware(AuthContextMiddleware) + logger.info("🔐 Authentication context middleware enabled - logging security events") +else: + logger.info("🔐 Security event logging disabled") # Add observability middleware if enabled # Note: Middleware runs in REVERSE order (last added runs first) -# We add ObservabilityMiddleware first so it wraps AuthContextMiddleware +# If AuthContextMiddleware is already registered, ObservabilityMiddleware wraps it # Execution order will be: AuthContext -> Observability -> Request Handler if settings.observability_enabled: # First-Party @@ -1190,13 +1275,6 @@ async def _call_streamable_http(self, scope, receive, send): app.add_middleware(ObservabilityMiddleware, enabled=True) logger.info("🔍 Observability middleware enabled - tracing all HTTP requests") - - # Add authentication context middleware (runs BEFORE observability in execution) - # First-Party - from mcpgateway.middleware.auth_middleware import AuthContextMiddleware - - app.add_middleware(AuthContextMiddleware) - logger.info("🔐 Authentication context middleware enabled - extracting user info for observability") else: logger.info("🔍 Observability middleware disabled") @@ -2402,7 +2480,20 @@ async def invoke_a2a_agent( logger.debug(f"User {user} is invoking A2A agent '{agent_name}' with type '{interaction_type}'") if a2a_service is None: raise HTTPException(status_code=503, detail="A2A service not available") - return await a2a_service.invoke_agent(db, agent_name, parameters, interaction_type) + user_email = get_user_email(user) + user_id = None + if isinstance(user, dict): + user_id = str(user.get("id") or user.get("sub") or user_email) + else: + user_id = str(user) + return await a2a_service.invoke_agent( + db, + agent_name, + parameters, + interaction_type, + user_id=user_id, + user_email=user_email, + ) except A2AAgentNotFoundError as e: raise HTTPException(status_code=404, detail=str(e)) except A2AAgentError as e: @@ -4980,6 +5071,19 @@ async def cleanup_import_statuses(max_age_hours: int = 24, user=Depends(get_curr app.include_router(tag_router) app.include_router(export_import_router) +# Include log search router if structured logging is enabled +if getattr(settings, "structured_logging_enabled", True): + try: + # First-Party + from mcpgateway.routers.log_search import router as log_search_router + + app.include_router(log_search_router) + logger.info("Log search router included - structured logging enabled") + except ImportError as e: + logger.warning(f"Failed to import log search router: {e}") +else: + logger.info("Log search router not included - structured logging disabled") + # Conditionally include observability router if enabled if settings.observability_enabled: # First-Party diff --git a/mcpgateway/middleware/auth_middleware.py b/mcpgateway/middleware/auth_middleware.py index a8868ccbe..1c2dc7a6c 100644 --- a/mcpgateway/middleware/auth_middleware.py +++ b/mcpgateway/middleware/auth_middleware.py @@ -28,8 +28,10 @@ # First-Party from mcpgateway.auth import get_current_user from mcpgateway.db import SessionLocal +from mcpgateway.services.security_logger import get_security_logger logger = logging.getLogger(__name__) +security_logger = get_security_logger() class AuthContextMiddleware(BaseHTTPMiddleware): @@ -85,14 +87,47 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response: credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials=token) user = await get_current_user(credentials, db) + # Eagerly access user attributes before session closes to prevent DetachedInstanceError + # This forces SQLAlchemy to load the data while the session is still active + # Note: EmailUser uses 'email' as primary key, not 'id' + user_email = user.email + user_id = user_email # For EmailUser, email IS the ID + + # Expunge the user from the session so it can be used after session closes + # This makes the object detached but with all attributes already loaded + db.expunge(user) + # Store user in request state for downstream use request.state.user = user - logger.info(f"✓ Authenticated user for observability: {user.email}") + logger.info(f"✓ Authenticated user: {user_email if user_email else user_id}") + + # Log successful authentication + security_logger.log_authentication_attempt( + user_id=user_id, + user_email=user_email, + auth_method="bearer_token", + success=True, + client_ip=request.client.host if request.client else "unknown", + user_agent=request.headers.get("user-agent"), + db=db, + ) except Exception as e: # Silently fail - let route handlers enforce auth if needed logger.info(f"✗ Auth context extraction failed (continuing as anonymous): {e}") + # Log failed authentication attempt + security_logger.log_authentication_attempt( + user_id="unknown", + user_email=None, + auth_method="bearer_token", + success=False, + client_ip=request.client.host if request.client else "unknown", + user_agent=request.headers.get("user-agent"), + failure_reason=str(e), + db=db if db else None, + ) + finally: # Always close database session if db: diff --git a/mcpgateway/middleware/correlation_id.py b/mcpgateway/middleware/correlation_id.py new file mode 100644 index 000000000..7d9a31193 --- /dev/null +++ b/mcpgateway/middleware/correlation_id.py @@ -0,0 +1,118 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/middleware/correlation_id.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: MCP Gateway Contributors + +Correlation ID (Request ID) Middleware. + +This middleware handles X-Correlation-ID HTTP headers and maps them to the internal +request_id used throughout the system for unified request tracing. + +Key concept: HTTP X-Correlation-ID header → Internal request_id field (single ID for entire request flow) + +The middleware automatically extracts or generates request IDs for every HTTP request, +stores them in context variables for async-safe propagation across services, and +injects them back into response headers for client-side correlation. + +This enables end-to-end tracing: HTTP → Middleware → Services → Plugins → Logs (all with same request_id) +""" + +# Standard +import logging +from typing import Callable + +# Third-Party +from fastapi import Request, Response +from starlette.middleware.base import BaseHTTPMiddleware + +# First-Party +from mcpgateway.config import settings +from mcpgateway.utils.correlation_id import ( + clear_correlation_id, + extract_correlation_id_from_headers, + generate_correlation_id, + set_correlation_id, +) + +logger = logging.getLogger(__name__) + + +class CorrelationIDMiddleware(BaseHTTPMiddleware): + """Middleware for automatic request ID (correlation ID) handling. + + This middleware: + 1. Extracts request ID from X-Correlation-ID header in incoming requests + 2. Generates a new UUID if no correlation ID is present + 3. Stores the ID in context variables for the request lifecycle (used as request_id throughout system) + 4. Injects the request ID into X-Correlation-ID response header + 5. Cleans up context after request completion + + The request ID extracted/generated here becomes the unified request_id used in: + - All log entries (request_id field) + - GlobalContext.request_id (when plugins execute) + - Service method calls for tracing + - Database queries for request tracking + + Configuration is controlled via settings: + - correlation_id_enabled: Enable/disable the middleware + - correlation_id_header: Header name to use (default: X-Correlation-ID) + - correlation_id_preserve: Whether to preserve incoming IDs (default: True) + - correlation_id_response_header: Whether to add ID to responses (default: True) + """ + + def __init__(self, app): + """Initialize the correlation ID (request ID) middleware. + + Args: + app: The FastAPI application instance + """ + super().__init__(app) + self.header_name = getattr(settings, "correlation_id_header", "X-Correlation-ID") + self.preserve_incoming = getattr(settings, "correlation_id_preserve", True) + self.add_to_response = getattr(settings, "correlation_id_response_header", True) + + async def dispatch(self, request: Request, call_next: Callable) -> Response: + """Process the request and manage request ID (correlation ID) lifecycle. + + Extracts or generates a request ID, stores it in context variables for use throughout + the request lifecycle (becomes request_id in logs, services, plugins), and injects + it back into the X-Correlation-ID response header. + + Args: + request: The incoming HTTP request + call_next: The next middleware or route handler + + Returns: + Response: The HTTP response with correlation ID header added + """ + # Extract correlation ID from incoming request headers + correlation_id = None + if self.preserve_incoming: + correlation_id = extract_correlation_id_from_headers(dict(request.headers), self.header_name) + + # Generate new correlation ID if none was provided + if not correlation_id: + correlation_id = generate_correlation_id() + logger.debug(f"Generated new correlation ID: {correlation_id}") + else: + logger.debug(f"Using client-provided correlation ID: {correlation_id}") + + # Store correlation ID in context variable for this request + # This makes it available to all downstream code (auth, services, plugins, logs) + set_correlation_id(correlation_id) + + try: + # Process the request + response = await call_next(request) + + # Add correlation ID to response headers if enabled + if self.add_to_response: + response.headers[self.header_name] = correlation_id + + return response + + finally: + # Clean up context after request completes + # Note: ContextVar automatically cleans up, but explicit cleanup is good practice + clear_correlation_id() diff --git a/mcpgateway/middleware/http_auth_middleware.py b/mcpgateway/middleware/http_auth_middleware.py index 84058641f..8b73ffacd 100644 --- a/mcpgateway/middleware/http_auth_middleware.py +++ b/mcpgateway/middleware/http_auth_middleware.py @@ -8,7 +8,6 @@ # Standard import logging -import uuid # Third-Party from fastapi import Request @@ -17,6 +16,7 @@ # First-Party from mcpgateway.plugins.framework import GlobalContext, HttpHeaderPayload, HttpHookType, HttpPostRequestPayload, HttpPreRequestPayload, PluginManager +from mcpgateway.utils.correlation_id import generate_correlation_id, get_correlation_id logger = logging.getLogger(__name__) @@ -60,9 +60,14 @@ async def dispatch(self, request: Request, call_next): if not self.plugin_manager: return await call_next(request) - # Generate request ID for tracing and store in request state - # This ensures all hooks and downstream code see the same request ID - request_id = uuid.uuid4().hex + # Use correlation ID from CorrelationIDMiddleware if available + # This ensures all hooks and downstream code see the same unified request ID + request_id = get_correlation_id() + if not request_id: + # Fallback if correlation ID middleware is disabled + request_id = generate_correlation_id() + logger.debug(f"Correlation ID not found, generated fallback: {request_id}") + request.state.request_id = request_id # Create global context for hooks diff --git a/mcpgateway/middleware/request_logging_middleware.py b/mcpgateway/middleware/request_logging_middleware.py index db286b20f..f241197ab 100644 --- a/mcpgateway/middleware/request_logging_middleware.py +++ b/mcpgateway/middleware/request_logging_middleware.py @@ -15,19 +15,29 @@ # Standard import json import logging +import time from typing import Callable # Third-Party -from fastapi import Request, Response +from fastapi.security import HTTPAuthorizationCredentials from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from starlette.responses import Response # First-Party +from mcpgateway.auth import get_current_user +from mcpgateway.db import SessionLocal from mcpgateway.services.logging_service import LoggingService +from mcpgateway.services.structured_logger import get_structured_logger +from mcpgateway.utils.correlation_id import get_correlation_id # Initialize logging service first logging_service = LoggingService() logger = logging_service.get_logger(__name__) +# Initialize structured logger for gateway boundary logging +structured_logger = get_structured_logger("http_gateway") + SENSITIVE_KEYS = {"password", "secret", "token", "apikey", "access_token", "refresh_token", "client_secret", "authorization", "jwt_token"} @@ -106,20 +116,67 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware): masking sensitive information like passwords, tokens, and authorization headers. """ - def __init__(self, app, log_requests: bool = True, log_level: str = "DEBUG", max_body_size: int = 4096): + def __init__(self, app, enable_gateway_logging: bool = True, log_detailed_requests: bool = False, log_level: str = "DEBUG", max_body_size: int = 4096): """Initialize the request logging middleware. Args: app: The FastAPI application instance - log_requests: Whether to enable request logging + enable_gateway_logging: Whether to enable gateway boundary logging (request_started/completed) + log_detailed_requests: Whether to enable detailed request/response payload logging log_level: The log level for requests (not used, logs at INFO) max_body_size: Maximum request body size to log in bytes """ super().__init__(app) - self.log_requests = log_requests + self.enable_gateway_logging = enable_gateway_logging + self.log_detailed_requests = log_detailed_requests self.log_level = log_level.upper() self.max_body_size = max_body_size # Expected to be in bytes + async def _resolve_user_identity(self, request: Request): + """Best-effort extraction of user identity for request logs. + + Args: + request: The incoming HTTP request + + Returns: + Tuple[Optional[str], Optional[str]]: User ID and email + """ + # Prefer context injected by upstream middleware + if hasattr(request.state, "user") and request.state.user is not None: + raw_user_id = getattr(request.state.user, "id", None) + user_email = getattr(request.state.user, "email", None) + return (str(raw_user_id) if raw_user_id is not None else None, user_email) + + # Fallback: try to authenticate using cookies/headers (matches AuthContextMiddleware) + token = None + if request.cookies: + token = request.cookies.get("jwt_token") or request.cookies.get("access_token") or request.cookies.get("token") + + if not token: + auth_header = request.headers.get("authorization") + if auth_header and auth_header.startswith("Bearer "): + token = auth_header.replace("Bearer ", "") + + if not token: + return (None, None) + + db = None + try: + db = SessionLocal() + credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials=token) + user = await get_current_user(credentials, db) + raw_user_id = getattr(user, "id", None) + user_email = getattr(user, "email", None) + return (str(raw_user_id) if raw_user_id is not None else None, user_email) + except Exception: + return (None, None) + finally: + if db: + try: + db.close() + except Exception: # nosec B110 - Silently handle db.close() failures during cleanup + pass + async def dispatch(self, request: Request, call_next: Callable): """Process incoming request and log details with sensitive data masked. @@ -129,10 +186,74 @@ async def dispatch(self, request: Request, call_next: Callable): Returns: Response: The HTTP response from downstream handlers + + Raises: + Exception: Any exception from downstream handlers is re-raised """ - # Skip logging if disabled - if not self.log_requests: - return await call_next(request) + # Track start time for total duration + start_time = time.time() + + # Get correlation ID and request metadata for boundary logging + correlation_id = get_correlation_id() + path = request.url.path + method = request.method + user_agent = request.headers.get("user-agent", "unknown") + client_ip = request.client.host if request.client else "unknown" + user_id, user_email = await self._resolve_user_identity(request) + + # Skip boundary logging for health checks and static assets + skip_paths = ["/health", "/healthz", "/static", "/favicon.ico"] + should_log_boundary = self.enable_gateway_logging and not any(path.startswith(skip_path) for skip_path in skip_paths) + + # Log gateway request started + if should_log_boundary: + try: + structured_logger.log( + level="INFO", + message=f"Request started: {method} {path}", + component="http_gateway", + correlation_id=correlation_id, + user_email=user_email, + user_id=user_id, + operation_type="http_request", + request_method=method, + request_path=path, + user_agent=user_agent, + client_ip=client_ip, + metadata={"event": "request_started", "query_params": str(request.query_params) if request.query_params else None}, + ) + except Exception as e: + logger.warning(f"Failed to log request start: {e}") + + # Skip detailed logging if disabled + if not self.log_detailed_requests: + response = await call_next(request) + + # Still log request completed even if detailed logging is disabled + if should_log_boundary: + duration_ms = (time.time() - start_time) * 1000 + try: + log_level = "ERROR" if response.status_code >= 500 else "WARNING" if response.status_code >= 400 else "INFO" + structured_logger.log( + level=log_level, + message=f"Request completed: {method} {path} - {response.status_code}", + component="http_gateway", + correlation_id=correlation_id, + user_email=user_email, + user_id=user_id, + operation_type="http_request", + request_method=method, + request_path=path, + response_status_code=response.status_code, + user_agent=user_agent, + client_ip=client_ip, + duration_ms=duration_ms, + metadata={"event": "request_completed", "response_time_category": "fast" if duration_ms < 100 else "normal" if duration_ms < 1000 else "slow"}, + ) + except Exception as e: + logger.warning(f"Failed to log request completion: {e}") + + return response # Always log at INFO level for request payloads to ensure visibility log_level = logging.INFO @@ -171,13 +292,28 @@ async def dispatch(self, request: Request, call_next: Callable): # Mask sensitive headers masked_headers = mask_sensitive_headers(dict(request.headers)) - logger.log( - log_level, - f"📩 Incoming request: {request.method} {request.url.path}\n" - f"Query params: {dict(request.query_params)}\n" - f"Headers: {masked_headers}\n" - f"Body: {payload_str}{'... [truncated]' if truncated else ''}", - ) + # Get correlation ID for request tracking + request_id = get_correlation_id() + + # Try to log with extra parameter, fall back to without if not supported + try: + logger.log( + log_level, + f"📩 Incoming request: {request.method} {request.url.path}\n" + f"Query params: {dict(request.query_params)}\n" + f"Headers: {masked_headers}\n" + f"Body: {payload_str}{'... [truncated]' if truncated else ''}", + extra={"request_id": request_id}, + ) + except TypeError: + # Fall back for test loggers that don't accept extra parameter + logger.log( + log_level, + f"📩 Incoming request: {request.method} {request.url.path}\n" + f"Query params: {dict(request.query_params)}\n" + f"Headers: {masked_headers}\n" + f"Body: {payload_str}{'... [truncated]' if truncated else ''}", + ) except Exception as e: logger.warning(f"Failed to log request body: {e}") @@ -195,5 +331,80 @@ async def receive(): new_scope = request.scope.copy() new_request = Request(new_scope, receive=receive) - response: Response = await call_next(new_request) + # Process request + try: + response: Response = await call_next(new_request) + status_code = response.status_code + except Exception as e: + duration_ms = (time.time() - start_time) * 1000 + + # Log request failed + if should_log_boundary: + try: + structured_logger.log( + level="ERROR", + message=f"Request failed: {method} {path}", + component="gateway", + correlation_id=correlation_id, + user_email=user_email, + user_id=user_id, + operation_type="http_request", + request_method=method, + request_path=path, + user_agent=user_agent, + client_ip=client_ip, + duration_ms=duration_ms, + error=e, + metadata={"event": "request_failed"}, + ) + except Exception as log_error: + logger.warning(f"Failed to log request failure: {log_error}") + + raise + + # Calculate total duration + duration_ms = (time.time() - start_time) * 1000 + + # Log gateway request completed + if should_log_boundary: + try: + log_level = "ERROR" if status_code >= 500 else "WARNING" if status_code >= 400 else "INFO" + + structured_logger.log( + level=log_level, + message=f"Request completed: {method} {path} - {status_code}", + component="gateway", + correlation_id=correlation_id, + user_email=user_email, + user_id=user_id, + operation_type="http_request", + request_method=method, + request_path=path, + response_status_code=status_code, + user_agent=user_agent, + client_ip=client_ip, + duration_ms=duration_ms, + metadata={"event": "request_completed", "response_time_category": self._categorize_response_time(duration_ms)}, + ) + except Exception as e: + logger.warning(f"Failed to log request completion: {e}") + return response + + @staticmethod + def _categorize_response_time(duration_ms: float) -> str: + """Categorize response time for analytics. + + Args: + duration_ms: Response time in milliseconds + + Returns: + Category string + """ + if duration_ms < 100: + return "fast" + if duration_ms < 500: + return "normal" + if duration_ms < 2000: + return "slow" + return "very_slow" diff --git a/mcpgateway/observability.py b/mcpgateway/observability.py index 6714dd392..25b28ef5c 100644 --- a/mcpgateway/observability.py +++ b/mcpgateway/observability.py @@ -15,7 +15,7 @@ import os from typing import Any, Callable, cast, Dict, Optional -# Try to import OpenTelemetry core components - make them truly optional +# Third-Party - Try to import OpenTelemetry core components - make them truly optional OTEL_AVAILABLE = False try: # Third-Party @@ -93,6 +93,9 @@ class _ConsoleSpanExporterStub: # pragma: no cover - test patch replaces this # Shimming is a non-critical, best-effort step for tests; log and continue. logging.getLogger(__name__).debug("Skipping OpenTelemetry shim setup: %s", exc) +# First-Party +from mcpgateway.utils.correlation_id import get_correlation_id # noqa: E402 # pylint: disable=wrong-import-position + # Try to import optional exporters try: OTLP_SPAN_EXPORTER = getattr(_im("opentelemetry.exporter.otlp.proto.grpc.trace_exporter"), "OTLPSpanExporter") @@ -440,6 +443,21 @@ def create_span(name: str, attributes: Optional[Dict[str, Any]] = None) -> Any: # Return a no-op context manager if tracing is not configured or available return nullcontext() + # Auto-inject correlation ID into all spans for request tracing + try: + correlation_id = get_correlation_id() + if correlation_id: + if attributes is None: + attributes = {} + # Add correlation ID if not already present + if "correlation_id" not in attributes: + attributes["correlation_id"] = correlation_id + if "request_id" not in attributes: + attributes["request_id"] = correlation_id # Alias for compatibility + except Exception as exc: + # Correlation ID not available or error getting it, continue without it + logger.debug("Failed to add correlation_id to span: %s", exc) + # Start span and return the context manager span_context = _TRACER.start_as_current_span(name) diff --git a/mcpgateway/plugins/framework/external/mcp/tls_utils.py b/mcpgateway/plugins/framework/external/mcp/tls_utils.py index 91b04cfb0..11d4d0acf 100644 --- a/mcpgateway/plugins/framework/external/mcp/tls_utils.py +++ b/mcpgateway/plugins/framework/external/mcp/tls_utils.py @@ -86,7 +86,7 @@ def create_ssl_context(tls_config: MCPClientTLSConfig, plugin_name: str) -> ssl. # Disable certificate verification (not recommended for production) logger.warning(f"Certificate verification disabled for plugin '{plugin_name}'. This is not recommended for production use.") ssl_context.check_hostname = False - ssl_context.verify_mode = ssl.CERT_NONE # noqa: DUO122 + ssl_context.verify_mode = ssl.CERT_NONE # nosec B502 # noqa: DUO122 else: # Enable strict certificate verification (production mode) # Load CA certificate bundle for server certificate validation diff --git a/mcpgateway/routers/log_search.py b/mcpgateway/routers/log_search.py new file mode 100644 index 000000000..5023ea614 --- /dev/null +++ b/mcpgateway/routers/log_search.py @@ -0,0 +1,754 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/routers/log_search.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 + +Log Search API Router. + +This module provides REST API endpoints for searching and analyzing structured logs, +security events, audit trails, and performance metrics. +""" + +# Standard +from datetime import datetime, timedelta, timezone +import logging +from typing import Any, Dict, List, Optional, Tuple + +# Third-Party +from fastapi import APIRouter, Depends, HTTPException, Query +from pydantic import BaseModel, Field +from sqlalchemy import and_, delete, desc, or_, select +from sqlalchemy.orm import Session +from sqlalchemy.sql import func as sa_func + +# First-Party +from mcpgateway.config import settings +from mcpgateway.db import ( + AuditTrail, + get_db, + PerformanceMetric, + SecurityEvent, + StructuredLogEntry, +) +from mcpgateway.middleware.rbac import get_current_user_with_permissions, require_permission +from mcpgateway.services.log_aggregator import get_log_aggregator + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/logs", tags=["logs"]) + +MIN_PERFORMANCE_RANGE_HOURS = 5.0 / 60.0 +_DEFAULT_AGGREGATION_KEY = "5m" +_AGGREGATION_LEVELS: Dict[str, Dict[str, Any]] = { + "5m": {"minutes": 5, "label": "5-minute windows"}, + "24h": {"minutes": 24 * 60, "label": "24-hour windows"}, +} + + +def _align_to_window(dt: datetime, window_minutes: int) -> datetime: + """Align a datetime down to the nearest aggregation window boundary. + + Args: + dt: Datetime to align + window_minutes: Aggregation window size in minutes + + Returns: + datetime: Aligned datetime at window boundary + """ + timestamp = dt.astimezone(timezone.utc) + total_minutes = int(timestamp.timestamp() // 60) + aligned_minutes = (total_minutes // window_minutes) * window_minutes + return datetime.fromtimestamp(aligned_minutes * 60, tz=timezone.utc) + + +def _deduplicate_metrics(metrics: List[PerformanceMetric]) -> List[PerformanceMetric]: + """Ensure a single metric per component/operation/window. + + Args: + metrics: List of performance metrics to deduplicate + + Returns: + List[PerformanceMetric]: Deduplicated metrics sorted by window_start + """ + if not metrics: + return [] + + deduped: Dict[Tuple[str, str, datetime], PerformanceMetric] = {} + for metric in metrics: + component = metric.component or "" + operation = metric.operation_type or "" + key = (component, operation, metric.window_start) + existing = deduped.get(key) + if existing is None or metric.timestamp > existing.timestamp: + deduped[key] = metric + + return sorted(deduped.values(), key=lambda m: m.window_start, reverse=True) + + +def _aggregate_custom_windows( + aggregator, + window_minutes: int, + db: Session, +) -> None: + """Aggregate metrics using custom window duration. + + Args: + aggregator: Log aggregator instance + window_minutes: Window size in minutes + db: Database session + """ + window_delta = timedelta(minutes=window_minutes) + window_duration_seconds = window_minutes * 60 + + sample_row = db.execute( + select(PerformanceMetric.window_start, PerformanceMetric.window_end) + .where(PerformanceMetric.window_duration_seconds == window_duration_seconds) + .order_by(desc(PerformanceMetric.window_start)) + .limit(1) + ).first() + + needs_rebuild = False + if sample_row: + sample_start, sample_end = sample_row + if sample_start is not None and sample_end is not None: + start_utc = sample_start if sample_start.tzinfo else sample_start.replace(tzinfo=timezone.utc) + end_utc = sample_end if sample_end.tzinfo else sample_end.replace(tzinfo=timezone.utc) + duration = int((end_utc - start_utc).total_seconds()) + if duration != window_duration_seconds: + needs_rebuild = True + aligned_start = _align_to_window(start_utc, window_minutes) + if aligned_start != start_utc: + needs_rebuild = True + + if needs_rebuild: + db.execute(delete(PerformanceMetric).where(PerformanceMetric.window_duration_seconds == window_duration_seconds)) + db.commit() + sample_row = None + + max_existing = None + if not needs_rebuild: + max_existing = db.execute(select(sa_func.max(PerformanceMetric.window_start)).where(PerformanceMetric.window_duration_seconds == window_duration_seconds)).scalar() + + if max_existing: + current_start = max_existing if max_existing.tzinfo else max_existing.replace(tzinfo=timezone.utc) + current_start = current_start + window_delta + else: + earliest_log = db.execute(select(sa_func.min(StructuredLogEntry.timestamp))).scalar() + if not earliest_log: + return + if earliest_log.tzinfo is None: + earliest_log = earliest_log.replace(tzinfo=timezone.utc) + current_start = _align_to_window(earliest_log, window_minutes) + + reference_end = datetime.now(timezone.utc) + + while current_start < reference_end: + current_end = current_start + window_delta + aggregator.aggregate_all_components( + window_start=current_start, + window_end=current_end, + db=db, + ) + current_start = current_end + + +# Request/Response Models +class LogSearchRequest(BaseModel): + """Log search request parameters.""" + + search_text: Optional[str] = Field(None, description="Text search query") + level: Optional[List[str]] = Field(None, description="Log levels to filter") + component: Optional[List[str]] = Field(None, description="Components to filter") + category: Optional[List[str]] = Field(None, description="Categories to filter") + correlation_id: Optional[str] = Field(None, description="Correlation ID to filter") + user_id: Optional[str] = Field(None, description="User ID to filter") + start_time: Optional[datetime] = Field(None, description="Start timestamp") + end_time: Optional[datetime] = Field(None, description="End timestamp") + min_duration_ms: Optional[float] = Field(None, description="Minimum duration") + max_duration_ms: Optional[float] = Field(None, description="Maximum duration") + has_error: Optional[bool] = Field(None, description="Filter for errors") + limit: int = Field(100, ge=1, le=1000, description="Maximum results") + offset: int = Field(0, ge=0, description="Result offset") + sort_by: str = Field("timestamp", description="Field to sort by") + sort_order: str = Field("desc", description="Sort order (asc/desc)") + + +class LogEntry(BaseModel): + """Log entry response model.""" + + id: str + timestamp: datetime + level: str + component: str + message: str + correlation_id: Optional[str] = None + user_id: Optional[str] = None + user_email: Optional[str] = None + duration_ms: Optional[float] = None + operation_type: Optional[str] = None + request_path: Optional[str] = None + request_method: Optional[str] = None + is_security_event: bool = False + error_details: Optional[Dict[str, Any]] = None + + class Config: + """Pydantic configuration.""" + + from_attributes = True + + +class LogSearchResponse(BaseModel): + """Log search response.""" + + total: int + results: List[LogEntry] + + +class CorrelationTraceRequest(BaseModel): + """Correlation trace request.""" + + correlation_id: str + + +class CorrelationTraceResponse(BaseModel): + """Correlation trace response with all related logs.""" + + correlation_id: str + total_duration_ms: Optional[float] + log_count: int + error_count: int + logs: List[LogEntry] + security_events: List[Dict[str, Any]] + audit_trails: List[Dict[str, Any]] + performance_metrics: Optional[Dict[str, Any]] + + +class SecurityEventResponse(BaseModel): + """Security event response model.""" + + id: str + timestamp: datetime + event_type: str + severity: str + category: str + user_id: Optional[str] + client_ip: str + description: str + threat_score: float + action_taken: Optional[str] + resolved: bool + + class Config: + """Pydantic configuration.""" + + from_attributes = True + + +class AuditTrailResponse(BaseModel): + """Audit trail response model.""" + + id: str + timestamp: datetime + correlation_id: Optional[str] = None + action: str + resource_type: str + resource_id: Optional[str] + resource_name: Optional[str] = None + user_id: str + user_email: Optional[str] = None + success: bool + requires_review: bool + data_classification: Optional[str] + + class Config: + """Pydantic configuration.""" + + from_attributes = True + + +class PerformanceMetricResponse(BaseModel): + """Performance metric response model.""" + + id: str + timestamp: datetime + component: str + operation_type: str + window_start: datetime + window_end: datetime + request_count: int + error_count: int + error_rate: float + avg_duration_ms: float + min_duration_ms: float + max_duration_ms: float + p50_duration_ms: float + p95_duration_ms: float + p99_duration_ms: float + + class Config: + """Pydantic configuration.""" + + from_attributes = True + + +# API Endpoints +@router.post("/search", response_model=LogSearchResponse) +@require_permission("logs:read") +async def search_logs(request: LogSearchRequest, user=Depends(get_current_user_with_permissions), db: Session = Depends(get_db)) -> LogSearchResponse: + """Search structured logs with filters and pagination. + + Args: + request: Search parameters + user: Current authenticated user + db: Database session + + Returns: + Search results with pagination + + Raises: + HTTPException: On database or validation errors + """ + try: + # Build base query + stmt = select(StructuredLogEntry) + + # Apply filters + conditions = [] + + if request.search_text: + conditions.append(or_(StructuredLogEntry.message.ilike(f"%{request.search_text}%"), StructuredLogEntry.component.ilike(f"%{request.search_text}%"))) + + if request.level: + conditions.append(StructuredLogEntry.level.in_(request.level)) + + if request.component: + conditions.append(StructuredLogEntry.component.in_(request.component)) + + # Note: category field doesn't exist in StructuredLogEntry + # if request.category: + # conditions.append(StructuredLogEntry.category.in_(request.category)) + + if request.correlation_id: + conditions.append(StructuredLogEntry.correlation_id == request.correlation_id) + + if request.user_id: + conditions.append(StructuredLogEntry.user_id == request.user_id) + + if request.start_time: + conditions.append(StructuredLogEntry.timestamp >= request.start_time) + + if request.end_time: + conditions.append(StructuredLogEntry.timestamp <= request.end_time) + + if request.min_duration_ms is not None: + conditions.append(StructuredLogEntry.duration_ms >= request.min_duration_ms) + + if request.max_duration_ms is not None: + conditions.append(StructuredLogEntry.duration_ms <= request.max_duration_ms) + + if request.has_error is not None: + if request.has_error: + conditions.append(StructuredLogEntry.error_details.isnot(None)) + else: + conditions.append(StructuredLogEntry.error_details.is_(None)) + + if conditions: + stmt = stmt.where(and_(*conditions)) + + # Get total count + count_stmt = select(sa_func.count()).select_from(stmt.subquery()) + total = db.execute(count_stmt).scalar() or 0 + + # Apply sorting + sort_column = getattr(StructuredLogEntry, request.sort_by, StructuredLogEntry.timestamp) + if request.sort_order == "desc": + stmt = stmt.order_by(desc(sort_column)) + else: + stmt = stmt.order_by(sort_column) + + # Apply pagination + stmt = stmt.limit(request.limit).offset(request.offset) + + # Execute query + results = db.execute(stmt).scalars().all() + + # Convert to response models + log_entries = [ + LogEntry( + id=str(log.id), + timestamp=log.timestamp, + level=log.level, + component=log.component, + message=log.message, + correlation_id=log.correlation_id, + user_id=log.user_id, + user_email=log.user_email, + duration_ms=log.duration_ms, + operation_type=log.operation_type, + request_path=log.request_path, + request_method=log.request_method, + is_security_event=log.is_security_event, + error_details=log.error_details, + ) + for log in results + ] + + return LogSearchResponse(total=total, results=log_entries) + + except Exception as e: + logger.error(f"Log search failed: {e}") + raise HTTPException(status_code=500, detail="Log search failed") + + +@router.get("/trace/{correlation_id}", response_model=CorrelationTraceResponse) +@require_permission("logs:read") +async def trace_correlation_id(correlation_id: str, user=Depends(get_current_user_with_permissions), db: Session = Depends(get_db)) -> CorrelationTraceResponse: + """Get all logs and events for a correlation ID. + + Args: + correlation_id: Correlation ID to trace + user: Current authenticated user + db: Database session + + Returns: + Complete trace of all related logs and events + + Raises: + HTTPException: On database or validation errors + """ + try: + # Get structured logs + log_stmt = select(StructuredLogEntry).where(StructuredLogEntry.correlation_id == correlation_id).order_by(StructuredLogEntry.timestamp) + + logs = db.execute(log_stmt).scalars().all() + + # Get security events + security_stmt = select(SecurityEvent).where(SecurityEvent.correlation_id == correlation_id).order_by(SecurityEvent.timestamp) + + security_events = db.execute(security_stmt).scalars().all() + + # Get audit trails + audit_stmt = select(AuditTrail).where(AuditTrail.correlation_id == correlation_id).order_by(AuditTrail.timestamp) + + audit_trails = db.execute(audit_stmt).scalars().all() + + # Calculate metrics + durations = [log.duration_ms for log in logs if log.duration_ms is not None] + total_duration = sum(durations) if durations else None + error_count = sum(1 for log in logs if log.error_details) + + # Get performance metrics (if any aggregations exist) + perf_metrics = None + if logs: + component = logs[0].component + operation = logs[0].operation_type + if component and operation: + perf_stmt = ( + select(PerformanceMetric) + .where(and_(PerformanceMetric.component == component, PerformanceMetric.operation_type == operation)) + .order_by(desc(PerformanceMetric.window_start)) + .limit(1) + ) + + perf = db.execute(perf_stmt).scalar_one_or_none() + if perf: + perf_metrics = { + "avg_duration_ms": perf.avg_duration_ms, + "p95_duration_ms": perf.p95_duration_ms, + "p99_duration_ms": perf.p99_duration_ms, + "error_rate": perf.error_rate, + } + + return CorrelationTraceResponse( + correlation_id=correlation_id, + total_duration_ms=total_duration, + log_count=len(logs), + error_count=error_count, + logs=[ + LogEntry( + id=str(log.id), + timestamp=log.timestamp, + level=log.level, + component=log.component, + message=log.message, + correlation_id=log.correlation_id, + user_id=log.user_id, + user_email=log.user_email, + duration_ms=log.duration_ms, + operation_type=log.operation_type, + request_path=log.request_path, + request_method=log.request_method, + is_security_event=log.is_security_event, + error_details=log.error_details, + ) + for log in logs + ], + security_events=[ + { + "id": str(event.id), + "timestamp": event.timestamp.isoformat(), + "event_type": event.event_type, + "severity": event.severity, + "description": event.description, + "threat_score": event.threat_score, + } + for event in security_events + ], + audit_trails=[ + { + "id": str(audit.id), + "timestamp": audit.timestamp.isoformat(), + "action": audit.action, + "resource_type": audit.resource_type, + "resource_id": audit.resource_id, + "success": audit.success, + } + for audit in audit_trails + ], + performance_metrics=perf_metrics, + ) + + except Exception as e: + logger.error(f"Correlation trace failed: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"Correlation trace failed: {str(e)}") + + +@router.get("/security-events", response_model=List[SecurityEventResponse]) +@require_permission("security:read") +async def get_security_events( + severity: Optional[List[str]] = Query(None), + event_type: Optional[List[str]] = Query(None), + resolved: Optional[bool] = Query(None), + start_time: Optional[datetime] = Query(None), + end_time: Optional[datetime] = Query(None), + limit: int = Query(100, ge=1, le=1000), + offset: int = Query(0, ge=0), + user=Depends(get_current_user_with_permissions), + db: Session = Depends(get_db), +) -> List[SecurityEventResponse]: + """Get security events with filters. + + Args: + severity: Filter by severity levels + event_type: Filter by event types + resolved: Filter by resolution status + start_time: Start timestamp + end_time: End timestamp + limit: Maximum results + offset: Result offset + user: Current authenticated user + db: Database session + + Returns: + List of security events + + Raises: + HTTPException: On database or validation errors + """ + try: + stmt = select(SecurityEvent) + + conditions = [] + if severity: + conditions.append(SecurityEvent.severity.in_(severity)) + if event_type: + conditions.append(SecurityEvent.event_type.in_(event_type)) + if resolved is not None: + conditions.append(SecurityEvent.resolved == resolved) + if start_time: + conditions.append(SecurityEvent.timestamp >= start_time) + if end_time: + conditions.append(SecurityEvent.timestamp <= end_time) + + if conditions: + stmt = stmt.where(and_(*conditions)) + + stmt = stmt.order_by(desc(SecurityEvent.timestamp)).limit(limit).offset(offset) + + events = db.execute(stmt).scalars().all() + + return [ + SecurityEventResponse( + id=str(event.id), + timestamp=event.timestamp, + event_type=event.event_type, + severity=event.severity, + category=event.category, + user_id=event.user_id, + client_ip=event.client_ip, + description=event.description, + threat_score=event.threat_score, + action_taken=event.action_taken, + resolved=event.resolved, + ) + for event in events + ] + + except Exception as e: + logger.error(f"Security events query failed: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"Security events query failed: {str(e)}") + + +@router.get("/audit-trails", response_model=List[AuditTrailResponse]) +@require_permission("audit:read") +async def get_audit_trails( + action: Optional[List[str]] = Query(None), + resource_type: Optional[List[str]] = Query(None), + user_id: Optional[str] = Query(None), + requires_review: Optional[bool] = Query(None), + start_time: Optional[datetime] = Query(None), + end_time: Optional[datetime] = Query(None), + limit: int = Query(100, ge=1, le=1000), + offset: int = Query(0, ge=0), + user=Depends(get_current_user_with_permissions), + db: Session = Depends(get_db), +) -> List[AuditTrailResponse]: + """Get audit trails with filters. + + Args: + action: Filter by actions + resource_type: Filter by resource types + user_id: Filter by user ID + requires_review: Filter by review requirement + start_time: Start timestamp + end_time: End timestamp + limit: Maximum results + offset: Result offset + user: Current authenticated user + db: Database session + + Returns: + List of audit trail entries + + Raises: + HTTPException: On database or validation errors + """ + try: + stmt = select(AuditTrail) + + conditions = [] + if action: + conditions.append(AuditTrail.action.in_(action)) + if resource_type: + conditions.append(AuditTrail.resource_type.in_(resource_type)) + if user_id: + conditions.append(AuditTrail.user_id == user_id) + if requires_review is not None: + conditions.append(AuditTrail.requires_review == requires_review) + if start_time: + conditions.append(AuditTrail.timestamp >= start_time) + if end_time: + conditions.append(AuditTrail.timestamp <= end_time) + + if conditions: + stmt = stmt.where(and_(*conditions)) + + stmt = stmt.order_by(desc(AuditTrail.timestamp)).limit(limit).offset(offset) + + trails = db.execute(stmt).scalars().all() + + return [ + AuditTrailResponse( + id=str(trail.id), + timestamp=trail.timestamp, + correlation_id=trail.correlation_id, + action=trail.action, + resource_type=trail.resource_type, + resource_id=trail.resource_id, + resource_name=trail.resource_name, + user_id=trail.user_id, + user_email=trail.user_email, + success=trail.success, + requires_review=trail.requires_review, + data_classification=trail.data_classification, + ) + for trail in trails + ] + + except Exception as e: + logger.error(f"Audit trails query failed: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"Audit trails query failed: {str(e)}") + + +@router.get("/performance-metrics", response_model=List[PerformanceMetricResponse]) +@require_permission("metrics:read") +async def get_performance_metrics( + component: Optional[str] = Query(None), + operation: Optional[str] = Query(None), + hours: float = Query(24.0, ge=MIN_PERFORMANCE_RANGE_HOURS, le=1000.0, description="Historical window to display"), + aggregation: str = Query(_DEFAULT_AGGREGATION_KEY, regex="^(5m|24h)$", description="Aggregation level for metrics"), + user=Depends(get_current_user_with_permissions), + db: Session = Depends(get_db), +) -> List[PerformanceMetricResponse]: + """Get performance metrics. + + Args: + component: Filter by component + operation: Filter by operation + aggregation: Aggregation level (5m, 1h, 1d, 7d) + hours: Hours of history + user: Current authenticated user + db: Database session + + Returns: + List of performance metrics + + Raises: + HTTPException: On database or validation errors + """ + try: + aggregation_config = _AGGREGATION_LEVELS.get(aggregation, _AGGREGATION_LEVELS[_DEFAULT_AGGREGATION_KEY]) + window_minutes = aggregation_config["minutes"] + window_duration_seconds = window_minutes * 60 + + if settings.metrics_aggregation_enabled: + try: + aggregator = get_log_aggregator() + if aggregation == "5m": + aggregator.backfill(hours=hours, db=db) + else: + _aggregate_custom_windows( + aggregator=aggregator, + window_minutes=window_minutes, + db=db, + ) + except Exception as agg_error: # pragma: no cover - defensive logging + logger.warning("On-demand metrics aggregation failed: %s", agg_error) + + stmt = select(PerformanceMetric).where(PerformanceMetric.window_duration_seconds == window_duration_seconds) + + if component: + stmt = stmt.where(PerformanceMetric.component == component) + if operation: + stmt = stmt.where(PerformanceMetric.operation_type == operation) + + stmt = stmt.order_by(desc(PerformanceMetric.window_start), desc(PerformanceMetric.timestamp)) + + metrics = db.execute(stmt).scalars().all() + + metrics = _deduplicate_metrics(metrics) + + return [ + PerformanceMetricResponse( + id=str(metric.id), + timestamp=metric.timestamp, + component=metric.component, + operation_type=metric.operation_type, + window_start=metric.window_start, + window_end=metric.window_end, + request_count=metric.request_count, + error_count=metric.error_count, + error_rate=metric.error_rate, + avg_duration_ms=metric.avg_duration_ms, + min_duration_ms=metric.min_duration_ms, + max_duration_ms=metric.max_duration_ms, + p50_duration_ms=metric.p50_duration_ms, + p95_duration_ms=metric.p95_duration_ms, + p99_duration_ms=metric.p99_duration_ms, + ) + for metric in metrics + ] + + except Exception as e: + logger.error(f"Performance metrics query failed: {e}") + raise HTTPException(status_code=500, detail="Performance metrics query failed") diff --git a/mcpgateway/services/a2a_service.py b/mcpgateway/services/a2a_service.py index 33f3468d0..6fa2b5774 100644 --- a/mcpgateway/services/a2a_service.py +++ b/mcpgateway/services/a2a_service.py @@ -26,8 +26,10 @@ from mcpgateway.db import A2AAgentMetric, EmailTeam from mcpgateway.schemas import A2AAgentCreate, A2AAgentMetrics, A2AAgentRead, A2AAgentUpdate from mcpgateway.services.logging_service import LoggingService +from mcpgateway.services.structured_logger import get_structured_logger from mcpgateway.services.team_management_service import TeamManagementService from mcpgateway.services.tool_service import ToolService +from mcpgateway.utils.correlation_id import get_correlation_id from mcpgateway.utils.create_slug import slugify from mcpgateway.utils.services_auth import encode_auth # ,decode_auth @@ -35,6 +37,9 @@ logging_service = LoggingService() logger = logging_service.get_logger(__name__) +# Initialize structured logger for A2A lifecycle tracking +structured_logger = get_structured_logger("a2a_service") + class A2AAgentError(Exception): """Base class for A2A agent-related errors. @@ -279,6 +284,25 @@ async def register_agent( ) logger.info(f"Registered new A2A agent: {new_agent.name} (ID: {new_agent.id})") + + # Log A2A agent registration for lifecycle tracking + structured_logger.info( + f"A2A agent '{new_agent.name}' registered successfully", + user_id=created_by, + user_email=owner_email, + team_id=team_id, + resource_type="a2a_agent", + resource_id=str(new_agent.id), + resource_action="create", + custom_fields={ + "agent_name": new_agent.name, + "agent_type": new_agent.agent_type, + "protocol_version": new_agent.protocol_version, + "visibility": visibility, + "endpoint_url": new_agent.endpoint_url, + }, + ) + return self._db_to_schema(db=db, db_agent=new_agent) except A2AAgentNameConflictError as ie: @@ -716,6 +740,21 @@ async def toggle_agent_status(self, db: Session, agent_id: str, activate: bool, status = "activated" if activate else "deactivated" logger.info(f"A2A agent {status}: {agent.name} (ID: {agent.id})") + structured_logger.log( + level="INFO", + message=f"A2A agent {status}", + event_type="a2a_agent_status_changed", + component="a2a_service", + user_email=user_email, + resource_type="a2a_agent", + resource_id=str(agent.id), + custom_fields={ + "agent_name": agent.name, + "enabled": agent.enabled, + "reachable": agent.reachable, + }, + ) + return self._db_to_schema(db=db, db_agent=agent) async def delete_agent(self, db: Session, agent_id: str, user_email: Optional[str] = None) -> None: @@ -751,11 +790,31 @@ async def delete_agent(self, db: Session, agent_id: str, user_email: Optional[st db.commit() logger.info(f"Deleted A2A agent: {agent_name} (ID: {agent_id})") + + structured_logger.log( + level="INFO", + message="A2A agent deleted", + event_type="a2a_agent_deleted", + component="a2a_service", + user_email=user_email, + resource_type="a2a_agent", + resource_id=str(agent_id), + custom_fields={"agent_name": agent_name}, + ) except PermissionError: db.rollback() raise - async def invoke_agent(self, db: Session, agent_name: str, parameters: Dict[str, Any], interaction_type: str = "query") -> Dict[str, Any]: + async def invoke_agent( + self, + db: Session, + agent_name: str, + parameters: Dict[str, Any], + interaction_type: str = "query", + *, + user_id: Optional[str] = None, + user_email: Optional[str] = None, + ) -> Dict[str, Any]: """Invoke an A2A agent. Args: @@ -763,6 +822,8 @@ async def invoke_agent(self, db: Session, agent_name: str, parameters: Dict[str, agent_name: Name of the agent to invoke. parameters: Parameters for the interaction. interaction_type: Type of interaction. + user_id: Identifier of the user initiating the call. + user_email: Email of the user initiating the call. Returns: Agent response. @@ -803,13 +864,64 @@ async def invoke_agent(self, db: Session, agent_name: str, parameters: Dict[str, if token_value: headers["Authorization"] = f"Bearer {token_value}" + # Add correlation ID to outbound headers for distributed tracing + correlation_id = get_correlation_id() + if correlation_id: + headers["X-Correlation-ID"] = correlation_id + + # Log A2A external call start + call_start_time = datetime.now(timezone.utc) + structured_logger.log( + level="INFO", + message=f"A2A external call started: {agent_name}", + component="a2a_service", + user_id=user_id, + user_email=user_email, + correlation_id=correlation_id, + metadata={ + "event": "a2a_call_started", + "agent_name": agent_name, + "agent_id": agent.id, + "endpoint_url": agent.endpoint_url, + "interaction_type": interaction_type, + "protocol_version": agent.protocol_version, + }, + ) + http_response = await client.post(agent.endpoint_url, json=request_data, headers=headers) + call_duration_ms = (datetime.now(timezone.utc) - call_start_time).total_seconds() * 1000 if http_response.status_code == 200: response = http_response.json() success = True + + # Log successful A2A call + structured_logger.log( + level="INFO", + message=f"A2A external call completed: {agent_name}", + component="a2a_service", + user_id=user_id, + user_email=user_email, + correlation_id=correlation_id, + duration_ms=call_duration_ms, + metadata={"event": "a2a_call_completed", "agent_name": agent_name, "agent_id": agent.id, "status_code": http_response.status_code, "success": True}, + ) else: error_message = f"HTTP {http_response.status_code}: {http_response.text}" + + # Log failed A2A call + structured_logger.log( + level="ERROR", + message=f"A2A external call failed: {agent_name}", + component="a2a_service", + user_id=user_id, + user_email=user_email, + correlation_id=correlation_id, + duration_ms=call_duration_ms, + error_details={"error_type": "A2AHTTPError", "error_message": error_message}, + metadata={"event": "a2a_call_failed", "agent_name": agent_name, "agent_id": agent.id, "status_code": http_response.status_code}, + ) + raise A2AAgentError(error_message) except Exception as e: diff --git a/mcpgateway/services/audit_trail_service.py b/mcpgateway/services/audit_trail_service.py new file mode 100644 index 000000000..3d9023bfe --- /dev/null +++ b/mcpgateway/services/audit_trail_service.py @@ -0,0 +1,445 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/services/audit_trail_service.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 + +Audit Trail Service. + +This module provides audit trail management for CRUD operations, +data access tracking, and compliance logging. +""" + +# Standard +from datetime import datetime, timezone +from enum import Enum +import logging +from typing import Any, Dict, Optional + +# Third-Party +from sqlalchemy import select +from sqlalchemy.orm import Session + +# First-Party +from mcpgateway.db import AuditTrail, SessionLocal +from mcpgateway.utils.correlation_id import get_or_generate_correlation_id + +logger = logging.getLogger(__name__) + + +class AuditAction(str, Enum): + """Audit trail action types.""" + + CREATE = "CREATE" + READ = "READ" + UPDATE = "UPDATE" + DELETE = "DELETE" + EXECUTE = "EXECUTE" + ACCESS = "ACCESS" + EXPORT = "EXPORT" + IMPORT = "IMPORT" + + +class DataClassification(str, Enum): + """Data classification levels.""" + + PUBLIC = "public" + INTERNAL = "internal" + CONFIDENTIAL = "confidential" + RESTRICTED = "restricted" + + +REVIEW_REQUIRED_ACTIONS = { + "delete_server", + "delete_tool", + "delete_resource", + "delete_gateway", + "update_sensitive_config", + "bulk_delete", +} + + +class AuditTrailService: + """Service for managing audit trails and compliance logging. + + Provides comprehensive audit trail management with data classification, + change tracking, and compliance reporting capabilities. + """ + + def __init__(self): + """Initialize audit trail service.""" + + def log_action( # pylint: disable=too-many-positional-arguments + self, + action: str, + resource_type: str, + resource_id: str, + user_id: str, + user_email: Optional[str] = None, + team_id: Optional[str] = None, + resource_name: Optional[str] = None, + client_ip: Optional[str] = None, + user_agent: Optional[str] = None, + request_path: Optional[str] = None, + request_method: Optional[str] = None, + old_values: Optional[Dict[str, Any]] = None, + new_values: Optional[Dict[str, Any]] = None, + changes: Optional[Dict[str, Any]] = None, + data_classification: Optional[str] = None, + requires_review: Optional[bool] = None, + success: bool = True, + error_message: Optional[str] = None, + context: Optional[Dict[str, Any]] = None, + details: Optional[Dict[str, Any]] = None, + metadata: Optional[Dict[str, Any]] = None, + db: Optional[Session] = None, + ) -> Optional[AuditTrail]: + """Log an audit trail entry. + + Args: + action: Action performed (CREATE, READ, UPDATE, DELETE, etc.) + resource_type: Type of resource (tool, server, prompt, etc.) + resource_id: ID of the resource + user_id: User who performed the action + user_email: User's email address + team_id: Team ID if applicable + resource_name: Name of the resource + client_ip: Client IP address + user_agent: Client user agent + request_path: HTTP request path + request_method: HTTP request method + old_values: Previous values before change + new_values: New values after change + changes: Specific changes made + data_classification: Data classification level + requires_review: Whether this action requires review (None = auto) + success: Whether the action succeeded + error_message: Error message if failed + context: Additional context + details: Extra key/value payload (stored under context.details) + metadata: Extra metadata payload (stored under context.metadata) + db: Optional database session + + Returns: + Created AuditTrail entry or None if logging disabled + """ + correlation_id = get_or_generate_correlation_id() + + # Use provided session or create new one + close_db = False + if db is None: + db = SessionLocal() + close_db = True + + try: + context_payload: Dict[str, Any] = dict(context) if context else {} + if details: + context_payload["details"] = details + if metadata: + context_payload["metadata"] = metadata + context_value = context_payload if context_payload else None + + requires_review_flag = self._determine_requires_review( + action=action, + data_classification=data_classification, + requires_review_param=requires_review, + ) + + # Create audit trail entry + audit_entry = AuditTrail( + timestamp=datetime.now(timezone.utc), + correlation_id=correlation_id, + action=action, + resource_type=resource_type, + resource_id=resource_id, + resource_name=resource_name, + user_id=user_id, + user_email=user_email, + team_id=team_id, + client_ip=client_ip, + user_agent=user_agent, + request_path=request_path, + request_method=request_method, + old_values=old_values, + new_values=new_values, + changes=changes, + data_classification=data_classification, + requires_review=requires_review_flag, + success=success, + error_message=error_message, + context=context_value, + ) + + db.add(audit_entry) + db.commit() + db.refresh(audit_entry) + + logger.debug( + f"Audit trail logged: {action} {resource_type}/{resource_id} by {user_id}", + extra={"correlation_id": correlation_id, "action": action, "resource_type": resource_type, "resource_id": resource_id, "user_id": user_id, "success": success}, + ) + + return audit_entry + + except Exception as e: + logger.error(f"Failed to log audit trail: {e}", exc_info=True, extra={"correlation_id": correlation_id, "action": action, "resource_type": resource_type, "resource_id": resource_id}) + if close_db: + db.rollback() + return None + + finally: + if close_db: + db.close() + + def _determine_requires_review( + self, + action: Optional[str], + data_classification: Optional[str], + requires_review_param: Optional[bool], + ) -> bool: + """Resolve whether an audit entry should require review. + + Args: + action: Action being performed + data_classification: Data classification level + requires_review_param: Explicit review requirement + + Returns: + bool: Whether the audit entry requires review + """ + if requires_review_param is not None: + return requires_review_param + + if data_classification in {DataClassification.CONFIDENTIAL.value, DataClassification.RESTRICTED.value}: + return True + + normalized_action = (action or "").lower() + if normalized_action in REVIEW_REQUIRED_ACTIONS: + return True + + return False + + def log_crud_operation( + self, + operation: str, + resource_type: str, + resource_id: str, + user_id: str, + user_email: Optional[str] = None, + team_id: Optional[str] = None, + resource_name: Optional[str] = None, + old_values: Optional[Dict[str, Any]] = None, + new_values: Optional[Dict[str, Any]] = None, + success: bool = True, + error_message: Optional[str] = None, + db: Optional[Session] = None, + **kwargs, + ) -> Optional[AuditTrail]: + """Log a CRUD operation with change tracking. + + Args: + operation: CRUD operation (CREATE, READ, UPDATE, DELETE) + resource_type: Type of resource + resource_id: ID of the resource + user_id: User who performed the operation + user_email: User's email + team_id: Team ID if applicable + resource_name: Name of the resource + old_values: Previous values (for UPDATE/DELETE) + new_values: New values (for CREATE/UPDATE) + success: Whether the operation succeeded + error_message: Error message if failed + db: Optional database session + **kwargs: Additional arguments passed to log_action + + Returns: + Created AuditTrail entry + """ + # Calculate changes for UPDATE operations + changes = None + if operation == "UPDATE" and old_values and new_values: + changes = {} + for key in set(old_values.keys()) | set(new_values.keys()): + old_val = old_values.get(key) + new_val = new_values.get(key) + if old_val != new_val: + changes[key] = {"old": old_val, "new": new_val} + + # Determine data classification based on resource type + data_classification = None + if resource_type in ["user", "team", "token", "credential"]: + data_classification = DataClassification.CONFIDENTIAL.value + elif resource_type in ["tool", "server", "prompt", "resource"]: + data_classification = DataClassification.INTERNAL.value + + # Determine if review is required + requires_review = False + if data_classification == DataClassification.CONFIDENTIAL.value: + requires_review = True + if operation == "DELETE" and resource_type in ["tool", "server", "gateway"]: + requires_review = True + + return self.log_action( + action=operation, + resource_type=resource_type, + resource_id=resource_id, + user_id=user_id, + user_email=user_email, + team_id=team_id, + resource_name=resource_name, + old_values=old_values, + new_values=new_values, + changes=changes, + data_classification=data_classification, + requires_review=requires_review, + success=success, + error_message=error_message, + db=db, + **kwargs, + ) + + def log_data_access( + self, + resource_type: str, + resource_id: str, + user_id: str, + access_type: str = "READ", + user_email: Optional[str] = None, + team_id: Optional[str] = None, + resource_name: Optional[str] = None, + data_classification: Optional[str] = None, + db: Optional[Session] = None, + **kwargs, + ) -> Optional[AuditTrail]: + """Log data access for compliance tracking. + + Args: + resource_type: Type of resource accessed + resource_id: ID of the resource + user_id: User who accessed the data + access_type: Type of access (READ, EXPORT, etc.) + user_email: User's email + team_id: Team ID if applicable + resource_name: Name of the resource + data_classification: Data classification level + db: Optional database session + **kwargs: Additional arguments passed to log_action + + Returns: + Created AuditTrail entry + """ + requires_review = data_classification in [DataClassification.CONFIDENTIAL.value, DataClassification.RESTRICTED.value] + + return self.log_action( + action=access_type, + resource_type=resource_type, + resource_id=resource_id, + user_id=user_id, + user_email=user_email, + team_id=team_id, + resource_name=resource_name, + data_classification=data_classification, + requires_review=requires_review, + success=True, + db=db, + **kwargs, + ) + + def log_audit( + self, user_id: str, resource_type: str, resource_id: str, action: str, user_email: Optional[str] = None, description: Optional[str] = None, db: Optional[Session] = None, **kwargs + ) -> Optional[AuditTrail]: + """Convenience method for simple audit logging. + + Args: + user_id: User who performed the action + resource_type: Type of resource + resource_id: ID of the resource + action: Action performed + user_email: User's email + description: Description of the action + db: Optional database session + **kwargs: Additional arguments passed to log_action + + Returns: + Created AuditTrail entry + """ + # Build context if description provided + context = kwargs.pop("context", {}) + if description: + context["description"] = description + + return self.log_action(action=action, resource_type=resource_type, resource_id=resource_id, user_id=user_id, user_email=user_email, context=context if context else None, db=db, **kwargs) + + def get_audit_trail( + self, + resource_type: Optional[str] = None, + resource_id: Optional[str] = None, + user_id: Optional[str] = None, + action: Optional[str] = None, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, + limit: int = 100, + offset: int = 0, + db: Optional[Session] = None, + ) -> list[AuditTrail]: + """Query audit trail entries. + + Args: + resource_type: Filter by resource type + resource_id: Filter by resource ID + user_id: Filter by user ID + action: Filter by action + start_time: Filter by start time + end_time: Filter by end time + limit: Maximum number of results + offset: Offset for pagination + db: Optional database session + + Returns: + List of AuditTrail entries + """ + close_db = False + if db is None: + db = SessionLocal() + close_db = True + + try: + query = select(AuditTrail) + + if resource_type: + query = query.where(AuditTrail.resource_type == resource_type) + if resource_id: + query = query.where(AuditTrail.resource_id == resource_id) + if user_id: + query = query.where(AuditTrail.user_id == user_id) + if action: + query = query.where(AuditTrail.action == action) + if start_time: + query = query.where(AuditTrail.timestamp >= start_time) + if end_time: + query = query.where(AuditTrail.timestamp <= end_time) + + query = query.order_by(AuditTrail.timestamp.desc()) + query = query.limit(limit).offset(offset) + + result = db.execute(query) + return list(result.scalars().all()) + + finally: + if close_db: + db.close() + + +# Singleton instance +_audit_trail_service: Optional[AuditTrailService] = None + + +def get_audit_trail_service() -> AuditTrailService: + """Get or create the singleton audit trail service instance. + + Returns: + AuditTrailService instance + """ + global _audit_trail_service # pylint: disable=global-statement + if _audit_trail_service is None: + _audit_trail_service = AuditTrailService() + return _audit_trail_service diff --git a/mcpgateway/services/export_service.py b/mcpgateway/services/export_service.py index d5806dd59..78a5a5763 100644 --- a/mcpgateway/services/export_service.py +++ b/mcpgateway/services/export_service.py @@ -399,7 +399,7 @@ async def _export_servers(self, db: Session, tags: Optional[List[str]], include_ "websocket_endpoint": f"{root_path}/servers/{server.id}/ws", "jsonrpc_endpoint": f"{root_path}/servers/{server.id}/jsonrpc", "capabilities": {"tools": {"list_changed": True}, "prompts": {"list_changed": True}}, - "is_active": server.is_active, + "is_active": getattr(server, "enabled", getattr(server, "is_active", False)), "tags": server.tags or [], } @@ -469,7 +469,7 @@ async def _export_resources(self, db: Session, tags: Optional[List[str]], includ "description": resource.description, "mime_type": resource.mime_type, "tags": resource.tags or [], - "is_active": resource.is_active, + "is_active": getattr(resource, "enabled", getattr(resource, "is_active", False)), "last_modified": resource.updated_at.isoformat() if resource.updated_at else None, } diff --git a/mcpgateway/services/gateway_service.py b/mcpgateway/services/gateway_service.py index 194db6f61..9d47fea1e 100644 --- a/mcpgateway/services/gateway_service.py +++ b/mcpgateway/services/gateway_service.py @@ -80,11 +80,13 @@ from mcpgateway.db import Tool as DbTool from mcpgateway.observability import create_span from mcpgateway.schemas import GatewayCreate, GatewayRead, GatewayUpdate, PromptCreate, ResourceCreate, ToolCreate -from mcpgateway.services.event_service import EventService # logging.getLogger("httpx").setLevel(logging.WARNING) # Disables httpx logs for regular health checks +from mcpgateway.services.audit_trail_service import get_audit_trail_service +from mcpgateway.services.event_service import EventService from mcpgateway.services.logging_service import LoggingService from mcpgateway.services.oauth_manager import OAuthManager +from mcpgateway.services.structured_logger import get_structured_logger from mcpgateway.services.team_management_service import TeamManagementService from mcpgateway.services.tool_service import ToolService from mcpgateway.utils.create_slug import slugify @@ -98,6 +100,10 @@ logging_service = LoggingService() logger = logging_service.get_logger(__name__) +# Initialize structured logger and audit trail for gateway operations +structured_logger = get_structured_logger("gateway_service") +audit_trail = get_audit_trail_service() + GW_FAILURE_THRESHOLD = settings.unhealthy_threshold GW_HEALTH_CHECK_INTERVAL = settings.health_check_interval @@ -809,6 +815,54 @@ async def register_gateway( # Notify subscribers await self._notify_gateway_added(db_gateway) + logger.info(f"Registered gateway: {gateway.name}") + + # Structured logging: Audit trail for gateway creation + audit_trail.log_action( + user_id=created_by or "system", + action="create_gateway", + resource_type="gateway", + resource_id=str(db_gateway.id), + resource_name=db_gateway.name, + user_email=owner_email, + team_id=team_id, + client_ip=created_from_ip, + user_agent=created_user_agent, + new_values={ + "name": db_gateway.name, + "url": db_gateway.url, + "visibility": visibility, + "transport": db_gateway.transport, + "tools_count": len(tools), + "resources_count": len(db_resources), + "prompts_count": len(db_prompts), + }, + context={ + "created_via": created_via, + }, + db=db, + ) + + # Structured logging: Log successful gateway creation + structured_logger.log( + level="INFO", + message="Gateway created successfully", + event_type="gateway_created", + component="gateway_service", + user_id=created_by, + user_email=owner_email, + team_id=team_id, + resource_type="gateway", + resource_id=str(db_gateway.id), + custom_fields={ + "gateway_name": db_gateway.name, + "gateway_url": normalized_url, + "visibility": visibility, + "transport": db_gateway.transport, + }, + db=db, + ) + # Add team name for response db_gateway.team = self._get_team_name(db, db_gateway.team_id) return GatewayRead.model_validate(self._prepare_gateway_for_read(db_gateway)).masked() @@ -816,31 +870,101 @@ async def register_gateway( if TYPE_CHECKING: ge: ExceptionGroup[GatewayConnectionError] logger.error(f"GatewayConnectionError in group: {ge.exceptions}") + + structured_logger.log( + level="ERROR", + message="Gateway creation failed due to connection error", + event_type="gateway_creation_failed", + component="gateway_service", + user_id=created_by, + user_email=owner_email, + error=ge.exceptions[0], + custom_fields={"gateway_name": gateway.name, "gateway_url": str(gateway.url)}, + db=db, + ) raise ge.exceptions[0] except* GatewayNameConflictError as gnce: # pragma: no mutate if TYPE_CHECKING: gnce: ExceptionGroup[GatewayNameConflictError] logger.error(f"GatewayNameConflictError in group: {gnce.exceptions}") + + structured_logger.log( + level="WARNING", + message="Gateway creation failed due to name conflict", + event_type="gateway_name_conflict", + component="gateway_service", + user_id=created_by, + user_email=owner_email, + custom_fields={"gateway_name": gateway.name, "visibility": visibility}, + db=db, + ) raise gnce.exceptions[0] except* GatewayDuplicateConflictError as guce: # pragma: no mutate if TYPE_CHECKING: guce: ExceptionGroup[GatewayDuplicateConflictError] logger.error(f"GatewayDuplicateConflictError in group: {guce.exceptions}") + + structured_logger.log( + level="WARNING", + message="Gateway creation failed due to duplicate", + event_type="gateway_duplicate_conflict", + component="gateway_service", + user_id=created_by, + user_email=owner_email, + custom_fields={"gateway_name": gateway.name}, + db=db, + ) raise guce.exceptions[0] except* ValueError as ve: # pragma: no mutate if TYPE_CHECKING: ve: ExceptionGroup[ValueError] logger.error(f"ValueErrors in group: {ve.exceptions}") + + structured_logger.log( + level="ERROR", + message="Gateway creation failed due to validation error", + event_type="gateway_creation_failed", + component="gateway_service", + user_id=created_by, + user_email=owner_email, + error=ve.exceptions[0], + custom_fields={"gateway_name": gateway.name}, + db=db, + ) raise ve.exceptions[0] except* RuntimeError as re: # pragma: no mutate if TYPE_CHECKING: re: ExceptionGroup[RuntimeError] logger.error(f"RuntimeErrors in group: {re.exceptions}") + + structured_logger.log( + level="ERROR", + message="Gateway creation failed due to runtime error", + event_type="gateway_creation_failed", + component="gateway_service", + user_id=created_by, + user_email=owner_email, + error=re.exceptions[0], + custom_fields={"gateway_name": gateway.name}, + db=db, + ) raise re.exceptions[0] except* IntegrityError as ie: # pragma: no mutate if TYPE_CHECKING: ie: ExceptionGroup[IntegrityError] logger.error(f"IntegrityErrors in group: {ie.exceptions}") + + structured_logger.log( + level="ERROR", + message="Gateway creation failed due to database integrity error", + event_type="gateway_creation_failed", + component="gateway_service", + user_id=created_by, + user_email=owner_email, + error=ie.exceptions[0], + custom_fields={"gateway_name": gateway.name}, + db=db, + ) raise ie.exceptions[0] except* BaseException as other: # catches every other sub-exception # pragma: no mutate if TYPE_CHECKING: @@ -1461,6 +1585,47 @@ async def update_gateway( await self._notify_gateway_updated(gateway) logger.info(f"Updated gateway: {gateway.name}") + + # Structured logging: Audit trail for gateway update + audit_trail.log_action( + user_id=user_email or modified_by or "system", + action="update_gateway", + resource_type="gateway", + resource_id=str(gateway.id), + resource_name=gateway.name, + user_email=user_email, + team_id=gateway.team_id, + client_ip=modified_from_ip, + user_agent=modified_user_agent, + new_values={ + "name": gateway.name, + "url": gateway.url, + "version": gateway.version, + }, + context={ + "modified_via": modified_via, + }, + db=db, + ) + + # Structured logging: Log successful gateway update + structured_logger.log( + level="INFO", + message="Gateway updated successfully", + event_type="gateway_updated", + component="gateway_service", + user_id=modified_by, + user_email=user_email, + team_id=gateway.team_id, + resource_type="gateway", + resource_id=str(gateway.id), + custom_fields={ + "gateway_name": gateway.name, + "version": gateway.version, + }, + db=db, + ) + gateway.team = self._get_team_name(db, getattr(gateway, "team_id", None)) return GatewayRead.model_validate(self._prepare_gateway_for_read(gateway)) @@ -1468,18 +1633,78 @@ async def update_gateway( return None except GatewayNameConflictError as ge: logger.error(f"GatewayNameConflictError in group: {ge}") + + structured_logger.log( + level="WARNING", + message="Gateway update failed due to name conflict", + event_type="gateway_name_conflict", + component="gateway_service", + user_email=user_email, + resource_type="gateway", + resource_id=gateway_id, + error=ge, + db=db, + ) raise ge except GatewayNotFoundError as gnfe: logger.error(f"GatewayNotFoundError: {gnfe}") + + structured_logger.log( + level="ERROR", + message="Gateway update failed - gateway not found", + event_type="gateway_not_found", + component="gateway_service", + user_email=user_email, + resource_type="gateway", + resource_id=gateway_id, + error=gnfe, + db=db, + ) raise gnfe except IntegrityError as ie: logger.error(f"IntegrityErrors in group: {ie}") + + structured_logger.log( + level="ERROR", + message="Gateway update failed due to database integrity error", + event_type="gateway_update_failed", + component="gateway_service", + user_email=user_email, + resource_type="gateway", + resource_id=gateway_id, + error=ie, + db=db, + ) raise ie - except PermissionError: + except PermissionError as pe: db.rollback() + + structured_logger.log( + level="WARNING", + message="Gateway update failed due to permission error", + event_type="gateway_update_permission_denied", + component="gateway_service", + user_email=user_email, + resource_type="gateway", + resource_id=gateway_id, + error=pe, + db=db, + ) raise except Exception as e: db.rollback() + + structured_logger.log( + level="ERROR", + message="Gateway update failed", + event_type="gateway_update_failed", + component="gateway_service", + user_email=user_email, + resource_type="gateway", + resource_id=gateway_id, + error=e, + db=db, + ) raise GatewayError(f"Failed to update gateway: {str(e)}") async def get_gateway(self, db: Session, gateway_id: str, include_inactive: bool = True) -> GatewayRead: @@ -1542,6 +1767,24 @@ async def get_gateway(self, db: Session, gateway_id: str, include_inactive: bool if gateway.enabled or include_inactive: gateway.team = self._get_team_name(db, getattr(gateway, "team_id", None)) + + # Structured logging: Log gateway view + structured_logger.log( + level="INFO", + message="Gateway retrieved successfully", + event_type="gateway_viewed", + component="gateway_service", + team_id=getattr(gateway, "team_id", None), + resource_type="gateway", + resource_id=str(gateway.id), + custom_fields={ + "gateway_name": gateway.name, + "gateway_url": gateway.url, + "include_inactive": include_inactive, + }, + db=db, + ) + return GatewayRead.model_validate(self._prepare_gateway_for_read(gateway)).masked() raise GatewayNotFoundError(f"Gateway not found: {gateway_id}") @@ -1689,13 +1932,76 @@ async def toggle_gateway_status(self, db: Session, gateway_id: str, activate: bo logger.info(f"Gateway status: {gateway.name} - {'enabled' if activate else 'disabled'} and {'accessible' if reachable else 'inaccessible'}") + # Structured logging: Audit trail for gateway status toggle + audit_trail.log_action( + user_id=user_email or "system", + action="toggle_gateway_status", + resource_type="gateway", + resource_id=str(gateway.id), + resource_name=gateway.name, + user_email=user_email, + team_id=gateway.team_id, + new_values={ + "enabled": gateway.enabled, + "reachable": gateway.reachable, + }, + context={ + "action": "activate" if activate else "deactivate", + "only_update_reachable": only_update_reachable, + }, + db=db, + ) + + # Structured logging: Log successful gateway status toggle + structured_logger.log( + level="INFO", + message=f"Gateway {'activated' if activate else 'deactivated'} successfully", + event_type="gateway_status_toggled", + component="gateway_service", + user_email=user_email, + team_id=gateway.team_id, + resource_type="gateway", + resource_id=str(gateway.id), + custom_fields={ + "gateway_name": gateway.name, + "enabled": gateway.enabled, + "reachable": gateway.reachable, + }, + db=db, + ) + gateway.team = self._get_team_name(db, getattr(gateway, "team_id", None)) return GatewayRead.model_validate(self._prepare_gateway_for_read(gateway)).masked() except PermissionError as e: + # Structured logging: Log permission error + structured_logger.log( + level="WARNING", + message="Gateway status toggle failed due to permission error", + event_type="gateway_toggle_permission_denied", + component="gateway_service", + user_email=user_email, + resource_type="gateway", + resource_id=gateway_id, + error=e, + db=db, + ) raise e except Exception as e: db.rollback() + + # Structured logging: Log generic gateway status toggle failure + structured_logger.log( + level="ERROR", + message="Gateway status toggle failed", + event_type="gateway_toggle_failed", + component="gateway_service", + user_email=user_email, + resource_type="gateway", + resource_id=gateway_id, + error=e, + db=db, + ) raise GatewayError(f"Failed to toggle gateway status: {str(e)}") async def _notify_gateway_updated(self, gateway: DbGateway) -> None: @@ -1765,6 +2071,8 @@ async def delete_gateway(self, db: Session, gateway_id: str, user_email: Optiona # Store gateway info for notification before deletion gateway_info = {"id": gateway.id, "name": gateway.name, "url": gateway.url} + gateway_name = gateway.name + gateway_team_id = gateway.team_id # Hard delete gateway db.delete(gateway) @@ -1778,11 +2086,70 @@ async def delete_gateway(self, db: Session, gateway_id: str, user_email: Optiona logger.info(f"Permanently deleted gateway: {gateway.name}") - except PermissionError: + # Structured logging: Audit trail for gateway deletion + audit_trail.log_action( + user_id=user_email or "system", + action="delete_gateway", + resource_type="gateway", + resource_id=str(gateway_info["id"]), + resource_name=gateway_name, + user_email=user_email, + team_id=gateway_team_id, + old_values={ + "name": gateway_name, + "url": gateway_info["url"], + }, + db=db, + ) + + # Structured logging: Log successful gateway deletion + structured_logger.log( + level="INFO", + message="Gateway deleted successfully", + event_type="gateway_deleted", + component="gateway_service", + user_email=user_email, + team_id=gateway_team_id, + resource_type="gateway", + resource_id=str(gateway_info["id"]), + custom_fields={ + "gateway_name": gateway_name, + "gateway_url": gateway_info["url"], + }, + db=db, + ) + + except PermissionError as pe: db.rollback() + + # Structured logging: Log permission error + structured_logger.log( + level="WARNING", + message="Gateway deletion failed due to permission error", + event_type="gateway_delete_permission_denied", + component="gateway_service", + user_email=user_email, + resource_type="gateway", + resource_id=gateway_id, + error=pe, + db=db, + ) raise except Exception as e: db.rollback() + + # Structured logging: Log generic gateway deletion failure + structured_logger.log( + level="ERROR", + message="Gateway deletion failed", + event_type="gateway_deletion_failed", + component="gateway_service", + user_email=user_email, + resource_type="gateway", + resource_id=gateway_id, + error=e, + db=db, + ) raise GatewayError(f"Failed to delete gateway: {str(e)}") async def forward_request( diff --git a/mcpgateway/services/log_aggregator.py b/mcpgateway/services/log_aggregator.py new file mode 100644 index 000000000..2d7f0f293 --- /dev/null +++ b/mcpgateway/services/log_aggregator.py @@ -0,0 +1,526 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/services/log_aggregator.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 + +Log Aggregation Service. + +This module provides aggregation of performance metrics from structured logs +into time-windowed statistics for analysis and monitoring. +""" + +# Standard +from datetime import datetime, timedelta, timezone +import logging +import math +import statistics +from typing import Any, Dict, List, Optional, Tuple + +# Third-Party +from sqlalchemy import and_, select +from sqlalchemy.orm import Session + +# First-Party +from mcpgateway.config import settings +from mcpgateway.db import PerformanceMetric, SessionLocal, StructuredLogEntry + +logger = logging.getLogger(__name__) + + +class LogAggregator: + """Aggregates structured logs into performance metrics.""" + + def __init__(self): + """Initialize log aggregator.""" + self.aggregation_window_minutes = getattr(settings, "metrics_aggregation_window_minutes", 5) + self.enabled = getattr(settings, "metrics_aggregation_enabled", True) + + def aggregate_performance_metrics( + self, component: Optional[str], operation_type: Optional[str], window_start: Optional[datetime] = None, window_end: Optional[datetime] = None, db: Optional[Session] = None + ) -> Optional[PerformanceMetric]: + """Aggregate performance metrics for a component and operation. + + Args: + component: Component name + operation_type: Operation name + window_start: Start of aggregation window (defaults to N minutes ago) + window_end: End of aggregation window (defaults to now) + db: Optional database session + + Returns: + Created PerformanceMetric or None if no data + """ + if not self.enabled: + return None + if not component or not operation_type: + return None + + window_start, window_end = self._resolve_window_bounds(window_start, window_end) + + should_close = False + if db is None: + db = SessionLocal() + should_close = True + + try: + # Query structured logs for this component/operation in time window + stmt = select(StructuredLogEntry).where( + and_( + StructuredLogEntry.component == component, + StructuredLogEntry.operation_type == operation_type, + StructuredLogEntry.timestamp >= window_start, + StructuredLogEntry.timestamp < window_end, + StructuredLogEntry.duration_ms.isnot(None), + ) + ) + + results = db.execute(stmt).scalars().all() + + if not results: + return None + + # Extract durations + durations = sorted(r.duration_ms for r in results if r.duration_ms is not None) + + if not durations: + return None + + # Calculate statistics + count = len(durations) + avg_duration = statistics.fmean(durations) if hasattr(statistics, "fmean") else statistics.mean(durations) + min_duration = durations[0] + max_duration = durations[-1] + + # Calculate percentiles + p50 = self._percentile(durations, 0.50) + p95 = self._percentile(durations, 0.95) + p99 = self._percentile(durations, 0.99) + + # Count errors + error_count = self._calculate_error_count(results) + error_rate = error_count / count if count > 0 else 0.0 + + metric = self._upsert_metric( + component=component, + operation_type=operation_type, + window_start=window_start, + window_end=window_end, + request_count=count, + error_count=error_count, + error_rate=error_rate, + avg_duration_ms=avg_duration, + min_duration_ms=min_duration, + max_duration_ms=max_duration, + p50_duration_ms=p50, + p95_duration_ms=p95, + p99_duration_ms=p99, + metric_metadata={ + "sample_size": count, + "generated_at": datetime.now(timezone.utc).isoformat(), + }, + db=db, + ) + + logger.info(f"Aggregated performance metrics for {component}.{operation_type}: " f"{count} requests, {avg_duration:.2f}ms avg, {error_rate:.2%} error rate") + + return metric + + except Exception as e: + logger.error(f"Failed to aggregate performance metrics: {e}") + if db: + db.rollback() + return None + + finally: + if should_close: + db.close() + + def aggregate_all_components(self, window_start: Optional[datetime] = None, window_end: Optional[datetime] = None, db: Optional[Session] = None) -> List[PerformanceMetric]: + """Aggregate metrics for all components and operations. + + Args: + window_start: Start of aggregation window + window_end: End of aggregation window + db: Optional database session + + Returns: + List of created PerformanceMetric records + """ + if not self.enabled: + return [] + + should_close = False + if db is None: + db = SessionLocal() + should_close = True + + try: + window_start, window_end = self._resolve_window_bounds(window_start, window_end) + + stmt = ( + select(StructuredLogEntry.component, StructuredLogEntry.operation_type) + .where( + and_( + StructuredLogEntry.timestamp >= window_start, + StructuredLogEntry.timestamp < window_end, + StructuredLogEntry.duration_ms.isnot(None), + StructuredLogEntry.operation_type.isnot(None), + ) + ) + .distinct() + ) + + pairs = db.execute(stmt).all() + + metrics = [] + for component, operation in pairs: + if component and operation: + metric = self.aggregate_performance_metrics(component=component, operation_type=operation, window_start=window_start, window_end=window_end, db=db) + if metric: + metrics.append(metric) + + return metrics + + finally: + if should_close: + db.close() + + def get_recent_metrics(self, component: Optional[str] = None, operation: Optional[str] = None, hours: int = 24, db: Optional[Session] = None) -> List[PerformanceMetric]: + """Get recent performance metrics. + + Args: + component: Optional component filter + operation: Optional operation filter + hours: Hours of history to retrieve + db: Optional database session + + Returns: + List of PerformanceMetric records + """ + should_close = False + if db is None: + db = SessionLocal() + should_close = True + + try: + since = datetime.now(timezone.utc) - timedelta(hours=hours) + + stmt = select(PerformanceMetric).where(PerformanceMetric.window_start >= since) + + if component: + stmt = stmt.where(PerformanceMetric.component == component) + if operation: + stmt = stmt.where(PerformanceMetric.operation_type == operation) + + stmt = stmt.order_by(PerformanceMetric.window_start.desc()) + + return db.execute(stmt).scalars().all() + + finally: + if should_close: + db.close() + + def get_degradation_alerts(self, threshold_multiplier: float = 1.5, hours: int = 24, db: Optional[Session] = None) -> List[Dict[str, Any]]: + """Identify performance degradations by comparing recent vs baseline. + + Args: + threshold_multiplier: Alert if recent is X times slower than baseline + hours: Hours of recent data to check + db: Optional database session + + Returns: + List of degradation alerts with details + """ + should_close = False + if db is None: + db = SessionLocal() + should_close = True + + try: + recent_cutoff = datetime.now(timezone.utc) - timedelta(hours=hours) + baseline_cutoff = recent_cutoff - timedelta(hours=hours * 2) + + # Get unique component/operation pairs + stmt = select(PerformanceMetric.component, PerformanceMetric.operation_type).distinct() + + pairs = db.execute(stmt).all() + + alerts = [] + for component, operation in pairs: + # Get recent metrics + recent_stmt = select(PerformanceMetric).where( + and_(PerformanceMetric.component == component, PerformanceMetric.operation_type == operation, PerformanceMetric.window_start >= recent_cutoff) + ) + recent_metrics = db.execute(recent_stmt).scalars().all() + + # Get baseline metrics + baseline_stmt = select(PerformanceMetric).where( + and_( + PerformanceMetric.component == component, + PerformanceMetric.operation_type == operation, + PerformanceMetric.window_start >= baseline_cutoff, + PerformanceMetric.window_start < recent_cutoff, + ) + ) + baseline_metrics = db.execute(baseline_stmt).scalars().all() + + if not recent_metrics or not baseline_metrics: + continue + + recent_avg = statistics.mean([m.avg_duration_ms for m in recent_metrics]) + baseline_avg = statistics.mean([m.avg_duration_ms for m in baseline_metrics]) + + if recent_avg > baseline_avg * threshold_multiplier: + alerts.append( + { + "component": component, + "operation": operation, + "recent_avg_ms": recent_avg, + "baseline_avg_ms": baseline_avg, + "degradation_ratio": recent_avg / baseline_avg, + "recent_error_rate": statistics.mean([m.error_rate for m in recent_metrics]), + "baseline_error_rate": statistics.mean([m.error_rate for m in baseline_metrics]), + } + ) + + return alerts + + finally: + if should_close: + db.close() + + def backfill(self, hours: float, db: Optional[Session] = None) -> int: + """Backfill metrics for a historical time range. + + Args: + hours: Number of hours of history to aggregate (supports fractional hours) + db: Optional shared database session + + Returns: + Count of performance metric windows processed + """ + if not self.enabled or hours <= 0: + return 0 + + window_minutes = self.aggregation_window_minutes + window_delta = timedelta(minutes=window_minutes) + total_windows = max(1, math.ceil((hours * 60) / window_minutes)) + + should_close = False + if db is None: + db = SessionLocal() + should_close = True + + try: + _, latest_end = self._resolve_window_bounds(None, None) + current_start = latest_end - (window_delta * total_windows) + processed = 0 + + while current_start < latest_end: + current_end = current_start + window_delta + created = self.aggregate_all_components( + window_start=current_start, + window_end=current_end, + db=db, + ) + if created: + processed += 1 + current_start = current_end + + return processed + + finally: + if should_close: + db.close() + + @staticmethod + def _percentile(sorted_values: List[float], percentile: float) -> float: + """Calculate percentile from sorted values. + + Args: + sorted_values: Sorted list of values + percentile: Percentile to calculate (0.0 to 1.0) + + Returns: + float: Calculated percentile value + """ + if not sorted_values: + return 0.0 + + if len(sorted_values) == 1: + return float(sorted_values[0]) + + k = (len(sorted_values) - 1) * percentile + f = math.floor(k) + c = math.ceil(k) + + if f == c: + return float(sorted_values[int(k)]) + + d0 = sorted_values[f] * (c - k) + d1 = sorted_values[c] * (k - f) + return float(d0 + d1) + + @staticmethod + def _calculate_error_count(entries: List[StructuredLogEntry]) -> int: + """Calculate error occurrences for a batch of log entries. + + Args: + entries: List of log entries to analyze + + Returns: + int: Count of error entries + """ + error_levels = {"ERROR", "CRITICAL"} + return sum(1 for entry in entries if (entry.level and entry.level.upper() in error_levels) or entry.error_details) + + def _resolve_window_bounds( + self, + window_start: Optional[datetime], + window_end: Optional[datetime], + ) -> Tuple[datetime, datetime]: + """Resolve and normalize aggregation window bounds. + + Args: + window_start: Start of window or None to calculate + window_end: End of window or None for current time + + Returns: + Tuple[datetime, datetime]: Resolved window start and end + """ + window_delta = timedelta(minutes=self.aggregation_window_minutes) + + if window_start is not None and window_end is not None: + resolved_start = window_start.astimezone(timezone.utc) + resolved_end = window_end.astimezone(timezone.utc) + if resolved_end <= resolved_start: + resolved_end = resolved_start + window_delta + return resolved_start, resolved_end + + if window_end is None: + reference = datetime.now(timezone.utc) + else: + reference = window_end.astimezone(timezone.utc) + + reference = reference.replace(second=0, microsecond=0) + minutes_offset = reference.minute % self.aggregation_window_minutes + if window_end is None and minutes_offset: + reference = reference - timedelta(minutes=minutes_offset) + + resolved_end = reference if window_end is None else reference + + if window_start is None: + resolved_start = resolved_end - window_delta + else: + resolved_start = window_start.astimezone(timezone.utc) + + if resolved_end <= resolved_start: + resolved_start = resolved_end - window_delta + + return resolved_start, resolved_end + + def _upsert_metric( + self, + component: str, + operation_type: str, + window_start: datetime, + window_end: datetime, + request_count: int, + error_count: int, + error_rate: float, + avg_duration_ms: float, + min_duration_ms: float, + max_duration_ms: float, + p50_duration_ms: float, + p95_duration_ms: float, + p99_duration_ms: float, + metric_metadata: Optional[Dict[str, Any]], + db: Session, + ) -> PerformanceMetric: + """Create or update a performance metric window. + + Args: + component: Component name + operation_type: Operation type + window_start: Window start time + window_end: Window end time + request_count: Total request count + error_count: Total error count + error_rate: Error rate (0.0-1.0) + avg_duration_ms: Average duration in milliseconds + min_duration_ms: Minimum duration in milliseconds + max_duration_ms: Maximum duration in milliseconds + p50_duration_ms: 50th percentile duration + p95_duration_ms: 95th percentile duration + p99_duration_ms: 99th percentile duration + metric_metadata: Additional metadata + db: Database session + + Returns: + PerformanceMetric: Created or updated metric + """ + + existing_stmt = select(PerformanceMetric).where( + and_( + PerformanceMetric.component == component, + PerformanceMetric.operation_type == operation_type, + PerformanceMetric.window_start == window_start, + PerformanceMetric.window_end == window_end, + ) + ) + + existing_metrics = db.execute(existing_stmt).scalars().all() + metric = existing_metrics[0] if existing_metrics else None + + if len(existing_metrics) > 1: + logger.warning( + "Found %s duplicate performance metric rows for %s.%s window %s-%s; pruning extras", + len(existing_metrics), + component, + operation_type, + window_start.isoformat(), + window_end.isoformat(), + ) + for duplicate in existing_metrics[1:]: + db.delete(duplicate) + + if metric is None: + metric = PerformanceMetric( + component=component, + operation_type=operation_type, + window_start=window_start, + window_end=window_end, + window_duration_seconds=int((window_end - window_start).total_seconds()), + ) + db.add(metric) + + metric.request_count = request_count + metric.error_count = error_count + metric.error_rate = error_rate + metric.avg_duration_ms = avg_duration_ms + metric.min_duration_ms = min_duration_ms + metric.max_duration_ms = max_duration_ms + metric.p50_duration_ms = p50_duration_ms + metric.p95_duration_ms = p95_duration_ms + metric.p99_duration_ms = p99_duration_ms + metric.metric_metadata = metric_metadata + + db.commit() + db.refresh(metric) + return metric + + +# Global log aggregator instance +_log_aggregator: Optional[LogAggregator] = None + + +def get_log_aggregator() -> LogAggregator: + """Get or create the global log aggregator instance. + + Returns: + Global LogAggregator instance + """ + global _log_aggregator # pylint: disable=global-statement + if _log_aggregator is None: + _log_aggregator = LogAggregator() + return _log_aggregator diff --git a/mcpgateway/services/logging_service.py b/mcpgateway/services/logging_service.py index 4f21111c0..f18f826f9 100644 --- a/mcpgateway/services/logging_service.py +++ b/mcpgateway/services/logging_service.py @@ -16,6 +16,7 @@ import logging from logging.handlers import RotatingFileHandler import os +import socket from typing import Any, AsyncGenerator, Dict, List, NotRequired, Optional, TextIO, TypedDict # Third-Party @@ -25,10 +26,18 @@ from mcpgateway.common.models import LogLevel from mcpgateway.config import settings from mcpgateway.services.log_storage_service import LogStorageService +from mcpgateway.utils.correlation_id import get_correlation_id + +# Optional OpenTelemetry support (Third-Party) +try: + # Third-Party + from opentelemetry import trace # type: ignore[import-untyped] +except ImportError: + trace = None # type: ignore[assignment] AnyioClosedResourceError: Optional[type] # pylint: disable=invalid-name try: - # Optional import; only used for filtering a known benign upstream error + # Optional import; only used for filtering a known benign upstream error (Third-Party) # Third-Party from anyio import ClosedResourceError as AnyioClosedResourceError # pylint: disable=invalid-name except Exception: # pragma: no cover - environment without anyio @@ -38,8 +47,52 @@ # Create a text formatter text_formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") -# Create a JSON formatter -json_formatter = jsonlogger.JsonFormatter("%(asctime)s %(name)s %(levelname)s %(message)s") + +class CorrelationIdJsonFormatter(jsonlogger.JsonFormatter): + """JSON formatter that includes correlation ID and OpenTelemetry trace context.""" + + def add_fields(self, log_record: dict, record: logging.LogRecord, message_dict: dict) -> None: # pylint: disable=arguments-renamed + """Add custom fields to the log record. + + Args: + log_record: The dictionary that will be logged as JSON + record: The original LogRecord + message_dict: Additional message fields + + """ + super().add_fields(log_record, record, message_dict) + + # Add timestamp in ISO 8601 format with 'Z' suffix for UTC + dt = datetime.fromtimestamp(record.created, tz=timezone.utc) + log_record["@timestamp"] = dt.isoformat().replace("+00:00", "Z") + + # Add hostname and process ID for log aggregation + log_record["hostname"] = socket.gethostname() + log_record["process_id"] = os.getpid() + + # Add correlation ID from context + correlation_id = get_correlation_id() + if correlation_id: + log_record["request_id"] = correlation_id + + # Add OpenTelemetry trace context if available + if trace is not None: + try: + span = trace.get_current_span() + if span and span.is_recording(): + span_context = span.get_span_context() + if span_context.is_valid: + # Format trace_id and span_id as hex strings + log_record["trace_id"] = format(span_context.trace_id, "032x") + log_record["span_id"] = format(span_context.span_id, "016x") + log_record["trace_flags"] = format(span_context.trace_flags, "02x") + except Exception: # nosec B110 - intentionally catching all exceptions for optional tracing + # Error accessing span context, continue without trace fields + pass + + +# Create a JSON formatter with correlation ID support +json_formatter = CorrelationIdJsonFormatter("%(asctime)s %(name)s %(levelname)s %(message)s") # Note: Don't use basicConfig here as it conflicts with our custom dual logging setup # The LoggingService.initialize() method will properly configure all handlers diff --git a/mcpgateway/services/performance_tracker.py b/mcpgateway/services/performance_tracker.py new file mode 100644 index 000000000..dcf813979 --- /dev/null +++ b/mcpgateway/services/performance_tracker.py @@ -0,0 +1,304 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/services/performance_tracker.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 + +Performance Tracking Service. + +This module provides performance tracking and analytics for all operations +across the MCP Gateway, enabling identification of bottlenecks and +optimization opportunities. +""" + +# Standard +from collections import defaultdict +from contextlib import contextmanager +import logging +import statistics +import time +from typing import Any, Dict, Generator, List, Optional + +# First-Party +from mcpgateway.config import settings +from mcpgateway.utils.correlation_id import get_correlation_id + +logger = logging.getLogger(__name__) + + +class PerformanceTracker: + """Tracks and analyzes performance metrics across requests. + + Provides context managers for tracking operation timing, + aggregation of metrics, and threshold-based alerting. + """ + + def __init__(self): + """Initialize performance tracker.""" + self.operation_timings: Dict[str, List[float]] = defaultdict(list) + + # Performance thresholds (seconds) from settings or defaults + self.performance_thresholds = { + "database_query": getattr(settings, "perf_threshold_database_query", 0.1), + "tool_invocation": getattr(settings, "perf_threshold_tool_invocation", 2.0), + "authentication": getattr(settings, "perf_threshold_authentication", 0.5), + "cache_operation": getattr(settings, "perf_threshold_cache_operation", 0.01), + "a2a_task": getattr(settings, "perf_threshold_a2a_task", 5.0), + "request_total": getattr(settings, "perf_threshold_request_total", 10.0), + "resource_fetch": getattr(settings, "perf_threshold_resource_fetch", 1.0), + "prompt_processing": getattr(settings, "perf_threshold_prompt_processing", 0.5), + } + + # Max buffer size per operation type + self.max_samples = getattr(settings, "perf_max_samples_per_operation", 1000) + + @contextmanager + def track_operation(self, operation_name: str, component: Optional[str] = None, log_slow: bool = True, extra_context: Optional[Dict[str, Any]] = None) -> Generator[None, None, None]: + """Context manager to track operation performance. + + Args: + operation_name: Name of the operation being tracked + component: Component/module name for context + log_slow: Whether to log operations exceeding thresholds + extra_context: Additional context to include in logs + + Yields: + None + + Raises: + Exception: Any exception from the tracked operation is re-raised + + Example: + >>> tracker = PerformanceTracker() + >>> with tracker.track_operation("database_query", component="tool_service"): + ... # Perform database operation + ... pass + """ + start_time = time.time() + correlation_id = get_correlation_id() + error_occurred = False + + try: + yield + except Exception: + error_occurred = True + raise + finally: + duration = time.time() - start_time + + # Record timing + self.operation_timings[operation_name].append(duration) + + # Limit buffer size + if len(self.operation_timings[operation_name]) > self.max_samples: + self.operation_timings[operation_name].pop(0) + + # Check threshold and log if needed + threshold = self.performance_thresholds.get(operation_name, float("inf")) + threshold_exceeded = duration > threshold + + if log_slow and threshold_exceeded: + context = { + "operation": operation_name, + "duration_ms": duration * 1000, + "threshold_ms": threshold * 1000, + "exceeded_by_ms": (duration - threshold) * 1000, + "component": component, + "correlation_id": correlation_id, + "error_occurred": error_occurred, + } + if extra_context: + context.update(extra_context) + + logger.warning(f"Slow operation detected: {operation_name} took {duration*1000:.2f}ms " f"(threshold: {threshold*1000:.2f}ms)", extra=context) + + def record_timing(self, operation_name: str, duration: float, component: Optional[str] = None, extra_context: Optional[Dict[str, Any]] = None) -> None: + """Manually record a timing measurement. + + Args: + operation_name: Name of the operation + duration: Duration in seconds + component: Component/module name + extra_context: Additional context + """ + self.operation_timings[operation_name].append(duration) + + # Limit buffer size + if len(self.operation_timings[operation_name]) > self.max_samples: + self.operation_timings[operation_name].pop(0) + + # Check threshold + threshold = self.performance_thresholds.get(operation_name, float("inf")) + if duration > threshold: + context = { + "operation": operation_name, + "duration_ms": duration * 1000, + "threshold_ms": threshold * 1000, + "component": component, + "correlation_id": get_correlation_id(), + } + if extra_context: + context.update(extra_context) + + logger.warning(f"Slow operation: {operation_name} took {duration*1000:.2f}ms", extra=context) + + def get_performance_summary(self, operation_name: Optional[str] = None, min_samples: int = 1) -> Dict[str, Any]: + """Get performance summary for analytics. + + Args: + operation_name: Specific operation to summarize (None for all) + min_samples: Minimum samples required to include in summary + + Returns: + Dictionary containing performance statistics + + Example: + >>> tracker = PerformanceTracker() + >>> summary = tracker.get_performance_summary() + >>> isinstance(summary, dict) + True + """ + summary = {} + + operations = {operation_name: self.operation_timings[operation_name]} if operation_name and operation_name in self.operation_timings else self.operation_timings + + for op_name, timings in operations.items(): + if len(timings) < min_samples: + continue + + # Calculate percentiles + sorted_timings = sorted(timings) + count = len(sorted_timings) + + def percentile(p: float, *, sorted_vals=sorted_timings, n=count) -> float: + """Calculate percentile value. + + Args: + p: Percentile to calculate (0.0 to 1.0) + sorted_vals: Sorted list of values + n: Number of values + + Returns: + float: Calculated percentile value + """ + k = (n - 1) * p + f = int(k) + c = k - f + if f + 1 < n: + return sorted_vals[f] * (1 - c) + sorted_vals[f + 1] * c + return sorted_vals[f] + + summary[op_name] = { + "count": count, + "avg_duration_ms": statistics.mean(timings) * 1000, + "min_duration_ms": min(timings) * 1000, + "max_duration_ms": max(timings) * 1000, + "p50_duration_ms": percentile(0.5) * 1000, + "p95_duration_ms": percentile(0.95) * 1000, + "p99_duration_ms": percentile(0.99) * 1000, + "threshold_ms": self.performance_thresholds.get(op_name, float("inf")) * 1000, + "threshold_violations": sum(1 for t in timings if t > self.performance_thresholds.get(op_name, float("inf"))), + "violation_rate": sum(1 for t in timings if t > self.performance_thresholds.get(op_name, float("inf"))) / count, + } + + return summary + + def get_operation_stats(self, operation_name: str) -> Optional[Dict[str, Any]]: + """Get statistics for a specific operation. + + Args: + operation_name: Name of the operation + + Returns: + Statistics dictionary or None if no data + """ + if operation_name not in self.operation_timings: + return None + + timings = self.operation_timings[operation_name] + if not timings: + return None + + return { + "operation": operation_name, + "sample_count": len(timings), + "avg_duration_ms": statistics.mean(timings) * 1000, + "min_duration_ms": min(timings) * 1000, + "max_duration_ms": max(timings) * 1000, + "total_time_ms": sum(timings) * 1000, + "threshold_ms": self.performance_thresholds.get(operation_name, float("inf")) * 1000, + } + + def clear_stats(self, operation_name: Optional[str] = None) -> None: + """Clear performance statistics. + + Args: + operation_name: Specific operation to clear (None for all) + """ + if operation_name: + if operation_name in self.operation_timings: + self.operation_timings[operation_name].clear() + else: + self.operation_timings.clear() + + def set_threshold(self, operation_name: str, threshold_seconds: float) -> None: + """Set or update performance threshold for an operation. + + Args: + operation_name: Name of the operation + threshold_seconds: Threshold in seconds + """ + self.performance_thresholds[operation_name] = threshold_seconds + + def check_performance_degradation(self, operation_name: str, baseline_multiplier: float = 2.0) -> Dict[str, Any]: + """Check if performance has degraded compared to baseline. + + Args: + operation_name: Name of the operation to check + baseline_multiplier: Multiplier for degradation detection + + Returns: + Dictionary with degradation analysis + """ + if operation_name not in self.operation_timings: + return {"degraded": False, "reason": "no_data"} + + timings = self.operation_timings[operation_name] + if len(timings) < 10: + return {"degraded": False, "reason": "insufficient_samples"} + + # Compare recent timings to overall average + recent_count = min(10, len(timings)) + recent_timings = timings[-recent_count:] + historical_timings = timings[:-recent_count] if len(timings) > recent_count else timings + + if not historical_timings: + return {"degraded": False, "reason": "insufficient_historical_data"} + + recent_avg = statistics.mean(recent_timings) + historical_avg = statistics.mean(historical_timings) + + degraded = recent_avg > (historical_avg * baseline_multiplier) + + return { + "degraded": degraded, + "recent_avg_ms": recent_avg * 1000, + "historical_avg_ms": historical_avg * 1000, + "multiplier": recent_avg / historical_avg if historical_avg > 0 else 0, + "threshold_multiplier": baseline_multiplier, + } + + +# Global performance tracker instance +_performance_tracker: Optional[PerformanceTracker] = None + + +def get_performance_tracker() -> PerformanceTracker: + """Get or create the global performance tracker instance. + + Returns: + Global PerformanceTracker instance + """ + global _performance_tracker # pylint: disable=global-statement + if _performance_tracker is None: + _performance_tracker = PerformanceTracker() + return _performance_tracker diff --git a/mcpgateway/services/prompt_service.py b/mcpgateway/services/prompt_service.py index dac07f888..cd3841563 100644 --- a/mcpgateway/services/prompt_service.py +++ b/mcpgateway/services/prompt_service.py @@ -37,9 +37,11 @@ from mcpgateway.observability import create_span from mcpgateway.plugins.framework import GlobalContext, PluginContextTable, PluginManager, PromptHookType, PromptPosthookPayload, PromptPrehookPayload from mcpgateway.schemas import PromptCreate, PromptRead, PromptUpdate, TopPerformer +from mcpgateway.services.audit_trail_service import get_audit_trail_service from mcpgateway.services.event_service import EventService from mcpgateway.services.logging_service import LoggingService from mcpgateway.services.observability_service import current_trace_id, ObservabilityService +from mcpgateway.services.structured_logger import get_structured_logger from mcpgateway.utils.metrics_common import build_top_performers from mcpgateway.utils.pagination import decode_cursor, encode_cursor from mcpgateway.utils.sqlalchemy_modifier import json_contains_expr @@ -48,6 +50,10 @@ logging_service = LoggingService() logger = logging_service.get_logger(__name__) +# Initialize structured logger and audit trail for prompt operations +structured_logger = get_structured_logger("prompt_service") +audit_trail = get_audit_trail_service() + class PromptError(Exception): """Base class for prompt-related errors.""" @@ -401,18 +407,95 @@ async def register_prompt( await self._notify_prompt_added(db_prompt) logger.info(f"Registered prompt: {prompt.name}") + + # Structured logging: Audit trail for prompt creation + audit_trail.log_action( + user_id=created_by or "system", + action="create_prompt", + resource_type="prompt", + resource_id=str(db_prompt.id), + resource_name=db_prompt.name, + user_email=owner_email, + team_id=team_id, + client_ip=created_from_ip, + user_agent=created_user_agent, + new_values={ + "name": db_prompt.name, + "visibility": visibility, + }, + context={ + "created_via": created_via, + "import_batch_id": import_batch_id, + "federation_source": federation_source, + }, + db=db, + ) + + # Structured logging: Log successful prompt creation + structured_logger.log( + level="INFO", + message="Prompt created successfully", + event_type="prompt_created", + component="prompt_service", + user_id=created_by, + user_email=owner_email, + team_id=team_id, + resource_type="prompt", + resource_id=str(db_prompt.id), + custom_fields={ + "prompt_name": db_prompt.name, + "visibility": visibility, + }, + db=db, + ) + db_prompt.team = self._get_team_name(db, db_prompt.team_id) prompt_dict = self._convert_db_prompt(db_prompt) return PromptRead.model_validate(prompt_dict) except IntegrityError as ie: logger.error(f"IntegrityErrors in group: {ie}") + + structured_logger.log( + level="ERROR", + message="Prompt creation failed due to database integrity error", + event_type="prompt_creation_failed", + component="prompt_service", + user_id=created_by, + user_email=owner_email, + error=ie, + custom_fields={"prompt_name": prompt.name}, + db=db, + ) raise ie except PromptNameConflictError as se: db.rollback() + + structured_logger.log( + level="WARNING", + message="Prompt creation failed due to name conflict", + event_type="prompt_name_conflict", + component="prompt_service", + user_id=created_by, + user_email=owner_email, + custom_fields={"prompt_name": prompt.name, "visibility": visibility}, + db=db, + ) raise se except Exception as e: db.rollback() + + structured_logger.log( + level="ERROR", + message="Prompt creation failed", + event_type="prompt_creation_failed", + component="prompt_service", + user_id=created_by, + user_email=owner_email, + error=e, + custom_fields={"prompt_name": prompt.name}, + db=db, + ) raise PromptError(f"Failed to register prompt: {str(e)}") async def list_prompts(self, db: Session, include_inactive: bool = False, cursor: Optional[str] = None, tags: Optional[List[str]] = None) -> tuple[List[PromptRead], Optional[str]]: @@ -826,6 +909,43 @@ async def get_prompt( # Use modified payload if provided result = post_result.modified_payload.result if post_result.modified_payload else result + arguments_supplied = bool(arguments) + + audit_trail.log_action( + user_id=user or "anonymous", + action="view_prompt", + resource_type="prompt", + resource_id=str(prompt.id), + resource_name=prompt.name, + team_id=prompt.team_id, + context={ + "tenant_id": tenant_id, + "server_id": server_id, + "arguments_provided": arguments_supplied, + "request_id": request_id, + }, + db=db, + ) + + structured_logger.log( + level="INFO", + message="Prompt retrieved successfully", + event_type="prompt_viewed", + component="prompt_service", + user_id=user, + team_id=prompt.team_id, + resource_type="prompt", + resource_id=str(prompt.id), + request_id=request_id, + custom_fields={ + "prompt_name": prompt.name, + "arguments_provided": arguments_supplied, + "tenant_id": tenant_id, + "server_id": server_id, + }, + db=db, + ) + # Set success attributes on span if span: span.set_attribute("success", True) @@ -990,26 +1110,117 @@ async def update_prompt( db.refresh(prompt) await self._notify_prompt_updated(prompt) + + # Structured logging: Audit trail for prompt update + audit_trail.log_action( + user_id=user_email or modified_by or "system", + action="update_prompt", + resource_type="prompt", + resource_id=str(prompt.id), + resource_name=prompt.name, + user_email=user_email, + team_id=prompt.team_id, + client_ip=modified_from_ip, + user_agent=modified_user_agent, + new_values={"name": prompt.name, "version": prompt.version}, + context={"modified_via": modified_via}, + db=db, + ) + + structured_logger.log( + level="INFO", + message="Prompt updated successfully", + event_type="prompt_updated", + component="prompt_service", + user_id=modified_by, + user_email=user_email, + team_id=prompt.team_id, + resource_type="prompt", + resource_id=str(prompt.id), + custom_fields={"prompt_name": prompt.name, "version": prompt.version}, + db=db, + ) + prompt.team = self._get_team_name(db, prompt.team_id) return PromptRead.model_validate(self._convert_db_prompt(prompt)) - except PermissionError: + except PermissionError as pe: db.rollback() + + structured_logger.log( + level="WARNING", + message="Prompt update failed due to permission error", + event_type="prompt_update_permission_denied", + component="prompt_service", + user_email=user_email, + resource_type="prompt", + resource_id=str(prompt_id), + error=pe, + db=db, + ) raise except IntegrityError as ie: db.rollback() logger.error(f"IntegrityErrors in group: {ie}") + + structured_logger.log( + level="ERROR", + message="Prompt update failed due to database integrity error", + event_type="prompt_update_failed", + component="prompt_service", + user_email=user_email, + resource_type="prompt", + resource_id=str(prompt_id), + error=ie, + db=db, + ) raise ie except PromptNotFoundError as e: db.rollback() logger.error(f"Prompt not found: {e}") + + structured_logger.log( + level="ERROR", + message="Prompt update failed - prompt not found", + event_type="prompt_not_found", + component="prompt_service", + user_email=user_email, + resource_type="prompt", + resource_id=str(prompt_id), + error=e, + db=db, + ) raise e except PromptNameConflictError as pnce: db.rollback() logger.error(f"Prompt name conflict: {pnce}") + + structured_logger.log( + level="WARNING", + message="Prompt update failed due to name conflict", + event_type="prompt_name_conflict", + component="prompt_service", + user_email=user_email, + resource_type="prompt", + resource_id=str(prompt_id), + error=pnce, + db=db, + ) raise pnce except Exception as e: db.rollback() + + structured_logger.log( + level="ERROR", + message="Prompt update failed", + event_type="prompt_update_failed", + component="prompt_service", + user_email=user_email, + resource_type="prompt", + resource_id=str(prompt_id), + error=e, + db=db, + ) raise PromptError(f"Failed to update prompt: {str(e)}") async def toggle_prompt_status(self, db: Session, prompt_id: int, activate: bool, user_email: Optional[str] = None) -> PromptRead: @@ -1071,12 +1282,63 @@ async def toggle_prompt_status(self, db: Session, prompt_id: int, activate: bool else: await self._notify_prompt_deactivated(prompt) logger.info(f"Prompt {prompt.name} {'activated' if activate else 'deactivated'}") + + # Structured logging: Audit trail for prompt status toggle + audit_trail.log_action( + user_id=user_email or "system", + action="toggle_prompt_status", + resource_type="prompt", + resource_id=str(prompt.id), + resource_name=prompt.name, + user_email=user_email, + team_id=prompt.team_id, + new_values={"enabled": prompt.enabled}, + context={"action": "activate" if activate else "deactivate"}, + db=db, + ) + + structured_logger.log( + level="INFO", + message=f"Prompt {'activated' if activate else 'deactivated'} successfully", + event_type="prompt_status_toggled", + component="prompt_service", + user_email=user_email, + team_id=prompt.team_id, + resource_type="prompt", + resource_id=str(prompt.id), + custom_fields={"prompt_name": prompt.name, "enabled": prompt.enabled}, + db=db, + ) + prompt.team = self._get_team_name(db, prompt.team_id) return PromptRead.model_validate(self._convert_db_prompt(prompt)) except PermissionError as e: + structured_logger.log( + level="WARNING", + message="Prompt status toggle failed due to permission error", + event_type="prompt_toggle_permission_denied", + component="prompt_service", + user_email=user_email, + resource_type="prompt", + resource_id=str(prompt_id), + error=e, + db=db, + ) raise e except Exception as e: db.rollback() + + structured_logger.log( + level="ERROR", + message="Prompt status toggle failed", + event_type="prompt_toggle_failed", + component="prompt_service", + user_email=user_email, + resource_type="prompt", + resource_id=str(prompt_id), + error=e, + db=db, + ) raise PromptError(f"Failed to toggle prompt status: {str(e)}") # Get prompt details for admin ui @@ -1113,7 +1375,35 @@ async def get_prompt_details(self, db: Session, prompt_id: Union[int, str], incl raise PromptNotFoundError(f"Prompt not found: {prompt_id}") # Return the fully converted prompt including metrics prompt.team = self._get_team_name(db, prompt.team_id) - return self._convert_db_prompt(prompt) + prompt_data = self._convert_db_prompt(prompt) + + audit_trail.log_action( + user_id="system", + action="view_prompt_details", + resource_type="prompt", + resource_id=str(prompt.id), + resource_name=prompt.name, + team_id=prompt.team_id, + context={"include_inactive": include_inactive}, + db=db, + ) + + structured_logger.log( + level="INFO", + message="Prompt details retrieved", + event_type="prompt_details_viewed", + component="prompt_service", + resource_type="prompt", + resource_id=str(prompt.id), + team_id=prompt.team_id, + custom_fields={ + "prompt_name": prompt.name, + "include_inactive": include_inactive, + }, + db=db, + ) + + return prompt_data async def delete_prompt(self, db: Session, prompt_id: Union[int, str], user_email: Optional[str] = None) -> None: """ @@ -1161,17 +1451,85 @@ async def delete_prompt(self, db: Session, prompt_id: Union[int, str], user_emai raise PermissionError("Only the owner can delete this prompt") prompt_info = {"id": prompt.id, "name": prompt.name} + prompt_name = prompt.name + prompt_team_id = prompt.team_id + db.delete(prompt) db.commit() await self._notify_prompt_deleted(prompt_info) logger.info(f"Deleted prompt: {prompt_info['name']}") - except PermissionError: + + # Structured logging: Audit trail for prompt deletion + audit_trail.log_action( + user_id=user_email or "system", + action="delete_prompt", + resource_type="prompt", + resource_id=str(prompt_info["id"]), + resource_name=prompt_name, + user_email=user_email, + team_id=prompt_team_id, + old_values={"name": prompt_name}, + db=db, + ) + + # Structured logging: Log successful prompt deletion + structured_logger.log( + level="INFO", + message="Prompt deleted successfully", + event_type="prompt_deleted", + component="prompt_service", + user_email=user_email, + team_id=prompt_team_id, + resource_type="prompt", + resource_id=str(prompt_info["id"]), + custom_fields={"prompt_name": prompt_name}, + db=db, + ) + except PermissionError as pe: db.rollback() + + # Structured logging: Log permission error + structured_logger.log( + level="WARNING", + message="Prompt deletion failed due to permission error", + event_type="prompt_delete_permission_denied", + component="prompt_service", + user_email=user_email, + resource_type="prompt", + resource_id=str(prompt_id), + error=pe, + db=db, + ) raise except Exception as e: db.rollback() if isinstance(e, PromptNotFoundError): + # Structured logging: Log not found error + structured_logger.log( + level="ERROR", + message="Prompt deletion failed - prompt not found", + event_type="prompt_not_found", + component="prompt_service", + user_email=user_email, + resource_type="prompt", + resource_id=str(prompt_id), + error=e, + db=db, + ) raise e + + # Structured logging: Log generic prompt deletion failure + structured_logger.log( + level="ERROR", + message="Prompt deletion failed", + event_type="prompt_deletion_failed", + component="prompt_service", + user_email=user_email, + resource_type="prompt", + resource_id=str(prompt_id), + error=e, + db=db, + ) raise PromptError(f"Failed to delete prompt: {str(e)}") async def subscribe_events(self) -> AsyncGenerator[Dict[str, Any], None]: diff --git a/mcpgateway/services/resource_service.py b/mcpgateway/services/resource_service.py index 2b0ef30b5..1b8136a51 100644 --- a/mcpgateway/services/resource_service.py +++ b/mcpgateway/services/resource_service.py @@ -51,10 +51,12 @@ from mcpgateway.db import server_resource_association from mcpgateway.observability import create_span from mcpgateway.schemas import ResourceCreate, ResourceMetrics, ResourceRead, ResourceSubscription, ResourceUpdate, TopPerformer +from mcpgateway.services.audit_trail_service import get_audit_trail_service from mcpgateway.services.event_service import EventService from mcpgateway.services.logging_service import LoggingService from mcpgateway.services.oauth_manager import OAuthManager from mcpgateway.services.observability_service import current_trace_id, ObservabilityService +from mcpgateway.services.structured_logger import get_structured_logger from mcpgateway.utils.metrics_common import build_top_performers from mcpgateway.utils.pagination import decode_cursor, encode_cursor from mcpgateway.utils.services_auth import decode_auth @@ -74,6 +76,10 @@ logging_service = LoggingService() logger = logging_service.get_logger(__name__) +# Initialize structured logger and audit trail for resource operations +structured_logger = get_structured_logger("resource_service") +audit_trail = get_audit_trail_service() + class ResourceError(Exception): """Base class for resource-related errors.""" @@ -240,6 +246,17 @@ def _convert_resource_to_read(self, resource: DbResource) -> ResourceRead: resource_dict.pop("_sa_instance_state", None) resource_dict.pop("metrics", None) + # Ensure required base fields are present even if SQLAlchemy hasn't loaded them into __dict__ yet + resource_dict["id"] = getattr(resource, "id", resource_dict.get("id")) + resource_dict["uri"] = getattr(resource, "uri", resource_dict.get("uri")) + resource_dict["name"] = getattr(resource, "name", resource_dict.get("name")) + resource_dict["description"] = getattr(resource, "description", resource_dict.get("description")) + resource_dict["mime_type"] = getattr(resource, "mime_type", resource_dict.get("mime_type")) + resource_dict["size"] = getattr(resource, "size", resource_dict.get("size")) + resource_dict["created_at"] = getattr(resource, "created_at", resource_dict.get("created_at")) + resource_dict["updated_at"] = getattr(resource, "updated_at", resource_dict.get("updated_at")) + resource_dict["is_active"] = getattr(resource, "is_active", resource_dict.get("is_active")) + # Compute aggregated metrics from the resource's metrics list. total = len(resource.metrics) if hasattr(resource, "metrics") and resource.metrics is not None else 0 successful = sum(1 for m in resource.metrics if m.is_success) if total > 0 else 0 @@ -397,16 +414,106 @@ async def register_resource( await self._notify_resource_added(db_resource) logger.info(f"Registered resource: {resource.uri}") + + # Structured logging: Audit trail for resource creation + audit_trail.log_action( + user_id=created_by or "system", + action="create_resource", + resource_type="resource", + resource_id=str(db_resource.id), + resource_name=db_resource.name, + user_email=owner_email, + team_id=team_id, + client_ip=created_from_ip, + user_agent=created_user_agent, + new_values={ + "uri": db_resource.uri, + "name": db_resource.name, + "visibility": visibility, + "mime_type": db_resource.mime_type, + }, + context={ + "created_via": created_via, + "import_batch_id": import_batch_id, + "federation_source": federation_source, + }, + db=db, + ) + + # Structured logging: Log successful resource creation + structured_logger.log( + level="INFO", + message="Resource created successfully", + event_type="resource_created", + component="resource_service", + user_id=created_by, + user_email=owner_email, + team_id=team_id, + resource_type="resource", + resource_id=str(db_resource.id), + custom_fields={ + "resource_uri": db_resource.uri, + "resource_name": db_resource.name, + "visibility": visibility, + }, + db=db, + ) + db_resource.team = self._get_team_name(db, db_resource.team_id) return self._convert_resource_to_read(db_resource) except IntegrityError as ie: logger.error(f"IntegrityErrors in group: {ie}") + + # Structured logging: Log database integrity error + structured_logger.log( + level="ERROR", + message="Resource creation failed due to database integrity error", + event_type="resource_creation_failed", + component="resource_service", + user_id=created_by, + user_email=owner_email, + error=ie, + custom_fields={ + "resource_uri": resource.uri, + }, + db=db, + ) raise ie except ResourceURIConflictError as rce: logger.error(f"ResourceURIConflictError in group: {resource.uri}") + + # Structured logging: Log URI conflict error + structured_logger.log( + level="WARNING", + message="Resource creation failed due to URI conflict", + event_type="resource_uri_conflict", + component="resource_service", + user_id=created_by, + user_email=owner_email, + custom_fields={ + "resource_uri": resource.uri, + "visibility": visibility, + }, + db=db, + ) raise rce except Exception as e: db.rollback() + + # Structured logging: Log generic resource creation failure + structured_logger.log( + level="ERROR", + message="Resource creation failed", + event_type="resource_creation_failed", + component="resource_service", + user_id=created_by, + user_email=owner_email, + error=e, + custom_fields={ + "resource_uri": resource.uri, + }, + db=db, + ) raise ResourceError(f"Failed to register resource: {str(e)}") async def list_resources(self, db: Session, include_inactive: bool = False, cursor: Optional[str] = None, tags: Optional[List[str]] = None) -> tuple[List[ResourceRead], Optional[str]]: @@ -1463,12 +1570,72 @@ async def toggle_resource_status(self, db: Session, resource_id: int, activate: logger.info(f"Resource {resource.uri} {'activated' if activate else 'deactivated'}") + # Structured logging: Audit trail for resource status toggle + audit_trail.log_action( + user_id=user_email or "system", + action="toggle_resource_status", + resource_type="resource", + resource_id=str(resource.id), + resource_name=resource.name, + user_email=user_email, + team_id=resource.team_id, + new_values={ + "enabled": resource.enabled, + }, + context={ + "action": "activate" if activate else "deactivate", + }, + db=db, + ) + + # Structured logging: Log successful resource status toggle + structured_logger.log( + level="INFO", + message=f"Resource {'activated' if activate else 'deactivated'} successfully", + event_type="resource_status_toggled", + component="resource_service", + user_email=user_email, + team_id=resource.team_id, + resource_type="resource", + resource_id=str(resource.id), + custom_fields={ + "resource_uri": resource.uri, + "enabled": resource.enabled, + }, + db=db, + ) + resource.team = self._get_team_name(db, resource.team_id) return self._convert_resource_to_read(resource) except PermissionError as e: + # Structured logging: Log permission error + structured_logger.log( + level="WARNING", + message="Resource status toggle failed due to permission error", + event_type="resource_toggle_permission_denied", + component="resource_service", + user_email=user_email, + resource_type="resource", + resource_id=str(resource_id), + error=e, + db=db, + ) raise e except Exception as e: db.rollback() + + # Structured logging: Log generic resource status toggle failure + structured_logger.log( + level="ERROR", + message="Resource status toggle failed", + event_type="resource_toggle_failed", + component="resource_service", + user_email=user_email, + resource_type="resource", + resource_id=str(resource_id), + error=e, + db=db, + ) raise ResourceError(f"Failed to toggle resource status: {str(e)}") async def subscribe_resource(self, db: Session, subscription: ResourceSubscription) -> None: @@ -1684,21 +1851,138 @@ async def update_resource( await self._notify_resource_updated(resource) logger.info(f"Updated resource: {resource.uri}") + + # Structured logging: Audit trail for resource update + changes = [] + if resource_update.uri: + changes.append(f"uri: {resource_update.uri}") + if resource_update.visibility: + changes.append(f"visibility: {resource_update.visibility}") + if resource_update.description: + changes.append("description updated") + + audit_trail.log_action( + user_id=user_email or modified_by or "system", + action="update_resource", + resource_type="resource", + resource_id=str(resource.id), + resource_name=resource.name, + user_email=user_email, + team_id=resource.team_id, + client_ip=modified_from_ip, + user_agent=modified_user_agent, + new_values={ + "uri": resource.uri, + "name": resource.name, + "version": resource.version, + }, + context={ + "modified_via": modified_via, + "changes": ", ".join(changes) if changes else "metadata only", + }, + db=db, + ) + + # Structured logging: Log successful resource update + structured_logger.log( + level="INFO", + message="Resource updated successfully", + event_type="resource_updated", + component="resource_service", + user_id=modified_by, + user_email=user_email, + team_id=resource.team_id, + resource_type="resource", + resource_id=str(resource.id), + custom_fields={ + "resource_uri": resource.uri, + "version": resource.version, + }, + db=db, + ) + return self._convert_resource_to_read(resource) - except PermissionError: + except PermissionError as pe: db.rollback() + + # Structured logging: Log permission error + structured_logger.log( + level="WARNING", + message="Resource update failed due to permission error", + event_type="resource_update_permission_denied", + component="resource_service", + user_email=user_email, + resource_type="resource", + resource_id=str(resource_id), + error=pe, + db=db, + ) raise except IntegrityError as ie: db.rollback() logger.error(f"IntegrityErrors in group: {ie}") + + # Structured logging: Log database integrity error + structured_logger.log( + level="ERROR", + message="Resource update failed due to database integrity error", + event_type="resource_update_failed", + component="resource_service", + user_id=modified_by, + user_email=user_email, + resource_type="resource", + resource_id=str(resource_id), + error=ie, + db=db, + ) raise ie except ResourceURIConflictError as pe: logger.error(f"Resource URI conflict: {pe}") + + # Structured logging: Log URI conflict error + structured_logger.log( + level="WARNING", + message="Resource update failed due to URI conflict", + event_type="resource_uri_conflict", + component="resource_service", + user_id=modified_by, + user_email=user_email, + resource_type="resource", + resource_id=str(resource_id), + error=pe, + db=db, + ) raise pe except Exception as e: db.rollback() if isinstance(e, ResourceNotFoundError): + # Structured logging: Log not found error + structured_logger.log( + level="ERROR", + message="Resource update failed - resource not found", + event_type="resource_not_found", + component="resource_service", + user_email=user_email, + resource_type="resource", + resource_id=str(resource_id), + error=e, + db=db, + ) raise e + + # Structured logging: Log generic resource update failure + structured_logger.log( + level="ERROR", + message="Resource update failed", + event_type="resource_update_failed", + component="resource_service", + user_id=modified_by, + user_email=user_email, + resource_type="resource", + resource_id=str(resource_id), + error=e, + db=db, + ) raise ResourceError(f"Failed to update resource: {str(e)}") async def delete_resource(self, db: Session, resource_id: Union[int, str], user_email: Optional[str] = None) -> None: @@ -1757,6 +2041,10 @@ async def delete_resource(self, db: Session, resource_id: Union[int, str], user_ db.execute(delete(DbSubscription).where(DbSubscription.resource_id == resource.id)) # Hard delete the resource. + resource_uri = resource.uri + resource_name = resource.name + resource_team_id = resource.team_id + db.delete(resource) db.commit() @@ -1765,14 +2053,84 @@ async def delete_resource(self, db: Session, resource_id: Union[int, str], user_ logger.info(f"Permanently deleted resource: {resource.uri}") - except PermissionError: + # Structured logging: Audit trail for resource deletion + audit_trail.log_action( + user_id=user_email or "system", + action="delete_resource", + resource_type="resource", + resource_id=str(resource_info["id"]), + resource_name=resource_name, + user_email=user_email, + team_id=resource_team_id, + old_values={ + "uri": resource_uri, + "name": resource_name, + }, + db=db, + ) + + # Structured logging: Log successful resource deletion + structured_logger.log( + level="INFO", + message="Resource deleted successfully", + event_type="resource_deleted", + component="resource_service", + user_email=user_email, + team_id=resource_team_id, + resource_type="resource", + resource_id=str(resource_info["id"]), + custom_fields={ + "resource_uri": resource_uri, + }, + db=db, + ) + + except PermissionError as pe: db.rollback() + + # Structured logging: Log permission error + structured_logger.log( + level="WARNING", + message="Resource deletion failed due to permission error", + event_type="resource_delete_permission_denied", + component="resource_service", + user_email=user_email, + resource_type="resource", + resource_id=str(resource_id), + error=pe, + db=db, + ) raise - except ResourceNotFoundError: + except ResourceNotFoundError as rnfe: # ResourceNotFoundError is re-raised to be handled in the endpoint. + # Structured logging: Log not found error + structured_logger.log( + level="ERROR", + message="Resource deletion failed - resource not found", + event_type="resource_not_found", + component="resource_service", + user_email=user_email, + resource_type="resource", + resource_id=str(resource_id), + error=rnfe, + db=db, + ) raise except Exception as e: db.rollback() + + # Structured logging: Log generic resource deletion failure + structured_logger.log( + level="ERROR", + message="Resource deletion failed", + event_type="resource_deletion_failed", + component="resource_service", + user_email=user_email, + resource_type="resource", + resource_id=str(resource_id), + error=e, + db=db, + ) raise ResourceError(f"Failed to delete resource: {str(e)}") async def get_resource_by_id(self, db: Session, resource_id: str, include_inactive: bool = False) -> ResourceRead: @@ -1819,7 +2177,24 @@ async def get_resource_by_id(self, db: Session, resource_id: str, include_inacti raise ResourceNotFoundError(f"Resource not found: {resource_id}") - return self._convert_resource_to_read(resource) + resource_read = self._convert_resource_to_read(resource) + + structured_logger.log( + level="INFO", + message="Resource retrieved successfully", + event_type="resource_viewed", + component="resource_service", + team_id=getattr(resource, "team_id", None), + resource_type="resource", + resource_id=str(resource.id), + custom_fields={ + "resource_uri": resource.uri, + "include_inactive": include_inactive, + }, + db=db, + ) + + return resource_read async def _notify_resource_activated(self, resource: DbResource) -> None: """ diff --git a/mcpgateway/services/security_logger.py b/mcpgateway/services/security_logger.py new file mode 100644 index 000000000..1b2470691 --- /dev/null +++ b/mcpgateway/services/security_logger.py @@ -0,0 +1,597 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/services/security_logger.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 + +Security Logger Service. + +This module provides specialized logging for security events, threat detection, +and audit trail management with automated threat analysis and alerting. +""" + +# Standard +from datetime import datetime, timedelta, timezone +from enum import Enum +import logging +from typing import Any, Dict, Optional + +# Third-Party +from sqlalchemy import func, select +from sqlalchemy.orm import Session + +# First-Party +from mcpgateway.config import settings +from mcpgateway.db import AuditTrail, SecurityEvent, SessionLocal +from mcpgateway.utils.correlation_id import get_correlation_id + +logger = logging.getLogger(__name__) + + +class SecuritySeverity(str, Enum): + """Security event severity levels.""" + + LOW = "LOW" + MEDIUM = "MEDIUM" + HIGH = "HIGH" + CRITICAL = "CRITICAL" + + +class SecurityEventType(str, Enum): + """Types of security events.""" + + AUTHENTICATION_FAILURE = "authentication_failure" + AUTHENTICATION_SUCCESS = "authentication_success" + AUTHORIZATION_FAILURE = "authorization_failure" + SUSPICIOUS_ACTIVITY = "suspicious_activity" + RATE_LIMIT_EXCEEDED = "rate_limit_exceeded" + BRUTE_FORCE_ATTEMPT = "brute_force_attempt" + TOKEN_MANIPULATION = "token_manipulation" # nosec B105 - Not a password, security event type constant + DATA_EXFILTRATION = "data_exfiltration" + PRIVILEGE_ESCALATION = "privilege_escalation" + INJECTION_ATTEMPT = "injection_attempt" + ANOMALOUS_BEHAVIOR = "anomalous_behavior" + + +class SecurityLogger: + """Specialized logger for security events and audit trails. + + Provides threat detection, security event logging, and audit trail + management with automated analysis and alerting capabilities. + """ + + def __init__(self): + """Initialize security logger.""" + self.failed_auth_threshold = getattr(settings, "security_failed_auth_threshold", 5) + self.threat_score_alert_threshold = getattr(settings, "security_threat_score_alert", 0.7) + self.rate_limit_window_minutes = getattr(settings, "security_rate_limit_window", 5) + + def log_authentication_attempt( + self, + user_id: str, + user_email: Optional[str], + auth_method: str, + success: bool, + client_ip: str, + user_agent: Optional[str] = None, + failure_reason: Optional[str] = None, + additional_context: Optional[Dict[str, Any]] = None, + db: Optional[Session] = None, + ) -> Optional[SecurityEvent]: + """Log authentication attempts with security analysis. + + Args: + user_id: User identifier + user_email: User email address + auth_method: Authentication method used + success: Whether authentication succeeded + client_ip: Client IP address + user_agent: Client user agent + failure_reason: Reason for failure if applicable + additional_context: Additional event context + db: Optional database session + + Returns: + Created SecurityEvent or None if logging disabled + """ + correlation_id = get_correlation_id() + + # Count recent failed attempts + failed_attempts = self._count_recent_failures(user_id=user_id, client_ip=client_ip, db=db) + + # Calculate threat score + threat_score = self._calculate_auth_threat_score(success=success, failed_attempts=failed_attempts, auth_method=auth_method) + + # Determine severity + if not success: + if failed_attempts >= self.failed_auth_threshold: + severity = SecuritySeverity.HIGH + elif failed_attempts >= 3: + severity = SecuritySeverity.MEDIUM + else: + severity = SecuritySeverity.LOW + else: + severity = SecuritySeverity.LOW + + # Build event description + description = f"Authentication {'successful' if success else 'failed'} for user {user_id}" + if not success and failure_reason: + description += f": {failure_reason}" + + # Build context + context = {"auth_method": auth_method, "failed_attempts_recent": failed_attempts, "user_agent": user_agent, **(additional_context or {})} + + # Create security event + event = self._create_security_event( + event_type=SecurityEventType.AUTHENTICATION_SUCCESS if success else SecurityEventType.AUTHENTICATION_FAILURE, + severity=severity, + category="authentication", + user_id=user_id, + user_email=user_email, + client_ip=client_ip, + user_agent=user_agent, + description=description, + threat_score=threat_score, + failed_attempts_count=failed_attempts, + context=context, + action_taken="allowed" if success else "denied", + correlation_id=correlation_id, + db=db, + ) + + # Log to standard logger as well + log_level = logging.WARNING if not success else logging.INFO + logger.log( + log_level, + f"Authentication attempt: {description}", + extra={ + "security_event": True, + "event_type": event.event_type if event else None, + "severity": severity.value, + "threat_score": threat_score, + "correlation_id": correlation_id, + }, + ) + + return event + + def log_data_access( # pylint: disable=too-many-positional-arguments + self, + action: str, + resource_type: str, + resource_id: str, + resource_name: Optional[str], + user_id: str, + user_email: Optional[str], + team_id: Optional[str], + client_ip: Optional[str], + user_agent: Optional[str], + success: bool, + data_classification: Optional[str] = None, + old_values: Optional[Dict[str, Any]] = None, + new_values: Optional[Dict[str, Any]] = None, + error_message: Optional[str] = None, + additional_context: Optional[Dict[str, Any]] = None, + db: Optional[Session] = None, + ) -> Optional[AuditTrail]: + """Log data access for audit trails. + + Args: + action: Action performed (create, read, update, delete, execute) + resource_type: Type of resource accessed + resource_id: Resource identifier + resource_name: Resource name + user_id: User performing the action + user_email: User email + team_id: Team context + client_ip: Client IP address + user_agent: Client user agent + success: Whether action succeeded + data_classification: Data sensitivity classification + old_values: Previous values (for updates) + new_values: New values (for updates/creates) + error_message: Error message if failed + additional_context: Additional context + db: Optional database session + + Returns: + Created AuditTrail entry or None + """ + correlation_id = get_correlation_id() + + # Determine if audit requires review + requires_review = self._requires_audit_review(action=action, resource_type=resource_type, data_classification=data_classification, success=success) + + # Calculate changes + changes = None + if old_values and new_values: + changes = {k: {"old": old_values.get(k), "new": new_values.get(k)} for k in set(old_values.keys()) | set(new_values.keys()) if old_values.get(k) != new_values.get(k)} + + # Create audit trail + audit = self._create_audit_trail( + action=action, + resource_type=resource_type, + resource_id=resource_id, + resource_name=resource_name, + user_id=user_id, + user_email=user_email, + team_id=team_id, + client_ip=client_ip, + user_agent=user_agent, + success=success, + old_values=old_values, + new_values=new_values, + changes=changes, + data_classification=data_classification, + requires_review=requires_review, + error_message=error_message, + context=additional_context, + correlation_id=correlation_id, + db=db, + ) + + # Log sensitive data access as security event + if data_classification in ["confidential", "restricted", "sensitive"]: + self._create_security_event( + event_type="data_access", + severity=SecuritySeverity.MEDIUM if success else SecuritySeverity.HIGH, + category="data_access", + user_id=user_id, + user_email=user_email, + client_ip=client_ip or "unknown", + user_agent=user_agent, + description=f"Access to {data_classification} {resource_type}: {resource_name or resource_id}", + threat_score=0.3 if success else 0.6, + context={ + "action": action, + "resource_type": resource_type, + "resource_id": resource_id, + "data_classification": data_classification, + }, + correlation_id=correlation_id, + db=db, + ) + + return audit + + def log_suspicious_activity( + self, + activity_type: str, + description: str, + user_id: Optional[str], + user_email: Optional[str], + client_ip: str, + user_agent: Optional[str], + threat_score: float, + severity: SecuritySeverity, + threat_indicators: Dict[str, Any], + action_taken: str, + additional_context: Optional[Dict[str, Any]] = None, + db: Optional[Session] = None, + ) -> Optional[SecurityEvent]: + """Log suspicious activity with threat analysis. + + Args: + activity_type: Type of suspicious activity + description: Event description + user_id: User identifier (if known) + user_email: User email (if known) + client_ip: Client IP address + user_agent: Client user agent + threat_score: Calculated threat score (0.0-1.0) + severity: Event severity + threat_indicators: Dictionary of threat indicators + action_taken: Action taken in response + additional_context: Additional context + db: Optional database session + + Returns: + Created SecurityEvent or None + """ + correlation_id = get_correlation_id() + + event = self._create_security_event( + event_type=SecurityEventType.SUSPICIOUS_ACTIVITY, + severity=severity, + category="suspicious_activity", + user_id=user_id, + user_email=user_email, + client_ip=client_ip, + user_agent=user_agent, + description=description, + threat_score=threat_score, + threat_indicators=threat_indicators, + action_taken=action_taken, + context=additional_context, + correlation_id=correlation_id, + db=db, + ) + + logger.warning( + f"Suspicious activity detected: {description}", + extra={ + "security_event": True, + "activity_type": activity_type, + "severity": severity.value, + "threat_score": threat_score, + "action_taken": action_taken, + "correlation_id": correlation_id, + }, + ) + + return event + + def _count_recent_failures(self, user_id: Optional[str] = None, client_ip: Optional[str] = None, minutes: Optional[int] = None, db: Optional[Session] = None) -> int: + """Count recent authentication failures. + + Args: + user_id: User identifier + client_ip: Client IP address + minutes: Time window in minutes + db: Optional database session + + Returns: + Count of recent failures + """ + if not user_id and not client_ip: + return 0 + + window_minutes = minutes or self.rate_limit_window_minutes + since = datetime.now(timezone.utc) - timedelta(minutes=window_minutes) + + should_close = False + if db is None: + db = SessionLocal() + should_close = True + + try: + stmt = select(func.count(SecurityEvent.id)).where(SecurityEvent.event_type == SecurityEventType.AUTHENTICATION_FAILURE, SecurityEvent.timestamp >= since) # pylint: disable=not-callable + + if user_id: + stmt = stmt.where(SecurityEvent.user_id == user_id) + if client_ip: + stmt = stmt.where(SecurityEvent.client_ip == client_ip) + + result = db.execute(stmt).scalar() + return result or 0 + + finally: + if should_close: + db.close() + + def _calculate_auth_threat_score(self, success: bool, failed_attempts: int, auth_method: str) -> float: # pylint: disable=unused-argument + """Calculate threat score for authentication attempt. + + Args: + success: Whether authentication succeeded + failed_attempts: Count of recent failures + auth_method: Authentication method used + + Returns: + Threat score from 0.0 to 1.0 + """ + if success: + return 0.0 + + # Base score for failure + score = 0.3 + + # Increase based on failed attempts + if failed_attempts >= 10: + score += 0.5 + elif failed_attempts >= 5: + score += 0.3 + elif failed_attempts >= 3: + score += 0.2 + + # Cap at 1.0 + return min(score, 1.0) + + def _requires_audit_review(self, action: str, resource_type: str, data_classification: Optional[str], success: bool) -> bool: + """Determine if audit entry requires manual review. + + Args: + action: Action performed + resource_type: Resource type + data_classification: Data classification + success: Whether action succeeded + + Returns: + True if review required + """ + # Failed actions on sensitive data require review + if not success and data_classification in ["confidential", "restricted"]: + return True + + # Deletions of sensitive data require review + if action == "delete" and data_classification in ["confidential", "restricted"]: + return True + + # Privilege modifications require review + if resource_type in ["role", "permission", "team_member"]: + return True + + return False + + def _create_security_event( + self, + event_type: str, + severity: SecuritySeverity, + category: str, + client_ip: str, + description: str, + threat_score: float, + user_id: Optional[str] = None, + user_email: Optional[str] = None, + user_agent: Optional[str] = None, + action_taken: Optional[str] = None, + failed_attempts_count: int = 0, + threat_indicators: Optional[Dict[str, Any]] = None, + context: Optional[Dict[str, Any]] = None, + correlation_id: Optional[str] = None, + db: Optional[Session] = None, + ) -> Optional[SecurityEvent]: + """Create a security event record. + + Args: + event_type: Type of security event + severity: Event severity + category: Event category + client_ip: Client IP address + description: Event description + threat_score: Threat score (0.0-1.0) + user_id: User identifier + user_email: User email + user_agent: User agent string + action_taken: Action taken + failed_attempts_count: Failed attempts count + threat_indicators: Threat indicators + context: Additional context + correlation_id: Correlation ID + db: Optional database session + + Returns: + Created SecurityEvent or None + """ + should_close = False + if db is None: + db = SessionLocal() + should_close = True + + try: + event = SecurityEvent( + event_type=event_type, + severity=severity.value, + category=category, + user_id=user_id, + user_email=user_email, + client_ip=client_ip, + user_agent=user_agent, + description=description, + action_taken=action_taken, + threat_score=threat_score, + threat_indicators=threat_indicators or {}, + failed_attempts_count=failed_attempts_count, + context=context, + correlation_id=correlation_id, + ) + + db.add(event) + db.commit() + db.refresh(event) + + return event + + except Exception as e: + logger.error(f"Failed to create security event: {e}") + db.rollback() + return None + + finally: + if should_close: + db.close() + + def _create_audit_trail( # pylint: disable=too-many-positional-arguments + self, + action: str, + resource_type: str, + user_id: str, + success: bool, + resource_id: Optional[str] = None, + resource_name: Optional[str] = None, + user_email: Optional[str] = None, + team_id: Optional[str] = None, + client_ip: Optional[str] = None, + user_agent: Optional[str] = None, + old_values: Optional[Dict[str, Any]] = None, + new_values: Optional[Dict[str, Any]] = None, + changes: Optional[Dict[str, Any]] = None, + data_classification: Optional[str] = None, + requires_review: bool = False, + error_message: Optional[str] = None, + context: Optional[Dict[str, Any]] = None, + correlation_id: Optional[str] = None, + db: Optional[Session] = None, + ) -> Optional[AuditTrail]: + """Create an audit trail record. + + Args: + action: Action performed + resource_type: Resource type + user_id: User performing action + success: Whether action succeeded + resource_id: Resource identifier + resource_name: Resource name + user_email: User email + team_id: Team context + client_ip: Client IP + user_agent: User agent + old_values: Previous values + new_values: New values + changes: Calculated changes + data_classification: Data classification + requires_review: Whether manual review needed + error_message: Error message if failed + context: Additional context + correlation_id: Correlation ID + db: Optional database session + + Returns: + Created AuditTrail or None + """ + should_close = False + if db is None: + db = SessionLocal() + should_close = True + + try: + audit = AuditTrail( + action=action, + resource_type=resource_type, + resource_id=resource_id, + resource_name=resource_name, + user_id=user_id, + user_email=user_email, + team_id=team_id, + client_ip=client_ip, + user_agent=user_agent, + old_values=old_values, + new_values=new_values, + changes=changes, + data_classification=data_classification, + requires_review=requires_review, + success=success, + error_message=error_message, + context=context, + correlation_id=correlation_id, + ) + + db.add(audit) + db.commit() + db.refresh(audit) + + return audit + + except Exception as e: + logger.error(f"Failed to create audit trail: {e}") + db.rollback() + return None + + finally: + if should_close: + db.close() + + +# Global security logger instance +_security_logger: Optional[SecurityLogger] = None + + +def get_security_logger() -> SecurityLogger: + """Get or create the global security logger instance. + + Returns: + Global SecurityLogger instance + """ + global _security_logger # pylint: disable=global-statement + if _security_logger is None: + _security_logger = SecurityLogger() + return _security_logger diff --git a/mcpgateway/services/server_service.py b/mcpgateway/services/server_service.py index e7f8aae4d..01f20b304 100644 --- a/mcpgateway/services/server_service.py +++ b/mcpgateway/services/server_service.py @@ -33,7 +33,10 @@ from mcpgateway.db import ServerMetric from mcpgateway.db import Tool as DbTool from mcpgateway.schemas import ServerCreate, ServerMetrics, ServerRead, ServerUpdate, TopPerformer +from mcpgateway.services.audit_trail_service import get_audit_trail_service from mcpgateway.services.logging_service import LoggingService +from mcpgateway.services.performance_tracker import get_performance_tracker +from mcpgateway.services.structured_logger import get_structured_logger from mcpgateway.services.team_management_service import TeamManagementService from mcpgateway.utils.metrics_common import build_top_performers from mcpgateway.utils.sqlalchemy_modifier import json_contains_expr @@ -130,6 +133,9 @@ def __init__(self) -> None: """ self._event_subscribers: List[asyncio.Queue] = [] self._http_client = httpx.AsyncClient(timeout=settings.federation_timeout, verify=not settings.skip_ssl_verify) + self._structured_logger = get_structured_logger("server_service") + self._audit_trail = get_audit_trail_service() + self._performance_tracker = get_performance_tracker() async def initialize(self) -> None: """Initialize the server service.""" @@ -394,7 +400,7 @@ async def register_server( Examples: >>> from mcpgateway.services.server_service import ServerService - >>> from unittest.mock import MagicMock, AsyncMock + >>> from unittest.mock import MagicMock, AsyncMock, patch >>> from mcpgateway.schemas import ServerRead >>> service = ServerService() >>> db = MagicMock() @@ -406,6 +412,8 @@ async def register_server( >>> db.refresh = MagicMock() >>> service._notify_server_added = AsyncMock() >>> service._convert_server_to_read = MagicMock(return_value='server_read') + >>> service._structured_logger = MagicMock() # Mock structured logger to prevent database writes + >>> service._audit_trail = MagicMock() # Mock audit trail to prevent database writes >>> ServerRead.model_validate = MagicMock(return_value='server_read') >>> import asyncio >>> asyncio.run(service.register_server(db, server_in)) @@ -549,17 +557,91 @@ async def register_server( logger.debug(f"Server Data: {server_data}") await self._notify_server_added(db_server) logger.info(f"Registered server: {server_in.name}") + + # Structured logging: Audit trail for server creation + self._audit_trail.log_action( + user_id=created_by or "system", + action="create_server", + resource_type="server", + resource_id=db_server.id, + details={ + "server_name": db_server.name, + "visibility": visibility, + "team_id": team_id, + "associated_tools_count": len(db_server.tools), + "associated_resources_count": len(db_server.resources), + "associated_prompts_count": len(db_server.prompts), + "associated_a2a_agents_count": len(db_server.a2a_agents), + }, + metadata={ + "created_from_ip": created_from_ip, + "created_via": created_via, + "created_user_agent": created_user_agent, + }, + ) + + # Structured logging: Log successful server creation + self._structured_logger.log( + level="INFO", + message="Server created successfully", + event_type="server_created", + component="server_service", + server_id=db_server.id, + server_name=db_server.name, + visibility=visibility, + created_by=created_by, + user_email=created_by, + ) + db_server.team = self._get_team_name(db, db_server.team_id) return self._convert_server_to_read(db_server) except IntegrityError as ie: db.rollback() logger.error(f"IntegrityErrors in group: {ie}") + + # Structured logging: Log database integrity error + self._structured_logger.log( + level="ERROR", + message="Server creation failed due to database integrity error", + event_type="server_creation_failed", + component="server_service", + server_name=server_in.name, + error_type="IntegrityError", + error_message=str(ie), + created_by=created_by, + user_email=created_by, + ) raise ie except ServerNameConflictError as se: db.rollback() + + # Structured logging: Log name conflict error + self._structured_logger.log( + level="WARNING", + message="Server creation failed due to name conflict", + event_type="server_name_conflict", + component="server_service", + server_name=server_in.name, + visibility=visibility, + created_by=created_by, + user_email=created_by, + ) raise se except Exception as ex: db.rollback() + + # Structured logging: Log generic server creation failure + self._structured_logger.log( + level="ERROR", + message="Server creation failed", + event_type="server_creation_failed", + component="server_service", + server_name=server_in.name, + error_type=type(ex).__name__, + error_message=str(ex), + created_by=created_by, + user_email=created_by, + ) raise ServerError(f"Failed to register server: {str(ex)}") async def list_servers(self, db: Session, include_inactive: bool = False, tags: Optional[List[str]] = None) -> List[ServerRead]: @@ -731,7 +813,39 @@ async def get_server(self, db: Session, server_id: str) -> ServerRead: } logger.debug(f"Server Data: {server_data}") server.team = self._get_team_name(db, server.team_id) if server else None - return self._convert_server_to_read(server) + server_read = self._convert_server_to_read(server) + + self._structured_logger.log( + level="INFO", + message="Server retrieved successfully", + event_type="server_viewed", + component="server_service", + server_id=server.id, + server_name=server.name, + team_id=getattr(server, "team_id", None), + resource_type="server", + resource_id=server.id, + custom_fields={ + "enabled": server.enabled, + "tool_count": len(getattr(server, "tools", []) or []), + "resource_count": len(getattr(server, "resources", []) or []), + "prompt_count": len(getattr(server, "prompts", []) or []), + }, + db=db, + ) + + self._audit_trail.log_action( + action="view_server", + resource_type="server", + resource_id=server.id, + resource_name=server.name, + user_id="system", + team_id=getattr(server, "team_id", None), + context={"enabled": server.enabled}, + db=db, + ) + + return server_read async def update_server( self, @@ -769,7 +883,7 @@ async def update_server( Examples: >>> from mcpgateway.services.server_service import ServerService - >>> from unittest.mock import MagicMock, AsyncMock + >>> from unittest.mock import MagicMock, AsyncMock, patch >>> from mcpgateway.schemas import ServerRead >>> service = ServerService() >>> db = MagicMock() @@ -783,6 +897,8 @@ async def update_server( >>> db.refresh = MagicMock() >>> db.execute.return_value.scalar_one_or_none.return_value = None >>> service._convert_server_to_read = MagicMock(return_value='server_read') + >>> service._structured_logger = MagicMock() # Mock structured logger to prevent database writes + >>> service._audit_trail = MagicMock() # Mock audit trail to prevent database writes >>> ServerRead.model_validate = MagicMock(return_value='server_read') >>> server_update = MagicMock() >>> server_update.id = None # No UUID change @@ -927,6 +1043,44 @@ async def update_server( await self._notify_server_updated(server) logger.info(f"Updated server: {server.name}") + # Structured logging: Audit trail for server update + changes = [] + if server_update.name: + changes.append(f"name: {server_update.name}") + if server_update.visibility: + changes.append(f"visibility: {server_update.visibility}") + if server_update.team_id: + changes.append(f"team_id: {server_update.team_id}") + + self._audit_trail.log_action( + user_id=user_email or "system", + action="update_server", + resource_type="server", + resource_id=server.id, + details={ + "server_name": server.name, + "changes": ", ".join(changes) if changes else "metadata only", + "version": server.version, + }, + metadata={ + "modified_from_ip": modified_from_ip, + "modified_via": modified_via, + "modified_user_agent": modified_user_agent, + }, + ) + + # Structured logging: Log successful server update + self._structured_logger.log( + level="INFO", + message="Server updated successfully", + event_type="server_updated", + component="server_service", + server_id=server.id, + server_name=server.name, + modified_by=user_email, + user_email=user_email, + ) + # Build a dictionary with associated IDs server_data = { "id": server.id, @@ -946,13 +1100,50 @@ async def update_server( except IntegrityError as ie: db.rollback() logger.error(f"IntegrityErrors in group: {ie}") + + # Structured logging: Log database integrity error + self._structured_logger.log( + level="ERROR", + message="Server update failed due to database integrity error", + event_type="server_update_failed", + component="server_service", + server_id=server_id, + error_type="IntegrityError", + error_message=str(ie), + modified_by=user_email, + user_email=user_email, + ) raise ie except ServerNameConflictError as snce: db.rollback() logger.error(f"Server name conflict: {snce}") + + # Structured logging: Log name conflict error + self._structured_logger.log( + level="WARNING", + message="Server update failed due to name conflict", + event_type="server_name_conflict", + component="server_service", + server_id=server_id, + modified_by=user_email, + user_email=user_email, + ) raise snce except Exception as e: db.rollback() + + # Structured logging: Log generic server update failure + self._structured_logger.log( + level="ERROR", + message="Server update failed", + event_type="server_update_failed", + component="server_service", + server_id=server_id, + error_type=type(e).__name__, + error_message=str(e), + modified_by=user_email, + user_email=user_email, + ) raise ServerError(f"Failed to update server: {str(e)}") async def toggle_server_status(self, db: Session, server_id: str, activate: bool, user_email: Optional[str] = None) -> ServerRead: @@ -974,7 +1165,7 @@ async def toggle_server_status(self, db: Session, server_id: str, activate: bool Examples: >>> from mcpgateway.services.server_service import ServerService - >>> from unittest.mock import MagicMock, AsyncMock + >>> from unittest.mock import MagicMock, AsyncMock, patch >>> from mcpgateway.schemas import ServerRead >>> service = ServerService() >>> db = MagicMock() @@ -985,6 +1176,8 @@ async def toggle_server_status(self, db: Session, server_id: str, activate: bool >>> service._notify_server_activated = AsyncMock() >>> service._notify_server_deactivated = AsyncMock() >>> service._convert_server_to_read = MagicMock(return_value='server_read') + >>> service._structured_logger = MagicMock() # Mock structured logger to prevent database writes + >>> service._audit_trail = MagicMock() # Mock audit trail to prevent database writes >>> ServerRead.model_validate = MagicMock(return_value='server_read') >>> import asyncio >>> asyncio.run(service.toggle_server_status(db, 'server_id', True)) @@ -1014,6 +1207,31 @@ async def toggle_server_status(self, db: Session, server_id: str, activate: bool await self._notify_server_deactivated(server) logger.info(f"Server {server.name} {'activated' if activate else 'deactivated'}") + # Structured logging: Audit trail for server status toggle + self._audit_trail.log_action( + user_id=user_email or "system", + action="activate_server" if activate else "deactivate_server", + resource_type="server", + resource_id=server.id, + details={ + "server_name": server.name, + "new_status": "active" if activate else "inactive", + }, + ) + + # Structured logging: Log server status change + self._structured_logger.log( + level="INFO", + message=f"Server {'activated' if activate else 'deactivated'}", + event_type="server_status_changed", + component="server_service", + server_id=server.id, + server_name=server.name, + new_status="active" if activate else "inactive", + changed_by=user_email, + user_email=user_email, + ) + server_data = { "id": server.id, "name": server.name, @@ -1030,9 +1248,30 @@ async def toggle_server_status(self, db: Session, server_id: str, activate: bool logger.info(f"Server Data: {server_data}") return self._convert_server_to_read(server) except PermissionError as e: + # Structured logging: Log permission error + self._structured_logger.log( + level="WARNING", + message="Server status toggle failed due to insufficient permissions", + event_type="server_status_toggle_permission_denied", + component="server_service", + server_id=server_id, + user_email=user_email, + ) raise e except Exception as e: db.rollback() + + # Structured logging: Log generic server status toggle failure + self._structured_logger.log( + level="ERROR", + message="Server status toggle failed", + event_type="server_status_toggle_failed", + component="server_service", + server_id=server_id, + error_type=type(e).__name__, + error_message=str(e), + user_email=user_email, + ) raise ServerError(f"Failed to toggle server status: {str(e)}") async def delete_server(self, db: Session, server_id: str, user_email: Optional[str] = None) -> None: @@ -1050,7 +1289,7 @@ async def delete_server(self, db: Session, server_id: str, user_email: Optional[ Examples: >>> from mcpgateway.services.server_service import ServerService - >>> from unittest.mock import MagicMock, AsyncMock + >>> from unittest.mock import MagicMock, AsyncMock, patch >>> service = ServerService() >>> db = MagicMock() >>> server = MagicMock() @@ -1058,6 +1297,8 @@ async def delete_server(self, db: Session, server_id: str, user_email: Optional[ >>> db.delete = MagicMock() >>> db.commit = MagicMock() >>> service._notify_server_deleted = AsyncMock() + >>> service._structured_logger = MagicMock() # Mock structured logger to prevent database writes + >>> service._audit_trail = MagicMock() # Mock audit trail to prevent database writes >>> import asyncio >>> asyncio.run(service.delete_server(db, 'server_id', 'user@example.com')) """ @@ -1081,11 +1322,56 @@ async def delete_server(self, db: Session, server_id: str, user_email: Optional[ await self._notify_server_deleted(server_info) logger.info(f"Deleted server: {server_info['name']}") - except PermissionError: + + # Structured logging: Audit trail for server deletion + self._audit_trail.log_action( + user_id=user_email or "system", + action="delete_server", + resource_type="server", + resource_id=server_info["id"], + details={ + "server_name": server_info["name"], + }, + ) + + # Structured logging: Log successful server deletion + self._structured_logger.log( + level="INFO", + message="Server deleted successfully", + event_type="server_deleted", + component="server_service", + server_id=server_info["id"], + server_name=server_info["name"], + deleted_by=user_email, + user_email=user_email, + ) + except PermissionError as pe: db.rollback() - raise + + # Structured logging: Log permission error + self._structured_logger.log( + level="WARNING", + message="Server deletion failed due to insufficient permissions", + event_type="server_deletion_permission_denied", + component="server_service", + server_id=server_id, + user_email=user_email, + ) + raise pe except Exception as e: db.rollback() + + # Structured logging: Log generic server deletion failure + self._structured_logger.log( + level="ERROR", + message="Server deletion failed", + event_type="server_deletion_failed", + component="server_service", + server_id=server_id, + error_type=type(e).__name__, + error_message=str(e), + user_email=user_email, + ) raise ServerError(f"Failed to delete server: {str(e)}") async def _publish_event(self, event: Dict[str, Any]) -> None: diff --git a/mcpgateway/services/structured_logger.py b/mcpgateway/services/structured_logger.py new file mode 100644 index 000000000..0d8a4a599 --- /dev/null +++ b/mcpgateway/services/structured_logger.py @@ -0,0 +1,441 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/services/structured_logger.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 + +Structured Logger Service. + +This module provides comprehensive structured logging with component-based loggers, +automatic enrichment, intelligent routing, and database persistence. +""" + +# Standard +from datetime import datetime, timezone +from enum import Enum +import logging +import os +import socket +import traceback +from typing import Any, Dict, List, Optional, Union + +# Third-Party +from sqlalchemy.orm import Session + +# First-Party +from mcpgateway.config import settings +from mcpgateway.db import SessionLocal, StructuredLogEntry +from mcpgateway.services.performance_tracker import get_performance_tracker +from mcpgateway.utils.correlation_id import get_correlation_id + +logger = logging.getLogger(__name__) + + +class LogLevel(str, Enum): + """Log levels matching Python logging.""" + + DEBUG = "DEBUG" + INFO = "INFO" + WARNING = "WARNING" + ERROR = "ERROR" + CRITICAL = "CRITICAL" + + +class LogCategory(str, Enum): + """Log categories for classification.""" + + APPLICATION = "application" + REQUEST = "request" + SECURITY = "security" + PERFORMANCE = "performance" + DATABASE = "database" + AUTHENTICATION = "authentication" + AUTHORIZATION = "authorization" + EXTERNAL_SERVICE = "external_service" + BUSINESS_LOGIC = "business_logic" + SYSTEM = "system" + + +class LogEnricher: + """Enriches log entries with contextual information.""" + + @staticmethod + def enrich(entry: Dict[str, Any]) -> Dict[str, Any]: + """Enrich log entry with system and context information. + + Args: + entry: Base log entry + + Returns: + Enriched log entry + """ + # Get correlation ID + correlation_id = get_correlation_id() + if correlation_id: + entry["correlation_id"] = correlation_id + + # Add hostname and process info + entry.setdefault("hostname", socket.gethostname()) + entry.setdefault("process_id", os.getpid()) + + # Add timestamp if not present + if "timestamp" not in entry: + entry["timestamp"] = datetime.now(timezone.utc) + + # Add performance metrics if available + try: + perf_tracker = get_performance_tracker() + if correlation_id and perf_tracker and hasattr(perf_tracker, "get_current_operations"): + current_ops = perf_tracker.get_current_operations(correlation_id) # pylint: disable=no-member + if current_ops: + entry["active_operations"] = len(current_ops) + except Exception: # nosec B110 - Graceful degradation if performance tracker unavailable + # Silently skip if performance tracker is unavailable or method doesn't exist + pass + + # Add OpenTelemetry trace context if available + try: + # Third-Party + from opentelemetry import trace # pylint: disable=import-outside-toplevel + + span = trace.get_current_span() + if span and span.get_span_context().is_valid: + ctx = span.get_span_context() + entry["trace_id"] = format(ctx.trace_id, "032x") + entry["span_id"] = format(ctx.span_id, "016x") + except (ImportError, Exception): + pass + + return entry + + +class LogRouter: + """Routes log entries to appropriate destinations.""" + + def __init__(self): + """Initialize log router.""" + self.database_enabled = getattr(settings, "structured_logging_database_enabled", True) + self.external_enabled = getattr(settings, "structured_logging_external_enabled", False) + + def route(self, entry: Dict[str, Any], db: Optional[Session] = None) -> None: + """Route log entry to configured destinations. + + Args: + entry: Log entry to route + db: Optional database session + """ + # Always log to standard Python logger + self._log_to_python_logger(entry) + + # Persist to database if enabled + if self.database_enabled: + self._persist_to_database(entry, db) + + # Send to external systems if enabled + if self.external_enabled: + self._send_to_external(entry) + + def _log_to_python_logger(self, entry: Dict[str, Any]) -> None: + """Log to standard Python logger. + + Args: + entry: Log entry + """ + level_str = entry.get("level", "INFO") + level = getattr(logging, level_str, logging.INFO) + + message = entry.get("message", "") + component = entry.get("component", "") + + log_message = f"[{component}] {message}" if component else message + + # Build extra dict for structured logging + extra = {k: v for k, v in entry.items() if k not in ["message", "level"]} + + logger.log(level, log_message, extra=extra) + + def _persist_to_database(self, entry: Dict[str, Any], db: Optional[Session] = None) -> None: + """Persist log entry to database. + + Args: + entry: Log entry + db: Optional database session + """ + should_close = False + if db is None: + db = SessionLocal() + should_close = True + + try: + # Build error_details JSON from error-related fields + error_details = None + if any([entry.get("error_type"), entry.get("error_message"), entry.get("error_stack_trace"), entry.get("error_context")]): + error_details = { + "error_type": entry.get("error_type"), + "error_message": entry.get("error_message"), + "error_stack_trace": entry.get("error_stack_trace"), + "error_context": entry.get("error_context"), + } + + # Build performance_metrics JSON from performance-related fields + performance_metrics = None + perf_fields = { + "database_query_count": entry.get("database_query_count"), + "database_query_duration_ms": entry.get("database_query_duration_ms"), + "cache_hits": entry.get("cache_hits"), + "cache_misses": entry.get("cache_misses"), + "external_api_calls": entry.get("external_api_calls"), + "external_api_duration_ms": entry.get("external_api_duration_ms"), + "memory_usage_mb": entry.get("memory_usage_mb"), + "cpu_usage_percent": entry.get("cpu_usage_percent"), + } + if any(v is not None for v in perf_fields.values()): + performance_metrics = {k: v for k, v in perf_fields.items() if v is not None} + + # Build threat_indicators JSON from security-related fields + threat_indicators = None + security_fields = { + "security_event_type": entry.get("security_event_type"), + "security_threat_score": entry.get("security_threat_score"), + "security_action_taken": entry.get("security_action_taken"), + } + if any(v is not None for v in security_fields.values()): + threat_indicators = {k: v for k, v in security_fields.items() if v is not None} + + # Build context JSON from remaining fields + context_fields = { + "team_id": entry.get("team_id"), + "request_query": entry.get("request_query"), + "request_headers": entry.get("request_headers"), + "request_body_size": entry.get("request_body_size"), + "response_status_code": entry.get("response_status_code"), + "response_body_size": entry.get("response_body_size"), + "response_headers": entry.get("response_headers"), + "business_event_type": entry.get("business_event_type"), + "business_entity_type": entry.get("business_entity_type"), + "business_entity_id": entry.get("business_entity_id"), + "resource_type": entry.get("resource_type"), + "resource_id": entry.get("resource_id"), + "resource_action": entry.get("resource_action"), + "category": entry.get("category"), + "custom_fields": entry.get("custom_fields"), + "tags": entry.get("tags"), + "metadata": entry.get("metadata"), + } + context = {k: v for k, v in context_fields.items() if v is not None} + + # Determine if this is a security event + is_security_event = entry.get("is_security_event", False) or bool(threat_indicators) + security_severity = entry.get("security_severity") + + log_entry = StructuredLogEntry( + timestamp=entry.get("timestamp", datetime.now(timezone.utc)), + level=entry.get("level", "INFO"), + component=entry.get("component"), + message=entry.get("message", ""), + correlation_id=entry.get("correlation_id"), + request_id=entry.get("request_id"), + trace_id=entry.get("trace_id"), + span_id=entry.get("span_id"), + user_id=entry.get("user_id"), + user_email=entry.get("user_email"), + client_ip=entry.get("client_ip"), + user_agent=entry.get("user_agent"), + request_method=entry.get("request_method"), + request_path=entry.get("request_path"), + duration_ms=entry.get("duration_ms"), + operation_type=entry.get("operation_type"), + is_security_event=is_security_event, + security_severity=security_severity, + threat_indicators=threat_indicators, + context=context if context else None, + error_details=error_details, + performance_metrics=performance_metrics, + hostname=entry.get("hostname"), + process_id=entry.get("process_id"), + thread_id=entry.get("thread_id"), + environment=entry.get("environment", getattr(settings, "environment", "development")), + version=entry.get("version", getattr(settings, "version", "unknown")), + ) + + db.add(log_entry) + db.commit() + + except Exception as e: + logger.error(f"Failed to persist log entry to database: {e}", exc_info=True) + # Also print to console for immediate visibility + print(f"ERROR persisting log to database: {e}") + traceback.print_exc() + if db: + db.rollback() + + finally: + if should_close: + db.close() + + def _send_to_external(self, entry: Dict[str, Any]) -> None: + """Send log entry to external systems. + + Args: + entry: Log entry + """ + # Placeholder for external logging integration + # Will be implemented in log exporters + + +class StructuredLogger: + """Main structured logger with enrichment and routing.""" + + def __init__(self, component: str): + """Initialize structured logger. + + Args: + component: Component name for log entries + """ + self.component = component + self.enricher = LogEnricher() + self.router = LogRouter() + + def log( + self, + level: Union[LogLevel, str], + message: str, + category: Optional[Union[LogCategory, str]] = None, + user_id: Optional[str] = None, + user_email: Optional[str] = None, + team_id: Optional[str] = None, + error: Optional[Exception] = None, + duration_ms: Optional[float] = None, + custom_fields: Optional[Dict[str, Any]] = None, + tags: Optional[List[str]] = None, + db: Optional[Session] = None, + **kwargs: Any, + ) -> None: + """Log a structured message. + + Args: + level: Log level + message: Log message + category: Log category + user_id: User identifier + user_email: User email + team_id: Team identifier + error: Exception object + duration_ms: Operation duration + custom_fields: Additional custom fields + tags: Log tags + db: Optional database session + **kwargs: Additional fields to include + """ + # Build base entry + entry: Dict[str, Any] = { + "level": level.value if isinstance(level, LogLevel) else level, + "component": self.component, + "message": message, + "category": category.value if isinstance(category, LogCategory) and category else category if category else None, + "user_id": user_id, + "user_email": user_email, + "team_id": team_id, + "duration_ms": duration_ms, + "custom_fields": custom_fields, + "tags": tags, + } + + # Add error information if present + if error: + entry["error_type"] = type(error).__name__ + entry["error_message"] = str(error) + entry["error_stack_trace"] = "".join(traceback.format_exception(type(error), error, error.__traceback__)) + + # Add any additional kwargs + entry.update(kwargs) + + # Enrich entry with context + entry = self.enricher.enrich(entry) + + # Route to destinations + self.router.route(entry, db) + + def debug(self, message: str, **kwargs: Any) -> None: + """Log debug message. + + Args: + message: Log message + **kwargs: Additional context fields + """ + self.log(LogLevel.DEBUG, message, **kwargs) + + def info(self, message: str, **kwargs: Any) -> None: + """Log info message. + + Args: + message: Log message + **kwargs: Additional context fields + """ + self.log(LogLevel.INFO, message, **kwargs) + + def warning(self, message: str, **kwargs: Any) -> None: + """Log warning message. + + Args: + message: Log message + **kwargs: Additional context fields + """ + self.log(LogLevel.WARNING, message, **kwargs) + + def error(self, message: str, error: Optional[Exception] = None, **kwargs: Any) -> None: + """Log error message. + + Args: + message: Log message + error: Exception object if available + **kwargs: Additional context fields + """ + self.log(LogLevel.ERROR, message, error=error, **kwargs) + + def critical(self, message: str, error: Optional[Exception] = None, **kwargs: Any) -> None: + """Log critical message. + + Args: + message: Log message + error: Exception object if available + **kwargs: Additional context fields + """ + self.log(LogLevel.CRITICAL, message, error=error, **kwargs) + + +class ComponentLogger: + """Logger factory for component-specific loggers.""" + + _loggers: Dict[str, StructuredLogger] = {} + + @classmethod + def get_logger(cls, component: str) -> StructuredLogger: + """Get or create a logger for a specific component. + + Args: + component: Component name + + Returns: + StructuredLogger instance for the component + """ + if component not in cls._loggers: + cls._loggers[component] = StructuredLogger(component) + return cls._loggers[component] + + @classmethod + def clear_loggers(cls) -> None: + """Clear all cached loggers (useful for testing).""" + cls._loggers.clear() + + +# Global structured logger instance for backward compatibility +def get_structured_logger(component: str = "mcpgateway") -> StructuredLogger: + """Get a structured logger instance. + + Args: + component: Component name + + Returns: + StructuredLogger instance + """ + return ComponentLogger.get_logger(component) diff --git a/mcpgateway/services/tool_service.py b/mcpgateway/services/tool_service.py index 731b67e0a..5616e0ff4 100644 --- a/mcpgateway/services/tool_service.py +++ b/mcpgateway/services/tool_service.py @@ -63,10 +63,14 @@ ) from mcpgateway.plugins.framework.constants import GATEWAY_METADATA, TOOL_METADATA from mcpgateway.schemas import ToolCreate, ToolRead, ToolUpdate, TopPerformer +from mcpgateway.services.audit_trail_service import get_audit_trail_service from mcpgateway.services.event_service import EventService from mcpgateway.services.logging_service import LoggingService from mcpgateway.services.oauth_manager import OAuthManager +from mcpgateway.services.performance_tracker import get_performance_tracker +from mcpgateway.services.structured_logger import get_structured_logger from mcpgateway.services.team_management_service import TeamManagementService +from mcpgateway.utils.correlation_id import get_correlation_id from mcpgateway.utils.create_slug import slugify from mcpgateway.utils.display_name import generate_display_name from mcpgateway.utils.metrics_common import build_top_performers @@ -81,6 +85,11 @@ logging_service = LoggingService() logger = logging_service.get_logger(__name__) +# Initialize performance tracker, structured logger, and audit trail for tool operations +perf_tracker = get_performance_tracker() +structured_logger = get_structured_logger("tool_service") +audit_trail = get_audit_trail_service() + def extract_using_jq(data, jq_filter=""): """ @@ -710,17 +719,109 @@ async def register_tool( db.commit() db.refresh(db_tool) await self._notify_tool_added(db_tool) + + # Structured logging: Audit trail for tool creation + audit_trail.log_action( + user_id=created_by or "system", + action="create_tool", + resource_type="tool", + resource_id=db_tool.id, + resource_name=db_tool.name, + user_email=owner_email, + team_id=team_id, + client_ip=created_from_ip, + user_agent=created_user_agent, + new_values={ + "name": db_tool.name, + "display_name": db_tool.display_name, + "visibility": visibility, + "integration_type": db_tool.integration_type, + }, + context={ + "created_via": created_via, + "import_batch_id": import_batch_id, + "federation_source": federation_source, + }, + db=db, + ) + + # Structured logging: Log successful tool creation + structured_logger.log( + level="INFO", + message="Tool created successfully", + event_type="tool_created", + component="tool_service", + user_id=created_by, + user_email=owner_email, + team_id=team_id, + resource_type="tool", + resource_id=db_tool.id, + custom_fields={ + "tool_name": db_tool.name, + "visibility": visibility, + "integration_type": db_tool.integration_type, + }, + db=db, + ) + + # Refresh db_tool after logging commits (they expire the session objects) + db.refresh(db_tool) return self._convert_tool_to_read(db_tool) except IntegrityError as ie: db.rollback() logger.error(f"IntegrityError during tool registration: {ie}") + + # Structured logging: Log database integrity error + structured_logger.log( + level="ERROR", + message="Tool creation failed due to database integrity error", + event_type="tool_creation_failed", + component="tool_service", + user_id=created_by, + user_email=owner_email, + error=ie, + custom_fields={ + "tool_name": tool.name, + }, + db=db, + ) raise ie except ToolNameConflictError as tnce: db.rollback() logger.error(f"ToolNameConflictError during tool registration: {tnce}") + + # Structured logging: Log name conflict error + structured_logger.log( + level="WARNING", + message="Tool creation failed due to name conflict", + event_type="tool_name_conflict", + component="tool_service", + user_id=created_by, + user_email=owner_email, + custom_fields={ + "tool_name": tool.name, + "visibility": visibility, + }, + db=db, + ) raise tnce except Exception as e: db.rollback() + + # Structured logging: Log generic tool creation failure + structured_logger.log( + level="ERROR", + message="Tool creation failed", + event_type="tool_creation_failed", + component="tool_service", + user_id=created_by, + user_email=owner_email, + error=e, + custom_fields={ + "tool_name": tool.name, + }, + db=db, + ) raise ToolError(f"Failed to register tool: {str(e)}") async def list_tools( @@ -1009,7 +1110,25 @@ async def get_tool(self, db: Session, tool_id: str) -> ToolRead: if not tool: raise ToolNotFoundError(f"Tool not found: {tool_id}") tool.team = self._get_team_name(db, getattr(tool, "team_id", None)) - return self._convert_tool_to_read(tool) + + tool_read = self._convert_tool_to_read(tool) + + structured_logger.log( + level="INFO", + message="Tool retrieved successfully", + event_type="tool_viewed", + component="tool_service", + team_id=getattr(tool, "team_id", None), + resource_type="tool", + resource_id=str(tool.id), + custom_fields={ + "tool_name": tool.name, + "include_metrics": bool(getattr(tool_read, "metrics", {})), + }, + db=db, + ) + + return tool_read async def delete_tool(self, db: Session, tool_id: str, user_email: Optional[str] = None) -> None: """ @@ -1053,15 +1172,75 @@ async def delete_tool(self, db: Session, tool_id: str, user_email: Optional[str] raise PermissionError("Only the owner can delete this tool") tool_info = {"id": tool.id, "name": tool.name} + tool_name = tool.name + tool_team_id = tool.team_id + db.delete(tool) db.commit() await self._notify_tool_deleted(tool_info) logger.info(f"Permanently deleted tool: {tool_info['name']}") - except PermissionError: + + # Structured logging: Audit trail for tool deletion + audit_trail.log_action( + user_id=user_email or "system", + action="delete_tool", + resource_type="tool", + resource_id=tool_info["id"], + resource_name=tool_name, + user_email=user_email, + team_id=tool_team_id, + old_values={ + "name": tool_name, + }, + db=db, + ) + + # Structured logging: Log successful tool deletion + structured_logger.log( + level="INFO", + message="Tool deleted successfully", + event_type="tool_deleted", + component="tool_service", + user_email=user_email, + team_id=tool_team_id, + resource_type="tool", + resource_id=tool_info["id"], + custom_fields={ + "tool_name": tool_name, + }, + db=db, + ) + except PermissionError as pe: db.rollback() + + # Structured logging: Log permission error + structured_logger.log( + level="WARNING", + message="Tool deletion failed due to permission error", + event_type="tool_delete_permission_denied", + component="tool_service", + user_email=user_email, + resource_type="tool", + resource_id=tool_id, + error=pe, + db=db, + ) raise except Exception as e: db.rollback() + + # Structured logging: Log generic tool deletion failure + structured_logger.log( + level="ERROR", + message="Tool deletion failed", + event_type="tool_deletion_failed", + component="tool_service", + user_email=user_email, + resource_type="tool", + resource_id=tool_id, + error=e, + db=db, + ) raise ToolError(f"Failed to delete tool: {str(e)}") async def toggle_tool_status(self, db: Session, tool_id: str, activate: bool, reachable: bool, user_email: Optional[str] = None) -> ToolRead: @@ -1140,11 +1319,74 @@ async def toggle_tool_status(self, db: Session, tool_id: str, activate: bool, re await self._notify_tool_activated(tool) logger.info(f"Tool: {tool.name} is {'enabled' if activate else 'disabled'}{' and accessible' if reachable else ' but inaccessible'}") + + # Structured logging: Audit trail for tool status toggle + audit_trail.log_action( + user_id=user_email or "system", + action="toggle_tool_status", + resource_type="tool", + resource_id=tool.id, + resource_name=tool.name, + user_email=user_email, + team_id=tool.team_id, + new_values={ + "enabled": tool.enabled, + "reachable": tool.reachable, + }, + context={ + "action": "activate" if activate else "deactivate", + }, + db=db, + ) + + # Structured logging: Log successful tool status toggle + structured_logger.log( + level="INFO", + message=f"Tool {'activated' if activate else 'deactivated'} successfully", + event_type="tool_status_toggled", + component="tool_service", + user_email=user_email, + team_id=tool.team_id, + resource_type="tool", + resource_id=tool.id, + custom_fields={ + "tool_name": tool.name, + "enabled": tool.enabled, + "reachable": tool.reachable, + }, + db=db, + ) + return self._convert_tool_to_read(tool) except PermissionError as e: + # Structured logging: Log permission error + structured_logger.log( + level="WARNING", + message="Tool status toggle failed due to permission error", + event_type="tool_toggle_permission_denied", + component="tool_service", + user_email=user_email, + resource_type="tool", + resource_id=tool_id, + error=e, + db=db, + ) raise e except Exception as e: db.rollback() + + # Structured logging: Log generic tool status toggle failure + structured_logger.log( + level="ERROR", + message="Tool status toggle failed", + event_type="tool_toggle_failed", + component="tool_service", + user_email=user_email, + resource_type="tool", + resource_id=tool_id, + error=e, + db=db, + ) raise ToolError(f"Failed to toggle tool status: {str(e)}") async def invoke_tool( @@ -1182,15 +1424,17 @@ async def invoke_tool( Examples: >>> from mcpgateway.services.tool_service import ToolService - >>> from unittest.mock import MagicMock + >>> from unittest.mock import MagicMock, patch >>> service = ToolService() >>> db = MagicMock() >>> tool = MagicMock() >>> db.execute.return_value.scalar_one_or_none.side_effect = [tool, None] >>> tool.reachable = True >>> import asyncio - >>> result = asyncio.run(service.invoke_tool(db, 'tool_name', {})) - >>> isinstance(result, object) + >>> # Mock structured_logger to prevent database writes during doctest + >>> with patch('mcpgateway.services.tool_service.structured_logger'): + ... result = asyncio.run(service.invoke_tool(db, 'tool_name', {})) + ... isinstance(result, object) True """ # pylint: disable=comparison-with-callable @@ -1224,7 +1468,8 @@ async def invoke_tool( global_context.server_id = gateway_id else: # Create new context (fallback when middleware didn't run) - request_id = uuid.uuid4().hex + # Use correlation ID from context if available, otherwise generate new one + request_id = get_correlation_id() or uuid.uuid4().hex gateway_id = getattr(tool, "gateway_id", "unknown") server_id = gateway_id if isinstance(gateway_id, str) else "unknown" global_context = GlobalContext(request_id=request_id, server_id=server_id, tenant_id=None, user=app_user_email) @@ -1445,12 +1690,58 @@ async def connect_to_sse_server(server_url: str, headers: dict = headers): Returns: ToolResult: Result of tool call + + Raises: + Exception: On connection or communication errors """ - async with sse_client(url=server_url, headers=headers, httpx_client_factory=get_httpx_client_factory) as streams: - async with ClientSession(*streams) as session: - await session.initialize() - tool_call_result = await session.call_tool(tool.original_name, arguments) - return tool_call_result + # Get correlation ID for distributed tracing + correlation_id = get_correlation_id() + + # Add correlation ID to headers + if correlation_id and headers: + headers["X-Correlation-ID"] = correlation_id + + # Log MCP call start + mcp_start_time = time.time() + structured_logger.log( + level="INFO", + message=f"MCP tool call started: {tool.original_name}", + component="tool_service", + correlation_id=correlation_id, + metadata={"event": "mcp_call_started", "tool_name": tool.original_name, "tool_id": tool.id, "server_url": server_url, "transport": "sse"}, + ) + + try: + async with sse_client(url=server_url, headers=headers, httpx_client_factory=get_httpx_client_factory) as streams: + async with ClientSession(*streams) as session: + await session.initialize() + tool_call_result = await session.call_tool(tool.original_name, arguments) + + # Log successful MCP call + mcp_duration_ms = (time.time() - mcp_start_time) * 1000 + structured_logger.log( + level="INFO", + message=f"MCP tool call completed: {tool.original_name}", + component="tool_service", + correlation_id=correlation_id, + duration_ms=mcp_duration_ms, + metadata={"event": "mcp_call_completed", "tool_name": tool.original_name, "tool_id": tool.id, "transport": "sse", "success": True}, + ) + + return tool_call_result + except Exception as e: + # Log failed MCP call + mcp_duration_ms = (time.time() - mcp_start_time) * 1000 + structured_logger.log( + level="ERROR", + message=f"MCP tool call failed: {tool.original_name}", + component="tool_service", + correlation_id=correlation_id, + duration_ms=mcp_duration_ms, + error_details={"error_type": type(e).__name__, "error_message": str(e)}, + metadata={"event": "mcp_call_failed", "tool_name": tool.original_name, "tool_id": tool.id, "transport": "sse"}, + ) + raise async def connect_to_streamablehttp_server(server_url: str, headers: dict = headers): """Connect to an MCP server running with Streamable HTTP transport. @@ -1461,12 +1752,58 @@ async def connect_to_streamablehttp_server(server_url: str, headers: dict = head Returns: ToolResult: Result of tool call + + Raises: + Exception: On connection or communication errors """ - async with streamablehttp_client(url=server_url, headers=headers, httpx_client_factory=get_httpx_client_factory) as (read_stream, write_stream, _get_session_id): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() - tool_call_result = await session.call_tool(tool.original_name, arguments) - return tool_call_result + # Get correlation ID for distributed tracing + correlation_id = get_correlation_id() + + # Add correlation ID to headers + if correlation_id and headers: + headers["X-Correlation-ID"] = correlation_id + + # Log MCP call start + mcp_start_time = time.time() + structured_logger.log( + level="INFO", + message=f"MCP tool call started: {tool.original_name}", + component="tool_service", + correlation_id=correlation_id, + metadata={"event": "mcp_call_started", "tool_name": tool.original_name, "tool_id": tool.id, "server_url": server_url, "transport": "streamablehttp"}, + ) + + try: + async with streamablehttp_client(url=server_url, headers=headers, httpx_client_factory=get_httpx_client_factory) as (read_stream, write_stream, _get_session_id): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + tool_call_result = await session.call_tool(tool.original_name, arguments) + + # Log successful MCP call + mcp_duration_ms = (time.time() - mcp_start_time) * 1000 + structured_logger.log( + level="INFO", + message=f"MCP tool call completed: {tool.original_name}", + component="tool_service", + correlation_id=correlation_id, + duration_ms=mcp_duration_ms, + metadata={"event": "mcp_call_completed", "tool_name": tool.original_name, "tool_id": tool.id, "transport": "streamablehttp", "success": True}, + ) + + return tool_call_result + except Exception as e: + # Log failed MCP call + mcp_duration_ms = (time.time() - mcp_start_time) * 1000 + structured_logger.log( + level="ERROR", + message=f"MCP tool call failed: {tool.original_name}", + component="tool_service", + correlation_id=correlation_id, + duration_ms=mcp_duration_ms, + error_details={"error_type": type(e).__name__, "error_message": str(e)}, + metadata={"event": "mcp_call_failed", "tool_name": tool.original_name, "tool_id": tool.id, "transport": "streamablehttp"}, + ) + raise tool_gateway_id = tool.gateway_id tool_gateway = db.execute(select(DbGateway).where(DbGateway.id == tool_gateway_id).where(DbGateway.enabled)).scalar_one_or_none() @@ -1546,12 +1883,44 @@ async def connect_to_streamablehttp_server(server_url: str, headers: dict = head span.set_attribute("error.message", str(e)) raise ToolInvocationError(f"Tool invocation failed: {error_message}") finally: + # Calculate duration + duration_ms = (time.monotonic() - start_time) * 1000 + # Add final span attributes if span: span.set_attribute("success", success) - span.set_attribute("duration.ms", (time.monotonic() - start_time) * 1000) + span.set_attribute("duration.ms", duration_ms) + + # Record tool metric await self._record_tool_metric(db, tool, start_time, success, error_message) + # Log structured message with performance tracking + if success: + structured_logger.info( + f"Tool '{name}' invoked successfully", + user_id=app_user_email, + resource_type="tool", + resource_id=str(tool.id), + resource_action="invoke", + duration_ms=duration_ms, + custom_fields={"tool_name": name, "integration_type": tool.integration_type, "arguments_count": len(arguments) if arguments else 0}, + ) + else: + structured_logger.error( + f"Tool '{name}' invocation failed", + error=Exception(error_message) if error_message else None, + user_id=app_user_email, + resource_type="tool", + resource_id=str(tool.id), + resource_action="invoke", + duration_ms=duration_ms, + custom_fields={"tool_name": name, "integration_type": tool.integration_type, "error_message": error_message}, + ) + + # Track performance with threshold checking + with perf_tracker.track_operation("tool_invocation", name): + pass # Duration already captured above + async def update_tool( self, db: Session, @@ -1696,24 +2065,142 @@ async def update_tool( db.refresh(tool) await self._notify_tool_updated(tool) logger.info(f"Updated tool: {tool.name}") + + # Structured logging: Audit trail for tool update + changes = [] + if tool_update.name: + changes.append(f"name: {tool_update.name}") + if tool_update.visibility: + changes.append(f"visibility: {tool_update.visibility}") + if tool_update.description: + changes.append("description updated") + + audit_trail.log_action( + user_id=user_email or modified_by or "system", + action="update_tool", + resource_type="tool", + resource_id=tool.id, + resource_name=tool.name, + user_email=user_email, + team_id=tool.team_id, + client_ip=modified_from_ip, + user_agent=modified_user_agent, + new_values={ + "name": tool.name, + "display_name": tool.display_name, + "version": tool.version, + }, + context={ + "modified_via": modified_via, + "changes": ", ".join(changes) if changes else "metadata only", + }, + db=db, + ) + + # Structured logging: Log successful tool update + structured_logger.log( + level="INFO", + message="Tool updated successfully", + event_type="tool_updated", + component="tool_service", + user_id=modified_by, + user_email=user_email, + team_id=tool.team_id, + resource_type="tool", + resource_id=tool.id, + custom_fields={ + "tool_name": tool.name, + "version": tool.version, + }, + db=db, + ) + return self._convert_tool_to_read(tool) - except PermissionError: + except PermissionError as pe: db.rollback() + + # Structured logging: Log permission error + structured_logger.log( + level="WARNING", + message="Tool update failed due to permission error", + event_type="tool_update_permission_denied", + component="tool_service", + user_email=user_email, + resource_type="tool", + resource_id=tool_id, + error=pe, + db=db, + ) raise except IntegrityError as ie: db.rollback() logger.error(f"IntegrityError during tool update: {ie}") + + # Structured logging: Log database integrity error + structured_logger.log( + level="ERROR", + message="Tool update failed due to database integrity error", + event_type="tool_update_failed", + component="tool_service", + user_id=modified_by, + user_email=user_email, + resource_type="tool", + resource_id=tool_id, + error=ie, + db=db, + ) raise ie except ToolNotFoundError as tnfe: db.rollback() logger.error(f"Tool not found during update: {tnfe}") + + # Structured logging: Log not found error + structured_logger.log( + level="ERROR", + message="Tool update failed - tool not found", + event_type="tool_not_found", + component="tool_service", + user_email=user_email, + resource_type="tool", + resource_id=tool_id, + error=tnfe, + db=db, + ) raise tnfe except ToolNameConflictError as tnce: db.rollback() logger.error(f"Tool name conflict during update: {tnce}") + + # Structured logging: Log name conflict error + structured_logger.log( + level="WARNING", + message="Tool update failed due to name conflict", + event_type="tool_name_conflict", + component="tool_service", + user_id=modified_by, + user_email=user_email, + resource_type="tool", + resource_id=tool_id, + error=tnce, + db=db, + ) raise tnce except Exception as ex: db.rollback() + + # Structured logging: Log generic tool update failure + structured_logger.log( + level="ERROR", + message="Tool update failed", + event_type="tool_update_failed", + component="tool_service", + user_id=modified_by, + user_email=user_email, + resource_type="tool", + resource_id=tool_id, + error=ex, + db=db, + ) raise ToolError(f"Failed to update tool: {str(ex)}") async def _notify_tool_updated(self, tool: DbTool) -> None: diff --git a/mcpgateway/static/admin.js b/mcpgateway/static/admin.js index ef418b617..f3e73d21a 100644 --- a/mcpgateway/static/admin.js +++ b/mcpgateway/static/admin.js @@ -3655,7 +3655,7 @@ function openResourceTestModal(resource) { // 2️⃣ If no template → show a simple message fieldsContainer.innerHTML = `
${escapeHtml(trace.correlation_id)}
+ ${escapeHtml(audit.resource_id || "-")}
+ | Time | Level | - Entity + Component | Message | ++ User + | ++ Duration + | ++ Correlation ID + |
|---|