11from collections .abc import Mapping , Sequence
2- from typing import Any
2+ from typing import Any , SupportsFloat
33
4+ from key_value .shared .utils .managed_entry import dump_to_json , load_from_json
45from typing_extensions import override
56
67from key_value .aio .protocols .key_value import AsyncKeyValue
@@ -12,28 +13,35 @@ class DefaultValueWrapper(BaseWrapper):
1213
1314 This wrapper provides dict.get(key, default) behavior for the key-value store,
1415 allowing you to specify a default value to return instead of None when a key doesn't exist.
16+
17+ It does not store the default value in the underlying key-value store and the TTL returned with the default
18+ value is hard-coded based on the default_ttl parameter. Picking a default_ttl requires careful consideration
19+ of how the value will be used and if any other wrappers will be used that may rely on the TTL.
1520 """
1621
17- _key_value : AsyncKeyValue
1822 key_value : AsyncKeyValue # Alias for BaseWrapper compatibility
23+ _default_ttl : float | None
24+ _default_value_json : str
1925
2026 def __init__ (
2127 self ,
2228 key_value : AsyncKeyValue ,
2329 default_value : Mapping [str , Any ],
24- default_ttl : float | None = None ,
30+ default_ttl : SupportsFloat | None = None ,
2531 ) -> None :
2632 """Initialize the DefaultValueWrapper.
2733
2834 Args:
2935 key_value: The underlying key-value store to wrap.
3036 default_value: The default value to return when a key is not found.
31- default_ttl: The TTL to return for default values. Defaults to None.
37+ default_ttl: The TTL to return to the caller for default values. Defaults to None.
3238 """
33- self ._key_value = key_value
34- self .key_value = key_value # Alias for BaseWrapper compatibility
35- self ._default_value = default_value
36- self ._default_ttl = default_ttl
39+ self .key_value = key_value
40+ self ._default_value_json = dump_to_json (obj = dict (default_value ))
41+ self ._default_ttl = None if default_ttl is None else float (default_ttl )
42+
43+ def _new_default_value (self ) -> dict [str , Any ]:
44+ return load_from_json (json_str = self ._default_value_json )
3745
3846 @override
3947 async def get (self , key : str , * , collection : str | None = None ) -> dict [str , Any ] | None :
@@ -46,8 +54,8 @@ async def get(self, key: str, *, collection: str | None = None) -> dict[str, Any
4654 Returns:
4755 The value associated with the key, or the default value if not found.
4856 """
49- result = await self ._key_value .get (key = key , collection = collection )
50- return result if result is not None else dict ( self ._default_value )
57+ result = await self .key_value .get (key = key , collection = collection )
58+ return result if result is not None else self ._new_default_value ( )
5159
5260 @override
5361 async def get_many (self , keys : Sequence [str ], * , collection : str | None = None ) -> list [dict [str , Any ] | None ]:
@@ -60,8 +68,8 @@ async def get_many(self, keys: Sequence[str], *, collection: str | None = None)
6068 Returns:
6169 A list of values, with default values for missing keys.
6270 """
63- results = await self ._key_value .get_many (keys = keys , collection = collection )
64- return [result if result is not None else dict ( self ._default_value ) for result in results ]
71+ results = await self .key_value .get_many (keys = keys , collection = collection )
72+ return [result if result is not None else self ._new_default_value ( ) for result in results ]
6573
6674 @override
6775 async def ttl (self , key : str , * , collection : str | None = None ) -> tuple [dict [str , Any ] | None , float | None ]:
@@ -74,9 +82,9 @@ async def ttl(self, key: str, *, collection: str | None = None) -> tuple[dict[st
7482 Returns:
7583 A tuple of (value, ttl), with default value and default TTL if not found.
7684 """
77- result , ttl_value = await self ._key_value .ttl (key = key , collection = collection )
85+ result , ttl_value = await self .key_value .ttl (key = key , collection = collection )
7886 if result is None :
79- return (dict ( self ._default_value ), self ._default_ttl )
87+ return (self ._new_default_value ( ), self ._default_ttl )
8088 return (result , ttl_value )
8189
8290 @override
@@ -90,7 +98,7 @@ async def ttl_many(self, keys: Sequence[str], *, collection: str | None = None)
9098 Returns:
9199 A list of (value, ttl) tuples, with default values and default TTL for missing keys.
92100 """
93- results = await self ._key_value .ttl_many (keys = keys , collection = collection )
101+ results = await self .key_value .ttl_many (keys = keys , collection = collection )
94102 return [
95- (result , ttl_value ) if result is not None else (dict ( self ._default_value ), self ._default_ttl ) for result , ttl_value in results
103+ (result , ttl_value ) if result is not None else (self ._new_default_value ( ), self ._default_ttl ) for result , ttl_value in results
96104 ]
0 commit comments