Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bpo-32972: Add unittest.AsyncioTestCase #10296

Closed
4 changes: 2 additions & 2 deletions Lib/unittest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def testMultiply(self):
SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS.
"""

__all__ = ['TestResult', 'TestCase', 'TestSuite',
__all__ = ['TestResult', 'TestCase', 'TestSuite', 'AsyncioTestCase',
'TextTestRunner', 'TestLoader', 'FunctionTestCase', 'main',
'defaultTestLoader', 'SkipTest', 'skip', 'skipIf', 'skipUnless',
'expectedFailure', 'TextTestResult', 'installHandler',
Expand All @@ -58,7 +58,7 @@ def testMultiply(self):

from .result import TestResult
from .case import (addModuleCleanup, TestCase, FunctionTestCase, SkipTest, skip,
skipIf, skipUnless, expectedFailure)
skipIf, skipUnless, expectedFailure, AsyncioTestCase)
from .suite import BaseTestSuite, TestSuite
from .loader import (TestLoader, defaultTestLoader, makeSuite, getTestCaseNames,
findTestCases)
Expand Down
225 changes: 187 additions & 38 deletions Lib/unittest/case.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Test case implementation"""

import asyncio
import sys
import functools
import difflib
Expand All @@ -23,25 +24,20 @@
'Set self.maxDiff to None to see it.')

class SkipTest(Exception):
"""
Raise this exception in a test to skip it.
"""Raise this exception in a test to skip it.

Usually you can use TestCase.skipTest() or one of the skipping decorators
instead of raising this directly.
"""

class _ShouldStop(Exception):
"""
The test should stop.
"""
"""The test should stop."""

class _UnexpectedSuccess(Exception):
"""
The test was supposed to fail, but it didn't!
"""
"""The test was supposed to fail, but it didn't!"""


class _Outcome(object):
class _Outcome:
def __init__(self, result=None):
self.expecting_failure = False
self.result = result
Expand Down Expand Up @@ -81,6 +77,16 @@ def testPartExecutor(self, test_case, isTest=False):
self.success = self.success and old_success


class _IgnoredOutcome:
def __init__(self, result=None):
self.expecting_failure = None
self.success = True

@contextlib.contextmanager
def testPartExecutor(self, test_case, isTest=False):
yield


def _id(obj):
return obj

Expand Down Expand Up @@ -109,9 +115,7 @@ def doModuleCleanups():


def skip(reason):
"""
Unconditionally skip a test.
"""
"""Unconditionally skip a test."""
def decorator(test_item):
if not isinstance(test_item, type):
@functools.wraps(test_item)
Expand All @@ -125,17 +129,13 @@ def skip_wrapper(*args, **kwargs):
return decorator

def skipIf(condition, reason):
"""
Skip a test if the condition is true.
"""
"""Skip a test if the condition is true."""
if condition:
return skip(reason)
return _id

def skipUnless(condition, reason):
"""
Skip a test unless the condition is true.
"""
"""Skip a test unless the condition is true."""
if not condition:
return skip(reason)
return _id
Expand Down Expand Up @@ -171,8 +171,7 @@ def __init__(self, expected, test_case, expected_regex=None):
self.msg = None

def handle(self, name, args, kwargs):
"""
If args is empty, assertRaises/Warns is being used as a
"""If args is empty, assertRaises/Warns is being used as a
context manager, so check for a 'msg' kwarg and return self.
If args is not empty, call a callable passing positional and keyword
arguments.
Expand Down Expand Up @@ -367,7 +366,7 @@ def __iter__(self):
yield k


class TestCase(object):
class TestCase:
"""A class whose instances are single test cases.

By default, the test code itself should be placed in a method named
Expand Down Expand Up @@ -633,17 +632,7 @@ def run(self, result=None):
outcome = _Outcome(result)
try:
self._outcome = outcome

with outcome.testPartExecutor(self):
self.setUp()
if outcome.success:
outcome.expecting_failure = expecting_failure
with outcome.testPartExecutor(self, isTest=True):
testMethod()
outcome.expecting_failure = False
with outcome.testPartExecutor(self):
self.tearDown()

self._runTest(testMethod, outcome, expecting_failure)
self.doCleanups()
for test, reason in outcome.skipped:
self._addSkip(result, test, reason)
Expand All @@ -658,6 +647,7 @@ def run(self, result=None):
result.addSuccess(self)
return result
finally:
self._terminateTest(outcome)
result.stopTest(self)
if orig_result is None:
stopTestRun = getattr(result, 'stopTestRun', None)
Expand Down Expand Up @@ -701,14 +691,27 @@ def doClassCleanups(cls):
def __call__(self, *args, **kwds):
return self.run(*args, **kwds)

def _runTest(self, testMethod, outcome, expecting_failure):
"""Run the test and collect errors into a TestResult"""
with outcome.testPartExecutor(self):
self.setUp()
if outcome.success:
outcome.expecting_failure = expecting_failure
with outcome.testPartExecutor(self, isTest=True):
testMethod()
outcome.expecting_failure = False
with outcome.testPartExecutor(self):
self.tearDown()

def _terminateTest(self, outcome):
"""Hook that is called after a test run is complete."""
pass

def debug(self):
"""Run the test without collecting errors in a TestResult"""
self.setUp()
getattr(self, self._testMethodName)()
self.tearDown()
while self._cleanups:
function, args, kwargs = self._cleanups.pop(-1)
function(*args, **kwargs)
self._runTest(getattr(self, self._testMethodName),
_IgnoredOutcome(), None)
self.doCleanups()

def skipTest(self, reason):
"""Skip this test."""
Expand Down Expand Up @@ -1483,3 +1486,149 @@ def shortDescription(self):

def __str__(self):
return "{} {}".format(self.test_case, self._subDescription())


class AsyncioTestCase(TestCase):
"""Extension of `unittest.TestCase` for concurrent test cases.

This extension of `unittest.TestCase` runs the instance methods
decorated with the **async**/**await** syntax on a event loop.
The event loop is created anew for each test method run and
destroyed when the test method has exited. Both the `setUp` and
`tearDown` method are decorated with ``async``. Sub-classes that
extend either method MUST call the parent method using the ``await``
keyword.

If individual test methods are co-routines (as defined by
``asyncio.iscoroutinefunction``), then they will be run on the
active event loop; otherwise, they will be called as simple methods.
In the latter case, the event loop IS NOT RUNNING; however, you
can use it to run async code if necessary.

Clean up functions will be run on the event loop if they are
detected as co-routines; otherwise, they are called as standard
functions. After ALL cleanup calls are completed, the event loop
is stopped and closed.

When subclassing AsyncioTestCase, the event loop is available
via the ``loop`` property. Tests can assume that `self.loop`
refers to the event loop created for the test case. It is also
the active event loop as returned by ``asyncio.get_event_loop()``.
Test cases SHOULD NOT change the active event loop during the test
run.

The lifecycle of a single test execution is thus::

- ``TestCase`` instance is created
- ``setUpClass`` is executed
- for each test method
- new event loop is created
- ``setUp`` is executed on the event loop
- if ``setUp`` succeeds
- the test method is run on the event loop if it is a
co-routine; otherwise it is simply called
- ``tearDown`` is executed on the event loop
- the event loop is stopped if it is somehow running
- ``doCleanups`` is called
- callables registered with ``addCleanup`` are called
in reverse-order executing them on the event loop if
necessary
- the event loop is stopped if necessary
- the event loop is closed and unregistered

The process is tightly managed by the ``_runTest`` and ``doCleanups``
methods. Care must be taken if extending this class to ensure that
the event loop is properly managed.

"""

def __init__(self, methodName='runTest'):
super().__init__(methodName)
self.__loop = None

@classmethod
def setUpClass(cls):
"""Hook method for one-time initialization.

Subclasses MUST invoke this method when extending it.
"""
super().setUpClass()
cls.__saved_policy = asyncio.events._event_loop_policy

@classmethod
def tearDownClass(cls):
"""Hook method for one-time cleanup.

Subclasses MUST invoke this method when extending it.
"""
super().tearDownClass()
asyncio.set_event_loop_policy(cls.__saved_policy)

async def asyncSetUp(self):
"""Hook method for setting up the test fixture before exercising it.

This method invokes ``setUp`` inline so subclasses MUST invoke this
method using the ``await`` keyword when extending it.
"""
self.setUp()

async def asyncTearDown(self):
"""Hook method for deconstructing the test fixture after testing it.

This method invokes ``tearDown`` inline so subclasses MUST invoke
this method using the ``await`` keyword when extending it.
"""
self.tearDown()

@property
def loop(self):
"""Active event loop for the test case."""
if self.__loop is None:
self.__loop = asyncio.new_event_loop()
self.__loop.set_debug(True)
asyncio.set_event_loop(self.__loop)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the test case have to modify the global state like this? It is calling self.loop.run_until_complete when running things.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that it does. This is the only way to ensure that the code under test gets the correct loop.

return self.__loop

def _runTest(self, testMethod, outcome, expecting_failure):
try:
with outcome.testPartExecutor(self):
self.loop.run_until_complete(self.asyncSetUp())
if outcome.success:
outcome.expecting_failure = expecting_failure
with outcome.testPartExecutor(self, isTest=True):
if asyncio.iscoroutinefunction(testMethod):
self.loop.run_until_complete(testMethod())
else:
testMethod()
outcome.expecting_failure = False
with outcome.testPartExecutor(self):
self.loop.run_until_complete(self.asyncTearDown())
finally:
if self.loop.is_running():
self.loop.stop()

def _terminateTest(self, outcome):
super()._terminateTest(outcome)
if self.loop.is_running():
self.loop.stop()
self.loop.run_until_complete(
self.loop.shutdown_asyncgens())
asyncio.set_event_loop(None)
if not self.loop.is_closed():
self.loop.close()
self.__loop = None

def doCleanups(self):
"""Execute all cleanup functions.

Normally called for you after tearDown.
"""
outcome = self._outcome or _Outcome()
while self._cleanups:
cleanup_func, args, kwargs = self._cleanups.pop()
with outcome.testPartExecutor(self):
if asyncio.iscoroutinefunction(cleanup_func):
self.loop.run_until_complete(
cleanup_func(*args, **kwargs))
else:
cleanup_func(*args, **kwargs)
Loading