diff --git a/Containerfile.lite b/Containerfile.lite index 94b8fa0f0..92f50fc48 100644 --- a/Containerfile.lite +++ b/Containerfile.lite @@ -76,7 +76,7 @@ SHELL ["/bin/bash", "-euo", "pipefail", "-c"] ARG PYTHON_VERSION ARG ROOTFS_PATH -ARG TARGETPLATFORM +ARG TARGETPLATFORM=linux/amd64 ARG GRPC_PYTHON_BUILD_SYSTEM_OPENSSL='False' # ---------------------------------------------------------------------------- diff --git a/Makefile b/Makefile index d8ced2b75..3331906c8 100644 --- a/Makefile +++ b/Makefile @@ -2658,7 +2658,7 @@ docker-dev: @$(MAKE) container-build CONTAINER_RUNTIME=docker CONTAINER_FILE=Containerfile docker: - @$(MAKE) container-build CONTAINER_RUNTIME=docker CONTAINER_FILE=Containerfile + @$(MAKE) container-build CONTAINER_RUNTIME=docker CONTAINER_FILE=Containerfile.lite docker-prod: @DOCKER_CONTENT_TRUST=1 $(MAKE) container-build CONTAINER_RUNTIME=docker CONTAINER_FILE=Containerfile.lite diff --git a/mcpgateway/admin.py b/mcpgateway/admin.py index 5d859a2cb..490d59b73 100644 --- a/mcpgateway/admin.py +++ b/mcpgateway/admin.py @@ -95,6 +95,7 @@ ) from mcpgateway.services.a2a_service import A2AAgentError, A2AAgentNameConflictError, A2AAgentNotFoundError, A2AAgentService from mcpgateway.services.catalog_service import catalog_service +from mcpgateway.services.encryption_service import get_encryption_service from mcpgateway.services.export_service import ExportError, ExportService from mcpgateway.services.gateway_service import GatewayConnectionError, GatewayNameConflictError, GatewayNotFoundError, GatewayService, GatewayUrlConflictError from mcpgateway.services.import_service import ConflictStrategy @@ -113,7 +114,6 @@ from mcpgateway.utils.create_jwt_token import create_jwt_token, get_jwt_token from mcpgateway.utils.error_formatter import ErrorFormatter from mcpgateway.utils.metadata_capture import MetadataCapture -from mcpgateway.utils.oauth_encryption import get_oauth_encryption from mcpgateway.utils.pagination import generate_pagination_links from mcpgateway.utils.passthrough_headers import PassthroughHeadersError from mcpgateway.utils.retry_manager import ResilientHttpClient @@ -6194,7 +6194,7 @@ async def admin_add_gateway(request: Request, db: Session = Depends(get_db), use oauth_config = json.loads(oauth_config_json) # Encrypt the client secret if present if oauth_config and "client_secret" in oauth_config: - encryption = get_oauth_encryption(settings.auth_encryption_secret) + encryption = get_encryption_service(settings.auth_encryption_secret) oauth_config["client_secret"] = encryption.encrypt_secret(oauth_config["client_secret"]) except (json.JSONDecodeError, ValueError) as e: LOGGER.error(f"Failed to parse OAuth config: {e}") @@ -6231,7 +6231,7 @@ async def admin_add_gateway(request: Request, db: Session = Depends(get_db), use oauth_config["client_id"] = oauth_client_id if oauth_client_secret: # Encrypt the client secret - encryption = get_oauth_encryption(settings.auth_encryption_secret) + encryption = get_encryption_service(settings.auth_encryption_secret) oauth_config["client_secret"] = encryption.encrypt_secret(oauth_client_secret) # Add username and password for password grant type @@ -6503,7 +6503,7 @@ async def admin_edit_gateway( oauth_config = json.loads(oauth_config_json) # Encrypt the client secret if present and not empty if oauth_config and "client_secret" in oauth_config and oauth_config["client_secret"]: - encryption = get_oauth_encryption(settings.auth_encryption_secret) + encryption = get_encryption_service(settings.auth_encryption_secret) oauth_config["client_secret"] = encryption.encrypt_secret(oauth_config["client_secret"]) except (json.JSONDecodeError, ValueError) as e: LOGGER.error(f"Failed to parse OAuth config: {e}") @@ -6540,7 +6540,7 @@ async def admin_edit_gateway( oauth_config["client_id"] = oauth_client_id if oauth_client_secret: # Encrypt the client secret - encryption = get_oauth_encryption(settings.auth_encryption_secret) + encryption = get_encryption_service(settings.auth_encryption_secret) oauth_config["client_secret"] = encryption.encrypt_secret(oauth_client_secret) # Add username and password for password grant type @@ -9571,7 +9571,7 @@ async def admin_add_a2a_agent( oauth_config = json.loads(oauth_config_json) # Encrypt the client secret if present if oauth_config and "client_secret" in oauth_config: - encryption = get_oauth_encryption(settings.auth_encryption_secret) + encryption = get_encryption_service(settings.auth_encryption_secret) oauth_config["client_secret"] = encryption.encrypt_secret(oauth_config["client_secret"]) except (json.JSONDecodeError, ValueError) as e: LOGGER.error(f"Failed to parse OAuth config: {e}") @@ -9608,7 +9608,7 @@ async def admin_add_a2a_agent( oauth_config["client_id"] = oauth_client_id if oauth_client_secret: # Encrypt the client secret - encryption = get_oauth_encryption(settings.auth_encryption_secret) + encryption = get_encryption_service(settings.auth_encryption_secret) oauth_config["client_secret"] = encryption.encrypt_secret(oauth_client_secret) # Add username and password for password grant type @@ -9890,7 +9890,7 @@ async def admin_edit_a2a_agent( oauth_config = json.loads(oauth_config_json) # Encrypt the client secret if present and not empty if oauth_config and "client_secret" in oauth_config and oauth_config["client_secret"]: - encryption = get_oauth_encryption(settings.auth_encryption_secret) + encryption = get_encryption_service(settings.auth_encryption_secret) oauth_config["client_secret"] = encryption.encrypt_secret(oauth_config["client_secret"]) except (json.JSONDecodeError, ValueError) as e: LOGGER.error(f"Failed to parse OAuth config: {e}") @@ -9927,7 +9927,7 @@ async def admin_edit_a2a_agent( oauth_config["client_id"] = oauth_client_id if oauth_client_secret: # Encrypt the client secret - encryption = get_oauth_encryption(settings.auth_encryption_secret) + encryption = get_encryption_service(settings.auth_encryption_secret) oauth_config["client_secret"] = encryption.encrypt_secret(oauth_client_secret) # Add username and password for password grant type diff --git a/mcpgateway/alembic/versions/a706a3320c56_use_argon2id_for_encryption_key.py b/mcpgateway/alembic/versions/a706a3320c56_use_argon2id_for_encryption_key.py new file mode 100644 index 000000000..09f3075a4 --- /dev/null +++ b/mcpgateway/alembic/versions/a706a3320c56_use_argon2id_for_encryption_key.py @@ -0,0 +1,535 @@ +# -*- coding: utf-8 -*- +"""Use Argon2id for encryption key + +Revision ID: a706a3320c56 +Revises: h2b3c4d5e6f7 +Create Date: 2025-10-30 15:31:25.115536 + +""" + +# Standard +import base64 +import json +import logging +import os +from typing import Optional, Sequence, Union + +# Third-Party +from alembic import op +from argon2.low_level import hash_secret_raw, Type +from cryptography.fernet import Fernet +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC +import sqlalchemy as sa +from sqlalchemy import text + +# First-Party +from mcpgateway.config import settings + +logger = logging.getLogger(__name__) + +# revision identifiers, used by Alembic. +revision: str = "a706a3320c56" +down_revision: Union[str, Sequence[str], None] = "h2b3c4d5e6f7" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def reencrypt_with_argon2id(encrypted_text: str) -> str: + """Re-encrypts an existing encrypted text using Argon2id KDF. + + Args: + encrypted_text: The original encrypted text using PBKDF2HMAC. + + Returns: + A JSON string containing the Argon2id re-encrypted token and parameters. + """ + encryption_secret = settings.auth_encryption_secret.get_secret_value().encode() + original_kdf = PBKDF2HMAC( + algorithm=hashes.SHA256(), + length=32, + salt=b"mcp_gateway_oauth", # Fixed salt for consistency + iterations=100000, + ) + original_key = base64.urlsafe_b64encode(original_kdf.derive(encryption_secret)) + original_fernet = Fernet(original_key) + original_encrypted_bytes = base64.urlsafe_b64decode(encrypted_text.encode()) + original_decrypted_bytes = original_fernet.decrypt(original_encrypted_bytes) + + time_cost = getattr(settings, "argon2id_time_cost", 3) + memory_cost = getattr(settings, "argon2id_memory_cost", 65536) + parallelism = getattr(settings, "argon2id_parallelism", 1) + hash_len = 32 + + salt = os.urandom(16) + argon2id_raw = hash_secret_raw( + secret=encryption_secret, + salt=salt, + time_cost=time_cost, + memory_cost=memory_cost, # KiB + parallelism=parallelism, + hash_len=hash_len, + type=Type.ID, + ) + argon2id_key = base64.urlsafe_b64encode(argon2id_raw) + argon2id_fernet = Fernet(argon2id_key) + argon2id_encrypted_bytes = argon2id_fernet.encrypt(original_decrypted_bytes) + return json.dumps( + { + "kdf": "argon2id", + "t": time_cost, + "m": memory_cost, + "p": parallelism, + "salt": base64.b64encode(salt).decode(), + "token": argon2id_encrypted_bytes.decode(), + } + ) + + +def reencrypt_with_pbkdf2hmac(argon2id_bundle: str) -> Optional[str]: + """Re-encrypts an Argon2id encrypted bundle back to PBKDF2HMAC. + + Args: + argon2id_bundle: The JSON string containing Argon2id encrypted data. + + Returns: + A PBKDF2HMAC re-encrypted token. + + Raises: + ValueError: If the input is not a valid Argon2id bundle. + """ + try: + argon2id_data = json.loads(argon2id_bundle) + if argon2id_data.get("kdf") != "argon2id": + raise ValueError("Not an Argon2id bundle") + + encryption_secret = settings.auth_encryption_secret.get_secret_value().encode() + salt = base64.b64decode(argon2id_data["salt"]) + time_cost = argon2id_data["t"] + memory_cost = argon2id_data["m"] + parallelism = argon2id_data["p"] + argon2id_raw = hash_secret_raw( + secret=encryption_secret, + salt=salt, + time_cost=time_cost, + memory_cost=memory_cost, # KiB + parallelism=parallelism, + hash_len=32, + type=Type.ID, + ) + argon2id_key = base64.urlsafe_b64encode(argon2id_raw) + argon2id_fernet = Fernet(argon2id_key) + argon2id_encrypted_bytes = argon2id_data["token"].encode() + decrypted_bytes = argon2id_fernet.decrypt(argon2id_encrypted_bytes) + + original_kdf = PBKDF2HMAC( + algorithm=hashes.SHA256(), + length=32, + salt=b"mcp_gateway_oauth", # Fixed salt for consistency + iterations=100000, + ) + original_key = base64.urlsafe_b64encode(original_kdf.derive(encryption_secret)) + original_fernet = Fernet(original_key) + original_encrypted_bytes = original_fernet.encrypt(decrypted_bytes) + return base64.urlsafe_b64encode(original_encrypted_bytes).decode() + except Exception as e: + raise ValueError("Invalid Argon2id bundle") from e + + +def _reflect(conn): + """Reflect relevant tables. + + Args: + conn: The database connection. + + Returns: + A dict of reflected tables. + """ + md = sa.MetaData() + gateways = sa.Table("gateways", md, autoload_with=conn) + a2a_agents = sa.Table("a2a_agents", md, autoload_with=conn) + return {"gateways": gateways, "a2a_agents": a2a_agents} + + +def _is_json(col): + """Check if a column is of JSON type. + + Args: + col: The column to check. + + Returns: + True if the column is of JSON type. + """ + return isinstance(col.type, sa.JSON) + + +def _looks_argon2_bundle(val: Optional[str]) -> bool: + """Heuristic for Argon2id bundle format (JSON with kdf=argon2id). + + Args: + val: The encrypted value. + + Returns: + True if it looks like an Argon2id encrypted token. + """ + if not val: + return False + # Fast path: Fernet tokens usually start with 'gAAAAA'; Argon2 bundle is JSON + if val and val[:1] in ("{", "["): + try: + obj = json.loads(val) + return isinstance(obj, dict) and obj.get("kdf") == "argon2id" + except Exception: + return False + return False + + +def _looks_legacy_pbkdf2_token(val: Optional[str]) -> bool: + """Heuristic for legacy PBKDF2 format (base64-wrapped Fernet token string, not JSON). + + Args: + val: The encrypted value. + + Returns: + True if it looks like a legacy PBKDF2 encrypted token. + """ + if not val or not isinstance(val, str): + return False + # Legacy column stored base64(urlsafe) of the Fernet token (which is itself base64 bytes), + # so it's NOT JSON and usually not starting with '{' + return not val.startswith("{") + + +def _upgrade_value(old: Optional[str]) -> Optional[str]: + """PBKDF2 -> Argon2id bundle, when needed. + + Args: + old: The existing encrypted value. + + Returns: + The re-encrypted value using Argon2id, or None if no change is needed. + """ + if not old: + return None + if _looks_argon2_bundle(old): + return None # already migrated + if not _looks_legacy_pbkdf2_token(old): + return None # unknown format; skip + try: + return reencrypt_with_argon2id(old) + except Exception as e: + logger.warning("Upgrade skip (cannot re-encrypt PBKDF2 value): %s", e) + return None + + +def _downgrade_value(old: Optional[str]) -> Optional[str]: + """Argon2id bundle -> PBKDF2 legacy, when needed. + + Args: + old: The existing encrypted value. + + Returns: + The re-encrypted value using PBKDF2HMAC, or None if no change is needed. + """ + if not old: + return None + if not _looks_argon2_bundle(old): + return None # not an Argon2 bundle + try: + return reencrypt_with_pbkdf2hmac(old) + except Exception as e: + logger.warning("Downgrade skip (cannot re-encrypt Argon2 bundle): %s", e) + return None + + +def _upgrade_json_client_secret(conn, table): + """Upgrade JSON client_secret fields in the given table. + + Args: + conn: The database connection. + table: The table to upgrade. + """ + t = table + sel = sa.select(t.c.id, t.c.oauth_config).where(t.c.oauth_config.isnot(None)) + for row in conn.execute(sel).mappings(): + rid = row["id"] + cfg = row["oauth_config"] + if isinstance(cfg, str): + try: + cfg = json.loads(cfg) + except json.JSONDecodeError as e: + logger.warning("Skipping %s.id=%s: invalid JSON (%s)", table, rid, e) + continue + if not isinstance(cfg, dict): + continue + + old = cfg.get("client_secret") + new = _upgrade_value(old) # your helper + if not new: + continue + + cfg["client_secret"] = new + value = cfg if _is_json(t.c.oauth_config) else json.dumps(cfg) + upd = sa.update(t).where(t.c.id == rid).values(oauth_config=value) + conn.execute(upd) + + +def _downgrade_json_client_secret(conn, table): + """Downgrade JSON client_secret fields in the given table. + + Args: + conn: The database connection. + table: The table to downgrade. + """ + t = table + sel = sa.select(t.c.id, t.c.oauth_config).where(t.c.oauth_config.isnot(None)) + for row in conn.execute(sel).mappings(): + rid = row["id"] + cfg = row["oauth_config"] + if isinstance(cfg, str): + try: + cfg = json.loads(cfg) + except json.JSONDecodeError as e: + logger.warning("Skipping %s.id=%s: invalid JSON (%s)", table, rid, e) + continue + if not isinstance(cfg, dict): + continue + + old = cfg.get("client_secret") + new = _downgrade_value(old) # your helper + if not new: + continue + + cfg["client_secret"] = new + value = cfg if _is_json(t.c.oauth_config) else json.dumps(cfg) + upd = sa.update(t).where(t.c.id == rid).values(oauth_config=value) + conn.execute(upd) + + +def upgrade() -> None: + """Use Argon2id KDF for encryption key re-encryption.""" + bind = op.get_bind() + + conn = op.get_bind() + t = _reflect(conn) + + # JSON: gateways.oauth_config.client_secret + _upgrade_json_client_secret(conn, t["gateways"]) + + # JSON: a2a_agents.oauth_config.client_secret + _upgrade_json_client_secret(conn, t["a2a_agents"]) + + # oauth_tokens: access_token, refresh_token + rows = ( + bind.execute( + text( + """ + SELECT id, access_token, refresh_token + FROM oauth_tokens + WHERE (access_token IS NOT NULL OR refresh_token IS NOT NULL) + """ + ) + ) + .mappings() + .all() + ) + + for r in rows: + tid = r["id"] + at = r["access_token"] + rt = r["refresh_token"] + nat = _upgrade_value(at) + nrt = _upgrade_value(rt) + if nat or nrt: + bind.execute( + text( + """ + UPDATE oauth_tokens + SET access_token = COALESCE(:nat, access_token), + refresh_token = COALESCE(:nrt, refresh_token) + WHERE id = :id + """ + ), + {"nat": nat, "nrt": nrt, "id": tid}, + ) + + # registered_oauth_clients: client_secret_encrypted, registration_access_token_encrypted + rows = ( + bind.execute( + text( + """ + SELECT id, client_secret_encrypted, registration_access_token_encrypted + FROM registered_oauth_clients + WHERE client_secret_encrypted IS NOT NULL + OR registration_access_token_encrypted IS NOT NULL + """ + ) + ) + .mappings() + .all() + ) + + for r in rows: + rid = r["id"] + cs = r["client_secret_encrypted"] + rat = r["registration_access_token_encrypted"] + ncs = _upgrade_value(cs) + nrat = _upgrade_value(rat) + if ncs or nrat: + bind.execute( + text( + """ + UPDATE registered_oauth_clients + SET client_secret_encrypted = COALESCE(:ncs, client_secret_encrypted), + registration_access_token_encrypted = COALESCE(:nrat, registration_access_token_encrypted) + WHERE id = :id + """ + ), + {"ncs": ncs, "nrat": nrat, "id": rid}, + ) + + # sso_providers: client_secret_encrypted + rows = ( + bind.execute( + text( + """ + SELECT id, client_secret_encrypted + FROM sso_providers + WHERE client_secret_encrypted IS NOT NULL + """ + ) + ) + .mappings() + .all() + ) + + for r in rows: + sid = r["id"] + cs = r["client_secret_encrypted"] + ncs = _upgrade_value(cs) + if ncs: + bind.execute( + text( + """ + UPDATE sso_providers + SET client_secret_encrypted = :ncs + WHERE id = :id + """ + ), + {"ncs": ncs, "id": sid}, + ) + + logger.info("Upgrade complete: PBKDF2 -> Argon2id bundle re-encryption.") + + +def downgrade() -> None: + """Revert to PBKDF2HMAC KDF for encryption key re-encryption.""" + bind = op.get_bind() + + # JSON: gateways.oauth_config.client_secret + _downgrade_json_client_secret(bind, "gateways") + + # JSON: a2a_agents.oauth_config.client_secret + _downgrade_json_client_secret(bind, "a2a_agents") + + # oauth_tokens: access_token, refresh_token + rows = ( + bind.execute( + text( + """ + SELECT id, access_token, refresh_token + FROM oauth_tokens + WHERE (access_token IS NOT NULL OR refresh_token IS NOT NULL) + """ + ) + ) + .mappings() + .all() + ) + + for r in rows: + tid = r["id"] + at = r["access_token"] + rt = r["refresh_token"] + nat = _downgrade_value(at) + nrt = _downgrade_value(rt) + if nat or nrt: + bind.execute( + text( + """ + UPDATE oauth_tokens + SET access_token = COALESCE(:nat, access_token), + refresh_token = COALESCE(:nrt, refresh_token) + WHERE id = :id + """ + ), + {"nat": nat, "nrt": nrt, "id": tid}, + ) + + # registered_oauth_clients: client_secret_encrypted, registration_access_token_encrypted + rows = ( + bind.execute( + text( + """ + SELECT id, client_secret_encrypted, registration_access_token_encrypted + FROM registered_oauth_clients + WHERE client_secret_encrypted IS NOT NULL + OR registration_access_token_encrypted IS NOT NULL + """ + ) + ) + .mappings() + .all() + ) + + for r in rows: + rid = r["id"] + cs = r["client_secret_encrypted"] + rat = r["registration_access_token_encrypted"] + ncs = _downgrade_value(cs) + nrat = _downgrade_value(rat) + if ncs or nrat: + bind.execute( + text( + """ + UPDATE registered_oauth_clients + SET client_secret_encrypted = COALESCE(:ncs, client_secret_encrypted), + registration_access_token_encrypted = COALESCE(:nrat, registration_access_token_encrypted) + WHERE id = :id + """ + ), + {"ncs": ncs, "nrat": nrat, "id": rid}, + ) + + # sso_providers: client_secret_encrypted + rows = ( + bind.execute( + text( + """ + SELECT id, client_secret_encrypted + FROM sso_providers + WHERE client_secret_encrypted IS NOT NULL + """ + ) + ) + .mappings() + .all() + ) + + for r in rows: + sid = r["id"] + cs = r["client_secret_encrypted"] + ncs = _downgrade_value(cs) + if ncs: + bind.execute( + text( + """ + UPDATE sso_providers + SET client_secret_encrypted = :ncs + WHERE id = :id + """ + ), + {"ncs": ncs, "id": sid}, + ) + + logger.info("Downgrade complete: Argon2id bundle -> PBKDF2 legacy re-encryption.") diff --git a/mcpgateway/routers/oauth_router.py b/mcpgateway/routers/oauth_router.py index e52ceb093..f8a1c17d3 100644 --- a/mcpgateway/routers/oauth_router.py +++ b/mcpgateway/routers/oauth_router.py @@ -113,9 +113,9 @@ async def initiate_oauth_flow( decrypted_secret = None if registered_client.client_secret_encrypted: # First-Party - from mcpgateway.utils.oauth_encryption import get_oauth_encryption + from mcpgateway.services.encryption_service import get_encryption_service - encryption = get_oauth_encryption(settings.auth_encryption_secret) + encryption = get_encryption_service(settings.auth_encryption_secret) decrypted_secret = encryption.decrypt_secret(registered_client.client_secret_encrypted) # Update oauth_config with registered credentials diff --git a/mcpgateway/services/dcr_service.py b/mcpgateway/services/dcr_service.py index 88573b126..3e95d1a52 100644 --- a/mcpgateway/services/dcr_service.py +++ b/mcpgateway/services/dcr_service.py @@ -25,7 +25,7 @@ # First-Party from mcpgateway.config import get_settings from mcpgateway.db import RegisteredOAuthClient -from mcpgateway.utils.oauth_encryption import get_oauth_encryption +from mcpgateway.services.encryption_service import get_encryption_service logger = logging.getLogger(__name__) @@ -168,7 +168,7 @@ async def register_client(self, gateway_id: str, gateway_name: str, issuer: str, raise DcrError(f"Failed to register client with {issuer}: {e}") # Encrypt secrets - encryption = get_oauth_encryption(self.settings.auth_encryption_secret) + encryption = get_encryption_service(self.settings.auth_encryption_secret) client_secret = registration_response.get("client_secret") client_secret_encrypted = encryption.encrypt_secret(client_secret) if client_secret else None @@ -260,7 +260,7 @@ async def update_client_registration(self, client_record: RegisteredOAuthClient, raise DcrError("Cannot update client: no registration_access_token available") # Decrypt registration access token - encryption = get_oauth_encryption(self.settings.auth_encryption_secret) + encryption = get_encryption_service(self.settings.auth_encryption_secret) registration_access_token = encryption.decrypt_secret(client_record.registration_access_token_encrypted) # Build update request @@ -313,7 +313,7 @@ async def delete_client_registration(self, client_record: RegisteredOAuthClient, return True # Consider it deleted locally # Decrypt registration access token - encryption = get_oauth_encryption(self.settings.auth_encryption_secret) + encryption = get_encryption_service(self.settings.auth_encryption_secret) registration_access_token = encryption.decrypt_secret(client_record.registration_access_token_encrypted) # Send delete request diff --git a/mcpgateway/services/encryption_service.py b/mcpgateway/services/encryption_service.py new file mode 100644 index 000000000..b5c926851 --- /dev/null +++ b/mcpgateway/services/encryption_service.py @@ -0,0 +1,201 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/services/encryption_service.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Mihai Criveti, Madhav Kandukuri + +Encryption Service. + +This service provides encryption and decryption functions for client secrets +using the AUTH_ENCRYPTION_SECRET from configuration. +""" + +# Standard +import base64 +import json +import logging +import os +from typing import Optional, Union + +# Third-Party +from argon2.low_level import hash_secret_raw, Type +from cryptography.fernet import Fernet +from pydantic import SecretStr + +# First-Party +from mcpgateway.config import settings + +logger = logging.getLogger(__name__) + + +class EncryptionService: + """Handles encryption and decryption of client secrets. + + Examples: + Basic roundtrip: + >>> enc = EncryptionService(SecretStr('very-secret-key')) + >>> cipher = enc.encrypt_secret('hello') + >>> isinstance(cipher, str) and enc.is_encrypted(cipher) + True + >>> enc.decrypt_secret(cipher) + 'hello' + + Non-encrypted text detection: + >>> enc.is_encrypted('plain-text') + False + """ + + def __init__( + self, encryption_secret: Union[SecretStr, str], time_cost: Optional[int] = None, memory_cost: Optional[int] = None, parallelism: Optional[int] = None, hash_len: int = 32, salt_len: int = 16 + ): + """Initialize the encryption handler. + + Args: + encryption_secret: Secret key for encryption/decryption + time_cost: Argon2id time cost parameter + memory_cost: Argon2id memory cost parameter (in KiB) + parallelism: Argon2id parallelism parameter + hash_len: Length of the derived key + salt_len: Length of the salt + """ + # Handle both SecretStr and plain string for backwards compatibility + if isinstance(encryption_secret, SecretStr): + self.encryption_secret = encryption_secret.get_secret_value().encode() + else: + # If a plain string is passed, use it directly (for testing/legacy code) + self.encryption_secret = str(encryption_secret).encode() + self.time_cost = time_cost or getattr(settings, "argon2id_time_cost", 3) + self.memory_cost = memory_cost or getattr(settings, "argon2id_memory_cost", 65536) + self.parallelism = parallelism or getattr(settings, "argon2id_parallelism", 1) + self.hash_len = hash_len + self.salt_len = salt_len + + def derive_key_argon2id(self, passphrase: bytes, salt: bytes, time_cost: int, memory_cost: int, parallelism: int) -> bytes: + """Derive a key from a passphrase using Argon2id. + + Args: + passphrase: The passphrase to derive the key from + salt: The salt to use in key derivation + time_cost: Argon2id time cost parameter + memory_cost: Argon2id memory cost parameter (in KiB) + parallelism: Argon2id parallelism parameter + + Returns: + The derived key + """ + raw = hash_secret_raw( + secret=passphrase, + salt=salt, + time_cost=time_cost, + memory_cost=memory_cost, # KiB + parallelism=parallelism, + hash_len=self.hash_len, + type=Type.ID, + ) + return base64.urlsafe_b64encode(raw) + + def encrypt_secret(self, plaintext: str) -> str: + """Encrypt a plaintext secret. + + Args: + plaintext: The secret to encrypt + + Returns: + Base64-encoded encrypted string + + Raises: + Exception: If encryption fails + """ + try: + salt = os.urandom(16) + key = self.derive_key_argon2id(self.encryption_secret, salt, self.time_cost, self.memory_cost, self.parallelism) + fernet = Fernet(key) + encrypted = fernet.encrypt(plaintext.encode()) + return json.dumps( + { + "kdf": "argon2id", + "t": self.time_cost, + "m": self.memory_cost, + "p": self.parallelism, + "salt": base64.b64encode(salt).decode(), + "token": encrypted.decode(), + } + ) + except Exception as e: + logger.error(f"Failed to encrypt secret: {e}") + raise + + def decrypt_secret(self, bundle_json: str) -> Optional[str]: + """Decrypt an encrypted secret. + + Args: + bundle_json: str: JSON string containing encryption metadata and token + + Returns: + Decrypted secret string, or None if decryption fails + """ + try: + b = json.loads(bundle_json) + salt = base64.b64decode(b["salt"]) + key = self.derive_key_argon2id(self.encryption_secret, salt, time_cost=b["t"], memory_cost=b["m"], parallelism=b["p"]) + fernet = Fernet(key) + decrypted = fernet.decrypt(b["token"].encode()) + return decrypted.decode() + except Exception as e: + logger.error(f"Failed to decrypt secret: {e}") + return None + + def is_encrypted(self, text: str) -> bool: + """Check if a string appears to be encrypted. + + Args: + text: String to check + + Returns: + True if the string appears to be encrypted + + Note: + Supports both legacy PBKDF2 (base64-wrapped Fernet) and new Argon2id + (JSON bundle) formats. Checks JSON format first, then falls back to + base64 check for legacy format. + """ + if not text: + return False + + # Check for new Argon2id JSON bundle format + if text.startswith("{"): + try: + obj = json.loads(text) + if isinstance(obj, dict) and obj.get("kdf") == "argon2id": + return True + except (json.JSONDecodeError, ValueError, KeyError): + # Not valid JSON or missing expected structure - continue to legacy check + pass + + # Check for legacy PBKDF2 base64-wrapped Fernet format + try: + decoded = base64.urlsafe_b64decode(text.encode()) + # Encrypted data should be at least 32 bytes (Fernet minimum) + return len(decoded) >= 32 + except Exception: + return False + + +def get_encryption_service(encryption_secret: Union[SecretStr, str]) -> EncryptionService: + """Get an EncryptionService instance. + + Args: + encryption_secret: Secret key for encryption/decryption (SecretStr or plain string) + + Returns: + EncryptionService instance + + Examples: + >>> enc = get_encryption_service(SecretStr('k')) + >>> isinstance(enc, EncryptionService) + True + >>> enc2 = get_encryption_service('plain-key') + >>> isinstance(enc2, EncryptionService) + True + """ + return EncryptionService(encryption_secret) diff --git a/mcpgateway/services/metrics.py b/mcpgateway/services/metrics.py index f339d0b25..b3276ba23 100644 --- a/mcpgateway/services/metrics.py +++ b/mcpgateway/services/metrics.py @@ -118,4 +118,9 @@ def setup_metrics(app): @app.get("/metrics/prometheus") async def metrics_disabled(): + """Returns metrics response when metrics collection is disabled. + + Returns: + Response: HTTP 503 response indicating metrics are disabled. + """ return Response(content='{"error": "Metrics collection is disabled"}', media_type="application/json", status_code=status.HTTP_503_SERVICE_UNAVAILABLE) diff --git a/mcpgateway/services/oauth_manager.py b/mcpgateway/services/oauth_manager.py index 1e755af00..610ed8213 100644 --- a/mcpgateway/services/oauth_manager.py +++ b/mcpgateway/services/oauth_manager.py @@ -28,7 +28,7 @@ # First-Party from mcpgateway.config import get_settings -from mcpgateway.utils.oauth_encryption import get_oauth_encryption +from mcpgateway.services.encryption_service import get_encryption_service logger = logging.getLogger(__name__) @@ -222,7 +222,7 @@ async def _client_credentials_flow(self, credentials: Dict[str, Any]) -> str: if len(client_secret) > 50: # Simple heuristic: encrypted secrets are longer try: settings = get_settings() - encryption = get_oauth_encryption(settings.auth_encryption_secret) + encryption = get_encryption_service(settings.auth_encryption_secret) decrypted_secret = encryption.decrypt_secret(client_secret) if decrypted_secret: client_secret = decrypted_secret @@ -313,7 +313,7 @@ async def _password_flow(self, credentials: Dict[str, Any]) -> str: if client_secret and len(client_secret) > 50: # Simple heuristic: encrypted secrets are longer try: settings = get_settings() - encryption = get_oauth_encryption(settings.auth_encryption_secret) + encryption = get_encryption_service(settings.auth_encryption_secret) decrypted_secret = encryption.decrypt_secret(client_secret) if decrypted_secret: client_secret = decrypted_secret @@ -430,7 +430,7 @@ async def exchange_code_for_token(self, credentials: Dict[str, Any], code: str, if client_secret and len(client_secret) > 50: # Simple heuristic: encrypted secrets are longer try: settings = get_settings() - encryption = get_oauth_encryption(settings.auth_encryption_secret) + encryption = get_encryption_service(settings.auth_encryption_secret) decrypted_secret = encryption.decrypt_secret(client_secret) if decrypted_secret: client_secret = decrypted_secret @@ -1007,7 +1007,7 @@ async def _exchange_code_for_tokens(self, credentials: Dict[str, Any], code: str if client_secret and len(client_secret) > 50: # Simple heuristic: encrypted secrets are longer try: settings = get_settings() - encryption = get_oauth_encryption(settings.auth_encryption_secret) + encryption = get_encryption_service(settings.auth_encryption_secret) decrypted_secret = encryption.decrypt_secret(client_secret) if decrypted_secret: client_secret = decrypted_secret diff --git a/mcpgateway/services/sso_service.py b/mcpgateway/services/sso_service.py index 032b22bf7..fb3bbdc78 100644 --- a/mcpgateway/services/sso_service.py +++ b/mcpgateway/services/sso_service.py @@ -22,11 +22,7 @@ import urllib.parse # Third-Party -from cryptography.fernet import Fernet -from cryptography.hazmat.primitives import hashes -from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC import httpx -from pydantic import SecretStr from sqlalchemy import and_, select from sqlalchemy.orm import Session @@ -34,6 +30,7 @@ from mcpgateway.config import settings from mcpgateway.db import PendingUserApproval, SSOAuthSession, SSOProvider, utc_now from mcpgateway.services.email_auth_service import EmailAuthService +from mcpgateway.services.encryption_service import get_encryption_service from mcpgateway.utils.create_jwt_token import create_jwt_token # Logger @@ -64,39 +61,7 @@ def __init__(self, db: Session): """ self.db = db self.auth_service = EmailAuthService(db) - self._encryption_key = self._get_or_create_encryption_key() - - def _get_or_create_encryption_key(self) -> bytes: - """Get or create encryption key for client secrets. - - Returns: - Encryption key bytes - """ - # Use the same encryption secret as the auth service - key = settings.auth_encryption_secret - - if not key: - # Generate a new key - in production, this should be persisted - key = Fernet.generate_key() - # Derive a proper Fernet key from the secret - - # Unwrap SecretStr if necessary - if isinstance(key, SecretStr): - key = key.get_secret_value() - - # Convert string to bytes - if isinstance(key, str): - key = key.encode("utf-8") - - # Derive a 32-byte key using PBKDF2 - kdf = PBKDF2HMAC( - algorithm=hashes.SHA256(), - length=32, - salt=b"sso_salt", # Static salt for consistency - iterations=100000, - ) - derived_key = base64.urlsafe_b64encode(kdf.derive(key)) - return derived_key + self._encryption = get_encryption_service(settings.auth_encryption_secret) def _encrypt_secret(self, secret: str) -> str: """Encrypt a client secret for secure storage. @@ -107,10 +72,9 @@ def _encrypt_secret(self, secret: str) -> str: Returns: Encrypted secret string """ - fernet = Fernet(self._encryption_key) - return fernet.encrypt(secret.encode()).decode() + return self._encryption.encrypt_secret(secret) - def _decrypt_secret(self, encrypted_secret: str) -> str: + def _decrypt_secret(self, encrypted_secret: str) -> Optional[str]: """Decrypt a client secret for use. Args: @@ -119,8 +83,11 @@ def _decrypt_secret(self, encrypted_secret: str) -> str: Returns: Plain text client secret """ - fernet = Fernet(self._encryption_key) - return fernet.decrypt(encrypted_secret.encode()).decode() + decrypted: str | None = self._encryption.decrypt_secret(encrypted_secret) + if decrypted: + return decrypted + + return None def list_enabled_providers(self) -> List[SSOProvider]: """Get list of enabled SSO providers. diff --git a/mcpgateway/services/token_storage_service.py b/mcpgateway/services/token_storage_service.py index da441c7b7..ef6b18e66 100644 --- a/mcpgateway/services/token_storage_service.py +++ b/mcpgateway/services/token_storage_service.py @@ -22,8 +22,8 @@ # First-Party from mcpgateway.config import get_settings from mcpgateway.db import OAuthToken +from mcpgateway.services.encryption_service import get_encryption_service from mcpgateway.services.oauth_manager import OAuthError -from mcpgateway.utils.oauth_encryption import get_oauth_encryption logger = logging.getLogger(__name__) @@ -68,7 +68,7 @@ def __init__(self, db: Session): self.db = db try: settings = get_settings() - self.encryption = get_oauth_encryption(settings.auth_encryption_secret) + self.encryption = get_encryption_service(settings.auth_encryption_secret) except (ImportError, AttributeError): logger.warning("OAuth encryption not available, using plain text storage") self.encryption = None diff --git a/mcpgateway/utils/oauth_encryption.py b/mcpgateway/utils/oauth_encryption.py deleted file mode 100644 index bd58a49a6..000000000 --- a/mcpgateway/utils/oauth_encryption.py +++ /dev/null @@ -1,141 +0,0 @@ -# -*- coding: utf-8 -*- -"""Location: ./mcpgateway/utils/oauth_encryption.py -Copyright 2025 -SPDX-License-Identifier: Apache-2.0 -Authors: Mihai Criveti - -OAuth Encryption Utilities. - -This module provides encryption and decryption functions for OAuth client secrets -using the AUTH_ENCRYPTION_SECRET from configuration. -""" - -# Standard -import base64 -import logging -from typing import Optional - -# Third-Party -from cryptography.fernet import Fernet -from cryptography.hazmat.primitives import hashes -from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC -from pydantic import SecretStr - -logger = logging.getLogger(__name__) - - -class OAuthEncryption: - """Handles encryption and decryption of OAuth client secrets. - - Examples: - Basic roundtrip: - >>> enc = OAuthEncryption(SecretStr('very-secret-key')) - >>> cipher = enc.encrypt_secret('hello') - >>> isinstance(cipher, str) and enc.is_encrypted(cipher) - True - >>> enc.decrypt_secret(cipher) - 'hello' - - Non-encrypted text detection: - >>> enc.is_encrypted('plain-text') - False - """ - - def __init__(self, encryption_secret: SecretStr): - """Initialize the encryption handler. - - Args: - encryption_secret: Secret key for encryption/decryption - """ - self.encryption_secret = encryption_secret.get_secret_value().encode() - self._fernet = None - - def _get_fernet(self) -> Fernet: - """Get or create Fernet instance for encryption. - - Returns: - Fernet instance for encryption/decryption - """ - if self._fernet is None: - # Derive a key from the encryption secret using PBKDF2 - kdf = PBKDF2HMAC( - algorithm=hashes.SHA256(), - length=32, - salt=b"mcp_gateway_oauth", # Fixed salt for consistency - iterations=100000, - ) - key = base64.urlsafe_b64encode(kdf.derive(self.encryption_secret)) - self._fernet = Fernet(key) - return self._fernet - - def encrypt_secret(self, plaintext: str) -> str: - """Encrypt a plaintext secret. - - Args: - plaintext: The secret to encrypt - - Returns: - Base64-encoded encrypted string - - Raises: - Exception: If encryption fails - """ - try: - fernet = self._get_fernet() - encrypted = fernet.encrypt(plaintext.encode()) - return base64.urlsafe_b64encode(encrypted).decode() - except Exception as e: - logger.error(f"Failed to encrypt OAuth secret: {e}") - raise - - def decrypt_secret(self, encrypted_text: str) -> Optional[str]: - """Decrypt an encrypted secret. - - Args: - encrypted_text: Base64-encoded encrypted string - - Returns: - Decrypted secret string, or None if decryption fails - """ - try: - fernet = self._get_fernet() - encrypted_bytes = base64.urlsafe_b64decode(encrypted_text.encode()) - decrypted = fernet.decrypt(encrypted_bytes) - return decrypted.decode() - except Exception as e: - logger.error(f"Failed to decrypt OAuth secret: {e}") - return None - - def is_encrypted(self, text: str) -> bool: - """Check if a string appears to be encrypted. - - Args: - text: String to check - - Returns: - True if the string appears to be encrypted - """ - try: - # Try to decode as base64 and check if it looks like encrypted data - decoded = base64.urlsafe_b64decode(text.encode()) - # Encrypted data should be at least 32 bytes (Fernet minimum) - return len(decoded) >= 32 - except Exception: - return False - - -def get_oauth_encryption(encryption_secret: SecretStr) -> OAuthEncryption: - """Get an OAuth encryption instance. - - Args: - encryption_secret: Secret key for encryption/decryption - - Returns: - OAuthEncryption instance - - Examples: - >>> enc = get_oauth_encryption(SecretStr('k')) - >>> isinstance(enc, OAuthEncryption) - True - """ - return OAuthEncryption(encryption_secret) diff --git a/tests/unit/mcpgateway/services/test_dcr_service.py b/tests/unit/mcpgateway/services/test_dcr_service.py index 9f493d103..a9b0a6987 100644 --- a/tests/unit/mcpgateway/services/test_dcr_service.py +++ b/tests/unit/mcpgateway/services/test_dcr_service.py @@ -429,7 +429,7 @@ class TestUpdateClientRegistration: @pytest.mark.asyncio async def test_update_client_registration_success(self, test_db): """Test successful client registration update.""" - from mcpgateway.utils.oauth_encryption import get_oauth_encryption + from mcpgateway.services.encryption_service import get_encryption_service from mcpgateway.config import get_settings dcr_service = DcrService() @@ -442,7 +442,7 @@ async def test_update_client_registration_success(self, test_db): test_db.commit() # Encrypt the registration access token properly - encryption = get_oauth_encryption(get_settings().auth_encryption_secret) + encryption = get_encryption_service(get_settings().auth_encryption_secret) encrypted_token = encryption.encrypt_secret("registration-access-token") client_record = RegisteredOAuthClient( @@ -474,7 +474,7 @@ async def test_update_client_registration_success(self, test_db): @pytest.mark.asyncio async def test_update_client_registration_uses_access_token(self, test_db): """Test that update uses registration_access_token.""" - from mcpgateway.utils.oauth_encryption import get_oauth_encryption + from mcpgateway.services.encryption_service import get_encryption_service from mcpgateway.config import get_settings dcr_service = DcrService() @@ -487,7 +487,7 @@ async def test_update_client_registration_uses_access_token(self, test_db): test_db.commit() # Encrypt the registration access token properly - encryption = get_oauth_encryption(get_settings().auth_encryption_secret) + encryption = get_encryption_service(get_settings().auth_encryption_secret) encrypted_token = encryption.encrypt_secret("registration-access-token") client_record = RegisteredOAuthClient( diff --git a/tests/unit/mcpgateway/test_admin.py b/tests/unit/mcpgateway/test_admin.py index 982206152..568c35477 100644 --- a/tests/unit/mcpgateway/test_admin.py +++ b/tests/unit/mcpgateway/test_admin.py @@ -2114,7 +2114,7 @@ async def test_admin_add_gateway_with_oauth_config(self, mock_register_gateway, mock_request.form = AsyncMock(return_value=form_data) # Mock OAuth encryption - with patch("mcpgateway.admin.get_oauth_encryption") as mock_get_encryption: + with patch("mcpgateway.admin.get_encryption_service") as mock_get_encryption: mock_encryption = MagicMock() mock_encryption.encrypt_secret.return_value = "encrypted-secret" mock_get_encryption.return_value = mock_encryption @@ -2175,7 +2175,7 @@ async def test_admin_edit_gateway_with_oauth_config(self, mock_update_gateway, m mock_request.form = AsyncMock(return_value=form_data) # Mock OAuth encryption - with patch("mcpgateway.admin.get_oauth_encryption") as mock_get_encryption: + with patch("mcpgateway.admin.get_encryption_service") as mock_get_encryption: mock_encryption = MagicMock() mock_encryption.encrypt_secret.return_value = "encrypted-edit-secret" mock_get_encryption.return_value = mock_encryption @@ -2204,7 +2204,7 @@ async def test_admin_edit_gateway_oauth_empty_client_secret(self, mock_update_ga mock_request.form = AsyncMock(return_value=form_data) # Mock OAuth encryption - should not be called for empty secret - with patch("mcpgateway.admin.get_oauth_encryption") as mock_get_encryption: + with patch("mcpgateway.admin.get_encryption_service") as mock_get_encryption: mock_encryption = MagicMock() mock_get_encryption.return_value = mock_encryption diff --git a/tests/unit/mcpgateway/test_oauth_manager.py b/tests/unit/mcpgateway/test_oauth_manager.py index 802227c2a..3431119a6 100644 --- a/tests/unit/mcpgateway/test_oauth_manager.py +++ b/tests/unit/mcpgateway/test_oauth_manager.py @@ -20,7 +20,7 @@ from mcpgateway.db import OAuthToken from mcpgateway.services.oauth_manager import OAuthError, OAuthManager from mcpgateway.services.token_storage_service import TokenStorageService -from mcpgateway.utils.oauth_encryption import OAuthEncryption +from mcpgateway.services.encryption_service import EncryptionService class TestOAuthManager: @@ -298,7 +298,7 @@ async def test_client_credentials_flow_with_encrypted_secret(self): mock_settings.auth_encryption_secret = "test_key" mock_get_settings.return_value = mock_settings - with patch("mcpgateway.services.oauth_manager.get_oauth_encryption") as mock_get_encryption: + with patch("mcpgateway.services.oauth_manager.get_encryption_service") as mock_get_encryption: mock_encryption = Mock() mock_encryption.decrypt_secret.return_value = "decrypted_secret" mock_get_encryption.return_value = mock_encryption @@ -371,7 +371,7 @@ async def test_client_credentials_flow_decryption_returns_none(self): mock_settings.auth_encryption_secret = "test_key" mock_get_settings.return_value = mock_settings - with patch("mcpgateway.services.oauth_manager.get_oauth_encryption") as mock_get_encryption: + with patch("mcpgateway.services.oauth_manager.get_encryption_service") as mock_get_encryption: mock_encryption = Mock() # Decryption returns None - line 108 mock_encryption.decrypt_secret.return_value = None @@ -890,7 +890,7 @@ async def test_exchange_code_for_tokens_decryption_returns_none(self): mock_settings.auth_encryption_secret = "test_key" mock_get_settings.return_value = mock_settings - with patch("mcpgateway.services.oauth_manager.get_oauth_encryption") as mock_get_encryption: + with patch("mcpgateway.services.oauth_manager.get_encryption_service") as mock_get_encryption: mock_encryption = Mock() # Decryption returns None - lines 438-439 mock_encryption.decrypt_secret.return_value = None @@ -981,7 +981,7 @@ async def test_exchange_code_for_token_decryption_returns_none(self): mock_settings.auth_encryption_secret = "test_key" mock_get_settings.return_value = mock_settings - with patch("mcpgateway.services.oauth_manager.get_oauth_encryption") as mock_get_encryption: + with patch("mcpgateway.services.oauth_manager.get_encryption_service") as mock_get_encryption: mock_encryption = Mock() # Decryption returns None - lines 216-217 mock_encryption.decrypt_secret.return_value = None @@ -1260,7 +1260,7 @@ async def test_exchange_code_for_tokens_decryption_success(self): mock_settings.auth_encryption_secret = "test_key" mock_get_settings.return_value = mock_settings - with patch("mcpgateway.services.oauth_manager.get_oauth_encryption") as mock_get_encryption: + with patch("mcpgateway.services.oauth_manager.get_encryption_service") as mock_get_encryption: mock_encryption = Mock() # Decryption succeeds - lines 435-437 mock_encryption.decrypt_secret.return_value = "decrypted_secret" @@ -1301,7 +1301,7 @@ async def test_exchange_code_for_tokens_decryption_exception(self): mock_settings.auth_encryption_secret = "test_key" mock_get_settings.return_value = mock_settings - with patch("mcpgateway.services.oauth_manager.get_oauth_encryption") as mock_get_encryption: + with patch("mcpgateway.services.oauth_manager.get_encryption_service") as mock_get_encryption: mock_encryption = Mock() # Decryption throws exception - lines 440-441 mock_encryption.decrypt_secret.side_effect = ValueError("Decryption failed") @@ -1454,7 +1454,7 @@ async def test_exchange_code_for_token_decryption_success(self): mock_settings.auth_encryption_secret = "test_key" mock_get_settings.return_value = mock_settings - with patch("mcpgateway.services.oauth_manager.get_oauth_encryption") as mock_get_encryption: + with patch("mcpgateway.services.oauth_manager.get_encryption_service") as mock_get_encryption: mock_encryption = Mock() # Decryption succeeds - lines 213-215 mock_encryption.decrypt_secret.return_value = "decrypted_secret" @@ -1495,7 +1495,7 @@ async def test_exchange_code_for_token_decryption_exception(self): mock_settings.auth_encryption_secret = "test_key" mock_get_settings.return_value = mock_settings - with patch("mcpgateway.services.oauth_manager.get_oauth_encryption") as mock_get_encryption: + with patch("mcpgateway.services.oauth_manager.get_encryption_service") as mock_get_encryption: mock_encryption = Mock() # Decryption throws exception - lines 218-219 mock_encryption.decrypt_secret.side_effect = ValueError("Decryption failed") @@ -1661,7 +1661,7 @@ def test_init_with_encryption(self): mock_settings.auth_encryption_secret = "test_secret_key" mock_get_settings.return_value = mock_settings - with patch("mcpgateway.services.token_storage_service.get_oauth_encryption") as mock_get_enc: + with patch("mcpgateway.services.token_storage_service.get_encryption_service") as mock_get_enc: mock_encryption = Mock() mock_get_enc.return_value = mock_encryption @@ -1709,7 +1709,7 @@ async def test_store_tokens_new_record_with_encryption(self): mock_settings.auth_encryption_secret = "test_secret" mock_get_settings.return_value = mock_settings - with patch("mcpgateway.services.token_storage_service.get_oauth_encryption") as mock_get_enc: + with patch("mcpgateway.services.token_storage_service.get_encryption_service") as mock_get_enc: mock_get_enc.return_value = mock_encryption service = TokenStorageService(mock_db) @@ -1810,7 +1810,7 @@ async def test_store_tokens_update_existing_record(self): mock_settings.auth_encryption_secret = "test_secret" mock_get_settings.return_value = mock_settings - with patch("mcpgateway.services.token_storage_service.get_oauth_encryption") as mock_get_enc: + with patch("mcpgateway.services.token_storage_service.get_encryption_service") as mock_get_enc: mock_get_enc.return_value = mock_encryption service = TokenStorageService(mock_db) @@ -1854,7 +1854,7 @@ async def test_store_tokens_without_refresh_token(self): mock_settings.auth_encryption_secret = "test_secret" mock_get_settings.return_value = mock_settings - with patch("mcpgateway.services.token_storage_service.get_oauth_encryption") as mock_get_enc: + with patch("mcpgateway.services.token_storage_service.get_encryption_service") as mock_get_enc: mock_get_enc.return_value = mock_encryption service = TokenStorageService(mock_db) @@ -1911,7 +1911,7 @@ async def test_get_valid_token_success_with_encryption(self): mock_settings.auth_encryption_secret = "test_secret" mock_get_settings.return_value = mock_settings - with patch("mcpgateway.services.token_storage_service.get_oauth_encryption") as mock_get_enc: + with patch("mcpgateway.services.token_storage_service.get_encryption_service") as mock_get_enc: mock_get_enc.return_value = mock_encryption service = TokenStorageService(mock_db) @@ -2442,29 +2442,17 @@ async def test_cleanup_expired_tokens_exception(self): mock_db.rollback.assert_called_once() -class TestOAuthEncryption: - """Test cases for OAuthEncryption class.""" +class TestEncryptionService: + """Test cases for EncryptionService class.""" def test_init(self): - """Test OAuthEncryption initialization.""" - encryption = OAuthEncryption(SecretStr("test_secret_key")) + """Test EncryptionService initialization.""" + encryption = EncryptionService(SecretStr("test_secret_key")) assert encryption.encryption_secret == b"test_secret_key" - assert encryption._fernet is None - - def test_get_fernet_creates_instance(self): - """Test _get_fernet creates Fernet instance on first call.""" - encryption = OAuthEncryption(SecretStr("test_secret_key")) - - fernet1 = encryption._get_fernet() - fernet2 = encryption._get_fernet() - - # Should return same instance (cached) - assert fernet1 is fernet2 - assert encryption._fernet is not None def test_encrypt_secret_success(self): """Test successful secret encryption.""" - encryption = OAuthEncryption(SecretStr("test_secret_key")) + encryption = EncryptionService(SecretStr("test_secret_key")) plaintext = "my_secret_token_123" encrypted = encryption.encrypt_secret(plaintext) @@ -2479,8 +2467,8 @@ def test_encrypt_secret_success(self): def test_encrypt_secret_different_keys_different_output(self): """Test that different keys produce different encrypted output.""" - encryption1 = OAuthEncryption(SecretStr("key1")) - encryption2 = OAuthEncryption(SecretStr("key2")) + encryption1 = EncryptionService(SecretStr("key1")) + encryption2 = EncryptionService(SecretStr("key2")) plaintext = "same_secret" encrypted1 = encryption1.encrypt_secret(plaintext) @@ -2491,7 +2479,7 @@ def test_encrypt_secret_different_keys_different_output(self): def test_encrypt_secret_same_key_different_output(self): """Test that same key produces different encrypted output due to nonce.""" - encryption = OAuthEncryption(SecretStr("test_key")) + encryption = EncryptionService(SecretStr("test_key")) plaintext = "same_secret" encrypted1 = encryption.encrypt_secret(plaintext) @@ -2506,7 +2494,7 @@ def test_encrypt_secret_same_key_different_output(self): def test_encrypt_secret_empty_string(self): """Test encrypting empty string.""" - encryption = OAuthEncryption(SecretStr("test_key")) + encryption = EncryptionService(SecretStr("test_key")) encrypted = encryption.encrypt_secret("") decrypted = encryption.decrypt_secret(encrypted) @@ -2515,7 +2503,7 @@ def test_encrypt_secret_empty_string(self): def test_encrypt_secret_unicode_characters(self): """Test encrypting string with unicode characters.""" - encryption = OAuthEncryption(SecretStr("test_key")) + encryption = EncryptionService(SecretStr("test_key")) plaintext = "šŸ” secret with Ć©mojis and spĆ©ciĆ l chars Ʊ" encrypted = encryption.encrypt_secret(plaintext) @@ -2525,20 +2513,15 @@ def test_encrypt_secret_unicode_characters(self): def test_encrypt_secret_exception_handling(self): """Test exception handling in encrypt_secret.""" - encryption = OAuthEncryption(SecretStr("test_key")) - - # Mock the Fernet instance to raise an exception - with patch.object(encryption, "_get_fernet") as mock_get_fernet: - mock_fernet = Mock() - mock_fernet.encrypt.side_effect = Exception("Encryption failed") - mock_get_fernet.return_value = mock_fernet + encryption = EncryptionService(SecretStr("test_key")) + with patch.object(encryption, "derive_key_argon2id", side_effect=Exception("Encryption failed")): with pytest.raises(Exception, match="Encryption failed"): encryption.encrypt_secret("test") def test_decrypt_secret_success(self): """Test successful secret decryption.""" - encryption = OAuthEncryption(SecretStr("test_secret_key")) + encryption = EncryptionService(SecretStr("test_secret_key")) plaintext = "original_secret" # First encrypt @@ -2551,7 +2534,7 @@ def test_decrypt_secret_success(self): def test_decrypt_secret_invalid_data(self): """Test decryption with invalid encrypted data.""" - encryption = OAuthEncryption(SecretStr("test_key")) + encryption = EncryptionService(SecretStr("test_key")) result = encryption.decrypt_secret("invalid_encrypted_data") @@ -2559,8 +2542,8 @@ def test_decrypt_secret_invalid_data(self): def test_decrypt_secret_wrong_key(self): """Test decryption with wrong key.""" - encryption1 = OAuthEncryption(SecretStr("key1")) - encryption2 = OAuthEncryption(SecretStr("key2")) + encryption1 = EncryptionService(SecretStr("key1")) + encryption2 = EncryptionService(SecretStr("key2")) # Encrypt with one key encrypted = encryption1.encrypt_secret("secret") @@ -2572,7 +2555,7 @@ def test_decrypt_secret_wrong_key(self): def test_decrypt_secret_corrupted_data(self): """Test decryption with corrupted base64 data.""" - encryption = OAuthEncryption(SecretStr("test_key")) + encryption = EncryptionService(SecretStr("test_key")) # Create valid encrypted data then corrupt it encrypted = encryption.encrypt_secret("test") @@ -2584,7 +2567,7 @@ def test_decrypt_secret_corrupted_data(self): def test_decrypt_secret_malformed_base64(self): """Test decryption with malformed base64.""" - encryption = OAuthEncryption(SecretStr("test_key")) + encryption = EncryptionService(SecretStr("test_key")) result = encryption.decrypt_secret("not_valid_base64!@#") @@ -2592,7 +2575,7 @@ def test_decrypt_secret_malformed_base64(self): def test_decrypt_secret_empty_string(self): """Test decryption with empty string.""" - encryption = OAuthEncryption(SecretStr("test_key")) + encryption = EncryptionService(SecretStr("test_key")) result = encryption.decrypt_secret("") @@ -2600,7 +2583,7 @@ def test_decrypt_secret_empty_string(self): def test_is_encrypted_valid_encrypted_data(self): """Test is_encrypted with valid encrypted data.""" - encryption = OAuthEncryption(SecretStr("test_key")) + encryption = EncryptionService(SecretStr("test_key")) encrypted = encryption.encrypt_secret("test_data") @@ -2608,14 +2591,14 @@ def test_is_encrypted_valid_encrypted_data(self): def test_is_encrypted_plain_text(self): """Test is_encrypted with plain text.""" - encryption = OAuthEncryption(SecretStr("test_key")) + encryption = EncryptionService(SecretStr("test_key")) assert encryption.is_encrypted("plain_text_secret") is False assert encryption.is_encrypted("another_plain_string") is False def test_is_encrypted_short_data(self): """Test is_encrypted with short data.""" - encryption = OAuthEncryption(SecretStr("test_key")) + encryption = EncryptionService(SecretStr("test_key")) # Fernet encrypted data should be at least 32 bytes short_data = "dGVzdA==" # "test" in base64 (only 4 bytes when decoded) @@ -2624,7 +2607,7 @@ def test_is_encrypted_short_data(self): def test_is_encrypted_valid_base64_but_not_encrypted(self): """Test is_encrypted with valid base64 that's not encrypted data.""" - encryption = OAuthEncryption(SecretStr("test_key")) + encryption = EncryptionService(SecretStr("test_key")) # Create base64 data that's long enough but not encrypted # Standard @@ -2641,32 +2624,32 @@ def test_is_encrypted_valid_base64_but_not_encrypted(self): def test_is_encrypted_invalid_base64(self): """Test is_encrypted with invalid base64.""" - encryption = OAuthEncryption(SecretStr("test_key")) + encryption = EncryptionService(SecretStr("test_key")) assert encryption.is_encrypted("not_base64!@#$%") is False def test_is_encrypted_exception_handling(self): """Test exception handling in is_encrypted.""" - encryption = OAuthEncryption(SecretStr("test_key")) + encryption = EncryptionService(SecretStr("test_key")) # Test with None (should handle gracefully) with patch("base64.urlsafe_b64decode", side_effect=Exception("Base64 error")): result = encryption.is_encrypted("any_string") assert result is False - def test_get_oauth_encryption_function(self): - """Test the get_oauth_encryption utility function.""" + def test_get_encryption_service_function(self): + """Test the get_encryption_service utility function.""" # First-Party - from mcpgateway.utils.oauth_encryption import get_oauth_encryption + from mcpgateway.services.encryption_service import get_encryption_service - encryption = get_oauth_encryption(SecretStr("test_secret")) + encryption = get_encryption_service(SecretStr("test_secret")) - assert isinstance(encryption, OAuthEncryption) + assert isinstance(encryption, EncryptionService) assert encryption.encryption_secret == b"test_secret" def test_encryption_roundtrip_multiple_values(self): """Test encryption/decryption roundtrip with multiple values.""" - encryption = OAuthEncryption(SecretStr("test_key")) + encryption = EncryptionService(SecretStr("test_key")) test_values = [ "simple_token", @@ -2687,8 +2670,8 @@ def test_encryption_roundtrip_multiple_values(self): def test_encryption_key_derivation_consistency(self): """Test that key derivation is consistent across instances.""" # Create two instances with same key - encryption1 = OAuthEncryption(SecretStr("same_key")) - encryption2 = OAuthEncryption(SecretStr("same_key")) + encryption1 = EncryptionService(SecretStr("same_key")) + encryption2 = EncryptionService(SecretStr("same_key")) # Encrypt with first instance plaintext = "test_consistency" @@ -2702,7 +2685,7 @@ def test_encryption_key_derivation_consistency(self): def test_encryption_with_long_key(self): """Test encryption with very long key.""" long_key = SecretStr("a" * 1000) # Very long key - encryption = OAuthEncryption(long_key) + encryption = EncryptionService(long_key) encrypted = encryption.encrypt_secret("test_data") decrypted = encryption.decrypt_secret(encrypted) @@ -2712,25 +2695,9 @@ def test_encryption_with_long_key(self): def test_encryption_with_special_char_key(self): """Test encryption with key containing special characters.""" special_key = SecretStr("key_with_special_chars!@#$%^&*()_+-={}[]|\\:;\"'<>?,./") - encryption = OAuthEncryption(special_key) + encryption = EncryptionService(special_key) encrypted = encryption.encrypt_secret("test_data") decrypted = encryption.decrypt_secret(encrypted) assert decrypted == "test_data" - - def test_fernet_instance_caching(self): - """Test that Fernet instance is properly cached.""" - encryption = OAuthEncryption(SecretStr("test_key")) - - # First call should create instance - assert encryption._fernet is None - fernet1 = encryption._get_fernet() - assert encryption._fernet is not None - - # Subsequent calls should return cached instance - fernet2 = encryption._get_fernet() - fernet3 = encryption._get_fernet() - - assert fernet1 is fernet2 - assert fernet2 is fernet3