Skip to content

Commit f60a8d9

Browse files
committed
Updates from PR Feedback
1 parent f1d9df5 commit f60a8d9

File tree

9 files changed

+125
-47
lines changed

9 files changed

+125
-47
lines changed

key-value/key-value-aio/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ wrappers-encryption = ["cryptography>=45.0.0"]
4949

5050
[tool.pytest.ini_options]
5151
asyncio_mode = "auto"
52-
addopts = ["--inline-snapshot=fix,create"]
52+
addopts = ["--inline-snapshot=disable"]
5353
markers = [
5454
"skip_on_ci: Skip running the test when running on CI",
5555
]

key-value/key-value-aio/src/key_value/aio/wrappers/encryption/base.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
_ENCRYPTION_VERSION_KEY = "__encryption_version__"
1616
_ENCRYPTION_VERSION = 1
1717

18-
EncryptionFn = Callable[[bytes], bytes]
19-
DecryptionFn = Callable[[bytes], bytes]
18+
EncryptionFn = Callable[[bytes, int], bytes]
19+
DecryptionFn = Callable[[bytes, int], bytes]
2020

2121

2222
class EncryptionError(Exception):
@@ -45,18 +45,23 @@ def __init__(
4545
key_value: AsyncKeyValue,
4646
encryption_fn: EncryptionFn,
4747
decryption_fn: DecryptionFn,
48+
encryption_version: int,
4849
raise_on_decryption_error: bool = True,
4950
) -> None:
5051
"""Initialize the encryption wrapper.
5152
5253
Args:
5354
key_value: The store to wrap.
54-
encryption_fn: The encryption function to use.
55-
decryption_fn: The decryption function to use.
55+
encryption_fn: The encryption function to use. A callable that takes bytes and an
56+
encryption version int and returns encrypted bytes.
57+
decryption_fn: The decryption function to use. A callable that takes bytes and an
58+
encryption version int and returns decrypted bytes.
59+
encryption_version: The encryption version to use.
5660
raise_on_decryption_error: Whether to raise an exception if decryption fails. Defaults to True.
5761
"""
5862
self.key_value: AsyncKeyValue = key_value
5963
self.raise_on_decryption_error: bool = raise_on_decryption_error
64+
self.encryption_version: int = encryption_version
6065

6166
self._encryption_fn: EncryptionFn = encryption_fn
6267
self._decryption_fn: DecryptionFn = decryption_fn
@@ -65,9 +70,6 @@ def __init__(
6570

6671
def _encrypt_value(self, value: dict[str, Any]) -> dict[str, Any]:
6772
"""Encrypt a value into the encrypted format."""
68-
# # Don't encrypt if it's already encrypted
69-
# if _ENCRYPTED_DATA_KEY in value:
70-
# return value
7173

7274
# Serialize to JSON
7375
try:
@@ -79,14 +81,14 @@ def _encrypt_value(self, value: dict[str, Any]) -> dict[str, Any]:
7981
json_bytes: bytes = json_str.encode(encoding="utf-8")
8082

8183
# Encrypt with Fernet
82-
encrypted_bytes: bytes = self._encryption_fn(json_bytes)
84+
encrypted_bytes: bytes = self._encryption_fn(json_bytes, self.encryption_version)
8385

8486
# Encode to base64 for storage in dict (though Fernet output is already base64)
8587
base64_str: str = base64.b64encode(encrypted_bytes).decode(encoding="ascii")
8688

8789
return {
8890
_ENCRYPTED_DATA_KEY: base64_str,
89-
_ENCRYPTION_VERSION_KEY: _ENCRYPTION_VERSION,
91+
_ENCRYPTION_VERSION_KEY: self.encryption_version,
9092
}
9193

9294
def _decrypt_value(self, value: dict[str, Any] | None) -> dict[str, Any] | None:
@@ -105,12 +107,18 @@ def _decrypt_value(self, value: dict[str, Any] | None) -> dict[str, Any] | None:
105107
msg = f"Corrupted data: expected str, got {type(base64_str)}"
106108
raise TypeError(msg)
107109

110+
encryption_version = value[_ENCRYPTION_VERSION_KEY]
111+
if not isinstance(encryption_version, int):
112+
# Corrupted data, return as-is
113+
msg = f"Corrupted data: expected int, got {type(encryption_version)}"
114+
raise TypeError(msg)
115+
108116
try:
109117
# Decode from base64
110118
encrypted_bytes: bytes = base64.b64decode(base64_str)
111119

112120
# Decrypt with Fernet
113-
json_bytes: bytes = self._decryption_fn(encrypted_bytes)
121+
json_bytes: bytes = self._decryption_fn(encrypted_bytes, encryption_version)
114122

115123
# Parse JSON
116124
json_str: str = json_bytes.decode(encoding="utf-8")

key-value/key-value-aio/src/key_value/aio/wrappers/encryption/fernet.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
from cryptography.fernet import Fernet
2+
from key_value.shared.errors.wrappers.encryption import EncryptionVersionError
23
from typing_extensions import overload
34

45
from key_value.aio.protocols.key_value import AsyncKeyValue
56
from key_value.aio.wrappers.encryption.base import BaseEncryptionWrapper
67

8+
ENCRYPTION_VERSION = 1
9+
710

811
class FernetEncryptionWrapper(BaseEncryptionWrapper):
912
@overload
@@ -28,7 +31,7 @@ def __init__(
2831
key_value: AsyncKeyValue,
2932
*,
3033
source_material: str,
31-
salt: str | None = None,
34+
salt: str,
3235
raise_on_decryption_error: bool = True,
3336
) -> None:
3437
"""Initialize the Fernet encryption wrapper.
@@ -57,31 +60,43 @@ def __init__(
5760
if source_material is None:
5861
msg = "Must provide either fernet or source_material"
5962
raise ValueError(msg)
63+
if salt is None:
64+
msg = "Must provide a salt"
65+
raise ValueError(msg)
6066
fernet = Fernet(key=_generate_encryption_key(source_material=source_material, salt=salt))
6167

62-
def encrypt_with_fernet(data: bytes) -> bytes:
68+
def encrypt_with_fernet(data: bytes, encryption_version: int) -> bytes:
69+
if encryption_version > self.encryption_version:
70+
msg = f"Encryption failed: encryption version {encryption_version} is not supported"
71+
raise EncryptionVersionError(msg)
6372
return fernet.encrypt(data)
6473

65-
def decrypt_with_fernet(data: bytes) -> bytes:
74+
def decrypt_with_fernet(data: bytes, encryption_version: int) -> bytes:
75+
if encryption_version > self.encryption_version:
76+
msg = f"Decryption failed: encryption version {encryption_version} is not supported"
77+
raise EncryptionVersionError(msg)
6678
return fernet.decrypt(data)
6779

6880
super().__init__(
6981
key_value=key_value,
7082
encryption_fn=encrypt_with_fernet,
7183
decryption_fn=decrypt_with_fernet,
84+
encryption_version=ENCRYPTION_VERSION,
7285
raise_on_decryption_error=raise_on_decryption_error,
7386
)
7487

7588

76-
def _generate_encryption_key(source_material: str, salt: str | None = None) -> bytes:
89+
def _generate_encryption_key(source_material: str, salt: str) -> bytes:
90+
import base64
91+
7792
from cryptography.hazmat.primitives import hashes
7893
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
7994

80-
salt = salt or "py-key-value-salt"
81-
82-
return HKDF(
95+
derived_key = HKDF(
8396
algorithm=hashes.SHA256(),
8497
length=32,
8598
salt=salt.encode(),
8699
info=b"Fernet",
87100
).derive(key_material=source_material.encode())
101+
102+
return base64.urlsafe_b64encode(derived_key)

key-value/key-value-aio/tests/stores/wrappers/test_encryption.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import pytest
22
from cryptography.fernet import Fernet
3+
from dirty_equals import IsStr
34
from inline_snapshot import snapshot
45
from key_value.shared.errors.wrappers.encryption import DecryptionError
56
from typing_extensions import override
@@ -40,6 +41,18 @@ async def test_encryption_encrypts_value(self, store: FernetEncryptionWrapper, m
4041
result = await store.get(collection="test", key="test")
4142
assert result == original_value
4243

44+
async def test_encryption_with_wrong_encryption_version(self, store: FernetEncryptionWrapper):
45+
"""Test that encryption fails with the wrong encryption version."""
46+
store.encryption_version = 2
47+
original_value = {"test": "value"}
48+
await store.put(collection="test", key="test", value=original_value)
49+
50+
assert await store.get(collection="test", key="test") is not None
51+
store.encryption_version = 1
52+
53+
with pytest.raises(DecryptionError):
54+
await store.get(collection="test", key="test")
55+
4356
async def test_encryption_with_string_key(self, store: FernetEncryptionWrapper, memory_store: MemoryStore):
4457
"""Test that encryption works with a string key."""
4558
original_value = {"test": "value"}
@@ -51,7 +64,7 @@ async def test_encryption_with_string_key(self, store: FernetEncryptionWrapper,
5164
raw_result = await memory_store.get(collection="test", key="test")
5265
assert raw_result == snapshot(
5366
{
54-
"__encrypted_data__": "Z0FBQUFBQm8tc0ZsYWhZUmJqUnN0VGlyeGVoUWZuczlPUllyWWxyVEotTVNMVFMtd1hoalNTQk56eFdzNGVocEg0T0xDeEVkTHpJckc2Z0lGZGpCTWZpS3o3cmVWRmRUTl91RENvSW8zNnI3QTlJVmtrQ1FtNnc9",
67+
"__encrypted_data__": IsStr(min_length=32),
5568
"__encryption_version__": 1,
5669
}
5770
)
@@ -109,7 +122,7 @@ async def test_decryption_ignores_corrupted_data(self, memory_store: MemoryStore
109122

110123
assert await store.get(collection="test", key="test") is None
111124

112-
async def test_decryption_with_wrong_key_returns_original(self, memory_store: MemoryStore):
125+
async def test_decryption_with_wrong_key_raises_error(self, memory_store: MemoryStore):
113126
"""Test that decryption with the wrong key raises an error."""
114127
fernet1 = Fernet(key=Fernet.generate_key())
115128
fernet2 = Fernet(key=Fernet.generate_key())

key-value/key-value-shared/src/key_value/shared/errors/base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ class BaseKeyValueError(Exception):
55
"""Base exception for all KV Store Adapter errors."""
66

77
extra_info: ExtraInfoType | None = None
8+
message: str | None = None
89

910
def __init__(self, message: str | None = None, extra_info: ExtraInfoType | None = None):
1011
message_parts: list[str] = []
@@ -19,6 +20,8 @@ def __init__(self, message: str | None = None, extra_info: ExtraInfoType | None
1920

2021
message_parts.append(extra_info_str)
2122

22-
super().__init__(": ".join(message_parts))
23+
self.message = ": ".join(message_parts)
24+
25+
super().__init__(self.message)
2326

2427
self.extra_info = extra_info

key-value/key-value-shared/src/key_value/shared/errors/wrappers/encryption.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,7 @@ class EncryptionError(KeyValueOperationError):
77

88
class DecryptionError(EncryptionError):
99
"""Exception raised when decryption fails."""
10+
11+
12+
class EncryptionVersionError(EncryptionError):
13+
"""Exception raised when the encryption version is not supported."""

key-value/key-value-sync/src/key_value/sync/code_gen/wrappers/encryption/base.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
_ENCRYPTION_VERSION_KEY = "__encryption_version__"
1919
_ENCRYPTION_VERSION = 1
2020

21-
EncryptionFn = Callable[[bytes], bytes]
22-
DecryptionFn = Callable[[bytes], bytes]
21+
EncryptionFn = Callable[[bytes, int], bytes]
22+
DecryptionFn = Callable[[bytes, int], bytes]
2323

2424

2525
class EncryptionError(Exception):
@@ -44,18 +44,27 @@ class BaseEncryptionWrapper(BaseWrapper):
4444
"""
4545

4646
def __init__(
47-
self, key_value: KeyValue, encryption_fn: EncryptionFn, decryption_fn: DecryptionFn, raise_on_decryption_error: bool = True
47+
self,
48+
key_value: KeyValue,
49+
encryption_fn: EncryptionFn,
50+
decryption_fn: DecryptionFn,
51+
encryption_version: int,
52+
raise_on_decryption_error: bool = True,
4853
) -> None:
4954
"""Initialize the encryption wrapper.
5055
5156
Args:
5257
key_value: The store to wrap.
53-
encryption_fn: The encryption function to use.
54-
decryption_fn: The decryption function to use.
58+
encryption_fn: The encryption function to use. A callable that takes bytes and an
59+
encryption version int and returns encrypted bytes.
60+
decryption_fn: The decryption function to use. A callable that takes bytes and an
61+
encryption version int and returns decrypted bytes.
62+
encryption_version: The encryption version to use.
5563
raise_on_decryption_error: Whether to raise an exception if decryption fails. Defaults to True.
5664
"""
5765
self.key_value: KeyValue = key_value
5866
self.raise_on_decryption_error: bool = raise_on_decryption_error
67+
self.encryption_version: int = encryption_version
5968

6069
self._encryption_fn: EncryptionFn = encryption_fn
6170
self._decryption_fn: DecryptionFn = decryption_fn
@@ -64,9 +73,6 @@ def __init__(
6473

6574
def _encrypt_value(self, value: dict[str, Any]) -> dict[str, Any]:
6675
"""Encrypt a value into the encrypted format."""
67-
# # Don't encrypt if it's already encrypted
68-
# if _ENCRYPTED_DATA_KEY in value:
69-
# return value
7076

7177
# Serialize to JSON
7278
try:
@@ -78,12 +84,12 @@ def _encrypt_value(self, value: dict[str, Any]) -> dict[str, Any]:
7884
json_bytes: bytes = json_str.encode(encoding="utf-8")
7985

8086
# Encrypt with Fernet
81-
encrypted_bytes: bytes = self._encryption_fn(json_bytes)
87+
encrypted_bytes: bytes = self._encryption_fn(json_bytes, self.encryption_version)
8288

8389
# Encode to base64 for storage in dict (though Fernet output is already base64)
8490
base64_str: str = base64.b64encode(encrypted_bytes).decode(encoding="ascii")
8591

86-
return {_ENCRYPTED_DATA_KEY: base64_str, _ENCRYPTION_VERSION_KEY: _ENCRYPTION_VERSION}
92+
return {_ENCRYPTED_DATA_KEY: base64_str, _ENCRYPTION_VERSION_KEY: self.encryption_version}
8793

8894
def _decrypt_value(self, value: dict[str, Any] | None) -> dict[str, Any] | None:
8995
"""Decrypt a value from the encrypted format."""
@@ -101,12 +107,18 @@ def _decrypt_value(self, value: dict[str, Any] | None) -> dict[str, Any] | None:
101107
msg = f"Corrupted data: expected str, got {type(base64_str)}"
102108
raise TypeError(msg)
103109

110+
encryption_version = value[_ENCRYPTION_VERSION_KEY]
111+
if not isinstance(encryption_version, int):
112+
# Corrupted data, return as-is
113+
msg = f"Corrupted data: expected int, got {type(encryption_version)}"
114+
raise TypeError(msg)
115+
104116
try:
105117
# Decode from base64
106118
encrypted_bytes: bytes = base64.b64decode(base64_str)
107119

108120
# Decrypt with Fernet
109-
json_bytes: bytes = self._decryption_fn(encrypted_bytes)
121+
json_bytes: bytes = self._decryption_fn(encrypted_bytes, encryption_version)
110122

111123
# Parse JSON
112124
json_str: str = json_bytes.decode(encoding="utf-8")

key-value/key-value-sync/src/key_value/sync/code_gen/wrappers/encryption/fernet.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,14 @@
22
# from the original file 'fernet.py'
33
# DO NOT CHANGE! Change the original file instead.
44
from cryptography.fernet import Fernet
5+
from key_value.shared.errors.wrappers.encryption import EncryptionVersionError
56
from typing_extensions import overload
67

78
from key_value.sync.code_gen.protocols.key_value import KeyValue
89
from key_value.sync.code_gen.wrappers.encryption.base import BaseEncryptionWrapper
910

11+
ENCRYPTION_VERSION = 1
12+
1013

1114
class FernetEncryptionWrapper(BaseEncryptionWrapper):
1215
@overload
@@ -20,9 +23,7 @@ def __init__(self, key_value: KeyValue, *, fernet: Fernet, raise_on_decryption_e
2023
"""
2124

2225
@overload
23-
def __init__(
24-
self, key_value: KeyValue, *, source_material: str, salt: str | None = None, raise_on_decryption_error: bool = True
25-
) -> None:
26+
def __init__(self, key_value: KeyValue, *, source_material: str, salt: str, raise_on_decryption_error: bool = True) -> None:
2627
"""Initialize the Fernet encryption wrapper.
2728
2829
Args:
@@ -49,26 +50,40 @@ def __init__(
4950
if source_material is None:
5051
msg = "Must provide either fernet or source_material"
5152
raise ValueError(msg)
53+
if salt is None:
54+
msg = "Must provide a salt"
55+
raise ValueError(msg)
5256
fernet = Fernet(key=_generate_encryption_key(source_material=source_material, salt=salt))
5357

54-
def encrypt_with_fernet(data: bytes) -> bytes:
58+
def encrypt_with_fernet(data: bytes, encryption_version: int) -> bytes:
59+
if encryption_version > self.encryption_version:
60+
msg = f"Encryption failed: encryption version {encryption_version} is not supported"
61+
raise EncryptionVersionError(msg)
5562
return fernet.encrypt(data)
5663

57-
def decrypt_with_fernet(data: bytes) -> bytes:
64+
def decrypt_with_fernet(data: bytes, encryption_version: int) -> bytes:
65+
if encryption_version > self.encryption_version:
66+
msg = f"Decryption failed: encryption version {encryption_version} is not supported"
67+
raise EncryptionVersionError(msg)
5868
return fernet.decrypt(data)
5969

6070
super().__init__(
6171
key_value=key_value,
6272
encryption_fn=encrypt_with_fernet,
6373
decryption_fn=decrypt_with_fernet,
74+
encryption_version=ENCRYPTION_VERSION,
6475
raise_on_decryption_error=raise_on_decryption_error,
6576
)
6677

6778

68-
def _generate_encryption_key(source_material: str, salt: str | None = None) -> bytes:
79+
def _generate_encryption_key(source_material: str, salt: str) -> bytes:
80+
import base64
81+
6982
from cryptography.hazmat.primitives import hashes
7083
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
7184

72-
salt = salt or "py-key-value-salt"
85+
derived_key = HKDF(algorithm=hashes.SHA256(), length=32, salt=salt.encode(), info=b"Fernet").derive(
86+
key_material=source_material.encode()
87+
)
7388

74-
return HKDF(algorithm=hashes.SHA256(), length=32, salt=salt.encode(), info=b"Fernet").derive(key_material=source_material.encode())
89+
return base64.urlsafe_b64encode(derived_key)

0 commit comments

Comments
 (0)