Skip to content

Commit ed20f38

Browse files
committed
Improve ParamSpec type inference from lambda
If ParamSpec is in the context of a lambda, treat it similar to `Callable[..., Any]`. This allows us to infer at least argument counts and kinds. Types can't be inferred since that would require "backwards" type inference, which we don't support. Follow-up to #11594.
1 parent 4996d57 commit ed20f38

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed

mypy/checkexpr.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -3608,18 +3608,19 @@ def infer_lambda_type_using_context(self, e: LambdaExpr) -> Tuple[Optional[Calla
36083608

36093609
arg_kinds = [arg.kind for arg in e.arguments]
36103610

3611-
if callable_ctx.is_ellipsis_args:
3611+
if callable_ctx.is_ellipsis_args or ctx.param_spec() is not None:
36123612
# Fill in Any arguments to match the arguments of the lambda.
36133613
callable_ctx = callable_ctx.copy_modified(
36143614
is_ellipsis_args=False,
36153615
arg_types=[AnyType(TypeOfAny.special_form)] * len(arg_kinds),
36163616
arg_kinds=arg_kinds,
3617-
arg_names=[None] * len(arg_kinds)
3617+
arg_names=e.arg_names[:],
36183618
)
36193619

36203620
if ARG_STAR in arg_kinds or ARG_STAR2 in arg_kinds:
36213621
# TODO treat this case appropriately
36223622
return callable_ctx, None
3623+
36233624
if callable_ctx.arg_kinds != arg_kinds:
36243625
# Incompatible context; cannot use it to infer types.
36253626
self.chk.fail(message_registry.CANNOT_INFER_LAMBDA_TYPE, e)

test-data/unit/check-parameter-specification.test

+18
Original file line numberDiff line numberDiff line change
@@ -328,3 +328,21 @@ c2 = c
328328
c3 = C(f)
329329
c = c3
330330
[builtins fixtures/dict.pyi]
331+
332+
[case testParamSpecInferredFromLambda]
333+
from typing import Callable, TypeVar
334+
from typing_extensions import ParamSpec
335+
336+
P = ParamSpec('P')
337+
T = TypeVar('T')
338+
339+
# Similar to atexit.register
340+
def register(f: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> Callable[P, T]: ... # N: "register" defined here
341+
342+
def f(x: int) -> None: pass
343+
344+
reveal_type(register(lambda: f(1))) # N: Revealed type is "def ()"
345+
reveal_type(register(lambda x: f(x), x=1)) # N: Revealed type is "def (x: Any)"
346+
register(lambda x: f(x)) # E: Missing positional argument "x" in call to "register"
347+
register(lambda x: f(x), y=1) # E: Unexpected keyword argument "y" for "register"
348+
[builtins fixtures/dict.pyi]

0 commit comments

Comments
 (0)