Skip to content

Commit ae6f3d1

Browse files
fix: merge with main and re-add serialization parameters
Resolved merge conflicts with PR #208 (SanitizationStrategy) by: - Adopting new SanitizationStrategy infrastructure from main - Re-adding key, collection, and version parameters to serialization - Updating all 12 store implementations to pass metadata - Running codegen to generate sync versions Changes: - SerializationAdapter.dump_dict() and dump_json() now accept key, collection, version - All async stores updated to pass key/collection to serialization - All sync stores regenerated via codegen - Elasticsearch mapping includes version field This preserves both PR #208's sanitization improvements and our PR #204's enumeration support for stores that sanitize/hash keys. Co-authored-by: William Easton <strawgate@users.noreply.github.com>
1 parent 2dc7f22 commit ae6f3d1

File tree

11 files changed

+219
-249
lines changed

11 files changed

+219
-249
lines changed

key-value/key-value-aio/src/key_value/aio/stores/elasticsearch/store.py

Lines changed: 55 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77
from elastic_transport import SerializationError as ElasticsearchSerializationError
88
from key_value.shared.errors import DeserializationError, SerializationError
99
from key_value.shared.utils.managed_entry import ManagedEntry
10+
from key_value.shared.utils.sanitization import AlwaysHashStrategy, HashFragmentMode, HybridSanitizationStrategy
1011
from key_value.shared.utils.sanitize import (
1112
ALPHANUMERIC_CHARACTERS,
1213
LOWERCASE_ALPHABET,
1314
NUMBERS,
14-
sanitize_string,
15+
UPPERCASE_ALPHABET,
1516
)
1617
from key_value.shared.utils.serialization import SerializationAdapter
1718
from key_value.shared.utils.time_to_live import now_as_epoch
@@ -148,7 +149,7 @@ class ElasticsearchStore(
148149

149150
_native_storage: bool
150151

151-
_adapter: SerializationAdapter
152+
_serializer: SerializationAdapter
152153

153154
@overload
154155
def __init__(
@@ -210,12 +211,31 @@ def __init__(
210211
LessCapableJsonSerializer.install_default_serializer(client=self._client)
211212
LessCapableNdjsonSerializer.install_serializer(client=self._client)
212213

213-
self._index_prefix = index_prefix
214+
self._index_prefix = index_prefix.lower()
214215
self._native_storage = native_storage
215216
self._is_serverless = False
216-
self._adapter = ElasticsearchSerializationAdapter(native_storage=native_storage)
217217

218-
super().__init__(default_collection=default_collection)
218+
# We have 240 characters to work with
219+
# We need to account for the index prefix and the hyphen.
220+
max_index_length = MAX_INDEX_LENGTH - (len(self._index_prefix) + 1)
221+
222+
self._serializer = ElasticsearchSerializationAdapter(native_storage=native_storage)
223+
224+
# We allow uppercase through the sanitizer so we can lowercase them instead of them
225+
# all turning into underscores.
226+
collection_sanitization = HybridSanitizationStrategy(
227+
replacement_character="_",
228+
max_length=max_index_length,
229+
allowed_characters=UPPERCASE_ALPHABET + ALLOWED_INDEX_CHARACTERS,
230+
hash_fragment_mode=HashFragmentMode.ALWAYS,
231+
)
232+
key_sanitization = AlwaysHashStrategy()
233+
234+
super().__init__(
235+
default_collection=default_collection,
236+
collection_sanitization_strategy=collection_sanitization,
237+
key_sanitization_strategy=key_sanitization,
238+
)
219239

220240
@override
221241
async def _setup(self) -> None:
@@ -225,32 +245,22 @@ async def _setup(self) -> None:
225245

226246
@override
227247
async def _setup_collection(self, *, collection: str) -> None:
228-
index_name = self._sanitize_index_name(collection=collection)
248+
index_name = self._get_index_name(collection=collection)
229249

230250
if await self._client.options(ignore_status=404).indices.exists(index=index_name):
231251
return
232252

233253
_ = await self._client.options(ignore_status=404).indices.create(index=index_name, mappings=DEFAULT_MAPPING, settings={})
234254

235-
def _sanitize_index_name(self, collection: str) -> str:
236-
return sanitize_string(
237-
value=self._index_prefix + "-" + collection,
238-
replacement_character="_",
239-
max_length=MAX_INDEX_LENGTH,
240-
allowed_characters=ALLOWED_INDEX_CHARACTERS,
241-
)
255+
def _get_index_name(self, collection: str) -> str:
256+
return self._index_prefix + "-" + self._sanitize_collection(collection=collection).lower()
242257

243-
def _sanitize_document_id(self, key: str) -> str:
244-
return sanitize_string(
245-
value=key,
246-
replacement_character="_",
247-
max_length=MAX_KEY_LENGTH,
248-
allowed_characters=ALLOWED_KEY_CHARACTERS,
249-
)
258+
def _get_document_id(self, key: str) -> str:
259+
return self._sanitize_key(key=key)
250260

251261
def _get_destination(self, *, collection: str, key: str) -> tuple[str, str]:
252-
index_name: str = self._sanitize_index_name(collection=collection)
253-
document_id: str = self._sanitize_document_id(key=key)
262+
index_name: str = self._get_index_name(collection=collection)
263+
document_id: str = self._get_document_id(key=key)
254264

255265
return index_name, document_id
256266

@@ -266,7 +276,7 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry
266276
return None
267277

268278
try:
269-
return self._adapter.load_dict(data=source)
279+
return self._serializer.load_dict(data=source)
270280
except DeserializationError:
271281
return None
272282

@@ -276,8 +286,8 @@ async def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) ->
276286
return []
277287

278288
# Use mget for efficient batch retrieval
279-
index_name = self._sanitize_index_name(collection=collection)
280-
document_ids = [self._sanitize_document_id(key=key) for key in keys]
289+
index_name = self._get_index_name(collection=collection)
290+
document_ids = [self._get_document_id(key=key) for key in keys]
281291
docs = [{"_id": document_id} for document_id in document_ids]
282292

283293
elasticsearch_response = await self._client.options(ignore_status=404).mget(index=index_name, docs=docs)
@@ -299,7 +309,7 @@ async def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) ->
299309
continue
300310

301311
try:
302-
entries_by_id[doc_id] = self._adapter.load_dict(data=source)
312+
entries_by_id[doc_id] = self._serializer.load_dict(data=source)
303313
except DeserializationError as e:
304314
logger.error(
305315
"Failed to deserialize Elasticsearch document in batch operation",
@@ -327,10 +337,10 @@ async def _put_managed_entry(
327337
collection: str,
328338
managed_entry: ManagedEntry,
329339
) -> None:
330-
index_name: str = self._sanitize_index_name(collection=collection)
331-
document_id: str = self._sanitize_document_id(key=key)
340+
index_name: str = self._get_index_name(collection=collection)
341+
document_id: str = self._get_document_id(key=key)
332342

333-
document: dict[str, Any] = self._adapter.dump_dict(entry=managed_entry, key=key, collection=collection)
343+
document: dict[str, Any] = self._serializer.dump_dict(entry=managed_entry, key=key, collection=collection)
334344

335345
try:
336346
_ = await self._client.index(
@@ -361,14 +371,14 @@ async def _put_managed_entries(
361371

362372
operations: list[dict[str, Any]] = []
363373

364-
index_name: str = self._sanitize_index_name(collection=collection)
374+
index_name: str = self._get_index_name(collection=collection)
365375

366376
for key, managed_entry in zip(keys, managed_entries, strict=True):
367-
document_id: str = self._sanitize_document_id(key=key)
377+
document_id: str = self._get_document_id(key=key)
368378

369379
index_action: dict[str, Any] = new_bulk_action(action="index", index=index_name, document_id=document_id)
370380

371-
document: dict[str, Any] = self._adapter.dump_dict(entry=managed_entry, key=key, collection=collection)
381+
document: dict[str, Any] = self._serializer.dump_dict(entry=managed_entry, key=key, collection=collection)
372382

373383
operations.extend([index_action, document])
374384

@@ -382,8 +392,8 @@ async def _put_managed_entries(
382392

383393
@override
384394
async def _delete_managed_entry(self, *, key: str, collection: str) -> bool:
385-
index_name: str = self._sanitize_index_name(collection=collection)
386-
document_id: str = self._sanitize_document_id(key=key)
395+
index_name: str = self._get_index_name(collection=collection)
396+
document_id: str = self._get_document_id(key=key)
387397

388398
elasticsearch_response: ObjectApiResponse[Any] = await self._client.options(ignore_status=404).delete(
389399
index=index_name, id=document_id
@@ -431,7 +441,7 @@ async def _get_collection_keys(self, *, collection: str, limit: int | None = Non
431441
limit = min(limit or DEFAULT_PAGE_SIZE, PAGE_LIMIT)
432442

433443
result: ObjectApiResponse[Any] = await self._client.options(ignore_status=404).search(
434-
index=self._sanitize_index_name(collection=collection),
444+
index=self._get_index_name(collection=collection),
435445
fields=[{"key": None}],
436446
body={
437447
"query": {
@@ -447,7 +457,15 @@ async def _get_collection_keys(self, *, collection: str, limit: int | None = Non
447457
if not (hits := get_hits_from_response(response=result)):
448458
return []
449459

450-
return [key for hit in hits if (key := get_first_value_from_field_in_hit(hit=hit, field="key", value_type=str))]
460+
all_keys: list[str] = []
461+
462+
for hit in hits:
463+
if not (key := get_first_value_from_field_in_hit(hit=hit, field="key", value_type=str)):
464+
continue
465+
466+
all_keys.append(key)
467+
468+
return all_keys
451469

452470
@override
453471
async def _get_collection_names(self, *, limit: int | None = None) -> list[str]:
@@ -478,7 +496,7 @@ async def _get_collection_names(self, *, limit: int | None = None) -> list[str]:
478496
@override
479497
async def _delete_collection(self, *, collection: str) -> bool:
480498
result: ObjectApiResponse[Any] = await self._client.options(ignore_status=404).delete_by_query(
481-
index=self._sanitize_index_name(collection=collection),
499+
index=self._get_index_name(collection=collection),
482500
body={
483501
"query": {
484502
"term": {

key-value/key-value-aio/src/key_value/aio/stores/keyring/store.py

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
from key_value.shared.utils.compound import compound_key
44
from key_value.shared.utils.managed_entry import ManagedEntry
5-
from key_value.shared.utils.sanitize import ALPHANUMERIC_CHARACTERS, sanitize_string
5+
from key_value.shared.utils.sanitization import HybridSanitizationStrategy
6+
from key_value.shared.utils.sanitize import ALPHANUMERIC_CHARACTERS
67
from typing_extensions import override
78

89
from key_value.aio.stores.base import BaseStore
@@ -15,11 +16,9 @@
1516
raise ImportError(msg) from e
1617

1718
DEFAULT_KEYCHAIN_SERVICE = "py-key-value"
18-
MAX_KEY_LENGTH = 256
19-
ALLOWED_KEY_CHARACTERS: str = ALPHANUMERIC_CHARACTERS
2019

21-
MAX_COLLECTION_LENGTH = 256
22-
ALLOWED_COLLECTION_CHARACTERS: str = ALPHANUMERIC_CHARACTERS
20+
MAX_KEY_COLLECTION_LENGTH = 256
21+
ALLOWED_KEY_COLLECTION_CHARACTERS: str = ALPHANUMERIC_CHARACTERS
2322

2423

2524
class KeyringStore(BaseStore):
@@ -48,25 +47,19 @@ def __init__(
4847
"""
4948
self._service_name = service_name
5049

51-
super().__init__(default_collection=default_collection)
52-
53-
def _sanitize_collection_name(self, collection: str) -> str:
54-
return sanitize_string(
55-
value=collection,
56-
max_length=MAX_COLLECTION_LENGTH,
57-
allowed_characters=ALLOWED_COLLECTION_CHARACTERS,
50+
sanitization_strategy = HybridSanitizationStrategy(
51+
replacement_character="_", max_length=MAX_KEY_COLLECTION_LENGTH, allowed_characters=ALLOWED_KEY_COLLECTION_CHARACTERS
5852
)
5953

60-
def _sanitize_key(self, key: str) -> str:
61-
return sanitize_string(
62-
value=key,
63-
max_length=MAX_KEY_LENGTH,
64-
allowed_characters=ALLOWED_KEY_CHARACTERS,
54+
super().__init__(
55+
default_collection=default_collection,
56+
collection_sanitization_strategy=sanitization_strategy,
57+
key_sanitization_strategy=sanitization_strategy,
6558
)
6659

6760
@override
6861
async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry | None:
69-
sanitized_collection = self._sanitize_collection_name(collection=collection)
62+
sanitized_collection = self._sanitize_collection(collection=collection)
7063
sanitized_key = self._sanitize_key(key=key)
7164

7265
combo_key: str = compound_key(collection=sanitized_collection, key=sanitized_key)
@@ -83,7 +76,7 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry
8376

8477
@override
8578
async def _put_managed_entry(self, *, key: str, collection: str, managed_entry: ManagedEntry) -> None:
86-
sanitized_collection = self._sanitize_collection_name(collection=collection)
79+
sanitized_collection = self._sanitize_collection(collection=collection)
8780
sanitized_key = self._sanitize_key(key=key)
8881

8982
combo_key: str = compound_key(collection=sanitized_collection, key=sanitized_key)
@@ -94,7 +87,7 @@ async def _put_managed_entry(self, *, key: str, collection: str, managed_entry:
9487

9588
@override
9689
async def _delete_managed_entry(self, *, key: str, collection: str) -> bool:
97-
sanitized_collection = self._sanitize_collection_name(collection=collection)
90+
sanitized_collection = self._sanitize_collection(collection=collection)
9891
sanitized_key = self._sanitize_key(key=key)
9992

10093
combo_key: str = compound_key(collection=sanitized_collection, key=sanitized_key)

key-value/key-value-aio/src/key_value/aio/stores/memcached/store.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from key_value.shared.utils.compound import compound_key
66
from key_value.shared.utils.managed_entry import ManagedEntry
7+
from key_value.shared.utils.sanitization import HashExcessLengthStrategy
78
from typing_extensions import override
89

910
from key_value.aio.stores.base import BaseContextManagerStore, BaseDestroyStore, BaseStore
@@ -46,7 +47,12 @@ def __init__(
4647
"""
4748
self._client = client or Client(host=host, port=port)
4849

49-
super().__init__(default_collection=default_collection)
50+
sanitization_strategy = HashExcessLengthStrategy(max_length=MAX_KEY_LENGTH)
51+
52+
super().__init__(
53+
default_collection=default_collection,
54+
key_sanitization_strategy=sanitization_strategy,
55+
)
5056

5157
def sanitize_key(self, key: str) -> str:
5258
if len(key) > MAX_KEY_LENGTH:

0 commit comments

Comments
 (0)