From ea5e8b981d1c3bafd5896f977b83f3d37c60ecd3 Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Wed, 15 Dec 2021 17:21:22 +0000 Subject: [PATCH] Fix inferring constrainta from param spec callable against Any (#11725) We didn't infer constraints from the return type. Fixes #11704. --- mypy/constraints.py | 10 +++---- .../unit/check-parameter-specification.test | 28 +++++++++++++++++++ 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/mypy/constraints.py b/mypy/constraints.py index 9c1bfb0eba53..5a78cdb94e93 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -581,12 +581,12 @@ def visit_callable_type(self, template: CallableType) -> List[Constraint]: if param_spec is None: # FIX what if generic res = self.infer_against_any(template.arg_types, self.actual) - res.extend(infer_constraints(template.ret_type, any_type, self.direction)) - return res else: - return [Constraint(param_spec.id, - SUBTYPE_OF, - callable_with_ellipsis(any_type, any_type, template.fallback))] + res = [Constraint(param_spec.id, + SUBTYPE_OF, + callable_with_ellipsis(any_type, any_type, template.fallback))] + res.extend(infer_constraints(template.ret_type, any_type, self.direction)) + return res elif isinstance(self.actual, Overloaded): return self.infer_against_overloaded(self.actual, template) elif isinstance(self.actual, TypeType): diff --git a/test-data/unit/check-parameter-specification.test b/test-data/unit/check-parameter-specification.test index 6beba0772d36..f6123915aada 100644 --- a/test-data/unit/check-parameter-specification.test +++ b/test-data/unit/check-parameter-specification.test @@ -378,3 +378,31 @@ class C(Generic[P, P2]): def m4(self, x: int) -> None: pass [builtins fixtures/dict.pyi] + +[case testParamSpecOverUnannotatedDecorator] +from typing import Callable, Iterator, TypeVar, ContextManager, Any +from typing_extensions import ParamSpec + +from nonexistent import deco2 # type: ignore + +T = TypeVar("T") +P = ParamSpec("P") +T_co = TypeVar("T_co", covariant=True) + +class CM(ContextManager[T_co]): + def __call__(self, func: T) -> T: ... + +def deco1( + func: Callable[P, Iterator[T]]) -> Callable[P, CM[T]]: ... + +@deco1 +@deco2 +def f(): + pass + +reveal_type(f) # N: Revealed type is "def (*Any, **Any) -> __main__.CM[Any]" + +with f() as x: + pass +[builtins fixtures/dict.pyi] +[typing fixtures/typing-full.pyi]