77from elastic_transport import SerializationError as ElasticsearchSerializationError
88from key_value .shared .errors import DeserializationError , SerializationError
99from key_value .shared .utils .managed_entry import ManagedEntry
10+ from key_value .shared .utils .sanitization import AlwaysHashStrategy , HashFragmentMode , HybridSanitizationStrategy
1011from key_value .shared .utils .sanitize import (
1112 ALPHANUMERIC_CHARACTERS ,
1213 LOWERCASE_ALPHABET ,
1314 NUMBERS ,
14- sanitize_string ,
15+ UPPERCASE_ALPHABET ,
1516)
1617from key_value .shared .utils .serialization import SerializationAdapter
1718from key_value .shared .utils .time_to_live import now_as_epoch
@@ -148,7 +149,7 @@ class ElasticsearchStore(
148149
149150 _native_storage : bool
150151
151- _adapter : SerializationAdapter
152+ _serializer : SerializationAdapter
152153
153154 @overload
154155 def __init__ (
@@ -210,12 +211,31 @@ def __init__(
210211 LessCapableJsonSerializer .install_default_serializer (client = self ._client )
211212 LessCapableNdjsonSerializer .install_serializer (client = self ._client )
212213
213- self ._index_prefix = index_prefix
214+ self ._index_prefix = index_prefix . lower ()
214215 self ._native_storage = native_storage
215216 self ._is_serverless = False
216- self ._adapter = ElasticsearchSerializationAdapter (native_storage = native_storage )
217217
218- super ().__init__ (default_collection = default_collection )
218+ # We have 240 characters to work with
219+ # We need to account for the index prefix and the hyphen.
220+ max_index_length = MAX_INDEX_LENGTH - (len (self ._index_prefix ) + 1 )
221+
222+ self ._serializer = ElasticsearchSerializationAdapter (native_storage = native_storage )
223+
224+ # We allow uppercase through the sanitizer so we can lowercase them instead of them
225+ # all turning into underscores.
226+ collection_sanitization = HybridSanitizationStrategy (
227+ replacement_character = "_" ,
228+ max_length = max_index_length ,
229+ allowed_characters = UPPERCASE_ALPHABET + ALLOWED_INDEX_CHARACTERS ,
230+ hash_fragment_mode = HashFragmentMode .ALWAYS ,
231+ )
232+ key_sanitization = AlwaysHashStrategy ()
233+
234+ super ().__init__ (
235+ default_collection = default_collection ,
236+ collection_sanitization_strategy = collection_sanitization ,
237+ key_sanitization_strategy = key_sanitization ,
238+ )
219239
220240 @override
221241 async def _setup (self ) -> None :
@@ -225,32 +245,22 @@ async def _setup(self) -> None:
225245
226246 @override
227247 async def _setup_collection (self , * , collection : str ) -> None :
228- index_name = self ._sanitize_index_name (collection = collection )
248+ index_name = self ._get_index_name (collection = collection )
229249
230250 if await self ._client .options (ignore_status = 404 ).indices .exists (index = index_name ):
231251 return
232252
233253 _ = await self ._client .options (ignore_status = 404 ).indices .create (index = index_name , mappings = DEFAULT_MAPPING , settings = {})
234254
235- def _sanitize_index_name (self , collection : str ) -> str :
236- return sanitize_string (
237- value = self ._index_prefix + "-" + collection ,
238- replacement_character = "_" ,
239- max_length = MAX_INDEX_LENGTH ,
240- allowed_characters = ALLOWED_INDEX_CHARACTERS ,
241- )
255+ def _get_index_name (self , collection : str ) -> str :
256+ return self ._index_prefix + "-" + self ._sanitize_collection (collection = collection ).lower ()
242257
243- def _sanitize_document_id (self , key : str ) -> str :
244- return sanitize_string (
245- value = key ,
246- replacement_character = "_" ,
247- max_length = MAX_KEY_LENGTH ,
248- allowed_characters = ALLOWED_KEY_CHARACTERS ,
249- )
258+ def _get_document_id (self , key : str ) -> str :
259+ return self ._sanitize_key (key = key )
250260
251261 def _get_destination (self , * , collection : str , key : str ) -> tuple [str , str ]:
252- index_name : str = self ._sanitize_index_name (collection = collection )
253- document_id : str = self ._sanitize_document_id (key = key )
262+ index_name : str = self ._get_index_name (collection = collection )
263+ document_id : str = self ._get_document_id (key = key )
254264
255265 return index_name , document_id
256266
@@ -266,7 +276,7 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry
266276 return None
267277
268278 try :
269- return self ._adapter .load_dict (data = source )
279+ return self ._serializer .load_dict (data = source )
270280 except DeserializationError :
271281 return None
272282
@@ -276,8 +286,8 @@ async def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) ->
276286 return []
277287
278288 # Use mget for efficient batch retrieval
279- index_name = self ._sanitize_index_name (collection = collection )
280- document_ids = [self ._sanitize_document_id (key = key ) for key in keys ]
289+ index_name = self ._get_index_name (collection = collection )
290+ document_ids = [self ._get_document_id (key = key ) for key in keys ]
281291 docs = [{"_id" : document_id } for document_id in document_ids ]
282292
283293 elasticsearch_response = await self ._client .options (ignore_status = 404 ).mget (index = index_name , docs = docs )
@@ -299,7 +309,7 @@ async def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) ->
299309 continue
300310
301311 try :
302- entries_by_id [doc_id ] = self ._adapter .load_dict (data = source )
312+ entries_by_id [doc_id ] = self ._serializer .load_dict (data = source )
303313 except DeserializationError as e :
304314 logger .error (
305315 "Failed to deserialize Elasticsearch document in batch operation" ,
@@ -327,10 +337,10 @@ async def _put_managed_entry(
327337 collection : str ,
328338 managed_entry : ManagedEntry ,
329339 ) -> None :
330- index_name : str = self ._sanitize_index_name (collection = collection )
331- document_id : str = self ._sanitize_document_id (key = key )
340+ index_name : str = self ._get_index_name (collection = collection )
341+ document_id : str = self ._get_document_id (key = key )
332342
333- document : dict [str , Any ] = self ._adapter .dump_dict (entry = managed_entry , key = key , collection = collection )
343+ document : dict [str , Any ] = self ._serializer .dump_dict (entry = managed_entry , key = key , collection = collection )
334344
335345 try :
336346 _ = await self ._client .index (
@@ -361,14 +371,14 @@ async def _put_managed_entries(
361371
362372 operations : list [dict [str , Any ]] = []
363373
364- index_name : str = self ._sanitize_index_name (collection = collection )
374+ index_name : str = self ._get_index_name (collection = collection )
365375
366376 for key , managed_entry in zip (keys , managed_entries , strict = True ):
367- document_id : str = self ._sanitize_document_id (key = key )
377+ document_id : str = self ._get_document_id (key = key )
368378
369379 index_action : dict [str , Any ] = new_bulk_action (action = "index" , index = index_name , document_id = document_id )
370380
371- document : dict [str , Any ] = self ._adapter .dump_dict (entry = managed_entry , key = key , collection = collection )
381+ document : dict [str , Any ] = self ._serializer .dump_dict (entry = managed_entry , key = key , collection = collection )
372382
373383 operations .extend ([index_action , document ])
374384
@@ -382,8 +392,8 @@ async def _put_managed_entries(
382392
383393 @override
384394 async def _delete_managed_entry (self , * , key : str , collection : str ) -> bool :
385- index_name : str = self ._sanitize_index_name (collection = collection )
386- document_id : str = self ._sanitize_document_id (key = key )
395+ index_name : str = self ._get_index_name (collection = collection )
396+ document_id : str = self ._get_document_id (key = key )
387397
388398 elasticsearch_response : ObjectApiResponse [Any ] = await self ._client .options (ignore_status = 404 ).delete (
389399 index = index_name , id = document_id
@@ -431,7 +441,7 @@ async def _get_collection_keys(self, *, collection: str, limit: int | None = Non
431441 limit = min (limit or DEFAULT_PAGE_SIZE , PAGE_LIMIT )
432442
433443 result : ObjectApiResponse [Any ] = await self ._client .options (ignore_status = 404 ).search (
434- index = self ._sanitize_index_name (collection = collection ),
444+ index = self ._get_index_name (collection = collection ),
435445 fields = [{"key" : None }],
436446 body = {
437447 "query" : {
@@ -447,7 +457,15 @@ async def _get_collection_keys(self, *, collection: str, limit: int | None = Non
447457 if not (hits := get_hits_from_response (response = result )):
448458 return []
449459
450- return [key for hit in hits if (key := get_first_value_from_field_in_hit (hit = hit , field = "key" , value_type = str ))]
460+ all_keys : list [str ] = []
461+
462+ for hit in hits :
463+ if not (key := get_first_value_from_field_in_hit (hit = hit , field = "key" , value_type = str )):
464+ continue
465+
466+ all_keys .append (key )
467+
468+ return all_keys
451469
452470 @override
453471 async def _get_collection_names (self , * , limit : int | None = None ) -> list [str ]:
@@ -478,7 +496,7 @@ async def _get_collection_names(self, *, limit: int | None = None) -> list[str]:
478496 @override
479497 async def _delete_collection (self , * , collection : str ) -> bool :
480498 result : ObjectApiResponse [Any ] = await self ._client .options (ignore_status = 404 ).delete_by_query (
481- index = self ._sanitize_index_name (collection = collection ),
499+ index = self ._get_index_name (collection = collection ),
482500 body = {
483501 "query" : {
484502 "term" : {
0 commit comments