From 0c2ee8e7804bdccd4e2da397dd673ad600a7748d Mon Sep 17 00:00:00 2001 From: Michael Adkins Date: Mon, 12 Dec 2022 22:54:04 -0600 Subject: [PATCH 1/6] Add `WorkerThreadPool` for running synchronous work in threads --- .../_internal/concurrency/event_loop.py | 33 +++- .../_internal/concurrency/primitives.py | 7 +- src/prefect/_internal/concurrency/workers.py | 161 ++++++++++++++++++ tests/_internal/concurrency/test_workers.py | 73 ++++++++ 4 files changed, 267 insertions(+), 7 deletions(-) create mode 100644 src/prefect/_internal/concurrency/workers.py create mode 100644 tests/_internal/concurrency/test_workers.py diff --git a/src/prefect/_internal/concurrency/event_loop.py b/src/prefect/_internal/concurrency/event_loop.py index 3a41bf59d25d..21763e05b585 100644 --- a/src/prefect/_internal/concurrency/event_loop.py +++ b/src/prefect/_internal/concurrency/event_loop.py @@ -25,7 +25,7 @@ def get_running_loop() -> Optional[asyncio.BaseEventLoop]: return None -def run_in_loop_thread( +def call_in_loop( __loop: asyncio.AbstractEventLoop, __fn: Callable[P, T], *args: P.args, @@ -34,6 +34,16 @@ def run_in_loop_thread( """ Run a synchronous call in event loop's thread from another thread. """ + future = call_soon_in_loop(__loop, __fn, *args, **kwargs) + return future.result() + + +def call_soon_in_loop( + __loop: asyncio.AbstractEventLoop, + __fn: Callable[P, T], + *args: P.args, + **kwargs: P.kwargs +) -> concurrent.futures.Future: future = concurrent.futures.Future() @functools.wraps(__fn) @@ -46,4 +56,23 @@ def wrapper() -> None: raise __loop.call_soon_threadsafe(wrapper) - return future.result() + return future + + +def call_soon( + __fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs +) -> concurrent.futures.Future: + future = concurrent.futures.Future() + __loop = asyncio.get_running_loop() + + @functools.wraps(__fn) + def wrapper() -> None: + try: + future.set_result(__fn(*args, **kwargs)) + except BaseException as exc: + future.set_exception(exc) + if not isinstance(exc, Exception): + raise + + __loop.call_soon(wrapper) + return future diff --git a/src/prefect/_internal/concurrency/primitives.py b/src/prefect/_internal/concurrency/primitives.py index 84cd63fc1aab..ac470486adb2 100644 --- a/src/prefect/_internal/concurrency/primitives.py +++ b/src/prefect/_internal/concurrency/primitives.py @@ -5,10 +5,7 @@ import concurrent.futures from typing import Generic, Optional, TypeVar -from prefect._internal.concurrency.event_loop import ( - get_running_loop, - run_in_loop_thread, -) +from prefect._internal.concurrency.event_loop import call_in_loop, get_running_loop T = TypeVar("T") @@ -33,7 +30,7 @@ def set(self) -> None: self._is_set = True if self._loop: if self._loop != get_running_loop(): - run_in_loop_thread(self._loop, self._event.set) + call_in_loop(self._loop, self._event.set) else: self._event.set() diff --git a/src/prefect/_internal/concurrency/workers.py b/src/prefect/_internal/concurrency/workers.py new file mode 100644 index 000000000000..fd9a9e5e7946 --- /dev/null +++ b/src/prefect/_internal/concurrency/workers.py @@ -0,0 +1,161 @@ +import asyncio +import contextvars +import dataclasses +import threading +import weakref +from queue import Queue +from typing import Callable, Dict, Optional, Set, Tuple, TypeVar, Union + +import anyio.abc +from typing_extensions import ParamSpec + +from prefect._internal.concurrency.primitives import Future + +T = TypeVar("T") +P = ParamSpec("P") + + +@dataclasses.dataclass +class _WorkItem: + """ + A representation of work sent to a worker thread. + """ + + future: Future + fn: Callable + args: Tuple + kwargs: Dict + context: contextvars.Context + + def run(self): + if not self.future.set_running_or_notify_cancel(): + return + try: + result = self.context.run(self.fn, *self.args, **self.kwargs) + except BaseException as exc: + self.future.set_exception(exc) + # Prevent reference cycle in `exc` + self = None + else: + self.future.set_result(result) + + +class _WorkerThread(threading.Thread): + def __init__( + self, + queue: Queue[Union[_WorkItem, None]], + idle: threading.Semaphore, + name: str = None, + ): + super().__init__(name=name) + self._queue = queue + self._idle = idle + + def run(self) -> None: + while True: + work_item = self._queue.get() + if work_item is None: + # Shutdown command received; forward to other workers and exit + self._queue.put_nowait(None) + return + + self._idle.release() + work_item.run() + + del work_item + + +class WorkerThreadPool: + def __init__(self, max_workers: int = 40) -> None: + self._queue: Queue[Union[_WorkItem, None]] = Queue() + self._workers: Set[_WorkerThread] = set() + self._max_workers = max_workers + self._idle = threading.Semaphore(0) + self._lock = asyncio.Lock() + self._shutdown = False + + # On garbage collection of the pool, signal shutdown to workers + weakref.finalize(self, self._queue.put_nowait, None) + + async def submit( + self, fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs + ) -> Future[T]: + """ + Submit a function to run in a worker thread. + + Returns a future which can be used to retrieve the result of the function. + """ + async with self._lock: + if self._shutdown: + raise RuntimeError("Work cannot be submitted to pool after shutdown.") + + future = Future() + + work_item = _WorkItem( + future=future, + fn=fn, + args=args, + kwargs=kwargs, + context=contextvars.copy_context(), + ) + + # Place the new work item on the work queue + self._queue.put_nowait(work_item) + + # Ensure there are workers available to run the work + self._adjust_worker_count() + + return future + + async def shutdown(self, task_status: Optional[anyio.abc.TaskStatus] = None): + """ + Shutdown the pool, waiting for all workers to complete before returning. + + If work is submitted before shutdown, they will run to completion. + After shutdown, new work may not be submitted. + + When called with `TaskGroup.start(...)`, the task will be reported as started + after signalling shutdown to workers. + """ + async with self._lock: + self._queue.put_nowait(None) + + if task_status: + task_status.started() + + # Avoid blocking the event loop while waiting for threads to join by + # joining in another thread; we use a new instance of ourself to avoid + # reimplementing threaded work. + pool = WorkerThreadPool(max_workers=1) + futures = [await pool.submit(worker.join) for worker in self._workers] + await asyncio.gather(*[future.aresult() for future in futures]) + + self._workers.clear() + self._shutdown = True + + def _adjust_worker_count(self): + """ + If no workers are idle and the maximum worker count is not reached, add a new + worker. + + Note on cleanup of workers: + Workers are only removed on shutdown. Workers could be shutdown after a + period of idle. However, we expect usage in Prefect to generally be + incurred in a workflow that will not have idle workers once they are + created. As long as the maximum number of workers remains relatively small, + the overhead of idle workers should be negligable. + """ + if ( + not self._idle.acquire(blocking=False) + and len(self._workers) < self._max_workers + ): + self._add_worker() + + def _add_worker(self): + worker = _WorkerThread( + queue=self._queue, + idle=self._idle, + name=f"PrefectWorker-{len(self._workers)}", + ) + self._workers.add(worker) + worker.start() diff --git a/tests/_internal/concurrency/test_workers.py b/tests/_internal/concurrency/test_workers.py new file mode 100644 index 000000000000..a426fc0a4b62 --- /dev/null +++ b/tests/_internal/concurrency/test_workers.py @@ -0,0 +1,73 @@ +import asyncio +import time + +import anyio +import pytest + +from prefect._internal.concurrency.workers import WorkerThreadPool + + +def identity(x): + return x + + +async def test_submit(): + pool = WorkerThreadPool() + future = await pool.submit(identity, 1) + assert await future.aresult() == 1 + + +async def test_submit_many(): + pool = WorkerThreadPool() + futures = [await pool.submit(identity, i) for i in range(100)] + results = await asyncio.gather(*[future.aresult() for future in futures]) + assert results == list(range(100)) + assert len(pool._workers) == pool._max_workers + + +async def test_submit_after_shutdown(): + pool = WorkerThreadPool() + await pool.shutdown() + + with pytest.raises( + RuntimeError, match="Work cannot be submitted to pool after shutdown" + ): + await pool.submit(identity, 1) + + +async def test_submit_during_shutdown(): + pool = WorkerThreadPool() + + async with anyio.create_task_group() as tg: + await tg.start(pool.shutdown) + + with pytest.raises( + RuntimeError, match="Work cannot be submitted to pool after shutdown" + ): + await pool.submit(identity, 1) + + +async def test_shutdown_no_workers(): + pool = WorkerThreadPool() + await pool.shutdown() + + +async def test_shutdown_multiple_times(): + pool = WorkerThreadPool() + await pool.submit(identity, 1) + await pool.shutdown() + await pool.shutdown() + + +async def test_shutdown_with_idle_workers(): + pool = WorkerThreadPool() + futures = [await pool.submit(identity, 1) for _ in range(5)] + await asyncio.gather(*[future.aresult() for future in futures]) + await pool.shutdown() + + +async def test_shutdown_with_active_worker(): + pool = WorkerThreadPool() + future = await pool.submit(time.sleep, 1) + await pool.shutdown() + assert await future.aresult() is None From 74b23a8696eae0d6bf1a10fff896389fee207505 Mon Sep 17 00:00:00 2001 From: Michael Adkins Date: Tue, 13 Dec 2022 18:37:49 -0600 Subject: [PATCH 2/6] Add a context manager for consistent cleanup of pool --- src/prefect/_internal/concurrency/workers.py | 8 ++- tests/_internal/concurrency/test_workers.py | 62 +++++++++++++++----- 2 files changed, 54 insertions(+), 16 deletions(-) diff --git a/src/prefect/_internal/concurrency/workers.py b/src/prefect/_internal/concurrency/workers.py index fd9a9e5e7946..18ceb9a28a78 100644 --- a/src/prefect/_internal/concurrency/workers.py +++ b/src/prefect/_internal/concurrency/workers.py @@ -118,6 +118,7 @@ async def shutdown(self, task_status: Optional[anyio.abc.TaskStatus] = None): after signalling shutdown to workers. """ async with self._lock: + self._shutdown = True self._queue.put_nowait(None) if task_status: @@ -131,7 +132,6 @@ async def shutdown(self, task_status: Optional[anyio.abc.TaskStatus] = None): await asyncio.gather(*[future.aresult() for future in futures]) self._workers.clear() - self._shutdown = True def _adjust_worker_count(self): """ @@ -159,3 +159,9 @@ def _add_worker(self): ) self._workers.add(worker) worker.start() + + async def __aenter__(self): + return self + + async def __aexit__(self, *_): + await self.shutdown() diff --git a/tests/_internal/concurrency/test_workers.py b/tests/_internal/concurrency/test_workers.py index a426fc0a4b62..350cec7c85a6 100644 --- a/tests/_internal/concurrency/test_workers.py +++ b/tests/_internal/concurrency/test_workers.py @@ -12,17 +12,26 @@ def identity(x): async def test_submit(): - pool = WorkerThreadPool() - future = await pool.submit(identity, 1) - assert await future.aresult() == 1 + async with WorkerThreadPool() as pool: + future = await pool.submit(identity, 1) + assert await future.aresult() == 1 async def test_submit_many(): - pool = WorkerThreadPool() - futures = [await pool.submit(identity, i) for i in range(100)] - results = await asyncio.gather(*[future.aresult() for future in futures]) - assert results == list(range(100)) - assert len(pool._workers) == pool._max_workers + async with WorkerThreadPool() as pool: + futures = [await pool.submit(identity, i) for i in range(100)] + results = await asyncio.gather(*[future.aresult() for future in futures]) + assert results == list(range(100)) + assert len(pool._workers) == pool._max_workers + + +async def test_submit_reuses_idle_thread(): + async with WorkerThreadPool() as pool: + future = await pool.submit(identity, 1) + await future.aresult() + future = await pool.submit(identity, 1) + await future.aresult() + assert len(pool._workers) == 1 async def test_submit_after_shutdown(): @@ -36,15 +45,15 @@ async def test_submit_after_shutdown(): async def test_submit_during_shutdown(): - pool = WorkerThreadPool() + async with WorkerThreadPool() as pool: - async with anyio.create_task_group() as tg: - await tg.start(pool.shutdown) + async with anyio.create_task_group() as tg: + await tg.start(pool.shutdown) - with pytest.raises( - RuntimeError, match="Work cannot be submitted to pool after shutdown" - ): - await pool.submit(identity, 1) + with pytest.raises( + RuntimeError, match="Work cannot be submitted to pool after shutdown" + ): + await pool.submit(identity, 1) async def test_shutdown_no_workers(): @@ -71,3 +80,26 @@ async def test_shutdown_with_active_worker(): future = await pool.submit(time.sleep, 1) await pool.shutdown() assert await future.aresult() is None + + +async def test_shutdown_exception_during_join(): + pool = WorkerThreadPool() + future = await pool.submit(identity, 1) + await future.aresult() + + try: + async with anyio.create_task_group() as tg: + await tg.start(pool.shutdown) + raise ValueError() + except ValueError: + pass + + assert pool._shutdown is True + + +async def test_context_manager_with_outstanding_future(): + async with WorkerThreadPool() as pool: + future = await pool.submit(identity, 1) + + assert pool._shutdown is True + assert await future.aresult() == 1 From 081ffc8cbfbb28c480a21db829fcab378def4855 Mon Sep 17 00:00:00 2001 From: Michael Adkins Date: Tue, 13 Dec 2022 18:38:27 -0600 Subject: [PATCH 3/6] Release idle slots _after_ work is complete --- src/prefect/_internal/concurrency/workers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/prefect/_internal/concurrency/workers.py b/src/prefect/_internal/concurrency/workers.py index 18ceb9a28a78..1c8cb847342f 100644 --- a/src/prefect/_internal/concurrency/workers.py +++ b/src/prefect/_internal/concurrency/workers.py @@ -59,8 +59,8 @@ def run(self) -> None: self._queue.put_nowait(None) return - self._idle.release() work_item.run() + self._idle.release() del work_item From 17939f07020db8b3930001f68d12ce0dea9c6d49 Mon Sep 17 00:00:00 2001 From: Michael Adkins Date: Wed, 14 Dec 2022 09:43:58 -0600 Subject: [PATCH 4/6] Fix idle test --- src/prefect/_internal/concurrency/workers.py | 5 +---- tests/_internal/concurrency/test_workers.py | 6 ++++++ 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/prefect/_internal/concurrency/workers.py b/src/prefect/_internal/concurrency/workers.py index 1c8cb847342f..6ef811396299 100644 --- a/src/prefect/_internal/concurrency/workers.py +++ b/src/prefect/_internal/concurrency/workers.py @@ -145,10 +145,7 @@ def _adjust_worker_count(self): created. As long as the maximum number of workers remains relatively small, the overhead of idle workers should be negligable. """ - if ( - not self._idle.acquire(blocking=False) - and len(self._workers) < self._max_workers - ): + if not self._idle.acquire(timeout=0) and len(self._workers) < self._max_workers: self._add_worker() def _add_worker(self): diff --git a/tests/_internal/concurrency/test_workers.py b/tests/_internal/concurrency/test_workers.py index 350cec7c85a6..b550e4774be8 100644 --- a/tests/_internal/concurrency/test_workers.py +++ b/tests/_internal/concurrency/test_workers.py @@ -29,6 +29,12 @@ async def test_submit_reuses_idle_thread(): async with WorkerThreadPool() as pool: future = await pool.submit(identity, 1) await future.aresult() + + # Spin until the worker is marked as idle + with anyio.fail_after(1): + while pool._idle._value == 0: + await anyio.sleep(0) + future = await pool.submit(identity, 1) await future.aresult() assert len(pool._workers) == 1 From 3994db8bb2d0a75618df4772b795049f7baa6203 Mon Sep 17 00:00:00 2001 From: Michael Adkins Date: Wed, 14 Dec 2022 09:47:25 -0600 Subject: [PATCH 5/6] Update annotations to avoid breaking on Python < 3.9 --- src/prefect/_internal/concurrency/workers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/prefect/_internal/concurrency/workers.py b/src/prefect/_internal/concurrency/workers.py index 6ef811396299..fb5248ea7d14 100644 --- a/src/prefect/_internal/concurrency/workers.py +++ b/src/prefect/_internal/concurrency/workers.py @@ -43,7 +43,7 @@ def run(self): class _WorkerThread(threading.Thread): def __init__( self, - queue: Queue[Union[_WorkItem, None]], + queue: "Queue[Union[_WorkItem, None]]", # Typing only supported in Python 3.9+ idle: threading.Semaphore, name: str = None, ): @@ -67,7 +67,7 @@ def run(self) -> None: class WorkerThreadPool: def __init__(self, max_workers: int = 40) -> None: - self._queue: Queue[Union[_WorkItem, None]] = Queue() + self._queue: "Queue[Union[_WorkItem, None]]" = Queue() self._workers: Set[_WorkerThread] = set() self._max_workers = max_workers self._idle = threading.Semaphore(0) From d21e4cba30dbc92ebc235119700740c86c7da714 Mon Sep 17 00:00:00 2001 From: Michael Adkins Date: Wed, 14 Dec 2022 10:39:13 -0600 Subject: [PATCH 6/6] Improve _adjust_worker_count docstring --- src/prefect/_internal/concurrency/workers.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/prefect/_internal/concurrency/workers.py b/src/prefect/_internal/concurrency/workers.py index fb5248ea7d14..9e98ad600f06 100644 --- a/src/prefect/_internal/concurrency/workers.py +++ b/src/prefect/_internal/concurrency/workers.py @@ -135,8 +135,11 @@ async def shutdown(self, task_status: Optional[anyio.abc.TaskStatus] = None): def _adjust_worker_count(self): """ + This method should called after work is added to the queue. + If no workers are idle and the maximum worker count is not reached, add a new - worker. + worker. Otherwise, decrement the idle worker count since work as been added + to the queue and a worker will be busy. Note on cleanup of workers: Workers are only removed on shutdown. Workers could be shutdown after a @@ -145,7 +148,12 @@ def _adjust_worker_count(self): created. As long as the maximum number of workers remains relatively small, the overhead of idle workers should be negligable. """ - if not self._idle.acquire(timeout=0) and len(self._workers) < self._max_workers: + if ( + # `acquire` returns false if the idle count is at zero; otherwise, it + # decrements the idle count and returns true + not self._idle.acquire(blocking=False) + and len(self._workers) < self._max_workers + ): self._add_worker() def _add_worker(self):