Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bpo-39622: Interrupt the main asyncio task on Ctrl+C #32105

Merged
merged 22 commits into from
Mar 30, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion Lib/asyncio/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,13 @@

import contextvars
import enum
import functools
import threading
import signal
import sys
from . import coroutines
from . import events
from . import exceptions
from . import tasks


Expand Down Expand Up @@ -47,6 +52,7 @@ def __init__(self, *, debug=None, loop_factory=None):
self._loop_factory = loop_factory
self._loop = None
self._context = None
self._interrunt_count = 0
gvanrossum marked this conversation as resolved.
Show resolved Hide resolved
asvetlov marked this conversation as resolved.
Show resolved Hide resolved

def __enter__(self):
self._lazy_init()
Expand Down Expand Up @@ -89,7 +95,28 @@ def run(self, coro, *, context=None):
if context is None:
context = self._context
task = self._loop.create_task(coro, context=context)
return self._loop.run_until_complete(task)

if (threading.current_thread() is threading.main_thread()
and signal.getsignal(signal.SIGINT) is signal.default_int_handler
):
sigint_handler = functools.partial(self._on_sigint, main_task=task)
signal.signal(signal.SIGINT, sigint_handler)
else:
sigint_handler = None

self._interrunt_count = 0
asvetlov marked this conversation as resolved.
Show resolved Hide resolved
try:
return self._loop.run_until_complete(task)
except exceptions.CancelledError:
if self._interrunt_count > 0 and task.uncancel() == 0:
asvetlov marked this conversation as resolved.
Show resolved Hide resolved
raise KeyboardInterrupt()
else:
raise # CancelledError
finally:
if (sigint_handler is not None
and signal.getsignal(signal.SIGINT) is sigint_handler
):
signal.signal(signal.SIGINT, signal.default_int_handler)

def _lazy_init(self):
if self._state is _State.CLOSED:
Expand All @@ -105,6 +132,14 @@ def _lazy_init(self):
self._context = contextvars.copy_context()
self._state = _State.INITIALIZED

def _on_sigint(self, signum, frame, main_task):
self._interrunt_count += 1
asvetlov marked this conversation as resolved.
Show resolved Hide resolved
if self._interrunt_count == 1 and not main_task.done():
asvetlov marked this conversation as resolved.
Show resolved Hide resolved
main_task.cancel()
# wakeup a loop if it is blocked by selector.select() with long timeour
asvetlov marked this conversation as resolved.
Show resolved Hide resolved
self._loop.call_soon_threadsafe(lambda: None)
return
raise KeyboardInterrupt()


def run(main, *, debug=None):
Expand Down
59 changes: 58 additions & 1 deletion Lib/test/test_asyncio/test_runners.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import _thread
import asyncio
import contextvars
import gc
import re
import threading
import unittest

from unittest import mock
Expand All @@ -12,6 +14,10 @@ def tearDownModule():
asyncio.set_event_loop_policy(None)


def interrupt_self():
_thread.interrupt_main()


class TestPolicy(asyncio.AbstractEventLoopPolicy):

def __init__(self, loop_factory):
Expand Down Expand Up @@ -298,7 +304,7 @@ async def get_context():

self.assertEqual(2, runner.run(get_context()).get(cvar))

def test_recursine_run(self):
def test_recursive_run(self):
async def g():
pass

Expand All @@ -318,6 +324,57 @@ async def f():
):
runner.run(f())

def test_interrupt_call_soon(self):
# The only case when task is not suspended by waiting a future
# or another task
assert threading.current_thread() is threading.main_thread()

async def coro():
with self.assertRaises(asyncio.CancelledError):
while True:
await asyncio.sleep(0)
raise asyncio.CancelledError()

with asyncio.Runner() as runner:
runner.get_loop().call_later(0.1, interrupt_self)
with self.assertRaises(KeyboardInterrupt):
runner.run(coro())

def test_interrupt_wait(self):
# interrupting when waiting a future cancels both future and main task
assert threading.current_thread() is threading.main_thread()

async def coro(fut):
with self.assertRaises(asyncio.CancelledError):
await fut
raise asyncio.CancelledError()

with asyncio.Runner() as runner:
fut = runner.get_loop().create_future()
runner.get_loop().call_later(0.1, interrupt_self)

with self.assertRaises(KeyboardInterrupt):
runner.run(coro(fut))

self.assertTrue(fut.cancelled())

def test_interrupt_cancelled_task(self):
# interrupting cancelled main task doesn't raise KeyboardInterrupt
assert threading.current_thread() is threading.main_thread()

async def subtask(task):
await asyncio.sleep(0)
task.cancel()
interrupt_self()

async def coro():
asyncio.create_task(subtask(asyncio.current_task()))
await asyncio.sleep(10)

with asyncio.Runner() as runner:
with self.assertRaises(asyncio.CancelledError):
runner.run(coro())


if __name__ == '__main__':
unittest.main()
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Handle Ctrl+C in asyncio programs to interrupt the main task.