Skip to content

Commit d62a400

Browse files
committed
Fixes for store tests
1 parent 194bc94 commit d62a400

File tree

16 files changed

+193
-47
lines changed

16 files changed

+193
-47
lines changed

key-value/key-value-aio/src/key_value/aio/stores/base.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,17 @@ class BaseStore(AsyncKeyValueProtocol, ABC):
7474

7575
default_collection: str
7676

77-
def __init__(self, *, default_collection: str | None = None, seed: SEED_DATA_TYPE | None = None) -> None:
77+
def __init__(
78+
self,
79+
*,
80+
serialization_adapter: SerializationAdapter | None = None,
81+
default_collection: str | None = None,
82+
seed: SEED_DATA_TYPE | None = None,
83+
) -> None:
7884
"""Initialize the managed key-value store.
7985
8086
Args:
87+
serialization_adapter: The serialization adapter to use for the store.
8188
default_collection: The default collection to use if no collection is provided.
8289
Defaults to "default_collection".
8390
seed: Optional seed data to pre-populate the store. Format: {collection: {key: {field: value, ...}}}.
@@ -94,7 +101,7 @@ def __init__(self, *, default_collection: str | None = None, seed: SEED_DATA_TYP
94101

95102
self.default_collection = default_collection or DEFAULT_COLLECTION_NAME
96103

97-
self._serialization_adapter = BasicSerializationAdapter()
104+
self._serialization_adapter = serialization_adapter or BasicSerializationAdapter()
98105

99106
if not hasattr(self, "_stable_api"):
100107
self._stable_api = False

key-value/key-value-aio/src/key_value/aio/stores/memcached/store.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
from key_value.shared.utils.compound import compound_key
66
from key_value.shared.utils.managed_entry import ManagedEntry
7-
from key_value.shared.utils.serialization import BasicSerializationAdapter
87
from typing_extensions import override
98

109
from key_value.aio.stores.base import BaseContextManagerStore, BaseDestroyStore, BaseStore
@@ -45,11 +44,9 @@ def __init__(
4544
port: Memcached port. Defaults to 11211.
4645
default_collection: The default collection to use if no collection is provided.
4746
"""
48-
super().__init__(default_collection=default_collection)
49-
5047
self._client = client or Client(host=host, port=port)
5148

52-
self._serialization_adapter = BasicSerializationAdapter(value_format="dict")
49+
super().__init__(default_collection=default_collection)
5350

5451
def sanitize_key(self, key: str) -> str:
5552
if len(key) > MAX_KEY_LENGTH:

key-value/key-value-aio/src/key_value/aio/stores/memory/store.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,6 @@ def __init__(
117117

118118
self._stable_api = True
119119

120-
self._serialization_adapter = BasicSerializationAdapter()
121-
122120
super().__init__(default_collection=default_collection, seed=seed)
123121

124122
@override

key-value/key-value-aio/src/key_value/aio/stores/mongodb/store.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
from datetime import datetime, timezone
33
from typing import Any, overload
44

5-
from key_value.shared.errors import DeserializationError
5+
from bson.errors import InvalidDocument
6+
from key_value.shared.errors import DeserializationError, SerializationError
67
from key_value.shared.utils.managed_entry import ManagedEntry
78
from key_value.shared.utils.sanitize import ALPHANUMERIC_CHARACTERS, sanitize_string
89
from key_value.shared.utils.serialization import SerializationAdapter
@@ -268,17 +269,22 @@ async def _put_managed_entry(
268269

269270
sanitized_collection = self._sanitize_collection_name(collection=collection)
270271

271-
_ = await self._collections_by_name[sanitized_collection].update_one(
272-
filter={"key": key},
273-
update={
274-
"$set": {
275-
"collection": collection,
276-
"key": key,
277-
**mongo_doc,
278-
}
279-
},
280-
upsert=True,
281-
)
272+
try:
273+
# Ensure that the value is serializable to JSON
274+
_ = managed_entry.value_as_json
275+
_ = await self._collections_by_name[sanitized_collection].update_one(
276+
filter={"key": key},
277+
update={
278+
"$set": {
279+
"key": key,
280+
**mongo_doc,
281+
}
282+
},
283+
upsert=True,
284+
)
285+
except InvalidDocument as e:
286+
msg = f"Failed to update MongoDB document: {e}"
287+
raise SerializationError(message=msg) from e
282288

283289
@override
284290
async def _put_managed_entries(

key-value/key-value-aio/src/key_value/aio/stores/simple/store.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,14 @@ def __init__(self, max_entries: int | None = None, default_collection: str | Non
5858
default_collection: The default collection to use if no collection is provided.
5959
"""
6060

61-
super().__init__(default_collection=default_collection)
62-
6361
self.max_entries = max_entries or sys.maxsize
6462

6563
self._data = defaultdict[str, SimpleStoreEntry]()
6664

6765
self._serialization_adapter = ValueOnlySerializationAdapter()
6866

67+
super().__init__(default_collection=default_collection)
68+
6969
@override
7070
async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry | None:
7171
combo_key: str = compound_key(collection=collection, key=key)

key-value/key-value-aio/tests/stores/mongodb/test_mongodb.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ def test_managed_entry_document_conversion_native_mode():
5656

5757
assert document == snapshot(
5858
{
59-
"key": "test",
6059
"value": {"object": {"test": "test"}},
6160
"created_at": datetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc),
6261
"expires_at": datetime(2025, 1, 1, 0, 0, 10, tzinfo=timezone.utc),
@@ -81,7 +80,6 @@ def test_managed_entry_document_conversion_legacy_mode():
8180

8281
assert document == snapshot(
8382
{
84-
"key": "test",
8583
"value": {"string": '{"test": "test"}'},
8684
"created_at": datetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc),
8785
"expires_at": datetime(2025, 1, 1, 0, 0, 10, tzinfo=timezone.utc),

key-value/key-value-aio/tests/stores/valkey/test_valkey.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ class ValkeyFailedToStartError(Exception):
3636
pass
3737

3838

39+
def get_valkey_client_from_store(store: ValkeyStore) -> BaseClient:
40+
return store._connected_client # pyright: ignore[reportPrivateUsage, reportReturnType]
41+
42+
3943
@pytest.mark.skipif(should_skip_docker_tests(), reason="Docker is not running")
4044
@pytest.mark.skipif(detect_on_windows(), reason="Valkey is not supported on Windows")
4145
class TestValkeyStore(ContextManagerStoreTestMixin, BaseStoreTests):
@@ -94,9 +98,10 @@ async def valkey_client(self, store: ValkeyStore):
9498
@override
9599
async def test_not_unbounded(self, store: BaseStore): ...
96100

97-
async def test_value_stored(self, store: ValkeyStore, valkey_client: BaseClient):
101+
async def test_value_stored(self, store: ValkeyStore):
98102
await store.put(collection="test", key="test_key", value={"name": "Alice", "age": 30})
99103

104+
valkey_client = get_valkey_client_from_store(store=store)
100105
value = await valkey_client.get(key="test::test_key")
101106
assert value is not None
102107
value_as_dict = json.loads(value.decode("utf-8"))

key-value/key-value-shared/src/key_value/shared/utils/managed_entry.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,12 +85,12 @@ def load_from_json(json_str: str) -> dict[str, Any]:
8585
@bear_enforce
8686
def verify_dict(obj: Any) -> dict[str, Any]:
8787
if not isinstance(obj, Mapping):
88-
msg = "Object is not a dictionary"
89-
raise DeserializationError(msg)
88+
msg = "Object is not a Mapping"
89+
raise TypeError(msg)
9090

9191
if not all(isinstance(key, str) for key in obj): # pyright: ignore[reportUnknownVariableType]
9292
msg = "Object contains non-string keys"
93-
raise DeserializationError(msg)
93+
raise TypeError(msg)
9494

9595
return dict(obj) # pyright: ignore[reportUnknownArgumentType]
9696

key-value/key-value-shared/src/key_value/shared/utils/serialization.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
from key_value.shared.errors import DeserializationError
1313
from key_value.shared.utils.managed_entry import ManagedEntry, dump_to_json, load_from_json, verify_dict
14-
from key_value.shared.utils.time_to_live import try_parse_datetime_str
1514

1615
T = TypeVar("T")
1716

@@ -25,6 +24,14 @@ def key_must_be(dictionary: dict[str, Any], /, key: str, expected_type: type[T])
2524
return dictionary[key]
2625

2726

27+
def parse_datetime_str(value: str) -> datetime:
28+
try:
29+
return datetime.fromisoformat(value)
30+
except ValueError:
31+
msg = f"Invalid datetime string: {value}"
32+
raise DeserializationError(message=msg) from None
33+
34+
2835
class SerializationAdapter(ABC):
2936
"""Base class for store-specific serialization adapters.
3037
@@ -43,7 +50,7 @@ def __init__(
4350
self._value_format = value_format
4451

4552
def load_json(self, json_str: str) -> ManagedEntry:
46-
"""Convert a JSON string to a dictionary."""
53+
"""Convert a JSON string to a ManagedEntry."""
4754
loaded_data: dict[str, Any] = load_from_json(json_str=json_str)
4855

4956
return self.load_dict(data=loaded_data)
@@ -63,9 +70,9 @@ def load_dict(self, data: dict[str, Any]) -> ManagedEntry:
6370

6471
if self._date_format == "isoformat":
6572
if created_at := key_must_be(data, key="created_at", expected_type=str):
66-
managed_entry_proto["created_at"] = try_parse_datetime_str(value=created_at)
73+
managed_entry_proto["created_at"] = parse_datetime_str(value=created_at)
6774
if expires_at := key_must_be(data, key="expires_at", expected_type=str):
68-
managed_entry_proto["expires_at"] = try_parse_datetime_str(value=expires_at)
75+
managed_entry_proto["expires_at"] = parse_datetime_str(value=expires_at)
6976

7077
if self._date_format == "datetime":
7178
if created_at := key_must_be(data, key="created_at", expected_type=datetime):
@@ -104,10 +111,16 @@ def dump_dict(self, entry: ManagedEntry, exclude_none: bool = True) -> dict[str,
104111

105112
data: dict[str, Any] = {
106113
"value": entry.value_as_dict if self._value_format == "dict" else entry.value_as_json,
107-
"created_at": entry.created_at_isoformat,
108-
"expires_at": entry.expires_at_isoformat,
109114
}
110115

116+
if self._date_format == "isoformat":
117+
data["created_at"] = entry.created_at_isoformat
118+
data["expires_at"] = entry.expires_at_isoformat
119+
120+
if self._date_format == "datetime":
121+
data["created_at"] = entry.created_at
122+
data["expires_at"] = entry.expires_at
123+
111124
if exclude_none:
112125
data = {k: v for k, v in data.items() if v is not None}
113126

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
from datetime import datetime, timedelta, timezone
2+
3+
import pytest
4+
from inline_snapshot import snapshot
5+
6+
from key_value.shared.utils.managed_entry import ManagedEntry
7+
from key_value.shared.utils.serialization import BasicSerializationAdapter, ValueOnlySerializationAdapter
8+
9+
FIXED_DATETIME_ONE = datetime(2025, 1, 1, 0, 0, 0, tzinfo=timezone.utc)
10+
FIXED_DATETIME_ONE_ISOFORMAT = FIXED_DATETIME_ONE.isoformat()
11+
FIXED_DATETIME_ONE_PLUS_10_SECONDS = FIXED_DATETIME_ONE + timedelta(seconds=10)
12+
FIXED_DATETIME_ONE_PLUS_10_SECONDS_ISOFORMAT = FIXED_DATETIME_ONE_PLUS_10_SECONDS.isoformat()
13+
14+
FIXED_DATETIME_TWO = datetime(2025, 1, 1, 0, 0, 1, tzinfo=timezone.utc)
15+
FIXED_DATETIME_TWO_PLUS_10_SECONDS = FIXED_DATETIME_TWO + timedelta(seconds=10)
16+
FIXED_DATETIME_TWO_ISOFORMAT = FIXED_DATETIME_TWO.isoformat()
17+
FIXED_DATETIME_TWO_PLUS_10_SECONDS_ISOFORMAT = FIXED_DATETIME_TWO_PLUS_10_SECONDS.isoformat()
18+
19+
TEST_DATA_ONE = {"key_one": "value_one", "key_two": "value_two", "key_three": {"nested_key": "nested_value"}}
20+
TEST_ENTRY_ONE = ManagedEntry(value=TEST_DATA_ONE, created_at=FIXED_DATETIME_ONE, expires_at=FIXED_DATETIME_ONE_PLUS_10_SECONDS)
21+
TEST_DATA_TWO = {"key_one": ["value_one", "value_two", "value_three"], "key_two": 123, "key_three": {"nested_key": "nested_value"}}
22+
TEST_ENTRY_TWO = ManagedEntry(value=TEST_DATA_TWO, created_at=FIXED_DATETIME_TWO, expires_at=FIXED_DATETIME_TWO_PLUS_10_SECONDS)
23+
24+
25+
@pytest.fixture
26+
def serialization_adapter() -> BasicSerializationAdapter:
27+
return BasicSerializationAdapter()
28+
29+
30+
class TestBasicSerializationAdapter:
31+
@pytest.fixture
32+
def adapter(self) -> BasicSerializationAdapter:
33+
return BasicSerializationAdapter()
34+
35+
def test_entry_one(self, adapter: BasicSerializationAdapter):
36+
assert adapter.dump_dict(entry=TEST_ENTRY_ONE) == snapshot(
37+
{
38+
"value": TEST_DATA_ONE,
39+
"created_at": FIXED_DATETIME_ONE_ISOFORMAT,
40+
"expires_at": FIXED_DATETIME_ONE_PLUS_10_SECONDS_ISOFORMAT,
41+
}
42+
)
43+
44+
assert adapter.dump_json(entry=TEST_ENTRY_ONE) == snapshot(
45+
'{"created_at": "2025-01-01T00:00:00+00:00", "expires_at": "2025-01-01T00:00:10+00:00", "value": {"key_one": "value_one", "key_three": {"nested_key": "nested_value"}, "key_two": "value_two"}}'
46+
)
47+
48+
assert adapter.load_dict(data=adapter.dump_dict(entry=TEST_ENTRY_ONE)) == snapshot(TEST_ENTRY_ONE)
49+
assert adapter.load_json(json_str=adapter.dump_json(entry=TEST_ENTRY_ONE)) == snapshot(TEST_ENTRY_ONE)
50+
51+
def test_entry_two(self, adapter: BasicSerializationAdapter):
52+
assert adapter.dump_dict(entry=TEST_ENTRY_TWO) == snapshot(
53+
{
54+
"value": TEST_DATA_TWO,
55+
"created_at": FIXED_DATETIME_TWO_ISOFORMAT,
56+
"expires_at": FIXED_DATETIME_TWO_PLUS_10_SECONDS_ISOFORMAT,
57+
}
58+
)
59+
60+
assert adapter.dump_json(entry=TEST_ENTRY_TWO) == snapshot(
61+
'{"created_at": "2025-01-01T00:00:01+00:00", "expires_at": "2025-01-01T00:00:11+00:00", "value": {"key_one": ["value_one", "value_two", "value_three"], "key_three": {"nested_key": "nested_value"}, "key_two": 123}}'
62+
)
63+
64+
assert adapter.load_dict(data=adapter.dump_dict(entry=TEST_ENTRY_TWO)) == snapshot(TEST_ENTRY_TWO)
65+
assert adapter.load_json(json_str=adapter.dump_json(entry=TEST_ENTRY_TWO)) == snapshot(TEST_ENTRY_TWO)
66+
67+
68+
class TestValueOnlySerializationAdapter:
69+
@pytest.fixture
70+
def adapter(self) -> ValueOnlySerializationAdapter:
71+
return ValueOnlySerializationAdapter()
72+
73+
def test_entry_one(self, adapter: ValueOnlySerializationAdapter):
74+
assert adapter.dump_dict(entry=TEST_ENTRY_ONE) == snapshot(
75+
{"key_one": "value_one", "key_two": "value_two", "key_three": {"nested_key": "nested_value"}}
76+
)
77+
78+
assert adapter.dump_json(entry=TEST_ENTRY_ONE) == snapshot(
79+
'{"key_one": "value_one", "key_three": {"nested_key": "nested_value"}, "key_two": "value_two"}'
80+
)
81+
82+
assert adapter.load_json(json_str=adapter.dump_json(entry=TEST_ENTRY_ONE)) == snapshot(
83+
ManagedEntry(value={"key_one": "value_one", "key_three": {"nested_key": "nested_value"}, "key_two": "value_two"})
84+
)
85+
assert adapter.load_dict(data=adapter.dump_dict(entry=TEST_ENTRY_ONE)) == snapshot(
86+
ManagedEntry(value={"key_one": "value_one", "key_two": "value_two", "key_three": {"nested_key": "nested_value"}})
87+
)
88+
89+
def test_entry_two(self, adapter: ValueOnlySerializationAdapter):
90+
assert adapter.dump_dict(entry=TEST_ENTRY_TWO) == snapshot(
91+
{"key_one": ["value_one", "value_two", "value_three"], "key_two": 123, "key_three": {"nested_key": "nested_value"}}
92+
)
93+
94+
assert adapter.dump_json(entry=TEST_ENTRY_TWO) == snapshot(
95+
'{"key_one": ["value_one", "value_two", "value_three"], "key_three": {"nested_key": "nested_value"}, "key_two": 123}'
96+
)
97+
98+
assert adapter.load_json(json_str=adapter.dump_json(entry=TEST_ENTRY_TWO)) == snapshot(
99+
ManagedEntry(
100+
value={"key_one": ["value_one", "value_two", "value_three"], "key_three": {"nested_key": "nested_value"}, "key_two": 123}
101+
)
102+
)
103+
assert adapter.load_dict(data=adapter.dump_dict(entry=TEST_ENTRY_TWO)) == snapshot(
104+
ManagedEntry(
105+
value={"key_one": ["value_one", "value_two", "value_three"], "key_two": 123, "key_three": {"nested_key": "nested_value"}}
106+
)
107+
)

0 commit comments

Comments
 (0)