diff --git a/Include/internal/pycore_global_objects_fini_generated.h b/Include/internal/pycore_global_objects_fini_generated.h index 4b12ae523c3260..3c528f64e609e2 100644 --- a/Include/internal/pycore_global_objects_fini_generated.h +++ b/Include/internal/pycore_global_objects_fini_generated.h @@ -853,6 +853,7 @@ _PyStaticObjects_CheckRefcnt(PyInterpreterState *interp) { _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(copy)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(copyreg)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(coro)); + _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(coro_result)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(count)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(cwd)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(d)); diff --git a/Include/internal/pycore_global_strings.h b/Include/internal/pycore_global_strings.h index 17fb9ffbbf9f11..1eea159b49d710 100644 --- a/Include/internal/pycore_global_strings.h +++ b/Include/internal/pycore_global_strings.h @@ -339,6 +339,7 @@ struct _Py_global_strings { STRUCT_FOR_ID(copy) STRUCT_FOR_ID(copyreg) STRUCT_FOR_ID(coro) + STRUCT_FOR_ID(coro_result) STRUCT_FOR_ID(count) STRUCT_FOR_ID(cwd) STRUCT_FOR_ID(d) diff --git a/Include/internal/pycore_runtime_init_generated.h b/Include/internal/pycore_runtime_init_generated.h index b240be57369d9d..a88e5c8eb58698 100644 --- a/Include/internal/pycore_runtime_init_generated.h +++ b/Include/internal/pycore_runtime_init_generated.h @@ -845,6 +845,7 @@ extern "C" { INIT_ID(copy), \ INIT_ID(copyreg), \ INIT_ID(coro), \ + INIT_ID(coro_result), \ INIT_ID(count), \ INIT_ID(cwd), \ INIT_ID(d), \ diff --git a/Include/internal/pycore_unicodeobject_generated.h b/Include/internal/pycore_unicodeobject_generated.h index fea9b6dbb1a75f..7ffeed7bb67f4b 100644 --- a/Include/internal/pycore_unicodeobject_generated.h +++ b/Include/internal/pycore_unicodeobject_generated.h @@ -870,6 +870,9 @@ _PyUnicode_InitStaticStrings(void) { string = &_Py_ID(coro); assert(_PyUnicode_CheckConsistency(string, 1)); PyUnicode_InternInPlace(&string); + string = &_Py_ID(coro_result); + assert(_PyUnicode_CheckConsistency(string, 1)); + PyUnicode_InternInPlace(&string); string = &_Py_ID(count); assert(_PyUnicode_CheckConsistency(string, 1)); PyUnicode_InternInPlace(&string); diff --git a/Lib/asyncio/__init__.py b/Lib/asyncio/__init__.py index fed16ec7c67fac..e0f24efde8f494 100644 --- a/Lib/asyncio/__init__.py +++ b/Lib/asyncio/__init__.py @@ -36,7 +36,40 @@ tasks.__all__ + threads.__all__ + timeouts.__all__ + - transports.__all__) + transports.__all__ + ( + 'create_eager_task_factory', + 'eager_task_factory', + )) + +# throwing things here temporarily to defer premature dir layout bikeshedding + +def create_eager_task_factory(custom_task_constructor): + + def factory(loop, coro, *, name=None, context=None): + loop._check_closed() + if not loop.is_running(): + return custom_task_constructor(coro, loop=loop, name=name, context=context) + + try: + result = coro.send(None) + except StopIteration as si: + fut = loop.create_future() + fut.set_result(si.value) + return fut + except Exception as ex: + fut = loop.create_future() + fut.set_exception(ex) + return fut + else: + task = custom_task_constructor( + coro, loop=loop, name=name, context=context, coro_result=result) + if task._source_traceback: + del task._source_traceback[-1] + return task + + return factory + +eager_task_factory = create_eager_task_factory(Task) if sys.platform == 'win32': # pragma: no cover from .windows_events import * diff --git a/Lib/asyncio/runners.py b/Lib/asyncio/runners.py index 1b89236599aad7..1d70f2c013befe 100644 --- a/Lib/asyncio/runners.py +++ b/Lib/asyncio/runners.py @@ -45,10 +45,11 @@ class Runner: # Note: the class is final, it is not intended for inheritance. - def __init__(self, *, debug=None, loop_factory=None): + def __init__(self, *, debug=None, loop_factory=None, task_factory=None): self._state = _State.CREATED self._debug = debug self._loop_factory = loop_factory + self._task_factory = task_factory self._loop = None self._context = None self._interrupt_count = 0 @@ -144,6 +145,8 @@ def _lazy_init(self): self._loop = self._loop_factory() if self._debug is not None: self._loop.set_debug(self._debug) + if self._task_factory is not None: + self._loop.set_task_factory(self._task_factory) self._context = contextvars.copy_context() self._state = _State.INITIALIZED @@ -157,7 +160,7 @@ def _on_sigint(self, signum, frame, main_task): raise KeyboardInterrupt() -def run(main, *, debug=None, loop_factory=None): +def run(main, *, debug=None, loop_factory=None, task_factory=None): """Execute the coroutine and return the result. This function runs the passed coroutine, taking care of @@ -190,7 +193,7 @@ async def main(): raise RuntimeError( "asyncio.run() cannot be called from a running event loop") - with Runner(debug=debug, loop_factory=loop_factory) as runner: + with Runner(debug=debug, loop_factory=loop_factory, task_factory=task_factory) as runner: return runner.run(main) diff --git a/Lib/asyncio/taskgroups.py b/Lib/asyncio/taskgroups.py index 0fdea3697ece3d..70bc2c405fecaf 100644 --- a/Lib/asyncio/taskgroups.py +++ b/Lib/asyncio/taskgroups.py @@ -163,7 +163,8 @@ def create_task(self, coro, *, name=None, context=None): task = self._loop.create_task(coro) else: task = self._loop.create_task(coro, context=context) - tasks._set_task_name(task, name) + if name is not None and not task.done(): # If it's done already, it's a future + tasks._set_task_name(task, name) task.add_done_callback(self._on_task_done) self._tasks.add(task) return task diff --git a/Lib/asyncio/tasks.py b/Lib/asyncio/tasks.py index 1c20754b839b69..5124af082f220e 100644 --- a/Lib/asyncio/tasks.py +++ b/Lib/asyncio/tasks.py @@ -76,6 +76,8 @@ def _set_task_name(task, name): set_name(name) +_NOT_SET = object() + class Task(futures._PyFuture): # Inherit Python Task implementation # from a Python Future implementation. @@ -94,7 +96,8 @@ class Task(futures._PyFuture): # Inherit Python Task implementation # status is still pending _log_destroy_pending = True - def __init__(self, coro, *, loop=None, name=None, context=None): + def __init__(self, coro, *, loop=None, name=None, context=None, + coro_result=_NOT_SET): super().__init__(loop=loop) if self._source_traceback: del self._source_traceback[-1] @@ -118,7 +121,10 @@ def __init__(self, coro, *, loop=None, name=None, context=None): else: self._context = context - self._loop.call_soon(self.__step, context=self._context) + if coro_result is _NOT_SET: + self._loop.call_soon(self.__step, context=self._context) + else: + self.__step_handle_result(coro_result) _register_task(self) def __del__(self): @@ -288,55 +294,58 @@ def __step(self, exc=None): except BaseException as exc: super().set_exception(exc) else: - blocking = getattr(result, '_asyncio_future_blocking', None) - if blocking is not None: + self.__step_handle_result(result) + finally: + _leave_task(self._loop, self) + self = None # Needed to break cycles when an exception occurs. + + def __step_handle_result(self, result): + blocking = getattr(result, '_asyncio_future_blocking', None) + if blocking is not None: # Yielded Future must come from Future.__iter__(). - if futures._get_loop(result) is not self._loop: + if futures._get_loop(result) is not self._loop: + new_exc = RuntimeError( + f'Task {self!r} got Future ' + f'{result!r} attached to a different loop') + self._loop.call_soon( + self.__step, new_exc, context=self._context) + elif blocking: + if result is self: new_exc = RuntimeError( - f'Task {self!r} got Future ' - f'{result!r} attached to a different loop') + f'Task cannot await on itself: {self!r}') self._loop.call_soon( self.__step, new_exc, context=self._context) - elif blocking: - if result is self: - new_exc = RuntimeError( - f'Task cannot await on itself: {self!r}') - self._loop.call_soon( - self.__step, new_exc, context=self._context) - else: - result._asyncio_future_blocking = False - result.add_done_callback( - self.__wakeup, context=self._context) - self._fut_waiter = result - if self._must_cancel: - if self._fut_waiter.cancel( - msg=self._cancel_message): - self._must_cancel = False else: - new_exc = RuntimeError( - f'yield was used instead of yield from ' - f'in task {self!r} with {result!r}') - self._loop.call_soon( - self.__step, new_exc, context=self._context) - - elif result is None: - # Bare yield relinquishes control for one event loop iteration. - self._loop.call_soon(self.__step, context=self._context) - elif inspect.isgenerator(result): - # Yielding a generator is just wrong. - new_exc = RuntimeError( - f'yield was used instead of yield from for ' - f'generator in task {self!r} with {result!r}') - self._loop.call_soon( - self.__step, new_exc, context=self._context) + result._asyncio_future_blocking = False + result.add_done_callback( + self.__wakeup, context=self._context) + self._fut_waiter = result + if self._must_cancel: + if self._fut_waiter.cancel( + msg=self._cancel_message): + self._must_cancel = False else: - # Yielding something else is an error. - new_exc = RuntimeError(f'Task got bad yield: {result!r}') + new_exc = RuntimeError( + f'yield was used instead of yield from ' + f'in task {self!r} with {result!r}') self._loop.call_soon( self.__step, new_exc, context=self._context) - finally: - _leave_task(self._loop, self) - self = None # Needed to break cycles when an exception occurs. + + elif result is None: + # Bare yield relinquishes control for one event loop iteration. + self._loop.call_soon(self.__step, context=self._context) + elif inspect.isgenerator(result): + # Yielding a generator is just wrong. + new_exc = RuntimeError( + f'yield was used instead of yield from for ' + f'generator in task {self!r} with {result!r}') + self._loop.call_soon( + self.__step, new_exc, context=self._context) + else: + # Yielding something else is an error. + new_exc = RuntimeError(f'Task got bad yield: {result!r}') + self._loop.call_soon( + self.__step, new_exc, context=self._context) def __wakeup(self, future): try: diff --git a/Lib/test/test_asyncio/test_eager_task_factory.py b/Lib/test/test_asyncio/test_eager_task_factory.py new file mode 100644 index 00000000000000..d62302f5558f05 --- /dev/null +++ b/Lib/test/test_asyncio/test_eager_task_factory.py @@ -0,0 +1,498 @@ +"""Tests for base_events.py""" + +import time +import unittest +from unittest import mock + +import asyncio +from asyncio import base_events +from asyncio import tasks +from test.test_asyncio import utils as test_utils +from test import support + +MOCK_ANY = mock.ANY + + +def tearDownModule(): + asyncio.set_event_loop_policy(None) + + +class EagerTaskFactoryLoopTests(test_utils.TestCase): + + def setUp(self): + super().setUp() + self.loop = asyncio.new_event_loop() + self.loop.set_task_factory(asyncio.eager_task_factory) + self.set_event_loop(self.loop) + + def test_eager_task_factory_set(self): + self.assertIs(self.loop.get_task_factory(), asyncio.eager_task_factory) + + def test_close(self): + self.assertFalse(self.loop.is_closed()) + self.loop.close() + self.assertTrue(self.loop.is_closed()) + + # it should be possible to call close() more than once + self.loop.close() + self.loop.close() + + # operation blocked when the loop is closed + f = self.loop.create_future() + self.assertRaises(RuntimeError, self.loop.run_forever) + self.assertRaises(RuntimeError, self.loop.run_until_complete, f) + + def test__add_callback_handle(self): + h = asyncio.Handle(lambda: False, (), self.loop, None) + + self.loop._add_callback(h) + self.assertFalse(self.loop._scheduled) + self.assertIn(h, self.loop._ready) + + def test__add_callback_cancelled_handle(self): + h = asyncio.Handle(lambda: False, (), self.loop, None) + h.cancel() + + self.loop._add_callback(h) + self.assertFalse(self.loop._scheduled) + self.assertFalse(self.loop._ready) + + def test_call_soon(self): + def cb(): + pass + + h = self.loop.call_soon(cb) + self.assertEqual(h._callback, cb) + self.assertIsInstance(h, asyncio.Handle) + self.assertIn(h, self.loop._ready) + + def test_call_soon_non_callable(self): + self.loop.set_debug(True) + with self.assertRaisesRegex(TypeError, 'a callable object'): + self.loop.call_soon(1) + + def test_call_later(self): + def cb(): + pass + + h = self.loop.call_later(10.0, cb) + self.assertIsInstance(h, asyncio.TimerHandle) + self.assertIn(h, self.loop._scheduled) + self.assertNotIn(h, self.loop._ready) + with self.assertRaises(TypeError, msg="delay must not be None"): + self.loop.call_later(None, cb) + + def test_call_later_negative_delays(self): + calls = [] + + def cb(arg): + calls.append(arg) + + self.loop._process_events = mock.Mock() + self.loop.call_later(-1, cb, 'a') + self.loop.call_later(-2, cb, 'b') + test_utils.run_briefly(self.loop) + self.assertEqual(calls, ['b', 'a']) + + def test_time_and_call_at(self): + def cb(): + self.loop.stop() + + self.loop._process_events = mock.Mock() + delay = 0.1 + + when = self.loop.time() + delay + self.loop.call_at(when, cb) + t0 = self.loop.time() + self.loop.run_forever() + dt = self.loop.time() - t0 + + # 50 ms: maximum granularity of the event loop + self.assertGreaterEqual(dt, delay - 0.050, dt) + # tolerate a difference of +800 ms because some Python buildbots + # are really slow + self.assertLessEqual(dt, 0.9, dt) + with self.assertRaises(TypeError, msg="when cannot be None"): + self.loop.call_at(None, cb) + + def test_run_until_complete_loop(self): + task = self.loop.create_future() + other_loop = self.new_test_loop() + self.addCleanup(other_loop.close) + self.assertRaises(ValueError, + other_loop.run_until_complete, task) + + def test_run_until_complete_loop_orphan_future_close_loop(self): + class ShowStopper(SystemExit): + pass + + async def foo(delay): + await asyncio.sleep(delay) + + def throw(): + raise ShowStopper + + self.loop._process_events = mock.Mock() + self.loop.call_soon(throw) + with self.assertRaises(ShowStopper): + self.loop.run_until_complete(foo(0.1)) + + # This call fails if run_until_complete does not clean up + # done-callback for the previous future. + self.loop.run_until_complete(foo(0.2)) + + def test_default_exc_handler_callback(self): + self.loop._process_events = mock.Mock() + + def zero_error(fut): + fut.set_result(True) + 1/0 + + # Test call_soon (events.Handle) + with mock.patch('asyncio.base_events.logger') as log: + fut = self.loop.create_future() + self.loop.call_soon(zero_error, fut) + fut.add_done_callback(lambda fut: self.loop.stop()) + self.loop.run_forever() + log.error.assert_called_with( + test_utils.MockPattern('Exception in callback.*zero'), + exc_info=(ZeroDivisionError, MOCK_ANY, MOCK_ANY)) + + # Test call_later (events.TimerHandle) + with mock.patch('asyncio.base_events.logger') as log: + fut = self.loop.create_future() + self.loop.call_later(0.01, zero_error, fut) + fut.add_done_callback(lambda fut: self.loop.stop()) + self.loop.run_forever() + log.error.assert_called_with( + test_utils.MockPattern('Exception in callback.*zero'), + exc_info=(ZeroDivisionError, MOCK_ANY, MOCK_ANY)) + + def test_default_exc_handler_coro(self): + self.loop._process_events = mock.Mock() + + async def zero_error_coro(): + await asyncio.sleep(0.01) + 1/0 + + # Test Future.__del__ + with mock.patch('asyncio.base_events.logger') as log: + fut = asyncio.ensure_future(zero_error_coro(), loop=self.loop) + fut.add_done_callback(lambda *args: self.loop.stop()) + self.loop.run_forever() + fut = None # Trigger Future.__del__ or futures._TracebackLogger + support.gc_collect() + # Future.__del__ in logs error with an actual exception context + log.error.assert_called_with( + test_utils.MockPattern('.*exception was never retrieved'), + exc_info=(ZeroDivisionError, MOCK_ANY, MOCK_ANY)) + + def test_set_exc_handler_invalid(self): + with self.assertRaisesRegex(TypeError, 'A callable object or None'): + self.loop.set_exception_handler('spam') + + def test_set_exc_handler_custom(self): + def zero_error(): + 1/0 + + def run_loop(): + handle = self.loop.call_soon(zero_error) + self.loop._run_once() + return handle + + self.loop.set_debug(True) + self.loop._process_events = mock.Mock() + + self.assertIsNone(self.loop.get_exception_handler()) + mock_handler = mock.Mock() + self.loop.set_exception_handler(mock_handler) + self.assertIs(self.loop.get_exception_handler(), mock_handler) + handle = run_loop() + mock_handler.assert_called_with(self.loop, { + 'exception': MOCK_ANY, + 'message': test_utils.MockPattern( + 'Exception in callback.*zero_error'), + 'handle': handle, + 'source_traceback': handle._source_traceback, + }) + mock_handler.reset_mock() + + self.loop.set_exception_handler(None) + with mock.patch('asyncio.base_events.logger') as log: + run_loop() + log.error.assert_called_with( + test_utils.MockPattern( + 'Exception in callback.*zero'), + exc_info=(ZeroDivisionError, MOCK_ANY, MOCK_ANY)) + + self.assertFalse(mock_handler.called) + + def test_set_exc_handler_broken(self): + def run_loop(): + def zero_error(): + 1/0 + self.loop.call_soon(zero_error) + self.loop._run_once() + + def handler(loop, context): + raise AttributeError('spam') + + self.loop._process_events = mock.Mock() + + self.loop.set_exception_handler(handler) + + with mock.patch('asyncio.base_events.logger') as log: + run_loop() + log.error.assert_called_with( + test_utils.MockPattern( + 'Unhandled error in exception handler'), + exc_info=(AttributeError, MOCK_ANY, MOCK_ANY)) + + def test_default_exc_handler_broken(self): + _context = None + + class Loop(base_events.BaseEventLoop): + + _selector = mock.Mock() + _process_events = mock.Mock() + + def default_exception_handler(self, context): + nonlocal _context + _context = context + # Simulates custom buggy "default_exception_handler" + raise ValueError('spam') + + loop = Loop() + self.addCleanup(loop.close) + asyncio.set_event_loop(loop) + + def run_loop(): + def zero_error(): + 1/0 + loop.call_soon(zero_error) + loop._run_once() + + with mock.patch('asyncio.base_events.logger') as log: + run_loop() + log.error.assert_called_with( + 'Exception in default exception handler', + exc_info=True) + + def custom_handler(loop, context): + raise ValueError('ham') + + _context = None + loop.set_exception_handler(custom_handler) + with mock.patch('asyncio.base_events.logger') as log: + run_loop() + log.error.assert_called_with( + test_utils.MockPattern('Exception in default exception.*' + 'while handling.*in custom'), + exc_info=True) + + # Check that original context was passed to default + # exception handler. + self.assertIn('context', _context) + self.assertIs(type(_context['context']['exception']), + ZeroDivisionError) + + def test_eager_task_factory_with_custom_task_ctor(self): + + class MyTask(asyncio.Task): + pass + + async def coro(): + pass + + factory = asyncio.create_eager_task_factory(MyTask) + + self.loop.set_task_factory(factory) + self.assertIs(self.loop.get_task_factory(), factory) + + task = self.loop.create_task(coro()) + self.assertTrue(isinstance(task, MyTask)) + self.loop.run_until_complete(task) + + def test_create_named_task(self): + async def test(): + pass + + task = self.loop.create_task(test(), name='test_task') + try: + self.assertEqual(task.get_name(), 'test_task') + finally: + self.loop.run_until_complete(task) + + def test_run_forever_keyboard_interrupt(self): + # Python issue #22601: ensure that the temporary task created by + # run_forever() consumes the KeyboardInterrupt and so don't log + # a warning + async def raise_keyboard_interrupt(): + raise KeyboardInterrupt + + self.loop._process_events = mock.Mock() + self.loop.call_exception_handler = mock.Mock() + + try: + self.loop.run_until_complete(raise_keyboard_interrupt()) + except KeyboardInterrupt: + pass + self.loop.close() + support.gc_collect() + + self.assertFalse(self.loop.call_exception_handler.called) + + def test_run_until_complete_baseexception(self): + # Python issue #22429: run_until_complete() must not schedule a pending + # call to stop() if the future raised a BaseException + async def raise_keyboard_interrupt(): + raise KeyboardInterrupt + + self.loop._process_events = mock.Mock() + + with self.assertRaises(KeyboardInterrupt): + self.loop.run_until_complete(raise_keyboard_interrupt()) + + def func(): + self.loop.stop() + func.called = True + func.called = False + self.loop.call_soon(self.loop.call_soon, func) + self.loop.run_forever() + self.assertTrue(func.called) + + def test_run_once(self): + # Simple test for test_utils.run_once(). It may seem strange + # to have a test for this (the function isn't even used!) but + # it's a de-factor standard API for library tests. This tests + # the idiom: loop.call_soon(loop.stop); loop.run_forever(). + count = 0 + + def callback(): + nonlocal count + count += 1 + + self.loop._process_events = mock.Mock() + self.loop.call_soon(callback) + test_utils.run_once(self.loop) + self.assertEqual(count, 1) + + +class AsyncTaskCounter: + def __init__(self, loop, *, task_class, eager): + self.suspense_count = 0 + self.task_count = 0 + + def CountingTask(*args, **kwargs): + self.task_count += 1 + return task_class(*args, **kwargs) + + if eager: + factory = asyncio.create_eager_task_factory(CountingTask) + else: + def factory(loop, coro, **kwargs): + return CountingTask(coro, loop=loop, **kwargs) + loop.set_task_factory(factory) + + def get(self): + return self.task_count + + +async def awaitable_chain(depth): + if depth == 0: + return 0 + return 1 + await awaitable_chain(depth - 1) + + +async def recursive_taskgroups(width, depth): + if depth == 0: + return 0 + + async with asyncio.TaskGroup() as tg: + futures = [ + tg.create_task(recursive_taskgroups(width, depth - 1)) + for _ in range(width) + ] + return sum( + (1 if isinstance(fut, (asyncio.Task, tasks._CTask, tasks._PyTask)) else 0) + + fut.result() + for fut in futures + ) + + +async def recursive_gather(width, depth): + if depth == 0: + return + + await asyncio.gather( + *[recursive_gather(width, depth - 1) for _ in range(width)] + ) + + +class BaseTaskCountingTests: + + Task = None + eager = None + expected_task_count = None + + def setUp(self): + super().setUp() + self.loop = asyncio.new_event_loop() + self.counter = AsyncTaskCounter(self.loop, task_class=self.Task, eager=self.eager) + self.set_event_loop(self.loop) + + def test_awaitables_chain(self): + observed_depth = self.loop.run_until_complete(awaitable_chain(100)) + self.assertEqual(observed_depth, 100) + self.assertEqual(self.counter.get(), 1) + + def test_recursive_taskgroups(self): + num_tasks = self.loop.run_until_complete(recursive_taskgroups(5, 4)) + self.assertEqual(num_tasks, self.expected_task_count - 1) # 5 + 5^2 + 5^3 + 5^4 + self.assertEqual(self.counter.get(), self.expected_task_count) # 1 + ^^ + + def test_recursive_gather(self): + self.loop.run_until_complete(recursive_gather(5, 4)) + self.assertEqual(self.counter.get(), self.expected_task_count) # 1 + 5 + 5^2 + 5^3 + 5^4 + + +class BaseNonEagerTaskFactoryTests(BaseTaskCountingTests): + eager = False + expected_task_count = 781 # 1 + 5 + 5^2 + 5^3 + 5^4 + + +class BaseEagerTaskFactoryTests(BaseTaskCountingTests): + eager = True + expected_task_count = 156 # 1 + 5 + 5^2 + 5^3 + + +class NonEagerTests(BaseNonEagerTaskFactoryTests, test_utils.TestCase): + Task = asyncio.Task + + +class EagerTests(BaseEagerTaskFactoryTests, test_utils.TestCase): + Task = asyncio.Task + + +class NonEagerPyTaskTests(BaseNonEagerTaskFactoryTests, test_utils.TestCase): + Task = tasks._PyTask + + +class EagerPyTaskTests(BaseEagerTaskFactoryTests, test_utils.TestCase): + Task = tasks._PyTask + + +@unittest.skipUnless(hasattr(tasks, '_CTask'), + 'requires the C _asyncio module') +class NonEagerCTaskTests(BaseNonEagerTaskFactoryTests, test_utils.TestCase): + Task = getattr(tasks, '_CTask', None) + + +@unittest.skipUnless(hasattr(tasks, '_CTask'), + 'requires the C _asyncio module') +class EagerCTaskTests(BaseEagerTaskFactoryTests, test_utils.TestCase): + Task = getattr(tasks, '_CTask', None) + + +if __name__ == '__main__': + unittest.main() diff --git a/Misc/NEWS.d/next/Library/2023-03-15-12-18-07.gh-issue-97696.DtnpIC.rst b/Misc/NEWS.d/next/Library/2023-03-15-12-18-07.gh-issue-97696.DtnpIC.rst new file mode 100644 index 00000000000000..eb1861b4e5aaaf --- /dev/null +++ b/Misc/NEWS.d/next/Library/2023-03-15-12-18-07.gh-issue-97696.DtnpIC.rst @@ -0,0 +1,4 @@ +Implemented an eager task factory in asyncio. When set as a task factory on +an event loop, it performs eager execution of coroutines and returns a +completed future instead of scheduling a task to the event loop if the +coroutine can complete without suspending. diff --git a/Modules/_asynciomodule.c b/Modules/_asynciomodule.c index 13d98eedf32f0e..edb2016be98acf 100644 --- a/Modules/_asynciomodule.c +++ b/Modules/_asynciomodule.c @@ -156,6 +156,9 @@ class _asyncio.Future "FutureObj *" "&Future_Type" /* Get FutureIter from Future */ static PyObject * future_new_iter(PyObject *); +static PyObject * +task_step_handle_result_impl(asyncio_state *state, TaskObj *task, PyObject *result); + static int _is_coroutine(asyncio_state *state, PyObject *coro) @@ -2032,15 +2035,16 @@ _asyncio.Task.__init__ loop: object = None name: object = None context: object = None + coro_result: object = NULL A coroutine wrapped in a Future. [clinic start generated code]*/ static int _asyncio_Task___init___impl(TaskObj *self, PyObject *coro, PyObject *loop, - PyObject *name, PyObject *context) -/*[clinic end generated code: output=49ac96fe33d0e5c7 input=924522490c8ce825]*/ - + PyObject *name, PyObject *context, + PyObject *coro_result) +/*[clinic end generated code: output=e241855787412a77 input=3fcd7fb1c00d3f87]*/ { if (future_init((FutureObj*)self, loop)) { return -1; @@ -2088,8 +2092,18 @@ _asyncio_Task___init___impl(TaskObj *self, PyObject *coro, PyObject *loop, return -1; } - if (task_call_step_soon(state, self, NULL)) { - return -1; + if (coro_result == NULL) { + if (task_call_step_soon(state, self, NULL)) { + return -1; + } + } + else { + // TODO this is a sketchy incref... + // Py_INCREF(coro_result); + // TODO: check return value, error on NULL + // (but first let's add a test case that hits this) + PyObject * res = task_step_handle_result_impl(state, self, coro_result); + assert(res != NULL); } return register_task(state, (PyObject*)self); } @@ -2827,6 +2841,24 @@ task_step_impl(asyncio_state *state, TaskObj *task, PyObject *exc) Py_RETURN_NONE; } + PyObject *ret = task_step_handle_result_impl(state, task, result); + Py_XDECREF(result); + return ret; + +fail: + Py_XDECREF(result); + return NULL; +} + + +static PyObject * +task_step_handle_result_impl(asyncio_state *state, TaskObj *task, PyObject *result) +{ + int res; + PyObject *o; + + // Py_INCREF(result); + if (result == (PyObject*)task) { /* We have a task that wants to await on itself */ goto self_await; @@ -2863,7 +2895,8 @@ task_step_impl(asyncio_state *state, TaskObj *task, PyObject *exc) Py_DECREF(tmp); /* task._fut_waiter = result */ - task->task_fut_waiter = result; /* no incref is necessary */ + Py_INCREF(result); + task->task_fut_waiter = result; if (task->task_must_cancel) { PyObject *r; @@ -2956,7 +2989,8 @@ task_step_impl(asyncio_state *state, TaskObj *task, PyObject *exc) Py_DECREF(tmp); /* task._fut_waiter = result */ - task->task_fut_waiter = result; /* no incref is necessary */ + Py_INCREF(result); + task->task_fut_waiter = result; if (task->task_must_cancel) { PyObject *r; @@ -2991,21 +3025,21 @@ task_step_impl(asyncio_state *state, TaskObj *task, PyObject *exc) state, task, PyExc_RuntimeError, "yield was used instead of yield from for " "generator in task %R with %R", task, result); - Py_DECREF(result); + // Py_DECREF(result); return o; } /* The `result` is none of the above */ o = task_set_error_soon( state, task, PyExc_RuntimeError, "Task got bad yield: %R", result); - Py_DECREF(result); + // Py_DECREF(result); return o; self_await: o = task_set_error_soon( state, task, PyExc_RuntimeError, "Task cannot await on itself: %R", task); - Py_DECREF(result); + // Py_DECREF(result); return o; yield_insteadof_yf: @@ -3014,7 +3048,7 @@ task_step_impl(asyncio_state *state, TaskObj *task, PyObject *exc) "yield was used instead of yield from " "in task %R with %R", task, result); - Py_DECREF(result); + // Py_DECREF(result); return o; different_loop: @@ -3022,11 +3056,11 @@ task_step_impl(asyncio_state *state, TaskObj *task, PyObject *exc) state, task, PyExc_RuntimeError, "Task %R got Future %R attached to a different loop", task, result); - Py_DECREF(result); + // Py_DECREF(result); return o; fail: - Py_XDECREF(result); + // Py_XDECREF(result); return NULL; } diff --git a/Modules/clinic/_asynciomodule.c.h b/Modules/clinic/_asynciomodule.c.h index 43c5d771798634..47a678b50784b1 100644 --- a/Modules/clinic/_asynciomodule.c.h +++ b/Modules/clinic/_asynciomodule.c.h @@ -482,14 +482,16 @@ _asyncio_Future__make_cancelled_error(FutureObj *self, PyObject *Py_UNUSED(ignor } PyDoc_STRVAR(_asyncio_Task___init____doc__, -"Task(coro, *, loop=None, name=None, context=None)\n" +"Task(coro, *, loop=None, name=None, context=None,\n" +" coro_result=)\n" "--\n" "\n" "A coroutine wrapped in a Future."); static int _asyncio_Task___init___impl(TaskObj *self, PyObject *coro, PyObject *loop, - PyObject *name, PyObject *context); + PyObject *name, PyObject *context, + PyObject *coro_result); static int _asyncio_Task___init__(PyObject *self, PyObject *args, PyObject *kwargs) @@ -497,14 +499,14 @@ _asyncio_Task___init__(PyObject *self, PyObject *args, PyObject *kwargs) int return_value = -1; #if defined(Py_BUILD_CORE) && !defined(Py_BUILD_CORE_MODULE) - #define NUM_KEYWORDS 4 + #define NUM_KEYWORDS 5 static struct { PyGC_Head _this_is_not_used; PyObject_VAR_HEAD PyObject *ob_item[NUM_KEYWORDS]; } _kwtuple = { .ob_base = PyVarObject_HEAD_INIT(&PyTuple_Type, NUM_KEYWORDS) - .ob_item = { &_Py_ID(coro), &_Py_ID(loop), &_Py_ID(name), &_Py_ID(context), }, + .ob_item = { &_Py_ID(coro), &_Py_ID(loop), &_Py_ID(name), &_Py_ID(context), &_Py_ID(coro_result), }, }; #undef NUM_KEYWORDS #define KWTUPLE (&_kwtuple.ob_base.ob_base) @@ -513,14 +515,14 @@ _asyncio_Task___init__(PyObject *self, PyObject *args, PyObject *kwargs) # define KWTUPLE NULL #endif // !Py_BUILD_CORE - static const char * const _keywords[] = {"coro", "loop", "name", "context", NULL}; + static const char * const _keywords[] = {"coro", "loop", "name", "context", "coro_result", NULL}; static _PyArg_Parser _parser = { .keywords = _keywords, .fname = "Task", .kwtuple = KWTUPLE, }; #undef KWTUPLE - PyObject *argsbuf[4]; + PyObject *argsbuf[5]; PyObject * const *fastargs; Py_ssize_t nargs = PyTuple_GET_SIZE(args); Py_ssize_t noptargs = nargs + (kwargs ? PyDict_GET_SIZE(kwargs) : 0) - 1; @@ -528,6 +530,7 @@ _asyncio_Task___init__(PyObject *self, PyObject *args, PyObject *kwargs) PyObject *loop = Py_None; PyObject *name = Py_None; PyObject *context = Py_None; + PyObject *coro_result = NULL; fastargs = _PyArg_UnpackKeywords(_PyTuple_CAST(args)->ob_item, nargs, kwargs, NULL, &_parser, 1, 1, 0, argsbuf); if (!fastargs) { @@ -549,9 +552,15 @@ _asyncio_Task___init__(PyObject *self, PyObject *args, PyObject *kwargs) goto skip_optional_kwonly; } } - context = fastargs[3]; + if (fastargs[3]) { + context = fastargs[3]; + if (!--noptargs) { + goto skip_optional_kwonly; + } + } + coro_result = fastargs[4]; skip_optional_kwonly: - return_value = _asyncio_Task___init___impl((TaskObj *)self, coro, loop, name, context); + return_value = _asyncio_Task___init___impl((TaskObj *)self, coro, loop, name, context, coro_result); exit: return return_value; @@ -1302,4 +1311,4 @@ _asyncio_current_task(PyObject *module, PyObject *const *args, Py_ssize_t nargs, exit: return return_value; } -/*[clinic end generated code: output=00f494214f2fd008 input=a9049054013a1b77]*/ +/*[clinic end generated code: output=d7cd98454c53b85a input=a9049054013a1b77]*/ diff --git a/async_tree.py b/async_tree.py new file mode 100644 index 00000000000000..0da16e049b9eaf --- /dev/null +++ b/async_tree.py @@ -0,0 +1,233 @@ +# Copyright (c) Facebook, Inc. and its affiliates. (http://www.facebook.com) +""" +Benchmark script for recursive async tree workloads. This script includes the +following microbenchmark scenarios: + +1) "no_suspension": No suspension in the async tree. +2) "suspense_all": Suspension (simulating IO) at all leaf nodes in the async tree. +3) "memoization": Simulated IO calls at all leaf nodes, but with memoization. Only + un-memoized IO calls will result in suspensions. +4) "cpu_io_mixed": A mix of CPU-bound workload and IO-bound workload (with + memoization) at the leaf nodes. + +Use the commandline flag or choose the corresponding AsyncTree class +to run the desired microbenchmark scenario. +""" + + +import asyncio +import math +import random +import time +from argparse import ArgumentParser + + +NUM_RECURSE_LEVELS = 6 +NUM_RECURSE_BRANCHES = 6 +IO_SLEEP_TIME = 0.05 +DEFAULT_MEMOIZABLE_PERCENTAGE = 90 +DEFAULT_CPU_PROBABILITY = 0.5 +FACTORIAL_N = 500 + + +def parse_args(): + parser = ArgumentParser( + description="""\ +Benchmark script for recursive async tree workloads. It can be run as a standalone +script, in which case you can specify the microbenchmark scenario to run and whether +to print the results. +""" + ) + parser.add_argument( + "-s", + "--scenario", + choices=["no_suspension", "suspense_all", "memoization", "cpu_io_mixed"], + default="no_suspension", + help="""\ +Determines which microbenchmark scenario to run. Defaults to no_suspension. Options: +1) "no_suspension": No suspension in the async tree. +2) "suspense_all": Suspension (simulating IO) at all leaf nodes in the async tree. +3) "memoization": Simulated IO calls at all leaf nodes, but with memoization. Only + un-memoized IO calls will result in suspensions. +4) "cpu_io_mixed": A mix of CPU-bound workload and IO-bound workload (with + memoization) at the leaf nodes. +""", + ) + parser.add_argument( + "-m", + "--memoizable-percentage", + type=int, + default=DEFAULT_MEMOIZABLE_PERCENTAGE, + help="""\ +Sets the percentage (0-100) of the data that should be memoized, defaults to 90. For +example, at the default 90 percent, data 1-90 will be memoized and data 91-100 will not. +""", + ) + parser.add_argument( + "-c", + "--cpu-probability", + type=float, + default=DEFAULT_CPU_PROBABILITY, + help="""\ +Sets the probability (0-1) that a leaf node will execute a cpu-bound workload instead +of an io-bound workload. Defaults to 0.5. Only applies to the "cpu_io_mixed" +microbenchmark scenario. +""", + ) + parser.add_argument( + "-p", + "--print", + action="store_true", + default=False, + help="Print the results (runtime and number of Tasks created).", + ) + parser.add_argument( + "-g", + "--gather", + action="store_true", + default=False, + help="Use gather (if not specified, use TaskGroup if available, otherwise use gather).", + ) + parser.add_argument( + "-e", + "--eager", + action="store_true", + default=False, + help="Use the eager task factory.", + ) + return parser.parse_args() + + +class AsyncTree: + def __init__( + self, + memoizable_percentage=DEFAULT_MEMOIZABLE_PERCENTAGE, + cpu_probability=DEFAULT_CPU_PROBABILITY, + use_gather=None, + use_eager_factory=None, + ): + self.suspense_count = 0 + self.task_count = 0 + self.memoizable_percentage = memoizable_percentage + self.cpu_probability = cpu_probability + has_taskgroups = hasattr(asyncio, "TaskGroup") + self.use_gather = use_gather or (not has_taskgroups) + has_eager_factory = hasattr(asyncio, "create_eager_task_factory") + self.use_eager_factory = use_eager_factory and has_eager_factory + self.cache = {} + # set to deterministic random, so that the results are reproducible + random.seed(0) + + async def mock_io_call(self): + self.suspense_count += 1 + await asyncio.sleep(IO_SLEEP_TIME) + + async def suspense_func(self): + raise NotImplementedError( + "To be implemented by each microbenchmark's derived class." + ) + + async def recurse(self, recurse_level): + if recurse_level == 0: + await self.suspense_func() + return + + if self.use_gather: + await asyncio.gather( + *[self.recurse(recurse_level - 1) for _ in range(NUM_RECURSE_BRANCHES)] + ) + else: + async with asyncio.TaskGroup() as tg: + for _ in range(NUM_RECURSE_BRANCHES): + tg.create_task(self.recurse(recurse_level - 1)) + + async def run_benchmark(self): + await self.recurse(NUM_RECURSE_LEVELS) + + def run(self): + + _NOT_SET = object() + + def counting_task_constructor(coro, *, loop=None, name=None, context=None, coro_result=_NOT_SET): + if coro_result is _NOT_SET: + # only count calls that will actually result a task scheduled to the event loop + # (if coro_result is non-None, it will return synchronously) + self.task_count += 1 + return asyncio.Task(coro, loop=loop, name=name, context=context) + return asyncio.Task(coro, loop=loop, name=name, context=context, coro_result=coro_result) + + def counting_task_factory(loop, coro, *, name=None, context=None, coro_result=_NOT_SET): + if coro_result is _NOT_SET: + return counting_task_constructor(coro, loop=loop, name=name, context=context) + return counting_task_constructor(coro, loop=loop, name=name, context=context, coro_result=coro_result) + + asyncio.run( + self.run_benchmark(), + task_factory=( + asyncio.create_eager_task_factory(counting_task_constructor) + if self.use_eager_factory else counting_task_factory + ), + ) + + +class NoSuspensionAsyncTree(AsyncTree): + async def suspense_func(self): + return + + +class SuspenseAllAsyncTree(AsyncTree): + async def suspense_func(self): + await self.mock_io_call() + + +class MemoizationAsyncTree(AsyncTree): + async def suspense_func(self): + # deterministic random (seed preset) + data = random.randint(1, 100) + + if data <= self.memoizable_percentage: + if self.cache.get(data): + return data + + self.cache[data] = True + + await self.mock_io_call() + return data + + +class CpuIoMixedAsyncTree(MemoizationAsyncTree): + async def suspense_func(self): + if random.random() < self.cpu_probability: + # mock cpu-bound call + return math.factorial(FACTORIAL_N) + else: + return await MemoizationAsyncTree.suspense_func(self) + + +if __name__ == "__main__": + args = parse_args() + scenario = args.scenario + + trees = { + "no_suspension": NoSuspensionAsyncTree, + "suspense_all": SuspenseAllAsyncTree, + "memoization": MemoizationAsyncTree, + "cpu_io_mixed": CpuIoMixedAsyncTree, + } + async_tree_class = trees[scenario] + async_tree = async_tree_class( + args.memoizable_percentage, args.cpu_probability, args.gather, args.eager) + + start_time = time.perf_counter() + async_tree.run() + end_time = time.perf_counter() + + if args.print: + eager_or_tg = "gather" if async_tree.use_gather else "TaskGroup" + task_factory = "eager" if async_tree.use_eager_factory else "standard" + print(f"Scenario: {scenario}") + print(f"Method: {eager_or_tg}") + print(f"Task factory: {task_factory}") + print(f"Time: {end_time - start_time} s") + print(f"Tasks created: {async_tree.task_count}") + print(f"Suspense called: {async_tree.suspense_count}")