Skip to content

Commit d591c4e

Browse files
refactor: create BasePydanticAdapter to eliminate code duplication
Created a new BasePydanticAdapter abstract base class that contains all shared functionality between PydanticAdapter and DataclassAdapter. This eliminates ~200 lines of duplicated code and addresses the SonarQube code duplication warning (9.5% -> expected to be <3%). Key changes: - Added BasePydanticAdapter in adapters/base/ with all shared methods - Refactored PydanticAdapter to inherit from BasePydanticAdapter - Refactored DataclassAdapter to inherit from BasePydanticAdapter - Both adapters now only contain their __init__ and _get_model_type_name() Benefits: - DRY principle: eliminates code duplication - Maintainability: bug fixes only need to be made once - Consistency: ensures both adapters behave identically - Future-proofing: new adapters can reuse the base class All tests pass (17 tests), type checking passes with 0 errors. Co-authored-by: William Easton <strawgate@users.noreply.github.com>
1 parent c317c01 commit d591c4e

File tree

4 files changed

+248
-395
lines changed

4 files changed

+248
-395
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from key_value.aio.adapters.base.adapter import BasePydanticAdapter
2+
3+
__all__ = ["BasePydanticAdapter"]
Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
from abc import ABC, abstractmethod
2+
from collections.abc import Sequence
3+
from typing import Any, Generic, SupportsFloat, TypeVar, overload
4+
5+
from key_value.shared.errors import DeserializationError, SerializationError
6+
from pydantic import ValidationError
7+
from pydantic.type_adapter import TypeAdapter
8+
from pydantic_core import PydanticSerializationError
9+
10+
from key_value.aio.protocols.key_value import AsyncKeyValue
11+
12+
T = TypeVar("T")
13+
14+
15+
class BasePydanticAdapter(Generic[T], ABC):
16+
"""Base adapter using Pydantic TypeAdapter for validation and serialization.
17+
18+
This abstract base class provides shared functionality for adapters that use
19+
Pydantic's TypeAdapter for validation and serialization. Concrete subclasses
20+
must implement _get_model_type_name() to provide appropriate error messages.
21+
"""
22+
23+
_key_value: AsyncKeyValue
24+
_is_list_model: bool
25+
_type_adapter: TypeAdapter[T]
26+
_default_collection: str | None
27+
_raise_on_validation_error: bool
28+
29+
@abstractmethod
30+
def _get_model_type_name(self) -> str:
31+
"""Return the model type name for error messages.
32+
33+
Returns:
34+
A string describing the model type (e.g., "Pydantic model", "dataclass").
35+
"""
36+
...
37+
38+
def _validate_model(self, value: dict[str, Any]) -> T | None:
39+
"""Validate and deserialize a dict into the configured model type.
40+
41+
This method handles both single models and list models. For list models, it expects the value
42+
to contain an "items" key with the list data, following the convention used by `_serialize_model`.
43+
If validation fails and `raise_on_validation_error` is False, returns None instead of raising.
44+
45+
Args:
46+
value: The dict to validate and convert to a model.
47+
48+
Returns:
49+
The validated model instance, or None if validation fails and errors are suppressed.
50+
51+
Raises:
52+
DeserializationError: If validation fails and `raise_on_validation_error` is True.
53+
"""
54+
try:
55+
if self._is_list_model:
56+
return self._type_adapter.validate_python(value.get("items", []))
57+
58+
return self._type_adapter.validate_python(value)
59+
except ValidationError as e:
60+
if self._raise_on_validation_error:
61+
msg = f"Invalid {self._get_model_type_name()}: {value}"
62+
raise DeserializationError(msg) from e
63+
return None
64+
65+
def _serialize_model(self, value: T) -> dict[str, Any]:
66+
"""Serialize a model to a dict for storage.
67+
68+
This method handles both single models and list models. For list models, it wraps the serialized
69+
list in a dict with an "items" key (e.g., {"items": [...]}) to ensure consistent dict-based storage
70+
format across all value types. This wrapping convention is expected by `_validate_model` during
71+
deserialization.
72+
73+
Args:
74+
value: The model instance to serialize.
75+
76+
Returns:
77+
A dict representation of the model suitable for storage.
78+
79+
Raises:
80+
SerializationError: If the model cannot be serialized.
81+
"""
82+
try:
83+
if self._is_list_model:
84+
return {"items": self._type_adapter.dump_python(value, mode="json")}
85+
86+
return self._type_adapter.dump_python(value, mode="json") # pyright: ignore[reportAny]
87+
except PydanticSerializationError as e:
88+
msg = f"Invalid {self._get_model_type_name()}: {e}"
89+
raise SerializationError(msg) from e
90+
91+
@overload
92+
async def get(self, key: str, *, collection: str | None = None, default: T) -> T: ...
93+
94+
@overload
95+
async def get(self, key: str, *, collection: str | None = None, default: None = None) -> T | None: ...
96+
97+
async def get(self, key: str, *, collection: str | None = None, default: T | None = None) -> T | None:
98+
"""Get and validate a model by key.
99+
100+
Args:
101+
key: The key to retrieve.
102+
collection: The collection to use. If not provided, uses the default collection.
103+
default: The default value to return if the key doesn't exist or validation fails.
104+
105+
Returns:
106+
The parsed model instance if found and valid, or the default value if key doesn't exist or validation fails.
107+
108+
Raises:
109+
DeserializationError: If the stored data cannot be validated as the model and the adapter is configured to
110+
raise on validation error.
111+
112+
Note:
113+
When raise_on_validation_error=False and validation fails, returns the default value (which may be None).
114+
When raise_on_validation_error=True and validation fails, raises DeserializationError.
115+
"""
116+
collection = collection or self._default_collection
117+
118+
if value := await self._key_value.get(key=key, collection=collection):
119+
validated = self._validate_model(value=value)
120+
if validated is not None:
121+
return validated
122+
123+
return default
124+
125+
@overload
126+
async def get_many(self, keys: Sequence[str], *, collection: str | None = None, default: T) -> list[T]: ...
127+
128+
@overload
129+
async def get_many(self, keys: Sequence[str], *, collection: str | None = None, default: None = None) -> list[T | None]: ...
130+
131+
async def get_many(self, keys: Sequence[str], *, collection: str | None = None, default: T | None = None) -> list[T] | list[T | None]:
132+
"""Batch get and validate models by keys, preserving order.
133+
134+
Args:
135+
keys: The list of keys to retrieve.
136+
collection: The collection to use. If not provided, uses the default collection.
137+
default: The default value to return for keys that don't exist or fail validation.
138+
139+
Returns:
140+
A list of parsed model instances, with default values for missing keys or validation failures.
141+
142+
Raises:
143+
DeserializationError: If the stored data cannot be validated as the model and the adapter is configured to
144+
raise on validation error.
145+
146+
Note:
147+
When raise_on_validation_error=False and validation fails for any key, that position in the returned list
148+
will contain the default value (which may be None). The method returns a complete list matching the order
149+
and length of the input keys, with defaults substituted for missing or invalid entries.
150+
"""
151+
collection = collection or self._default_collection
152+
153+
values: list[dict[str, Any] | None] = await self._key_value.get_many(keys=keys, collection=collection)
154+
155+
result: list[T | None] = []
156+
for value in values:
157+
if value is None:
158+
result.append(default)
159+
else:
160+
validated = self._validate_model(value=value)
161+
result.append(validated if validated is not None else default)
162+
return result
163+
164+
async def put(self, key: str, value: T, *, collection: str | None = None, ttl: SupportsFloat | None = None) -> None:
165+
"""Serialize and store a model.
166+
167+
Propagates SerializationError if the model cannot be serialized.
168+
"""
169+
collection = collection or self._default_collection
170+
171+
value_dict: dict[str, Any] = self._serialize_model(value=value)
172+
173+
await self._key_value.put(key=key, value=value_dict, collection=collection, ttl=ttl)
174+
175+
async def put_many(
176+
self, keys: Sequence[str], values: Sequence[T], *, collection: str | None = None, ttl: SupportsFloat | None = None
177+
) -> None:
178+
"""Serialize and store multiple models, preserving order alignment with keys."""
179+
collection = collection or self._default_collection
180+
181+
value_dicts: list[dict[str, Any]] = [self._serialize_model(value=value) for value in values]
182+
183+
await self._key_value.put_many(keys=keys, values=value_dicts, collection=collection, ttl=ttl)
184+
185+
async def delete(self, key: str, *, collection: str | None = None) -> bool:
186+
"""Delete a model by key. Returns True if a value was deleted, else False."""
187+
collection = collection or self._default_collection
188+
189+
return await self._key_value.delete(key=key, collection=collection)
190+
191+
async def delete_many(self, keys: Sequence[str], *, collection: str | None = None) -> int:
192+
"""Delete multiple models by key. Returns the count of deleted entries."""
193+
collection = collection or self._default_collection
194+
195+
return await self._key_value.delete_many(keys=keys, collection=collection)
196+
197+
async def ttl(self, key: str, *, collection: str | None = None) -> tuple[T | None, float | None]:
198+
"""Get a model and its TTL seconds if present.
199+
200+
Args:
201+
key: The key to retrieve.
202+
collection: The collection to use. If not provided, uses the default collection.
203+
204+
Returns:
205+
A tuple of (model, ttl_seconds). Returns (None, None) if the key is missing or validation fails.
206+
207+
Note:
208+
When validation fails and raise_on_validation_error=False, returns (None, None) even if TTL data exists.
209+
When validation fails and raise_on_validation_error=True, raises DeserializationError.
210+
"""
211+
collection = collection or self._default_collection
212+
213+
entry: dict[str, Any] | None
214+
ttl_info: float | None
215+
216+
entry, ttl_info = await self._key_value.ttl(key=key, collection=collection)
217+
218+
if entry is None:
219+
return (None, None)
220+
221+
if validated_model := self._validate_model(value=entry):
222+
return (validated_model, ttl_info)
223+
224+
return (None, None)
225+
226+
async def ttl_many(self, keys: Sequence[str], *, collection: str | None = None) -> list[tuple[T | None, float | None]]:
227+
"""Batch get models with TTLs. Each element is (model|None, ttl_seconds|None)."""
228+
collection = collection or self._default_collection
229+
230+
entries: list[tuple[dict[str, Any] | None, float | None]] = await self._key_value.ttl_many(keys=keys, collection=collection)
231+
232+
return [(self._validate_model(value=entry) if entry else None, ttl_info) for entry, ttl_info in entries]

0 commit comments

Comments
 (0)