Skip to content

Commit b3a84a6

Browse files
feat: add seed support, default values, and JSON serialization wrappers (#122)
Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com> Co-authored-by: William Easton <strawgate@users.noreply.github.com>
1 parent 8151b4f commit b3a84a6

File tree

17 files changed

+573
-42
lines changed

17 files changed

+573
-42
lines changed

key-value/key-value-aio/src/key_value/aio/adapters/pydantic/adapter.py

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections.abc import Sequence
2-
from typing import Any, Generic, SupportsFloat, TypeVar, get_origin
2+
from typing import Any, Generic, SupportsFloat, TypeVar, get_origin, overload
33

44
from key_value.shared.errors import DeserializationError, SerializationError
55
from key_value.shared.type_checking.bear_spray import bear_spray
@@ -71,11 +71,22 @@ def _serialize_model(self, value: T) -> dict[str, Any]:
7171
msg = f"Invalid Pydantic model: {e}"
7272
raise SerializationError(msg) from e
7373

74-
async def get(self, key: str, *, collection: str | None = None) -> T | None:
74+
@overload
75+
async def get(self, key: str, *, collection: str | None = None, default: T) -> T: ...
76+
77+
@overload
78+
async def get(self, key: str, *, collection: str | None = None, default: None = None) -> T | None: ...
79+
80+
async def get(self, key: str, *, collection: str | None = None, default: T | None = None) -> T | None:
7581
"""Get and validate a model by key.
7682
83+
Args:
84+
key: The key to retrieve.
85+
collection: The collection to use. If not provided, uses the default collection.
86+
default: The default value to return if the key doesn't exist or validation fails.
87+
7788
Returns:
78-
The parsed model instance, or None if not present.
89+
The parsed model instance if found and valid, or the default value if key doesn't exist or validation fails.
7990
8091
Raises:
8192
DeserializationError if the stored data cannot be validated as the model and the PydanticAdapter is configured to
@@ -84,15 +95,28 @@ async def get(self, key: str, *, collection: str | None = None) -> T | None:
8495
collection = collection or self._default_collection
8596

8697
if value := await self._key_value.get(key=key, collection=collection):
87-
return self._validate_model(value=value)
98+
validated = self._validate_model(value=value)
99+
if validated is not None:
100+
return validated
88101

89-
return None
102+
return default
90103

91-
async def get_many(self, keys: Sequence[str], *, collection: str | None = None) -> list[T | None]:
104+
@overload
105+
async def get_many(self, keys: Sequence[str], *, collection: str | None = None, default: T) -> list[T]: ...
106+
107+
@overload
108+
async def get_many(self, keys: Sequence[str], *, collection: str | None = None, default: None = None) -> list[T | None]: ...
109+
110+
async def get_many(self, keys: Sequence[str], *, collection: str | None = None, default: T | None = None) -> list[T] | list[T | None]:
92111
"""Batch get and validate models by keys, preserving order.
93112
113+
Args:
114+
keys: The list of keys to retrieve.
115+
collection: The collection to use. If not provided, uses the default collection.
116+
default: The default value to return for keys that don't exist or fail validation.
117+
94118
Returns:
95-
A list of parsed model instances, or None if missing.
119+
A list of parsed model instances, with default values for missing keys or validation failures.
96120
97121
Raises:
98122
DeserializationError if the stored data cannot be validated as the model and the PydanticAdapter is configured to
@@ -102,7 +126,14 @@ async def get_many(self, keys: Sequence[str], *, collection: str | None = None)
102126

103127
values: list[dict[str, Any] | None] = await self._key_value.get_many(keys=keys, collection=collection)
104128

105-
return [self._validate_model(value=value) if value else None for value in values]
129+
result: list[T | None] = []
130+
for value in values:
131+
if value is None:
132+
result.append(default)
133+
else:
134+
validated = self._validate_model(value=value)
135+
result.append(validated if validated is not None else default)
136+
return result
106137

107138
async def put(self, key: str, value: T, *, collection: str | None = None, ttl: SupportsFloat | None = None) -> None:
108139
"""Serialize and store a model.

key-value/key-value-aio/src/key_value/aio/stores/base.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from asyncio.locks import Lock
77
from collections import defaultdict
88
from collections.abc import Mapping, Sequence
9-
from types import TracebackType
9+
from types import MappingProxyType, TracebackType
1010
from typing import Any, SupportsFloat
1111

1212
from key_value.shared.constants import DEFAULT_COLLECTION_NAME
@@ -24,6 +24,16 @@
2424
AsyncKeyValueProtocol,
2525
)
2626

27+
SEED_DATA_TYPE = Mapping[str, Mapping[str, Mapping[str, Any]]]
28+
FROZEN_SEED_DATA_TYPE = MappingProxyType[str, MappingProxyType[str, MappingProxyType[str, Any]]]
29+
DEFAULT_SEED_DATA: FROZEN_SEED_DATA_TYPE = MappingProxyType({})
30+
31+
32+
def _seed_to_frozen_seed_data(seed: SEED_DATA_TYPE) -> FROZEN_SEED_DATA_TYPE:
33+
return MappingProxyType(
34+
{collection: MappingProxyType({key: MappingProxyType(value) for key, value in items.items()}) for collection, items in seed.items()}
35+
)
36+
2737

2838
class BaseStore(AsyncKeyValueProtocol, ABC):
2939
"""An opinionated Abstract base class for managed key-value stores using ManagedEntry objects.
@@ -43,21 +53,28 @@ class BaseStore(AsyncKeyValueProtocol, ABC):
4353
_setup_collection_locks: defaultdict[str, Lock]
4454
_setup_collection_complete: defaultdict[str, bool]
4555

56+
_seed: FROZEN_SEED_DATA_TYPE
57+
4658
default_collection: str
4759

48-
def __init__(self, *, default_collection: str | None = None) -> None:
60+
def __init__(self, *, default_collection: str | None = None, seed: SEED_DATA_TYPE | None = None) -> None:
4961
"""Initialize the managed key-value store.
5062
5163
Args:
5264
default_collection: The default collection to use if no collection is provided.
5365
Defaults to "default_collection".
66+
seed: Optional seed data to pre-populate the store. Format: {collection: {key: {field: value, ...}}}.
67+
Seeding occurs once during store initialization (when the store is first entered or when the
68+
first operation is performed on the store).
5469
"""
5570

5671
self._setup_complete = False
5772
self._setup_lock = Lock()
5873
self._setup_collection_locks = defaultdict(Lock)
5974
self._setup_collection_complete = defaultdict(bool)
6075

76+
self._seed = _seed_to_frozen_seed_data(seed=seed or {})
77+
6178
self.default_collection = default_collection or DEFAULT_COLLECTION_NAME
6279

6380
if not hasattr(self, "_stable_api"):
@@ -74,6 +91,13 @@ async def _setup(self) -> None:
7491
async def _setup_collection(self, *, collection: str) -> None:
7592
"""Initialize the collection (called once before first use of the collection)."""
7693

94+
async def _seed_store(self) -> None:
95+
"""Seed the store with the data from the seed."""
96+
for collection, items in self._seed.items():
97+
await self.setup_collection(collection=collection)
98+
for key, value in items.items():
99+
await self.put(key=key, value=dict(value), collection=collection)
100+
77101
async def setup(self) -> None:
78102
if not self._setup_complete:
79103
async with self._setup_lock:
@@ -84,8 +108,11 @@ async def setup(self) -> None:
84108
raise StoreSetupError(
85109
message=f"Failed to setup key value store: {e}", extra_info={"store": self.__class__.__name__}
86110
) from e
111+
87112
self._setup_complete = True
88113

114+
await self._seed_store()
115+
89116
async def setup_collection(self, *, collection: str) -> None:
90117
await self.setup()
91118

key-value/key-value-aio/src/key_value/aio/stores/memory/store.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing_extensions import Self, override
99

1010
from key_value.aio.stores.base import (
11+
SEED_DATA_TYPE,
1112
BaseDestroyCollectionStore,
1213
BaseDestroyStore,
1314
BaseEnumerateCollectionsStore,
@@ -109,12 +110,21 @@ class MemoryStore(BaseDestroyStore, BaseDestroyCollectionStore, BaseEnumerateCol
109110

110111
_cache: dict[str, MemoryCollection]
111112

112-
def __init__(self, *, max_entries_per_collection: int = DEFAULT_MAX_ENTRIES_PER_COLLECTION, default_collection: str | None = None):
113+
def __init__(
114+
self,
115+
*,
116+
max_entries_per_collection: int = DEFAULT_MAX_ENTRIES_PER_COLLECTION,
117+
default_collection: str | None = None,
118+
seed: SEED_DATA_TYPE | None = None,
119+
):
113120
"""Initialize a fixed-size in-memory store.
114121
115122
Args:
116123
max_entries_per_collection: The maximum number of entries per collection. Defaults to 10000.
117124
default_collection: The default collection to use if no collection is provided.
125+
seed: Optional seed data to pre-populate the store. Format: {collection: {key: {field: value, ...}}}.
126+
Each value must be a mapping (dict) that will be stored as the entry's value.
127+
Seeding occurs lazily when each collection is first accessed.
118128
"""
119129

120130
self.max_entries_per_collection = max_entries_per_collection
@@ -123,11 +133,25 @@ def __init__(self, *, max_entries_per_collection: int = DEFAULT_MAX_ENTRIES_PER_
123133

124134
self._stable_api = True
125135

126-
super().__init__(default_collection=default_collection)
136+
super().__init__(default_collection=default_collection, seed=seed)
137+
138+
@override
139+
async def _setup(self) -> None:
140+
for collection in self._seed:
141+
await self._setup_collection(collection=collection)
127142

128143
@override
129144
async def _setup_collection(self, *, collection: str) -> None:
130-
self._cache[collection] = MemoryCollection(max_entries=self.max_entries_per_collection)
145+
"""Set up a collection, creating it and seeding it if seed data is available.
146+
147+
Args:
148+
collection: The collection name.
149+
"""
150+
if collection in self._cache:
151+
return
152+
153+
collection_cache = MemoryCollection(max_entries=self.max_entries_per_collection)
154+
self._cache[collection] = collection_cache
131155

132156
@override
133157
async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry | None:
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""Default value wrapper for returning fallback values when keys are not found."""
2+
3+
from key_value.aio.wrappers.default_value.wrapper import DefaultValueWrapper
4+
5+
__all__ = ["DefaultValueWrapper"]
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from collections.abc import Mapping, Sequence
2+
from typing import Any, SupportsFloat
3+
4+
from key_value.shared.utils.managed_entry import dump_to_json, load_from_json
5+
from typing_extensions import override
6+
7+
from key_value.aio.protocols.key_value import AsyncKeyValue
8+
from key_value.aio.wrappers.base import BaseWrapper
9+
10+
11+
class DefaultValueWrapper(BaseWrapper):
12+
"""A wrapper that returns a default value when a key is not found.
13+
14+
This wrapper provides dict.get(key, default) behavior for the key-value store,
15+
allowing you to specify a default value to return instead of None when a key doesn't exist.
16+
17+
It does not store the default value in the underlying key-value store and the TTL returned with the default
18+
value is hard-coded based on the default_ttl parameter. Picking a default_ttl requires careful consideration
19+
of how the value will be used and if any other wrappers will be used that may rely on the TTL.
20+
"""
21+
22+
key_value: AsyncKeyValue # Alias for BaseWrapper compatibility
23+
_default_ttl: float | None
24+
_default_value_json: str
25+
26+
def __init__(
27+
self,
28+
key_value: AsyncKeyValue,
29+
default_value: Mapping[str, Any],
30+
default_ttl: SupportsFloat | None = None,
31+
) -> None:
32+
"""Initialize the DefaultValueWrapper.
33+
34+
Args:
35+
key_value: The underlying key-value store to wrap.
36+
default_value: The default value to return when a key is not found.
37+
default_ttl: The TTL to return to the caller for default values. Defaults to None.
38+
"""
39+
self.key_value = key_value
40+
self._default_value_json = dump_to_json(obj=dict(default_value))
41+
self._default_ttl = None if default_ttl is None else float(default_ttl)
42+
43+
def _new_default_value(self) -> dict[str, Any]:
44+
return load_from_json(json_str=self._default_value_json)
45+
46+
@override
47+
async def get(self, key: str, *, collection: str | None = None) -> dict[str, Any] | None:
48+
result = await self.key_value.get(key=key, collection=collection)
49+
return result if result is not None else self._new_default_value()
50+
51+
@override
52+
async def get_many(self, keys: Sequence[str], *, collection: str | None = None) -> list[dict[str, Any] | None]:
53+
results = await self.key_value.get_many(keys=keys, collection=collection)
54+
return [result if result is not None else self._new_default_value() for result in results]
55+
56+
@override
57+
async def ttl(self, key: str, *, collection: str | None = None) -> tuple[dict[str, Any] | None, float | None]:
58+
result, ttl_value = await self.key_value.ttl(key=key, collection=collection)
59+
if result is None:
60+
return (self._new_default_value(), self._default_ttl)
61+
return (result, ttl_value)
62+
63+
@override
64+
async def ttl_many(self, keys: Sequence[str], *, collection: str | None = None) -> list[tuple[dict[str, Any] | None, float | None]]:
65+
results = await self.key_value.ttl_many(keys=keys, collection=collection)
66+
return [
67+
(result, ttl_value) if result is not None else (self._new_default_value(), self._default_ttl) for result, ttl_value in results
68+
]

key-value/key-value-aio/tests/adapters/test_pydantic.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,13 @@ class Order(BaseModel):
3838
FIXED_UPDATED_AT: datetime = datetime(year=2021, month=1, day=1, hour=15, minute=0, second=0, tzinfo=timezone.utc)
3939

4040
SAMPLE_USER: User = User(name="John Doe", email="john.doe@example.com", age=30)
41+
SAMPLE_USER_2: User = User(name="Jane Doe", email="jane.doe@example.com", age=25)
4142
SAMPLE_PRODUCT: Product = Product(name="Widget", price=29.99, quantity=10, url=AnyHttpUrl(url="https://example.com"))
4243
SAMPLE_ORDER: Order = Order(created_at=datetime.now(), updated_at=datetime.now(), user=SAMPLE_USER, product=SAMPLE_PRODUCT, paid=False)
4344

4445
TEST_COLLECTION: str = "test_collection"
4546
TEST_KEY: str = "test_key"
47+
TEST_KEY_2: str = "test_key_2"
4648

4749

4850
class TestPydanticAdapter:
@@ -77,8 +79,17 @@ async def test_simple_adapter(self, user_adapter: PydanticAdapter[User]):
7779

7880
assert await user_adapter.delete(collection=TEST_COLLECTION, key=TEST_KEY)
7981

80-
cached_user = await user_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY)
81-
assert cached_user is None
82+
assert await user_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) is None
83+
84+
async def test_simple_adapter_with_default(self, user_adapter: PydanticAdapter[User]):
85+
assert await user_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY, default=SAMPLE_USER) == SAMPLE_USER
86+
87+
await user_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY, value=SAMPLE_USER_2)
88+
assert await user_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY, default=SAMPLE_USER) == SAMPLE_USER_2
89+
90+
assert await user_adapter.get_many(collection=TEST_COLLECTION, keys=[TEST_KEY, TEST_KEY_2], default=SAMPLE_USER) == snapshot(
91+
[SAMPLE_USER_2, SAMPLE_USER]
92+
)
8293

8394
async def test_simple_adapter_with_validation_error_ignore(
8495
self, user_adapter: PydanticAdapter[User], updated_user_adapter: PydanticAdapter[UpdatedUser]
@@ -98,12 +109,10 @@ async def test_simple_adapter_with_validation_error_raise(
98109

99110
async def test_complex_adapter(self, order_adapter: PydanticAdapter[Order]):
100111
await order_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY, value=SAMPLE_ORDER, ttl=10)
101-
cached_order: Order | None = await order_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY)
102-
assert cached_order == SAMPLE_ORDER
112+
assert await order_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) == SAMPLE_ORDER
103113

104114
assert await order_adapter.delete(collection=TEST_COLLECTION, key=TEST_KEY)
105-
cached_order = await order_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY)
106-
assert cached_order is None
115+
assert await order_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) is None
107116

108117
async def test_complex_adapter_with_list(self, product_list_adapter: PydanticAdapter[list[Product]], store: MemoryStore):
109118
await product_list_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY, value=[SAMPLE_PRODUCT, SAMPLE_PRODUCT], ttl=10)
@@ -127,5 +136,4 @@ async def test_complex_adapter_with_list(self, product_list_adapter: PydanticAda
127136
)
128137

129138
assert await product_list_adapter.delete(collection=TEST_COLLECTION, key=TEST_KEY)
130-
cached_products = await product_list_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY)
131-
assert cached_products is None
139+
assert await product_list_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) is None

key-value/key-value-aio/tests/stores/memory/test_memory.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,7 @@ class TestMemoryStore(BaseStoreTests):
1010
@pytest.fixture
1111
async def store(self) -> MemoryStore:
1212
return MemoryStore(max_entries_per_collection=500)
13+
14+
async def test_seed(self):
15+
store = MemoryStore(max_entries_per_collection=500, seed={"test_collection": {"test_key": {"obj_key": "obj_value"}}})
16+
assert await store.get(key="test_key", collection="test_collection") == {"obj_key": "obj_value"}

0 commit comments

Comments
 (0)