diff --git a/asyncio/base_events.py b/asyncio/base_events.py index e5feb998..c519fda4 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -339,9 +339,11 @@ def run_forever(self): self._check_closed() if self.is_running(): raise RuntimeError('Event loop is running.') + policy = events.get_event_loop_policy() self._set_coroutine_wrapper(self._debug) self._thread_id = threading.get_ident() try: + policy.set_running_loop(self) while True: self._run_once() if self._stopping: @@ -349,6 +351,7 @@ def run_forever(self): finally: self._stopping = False self._thread_id = None + policy.set_running_loop(None) self._set_coroutine_wrapper(False) def run_until_complete(self, future): diff --git a/asyncio/events.py b/asyncio/events.py index c48c5bed..b36738bc 100644 --- a/asyncio/events.py +++ b/asyncio/events.py @@ -514,21 +514,44 @@ class AbstractEventLoopPolicy: def get_event_loop(self): """Get the event loop for the current context. - Returns an event loop object implementing the BaseEventLoop interface, - or raises an exception in case no event loop has been set for the - current context and the current policy does not specify to create one. + Returns an event loop object implementing the BaseEventLoop interface: + - the running loop if it has been set (using set_running_loop) + - the loop for the current context otherwise. - It should never return None.""" + It may also raise an exception in case no event loop has been set for + the current context and the current policy does not specify to create + one. It should never return None. + """ raise NotImplementedError def set_event_loop(self, loop): - """Set the event loop for the current context to loop.""" + """Set the event loop for the current context.""" + raise NotImplementedError + + def get_running_loop(self): + """Get the running event loop running for the current context, if any. + + Returns an event loop object implementing the BaseEventLoop interface. + If no running loop is set, it returns None. + """ + raise NotImplementedError + + def set_running_loop(self, loop): + """Set the running event loop for the current context. + + The loop argument can be None to clear the former running loop. + This method should be called by the event loop itself to set the + running loop when it starts, and clear it when it's done. + """ raise NotImplementedError def new_event_loop(self): - """Create and return a new event loop object according to this - policy's rules. If there's need to set this loop as the event loop for - the current context, set_event_loop must be called explicitly.""" + """Create and return a new event loop object. + + The loop is created according to the policy's rules. + If there is need to set this loop as the event loop for the + current context, set_event_loop must be called explicitly. + """ raise NotImplementedError # Child processes handling (Unix only). @@ -559,16 +582,22 @@ class BaseDefaultEventLoopPolicy(AbstractEventLoopPolicy): class _Local(threading.local): _loop = None + _running_loop = None _set_called = False def __init__(self): self._local = self._Local() def get_event_loop(self): - """Get the event loop. + """Get the event loop for the current context. - This may be None or an instance of EventLoop. + Returns an event loop object implementing the BaseEventLoop interface: + - the running loop if it has been set (using set_running_loop) + - the loop for the current thread otherwise. """ + running_loop = self.get_running_loop() + if running_loop is not None: + return running_loop if (self._local._loop is None and not self._local._set_called and isinstance(threading.current_thread(), threading._MainThread)): @@ -579,11 +608,26 @@ def get_event_loop(self): return self._local._loop def set_event_loop(self, loop): - """Set the event loop.""" + """Set the event loop for the current thread.""" self._local._set_called = True assert loop is None or isinstance(loop, AbstractEventLoop) self._local._loop = loop + def get_running_loop(self): + """Get the running event loop for the current thread if any. + + This may be None or an instance of EventLoop. + """ + return self._local._running_loop + + def set_running_loop(self, loop): + """Set the running event loop for the current thread.""" + assert loop is None or isinstance(loop, AbstractEventLoop) + running_loop = self._local._running_loop + if running_loop is not None and loop is not None: + raise RuntimeError('A loop is already running') + self._local._running_loop = loop + def new_event_loop(self): """Create a new event loop. diff --git a/asyncio/test_utils.py b/asyncio/test_utils.py index 396e6aed..e91d8bae 100644 --- a/asyncio/test_utils.py +++ b/asyncio/test_utils.py @@ -403,9 +403,25 @@ def get_function_source(func): class TestCase(unittest.TestCase): + def disable_get_event_loop(self): + policy = events.get_event_loop_policy() + if hasattr(policy, '_patched_get_event_loop'): + return + + def reset_event_loop_method(): + policy.get_running_loop = old_get_running_loop + del policy._patched_get_event_loop + + old_get_running_loop = policy.get_running_loop + policy.get_running_loop = lambda: None + policy._patched_get_event_loop = True + + self.addCleanup(reset_event_loop_method) + def set_event_loop(self, loop, *, cleanup=True): assert loop is not None # ensure that the event loop is passed explicitly in asyncio + self.disable_get_event_loop() events.set_event_loop(None) if cleanup: self.addCleanup(loop.close) diff --git a/tests/test_events.py b/tests/test_events.py index d52213ce..52525424 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -2467,6 +2467,9 @@ def test_event_loop_policy(self): self.assertRaises(NotImplementedError, policy.get_event_loop) self.assertRaises(NotImplementedError, policy.set_event_loop, object()) self.assertRaises(NotImplementedError, policy.new_event_loop) + self.assertRaises(NotImplementedError, policy.get_running_loop) + self.assertRaises(NotImplementedError, policy.set_running_loop, + object()) self.assertRaises(NotImplementedError, policy.get_child_watcher) self.assertRaises(NotImplementedError, policy.set_child_watcher, object()) @@ -2534,6 +2537,54 @@ def test_set_event_loop(self): loop.close() old_loop.close() + def test_set_running_loop(self): + policy = asyncio.DefaultEventLoopPolicy() + self.assertIsNone(policy._local._running_loop) + self.assertIsNone(policy.get_running_loop()) + + self.assertRaises(AssertionError, policy.set_running_loop, object()) + + loop = policy.new_event_loop() + policy.set_running_loop(loop) + + self.assertIs(policy._local._running_loop, loop) + self.assertIs(policy.get_running_loop(), loop) + loop.close() + + loop2 = policy.new_event_loop() + self.assertRaises(RuntimeError, policy.set_running_loop, loop2) + loop2.close() + + policy.set_running_loop(None) + self.assertIsNone(policy._local._running_loop) + + def test_get_event_loop_after_set_running_loop(self): + policy = asyncio.DefaultEventLoopPolicy() + + running_loop = policy.new_event_loop() + policy.set_running_loop(running_loop) + + self.assertIsNone(policy._local._loop) + self.assertIs(policy.get_event_loop(), running_loop) + + loop = policy.new_event_loop() + policy.set_event_loop(loop) + + self.assertIs(policy._local._loop, loop) + self.assertIs(policy.get_event_loop(), running_loop) + + policy.set_running_loop(None) + running_loop.close() + + self.assertIs(policy._local._loop, loop) + self.assertIs(policy.get_event_loop(), loop) + + policy.set_event_loop(None) + loop.close() + + self.assertIsNone(policy._local._loop) + self.assertRaises(RuntimeError, policy.get_event_loop) + def test_get_event_loop_policy(self): policy = asyncio.get_event_loop_policy() self.assertIsInstance(policy, asyncio.AbstractEventLoopPolicy)