Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add WorkerThreadPool for running synchronous work in threads #7875

Merged
merged 6 commits into from
Dec 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions src/prefect/_internal/concurrency/event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def get_running_loop() -> Optional[asyncio.BaseEventLoop]:
return None


def run_in_loop_thread(
def call_in_loop(
__loop: asyncio.AbstractEventLoop,
__fn: Callable[P, T],
*args: P.args,
Expand All @@ -34,6 +34,16 @@ def run_in_loop_thread(
"""
Run a synchronous call in event loop's thread from another thread.
"""
future = call_soon_in_loop(__loop, __fn, *args, **kwargs)
return future.result()


def call_soon_in_loop(
__loop: asyncio.AbstractEventLoop,
__fn: Callable[P, T],
*args: P.args,
**kwargs: P.kwargs
) -> concurrent.futures.Future:
future = concurrent.futures.Future()

@functools.wraps(__fn)
Expand All @@ -46,4 +56,23 @@ def wrapper() -> None:
raise

__loop.call_soon_threadsafe(wrapper)
return future.result()
return future
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't say I understand what's going on here, but I'm ok with that.

Copy link
Contributor Author

@zanieb zanieb Dec 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These utilities are kind of dumb, but basically AbstractEventLoop.call_soon_threadsafe returns a Handle which is the most useless object around town — you can cancel it and that's it. In cases where we might like, actually want to know what our function returned or wait for its result, we need to return something else. To accomplish this, we wrap the function that we submit to call_soon_threadsafe and use a threading Future to capture the return value.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added some additional documentation for these functions in the next PR



def call_soon(
__fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs
) -> concurrent.futures.Future:
future = concurrent.futures.Future()
__loop = asyncio.get_running_loop()

@functools.wraps(__fn)
def wrapper() -> None:
try:
future.set_result(__fn(*args, **kwargs))
except BaseException as exc:
future.set_exception(exc)
if not isinstance(exc, Exception):
raise

__loop.call_soon(wrapper)
return future
7 changes: 2 additions & 5 deletions src/prefect/_internal/concurrency/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,7 @@
import concurrent.futures
from typing import Generic, Optional, TypeVar

from prefect._internal.concurrency.event_loop import (
get_running_loop,
run_in_loop_thread,
)
from prefect._internal.concurrency.event_loop import call_in_loop, get_running_loop

T = TypeVar("T")

Expand All @@ -33,7 +30,7 @@ def set(self) -> None:
self._is_set = True
if self._loop:
if self._loop != get_running_loop():
run_in_loop_thread(self._loop, self._event.set)
call_in_loop(self._loop, self._event.set)
else:
self._event.set()

Expand Down
172 changes: 172 additions & 0 deletions src/prefect/_internal/concurrency/workers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import asyncio
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am considering renaming this module to threads.py to clear the path for process based workers, but think I will also defer that to the future.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _internal path allows us the freedom to do that without worrying about compatibility, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep!

import contextvars
import dataclasses
import threading
import weakref
from queue import Queue
from typing import Callable, Dict, Optional, Set, Tuple, TypeVar, Union

import anyio.abc
from typing_extensions import ParamSpec

from prefect._internal.concurrency.primitives import Future

T = TypeVar("T")
P = ParamSpec("P")


@dataclasses.dataclass
class _WorkItem:
"""
A representation of work sent to a worker thread.
"""

future: Future
fn: Callable
args: Tuple
kwargs: Dict
context: contextvars.Context

def run(self):
if not self.future.set_running_or_notify_cancel():
return
try:
result = self.context.run(self.fn, *self.args, **self.kwargs)
except BaseException as exc:
self.future.set_exception(exc)
# Prevent reference cycle in `exc`
self = None

Check notice

Code scanning / CodeQL

Unused local variable

Variable self is not used.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused about how this works

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for calling this out! This is a common pattern in CPython, but I also have no idea how it works. Let me try to find some resources.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

python/cpython#80111 may be helpful?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the exception captures frame locals so the future references the exception which references the work item which references the future and we have a cycle. If we set self to None the exception no longer references the work item.

else:
self.future.set_result(result)


class _WorkerThread(threading.Thread):
def __init__(
self,
queue: "Queue[Union[_WorkItem, None]]", # Typing only supported in Python 3.9+
idle: threading.Semaphore,
name: str = None,
):
super().__init__(name=name)
self._queue = queue
self._idle = idle

def run(self) -> None:
while True:
work_item = self._queue.get()
if work_item is None:
# Shutdown command received; forward to other workers and exit
self._queue.put_nowait(None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering about a slightly different approach here, though I haven't thought it through. But what if WorkerThreadPool handed a shutdown_event (Event) to WorkerThread on init. The worker checks every iteration of its run loop to see if shutdown_event is set, and if it is, run() returns. I'm slightly biased toward this approach -- assuming it actually makes sense -- to avoid attaching a semantic value to None. Up to you though!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, the issue is that self._queue.get() is blocking so the worker will not do anything until it receives something in the queue. We could check an event at the end of each work item but we still need to push something in the queue to wake up all the workers. Both AnyIO's worker threads and the CPython thread pool executor use this model — I trust what they're up to for now :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, you're right -- I went looking to see how ThreadPoolExecutor handled this and missed its use of None to signal. If it's good enough for them I suppose it's good enough for us! 😂

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit late but I finally feel up to speed with all of the is. Double checking my understanding from https://github.com/python/cpython/blob/1332fdabbab75bc9e4bced064dc4daab2d7acb47/Lib/asyncio/queues.py#L149

queue.get() will only ever return None after an exception, so we can use it as a signal here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also double checking - by blocking here, this keeps us burning through CPU on this while True loop, yeah?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it stops us from burning CPU.

It returns None when we put None in the queue :D — exceptions are raised. We place None in the queue to signal shutdown and to wake up the worker since it is otherwise blocked.

return

work_item.run()
self._idle.release()

del work_item


class WorkerThreadPool:
def __init__(self, max_workers: int = 40) -> None:
self._queue: "Queue[Union[_WorkItem, None]]" = Queue()
self._workers: Set[_WorkerThread] = set()
self._max_workers = max_workers
self._idle = threading.Semaphore(0)
self._lock = asyncio.Lock()
self._shutdown = False

# On garbage collection of the pool, signal shutdown to workers
weakref.finalize(self, self._queue.put_nowait, None)

async def submit(
self, fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs
) -> Future[T]:
"""
Submit a function to run in a worker thread.

Returns a future which can be used to retrieve the result of the function.
"""
async with self._lock:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically, because this function never calls await it does not need a lock here. I'm tempted to remove the lock and make a synchronous method, but it feels safer to wait and see what we need in the future.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm. My instinct would be to remove the lock until we need it, but I'll defer to you here because I'm not sure what keeping it protects us against -- you may have a better idea of that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've removed the lock in a following pull request — it turns out I need submit to be synchronous.

I actually had never considered that if your async function doesn't await anything it doesn't need a lock, this was pointed out to me during review of an httpx PR :)

Copy link
Contributor

@peytonrunyan peytonrunyan Dec 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Double checking understanding here - this is because we aren't going to be handing execution to another coroutine without an await, right? So as long as this object is only being accessed by coroutines running within the same thread, and not being accessed by other threads, we don't have to worry about race conditions without the await?

if self._shutdown:
raise RuntimeError("Work cannot be submitted to pool after shutdown.")

future = Future()

work_item = _WorkItem(
future=future,
fn=fn,
args=args,
kwargs=kwargs,
context=contextvars.copy_context(),
)

# Place the new work item on the work queue
self._queue.put_nowait(work_item)

# Ensure there are workers available to run the work
self._adjust_worker_count()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it ever be possible for work to be submitted and error before the worker count is adjusted?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the threads will context switch until after submission completes because control isn't yielded, but even if it did I don't think there would be significant effects.


return future

async def shutdown(self, task_status: Optional[anyio.abc.TaskStatus] = None):
"""
Shutdown the pool, waiting for all workers to complete before returning.

If work is submitted before shutdown, they will run to completion.
After shutdown, new work may not be submitted.

When called with `TaskGroup.start(...)`, the task will be reported as started
after signalling shutdown to workers.
"""
async with self._lock:
self._shutdown = True
self._queue.put_nowait(None)

if task_status:
task_status.started()

# Avoid blocking the event loop while waiting for threads to join by
# joining in another thread; we use a new instance of ourself to avoid
# reimplementing threaded work.
pool = WorkerThreadPool(max_workers=1)
futures = [await pool.submit(worker.join) for worker in self._workers]
await asyncio.gather(*[future.aresult() for future in futures])
Comment on lines +130 to +132
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really meta but it's way easier than spinning a single thread manually just for this purpose.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was confused for a second but then remembered -- yeah, the ThreadPoolExecutor interface (and thus ours) is way easier to deal with.


self._workers.clear()

def _adjust_worker_count(self):
"""
This method should called after work is added to the queue.

If no workers are idle and the maximum worker count is not reached, add a new
worker. Otherwise, decrement the idle worker count since work as been added
to the queue and a worker will be busy.

Note on cleanup of workers:
Workers are only removed on shutdown. Workers could be shutdown after a
period of idle. However, we expect usage in Prefect to generally be
incurred in a workflow that will not have idle workers once they are
created. As long as the maximum number of workers remains relatively small,
the overhead of idle workers should be negligable.
"""
if (
# `acquire` returns false if the idle count is at zero; otherwise, it
# decrements the idle count and returns true
not self._idle.acquire(blocking=False)
and len(self._workers) < self._max_workers
):
self._add_worker()
Copy link
Contributor

@peytonrunyan peytonrunyan Dec 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Double checking understanding here. We've got the semaphore set to 0. So we go to submit work, acquire() will give us back False, we check that we have room for additional workers, we spin up a new worker, it does its work, then calls the release() incrementing our semaphore, which lets us know we have idle workers in the pool, yeah?


def _add_worker(self):
worker = _WorkerThread(
queue=self._queue,
idle=self._idle,
name=f"PrefectWorker-{len(self._workers)}",
)
self._workers.add(worker)
worker.start()

async def __aenter__(self):
return self

async def __aexit__(self, *_):
await self.shutdown()
111 changes: 111 additions & 0 deletions tests/_internal/concurrency/test_workers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import asyncio
import time

import anyio
import pytest

from prefect._internal.concurrency.workers import WorkerThreadPool


def identity(x):
return x


async def test_submit():
async with WorkerThreadPool() as pool:
future = await pool.submit(identity, 1)
assert await future.aresult() == 1


async def test_submit_many():
async with WorkerThreadPool() as pool:
futures = [await pool.submit(identity, i) for i in range(100)]
results = await asyncio.gather(*[future.aresult() for future in futures])
assert results == list(range(100))
assert len(pool._workers) == pool._max_workers


async def test_submit_reuses_idle_thread():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might preemptively decorate this with the flaky decorator to get retries.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should never flake in theory :) We can also just sleep for a full second there instead of doing a busy wait, but this is a bit faster. We're just letting Python context switch to call the release method.

async with WorkerThreadPool() as pool:
future = await pool.submit(identity, 1)
await future.aresult()

# Spin until the worker is marked as idle
with anyio.fail_after(1):
while pool._idle._value == 0:
await anyio.sleep(0)

future = await pool.submit(identity, 1)
await future.aresult()
assert len(pool._workers) == 1


async def test_submit_after_shutdown():
pool = WorkerThreadPool()
await pool.shutdown()

with pytest.raises(
RuntimeError, match="Work cannot be submitted to pool after shutdown"
):
await pool.submit(identity, 1)


async def test_submit_during_shutdown():
async with WorkerThreadPool() as pool:

async with anyio.create_task_group() as tg:
await tg.start(pool.shutdown)

with pytest.raises(
RuntimeError, match="Work cannot be submitted to pool after shutdown"
):
await pool.submit(identity, 1)


async def test_shutdown_no_workers():
pool = WorkerThreadPool()
await pool.shutdown()


async def test_shutdown_multiple_times():
pool = WorkerThreadPool()
await pool.submit(identity, 1)
await pool.shutdown()
await pool.shutdown()


async def test_shutdown_with_idle_workers():
pool = WorkerThreadPool()
futures = [await pool.submit(identity, 1) for _ in range(5)]
await asyncio.gather(*[future.aresult() for future in futures])
await pool.shutdown()


async def test_shutdown_with_active_worker():
pool = WorkerThreadPool()
future = await pool.submit(time.sleep, 1)
await pool.shutdown()
assert await future.aresult() is None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a mild concern that aresult() returning None in the successful case leaves room to mask an unknown failure case that would erroneously lead to the same value. I don't feel strongly about it though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 I could add a function that sleeps then returns a value.



async def test_shutdown_exception_during_join():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where does the join happen? I'm a little confused about what this test checks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah pool.shutdown joins all the threads. When called with start, it sends the shutdown to signal to works than kills the threads. Here the pattern is a bit like

--> Test: Shutdown pool
--> Pool: Sends signal to workers
--> Pool: Awaits on worker join which context switches
--> Test: Raises exception
--> Pool: Shuts down cleanly still

I wrote this while dealing with some weird issues with pool shutdown when an exception was raised. It's unclear to me how to clarify it / if it's worth keeping.

pool = WorkerThreadPool()
future = await pool.submit(identity, 1)
await future.aresult()

try:
async with anyio.create_task_group() as tg:
await tg.start(pool.shutdown)
raise ValueError()
except ValueError:
pass

assert pool._shutdown is True


async def test_context_manager_with_outstanding_future():
async with WorkerThreadPool() as pool:
future = await pool.submit(identity, 1)

assert pool._shutdown is True
assert await future.aresult() == 1