Skip to content

Commit 0481973

Browse files
Address CodeRabbit review feedback
- Replace truthiness checks with explicit None checks in BasePydanticAdapter to correctly handle empty dicts/lists - Validate missing 'items' wrapper for list models instead of silently accepting malformed payloads - Restrict list model detection to only 'list' type (not all Sequence types) in DataclassAdapter - Update @bear_spray comment to accurately describe why it's needed - Include type in error message for missing type argument - Add edge-case tests for empty list TTL and missing 'items' wrapper scenarios - All 28 adapter tests passing, linting and type checking clean Co-authored-by: William Easton <strawgate@users.noreply.github.com>
1 parent d591c4e commit 0481973

File tree

3 files changed

+41
-9
lines changed

3 files changed

+41
-9
lines changed

key-value/key-value-aio/src/key_value/aio/adapters/base/adapter.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,12 @@ def _validate_model(self, value: dict[str, Any]) -> T | None:
5353
"""
5454
try:
5555
if self._is_list_model:
56-
return self._type_adapter.validate_python(value.get("items", []))
56+
if "items" not in value:
57+
if self._raise_on_validation_error:
58+
msg = f"Invalid {self._get_model_type_name()} payload: missing 'items' wrapper"
59+
raise DeserializationError(msg)
60+
return None
61+
return self._type_adapter.validate_python(value["items"])
5762

5863
return self._type_adapter.validate_python(value)
5964
except ValidationError as e:
@@ -115,7 +120,8 @@ async def get(self, key: str, *, collection: str | None = None, default: T | Non
115120
"""
116121
collection = collection or self._default_collection
117122

118-
if value := await self._key_value.get(key=key, collection=collection):
123+
value = await self._key_value.get(key=key, collection=collection)
124+
if value is not None:
119125
validated = self._validate_model(value=value)
120126
if validated is not None:
121127
return validated
@@ -218,7 +224,8 @@ async def ttl(self, key: str, *, collection: str | None = None) -> tuple[T | Non
218224
if entry is None:
219225
return (None, None)
220226

221-
if validated_model := self._validate_model(value=entry):
227+
validated_model = self._validate_model(value=entry)
228+
if validated_model is not None:
222229
return (validated_model, ttl_info)
223230

224231
return (None, None)
@@ -229,4 +236,4 @@ async def ttl_many(self, keys: Sequence[str], *, collection: str | None = None)
229236

230237
entries: list[tuple[dict[str, Any] | None, float | None]] = await self._key_value.ttl_many(keys=keys, collection=collection)
231238

232-
return [(self._validate_model(value=entry) if entry else None, ttl_info) for entry, ttl_info in entries]
239+
return [(self._validate_model(value=entry) if entry is not None else None, ttl_info) for entry, ttl_info in entries]

key-value/key-value-aio/src/key_value/aio/adapters/dataclass/adapter.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from collections.abc import Sequence
21
from dataclasses import is_dataclass
32
from typing import Any, TypeVar, get_args, get_origin
43

@@ -20,8 +19,8 @@ class DataclassAdapter(BasePydanticAdapter[T]):
2019

2120
_inner_type: type[Any]
2221

23-
# Beartype doesn't like our `type[T]` includes a bound on Sequence[...] as the subscript is not checkable at runtime
24-
# For just the next 20 or so lines we are no longer bear bros but have no fear, we will be back soon!
22+
# Beartype cannot handle the parameterized type annotation (type[T]) used here for this generic dataclass adapter.
23+
# Using @bear_spray to bypass beartype's runtime checks for this specific method.
2524
@bear_spray
2625
def __init__(
2726
self,
@@ -44,13 +43,13 @@ def __init__(
4443
self._key_value = key_value
4544

4645
origin = get_origin(dataclass_type)
47-
self._is_list_model = origin is not None and isinstance(origin, type) and issubclass(origin, Sequence)
46+
self._is_list_model = origin is list
4847

4948
# Extract the inner type for list models
5049
if self._is_list_model:
5150
args = get_args(dataclass_type)
5251
if not args:
53-
msg = "List type must have a type argument"
52+
msg = f"List type {dataclass_type} must have a type argument"
5453
raise TypeError(msg)
5554
self._inner_type = args[0]
5655
else:

key-value/key-value-aio/tests/adapters/test_dataclass.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from dataclasses import dataclass
22
from datetime import datetime, timezone
3+
from typing import Any
34

45
import pytest
56
from inline_snapshot import snapshot
@@ -225,3 +226,28 @@ async def test_default_collection(self, store: MemoryStore):
225226
assert cached_user == SAMPLE_USER
226227

227228
assert await adapter.delete(key=TEST_KEY)
229+
230+
async def test_ttl_with_empty_list(self, product_list_adapter: DataclassAdapter[list[Product]]):
231+
"""Test that TTL with empty list returns correctly (not None)."""
232+
await product_list_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY, value=[], ttl=10)
233+
value, ttl = await product_list_adapter.ttl(collection=TEST_COLLECTION, key=TEST_KEY)
234+
assert value == []
235+
assert ttl is not None
236+
assert ttl > 0
237+
238+
async def test_list_payload_missing_items_returns_none(self, product_list_adapter: DataclassAdapter[list[Product]], store: MemoryStore):
239+
"""Test that list payload without 'items' wrapper returns None when raise_on_validation_error is False."""
240+
# Manually insert malformed payload without the 'items' wrapper
241+
# The payload is a dict but without the expected 'items' key for list models
242+
malformed_payload: dict[str, Any] = {"wrong": []}
243+
await store.put(collection=TEST_COLLECTION, key=TEST_KEY, value=malformed_payload)
244+
assert await product_list_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) is None
245+
246+
async def test_list_payload_missing_items_raises(self, product_list_adapter: DataclassAdapter[list[Product]], store: MemoryStore):
247+
"""Test that list payload without 'items' wrapper raises DeserializationError when configured."""
248+
product_list_adapter._raise_on_validation_error = True # pyright: ignore[reportPrivateUsage]
249+
# Manually insert malformed payload without the 'items' wrapper
250+
malformed_payload: dict[str, Any] = {"wrong": []}
251+
await store.put(collection=TEST_COLLECTION, key=TEST_KEY, value=malformed_payload)
252+
with pytest.raises(DeserializationError, match="missing 'items'"):
253+
await product_list_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY)

0 commit comments

Comments
 (0)