-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add WorkerThreadPool
for running synchronous work in threads
#7875
Changes from all commits
0c2ee8e
74b23a8
081ffc8
17939f0
3994db8
d21e4cb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
import asyncio | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am considering renaming this module to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep! |
||
import contextvars | ||
import dataclasses | ||
import threading | ||
import weakref | ||
from queue import Queue | ||
from typing import Callable, Dict, Optional, Set, Tuple, TypeVar, Union | ||
|
||
import anyio.abc | ||
from typing_extensions import ParamSpec | ||
|
||
from prefect._internal.concurrency.primitives import Future | ||
|
||
T = TypeVar("T") | ||
P = ParamSpec("P") | ||
|
||
|
||
@dataclasses.dataclass | ||
class _WorkItem: | ||
""" | ||
A representation of work sent to a worker thread. | ||
""" | ||
|
||
future: Future | ||
fn: Callable | ||
args: Tuple | ||
kwargs: Dict | ||
context: contextvars.Context | ||
|
||
def run(self): | ||
if not self.future.set_running_or_notify_cancel(): | ||
return | ||
try: | ||
result = self.context.run(self.fn, *self.args, **self.kwargs) | ||
except BaseException as exc: | ||
self.future.set_exception(exc) | ||
# Prevent reference cycle in `exc` | ||
self = None | ||
Check notice Code scanning / CodeQL Unused local variable
Variable self is not used.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm confused about how this works There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for calling this out! This is a common pattern in CPython, but I also have no idea how it works. Let me try to find some resources. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. python/cpython#80111 may be helpful? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the exception captures frame locals so the future references the exception which references the work item which references the future and we have a cycle. If we set |
||
else: | ||
self.future.set_result(result) | ||
|
||
|
||
class _WorkerThread(threading.Thread): | ||
def __init__( | ||
self, | ||
queue: "Queue[Union[_WorkItem, None]]", # Typing only supported in Python 3.9+ | ||
idle: threading.Semaphore, | ||
name: str = None, | ||
): | ||
super().__init__(name=name) | ||
self._queue = queue | ||
self._idle = idle | ||
|
||
def run(self) -> None: | ||
while True: | ||
work_item = self._queue.get() | ||
if work_item is None: | ||
# Shutdown command received; forward to other workers and exit | ||
self._queue.put_nowait(None) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm wondering about a slightly different approach here, though I haven't thought it through. But what if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hm, the issue is that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, you're right -- I went looking to see how ThreadPoolExecutor handled this and missed its use of None to signal. If it's good enough for them I suppose it's good enough for us! 😂 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A bit late but I finally feel up to speed with all of the is. Double checking my understanding from https://github.com/python/cpython/blob/1332fdabbab75bc9e4bced064dc4daab2d7acb47/Lib/asyncio/queues.py#L149
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also double checking - by blocking here, this keeps us burning through CPU on this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes it stops us from burning CPU. It returns |
||
return | ||
|
||
work_item.run() | ||
self._idle.release() | ||
|
||
del work_item | ||
|
||
|
||
class WorkerThreadPool: | ||
def __init__(self, max_workers: int = 40) -> None: | ||
self._queue: "Queue[Union[_WorkItem, None]]" = Queue() | ||
self._workers: Set[_WorkerThread] = set() | ||
self._max_workers = max_workers | ||
self._idle = threading.Semaphore(0) | ||
self._lock = asyncio.Lock() | ||
self._shutdown = False | ||
|
||
# On garbage collection of the pool, signal shutdown to workers | ||
weakref.finalize(self, self._queue.put_nowait, None) | ||
|
||
async def submit( | ||
self, fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs | ||
) -> Future[T]: | ||
""" | ||
Submit a function to run in a worker thread. | ||
|
||
Returns a future which can be used to retrieve the result of the function. | ||
""" | ||
async with self._lock: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Technically, because this function never calls There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmmm. My instinct would be to remove the lock until we need it, but I'll defer to you here because I'm not sure what keeping it protects us against -- you may have a better idea of that. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've removed the lock in a following pull request — it turns out I need I actually had never considered that if your async function doesn't await anything it doesn't need a lock, this was pointed out to me during review of an httpx PR :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Double checking understanding here - this is because we aren't going to be handing execution to another coroutine without an |
||
if self._shutdown: | ||
raise RuntimeError("Work cannot be submitted to pool after shutdown.") | ||
|
||
future = Future() | ||
|
||
work_item = _WorkItem( | ||
future=future, | ||
fn=fn, | ||
args=args, | ||
kwargs=kwargs, | ||
context=contextvars.copy_context(), | ||
) | ||
|
||
# Place the new work item on the work queue | ||
self._queue.put_nowait(work_item) | ||
|
||
# Ensure there are workers available to run the work | ||
self._adjust_worker_count() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. would it ever be possible for work to be submitted and error before the worker count is adjusted? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think the threads will context switch until after submission completes because control isn't yielded, but even if it did I don't think there would be significant effects. |
||
|
||
return future | ||
|
||
async def shutdown(self, task_status: Optional[anyio.abc.TaskStatus] = None): | ||
""" | ||
Shutdown the pool, waiting for all workers to complete before returning. | ||
|
||
If work is submitted before shutdown, they will run to completion. | ||
After shutdown, new work may not be submitted. | ||
|
||
When called with `TaskGroup.start(...)`, the task will be reported as started | ||
after signalling shutdown to workers. | ||
""" | ||
async with self._lock: | ||
self._shutdown = True | ||
self._queue.put_nowait(None) | ||
|
||
if task_status: | ||
task_status.started() | ||
|
||
# Avoid blocking the event loop while waiting for threads to join by | ||
# joining in another thread; we use a new instance of ourself to avoid | ||
# reimplementing threaded work. | ||
pool = WorkerThreadPool(max_workers=1) | ||
futures = [await pool.submit(worker.join) for worker in self._workers] | ||
await asyncio.gather(*[future.aresult() for future in futures]) | ||
Comment on lines
+130
to
+132
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is really meta but it's way easier than spinning a single thread manually just for this purpose. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was confused for a second but then remembered -- yeah, the |
||
|
||
self._workers.clear() | ||
|
||
def _adjust_worker_count(self): | ||
""" | ||
This method should called after work is added to the queue. | ||
|
||
If no workers are idle and the maximum worker count is not reached, add a new | ||
worker. Otherwise, decrement the idle worker count since work as been added | ||
to the queue and a worker will be busy. | ||
|
||
Note on cleanup of workers: | ||
Workers are only removed on shutdown. Workers could be shutdown after a | ||
period of idle. However, we expect usage in Prefect to generally be | ||
incurred in a workflow that will not have idle workers once they are | ||
created. As long as the maximum number of workers remains relatively small, | ||
the overhead of idle workers should be negligable. | ||
""" | ||
if ( | ||
# `acquire` returns false if the idle count is at zero; otherwise, it | ||
# decrements the idle count and returns true | ||
not self._idle.acquire(blocking=False) | ||
and len(self._workers) < self._max_workers | ||
): | ||
self._add_worker() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Double checking understanding here. We've got the semaphore set to |
||
|
||
def _add_worker(self): | ||
worker = _WorkerThread( | ||
queue=self._queue, | ||
idle=self._idle, | ||
name=f"PrefectWorker-{len(self._workers)}", | ||
) | ||
self._workers.add(worker) | ||
worker.start() | ||
|
||
async def __aenter__(self): | ||
return self | ||
|
||
async def __aexit__(self, *_): | ||
await self.shutdown() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
import asyncio | ||
import time | ||
|
||
import anyio | ||
import pytest | ||
|
||
from prefect._internal.concurrency.workers import WorkerThreadPool | ||
|
||
|
||
def identity(x): | ||
return x | ||
|
||
|
||
async def test_submit(): | ||
async with WorkerThreadPool() as pool: | ||
future = await pool.submit(identity, 1) | ||
assert await future.aresult() == 1 | ||
|
||
|
||
async def test_submit_many(): | ||
async with WorkerThreadPool() as pool: | ||
futures = [await pool.submit(identity, i) for i in range(100)] | ||
results = await asyncio.gather(*[future.aresult() for future in futures]) | ||
assert results == list(range(100)) | ||
assert len(pool._workers) == pool._max_workers | ||
|
||
|
||
async def test_submit_reuses_idle_thread(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I might preemptively decorate this with the flaky decorator to get retries. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It should never flake in theory :) We can also just sleep for a full second there instead of doing a busy wait, but this is a bit faster. We're just letting Python context switch to call the release method. |
||
async with WorkerThreadPool() as pool: | ||
future = await pool.submit(identity, 1) | ||
await future.aresult() | ||
|
||
# Spin until the worker is marked as idle | ||
with anyio.fail_after(1): | ||
while pool._idle._value == 0: | ||
await anyio.sleep(0) | ||
|
||
future = await pool.submit(identity, 1) | ||
await future.aresult() | ||
assert len(pool._workers) == 1 | ||
|
||
|
||
async def test_submit_after_shutdown(): | ||
pool = WorkerThreadPool() | ||
await pool.shutdown() | ||
|
||
with pytest.raises( | ||
RuntimeError, match="Work cannot be submitted to pool after shutdown" | ||
): | ||
await pool.submit(identity, 1) | ||
|
||
|
||
async def test_submit_during_shutdown(): | ||
async with WorkerThreadPool() as pool: | ||
|
||
async with anyio.create_task_group() as tg: | ||
await tg.start(pool.shutdown) | ||
|
||
with pytest.raises( | ||
RuntimeError, match="Work cannot be submitted to pool after shutdown" | ||
): | ||
await pool.submit(identity, 1) | ||
|
||
|
||
async def test_shutdown_no_workers(): | ||
pool = WorkerThreadPool() | ||
await pool.shutdown() | ||
|
||
|
||
async def test_shutdown_multiple_times(): | ||
pool = WorkerThreadPool() | ||
await pool.submit(identity, 1) | ||
await pool.shutdown() | ||
await pool.shutdown() | ||
|
||
|
||
async def test_shutdown_with_idle_workers(): | ||
pool = WorkerThreadPool() | ||
futures = [await pool.submit(identity, 1) for _ in range(5)] | ||
await asyncio.gather(*[future.aresult() for future in futures]) | ||
await pool.shutdown() | ||
|
||
|
||
async def test_shutdown_with_active_worker(): | ||
pool = WorkerThreadPool() | ||
future = await pool.submit(time.sleep, 1) | ||
await pool.shutdown() | ||
assert await future.aresult() is None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have a mild concern that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 I could add a function that sleeps then returns a value. |
||
|
||
|
||
async def test_shutdown_exception_during_join(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Where does the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah
I wrote this while dealing with some weird issues with pool shutdown when an exception was raised. It's unclear to me how to clarify it / if it's worth keeping. |
||
pool = WorkerThreadPool() | ||
future = await pool.submit(identity, 1) | ||
await future.aresult() | ||
|
||
try: | ||
async with anyio.create_task_group() as tg: | ||
await tg.start(pool.shutdown) | ||
raise ValueError() | ||
except ValueError: | ||
pass | ||
|
||
assert pool._shutdown is True | ||
|
||
|
||
async def test_context_manager_with_outstanding_future(): | ||
async with WorkerThreadPool() as pool: | ||
future = await pool.submit(identity, 1) | ||
|
||
assert pool._shutdown is True | ||
assert await future.aresult() == 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't say I understand what's going on here, but I'm ok with that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These utilities are kind of dumb, but basically
AbstractEventLoop.call_soon_threadsafe
returns aHandle
which is the most useless object around town — you can cancel it and that's it. In cases where we might like, actually want to know what our function returned or wait for its result, we need to return something else. To accomplish this, we wrap the function that we submit tocall_soon_threadsafe
and use a threadingFuture
to capture the return value.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added some additional documentation for these functions in the next PR