@@ -411,15 +411,18 @@ class NonCallableMock(Base):
411
411
# necessary.
412
412
_lock = RLock ()
413
413
414
- def __new__ (cls , / , * args , ** kw ):
414
+ def __new__ (
415
+ cls , spec = None , wraps = None , name = None , spec_set = None ,
416
+ parent = None , _spec_state = None , _new_name = '' , _new_parent = None ,
417
+ _spec_as_instance = False , _eat_self = None , unsafe = False , ** kwargs
418
+ ):
415
419
# every instance has its own class
416
420
# so we can create magic methods on the
417
421
# class without stomping on other mocks
418
422
bases = (cls ,)
419
423
if not issubclass (cls , AsyncMockMixin ):
420
424
# Check if spec is an async object or function
421
- bound_args = _MOCK_SIG .bind_partial (cls , * args , ** kw ).arguments
422
- spec_arg = bound_args .get ('spec_set' , bound_args .get ('spec' ))
425
+ spec_arg = spec_set or spec
423
426
if spec_arg is not None and _is_async_obj (spec_arg ):
424
427
bases = (AsyncMockMixin , cls )
425
428
new = type (cls .__name__ , bases , {'__doc__' : cls .__doc__ })
@@ -503,11 +506,6 @@ def _mock_add_spec(self, spec, spec_set, _spec_as_instance=False,
503
506
504
507
_spec_class = None
505
508
_spec_signature = None
506
- _spec_asyncs = []
507
-
508
- for attr in dir (spec ):
509
- if iscoroutinefunction (getattr (spec , attr , None )):
510
- _spec_asyncs .append (attr )
511
509
512
510
if spec is not None and not _is_list (spec ):
513
511
if isinstance (spec , type ):
@@ -525,7 +523,6 @@ def _mock_add_spec(self, spec, spec_set, _spec_as_instance=False,
525
523
__dict__ ['_spec_set' ] = spec_set
526
524
__dict__ ['_spec_signature' ] = _spec_signature
527
525
__dict__ ['_mock_methods' ] = spec
528
- __dict__ ['_spec_asyncs' ] = _spec_asyncs
529
526
530
527
def __get_return_value (self ):
531
528
ret = self ._mock_return_value
@@ -1015,7 +1012,8 @@ def _get_child_mock(self, /, **kw):
1015
1012
For non-callable mocks the callable variant will be used (rather than
1016
1013
any custom subclass)."""
1017
1014
_new_name = kw .get ("_new_name" )
1018
- if _new_name in self .__dict__ ['_spec_asyncs' ]:
1015
+ _spec_val = getattr (self .__dict__ ["_spec_class" ], _new_name , None )
1016
+ if _spec_val is not None and asyncio .iscoroutinefunction (_spec_val ):
1019
1017
return AsyncMock (** kw )
1020
1018
1021
1019
if self ._mock_sealed :
@@ -1057,9 +1055,6 @@ def _calls_repr(self, prefix="Calls"):
1057
1055
return f"\n { prefix } : { safe_repr (self .mock_calls )} ."
1058
1056
1059
1057
1060
- _MOCK_SIG = inspect .signature (NonCallableMock .__init__ )
1061
-
1062
-
1063
1058
class _AnyComparer (list ):
1064
1059
"""A list which checks if it contains a call which may have an
1065
1060
argument of ANY, flipping the components of item and self from
@@ -2183,6 +2178,10 @@ def __get__(self, obj, _type=None):
2183
2178
return self .create_mock ()
2184
2179
2185
2180
2181
+ _CODE_ATTRS = dir (CodeType )
2182
+ _CODE_SIG = inspect .signature (partial (CodeType .__init__ , None ))
2183
+
2184
+
2186
2185
class AsyncMockMixin (Base ):
2187
2186
await_count = _delegating_property ('await_count' )
2188
2187
await_args = _delegating_property ('await_args' )
@@ -2200,7 +2199,9 @@ def __init__(self, /, *args, **kwargs):
2200
2199
self .__dict__ ['_mock_await_count' ] = 0
2201
2200
self .__dict__ ['_mock_await_args' ] = None
2202
2201
self .__dict__ ['_mock_await_args_list' ] = _CallList ()
2203
- code_mock = NonCallableMock (spec_set = CodeType )
2202
+ code_mock = NonCallableMock (spec_set = _CODE_ATTRS )
2203
+ code_mock .__dict__ ["_spec_class" ] = CodeType
2204
+ code_mock .__dict__ ["_spec_signature" ] = _CODE_SIG
2204
2205
code_mock .co_flags = inspect .CO_COROUTINE
2205
2206
self .__dict__ ['__code__' ] = code_mock
2206
2207
self .__dict__ ['__name__' ] = 'AsyncMock'
0 commit comments