Skip to content

Improve type checking of ParamSpec in calls #11603

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion mypy/argmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from mypy.maptype import map_instance_to_supertype
from mypy.types import (
Type, Instance, TupleType, AnyType, TypeOfAny, TypedDictType, get_proper_type
Type, Instance, TupleType, AnyType, TypeOfAny, TypedDictType, ParamSpecType, get_proper_type
)
from mypy import nodes

Expand Down Expand Up @@ -191,6 +191,9 @@ def expand_actual_type(self,
else:
self.tuple_index += 1
return actual_type.items[self.tuple_index - 1]
elif isinstance(actual_type, ParamSpecType):
# ParamSpec is valid in *args but it can't be unpacked.
return actual_type
else:
return AnyType(TypeOfAny.from_error)
elif actual_kind == nodes.ARG_STAR2:
Expand All @@ -215,6 +218,9 @@ def expand_actual_type(self,
actual_type,
self.context.mapping_type.type,
).args[1]
elif isinstance(actual_type, ParamSpecType):
# ParamSpec is valid in **kwargs but it can't be unpacked.
return actual_type
else:
return AnyType(TypeOfAny.from_error)
else:
Expand Down
26 changes: 13 additions & 13 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1032,24 +1032,24 @@ def check_callable_call(self,
callee = self.infer_function_type_arguments(
callee, args, arg_kinds, formal_to_actual, context)
if need_refresh:
# Argument kinds etc. may have changed; recalculate actual-to-formal map
# Argument kinds etc. may have changed due to
# ParamSpec variables being replaced with an arbitrary
# number of arguments; recalculate actual-to-formal map
formal_to_actual = map_actuals_to_formals(
arg_kinds, arg_names,
callee.arg_kinds, callee.arg_names,
lambda i: self.accept(args[i]))

param_spec = callee.param_spec()
if param_spec is not None and arg_kinds == [ARG_STAR, ARG_STAR2]:
arg1 = get_proper_type(self.accept(args[0]))
arg2 = get_proper_type(self.accept(args[1]))
if (is_named_instance(arg1, 'builtins.tuple')
and is_named_instance(arg2, 'builtins.dict')):
assert isinstance(arg1, Instance)
assert isinstance(arg2, Instance)
if (isinstance(arg1.args[0], ParamSpecType)
and isinstance(arg2.args[1], ParamSpecType)):
# TODO: Check ParamSpec ids and flavors
return callee.ret_type, callee
arg1 = self.accept(args[0])
arg2 = self.accept(args[1])
if (isinstance(arg1, ParamSpecType)
and isinstance(arg2, ParamSpecType)
and arg1.flavor == ParamSpecFlavor.ARGS
and arg2.flavor == ParamSpecFlavor.KWARGS
and arg1.id == arg2.id == param_spec.id):
return callee.ret_type, callee

arg_types = self.infer_arg_types_in_context(
callee, args, arg_kinds, formal_to_actual)
Expand Down Expand Up @@ -4003,7 +4003,7 @@ def is_valid_var_arg(self, typ: Type) -> bool:
is_subtype(typ, self.chk.named_generic_type('typing.Iterable',
[AnyType(TypeOfAny.special_form)])) or
isinstance(typ, AnyType) or
(isinstance(typ, ParamSpecType) and typ.flavor == ParamSpecFlavor.ARGS))
isinstance(typ, ParamSpecType))

def is_valid_keyword_var_arg(self, typ: Type) -> bool:
"""Is a type valid as a **kwargs argument?"""
Expand All @@ -4012,7 +4012,7 @@ def is_valid_keyword_var_arg(self, typ: Type) -> bool:
[self.named_type('builtins.str'), AnyType(TypeOfAny.special_form)])) or
is_subtype(typ, self.chk.named_generic_type('typing.Mapping',
[UninhabitedType(), UninhabitedType()])) or
(isinstance(typ, ParamSpecType) and typ.flavor == ParamSpecFlavor.KWARGS)
isinstance(typ, ParamSpecType)
)
if self.chk.options.python_version[0] < 3:
ret = ret or is_subtype(typ, self.chk.named_generic_type('typing.Mapping',
Expand Down
2 changes: 2 additions & 0 deletions mypy/expandtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ def visit_param_spec(self, t: ParamSpecType) -> Type:
# Return copy of instance with type erasure flag on.
return Instance(inst.type, inst.args, line=inst.line,
column=inst.column, erased=True)
elif isinstance(repl, ParamSpecType):
return repl.with_flavor(t.flavor)
else:
return repl

Expand Down
4 changes: 4 additions & 0 deletions mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,10 @@ def new_unification_variable(old: 'ParamSpecType') -> 'ParamSpecType':
return ParamSpecType(old.name, old.fullname, new_id, old.flavor, old.upper_bound,
line=old.line, column=old.column)

def with_flavor(self, flavor: int) -> 'ParamSpecType':
return ParamSpecType(self.name, self.fullname, self.id, flavor,
upper_bound=self.upper_bound)

def accept(self, visitor: 'TypeVisitor[T]') -> T:
return visitor.visit_param_spec(self)

Expand Down
32 changes: 32 additions & 0 deletions test-data/unit/check-parameter-specification.test
Original file line number Diff line number Diff line change
Expand Up @@ -346,3 +346,35 @@ reveal_type(register(lambda x: f(x), x=1)) # N: Revealed type is "def (x: Any)"
register(lambda x: f(x)) # E: Missing positional argument "x" in call to "register"
register(lambda x: f(x), y=1) # E: Unexpected keyword argument "y" for "register"
[builtins fixtures/dict.pyi]

[case testParamSpecInvalidCalls]
from typing import Callable, Generic
from typing_extensions import ParamSpec

P = ParamSpec('P')
P2 = ParamSpec('P2')

class C(Generic[P, P2]):
def m1(self, *args: P.args, **kwargs: P.kwargs) -> None:
self.m1(*args, **kwargs)
self.m2(*args, **kwargs) # E: Argument 1 to "m2" of "C" has incompatible type "*P.args"; expected "P2.args" \
# E: Argument 2 to "m2" of "C" has incompatible type "**P.kwargs"; expected "P2.kwargs"
self.m1(*kwargs, **args) # E: Argument 1 to "m1" of "C" has incompatible type "*P.kwargs"; expected "P.args" \
# E: Argument 2 to "m1" of "C" has incompatible type "**P.args"; expected "P.kwargs"
self.m3(*args, **kwargs) # E: Argument 1 to "m3" of "C" has incompatible type "*P.args"; expected "int" \
# E: Argument 2 to "m3" of "C" has incompatible type "**P.kwargs"; expected "int"
self.m4(*args, **kwargs) # E: Argument 1 to "m4" of "C" has incompatible type "*P.args"; expected "int" \
# E: Argument 2 to "m4" of "C" has incompatible type "**P.kwargs"; expected "int"

self.m1(*args, **args) # E: Argument 2 to "m1" of "C" has incompatible type "**P.args"; expected "P.kwargs"
self.m1(*kwargs, **kwargs) # E: Argument 1 to "m1" of "C" has incompatible type "*P.kwargs"; expected "P.args"

def m2(self, *args: P2.args, **kwargs: P2.kwargs) -> None:
pass

def m3(self, *args: int, **kwargs: int) -> None:
pass

def m4(self, x: int) -> None:
pass
[builtins fixtures/dict.pyi]