Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move apply_type() to applytype.py #17346

Merged
merged 1 commit into from
Jun 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 133 additions & 2 deletions mypy/applytype.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,32 @@
from __future__ import annotations

from typing import Callable, Sequence
from typing import Callable, Iterable, Sequence

import mypy.subtypes
from mypy.erasetype import erase_typevars
from mypy.expandtype import expand_type
from mypy.nodes import Context
from mypy.nodes import Context, TypeInfo
from mypy.type_visitor import TypeTranslator
from mypy.typeops import get_all_type_vars
from mypy.types import (
AnyType,
CallableType,
Instance,
Parameters,
ParamSpecFlavor,
ParamSpecType,
PartialType,
ProperType,
Type,
TypeAliasType,
TypeVarId,
TypeVarLikeType,
TypeVarTupleType,
TypeVarType,
UninhabitedType,
UnpackType,
get_proper_type,
remove_dups,
)


Expand Down Expand Up @@ -170,3 +178,126 @@ def apply_generic_arguments(
type_guard=type_guard,
type_is=type_is,
)


def apply_poly(tp: CallableType, poly_tvars: Sequence[TypeVarLikeType]) -> CallableType | None:
"""Make free type variables generic in the type if possible.

This will translate the type `tp` while trying to create valid bindings for
type variables `poly_tvars` while traversing the type. This follows the same rules
as we do during semantic analysis phase, examples:
* Callable[Callable[[T], T], T] -> def [T] (def (T) -> T) -> T
* Callable[[], Callable[[T], T]] -> def () -> def [T] (T -> T)
* List[T] -> None (not possible)
"""
try:
return tp.copy_modified(
arg_types=[t.accept(PolyTranslator(poly_tvars)) for t in tp.arg_types],
ret_type=tp.ret_type.accept(PolyTranslator(poly_tvars)),
variables=[],
)
except PolyTranslationError:
return None


class PolyTranslationError(Exception):
pass


class PolyTranslator(TypeTranslator):
"""Make free type variables generic in the type if possible.

See docstring for apply_poly() for details.
"""

def __init__(
self,
poly_tvars: Iterable[TypeVarLikeType],
bound_tvars: frozenset[TypeVarLikeType] = frozenset(),
seen_aliases: frozenset[TypeInfo] = frozenset(),
) -> None:
self.poly_tvars = set(poly_tvars)
# This is a simplified version of TypeVarScope used during semantic analysis.
self.bound_tvars = bound_tvars
self.seen_aliases = seen_aliases

def collect_vars(self, t: CallableType | Parameters) -> list[TypeVarLikeType]:
found_vars = []
for arg in t.arg_types:
for tv in get_all_type_vars(arg):
if isinstance(tv, ParamSpecType):
normalized: TypeVarLikeType = tv.copy_modified(
flavor=ParamSpecFlavor.BARE, prefix=Parameters([], [], [])
)
else:
normalized = tv
if normalized in self.poly_tvars and normalized not in self.bound_tvars:
found_vars.append(normalized)
return remove_dups(found_vars)

def visit_callable_type(self, t: CallableType) -> Type:
found_vars = self.collect_vars(t)
self.bound_tvars |= set(found_vars)
result = super().visit_callable_type(t)
self.bound_tvars -= set(found_vars)

assert isinstance(result, ProperType) and isinstance(result, CallableType)
result.variables = list(result.variables) + found_vars
return result

def visit_type_var(self, t: TypeVarType) -> Type:
if t in self.poly_tvars and t not in self.bound_tvars:
raise PolyTranslationError()
return super().visit_type_var(t)

def visit_param_spec(self, t: ParamSpecType) -> Type:
if t in self.poly_tvars and t not in self.bound_tvars:
raise PolyTranslationError()
return super().visit_param_spec(t)

def visit_type_var_tuple(self, t: TypeVarTupleType) -> Type:
if t in self.poly_tvars and t not in self.bound_tvars:
raise PolyTranslationError()
return super().visit_type_var_tuple(t)

def visit_type_alias_type(self, t: TypeAliasType) -> Type:
if not t.args:
return t.copy_modified()
if not t.is_recursive:
return get_proper_type(t).accept(self)
# We can't handle polymorphic application for recursive generic aliases
# without risking an infinite recursion, just give up for now.
raise PolyTranslationError()

def visit_instance(self, t: Instance) -> Type:
if t.type.has_param_spec_type:
# We need this special-casing to preserve the possibility to store a
# generic function in an instance type. Things like
# forall T . Foo[[x: T], T]
# are not really expressible in current type system, but this looks like
# a useful feature, so let's keep it.
param_spec_index = next(
i for (i, tv) in enumerate(t.type.defn.type_vars) if isinstance(tv, ParamSpecType)
)
p = get_proper_type(t.args[param_spec_index])
if isinstance(p, Parameters):
found_vars = self.collect_vars(p)
self.bound_tvars |= set(found_vars)
new_args = [a.accept(self) for a in t.args]
self.bound_tvars -= set(found_vars)

repl = new_args[param_spec_index]
assert isinstance(repl, ProperType) and isinstance(repl, Parameters)
repl.variables = list(repl.variables) + list(found_vars)
return t.copy_modified(args=new_args)
# There is the same problem with callback protocols as with aliases
# (callback protocols are essentially more flexible aliases to callables).
if t.args and t.type.is_protocol and t.type.protocol_members == ["__call__"]:
if t.type in self.seen_aliases:
raise PolyTranslationError()
call = mypy.subtypes.find_member("__call__", t, t, is_operator=True)
assert call is not None
return call.accept(
PolyTranslator(self.poly_tvars, self.bound_tvars, self.seen_aliases | {t.type})
)
return super().visit_instance(t)
128 changes: 1 addition & 127 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@
non_method_protocol_members,
)
from mypy.traverser import has_await_expression
from mypy.type_visitor import TypeTranslator
from mypy.typeanal import (
check_for_explicit_any,
fix_instance,
Expand Down Expand Up @@ -168,7 +167,6 @@
TypeOfAny,
TypeType,
TypeVarId,
TypeVarLikeType,
TypeVarTupleType,
TypeVarType,
UnboundType,
Expand All @@ -182,7 +180,6 @@
get_proper_types,
has_recursive_types,
is_named_instance,
remove_dups,
split_with_prefix_and_suffix,
)
from mypy.types_utils import (
Expand Down Expand Up @@ -2136,7 +2133,7 @@ def infer_function_type_arguments(
)
# Try applying inferred polymorphic type if possible, e.g. Callable[[T], T] can
# be interpreted as def [T] (T) -> T, but dict[T, T] cannot be expressed.
applied = apply_poly(poly_callee_type, free_vars)
applied = applytype.apply_poly(poly_callee_type, free_vars)
if applied is not None and all(
a is not None and not isinstance(get_proper_type(a), UninhabitedType)
for a in poly_inferred_args
Expand Down Expand Up @@ -6220,129 +6217,6 @@ def replace_callable_return_type(c: CallableType, new_ret_type: Type) -> Callabl
return c.copy_modified(ret_type=new_ret_type)


def apply_poly(tp: CallableType, poly_tvars: Sequence[TypeVarLikeType]) -> CallableType | None:
"""Make free type variables generic in the type if possible.

This will translate the type `tp` while trying to create valid bindings for
type variables `poly_tvars` while traversing the type. This follows the same rules
as we do during semantic analysis phase, examples:
* Callable[Callable[[T], T], T] -> def [T] (def (T) -> T) -> T
* Callable[[], Callable[[T], T]] -> def () -> def [T] (T -> T)
* List[T] -> None (not possible)
"""
try:
return tp.copy_modified(
arg_types=[t.accept(PolyTranslator(poly_tvars)) for t in tp.arg_types],
ret_type=tp.ret_type.accept(PolyTranslator(poly_tvars)),
variables=[],
)
except PolyTranslationError:
return None


class PolyTranslationError(Exception):
pass


class PolyTranslator(TypeTranslator):
"""Make free type variables generic in the type if possible.

See docstring for apply_poly() for details.
"""

def __init__(
self,
poly_tvars: Iterable[TypeVarLikeType],
bound_tvars: frozenset[TypeVarLikeType] = frozenset(),
seen_aliases: frozenset[TypeInfo] = frozenset(),
) -> None:
self.poly_tvars = set(poly_tvars)
# This is a simplified version of TypeVarScope used during semantic analysis.
self.bound_tvars = bound_tvars
self.seen_aliases = seen_aliases

def collect_vars(self, t: CallableType | Parameters) -> list[TypeVarLikeType]:
found_vars = []
for arg in t.arg_types:
for tv in get_all_type_vars(arg):
if isinstance(tv, ParamSpecType):
normalized: TypeVarLikeType = tv.copy_modified(
flavor=ParamSpecFlavor.BARE, prefix=Parameters([], [], [])
)
else:
normalized = tv
if normalized in self.poly_tvars and normalized not in self.bound_tvars:
found_vars.append(normalized)
return remove_dups(found_vars)

def visit_callable_type(self, t: CallableType) -> Type:
found_vars = self.collect_vars(t)
self.bound_tvars |= set(found_vars)
result = super().visit_callable_type(t)
self.bound_tvars -= set(found_vars)

assert isinstance(result, ProperType) and isinstance(result, CallableType)
result.variables = list(result.variables) + found_vars
return result

def visit_type_var(self, t: TypeVarType) -> Type:
if t in self.poly_tvars and t not in self.bound_tvars:
raise PolyTranslationError()
return super().visit_type_var(t)

def visit_param_spec(self, t: ParamSpecType) -> Type:
if t in self.poly_tvars and t not in self.bound_tvars:
raise PolyTranslationError()
return super().visit_param_spec(t)

def visit_type_var_tuple(self, t: TypeVarTupleType) -> Type:
if t in self.poly_tvars and t not in self.bound_tvars:
raise PolyTranslationError()
return super().visit_type_var_tuple(t)

def visit_type_alias_type(self, t: TypeAliasType) -> Type:
if not t.args:
return t.copy_modified()
if not t.is_recursive:
return get_proper_type(t).accept(self)
# We can't handle polymorphic application for recursive generic aliases
# without risking an infinite recursion, just give up for now.
raise PolyTranslationError()

def visit_instance(self, t: Instance) -> Type:
if t.type.has_param_spec_type:
# We need this special-casing to preserve the possibility to store a
# generic function in an instance type. Things like
# forall T . Foo[[x: T], T]
# are not really expressible in current type system, but this looks like
# a useful feature, so let's keep it.
param_spec_index = next(
i for (i, tv) in enumerate(t.type.defn.type_vars) if isinstance(tv, ParamSpecType)
)
p = get_proper_type(t.args[param_spec_index])
if isinstance(p, Parameters):
found_vars = self.collect_vars(p)
self.bound_tvars |= set(found_vars)
new_args = [a.accept(self) for a in t.args]
self.bound_tvars -= set(found_vars)

repl = new_args[param_spec_index]
assert isinstance(repl, ProperType) and isinstance(repl, Parameters)
repl.variables = list(repl.variables) + list(found_vars)
return t.copy_modified(args=new_args)
# There is the same problem with callback protocols as with aliases
# (callback protocols are essentially more flexible aliases to callables).
if t.args and t.type.is_protocol and t.type.protocol_members == ["__call__"]:
if t.type in self.seen_aliases:
raise PolyTranslationError()
call = find_member("__call__", t, t, is_operator=True)
assert call is not None
return call.accept(
PolyTranslator(self.poly_tvars, self.bound_tvars, self.seen_aliases | {t.type})
)
return super().visit_instance(t)


class ArgInferSecondPassQuery(types.BoolTypeQuery):
"""Query whether an argument type should be inferred in the second pass.

Expand Down
Loading