diff --git a/config.yml b/config.yml index 35159b3ac..5055eb5a2 100644 --- a/config.yml +++ b/config.yml @@ -82,6 +82,14 @@ #elasticsearch.log_level: INFO # # +## Maximum number of times failed Elasticsearch requests are retried, except bulk requests +#elasticsearch.max_retries: 5 +# +# +## Retry interval between failed Elasticsearch requests, except bulk requests +#elasticsearch.retry_interval: 10 +# +# ## ------------------------------- Elasticsearch: Bulk ------------------------ # ## Options for the Bulk API calls behavior - all options can be @@ -120,6 +128,12 @@ #elasticsearch.bulk.concurrent_downloads: 10 # # +## Maximum number of times failed bulk requests are retried +#elasticsearch.bulk.max_retries: 5 +# +# +## Retry interval between failed bulk attempts +#elasticsearch.bulk.retry_interval: 10 ## ------------------------------- Service ---------------------------------- # ## Connector service/framework related configurations diff --git a/connectors/config.py b/connectors/config.py index e42318e41..2e7373708 100644 --- a/connectors/config.py +++ b/connectors/config.py @@ -10,6 +10,9 @@ from connectors.logger import logger +DEFAULT_ELASTICSEARCH_MAX_RETRIES = 5 +DEFAULT_ELASTICSEARCH_RETRY_INTERVAL = 10 + DEFAULT_MAX_FILE_SIZE = 10485760 # 10MB @@ -21,6 +24,7 @@ def load_config(config_file): _nest_configs(nested_yaml_config, key, value) configuration = dict(_merge_dicts(_default_config(), nested_yaml_config)) _ent_search_config(configuration) + return configuration @@ -60,9 +64,12 @@ def _default_config(): "chunk_size": 1000, "max_concurrency": 5, "chunk_max_mem_size": 5, + "max_retries": DEFAULT_ELASTICSEARCH_MAX_RETRIES, + "retry_interval": DEFAULT_ELASTICSEARCH_RETRY_INTERVAL, "concurrent_downloads": 10, - "max_retries": 3, }, + "max_retries": DEFAULT_ELASTICSEARCH_MAX_RETRIES, + "retry_interval": DEFAULT_ELASTICSEARCH_RETRY_INTERVAL, "retry_on_timeout": True, "request_timeout": 120, "max_wait_duration": 120, diff --git a/connectors/es/client.py b/connectors/es/client.py index 30170d6c0..c6569283d 100644 --- a/connectors/es/client.py +++ b/connectors/es/client.py @@ -9,6 +9,7 @@ import time from enum import Enum +from elastic_transport import ConnectionTimeout from elastic_transport.client_utils import url_to_node_config from elasticsearch import ApiError, AsyncElasticsearch, ConflictError from elasticsearch import ( @@ -16,8 +17,16 @@ ) from connectors import __version__ +from connectors.config import ( + DEFAULT_ELASTICSEARCH_MAX_RETRIES, + DEFAULT_ELASTICSEARCH_RETRY_INTERVAL, +) from connectors.logger import logger, set_extra_logger -from connectors.utils import CancellableSleeps +from connectors.utils import ( + CancellableSleeps, + RetryStrategy, + time_to_sleep_between_retries, +) class License(Enum): @@ -41,6 +50,12 @@ def __init__(self, config): use_default_ports_for_scheme=True, ) self._sleeps = CancellableSleeps() + self._retrier = TransientElasticsearchRetrier( + logger, + config.get("max_retries", DEFAULT_ELASTICSEARCH_MAX_RETRIES), + config.get("retry_interval", DEFAULT_ELASTICSEARCH_RETRY_INTERVAL), + ) + options = { "hosts": [self.host], "request_timeout": config.get("request_timeout", 120), @@ -100,7 +115,9 @@ async def has_active_license_enabled(self, license_): Tuple: (boolean if `license_` is enabled and not expired, actual license Elasticsearch is using) """ - license_response = await self.client.license.get() + license_response = await self._retrier.execute_with_retry( + self.client.license.get + ) license_info = license_response.get("license", {}) is_expired = license_info.get("status", "").lower() == "expired" @@ -125,23 +142,9 @@ async def has_active_license_enabled(self, license_): ) async def close(self): + await self._retrier.close() await self.client.close() - async def ping(self): - try: - await self.client.info() - except ApiError as e: - logger.error(f"The server returned a {e.status_code} code") - if e.info is not None and "error" in e.info and "reason" in e.info["error"]: - logger.error(e.info["error"]["reason"]) - return False - except ElasticConnectionError as e: - logger.error("Could not connect to the server") - if e.message is not None: - logger.error(e.message) - return False - return True - async def wait(self): backoff = self.initial_backoff_duration start = time.time() @@ -162,6 +165,81 @@ async def wait(self): await self.close() return False + async def ping(self): + try: + await self.client.info() + except ApiError as e: + logger.error(f"The server returned a {e.status_code} code") + if e.info is not None and "error" in e.info and "reason" in e.info["error"]: + logger.error(e.info["error"]["reason"]) + return False + except ElasticConnectionError as e: + logger.error("Could not connect to the server") + if e.message is not None: + logger.error(e.message) + return False + return True + + +class RetryInterruptedError(Exception): + pass + + +class TransientElasticsearchRetrier: + def __init__( + self, + logger_, + max_retries, + retry_interval, + retry_strategy=RetryStrategy.LINEAR_BACKOFF, + ): + self._logger = logger_ + self._sleeps = CancellableSleeps() + self._keep_retrying = True + self._error_codes_to_retry = [429, 500, 502, 503, 504] + self._max_retries = max_retries + self._retry_interval = retry_interval + self._retry_strategy = retry_strategy + + async def close(self): + self._sleeps.cancel() + self._keep_retrying = False + + async def _sleep(self, retry): + time_to_sleep = time_to_sleep_between_retries( + self._retry_strategy, self._retry_interval, retry + ) + self._logger.debug(f"Attempt {retry}: sleeping for {time_to_sleep}") + await self._sleeps.sleep(time_to_sleep) + + async def execute_with_retry(self, func): + retry = 0 + while self._keep_retrying and retry < self._max_retries: + retry += 1 + try: + result = await func() + + return result + except ConnectionTimeout: + self._logger.debug(f"Attempt {retry}: connection timeout") + + if retry >= self._max_retries: + raise + except ApiError as e: + self._logger.debug( + f"Attempt {retry}: api error with status {e.status_code}" + ) + + if e.status_code not in self._error_codes_to_retry: + raise + if retry >= self._max_retries: + raise + + await self._sleep(retry) + + msg = "Retry operation was interrupted" + raise RetryInterruptedError(msg) + def with_concurrency_control(retries=3): def wrapper(func): diff --git a/connectors/es/management_client.py b/connectors/es/management_client.py index 496ad8709..5c4c91a6b 100644 --- a/connectors/es/management_client.py +++ b/connectors/es/management_client.py @@ -4,6 +4,8 @@ # you may not use this file except in compliance with the Elastic License 2.0. # +from functools import partial + from elasticsearch import ( NotFoundError as ElasticNotFoundError, ) @@ -39,22 +41,33 @@ async def ensure_exists(self, indices=None): for index in indices: logger.debug(f"Checking index {index}") - if not await self.client.indices.exists(index=index): - await self.client.indices.create(index=index) + if not await self._retrier.execute_with_retry( + partial(self.client.indices.exists, index=index) + ): + await self._retrier.execute_with_retry( + partial(self.client.indices.create, index=index) + ) logger.debug(f"Created index {index}") async def create_content_index(self, search_index_name, language_code): settings = Settings(language_code=language_code, analysis_icu=False).to_hash() mappings = Mappings.default_text_fields_mappings(is_connectors_index=True) - return await self.client.indices.create( - index=search_index_name, mappings=mappings, settings=settings + return await self._retrier.execute_with_retry( + partial( + self.client.indices.create, + index=search_index_name, + mappings=mappings, + settings=settings, + ) ) async def ensure_content_index_mappings(self, index, mappings): # open = Match open, non-hidden indices. Also matches any non-hidden data stream. # Content indices are always non-hidden. - response = await self.client.indices.get_mapping(index=index) + response = await self._retrier.execute_with_retry( + partial(self.client.indices.get_mapping, index=index) + ) existing_mappings = response[index].get("mappings", {}) if len(existing_mappings) == 0: @@ -62,11 +75,14 @@ async def ensure_content_index_mappings(self, index, mappings): logger.debug( "Index %s has no mappings or it's empty. Adding mappings...", index ) - await self.client.indices.put_mapping( - index=index, - dynamic=mappings.get("dynamic", False), - dynamic_templates=mappings.get("dynamic_templates", []), - properties=mappings.get("properties", {}), + await self._retrier.execute_with_retry( + partial( + self.client.indices.put_mapping, + index=index, + dynamic=mappings.get("dynamic", False), + dynamic_templates=mappings.get("dynamic_templates", []), + properties=mappings.get("properties", {}), + ) ) logger.debug("Successfully added mappings for index %s", index) else: @@ -82,34 +98,62 @@ async def ensure_ingest_pipeline_exists( self, pipeline_id, version, description, processors ): try: - await self.client.ingest.get_pipeline(id=pipeline_id) + await self._retrier.execute_with_retry( + partial(self.client.ingest.get_pipeline, id=pipeline_id) + ) except ElasticNotFoundError: - await self.client.ingest.put_pipeline( - id=pipeline_id, - version=version, - description=description, - processors=processors, + await self._retrier.execute_with_retry( + partial( + self.client.ingest.put_pipeline, + id=pipeline_id, + version=version, + description=description, + processors=processors, + ) ) async def delete_indices(self, indices): - await self.client.indices.delete(index=indices, ignore_unavailable=True) + await self._retrier.execute_with_retry( + partial(self.client.indices.delete, index=indices, ignore_unavailable=True) + ) async def clean_index(self, index_name): - return await self.client.delete_by_query( - index=index_name, body={"query": {"match_all": {}}}, ignore_unavailable=True + return await self._retrier.execute_with_retry( + partial( + self.client.delete_by_query, + index=index_name, + body={"query": {"match_all": {}}}, + ignore_unavailable=True, + ) ) async def list_indices(self): - return await self.client.indices.stats(index="search-*") + return await self._retrier.execute_with_retry( + partial(self.client.indices.stats, index="search-*") + ) async def index_exists(self, index_name): - return await self.client.indices.exists(index=index_name) + return await self._retrier.execute_with_retry( + partial(self.client.indices.exists, index=index_name) + ) async def upsert(self, _id, index_name, doc): - await self.client.index( - id=_id, - index=index_name, - document=doc, + return await self._retrier.execute_with_retry( + partial( + self.client.index, + id=_id, + index=index_name, + document=doc, + ) + ) + + async def bulk_insert(self, operations, pipeline): + return await self._retrier.execute_with_retry( + partial( + self.client.bulk, + operations=operations, + pipeline=pipeline, + ) ) async def yield_existing_documents_metadata(self, index): diff --git a/connectors/es/sink.py b/connectors/es/sink.py index ca0b68269..ed55e8823 100644 --- a/connectors/es/sink.py +++ b/connectors/es/sink.py @@ -24,13 +24,16 @@ import time from collections import defaultdict +from connectors.config import ( + DEFAULT_ELASTICSEARCH_MAX_RETRIES, + DEFAULT_ELASTICSEARCH_RETRY_INTERVAL, +) from connectors.es.management_client import ESManagementClient from connectors.es.settings import TIMESTAMP_FIELD, Mappings from connectors.filtering.basic_rule import BasicRuleEngine, parse from connectors.logger import logger, tracer from connectors.protocol import Filter, JobType from connectors.utils import ( - DEFAULT_BULK_MAX_RETRIES, DEFAULT_CHUNK_MEM_SIZE, DEFAULT_CHUNK_SIZE, DEFAULT_CONCURRENT_DOWNLOADS, @@ -96,6 +99,7 @@ def __init__( chunk_mem_size, max_concurrency, max_retries, + retry_interval, logger_=None, ): self.client = client @@ -106,6 +110,7 @@ def __init__( self.chunk_mem_size = chunk_mem_size * 1024 * 1024 self.bulk_tasks = ConcurrentTasks(max_concurrency=max_concurrency) self.max_retires = max_retries + self.retry_interval = retry_interval self.indexed_document_count = 0 self.indexed_document_volume = 0 self.deleted_document_count = 0 @@ -130,7 +135,8 @@ def _bulk_op(self, doc, operation=OP_INDEX): @tracer.start_as_current_span("_bulk API call", slow_log=1.0) async def _batch_bulk(self, operations, stats): - @retryable(retries=self.max_retires) + # TODO: make this retry policy work with unified retry strategy + @retryable(retries=self.max_retires, interval=self.retry_interval) async def _bulk_api_call(): return await self.client.client.bulk( operations=operations, pipeline=self.pipeline["name"] @@ -143,7 +149,9 @@ async def _bulk_api_call(): self._logger.debug( f"Task {task_num} - Sending a batch of {len(operations)} ops -- {get_mb_size(operations)}MiB" ) - res = await _bulk_api_call() + + # TODO: retry 429s for individual items here + res = await self.client.bulk_insert(operations, self.pipeline["name"]) if res.get("errors"): for item in res["items"]: for op, data in item.items(): @@ -777,7 +785,10 @@ async def async_bulk( concurrent_downloads = options.get( "concurrent_downloads", DEFAULT_CONCURRENT_DOWNLOADS ) - max_bulk_retries = options.get("max_retries", DEFAULT_BULK_MAX_RETRIES) + max_bulk_retries = options.get("max_retries", DEFAULT_ELASTICSEARCH_MAX_RETRIES) + retry_interval = options.get( + "retry_interval", DEFAULT_ELASTICSEARCH_RETRY_INTERVAL + ) stream = MemQueue(maxsize=queue_size, maxmemsize=queue_mem_size * 1024 * 1024) @@ -807,6 +818,7 @@ async def async_bulk( chunk_mem_size=chunk_mem_size, max_concurrency=max_concurrency, max_retries=max_bulk_retries, + retry_interval=retry_interval, logger_=self._logger, ) self._sink_task = asyncio.create_task(self._sink.run()) diff --git a/connectors/utils.py b/connectors/utils.py index d9fc783fe..c22fbb493 100644 --- a/connectors/utils.py +++ b/connectors/utils.py @@ -35,7 +35,6 @@ DEFAULT_CHUNK_MEM_SIZE = 25 DEFAULT_MAX_CONCURRENCY = 5 DEFAULT_CONCURRENT_DOWNLOADS = 10 -DEFAULT_BULK_MAX_RETRIES = 3 # Regular expression pattern to match a basic email format (no whitespace, valid domain) EMAIL_REGEX_PATTERN = r"^\S+@\S+\.\S+$" diff --git a/tests/es/test_client.py b/tests/es/test_client.py index 806b7093f..2a6ecac50 100644 --- a/tests/es/test_client.py +++ b/tests/es/test_client.py @@ -4,15 +4,18 @@ # you may not use this file except in compliance with the Elastic License 2.0. # import base64 +from functools import cached_property from unittest import mock from unittest.mock import AsyncMock, Mock import pytest -from elasticsearch import ConflictError, ConnectionError +from elasticsearch import ApiError, ConflictError, ConnectionError, ConnectionTimeout from connectors.es.client import ( ESClient, License, + RetryInterruptedError, + TransientElasticsearchRetrier, with_concurrency_control, ) @@ -232,3 +235,113 @@ async def test_es_client_no_server(self): # Execute assert not await es_client.ping() await es_client.close() + + +class TestTransientElasticsearchRetrier: + @cached_property + def logger_mock(self): + return Mock() + + @cached_property + def max_retries(self): + return 5 + + @cached_property + def retry_interval(self): + return 50 + + @pytest.mark.asyncio + async def test_execute_with_retry(self, patch_sleep): + retrier = TransientElasticsearchRetrier( + self.logger_mock, self.max_retries, self.retry_interval + ) + + async def _func(): + pass + + await retrier.execute_with_retry(_func) + + assert patch_sleep.not_called() + + @pytest.mark.asyncio + async def test_execute_with_retry_429_with_recovery(self, patch_sleep): + retrier = TransientElasticsearchRetrier( + self.logger_mock, self.max_retries, self.retry_interval + ) + + # Emulate {nr_failed_requests} failures from Elasticsearch + nr_failed_requests = 2 + + global attempt + attempt = 0 + + async def _func(): + global attempt + + meta_mock = Mock() + meta_mock.status = 429 + + if attempt < nr_failed_requests: + attempt += 1 + raise ApiError(429, meta_mock, "data") + pass + + await retrier.execute_with_retry(_func) + + assert patch_sleep.awaited_exactly(2) + + @pytest.mark.asyncio + async def test_execute_with_retry_429_no_recovery(self, patch_sleep): + retrier = TransientElasticsearchRetrier( + self.logger_mock, self.max_retries, self.retry_interval + ) + + # Emulate failures from Elasticsearch that we cannot recover from + + async def _func(): + meta_mock = Mock() + meta_mock.status = 429 + raise ApiError(429, meta_mock, "data") + + with pytest.raises(ApiError) as e: + await retrier.execute_with_retry(_func) + + assert e is not None + assert patch_sleep.awaited_exactly(self.max_retries) + + @pytest.mark.asyncio + async def test_execute_with_retry_connection_timeout(self, patch_sleep): + retrier = TransientElasticsearchRetrier( + self.logger_mock, self.max_retries, self.retry_interval + ) + + # Emulate failures from Elasticsearch that we cannot recover from + + async def _func(): + msg = ":stop:" + raise ConnectionTimeout(msg) + + with pytest.raises(ConnectionTimeout) as e: + await retrier.execute_with_retry(_func) + + assert e is not None + assert patch_sleep.awaited_exactly(self.max_retries) + + @pytest.mark.asyncio + async def test_execute_with_retry_cancelled_midway(self, patch_sleep): + retrier = TransientElasticsearchRetrier( + self.logger_mock, self.max_retries, self.retry_interval + ) + + # Emulate failures from Elasticsearch that we cannot recover from + + async def _func(): + await retrier.close() + msg = ":stop:" + raise ConnectionTimeout(msg) + + with pytest.raises(RetryInterruptedError) as e: + await retrier.execute_with_retry(_func) + + assert e is not None + assert patch_sleep.not_awaited() diff --git a/tests/es/test_management_client.py b/tests/es/test_management_client.py index ce9849979..0b05df3a6 100644 --- a/tests/es/test_management_client.py +++ b/tests/es/test_management_client.py @@ -5,7 +5,7 @@ # from datetime import datetime from unittest import mock -from unittest.mock import ANY, AsyncMock +from unittest.mock import ANY, AsyncMock, Mock import pytest import pytest_asyncio @@ -120,9 +120,11 @@ async def test_ensure_ingest_pipeline_exists_when_pipeline_do_not_exist( description = "that's a pipeline" processors = ["something"] - es_management_client.client.ingest.get_pipeline.side_effect = ( - ElasticNotFoundError("1", "2", "3") - ) + error_meta = Mock() + error_meta.status = 404 + error = ElasticNotFoundError("1", error_meta, "3") + + es_management_client.client.ingest.get_pipeline.side_effect = error await es_management_client.ensure_ingest_pipeline_exists( pipeline_id, version, description, processors diff --git a/tests/test_sink.py b/tests/test_sink.py index d60426232..16b422272 100644 --- a/tests/test_sink.py +++ b/tests/test_sink.py @@ -11,7 +11,7 @@ from unittest.mock import ANY, AsyncMock, Mock, call import pytest -from elasticsearch import BadRequestError +from elasticsearch import ApiError, BadRequestError from connectors.es import Mappings from connectors.es.management_client import ESManagementClient @@ -1051,6 +1051,7 @@ def test_bulk_populate_stats(res, expected_result): chunk_mem_size=0, max_concurrency=0, max_retries=3, + retry_interval=10, ) sink._populate_stats(deepcopy(STATS), res) @@ -1076,11 +1077,18 @@ async def test_batch_bulk_with_retry(): chunk_mem_size=0, max_concurrency=0, max_retries=3, + retry_interval=10, ) with mock.patch.object(asyncio, "sleep"): # first call raises exception, and the second call succeeds - client.client.bulk = AsyncMock(side_effect=[Exception(), {"items": []}]) + error_meta = Mock() + error_meta.status = 429 + first_call_error = ApiError(429, meta=error_meta, body="error") + second_call_result = {"items": []} + client.client.bulk = AsyncMock( + side_effect=[first_call_error, second_call_result] + ) await sink._batch_bulk([], {OP_INDEX: {}, OP_UPSERT: {}, OP_DELETE: {}}) assert client.client.bulk.await_count == 2 @@ -1183,6 +1191,7 @@ async def test_sink_fetch_doc(): chunk_mem_size=0, max_concurrency=0, max_retries=3, + retry_interval=10, ) doc = await sink.fetch_doc() @@ -1203,6 +1212,7 @@ async def test_force_canceled_sink_fetch_doc(): chunk_mem_size=0, max_concurrency=0, max_retries=3, + retry_interval=10, ) sink.force_cancel() @@ -1223,6 +1233,7 @@ async def test_force_canceled_sink_with_other_errors(patch_logger): chunk_mem_size=0, max_concurrency=0, max_retries=3, + retry_interval=10, ) sink.force_cancel()