diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py index 031071281b38f7..3e4fa03974fbd5 100644 --- a/Lib/asyncio/base_events.py +++ b/Lib/asyncio/base_events.py @@ -16,6 +16,7 @@ import collections import collections.abc import concurrent.futures +import contextvars import functools import heapq import itertools @@ -789,7 +790,8 @@ def call_soon_threadsafe(self, callback, *args, context=None): self._write_to_self() return handle - def run_in_executor(self, executor, func, *args): + def run_in_executor(self, executor, func, *args, context=None, + retain_context=False): self._check_closed() if self._debug: self._check_callback(func, 'run_in_executor') @@ -800,8 +802,23 @@ def run_in_executor(self, executor, func, *args): if executor is None: executor = concurrent.futures.ThreadPoolExecutor() self._default_executor = executor - return futures.wrap_future( - executor.submit(func, *args), loop=self) + + if args: + runner = functools.partial(func, *args) + else: + runner = func + + if retain_context: + if not isinstance(executor, concurrent.futures.ThreadPoolExecutor): + raise RuntimeError( + 'retain_context=True supports only ThreadPoolExecutor') + + if context is None: + context = contextvars.copy_context() + + runner = functools.partial(context.run, runner) + + return futures.wrap_future(executor.submit(runner), loop=self) def set_default_executor(self, executor): if not isinstance(executor, concurrent.futures.ThreadPoolExecutor): diff --git a/Lib/test/test_asyncio/test_events.py b/Lib/test/test_asyncio/test_events.py index 7256758465da1e..f97708c8b7acf5 100644 --- a/Lib/test/test_asyncio/test_events.py +++ b/Lib/test/test_asyncio/test_events.py @@ -2,6 +2,7 @@ import collections.abc import concurrent.futures +import contextvars import functools import io import os @@ -34,6 +35,9 @@ from test import support from test.support import ALWAYS_EQ, LARGEST, SMALLEST +foo_ctx = contextvars.ContextVar('foo') +foo_ctx.set('bar') + def tearDownModule(): asyncio.set_event_loop_policy(None) @@ -367,6 +371,58 @@ def run(): time.sleep(0.4) self.assertFalse(called) + def test_run_in_executor_hierarchy(self): + def run(): + foo_ctx.set('foo') + res = foo_ctx.get() + self.assertEqual(res, 'foo') + return res + + f = self.loop.run_in_executor(None, run, retain_context=True) + res = self.loop.run_until_complete(f) + self.assertEqual(res, 'foo') + + res = foo_ctx.get() + self.assertEqual(res, 'bar') + + def test_run_in_executor_no_context(self): + def run(): + return foo_ctx.get() + + f = self.loop.run_in_executor(None, run, retain_context=True) + res = self.loop.run_until_complete(f) + self.assertEqual(res, 'bar') + + def test_run_in_executor_context(self): + def run(): + return foo_ctx.get() + + context = contextvars.copy_context() + f = self.loop.run_in_executor(None, run, context=context, + retain_context=True) + res = self.loop.run_until_complete(f) + self.assertEqual(res, 'bar') + + def test_run_in_executor_context_args(self): + def run(arg): + return (arg, foo_ctx.get()) + + context = contextvars.copy_context() + f = self.loop.run_in_executor(None, run, 'yo', context=context, + retain_context=True) + res = self.loop.run_until_complete(f) + self.assertEqual(res, ('yo', 'bar')) + + def test_run_in_executor_context_subprocess(self): + def run(arg): + pass + + pool = concurrent.futures.ProcessPoolExecutor() + context = contextvars.copy_context() + with self.assertRaises(RuntimeError): + self.loop.run_in_executor(pool, run, retain_context=True) + pool.shutdown() + def test_reader_callback(self): r, w = socket.socketpair() r.setblocking(False) diff --git a/Misc/NEWS.d/next/Library/2018-07-01-02-37-05.bpo-34014.RfrJGJ.rst b/Misc/NEWS.d/next/Library/2018-07-01-02-37-05.bpo-34014.RfrJGJ.rst new file mode 100644 index 00000000000000..28c0a9c504da4a --- /dev/null +++ b/Misc/NEWS.d/next/Library/2018-07-01-02-37-05.bpo-34014.RfrJGJ.rst @@ -0,0 +1 @@ +Added support of contextvars for BaseEventLoop.run_in_executor