Skip to content

Commit d720829

Browse files
committed
Updates for PR Feedback
1 parent 2af53c3 commit d720829

File tree

15 files changed

+305
-251
lines changed

15 files changed

+305
-251
lines changed

key-value/key-value-aio/src/key_value/aio/stores/base.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
AsyncKeyValueProtocol,
2525
)
2626

27+
SEED_DATA_TYPE = Mapping[str, Mapping[str, Mapping[str, Any]]]
28+
2729

2830
class BaseStore(AsyncKeyValueProtocol, ABC):
2931
"""An opinionated Abstract base class for managed key-value stores using ManagedEntry objects.
@@ -43,21 +45,29 @@ class BaseStore(AsyncKeyValueProtocol, ABC):
4345
_setup_collection_locks: defaultdict[str, Lock]
4446
_setup_collection_complete: defaultdict[str, bool]
4547

48+
_seed: SEED_DATA_TYPE
49+
4650
default_collection: str
4751

48-
def __init__(self, *, default_collection: str | None = None) -> None:
52+
def __init__(self, *, default_collection: str | None = None, seed: SEED_DATA_TYPE | None = None) -> None:
4953
"""Initialize the managed key-value store.
5054
5155
Args:
5256
default_collection: The default collection to use if no collection is provided.
5357
Defaults to "default_collection".
58+
seed: Optional seed data to pre-populate the store. Format: {collection: {key: {field: value, ...}}}.
59+
Each value must be a mapping (dict) that will be stored as the entry's value.
60+
Seeding occurs during every store setup (when the store is entered or when the first operation
61+
is performed on the store).
5462
"""
5563

5664
self._setup_complete = False
5765
self._setup_lock = Lock()
5866
self._setup_collection_locks = defaultdict(Lock)
5967
self._setup_collection_complete = defaultdict(bool)
6068

69+
self._seed = seed or {}
70+
6171
self.default_collection = default_collection or DEFAULT_COLLECTION_NAME
6272

6373
if not hasattr(self, "_stable_api"):
@@ -74,6 +84,13 @@ async def _setup(self) -> None:
7484
async def _setup_collection(self, *, collection: str) -> None:
7585
"""Initialize the collection (called once before first use of the collection)."""
7686

87+
async def _seed_store(self) -> None:
88+
"""Seed the store with the data from the seed."""
89+
for collection, items in self._seed.items():
90+
await self.setup_collection(collection=collection)
91+
for key, value in items.items():
92+
await self.put(key=key, value=value, collection=collection)
93+
7794
async def setup(self) -> None:
7895
if not self._setup_complete:
7996
async with self._setup_lock:
@@ -86,6 +103,8 @@ async def setup(self) -> None:
86103
) from e
87104
self._setup_complete = True
88105

106+
await self._seed_store()
107+
89108
async def setup_collection(self, *, collection: str) -> None:
90109
await self.setup()
91110

key-value/key-value-aio/src/key_value/aio/stores/memory/store.py

Lines changed: 12 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import sys
2-
from collections.abc import Mapping
32
from dataclasses import dataclass, field
43
from datetime import datetime
54
from typing import Any, SupportsFloat
@@ -9,6 +8,7 @@
98
from typing_extensions import Self, override
109

1110
from key_value.aio.stores.base import (
11+
SEED_DATA_TYPE,
1212
BaseDestroyCollectionStore,
1313
BaseDestroyStore,
1414
BaseEnumerateCollectionsStore,
@@ -109,14 +109,13 @@ class MemoryStore(BaseDestroyStore, BaseDestroyCollectionStore, BaseEnumerateCol
109109
max_entries_per_collection: int
110110

111111
_cache: dict[str, MemoryCollection]
112-
_seed: Mapping[str, Mapping[str, Mapping[str, Any]]]
113112

114113
def __init__(
115114
self,
116115
*,
117116
max_entries_per_collection: int = DEFAULT_MAX_ENTRIES_PER_COLLECTION,
118117
default_collection: str | None = None,
119-
seed: Mapping[str, Mapping[str, Mapping[str, Any]]] | None = None,
118+
seed: SEED_DATA_TYPE | None = None,
120119
):
121120
"""Initialize a fixed-size in-memory store.
122121
@@ -131,24 +130,15 @@ def __init__(
131130
self.max_entries_per_collection = max_entries_per_collection
132131

133132
self._cache = {}
134-
self._seed = seed or {}
135133

136134
self._stable_api = True
137135

138-
super().__init__(default_collection=default_collection)
136+
super().__init__(default_collection=default_collection, seed=seed)
139137

140-
def _create_collection(self, collection: str) -> MemoryCollection:
141-
"""Create a new collection.
142-
143-
Args:
144-
collection: The collection name.
145-
146-
Returns:
147-
The created MemoryCollection instance.
148-
"""
149-
collection_cache = MemoryCollection(max_entries=self.max_entries_per_collection)
150-
self._cache[collection] = collection_cache
151-
return collection_cache
138+
@override
139+
async def _setup(self) -> None:
140+
for collection in self._seed:
141+
await self._setup_collection(collection=collection)
152142

153143
@override
154144
async def _setup_collection(self, *, collection: str) -> None:
@@ -157,15 +147,11 @@ async def _setup_collection(self, *, collection: str) -> None:
157147
Args:
158148
collection: The collection name.
159149
"""
160-
# Create the collection
161-
collection_cache = self._create_collection(collection)
162-
163-
# Seed the collection if seed data is available for it
164-
if collection in self._seed:
165-
items = self._seed[collection]
166-
for key, value in items.items():
167-
managed_entry = ManagedEntry(value=value)
168-
collection_cache.put(key=key, value=managed_entry)
150+
if collection in self._cache:
151+
return
152+
153+
collection_cache = MemoryCollection(max_entries=self.max_entries_per_collection)
154+
self._cache[collection] = collection_cache
169155

170156
@override
171157
async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry | None:

key-value/key-value-aio/src/key_value/aio/wrappers/default_value/wrapper.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from collections.abc import Mapping, Sequence
2-
from typing import Any
2+
from typing import Any, SupportsFloat
33

4+
from key_value.shared.utils.managed_entry import dump_to_json, load_from_json
45
from typing_extensions import override
56

67
from key_value.aio.protocols.key_value import AsyncKeyValue
@@ -12,28 +13,35 @@ class DefaultValueWrapper(BaseWrapper):
1213
1314
This wrapper provides dict.get(key, default) behavior for the key-value store,
1415
allowing you to specify a default value to return instead of None when a key doesn't exist.
16+
17+
It does not store the default value in the underlying key-value store and the TTL returned with the default
18+
value is hard-coded based on the default_ttl parameter. Picking a default_ttl requires careful consideration
19+
of how the value will be used and if any other wrappers will be used that may rely on the TTL.
1520
"""
1621

17-
_key_value: AsyncKeyValue
1822
key_value: AsyncKeyValue # Alias for BaseWrapper compatibility
23+
_default_ttl: float | None
24+
_default_value_json: str
1925

2026
def __init__(
2127
self,
2228
key_value: AsyncKeyValue,
2329
default_value: Mapping[str, Any],
24-
default_ttl: float | None = None,
30+
default_ttl: SupportsFloat | None = None,
2531
) -> None:
2632
"""Initialize the DefaultValueWrapper.
2733
2834
Args:
2935
key_value: The underlying key-value store to wrap.
3036
default_value: The default value to return when a key is not found.
31-
default_ttl: The TTL to return for default values. Defaults to None.
37+
default_ttl: The TTL to return to the caller for default values. Defaults to None.
3238
"""
33-
self._key_value = key_value
34-
self.key_value = key_value # Alias for BaseWrapper compatibility
35-
self._default_value = default_value
36-
self._default_ttl = default_ttl
39+
self.key_value = key_value
40+
self._default_value_json = dump_to_json(obj=dict(default_value))
41+
self._default_ttl = None if default_ttl is None else float(default_ttl)
42+
43+
def _new_default_value(self) -> dict[str, Any]:
44+
return load_from_json(json_str=self._default_value_json)
3745

3846
@override
3947
async def get(self, key: str, *, collection: str | None = None) -> dict[str, Any] | None:
@@ -46,8 +54,8 @@ async def get(self, key: str, *, collection: str | None = None) -> dict[str, Any
4654
Returns:
4755
The value associated with the key, or the default value if not found.
4856
"""
49-
result = await self._key_value.get(key=key, collection=collection)
50-
return result if result is not None else dict(self._default_value)
57+
result = await self.key_value.get(key=key, collection=collection)
58+
return result if result is not None else self._new_default_value()
5159

5260
@override
5361
async def get_many(self, keys: Sequence[str], *, collection: str | None = None) -> list[dict[str, Any] | None]:
@@ -60,8 +68,8 @@ async def get_many(self, keys: Sequence[str], *, collection: str | None = None)
6068
Returns:
6169
A list of values, with default values for missing keys.
6270
"""
63-
results = await self._key_value.get_many(keys=keys, collection=collection)
64-
return [result if result is not None else dict(self._default_value) for result in results]
71+
results = await self.key_value.get_many(keys=keys, collection=collection)
72+
return [result if result is not None else self._new_default_value() for result in results]
6573

6674
@override
6775
async def ttl(self, key: str, *, collection: str | None = None) -> tuple[dict[str, Any] | None, float | None]:
@@ -74,9 +82,9 @@ async def ttl(self, key: str, *, collection: str | None = None) -> tuple[dict[st
7482
Returns:
7583
A tuple of (value, ttl), with default value and default TTL if not found.
7684
"""
77-
result, ttl_value = await self._key_value.ttl(key=key, collection=collection)
85+
result, ttl_value = await self.key_value.ttl(key=key, collection=collection)
7886
if result is None:
79-
return (dict(self._default_value), self._default_ttl)
87+
return (self._new_default_value(), self._default_ttl)
8088
return (result, ttl_value)
8189

8290
@override
@@ -90,7 +98,7 @@ async def ttl_many(self, keys: Sequence[str], *, collection: str | None = None)
9098
Returns:
9199
A list of (value, ttl) tuples, with default values and default TTL for missing keys.
92100
"""
93-
results = await self._key_value.ttl_many(keys=keys, collection=collection)
101+
results = await self.key_value.ttl_many(keys=keys, collection=collection)
94102
return [
95-
(result, ttl_value) if result is not None else (dict(self._default_value), self._default_ttl) for result, ttl_value in results
103+
(result, ttl_value) if result is not None else (self._new_default_value(), self._default_ttl) for result, ttl_value in results
96104
]

key-value/key-value-aio/src/key_value/aio/wrappers/pydantic_json/__init__.py

Lines changed: 0 additions & 5 deletions
This file was deleted.

key-value/key-value-aio/src/key_value/aio/wrappers/pydantic_json/wrapper.py

Lines changed: 0 additions & 73 deletions
This file was deleted.

key-value/key-value-aio/tests/stores/memory/test_memory.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,7 @@ class TestMemoryStore(BaseStoreTests):
1010
@pytest.fixture
1111
async def store(self) -> MemoryStore:
1212
return MemoryStore(max_entries_per_collection=500)
13+
14+
async def test_seed(self):
15+
store = MemoryStore(max_entries_per_collection=500, seed={"test_collection": {"test_key": {"obj_key": "obj_value"}}})
16+
assert await store.get(key="test_key", collection="test_collection") == {"obj_key": "obj_value"}

0 commit comments

Comments
 (0)