Skip to content

Commit

Permalink
Fix crash on nested generic callable (#14093)
Browse files Browse the repository at this point in the history
Fixes #10244
Fixes #13515 

This fixes only the crash part, I am going to fix also the embarrassing
type variable clash in a separate PR, since it is completely unrelated
issue.

The crash happens because solver can call `is_suptype()` on the
constraint bounds, and those can contain `<Erased>`. Then if it is a
generic callable type (e.g. `def [S] (S) -> T` when used as a context is
erased to `def [S] (S) -> <Erased>`), `is_subtype()` will try unifying
them, causing the crash when applying unified arguments.

My fix is to simply allow subtyping between callable types that contain
`<Erased>`, we anyway allow checking subtpying between all other types
with `<Erased>` components. And this technically can be useful, e.g. `[T
<: DerivedGen1[<Erased>], T <: DerivedGen2[<Erased>]]` will be solved as
`T <: NonGenBase`.

Btw this crash technically has nothing to do with dataclasses, but it
looks like there is no other way in mypy to define a callable with
generic callable as argument type, if I try:
```python
def foo(x: Callable[[S], T]) -> T: ...
```
to repro the crash, mypy instead interprets `foo` as `def [S, T] (x:
Callable[[S], T]) -> T`, i.e. the argument type is not generic. I also
tried callback protocols, but they also don't repro the crash (at least
I can't find a repro), because protocols use variance for subtyping,
before actually checking member types.
  • Loading branch information
ilevkivskyi committed Nov 16, 2022
1 parent e01359d commit 7d0d1d9
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 13 deletions.
19 changes: 14 additions & 5 deletions mypy/applytype.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def apply_generic_arguments(
report_incompatible_typevar_value: Callable[[CallableType, Type, str, Context], None],
context: Context,
skip_unsatisfied: bool = False,
allow_erased_callables: bool = False,
) -> CallableType:
"""Apply generic type arguments to a callable type.
Expand Down Expand Up @@ -130,18 +131,26 @@ def apply_generic_arguments(
+ callable.arg_names[star_index + 1 :]
)
arg_types = (
[expand_type(at, id_to_type) for at in callable.arg_types[:star_index]]
[
expand_type(at, id_to_type, allow_erased_callables)
for at in callable.arg_types[:star_index]
]
+ expanded
+ [expand_type(at, id_to_type) for at in callable.arg_types[star_index + 1 :]]
+ [
expand_type(at, id_to_type, allow_erased_callables)
for at in callable.arg_types[star_index + 1 :]
]
)
else:
arg_types = [expand_type(at, id_to_type) for at in callable.arg_types]
arg_types = [
expand_type(at, id_to_type, allow_erased_callables) for at in callable.arg_types
]
arg_kinds = callable.arg_kinds
arg_names = callable.arg_names

# Apply arguments to TypeGuard if any.
if callable.type_guard is not None:
type_guard = expand_type(callable.type_guard, id_to_type)
type_guard = expand_type(callable.type_guard, id_to_type, allow_erased_callables)
else:
type_guard = None

Expand All @@ -150,7 +159,7 @@ def apply_generic_arguments(

return callable.copy_modified(
arg_types=arg_types,
ret_type=expand_type(callable.ret_type, id_to_type),
ret_type=expand_type(callable.ret_type, id_to_type, allow_erased_callables),
variables=remaining_tvars,
type_guard=type_guard,
arg_kinds=arg_kinds,
Expand Down
29 changes: 22 additions & 7 deletions mypy/expandtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,20 +39,26 @@


@overload
def expand_type(typ: ProperType, env: Mapping[TypeVarId, Type]) -> ProperType:
def expand_type(
typ: ProperType, env: Mapping[TypeVarId, Type], allow_erased_callables: bool = ...
) -> ProperType:
...


@overload
def expand_type(typ: Type, env: Mapping[TypeVarId, Type]) -> Type:
def expand_type(
typ: Type, env: Mapping[TypeVarId, Type], allow_erased_callables: bool = ...
) -> Type:
...


def expand_type(typ: Type, env: Mapping[TypeVarId, Type]) -> Type:
def expand_type(
typ: Type, env: Mapping[TypeVarId, Type], allow_erased_callables: bool = False
) -> Type:
"""Substitute any type variable references in a type given by a type
environment.
"""
return typ.accept(ExpandTypeVisitor(env))
return typ.accept(ExpandTypeVisitor(env, allow_erased_callables))


@overload
Expand Down Expand Up @@ -129,8 +135,11 @@ class ExpandTypeVisitor(TypeVisitor[Type]):

variables: Mapping[TypeVarId, Type] # TypeVar id -> TypeVar value

def __init__(self, variables: Mapping[TypeVarId, Type]) -> None:
def __init__(
self, variables: Mapping[TypeVarId, Type], allow_erased_callables: bool = False
) -> None:
self.variables = variables
self.allow_erased_callables = allow_erased_callables

def visit_unbound_type(self, t: UnboundType) -> Type:
return t
Expand All @@ -148,8 +157,14 @@ def visit_deleted_type(self, t: DeletedType) -> Type:
return t

def visit_erased_type(self, t: ErasedType) -> Type:
# Should not get here.
raise RuntimeError()
if not self.allow_erased_callables:
raise RuntimeError()
# This may happen during type inference if some function argument
# type is a generic callable, and its erased form will appear in inferred
# constraints, then solver may check subtyping between them, which will trigger
# unify_generic_callables(), this is why we can get here. In all other cases it
# is a sign of a bug, since <Erased> should never appear in any stored types.
return t

def visit_instance(self, t: Instance) -> Type:
args = self.expand_types_with_unpack(list(t.args))
Expand Down
6 changes: 5 additions & 1 deletion mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1667,8 +1667,12 @@ def report(*args: Any) -> None:
nonlocal had_errors
had_errors = True

# This function may be called by the solver, so we need to allow erased types here.
# We anyway allow checking subtyping between other types containing <Erased>
# (probably also because solver needs subtyping). See also comment in
# ExpandTypeVisitor.visit_erased_type().
applied = mypy.applytype.apply_generic_arguments(
type, non_none_inferred_vars, report, context=target
type, non_none_inferred_vars, report, context=target, allow_erased_callables=True
)
if had_errors:
return None
Expand Down
23 changes: 23 additions & 0 deletions test-data/unit/check-dataclasses.test
Original file line number Diff line number Diff line change
Expand Up @@ -1958,3 +1958,26 @@ lst = SubLinkedList(1, LinkedList(2)) # E: Argument 2 to "SubLinkedList" has in
reveal_type(lst.next) # N: Revealed type is "Union[__main__.SubLinkedList, None]"
reveal_type(SubLinkedList) # N: Revealed type is "def (value: builtins.int, next: Union[__main__.SubLinkedList, None] =) -> __main__.SubLinkedList"
[builtins fixtures/dataclasses.pyi]

[case testNoCrashOnNestedGenericCallable]
from dataclasses import dataclass
from typing import Generic, TypeVar, Callable

T = TypeVar('T')
R = TypeVar('R')
X = TypeVar('X')

@dataclass
class Box(Generic[T]):
inner: T

@dataclass
class Cont(Generic[R]):
run: Box[Callable[[X], R]]

def const_two(x: T) -> str:
return "two"

c = Cont(Box(const_two))
reveal_type(c) # N: Revealed type is "__main__.Cont[builtins.str]"
[builtins fixtures/dataclasses.pyi]

0 comments on commit 7d0d1d9

Please sign in to comment.