diff --git a/mypy/applytype.py b/mypy/applytype.py index 29f2287ef39c..5f066e5e62fc 100644 --- a/mypy/applytype.py +++ b/mypy/applytype.py @@ -3,7 +3,7 @@ import mypy.subtypes from mypy.sametypes import is_same_type from mypy.expandtype import expand_type -from mypy.types import Type, TypeVarType, CallableType, AnyType, Void +from mypy.types import Type, TypeVarId, TypeVarType, CallableType, AnyType, Void from mypy.messages import MessageBuilder from mypy.nodes import Context @@ -48,7 +48,7 @@ def apply_generic_arguments(callable: CallableType, types: List[Type], msg.incompatible_typevar_value(callable, i + 1, type, context) # Create a map from type variable id to target type. - id_to_type = {} # type: Dict[int, Type] + id_to_type = {} # type: Dict[TypeVarId, Type] for i, tv in enumerate(tvars): if types[i]: id_to_type[tv.id] = types[i] diff --git a/mypy/checker.py b/mypy/checker.py index d0a51dac70d4..4bd90fd6378e 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -32,7 +32,7 @@ from mypy.types import ( Type, AnyType, CallableType, Void, FunctionLike, Overloaded, TupleType, Instance, NoneTyp, ErrorType, strip_type, - UnionType, TypeVarType, PartialType, DeletedType, UninhabitedType + UnionType, TypeVarId, TypeVarType, PartialType, DeletedType, UninhabitedType ) from mypy.sametypes import is_same_type from mypy.messages import MessageBuilder @@ -920,7 +920,7 @@ def check_getattr_method(self, typ: CallableType, context: Context) -> None: def expand_typevars(self, defn: FuncItem, typ: CallableType) -> List[Tuple[FuncItem, CallableType]]: # TODO use generator - subst = [] # type: List[List[Tuple[int, Type]]] + subst = [] # type: List[List[Tuple[TypeVarId, Type]]] tvars = typ.variables or [] tvars = tvars[:] if defn.info: @@ -2524,17 +2524,17 @@ def get_isinstance_type(node: Node, type_map: Dict[Node, Type]) -> Type: return UnionType(types) -def expand_node(defn: Node, map: Dict[int, Type]) -> Node: +def expand_node(defn: Node, map: Dict[TypeVarId, Type]) -> Node: visitor = TypeTransformVisitor(map) return defn.accept(visitor) -def expand_func(defn: FuncItem, map: Dict[int, Type]) -> FuncItem: +def expand_func(defn: FuncItem, map: Dict[TypeVarId, Type]) -> FuncItem: return cast(FuncItem, expand_node(defn, map)) class TypeTransformVisitor(TransformVisitor): - def __init__(self, map: Dict[int, Type]) -> None: + def __init__(self, map: Dict[TypeVarId, Type]) -> None: super().__init__() self.map = map diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index d759c96475e1..a8649764a514 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -4,7 +4,7 @@ from mypy.types import ( Type, AnyType, CallableType, Overloaded, NoneTyp, Void, TypeVarDef, - TupleType, Instance, TypeVarType, ErasedType, UnionType, + TupleType, Instance, TypeVarId, TypeVarType, ErasedType, UnionType, PartialType, DeletedType, UnboundType, UninhabitedType, TypeType ) from mypy.nodes import ( @@ -22,7 +22,7 @@ import mypy.checker from mypy import types from mypy.sametypes import is_same_type -from mypy.replacetvars import replace_func_type_vars +from mypy.erasetype import replace_meta_vars from mypy.messages import MessageBuilder from mypy import messages from mypy.infer import infer_type_arguments, infer_function_type_arguments @@ -34,6 +34,7 @@ from mypy.semanal import self_type from mypy.constraints import get_actual_type from mypy.checkstrformat import StringFormatterChecker +from mypy.expandtype import expand_type from mypy import experiments @@ -234,6 +235,7 @@ def check_call(self, callee: Type, args: List[Node], lambda i: self.accept(args[i])) if callee.is_generic(): + callee = freshen_generic_callable(callee) callee = self.infer_function_type_arguments_using_context( callee, context) callee = self.infer_function_type_arguments( @@ -394,12 +396,12 @@ def infer_function_type_arguments_using_context( ctx = self.chk.type_context[-1] if not ctx: return callable - # The return type may have references to function type variables that + # The return type may have references to type metavariables that # we are inferring right now. We must consider them as indeterminate # and they are not potential results; thus we replace them with the # special ErasedType type. On the other hand, class type variables are # valid results. - erased_ctx = replace_func_type_vars(ctx, ErasedType()) + erased_ctx = replace_meta_vars(ctx, ErasedType()) ret_type = callable.ret_type if isinstance(ret_type, TypeVarType): if ret_type.values or (not isinstance(ctx, Instance) or @@ -1264,7 +1266,8 @@ def visit_set_expr(self, e: SetExpr) -> Type: def check_list_or_set_expr(self, items: List[Node], fullname: str, tag: str, context: Context) -> Type: # Translate into type checking a generic function call. - tv = TypeVarType('T', -1, [], self.chk.object_type()) + tvdef = TypeVarDef('T', -1, [], self.chk.object_type()) + tv = TypeVarType(tvdef) constructor = CallableType( [tv], [nodes.ARG_STAR], @@ -1272,7 +1275,7 @@ def check_list_or_set_expr(self, items: List[Node], fullname: str, self.chk.named_generic_type(fullname, [tv]), self.named_type('builtins.function'), name=tag, - variables=[TypeVarDef('T', -1, None, self.chk.object_type())]) + variables=[tvdef]) return self.check_call(constructor, items, [nodes.ARG_POS] * len(items), context)[0] @@ -1301,20 +1304,21 @@ def visit_tuple_expr(self, e: TupleExpr) -> Type: def visit_dict_expr(self, e: DictExpr) -> Type: # Translate into type checking a generic function call. - tv1 = TypeVarType('KT', -1, [], self.chk.object_type()) - tv2 = TypeVarType('VT', -2, [], self.chk.object_type()) + ktdef = TypeVarDef('KT', -1, [], self.chk.object_type()) + vtdef = TypeVarDef('VT', -2, [], self.chk.object_type()) + kt = TypeVarType(ktdef) + vt = TypeVarType(vtdef) # The callable type represents a function like this: # # def (*v: Tuple[kt, vt]) -> Dict[kt, vt]: ... constructor = CallableType( - [TupleType([tv1, tv2], self.named_type('builtins.tuple'))], + [TupleType([kt, vt], self.named_type('builtins.tuple'))], [nodes.ARG_STAR], [None], - self.chk.named_generic_type('builtins.dict', [tv1, tv2]), + self.chk.named_generic_type('builtins.dict', [kt, vt]), self.named_type('builtins.function'), name='', - variables=[TypeVarDef('KT', -1, None, self.chk.object_type()), - TypeVarDef('VT', -2, None, self.chk.object_type())]) + variables=[ktdef, vtdef]) # Synthesize function arguments. args = [] # type: List[Node] for key, value in e.items: @@ -1360,7 +1364,7 @@ def infer_lambda_type_using_context(self, e: FuncExpr) -> CallableType: # they must be considered as indeterminate. We use ErasedType since it # does not affect type inference results (it is for purposes like this # only). - ctx = replace_func_type_vars(ctx, ErasedType()) + ctx = replace_meta_vars(ctx, ErasedType()) callable_ctx = cast(CallableType, ctx) @@ -1438,7 +1442,8 @@ def check_generator_or_comprehension(self, gen: GeneratorExpr, # Infer the type of the list comprehension by using a synthetic generic # callable type. - tv = TypeVarType('T', -1, [], self.chk.object_type()) + tvdef = TypeVarDef('T', -1, [], self.chk.object_type()) + tv = TypeVarType(tvdef) constructor = CallableType( [tv], [nodes.ARG_POS], @@ -1446,7 +1451,7 @@ def check_generator_or_comprehension(self, gen: GeneratorExpr, self.chk.named_generic_type(type_name, [tv]), self.chk.named_type('builtins.function'), name=id_for_messages, - variables=[TypeVarDef('T', -1, None, self.chk.object_type())]) + variables=[tvdef]) return self.check_call(constructor, [gen.left_expr], [nodes.ARG_POS], gen)[0] @@ -1456,17 +1461,18 @@ def visit_dictionary_comprehension(self, e: DictionaryComprehension): # Infer the type of the list comprehension by using a synthetic generic # callable type. - key_tv = TypeVarType('KT', -1, [], self.chk.object_type()) - value_tv = TypeVarType('VT', -2, [], self.chk.object_type()) + ktdef = TypeVarDef('KT', -1, [], self.chk.object_type()) + vtdef = TypeVarDef('VT', -2, [], self.chk.object_type()) + kt = TypeVarType(ktdef) + vt = TypeVarType(vtdef) constructor = CallableType( - [key_tv, value_tv], + [kt, vt], [nodes.ARG_POS, nodes.ARG_POS], [None, None], - self.chk.named_generic_type('builtins.dict', [key_tv, value_tv]), + self.chk.named_generic_type('builtins.dict', [kt, vt]), self.chk.named_type('builtins.function'), name='', - variables=[TypeVarDef('KT', -1, None, self.chk.object_type()), - TypeVarDef('VT', -2, None, self.chk.object_type())]) + variables=[ktdef, vtdef]) return self.check_call(constructor, [e.key, e.value], [nodes.ARG_POS, nodes.ARG_POS], e)[0] @@ -1775,3 +1781,14 @@ def overload_arg_similarity(actual: Type, formal: Type) -> int: return 2 # Fall back to a conservative equality check for the remaining kinds of type. return 2 if is_same_type(erasetype.erase_type(actual), erasetype.erase_type(formal)) else 0 + + +def freshen_generic_callable(callee: CallableType) -> CallableType: + tvdefs = [] + tvmap = {} # type: Dict[TypeVarId, Type] + for v in callee.variables: + tvdef = TypeVarDef.new_unification_variable(v) + tvdefs.append(tvdef) + tvmap[v.id] = TypeVarType(tvdef) + + return cast(CallableType, expand_type(callee, tvmap)).copy_modified(variables=tvdefs) diff --git a/mypy/checkmember.py b/mypy/checkmember.py index 2c6ea3166d8b..cc0be2058fbf 100644 --- a/mypy/checkmember.py +++ b/mypy/checkmember.py @@ -1,9 +1,9 @@ """Type checking of attribute access""" -from typing import cast, Callable, List, Optional +from typing import cast, Callable, List, Dict, Optional from mypy.types import ( - Type, Instance, AnyType, TupleType, CallableType, FunctionLike, TypeVarDef, + Type, Instance, AnyType, TupleType, CallableType, FunctionLike, TypeVarId, TypeVarDef, Overloaded, TypeVarType, TypeTranslator, UnionType, PartialType, DeletedType, NoneTyp, TypeType ) @@ -413,51 +413,15 @@ def class_callable(init_type: CallableType, info: TypeInfo, type_type: Instance, special_sig: Optional[str]) -> CallableType: """Create a type object type based on the signature of __init__.""" variables = [] # type: List[TypeVarDef] - for i, tvar in enumerate(info.defn.type_vars): - variables.append(TypeVarDef(tvar.name, i + 1, tvar.values, tvar.upper_bound, - tvar.variance)) - - initvars = init_type.variables - variables.extend(initvars) + variables.extend(info.defn.type_vars) + variables.extend(init_type.variables) callable_type = init_type.copy_modified( ret_type=self_type(info), fallback=type_type, name=None, variables=variables, special_sig=special_sig) c = callable_type.with_name('"{}"'.format(info.name())) - cc = convert_class_tvars_to_func_tvars(c, len(initvars)) - cc.is_classmethod_class = True - return cc - - -def convert_class_tvars_to_func_tvars(callable: CallableType, - num_func_tvars: int) -> CallableType: - return cast(CallableType, callable.accept(TvarTranslator(num_func_tvars))) - - -class TvarTranslator(TypeTranslator): - def __init__(self, num_func_tvars: int) -> None: - super().__init__() - self.num_func_tvars = num_func_tvars - - def visit_type_var(self, t: TypeVarType) -> Type: - if t.id < 0: - return t - else: - return TypeVarType(t.name, -t.id - self.num_func_tvars, t.values, t.upper_bound, - t.variance) - - def translate_variables(self, - variables: List[TypeVarDef]) -> List[TypeVarDef]: - if not variables: - return variables - items = [] # type: List[TypeVarDef] - for v in variables: - if v.id > 0: - items.append(TypeVarDef(v.name, -v.id - self.num_func_tvars, - v.values, v.upper_bound, v.variance)) - else: - items.append(v) - return items + c.is_classmethod_class = True + return c def map_type_from_supertype(typ: Type, sub_info: TypeInfo, diff --git a/mypy/constraints.py b/mypy/constraints.py index 016a1593c659..e9f0402d8fca 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -5,7 +5,7 @@ from mypy.types import ( CallableType, Type, TypeVisitor, UnboundType, AnyType, Void, NoneTyp, TypeVarType, Instance, TupleType, UnionType, Overloaded, ErasedType, PartialType, DeletedType, - UninhabitedType, TypeType, is_named_instance + UninhabitedType, TypeType, TypeVarId, is_named_instance ) from mypy.maptype import map_instance_to_supertype from mypy import nodes @@ -23,11 +23,11 @@ class Constraint: It can be either T <: type or T :> type (T is a type variable). """ - type_var = 0 # Type variable id - op = 0 # SUBTYPE_OF or SUPERTYPE_OF - target = None # type: Type + type_var = None # Type variable id + op = 0 # SUBTYPE_OF or SUPERTYPE_OF + target = None # type: Type - def __init__(self, type_var: int, op: int, target: Type) -> None: + def __init__(self, type_var: TypeVarId, op: int, target: Type) -> None: self.type_var = type_var self.op = op self.target = target diff --git a/mypy/erasetype.py b/mypy/erasetype.py index 17b51e3782af..b2c1f76afea4 100644 --- a/mypy/erasetype.py +++ b/mypy/erasetype.py @@ -1,7 +1,7 @@ -from typing import Optional, Container +from typing import Optional, Container, Callable from mypy.types import ( - Type, TypeVisitor, UnboundType, ErrorType, AnyType, Void, NoneTyp, + Type, TypeVisitor, UnboundType, ErrorType, AnyType, Void, NoneTyp, TypeVarId, Instance, TypeVarType, CallableType, TupleType, UnionType, Overloaded, ErasedType, PartialType, DeletedType, TypeTranslator, TypeList, UninhabitedType, TypeType ) @@ -105,20 +105,30 @@ def visit_instance(self, t: Instance) -> Type: return Instance(t.type, [], t.line) -def erase_typevars(t: Type, ids_to_erase: Optional[Container[int]] = None) -> Type: +def erase_typevars(t: Type, ids_to_erase: Optional[Container[TypeVarId]] = None) -> Type: """Replace all type variables in a type with any, or just the ones in the provided collection. """ - return t.accept(TypeVarEraser(ids_to_erase)) + def erase_id(id: TypeVarId) -> bool: + if ids_to_erase is None: + return True + return id in ids_to_erase + return t.accept(TypeVarEraser(erase_id, AnyType())) + + +def replace_meta_vars(t: Type, target_type: Type) -> Type: + """Replace unification variables in a type with the target type.""" + return t.accept(TypeVarEraser(lambda id: id.is_meta_var(), target_type)) class TypeVarEraser(TypeTranslator): """Implementation of type erasure""" - def __init__(self, ids_to_erase: Optional[Container[int]]) -> None: - self.ids_to_erase = ids_to_erase + def __init__(self, erase_id: Callable[[TypeVarId], bool], replacement: Type) -> None: + self.erase_id = erase_id + self.replacement = replacement def visit_type_var(self, t: TypeVarType) -> Type: - if self.ids_to_erase is not None and t.id not in self.ids_to_erase: - return t - return AnyType() + if self.erase_id(t.id): + return self.replacement + return t diff --git a/mypy/expandtype.py b/mypy/expandtype.py index 2a4a67c68276..87b1641daecc 100644 --- a/mypy/expandtype.py +++ b/mypy/expandtype.py @@ -3,11 +3,11 @@ from mypy.types import ( Type, Instance, CallableType, TypeVisitor, UnboundType, ErrorType, AnyType, Void, NoneTyp, TypeVarType, Overloaded, TupleType, UnionType, ErasedType, TypeList, - PartialType, DeletedType, UninhabitedType, TypeType + PartialType, DeletedType, UninhabitedType, TypeType, TypeVarId ) -def expand_type(typ: Type, env: Dict[int, Type]) -> Type: +def expand_type(typ: Type, env: Dict[TypeVarId, Type]) -> Type: """Substitute any type variable references in a type given by a type environment. """ @@ -16,23 +16,24 @@ def expand_type(typ: Type, env: Dict[int, Type]) -> Type: def expand_type_by_instance(typ: Type, instance: Instance) -> Type: - """Substitute type variables in type using values from an Instance.""" + """Substitute type variables in type using values from an Instance. + Type variables are considered to be bound by the class declaration.""" if instance.args == []: return typ else: - variables = {} # type: Dict[int, Type] - for i in range(len(instance.args)): - variables[i + 1] = instance.args[i] + variables = {} # type: Dict[TypeVarId, Type] + for binder, arg in zip(instance.type.defn.type_vars, instance.args): + variables[binder.id] = arg return expand_type(typ, variables) class ExpandTypeVisitor(TypeVisitor[Type]): """Visitor that substitutes type variables with values.""" - variables = None # type: Dict[int, Type] # TypeVar id -> TypeVar value + variables = None # type: Dict[TypeVarId, Type] # TypeVar id -> TypeVar value - def __init__(self, variables: Dict[int, Type]) -> None: + def __init__(self, variables: Dict[TypeVarId, Type]) -> None: self.variables = variables def visit_unbound_type(self, t: UnboundType) -> Type: diff --git a/mypy/infer.py b/mypy/infer.py index 3ba66efbe5ff..0047fe4536d3 100644 --- a/mypy/infer.py +++ b/mypy/infer.py @@ -3,7 +3,7 @@ from typing import List, Optional from mypy.constraints import infer_constraints, infer_constraints_for_callable -from mypy.types import Type, CallableType +from mypy.types import Type, TypeVarId, CallableType from mypy.solve import solve_constraints from mypy.constraints import SUBTYPE_OF @@ -35,7 +35,7 @@ def infer_function_type_arguments(callee_type: CallableType, return solve_constraints(type_vars, constraints, strict) -def infer_type_arguments(type_var_ids: List[int], +def infer_type_arguments(type_var_ids: List[TypeVarId], template: Type, actual: Type) -> List[Type]: # Like infer_function_type_arguments, but only match a single type # against a generic type. diff --git a/mypy/maptype.py b/mypy/maptype.py index 5eb6a0d92c31..dc8e7b29d3c5 100644 --- a/mypy/maptype.py +++ b/mypy/maptype.py @@ -2,7 +2,7 @@ from mypy.expandtype import expand_type from mypy.nodes import TypeInfo -from mypy.types import Type, Instance, AnyType +from mypy.types import Type, TypeVarId, Instance, AnyType def map_instance_to_supertype(instance: Instance, @@ -82,7 +82,7 @@ def map_instance_to_direct_supertypes(instance: Instance, return [Instance(supertype, [AnyType()] * len(supertype.type_vars))] -def instance_to_type_environment(instance: Instance) -> Dict[int, Type]: +def instance_to_type_environment(instance: Instance) -> Dict[TypeVarId, Type]: """Given an Instance, produce the resulting type environment for type variables bound by the Instance's class definition. @@ -92,5 +92,4 @@ def instance_to_type_environment(instance: Instance) -> Dict[int, Type]: arguments. The type variables are mapped by their `id`. """ - # Type variables bound by a class have `id` of 1, 2, etc. - return {i + 1: instance.args[i] for i in range(len(instance.args))} + return {binder.id: arg for binder, arg in zip(instance.type.defn.type_vars, instance.args)} diff --git a/mypy/nodes.py b/mypy/nodes.py index a027f6daccfd..61213c732fb0 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -1983,8 +1983,8 @@ class SymbolTableNode: # AST node of definition (FuncDef/Var/TypeInfo/Decorator/TypeVarExpr, # or None for a bound type variable). node = None # type: Optional[SymbolNode] - # Type variable id (for bound type variables only) - tvar_id = 0 + # Type variable definition (for bound type variables only) + tvar_def = None # type: Optional[mypy.types.TypeVarDef] # Module id (e.g. "foo.bar") or None mod_id = '' # If this not None, override the type of the 'node' attribute. @@ -1997,13 +1997,14 @@ class SymbolTableNode: cross_ref = None # type: Optional[str] def __init__(self, kind: int, node: Optional[SymbolNode], mod_id: str = None, - typ: 'mypy.types.Type' = None, tvar_id: int = 0, + typ: 'mypy.types.Type' = None, + tvar_def: 'mypy.types.TypeVarDef' = None, module_public: bool = True) -> None: self.kind = kind self.node = node self.type_override = typ self.mod_id = mod_id - self.tvar_id = tvar_id + self.tvar_def = tvar_def self.module_public = module_public @property @@ -2046,8 +2047,8 @@ def serialize(self, prefix: str, name: str) -> JsonDict: data = {'.class': 'SymbolTableNode', 'kind': node_kinds[self.kind], } # type: JsonDict - if self.tvar_id: - data['tvar_id'] = self.tvar_id + if self.tvar_def: + data['tvar_def'] = self.tvar_def.serialize() if not self.module_public: data['module_public'] = False if self.kind == MODULE_REF: @@ -2089,8 +2090,8 @@ def deserialize(cls, data: JsonDict) -> 'SymbolTableNode': if 'type_override' in data: typ = mypy.types.Type.deserialize(data['type_override']) stnode = SymbolTableNode(kind, node, typ=typ) - if 'tvar_id' in data: - stnode.tvar_id = data['tvar_id'] + if 'tvar_def' in data: + stnode.tvar_def = mypy.types.TypeVarDef.deserialize(data['tvar_def']) if 'module_public' in data: stnode.module_public = data['module_public'] return stnode diff --git a/mypy/replacetvars.py b/mypy/replacetvars.py deleted file mode 100644 index 1ea5f83febf5..000000000000 --- a/mypy/replacetvars.py +++ /dev/null @@ -1,45 +0,0 @@ -"""Type operations""" - -from mypy.types import Type, AnyType, TypeTranslator, TypeVarType - - -def replace_type_vars(typ: Type, func_tvars: bool = True) -> Type: - """Replace type variable references in a type with the Any type. If - func_tvars is false, only replace instance type variables. - """ - return typ.accept(ReplaceTypeVarsVisitor(func_tvars)) - - -class ReplaceTypeVarsVisitor(TypeTranslator): - # Only override type variable handling; otherwise perform an identity - # transformation. - - func_tvars = False - - def __init__(self, func_tvars: bool) -> None: - self.func_tvars = func_tvars - - def visit_type_var(self, t: TypeVarType) -> Type: - if t.id > 0 or self.func_tvars: - if t.line is not None: - return AnyType(t.line) - else: - return AnyType() - else: - return t - - -def replace_func_type_vars(typ: Type, target_type: Type) -> Type: - """Replace function type variables in a type with the target type.""" - return typ.accept(ReplaceFuncTypeVarsVisitor(target_type)) - - -class ReplaceFuncTypeVarsVisitor(TypeTranslator): - def __init__(self, target_type: Type) -> None: - self.target_type = target_type - - def visit_type_var(self, t: TypeVarType) -> Type: - if t.id < 0: - return self.target_type - else: - return t diff --git a/mypy/semanal.py b/mypy/semanal.py index 48075aab7a8c..3bc6b0ba3166 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -459,7 +459,7 @@ def analyze_function(self, defn: FuncItem) -> None: tvarnodes = self.add_func_type_variables_to_symbol_table(defn) next_function_tvar_id = min([self.next_function_tvar_id()] + - [n.tvar_id - 1 for n in tvarnodes]) + [n.tvar_def.id.raw_id - 1 for n in tvarnodes]) self.next_function_tvar_id_stack.append(next_function_tvar_id) if defn.type: @@ -516,7 +516,7 @@ def add_func_type_variables_to_symbol_table( name = item.name if name in names: self.name_already_defined(name, defn) - node = self.bind_type_var(name, item.id, defn) + node = self.bind_type_var(name, item, defn) nodes.append(node) names.add(name) return nodes @@ -527,11 +527,11 @@ def type_var_names(self) -> Set[str]: else: return set(self.type.type_vars) - def bind_type_var(self, fullname: str, id: int, + def bind_type_var(self, fullname: str, tvar_def: TypeVarDef, context: Context) -> SymbolTableNode: node = self.lookup_qualified(fullname, context) node.kind = BOUND_TVAR - node.tvar_id = id + node.tvar_def = tvar_def return node def check_function_signature(self, fdef: FuncItem) -> None: @@ -863,10 +863,9 @@ def is_instance_type(self, t: Type) -> bool: def bind_class_type_variables_in_symbol_table( self, info: TypeInfo) -> List[SymbolTableNode]: - vars = info.type_vars nodes = [] # type: List[SymbolTableNode] - for index, var in enumerate(vars, 1): - node = self.bind_type_var(var, index, info) + for var, binder in zip(info.type_vars, info.defn.type_vars): + node = self.bind_type_var(var, binder, info) nodes.append(node) return nodes @@ -2598,10 +2597,7 @@ def self_type(typ: TypeInfo) -> Union[Instance, TupleType]: """ tv = [] # type: List[Type] for i in range(len(typ.type_vars)): - tv.append(TypeVarType(typ.type_vars[i], i + 1, - typ.defn.type_vars[i].values, - typ.defn.type_vars[i].upper_bound, - typ.defn.type_vars[i].variance)) + tv.append(TypeVarType(typ.defn.type_vars[i])) inst = Instance(typ, tv) if typ.tuple_type is None: return inst diff --git a/mypy/solve.py b/mypy/solve.py index 9d751dd6b2db..e3bbfd72d420 100644 --- a/mypy/solve.py +++ b/mypy/solve.py @@ -2,7 +2,7 @@ from typing import List, Dict -from mypy.types import Type, Void, NoneTyp, AnyType, ErrorType, UninhabitedType +from mypy.types import Type, Void, NoneTyp, AnyType, ErrorType, UninhabitedType, TypeVarId from mypy.constraints import Constraint, SUPERTYPE_OF from mypy.join import join_types from mypy.meet import meet_types @@ -11,7 +11,7 @@ from mypy import experiments -def solve_constraints(vars: List[int], constraints: List[Constraint], +def solve_constraints(vars: List[TypeVarId], constraints: List[Constraint], strict=True) -> List[Type]: """Solve type constraints. @@ -23,7 +23,7 @@ def solve_constraints(vars: List[int], constraints: List[Constraint], pick AnyType. """ # Collect a list of constraints for each type variable. - cmap = {} # type: Dict[int, List[Constraint]] + cmap = {} # type: Dict[TypeVarId, List[Constraint]] for con in constraints: a = cmap.get(con.type_var, []) # type: List[Constraint] a.append(con) diff --git a/mypy/test/testtypes.py b/mypy/test/testtypes.py index e4b8c0294900..73154e2cd867 100644 --- a/mypy/test/testtypes.py +++ b/mypy/test/testtypes.py @@ -14,7 +14,6 @@ Instance, NoneTyp, ErrorType, Overloaded, TypeType, ) from mypy.nodes import ARG_POS, ARG_OPT, ARG_STAR, CONTRAVARIANT, INVARIANT, COVARIANT -from mypy.replacetvars import replace_type_vars from mypy.subtypes import is_subtype, is_more_precise, is_proper_subtype from mypy.typefixture import TypeFixture, InterfaceTypeFixture @@ -111,11 +110,11 @@ def test_trivial_expand(self): self.assert_expand(t, [], t) def test_expand_naked_type_var(self): - self.assert_expand(self.fx.t, [(1, self.fx.a)], self.fx.a) - self.assert_expand(self.fx.t, [(2, self.fx.a)], self.fx.t) + self.assert_expand(self.fx.t, [(self.fx.t.id, self.fx.a)], self.fx.a) + self.assert_expand(self.fx.t, [(self.fx.s.id, self.fx.a)], self.fx.t) def test_expand_basic_generic_types(self): - self.assert_expand(self.fx.gt, [(1, self.fx.a)], self.fx.ga) + self.assert_expand(self.fx.gt, [(self.fx.t.id, self.fx.a)], self.fx.ga) # IDEA: Add test cases for # tuple types @@ -132,25 +131,6 @@ def assert_expand(self, orig, map_items, result): # Remove erased tags (asterisks). assert_equal(str(exp).replace('*', ''), str(result)) - # replace_type_vars - - def test_trivial_replace(self): - for t in (self.fx.a, self.fx.o, self.fx.void, self.fx.nonet, - self.tuple(self.fx.a), - self.callable([], self.fx.a, self.fx.a), self.fx.anyt, - self.fx.err): - self.assert_replace(t, t) - - def test_replace_type_var(self): - self.assert_replace(self.fx.t, self.fx.anyt) - - def test_replace_generic_instance(self): - self.assert_replace(self.fx.ga, self.fx.ga) - self.assert_replace(self.fx.gt, self.fx.gdyn) - - def assert_replace(self, orig, result): - assert_equal(str(replace_type_vars(orig)), str(result)) - # erase_type def test_trivial_erase(self): diff --git a/mypy/typeanal.py b/mypy/typeanal.py index 545a305347ff..0493d3a15f6b 100644 --- a/mypy/typeanal.py +++ b/mypy/typeanal.py @@ -91,11 +91,8 @@ def visit_unbound_type(self, t: UnboundType) -> Type: if len(t.args) > 0: self.fail('Type variable "{}" used with arguments'.format( t.name), t) - tvar_expr = cast(TypeVarExpr, sym.node) - return TypeVarType(t.name, sym.tvar_id, tvar_expr.values, - tvar_expr.upper_bound, - tvar_expr.variance, - t.line) + assert sym.tvar_def is not None + return TypeVarType(sym.tvar_def, t.line) elif fullname == 'builtins.None': if experiments.STRICT_OPTIONAL: if t.is_ret_type: @@ -280,7 +277,7 @@ def anal_array(self, a: List[Type]) -> List[Type]: def anal_var_defs(self, var_defs: List[TypeVarDef]) -> List[TypeVarDef]: a = [] # type: List[TypeVarDef] for vd in var_defs: - a.append(TypeVarDef(vd.name, vd.id, self.anal_array(vd.values), + a.append(TypeVarDef(vd.name, vd.id.raw_id, self.anal_array(vd.values), vd.upper_bound.accept(self), vd.variance, vd.line)) diff --git a/mypy/typefixture.py b/mypy/typefixture.py index ec76cf80d68b..59ffeea1046d 100644 --- a/mypy/typefixture.py +++ b/mypy/typefixture.py @@ -6,8 +6,8 @@ from typing import List from mypy.types import ( - TypeVarType, AnyType, Void, ErrorType, NoneTyp, Instance, CallableType, TypeVarDef, - TypeType, + Type, TypeVarType, AnyType, Void, ErrorType, NoneTyp, + Instance, CallableType, TypeVarDef, TypeType, ) from mypy.nodes import ( TypeInfo, ClassDef, Block, ARG_POS, ARG_OPT, ARG_STAR, SymbolTable, @@ -25,14 +25,19 @@ def __init__(self, variance: int=COVARIANT) -> None: self.oi = self.make_type_info('builtins.object') # class object self.o = Instance(self.oi, []) # object - # Type variables - self.t = TypeVarType('T', 1, [], self.o, variance) # T`1 (type variable) - self.tf = TypeVarType('T', -1, [], self.o, variance) # T`-1 (type variable) - self.tf2 = TypeVarType('T', -2, [], self.o, variance) # T`-2 (type variable) - self.s = TypeVarType('S', 2, [], self.o, variance) # S`2 (type variable) - self.s1 = TypeVarType('S', 1, [], self.o, variance) # S`1 (type variable) - self.sf = TypeVarType('S', -2, [], self.o, variance) # S`-2 (type variable) - self.sf1 = TypeVarType('S', -1, [], self.o, variance) # S`-1 (type variable) + # Type variables (these are effectively global) + + def make_type_var(name: str, id: int, values: List[Type], upper_bound: Type, + variance: int) -> TypeVarType: + return TypeVarType(TypeVarDef(name, id, values, upper_bound, variance)) + + self.t = make_type_var('T', 1, [], self.o, variance) # T`1 (type variable) + self.tf = make_type_var('T', -1, [], self.o, variance) # T`-1 (type variable) + self.tf2 = make_type_var('T', -2, [], self.o, variance) # T`-2 (type variable) + self.s = make_type_var('S', 2, [], self.o, variance) # S`2 (type variable) + self.s1 = make_type_var('S', 1, [], self.o, variance) # S`1 (type variable) + self.sf = make_type_var('S', -2, [], self.o, variance) # S`-2 (type variable) + self.sf1 = make_type_var('S', -1, [], self.o, variance) # S`-1 (type variable) # Simple types self.anyt = AnyType() diff --git a/mypy/types.py b/mypy/types.py index 9b1e93e27863..8f0f938d09d1 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -1,7 +1,9 @@ """Classes for representing mypy types.""" from abc import abstractmethod -from typing import Any, TypeVar, Dict, List, Tuple, cast, Generic, Set, Sequence, Optional +from typing import ( + Any, TypeVar, Dict, List, Tuple, cast, Generic, Set, Sequence, Optional, Union +) import mypy.nodes from mypy.nodes import INVARIANT, SymbolNode @@ -45,25 +47,84 @@ def deserialize(cls, data: JsonDict) -> 'Type': raise NotImplementedError('unexpected .class {}'.format(classname)) +class TypeVarId: + # A type variable is uniquely identified by its raw id and meta level. + + # For plain variables (type parameters of generic classes and + # functions) raw ids are allocated by semantic analysis, using + # positive ids 1, 2, ... for generic class parameters and negative + # ids -1, ... for generic function type arguments. This convention + # is only used to keep type variable ids distinct when allocating + # them; the type checker makes no distinction between class and + # function type variables. + + # Metavariables are allocated unique ids starting from 1. + raw_id = 0 # type: int + + # Level of the variable in type inference. Currently either 0 for + # declared types, or 1 for type inference metavariables. + meta_level = 0 # type: int + + # Class variable used for allocating fresh ids for metavariables. + next_raw_id = 1 # type: int + + def __init__(self, raw_id: int, meta_level: int = 0) -> None: + self.raw_id = raw_id + self.meta_level = meta_level + + @staticmethod + def new(meta_level: int) -> 'TypeVarId': + raw_id = TypeVarId.next_raw_id + TypeVarId.next_raw_id += 1 + return TypeVarId(raw_id, meta_level) + + def __repr__(self) -> str: + return self.raw_id.__repr__() + + def __eq__(self, other: object) -> bool: + if isinstance(other, TypeVarId): + return (self.raw_id == other.raw_id and + self.meta_level == other.meta_level) + else: + return False + + def __ne__(self, other: object) -> bool: + return not (self == other) + + def __hash__(self) -> int: + return hash((self.raw_id, self.meta_level)) + + def is_meta_var(self) -> bool: + return self.meta_level > 0 + + class TypeVarDef(mypy.nodes.Context): """Definition of a single type variable.""" name = '' - id = 0 - values = None # type: Optional[List[Type]] + id = None # type: TypeVarId + values = None # type: List[Type] # Value restriction, empty list if no restriction upper_bound = None # type: Type variance = INVARIANT # type: int line = 0 - def __init__(self, name: str, id: int, values: Optional[List[Type]], + def __init__(self, name: str, id: Union[TypeVarId, int], values: Optional[List[Type]], upper_bound: Type, variance: int = INVARIANT, line: int = -1) -> None: self.name = name + if isinstance(id, int): + id = TypeVarId(id) self.id = id self.values = values self.upper_bound = upper_bound self.variance = variance self.line = line + @staticmethod + def new_unification_variable(old: 'TypeVarDef') -> 'TypeVarDef': + new_id = TypeVarId.new(meta_level=1) + return TypeVarDef(old.name, new_id, old.values, + old.upper_bound, old.variance, old.line) + def get_line(self) -> int: return self.line @@ -76,9 +137,10 @@ def __repr__(self) -> str: return self.name def serialize(self) -> JsonDict: + assert not self.id.is_meta_var() return {'.class': 'TypeVarDef', 'name': self.name, - 'id': self.id, + 'id': self.id.raw_id, 'values': None if self.values is None else [v.serialize() for v in self.values], 'upper_bound': self.upper_bound.serialize(), 'variance': self.variance, @@ -368,19 +430,18 @@ class TypeVarType(Type): """ name = '' # Name of the type variable (for messages and debugging) - id = 0 # 1, 2, ... for type-related, -1, ... for function-related + id = None # type: TypeVarId values = None # type: List[Type] # Value restriction, empty list if no restriction upper_bound = None # type: Type # Upper bound for values # See comments in TypeVarDef for more about variance. variance = INVARIANT # type: int - def __init__(self, name: str, id: int, values: List[Type], upper_bound: Type, - variance: int = INVARIANT, line: int = -1) -> None: - self.name = name - self.id = id - self.values = values - self.upper_bound = upper_bound - self.variance = variance + def __init__(self, binder: TypeVarDef, line: int = -1) -> None: + self.name = binder.name + self.id = binder.id + self.values = binder.values + self.upper_bound = binder.upper_bound + self.variance = binder.variance super().__init__(line) def accept(self, visitor: 'TypeVisitor[T]') -> T: @@ -393,9 +454,10 @@ def erase_to_union_or_bound(self) -> Type: return self.upper_bound def serialize(self) -> JsonDict: + assert not self.id.is_meta_var() return {'.class': 'TypeVarType', 'name': self.name, - 'id': self.id, + 'id': self.id.raw_id, 'values': [v.serialize() for v in self.values], 'upper_bound': self.upper_bound.serialize(), 'variance': self.variance, @@ -404,11 +466,12 @@ def serialize(self) -> JsonDict: @classmethod def deserialize(cls, data: JsonDict) -> 'TypeVarType': assert data['.class'] == 'TypeVarType' - return TypeVarType(data['name'], + tvdef = TypeVarDef(data['name'], data['id'], [Type.deserialize(v) for v in data['values']], Type.deserialize(data['upper_bound']), data['variance']) + return TypeVarType(tvdef) class FunctionLike(Type): @@ -561,8 +624,8 @@ def items(self) -> List['CallableType']: def is_generic(self) -> bool: return bool(self.variables) - def type_var_ids(self) -> List[int]: - a = [] # type: List[int] + def type_var_ids(self) -> List[TypeVarId]: + a = [] # type: List[TypeVarId] for tv in self.variables: a.append(tv.id) return a diff --git a/test-data/unit/check-inference-context.test b/test-data/unit/check-inference-context.test index 706c055abce8..ad97d39710b6 100644 --- a/test-data/unit/check-inference-context.test +++ b/test-data/unit/check-inference-context.test @@ -748,3 +748,27 @@ def f2(iterable: Iterable[Tuple[str, Any]], **kw: Any) -> None: pass [builtins fixtures/dict.py] [out] + +[case testInferenceInGenericFunction] +from typing import TypeVar, List +T = TypeVar('T') +def f(a: T) -> None: + l = [] # type: List[T] + l.append(a) + l.append(1) # E: Argument 1 to "append" of "list" has incompatible type "int"; expected "T" +[builtins fixtures/list.py] +[out] +main: note: In function "f": + +[case testInferenceInGenericClass] +from typing import TypeVar, Generic, List +S = TypeVar('S') +T = TypeVar('T') +class A(Generic[S]): + def f(self, a: T, b: S) -> None: + l = [] # type: List[T] + l.append(a) + l.append(b) # E: Argument 1 to "append" of "list" has incompatible type "S"; expected "T" +[builtins fixtures/list.py] +[out] +main: note: In member "f" of class "A": diff --git a/test-data/unit/typexport-basic.test b/test-data/unit/typexport-basic.test index 0264bc981fc1..5d546f3f9de8 100644 --- a/test-data/unit/typexport-basic.test +++ b/test-data/unit/typexport-basic.test @@ -550,7 +550,7 @@ def f(): class A(Generic[T]): pass [out] CallExpr(4) : Any -NameExpr(4) : def [T] () -> A[T`-1] +NameExpr(4) : def [T] () -> A[T`1] [case testGenericCallInDynamicallyTypedFunction2] from typing import TypeVar, Generic @@ -561,7 +561,7 @@ class A(Generic[T]): def __init__(self, x: T) -> None: pass [out] CallExpr(4) : Any -NameExpr(4) : def [T] (x: T`-1) -> A[T`-1] +NameExpr(4) : def [T] (x: T`1) -> A[T`1] NameExpr(4) : def () -> Any [case testGenericCallInDynamicallyTypedFunction3]