From eab4041b55a992d26520dfcf463cae96d9bf55cb Mon Sep 17 00:00:00 2001 From: Ilya Kulakov Date: Tue, 12 Nov 2019 13:41:01 -0800 Subject: [PATCH] bpo-17013: Extend Mock.called to allow waiting for calls New methods allow tests to wait for calls executing in other threads. --- Doc/library/unittest.mock.rst | 11 +- Lib/unittest/mock.py | 72 +++++++++++- Lib/unittest/test/testmock/support.py | 16 +++ Lib/unittest/test/testmock/testmock.py | 104 +++++++++++++++++- .../2019-11-12-13-11-44.bpo-17013.C06aC9.rst | 2 + 5 files changed, 197 insertions(+), 8 deletions(-) create mode 100644 Misc/NEWS.d/next/Library/2019-11-12-13-11-44.bpo-17013.C06aC9.rst diff --git a/Doc/library/unittest.mock.rst b/Doc/library/unittest.mock.rst index 746cf765b077fd..0c8b99094346b8 100644 --- a/Doc/library/unittest.mock.rst +++ b/Doc/library/unittest.mock.rst @@ -493,7 +493,7 @@ the *new_callable* argument to :func:`patch`. .. attribute:: called - A boolean representing whether or not the mock object has been called: + A boolean-like object representing whether or not the mock object has been called: >>> mock = Mock(return_value=None) >>> mock.called @@ -502,6 +502,15 @@ the *new_callable* argument to :func:`patch`. >>> mock.called True + The object gives access to methods helpful in multithreaded tests: + + - :meth:`wait(/, skip=0, timeout=None)` asserts that mock is called + *skip* times during *timeout* + + - :meth:`wait_for(predicate, /, timeout=None)` asserts that + *predicate* was ``True`` at least once during the timeout; + *predicate* receives exactly one positional argument: the mock itself + .. attribute:: call_count An integer telling you how many times the mock object has been called: diff --git a/Lib/unittest/mock.py b/Lib/unittest/mock.py index a48132c5b1cb5b..4c5b6c93d184df 100644 --- a/Lib/unittest/mock.py +++ b/Lib/unittest/mock.py @@ -32,9 +32,10 @@ import pprint import sys import builtins +import threading from types import CodeType, ModuleType, MethodType from unittest.util import safe_repr -from functools import wraps, partial +from functools import wraps, partial, total_ordering _builtins = {name for name in dir(builtins) if not name.startswith('_')} @@ -217,7 +218,7 @@ def reset_mock(): if _is_instance_mock(ret) and not ret is mock: ret.reset_mock() - funcopy.called = False + funcopy.called = _CallEvent(mock) funcopy.call_count = 0 funcopy.call_args = None funcopy.call_args_list = _CallList() @@ -439,7 +440,7 @@ def __init__( __dict__['_mock_wraps'] = wraps __dict__['_mock_delegate'] = None - __dict__['_mock_called'] = False + __dict__['_mock_called'] = _CallEvent(self) __dict__['_mock_call_args'] = None __dict__['_mock_call_count'] = 0 __dict__['_mock_call_args_list'] = _CallList() @@ -577,7 +578,7 @@ def reset_mock(self, visited=None,*, return_value=False, side_effect=False): return visited.append(id(self)) - self.called = False + self.called = _CallEvent(self) self.call_args = None self.call_count = 0 self.mock_calls = _CallList() @@ -1093,8 +1094,8 @@ def _mock_call(self, /, *args, **kwargs): return self._execute_mock_call(*args, **kwargs) def _increment_mock_call(self, /, *args, **kwargs): - self.called = True self.call_count += 1 + self.called._notify() # handle call_args # needs to be set here so assertions on call arguments pass before @@ -2358,6 +2359,67 @@ def _format_call_signature(name, args, kwargs): return message % formatted_args +@total_ordering +class _CallEvent(object): + def __init__(self, mock): + self._mock = mock + self._condition = threading.Condition() + + def wait(self, /, skip=0, timeout=None): + """ + Wait for any call. + + :param skip: How many calls will be skipped. + As a result, the mock should be called at least + ``skip + 1`` times. + """ + def predicate(mock): + return mock.call_count > skip + + self.wait_for(predicate, timeout=timeout) + + def wait_for(self, predicate, /, timeout=None): + """ + Wait for a given predicate to become True. + + :param predicate: A callable that receives mock which result + will be interpreted as a boolean value. + The final predicate value is the return value. + """ + try: + self._condition.acquire() + + def _predicate(): + return predicate(self._mock) + + b = self._condition.wait_for(_predicate, timeout) + + if not b: + msg = (f"{self._mock._mock_name or 'mock'} was not called before" + f" timeout({timeout}).") + raise AssertionError(msg) + finally: + self._condition.release() + + def __bool__(self): + return self._mock.call_count != 0 + + def __eq__(self, other): + return bool(self) == other + + def __lt__(self, other): + return bool(self) < other + + def __repr__(self): + return repr(bool(self)) + + def _notify(self): + try: + self._condition.acquire() + self._condition.notify_all() + finally: + self._condition.release() + class _Call(tuple): """ diff --git a/Lib/unittest/test/testmock/support.py b/Lib/unittest/test/testmock/support.py index 49986d65dc47af..424e4757b97e11 100644 --- a/Lib/unittest/test/testmock/support.py +++ b/Lib/unittest/test/testmock/support.py @@ -1,3 +1,7 @@ +import concurrent.futures +import time + + target = {'foo': 'FOO'} @@ -14,3 +18,15 @@ def wibble(self): pass class X(object): pass + + +def call_after_delay(func, /, *args, **kwargs): + time.sleep(kwargs.pop('delay')) + func(*args, **kwargs) + + +def run_async(func, /, *args, executor=None, delay=0, **kwargs): + if executor is None: + executor = concurrent.futures.ThreadPoolExecutor(max_workers=5) + + executor.submit(call_after_delay, func, *args, **kwargs, delay=delay) diff --git a/Lib/unittest/test/testmock/testmock.py b/Lib/unittest/test/testmock/testmock.py index 01bc4794446521..5654ea2c146691 100644 --- a/Lib/unittest/test/testmock/testmock.py +++ b/Lib/unittest/test/testmock/testmock.py @@ -1,11 +1,12 @@ +import concurrent.futures import copy import re import sys import tempfile -from test.support import ALWAYS_EQ +from test.support import ALWAYS_EQ, start_threads import unittest -from unittest.test.testmock.support import is_instance +from unittest.test.testmock.support import is_instance, run_async from unittest import mock from unittest.mock import ( call, DEFAULT, patch, sentinel, @@ -2059,6 +2060,105 @@ def trace(frame, event, arg): # pragma: no cover obj = mock(spec=Something) self.assertIsInstance(obj, Something) + def test_wait_until_called(self): + mock = Mock(spec=Something)() + run_async(mock.method_1, delay=0.01) + mock.method_1.called.wait() + mock.method_1.assert_called_once() + + def test_wait_until_called_called_before(self): + mock = Mock(spec=Something)() + mock.method_1() + mock.method_1.wait_until_called() + mock.method_1.assert_called_once() + + def test_wait_until_called_magic_method(self): + mock = MagicMock(spec=Something)() + run_async(mock.method_1.__str__, delay=0.01) + mock.method_1.__str__.called.wait() + mock.method_1.__str__.assert_called_once() + + def test_wait_until_called_timeout(self): + mock = Mock(spec=Something)() + run_async(mock.method_1, delay=0.2) + + with self.assertRaises(AssertionError): + mock.method_1.called.wait(timeout=0.1) + + mock.method_1.assert_not_called() + mock.method_1.called.wait() + mock.method_1.assert_called_once() + + def test_wait_until_any_call_positional(self): + mock = Mock(spec=Something)() + run_async(mock.method_1, 1, delay=0.1) + run_async(mock.method_1, 2, delay=0.2) + run_async(mock.method_1, 3, delay=0.3) + + for arg in (1, 2, 3): + self.assertNotIn(call(arg), mock.method_1.mock_calls) + mock.method_1.called.wait_for(lambda m: call(arg) in m.call_args_list) + mock.method_1.assert_called_with(arg) + + def test_wait_until_any_call_keywords(self): + mock = Mock(spec=Something)() + run_async(mock.method_1, a=1, delay=0.1) + run_async(mock.method_1, a=2, delay=0.2) + run_async(mock.method_1, a=3, delay=0.3) + + for arg in (1, 2, 3): + self.assertNotIn(call(arg), mock.method_1.mock_calls) + mock.method_1.called.wait_for(lambda m: call(a=arg) in m.call_args_list) + mock.method_1.assert_called_with(a=arg) + + def test_wait_until_any_call_no_argument(self): + mock = Mock(spec=Something)() + mock.method_1(1) + mock.method_1assert_called_once_with(1) + + with self.assertRaises(AssertionError): + mock.method_1.called.wait_for(lambda m: call() in m.call_args_list, timeout=0.01) + + mock.method_1() + mock.method_1.called.wait_for(lambda m: call() in m.call_args_list, timeout=0.01) + + def test_called_is_boolean_like(self): + mock = Mock(spec=Something)() + + self.assertFalse(mock.method_1.called) + + self.assertEqual(mock.method_1.called, False) + self.assertEqual(mock.method_1.called, 0) + self.assertEqual(mock.method_1.called, 0.0) + + self.assertLess(mock.method_1.called, 1) + self.assertLess(mock.method_1.called, 1.0) + + self.assertLessEqual(mock.method_1.called, False) + self.assertLessEqual(mock.method_1.called, 0) + self.assertLessEqual(mock.method_1.called, 0.0) + + self.assertEqual(str(mock.method_1.called), str(False)) + self.assertEqual(repr(mock.method_1.called), repr(False)) + + mock.method_1() + + self.assertTrue(mock.method_1.called) + + self.assertEqual(mock.method_1.called, True) + self.assertEqual(mock.method_1.called, 1) + self.assertEqual(mock.method_1.called, 1.0) + + self.assertGreater(mock.method_1.called, 0) + self.assertGreater(mock.method_1.called, 0.0) + + self.assertGreaterEqual(mock.method_1.called, True) + self.assertGreaterEqual(mock.method_1.called, 1) + self.assertGreaterEqual(mock.method_1.called, 1.0) + + self.assertEqual(str(mock.method_1.called), str(True)) + self.assertEqual(repr(mock.method_1.called), repr(True)) + if __name__ == '__main__': unittest.main() diff --git a/Misc/NEWS.d/next/Library/2019-11-12-13-11-44.bpo-17013.C06aC9.rst b/Misc/NEWS.d/next/Library/2019-11-12-13-11-44.bpo-17013.C06aC9.rst new file mode 100644 index 00000000000000..2270198c3975a7 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2019-11-12-13-11-44.bpo-17013.C06aC9.rst @@ -0,0 +1,2 @@ +Extend :attr:`called` of :class:`Mock.called` to wait for the calls in +multithreaded tests. Patch by Ilya Kulakov.