Skip to content
This repository has been archived by the owner on Nov 23, 2017. It is now read-only.

Commit

Permalink
Make get_event_loop() return the running loop if called from coroutines
Browse files Browse the repository at this point in the history
  • Loading branch information
1st1 committed Nov 4, 2016
1 parent 405d919 commit fb96b0d
Show file tree
Hide file tree
Showing 17 changed files with 102 additions and 1 deletion.
3 changes: 3 additions & 0 deletions asyncio/base_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,14 +400,17 @@ def run_forever(self):
old_agen_hooks = sys.get_asyncgen_hooks()
sys.set_asyncgen_hooks(firstiter=self._asyncgen_firstiter_hook,
finalizer=self._asyncgen_finalizer_hook)
old_loop = events._get_running_loop()
try:
events._set_running_loop(self)
while True:
self._run_once()
if self._stopping:
break
finally:
self._stopping = False
self._thread_id = None
events._set_running_loop(old_loop)
self._set_coroutine_wrapper(False)
if self._asyncgens is not None:
sys.set_asyncgen_hooks(*old_agen_hooks)
Expand Down
36 changes: 35 additions & 1 deletion asyncio/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,30 @@ def new_event_loop(self):
_lock = threading.Lock()


# A TLS for the running event loop, used by _get_running_loop.
class _RunningLoop(threading.local):
_loop = None
_running_loop = _RunningLoop()


def _get_running_loop():
"""Return the running event loop or None.
This is a low-level function intended to be used by event loops.
This function is thread-specific.
"""
return _running_loop._loop


def _set_running_loop(loop):
"""Set the running event loop.
This is a low-level function intended to be used by event loops.
This function is thread-specific.
"""
_running_loop._loop = loop


def _init_event_loop_policy():
global _event_loop_policy
with _lock:
Expand All @@ -632,7 +656,17 @@ def set_event_loop_policy(policy):


def get_event_loop():
"""Equivalent to calling get_event_loop_policy().get_event_loop()."""
"""Return an asyncio event loop.
When called from coroutines, this function will always return the
running event loop.
If there is no running event loop set, the function will return
the result of `get_event_loop_policy().get_event_loop()` call.
"""
current_loop = _get_running_loop()
if current_loop is not None:
return current_loop
return get_event_loop_policy().get_event_loop()


Expand Down
6 changes: 6 additions & 0 deletions asyncio/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,13 @@ def new_test_loop(self, gen=None):
self.set_event_loop(loop)
return loop

def setUp(self):
self._get_running_loop = events._get_running_loop
events._get_running_loop = lambda: None

def tearDown(self):
events._get_running_loop = self._get_running_loop

events.set_event_loop(None)

# Detect CPython bug #23353: ensure that yield/yield-from is not used
Expand Down
2 changes: 2 additions & 0 deletions tests/test_base_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def test_ipaddr_info_no_inet_pton(self, m_socket):
class BaseEventLoopTests(test_utils.TestCase):

def setUp(self):
super().setUp()
self.loop = base_events.BaseEventLoop()
self.loop._selector = mock.Mock()
self.loop._selector.select.return_value = ()
Expand Down Expand Up @@ -976,6 +977,7 @@ def connection_lost(self, exc):
class BaseEventLoopWithSelectorTests(test_utils.TestCase):

def setUp(self):
super().setUp()
self.loop = asyncio.new_event_loop()
self.set_event_loop(self.loop)

Expand Down
23 changes: 23 additions & 0 deletions tests/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -2233,6 +2233,7 @@ def noop(*args, **kwargs):
class HandleTests(test_utils.TestCase):

def setUp(self):
super().setUp()
self.loop = mock.Mock()
self.loop.get_debug.return_value = True

Expand Down Expand Up @@ -2411,6 +2412,7 @@ def __await__(self):
class TimerTests(unittest.TestCase):

def setUp(self):
super().setUp()
self.loop = mock.Mock()

def test_hash(self):
Expand Down Expand Up @@ -2719,6 +2721,27 @@ def test_set_event_loop_policy(self):
self.assertIs(policy, asyncio.get_event_loop_policy())
self.assertIsNot(policy, old_policy)

def test_get_event_loop_returns_running_loop(self):
class Policy(asyncio.DefaultEventLoopPolicy):
def get_event_loop(self):
raise NotImplementedError

loop = None

old_policy = asyncio.get_event_loop_policy()
try:
asyncio.set_event_loop_policy(Policy())
loop = asyncio.new_event_loop()

async def func():
self.assertIs(asyncio.get_event_loop(), loop)

loop.run_until_complete(func())
finally:
asyncio.set_event_loop_policy(old_policy)
if loop is not None:
loop.close()


if __name__ == '__main__':
unittest.main()
3 changes: 3 additions & 0 deletions tests/test_futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def __iter__(self):
class DuckTests(test_utils.TestCase):

def setUp(self):
super().setUp()
self.loop = self.new_test_loop()
self.addCleanup(self.loop.close)

Expand All @@ -96,6 +97,7 @@ def test_ensure_future(self):
class FutureTests(test_utils.TestCase):

def setUp(self):
super().setUp()
self.loop = self.new_test_loop()
self.addCleanup(self.loop.close)

Expand Down Expand Up @@ -468,6 +470,7 @@ def test_set_result_unless_cancelled(self):
class FutureDoneCallbackTests(test_utils.TestCase):

def setUp(self):
super().setUp()
self.loop = self.new_test_loop()

def run_briefly(self):
Expand Down
4 changes: 4 additions & 0 deletions tests/test_locks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
class LockTests(test_utils.TestCase):

def setUp(self):
super().setUp()
self.loop = self.new_test_loop()

def test_ctor_loop(self):
Expand Down Expand Up @@ -235,6 +236,7 @@ def test_context_manager_no_yield(self):
class EventTests(test_utils.TestCase):

def setUp(self):
super().setUp()
self.loop = self.new_test_loop()

def test_ctor_loop(self):
Expand Down Expand Up @@ -364,6 +366,7 @@ def c1(result):
class ConditionTests(test_utils.TestCase):

def setUp(self):
super().setUp()
self.loop = self.new_test_loop()

def test_ctor_loop(self):
Expand Down Expand Up @@ -699,6 +702,7 @@ def test_ambiguous_loops(self):
class SemaphoreTests(test_utils.TestCase):

def setUp(self):
super().setUp()
self.loop = self.new_test_loop()

def test_ctor_loop(self):
Expand Down
1 change: 1 addition & 0 deletions tests/test_pep492.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
class BaseTest(test_utils.TestCase):

def setUp(self):
super().setUp()
self.loop = asyncio.BaseEventLoop()
self.loop._process_events = mock.Mock()
self.loop._selector = mock.Mock()
Expand Down
3 changes: 3 additions & 0 deletions tests/test_proactor_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def close_transport(transport):
class ProactorSocketTransportTests(test_utils.TestCase):

def setUp(self):
super().setUp()
self.loop = self.new_test_loop()
self.addCleanup(self.loop.close)
self.proactor = mock.Mock()
Expand Down Expand Up @@ -436,6 +437,8 @@ def test_dont_pause_writing(self):
class BaseProactorEventLoopTests(test_utils.TestCase):

def setUp(self):
super().setUp()

self.sock = test_utils.mock_nonblocking_socket()
self.proactor = mock.Mock()

Expand Down
1 change: 1 addition & 0 deletions tests/test_queues.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
class _QueueTestBase(test_utils.TestCase):

def setUp(self):
super().setUp()
self.loop = self.new_test_loop()


Expand Down
5 changes: 5 additions & 0 deletions tests/test_selector_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def close_transport(transport):
class BaseSelectorEventLoopTests(test_utils.TestCase):

def setUp(self):
super().setUp()
self.selector = mock.Mock()
self.selector.select.return_value = []
self.loop = TestBaseSelectorEventLoop(self.selector)
Expand Down Expand Up @@ -698,6 +699,7 @@ def test_accept_connection_multiple(self):
class SelectorTransportTests(test_utils.TestCase):

def setUp(self):
super().setUp()
self.loop = self.new_test_loop()
self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
self.sock = mock.Mock(socket.socket)
Expand Down Expand Up @@ -793,6 +795,7 @@ def test_connection_lost(self):
class SelectorSocketTransportTests(test_utils.TestCase):

def setUp(self):
super().setUp()
self.loop = self.new_test_loop()
self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
self.sock = mock.Mock(socket.socket)
Expand Down Expand Up @@ -1141,6 +1144,7 @@ def test_transport_close_remove_writer(self, m_log):
class SelectorSslTransportTests(test_utils.TestCase):

def setUp(self):
super().setUp()
self.loop = self.new_test_loop()
self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
self.sock = mock.Mock(socket.socket)
Expand Down Expand Up @@ -1501,6 +1505,7 @@ def test_ssl_transport_requires_ssl_module(self):
class SelectorDatagramTransportTests(test_utils.TestCase):

def setUp(self):
super().setUp()
self.loop = self.new_test_loop()
self.protocol = test_utils.make_test_protocol(asyncio.DatagramProtocol)
self.sock = mock.Mock(spec_set=socket.socket)
Expand Down
1 change: 1 addition & 0 deletions tests/test_sslproto.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
class SslProtoHandshakeTests(test_utils.TestCase):

def setUp(self):
super().setUp()
self.loop = asyncio.new_event_loop()
self.set_event_loop(self.loop)

Expand Down
1 change: 1 addition & 0 deletions tests/test_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class StreamReaderTests(test_utils.TestCase):
DATA = b'line1\nline2\nline3\n'

def setUp(self):
super().setUp()
self.loop = asyncio.new_event_loop()
self.set_event_loop(self.loop)

Expand Down
3 changes: 3 additions & 0 deletions tests/test_subprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def _start(self, *args, **kwargs):

class SubprocessTransportTests(test_utils.TestCase):
def setUp(self):
super().setUp()
self.loop = self.new_test_loop()
self.set_event_loop(self.loop)

Expand Down Expand Up @@ -466,6 +467,7 @@ class SubprocessWatcherMixin(SubprocessMixin):
Watcher = None

def setUp(self):
super().setUp()
policy = asyncio.get_event_loop_policy()
self.loop = policy.new_event_loop()
self.set_event_loop(self.loop)
Expand All @@ -490,6 +492,7 @@ class SubprocessFastWatcherTests(SubprocessWatcherMixin,
class SubprocessProactorTests(SubprocessMixin, test_utils.TestCase):

def setUp(self):
super().setUp()
self.loop = asyncio.ProactorEventLoop()
self.set_event_loop(self.loop)

Expand Down
5 changes: 5 additions & 0 deletions tests/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __call__(self, *args):
class TaskTests(test_utils.TestCase):

def setUp(self):
super().setUp()
self.loop = self.new_test_loop()

def test_other_loop_future(self):
Expand Down Expand Up @@ -1933,6 +1934,7 @@ def cancelling_callback(_):
class GatherTestsBase:

def setUp(self):
super().setUp()
self.one_loop = self.new_test_loop()
self.other_loop = self.new_test_loop()
self.set_event_loop(self.one_loop, cleanup=False)
Expand Down Expand Up @@ -2216,6 +2218,7 @@ class RunCoroutineThreadsafeTests(test_utils.TestCase):
"""Test case for asyncio.run_coroutine_threadsafe."""

def setUp(self):
super().setUp()
self.loop = asyncio.new_event_loop()
self.set_event_loop(self.loop) # Will cleanup properly

Expand Down Expand Up @@ -2306,12 +2309,14 @@ def test_run_coroutine_threadsafe_task_factory_exception(self):

class SleepTests(test_utils.TestCase):
def setUp(self):
super().setUp()
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(None)

def tearDown(self):
self.loop.close()
self.loop = None
super().tearDown()

def test_sleep_zero(self):
result = 0
Expand Down
5 changes: 5 additions & 0 deletions tests/test_unix_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def close_pipe_transport(transport):
class SelectorEventLoopSignalTests(test_utils.TestCase):

def setUp(self):
super().setUp()
self.loop = asyncio.SelectorEventLoop()
self.set_event_loop(self.loop)

Expand Down Expand Up @@ -234,6 +235,7 @@ def test_close(self, m_signal):
class SelectorEventLoopUnixSocketTests(test_utils.TestCase):

def setUp(self):
super().setUp()
self.loop = asyncio.SelectorEventLoop()
self.set_event_loop(self.loop)

Expand Down Expand Up @@ -338,6 +340,7 @@ def test_create_unix_connection_ssl_noserverhost(self):
class UnixReadPipeTransportTests(test_utils.TestCase):

def setUp(self):
super().setUp()
self.loop = self.new_test_loop()
self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
self.pipe = mock.Mock(spec_set=io.RawIOBase)
Expand Down Expand Up @@ -487,6 +490,7 @@ def test__call_connection_lost_with_err(self):
class UnixWritePipeTransportTests(test_utils.TestCase):

def setUp(self):
super().setUp()
self.loop = self.new_test_loop()
self.protocol = test_utils.make_test_protocol(asyncio.BaseProtocol)
self.pipe = mock.Mock(spec_set=io.RawIOBase)
Expand Down Expand Up @@ -805,6 +809,7 @@ class ChildWatcherTestsMixin:
ignore_warnings = mock.patch.object(log.logger, "warning")

def setUp(self):
super().setUp()
self.loop = self.new_test_loop()
self.running = False
self.zombies = {}
Expand Down
1 change: 1 addition & 0 deletions tests/test_windows_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def data_received(self, data):
class ProactorTests(test_utils.TestCase):

def setUp(self):
super().setUp()
self.loop = asyncio.ProactorEventLoop()
self.set_event_loop(self.loop)

Expand Down

0 comments on commit fb96b0d

Please sign in to comment.