diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 73fd62a..222182e 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -10,6 +10,12 @@ Note however that this is just a stop gap measure, and new code should use ``MockerFixture`` for type annotations. +* Improved typing for ``MockerFixture.patch`` (`#201`_). Thanks `@srittau`_ for the PR. + +.. _@srittau: https://github.com/srittau +.. _#201: https://github.com/pytest-dev/pytest-mock/pull/201 + + 3.3.0 (2020-08-21) ------------------ diff --git a/src/pytest_mock/plugin.py b/src/pytest_mock/plugin.py index 3ebaefb..0573a4b 100644 --- a/src/pytest_mock/plugin.py +++ b/src/pytest_mock/plugin.py @@ -1,6 +1,6 @@ import builtins import unittest.mock -from typing import cast, Generator, Mapping, Iterable, Tuple +from typing import cast, overload, Generator, Mapping, Iterable, Tuple, TypeVar from typing import Any from typing import Callable from typing import Dict @@ -17,6 +17,8 @@ import pytest +_T = TypeVar("_T") + def _get_mock_module(config): """ @@ -50,7 +52,9 @@ def __init__(self, config: Any) -> None: self._patches = [] # type: List[Any] self._mocks = [] # type: List[Any] self.mock_module = mock_module = _get_mock_module(config) - self.patch = self._Patcher(self._patches, self._mocks, mock_module) + self.patch = self._Patcher( + self._patches, self._mocks, mock_module + ) # type: MockerFixture._Patcher # aliases for convenience self.Mock = mock_module.Mock self.MagicMock = mock_module.MagicMock @@ -254,6 +258,63 @@ def dict( **kwargs ) + @overload + def __call__( + self, + target: str, + new: None = ..., + spec: Optional[builtins.object] = ..., + create: bool = ..., + spec_set: Optional[builtins.object] = ..., + autospec: Optional[builtins.object] = ..., + new_callable: None = ..., + **kwargs: Any + ) -> unittest.mock.MagicMock: + ... + + @overload + def __call__( + self, + target: str, + new: _T, + spec: Optional[builtins.object] = ..., + create: bool = ..., + spec_set: Optional[builtins.object] = ..., + autospec: Optional[builtins.object] = ..., + new_callable: None = ..., + **kwargs: Any + ) -> _T: + ... + + @overload + def __call__( + self, + target: str, + new: None, + spec: Optional[builtins.object], + create: bool, + spec_set: Optional[builtins.object], + autospec: Optional[builtins.object], + new_callable: Callable[[], _T], + **kwargs: Any + ) -> _T: + ... + + @overload + def __call__( + self, + target: str, + new: None = ..., + spec: Optional[builtins.object] = ..., + create: bool = ..., + spec_set: Optional[builtins.object] = ..., + autospec: Optional[builtins.object] = ..., + *, + new_callable: Callable[[], _T], + **kwargs: Any + ) -> _T: + ... + def __call__( self, target: str, @@ -262,9 +323,9 @@ def __call__( create: bool = False, spec_set: Optional[builtins.object] = None, autospec: Optional[builtins.object] = None, - new_callable: Optional[builtins.object] = None, + new_callable: Optional[Callable[[], Any]] = None, **kwargs: Any - ) -> unittest.mock.MagicMock: + ) -> Any: """API to mock.patch""" if new is self.DEFAULT: new = self.mock_module.DEFAULT