22from datetime import datetime
33from typing import Any , overload
44
5- from elastic_transport import ObjectApiResponse # noqa: TC002
6- from key_value .shared .errors import DeserializationError
7- from key_value .shared .utils .managed_entry import ManagedEntry , load_from_json
5+ from elastic_transport import ObjectApiResponse
6+ from elastic_transport import SerializationError as ElasticsearchSerializationError
7+ from key_value .shared .errors import DeserializationError , SerializationError
8+ from key_value .shared .utils .managed_entry import ManagedEntry , load_from_json , verify_dict
89from key_value .shared .utils .sanitize import (
910 ALPHANUMERIC_CHARACTERS ,
1011 LOWERCASE_ALPHABET ,
2223 BaseEnumerateKeysStore ,
2324 BaseStore ,
2425)
25- from key_value .aio .stores .elasticsearch .utils import new_bulk_action
26+ from key_value .aio .stores .elasticsearch .utils import LessCapableJsonSerializer , LessCapableNdjsonSerializer , new_bulk_action
2627
2728try :
2829 from elasticsearch import AsyncElasticsearch
5556 "type" : "keyword" ,
5657 },
5758 "value" : {
58- "type" : "keyword" ,
59- "index" : False ,
60- "doc_values" : False ,
61- "ignore_above" : 256 ,
59+ "properties" : {
60+ # You might think the `string` field should be a text/keyword field
61+ # but this is the recommended mapping for large stringified json
62+ "string" : {
63+ "type" : "object" ,
64+ "enabled" : False ,
65+ },
66+ "flattened" : {
67+ "type" : "flattened" ,
68+ },
69+ },
6270 },
6371 },
6472}
7381ALLOWED_INDEX_CHARACTERS : str = LOWERCASE_ALPHABET + NUMBERS + "_" + "-" + "."
7482
7583
76- def managed_entry_to_document (collection : str , key : str , managed_entry : ManagedEntry ) -> dict [str , Any ]:
77- document : dict [str , Any ] = {
78- "collection" : collection ,
79- "key" : key ,
80- "value" : managed_entry .to_json (include_metadata = False ),
81- }
84+ def managed_entry_to_document (collection : str , key : str , managed_entry : ManagedEntry , * , native_storage : bool = False ) -> dict [str , Any ]:
85+ document : dict [str , Any ] = {"collection" : collection , "key" : key , "value" : {}}
86+
87+ # Store in appropriate field based on mode
88+ if native_storage :
89+ document ["value" ]["flattened" ] = managed_entry .value_as_dict
90+ else :
91+ document ["value" ]["string" ] = managed_entry .value_as_json
8292
8393 if managed_entry .created_at :
8494 document ["created_at" ] = managed_entry .created_at .isoformat ()
@@ -89,15 +99,31 @@ def managed_entry_to_document(collection: str, key: str, managed_entry: ManagedE
8999
90100
91101def source_to_managed_entry (source : dict [str , Any ]) -> ManagedEntry :
92- if not (value_str := source .get ("value" )) or not isinstance (value_str , str ):
93- msg = "Value is not a string"
102+ value : dict [str , Any ] = {}
103+
104+ raw_value = source .get ("value" )
105+
106+ # Try flattened field first, fall back to string field
107+ if not raw_value or not isinstance (raw_value , dict ):
108+ msg = "Value field not found or invalid type"
109+ raise DeserializationError (msg )
110+
111+ if value_flattened := raw_value .get ("flattened" ): # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType]
112+ value = verify_dict (obj = value_flattened )
113+ elif value_str := raw_value .get ("string" ): # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType]
114+ if not isinstance (value_str , str ):
115+ msg = "Value in `value` field is not a string"
116+ raise DeserializationError (msg )
117+ value = load_from_json (value_str )
118+ else :
119+ msg = "Value field not found or invalid type"
94120 raise DeserializationError (msg )
95121
96122 created_at : datetime | None = try_parse_datetime_str (value = source .get ("created_at" ))
97123 expires_at : datetime | None = try_parse_datetime_str (value = source .get ("expires_at" ))
98124
99125 return ManagedEntry (
100- value = load_from_json ( value_str ) ,
126+ value = value ,
101127 created_at = created_at ,
102128 expires_at = expires_at ,
103129 )
@@ -114,11 +140,28 @@ class ElasticsearchStore(
114140
115141 _index_prefix : str
116142
143+ _native_storage : bool
144+
117145 @overload
118- def __init__ (self , * , elasticsearch_client : AsyncElasticsearch , index_prefix : str , default_collection : str | None = None ) -> None : ...
146+ def __init__ (
147+ self ,
148+ * ,
149+ elasticsearch_client : AsyncElasticsearch ,
150+ index_prefix : str ,
151+ native_storage : bool = True ,
152+ default_collection : str | None = None ,
153+ ) -> None : ...
119154
120155 @overload
121- def __init__ (self , * , url : str , api_key : str | None = None , index_prefix : str , default_collection : str | None = None ) -> None : ...
156+ def __init__ (
157+ self ,
158+ * ,
159+ url : str ,
160+ api_key : str | None = None ,
161+ index_prefix : str ,
162+ native_storage : bool = True ,
163+ default_collection : str | None = None ,
164+ ) -> None : ...
122165
123166 def __init__ (
124167 self ,
@@ -127,6 +170,7 @@ def __init__(
127170 url : str | None = None ,
128171 api_key : str | None = None ,
129172 index_prefix : str ,
173+ native_storage : bool = True ,
130174 default_collection : str | None = None ,
131175 ) -> None :
132176 """Initialize the elasticsearch store.
@@ -136,6 +180,8 @@ def __init__(
136180 url: The url of the elasticsearch cluster.
137181 api_key: The api key to use.
138182 index_prefix: The index prefix to use. Collections will be prefixed with this prefix.
183+ native_storage: Whether to use native storage mode (flattened field type) or serialize
184+ all values to JSON strings. Defaults to True.
139185 default_collection: The default collection to use if no collection is provided.
140186 """
141187 if elasticsearch_client is None and url is None :
@@ -152,7 +198,12 @@ def __init__(
152198 msg = "Either elasticsearch_client or url must be provided"
153199 raise ValueError (msg )
154200
201+ LessCapableJsonSerializer .install_serializer (client = self ._client )
202+ LessCapableJsonSerializer .install_default_serializer (client = self ._client )
203+ LessCapableNdjsonSerializer .install_serializer (client = self ._client )
204+
155205 self ._index_prefix = index_prefix
206+ self ._native_storage = native_storage
156207 self ._is_serverless = False
157208
158209 super ().__init__ (default_collection = default_collection )
@@ -205,18 +256,11 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry
205256 if not (source := get_source_from_body (body = body )):
206257 return None
207258
208- if not (value_str := source .get ("value" )) or not isinstance (value_str , str ):
259+ try :
260+ return source_to_managed_entry (source = source )
261+ except DeserializationError :
209262 return None
210263
211- created_at : datetime | None = try_parse_datetime_str (value = source .get ("created_at" ))
212- expires_at : datetime | None = try_parse_datetime_str (value = source .get ("expires_at" ))
213-
214- return ManagedEntry (
215- value = load_from_json (value_str ),
216- created_at = created_at ,
217- expires_at = expires_at ,
218- )
219-
220264 @override
221265 async def _get_managed_entries (self , * , collection : str , keys : Sequence [str ]) -> list [ManagedEntry | None ]:
222266 if not keys :
@@ -265,15 +309,23 @@ async def _put_managed_entry(
265309 index_name : str = self ._sanitize_index_name (collection = collection )
266310 document_id : str = self ._sanitize_document_id (key = key )
267311
268- document : dict [str , Any ] = managed_entry_to_document (collection = collection , key = key , managed_entry = managed_entry )
269-
270- _ = await self ._client .index (
271- index = index_name ,
272- id = document_id ,
273- body = document ,
274- refresh = self ._should_refresh_on_put ,
312+ document : dict [str , Any ] = managed_entry_to_document (
313+ collection = collection , key = key , managed_entry = managed_entry , native_storage = self ._native_storage
275314 )
276315
316+ try :
317+ _ = await self ._client .index (
318+ index = index_name ,
319+ id = document_id ,
320+ body = document ,
321+ refresh = self ._should_refresh_on_put ,
322+ )
323+ except ElasticsearchSerializationError as e :
324+ msg = f"Failed to serialize document: { e } "
325+ raise SerializationError (message = msg ) from e
326+ except Exception :
327+ raise
328+
277329 @override
278330 async def _put_managed_entries (
279331 self ,
@@ -297,11 +349,18 @@ async def _put_managed_entries(
297349
298350 index_action : dict [str , Any ] = new_bulk_action (action = "index" , index = index_name , document_id = document_id )
299351
300- document : dict [str , Any ] = managed_entry_to_document (collection = collection , key = key , managed_entry = managed_entry )
352+ document : dict [str , Any ] = managed_entry_to_document (
353+ collection = collection , key = key , managed_entry = managed_entry , native_storage = self ._native_storage
354+ )
301355
302356 operations .extend ([index_action , document ])
303-
304- _ = await self ._client .bulk (operations = operations , refresh = self ._should_refresh_on_put ) # pyright: ignore[reportUnknownMemberType]
357+ try :
358+ _ = await self ._client .bulk (operations = operations , refresh = self ._should_refresh_on_put ) # pyright: ignore[reportUnknownMemberType]
359+ except ElasticsearchSerializationError as e :
360+ msg = f"Failed to serialize bulk operations: { e } "
361+ raise SerializationError (message = msg ) from e
362+ except Exception :
363+ raise
305364
306365 @override
307366 async def _delete_managed_entry (self , * , key : str , collection : str ) -> bool :
0 commit comments