Skip to content

Commit d58fe24

Browse files
authored
Implement bulk operations for stores with native batch APIs (#79)
1 parent 4af690a commit d58fe24

File tree

38 files changed

+1432
-262
lines changed

38 files changed

+1432
-262
lines changed

key-value/key-value-aio/src/key_value/aio/stores/base.py

Lines changed: 47 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,15 @@
66
from asyncio.locks import Lock
77
from collections import defaultdict
88
from collections.abc import Mapping, Sequence
9+
from datetime import datetime
910
from types import MappingProxyType, TracebackType
1011
from typing import Any, SupportsFloat
1112

1213
from key_value.shared.constants import DEFAULT_COLLECTION_NAME
1314
from key_value.shared.errors import StoreSetupError
1415
from key_value.shared.type_checking.bear_spray import bear_enforce
1516
from key_value.shared.utils.managed_entry import ManagedEntry
16-
from key_value.shared.utils.time_to_live import now, prepare_ttl
17+
from key_value.shared.utils.time_to_live import prepare_entry_timestamps
1718
from typing_extensions import Self, override
1819

1920
from key_value.aio.protocols.key_value import (
@@ -207,13 +208,36 @@ async def ttl_many(
207208
return [(dict(entry.value), entry.ttl) if entry and not entry.is_expired else (None, None) for entry in entries]
208209

209210
@abstractmethod
210-
async def _put_managed_entry(self, *, collection: str, key: str, managed_entry: ManagedEntry) -> None:
211+
async def _put_managed_entry(
212+
self,
213+
*,
214+
collection: str,
215+
key: str,
216+
managed_entry: ManagedEntry,
217+
) -> None:
211218
"""Store a managed entry by key in the specified collection."""
212219
...
213220

214-
async def _put_managed_entries(self, *, collection: str, keys: Sequence[str], managed_entries: Sequence[ManagedEntry]) -> None:
215-
"""Store multiple managed entries by key in the specified collection."""
221+
async def _put_managed_entries(
222+
self,
223+
*,
224+
collection: str,
225+
keys: Sequence[str],
226+
managed_entries: Sequence[ManagedEntry],
227+
ttl: float | None, # noqa: ARG002
228+
created_at: datetime, # noqa: ARG002
229+
expires_at: datetime | None, # noqa: ARG002
230+
) -> None:
231+
"""Store multiple managed entries by key in the specified collection.
216232
233+
Args:
234+
collection: The collection to store entries in
235+
keys: The keys for the entries
236+
managed_entries: The managed entries to store
237+
ttl: The TTL in seconds (None for no expiration)
238+
created_at: The creation timestamp for all entries
239+
expires_at: The expiration timestamp for all entries (None if no TTL)
240+
"""
217241
for key, managed_entry in zip(keys, managed_entries, strict=True):
218242
await self._put_managed_entry(
219243
collection=collection,
@@ -228,30 +252,16 @@ async def put(self, key: str, value: Mapping[str, Any], *, collection: str | Non
228252
collection = collection or self.default_collection
229253
await self.setup_collection(collection=collection)
230254

231-
managed_entry: ManagedEntry = ManagedEntry(value=value, ttl=prepare_ttl(t=ttl), created_at=now())
255+
created_at, ttl_seconds, expires_at = prepare_entry_timestamps(ttl=ttl)
256+
257+
managed_entry: ManagedEntry = ManagedEntry(value=value, ttl=ttl_seconds, created_at=created_at, expires_at=expires_at)
232258

233259
await self._put_managed_entry(
234260
collection=collection,
235261
key=key,
236262
managed_entry=managed_entry,
237263
)
238264

239-
def _prepare_put_many(
240-
self, *, keys: Sequence[str], values: Sequence[Mapping[str, Any]], ttl: SupportsFloat | None
241-
) -> tuple[Sequence[str], Sequence[Mapping[str, Any]], float | None]:
242-
"""Prepare multiple managed entries for a put_many operation.
243-
244-
Inheriting classes can use this method if they need to modify a put_many operation."""
245-
246-
if len(keys) != len(values):
247-
msg = "put_many called but a different number of keys and values were provided"
248-
raise ValueError(msg) from None
249-
250-
ttl_for_entries: float | None = prepare_ttl(t=ttl)
251-
252-
return (keys, values, ttl_for_entries)
253-
254-
@bear_enforce
255265
@override
256266
async def put_many(
257267
self,
@@ -266,11 +276,24 @@ async def put_many(
266276
collection = collection or self.default_collection
267277
await self.setup_collection(collection=collection)
268278

269-
keys, values, ttl_for_entries = self._prepare_put_many(keys=keys, values=values, ttl=ttl)
279+
if len(keys) != len(values):
280+
msg = "put_many called but a different number of keys and values were provided"
281+
raise ValueError(msg) from None
270282

271-
managed_entries: list[ManagedEntry] = [ManagedEntry(value=value, ttl=ttl_for_entries, created_at=now()) for value in values]
283+
created_at, ttl_seconds, expires_at = prepare_entry_timestamps(ttl=ttl)
272284

273-
await self._put_managed_entries(collection=collection, keys=keys, managed_entries=managed_entries)
285+
managed_entries: list[ManagedEntry] = [
286+
ManagedEntry(value=value, ttl=ttl_seconds, created_at=created_at, expires_at=expires_at) for value in values
287+
]
288+
289+
await self._put_managed_entries(
290+
collection=collection,
291+
keys=keys,
292+
managed_entries=managed_entries,
293+
ttl=ttl_seconds,
294+
created_at=created_at,
295+
expires_at=expires_at,
296+
)
274297

275298
@abstractmethod
276299
async def _delete_managed_entry(self, *, key: str, collection: str) -> bool:

key-value/key-value-aio/src/key_value/aio/stores/elasticsearch/store.py

Lines changed: 140 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
from datetime import datetime # noqa: TC003
1+
from collections.abc import Sequence
2+
from datetime import datetime
23
from typing import Any, overload
34

45
from elastic_transport import ObjectApiResponse # noqa: TC002
5-
from key_value.shared.utils.compound import compound_key
6+
from key_value.shared.errors import DeserializationError
67
from key_value.shared.utils.managed_entry import ManagedEntry, load_from_json
78
from key_value.shared.utils.sanitize import (
89
ALPHANUMERIC_CHARACTERS,
@@ -21,6 +22,7 @@
2122
BaseEnumerateKeysStore,
2223
BaseStore,
2324
)
25+
from key_value.aio.stores.elasticsearch.utils import new_bulk_action
2426

2527
try:
2628
from elasticsearch import AsyncElasticsearch
@@ -71,6 +73,36 @@
7173
ALLOWED_INDEX_CHARACTERS: str = LOWERCASE_ALPHABET + NUMBERS + "_" + "-" + "."
7274

7375

76+
def managed_entry_to_document(collection: str, key: str, managed_entry: ManagedEntry) -> dict[str, Any]:
77+
document: dict[str, Any] = {
78+
"collection": collection,
79+
"key": key,
80+
"value": managed_entry.to_json(include_metadata=False),
81+
}
82+
83+
if managed_entry.created_at:
84+
document["created_at"] = managed_entry.created_at.isoformat()
85+
if managed_entry.expires_at:
86+
document["expires_at"] = managed_entry.expires_at.isoformat()
87+
88+
return document
89+
90+
91+
def source_to_managed_entry(source: dict[str, Any]) -> ManagedEntry:
92+
if not (value_str := source.get("value")) or not isinstance(value_str, str):
93+
msg = "Value is not a string"
94+
raise DeserializationError(msg)
95+
96+
created_at: datetime | None = try_parse_datetime_str(value=source.get("created_at"))
97+
expires_at: datetime | None = try_parse_datetime_str(value=source.get("expires_at"))
98+
99+
return ManagedEntry(
100+
value=load_from_json(value_str),
101+
created_at=created_at,
102+
expires_at=expires_at,
103+
)
104+
105+
74106
class ElasticsearchStore(
75107
BaseEnumerateCollectionsStore, BaseEnumerateKeysStore, BaseDestroyCollectionStore, BaseCullStore, BaseContextManagerStore, BaseStore
76108
):
@@ -156,13 +188,17 @@ def _sanitize_document_id(self, key: str) -> str:
156188
allowed_characters=ALLOWED_KEY_CHARACTERS,
157189
)
158190

191+
def _get_destination(self, *, collection: str, key: str) -> tuple[str, str]:
192+
index_name: str = self._sanitize_index_name(collection=collection)
193+
document_id: str = self._sanitize_document_id(key=key)
194+
195+
return index_name, document_id
196+
159197
@override
160198
async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry | None:
161-
combo_key: str = compound_key(collection=collection, key=key)
199+
index_name, document_id = self._get_destination(collection=collection, key=key)
162200

163-
elasticsearch_response = await self._client.options(ignore_status=404).get(
164-
index=self._sanitize_index_name(collection=collection), id=self._sanitize_document_id(key=combo_key)
165-
)
201+
elasticsearch_response = await self._client.options(ignore_status=404).get(index=index_name, id=document_id)
166202

167203
body: dict[str, Any] = get_body_from_response(response=elasticsearch_response)
168204

@@ -181,6 +217,39 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry
181217
expires_at=expires_at,
182218
)
183219

220+
@override
221+
async def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) -> list[ManagedEntry | None]:
222+
if not keys:
223+
return []
224+
225+
# Use mget for efficient batch retrieval
226+
index_name = self._sanitize_index_name(collection=collection)
227+
document_ids = [self._sanitize_document_id(key=key) for key in keys]
228+
docs = [{"_id": document_id} for document_id in document_ids]
229+
230+
elasticsearch_response = await self._client.options(ignore_status=404).mget(index=index_name, docs=docs)
231+
232+
body: dict[str, Any] = get_body_from_response(response=elasticsearch_response)
233+
docs_result = body.get("docs", [])
234+
235+
entries_by_id: dict[str, ManagedEntry | None] = {}
236+
for doc in docs_result:
237+
if not (doc_id := doc.get("_id")):
238+
continue
239+
240+
if "found" not in doc:
241+
entries_by_id[doc_id] = None
242+
continue
243+
244+
if not (source := doc.get("_source")):
245+
entries_by_id[doc_id] = None
246+
continue
247+
248+
entries_by_id[doc_id] = source_to_managed_entry(source=source)
249+
250+
# Return entries in the same order as input keys
251+
return [entries_by_id.get(document_id) for document_id in document_ids]
252+
184253
@property
185254
def _should_refresh_on_put(self) -> bool:
186255
return not self._is_serverless
@@ -193,32 +262,54 @@ async def _put_managed_entry(
193262
collection: str,
194263
managed_entry: ManagedEntry,
195264
) -> None:
196-
combo_key: str = compound_key(collection=collection, key=key)
265+
index_name: str = self._sanitize_index_name(collection=collection)
266+
document_id: str = self._sanitize_document_id(key=key)
197267

198-
document: dict[str, Any] = {
199-
"collection": collection,
200-
"key": key,
201-
"value": managed_entry.to_json(include_metadata=False),
202-
}
203-
204-
if managed_entry.created_at:
205-
document["created_at"] = managed_entry.created_at.isoformat()
206-
if managed_entry.expires_at:
207-
document["expires_at"] = managed_entry.expires_at.isoformat()
268+
document: dict[str, Any] = managed_entry_to_document(collection=collection, key=key, managed_entry=managed_entry)
208269

209270
_ = await self._client.index(
210-
index=self._sanitize_index_name(collection=collection),
211-
id=self._sanitize_document_id(key=combo_key),
271+
index=index_name,
272+
id=document_id,
212273
body=document,
213274
refresh=self._should_refresh_on_put,
214275
)
215276

277+
@override
278+
async def _put_managed_entries(
279+
self,
280+
*,
281+
collection: str,
282+
keys: Sequence[str],
283+
managed_entries: Sequence[ManagedEntry],
284+
ttl: float | None,
285+
created_at: datetime,
286+
expires_at: datetime | None,
287+
) -> None:
288+
if not keys:
289+
return
290+
291+
operations: list[dict[str, Any]] = []
292+
293+
index_name: str = self._sanitize_index_name(collection=collection)
294+
295+
for key, managed_entry in zip(keys, managed_entries, strict=True):
296+
document_id: str = self._sanitize_document_id(key=key)
297+
298+
index_action: dict[str, Any] = new_bulk_action(action="index", index=index_name, document_id=document_id)
299+
300+
document: dict[str, Any] = managed_entry_to_document(collection=collection, key=key, managed_entry=managed_entry)
301+
302+
operations.extend([index_action, document])
303+
304+
_ = await self._client.bulk(operations=operations, refresh=self._should_refresh_on_put) # pyright: ignore[reportUnknownMemberType]
305+
216306
@override
217307
async def _delete_managed_entry(self, *, key: str, collection: str) -> bool:
218-
combo_key: str = compound_key(collection=collection, key=key)
308+
index_name: str = self._sanitize_index_name(collection=collection)
309+
document_id: str = self._sanitize_document_id(key=key)
219310

220311
elasticsearch_response: ObjectApiResponse[Any] = await self._client.options(ignore_status=404).delete(
221-
index=self._sanitize_index_name(collection=collection), id=self._sanitize_document_id(key=combo_key)
312+
index=index_name, id=document_id
222313
)
223314

224315
body: dict[str, Any] = get_body_from_response(response=elasticsearch_response)
@@ -228,6 +319,34 @@ async def _delete_managed_entry(self, *, key: str, collection: str) -> bool:
228319

229320
return result == "deleted"
230321

322+
@override
323+
async def _delete_managed_entries(self, *, keys: Sequence[str], collection: str) -> int:
324+
if not keys:
325+
return 0
326+
327+
operations: list[dict[str, Any]] = []
328+
329+
for key in keys:
330+
index_name, document_id = self._get_destination(collection=collection, key=key)
331+
332+
delete_action: dict[str, Any] = new_bulk_action(action="delete", index=index_name, document_id=document_id)
333+
334+
operations.append(delete_action)
335+
336+
elasticsearch_response = await self._client.bulk(operations=operations) # pyright: ignore[reportUnknownMemberType]
337+
338+
body: dict[str, Any] = get_body_from_response(response=elasticsearch_response)
339+
340+
# Count successful deletions
341+
deleted_count = 0
342+
items = body.get("items", [])
343+
for item in items:
344+
delete_result = item.get("delete", {})
345+
if delete_result.get("result") == "deleted":
346+
deleted_count += 1
347+
348+
return deleted_count
349+
231350
@override
232351
async def _get_collection_keys(self, *, collection: str, limit: int | None = None) -> list[str]:
233352
"""Get up to 10,000 keys in the specified collection (eventually consistent)."""

key-value/key-value-aio/src/key_value/aio/stores/elasticsearch/utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Any, TypeVar, cast
22

33
from elastic_transport import ObjectApiResponse
4+
from key_value.shared.utils.managed_entry import ManagedEntry
45

56

67
def get_body_from_response(response: ObjectApiResponse[Any]) -> dict[str, Any]:
@@ -105,3 +106,22 @@ def get_first_value_from_field_in_hit(hit: dict[str, Any], field: str, value_typ
105106
msg: str = f"Field {field} in hit {hit} is not a single value"
106107
raise TypeError(msg)
107108
return values[0]
109+
110+
111+
def managed_entry_to_document(collection: str, key: str, managed_entry: ManagedEntry) -> dict[str, Any]:
112+
document: dict[str, Any] = {
113+
"collection": collection,
114+
"key": key,
115+
"value": managed_entry.to_json(include_metadata=False),
116+
}
117+
118+
if managed_entry.created_at:
119+
document["created_at"] = managed_entry.created_at.isoformat()
120+
if managed_entry.expires_at:
121+
document["expires_at"] = managed_entry.expires_at.isoformat()
122+
123+
return document
124+
125+
126+
def new_bulk_action(action: str, index: str, document_id: str) -> dict[str, Any]:
127+
return {action: {"_index": index, "_id": document_id}}

0 commit comments

Comments
 (0)