Skip to content
This repository has been archived by the owner on Nov 23, 2017. It is now read-only.

Make get_event_loop() return the current loop if called from coroutines/callbacks #452

Merged
merged 1 commit into from
Nov 4, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion asyncio/base_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,21 +393,26 @@ def run_forever(self):
"""Run until stop() is called."""
self._check_closed()
if self.is_running():
raise RuntimeError('Event loop is running.')
raise RuntimeError('This event loop is already running')
if events._get_running_loop() is not None:
raise RuntimeError(
'Cannot run the event loop while another loop is running')
self._set_coroutine_wrapper(self._debug)
self._thread_id = threading.get_ident()
if self._asyncgens is not None:
old_agen_hooks = sys.get_asyncgen_hooks()
sys.set_asyncgen_hooks(firstiter=self._asyncgen_firstiter_hook,
finalizer=self._asyncgen_finalizer_hook)
try:
events._set_running_loop(self)
while True:
self._run_once()
if self._stopping:
break
finally:
self._stopping = False
self._thread_id = None
events._set_running_loop(None)
self._set_coroutine_wrapper(False)
if self._asyncgens is not None:
sys.set_asyncgen_hooks(*old_agen_hooks)
Expand Down
36 changes: 35 additions & 1 deletion asyncio/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,30 @@ def new_event_loop(self):
_lock = threading.Lock()


# A TLS for the running event loop, used by _get_running_loop.
class _RunningLoop(threading.local):
_loop = None
_running_loop = _RunningLoop()


def _get_running_loop():
"""Return the running event loop or None.

This is a low-level function intended to be used by event loops.
This function is thread-specific.
"""
return _running_loop._loop


def _set_running_loop(loop):
"""Set the running event loop.

This is a low-level function intended to be used by event loops.
This function is thread-specific.
"""
_running_loop._loop = loop


def _init_event_loop_policy():
global _event_loop_policy
with _lock:
Expand All @@ -632,7 +656,17 @@ def set_event_loop_policy(policy):


def get_event_loop():
"""Equivalent to calling get_event_loop_policy().get_event_loop()."""
"""Return an asyncio event loop.

When called from a coroutine or a callback (e.g. scheduled with call_soon
or similar API), this function will always return the running event loop.

If there is no running event loop set, the function will return
the result of `get_event_loop_policy().get_event_loop()` call.
"""
current_loop = _get_running_loop()
if current_loop is not None:
return current_loop
return get_event_loop_policy().get_event_loop()


Expand Down
6 changes: 6 additions & 0 deletions asyncio/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,13 @@ def new_test_loop(self, gen=None):
self.set_event_loop(loop)
return loop

def setUp(self):
self._get_running_loop = events._get_running_loop
events._get_running_loop = lambda: None

def tearDown(self):
events._get_running_loop = self._get_running_loop

events.set_event_loop(None)

# Detect CPython bug #23353: ensure that yield/yield-from is not used
Expand Down
20 changes: 20 additions & 0 deletions tests/test_base_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def test_ipaddr_info_no_inet_pton(self, m_socket):
class BaseEventLoopTests(test_utils.TestCase):

def setUp(self):
super().setUp()
self.loop = base_events.BaseEventLoop()
self.loop._selector = mock.Mock()
self.loop._selector.select.return_value = ()
Expand Down Expand Up @@ -976,6 +977,7 @@ def connection_lost(self, exc):
class BaseEventLoopWithSelectorTests(test_utils.TestCase):

def setUp(self):
super().setUp()
self.loop = asyncio.new_event_loop()
self.set_event_loop(self.loop)

Expand Down Expand Up @@ -1692,5 +1694,23 @@ def stop_loop_coro(loop):
"took .* seconds$")


class RunningLoopTests(unittest.TestCase):

def test_running_loop_within_a_loop(self):
@asyncio.coroutine
def runner(loop):
loop.run_forever()

loop = asyncio.new_event_loop()
outer_loop = asyncio.new_event_loop()
try:
with self.assertRaisesRegex(RuntimeError,
'while another loop is running'):
outer_loop.run_until_complete(runner(loop))
finally:
loop.close()
outer_loop.close()


if __name__ == '__main__':
unittest.main()
23 changes: 23 additions & 0 deletions tests/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -2233,6 +2233,7 @@ def noop(*args, **kwargs):
class HandleTests(test_utils.TestCase):

def setUp(self):
super().setUp()
self.loop = mock.Mock()
self.loop.get_debug.return_value = True

Expand Down Expand Up @@ -2411,6 +2412,7 @@ def __await__(self):
class TimerTests(unittest.TestCase):

def setUp(self):
super().setUp()
self.loop = mock.Mock()

def test_hash(self):
Expand Down Expand Up @@ -2719,6 +2721,27 @@ def test_set_event_loop_policy(self):
self.assertIs(policy, asyncio.get_event_loop_policy())
self.assertIsNot(policy, old_policy)

def test_get_event_loop_returns_running_loop(self):
class Policy(asyncio.DefaultEventLoopPolicy):
def get_event_loop(self):
raise NotImplementedError

loop = None

old_policy = asyncio.get_event_loop_policy()
try:
asyncio.set_event_loop_policy(Policy())
loop = asyncio.new_event_loop()

async def func():
self.assertIs(asyncio.get_event_loop(), loop)

loop.run_until_complete(func())
finally:
asyncio.set_event_loop_policy(old_policy)
if loop is not None:
loop.close()


if __name__ == '__main__':
unittest.main()
3 changes: 3 additions & 0 deletions tests/test_futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def __iter__(self):
class DuckTests(test_utils.TestCase):

def setUp(self):
super().setUp()
self.loop = self.new_test_loop()
self.addCleanup(self.loop.close)

Expand All @@ -96,6 +97,7 @@ def test_ensure_future(self):
class FutureTests(test_utils.TestCase):

def setUp(self):
super().setUp()
self.loop = self.new_test_loop()
self.addCleanup(self.loop.close)

Expand Down Expand Up @@ -468,6 +470,7 @@ def test_set_result_unless_cancelled(self):
class FutureDoneCallbackTests(test_utils.TestCase):

def setUp(self):
super().setUp()
self.loop = self.new_test_loop()

def run_briefly(self):
Expand Down
4 changes: 4 additions & 0 deletions tests/test_locks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
class LockTests(test_utils.TestCase):

def setUp(self):
super().setUp()
self.loop = self.new_test_loop()

def test_ctor_loop(self):
Expand Down Expand Up @@ -235,6 +236,7 @@ def test_context_manager_no_yield(self):
class EventTests(test_utils.TestCase):

def setUp(self):
super().setUp()
self.loop = self.new_test_loop()

def test_ctor_loop(self):
Expand Down Expand Up @@ -364,6 +366,7 @@ def c1(result):
class ConditionTests(test_utils.TestCase):

def setUp(self):
super().setUp()
self.loop = self.new_test_loop()

def test_ctor_loop(self):
Expand Down Expand Up @@ -699,6 +702,7 @@ def test_ambiguous_loops(self):
class SemaphoreTests(test_utils.TestCase):

def setUp(self):
super().setUp()
self.loop = self.new_test_loop()

def test_ctor_loop(self):
Expand Down
1 change: 1 addition & 0 deletions tests/test_pep492.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
class BaseTest(test_utils.TestCase):

def setUp(self):
super().setUp()
self.loop = asyncio.BaseEventLoop()
self.loop._process_events = mock.Mock()
self.loop._selector = mock.Mock()
Expand Down
3 changes: 3 additions & 0 deletions tests/test_proactor_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def close_transport(transport):
class ProactorSocketTransportTests(test_utils.TestCase):

def setUp(self):
super().setUp()
self.loop = self.new_test_loop()
self.addCleanup(self.loop.close)
self.proactor = mock.Mock()
Expand Down Expand Up @@ -436,6 +437,8 @@ def test_dont_pause_writing(self):
class BaseProactorEventLoopTests(test_utils.TestCase):

def setUp(self):
super().setUp()

self.sock = test_utils.mock_nonblocking_socket()
self.proactor = mock.Mock()

Expand Down
1 change: 1 addition & 0 deletions tests/test_queues.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
class _QueueTestBase(test_utils.TestCase):

def setUp(self):
super().setUp()
self.loop = self.new_test_loop()


Expand Down
5 changes: 5 additions & 0 deletions tests/test_selector_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def close_transport(transport):
class BaseSelectorEventLoopTests(test_utils.TestCase):

def setUp(self):
super().setUp()
self.selector = mock.Mock()
self.selector.select.return_value = []
self.loop = TestBaseSelectorEventLoop(self.selector)
Expand Down Expand Up @@ -698,6 +699,7 @@ def test_accept_connection_multiple(self):
class SelectorTransportTests(test_utils.TestCase):

def setUp(self):
super().setUp()
self.loop = self.new_test_loop()
self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
self.sock = mock.Mock(socket.socket)
Expand Down Expand Up @@ -793,6 +795,7 @@ def test_connection_lost(self):
class SelectorSocketTransportTests(test_utils.TestCase):

def setUp(self):
super().setUp()
self.loop = self.new_test_loop()
self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
self.sock = mock.Mock(socket.socket)
Expand Down Expand Up @@ -1141,6 +1144,7 @@ def test_transport_close_remove_writer(self, m_log):
class SelectorSslTransportTests(test_utils.TestCase):

def setUp(self):
super().setUp()
self.loop = self.new_test_loop()
self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
self.sock = mock.Mock(socket.socket)
Expand Down Expand Up @@ -1501,6 +1505,7 @@ def test_ssl_transport_requires_ssl_module(self):
class SelectorDatagramTransportTests(test_utils.TestCase):

def setUp(self):
super().setUp()
self.loop = self.new_test_loop()
self.protocol = test_utils.make_test_protocol(asyncio.DatagramProtocol)
self.sock = mock.Mock(spec_set=socket.socket)
Expand Down
1 change: 1 addition & 0 deletions tests/test_sslproto.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
class SslProtoHandshakeTests(test_utils.TestCase):

def setUp(self):
super().setUp()
self.loop = asyncio.new_event_loop()
self.set_event_loop(self.loop)

Expand Down
1 change: 1 addition & 0 deletions tests/test_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class StreamReaderTests(test_utils.TestCase):
DATA = b'line1\nline2\nline3\n'

def setUp(self):
super().setUp()
self.loop = asyncio.new_event_loop()
self.set_event_loop(self.loop)

Expand Down
3 changes: 3 additions & 0 deletions tests/test_subprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def _start(self, *args, **kwargs):

class SubprocessTransportTests(test_utils.TestCase):
def setUp(self):
super().setUp()
self.loop = self.new_test_loop()
self.set_event_loop(self.loop)

Expand Down Expand Up @@ -466,6 +467,7 @@ class SubprocessWatcherMixin(SubprocessMixin):
Watcher = None

def setUp(self):
super().setUp()
policy = asyncio.get_event_loop_policy()
self.loop = policy.new_event_loop()
self.set_event_loop(self.loop)
Expand All @@ -490,6 +492,7 @@ class SubprocessFastWatcherTests(SubprocessWatcherMixin,
class SubprocessProactorTests(SubprocessMixin, test_utils.TestCase):

def setUp(self):
super().setUp()
self.loop = asyncio.ProactorEventLoop()
self.set_event_loop(self.loop)

Expand Down
5 changes: 5 additions & 0 deletions tests/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __call__(self, *args):
class TaskTests(test_utils.TestCase):

def setUp(self):
super().setUp()
self.loop = self.new_test_loop()

def test_other_loop_future(self):
Expand Down Expand Up @@ -1933,6 +1934,7 @@ def cancelling_callback(_):
class GatherTestsBase:

def setUp(self):
super().setUp()
self.one_loop = self.new_test_loop()
self.other_loop = self.new_test_loop()
self.set_event_loop(self.one_loop, cleanup=False)
Expand Down Expand Up @@ -2216,6 +2218,7 @@ class RunCoroutineThreadsafeTests(test_utils.TestCase):
"""Test case for asyncio.run_coroutine_threadsafe."""

def setUp(self):
super().setUp()
self.loop = asyncio.new_event_loop()
self.set_event_loop(self.loop) # Will cleanup properly

Expand Down Expand Up @@ -2306,12 +2309,14 @@ def test_run_coroutine_threadsafe_task_factory_exception(self):

class SleepTests(test_utils.TestCase):
def setUp(self):
super().setUp()
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(None)

def tearDown(self):
self.loop.close()
self.loop = None
super().tearDown()

def test_sleep_zero(self):
result = 0
Expand Down
Loading