Skip to content

Commit 6a409c6

Browse files
authored
Add FernetEncryptionWrapper and simplify setup (#77)
1 parent 6c46b5b commit 6a409c6

File tree

12 files changed

+534
-371
lines changed

12 files changed

+534
-371
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ The following wrappers are available:
183183
| Wrapper | Description | Example |
184184
|---------|---------------|-----|
185185
| CompressionWrapper | Compress values before storing and decompress on retrieval. | `CompressionWrapper(key_value=memory_store, min_size_to_compress=0)` |
186-
| EncryptionWrapper | Encrypt values before storing and decrypt on retrieval. | `EncryptionWrapper(key_value=memory_store, encryption_key=Fernet.generate_key())` |
186+
| FernetEncryptionWrapper | Encrypt values before storing and decrypt on retrieval. | `FernetEncryptionWrapper(key_value=memory_store, source_material="your-source-material", salt="your-salt")` |
187187
| FallbackWrapper | Fallback to a secondary store when the primary store fails. | `FallbackWrapper(primary_key_value=memory_store, fallback_key_value=memory_store)` |
188188
| LimitSizeWrapper | Limit the size of entries stored in the cache. | `LimitSizeWrapper(key_value=memory_store, max_size=1024, raise_on_too_large=True)` |
189189
| LoggingWrapper | Log the operations performed on the store. | `LoggingWrapper(key_value=memory_store, log_level=logging.INFO, structured_logs=True)` |
Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1-
from key_value.aio.wrappers.encryption.wrapper import EncryptionWrapper
1+
from key_value.aio.wrappers.encryption.base import BaseEncryptionWrapper
2+
from key_value.aio.wrappers.encryption.fernet import FernetEncryptionWrapper
23

3-
__all__ = ["EncryptionWrapper"]
4+
__all__ = ["BaseEncryptionWrapper", "FernetEncryptionWrapper"]

key-value/key-value-aio/src/key_value/aio/wrappers/encryption/wrapper.py renamed to key-value/key-value-aio/src/key_value/aio/wrappers/encryption/base.py

Lines changed: 66 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,120 +1,136 @@
11
import base64
22
import json
3-
from collections.abc import Sequence
3+
from collections.abc import Callable, Sequence
44
from typing import Any, SupportsFloat
55

6-
from cryptography.fernet import Fernet
76
from key_value.shared.errors.key_value import SerializationError
8-
from key_value.shared.errors.wrappers.encryption import DecryptionError
7+
from key_value.shared.errors.wrappers.encryption import CorruptedDataError, DecryptionError, EncryptionError
98
from typing_extensions import override
109

1110
from key_value.aio.protocols.key_value import AsyncKeyValue
1211
from key_value.aio.wrappers.base import BaseWrapper
1312

14-
# Special keys used to store encrypted data
1513
_ENCRYPTED_DATA_KEY = "__encrypted_data__"
1614
_ENCRYPTION_VERSION_KEY = "__encryption_version__"
17-
_ENCRYPTION_VERSION = 1
1815

1916

20-
class EncryptionError(Exception):
21-
"""Exception raised when encryption or decryption fails."""
17+
EncryptionFn = Callable[[bytes], bytes]
18+
DecryptionFn = Callable[[bytes, int], bytes]
2219

2320

24-
class EncryptionWrapper(BaseWrapper):
21+
class BaseEncryptionWrapper(BaseWrapper):
2522
"""Wrapper that encrypts values before storing and decrypts on retrieval.
2623
27-
This wrapper encrypts the JSON-serialized value using Fernet (symmetric encryption)
24+
This wrapper encrypts the JSON-serialized value using a custom encryption function
2825
and stores it as a base64-encoded string within a special key in the dictionary.
2926
This allows encryption while maintaining the dict[str, Any] interface.
30-
31-
The encrypted format looks like:
32-
{
33-
"__encrypted_data__": "base64-encoded-encrypted-data",
34-
"__encryption_version__": 1
35-
}
36-
37-
Note: The encryption key must be kept secret and secure. If the key is lost,
38-
encrypted data cannot be recovered.
3927
"""
4028

4129
def __init__(
4230
self,
4331
key_value: AsyncKeyValue,
44-
encryption_key: bytes | str,
32+
encryption_fn: EncryptionFn,
33+
decryption_fn: DecryptionFn,
34+
encryption_version: int,
4535
raise_on_decryption_error: bool = True,
4636
) -> None:
4737
"""Initialize the encryption wrapper.
4838
4939
Args:
5040
key_value: The store to wrap.
51-
encryption_key: The encryption key to use. Can be a bytes object or a base64-encoded string.
52-
Use Fernet.generate_key() to generate a new key.
41+
encryption_fn: The encryption function to use. A callable that takes bytes and returns encrypted bytes.
42+
decryption_fn: The decryption function to use. A callable that takes bytes and an
43+
encryption version int and returns decrypted bytes.
44+
encryption_version: The encryption version to use.
5345
raise_on_decryption_error: Whether to raise an exception if decryption fails. Defaults to True.
5446
"""
5547
self.key_value: AsyncKeyValue = key_value
5648
self.raise_on_decryption_error: bool = raise_on_decryption_error
5749

58-
# Convert string key to bytes if needed
59-
if isinstance(encryption_key, str):
60-
encryption_key = encryption_key.encode("utf-8")
50+
self.encryption_version: int = encryption_version
6151

62-
self._fernet: Fernet = Fernet(key=encryption_key)
52+
self._encryption_fn: EncryptionFn = encryption_fn
53+
self._decryption_fn: DecryptionFn = decryption_fn
6354

6455
super().__init__()
6556

6657
def _encrypt_value(self, value: dict[str, Any]) -> dict[str, Any]:
67-
"""Encrypt a value into the encrypted format."""
68-
# Don't encrypt if it's already encrypted
69-
if _ENCRYPTED_DATA_KEY in value:
70-
return value
58+
"""Encrypt a value into the encrypted format.
59+
60+
The encrypted format looks like:
61+
{
62+
"__encrypted_data__": "base64-encoded-encrypted-data",
63+
"__encryption_version__": 1
64+
}
65+
"""
7166

7267
# Serialize to JSON
7368
try:
7469
json_str: str = json.dumps(value, separators=(",", ":"))
70+
71+
json_bytes: bytes = json_str.encode(encoding="utf-8")
7572
except (json.JSONDecodeError, TypeError) as e:
7673
msg: str = f"Failed to serialize object to JSON: {e}"
7774
raise SerializationError(msg) from e
7875

79-
json_bytes: bytes = json_str.encode(encoding="utf-8")
80-
81-
# Encrypt with Fernet
82-
encrypted_bytes: bytes = self._fernet.encrypt(data=json_bytes)
76+
try:
77+
encrypted_bytes: bytes = self._encryption_fn(json_bytes)
8378

84-
# Encode to base64 for storage in dict (though Fernet output is already base64)
85-
base64_str: str = base64.b64encode(encrypted_bytes).decode(encoding="ascii")
79+
base64_str: str = base64.b64encode(encrypted_bytes).decode(encoding="ascii")
80+
except Exception as e:
81+
msg = "Failed to encrypt value"
82+
raise EncryptionError(msg) from e
8683

8784
return {
8885
_ENCRYPTED_DATA_KEY: base64_str,
89-
_ENCRYPTION_VERSION_KEY: _ENCRYPTION_VERSION,
86+
_ENCRYPTION_VERSION_KEY: self.encryption_version,
9087
}
9188

89+
def _validate_encrypted_payload(self, value: dict[str, Any]) -> tuple[int, str]:
90+
if _ENCRYPTION_VERSION_KEY not in value:
91+
msg = "missing encryption version key"
92+
raise CorruptedDataError(msg)
93+
94+
encryption_version = value[_ENCRYPTION_VERSION_KEY]
95+
if not isinstance(encryption_version, int):
96+
msg = f"expected encryption version to be an int, got {type(encryption_version)}"
97+
raise CorruptedDataError(msg)
98+
99+
if _ENCRYPTED_DATA_KEY not in value:
100+
msg = "missing encrypted data key"
101+
raise CorruptedDataError(msg)
102+
103+
encrypted_data = value[_ENCRYPTED_DATA_KEY]
104+
105+
if not isinstance(encrypted_data, str):
106+
msg = f"expected encrypted data to be a str, got {type(encrypted_data)}"
107+
raise CorruptedDataError(msg)
108+
109+
return encryption_version, encrypted_data
110+
92111
def _decrypt_value(self, value: dict[str, Any] | None) -> dict[str, Any] | None:
93112
"""Decrypt a value from the encrypted format."""
94113
if value is None:
95114
return None
96115

97-
# Check if it's encrypted
98-
if _ENCRYPTED_DATA_KEY not in value:
116+
# If the value is not actually encrypted, return it as-is
117+
if _ENCRYPTED_DATA_KEY not in value and isinstance(value, dict): # pyright: ignore[reportUnnecessaryIsInstance]
99118
return value
100119

101-
# Extract encrypted data
102-
base64_str = value[_ENCRYPTED_DATA_KEY]
103-
if not isinstance(base64_str, str):
104-
# Corrupted data, return as-is
105-
msg = f"Corrupted data: expected str, got {type(base64_str)}"
106-
raise TypeError(msg)
107-
108120
try:
109-
# Decode from base64
110-
encrypted_bytes: bytes = base64.b64decode(base64_str)
121+
encryption_version, encrypted_data = self._validate_encrypted_payload(value)
122+
123+
encrypted_bytes: bytes = base64.b64decode(encrypted_data, validate=True)
111124

112-
# Decrypt with Fernet
113-
json_bytes: bytes = self._fernet.decrypt(token=encrypted_bytes)
125+
json_bytes: bytes = self._decryption_fn(encrypted_bytes, encryption_version)
114126

115-
# Parse JSON
116127
json_str: str = json_bytes.decode(encoding="utf-8")
128+
117129
return json.loads(json_str) # type: ignore[no-any-return]
130+
except CorruptedDataError:
131+
if self.raise_on_decryption_error:
132+
raise
133+
return None
118134
except Exception as e:
119135
msg = "Failed to decrypt value"
120136
if self.raise_on_decryption_error:
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
from cryptography.fernet import Fernet, MultiFernet
2+
from key_value.shared.errors.wrappers.encryption import EncryptionVersionError
3+
from typing_extensions import overload
4+
5+
from key_value.aio.protocols.key_value import AsyncKeyValue
6+
from key_value.aio.wrappers.encryption.base import BaseEncryptionWrapper
7+
8+
ENCRYPTION_VERSION = 1
9+
10+
KDF_ITERATIONS = 1_200_000
11+
12+
13+
class FernetEncryptionWrapper(BaseEncryptionWrapper):
14+
"""Wrapper that encrypts values before storing and decrypts on retrieval using Fernet (symmetric encryption)."""
15+
16+
@overload
17+
def __init__(
18+
self,
19+
key_value: AsyncKeyValue,
20+
*,
21+
fernet: Fernet | MultiFernet,
22+
raise_on_decryption_error: bool = True,
23+
) -> None:
24+
"""Initialize the Fernet encryption wrapper.
25+
26+
Args:
27+
key_value: The key-value store to wrap.
28+
fernet: The Fernet or MultiFernet instance to use for encryption and decryption MultiFernet is used to support
29+
key rotation by allowing you to provide multiple Fernet instances that are attempted in order.
30+
raise_on_decryption_error: Whether to raise an exception if decryption fails. Defaults to True.
31+
"""
32+
33+
@overload
34+
def __init__(
35+
self,
36+
key_value: AsyncKeyValue,
37+
*,
38+
source_material: str,
39+
salt: str,
40+
raise_on_decryption_error: bool = True,
41+
) -> None:
42+
"""Initialize the Fernet encryption wrapper.
43+
44+
Args:
45+
key_value: The key-value store to wrap.
46+
source_material: A string to use as the source material for the encryption key.
47+
salt: A string to use as the salt for the encryption key.
48+
raise_on_decryption_error: Whether to raise an exception if decryption fails. Defaults to True.
49+
"""
50+
51+
def __init__(
52+
self,
53+
key_value: AsyncKeyValue,
54+
*,
55+
fernet: Fernet | MultiFernet | None = None,
56+
source_material: str | None = None,
57+
salt: str | None = None,
58+
raise_on_decryption_error: bool = True,
59+
) -> None:
60+
if fernet is not None: # noqa: SIM102
61+
if source_material or salt:
62+
msg = "Cannot provide fernet together with source_material or salt"
63+
raise ValueError(msg)
64+
65+
if fernet is None:
66+
if not source_material or not source_material.strip():
67+
msg = "Must provide either fernet or source_material"
68+
raise ValueError(msg)
69+
if not salt or not salt.strip():
70+
msg = "Must provide a salt"
71+
raise ValueError(msg)
72+
fernet = Fernet(key=_generate_encryption_key(source_material=source_material, salt=salt))
73+
74+
def encrypt_with_fernet(data: bytes) -> bytes:
75+
return fernet.encrypt(data)
76+
77+
def decrypt_with_fernet(data: bytes, encryption_version: int) -> bytes:
78+
if encryption_version > self.encryption_version:
79+
msg = f"Decryption failed: encryption versions newer than {self.encryption_version} are not supported"
80+
raise EncryptionVersionError(msg)
81+
return fernet.decrypt(data)
82+
83+
super().__init__(
84+
key_value=key_value,
85+
encryption_fn=encrypt_with_fernet,
86+
decryption_fn=decrypt_with_fernet,
87+
encryption_version=ENCRYPTION_VERSION,
88+
raise_on_decryption_error=raise_on_decryption_error,
89+
)
90+
91+
92+
def _generate_encryption_key(source_material: str, salt: str) -> bytes:
93+
"""Generate a Fernet encryption key from a source material and salt using PBKDF2."""
94+
import base64
95+
96+
from cryptography.hazmat.primitives import hashes
97+
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
98+
99+
pbkdf2 = PBKDF2HMAC(
100+
algorithm=hashes.SHA256(),
101+
length=32,
102+
salt=salt.encode(),
103+
iterations=KDF_ITERATIONS,
104+
).derive(key_material=source_material.encode())
105+
106+
return base64.urlsafe_b64encode(pbkdf2)

0 commit comments

Comments
 (0)