Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 60 additions & 27 deletions src/qasync/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import time
from concurrent.futures import Future
from queue import Queue
from threading import Lock
from typing import TYPE_CHECKING, Literal, Tuple, cast, get_args

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -172,51 +173,83 @@ def __init__(self, max_workers=10, stack_size=None):
self.__workers = [
_QThreadWorker(self.__queue, i + 1, stack_size) for i in range(max_workers)
]
self.__shutdown_lock = Lock()
self.__been_shutdown = False

for w in self.__workers:
w.start()

def submit(self, callback, *args, **kwargs):
if self.__been_shutdown:
raise RuntimeError("QThreadExecutor has been shutdown")
with self.__shutdown_lock:
if self.__been_shutdown:
raise RuntimeError("QThreadExecutor has been shutdown")

future = Future()
self._logger.debug(
"Submitting callback %s with args %s and kwargs %s to thread worker queue",
callback,
args,
kwargs,
)
self.__queue.put((future, callback, args, kwargs))
return future
future = Future()
self._logger.debug(
"Submitting callback %s with args %s and kwargs %s to thread worker queue",
callback,
args,
kwargs,
)
self.__queue.put((future, callback, args, kwargs))
return future

def map(self, func, *iterables, timeout=None):
raise NotImplementedError("use as_completed on the event loop")

def shutdown(self, wait=True):
if self.__been_shutdown:
raise RuntimeError("QThreadExecutor has been shutdown")

self.__been_shutdown = True
deadline = time.monotonic() + timeout if timeout is not None else None
futures = [self.submit(func, *args) for args in zip(*iterables)]

self._logger.debug("Shutting down")
for i in range(len(self.__workers)):
# Signal workers to stop
self.__queue.put(None)
if wait:
for w in self.__workers:
w.wait()
# must have generator as a closure so that the submit occurs before first iteration
def generator():
try:
futures.reverse()
while futures:
if deadline is not None:
yield _result_or_cancel(
futures.pop(), timeout=deadline - time.monotonic()
)
else:
yield _result_or_cancel(futures.pop())
finally:
for future in futures:
future.cancel()

return generator()

def shutdown(self, wait=True, *, cancel_futures=False):
with self.__shutdown_lock:
self.__been_shutdown = True
self._logger.debug("Shutting down")
if cancel_futures:
# pop all the futures and cancel them
while not self.__queue.empty():
item = self.__queue.get_nowait()
if item is not None:
future, _, _, _ = item
future.cancel()
for i in range(len(self.__workers)):
# Signal workers to stop
self.__queue.put(None)
if wait:
for w in self.__workers:
w.wait()

def __enter__(self, *args):
if self.__been_shutdown:
raise RuntimeError("QThreadExecutor has been shutdown")
return self

def __exit__(self, *args):
self.shutdown()


def _result_or_cancel(fut, timeout=None):
try:
try:
return fut.result(timeout)
finally:
fut.cancel()
finally:
del fut # break reference cycle in exceptions


def _format_handle(handle: asyncio.Handle):
cb = getattr(handle, "_callback", None)
if isinstance(getattr(cb, "__self__", None), asyncio.tasks.Task):
Expand Down
167 changes: 161 additions & 6 deletions tests/test_qthreadexec.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import logging
import threading
import weakref
from concurrent.futures import Future, TimeoutError
from unittest.mock import Mock, patch

import pytest

Expand Down Expand Up @@ -44,15 +46,28 @@ def shutdown_executor():
return exe


def test_shutdown_after_shutdown(shutdown_executor):
with pytest.raises(RuntimeError):
shutdown_executor.shutdown()
@pytest.fixture
def executor0():
"""
Provides a QThreadExecutor with max_workers=0 for deterministic testing.
"""
executor = qasync.QThreadExecutor(max_workers=0)
try:
yield executor
finally:
executor.shutdown(wait=True, cancel_futures=False)


@pytest.mark.parametrize("wait", [True, False])
def test_shutdown_after_shutdown(shutdown_executor, wait):
# it is safe to shutdown twice
shutdown_executor.shutdown(wait=wait)


def test_ctx_after_shutdown(shutdown_executor):
with pytest.raises(RuntimeError):
with shutdown_executor:
pass
# it is safe to enter and exit the context after shutdown
with shutdown_executor:
pass


def test_submit_after_shutdown(shutdown_executor):
Expand Down Expand Up @@ -104,3 +119,143 @@ def test_no_stale_reference_as_result(executor, disable_executor_logging):
assert collected is True, (
"Stale reference to executor result not collected within timeout."
)


def test_context(executor):
"""Test that the context manager will shutdown executor"""
with executor:
f = executor.submit(lambda: 42)
assert f.result() == 42

# it can be entered again
with executor:
# but will fail when we submit
with pytest.raises(RuntimeError):
executor.submit(lambda: 42)


@pytest.mark.parametrize("cancel", [True, False])
def test_shutdown_cancel_futures(executor0, cancel):
"""Test that shutdown with cancel_futures=True cancels all remaining futures in the queue."""

futures = [executor0.submit(lambda: None) for _ in range(10)]

# Shutdown with cancel_futures parameter
executor0.shutdown(wait=False, cancel_futures=cancel)

if cancel:
# All futures should be cancelled since no workers consumed them
cancelled_count = sum(1 for f in futures if f.cancelled())
assert cancelled_count == 10, (
f"Expected all 10 futures to be cancelled, got {cancelled_count}"
)
else:
# No futures should be cancelled, they should still be pending
cancelled_count = sum(1 for f in futures if f.cancelled())
assert cancelled_count == 0, (
f"Expected no futures to be cancelled, got {cancelled_count}"
)


def test_map(executor):
"""Basic test of executor map functionality"""
results = list(executor.map(lambda x: x + 1, range(10)))
assert results == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]

results = list(executor.map(lambda x, y: x + y, range(10), range(9)))
assert results == [0, 2, 4, 6, 8, 10, 12, 14, 16]


def test_map_timeout(executor0):
"""Test that map with timeout propagates the timeout parameter to future.result()"""

f = Mock(spec=Future)
f.result = Mock(side_effect=TimeoutError("Timeout"))
f.cancel = Mock(return_value=True)

with patch.object(executor0, "submit", return_value=f):
with pytest.raises(TimeoutError, match="Timeout"):
list(executor0.map(lambda x: x, [1], timeout=0.5))

# Verify the timeout parameter was passed to result() (not None)
# Note: The timeout is calculated as (deadline - time.monotonic()), so it will be
# slightly less than 0.5 due to the time taken to submit futures and start iteration
assert f.result.called
f_timeout = f.result.call_args[0][0] if f.result.call_args[0] else None
assert f_timeout is not None
assert f_timeout <= 0.5


def test_map_error(executor0):
"""Test that map with an exception will raise, and remaining tasks are cancelled"""

# Create 3 futures: one success, one exception, one to be cancelled
mock_futures = []

# First future succeeds
f0 = Mock(spec=Future)
f0.result = Mock(return_value=0)
f0.cancel = Mock(return_value=True)
mock_futures.append(f0)

# Second future raises an exception
f1 = Future()
f1.set_exception(ValueError("Test error"))
mock_futures.append(f1)

# Third future should be cancelled
f2 = Mock(spec=Future)
f2.result = Mock(return_value=2)
f2.cancel = Mock(return_value=True)
mock_futures.append(f2)

with patch.object(executor0, "submit", side_effect=mock_futures):
with pytest.raises(ValueError, match="Test error"):
list(executor0.map(lambda x: x, range(3)))

# Verify the third future was cancelled when the exception occurred
assert f2.cancel.called, "Future after exception should have been cancelled"


def test_map_start(executor0):
"""Test that map starts tasks immediately, before iterating"""

# Mock future that returns immediately
mock_future = Mock(spec=Future)
mock_future.result = Mock(return_value=0)
mock_future.cancel = Mock(return_value=True)

with patch.object(executor0, "submit", return_value=mock_future) as mock_submit:
# Create the map - submit should be called immediately
m = executor0.map(lambda x: x, range(1))

# Verify submit was called before we start iterating
mock_submit.assert_called_once()

# Now iterate to verify the result
assert list(m) == [0]


def test_map_close(executor0):
"""Test that closing a running map cancels all remaining tasks."""

# Create mock futures with proper result() method
mock_futures = []
for i in range(10):
mock_future = Mock(spec=Future)
mock_future.cancel = Mock(return_value=True)
mock_future.result = Mock(return_value=i)
mock_futures.append(mock_future)

# Mock submit to return our pre-created futures
with patch.object(executor0, "submit", side_effect=mock_futures):
m = executor0.map(lambda x: x, range(10))
# must start the generator so that close() has any effect
assert next(m) == 0
m.close()

# All futures should have cancel() called:
# - The first one via _result_or_cancel after next() consumed it
# - The rest via the finally block when the generator is closed
for i, f in enumerate(mock_futures):
assert f.cancel.called, f"Future {i} should have been cancelled"
Loading