diff --git a/src/qasync/__init__.py b/src/qasync/__init__.py index d7702c9..e5f568d 100644 --- a/src/qasync/__init__.py +++ b/src/qasync/__init__.py @@ -8,7 +8,14 @@ BSD License """ -__all__ = ["QEventLoop", "QThreadExecutor", "asyncSlot", "asyncClose", "asyncWrap"] +__all__ = [ + "QEventLoop", + "QThreadExecutor", + "QThreadPoolExecutor", + "asyncSlot", + "asyncClose", + "asyncWrap", +] import asyncio import contextlib @@ -22,6 +29,7 @@ import time from concurrent.futures import Future from queue import Queue +from weakref import WeakSet logger = logging.getLogger(__name__) @@ -162,8 +170,80 @@ def wait(self): super().wait() +def _result_or_cancel(fut, timeout=None): + try: + try: + return fut.result(timeout) + finally: + fut.cancel() + finally: + # Break a reference cycle with the exception in self._exception + del fut + + +class QThreadExecutorBase: + def __init__(self): + self._been_shutdown = False + self.futures = WeakSet() + + def submit(self, callback, *args, **kwargs): + raise NotImplementedError() + + def map(self, func, *iterables, timeout=None, chunksize=1): + """Map the function to the iterables in a blocking way.""" + # based on standard python implementation for BaseExecutor.map + end_time = time.monotonic() + timeout if timeout is not None else None + futures = [self.submit(func, *args) for args in zip(*iterables)] + + # the generator must be an inner function so that map() and the submit + # occurs immediately. + def generator(): + # reverse and pop to not keep future references around + # (for reference cycles in exceptions) + try: + futures.reverse() + while futures: + if end_time is not None: + yield _result_or_cancel( + futures.pop(), timeout=end_time - time.monotonic() + ) + else: + yield _result_or_cancel(futures.pop()) + finally: + for future in futures: + future.cancel() + + return generator() + + def shutdown(self, wait=True, *, cancel_futures=False): + if self._been_shutdown: + raise RuntimeError(f"{self.__class__.__name__} has been shutdown") + self._been_shutdown = True + + def __enter__(self, *args): + if self._been_shutdown: + raise RuntimeError(f"{self.__class__.__name__} has been shutdown") + return self + + def __exit__(self, *args): + self.shutdown() + + @staticmethod + def compute_stack_size(): + # Match cpython/Python/thread_pthread.h + if sys.platform.startswith("darwin"): + stack_size = 16 * 2**20 + elif sys.platform.startswith("freebsd"): + stack_size = 4 * 2**20 + elif sys.platform.startswith("aix"): + stack_size = 2 * 2**20 + else: + stack_size = None + return stack_size + + @with_logger -class QThreadExecutor: +class QThreadExecutor(QThreadExecutorBase): """ ThreadExecutor that produces QThreads. @@ -181,23 +261,16 @@ def __init__(self, max_workers=10, stack_size=None): self.__max_workers = max_workers self.__queue = Queue() if stack_size is None: - # Match cpython/Python/thread_pthread.h - if sys.platform.startswith("darwin"): - stack_size = 16 * 2**20 - elif sys.platform.startswith("freebsd"): - stack_size = 4 * 2**20 - elif sys.platform.startswith("aix"): - stack_size = 2 * 2**20 + stack_size = self.compute_stack_size() self.__workers = [ _QThreadWorker(self.__queue, i + 1, stack_size) for i in range(max_workers) ] - self.__been_shutdown = False for w in self.__workers: w.start() def submit(self, callback, *args, **kwargs): - if self.__been_shutdown: + if self._been_shutdown: raise RuntimeError("QThreadExecutor has been shutdown") future = Future() @@ -208,32 +281,80 @@ def submit(self, callback, *args, **kwargs): kwargs, ) self.__queue.put((future, callback, args, kwargs)) + self.futures.add(future) return future - def map(self, func, *iterables, timeout=None): - raise NotImplementedError("use as_completed on the event loop") - - def shutdown(self, wait=True): - if self.__been_shutdown: - raise RuntimeError("QThreadExecutor has been shutdown") - - self.__been_shutdown = True + def shutdown(self, wait=True, *, cancel_futures=False): + super().shutdown(wait=wait, cancel_futures=cancel_futures) self._logger.debug("Shutting down") for i in range(len(self.__workers)): # Signal workers to stop self.__queue.put(None) + if cancel_futures: + for future in self.futures: + future.cancel() if wait: for w in self.__workers: w.wait() - def __enter__(self, *args): - if self.__been_shutdown: - raise RuntimeError("QThreadExecutor has been shutdown") - return self - def __exit__(self, *args): - self.shutdown() +class _QThreadPoolExecutorRunnable(QtCore.QRunnable): + def __init__(self, callback, *args, **kwargs): + super().__init__() + self._callback = callback + self._args = args + self._kwargs = kwargs + self.future = Future() + + def run(self): + if self.future.set_running_or_notify_cancel(): + try: + result = self._callback(*self._args, **self._kwargs) + self.future.set_result(result) + except Exception as e: + self.future.set_exception(e) + + +@with_logger +class QThreadPoolExecutor(QThreadExecutorBase): + """ + ThreadPoolExecutor uses a QThreadPool as the underlying implementation. + + Same API as `concurrent.futures.Executor` + + >>> from qasync import QThreadPoolExecutor + >>> with QThreadPoolExecutor() as executor: + ... f = executor.submit(lambda x: 2 + x, 2) + ... r = f.result() + ... assert r == 4 + """ + + def __init__(self, pool=None): + super().__init__() + self.pool = pool or QtCore.QThreadPool.globalInstance() + + def submit(self, callback, *args, **kwargs): + if self._been_shutdown: + raise RuntimeError(f"{self.__class__.__name__} has been shutdown") + + runnable = _QThreadPoolExecutorRunnable(callback, *args, **kwargs) + self.pool.start(runnable) + self.futures.add(runnable.future) + return runnable.future + + def shutdown(self, wait=True, *, cancel_futures=False): + super().shutdown(wait=wait, cancel_futures=cancel_futures) + self._logger.debug("Shutting down") + if cancel_futures: + for future in self.futures: + future.cancel() + if wait: + for w in list(self.futures): + try: + w.result() + except Exception: + pass def _format_handle(handle: asyncio.Handle): diff --git a/tests/test_qthreadexec.py b/tests/test_qthreadexec.py index 67c1833..7363dfa 100644 --- a/tests/test_qthreadexec.py +++ b/tests/test_qthreadexec.py @@ -4,7 +4,10 @@ # BSD License import logging import threading +import time import weakref +from concurrent.futures import TimeoutError +from itertools import islice import pytest @@ -21,7 +24,11 @@ def disable_executor_logging(): To avoid issues with tests targeting stale references, we disable logging for QThreadExecutor and _QThreadWorker classes. """ - for cls in (qasync.QThreadExecutor, qasync._QThreadWorker): + for cls in ( + qasync.QThreadExecutor, + qasync._QThreadWorker, + qasync.QThreadPoolExecutor, + ): logger_name = cls.__qualname__ if cls.__module__ is not None: logger_name = f"{cls.__module__}.{logger_name}" @@ -30,16 +37,38 @@ def disable_executor_logging(): logger.propagate = False -@pytest.fixture +@pytest.fixture(params=[qasync.QThreadExecutor, qasync.QThreadPoolExecutor]) def executor(request): - exe = qasync.QThreadExecutor(5) - request.addfinalizer(exe.shutdown) + exe = get_executor(request) + request.addfinalizer(lambda: safe_shutdown(exe)) return exe -@pytest.fixture -def shutdown_executor(): - exe = qasync.QThreadExecutor(5) +def get_executor(request): + if request.param is qasync.QThreadPoolExecutor: + pool = qasync.QtCore.QThreadPool() + stack_size = qasync.QThreadExecutorBase.compute_stack_size() + if stack_size is not None: + pool.setStackSize(stack_size) + pool.setMaxThreadCount(5) + return request.param(pool) + else: + return request.param(5) + + +def safe_shutdown(executor): + try: + executor.shutdown() + except Exception: + pass + if isinstance(executor, qasync.QThreadPoolExecutor): + # empty the underlying QThreadPool object + executor.pool.waitForDone() + + +@pytest.fixture(params=[qasync.QThreadExecutor, qasync.QThreadPoolExecutor]) +def shutdown_executor(request): + exe = get_executor(request) exe.shutdown() return exe @@ -55,7 +84,7 @@ def test_ctx_after_shutdown(shutdown_executor): pass -def test_submit_after_shutdown(shutdown_executor): +def _test_submit_after_shutdown(shutdown_executor): with pytest.raises(RuntimeError): shutdown_executor.submit(None) @@ -64,6 +93,7 @@ def test_stack_recursion_limit(executor): # Test that worker threads have sufficient stack size for the default # sys.getrecursionlimit. If not this should fail with SIGSEGV or SIGBUS # (or event SIGILL?) + def rec(a, *args, **kwargs): rec(a, *args, **kwargs) @@ -104,3 +134,103 @@ def test_no_stale_reference_as_result(executor, disable_executor_logging): assert collected is True, ( "Stale reference to executor result not collected within timeout." ) + + +def test_map(executor): + """Basic test of executor map functionality""" + results = list(executor.map(lambda x: x + 1, range(10))) + assert results == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + + results = list(executor.map(lambda x, y: x + y, range(10), range(9))) + assert results == [0, 2, 4, 6, 8, 10, 12, 14, 16] + + +def test_map_timeout(executor): + """Test that map with timeout raises TimeoutError and cancels futures""" + results = [] + + def func(x): + nonlocal results + time.sleep(0.05) + results.append(x) + return x + + start = time.monotonic() + with pytest.raises(TimeoutError): + list(executor.map(func, range(10), timeout=0.01)) + duration = time.monotonic() - start + assert duration < 0.05 + + executor.shutdown(wait=True) + # only about half of the tasks should have completed + # because the max number of workers is 5 and the rest of + # the tasks were not started at the time of the cancel. + assert set(results) != {0, 1, 2, 3, 4, 5, 6, 7, 8, 9} + + +def test_map_error(executor): + """Test that map with an exception will raise, and remaining tasks are cancelled""" + results = [] + + def func(x): + nonlocal results + time.sleep(0.05) + if len(results) == 5: + raise ValueError("Test error") + results.append(x) + return x + + with pytest.raises(ValueError): + list(executor.map(func, range(15))) + + executor.shutdown(wait=True, cancel_futures=False) + assert len(results) <= 10, "Final 5 at least should have been cancelled" + + +@pytest.mark.parametrize("cancel", [True, False]) +def test_map_shutdown(executor, cancel): + results = [] + + def func(x): + nonlocal results + time.sleep(0.05) + results.append(x) + return x + + # Get the first few results. + # Keep the iterator alive so that it isn't closed when its reference is dropped. + m = executor.map(func, range(15)) + values = list(islice(m, 5)) + assert values == [0, 1, 2, 3, 4] + + executor.shutdown(wait=True, cancel_futures=cancel) + if cancel: + assert len(results) < 15, "Some tasks should have been cancelled" + else: + assert len(results) == 15, "All tasks should have been completed" + + +def test_map_start(executor): + """Test that map starts tasks immediately, before iterating""" + e = threading.Event() + m = executor.map(lambda x: (e.set(), x), range(1)) + e.wait(timeout=0.1) + assert list(m) == [(None, 0)] + + +def test_context(executor): + """Test that the context manager will shutdown executor""" + with executor: + f = executor.submit(lambda: 42) + assert f.result() == 42 + + with pytest.raises(RuntimeError): + executor.submit(lambda: 42) + + +def test_default_pool_executor(): + """Test that using the global instance of QThreadPool works""" + with qasync.QThreadPoolExecutor() as executor: + f = executor.submit(lambda: 42) + assert f.result() == 42 + executor.pool.waitForDone()