Skip to content

Commit ccabd02

Browse files
committed
Updates for PR Feedback
1 parent 1bb30e5 commit ccabd02

File tree

11 files changed

+306
-12
lines changed

11 files changed

+306
-12
lines changed

key-value/key-value-aio/src/key_value/aio/adapters/pydantic/adapter.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,13 @@ async def get(self, key: str, *, collection: str | None = None, default: T | Non
9595

9696
return default
9797

98-
async def get_many(self, keys: Sequence[str], *, collection: str | None = None, default: T | None = None) -> list[T | None]:
98+
@overload
99+
async def get_many(self, keys: Sequence[str], *, collection: str | None = None, default: T) -> list[T]: ...
100+
101+
@overload
102+
async def get_many(self, keys: Sequence[str], *, collection: str | None = None, default: None = None) -> list[T | None]: ...
103+
104+
async def get_many(self, keys: Sequence[str], *, collection: str | None = None, default: T | None = None) -> list[T] | list[T | None]:
99105
"""Batch get and validate models by keys, preserving order.
100106
101107
Args:

key-value/key-value-aio/tests/adapters/test_pydantic.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,13 @@ class Order(BaseModel):
3838
FIXED_UPDATED_AT: datetime = datetime(year=2021, month=1, day=1, hour=15, minute=0, second=0, tzinfo=timezone.utc)
3939

4040
SAMPLE_USER: User = User(name="John Doe", email="john.doe@example.com", age=30)
41+
SAMPLE_USER_2: User = User(name="Jane Doe", email="jane.doe@example.com", age=25)
4142
SAMPLE_PRODUCT: Product = Product(name="Widget", price=29.99, quantity=10, url=AnyHttpUrl(url="https://example.com"))
4243
SAMPLE_ORDER: Order = Order(created_at=datetime.now(), updated_at=datetime.now(), user=SAMPLE_USER, product=SAMPLE_PRODUCT, paid=False)
4344

4445
TEST_COLLECTION: str = "test_collection"
4546
TEST_KEY: str = "test_key"
47+
TEST_KEY_2: str = "test_key_2"
4648

4749

4850
class TestPydanticAdapter:
@@ -80,6 +82,15 @@ async def test_simple_adapter(self, user_adapter: PydanticAdapter[User]):
8082
cached_user = await user_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY)
8183
assert cached_user is None
8284

85+
async def test_simple_adapter_with_default(self, user_adapter: PydanticAdapter[User]):
86+
cached_user = await user_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY, default=SAMPLE_USER)
87+
assert cached_user == SAMPLE_USER
88+
89+
await user_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY, value=SAMPLE_USER_2)
90+
91+
cached_users = await user_adapter.get_many(collection=TEST_COLLECTION, keys=[TEST_KEY, TEST_KEY_2], default=SAMPLE_USER)
92+
assert sorted(cached_users, key=lambda x: x.name) == [SAMPLE_USER_2, SAMPLE_USER]
93+
8394
async def test_simple_adapter_with_validation_error_ignore(
8495
self, user_adapter: PydanticAdapter[User], updated_user_adapter: PydanticAdapter[UpdatedUser]
8596
):

key-value/key-value-sync/src/key_value/sync/code_gen/adapters/pydantic/adapter.py

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# from the original file 'adapter.py'
33
# DO NOT CHANGE! Change the original file instead.
44
from collections.abc import Sequence
5-
from typing import Any, Generic, SupportsFloat, TypeVar, get_origin
5+
from typing import Any, Generic, SupportsFloat, TypeVar, get_origin, overload
66

77
from key_value.shared.errors import DeserializationError, SerializationError
88
from key_value.shared.type_checking.bear_spray import bear_spray
@@ -65,11 +65,22 @@ def _serialize_model(self, value: T) -> dict[str, Any]:
6565
msg = f"Invalid Pydantic model: {e}"
6666
raise SerializationError(msg) from e
6767

68-
def get(self, key: str, *, collection: str | None = None) -> T | None:
68+
@overload
69+
def get(self, key: str, *, collection: str | None = None, default: T) -> T: ...
70+
71+
@overload
72+
def get(self, key: str, *, collection: str | None = None, default: None = None) -> T | None: ...
73+
74+
def get(self, key: str, *, collection: str | None = None, default: T | None = None) -> T | None:
6975
"""Get and validate a model by key.
7076
77+
Args:
78+
key: The key to retrieve.
79+
collection: The collection to use. If not provided, uses the default collection.
80+
default: The default value to return if the key doesn't exist or validation fails.
81+
7182
Returns:
72-
The parsed model instance, or None if not present.
83+
The parsed model instance if found and valid, or the default value if key doesn't exist or validation fails.
7384
7485
Raises:
7586
DeserializationError if the stored data cannot be validated as the model and the PydanticAdapter is configured to
@@ -78,15 +89,28 @@ def get(self, key: str, *, collection: str | None = None) -> T | None:
7889
collection = collection or self._default_collection
7990

8091
if value := self._key_value.get(key=key, collection=collection):
81-
return self._validate_model(value=value)
92+
validated = self._validate_model(value=value)
93+
if validated is not None:
94+
return validated
8295

83-
return None
96+
return default
8497

85-
def get_many(self, keys: Sequence[str], *, collection: str | None = None) -> list[T | None]:
98+
@overload
99+
def get_many(self, keys: Sequence[str], *, collection: str | None = None, default: T) -> list[T]: ...
100+
101+
@overload
102+
def get_many(self, keys: Sequence[str], *, collection: str | None = None, default: None = None) -> list[T | None]: ...
103+
104+
def get_many(self, keys: Sequence[str], *, collection: str | None = None, default: T | None = None) -> list[T] | list[T | None]:
86105
"""Batch get and validate models by keys, preserving order.
87106
107+
Args:
108+
keys: The list of keys to retrieve.
109+
collection: The collection to use. If not provided, uses the default collection.
110+
default: The default value to return for keys that don't exist or fail validation.
111+
88112
Returns:
89-
A list of parsed model instances, or None if missing.
113+
A list of parsed model instances, with default values for missing keys or validation failures.
90114
91115
Raises:
92116
DeserializationError if the stored data cannot be validated as the model and the PydanticAdapter is configured to
@@ -96,7 +120,14 @@ def get_many(self, keys: Sequence[str], *, collection: str | None = None) -> lis
96120

97121
values: list[dict[str, Any] | None] = self._key_value.get_many(keys=keys, collection=collection)
98122

99-
return [self._validate_model(value=value) if value else None for value in values]
123+
result: list[T | None] = []
124+
for value in values:
125+
if value is None:
126+
result.append(default)
127+
else:
128+
validated = self._validate_model(value=value)
129+
result.append(validated if validated is not None else default)
130+
return result
100131

101132
def put(self, key: str, value: T, *, collection: str | None = None, ttl: SupportsFloat | None = None) -> None:
102133
"""Serialize and store a model.

key-value/key-value-sync/src/key_value/sync/code_gen/stores/memory/store.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# from the original file 'store.py'
33
# DO NOT CHANGE! Change the original file instead.
44
import sys
5+
from collections.abc import Mapping
56
from dataclasses import dataclass, field
67
from datetime import datetime
78
from typing import Any, SupportsFloat
@@ -103,26 +104,63 @@ class MemoryStore(BaseDestroyStore, BaseDestroyCollectionStore, BaseEnumerateCol
103104
max_entries_per_collection: int
104105

105106
_cache: dict[str, MemoryCollection]
106-
107-
def __init__(self, *, max_entries_per_collection: int = DEFAULT_MAX_ENTRIES_PER_COLLECTION, default_collection: str | None = None):
107+
_seed: Mapping[str, Mapping[str, Mapping[str, Any]]]
108+
109+
def __init__(
110+
self,
111+
*,
112+
max_entries_per_collection: int = DEFAULT_MAX_ENTRIES_PER_COLLECTION,
113+
default_collection: str | None = None,
114+
seed: Mapping[str, Mapping[str, Mapping[str, Any]]] | None = None,
115+
):
108116
"""Initialize a fixed-size in-memory store.
109117
110118
Args:
111119
max_entries_per_collection: The maximum number of entries per collection. Defaults to 10000.
112120
default_collection: The default collection to use if no collection is provided.
121+
seed: Optional seed data to pre-populate the store. Format: {collection: {key: {field: value, ...}}}.
122+
Each value must be a mapping (dict) that will be stored as the entry's value.
123+
Seeding occurs lazily when each collection is first accessed.
113124
"""
114125

115126
self.max_entries_per_collection = max_entries_per_collection
116127

117128
self._cache = {}
129+
self._seed = seed or {}
118130

119131
self._stable_api = True
120132

121133
super().__init__(default_collection=default_collection)
122134

135+
def _create_collection(self, collection: str) -> MemoryCollection:
136+
"""Create a new collection.
137+
138+
Args:
139+
collection: The collection name.
140+
141+
Returns:
142+
The created MemoryCollection instance.
143+
"""
144+
collection_cache = MemoryCollection(max_entries=self.max_entries_per_collection)
145+
self._cache[collection] = collection_cache
146+
return collection_cache
147+
123148
@override
124149
def _setup_collection(self, *, collection: str) -> None:
125-
self._cache[collection] = MemoryCollection(max_entries=self.max_entries_per_collection)
150+
"""Set up a collection, creating it and seeding it if seed data is available.
151+
152+
Args:
153+
collection: The collection name.
154+
"""
155+
# Create the collection
156+
collection_cache = self._create_collection(collection)
157+
158+
# Seed the collection if seed data is available for it
159+
if collection in self._seed:
160+
items = self._seed[collection]
161+
for key, value in items.items():
162+
managed_entry = ManagedEntry(value=value)
163+
collection_cache.put(key=key, value=managed_entry)
126164

127165
@override
128166
def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry | None:
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# WARNING: this file is auto-generated by 'build_sync_library.py'
2+
# from the original file '__init__.py'
3+
# DO NOT CHANGE! Change the original file instead.
4+
"""Default value wrapper for returning fallback values when keys are not found."""
5+
6+
from key_value.sync.code_gen.wrappers.default_value.wrapper import DefaultValueWrapper
7+
8+
__all__ = ["DefaultValueWrapper"]
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# WARNING: this file is auto-generated by 'build_sync_library.py'
2+
# from the original file 'wrapper.py'
3+
# DO NOT CHANGE! Change the original file instead.
4+
from collections.abc import Mapping, Sequence
5+
from typing import Any
6+
7+
from typing_extensions import override
8+
9+
from key_value.sync.code_gen.protocols.key_value import KeyValue
10+
from key_value.sync.code_gen.wrappers.base import BaseWrapper
11+
12+
13+
class DefaultValueWrapper(BaseWrapper):
14+
"""A wrapper that returns a default value when a key is not found.
15+
16+
This wrapper provides dict.get(key, default) behavior for the key-value store,
17+
allowing you to specify a default value to return instead of None when a key doesn't exist.
18+
"""
19+
20+
_key_value: KeyValue
21+
key_value: KeyValue # Alias for BaseWrapper compatibility
22+
23+
def __init__(self, key_value: KeyValue, default_value: Mapping[str, Any], default_ttl: float | None = None) -> None:
24+
"""Initialize the DefaultValueWrapper.
25+
26+
Args:
27+
key_value: The underlying key-value store to wrap.
28+
default_value: The default value to return when a key is not found.
29+
default_ttl: The TTL to return for default values. Defaults to None.
30+
"""
31+
self._key_value = key_value
32+
self.key_value = key_value # Alias for BaseWrapper compatibility
33+
self._default_value = default_value
34+
self._default_ttl = default_ttl
35+
36+
@override
37+
def get(self, key: str, *, collection: str | None = None) -> dict[str, Any] | None:
38+
"""Get a value by key, returning the default value if not found.
39+
40+
Args:
41+
key: The key to retrieve.
42+
collection: The collection to use.
43+
44+
Returns:
45+
The value associated with the key, or the default value if not found.
46+
"""
47+
result = self._key_value.get(key=key, collection=collection)
48+
return result if result is not None else dict(self._default_value)
49+
50+
@override
51+
def get_many(self, keys: Sequence[str], *, collection: str | None = None) -> list[dict[str, Any] | None]:
52+
"""Get multiple values by keys, returning the default value for missing keys.
53+
54+
Args:
55+
keys: The keys to retrieve.
56+
collection: The collection to use.
57+
58+
Returns:
59+
A list of values, with default values for missing keys.
60+
"""
61+
results = self._key_value.get_many(keys=keys, collection=collection)
62+
return [result if result is not None else dict(self._default_value) for result in results]
63+
64+
@override
65+
def ttl(self, key: str, *, collection: str | None = None) -> tuple[dict[str, Any] | None, float | None]:
66+
"""Get a value and its TTL, returning the default value if not found.
67+
68+
Args:
69+
key: The key to retrieve.
70+
collection: The collection to use.
71+
72+
Returns:
73+
A tuple of (value, ttl), with default value and default TTL if not found.
74+
"""
75+
(result, ttl_value) = self._key_value.ttl(key=key, collection=collection)
76+
if result is None:
77+
return (dict(self._default_value), self._default_ttl)
78+
return (result, ttl_value)
79+
80+
@override
81+
def ttl_many(self, keys: Sequence[str], *, collection: str | None = None) -> list[tuple[dict[str, Any] | None, float | None]]:
82+
"""Get multiple values with TTLs, returning the default value for missing keys.
83+
84+
Args:
85+
keys: The keys to retrieve.
86+
collection: The collection to use.
87+
88+
Returns:
89+
A list of (value, ttl) tuples, with default values and default TTL for missing keys.
90+
"""
91+
results = self._key_value.ttl_many(keys=keys, collection=collection)
92+
return [
93+
(result, ttl_value) if result is not None else (dict(self._default_value), self._default_ttl) for (result, ttl_value) in results
94+
]
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# WARNING: this file is auto-generated by 'build_sync_library.py'
2+
# from the original file '__init__.py'
3+
# DO NOT CHANGE! Change the original file instead.
4+
"""Pydantic JSON wrapper for ensuring JSON-serializable value storage."""
5+
6+
from key_value.sync.code_gen.wrappers.pydantic_json.wrapper import PydanticJsonWrapper
7+
8+
__all__ = ["PydanticJsonWrapper"]
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# WARNING: this file is auto-generated by 'build_sync_library.py'
2+
# from the original file 'wrapper.py'
3+
# DO NOT CHANGE! Change the original file instead.
4+
from collections.abc import Mapping, Sequence
5+
from typing import Any, SupportsFloat
6+
7+
from pydantic import TypeAdapter
8+
from typing_extensions import override
9+
10+
from key_value.sync.code_gen.protocols.key_value import KeyValue
11+
from key_value.sync.code_gen.wrappers.base import BaseWrapper
12+
13+
14+
class PydanticJsonWrapper(BaseWrapper):
15+
"""A wrapper that ensures all values are JSON-serializable using Pydantic's TypeAdapter.
16+
17+
This wrapper automatically converts values to JSON-safe formats before storage,
18+
ensuring compatibility with stores that require JSON-serializable data.
19+
"""
20+
21+
_key_value: KeyValue
22+
key_value: KeyValue # Alias for BaseWrapper compatibility
23+
24+
def __init__(self, key_value: KeyValue) -> None:
25+
"""Initialize the PydanticJsonWrapper.
26+
27+
Args:
28+
key_value: The underlying key-value store to wrap.
29+
"""
30+
self._key_value = key_value
31+
self.key_value = key_value # Alias for BaseWrapper compatibility
32+
self._adapter: TypeAdapter[dict[str, Any]] = TypeAdapter(dict[str, Any])
33+
34+
def _to_json_safe(self, value: Mapping[str, Any]) -> dict[str, Any]:
35+
"""Convert a value to a JSON-safe format using Pydantic.
36+
37+
Args:
38+
value: The value to convert.
39+
40+
Returns:
41+
A JSON-safe dictionary.
42+
"""
43+
return self._adapter.dump_python(value, mode="json") # type: ignore[return-value]
44+
45+
@override
46+
def put(self, key: str, value: Mapping[str, Any], *, collection: str | None = None, ttl: SupportsFloat | None = None) -> None:
47+
"""Store a value after converting it to JSON-safe format.
48+
49+
Args:
50+
key: The key to store.
51+
value: The value to store (will be converted to JSON-safe format).
52+
collection: The collection to use.
53+
ttl: The time-to-live in seconds.
54+
"""
55+
json_safe_value = self._to_json_safe(value)
56+
self._key_value.put(key=key, value=json_safe_value, collection=collection, ttl=ttl)
57+
58+
@override
59+
def put_many(
60+
self, keys: Sequence[str], values: Sequence[Mapping[str, Any]], *, collection: str | None = None, ttl: SupportsFloat | None = None
61+
) -> None:
62+
"""Store multiple values after converting them to JSON-safe format.
63+
64+
Args:
65+
keys: The keys to store.
66+
values: The values to store (will be converted to JSON-safe format).
67+
collection: The collection to use.
68+
ttl: The time-to-live in seconds for all items.
69+
"""
70+
json_safe_values = [self._to_json_safe(value) for value in values]
71+
self._key_value.put_many(keys=keys, values=json_safe_values, collection=collection, ttl=ttl)
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# WARNING: this file is auto-generated by 'build_sync_library.py'
2+
# from the original file '__init__.py'
3+
# DO NOT CHANGE! Change the original file instead.
4+
"""Default value wrapper for returning fallback values when keys are not found."""
5+
6+
from key_value.sync.code_gen.wrappers.default_value.wrapper import DefaultValueWrapper
7+
8+
__all__ = ["DefaultValueWrapper"]

0 commit comments

Comments
 (0)