Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ The following wrappers are available:
| Wrapper | Description | Example |
|---------|---------------|-----|
| CompressionWrapper | Compress values before storing and decompress on retrieval. | `CompressionWrapper(key_value=memory_store, min_size_to_compress=0)` |
| EncryptionWrapper | Encrypt values before storing and decrypt on retrieval. | `EncryptionWrapper(key_value=memory_store, encryption_key=Fernet.generate_key())` |
| FernetEncryptionWrapper | Encrypt values before storing and decrypt on retrieval. | `FernetEncryptionWrapper(key_value=memory_store, source_material="your-source-material", salt="your-salt")` |
| FallbackWrapper | Fallback to a secondary store when the primary store fails. | `FallbackWrapper(primary_key_value=memory_store, fallback_key_value=memory_store)` |
| LimitSizeWrapper | Limit the size of entries stored in the cache. | `LimitSizeWrapper(key_value=memory_store, max_size=1024, raise_on_too_large=True)` |
| LoggingWrapper | Log the operations performed on the store. | `LoggingWrapper(key_value=memory_store, log_level=logging.INFO, structured_logs=True)` |
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from key_value.aio.wrappers.encryption.wrapper import EncryptionWrapper
from key_value.aio.wrappers.encryption.base import BaseEncryptionWrapper
from key_value.aio.wrappers.encryption.fernet import FernetEncryptionWrapper

__all__ = ["EncryptionWrapper"]
__all__ = ["BaseEncryptionWrapper", "FernetEncryptionWrapper"]
Original file line number Diff line number Diff line change
@@ -1,120 +1,136 @@
import base64
import json
from collections.abc import Sequence
from collections.abc import Callable, Sequence
from typing import Any, SupportsFloat

from cryptography.fernet import Fernet
from key_value.shared.errors.key_value import SerializationError
from key_value.shared.errors.wrappers.encryption import DecryptionError
from key_value.shared.errors.wrappers.encryption import CorruptedDataError, DecryptionError, EncryptionError
from typing_extensions import override

from key_value.aio.protocols.key_value import AsyncKeyValue
from key_value.aio.wrappers.base import BaseWrapper

# Special keys used to store encrypted data
_ENCRYPTED_DATA_KEY = "__encrypted_data__"
_ENCRYPTION_VERSION_KEY = "__encryption_version__"
_ENCRYPTION_VERSION = 1


class EncryptionError(Exception):
"""Exception raised when encryption or decryption fails."""
EncryptionFn = Callable[[bytes], bytes]
DecryptionFn = Callable[[bytes, int], bytes]


class EncryptionWrapper(BaseWrapper):
class BaseEncryptionWrapper(BaseWrapper):
"""Wrapper that encrypts values before storing and decrypts on retrieval.

This wrapper encrypts the JSON-serialized value using Fernet (symmetric encryption)
This wrapper encrypts the JSON-serialized value using a custom encryption function
and stores it as a base64-encoded string within a special key in the dictionary.
This allows encryption while maintaining the dict[str, Any] interface.

The encrypted format looks like:
{
"__encrypted_data__": "base64-encoded-encrypted-data",
"__encryption_version__": 1
}

Note: The encryption key must be kept secret and secure. If the key is lost,
encrypted data cannot be recovered.
"""

def __init__(
self,
key_value: AsyncKeyValue,
encryption_key: bytes | str,
encryption_fn: EncryptionFn,
decryption_fn: DecryptionFn,
encryption_version: int,
raise_on_decryption_error: bool = True,
) -> None:
"""Initialize the encryption wrapper.

Args:
key_value: The store to wrap.
encryption_key: The encryption key to use. Can be a bytes object or a base64-encoded string.
Use Fernet.generate_key() to generate a new key.
encryption_fn: The encryption function to use. A callable that takes bytes and returns encrypted bytes.
decryption_fn: The decryption function to use. A callable that takes bytes and an
encryption version int and returns decrypted bytes.
encryption_version: The encryption version to use.
raise_on_decryption_error: Whether to raise an exception if decryption fails. Defaults to True.
"""
self.key_value: AsyncKeyValue = key_value
self.raise_on_decryption_error: bool = raise_on_decryption_error

# Convert string key to bytes if needed
if isinstance(encryption_key, str):
encryption_key = encryption_key.encode("utf-8")
self.encryption_version: int = encryption_version

self._fernet: Fernet = Fernet(key=encryption_key)
self._encryption_fn: EncryptionFn = encryption_fn
self._decryption_fn: DecryptionFn = decryption_fn

super().__init__()

def _encrypt_value(self, value: dict[str, Any]) -> dict[str, Any]:
"""Encrypt a value into the encrypted format."""
# Don't encrypt if it's already encrypted
if _ENCRYPTED_DATA_KEY in value:
return value
"""Encrypt a value into the encrypted format.

The encrypted format looks like:
{
"__encrypted_data__": "base64-encoded-encrypted-data",
"__encryption_version__": 1
}
"""

# Serialize to JSON
try:
json_str: str = json.dumps(value, separators=(",", ":"))

json_bytes: bytes = json_str.encode(encoding="utf-8")
except (json.JSONDecodeError, TypeError) as e:
msg: str = f"Failed to serialize object to JSON: {e}"
raise SerializationError(msg) from e

json_bytes: bytes = json_str.encode(encoding="utf-8")

# Encrypt with Fernet
encrypted_bytes: bytes = self._fernet.encrypt(data=json_bytes)
try:
encrypted_bytes: bytes = self._encryption_fn(json_bytes)

# Encode to base64 for storage in dict (though Fernet output is already base64)
base64_str: str = base64.b64encode(encrypted_bytes).decode(encoding="ascii")
base64_str: str = base64.b64encode(encrypted_bytes).decode(encoding="ascii")
except Exception as e:
msg = "Failed to encrypt value"
raise EncryptionError(msg) from e

return {
_ENCRYPTED_DATA_KEY: base64_str,
_ENCRYPTION_VERSION_KEY: _ENCRYPTION_VERSION,
_ENCRYPTION_VERSION_KEY: self.encryption_version,
}

def _validate_encrypted_payload(self, value: dict[str, Any]) -> tuple[int, str]:
if _ENCRYPTION_VERSION_KEY not in value:
msg = "missing encryption version key"
raise CorruptedDataError(msg)

encryption_version = value[_ENCRYPTION_VERSION_KEY]
if not isinstance(encryption_version, int):
msg = f"expected encryption version to be an int, got {type(encryption_version)}"
raise CorruptedDataError(msg)

if _ENCRYPTED_DATA_KEY not in value:
msg = "missing encrypted data key"
raise CorruptedDataError(msg)

encrypted_data = value[_ENCRYPTED_DATA_KEY]

if not isinstance(encrypted_data, str):
msg = f"expected encrypted data to be a str, got {type(encrypted_data)}"
raise CorruptedDataError(msg)

return encryption_version, encrypted_data

def _decrypt_value(self, value: dict[str, Any] | None) -> dict[str, Any] | None:
"""Decrypt a value from the encrypted format."""
if value is None:
return None

# Check if it's encrypted
if _ENCRYPTED_DATA_KEY not in value:
# If the value is not actually encrypted, return it as-is
if _ENCRYPTED_DATA_KEY not in value and isinstance(value, dict): # pyright: ignore[reportUnnecessaryIsInstance]
return value

# Extract encrypted data
base64_str = value[_ENCRYPTED_DATA_KEY]
if not isinstance(base64_str, str):
# Corrupted data, return as-is
msg = f"Corrupted data: expected str, got {type(base64_str)}"
raise TypeError(msg)

try:
# Decode from base64
encrypted_bytes: bytes = base64.b64decode(base64_str)
encryption_version, encrypted_data = self._validate_encrypted_payload(value)

encrypted_bytes: bytes = base64.b64decode(encrypted_data, validate=True)

# Decrypt with Fernet
json_bytes: bytes = self._fernet.decrypt(token=encrypted_bytes)
json_bytes: bytes = self._decryption_fn(encrypted_bytes, encryption_version)

# Parse JSON
json_str: str = json_bytes.decode(encoding="utf-8")

return json.loads(json_str) # type: ignore[no-any-return]
except CorruptedDataError:
if self.raise_on_decryption_error:
raise
return None
except Exception as e:
msg = "Failed to decrypt value"
if self.raise_on_decryption_error:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from cryptography.fernet import Fernet, MultiFernet
from key_value.shared.errors.wrappers.encryption import EncryptionVersionError
from typing_extensions import overload

from key_value.aio.protocols.key_value import AsyncKeyValue
from key_value.aio.wrappers.encryption.base import BaseEncryptionWrapper

ENCRYPTION_VERSION = 1

KDF_ITERATIONS = 1_200_000


class FernetEncryptionWrapper(BaseEncryptionWrapper):
"""Wrapper that encrypts values before storing and decrypts on retrieval using Fernet (symmetric encryption)."""

@overload
def __init__(
self,
key_value: AsyncKeyValue,
*,
fernet: Fernet | MultiFernet,
raise_on_decryption_error: bool = True,
) -> None:
"""Initialize the Fernet encryption wrapper.

Args:
key_value: The key-value store to wrap.
fernet: The Fernet or MultiFernet instance to use for encryption and decryption MultiFernet is used to support
key rotation by allowing you to provide multiple Fernet instances that are attempted in order.
raise_on_decryption_error: Whether to raise an exception if decryption fails. Defaults to True.
"""

@overload
def __init__(
self,
key_value: AsyncKeyValue,
*,
source_material: str,
salt: str,
raise_on_decryption_error: bool = True,
) -> None:
"""Initialize the Fernet encryption wrapper.

Args:
key_value: The key-value store to wrap.
source_material: A string to use as the source material for the encryption key.
salt: A string to use as the salt for the encryption key.
raise_on_decryption_error: Whether to raise an exception if decryption fails. Defaults to True.
"""

def __init__(
self,
key_value: AsyncKeyValue,
*,
fernet: Fernet | MultiFernet | None = None,
source_material: str | None = None,
salt: str | None = None,
raise_on_decryption_error: bool = True,
) -> None:
if fernet is not None: # noqa: SIM102
if source_material or salt:
msg = "Cannot provide fernet together with source_material or salt"
raise ValueError(msg)

if fernet is None:
if not source_material or not source_material.strip():
msg = "Must provide either fernet or source_material"
raise ValueError(msg)
if not salt or not salt.strip():
msg = "Must provide a salt"
raise ValueError(msg)
fernet = Fernet(key=_generate_encryption_key(source_material=source_material, salt=salt))

def encrypt_with_fernet(data: bytes) -> bytes:
return fernet.encrypt(data)

def decrypt_with_fernet(data: bytes, encryption_version: int) -> bytes:
if encryption_version > self.encryption_version:
msg = f"Decryption failed: encryption versions newer than {self.encryption_version} are not supported"
raise EncryptionVersionError(msg)
return fernet.decrypt(data)

super().__init__(
key_value=key_value,
encryption_fn=encrypt_with_fernet,
decryption_fn=decrypt_with_fernet,
encryption_version=ENCRYPTION_VERSION,
raise_on_decryption_error=raise_on_decryption_error,
)


def _generate_encryption_key(source_material: str, salt: str) -> bytes:
"""Generate a Fernet encryption key from a source material and salt using PBKDF2."""
import base64

from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC

pbkdf2 = PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=32,
salt=salt.encode(),
iterations=KDF_ITERATIONS,
).derive(key_material=source_material.encode())

return base64.urlsafe_b64encode(pbkdf2)
Loading