Skip to content

Commit f41e24c

Browse files
authored
Lenient handling of trivial Callable suffixes (#15913)
Fixes #15734 Fixes #15188 Fixes #14321 Fixes #13107 (plain Callable was already working, this fixes the protocol example) Fixes #16058 It looks like treating trivial suffixes (especially for erased callables) as "whatever works" is a right thing, because it reflects the whole idea of why we normally check subtyping with respect to an e.g. erased type. As you can see this fixes a bunch of issues. Note it was necessary to make couple more tweaks to make everything work smoothly: * Adjust self-type erasure level in `checker.py` to match other places. * Explicitly allow `Callable` as a `self`/`cls` annotation (actually I am not sure we need to keep this check at all, since we now have good inference for self-types, and we check they are safe either at definition site or at call site).
1 parent b327557 commit f41e24c

9 files changed

+204
-11
lines changed

mypy/checker.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1208,7 +1208,9 @@ def check_func_def(
12081208
):
12091209
if defn.is_class or defn.name == "__new__":
12101210
ref_type = mypy.types.TypeType.make_normalized(ref_type)
1211-
erased = get_proper_type(erase_to_bound(arg_type))
1211+
# This level of erasure matches the one in checkmember.check_self_arg(),
1212+
# better keep these two checks consistent.
1213+
erased = get_proper_type(erase_typevars(erase_to_bound(arg_type)))
12121214
if not is_subtype(ref_type, erased, ignore_type_params=True):
12131215
if (
12141216
isinstance(erased, Instance)

mypy/checkmember.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -896,6 +896,8 @@ def f(self: S) -> T: ...
896896
return functype
897897
else:
898898
selfarg = get_proper_type(item.arg_types[0])
899+
# This level of erasure matches the one in checker.check_func_def(),
900+
# better keep these two checks consistent.
899901
if subtypes.is_subtype(dispatched_arg_type, erase_typevars(erase_to_bound(selfarg))):
900902
new_items.append(item)
901903
elif isinstance(selfarg, ParamSpecType):

mypy/messages.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2132,6 +2132,9 @@ def report_protocol_problems(
21322132
not is_subtype(subtype, erase_type(supertype), options=self.options)
21332133
or not subtype.type.defn.type_vars
21342134
or not supertype.type.defn.type_vars
2135+
# Always show detailed message for ParamSpec
2136+
or subtype.type.has_param_spec_type
2137+
or supertype.type.has_param_spec_type
21352138
):
21362139
type_name = format_type(subtype, self.options, module_names=True)
21372140
self.note(f"Following member(s) of {type_name} have conflicts:", context, code=code)

mypy/subtypes.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1519,6 +1519,18 @@ def are_trivial_parameters(param: Parameters | NormalizedCallableType) -> bool:
15191519
)
15201520

15211521

1522+
def is_trivial_suffix(param: Parameters | NormalizedCallableType) -> bool:
1523+
param_star = param.var_arg()
1524+
param_star2 = param.kw_arg()
1525+
return (
1526+
param.arg_kinds[-2:] == [ARG_STAR, ARG_STAR2]
1527+
and param_star is not None
1528+
and isinstance(get_proper_type(param_star.typ), AnyType)
1529+
and param_star2 is not None
1530+
and isinstance(get_proper_type(param_star2.typ), AnyType)
1531+
)
1532+
1533+
15221534
def are_parameters_compatible(
15231535
left: Parameters | NormalizedCallableType,
15241536
right: Parameters | NormalizedCallableType,
@@ -1540,6 +1552,7 @@ def are_parameters_compatible(
15401552
# Treat "def _(*a: Any, **kw: Any) -> X" similarly to "Callable[..., X]"
15411553
if are_trivial_parameters(right):
15421554
return True
1555+
trivial_suffix = is_trivial_suffix(right)
15431556

15441557
# Match up corresponding arguments and check them for compatibility. In
15451558
# every pair (argL, argR) of corresponding arguments from L and R, argL must
@@ -1570,7 +1583,7 @@ def _incompatible(left_arg: FormalArgument | None, right_arg: FormalArgument | N
15701583
if right_arg is None:
15711584
return False
15721585
if left_arg is None:
1573-
return not allow_partial_overlap
1586+
return not allow_partial_overlap and not trivial_suffix
15741587
return not is_compat(right_arg.typ, left_arg.typ)
15751588

15761589
if _incompatible(left_star, right_star) or _incompatible(left_star2, right_star2):
@@ -1594,7 +1607,7 @@ def _incompatible(left_arg: FormalArgument | None, right_arg: FormalArgument | N
15941607
# arguments. Get all further positional args of left, and make sure
15951608
# they're more general than the corresponding member in right.
15961609
# TODO: are we handling UnpackType correctly here?
1597-
if right_star is not None:
1610+
if right_star is not None and not trivial_suffix:
15981611
# Synthesize an anonymous formal argument for the right
15991612
right_by_position = right.try_synthesizing_arg_from_vararg(None)
16001613
assert right_by_position is not None
@@ -1621,7 +1634,7 @@ def _incompatible(left_arg: FormalArgument | None, right_arg: FormalArgument | N
16211634
# Phase 1d: Check kw args. Right has an infinite series of optional named
16221635
# arguments. Get all further named args of left, and make sure
16231636
# they're more general than the corresponding member in right.
1624-
if right_star2 is not None:
1637+
if right_star2 is not None and not trivial_suffix:
16251638
right_names = {name for name in right.arg_names if name is not None}
16261639
left_only_names = set()
16271640
for name, kind in zip(left.arg_names, left.arg_kinds):

mypy/typeops.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,10 @@ def supported_self_type(typ: ProperType) -> bool:
251251
"""
252252
if isinstance(typ, TypeType):
253253
return supported_self_type(typ.item)
254+
if isinstance(typ, CallableType):
255+
# Special case: allow class callable instead of Type[...] as cls annotation,
256+
# as well as callable self for callback protocols.
257+
return True
254258
return isinstance(typ, TypeVarType) or (
255259
isinstance(typ, Instance) and typ != fill_typevars(typ.type)
256260
)

test-data/unit/check-callable.test

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -598,3 +598,34 @@ a: A
598598
a() # E: Missing positional argument "other" in call to "__call__" of "A"
599599
a(a)
600600
a(lambda: None)
601+
602+
[case testCallableSubtypingTrivialSuffix]
603+
from typing import Any, Protocol
604+
605+
class Call(Protocol):
606+
def __call__(self, x: int, *args: Any, **kwargs: Any) -> None: ...
607+
608+
def f1() -> None: ...
609+
a1: Call = f1 # E: Incompatible types in assignment (expression has type "Callable[[], None]", variable has type "Call") \
610+
# N: "Call.__call__" has type "Callable[[Arg(int, 'x'), VarArg(Any), KwArg(Any)], None]"
611+
def f2(x: str) -> None: ...
612+
a2: Call = f2 # E: Incompatible types in assignment (expression has type "Callable[[str], None]", variable has type "Call") \
613+
# N: "Call.__call__" has type "Callable[[Arg(int, 'x'), VarArg(Any), KwArg(Any)], None]"
614+
def f3(y: int) -> None: ...
615+
a3: Call = f3 # E: Incompatible types in assignment (expression has type "Callable[[int], None]", variable has type "Call") \
616+
# N: "Call.__call__" has type "Callable[[Arg(int, 'x'), VarArg(Any), KwArg(Any)], None]"
617+
def f4(x: int) -> None: ...
618+
a4: Call = f4
619+
620+
def f5(x: int, y: int) -> None: ...
621+
a5: Call = f5
622+
623+
def f6(x: int, y: int = 0) -> None: ...
624+
a6: Call = f6
625+
626+
def f7(x: int, *, y: int) -> None: ...
627+
a7: Call = f7
628+
629+
def f8(x: int, *args: int, **kwargs: str) -> None: ...
630+
a8: Call = f8
631+
[builtins fixtures/tuple.pyi]

test-data/unit/check-modules.test

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3193,29 +3193,29 @@ from test1 import aaaa # E: Module "test1" has no attribute "aaaa"
31933193
import b
31943194
[file a.py]
31953195
class Foo:
3196-
def frobnicate(self, x, *args, **kwargs): pass
3196+
def frobnicate(self, x: str, *args, **kwargs): pass
31973197
[file b.py]
31983198
from a import Foo
31993199
class Bar(Foo):
32003200
def frobnicate(self) -> None: pass
32013201
[file b.py.2]
32023202
from a import Foo
32033203
class Bar(Foo):
3204-
def frobnicate(self, *args) -> None: pass
3204+
def frobnicate(self, *args: int) -> None: pass
32053205
[file b.py.3]
32063206
from a import Foo
32073207
class Bar(Foo):
3208-
def frobnicate(self, *args) -> None: pass # type: ignore[override] # I know
3208+
def frobnicate(self, *args: int) -> None: pass # type: ignore[override] # I know
32093209
[builtins fixtures/dict.pyi]
32103210
[out1]
32113211
tmp/b.py:3: error: Signature of "frobnicate" incompatible with supertype "Foo"
32123212
tmp/b.py:3: note: Superclass:
3213-
tmp/b.py:3: note: def frobnicate(self, x: Any, *args: Any, **kwargs: Any) -> Any
3213+
tmp/b.py:3: note: def frobnicate(self, x: str, *args: Any, **kwargs: Any) -> Any
32143214
tmp/b.py:3: note: Subclass:
32153215
tmp/b.py:3: note: def frobnicate(self) -> None
32163216
[out2]
32173217
tmp/b.py:3: error: Signature of "frobnicate" incompatible with supertype "Foo"
32183218
tmp/b.py:3: note: Superclass:
3219-
tmp/b.py:3: note: def frobnicate(self, x: Any, *args: Any, **kwargs: Any) -> Any
3219+
tmp/b.py:3: note: def frobnicate(self, x: str, *args: Any, **kwargs: Any) -> Any
32203220
tmp/b.py:3: note: Subclass:
3221-
tmp/b.py:3: note: def frobnicate(self, *args: Any) -> None
3221+
tmp/b.py:3: note: def frobnicate(self, *args: int) -> None

test-data/unit/check-parameter-specification.test

Lines changed: 138 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1729,7 +1729,12 @@ class A(Protocol[P]):
17291729
...
17301730

17311731
def bar(b: A[P]) -> A[Concatenate[int, P]]:
1732-
return b # E: Incompatible return value type (got "A[P]", expected "A[[int, **P]]")
1732+
return b # E: Incompatible return value type (got "A[P]", expected "A[[int, **P]]") \
1733+
# N: Following member(s) of "A[P]" have conflicts: \
1734+
# N: Expected: \
1735+
# N: def foo(self, int, /, *args: P.args, **kwargs: P.kwargs) -> Any \
1736+
# N: Got: \
1737+
# N: def foo(self, *args: P.args, **kwargs: P.kwargs) -> Any
17331738
[builtins fixtures/paramspec.pyi]
17341739

17351740
[case testParamSpecPrefixSubtypingValidNonStrict]
@@ -1825,6 +1830,138 @@ c: C[int, [int, str], str] # E: Nested parameter specifications are not allowed
18251830
reveal_type(c) # N: Revealed type is "__main__.C[Any]"
18261831
[builtins fixtures/paramspec.pyi]
18271832

1833+
[case testParamSpecConcatenateSelfType]
1834+
from typing import Callable
1835+
from typing_extensions import ParamSpec, Concatenate
1836+
1837+
P = ParamSpec("P")
1838+
class A:
1839+
def __init__(self, a_param_1: str) -> None: ...
1840+
1841+
@classmethod
1842+
def add_params(cls: Callable[P, A]) -> Callable[Concatenate[float, P], A]:
1843+
def new_constructor(i: float, *args: P.args, **kwargs: P.kwargs) -> A:
1844+
return cls(*args, **kwargs)
1845+
return new_constructor
1846+
1847+
@classmethod
1848+
def remove_params(cls: Callable[Concatenate[str, P], A]) -> Callable[P, A]:
1849+
def new_constructor(*args: P.args, **kwargs: P.kwargs) -> A:
1850+
return cls("my_special_str", *args, **kwargs)
1851+
return new_constructor
1852+
1853+
reveal_type(A.add_params()) # N: Revealed type is "def (builtins.float, a_param_1: builtins.str) -> __main__.A"
1854+
reveal_type(A.remove_params()) # N: Revealed type is "def () -> __main__.A"
1855+
[builtins fixtures/paramspec.pyi]
1856+
1857+
[case testParamSpecConcatenateCallbackProtocol]
1858+
from typing import Protocol, TypeVar
1859+
from typing_extensions import ParamSpec, Concatenate
1860+
1861+
P = ParamSpec("P")
1862+
R = TypeVar("R", covariant=True)
1863+
1864+
class Path: ...
1865+
1866+
class Function(Protocol[P, R]):
1867+
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: ...
1868+
1869+
def file_cache(fn: Function[Concatenate[Path, P], R]) -> Function[P, R]:
1870+
def wrapper(*args: P.args, **kw: P.kwargs) -> R:
1871+
return fn(Path(), *args, **kw)
1872+
return wrapper
1873+
1874+
@file_cache
1875+
def get_thing(path: Path, *, some_arg: int) -> int: ...
1876+
reveal_type(get_thing) # N: Revealed type is "__main__.Function[[*, some_arg: builtins.int], builtins.int]"
1877+
get_thing(some_arg=1) # OK
1878+
[builtins fixtures/paramspec.pyi]
1879+
1880+
[case testParamSpecConcatenateKeywordOnly]
1881+
from typing import Callable, TypeVar
1882+
from typing_extensions import ParamSpec, Concatenate
1883+
1884+
P = ParamSpec("P")
1885+
R = TypeVar("R")
1886+
1887+
class Path: ...
1888+
1889+
def file_cache(fn: Callable[Concatenate[Path, P], R]) -> Callable[P, R]:
1890+
def wrapper(*args: P.args, **kw: P.kwargs) -> R:
1891+
return fn(Path(), *args, **kw)
1892+
return wrapper
1893+
1894+
@file_cache
1895+
def get_thing(path: Path, *, some_arg: int) -> int: ...
1896+
reveal_type(get_thing) # N: Revealed type is "def (*, some_arg: builtins.int) -> builtins.int"
1897+
get_thing(some_arg=1) # OK
1898+
[builtins fixtures/paramspec.pyi]
1899+
1900+
[case testParamSpecConcatenateCallbackApply]
1901+
from typing import Callable, Protocol
1902+
from typing_extensions import ParamSpec, Concatenate
1903+
1904+
P = ParamSpec("P")
1905+
1906+
class FuncType(Protocol[P]):
1907+
def __call__(self, x: int, s: str, *args: P.args, **kw_args: P.kwargs) -> str: ...
1908+
1909+
def forwarder1(fp: FuncType[P], *args: P.args, **kw_args: P.kwargs) -> str:
1910+
return fp(0, '', *args, **kw_args)
1911+
1912+
def forwarder2(fp: Callable[Concatenate[int, str, P], str], *args: P.args, **kw_args: P.kwargs) -> str:
1913+
return fp(0, '', *args, **kw_args)
1914+
1915+
def my_f(x: int, s: str, d: bool) -> str: ...
1916+
forwarder1(my_f, True) # OK
1917+
forwarder2(my_f, True) # OK
1918+
forwarder1(my_f, 1.0) # E: Argument 2 to "forwarder1" has incompatible type "float"; expected "bool"
1919+
forwarder2(my_f, 1.0) # E: Argument 2 to "forwarder2" has incompatible type "float"; expected "bool"
1920+
[builtins fixtures/paramspec.pyi]
1921+
1922+
[case testParamSpecCallbackProtocolSelf]
1923+
from typing import Callable, Protocol, TypeVar
1924+
from typing_extensions import ParamSpec, Concatenate
1925+
1926+
Params = ParamSpec("Params")
1927+
Result = TypeVar("Result", covariant=True)
1928+
1929+
class FancyMethod(Protocol):
1930+
def __call__(self, arg1: int, arg2: str) -> bool: ...
1931+
def return_me(self: Callable[Params, Result]) -> Callable[Params, Result]: ...
1932+
def return_part(self: Callable[Concatenate[int, Params], Result]) -> Callable[Params, Result]: ...
1933+
1934+
m: FancyMethod
1935+
reveal_type(m.return_me()) # N: Revealed type is "def (arg1: builtins.int, arg2: builtins.str) -> builtins.bool"
1936+
reveal_type(m.return_part()) # N: Revealed type is "def (arg2: builtins.str) -> builtins.bool"
1937+
[builtins fixtures/paramspec.pyi]
1938+
1939+
[case testParamSpecInferenceCallableAgainstAny]
1940+
from typing import Callable, TypeVar, Any
1941+
from typing_extensions import ParamSpec, Concatenate
1942+
1943+
_P = ParamSpec("_P")
1944+
_R = TypeVar("_R")
1945+
1946+
class A: ...
1947+
a = A()
1948+
1949+
def a_func(
1950+
func: Callable[Concatenate[A, _P], _R],
1951+
) -> Callable[Concatenate[Any, _P], _R]:
1952+
def wrapper(__a: Any, *args: _P.args, **kwargs: _P.kwargs) -> _R:
1953+
return func(a, *args, **kwargs)
1954+
return wrapper
1955+
1956+
def test(a, *args): ...
1957+
x: Any
1958+
y: object
1959+
1960+
a_func(test)
1961+
x = a_func(test)
1962+
y = a_func(test)
1963+
[builtins fixtures/paramspec.pyi]
1964+
18281965
[case testParamSpecInferenceWithCallbackProtocol]
18291966
from typing import Protocol, Callable, ParamSpec
18301967

test-data/unit/fixtures/paramspec.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ class object:
1616

1717
class function: ...
1818
class ellipsis: ...
19+
class classmethod: ...
1920

2021
class type:
2122
def __init__(self, *a: object) -> None: ...

0 commit comments

Comments
 (0)