Skip to content

Commit d96e188

Browse files
committed
Address PR Feedback
1 parent eca7bf4 commit d96e188

File tree

5 files changed

+65
-43
lines changed

5 files changed

+65
-43
lines changed

key-value/key-value-aio/src/key_value/aio/wrappers/encryption/base.py

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import Any, SupportsFloat
55

66
from key_value.shared.errors.key_value import SerializationError
7-
from key_value.shared.errors.wrappers.encryption import CorruptedEncryptionDataError, DecryptionError
7+
from key_value.shared.errors.wrappers.encryption import CorruptedDataError, DecryptionError, EncryptionError
88
from typing_extensions import override
99

1010
from key_value.aio.protocols.key_value import AsyncKeyValue
@@ -18,10 +18,6 @@
1818
DecryptionFn = Callable[[bytes, int], bytes]
1919

2020

21-
class EncryptionError(Exception):
22-
"""Exception raised when encryption or decryption fails."""
23-
24-
2521
class BaseEncryptionWrapper(BaseWrapper):
2622
"""Wrapper that encrypts values before storing and decrypts on retrieval.
2723
@@ -90,36 +86,51 @@ def _encrypt_value(self, value: dict[str, Any]) -> dict[str, Any]:
9086
_ENCRYPTION_VERSION_KEY: self.encryption_version,
9187
}
9288

89+
def _validate_encrypted_payload(self, value: dict[str, Any]) -> tuple[int, str]:
90+
if _ENCRYPTION_VERSION_KEY not in value:
91+
msg = "missing encryption version key"
92+
raise CorruptedDataError(msg)
93+
94+
encryption_version = value[_ENCRYPTION_VERSION_KEY]
95+
if not isinstance(encryption_version, int):
96+
msg = f"expected encryption version to be an int, got {type(encryption_version)}"
97+
raise CorruptedDataError(msg)
98+
99+
if _ENCRYPTED_DATA_KEY not in value:
100+
msg = "missing encrypted data key"
101+
raise CorruptedDataError(msg)
102+
103+
encrypted_data = value[_ENCRYPTED_DATA_KEY]
104+
105+
if not isinstance(encrypted_data, str):
106+
msg = f"expected encrypted data to be a str, got {type(encrypted_data)}"
107+
raise CorruptedDataError(msg)
108+
109+
return encryption_version, encrypted_data
110+
93111
def _decrypt_value(self, value: dict[str, Any] | None) -> dict[str, Any] | None:
94112
"""Decrypt a value from the encrypted format."""
95113
if value is None:
96114
return None
97115

116+
# If the value is not actually encrypted, return it as-is
98117
if _ENCRYPTED_DATA_KEY not in value and isinstance(value, dict): # pyright: ignore[reportUnnecessaryIsInstance]
99118
return value
100119

101-
base64_str = value[_ENCRYPTED_DATA_KEY]
102-
if not isinstance(base64_str, str):
103-
msg = f"Corrupted data: expected str, got {type(base64_str)}"
104-
raise CorruptedEncryptionDataError(msg)
105-
106-
if _ENCRYPTION_VERSION_KEY not in value:
107-
msg = "Corrupted data: missing encryption version"
108-
raise CorruptedEncryptionDataError(msg)
109-
110-
encryption_version = value[_ENCRYPTION_VERSION_KEY]
111-
if not isinstance(encryption_version, int):
112-
msg = f"Corrupted data: expected int, got {type(encryption_version)}"
113-
raise CorruptedEncryptionDataError(msg)
114-
115120
try:
116-
encrypted_bytes: bytes = base64.b64decode(base64_str, validate=True)
121+
encryption_version, encrypted_data = self._validate_encrypted_payload(value)
122+
123+
encrypted_bytes: bytes = base64.b64decode(encrypted_data, validate=True)
117124

118125
json_bytes: bytes = self._decryption_fn(encrypted_bytes, encryption_version)
119126

120127
json_str: str = json_bytes.decode(encoding="utf-8")
121128

122129
return json.loads(json_str) # type: ignore[no-any-return]
130+
except CorruptedDataError:
131+
if self.raise_on_decryption_error:
132+
raise
133+
return None
123134
except Exception as e:
124135
msg = "Failed to decrypt value"
125136
if self.raise_on_decryption_error:

key-value/key-value-aio/src/key_value/aio/wrappers/encryption/fernet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def encrypt_with_fernet(data: bytes) -> bytes:
7676

7777
def decrypt_with_fernet(data: bytes, encryption_version: int) -> bytes:
7878
if encryption_version > self.encryption_version:
79-
msg = f"Decryption failed: encryption version {encryption_version} is not supported"
79+
msg = f"Decryption failed: encryption versions newer than {self.encryption_version} are not supported"
8080
raise EncryptionVersionError(msg)
8181
return fernet.decrypt(data)
8282

key-value/key-value-shared/src/key_value/shared/errors/wrappers/encryption.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,5 @@ class EncryptionVersionError(EncryptionError):
1313
"""Exception raised when the encryption version is not supported."""
1414

1515

16-
class CorruptedEncryptionDataError(EncryptionError):
16+
class CorruptedDataError(DecryptionError):
1717
"""Exception raised when the encrypted data is corrupted."""

key-value/key-value-sync/src/key_value/sync/code_gen/wrappers/encryption/base.py

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from typing import Any, SupportsFloat
88

99
from key_value.shared.errors.key_value import SerializationError
10-
from key_value.shared.errors.wrappers.encryption import CorruptedEncryptionDataError, DecryptionError
10+
from key_value.shared.errors.wrappers.encryption import CorruptedDataError, DecryptionError, EncryptionError
1111
from typing_extensions import override
1212

1313
from key_value.sync.code_gen.protocols.key_value import KeyValue
@@ -20,10 +20,6 @@
2020
DecryptionFn = Callable[[bytes, int], bytes]
2121

2222

23-
class EncryptionError(Exception):
24-
"""Exception raised when encryption or decryption fails."""
25-
26-
2723
class BaseEncryptionWrapper(BaseWrapper):
2824
"""Wrapper that encrypts values before storing and decrypts on retrieval.
2925
@@ -89,36 +85,51 @@ def _encrypt_value(self, value: dict[str, Any]) -> dict[str, Any]:
8985

9086
return {_ENCRYPTED_DATA_KEY: base64_str, _ENCRYPTION_VERSION_KEY: self.encryption_version}
9187

88+
def _validate_encrypted_payload(self, value: dict[str, Any]) -> tuple[int, str]:
89+
if _ENCRYPTION_VERSION_KEY not in value:
90+
msg = "missing encryption version key"
91+
raise CorruptedDataError(msg)
92+
93+
encryption_version = value[_ENCRYPTION_VERSION_KEY]
94+
if not isinstance(encryption_version, int):
95+
msg = f"expected encryption version to be an int, got {type(encryption_version)}"
96+
raise CorruptedDataError(msg)
97+
98+
if _ENCRYPTED_DATA_KEY not in value:
99+
msg = "missing encrypted data key"
100+
raise CorruptedDataError(msg)
101+
102+
encrypted_data = value[_ENCRYPTED_DATA_KEY]
103+
104+
if not isinstance(encrypted_data, str):
105+
msg = f"expected encrypted data to be a str, got {type(encrypted_data)}"
106+
raise CorruptedDataError(msg)
107+
108+
return (encryption_version, encrypted_data)
109+
92110
def _decrypt_value(self, value: dict[str, Any] | None) -> dict[str, Any] | None:
93111
"""Decrypt a value from the encrypted format."""
94112
if value is None:
95113
return None
96114

115+
# If the value is not actually encrypted, return it as-is
97116
if _ENCRYPTED_DATA_KEY not in value and isinstance(value, dict): # pyright: ignore[reportUnnecessaryIsInstance]
98117
return value
99118

100-
base64_str = value[_ENCRYPTED_DATA_KEY]
101-
if not isinstance(base64_str, str):
102-
msg = f"Corrupted data: expected str, got {type(base64_str)}"
103-
raise CorruptedEncryptionDataError(msg)
104-
105-
if _ENCRYPTION_VERSION_KEY not in value:
106-
msg = "Corrupted data: missing encryption version"
107-
raise CorruptedEncryptionDataError(msg)
108-
109-
encryption_version = value[_ENCRYPTION_VERSION_KEY]
110-
if not isinstance(encryption_version, int):
111-
msg = f"Corrupted data: expected int, got {type(encryption_version)}"
112-
raise CorruptedEncryptionDataError(msg)
113-
114119
try:
115-
encrypted_bytes: bytes = base64.b64decode(base64_str, validate=True)
120+
(encryption_version, encrypted_data) = self._validate_encrypted_payload(value)
121+
122+
encrypted_bytes: bytes = base64.b64decode(encrypted_data, validate=True)
116123

117124
json_bytes: bytes = self._decryption_fn(encrypted_bytes, encryption_version)
118125

119126
json_str: str = json_bytes.decode(encoding="utf-8")
120127

121128
return json.loads(json_str) # type: ignore[no-any-return]
129+
except CorruptedDataError:
130+
if self.raise_on_decryption_error:
131+
raise
132+
return None
122133
except Exception as e:
123134
msg = "Failed to decrypt value"
124135
if self.raise_on_decryption_error:

key-value/key-value-sync/src/key_value/sync/code_gen/wrappers/encryption/fernet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def encrypt_with_fernet(data: bytes) -> bytes:
6666

6767
def decrypt_with_fernet(data: bytes, encryption_version: int) -> bytes:
6868
if encryption_version > self.encryption_version:
69-
msg = f"Decryption failed: encryption version {encryption_version} is not supported"
69+
msg = f"Decryption failed: encryption versions newer than {self.encryption_version} are not supported"
7070
raise EncryptionVersionError(msg)
7171
return fernet.decrypt(data)
7272

0 commit comments

Comments
 (0)