diff --git a/mypy/binder.py b/mypy/binder.py index 109fef25ce6a..c1b6862c9e6d 100644 --- a/mypy/binder.py +++ b/mypy/binder.py @@ -296,16 +296,17 @@ def assign_type(self, expr: Expression, # (See discussion in #3526) elif (isinstance(type, AnyType) and isinstance(declared_type, UnionType) - and any(isinstance(item, NoneType) for item in declared_type.items) + and any(isinstance(get_proper_type(item), NoneType) for item in declared_type.items) and isinstance(get_proper_type(self.most_recent_enclosing_type(expr, NoneType())), NoneType)): # Replace any Nones in the union type with Any - new_items = [type if isinstance(item, NoneType) else item + new_items = [type if isinstance(get_proper_type(item), NoneType) else item for item in declared_type.items] self.put(expr, UnionType(new_items)) elif (isinstance(type, AnyType) and not (isinstance(declared_type, UnionType) - and any(isinstance(item, AnyType) for item in declared_type.items))): + and any(isinstance(get_proper_type(item), AnyType) + for item in declared_type.items))): # Assigning an Any value doesn't affect the type to avoid false negatives, unless # there is an Any item in a declared union type. self.put(expr, declared_type) diff --git a/mypy/checker.py b/mypy/checker.py index 320a2cf83f2b..18cb56fb368e 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -36,8 +36,7 @@ UnionType, TypeVarId, TypeVarType, PartialType, DeletedType, UninhabitedType, TypeVarDef, is_named_instance, union_items, TypeQuery, LiteralType, is_optional, remove_optional, TypeTranslator, StarType, get_proper_type, ProperType, - get_proper_types, is_literal_type -) + get_proper_types, is_literal_type, TypeAliasType) from mypy.sametypes import is_same_type from mypy.messages import ( MessageBuilder, make_inferred_type_note, append_invariance_notes, @@ -2480,7 +2479,7 @@ def check_multi_assignment(self, lvalues: List[Lvalue], # If this is an Optional type in non-strict Optional code, unwrap it. relevant_items = rvalue_type.relevant_items() if len(relevant_items) == 1: - rvalue_type = relevant_items[0] + rvalue_type = get_proper_type(relevant_items[0]) if isinstance(rvalue_type, AnyType): for lv in lvalues: @@ -2587,7 +2586,7 @@ def check_multi_assignment_from_tuple(self, lvalues: List[Lvalue], rvalue: Expre # If this is an Optional type in non-strict Optional code, unwrap it. relevant_items = reinferred_rvalue_type.relevant_items() if len(relevant_items) == 1: - reinferred_rvalue_type = relevant_items[0] + reinferred_rvalue_type = get_proper_type(relevant_items[0]) if isinstance(reinferred_rvalue_type, UnionType): self.check_multi_assignment_from_union(lvalues, rvalue, reinferred_rvalue_type, context, @@ -3732,7 +3731,7 @@ def find_isinstance_check(self, node: Expression type = get_isinstance_type(node.args[1], type_map) if isinstance(vartype, UnionType): union_list = [] - for t in vartype.items: + for t in get_proper_types(vartype.items): if isinstance(t, TypeType): union_list.append(t.item) else: @@ -4558,6 +4557,7 @@ def overload_can_never_match(signature: CallableType, other: CallableType) -> bo # TODO: find a cleaner solution instead of this ad-hoc erasure. exp_signature = expand_type(signature, {tvar.id: erase_def_to_union_or_bound(tvar) for tvar in signature.variables}) + assert isinstance(exp_signature, ProperType) assert isinstance(exp_signature, CallableType) return is_callable_compatible(exp_signature, other, is_compat=is_more_precise, @@ -4641,6 +4641,11 @@ def visit_uninhabited_type(self, t: UninhabitedType) -> Type: return AnyType(TypeOfAny.from_error) return t + def visit_type_alias_type(self, t: TypeAliasType) -> Type: + # Target of the alias cannot by an ambigous , so we just + # replace the arguments. + return t.copy_modified(args=[a.accept(self) for a in t.args]) + def is_node_static(node: Optional[Node]) -> Optional[bool]: """Find out if a node describes a static function method.""" diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index b0ae074229ab..b86801e25f1b 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -900,7 +900,7 @@ def check_callable_call(self, callee = callee.copy_modified(ret_type=new_ret_type) return callee.ret_type, callee - def analyze_type_type_callee(self, item: ProperType, context: Context) -> ProperType: + def analyze_type_type_callee(self, item: ProperType, context: Context) -> Type: """Analyze the callee X in X(...) where X is Type[item]. Return a Y that we can pass to check_call(Y, ...). @@ -913,7 +913,7 @@ def analyze_type_type_callee(self, item: ProperType, context: Context) -> Proper res = res.copy_modified(from_type_type=True) return expand_type_by_instance(res, item) if isinstance(item, UnionType): - return UnionType([self.analyze_type_type_callee(tp, context) + return UnionType([self.analyze_type_type_callee(get_proper_type(tp), context) for tp in item.relevant_items()], item.line) if isinstance(item, TypeVarType): # Pretend we're calling the typevar's upper bound, @@ -921,6 +921,7 @@ def analyze_type_type_callee(self, item: ProperType, context: Context) -> Proper # but better than AnyType...), but replace the return type # with typevar. callee = self.analyze_type_type_callee(get_proper_type(item.upper_bound), context) + callee = get_proper_type(callee) if isinstance(callee, CallableType): callee = callee.copy_modified(ret_type=item) elif isinstance(callee, Overloaded): @@ -2144,8 +2145,7 @@ def dangerous_comparison(self, left: Type, right: Type, if not self.chk.options.strict_equality: return False - left = get_proper_type(left) - right = get_proper_type(right) + left, right = get_proper_types((left, right)) if self.chk.binder.is_unreachable_warning_suppressed(): # We are inside a function that contains type variables with value restrictions in @@ -2165,6 +2165,7 @@ def dangerous_comparison(self, left: Type, right: Type, if isinstance(left, UnionType) and isinstance(right, UnionType): left = remove_optional(left) right = remove_optional(right) + left, right = get_proper_types((left, right)) py2 = self.chk.options.python_version < (3, 0) if (original_container and has_bytes_component(original_container, py2) and has_bytes_component(left, py2)): @@ -2794,7 +2795,7 @@ def try_getting_int_literals(self, index: Expression) -> Optional[List[int]]: return [typ.value] if isinstance(typ, UnionType): out = [] - for item in typ.items: + for item in get_proper_types(typ.items): if isinstance(item, LiteralType) and isinstance(item.value, int): out.append(item.value) else: @@ -2969,7 +2970,7 @@ class LongName(Generic[T]): ... # For example: # A = List[Tuple[T, T]] # x = A() <- same as List[Tuple[Any, Any]], see PEP 484. - item = set_any_tvars(target, alias_tvars, ctx.line, ctx.column) + item = get_proper_type(set_any_tvars(target, alias_tvars, ctx.line, ctx.column)) if isinstance(item, Instance): # Normally we get a callable type (or overloaded) with .is_type_obj() true # representing the class's constructor @@ -3052,7 +3053,7 @@ def visit_tuple_expr(self, e: TupleExpr) -> Type: type_context = get_proper_type(self.type_context[-1]) type_context_items = None if isinstance(type_context, UnionType): - tuples_in_context = [t for t in type_context.items + tuples_in_context = [t for t in get_proper_types(type_context.items) if (isinstance(t, TupleType) and len(t.items) == len(e.items)) or is_named_instance(t, 'builtins.tuple')] if len(tuples_in_context) == 1: @@ -3240,7 +3241,8 @@ def infer_lambda_type_using_context(self, e: LambdaExpr) -> Tuple[Optional[Calla ctx = get_proper_type(self.type_context[-1]) if isinstance(ctx, UnionType): - callables = [t for t in ctx.relevant_items() if isinstance(t, CallableType)] + callables = [t for t in get_proper_types(ctx.relevant_items()) + if isinstance(t, CallableType)] if len(callables) == 1: ctx = callables[0] diff --git a/mypy/checkmember.py b/mypy/checkmember.py index 85683d3c82c1..859bd6afcc6d 100644 --- a/mypy/checkmember.py +++ b/mypy/checkmember.py @@ -6,7 +6,7 @@ from mypy.types import ( Type, Instance, AnyType, TupleType, TypedDictType, CallableType, FunctionLike, TypeVarDef, Overloaded, TypeVarType, UnionType, PartialType, TypeOfAny, LiteralType, - DeletedType, NoneType, TypeType, get_type_vars, get_proper_type, ProperType + DeletedType, NoneType, TypeType, has_type_vars, get_proper_type, ProperType ) from mypy.nodes import ( TypeInfo, FuncBase, Var, FuncDef, SymbolNode, Context, MypyFile, TypeVarExpr, @@ -377,7 +377,7 @@ def analyze_member_var_access(name: str, function = function_type(method, mx.builtin_type('builtins.function')) bound_method = bind_self(function, mx.self_type) typ = map_instance_to_supertype(itype, method.info) - getattr_type = expand_type_by_instance(bound_method, typ) + getattr_type = get_proper_type(expand_type_by_instance(bound_method, typ)) if isinstance(getattr_type, CallableType): result = getattr_type.ret_type @@ -394,7 +394,7 @@ def analyze_member_var_access(name: str, setattr_func = function_type(setattr_meth, mx.builtin_type('builtins.function')) bound_type = bind_self(setattr_func, mx.self_type) typ = map_instance_to_supertype(itype, setattr_meth.info) - setattr_type = expand_type_by_instance(bound_type, typ) + setattr_type = get_proper_type(expand_type_by_instance(bound_type, typ)) if isinstance(setattr_type, CallableType) and len(setattr_type.arg_types) > 0: return setattr_type.arg_types[-1] @@ -497,10 +497,11 @@ def instance_alias_type(alias: TypeAlias, As usual, we first erase any unbound type variables to Any. """ - target = get_proper_type(alias.target) - assert isinstance(target, Instance), "Must be called only with aliases to classes" + target = get_proper_type(alias.target) # type: Type + assert isinstance(get_proper_type(target), + Instance), "Must be called only with aliases to classes" target = set_any_tvars(target, alias.alias_tvars, alias.line, alias.column) - assert isinstance(target, Instance) + assert isinstance(target, Instance) # type: ignore[misc] tp = type_object_type(target.type, builtin_type) return expand_type_by_instance(tp, target) @@ -525,7 +526,7 @@ def analyze_var(name: str, if typ: if isinstance(typ, PartialType): return mx.chk.handle_partial_var_type(typ, mx.is_lvalue, var, mx.context) - t = expand_type_by_instance(typ, itype) + t = get_proper_type(expand_type_by_instance(typ, itype)) if mx.is_lvalue and var.is_property and not var.is_settable_property: # TODO allow setting attributes in subclass (although it is probably an error) mx.msg.read_only_property(name, itype.type, mx.context) @@ -577,7 +578,9 @@ def analyze_var(name: str, return result -def freeze_type_vars(member_type: ProperType) -> None: +def freeze_type_vars(member_type: Type) -> None: + if not isinstance(member_type, ProperType): + return if isinstance(member_type, CallableType): for v in member_type.variables: v.id.meta_level = 0 @@ -713,7 +716,7 @@ def analyze_class_attribute_access(itype: Instance, # x: T # C.x # Error, ambiguous access # C[int].x # Also an error, since C[int] is same as C at runtime - if isinstance(t, TypeVarType) or get_type_vars(t): + if isinstance(t, TypeVarType) or has_type_vars(t): # Exception: access on Type[...], including first argument of class methods is OK. if not isinstance(get_proper_type(mx.original_type), TypeType): if node.node.is_classvar: @@ -799,7 +802,7 @@ class B(A[str]): pass info = itype.type # type: TypeInfo if is_classmethod: assert isuper is not None - t = expand_type_by_instance(t, isuper) + t = get_proper_type(expand_type_by_instance(t, isuper)) # We add class type variables if the class method is accessed on class object # without applied type arguments, this matches the behavior of __init__(). # For example (continuing the example in docstring): diff --git a/mypy/checkstrformat.py b/mypy/checkstrformat.py index 2add4509b298..08be7f86a2c1 100644 --- a/mypy/checkstrformat.py +++ b/mypy/checkstrformat.py @@ -19,7 +19,7 @@ from mypy.types import ( Type, AnyType, TupleType, Instance, UnionType, TypeOfAny, get_proper_type, TypeVarType, - CallableType, LiteralType + CallableType, LiteralType, get_proper_types ) from mypy.nodes import ( StrExpr, BytesExpr, UnicodeExpr, TupleExpr, DictExpr, Context, Expression, StarExpr, CallExpr, @@ -359,7 +359,8 @@ def check_specs_in_format_call(self, call: CallExpr, continue a_type = get_proper_type(actual_type) - actual_items = a_type.items if isinstance(a_type, UnionType) else [a_type] + actual_items = (get_proper_types(a_type.items) if isinstance(a_type, UnionType) + else [a_type]) for a_type in actual_items: if custom_special_method(a_type, '__format__'): continue diff --git a/mypy/constraints.py b/mypy/constraints.py index b81445532ec9..a078eb0b08b5 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -7,7 +7,7 @@ CallableType, Type, TypeVisitor, UnboundType, AnyType, NoneType, TypeVarType, Instance, TupleType, TypedDictType, UnionType, Overloaded, ErasedType, PartialType, DeletedType, UninhabitedType, TypeType, TypeVarId, TypeQuery, is_named_instance, TypeOfAny, LiteralType, - ProperType, get_proper_type + ProperType, get_proper_type, TypeAliasType ) from mypy.maptype import map_instance_to_supertype import mypy.subtypes @@ -16,6 +16,7 @@ from mypy.erasetype import erase_typevars from mypy.nodes import COVARIANT, CONTRAVARIANT from mypy.argmap import ArgTypeExpander +from mypy.typestate import TypeState SUBTYPE_OF = 0 # type: Final[int] SUPERTYPE_OF = 1 # type: Final[int] @@ -89,6 +90,21 @@ def infer_constraints(template: Type, actual: Type, The constraints are represented as Constraint objects. """ + if any(get_proper_type(template) == get_proper_type(t) for t in TypeState._inferring): + return [] + if (isinstance(template, TypeAliasType) and isinstance(actual, TypeAliasType) and + template.is_recursive and actual.is_recursive): + # This case requires special care because it may cause infinite recursion. + TypeState._inferring.append(template) + res = _infer_constraints(template, actual, direction) + TypeState._inferring.pop() + return res + return _infer_constraints(template, actual, direction) + + +def _infer_constraints(template: Type, actual: Type, + direction: int) -> List[Constraint]: + template = get_proper_type(template) actual = get_proper_type(actual) @@ -487,6 +503,9 @@ def visit_union_type(self, template: UnionType) -> List[Constraint]: assert False, ("Unexpected UnionType in ConstraintBuilderVisitor" " (should have been handled in infer_constraints)") + def visit_type_alias_type(self, template: TypeAliasType) -> List[Constraint]: + assert False, "This should be never called, got {}".format(template) + def infer_against_any(self, types: Iterable[Type], any_type: AnyType) -> List[Constraint]: res = [] # type: List[Constraint] for t in types: diff --git a/mypy/erasetype.py b/mypy/erasetype.py index 62580dfb0f12..55cc58798c1c 100644 --- a/mypy/erasetype.py +++ b/mypy/erasetype.py @@ -4,7 +4,7 @@ Type, TypeVisitor, UnboundType, AnyType, NoneType, TypeVarId, Instance, TypeVarType, CallableType, TupleType, TypedDictType, UnionType, Overloaded, ErasedType, PartialType, DeletedType, TypeTranslator, UninhabitedType, TypeType, TypeOfAny, LiteralType, ProperType, - get_proper_type + get_proper_type, TypeAliasType ) from mypy.nodes import ARG_STAR, ARG_STAR2 @@ -93,6 +93,9 @@ def visit_union_type(self, t: UnionType) -> ProperType: def visit_type_type(self, t: TypeType) -> ProperType: return TypeType.make_normalized(t.item.accept(self), line=t.line) + def visit_type_alias_type(self, t: TypeAliasType) -> ProperType: + raise RuntimeError("Type aliases should be expanded before accepting this visitor") + def erase_typevars(t: Type, ids_to_erase: Optional[Container[TypeVarId]] = None) -> Type: """Replace all type variables in a type with any, @@ -122,6 +125,11 @@ def visit_type_var(self, t: TypeVarType) -> Type: return self.replacement return t + def visit_type_alias_type(self, t: TypeAliasType) -> Type: + # Type alias target can't contain bound type variables, so + # it is safe to just erase the arguments. + return t.copy_modified(args=[a.accept(self) for a in t.args]) + def remove_instance_last_known_values(t: Type) -> Type: return t.accept(LastKnownValueEraser()) @@ -135,3 +143,8 @@ def visit_instance(self, t: Instance) -> Type: if t.last_known_value: return t.copy_modified(last_known_value=None) return t + + def visit_type_alias_type(self, t: TypeAliasType) -> Type: + # Type aliases can't contain literal values, because they are + # always constructed as explicit types. + return t diff --git a/mypy/expandtype.py b/mypy/expandtype.py index d92275b684bf..128f187e3d88 100644 --- a/mypy/expandtype.py +++ b/mypy/expandtype.py @@ -4,23 +4,21 @@ Type, Instance, CallableType, TypeVisitor, UnboundType, AnyType, NoneType, TypeVarType, Overloaded, TupleType, TypedDictType, UnionType, ErasedType, PartialType, DeletedType, UninhabitedType, TypeType, TypeVarId, - FunctionLike, TypeVarDef, LiteralType, get_proper_type, ProperType -) + FunctionLike, TypeVarDef, LiteralType, get_proper_type, ProperType, + TypeAliasType) -def expand_type(typ: Type, env: Mapping[TypeVarId, Type]) -> ProperType: +def expand_type(typ: Type, env: Mapping[TypeVarId, Type]) -> Type: """Substitute any type variable references in a type given by a type environment. """ - + # TODO: use an overloaded signature? (ProperType stays proper after expansion.) return typ.accept(ExpandTypeVisitor(env)) -def expand_type_by_instance(typ: Type, instance: Instance) -> ProperType: +def expand_type_by_instance(typ: Type, instance: Instance) -> Type: """Substitute type variables in type using values from an Instance. Type variables are considered to be bound by the class declaration.""" - typ = get_proper_type(typ) - if instance.args == []: return typ else: @@ -53,7 +51,7 @@ def freshen_function_type_vars(callee: F) -> F: return cast(F, fresh_overload) -class ExpandTypeVisitor(TypeVisitor[ProperType]): +class ExpandTypeVisitor(TypeVisitor[Type]): """Visitor that substitutes type variables with values.""" variables = None # type: Mapping[TypeVarId, Type] # TypeVar id -> TypeVar value @@ -61,30 +59,30 @@ class ExpandTypeVisitor(TypeVisitor[ProperType]): def __init__(self, variables: Mapping[TypeVarId, Type]) -> None: self.variables = variables - def visit_unbound_type(self, t: UnboundType) -> ProperType: + def visit_unbound_type(self, t: UnboundType) -> Type: return t - def visit_any(self, t: AnyType) -> ProperType: + def visit_any(self, t: AnyType) -> Type: return t - def visit_none_type(self, t: NoneType) -> ProperType: + def visit_none_type(self, t: NoneType) -> Type: return t - def visit_uninhabited_type(self, t: UninhabitedType) -> ProperType: + def visit_uninhabited_type(self, t: UninhabitedType) -> Type: return t - def visit_deleted_type(self, t: DeletedType) -> ProperType: + def visit_deleted_type(self, t: DeletedType) -> Type: return t - def visit_erased_type(self, t: ErasedType) -> ProperType: + def visit_erased_type(self, t: ErasedType) -> Type: # Should not get here. raise RuntimeError() - def visit_instance(self, t: Instance) -> ProperType: + def visit_instance(self, t: Instance) -> Type: args = self.expand_types(t.args) return Instance(t.type, args, t.line, t.column) - def visit_type_var(self, t: TypeVarType) -> ProperType: + def visit_type_var(self, t: TypeVarType) -> Type: repl = get_proper_type(self.variables.get(t.id, t)) if isinstance(repl, Instance): inst = repl @@ -94,44 +92,50 @@ def visit_type_var(self, t: TypeVarType) -> ProperType: else: return repl - def visit_callable_type(self, t: CallableType) -> ProperType: + def visit_callable_type(self, t: CallableType) -> Type: return t.copy_modified(arg_types=self.expand_types(t.arg_types), ret_type=t.ret_type.accept(self)) - def visit_overloaded(self, t: Overloaded) -> ProperType: + def visit_overloaded(self, t: Overloaded) -> Type: items = [] # type: List[CallableType] for item in t.items(): new_item = item.accept(self) + assert isinstance(new_item, ProperType) assert isinstance(new_item, CallableType) items.append(new_item) return Overloaded(items) - def visit_tuple_type(self, t: TupleType) -> ProperType: + def visit_tuple_type(self, t: TupleType) -> Type: return t.copy_modified(items=self.expand_types(t.items)) - def visit_typeddict_type(self, t: TypedDictType) -> ProperType: + def visit_typeddict_type(self, t: TypedDictType) -> Type: return t.copy_modified(item_types=self.expand_types(t.items.values())) - def visit_literal_type(self, t: LiteralType) -> ProperType: + def visit_literal_type(self, t: LiteralType) -> Type: # TODO: Verify this implementation is correct return t - def visit_union_type(self, t: UnionType) -> ProperType: + def visit_union_type(self, t: UnionType) -> Type: # After substituting for type variables in t.items, # some of the resulting types might be subtypes of others. from mypy.typeops import make_simplified_union # asdf return make_simplified_union(self.expand_types(t.items), t.line, t.column) - def visit_partial_type(self, t: PartialType) -> ProperType: + def visit_partial_type(self, t: PartialType) -> Type: return t - def visit_type_type(self, t: TypeType) -> ProperType: + def visit_type_type(self, t: TypeType) -> Type: # TODO: Verify that the new item type is valid (instance or # union of instances or Any). Sadly we can't report errors # here yet. item = t.item.accept(self) return TypeType.make_normalized(item) + def visit_type_alias_type(self, t: TypeAliasType) -> Type: + # Target of the type alias cannot contain type variables, + # so we just expand the arguments. + return t.copy_modified(args=self.expand_types(t.args)) + def expand_types(self, types: Iterable[Type]) -> List[Type]: a = [] # type: List[Type] for t in types: diff --git a/mypy/fixup.py b/mypy/fixup.py index 8f3e29c9750d..73458c59e619 100644 --- a/mypy/fixup.py +++ b/mypy/fixup.py @@ -1,6 +1,7 @@ """Fix up various things after deserialization.""" from typing import Any, Dict, Optional +from typing_extensions import Final from mypy.nodes import ( MypyFile, SymbolNode, SymbolTable, SymbolTableNode, @@ -10,8 +11,7 @@ from mypy.types import ( CallableType, Instance, Overloaded, TupleType, TypedDictType, TypeVarType, UnboundType, UnionType, TypeVisitor, LiteralType, - TypeType, NOT_READY -) + TypeType, NOT_READY, TypeAliasType, AnyType, TypeOfAny) from mypy.visitor import NodeVisitor from mypy.lookup import lookup_fully_qualified @@ -161,6 +161,15 @@ def visit_instance(self, inst: Instance) -> None: if inst.last_known_value is not None: inst.last_known_value.accept(self) + def visit_type_alias_type(self, t: TypeAliasType) -> None: + type_ref = t.type_ref + if type_ref is None: + return # We've already been here. + t.type_ref = None + t.alias = lookup_qualified_alias(self.modules, type_ref, self.allow_missing) + for a in t.args: + a.accept(self) + def visit_any(self, o: Any) -> None: pass # Nothing to descend into. @@ -262,6 +271,20 @@ def lookup_qualified_typeinfo(modules: Dict[str, MypyFile], name: str, return missing_info(modules) +def lookup_qualified_alias(modules: Dict[str, MypyFile], name: str, + allow_missing: bool) -> TypeAlias: + node = lookup_qualified(modules, name, allow_missing) + if isinstance(node, TypeAlias): + return node + else: + # Looks like a missing TypeAlias during an initial daemon load, put something there + assert allow_missing, "Should never get here in normal mode," \ + " got {}:{} instead of TypeAlias".format(type(node).__name__, + node.fullname() if node + else '') + return missing_alias() + + def lookup_qualified(modules: Dict[str, MypyFile], name: str, allow_missing: bool) -> Optional[SymbolNode]: stnode = lookup_qualified_stnode(modules, name, allow_missing) @@ -276,8 +299,11 @@ def lookup_qualified_stnode(modules: Dict[str, MypyFile], name: str, return lookup_fully_qualified(name, modules, raise_on_missing=not allow_missing) +_SUGGESTION = "" # type: Final + + def missing_info(modules: Dict[str, MypyFile]) -> TypeInfo: - suggestion = "" + suggestion = _SUGGESTION.format('info') dummy_def = ClassDef(suggestion, Block([])) dummy_def.fullname = suggestion @@ -287,3 +313,9 @@ def missing_info(modules: Dict[str, MypyFile]) -> TypeInfo: info.bases = [Instance(obj_type, [])] info.mro = [info, obj_type] return info + + +def missing_alias() -> TypeAlias: + suggestion = _SUGGESTION.format('alias') + return TypeAlias(AnyType(TypeOfAny.special_form), suggestion, + line=-1, column=-1) diff --git a/mypy/indirection.py b/mypy/indirection.py index 0d5b3135560b..bae9be4cb750 100644 --- a/mypy/indirection.py +++ b/mypy/indirection.py @@ -20,14 +20,21 @@ class TypeIndirectionVisitor(TypeVisitor[Set[str]]): def __init__(self) -> None: self.cache = {} # type: Dict[types.Type, Set[str]] + self.seen_aliases = set() # type: Set[types.TypeAliasType] def find_modules(self, typs: Iterable[types.Type]) -> Set[str]: + self.seen_aliases.clear() return self._visit(typs) def _visit(self, typ_or_typs: Union[types.Type, Iterable[types.Type]]) -> Set[str]: typs = [typ_or_typs] if isinstance(typ_or_typs, types.Type) else typ_or_typs output = set() # type: Set[str] for typ in typs: + if isinstance(typ, types.TypeAliasType): + # Avoid infinite recursion for recursive type aliases. + if typ in self.seen_aliases: + continue + self.seen_aliases.add(typ) if typ in self.cache: modules = self.cache[typ] else: @@ -95,3 +102,6 @@ def visit_partial_type(self, t: types.PartialType) -> Set[str]: def visit_type_type(self, t: types.TypeType) -> Set[str]: return self._visit(t.item) + + def visit_type_alias_type(self, t: types.TypeAliasType) -> Set[str]: + return self._visit(types.get_proper_type(t)) diff --git a/mypy/join.py b/mypy/join.py index b8f1f2d66dc8..7472c2dead93 100644 --- a/mypy/join.py +++ b/mypy/join.py @@ -7,7 +7,7 @@ Type, AnyType, NoneType, TypeVisitor, Instance, UnboundType, TypeVarType, CallableType, TupleType, TypedDictType, ErasedType, UnionType, FunctionLike, Overloaded, LiteralType, PartialType, DeletedType, UninhabitedType, TypeType, TypeOfAny, get_proper_type, - ProperType, get_proper_types + ProperType, get_proper_types, TypeAliasType ) from mypy.maptype import map_instance_to_supertype from mypy.subtypes import ( @@ -21,6 +21,7 @@ def join_simple(declaration: Optional[Type], s: Type, t: Type) -> ProperType: """Return a simple least upper bound given the declared type.""" + # TODO: check infinite recursion for aliases here. declaration = get_proper_type(declaration) s = get_proper_type(s) t = get_proper_type(t) @@ -58,11 +59,25 @@ def join_simple(declaration: Optional[Type], s: Type, t: Type) -> ProperType: return declaration +def trivial_join(s: Type, t: Type) -> ProperType: + """Return one of types (expanded) if it is a supertype of other, otherwise top type.""" + if is_subtype(s, t): + return get_proper_type(t) + elif is_subtype(t, s): + return get_proper_type(s) + else: + return object_or_any_from_type(get_proper_type(t)) + + def join_types(s: Type, t: Type) -> ProperType: """Return the least upper bound of s and t. For example, the join of 'int' and 'object' is 'object'. """ + if mypy.typeops.is_recursive_pair(s, t): + # This case can trigger an infinite recursion, general support for this will be + # tricky so we use a trivial join (like for protocols). + return trivial_join(s, t) s = get_proper_type(s) t = get_proper_type(t) @@ -292,6 +307,9 @@ def visit_type_type(self, t: TypeType) -> ProperType: else: return self.default(self.s) + def visit_type_alias_type(self, t: TypeAliasType) -> ProperType: + assert False, "This should be never called, got {}".format(t) + def join(self, s: Type, t: Type) -> ProperType: return join_types(s, t) @@ -454,12 +472,31 @@ def object_from_instance(instance: Instance) -> Instance: return res -def join_type_list(types: List[Type]) -> Type: +def object_or_any_from_type(typ: ProperType) -> ProperType: + # Similar to object_from_instance() but tries hard for all types. + # TODO: find a better way to get object, or make this more reliable. + if isinstance(typ, Instance): + return object_from_instance(typ) + elif isinstance(typ, (CallableType, TypedDictType, LiteralType)): + return object_from_instance(typ.fallback) + elif isinstance(typ, TupleType): + return object_from_instance(typ.partial_fallback) + elif isinstance(typ, TypeType): + return object_or_any_from_type(typ.item) + elif isinstance(typ, TypeVarType) and isinstance(typ.upper_bound, ProperType): + return object_or_any_from_type(typ.upper_bound) + elif isinstance(typ, UnionType): + joined = join_type_list([it for it in typ.items if isinstance(it, ProperType)]) + return object_or_any_from_type(joined) + return AnyType(TypeOfAny.implementation_artifact) + + +def join_type_list(types: List[Type]) -> ProperType: if not types: # This is a little arbitrary but reasonable. Any empty tuple should be compatible # with all variable length tuples, and this makes it possible. return UninhabitedType() - joined = types[0] + joined = get_proper_type(types[0]) for t in types[1:]: joined = join_types(joined, t) return joined diff --git a/mypy/maptype.py b/mypy/maptype.py index f90d0a056cab..5e58754655ef 100644 --- a/mypy/maptype.py +++ b/mypy/maptype.py @@ -2,7 +2,7 @@ from mypy.expandtype import expand_type from mypy.nodes import TypeInfo -from mypy.types import Type, TypeVarId, Instance, AnyType, TypeOfAny +from mypy.types import Type, TypeVarId, Instance, AnyType, TypeOfAny, ProperType def map_instance_to_supertype(instance: Instance, @@ -80,6 +80,7 @@ def map_instance_to_direct_supertypes(instance: Instance, if b.type == supertype: env = instance_to_type_environment(instance) t = expand_type(b, env) + assert isinstance(t, ProperType) assert isinstance(t, Instance) result.append(t) diff --git a/mypy/meet.py b/mypy/meet.py index 517eb93a5c81..59e94a24596f 100644 --- a/mypy/meet.py +++ b/mypy/meet.py @@ -8,19 +8,36 @@ Type, AnyType, TypeVisitor, UnboundType, NoneType, TypeVarType, Instance, CallableType, TupleType, TypedDictType, ErasedType, UnionType, PartialType, DeletedType, UninhabitedType, TypeType, TypeOfAny, Overloaded, FunctionLike, LiteralType, - ProperType, get_proper_type, get_proper_types + ProperType, get_proper_type, get_proper_types, TypeAliasType ) from mypy.subtypes import is_equivalent, is_subtype, is_callable_compatible, is_proper_subtype from mypy.erasetype import erase_type from mypy.maptype import map_instance_to_supertype -from mypy.typeops import tuple_fallback, make_simplified_union +from mypy.typeops import tuple_fallback, make_simplified_union, is_recursive_pair from mypy import state # TODO Describe this module. +def trivial_meet(s: Type, t: Type) -> ProperType: + """Return one of types (expanded) if it is a subtype of other, otherwise bottom type.""" + if is_subtype(s, t): + return get_proper_type(s) + elif is_subtype(t, s): + return get_proper_type(t) + else: + if state.strict_optional: + return UninhabitedType() + else: + return NoneType() + + def meet_types(s: Type, t: Type) -> ProperType: """Return the greatest lower bound of two types.""" + if is_recursive_pair(s, t): + # This case can trigger an infinite recursion, general support for this will be + # tricky so we use a trivial meet (like for protocols). + return trivial_meet(s, t) s = get_proper_type(s) t = get_proper_type(t) @@ -35,6 +52,7 @@ def meet_types(s: Type, t: Type) -> ProperType: def narrow_declared_type(declared: Type, narrowed: Type) -> Type: """Return the declared type narrowed down to another type.""" + # TODO: check infinite recursion for aliases here. declared = get_proper_type(declared) narrowed = get_proper_type(narrowed) @@ -119,8 +137,7 @@ def is_overlapping_types(left: Type, If 'prohibit_none_typevar_overlap' is True, we disallow None from overlapping with TypeVars (in both strict-optional and non-strict-optional mode). """ - left = get_proper_type(left) - right = get_proper_type(right) + left, right = get_proper_types((left, right)) def _is_overlapping_types(left: Type, right: Type) -> bool: '''Encode the kind of overlapping check to perform. @@ -156,6 +173,7 @@ def _is_overlapping_types(left: Type, right: Type) -> bool: left = UnionType.make_union(left.relevant_items()) if isinstance(right, UnionType): right = UnionType.make_union(right.relevant_items()) + left, right = get_proper_types((left, right)) # We check for complete overlaps next as a general-purpose failsafe. # If this check fails, we start checking to see if there exists a @@ -183,7 +201,8 @@ def _is_overlapping_types(left: Type, right: Type) -> bool: # If both types are singleton variants (and are not TypeVars), we've hit the base case: # we skip these checks to avoid infinitely recursing. - def is_none_typevar_overlap(t1: ProperType, t2: ProperType) -> bool: + def is_none_typevar_overlap(t1: Type, t2: Type) -> bool: + t1, t2 = get_proper_types((t1, t2)) return isinstance(t1, NoneType) and isinstance(t2, TypeVarType) if prohibit_none_typevar_overlap: @@ -242,9 +261,10 @@ def is_none_typevar_overlap(t1: ProperType, t2: ProperType) -> bool: if isinstance(left, TypeType) and isinstance(right, TypeType): return _is_overlapping_types(left.item, right.item) - def _type_object_overlap(left: ProperType, right: ProperType) -> bool: + def _type_object_overlap(left: Type, right: Type) -> bool: """Special cases for type object types overlaps.""" # TODO: these checks are a bit in gray area, adjust if they cause problems. + left, right = get_proper_types((left, right)) # 1. Type[C] vs Callable[..., C], where the latter is class object. if isinstance(left, TypeType) and isinstance(right, CallableType) and right.is_type_obj(): return _is_overlapping_types(left.item, right.ret_type) @@ -370,10 +390,11 @@ def are_typed_dicts_overlapping(left: TypedDictType, right: TypedDictType, *, return True -def are_tuples_overlapping(left: ProperType, right: ProperType, *, +def are_tuples_overlapping(left: Type, right: Type, *, ignore_promotions: bool = False, prohibit_none_typevar_overlap: bool = False) -> bool: """Returns true if left and right are overlapping tuples.""" + left, right = get_proper_types((left, right)) left = adjust_tuple(left, right) or left right = adjust_tuple(right, left) or right assert isinstance(left, TupleType), 'Type {} is not a tuple'.format(left) @@ -612,6 +633,9 @@ def visit_type_type(self, t: TypeType) -> ProperType: else: return self.default(self.s) + def visit_type_alias_type(self, t: TypeAliasType) -> ProperType: + assert False, "This should be never called, got {}".format(t) + def meet(self, s: Type, t: Type) -> ProperType: return meet_types(s, t) @@ -654,7 +678,7 @@ def meet_type_list(types: List[Type]) -> Type: return met -def typed_dict_mapping_pair(left: ProperType, right: ProperType) -> bool: +def typed_dict_mapping_pair(left: Type, right: Type) -> bool: """Is this a pair where one type is a TypedDict and another one is an instance of Mapping? This case requires a precise/principled consideration because there are two use cases @@ -662,6 +686,7 @@ def typed_dict_mapping_pair(left: ProperType, right: ProperType) -> bool: false positives for overloads, but we also need to avoid spuriously non-overlapping types to avoid false positives with --strict-equality. """ + left, right = get_proper_types((left, right)) assert not isinstance(left, TypedDictType) or not isinstance(right, TypedDictType) if isinstance(left, TypedDictType): @@ -673,7 +698,7 @@ def typed_dict_mapping_pair(left: ProperType, right: ProperType) -> bool: return isinstance(other, Instance) and other.type.has_base('typing.Mapping') -def typed_dict_mapping_overlap(left: ProperType, right: ProperType, +def typed_dict_mapping_overlap(left: Type, right: Type, overlapping: Callable[[Type, Type], bool]) -> bool: """Check if a TypedDict type is overlapping with a Mapping. @@ -703,6 +728,7 @@ def typed_dict_mapping_overlap(left: ProperType, right: ProperType, Mapping[, ]. This way we avoid false positives for overloads, and also avoid false positives for comparisons like SomeTypedDict == {} under --strict-equality. """ + left, right = get_proper_types((left, right)) assert not isinstance(left, TypedDictType) or not isinstance(right, TypedDictType) if isinstance(left, TypedDictType): diff --git a/mypy/messages.py b/mypy/messages.py index 021ac5c0f1a7..29b99352cfec 100644 --- a/mypy/messages.py +++ b/mypy/messages.py @@ -494,7 +494,7 @@ def incompatible_argument(self, expected_types = list(expected_type.items) else: expected_types = [expected_type] - for type in expected_types: + for type in get_proper_types(expected_types): if isinstance(arg_type, Instance) and isinstance(type, Instance): notes = append_invariance_notes(notes, arg_type, type) self.fail(msg, context, code=code) @@ -1484,9 +1484,10 @@ def format(typ: Type) -> str: elif isinstance(typ, UnionType): # Only print Unions as Optionals if the Optional wouldn't have to contain another Union print_as_optional = (len(typ.items) - - sum(isinstance(t, NoneType) for t in typ.items) == 1) + sum(isinstance(get_proper_type(t), NoneType) + for t in typ.items) == 1) if print_as_optional: - rest = [t for t in typ.items if not isinstance(t, NoneType)] + rest = [t for t in typ.items if not isinstance(get_proper_type(t), NoneType)] return 'Optional[{}]'.format(format(rest[0])) else: items = [] diff --git a/mypy/nodes.py b/mypy/nodes.py index 15168163b7e6..f294705ada01 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -2721,7 +2721,8 @@ def f(x: B[T]) -> T: ... # without T, Any would be used here Python runtime limitation. line and column: Line an column on the original alias definition. """ - __slots__ = ('target', '_fullname', 'alias_tvars', 'no_args', 'normalized', 'line', 'column') + __slots__ = ('target', '_fullname', 'alias_tvars', 'no_args', 'normalized', + 'line', 'column', 'assuming', 'assuming_proper', 'inferring') def __init__(self, target: 'mypy.types.Type', fullname: str, line: int, column: int, *, diff --git a/mypy/sametypes.py b/mypy/sametypes.py index f09de9c18e15..024333a13ec8 100644 --- a/mypy/sametypes.py +++ b/mypy/sametypes.py @@ -4,8 +4,7 @@ Type, UnboundType, AnyType, NoneType, TupleType, TypedDictType, UnionType, CallableType, TypeVarType, Instance, TypeVisitor, ErasedType, Overloaded, PartialType, DeletedType, UninhabitedType, TypeType, LiteralType, - ProperType, get_proper_type -) + ProperType, get_proper_type, TypeAliasType) from mypy.typeops import tuple_fallback, make_simplified_union @@ -85,6 +84,13 @@ def visit_instance(self, left: Instance) -> bool: is_same_types(left.args, self.right.args) and left.last_known_value == self.right.last_known_value) + def visit_type_alias_type(self, left: TypeAliasType) -> bool: + # Similar to protocols, two aliases with the same targets return False here, + # but both is_subtype(t, s) and is_subtype(s, t) return True. + return (isinstance(self.right, TypeAliasType) and + left.alias == self.right.alias and + is_same_types(left.args, self.right.args)) + def visit_type_var(self, left: TypeVarType) -> bool: return (isinstance(self.right, TypeVarType) and left.id == self.right.id) diff --git a/mypy/semanal.py b/mypy/semanal.py index 1975f2704c70..18e45d684f64 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -88,8 +88,7 @@ FunctionLike, UnboundType, TypeVarDef, TupleType, UnionType, StarType, CallableType, Overloaded, Instance, Type, AnyType, LiteralType, LiteralValue, TypeTranslator, TypeOfAny, TypeType, NoneType, PlaceholderType, TPDICT_NAMES, ProperType, - get_proper_type, get_proper_types -) + get_proper_type, get_proper_types, TypeAliasType) from mypy.typeops import function_type from mypy.type_visitor import TypeQuery from mypy.nodes import implicit_module_attrs @@ -4873,6 +4872,9 @@ def visit_any(self, t: AnyType) -> Type: return t.copy_modified(TypeOfAny.special_form) return t + def visit_type_alias_type(self, t: TypeAliasType) -> Type: + return t.copy_modified(args=[a.accept(self) for a in t.args]) + def apply_semantic_analyzer_patches(patches: List[Tuple[int, Callable[[], None]]]) -> None: """Call patch callbacks in the right order. diff --git a/mypy/server/astdiff.py b/mypy/server/astdiff.py index f7e9cd1b7471..fdb6a273c131 100644 --- a/mypy/server/astdiff.py +++ b/mypy/server/astdiff.py @@ -59,7 +59,7 @@ class level -- these are handled at attribute level (say, 'mod.Cls.method' from mypy.types import ( Type, TypeVisitor, UnboundType, AnyType, NoneType, UninhabitedType, ErasedType, DeletedType, Instance, TypeVarType, CallableType, TupleType, TypedDictType, - UnionType, Overloaded, PartialType, TypeType, LiteralType, + UnionType, Overloaded, PartialType, TypeType, LiteralType, TypeAliasType ) from mypy.util import get_prefix @@ -346,6 +346,10 @@ def visit_partial_type(self, typ: PartialType) -> SnapshotItem: def visit_type_type(self, typ: TypeType) -> SnapshotItem: return ('TypeType', snapshot_type(typ.item)) + def visit_type_alias_type(self, typ: TypeAliasType) -> SnapshotItem: + assert typ.alias is not None + return ('TypeAliasType', typ.alias.fullname(), snapshot_types(typ.args)) + def snapshot_untyped_signature(func: Union[OverloadedFuncDef, FuncItem]) -> Tuple[object, ...]: """Create a snapshot of the signature of a function that has no explicit signature. diff --git a/mypy/server/astmerge.py b/mypy/server/astmerge.py index 62c1e14cf83b..5080905d1eaf 100644 --- a/mypy/server/astmerge.py +++ b/mypy/server/astmerge.py @@ -59,7 +59,7 @@ Type, SyntheticTypeVisitor, Instance, AnyType, NoneType, CallableType, ErasedType, DeletedType, TupleType, TypeType, TypeVarType, TypedDictType, UnboundType, UninhabitedType, UnionType, Overloaded, TypeVarDef, TypeList, CallableArgument, EllipsisType, StarType, LiteralType, - RawExpressionType, PartialType, PlaceholderType, + RawExpressionType, PartialType, PlaceholderType, TypeAliasType ) from mypy.util import get_prefix, replace_object_state from mypy.typestate import TypeState @@ -343,6 +343,12 @@ def visit_instance(self, typ: Instance) -> None: if typ.last_known_value: typ.last_known_value.accept(self) + def visit_type_alias_type(self, typ: TypeAliasType) -> None: + assert typ.alias is not None + typ.alias = self.fixup(typ.alias) + for arg in typ.args: + arg.accept(self) + def visit_any(self, typ: AnyType) -> None: pass diff --git a/mypy/server/deps.py b/mypy/server/deps.py index db457e3e9c72..295b1bca266c 100644 --- a/mypy/server/deps.py +++ b/mypy/server/deps.py @@ -96,8 +96,8 @@ class 'mod.Cls'. This can also refer to an attribute inherited from a from mypy.types import ( Type, Instance, AnyType, NoneType, TypeVisitor, CallableType, DeletedType, PartialType, TupleType, TypeType, TypeVarType, TypedDictType, UnboundType, UninhabitedType, UnionType, - FunctionLike, Overloaded, TypeOfAny, LiteralType, ErasedType, get_proper_type, ProperType -) + FunctionLike, Overloaded, TypeOfAny, LiteralType, ErasedType, get_proper_type, ProperType, + TypeAliasType) from mypy.server.trigger import make_trigger, make_wildcard_trigger from mypy.util import correct_relative_import from mypy.scope import Scope @@ -739,7 +739,7 @@ def add_operator_method_dependency_for_type(self, typ: ProperType, method: str) self.add_dependency(trigger) elif isinstance(typ, UnionType): for item in typ.items: - self.add_operator_method_dependency_for_type(item, method) + self.add_operator_method_dependency_for_type(get_proper_type(item), method) elif isinstance(typ, FunctionLike) and typ.is_type_obj(): self.add_operator_method_dependency_for_type(typ.fallback, method) elif isinstance(typ, TypeType): @@ -878,6 +878,14 @@ def visit_instance(self, typ: Instance) -> List[str]: triggers.extend(self.get_type_triggers(typ.last_known_value)) return triggers + def visit_type_alias_type(self, typ: TypeAliasType) -> List[str]: + assert typ.alias is not None + trigger = make_trigger(typ.alias.fullname()) + triggers = [trigger] + for arg in typ.args: + triggers.extend(self.get_type_triggers(arg)) + return triggers + def visit_any(self, typ: AnyType) -> List[str]: if typ.missing_import_name is not None: return [make_trigger(typ.missing_import_name)] diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 9f4c001264e1..1c5862471656 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -1,13 +1,13 @@ from contextlib import contextmanager -from typing import Any, List, Optional, Callable, Tuple, Iterator, Set, Union, cast +from typing import Any, List, Optional, Callable, Tuple, Iterator, Set, Union, cast, TypeVar from typing_extensions import Final from mypy.types import ( Type, AnyType, UnboundType, TypeVisitor, FormalArgument, NoneType, Instance, TypeVarType, CallableType, TupleType, TypedDictType, UnionType, Overloaded, ErasedType, PartialType, DeletedType, UninhabitedType, TypeType, is_named_instance, - FunctionLike, TypeOfAny, LiteralType, ProperType, get_proper_type + FunctionLike, TypeOfAny, LiteralType, ProperType, get_proper_type, TypeAliasType ) import mypy.applytype import mypy.constraints @@ -63,6 +63,46 @@ def is_subtype(left: Type, right: Type, between the type arguments (e.g., A and B), taking the variance of the type var into account. """ + if TypeState.is_assumed_subtype(left, right): + return True + if (isinstance(left, TypeAliasType) and isinstance(right, TypeAliasType) and + left.is_recursive and right.is_recursive): + # This case requires special care because it may cause infinite recursion. + # Our view on recursive types is known under a fancy name of equirecursive mu-types. + # Roughly this means that a recursive type is defined as an alias where right hand side + # can refer to the type as a whole, for example: + # A = Union[int, Tuple[A, ...]] + # and an alias unrolled once represents the *same type*, in our case all these represent + # the same type: + # A + # Union[int, Tuple[A, ...]] + # Union[int, Tuple[Union[int, Tuple[A, ...]], ...]] + # The algorithm for subtyping is then essentially under the assumption that left <: right, + # check that get_proper_type(left) <: get_proper_type(right). On the example above, + # If we start with: + # A = Union[int, Tuple[A, ...]] + # B = Union[int, Tuple[B, ...]] + # When checking if A <: B we push pair (A, B) onto 'assuming' stack, then when after few + # steps we come back to initial call is_subtype(A, B) and immediately return True. + with pop_on_exit(TypeState._assuming, left, right): + return _is_subtype(left, right, + ignore_type_params=ignore_type_params, + ignore_pos_arg_names=ignore_pos_arg_names, + ignore_declared_variance=ignore_declared_variance, + ignore_promotions=ignore_promotions) + return _is_subtype(left, right, + ignore_type_params=ignore_type_params, + ignore_pos_arg_names=ignore_pos_arg_names, + ignore_declared_variance=ignore_declared_variance, + ignore_promotions=ignore_promotions) + + +def _is_subtype(left: Type, right: Type, + *, + ignore_type_params: bool = False, + ignore_pos_arg_names: bool = False, + ignore_declared_variance: bool = False, + ignore_promotions: bool = False) -> bool: left = get_proper_type(left) right = get_proper_type(right) @@ -433,10 +473,16 @@ def visit_type_type(self, left: TypeType) -> bool: return metaclass is not None and self._is_subtype(metaclass, right) return False + def visit_type_alias_type(self, left: TypeAliasType) -> bool: + assert False, "This should be never called, got {}".format(left) + + +T = TypeVar('T', Instance, TypeAliasType) + @contextmanager -def pop_on_exit(stack: List[Tuple[Instance, Instance]], - left: Instance, right: Instance) -> Iterator[None]: +def pop_on_exit(stack: List[Tuple[T, T]], + left: T, right: T) -> Iterator[None]: stack.append((left, right)) yield stack.pop() @@ -1038,7 +1084,7 @@ def restrict_subtype_away(t: Type, s: Type, *, ignore_promotions: bool = False) if isinstance(t, UnionType): new_items = [item for item in t.relevant_items() - if (isinstance(item, AnyType) or + if (isinstance(get_proper_type(item), AnyType) or not covers_at_runtime(item, s, ignore_promotions))] return UnionType.make_union(new_items) else: @@ -1076,6 +1122,23 @@ def is_proper_subtype(left: Type, right: Type, *, ignore_promotions: bool = Fals If erase_instances is True, erase left instance *after* mapping it to supertype (this is useful for runtime isinstance() checks). """ + if TypeState.is_assumed_proper_subtype(left, right): + return True + if (isinstance(left, TypeAliasType) and isinstance(right, TypeAliasType) and + left.is_recursive and right.is_recursive): + # This case requires special care because it may cause infinite recursion. + # See is_subtype() for more info. + with pop_on_exit(TypeState._assuming_proper, left, right): + return _is_proper_subtype(left, right, + ignore_promotions=ignore_promotions, + erase_instances=erase_instances) + return _is_proper_subtype(left, right, + ignore_promotions=ignore_promotions, + erase_instances=erase_instances) + + +def _is_proper_subtype(left: Type, right: Type, *, ignore_promotions: bool = False, + erase_instances: bool = False) -> bool: left = get_proper_type(left) right = get_proper_type(right) @@ -1281,6 +1344,9 @@ def visit_type_type(self, left: TypeType) -> bool: return metaclass is not None and self._is_proper_subtype(metaclass, right) return False + def visit_type_alias_type(self, left: TypeAliasType) -> bool: + assert False, "This should be never called, got {}".format(left) + def is_more_precise(left: Type, right: Type, *, ignore_promotions: bool = False) -> bool: """Check if left is a more precise type than right. diff --git a/mypy/suggestions.py b/mypy/suggestions.py index 7bb4a583d0cf..44d15d2edd1e 100644 --- a/mypy/suggestions.py +++ b/mypy/suggestions.py @@ -33,7 +33,7 @@ TypeVarType, FunctionLike, TypeStrVisitor, TypeTranslator, is_optional, remove_optional, ProperType, get_proper_type, - TypedDictType + TypedDictType, TypeAliasType ) from mypy.build import State, Graph from mypy.nodes import ( @@ -688,7 +688,7 @@ def score_type(self, t: Type, arg_pos: bool) -> int: if arg_pos and isinstance(t, NoneType): return 20 if isinstance(t, UnionType): - if any(isinstance(x, AnyType) for x in t.items): + if any(isinstance(get_proper_type(x), AnyType) for x in t.items): return 20 if any(has_any_type(x) for x in t.items): return 15 @@ -716,7 +716,7 @@ def any_score_type(ut: Type, arg_pos: bool) -> float: if isinstance(t, NoneType) and arg_pos: return 0.5 if isinstance(t, UnionType): - if any(isinstance(x, AnyType) for x in t.items): + if any(isinstance(get_proper_type(x), AnyType) for x in t.items): return 0.5 if any(has_any_type(x) for x in t.items): return 0.25 @@ -825,6 +825,12 @@ class StrToText(TypeTranslator): def __init__(self, builtin_type: Callable[[str], Instance]) -> None: self.text_type = builtin_type('builtins.unicode') + def visit_type_alias_type(self, t: TypeAliasType) -> Type: + exp_t = get_proper_type(t) + if isinstance(exp_t, Instance) and exp_t.type.fullname() == 'builtins.str': + return self.text_type + return t.copy_modified(args=[a.accept(self) for a in t.args]) + def visit_instance(self, t: Instance) -> Type: if t.type.fullname() == 'builtins.str': return self.text_type diff --git a/mypy/test/testtypes.py b/mypy/test/testtypes.py index 604ce941e5e8..4609e0dd1a02 100644 --- a/mypy/test/testtypes.py +++ b/mypy/test/testtypes.py @@ -8,10 +8,11 @@ from mypy.join import join_types, join_simple from mypy.meet import meet_types, narrow_declared_type from mypy.sametypes import is_same_type +from mypy.indirection import TypeIndirectionVisitor from mypy.types import ( UnboundType, AnyType, CallableType, TupleType, TypeVarDef, Type, Instance, NoneType, Overloaded, TypeType, UnionType, UninhabitedType, TypeVarId, TypeOfAny, - LiteralType, + LiteralType, get_proper_type ) from mypy.nodes import ARG_POS, ARG_OPT, ARG_STAR, ARG_STAR2, CONTRAVARIANT, INVARIANT, COVARIANT from mypy.subtypes import is_subtype, is_more_precise, is_proper_subtype @@ -92,6 +93,39 @@ def test_generic_function_type(self) -> None: c2 = CallableType([], [], [], NoneType(), self.function, name=None, variables=v) assert_equal(str(c2), 'def [Y, X] ()') + def test_type_alias_expand_once(self) -> None: + A, target = self.fx.def_alias_1(self.fx.a) + assert get_proper_type(A) == target + assert get_proper_type(target) == target + + A, target = self.fx.def_alias_2(self.fx.a) + assert get_proper_type(A) == target + assert get_proper_type(target) == target + + def test_type_alias_expand_all(self) -> None: + A, _ = self.fx.def_alias_1(self.fx.a) + assert A.expand_all_if_possible() is None + A, _ = self.fx.def_alias_2(self.fx.a) + assert A.expand_all_if_possible() is None + + B = self.fx.non_rec_alias(self.fx.a) + C = self.fx.non_rec_alias(TupleType([B, B], Instance(self.fx.std_tuplei, + [B]))) + assert C.expand_all_if_possible() == TupleType([self.fx.a, self.fx.a], + Instance(self.fx.std_tuplei, + [self.fx.a])) + + def test_indirection_no_infinite_recursion(self) -> None: + A, _ = self.fx.def_alias_1(self.fx.a) + visitor = TypeIndirectionVisitor() + modules = A.accept(visitor) + assert modules == {'__main__', 'builtins'} + + A, _ = self.fx.def_alias_2(self.fx.a) + visitor = TypeIndirectionVisitor() + modules = A.accept(visitor) + assert modules == {'__main__', 'builtins'} + class TypeOpsSuite(Suite): def setUp(self) -> None: @@ -109,6 +143,12 @@ def test_trivial_expand(self) -> None: self.assert_expand(t, [], t) self.assert_expand(t, [], t) + def test_trivial_expand_recursive(self) -> None: + A, _ = self.fx.def_alias_1(self.fx.a) + self.assert_expand(A, [], A) + A, _ = self.fx.def_alias_2(self.fx.a) + self.assert_expand(A, [], A) + def test_expand_naked_type_var(self) -> None: self.assert_expand(self.fx.t, [(self.fx.t.id, self.fx.a)], self.fx.a) self.assert_expand(self.fx.t, [(self.fx.s.id, self.fx.a)], self.fx.t) @@ -149,6 +189,13 @@ def test_erase_with_generic_type(self) -> None: self.assert_erase(self.fx.hab, Instance(self.fx.hi, [self.fx.anyt, self.fx.anyt])) + def test_erase_with_generic_type_recursive(self) -> None: + tuple_any = Instance(self.fx.std_tuplei, [AnyType(TypeOfAny.explicit)]) + A, _ = self.fx.def_alias_1(self.fx.a) + self.assert_erase(A, tuple_any) + A, _ = self.fx.def_alias_2(self.fx.a) + self.assert_erase(A, UnionType([self.fx.a, tuple_any])) + def test_erase_with_tuple_type(self) -> None: self.assert_erase(self.tuple(self.fx.a), self.fx.std_tuple) @@ -280,6 +327,27 @@ def test_is_proper_subtype_and_subtype_literal_types(self) -> None: assert_true(is_subtype(lit1, fx.anyt)) assert_true(is_subtype(fx.anyt, lit1)) + def test_subtype_aliases(self) -> None: + A1, _ = self.fx.def_alias_1(self.fx.a) + AA1, _ = self.fx.def_alias_1(self.fx.a) + assert_true(is_subtype(A1, AA1)) + assert_true(is_subtype(AA1, A1)) + + A2, _ = self.fx.def_alias_2(self.fx.a) + AA2, _ = self.fx.def_alias_2(self.fx.a) + assert_true(is_subtype(A2, AA2)) + assert_true(is_subtype(AA2, A2)) + + B1, _ = self.fx.def_alias_1(self.fx.b) + B2, _ = self.fx.def_alias_2(self.fx.b) + assert_true(is_subtype(B1, A1)) + assert_true(is_subtype(B2, A2)) + assert_false(is_subtype(A1, B1)) + assert_false(is_subtype(A2, B2)) + + assert_false(is_subtype(A2, A1)) + assert_true(is_subtype(A1, A2)) + # can_be_true / can_be_false def test_empty_tuple_always_false(self) -> None: diff --git a/mypy/test/typefixture.py b/mypy/test/typefixture.py index 5f6680718fb8..b29f7164c911 100644 --- a/mypy/test/typefixture.py +++ b/mypy/test/typefixture.py @@ -3,15 +3,16 @@ It contains class TypeInfos and Type objects. """ -from typing import List, Optional +from typing import List, Optional, Tuple from mypy.types import ( Type, TypeVarType, AnyType, NoneType, Instance, CallableType, TypeVarDef, TypeType, - UninhabitedType, TypeOfAny + UninhabitedType, TypeOfAny, TypeAliasType, UnionType ) from mypy.nodes import ( TypeInfo, ClassDef, Block, ARG_POS, ARG_OPT, ARG_STAR, SymbolTable, - COVARIANT) + COVARIANT, TypeAlias +) class TypeFixture: @@ -238,6 +239,26 @@ def make_type_info(self, name: str, return info + def def_alias_1(self, base: Instance) -> Tuple[TypeAliasType, Type]: + A = TypeAliasType(None, []) + target = Instance(self.std_tuplei, + [UnionType([base, A])]) # A = Tuple[Union[base, A], ...] + AN = TypeAlias(target, '__main__.A', -1, -1) + A.alias = AN + return A, target + + def def_alias_2(self, base: Instance) -> Tuple[TypeAliasType, Type]: + A = TypeAliasType(None, []) + target = UnionType([base, + Instance(self.std_tuplei, [A])]) # A = Union[base, Tuple[A, ...]] + AN = TypeAlias(target, '__main__.A', -1, -1) + A.alias = AN + return A, target + + def non_rec_alias(self, target: Type) -> TypeAliasType: + AN = TypeAlias(target, '__main__.A', -1, -1) + return TypeAliasType(AN, []) + class InterfaceTypeFixture(TypeFixture): """Extension of TypeFixture that contains additional generic diff --git a/mypy/type_visitor.py b/mypy/type_visitor.py index 7f001eed1f33..e2812364f25e 100644 --- a/mypy/type_visitor.py +++ b/mypy/type_visitor.py @@ -23,7 +23,7 @@ RawExpressionType, Instance, NoneType, TypeType, UnionType, TypeVarType, PartialType, DeletedType, UninhabitedType, TypeVarDef, UnboundType, ErasedType, StarType, EllipsisType, TypeList, CallableArgument, - PlaceholderType, TypeAliasType + PlaceholderType, TypeAliasType, get_proper_type ) @@ -98,8 +98,9 @@ def visit_partial_type(self, t: PartialType) -> T: def visit_type_type(self, t: TypeType) -> T: pass + @abstractmethod def visit_type_alias_type(self, t: TypeAliasType) -> T: - raise NotImplementedError('TODO') + pass @trait @@ -232,6 +233,14 @@ def visit_overloaded(self, t: Overloaded) -> Type: def visit_type_type(self, t: TypeType) -> Type: return TypeType.make_normalized(t.item.accept(self), line=t.line, column=t.column) + @abstractmethod + def visit_type_alias_type(self, t: TypeAliasType) -> Type: + # This method doesn't have a default implementation for type translators, + # because type aliases are special: some information is contained in the + # TypeAlias node, and we normally don't generate new nodes. Every subclass + # must implement this depending on its semantics. + pass + @trait class TypeQuery(SyntheticTypeVisitor[T]): @@ -313,6 +322,9 @@ def visit_ellipsis_type(self, t: EllipsisType) -> T: def visit_placeholder_type(self, t: PlaceholderType) -> T: return self.query_types(t.args) + def visit_type_alias_type(self, t: TypeAliasType) -> T: + return get_proper_type(t).accept(self) + def query_types(self, types: Iterable[Type]) -> T: """Perform a query for a list of types. diff --git a/mypy/typeanal.py b/mypy/typeanal.py index 4c36355e80b6..e455f80dc12e 100644 --- a/mypy/typeanal.py +++ b/mypy/typeanal.py @@ -16,7 +16,7 @@ CallableType, NoneType, ErasedType, DeletedType, TypeList, TypeVarDef, SyntheticTypeVisitor, StarType, PartialType, EllipsisType, UninhabitedType, TypeType, replace_alias_tvars, CallableArgument, TypeQuery, union_items, TypeOfAny, LiteralType, RawExpressionType, - PlaceholderType, Overloaded, get_proper_type, ProperType + PlaceholderType, Overloaded, get_proper_type, TypeAliasType ) from mypy.nodes import ( @@ -310,8 +310,7 @@ def try_analyze_special_unbound_type(self, t: UnboundType, fullname: str) -> Opt if len(t.args) != 1: self.fail('ClassVar[...] must have at most one type argument', t) return AnyType(TypeOfAny.from_error) - item = self.anal_type(t.args[0]) - return item + return self.anal_type(t.args[0]) elif fullname in ('mypy_extensions.NoReturn', 'typing.NoReturn'): return UninhabitedType(is_noreturn=True) elif fullname in ('typing_extensions.Literal', 'typing.Literal'): @@ -483,6 +482,9 @@ def visit_callable_argument(self, t: CallableArgument) -> Type: def visit_instance(self, t: Instance) -> Type: return t + def visit_type_alias_type(self, t: TypeAliasType) -> Type: + return t + def visit_type_var(self, t: TypeVarType) -> Type: return t @@ -1021,7 +1023,7 @@ def set_any_tvars(tp: Type, vars: List[str], from_error: bool = False, disallow_any: bool = False, fail: Optional[MsgCallback] = None, - unexpanded_type: Optional[Type] = None) -> ProperType: + unexpanded_type: Optional[Type] = None) -> Type: if from_error or disallow_any: type_of_any = TypeOfAny.from_error else: diff --git a/mypy/typeops.py b/mypy/typeops.py index 8db2158d809c..53b65fefcad2 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -11,7 +11,7 @@ TupleType, Instance, FunctionLike, Type, CallableType, TypeVarDef, Overloaded, TypeVarType, UninhabitedType, FormalArgument, UnionType, NoneType, AnyType, TypeOfAny, TypeType, ProperType, LiteralType, get_proper_type, get_proper_types, - copy_type + copy_type, TypeAliasType ) from mypy.nodes import ( FuncBase, FuncItem, OverloadedFuncDef, TypeInfo, TypeVar, ARG_STAR, ARG_STAR2, Expression, @@ -26,6 +26,12 @@ from mypy import state +def is_recursive_pair(s: Type, t: Type) -> bool: + """Is this a pair of recursive type aliases?""" + return (isinstance(s, TypeAliasType) and isinstance(t, TypeAliasType) and + s.is_recursive and t.is_recursive) + + def tuple_fallback(typ: TupleType) -> Instance: """Return fallback type for a tuple.""" from mypy.join import join_type_list @@ -302,7 +308,7 @@ def make_simplified_union(items: Sequence[Type], all_items = [] # type: List[ProperType] for typ in items: if isinstance(typ, UnionType): - all_items.extend(typ.items) + all_items.extend(get_proper_types(typ.items)) else: all_items.append(typ) items = all_items diff --git a/mypy/types.py b/mypy/types.py index 34899c40f824..b2c689f537e2 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -154,15 +154,23 @@ class TypeAliasType(Type): during semantic analysis, but create an instance of this type that records the target alias definition node (mypy.nodes.TypeAlias) and type arguments (for generic aliases). - This is very similar to how TypeInfo vs Instance interact. + This is very similar to how TypeInfo vs Instance interact, where a recursive class-based + structure like + class Node: + value: int + children: List[Node] + can be represented in a tree-like manner. """ + __slots__ = ('alias', 'args', 'line', 'column', 'type_ref', '_is_recursive') + def __init__(self, alias: Optional[mypy.nodes.TypeAlias], args: List[Type], line: int = -1, column: int = -1) -> None: - super().__init__(line, column) self.alias = alias self.args = args self.type_ref = None # type: Optional[str] + self._is_recursive = None # type: Optional[bool] + super().__init__(line, column) def _expand_once(self) -> Type: """Expand to the target type exactly once. @@ -175,25 +183,41 @@ def _expand_once(self) -> Type: return replace_alias_tvars(self.alias.target, self.alias.alias_tvars, self.args, self.line, self.column) + def _partial_expansion(self) -> Tuple['ProperType', bool]: + # Private method mostly for debugging and testing. + unroller = UnrollAliasVisitor(set()) + unrolled = self.accept(unroller) + assert isinstance(unrolled, ProperType) + return unrolled, unroller.recursed + def expand_all_if_possible(self) -> Optional['ProperType']: """Attempt a full expansion of the type alias (including nested aliases). If the expansion is not possible, i.e. the alias is (mutually-)recursive, return None. """ - raise NotImplementedError('TODO') + unrolled, recursed = self._partial_expansion() + if recursed: + return None + return unrolled - # TODO: remove ignore caused by https://github.com/python/mypy/issues/6759 @property - def can_be_true(self) -> bool: # type: ignore[override] - assert self.alias is not None - return self.alias.target.can_be_true + def is_recursive(self) -> bool: + if self._is_recursive is not None: + return self._is_recursive + is_recursive = self.expand_all_if_possible() is None + self._is_recursive = is_recursive + return is_recursive - # TODO: remove ignore caused by https://github.com/python/mypy/issues/6759 - @property - def can_be_false(self) -> bool: # type: ignore[override] - assert self.alias is not None - return self.alias.target.can_be_false + def can_be_true_default(self) -> bool: + if self.alias is not None: + return self.alias.target.can_be_true + return super().can_be_true_default() + + def can_be_false_default(self) -> bool: + if self.alias is not None: + return self.alias.target.can_be_false + return super().can_be_false_default() def accept(self, visitor: 'TypeVisitor[T]') -> T: return visitor.visit_type_alias_type(self) @@ -202,6 +226,7 @@ def __hash__(self) -> int: return hash((self.alias, tuple(self.args))) def __eq__(self, other: object) -> bool: + # Note: never use this to determine subtype relationships, use is_subtype(). if not isinstance(other, TypeAliasType): return NotImplemented return (self.alias == other.alias @@ -223,7 +248,7 @@ def deserialize(cls, data: JsonDict) -> 'TypeAliasType': assert isinstance(args_list, list) args = [deserialize_type(arg) for arg in args_list] alias = TypeAliasType(None, args) - alias.type_ref = data['type_ref'] # TODO: fix this up in fixup.py. + alias.type_ref = data['type_ref'] return alias def copy_modified(self, *, @@ -1688,14 +1713,14 @@ def has_readable_member(self, name: str) -> bool: """ return all((isinstance(x, UnionType) and x.has_readable_member(name)) or (isinstance(x, Instance) and x.type.has_readable_member(name)) - for x in self.relevant_items()) + for x in get_proper_types(self.relevant_items())) - def relevant_items(self) -> List[ProperType]: + def relevant_items(self) -> List[Type]: """Removes NoneTypes from Unions when strict Optional checking is off.""" if state.strict_optional: return self.items else: - return [i for i in self.items if not isinstance(i, NoneType)] + return [i for i in get_proper_types(self.items) if not isinstance(i, NoneType)] def serialize(self) -> JsonDict: return {'.class': 'UnionType', @@ -1859,6 +1884,32 @@ def serialize(self) -> str: assert False, "Internal error: unresolved placeholder type {}".format(self.fullname) +@overload +def get_proper_type(typ: None) -> None: ... +@overload +def get_proper_type(typ: Type) -> ProperType: ... + + +def get_proper_type(typ: Optional[Type]) -> Optional[ProperType]: + if typ is None: + return None + while isinstance(typ, TypeAliasType): + typ = typ._expand_once() + assert isinstance(typ, ProperType), typ + # TODO: store the name of original type alias on this type, so we can show it in errors. + return typ + + +@overload +def get_proper_types(it: Iterable[Type]) -> List[ProperType]: ... +@overload +def get_proper_types(typ: Iterable[Optional[Type]]) -> List[Optional[ProperType]]: ... + + +def get_proper_types(it: Iterable[Optional[Type]]) -> List[Optional[ProperType]]: # type: ignore + return [get_proper_type(t) for t in it] + + # We split off the type visitor base classes to another module # to make it easier to gradually get modules working with mypyc. # Import them here, after the types are defined. @@ -2055,6 +2106,26 @@ def list_str(self, a: Iterable[Type]) -> str: return ', '.join(res) +class UnrollAliasVisitor(TypeTranslator): + def __init__(self, initial_aliases: Set[TypeAliasType]) -> None: + self.recursed = False + self.initial_aliases = initial_aliases + + def visit_type_alias_type(self, t: TypeAliasType) -> Type: + if t in self.initial_aliases: + self.recursed = True + return AnyType(TypeOfAny.special_form) + # Create a new visitor on encountering a new type alias, so that an alias like + # A = Tuple[B, B] + # B = int + # will not be detected as recursive on the second encounter of B. + subvisitor = UnrollAliasVisitor(self.initial_aliases | {t}) + result = get_proper_type(t).accept(subvisitor) + if subvisitor.recursed: + self.recursed = True + return result + + def strip_type(typ: Type) -> ProperType: """Make a copy of type without 'debugging info' (function name).""" typ = get_proper_type(typ) @@ -2082,85 +2153,61 @@ def copy_type(t: TP) -> TP: return copy.copy(t) +class InstantiateAliasVisitor(TypeTranslator): + def __init__(self, vars: List[str], subs: List[Type]) -> None: + self.replacements = {v: s for (v, s) in zip(vars, subs)} + + def visit_type_alias_type(self, typ: TypeAliasType) -> Type: + return typ.copy_modified(args=[t.accept(self) for t in typ.args]) + + def visit_unbound_type(self, typ: UnboundType) -> Type: + # TODO: stop using unbound type variables for type aliases. + # Now that type aliases are very similar to TypeInfos we should + # make type variable tracking similar as well. Maybe we can even support + # upper bounds etc. for generic type aliases. + if typ.name in self.replacements: + return self.replacements[typ.name] + return typ + + def visit_type_var(self, typ: TypeVarType) -> Type: + if typ.name in self.replacements: + return self.replacements[typ.name] + return typ + + def replace_alias_tvars(tp: Type, vars: List[str], subs: List[Type], - newline: int, newcolumn: int) -> ProperType: + newline: int, newcolumn: int) -> Type: """Replace type variables in a generic type alias tp with substitutions subs resetting context. Length of subs should be already checked. """ - typ_args = get_typ_args(tp) - new_args = typ_args[:] - for i, arg in enumerate(typ_args): - if isinstance(arg, (UnboundType, TypeVarType)): - tvar = arg.name # type: Optional[str] - else: - tvar = None - if tvar and tvar in vars: - # Perform actual substitution... - new_args[i] = subs[vars.index(tvar)] - else: - # ...recursively, if needed. - new_args[i] = replace_alias_tvars(arg, vars, subs, newline, newcolumn) - return set_typ_args(tp, new_args, newline, newcolumn) - - -def get_typ_args(tp: Type) -> List[Type]: - """Get all type arguments from a parametrizable Type.""" - # TODO: replace this and related functions with proper visitors. - tp = get_proper_type(tp) # TODO: is this really needed? - - if not isinstance(tp, (Instance, UnionType, TupleType, CallableType)): - return [] - typ_args = (tp.args if isinstance(tp, Instance) else - tp.items if not isinstance(tp, CallableType) else - tp.arg_types + [tp.ret_type]) - return cast(List[Type], typ_args) - - -def set_typ_args(tp: Type, new_args: List[Type], line: int = -1, column: int = -1) -> ProperType: - """Return a copy of a parametrizable Type with arguments set to new_args.""" - tp = get_proper_type(tp) # TODO: is this really needed? - - if isinstance(tp, Instance): - return Instance(tp.type, new_args, line, column) - if isinstance(tp, TupleType): - return tp.copy_modified(items=new_args) - if isinstance(tp, UnionType): - return UnionType(new_args, line, column) - if isinstance(tp, CallableType): - return tp.copy_modified(arg_types=new_args[:-1], ret_type=new_args[-1], - line=line, column=column) - return tp - - -def get_type_vars(typ: Type) -> List[TypeVarType]: - """Get all type variables that are present in an already analyzed type, - without duplicates, in order of textual appearance. - Similar to TypeAnalyser.get_type_var_names. - """ - all_vars = [] # type: List[TypeVarType] - for t in get_typ_args(typ): - if isinstance(t, TypeVarType): - all_vars.append(t) - else: - all_vars.extend(get_type_vars(t)) - # Remove duplicates while preserving order - included = set() # type: Set[TypeVarId] - tvars = [] - for var in all_vars: - if var.id not in included: - tvars.append(var) - included.add(var.id) - return tvars - - -def flatten_nested_unions(types: Iterable[Type]) -> List[ProperType]: + replacer = InstantiateAliasVisitor(vars, subs) + new_tp = tp.accept(replacer) + new_tp.line = newline + new_tp.column = newcolumn + return new_tp + + +class HasTypeVars(TypeQuery[bool]): + def __init__(self) -> None: + super().__init__(any) + + def visit_type_var(self, t: TypeVarType) -> bool: + return True + + +def has_type_vars(typ: Type) -> bool: + """Check if a type contains any type variables (recursively).""" + return typ.accept(HasTypeVars()) + + +def flatten_nested_unions(types: Iterable[Type]) -> List[Type]: """Flatten nested unions in a type list.""" # This and similar functions on unions can cause infinite recursion # if passed a "pathological" alias like A = Union[int, A] or similar. # TODO: ban such aliases in semantic analyzer. - flat_items = [] # type: List[ProperType] - for tp in get_proper_types(types): - if isinstance(tp, UnionType): + flat_items = [] # type: List[Type] + for tp in types: + if isinstance(tp, ProperType) and isinstance(tp, UnionType): flat_items.extend(flatten_nested_unions(tp.items)) else: flat_items.append(tp) @@ -2189,13 +2236,15 @@ def is_generic_instance(tp: Type) -> bool: def is_optional(t: Type) -> bool: t = get_proper_type(t) - return isinstance(t, UnionType) and any(isinstance(e, NoneType) for e in t.items) + return isinstance(t, UnionType) and any(isinstance(get_proper_type(e), NoneType) + for e in t.items) -def remove_optional(typ: Type) -> ProperType: +def remove_optional(typ: Type) -> Type: typ = get_proper_type(typ) if isinstance(typ, UnionType): - return UnionType.make_union([t for t in typ.items if not isinstance(t, NoneType)]) + return UnionType.make_union([t for t in typ.items + if not isinstance(get_proper_type(t), NoneType)]) else: return typ @@ -2211,31 +2260,6 @@ def is_literal_type(typ: ProperType, fallback_fullname: str, value: LiteralValue return typ.value == value -@overload -def get_proper_type(typ: None) -> None: ... -@overload -def get_proper_type(typ: Type) -> ProperType: ... - - -def get_proper_type(typ: Optional[Type]) -> Optional[ProperType]: - if typ is None: - return None - while isinstance(typ, TypeAliasType): - typ = typ._expand_once() - assert isinstance(typ, ProperType), typ - return typ - - -@overload -def get_proper_types(it: Iterable[Type]) -> List[ProperType]: ... -@overload -def get_proper_types(typ: Iterable[Optional[Type]]) -> List[Optional[ProperType]]: ... - - -def get_proper_types(it: Iterable[Optional[Type]]) -> List[Optional[ProperType]]: # type: ignore - return [get_proper_type(t) for t in it] - - names = globals().copy() # type: Final names.pop('NOT_READY', None) deserialize_map = { diff --git a/mypy/typestate.py b/mypy/typestate.py index c4a2554350e7..87dba5cbf601 100644 --- a/mypy/typestate.py +++ b/mypy/typestate.py @@ -3,11 +3,11 @@ and potentially other mutable TypeInfo state. This module contains mutable global state. """ -from typing import Dict, Set, Tuple, Optional +from typing import Dict, Set, Tuple, Optional, List from typing_extensions import ClassVar, Final from mypy.nodes import TypeInfo -from mypy.types import Instance +from mypy.types import Instance, TypeAliasType, get_proper_type, Type from mypy.server.trigger import make_trigger from mypy import state @@ -75,11 +75,36 @@ class TypeState: # a re-checked target) during the update. _rechecked_types = set() # type: Final[Set[TypeInfo]] + # The two attributes below are assumption stacks for subtyping relationships between + # recursive type aliases. Normally, one would pass type assumptions as an additional + # arguments to is_subtype(), but this would mean updating dozens of related functions + # threading this through all callsites (see also comment for TypeInfo.assuming). + _assuming = [] # type: Final[List[Tuple[TypeAliasType, TypeAliasType]]] + _assuming_proper = [] # type: Final[List[Tuple[TypeAliasType, TypeAliasType]]] + # Ditto for inference of generic constraints against recursive type aliases. + _inferring = [] # type: Final[List[TypeAliasType]] + # N.B: We do all of the accesses to these properties through # TypeState, instead of making these classmethods and accessing # via the cls parameter, since mypyc can optimize accesses to # Final attributes of a directly referenced type. + @staticmethod + def is_assumed_subtype(left: Type, right: Type) -> bool: + for (l, r) in reversed(TypeState._assuming): + if (get_proper_type(l) == get_proper_type(left) + and get_proper_type(r) == get_proper_type(right)): + return True + return False + + @staticmethod + def is_assumed_proper_subtype(left: Type, right: Type) -> bool: + for (l, r) in reversed(TypeState._assuming_proper): + if (get_proper_type(l) == get_proper_type(left) + and get_proper_type(r) == get_proper_type(right)): + return True + return False + @staticmethod def reset_all_subtype_caches() -> None: """Completely reset all known subtype caches.""" diff --git a/mypy/typetraverser.py b/mypy/typetraverser.py index 86c4313f57fa..8d7459f7a551 100644 --- a/mypy/typetraverser.py +++ b/mypy/typetraverser.py @@ -6,7 +6,7 @@ Type, SyntheticTypeVisitor, AnyType, UninhabitedType, NoneType, ErasedType, DeletedType, TypeVarType, LiteralType, Instance, CallableType, TupleType, TypedDictType, UnionType, Overloaded, TypeType, CallableArgument, UnboundType, TypeList, StarType, EllipsisType, - PlaceholderType, PartialType, RawExpressionType + PlaceholderType, PartialType, RawExpressionType, TypeAliasType ) @@ -94,6 +94,9 @@ def visit_partial_type(self, t: PartialType) -> None: def visit_raw_expression_type(self, t: RawExpressionType) -> None: pass + def visit_type_alias_type(self, t: TypeAliasType) -> None: + self.traverse_types(t.args) + # Helpers def traverse_types(self, types: Iterable[Type]) -> None: