Skip to content

Commit 7a3edc9

Browse files
feat: add native storage mode for Elasticsearch (#149)
Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com> Co-authored-by: William Easton <strawgate@users.noreply.github.com>
1 parent a694fe8 commit 7a3edc9

File tree

14 files changed

+664
-130
lines changed

14 files changed

+664
-130
lines changed

key-value/key-value-aio/src/key_value/aio/stores/elasticsearch/store.py

Lines changed: 98 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22
from datetime import datetime
33
from typing import Any, overload
44

5-
from elastic_transport import ObjectApiResponse # noqa: TC002
6-
from key_value.shared.errors import DeserializationError
7-
from key_value.shared.utils.managed_entry import ManagedEntry, load_from_json
5+
from elastic_transport import ObjectApiResponse
6+
from elastic_transport import SerializationError as ElasticsearchSerializationError
7+
from key_value.shared.errors import DeserializationError, SerializationError
8+
from key_value.shared.utils.managed_entry import ManagedEntry, load_from_json, verify_dict
89
from key_value.shared.utils.sanitize import (
910
ALPHANUMERIC_CHARACTERS,
1011
LOWERCASE_ALPHABET,
@@ -22,7 +23,7 @@
2223
BaseEnumerateKeysStore,
2324
BaseStore,
2425
)
25-
from key_value.aio.stores.elasticsearch.utils import new_bulk_action
26+
from key_value.aio.stores.elasticsearch.utils import LessCapableJsonSerializer, LessCapableNdjsonSerializer, new_bulk_action
2627

2728
try:
2829
from elasticsearch import AsyncElasticsearch
@@ -55,10 +56,17 @@
5556
"type": "keyword",
5657
},
5758
"value": {
58-
"type": "keyword",
59-
"index": False,
60-
"doc_values": False,
61-
"ignore_above": 256,
59+
"properties": {
60+
# You might think the `string` field should be a text/keyword field
61+
# but this is the recommended mapping for large stringified json
62+
"string": {
63+
"type": "object",
64+
"enabled": False,
65+
},
66+
"flattened": {
67+
"type": "flattened",
68+
},
69+
},
6270
},
6371
},
6472
}
@@ -73,12 +81,14 @@
7381
ALLOWED_INDEX_CHARACTERS: str = LOWERCASE_ALPHABET + NUMBERS + "_" + "-" + "."
7482

7583

76-
def managed_entry_to_document(collection: str, key: str, managed_entry: ManagedEntry) -> dict[str, Any]:
77-
document: dict[str, Any] = {
78-
"collection": collection,
79-
"key": key,
80-
"value": managed_entry.to_json(include_metadata=False),
81-
}
84+
def managed_entry_to_document(collection: str, key: str, managed_entry: ManagedEntry, *, native_storage: bool = False) -> dict[str, Any]:
85+
document: dict[str, Any] = {"collection": collection, "key": key, "value": {}}
86+
87+
# Store in appropriate field based on mode
88+
if native_storage:
89+
document["value"]["flattened"] = managed_entry.value_as_dict
90+
else:
91+
document["value"]["string"] = managed_entry.value_as_json
8292

8393
if managed_entry.created_at:
8494
document["created_at"] = managed_entry.created_at.isoformat()
@@ -89,15 +99,31 @@ def managed_entry_to_document(collection: str, key: str, managed_entry: ManagedE
8999

90100

91101
def source_to_managed_entry(source: dict[str, Any]) -> ManagedEntry:
92-
if not (value_str := source.get("value")) or not isinstance(value_str, str):
93-
msg = "Value is not a string"
102+
value: dict[str, Any] = {}
103+
104+
raw_value = source.get("value")
105+
106+
# Try flattened field first, fall back to string field
107+
if not raw_value or not isinstance(raw_value, dict):
108+
msg = "Value field not found or invalid type"
109+
raise DeserializationError(msg)
110+
111+
if value_flattened := raw_value.get("flattened"): # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType]
112+
value = verify_dict(obj=value_flattened)
113+
elif value_str := raw_value.get("string"): # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType]
114+
if not isinstance(value_str, str):
115+
msg = "Value in `value` field is not a string"
116+
raise DeserializationError(msg)
117+
value = load_from_json(value_str)
118+
else:
119+
msg = "Value field not found or invalid type"
94120
raise DeserializationError(msg)
95121

96122
created_at: datetime | None = try_parse_datetime_str(value=source.get("created_at"))
97123
expires_at: datetime | None = try_parse_datetime_str(value=source.get("expires_at"))
98124

99125
return ManagedEntry(
100-
value=load_from_json(value_str),
126+
value=value,
101127
created_at=created_at,
102128
expires_at=expires_at,
103129
)
@@ -114,11 +140,28 @@ class ElasticsearchStore(
114140

115141
_index_prefix: str
116142

143+
_native_storage: bool
144+
117145
@overload
118-
def __init__(self, *, elasticsearch_client: AsyncElasticsearch, index_prefix: str, default_collection: str | None = None) -> None: ...
146+
def __init__(
147+
self,
148+
*,
149+
elasticsearch_client: AsyncElasticsearch,
150+
index_prefix: str,
151+
native_storage: bool = True,
152+
default_collection: str | None = None,
153+
) -> None: ...
119154

120155
@overload
121-
def __init__(self, *, url: str, api_key: str | None = None, index_prefix: str, default_collection: str | None = None) -> None: ...
156+
def __init__(
157+
self,
158+
*,
159+
url: str,
160+
api_key: str | None = None,
161+
index_prefix: str,
162+
native_storage: bool = True,
163+
default_collection: str | None = None,
164+
) -> None: ...
122165

123166
def __init__(
124167
self,
@@ -127,6 +170,7 @@ def __init__(
127170
url: str | None = None,
128171
api_key: str | None = None,
129172
index_prefix: str,
173+
native_storage: bool = True,
130174
default_collection: str | None = None,
131175
) -> None:
132176
"""Initialize the elasticsearch store.
@@ -136,6 +180,8 @@ def __init__(
136180
url: The url of the elasticsearch cluster.
137181
api_key: The api key to use.
138182
index_prefix: The index prefix to use. Collections will be prefixed with this prefix.
183+
native_storage: Whether to use native storage mode (flattened field type) or serialize
184+
all values to JSON strings. Defaults to True.
139185
default_collection: The default collection to use if no collection is provided.
140186
"""
141187
if elasticsearch_client is None and url is None:
@@ -152,7 +198,12 @@ def __init__(
152198
msg = "Either elasticsearch_client or url must be provided"
153199
raise ValueError(msg)
154200

201+
LessCapableJsonSerializer.install_serializer(client=self._client)
202+
LessCapableJsonSerializer.install_default_serializer(client=self._client)
203+
LessCapableNdjsonSerializer.install_serializer(client=self._client)
204+
155205
self._index_prefix = index_prefix
206+
self._native_storage = native_storage
156207
self._is_serverless = False
157208

158209
super().__init__(default_collection=default_collection)
@@ -205,18 +256,11 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry
205256
if not (source := get_source_from_body(body=body)):
206257
return None
207258

208-
if not (value_str := source.get("value")) or not isinstance(value_str, str):
259+
try:
260+
return source_to_managed_entry(source=source)
261+
except DeserializationError:
209262
return None
210263

211-
created_at: datetime | None = try_parse_datetime_str(value=source.get("created_at"))
212-
expires_at: datetime | None = try_parse_datetime_str(value=source.get("expires_at"))
213-
214-
return ManagedEntry(
215-
value=load_from_json(value_str),
216-
created_at=created_at,
217-
expires_at=expires_at,
218-
)
219-
220264
@override
221265
async def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) -> list[ManagedEntry | None]:
222266
if not keys:
@@ -265,15 +309,23 @@ async def _put_managed_entry(
265309
index_name: str = self._sanitize_index_name(collection=collection)
266310
document_id: str = self._sanitize_document_id(key=key)
267311

268-
document: dict[str, Any] = managed_entry_to_document(collection=collection, key=key, managed_entry=managed_entry)
269-
270-
_ = await self._client.index(
271-
index=index_name,
272-
id=document_id,
273-
body=document,
274-
refresh=self._should_refresh_on_put,
312+
document: dict[str, Any] = managed_entry_to_document(
313+
collection=collection, key=key, managed_entry=managed_entry, native_storage=self._native_storage
275314
)
276315

316+
try:
317+
_ = await self._client.index(
318+
index=index_name,
319+
id=document_id,
320+
body=document,
321+
refresh=self._should_refresh_on_put,
322+
)
323+
except ElasticsearchSerializationError as e:
324+
msg = f"Failed to serialize document: {e}"
325+
raise SerializationError(message=msg) from e
326+
except Exception:
327+
raise
328+
277329
@override
278330
async def _put_managed_entries(
279331
self,
@@ -297,11 +349,18 @@ async def _put_managed_entries(
297349

298350
index_action: dict[str, Any] = new_bulk_action(action="index", index=index_name, document_id=document_id)
299351

300-
document: dict[str, Any] = managed_entry_to_document(collection=collection, key=key, managed_entry=managed_entry)
352+
document: dict[str, Any] = managed_entry_to_document(
353+
collection=collection, key=key, managed_entry=managed_entry, native_storage=self._native_storage
354+
)
301355

302356
operations.extend([index_action, document])
303-
304-
_ = await self._client.bulk(operations=operations, refresh=self._should_refresh_on_put) # pyright: ignore[reportUnknownMemberType]
357+
try:
358+
_ = await self._client.bulk(operations=operations, refresh=self._should_refresh_on_put) # pyright: ignore[reportUnknownMemberType]
359+
except ElasticsearchSerializationError as e:
360+
msg = f"Failed to serialize bulk operations: {e}"
361+
raise SerializationError(message=msg) from e
362+
except Exception:
363+
raise
305364

306365
@override
307366
async def _delete_managed_entry(self, *, key: str, collection: str) -> bool:

key-value/key-value-aio/src/key_value/aio/stores/elasticsearch/utils.py

Lines changed: 56 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
1-
from typing import Any, TypeVar, cast
1+
from typing import Any, ClassVar, TypeVar, cast
22

3-
from elastic_transport import ObjectApiResponse
4-
from key_value.shared.utils.managed_entry import ManagedEntry
3+
from elastic_transport import (
4+
JsonSerializer,
5+
NdjsonSerializer,
6+
ObjectApiResponse,
7+
SerializationError,
8+
)
9+
10+
from elasticsearch import AsyncElasticsearch
511

612

713
def get_body_from_response(response: ObjectApiResponse[Any]) -> dict[str, Any]:
@@ -28,7 +34,10 @@ def get_aggregations_from_body(body: dict[str, Any]) -> dict[str, Any]:
2834
if not (aggregations := body.get("aggregations")):
2935
return {}
3036

31-
if not isinstance(aggregations, dict) or not all(isinstance(key, str) for key in aggregations): # pyright: ignore[reportUnknownVariableType]
37+
if not isinstance(aggregations, dict) or not all(
38+
isinstance(key, str)
39+
for key in aggregations # pyright: ignore[reportUnknownVariableType]
40+
):
3241
return {}
3342

3443
return cast("dict[str, Any]", aggregations)
@@ -108,20 +117,50 @@ def get_first_value_from_field_in_hit(hit: dict[str, Any], field: str, value_typ
108117
return values[0]
109118

110119

111-
def managed_entry_to_document(collection: str, key: str, managed_entry: ManagedEntry) -> dict[str, Any]:
112-
document: dict[str, Any] = {
113-
"collection": collection,
114-
"key": key,
115-
"value": managed_entry.to_json(include_metadata=False),
116-
}
120+
def new_bulk_action(action: str, index: str, document_id: str) -> dict[str, Any]:
121+
return {action: {"_index": index, "_id": document_id}}
122+
117123

118-
if managed_entry.created_at:
119-
document["created_at"] = managed_entry.created_at.isoformat()
120-
if managed_entry.expires_at:
121-
document["expires_at"] = managed_entry.expires_at.isoformat()
124+
class LessCapableJsonSerializer(JsonSerializer):
125+
"""A JSON Serializer that doesnt try to be smart with datetime, floats, etc."""
122126

123-
return document
127+
mimetype: ClassVar[str] = "application/json"
128+
compatibility_mimetype: ClassVar[str] = "application/vnd.elasticsearch+json"
124129

130+
def default(self, data: Any) -> Any:
131+
raise SerializationError(
132+
message=f"Unable to serialize to JSON: {data!r} (type: {type(data).__name__})",
133+
)
125134

126-
def new_bulk_action(action: str, index: str, document_id: str) -> dict[str, Any]:
127-
return {action: {"_index": index, "_id": document_id}}
135+
@classmethod
136+
def install_default_serializer(cls, client: AsyncElasticsearch) -> None:
137+
cls.install_serializer(client=client)
138+
client.transport.serializers.default_serializer = cls()
139+
140+
@classmethod
141+
def install_serializer(cls, client: AsyncElasticsearch) -> None:
142+
client.transport.serializers.serializers.update(
143+
{
144+
cls.mimetype: cls(),
145+
cls.compatibility_mimetype: cls(),
146+
}
147+
)
148+
149+
150+
class LessCapableNdjsonSerializer(NdjsonSerializer):
151+
"""A NDJSON Serializer that doesnt try to be smart with datetime, floats, etc."""
152+
153+
mimetype: ClassVar[str] = "application/x-ndjson"
154+
compatibility_mimetype: ClassVar[str] = "application/vnd.elasticsearch+x-ndjson"
155+
156+
def default(self, data: Any) -> Any:
157+
return LessCapableJsonSerializer.default(self=self, data=data) # pyright: ignore[reportCallIssue, reportUnknownVariableType, reportArgumentType]
158+
159+
@classmethod
160+
def install_serializer(cls, client: AsyncElasticsearch) -> None:
161+
client.transport.serializers.serializers.update(
162+
{
163+
cls.mimetype: cls(),
164+
cls.compatibility_mimetype: cls(),
165+
}
166+
)

0 commit comments

Comments
 (0)