From ef3da6eda2147387cbda07908e90b058fbdf5d4e Mon Sep 17 00:00:00 2001 From: Quentin Pradet Date: Thu, 8 Aug 2024 01:05:48 +0200 Subject: [PATCH] Add optional Arrow deserialization support (#2632) --- docs/guide/configuration.asciidoc | 2 +- elasticsearch/serializer.py | 34 +++++++++++++++++++ noxfile.py | 2 +- pyproject.toml | 2 ++ .../test_client/test_deprecated_options.py | 2 ++ .../test_client/test_serializers.py | 3 ++ test_elasticsearch/test_serializer.py | 27 ++++++++++++++- 7 files changed, 69 insertions(+), 3 deletions(-) diff --git a/docs/guide/configuration.asciidoc b/docs/guide/configuration.asciidoc index dc934ba61..15c3f413c 100644 --- a/docs/guide/configuration.asciidoc +++ b/docs/guide/configuration.asciidoc @@ -359,7 +359,7 @@ The calculation is equal to `min(dead_node_backoff_factor * (2 ** (consecutive_f [[serializer]] === Serializers -Serializers transform bytes on the wire into native Python objects and vice-versa. By default the client ships with serializers for `application/json`, `application/x-ndjson`, `text/*`, and `application/mapbox-vector-tile`. +Serializers transform bytes on the wire into native Python objects and vice-versa. By default the client ships with serializers for `application/json`, `application/x-ndjson`, `text/*`, `application/vnd.apache.arrow.stream` and `application/mapbox-vector-tile`. You can define custom serializers via the `serializers` parameter: diff --git a/elasticsearch/serializer.py b/elasticsearch/serializer.py index 37ad5724c..c281f1348 100644 --- a/elasticsearch/serializer.py +++ b/elasticsearch/serializer.py @@ -49,6 +49,14 @@ _OrjsonSerializer = None # type: ignore[assignment,misc] +try: + import pyarrow as pa + + __all__.append("PyArrowSerializer") +except ImportError: + pa = None + + class JsonSerializer(_JsonSerializer): mimetype: ClassVar[str] = "application/json" @@ -114,6 +122,29 @@ def dumps(self, data: bytes) -> bytes: raise SerializationError(f"Cannot serialize {data!r} into a MapBox vector tile") +if pa is not None: + + class PyArrowSerializer(Serializer): + """PyArrow serializer for deserializing Arrow Stream data.""" + + mimetype: ClassVar[str] = "application/vnd.apache.arrow.stream" + + def loads(self, data: bytes) -> pa.Table: + try: + with pa.ipc.open_stream(data) as reader: + return reader.read_all() + except pa.ArrowException as e: + raise SerializationError( + message=f"Unable to deserialize as Arrow stream: {data!r}", + errors=(e,), + ) + + def dumps(self, data: Any) -> bytes: + raise SerializationError( + message="Elasticsearch does not accept Arrow input data" + ) + + DEFAULT_SERIALIZERS: Dict[str, Serializer] = { JsonSerializer.mimetype: JsonSerializer(), MapboxVectorTileSerializer.mimetype: MapboxVectorTileSerializer(), @@ -122,6 +153,9 @@ def dumps(self, data: bytes) -> bytes: CompatibilityModeNdjsonSerializer.mimetype: CompatibilityModeNdjsonSerializer(), } +if pa is not None: + DEFAULT_SERIALIZERS[PyArrowSerializer.mimetype] = PyArrowSerializer() + # Alias for backwards compatibility JSONSerializer = JsonSerializer diff --git a/noxfile.py b/noxfile.py index 69e53417f..600120bb3 100644 --- a/noxfile.py +++ b/noxfile.py @@ -94,7 +94,7 @@ def lint(session): session.run("flake8", *SOURCE_FILES) session.run("python", "utils/license-headers.py", "check", *SOURCE_FILES) - session.install(".[async,requests,orjson,vectorstore_mmr]", env=INSTALL_ENV) + session.install(".[async,requests,orjson,pyarrow,vectorstore_mmr]", env=INSTALL_ENV) # Run mypy on the package and then the type examples separately for # the two different mypy use-cases, ourselves and our users. diff --git a/pyproject.toml b/pyproject.toml index 2a35c51f0..ba83d6329 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,7 @@ dependencies = [ async = ["aiohttp>=3,<4"] requests = ["requests>=2.4.0, !=2.32.2, <3.0.0"] orjson = ["orjson>=3"] +pyarrow = ["pyarrow>=1"] # Maximal Marginal Relevance (MMR) for search results vectorstore_mmr = ["numpy>=1", "simsimd>=3"] dev = [ @@ -69,6 +70,7 @@ dev = [ "orjson", "numpy", "simsimd", + "pyarrow", "pandas", "mapbox-vector-tile", ] diff --git a/test_elasticsearch/test_client/test_deprecated_options.py b/test_elasticsearch/test_client/test_deprecated_options.py index dd1016bb9..810e75cf4 100644 --- a/test_elasticsearch/test_client/test_deprecated_options.py +++ b/test_elasticsearch/test_client/test_deprecated_options.py @@ -134,6 +134,7 @@ class CustomSerializer(JsonSerializer): "application/x-ndjson", "application/json", "text/*", + "application/vnd.apache.arrow.stream", "application/vnd.elasticsearch+json", "application/vnd.elasticsearch+x-ndjson", } @@ -154,6 +155,7 @@ class CustomSerializer(JsonSerializer): "application/x-ndjson", "application/json", "text/*", + "application/vnd.apache.arrow.stream", "application/vnd.elasticsearch+json", "application/vnd.elasticsearch+x-ndjson", "application/cbor", diff --git a/test_elasticsearch/test_client/test_serializers.py b/test_elasticsearch/test_client/test_serializers.py index fa1ea362c..9d13386ed 100644 --- a/test_elasticsearch/test_client/test_serializers.py +++ b/test_elasticsearch/test_client/test_serializers.py @@ -94,6 +94,7 @@ class CustomSerializer: "application/json", "text/*", "application/x-ndjson", + "application/vnd.apache.arrow.stream", "application/vnd.mapbox-vector-tile", "application/vnd.elasticsearch+json", "application/vnd.elasticsearch+x-ndjson", @@ -121,6 +122,7 @@ class CustomSerializer: "application/json", "text/*", "application/x-ndjson", + "application/vnd.apache.arrow.stream", "application/vnd.mapbox-vector-tile", "application/vnd.elasticsearch+json", "application/vnd.elasticsearch+x-ndjson", @@ -140,6 +142,7 @@ class CustomSerializer: "application/json", "text/*", "application/x-ndjson", + "application/vnd.apache.arrow.stream", "application/vnd.mapbox-vector-tile", "application/vnd.elasticsearch+json", "application/vnd.elasticsearch+x-ndjson", diff --git a/test_elasticsearch/test_serializer.py b/test_elasticsearch/test_serializer.py index ba5f1adec..02723e8f4 100644 --- a/test_elasticsearch/test_serializer.py +++ b/test_elasticsearch/test_serializer.py @@ -19,6 +19,7 @@ from datetime import datetime from decimal import Decimal +import pyarrow as pa import pytest try: @@ -31,7 +32,12 @@ from elasticsearch import Elasticsearch from elasticsearch.exceptions import SerializationError -from elasticsearch.serializer import JSONSerializer, OrjsonSerializer, TextSerializer +from elasticsearch.serializer import ( + JSONSerializer, + OrjsonSerializer, + PyArrowSerializer, + TextSerializer, +) requires_numpy_and_pandas = pytest.mark.skipif( np is None or pd is None, reason="Test requires numpy and pandas to be available" @@ -157,6 +163,25 @@ def test_serializes_pandas_category(json_serializer): assert b'{"d":[1,2,3]}' == json_serializer.dumps({"d": cat}) +def test_pyarrow_loads(): + data = [ + pa.array([1, 2, 3, 4]), + pa.array(["foo", "bar", "baz", None]), + pa.array([True, None, False, True]), + ] + batch = pa.record_batch(data, names=["f0", "f1", "f2"]) + sink = pa.BufferOutputStream() + with pa.ipc.new_stream(sink, batch.schema) as writer: + writer.write_batch(batch) + + serializer = PyArrowSerializer() + assert serializer.loads(sink.getvalue()).to_pydict() == { + "f0": [1, 2, 3, 4], + "f1": ["foo", "bar", "baz", None], + "f2": [True, None, False, True], + } + + def test_json_raises_serialization_error_on_dump_error(json_serializer): with pytest.raises(SerializationError): json_serializer.dumps(object())