|
1 | 1 | import base64 |
2 | 2 | import json |
3 | | -from collections.abc import Sequence |
| 3 | +from collections.abc import Callable, Sequence |
4 | 4 | from typing import Any, SupportsFloat |
5 | 5 |
|
6 | | -from cryptography.fernet import Fernet |
7 | 6 | from key_value.shared.errors.key_value import SerializationError |
8 | | -from key_value.shared.errors.wrappers.encryption import DecryptionError |
| 7 | +from key_value.shared.errors.wrappers.encryption import CorruptedDataError, DecryptionError, EncryptionError |
9 | 8 | from typing_extensions import override |
10 | 9 |
|
11 | 10 | from key_value.aio.protocols.key_value import AsyncKeyValue |
12 | 11 | from key_value.aio.wrappers.base import BaseWrapper |
13 | 12 |
|
14 | | -# Special keys used to store encrypted data |
15 | 13 | _ENCRYPTED_DATA_KEY = "__encrypted_data__" |
16 | 14 | _ENCRYPTION_VERSION_KEY = "__encryption_version__" |
17 | | -_ENCRYPTION_VERSION = 1 |
18 | 15 |
|
19 | 16 |
|
20 | | -class EncryptionError(Exception): |
21 | | - """Exception raised when encryption or decryption fails.""" |
| 17 | +EncryptionFn = Callable[[bytes], bytes] |
| 18 | +DecryptionFn = Callable[[bytes, int], bytes] |
22 | 19 |
|
23 | 20 |
|
24 | | -class EncryptionWrapper(BaseWrapper): |
| 21 | +class BaseEncryptionWrapper(BaseWrapper): |
25 | 22 | """Wrapper that encrypts values before storing and decrypts on retrieval. |
26 | 23 |
|
27 | | - This wrapper encrypts the JSON-serialized value using Fernet (symmetric encryption) |
| 24 | + This wrapper encrypts the JSON-serialized value using a custom encryption function |
28 | 25 | and stores it as a base64-encoded string within a special key in the dictionary. |
29 | 26 | This allows encryption while maintaining the dict[str, Any] interface. |
30 | | -
|
31 | | - The encrypted format looks like: |
32 | | - { |
33 | | - "__encrypted_data__": "base64-encoded-encrypted-data", |
34 | | - "__encryption_version__": 1 |
35 | | - } |
36 | | -
|
37 | | - Note: The encryption key must be kept secret and secure. If the key is lost, |
38 | | - encrypted data cannot be recovered. |
39 | 27 | """ |
40 | 28 |
|
41 | 29 | def __init__( |
42 | 30 | self, |
43 | 31 | key_value: AsyncKeyValue, |
44 | | - encryption_key: bytes | str, |
| 32 | + encryption_fn: EncryptionFn, |
| 33 | + decryption_fn: DecryptionFn, |
| 34 | + encryption_version: int, |
45 | 35 | raise_on_decryption_error: bool = True, |
46 | 36 | ) -> None: |
47 | 37 | """Initialize the encryption wrapper. |
48 | 38 |
|
49 | 39 | Args: |
50 | 40 | key_value: The store to wrap. |
51 | | - encryption_key: The encryption key to use. Can be a bytes object or a base64-encoded string. |
52 | | - Use Fernet.generate_key() to generate a new key. |
| 41 | + encryption_fn: The encryption function to use. A callable that takes bytes and returns encrypted bytes. |
| 42 | + decryption_fn: The decryption function to use. A callable that takes bytes and an |
| 43 | + encryption version int and returns decrypted bytes. |
| 44 | + encryption_version: The encryption version to use. |
53 | 45 | raise_on_decryption_error: Whether to raise an exception if decryption fails. Defaults to True. |
54 | 46 | """ |
55 | 47 | self.key_value: AsyncKeyValue = key_value |
56 | 48 | self.raise_on_decryption_error: bool = raise_on_decryption_error |
57 | 49 |
|
58 | | - # Convert string key to bytes if needed |
59 | | - if isinstance(encryption_key, str): |
60 | | - encryption_key = encryption_key.encode("utf-8") |
| 50 | + self.encryption_version: int = encryption_version |
61 | 51 |
|
62 | | - self._fernet: Fernet = Fernet(key=encryption_key) |
| 52 | + self._encryption_fn: EncryptionFn = encryption_fn |
| 53 | + self._decryption_fn: DecryptionFn = decryption_fn |
63 | 54 |
|
64 | 55 | super().__init__() |
65 | 56 |
|
66 | 57 | def _encrypt_value(self, value: dict[str, Any]) -> dict[str, Any]: |
67 | | - """Encrypt a value into the encrypted format.""" |
68 | | - # Don't encrypt if it's already encrypted |
69 | | - if _ENCRYPTED_DATA_KEY in value: |
70 | | - return value |
| 58 | + """Encrypt a value into the encrypted format. |
| 59 | +
|
| 60 | + The encrypted format looks like: |
| 61 | + { |
| 62 | + "__encrypted_data__": "base64-encoded-encrypted-data", |
| 63 | + "__encryption_version__": 1 |
| 64 | + } |
| 65 | + """ |
71 | 66 |
|
72 | 67 | # Serialize to JSON |
73 | 68 | try: |
74 | 69 | json_str: str = json.dumps(value, separators=(",", ":")) |
| 70 | + |
| 71 | + json_bytes: bytes = json_str.encode(encoding="utf-8") |
75 | 72 | except (json.JSONDecodeError, TypeError) as e: |
76 | 73 | msg: str = f"Failed to serialize object to JSON: {e}" |
77 | 74 | raise SerializationError(msg) from e |
78 | 75 |
|
79 | | - json_bytes: bytes = json_str.encode(encoding="utf-8") |
80 | | - |
81 | | - # Encrypt with Fernet |
82 | | - encrypted_bytes: bytes = self._fernet.encrypt(data=json_bytes) |
| 76 | + try: |
| 77 | + encrypted_bytes: bytes = self._encryption_fn(json_bytes) |
83 | 78 |
|
84 | | - # Encode to base64 for storage in dict (though Fernet output is already base64) |
85 | | - base64_str: str = base64.b64encode(encrypted_bytes).decode(encoding="ascii") |
| 79 | + base64_str: str = base64.b64encode(encrypted_bytes).decode(encoding="ascii") |
| 80 | + except Exception as e: |
| 81 | + msg = "Failed to encrypt value" |
| 82 | + raise EncryptionError(msg) from e |
86 | 83 |
|
87 | 84 | return { |
88 | 85 | _ENCRYPTED_DATA_KEY: base64_str, |
89 | | - _ENCRYPTION_VERSION_KEY: _ENCRYPTION_VERSION, |
| 86 | + _ENCRYPTION_VERSION_KEY: self.encryption_version, |
90 | 87 | } |
91 | 88 |
|
| 89 | + def _validate_encrypted_payload(self, value: dict[str, Any]) -> tuple[int, str]: |
| 90 | + if _ENCRYPTION_VERSION_KEY not in value: |
| 91 | + msg = "missing encryption version key" |
| 92 | + raise CorruptedDataError(msg) |
| 93 | + |
| 94 | + encryption_version = value[_ENCRYPTION_VERSION_KEY] |
| 95 | + if not isinstance(encryption_version, int): |
| 96 | + msg = f"expected encryption version to be an int, got {type(encryption_version)}" |
| 97 | + raise CorruptedDataError(msg) |
| 98 | + |
| 99 | + if _ENCRYPTED_DATA_KEY not in value: |
| 100 | + msg = "missing encrypted data key" |
| 101 | + raise CorruptedDataError(msg) |
| 102 | + |
| 103 | + encrypted_data = value[_ENCRYPTED_DATA_KEY] |
| 104 | + |
| 105 | + if not isinstance(encrypted_data, str): |
| 106 | + msg = f"expected encrypted data to be a str, got {type(encrypted_data)}" |
| 107 | + raise CorruptedDataError(msg) |
| 108 | + |
| 109 | + return encryption_version, encrypted_data |
| 110 | + |
92 | 111 | def _decrypt_value(self, value: dict[str, Any] | None) -> dict[str, Any] | None: |
93 | 112 | """Decrypt a value from the encrypted format.""" |
94 | 113 | if value is None: |
95 | 114 | return None |
96 | 115 |
|
97 | | - # Check if it's encrypted |
98 | | - if _ENCRYPTED_DATA_KEY not in value: |
| 116 | + # If the value is not actually encrypted, return it as-is |
| 117 | + if _ENCRYPTED_DATA_KEY not in value and isinstance(value, dict): # pyright: ignore[reportUnnecessaryIsInstance] |
99 | 118 | return value |
100 | 119 |
|
101 | | - # Extract encrypted data |
102 | | - base64_str = value[_ENCRYPTED_DATA_KEY] |
103 | | - if not isinstance(base64_str, str): |
104 | | - # Corrupted data, return as-is |
105 | | - msg = f"Corrupted data: expected str, got {type(base64_str)}" |
106 | | - raise TypeError(msg) |
107 | | - |
108 | 120 | try: |
109 | | - # Decode from base64 |
110 | | - encrypted_bytes: bytes = base64.b64decode(base64_str) |
| 121 | + encryption_version, encrypted_data = self._validate_encrypted_payload(value) |
| 122 | + |
| 123 | + encrypted_bytes: bytes = base64.b64decode(encrypted_data, validate=True) |
111 | 124 |
|
112 | | - # Decrypt with Fernet |
113 | | - json_bytes: bytes = self._fernet.decrypt(token=encrypted_bytes) |
| 125 | + json_bytes: bytes = self._decryption_fn(encrypted_bytes, encryption_version) |
114 | 126 |
|
115 | | - # Parse JSON |
116 | 127 | json_str: str = json_bytes.decode(encoding="utf-8") |
| 128 | + |
117 | 129 | return json.loads(json_str) # type: ignore[no-any-return] |
| 130 | + except CorruptedDataError: |
| 131 | + if self.raise_on_decryption_error: |
| 132 | + raise |
| 133 | + return None |
118 | 134 | except Exception as e: |
119 | 135 | msg = "Failed to decrypt value" |
120 | 136 | if self.raise_on_decryption_error: |
|
0 commit comments