1- from datetime import datetime # noqa: TC003
1+ from collections .abc import Sequence
2+ from datetime import datetime
23from typing import Any , overload
34
45from elastic_transport import ObjectApiResponse # noqa: TC002
5- from key_value .shared .utils . compound import compound_key
6+ from key_value .shared .errors import DeserializationError
67from key_value .shared .utils .managed_entry import ManagedEntry , load_from_json
78from key_value .shared .utils .sanitize import (
89 ALPHANUMERIC_CHARACTERS ,
2122 BaseEnumerateKeysStore ,
2223 BaseStore ,
2324)
25+ from key_value .aio .stores .elasticsearch .utils import new_bulk_action
2426
2527try :
2628 from elasticsearch import AsyncElasticsearch
7173ALLOWED_INDEX_CHARACTERS : str = LOWERCASE_ALPHABET + NUMBERS + "_" + "-" + "."
7274
7375
76+ def managed_entry_to_document (collection : str , key : str , managed_entry : ManagedEntry ) -> dict [str , Any ]:
77+ document : dict [str , Any ] = {
78+ "collection" : collection ,
79+ "key" : key ,
80+ "value" : managed_entry .to_json (include_metadata = False ),
81+ }
82+
83+ if managed_entry .created_at :
84+ document ["created_at" ] = managed_entry .created_at .isoformat ()
85+ if managed_entry .expires_at :
86+ document ["expires_at" ] = managed_entry .expires_at .isoformat ()
87+
88+ return document
89+
90+
91+ def source_to_managed_entry (source : dict [str , Any ]) -> ManagedEntry :
92+ if not (value_str := source .get ("value" )) or not isinstance (value_str , str ):
93+ msg = "Value is not a string"
94+ raise DeserializationError (msg )
95+
96+ created_at : datetime | None = try_parse_datetime_str (value = source .get ("created_at" ))
97+ expires_at : datetime | None = try_parse_datetime_str (value = source .get ("expires_at" ))
98+
99+ return ManagedEntry (
100+ value = load_from_json (value_str ),
101+ created_at = created_at ,
102+ expires_at = expires_at ,
103+ )
104+
105+
74106class ElasticsearchStore (
75107 BaseEnumerateCollectionsStore , BaseEnumerateKeysStore , BaseDestroyCollectionStore , BaseCullStore , BaseContextManagerStore , BaseStore
76108):
@@ -156,13 +188,17 @@ def _sanitize_document_id(self, key: str) -> str:
156188 allowed_characters = ALLOWED_KEY_CHARACTERS ,
157189 )
158190
191+ def _get_destination (self , * , collection : str , key : str ) -> tuple [str , str ]:
192+ index_name : str = self ._sanitize_index_name (collection = collection )
193+ document_id : str = self ._sanitize_document_id (key = key )
194+
195+ return index_name , document_id
196+
159197 @override
160198 async def _get_managed_entry (self , * , key : str , collection : str ) -> ManagedEntry | None :
161- combo_key : str = compound_key (collection = collection , key = key )
199+ index_name , document_id = self . _get_destination (collection = collection , key = key )
162200
163- elasticsearch_response = await self ._client .options (ignore_status = 404 ).get (
164- index = self ._sanitize_index_name (collection = collection ), id = self ._sanitize_document_id (key = combo_key )
165- )
201+ elasticsearch_response = await self ._client .options (ignore_status = 404 ).get (index = index_name , id = document_id )
166202
167203 body : dict [str , Any ] = get_body_from_response (response = elasticsearch_response )
168204
@@ -181,6 +217,39 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry
181217 expires_at = expires_at ,
182218 )
183219
220+ @override
221+ async def _get_managed_entries (self , * , collection : str , keys : Sequence [str ]) -> list [ManagedEntry | None ]:
222+ if not keys :
223+ return []
224+
225+ # Use mget for efficient batch retrieval
226+ index_name = self ._sanitize_index_name (collection = collection )
227+ document_ids = [self ._sanitize_document_id (key = key ) for key in keys ]
228+ docs = [{"_id" : document_id } for document_id in document_ids ]
229+
230+ elasticsearch_response = await self ._client .options (ignore_status = 404 ).mget (index = index_name , docs = docs )
231+
232+ body : dict [str , Any ] = get_body_from_response (response = elasticsearch_response )
233+ docs_result = body .get ("docs" , [])
234+
235+ entries_by_id : dict [str , ManagedEntry | None ] = {}
236+ for doc in docs_result :
237+ if not (doc_id := doc .get ("_id" )):
238+ continue
239+
240+ if "found" not in doc :
241+ entries_by_id [doc_id ] = None
242+ continue
243+
244+ if not (source := doc .get ("_source" )):
245+ entries_by_id [doc_id ] = None
246+ continue
247+
248+ entries_by_id [doc_id ] = source_to_managed_entry (source = source )
249+
250+ # Return entries in the same order as input keys
251+ return [entries_by_id .get (document_id ) for document_id in document_ids ]
252+
184253 @property
185254 def _should_refresh_on_put (self ) -> bool :
186255 return not self ._is_serverless
@@ -193,32 +262,54 @@ async def _put_managed_entry(
193262 collection : str ,
194263 managed_entry : ManagedEntry ,
195264 ) -> None :
196- combo_key : str = compound_key (collection = collection , key = key )
265+ index_name : str = self ._sanitize_index_name (collection = collection )
266+ document_id : str = self ._sanitize_document_id (key = key )
197267
198- document : dict [str , Any ] = {
199- "collection" : collection ,
200- "key" : key ,
201- "value" : managed_entry .to_json (include_metadata = False ),
202- }
203-
204- if managed_entry .created_at :
205- document ["created_at" ] = managed_entry .created_at .isoformat ()
206- if managed_entry .expires_at :
207- document ["expires_at" ] = managed_entry .expires_at .isoformat ()
268+ document : dict [str , Any ] = managed_entry_to_document (collection = collection , key = key , managed_entry = managed_entry )
208269
209270 _ = await self ._client .index (
210- index = self . _sanitize_index_name ( collection = collection ) ,
211- id = self . _sanitize_document_id ( key = combo_key ) ,
271+ index = index_name ,
272+ id = document_id ,
212273 body = document ,
213274 refresh = self ._should_refresh_on_put ,
214275 )
215276
277+ @override
278+ async def _put_managed_entries (
279+ self ,
280+ * ,
281+ collection : str ,
282+ keys : Sequence [str ],
283+ managed_entries : Sequence [ManagedEntry ],
284+ ttl : float | None ,
285+ created_at : datetime ,
286+ expires_at : datetime | None ,
287+ ) -> None :
288+ if not keys :
289+ return
290+
291+ operations : list [dict [str , Any ]] = []
292+
293+ index_name : str = self ._sanitize_index_name (collection = collection )
294+
295+ for key , managed_entry in zip (keys , managed_entries , strict = True ):
296+ document_id : str = self ._sanitize_document_id (key = key )
297+
298+ index_action : dict [str , Any ] = new_bulk_action (action = "index" , index = index_name , document_id = document_id )
299+
300+ document : dict [str , Any ] = managed_entry_to_document (collection = collection , key = key , managed_entry = managed_entry )
301+
302+ operations .extend ([index_action , document ])
303+
304+ _ = await self ._client .bulk (operations = operations , refresh = self ._should_refresh_on_put ) # pyright: ignore[reportUnknownMemberType]
305+
216306 @override
217307 async def _delete_managed_entry (self , * , key : str , collection : str ) -> bool :
218- combo_key : str = compound_key (collection = collection , key = key )
308+ index_name : str = self ._sanitize_index_name (collection = collection )
309+ document_id : str = self ._sanitize_document_id (key = key )
219310
220311 elasticsearch_response : ObjectApiResponse [Any ] = await self ._client .options (ignore_status = 404 ).delete (
221- index = self . _sanitize_index_name ( collection = collection ) , id = self . _sanitize_document_id ( key = combo_key )
312+ index = index_name , id = document_id
222313 )
223314
224315 body : dict [str , Any ] = get_body_from_response (response = elasticsearch_response )
@@ -228,6 +319,34 @@ async def _delete_managed_entry(self, *, key: str, collection: str) -> bool:
228319
229320 return result == "deleted"
230321
322+ @override
323+ async def _delete_managed_entries (self , * , keys : Sequence [str ], collection : str ) -> int :
324+ if not keys :
325+ return 0
326+
327+ operations : list [dict [str , Any ]] = []
328+
329+ for key in keys :
330+ index_name , document_id = self ._get_destination (collection = collection , key = key )
331+
332+ delete_action : dict [str , Any ] = new_bulk_action (action = "delete" , index = index_name , document_id = document_id )
333+
334+ operations .append (delete_action )
335+
336+ elasticsearch_response = await self ._client .bulk (operations = operations ) # pyright: ignore[reportUnknownMemberType]
337+
338+ body : dict [str , Any ] = get_body_from_response (response = elasticsearch_response )
339+
340+ # Count successful deletions
341+ deleted_count = 0
342+ items = body .get ("items" , [])
343+ for item in items :
344+ delete_result = item .get ("delete" , {})
345+ if delete_result .get ("result" ) == "deleted" :
346+ deleted_count += 1
347+
348+ return deleted_count
349+
231350 @override
232351 async def _get_collection_keys (self , * , collection : str , limit : int | None = None ) -> list [str ]:
233352 """Get up to 10,000 keys in the specified collection (eventually consistent)."""
0 commit comments