Skip to content

Commit

Permalink
Improve type checking of generic function bodies (#1580)
Browse files Browse the repository at this point in the history
This commit series gives type variable ids their own type and replaces the class/function type variable distinction with a plain/metavariable distinction. The main goal is to fix #603, but it's also progress towards #1261 and other bugs involving type variable inference.

Metavariables (or unification variables) are variables introduced during type inference to represent the types that will be substituted for generic class or function type parameters. They only exist during type inference and should never escape into the inferred type of identifiers.

Fixes #603.
  • Loading branch information
rwbarton authored Jun 15, 2016
1 parent a3f002f commit 584f8f3
Show file tree
Hide file tree
Showing 19 changed files with 235 additions and 223 deletions.
4 changes: 2 additions & 2 deletions mypy/applytype.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import mypy.subtypes
from mypy.sametypes import is_same_type
from mypy.expandtype import expand_type
from mypy.types import Type, TypeVarType, CallableType, AnyType, Void
from mypy.types import Type, TypeVarId, TypeVarType, CallableType, AnyType, Void
from mypy.messages import MessageBuilder
from mypy.nodes import Context

Expand Down Expand Up @@ -48,7 +48,7 @@ def apply_generic_arguments(callable: CallableType, types: List[Type],
msg.incompatible_typevar_value(callable, i + 1, type, context)

# Create a map from type variable id to target type.
id_to_type = {} # type: Dict[int, Type]
id_to_type = {} # type: Dict[TypeVarId, Type]
for i, tv in enumerate(tvars):
if types[i]:
id_to_type[tv.id] = types[i]
Expand Down
10 changes: 5 additions & 5 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from mypy.types import (
Type, AnyType, CallableType, Void, FunctionLike, Overloaded, TupleType,
Instance, NoneTyp, ErrorType, strip_type,
UnionType, TypeVarType, PartialType, DeletedType, UninhabitedType
UnionType, TypeVarId, TypeVarType, PartialType, DeletedType, UninhabitedType
)
from mypy.sametypes import is_same_type
from mypy.messages import MessageBuilder
Expand Down Expand Up @@ -920,7 +920,7 @@ def check_getattr_method(self, typ: CallableType, context: Context) -> None:
def expand_typevars(self, defn: FuncItem,
typ: CallableType) -> List[Tuple[FuncItem, CallableType]]:
# TODO use generator
subst = [] # type: List[List[Tuple[int, Type]]]
subst = [] # type: List[List[Tuple[TypeVarId, Type]]]
tvars = typ.variables or []
tvars = tvars[:]
if defn.info:
Expand Down Expand Up @@ -2524,17 +2524,17 @@ def get_isinstance_type(node: Node, type_map: Dict[Node, Type]) -> Type:
return UnionType(types)


def expand_node(defn: Node, map: Dict[int, Type]) -> Node:
def expand_node(defn: Node, map: Dict[TypeVarId, Type]) -> Node:
visitor = TypeTransformVisitor(map)
return defn.accept(visitor)


def expand_func(defn: FuncItem, map: Dict[int, Type]) -> FuncItem:
def expand_func(defn: FuncItem, map: Dict[TypeVarId, Type]) -> FuncItem:
return cast(FuncItem, expand_node(defn, map))


class TypeTransformVisitor(TransformVisitor):
def __init__(self, map: Dict[int, Type]) -> None:
def __init__(self, map: Dict[TypeVarId, Type]) -> None:
super().__init__()
self.map = map

Expand Down
59 changes: 38 additions & 21 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from mypy.types import (
Type, AnyType, CallableType, Overloaded, NoneTyp, Void, TypeVarDef,
TupleType, Instance, TypeVarType, ErasedType, UnionType,
TupleType, Instance, TypeVarId, TypeVarType, ErasedType, UnionType,
PartialType, DeletedType, UnboundType, UninhabitedType, TypeType
)
from mypy.nodes import (
Expand All @@ -22,7 +22,7 @@
import mypy.checker
from mypy import types
from mypy.sametypes import is_same_type
from mypy.replacetvars import replace_func_type_vars
from mypy.erasetype import replace_meta_vars
from mypy.messages import MessageBuilder
from mypy import messages
from mypy.infer import infer_type_arguments, infer_function_type_arguments
Expand All @@ -34,6 +34,7 @@
from mypy.semanal import self_type
from mypy.constraints import get_actual_type
from mypy.checkstrformat import StringFormatterChecker
from mypy.expandtype import expand_type

from mypy import experiments

Expand Down Expand Up @@ -234,6 +235,7 @@ def check_call(self, callee: Type, args: List[Node],
lambda i: self.accept(args[i]))

if callee.is_generic():
callee = freshen_generic_callable(callee)
callee = self.infer_function_type_arguments_using_context(
callee, context)
callee = self.infer_function_type_arguments(
Expand Down Expand Up @@ -394,12 +396,12 @@ def infer_function_type_arguments_using_context(
ctx = self.chk.type_context[-1]
if not ctx:
return callable
# The return type may have references to function type variables that
# The return type may have references to type metavariables that
# we are inferring right now. We must consider them as indeterminate
# and they are not potential results; thus we replace them with the
# special ErasedType type. On the other hand, class type variables are
# valid results.
erased_ctx = replace_func_type_vars(ctx, ErasedType())
erased_ctx = replace_meta_vars(ctx, ErasedType())
ret_type = callable.ret_type
if isinstance(ret_type, TypeVarType):
if ret_type.values or (not isinstance(ctx, Instance) or
Expand Down Expand Up @@ -1264,15 +1266,16 @@ def visit_set_expr(self, e: SetExpr) -> Type:
def check_list_or_set_expr(self, items: List[Node], fullname: str,
tag: str, context: Context) -> Type:
# Translate into type checking a generic function call.
tv = TypeVarType('T', -1, [], self.chk.object_type())
tvdef = TypeVarDef('T', -1, [], self.chk.object_type())
tv = TypeVarType(tvdef)
constructor = CallableType(
[tv],
[nodes.ARG_STAR],
[None],
self.chk.named_generic_type(fullname, [tv]),
self.named_type('builtins.function'),
name=tag,
variables=[TypeVarDef('T', -1, None, self.chk.object_type())])
variables=[tvdef])
return self.check_call(constructor,
items,
[nodes.ARG_POS] * len(items), context)[0]
Expand Down Expand Up @@ -1301,20 +1304,21 @@ def visit_tuple_expr(self, e: TupleExpr) -> Type:

def visit_dict_expr(self, e: DictExpr) -> Type:
# Translate into type checking a generic function call.
tv1 = TypeVarType('KT', -1, [], self.chk.object_type())
tv2 = TypeVarType('VT', -2, [], self.chk.object_type())
ktdef = TypeVarDef('KT', -1, [], self.chk.object_type())
vtdef = TypeVarDef('VT', -2, [], self.chk.object_type())
kt = TypeVarType(ktdef)
vt = TypeVarType(vtdef)
# The callable type represents a function like this:
#
# def <unnamed>(*v: Tuple[kt, vt]) -> Dict[kt, vt]: ...
constructor = CallableType(
[TupleType([tv1, tv2], self.named_type('builtins.tuple'))],
[TupleType([kt, vt], self.named_type('builtins.tuple'))],
[nodes.ARG_STAR],
[None],
self.chk.named_generic_type('builtins.dict', [tv1, tv2]),
self.chk.named_generic_type('builtins.dict', [kt, vt]),
self.named_type('builtins.function'),
name='<list>',
variables=[TypeVarDef('KT', -1, None, self.chk.object_type()),
TypeVarDef('VT', -2, None, self.chk.object_type())])
variables=[ktdef, vtdef])
# Synthesize function arguments.
args = [] # type: List[Node]
for key, value in e.items:
Expand Down Expand Up @@ -1360,7 +1364,7 @@ def infer_lambda_type_using_context(self, e: FuncExpr) -> CallableType:
# they must be considered as indeterminate. We use ErasedType since it
# does not affect type inference results (it is for purposes like this
# only).
ctx = replace_func_type_vars(ctx, ErasedType())
ctx = replace_meta_vars(ctx, ErasedType())

callable_ctx = cast(CallableType, ctx)

Expand Down Expand Up @@ -1438,15 +1442,16 @@ def check_generator_or_comprehension(self, gen: GeneratorExpr,

# Infer the type of the list comprehension by using a synthetic generic
# callable type.
tv = TypeVarType('T', -1, [], self.chk.object_type())
tvdef = TypeVarDef('T', -1, [], self.chk.object_type())
tv = TypeVarType(tvdef)
constructor = CallableType(
[tv],
[nodes.ARG_POS],
[None],
self.chk.named_generic_type(type_name, [tv]),
self.chk.named_type('builtins.function'),
name=id_for_messages,
variables=[TypeVarDef('T', -1, None, self.chk.object_type())])
variables=[tvdef])
return self.check_call(constructor,
[gen.left_expr], [nodes.ARG_POS], gen)[0]

Expand All @@ -1456,17 +1461,18 @@ def visit_dictionary_comprehension(self, e: DictionaryComprehension):

# Infer the type of the list comprehension by using a synthetic generic
# callable type.
key_tv = TypeVarType('KT', -1, [], self.chk.object_type())
value_tv = TypeVarType('VT', -2, [], self.chk.object_type())
ktdef = TypeVarDef('KT', -1, [], self.chk.object_type())
vtdef = TypeVarDef('VT', -2, [], self.chk.object_type())
kt = TypeVarType(ktdef)
vt = TypeVarType(vtdef)
constructor = CallableType(
[key_tv, value_tv],
[kt, vt],
[nodes.ARG_POS, nodes.ARG_POS],
[None, None],
self.chk.named_generic_type('builtins.dict', [key_tv, value_tv]),
self.chk.named_generic_type('builtins.dict', [kt, vt]),
self.chk.named_type('builtins.function'),
name='<dictionary-comprehension>',
variables=[TypeVarDef('KT', -1, None, self.chk.object_type()),
TypeVarDef('VT', -2, None, self.chk.object_type())])
variables=[ktdef, vtdef])
return self.check_call(constructor,
[e.key, e.value], [nodes.ARG_POS, nodes.ARG_POS], e)[0]

Expand Down Expand Up @@ -1775,3 +1781,14 @@ def overload_arg_similarity(actual: Type, formal: Type) -> int:
return 2
# Fall back to a conservative equality check for the remaining kinds of type.
return 2 if is_same_type(erasetype.erase_type(actual), erasetype.erase_type(formal)) else 0


def freshen_generic_callable(callee: CallableType) -> CallableType:
tvdefs = []
tvmap = {} # type: Dict[TypeVarId, Type]
for v in callee.variables:
tvdef = TypeVarDef.new_unification_variable(v)
tvdefs.append(tvdef)
tvmap[v.id] = TypeVarType(tvdef)

return cast(CallableType, expand_type(callee, tvmap)).copy_modified(variables=tvdefs)
48 changes: 6 additions & 42 deletions mypy/checkmember.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""Type checking of attribute access"""

from typing import cast, Callable, List, Optional
from typing import cast, Callable, List, Dict, Optional

from mypy.types import (
Type, Instance, AnyType, TupleType, CallableType, FunctionLike, TypeVarDef,
Type, Instance, AnyType, TupleType, CallableType, FunctionLike, TypeVarId, TypeVarDef,
Overloaded, TypeVarType, TypeTranslator, UnionType, PartialType,
DeletedType, NoneTyp, TypeType
)
Expand Down Expand Up @@ -413,51 +413,15 @@ def class_callable(init_type: CallableType, info: TypeInfo, type_type: Instance,
special_sig: Optional[str]) -> CallableType:
"""Create a type object type based on the signature of __init__."""
variables = [] # type: List[TypeVarDef]
for i, tvar in enumerate(info.defn.type_vars):
variables.append(TypeVarDef(tvar.name, i + 1, tvar.values, tvar.upper_bound,
tvar.variance))

initvars = init_type.variables
variables.extend(initvars)
variables.extend(info.defn.type_vars)
variables.extend(init_type.variables)

callable_type = init_type.copy_modified(
ret_type=self_type(info), fallback=type_type, name=None, variables=variables,
special_sig=special_sig)
c = callable_type.with_name('"{}"'.format(info.name()))
cc = convert_class_tvars_to_func_tvars(c, len(initvars))
cc.is_classmethod_class = True
return cc


def convert_class_tvars_to_func_tvars(callable: CallableType,
num_func_tvars: int) -> CallableType:
return cast(CallableType, callable.accept(TvarTranslator(num_func_tvars)))


class TvarTranslator(TypeTranslator):
def __init__(self, num_func_tvars: int) -> None:
super().__init__()
self.num_func_tvars = num_func_tvars

def visit_type_var(self, t: TypeVarType) -> Type:
if t.id < 0:
return t
else:
return TypeVarType(t.name, -t.id - self.num_func_tvars, t.values, t.upper_bound,
t.variance)

def translate_variables(self,
variables: List[TypeVarDef]) -> List[TypeVarDef]:
if not variables:
return variables
items = [] # type: List[TypeVarDef]
for v in variables:
if v.id > 0:
items.append(TypeVarDef(v.name, -v.id - self.num_func_tvars,
v.values, v.upper_bound, v.variance))
else:
items.append(v)
return items
c.is_classmethod_class = True
return c


def map_type_from_supertype(typ: Type, sub_info: TypeInfo,
Expand Down
10 changes: 5 additions & 5 deletions mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from mypy.types import (
CallableType, Type, TypeVisitor, UnboundType, AnyType, Void, NoneTyp, TypeVarType,
Instance, TupleType, UnionType, Overloaded, ErasedType, PartialType, DeletedType,
UninhabitedType, TypeType, is_named_instance
UninhabitedType, TypeType, TypeVarId, is_named_instance
)
from mypy.maptype import map_instance_to_supertype
from mypy import nodes
Expand All @@ -23,11 +23,11 @@ class Constraint:
It can be either T <: type or T :> type (T is a type variable).
"""

type_var = 0 # Type variable id
op = 0 # SUBTYPE_OF or SUPERTYPE_OF
target = None # type: Type
type_var = None # Type variable id
op = 0 # SUBTYPE_OF or SUPERTYPE_OF
target = None # type: Type

def __init__(self, type_var: int, op: int, target: Type) -> None:
def __init__(self, type_var: TypeVarId, op: int, target: Type) -> None:
self.type_var = type_var
self.op = op
self.target = target
Expand Down
28 changes: 19 additions & 9 deletions mypy/erasetype.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Optional, Container
from typing import Optional, Container, Callable

from mypy.types import (
Type, TypeVisitor, UnboundType, ErrorType, AnyType, Void, NoneTyp,
Type, TypeVisitor, UnboundType, ErrorType, AnyType, Void, NoneTyp, TypeVarId,
Instance, TypeVarType, CallableType, TupleType, UnionType, Overloaded, ErasedType,
PartialType, DeletedType, TypeTranslator, TypeList, UninhabitedType, TypeType
)
Expand Down Expand Up @@ -105,20 +105,30 @@ def visit_instance(self, t: Instance) -> Type:
return Instance(t.type, [], t.line)


def erase_typevars(t: Type, ids_to_erase: Optional[Container[int]] = None) -> Type:
def erase_typevars(t: Type, ids_to_erase: Optional[Container[TypeVarId]] = None) -> Type:
"""Replace all type variables in a type with any,
or just the ones in the provided collection.
"""
return t.accept(TypeVarEraser(ids_to_erase))
def erase_id(id: TypeVarId) -> bool:
if ids_to_erase is None:
return True
return id in ids_to_erase
return t.accept(TypeVarEraser(erase_id, AnyType()))


def replace_meta_vars(t: Type, target_type: Type) -> Type:
"""Replace unification variables in a type with the target type."""
return t.accept(TypeVarEraser(lambda id: id.is_meta_var(), target_type))


class TypeVarEraser(TypeTranslator):
"""Implementation of type erasure"""

def __init__(self, ids_to_erase: Optional[Container[int]]) -> None:
self.ids_to_erase = ids_to_erase
def __init__(self, erase_id: Callable[[TypeVarId], bool], replacement: Type) -> None:
self.erase_id = erase_id
self.replacement = replacement

def visit_type_var(self, t: TypeVarType) -> Type:
if self.ids_to_erase is not None and t.id not in self.ids_to_erase:
return t
return AnyType()
if self.erase_id(t.id):
return self.replacement
return t
Loading

0 comments on commit 584f8f3

Please sign in to comment.