From 122ac0b1d3629b3f30b291c05b83f514b8699400 Mon Sep 17 00:00:00 2001 From: Harshal Sheth Date: Tue, 25 Jun 2024 15:49:00 -0700 Subject: [PATCH] feat(ingest): add async batch mode to the rest sink (#10733) --- .../src/datahub/ingestion/graph/client.py | 10 +- .../datahub/ingestion/sink/datahub_rest.py | 82 +++- .../ingestion/source/looker/looker_source.py | 2 +- .../utilities/advanced_thread_executor.py | 231 ---------- .../utilities/backpressure_aware_executor.py | 78 ++++ .../datahub/utilities/partition_executor.py | 404 ++++++++++++++++++ .../test_advanced_thread_executor.py | 128 ------ .../test_backpressure_aware_executor.py | 59 +++ .../unit/utilities/test_partition_executor.py | 150 +++++++ 9 files changed, 766 insertions(+), 378 deletions(-) delete mode 100644 metadata-ingestion/src/datahub/utilities/advanced_thread_executor.py create mode 100644 metadata-ingestion/src/datahub/utilities/backpressure_aware_executor.py create mode 100644 metadata-ingestion/src/datahub/utilities/partition_executor.py delete mode 100644 metadata-ingestion/tests/unit/utilities/test_advanced_thread_executor.py create mode 100644 metadata-ingestion/tests/unit/utilities/test_backpressure_aware_executor.py create mode 100644 metadata-ingestion/tests/unit/utilities/test_partition_executor.py diff --git a/metadata-ingestion/src/datahub/ingestion/graph/client.py b/metadata-ingestion/src/datahub/ingestion/graph/client.py index 795f5c1e4cd9b..252846326b49e 100644 --- a/metadata-ingestion/src/datahub/ingestion/graph/client.py +++ b/metadata-ingestion/src/datahub/ingestion/graph/client.py @@ -214,7 +214,7 @@ def _post_generic(self, url: str, payload_dict: Dict) -> Dict: def _make_rest_sink_config(self) -> "DatahubRestSinkConfig": from datahub.ingestion.sink.datahub_rest import ( DatahubRestSinkConfig, - SyncOrAsync, + RestSinkMode, ) # This is a bit convoluted - this DataHubGraph class is a subclass of DatahubRestEmitter, @@ -222,7 +222,7 @@ def _make_rest_sink_config(self) -> "DatahubRestSinkConfig": # TODO: We should refactor out the multithreading functionality of the sink # into a separate class that can be used by both the sink and the graph client # e.g. a DatahubBulkRestEmitter that both the sink and the graph client use. - return DatahubRestSinkConfig(**self.config.dict(), mode=SyncOrAsync.ASYNC) + return DatahubRestSinkConfig(**self.config.dict(), mode=RestSinkMode.ASYNC) @contextlib.contextmanager def make_rest_sink( @@ -253,14 +253,10 @@ def emit_all( ) -> None: """Emit all items in the iterable using multiple threads.""" + # The context manager also ensures that we raise an error if a failure occurs. with self.make_rest_sink(run_id=run_id) as sink: for item in items: sink.emit_async(item) - if sink.report.failures: - raise OperationalError( - f"Failed to emit {len(sink.report.failures)} records", - info=sink.report.as_obj(), - ) def get_aspect( self, diff --git a/metadata-ingestion/src/datahub/ingestion/sink/datahub_rest.py b/metadata-ingestion/src/datahub/ingestion/sink/datahub_rest.py index dab8e99b797fe..33a8f4a126182 100644 --- a/metadata-ingestion/src/datahub/ingestion/sink/datahub_rest.py +++ b/metadata-ingestion/src/datahub/ingestion/sink/datahub_rest.py @@ -7,7 +7,7 @@ import threading import uuid from enum import auto -from typing import Optional, Union +from typing import List, Optional, Tuple, Union from datahub.cli.cli_utils import set_env_variables_override_config from datahub.configuration.common import ( @@ -16,6 +16,7 @@ OperationalError, ) from datahub.emitter.mcp import MetadataChangeProposalWrapper +from datahub.emitter.mcp_builder import mcps_from_mce from datahub.emitter.rest_emitter import DataHubRestEmitter from datahub.ingestion.api.common import RecordEnvelope, WorkUnit from datahub.ingestion.api.sink import ( @@ -30,7 +31,10 @@ MetadataChangeEvent, MetadataChangeProposal, ) -from datahub.utilities.advanced_thread_executor import PartitionExecutor +from datahub.utilities.partition_executor import ( + BatchPartitionExecutor, + PartitionExecutor, +) from datahub.utilities.perf_timer import PerfTimer from datahub.utilities.server_config_util import set_gms_config @@ -41,18 +45,26 @@ ) -class SyncOrAsync(ConfigEnum): +class RestSinkMode(ConfigEnum): SYNC = auto() ASYNC = auto() + # Uses the new ingestProposalBatch endpoint. Significantly more efficient than the other modes, + # but requires a server version that supports it. + # https://github.com/datahub-project/datahub/pull/10706 + ASYNC_BATCH = auto() + class DatahubRestSinkConfig(DatahubClientConfig): - mode: SyncOrAsync = SyncOrAsync.ASYNC + mode: RestSinkMode = RestSinkMode.ASYNC - # These only apply in async mode. + # These only apply in async modes. max_threads: int = DEFAULT_REST_SINK_MAX_THREADS max_pending_requests: int = 2000 + # Only applies in async batch mode. + max_per_batch: int = 100 + @dataclasses.dataclass class DataHubRestSinkReport(SinkReport): @@ -111,10 +123,20 @@ def __post_init__(self) -> None: set_env_variables_override_config(self.config.server, self.config.token) logger.debug("Setting gms config") set_gms_config(gms_config) - self.executor = PartitionExecutor( - max_workers=self.config.max_threads, - max_pending=self.config.max_pending_requests, - ) + + self.executor: Union[PartitionExecutor, BatchPartitionExecutor] + if self.config.mode == RestSinkMode.ASYNC_BATCH: + self.executor = BatchPartitionExecutor( + max_workers=self.config.max_threads, + max_pending=self.config.max_pending_requests, + process_batch=self._emit_batch_wrapper, + max_per_batch=self.config.max_per_batch, + ) + else: + self.executor = PartitionExecutor( + max_workers=self.config.max_threads, + max_pending=self.config.max_pending_requests, + ) @classmethod def _make_emitter(cls, config: DatahubRestSinkConfig) -> DataHubRestEmitter: @@ -189,6 +211,7 @@ def _write_done_callback( self.report.report_warning({"warning": e.message, "info": e.info}) write_callback.on_failure(record_envelope, e, e.info) else: + logger.exception(f"Failure: {e}", exc_info=e) self.report.report_failure({"e": e}) write_callback.on_failure(record_envelope, Exception(e), {}) @@ -203,6 +226,30 @@ def _emit_wrapper( # TODO: Add timing metrics self.emitter.emit(record) + def _emit_batch_wrapper( + self, + records: List[ + Tuple[ + Union[ + MetadataChangeEvent, + MetadataChangeProposal, + MetadataChangeProposalWrapper, + ], + ] + ], + ) -> None: + events: List[Union[MetadataChangeProposal, MetadataChangeProposalWrapper]] = [] + for record in records: + event = record[0] + if isinstance(event, MetadataChangeEvent): + # Unpack MCEs into MCPs. + mcps = mcps_from_mce(event) + events.extend(mcps) + else: + events.append(event) + + self.emitter.emit_mcps(events) + def write_record_async( self, record_envelope: RecordEnvelope[ @@ -218,7 +265,8 @@ def write_record_async( # should only have a high value if the sink is actually a bottleneck. with self.report.main_thread_blocking_timer: record = record_envelope.record - if self.config.mode == SyncOrAsync.ASYNC: + if self.config.mode == RestSinkMode.ASYNC: + assert isinstance(self.executor, PartitionExecutor) partition_key = _get_partition_key(record_envelope) self.executor.submit( partition_key, @@ -229,6 +277,17 @@ def write_record_async( ), ) self.report.pending_requests += 1 + elif self.config.mode == RestSinkMode.ASYNC_BATCH: + assert isinstance(self.executor, BatchPartitionExecutor) + partition_key = _get_partition_key(record_envelope) + self.executor.submit( + partition_key, + record, + done_callback=functools.partial( + self._write_done_callback, record_envelope, write_callback + ), + ) + self.report.pending_requests += 1 else: # execute synchronously try: @@ -249,7 +308,8 @@ def emit_async( ) def close(self): - self.executor.shutdown() + with self.report.main_thread_blocking_timer: + self.executor.shutdown() def __repr__(self) -> str: return self.emitter.__repr__() diff --git a/metadata-ingestion/src/datahub/ingestion/source/looker/looker_source.py b/metadata-ingestion/src/datahub/ingestion/source/looker/looker_source.py index c87ee1d77f5cd..2277f512def2f 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/looker/looker_source.py +++ b/metadata-ingestion/src/datahub/ingestion/source/looker/looker_source.py @@ -102,7 +102,7 @@ OwnershipTypeClass, SubTypesClass, ) -from datahub.utilities.advanced_thread_executor import BackpressureAwareExecutor +from datahub.utilities.backpressure_aware_executor import BackpressureAwareExecutor logger = logging.getLogger(__name__) diff --git a/metadata-ingestion/src/datahub/utilities/advanced_thread_executor.py b/metadata-ingestion/src/datahub/utilities/advanced_thread_executor.py deleted file mode 100644 index 1e241a9a49e97..0000000000000 --- a/metadata-ingestion/src/datahub/utilities/advanced_thread_executor.py +++ /dev/null @@ -1,231 +0,0 @@ -from __future__ import annotations - -import collections -import concurrent.futures -import logging -import time -from concurrent.futures import Future, ThreadPoolExecutor -from threading import BoundedSemaphore -from typing import ( - Any, - Callable, - Deque, - Dict, - Iterable, - Iterator, - Optional, - Set, - Tuple, - TypeVar, -) - -from datahub.ingestion.api.closeable import Closeable - -logger = logging.getLogger(__name__) -_R = TypeVar("_R") -_PARTITION_EXECUTOR_FLUSH_SLEEP_INTERVAL = 0.05 - - -class PartitionExecutor(Closeable): - def __init__(self, max_workers: int, max_pending: int) -> None: - """A thread pool executor with partitioning and a pending request bound. - - It works similarly to a ThreadPoolExecutor, with the following changes: - - At most one request per partition key will be executing at a time. - - If the number of pending requests exceeds the threshold, the submit() call - will block until the number of pending requests drops below the threshold. - - Due to the interaction between max_workers and max_pending, it is possible - for execution to effectively be serialized when there's a large influx of - requests with the same key. This can be mitigated by setting a reasonably - large max_pending value. - - Args: - max_workers: The maximum number of threads to use for executing requests. - max_pending: The maximum number of pending (e.g. non-executing) requests to allow. - """ - self.max_workers = max_workers - self.max_pending = max_pending - - self._executor = ThreadPoolExecutor(max_workers=max_workers) - - # Each pending or executing request will acquire a permit from this semaphore. - self._semaphore = BoundedSemaphore(max_pending + max_workers) - - # A key existing in this dict means that there is a submitted request for that key. - # Any entries in the key's value e.g. the deque are requests that are waiting - # to be submitted once the current request for that key completes. - self._pending_by_key: Dict[ - str, Deque[Tuple[Callable, tuple, dict, Optional[Callable[[Future], None]]]] - ] = {} - - def submit( - self, - key: str, - fn: Callable[..., _R], - *args: Any, - # Ideally, we would've used ParamSpec to annotate this method. However, - # due to the limitations of PEP 612, we can't add a keyword argument here. - # See https://peps.python.org/pep-0612/#concatenating-keyword-parameters - # As such, we're using Any here, and won't validate the args to this method. - # We might be able to work around it by moving the done_callback arg to be before - # the *args, but that would mean making done_callback a required arg instead of - # optional as it is now. - done_callback: Optional[Callable[[Future], None]] = None, - **kwargs: Any, - ) -> None: - """See concurrent.futures.Executor#submit""" - - self._semaphore.acquire() - - if key in self._pending_by_key: - self._pending_by_key[key].append((fn, args, kwargs, done_callback)) - - else: - self._pending_by_key[key] = collections.deque() - self._submit_nowait(key, fn, args, kwargs, done_callback=done_callback) - - def _submit_nowait( - self, - key: str, - fn: Callable[..., _R], - args: tuple, - kwargs: dict, - done_callback: Optional[Callable[[Future], None]], - ) -> Future: - future = self._executor.submit(fn, *args, **kwargs) - - def _system_done_callback(future: Future) -> None: - self._semaphore.release() - - # If there is another pending request for this key, submit it now. - # The key must exist in the map. - if self._pending_by_key[key]: - fn, args, kwargs, user_done_callback = self._pending_by_key[ - key - ].popleft() - - try: - self._submit_nowait(key, fn, args, kwargs, user_done_callback) - except RuntimeError as e: - if self._executor._shutdown: - # If we're in shutdown mode, then we can't submit any more requests. - # That means we'll need to drop requests on the floor, which is to - # be expected in shutdown mode. - # The only reason we'd normally be in shutdown here is during - # Python exit (e.g. KeyboardInterrupt), so this is reasonable. - logger.debug("Dropping request due to shutdown") - else: - raise e - - else: - # If there are no pending requests for this key, mark the key - # as no longer in progress. - del self._pending_by_key[key] - - if done_callback: - future.add_done_callback(done_callback) - future.add_done_callback(_system_done_callback) - return future - - def flush(self) -> None: - """Wait for all pending requests to complete.""" - - # Acquire all the semaphore permits so that no more requests can be submitted. - for _i in range(self.max_pending): - self._semaphore.acquire() - - # Now, wait for all the pending requests to complete. - while len(self._pending_by_key) > 0: - # TODO: There should be a better way to wait for all executor threads to be idle. - # One option would be to just shutdown the existing executor and create a new one. - time.sleep(_PARTITION_EXECUTOR_FLUSH_SLEEP_INTERVAL) - - # Now allow new requests to be submitted. - # TODO: With Python 3.9, release() can take a count argument. - for _i in range(self.max_pending): - self._semaphore.release() - - def shutdown(self) -> None: - """See concurrent.futures.Executor#shutdown. Behaves as if wait=True.""" - - self.flush() - assert len(self._pending_by_key) == 0 - - # Technically, the wait=True here is redundant, since all the threads should - # be idle now. - self._executor.shutdown(wait=True) - - def close(self) -> None: - self.shutdown() - - -class BackpressureAwareExecutor: - # This couldn't be a real executor because the semantics of submit wouldn't really make sense. - # In this variant, if we blocked on submit, then we would also be blocking the thread that - # we expect to be consuming the results. As such, I made it accept the full list of args - # up front, and that way the consumer can read results at its own pace. - - @classmethod - def map( - cls, - fn: Callable[..., _R], - args_list: Iterable[Tuple[Any, ...]], - max_workers: int, - max_pending: Optional[int] = None, - ) -> Iterator[Future[_R]]: - """Similar to concurrent.futures.ThreadPoolExecutor#map, except that it won't run ahead of the consumer. - - The main benefit is that the ThreadPoolExecutor isn't stuck holding a ton of result - objects in memory if the consumer is slow. Instead, the consumer can read the results - at its own pace and the executor threads will idle if they need to. - - Args: - fn: The function to apply to each input. - args_list: The list of inputs. In contrast to the builtin map, this is a list - of tuples, where each tuple is the arguments to fn. - max_workers: The maximum number of threads to use. - max_pending: The maximum number of pending results to keep in memory. - If not set, it will be set to 2*max_workers. - - Returns: - An iterable of futures. - - This differs from a traditional map because it returns futures - instead of the actual results, so that the caller is required - to handle exceptions. - - Additionally, it does not maintain the order of the arguments. - If you want to know which result corresponds to which input, - the mapped function should return some form of an identifier. - """ - - if max_pending is None: - max_pending = 2 * max_workers - assert max_pending >= max_workers - - pending_futures: Set[Future] = set() - - with ThreadPoolExecutor(max_workers=max_workers) as executor: - for args in args_list: - # If the pending list is full, wait until one is done. - if len(pending_futures) >= max_pending: - (done, _) = concurrent.futures.wait( - pending_futures, return_when=concurrent.futures.FIRST_COMPLETED - ) - for future in done: - pending_futures.remove(future) - - # We don't want to call result() here because we want the caller - # to handle exceptions/cancellation. - yield future - - # Now that there's space in the pending list, enqueue the next task. - pending_futures.add(executor.submit(fn, *args)) - - # Wait for all the remaining tasks to complete. - for future in concurrent.futures.as_completed(pending_futures): - pending_futures.remove(future) - yield future - - assert not pending_futures diff --git a/metadata-ingestion/src/datahub/utilities/backpressure_aware_executor.py b/metadata-ingestion/src/datahub/utilities/backpressure_aware_executor.py new file mode 100644 index 0000000000000..988bd91c4a642 --- /dev/null +++ b/metadata-ingestion/src/datahub/utilities/backpressure_aware_executor.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +import concurrent.futures +from concurrent.futures import Future, ThreadPoolExecutor +from typing import Any, Callable, Iterable, Iterator, Optional, Set, Tuple, TypeVar + +_R = TypeVar("_R") + + +class BackpressureAwareExecutor: + # This couldn't be a real executor because the semantics of submit wouldn't really make sense. + # In this variant, if we blocked on submit, then we would also be blocking the thread that + # we expect to be consuming the results. As such, I made it accept the full list of args + # up front, and that way the consumer can read results at its own pace. + + @classmethod + def map( + cls, + fn: Callable[..., _R], + args_list: Iterable[Tuple[Any, ...]], + max_workers: int, + max_pending: Optional[int] = None, + ) -> Iterator[Future[_R]]: + """Similar to concurrent.futures.ThreadPoolExecutor#map, except that it won't run ahead of the consumer. + + The main benefit is that the ThreadPoolExecutor isn't stuck holding a ton of result + objects in memory if the consumer is slow. Instead, the consumer can read the results + at its own pace and the executor threads will idle if they need to. + + Args: + fn: The function to apply to each input. + args_list: The list of inputs. In contrast to the builtin map, this is a list + of tuples, where each tuple is the arguments to fn. + max_workers: The maximum number of threads to use. + max_pending: The maximum number of pending results to keep in memory. + If not set, it will be set to 2*max_workers. + + Returns: + An iterable of futures. + + This differs from a traditional map because it returns futures + instead of the actual results, so that the caller is required + to handle exceptions. + + Additionally, it does not maintain the order of the arguments. + If you want to know which result corresponds to which input, + the mapped function should return some form of an identifier. + """ + + if max_pending is None: + max_pending = 2 * max_workers + assert max_pending >= max_workers + + pending_futures: Set[Future] = set() + + with ThreadPoolExecutor(max_workers=max_workers) as executor: + for args in args_list: + # If the pending list is full, wait until one is done. + if len(pending_futures) >= max_pending: + (done, _) = concurrent.futures.wait( + pending_futures, return_when=concurrent.futures.FIRST_COMPLETED + ) + for future in done: + pending_futures.remove(future) + + # We don't want to call result() here because we want the caller + # to handle exceptions/cancellation. + yield future + + # Now that there's space in the pending list, enqueue the next task. + pending_futures.add(executor.submit(fn, *args)) + + # Wait for all the remaining tasks to complete. + for future in concurrent.futures.as_completed(pending_futures): + pending_futures.remove(future) + yield future + + assert not pending_futures diff --git a/metadata-ingestion/src/datahub/utilities/partition_executor.py b/metadata-ingestion/src/datahub/utilities/partition_executor.py new file mode 100644 index 0000000000000..05e81da47285d --- /dev/null +++ b/metadata-ingestion/src/datahub/utilities/partition_executor.py @@ -0,0 +1,404 @@ +from __future__ import annotations + +import collections +import functools +import logging +import queue +import threading +import time +from concurrent.futures import Future, ThreadPoolExecutor +from datetime import datetime, timedelta, timezone +from threading import BoundedSemaphore +from typing import ( + Any, + Callable, + Deque, + Dict, + List, + NamedTuple, + Optional, + Set, + Tuple, + TypeVar, +) + +from datahub.ingestion.api.closeable import Closeable + +logger = logging.getLogger(__name__) +_R = TypeVar("_R") +_Args = TypeVar("_Args", bound=tuple) +_PARTITION_EXECUTOR_FLUSH_SLEEP_INTERVAL = 0.05 +_DEFAULT_BATCHER_MIN_PROCESS_INTERVAL = timedelta(seconds=30) + + +class PartitionExecutor(Closeable): + def __init__(self, max_workers: int, max_pending: int) -> None: + """A thread pool executor with partitioning and a pending request bound. + + It works similarly to a ThreadPoolExecutor, with the following changes: + - At most one request per partition key will be executing at a time. + - If the number of pending requests exceeds the threshold, the submit() call + will block until the number of pending requests drops below the threshold. + + Due to the interaction between max_workers and max_pending, it is possible + for execution to effectively be serialized when there's a large influx of + requests with the same key. This can be mitigated by setting a reasonably + large max_pending value. + + Args: + max_workers: The maximum number of threads to use for executing requests. + max_pending: The maximum number of pending (e.g. non-executing) requests to allow. + """ + self.max_workers = max_workers + self.max_pending = max_pending + + self._executor = ThreadPoolExecutor(max_workers=max_workers) + + # Each pending or executing request will acquire a permit from this semaphore. + self._semaphore = BoundedSemaphore(max_pending + max_workers) + + # A key existing in this dict means that there is a submitted request for that key. + # Any entries in the key's value e.g. the deque are requests that are waiting + # to be submitted once the current request for that key completes. + self._pending_by_key: Dict[ + str, Deque[Tuple[Callable, tuple, dict, Optional[Callable[[Future], None]]]] + ] = {} + + def submit( + self, + key: str, + fn: Callable[..., _R], + *args: Any, + # Ideally, we would've used ParamSpec to annotate this method. However, + # due to the limitations of PEP 612, we can't add a keyword argument here. + # See https://peps.python.org/pep-0612/#concatenating-keyword-parameters + # As such, we're using Any here, and won't validate the args to this method. + # We might be able to work around it by moving the done_callback arg to be before + # the *args, but that would mean making done_callback a required arg instead of + # optional as it is now. + done_callback: Optional[Callable[[Future], None]] = None, + **kwargs: Any, + ) -> None: + """See concurrent.futures.Executor#submit""" + + self._semaphore.acquire() + + if key in self._pending_by_key: + self._pending_by_key[key].append((fn, args, kwargs, done_callback)) + + else: + self._pending_by_key[key] = collections.deque() + self._submit_nowait(key, fn, args, kwargs, done_callback=done_callback) + + def _submit_nowait( + self, + key: str, + fn: Callable[..., _R], + args: tuple, + kwargs: dict, + done_callback: Optional[Callable[[Future], None]], + ) -> Future: + future = self._executor.submit(fn, *args, **kwargs) + + def _system_done_callback(future: Future) -> None: + self._semaphore.release() + + # If there is another pending request for this key, submit it now. + # The key must exist in the map. + if self._pending_by_key[key]: + fn, args, kwargs, user_done_callback = self._pending_by_key[ + key + ].popleft() + + try: + self._submit_nowait(key, fn, args, kwargs, user_done_callback) + except RuntimeError as e: + if self._executor._shutdown: + # If we're in shutdown mode, then we can't submit any more requests. + # That means we'll need to drop requests on the floor, which is to + # be expected in shutdown mode. + # The only reason we'd normally be in shutdown here is during + # Python exit (e.g. KeyboardInterrupt), so this is reasonable. + logger.debug("Dropping request due to shutdown") + else: + raise e + + else: + # If there are no pending requests for this key, mark the key + # as no longer in progress. + del self._pending_by_key[key] + + if done_callback: + future.add_done_callback(done_callback) + future.add_done_callback(_system_done_callback) + return future + + def flush(self) -> None: + """Wait for all pending requests to complete.""" + + # Acquire all the semaphore permits so that no more requests can be submitted. + for _i in range(self.max_pending): + self._semaphore.acquire() + + # Now, wait for all the pending requests to complete. + while len(self._pending_by_key) > 0: + # TODO: There should be a better way to wait for all executor threads to be idle. + # One option would be to just shutdown the existing executor and create a new one. + time.sleep(_PARTITION_EXECUTOR_FLUSH_SLEEP_INTERVAL) + + # Now allow new requests to be submitted. + # TODO: With Python 3.9, release() can take a count argument. + for _i in range(self.max_pending): + self._semaphore.release() + + def shutdown(self) -> None: + """See concurrent.futures.Executor#shutdown. Behaves as if wait=True.""" + + self.flush() + assert len(self._pending_by_key) == 0 + + self._executor.shutdown(wait=True) + + def close(self) -> None: + self.shutdown() + + +class _BatchPartitionWorkItem(NamedTuple): + key: str + args: tuple + done_callback: Optional[Callable[[Future], None]] + + +def _now() -> datetime: + return datetime.now(tz=timezone.utc) + + +class BatchPartitionExecutor(Closeable): + def __init__( + self, + max_workers: int, + max_pending: int, + # Due to limitations of Python's typing, we can't express the type of the list + # effectively. Ideally we'd use ParamSpec here, but that's not allowed in a + # class context like this. + process_batch: Callable[[List], None], + max_per_batch: int = 100, + min_process_interval: timedelta = _DEFAULT_BATCHER_MIN_PROCESS_INTERVAL, + ) -> None: + """Similar to PartitionExecutor, but with batching. + + This takes in the stream of requests, automatically segments them into partition-aware + batches, and schedules them across a pool of worker threads. + + It maintains the invariant that multiple requests with the same key will not be in + flight concurrently, except when part of the same batch. Requests for a given key + will also be executed in the order they were submitted. + + Unlike the PartitionExecutor, this does not support return values or kwargs. + + Args: + max_workers: The maximum number of threads to use for executing requests. + max_pending: The maximum number of pending (e.g. non-executing) requests to allow. + max_per_batch: The maximum number of requests to include in a batch. + min_process_interval: When requests are coming in slowly, we will wait at least this long + before submitting a non-full batch. + process_batch: A function that takes in a list of argument tuples. + """ + self.max_workers = max_workers + self.max_pending = max_pending + self.max_per_batch = max_per_batch + self.process_batch = process_batch + self.min_process_interval = min_process_interval + assert self.max_workers > 1 + + # We add one here to account for the clearinghouse worker thread. + self._executor = ThreadPoolExecutor(max_workers=max_workers + 1) + self._clearinghouse_started = False + + self._pending_count = BoundedSemaphore(max_pending) + self._pending: "queue.Queue[Optional[_BatchPartitionWorkItem]]" = queue.Queue( + maxsize=max_pending + ) + + # If this is true, that means shutdown() has been called and self._pending is empty. + self._queue_empty_for_shutdown = False + + def _clearinghouse_worker(self) -> None: # noqa: C901 + # This worker will pull items off the queue, and submit them into the executor + # in batches. Only this worker will submit process commands to the executor thread pool. + + # The lock protects the function's internal state. + clearinghouse_state_lock = threading.Lock() + workers_available = self.max_workers + keys_in_flight: Set[str] = set() + keys_no_longer_in_flight: Set[str] = set() + pending_key_completion: List[_BatchPartitionWorkItem] = [] + + last_submit_time = _now() + + def _handle_batch_completion( + batch: List[_BatchPartitionWorkItem], future: Future + ) -> None: + with clearinghouse_state_lock: + for item in batch: + keys_no_longer_in_flight.add(item.key) + self._pending_count.release() + + # Separate from the above loop to avoid holding the lock while calling the callbacks. + for item in batch: + if item.done_callback: + item.done_callback(future) + + def _find_ready_items() -> List[_BatchPartitionWorkItem]: + with clearinghouse_state_lock: + # First, update the keys in flight. + for key in keys_no_longer_in_flight: + keys_in_flight.remove(key) + keys_no_longer_in_flight.clear() + + # Then, update the pending key completion and build the ready list. + pending = pending_key_completion.copy() + pending_key_completion.clear() + + ready: List[_BatchPartitionWorkItem] = [] + for item in pending: + if ( + len(ready) < self.max_per_batch + and item.key not in keys_in_flight + ): + ready.append(item) + else: + pending_key_completion.append(item) + + return ready + + def _build_batch() -> List[_BatchPartitionWorkItem]: + next_batch = _find_ready_items() + + while ( + not self._queue_empty_for_shutdown + and len(next_batch) < self.max_per_batch + ): + blocking = True + if ( + next_batch + and _now() - last_submit_time > self.min_process_interval + and workers_available > 0 + ): + # If we're past the submit deadline, pull from the queue + # in a non-blocking way, and submit the batch once the queue + # is empty. + blocking = False + + try: + next_item: Optional[_BatchPartitionWorkItem] = self._pending.get( + block=blocking, + timeout=self.min_process_interval.total_seconds(), + ) + if next_item is None: + self._queue_empty_for_shutdown = True + break + + with clearinghouse_state_lock: + if next_item.key in keys_in_flight: + pending_key_completion.append(next_item) + else: + next_batch.append(next_item) + except queue.Empty: + if not blocking: + break + + return next_batch + + def _submit_batch(next_batch: List[_BatchPartitionWorkItem]) -> None: + with clearinghouse_state_lock: + for item in next_batch: + keys_in_flight.add(item.key) + + nonlocal workers_available + workers_available -= 1 + + nonlocal last_submit_time + last_submit_time = _now() + + future = self._executor.submit( + self.process_batch, [item.args for item in next_batch] + ) + future.add_done_callback( + functools.partial(_handle_batch_completion, next_batch) + ) + + try: + # Normal operation - submit batches as they become available. + while not self._queue_empty_for_shutdown: + next_batch = _build_batch() + if next_batch: + _submit_batch(next_batch) + + # Shutdown time. + # Invariant - at this point, we know self._pending is empty. + # We just need to wait for the in-flight items to complete, + # and submit any currently pending items once possible. + while pending_key_completion: + next_batch = _build_batch() + if next_batch: + _submit_batch(next_batch) + time.sleep(_PARTITION_EXECUTOR_FLUSH_SLEEP_INTERVAL) + + # At this point, there are no more things to submit. + # We could wait for the in-flight items to complete, + # but the executor will take care of waiting for them to complete. + except Exception as e: + # This represents a fatal error that makes the entire executor defunct. + logger.exception( + "Threaded executor's clearinghouse worker failed.", exc_info=e + ) + finally: + self._clearinghouse_started = False + + def _ensure_clearinghouse_started(self) -> None: + # Lazily start the clearinghouse worker. + if not self._clearinghouse_started: + self._clearinghouse_started = True + self._executor.submit(self._clearinghouse_worker) + + def submit( + self, + key: str, + *args: Any, + done_callback: Optional[Callable[[Future], None]] = None, + ) -> None: + """See concurrent.futures.Executor#submit""" + + self._ensure_clearinghouse_started() + + self._pending_count.acquire() + self._pending.put(_BatchPartitionWorkItem(key, args, done_callback)) + + def shutdown(self) -> None: + if not self._clearinghouse_started: + # This is required to make shutdown() idempotent, which is important + # when it's called explicitly and then also by a context manager. + logger.debug("Shutting down: clearinghouse not started") + return + + logger.debug(f"Shutting down {self.__class__.__name__}") + + # Send the shutdown signal. + self._pending.put(None) + + # By acquiring all the permits, we ensure that no more tasks will be scheduled + # and automatically wait until all existing tasks have completed. + for _ in range(self.max_pending): + self._pending_count.acquire() + + # We must wait for the clearinghouse worker to exit before calling shutdown + # on the thread pool. Without this, the clearinghouse worker might fail to + # enqueue pending tasks into the pool. + while self._clearinghouse_started: + time.sleep(_PARTITION_EXECUTOR_FLUSH_SLEEP_INTERVAL) + + self._executor.shutdown(wait=False) + + def close(self) -> None: + self.shutdown() diff --git a/metadata-ingestion/tests/unit/utilities/test_advanced_thread_executor.py b/metadata-ingestion/tests/unit/utilities/test_advanced_thread_executor.py deleted file mode 100644 index 7b51c18a85c5f..0000000000000 --- a/metadata-ingestion/tests/unit/utilities/test_advanced_thread_executor.py +++ /dev/null @@ -1,128 +0,0 @@ -import time -from concurrent.futures import Future - -from datahub.utilities.advanced_thread_executor import ( - BackpressureAwareExecutor, - PartitionExecutor, -) -from datahub.utilities.perf_timer import PerfTimer - - -def test_partitioned_executor(): - executing_tasks = set() - done_tasks = set() - - def task(key: str, id: str) -> None: - executing_tasks.add((key, id)) - time.sleep(0.8) - done_tasks.add(id) - executing_tasks.remove((key, id)) - - with PartitionExecutor(max_workers=2, max_pending=10) as executor: - # Submit tasks with the same key. They should be executed sequentially. - executor.submit("key1", task, "key1", "task1") - executor.submit("key1", task, "key1", "task2") - executor.submit("key1", task, "key1", "task3") - - # Submit a task with a different key. It should be executed in parallel. - executor.submit("key2", task, "key2", "task4") - - saw_keys_in_parallel = False - while executing_tasks or not done_tasks: - keys_executing = [key for key, _ in executing_tasks] - assert list(sorted(keys_executing)) == list( - sorted(set(keys_executing)) - ), "partitioning not working" - - if len(keys_executing) == 2: - saw_keys_in_parallel = True - - time.sleep(0.1) - - executor.flush() - assert saw_keys_in_parallel - assert not executing_tasks - assert done_tasks == {"task1", "task2", "task3", "task4"} - - -def test_partitioned_executor_bounding(): - task_duration = 0.5 - done_tasks = set() - - def on_done(future: Future) -> None: - done_tasks.add(future.result()) - - def task(id: str) -> str: - time.sleep(task_duration) - return id - - with PartitionExecutor( - max_workers=5, max_pending=10 - ) as executor, PerfTimer() as timer: - # The first 15 submits should be non-blocking. - for i in range(15): - executor.submit(f"key{i}", task, f"task{i}", done_callback=on_done) - assert timer.elapsed_seconds() < task_duration - - # This submit should block. - executor.submit("key-blocking", task, "task-blocking", done_callback=on_done) - assert timer.elapsed_seconds() > task_duration - - # Wait for everything to finish. - executor.flush() - assert len(done_tasks) == 16 - - -def test_backpressure_aware_executor_simple(): - def task(i): - return i - - assert { - res.result() - for res in BackpressureAwareExecutor.map( - task, ((i,) for i in range(10)), max_workers=2 - ) - } == set(range(10)) - - -def test_backpressure_aware_executor_advanced(): - task_duration = 0.5 - started = set() - executed = set() - - def task(x, y): - assert x + 1 == y - started.add(x) - time.sleep(task_duration) - executed.add(x) - return x - - args_list = [(i, i + 1) for i in range(10)] - - with PerfTimer() as timer: - results = BackpressureAwareExecutor.map( - task, args_list, max_workers=2, max_pending=4 - ) - assert timer.elapsed_seconds() < task_duration - - # No tasks should have completed yet. - assert len(executed) == 0 - - # Consume the first result. - first_result = next(results) - assert 0 <= first_result.result() < 4 - assert timer.elapsed_seconds() > task_duration - - # By now, the first four tasks should have started. - time.sleep(task_duration) - assert {0, 1, 2, 3}.issubset(started) - assert 2 <= len(executed) <= 4 - - # Finally, consume the rest of the results. - assert {r.result() for r in results} == { - i for i in range(10) if i != first_result.result() - } - - # Validate that the entire process took about 5-10x the task duration. - # That's because we have 2 workers and 10 tasks. - assert 5 * task_duration < timer.elapsed_seconds() < 10 * task_duration diff --git a/metadata-ingestion/tests/unit/utilities/test_backpressure_aware_executor.py b/metadata-ingestion/tests/unit/utilities/test_backpressure_aware_executor.py new file mode 100644 index 0000000000000..5b320b8a23254 --- /dev/null +++ b/metadata-ingestion/tests/unit/utilities/test_backpressure_aware_executor.py @@ -0,0 +1,59 @@ +import time + +from datahub.utilities.backpressure_aware_executor import BackpressureAwareExecutor +from datahub.utilities.perf_timer import PerfTimer + + +def test_backpressure_aware_executor_simple(): + def task(i): + return i + + assert { + res.result() + for res in BackpressureAwareExecutor.map( + task, ((i,) for i in range(10)), max_workers=2 + ) + } == set(range(10)) + + +def test_backpressure_aware_executor_advanced(): + task_duration = 0.5 + started = set() + executed = set() + + def task(x, y): + assert x + 1 == y + started.add(x) + time.sleep(task_duration) + executed.add(x) + return x + + args_list = [(i, i + 1) for i in range(10)] + + with PerfTimer() as timer: + results = BackpressureAwareExecutor.map( + task, args_list, max_workers=2, max_pending=4 + ) + assert timer.elapsed_seconds() < task_duration + + # No tasks should have completed yet. + assert len(executed) == 0 + + # Consume the first result. + first_result = next(results) + assert 0 <= first_result.result() < 4 + assert timer.elapsed_seconds() > task_duration + + # By now, the first four tasks should have started. + time.sleep(task_duration) + assert {0, 1, 2, 3}.issubset(started) + assert 2 <= len(executed) <= 4 + + # Finally, consume the rest of the results. + assert {r.result() for r in results} == { + i for i in range(10) if i != first_result.result() + } + + # Validate that the entire process took about 5-10x the task duration. + # That's because we have 2 workers and 10 tasks. + assert 5 * task_duration < timer.elapsed_seconds() < 10 * task_duration diff --git a/metadata-ingestion/tests/unit/utilities/test_partition_executor.py b/metadata-ingestion/tests/unit/utilities/test_partition_executor.py new file mode 100644 index 0000000000000..81c5b898caf2b --- /dev/null +++ b/metadata-ingestion/tests/unit/utilities/test_partition_executor.py @@ -0,0 +1,150 @@ +import logging +import time +from concurrent.futures import Future + +from datahub.utilities.partition_executor import ( + BatchPartitionExecutor, + PartitionExecutor, +) +from datahub.utilities.perf_timer import PerfTimer + +logger = logging.getLogger(__name__) + + +def test_partitioned_executor(): + executing_tasks = set() + done_tasks = set() + + def task(key: str, id: str) -> None: + executing_tasks.add((key, id)) + time.sleep(0.8) + done_tasks.add(id) + executing_tasks.remove((key, id)) + + with PartitionExecutor(max_workers=2, max_pending=10) as executor: + # Submit tasks with the same key. They should be executed sequentially. + executor.submit("key1", task, "key1", "task1") + executor.submit("key1", task, "key1", "task2") + executor.submit("key1", task, "key1", "task3") + + # Submit a task with a different key. It should be executed in parallel. + executor.submit("key2", task, "key2", "task4") + + saw_keys_in_parallel = False + while executing_tasks or not done_tasks: + keys_executing = [key for key, _ in executing_tasks] + assert list(sorted(keys_executing)) == list( + sorted(set(keys_executing)) + ), "partitioning not working" + + if len(keys_executing) == 2: + saw_keys_in_parallel = True + + time.sleep(0.1) + + executor.flush() + assert saw_keys_in_parallel + assert not executing_tasks + assert done_tasks == {"task1", "task2", "task3", "task4"} + + +def test_partitioned_executor_bounding(): + task_duration = 0.5 + done_tasks = set() + + def on_done(future: Future) -> None: + done_tasks.add(future.result()) + + def task(id: str) -> str: + time.sleep(task_duration) + return id + + with PartitionExecutor( + max_workers=5, max_pending=10 + ) as executor, PerfTimer() as timer: + # The first 15 submits should be non-blocking. + for i in range(15): + executor.submit(f"key{i}", task, f"task{i}", done_callback=on_done) + assert timer.elapsed_seconds() < task_duration + + # This submit should block. + executor.submit("key-blocking", task, "task-blocking", done_callback=on_done) + assert timer.elapsed_seconds() > task_duration + + # Wait for everything to finish. + executor.flush() + assert len(done_tasks) == 16 + + +def test_batch_partition_executor_sequential_key_execution(): + executing_tasks = set() + done_tasks = set() + done_task_batches = set() + + def process_batch(batch): + for key, id in batch: + assert (key, id) not in executing_tasks, "Task is already executing" + executing_tasks.add((key, id)) + + time.sleep(0.5) # Simulate work + + for key, id in batch: + executing_tasks.remove((key, id)) + done_tasks.add(id) + + done_task_batches.add(tuple(id for _, id in batch)) + + with BatchPartitionExecutor( + max_workers=2, + max_pending=10, + max_per_batch=2, + process_batch=process_batch, + ) as executor: + # Submit tasks with the same key. The first two should get batched together. + executor.submit("key1", "key1", "task1") + executor.submit("key1", "key1", "task2") + executor.submit("key1", "key1", "task3") + + # Submit tasks with a different key. These should get their own batch. + executor.submit("key2", "key2", "task4") + executor.submit("key2", "key2", "task5") + + # Test idempotency of shutdown(). + executor.shutdown() + + # Check if all tasks were executed and completed. + assert done_tasks == { + "task1", + "task2", + "task3", + "task4", + "task5", + }, "Not all tasks completed" + + # Check the batching configuration. + assert done_task_batches == { + ("task1", "task2"), + ("task4", "task5"), + ("task3",), + } + + +def test_batch_partition_executor_max_batch_size(): + batches_processed = [] + + def process_batch(batch): + batches_processed.append(batch) + time.sleep(0.1) # Simulate batch processing time + + with BatchPartitionExecutor( + max_workers=5, max_pending=20, process_batch=process_batch, max_per_batch=2 + ) as executor: + # Submit more tasks than the max_per_batch to test batching limits. + for i in range(5): + executor.submit("key3", "key3", f"task{i}") + + # Check the batches. + logger.info(f"batches_processed: {batches_processed}") + assert len(batches_processed) == 3 + for batch in batches_processed: + assert len(batch) <= 2, "Batch size exceeded max_per_batch limit"