From b57cd31fc49063dcb569f6047ad6d544bd31daa0 Mon Sep 17 00:00:00 2001 From: Sebastian Goodman <164915775+seagrine@users.noreply.github.com> Date: Tue, 12 Nov 2024 03:44:42 -0800 Subject: [PATCH] Make BulkIndexError and ScanError serializable (#2669) Co-authored-by: Quentin Pradet (cherry picked from commit 08addf2255fa0678f58a51e8ee41b9f827dc42eb) --- elasticsearch/helpers/errors.py | 16 ++++++++++++---- test_elasticsearch/test_helpers.py | 17 +++++++++++++++++ 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/elasticsearch/helpers/errors.py b/elasticsearch/helpers/errors.py index 359fe87b1..4814ca581 100644 --- a/elasticsearch/helpers/errors.py +++ b/elasticsearch/helpers/errors.py @@ -15,18 +15,26 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, List +from typing import Any, Dict, List, Tuple, Type class BulkIndexError(Exception): - def __init__(self, message: Any, errors: List[Dict[str, Any]]): + def __init__(self, message: str, errors: List[Dict[str, Any]]): super().__init__(message) self.errors: List[Dict[str, Any]] = errors + def __reduce__( + self, + ) -> Tuple[Type["BulkIndexError"], Tuple[str, List[Dict[str, Any]]]]: + return (self.__class__, (self.args[0], self.errors)) + class ScanError(Exception): scroll_id: str - def __init__(self, scroll_id: str, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) + def __init__(self, scroll_id: str, *args: Any) -> None: + super().__init__(*args) self.scroll_id = scroll_id + + def __reduce__(self) -> Tuple[Type["ScanError"], Tuple[str, str]]: + return (self.__class__, (self.scroll_id,) + self.args) diff --git a/test_elasticsearch/test_helpers.py b/test_elasticsearch/test_helpers.py index c9284afc5..e30635f44 100644 --- a/test_elasticsearch/test_helpers.py +++ b/test_elasticsearch/test_helpers.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. +import pickle import threading import time from unittest import mock @@ -182,3 +183,19 @@ class TestExpandActions: @pytest.mark.parametrize("action", ["whatever", b"whatever"]) def test_string_actions_are_marked_as_simple_inserts(self, action): assert ({"index": {}}, b"whatever") == helpers.expand_action(action) + + +def test_serialize_bulk_index_error(): + error = helpers.BulkIndexError("message", [{"error": 1}]) + pickled = pickle.loads(pickle.dumps(error)) + assert pickled.__class__ == helpers.BulkIndexError + assert pickled.errors == error.errors + assert pickled.args == error.args + + +def test_serialize_scan_error(): + error = helpers.ScanError("scroll_id", "shard_message") + pickled = pickle.loads(pickle.dumps(error)) + assert pickled.__class__ == helpers.ScanError + assert pickled.scroll_id == error.scroll_id + assert pickled.args == error.args