Skip to content

bpo-17013: Extend Mock.called to allow waiting for calls #17133

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion Doc/library/unittest.mock.rst
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ the *new_callable* argument to :func:`patch`.

.. attribute:: called

A boolean representing whether or not the mock object has been called:
A boolean-like object representing whether or not the mock object has been called:

>>> mock = Mock(return_value=None)
>>> mock.called
Expand All @@ -502,6 +502,15 @@ the *new_callable* argument to :func:`patch`.
>>> mock.called
True

The object gives access to methods helpful in multithreaded tests:

- :meth:`wait(/, skip=0, timeout=None)` asserts that mock is called
*skip* times during *timeout*

- :meth:`wait_for(predicate, /, timeout=None)` asserts that
*predicate* was ``True`` at least once during the timeout;
*predicate* receives exactly one positional argument: the mock itself

.. attribute:: call_count

An integer telling you how many times the mock object has been called:
Expand Down
72 changes: 67 additions & 5 deletions Lib/unittest/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,10 @@
import pprint
import sys
import builtins
import threading
from types import CodeType, ModuleType, MethodType
from unittest.util import safe_repr
from functools import wraps, partial
from functools import wraps, partial, total_ordering


_builtins = {name for name in dir(builtins) if not name.startswith('_')}
Expand Down Expand Up @@ -217,7 +218,7 @@ def reset_mock():
if _is_instance_mock(ret) and not ret is mock:
ret.reset_mock()

funcopy.called = False
funcopy.called = _CallEvent(mock)
funcopy.call_count = 0
funcopy.call_args = None
funcopy.call_args_list = _CallList()
Expand Down Expand Up @@ -439,7 +440,7 @@ def __init__(
__dict__['_mock_wraps'] = wraps
__dict__['_mock_delegate'] = None

__dict__['_mock_called'] = False
__dict__['_mock_called'] = _CallEvent(self)
__dict__['_mock_call_args'] = None
__dict__['_mock_call_count'] = 0
__dict__['_mock_call_args_list'] = _CallList()
Expand Down Expand Up @@ -577,7 +578,7 @@ def reset_mock(self, visited=None,*, return_value=False, side_effect=False):
return
visited.append(id(self))

self.called = False
self.called = _CallEvent(self)
self.call_args = None
self.call_count = 0
self.mock_calls = _CallList()
Expand Down Expand Up @@ -1093,8 +1094,8 @@ def _mock_call(self, /, *args, **kwargs):
return self._execute_mock_call(*args, **kwargs)

def _increment_mock_call(self, /, *args, **kwargs):
self.called = True
self.call_count += 1
self.called._notify()

# handle call_args
# needs to be set here so assertions on call arguments pass before
Expand Down Expand Up @@ -2358,6 +2359,67 @@ def _format_call_signature(name, args, kwargs):
return message % formatted_args


@total_ordering
class _CallEvent(object):
def __init__(self, mock):
self._mock = mock
self._condition = threading.Condition()

def wait(self, /, skip=0, timeout=None):
"""
Wait for any call.

:param skip: How many calls will be skipped.
As a result, the mock should be called at least
``skip + 1`` times.
"""
def predicate(mock):
return mock.call_count > skip

self.wait_for(predicate, timeout=timeout)

def wait_for(self, predicate, /, timeout=None):
"""
Wait for a given predicate to become True.

:param predicate: A callable that receives mock which result
will be interpreted as a boolean value.
The final predicate value is the return value.
"""
try:
self._condition.acquire()

def _predicate():
return predicate(self._mock)

b = self._condition.wait_for(_predicate, timeout)

if not b:
msg = (f"{self._mock._mock_name or 'mock'} was not called before"
f" timeout({timeout}).")
raise AssertionError(msg)
finally:
self._condition.release()

def __bool__(self):
return self._mock.call_count != 0

def __eq__(self, other):
return bool(self) == other

def __lt__(self, other):
return bool(self) < other

def __repr__(self):
return repr(bool(self))

def _notify(self):
try:
self._condition.acquire()
self._condition.notify_all()
finally:
self._condition.release()


class _Call(tuple):
"""
Expand Down
16 changes: 16 additions & 0 deletions Lib/unittest/test/testmock/support.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import concurrent.futures
import time


target = {'foo': 'FOO'}


Expand All @@ -14,3 +18,15 @@ def wibble(self): pass

class X(object):
pass


def call_after_delay(func, /, *args, **kwargs):
time.sleep(kwargs.pop('delay'))
func(*args, **kwargs)


def run_async(func, /, *args, executor=None, delay=0, **kwargs):
if executor is None:
executor = concurrent.futures.ThreadPoolExecutor(max_workers=5)

executor.submit(call_after_delay, func, *args, **kwargs, delay=delay)
104 changes: 102 additions & 2 deletions Lib/unittest/test/testmock/testmock.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import concurrent.futures
import copy
import re
import sys
import tempfile

from test.support import ALWAYS_EQ
from test.support import ALWAYS_EQ, start_threads
import unittest
from unittest.test.testmock.support import is_instance
from unittest.test.testmock.support import is_instance, run_async
from unittest import mock
from unittest.mock import (
call, DEFAULT, patch, sentinel,
Expand Down Expand Up @@ -2059,6 +2060,105 @@ def trace(frame, event, arg): # pragma: no cover
obj = mock(spec=Something)
self.assertIsInstance(obj, Something)

def test_wait_until_called(self):
mock = Mock(spec=Something)()
run_async(mock.method_1, delay=0.01)
mock.method_1.called.wait()
mock.method_1.assert_called_once()

def test_wait_until_called_called_before(self):
mock = Mock(spec=Something)()
mock.method_1()
mock.method_1.wait_until_called()
mock.method_1.assert_called_once()

def test_wait_until_called_magic_method(self):
mock = MagicMock(spec=Something)()
run_async(mock.method_1.__str__, delay=0.01)
mock.method_1.__str__.called.wait()
mock.method_1.__str__.assert_called_once()

def test_wait_until_called_timeout(self):
mock = Mock(spec=Something)()
run_async(mock.method_1, delay=0.2)

with self.assertRaises(AssertionError):
mock.method_1.called.wait(timeout=0.1)

mock.method_1.assert_not_called()
mock.method_1.called.wait()
mock.method_1.assert_called_once()

def test_wait_until_any_call_positional(self):
mock = Mock(spec=Something)()
run_async(mock.method_1, 1, delay=0.1)
run_async(mock.method_1, 2, delay=0.2)
run_async(mock.method_1, 3, delay=0.3)

for arg in (1, 2, 3):
self.assertNotIn(call(arg), mock.method_1.mock_calls)
mock.method_1.called.wait_for(lambda m: call(arg) in m.call_args_list)
mock.method_1.assert_called_with(arg)

def test_wait_until_any_call_keywords(self):
mock = Mock(spec=Something)()
run_async(mock.method_1, a=1, delay=0.1)
run_async(mock.method_1, a=2, delay=0.2)
run_async(mock.method_1, a=3, delay=0.3)

for arg in (1, 2, 3):
self.assertNotIn(call(arg), mock.method_1.mock_calls)
mock.method_1.called.wait_for(lambda m: call(a=arg) in m.call_args_list)
mock.method_1.assert_called_with(a=arg)

def test_wait_until_any_call_no_argument(self):
mock = Mock(spec=Something)()
mock.method_1(1)
mock.method_1assert_called_once_with(1)

with self.assertRaises(AssertionError):
mock.method_1.called.wait_for(lambda m: call() in m.call_args_list, timeout=0.01)

mock.method_1()
mock.method_1.called.wait_for(lambda m: call() in m.call_args_list, timeout=0.01)

def test_called_is_boolean_like(self):
mock = Mock(spec=Something)()

self.assertFalse(mock.method_1.called)

self.assertEqual(mock.method_1.called, False)
self.assertEqual(mock.method_1.called, 0)
self.assertEqual(mock.method_1.called, 0.0)

self.assertLess(mock.method_1.called, 1)
self.assertLess(mock.method_1.called, 1.0)

self.assertLessEqual(mock.method_1.called, False)
self.assertLessEqual(mock.method_1.called, 0)
self.assertLessEqual(mock.method_1.called, 0.0)

self.assertEqual(str(mock.method_1.called), str(False))
self.assertEqual(repr(mock.method_1.called), repr(False))

mock.method_1()

self.assertTrue(mock.method_1.called)

self.assertEqual(mock.method_1.called, True)
self.assertEqual(mock.method_1.called, 1)
self.assertEqual(mock.method_1.called, 1.0)

self.assertGreater(mock.method_1.called, 0)
self.assertGreater(mock.method_1.called, 0.0)

self.assertGreaterEqual(mock.method_1.called, True)
self.assertGreaterEqual(mock.method_1.called, 1)
self.assertGreaterEqual(mock.method_1.called, 1.0)

self.assertEqual(str(mock.method_1.called), str(True))
self.assertEqual(repr(mock.method_1.called), repr(True))


if __name__ == '__main__':
unittest.main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Extend :attr:`called` of :class:`Mock.called` to wait for the calls in
multithreaded tests. Patch by Ilya Kulakov.