Skip to content
This repository was archived by the owner on Nov 23, 2017. It is now read-only.

Commit 0b8d15d

Browse files
committed
Make get_event_loop() return the current loop if called from coroutines
1 parent 405d919 commit 0b8d15d

17 files changed

+102
-1
lines changed

asyncio/base_events.py

+3
Original file line numberDiff line numberDiff line change
@@ -400,14 +400,17 @@ def run_forever(self):
400400
old_agen_hooks = sys.get_asyncgen_hooks()
401401
sys.set_asyncgen_hooks(firstiter=self._asyncgen_firstiter_hook,
402402
finalizer=self._asyncgen_finalizer_hook)
403+
old_loop = events._get_current_loop()
403404
try:
405+
events._set_current_loop(self)
404406
while True:
405407
self._run_once()
406408
if self._stopping:
407409
break
408410
finally:
409411
self._stopping = False
410412
self._thread_id = None
413+
events._set_current_loop(old_loop)
411414
self._set_coroutine_wrapper(False)
412415
if self._asyncgens is not None:
413416
sys.set_asyncgen_hooks(*old_agen_hooks)

asyncio/events.py

+35-1
Original file line numberDiff line numberDiff line change
@@ -607,6 +607,30 @@ def new_event_loop(self):
607607
_lock = threading.Lock()
608608

609609

610+
# A TLS for the current event loop, used by get_current_loop.
611+
class _CurrentLoop(threading.local):
612+
_loop = None
613+
_current_loop = _CurrentLoop()
614+
615+
616+
def _get_current_loop():
617+
"""Return the currently running event loop or None.
618+
619+
This is a low-level function intended to be used by event loops.
620+
This function is thread-specific.
621+
"""
622+
return _current_loop._loop
623+
624+
625+
def _set_current_loop(loop):
626+
"""Set the currently running event loop.
627+
628+
This is a low-level function intended to be used by event loops.
629+
This function is thread-specific.
630+
"""
631+
_current_loop._loop = loop
632+
633+
610634
def _init_event_loop_policy():
611635
global _event_loop_policy
612636
with _lock:
@@ -632,7 +656,17 @@ def set_event_loop_policy(policy):
632656

633657

634658
def get_event_loop():
635-
"""Equivalent to calling get_event_loop_policy().get_event_loop()."""
659+
"""Return an asyncio event loop.
660+
661+
When called from coroutines, this function will always return the
662+
current event loop.
663+
664+
If there is no current event loop set, the function will return
665+
the result of `get_event_loop_policy().get_event_loop()` call.
666+
"""
667+
current_loop = _get_current_loop()
668+
if current_loop is not None:
669+
return current_loop
636670
return get_event_loop_policy().get_event_loop()
637671

638672

asyncio/test_utils.py

+6
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,13 @@ def new_test_loop(self, gen=None):
449449
self.set_event_loop(loop)
450450
return loop
451451

452+
def setUp(self):
453+
self._get_current_loop = events._get_current_loop
454+
events._get_current_loop = lambda: None
455+
452456
def tearDown(self):
457+
events._get_current_loop = self._get_current_loop
458+
453459
events.set_event_loop(None)
454460

455461
# Detect CPython bug #23353: ensure that yield/yield-from is not used

tests/test_base_events.py

+2
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ def test_ipaddr_info_no_inet_pton(self, m_socket):
154154
class BaseEventLoopTests(test_utils.TestCase):
155155

156156
def setUp(self):
157+
super().setUp()
157158
self.loop = base_events.BaseEventLoop()
158159
self.loop._selector = mock.Mock()
159160
self.loop._selector.select.return_value = ()
@@ -976,6 +977,7 @@ def connection_lost(self, exc):
976977
class BaseEventLoopWithSelectorTests(test_utils.TestCase):
977978

978979
def setUp(self):
980+
super().setUp()
979981
self.loop = asyncio.new_event_loop()
980982
self.set_event_loop(self.loop)
981983

tests/test_events.py

+23
Original file line numberDiff line numberDiff line change
@@ -2233,6 +2233,7 @@ def noop(*args, **kwargs):
22332233
class HandleTests(test_utils.TestCase):
22342234

22352235
def setUp(self):
2236+
super().setUp()
22362237
self.loop = mock.Mock()
22372238
self.loop.get_debug.return_value = True
22382239

@@ -2411,6 +2412,7 @@ def __await__(self):
24112412
class TimerTests(unittest.TestCase):
24122413

24132414
def setUp(self):
2415+
super().setUp()
24142416
self.loop = mock.Mock()
24152417

24162418
def test_hash(self):
@@ -2719,6 +2721,27 @@ def test_set_event_loop_policy(self):
27192721
self.assertIs(policy, asyncio.get_event_loop_policy())
27202722
self.assertIsNot(policy, old_policy)
27212723

2724+
def test_get_event_loop_returns_current_loop(self):
2725+
class Policy(asyncio.DefaultEventLoopPolicy):
2726+
def get_event_loop(self):
2727+
raise NotImplementedError
2728+
2729+
loop = None
2730+
2731+
old_policy = asyncio.get_event_loop_policy()
2732+
try:
2733+
asyncio.set_event_loop_policy(Policy())
2734+
loop = asyncio.new_event_loop()
2735+
2736+
async def func():
2737+
self.assertIs(asyncio.get_event_loop(), loop)
2738+
2739+
loop.run_until_complete(func())
2740+
finally:
2741+
asyncio.set_event_loop_policy(old_policy)
2742+
if loop is not None:
2743+
loop.close()
2744+
27222745

27232746
if __name__ == '__main__':
27242747
unittest.main()

tests/test_futures.py

+3
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def __iter__(self):
7979
class DuckTests(test_utils.TestCase):
8080

8181
def setUp(self):
82+
super().setUp()
8283
self.loop = self.new_test_loop()
8384
self.addCleanup(self.loop.close)
8485

@@ -96,6 +97,7 @@ def test_ensure_future(self):
9697
class FutureTests(test_utils.TestCase):
9798

9899
def setUp(self):
100+
super().setUp()
99101
self.loop = self.new_test_loop()
100102
self.addCleanup(self.loop.close)
101103

@@ -468,6 +470,7 @@ def test_set_result_unless_cancelled(self):
468470
class FutureDoneCallbackTests(test_utils.TestCase):
469471

470472
def setUp(self):
473+
super().setUp()
471474
self.loop = self.new_test_loop()
472475

473476
def run_briefly(self):

tests/test_locks.py

+4
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
class LockTests(test_utils.TestCase):
2020

2121
def setUp(self):
22+
super().setUp()
2223
self.loop = self.new_test_loop()
2324

2425
def test_ctor_loop(self):
@@ -235,6 +236,7 @@ def test_context_manager_no_yield(self):
235236
class EventTests(test_utils.TestCase):
236237

237238
def setUp(self):
239+
super().setUp()
238240
self.loop = self.new_test_loop()
239241

240242
def test_ctor_loop(self):
@@ -364,6 +366,7 @@ def c1(result):
364366
class ConditionTests(test_utils.TestCase):
365367

366368
def setUp(self):
369+
super().setUp()
367370
self.loop = self.new_test_loop()
368371

369372
def test_ctor_loop(self):
@@ -699,6 +702,7 @@ def test_ambiguous_loops(self):
699702
class SemaphoreTests(test_utils.TestCase):
700703

701704
def setUp(self):
705+
super().setUp()
702706
self.loop = self.new_test_loop()
703707

704708
def test_ctor_loop(self):

tests/test_pep492.py

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
class BaseTest(test_utils.TestCase):
1818

1919
def setUp(self):
20+
super().setUp()
2021
self.loop = asyncio.BaseEventLoop()
2122
self.loop._process_events = mock.Mock()
2223
self.loop._selector = mock.Mock()

tests/test_proactor_events.py

+3
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def close_transport(transport):
2424
class ProactorSocketTransportTests(test_utils.TestCase):
2525

2626
def setUp(self):
27+
super().setUp()
2728
self.loop = self.new_test_loop()
2829
self.addCleanup(self.loop.close)
2930
self.proactor = mock.Mock()
@@ -436,6 +437,8 @@ def test_dont_pause_writing(self):
436437
class BaseProactorEventLoopTests(test_utils.TestCase):
437438

438439
def setUp(self):
440+
super().setUp()
441+
439442
self.sock = test_utils.mock_nonblocking_socket()
440443
self.proactor = mock.Mock()
441444

tests/test_queues.py

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
class _QueueTestBase(test_utils.TestCase):
1111

1212
def setUp(self):
13+
super().setUp()
1314
self.loop = self.new_test_loop()
1415

1516

tests/test_selector_events.py

+5
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def close_transport(transport):
5151
class BaseSelectorEventLoopTests(test_utils.TestCase):
5252

5353
def setUp(self):
54+
super().setUp()
5455
self.selector = mock.Mock()
5556
self.selector.select.return_value = []
5657
self.loop = TestBaseSelectorEventLoop(self.selector)
@@ -698,6 +699,7 @@ def test_accept_connection_multiple(self):
698699
class SelectorTransportTests(test_utils.TestCase):
699700

700701
def setUp(self):
702+
super().setUp()
701703
self.loop = self.new_test_loop()
702704
self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
703705
self.sock = mock.Mock(socket.socket)
@@ -793,6 +795,7 @@ def test_connection_lost(self):
793795
class SelectorSocketTransportTests(test_utils.TestCase):
794796

795797
def setUp(self):
798+
super().setUp()
796799
self.loop = self.new_test_loop()
797800
self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
798801
self.sock = mock.Mock(socket.socket)
@@ -1141,6 +1144,7 @@ def test_transport_close_remove_writer(self, m_log):
11411144
class SelectorSslTransportTests(test_utils.TestCase):
11421145

11431146
def setUp(self):
1147+
super().setUp()
11441148
self.loop = self.new_test_loop()
11451149
self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
11461150
self.sock = mock.Mock(socket.socket)
@@ -1501,6 +1505,7 @@ def test_ssl_transport_requires_ssl_module(self):
15011505
class SelectorDatagramTransportTests(test_utils.TestCase):
15021506

15031507
def setUp(self):
1508+
super().setUp()
15041509
self.loop = self.new_test_loop()
15051510
self.protocol = test_utils.make_test_protocol(asyncio.DatagramProtocol)
15061511
self.sock = mock.Mock(spec_set=socket.socket)

tests/test_sslproto.py

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
class SslProtoHandshakeTests(test_utils.TestCase):
1919

2020
def setUp(self):
21+
super().setUp()
2122
self.loop = asyncio.new_event_loop()
2223
self.set_event_loop(self.loop)
2324

tests/test_streams.py

+1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class StreamReaderTests(test_utils.TestCase):
2222
DATA = b'line1\nline2\nline3\n'
2323

2424
def setUp(self):
25+
super().setUp()
2526
self.loop = asyncio.new_event_loop()
2627
self.set_event_loop(self.loop)
2728

tests/test_subprocess.py

+3
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def _start(self, *args, **kwargs):
3535

3636
class SubprocessTransportTests(test_utils.TestCase):
3737
def setUp(self):
38+
super().setUp()
3839
self.loop = self.new_test_loop()
3940
self.set_event_loop(self.loop)
4041

@@ -466,6 +467,7 @@ class SubprocessWatcherMixin(SubprocessMixin):
466467
Watcher = None
467468

468469
def setUp(self):
470+
super().setUp()
469471
policy = asyncio.get_event_loop_policy()
470472
self.loop = policy.new_event_loop()
471473
self.set_event_loop(self.loop)
@@ -490,6 +492,7 @@ class SubprocessFastWatcherTests(SubprocessWatcherMixin,
490492
class SubprocessProactorTests(SubprocessMixin, test_utils.TestCase):
491493

492494
def setUp(self):
495+
super().setUp()
493496
self.loop = asyncio.ProactorEventLoop()
494497
self.set_event_loop(self.loop)
495498

tests/test_tasks.py

+5
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def __call__(self, *args):
7575
class TaskTests(test_utils.TestCase):
7676

7777
def setUp(self):
78+
super().setUp()
7879
self.loop = self.new_test_loop()
7980

8081
def test_other_loop_future(self):
@@ -1933,6 +1934,7 @@ def cancelling_callback(_):
19331934
class GatherTestsBase:
19341935

19351936
def setUp(self):
1937+
super().setUp()
19361938
self.one_loop = self.new_test_loop()
19371939
self.other_loop = self.new_test_loop()
19381940
self.set_event_loop(self.one_loop, cleanup=False)
@@ -2216,6 +2218,7 @@ class RunCoroutineThreadsafeTests(test_utils.TestCase):
22162218
"""Test case for asyncio.run_coroutine_threadsafe."""
22172219

22182220
def setUp(self):
2221+
super().setUp()
22192222
self.loop = asyncio.new_event_loop()
22202223
self.set_event_loop(self.loop) # Will cleanup properly
22212224

@@ -2306,12 +2309,14 @@ def test_run_coroutine_threadsafe_task_factory_exception(self):
23062309

23072310
class SleepTests(test_utils.TestCase):
23082311
def setUp(self):
2312+
super().setUp()
23092313
self.loop = asyncio.new_event_loop()
23102314
asyncio.set_event_loop(None)
23112315

23122316
def tearDown(self):
23132317
self.loop.close()
23142318
self.loop = None
2319+
super().tearDown()
23152320

23162321
def test_sleep_zero(self):
23172322
result = 0

tests/test_unix_events.py

+5
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def close_pipe_transport(transport):
4040
class SelectorEventLoopSignalTests(test_utils.TestCase):
4141

4242
def setUp(self):
43+
super().setUp()
4344
self.loop = asyncio.SelectorEventLoop()
4445
self.set_event_loop(self.loop)
4546

@@ -234,6 +235,7 @@ def test_close(self, m_signal):
234235
class SelectorEventLoopUnixSocketTests(test_utils.TestCase):
235236

236237
def setUp(self):
238+
super().setUp()
237239
self.loop = asyncio.SelectorEventLoop()
238240
self.set_event_loop(self.loop)
239241

@@ -338,6 +340,7 @@ def test_create_unix_connection_ssl_noserverhost(self):
338340
class UnixReadPipeTransportTests(test_utils.TestCase):
339341

340342
def setUp(self):
343+
super().setUp()
341344
self.loop = self.new_test_loop()
342345
self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
343346
self.pipe = mock.Mock(spec_set=io.RawIOBase)
@@ -487,6 +490,7 @@ def test__call_connection_lost_with_err(self):
487490
class UnixWritePipeTransportTests(test_utils.TestCase):
488491

489492
def setUp(self):
493+
super().setUp()
490494
self.loop = self.new_test_loop()
491495
self.protocol = test_utils.make_test_protocol(asyncio.BaseProtocol)
492496
self.pipe = mock.Mock(spec_set=io.RawIOBase)
@@ -805,6 +809,7 @@ class ChildWatcherTestsMixin:
805809
ignore_warnings = mock.patch.object(log.logger, "warning")
806810

807811
def setUp(self):
812+
super().setUp()
808813
self.loop = self.new_test_loop()
809814
self.running = False
810815
self.zombies = {}

tests/test_windows_events.py

+1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def data_received(self, data):
3131
class ProactorTests(test_utils.TestCase):
3232

3333
def setUp(self):
34+
super().setUp()
3435
self.loop = asyncio.ProactorEventLoop()
3536
self.set_event_loop(self.loop)
3637

0 commit comments

Comments
 (0)