diff --git a/misc/proper_plugin.py b/misc/proper_plugin.py new file mode 100644 index 000000000000..c38fbc023967 --- /dev/null +++ b/misc/proper_plugin.py @@ -0,0 +1,70 @@ +from mypy.plugin import Plugin, FunctionContext +from mypy.types import Type, Instance, CallableType, UnionType, get_proper_type + +import os.path +from typing_extensions import Type as typing_Type +from typing import Optional, Callable + +FILE_WHITELIST = [ + 'checker.py', + 'checkexpr.py', + 'checkmember.py', + 'messages.py', + 'semanal.py', + 'typeanal.py' +] + + +class ProperTypePlugin(Plugin): + """ + A plugin to ensure that every type is expanded before doing any special-casing. + + This solves the problem that we have hundreds of call sites like: + + if isinstance(typ, UnionType): + ... # special-case union + + But after introducing a new type TypeAliasType (and removing immediate expansion) + all these became dangerous because typ may be e.g. an alias to union. + """ + def get_function_hook(self, fullname: str + ) -> Optional[Callable[[FunctionContext], Type]]: + if fullname == 'builtins.isinstance': + return isinstance_proper_hook + return None + + +def isinstance_proper_hook(ctx: FunctionContext) -> Type: + if os.path.split(ctx.api.path)[-1] in FILE_WHITELIST: + return ctx.default_return_type + for arg in ctx.arg_types[0]: + if is_improper_type(arg): + right = get_proper_type(ctx.arg_types[1][0]) + if isinstance(right, CallableType) and right.is_type_obj(): + if right.type_object().fullname() in ('mypy.types.Type', + 'mypy.types.ProperType', + 'mypy.types.TypeAliasType'): + # Special case: things like assert isinstance(typ, ProperType) are always OK. + return ctx.default_return_type + if right.type_object().fullname() in ('mypy.types.UnboundType', + 'mypy.types.TypeVarType'): + # Special case: these are not valid targets for a type alias and thus safe. + return ctx.default_return_type + ctx.api.fail('Never apply isinstance() to unexpanded types;' + ' use mypy.types.get_proper_type() first', ctx.context) + return ctx.default_return_type + + +def is_improper_type(typ: Type) -> bool: + """Is this a type that is not a subtype of ProperType?""" + typ = get_proper_type(typ) + if isinstance(typ, Instance): + info = typ.type + return info.has_base('mypy.types.Type') and not info.has_base('mypy.types.ProperType') + if isinstance(typ, UnionType): + return any(is_improper_type(t) for t in typ.items) + return False + + +def plugin(version: str) -> typing_Type[ProperTypePlugin]: + return ProperTypePlugin diff --git a/mypy/applytype.py b/mypy/applytype.py index b72253385c36..de678d04f442 100644 --- a/mypy/applytype.py +++ b/mypy/applytype.py @@ -3,7 +3,9 @@ import mypy.subtypes import mypy.sametypes from mypy.expandtype import expand_type -from mypy.types import Type, TypeVarId, TypeVarType, CallableType, AnyType, PartialType +from mypy.types import ( + Type, TypeVarId, TypeVarType, CallableType, AnyType, PartialType, get_proper_types +) from mypy.messages import MessageBuilder from mypy.nodes import Context @@ -25,10 +27,10 @@ def apply_generic_arguments(callable: CallableType, orig_types: Sequence[Optiona assert len(tvars) == len(orig_types) # Check that inferred type variable values are compatible with allowed # values and bounds. Also, promote subtype values to allowed values. - types = list(orig_types) + types = get_proper_types(orig_types) for i, type in enumerate(types): assert not isinstance(type, PartialType), "Internal error: must never apply partial type" - values = callable.variables[i].values + values = get_proper_types(callable.variables[i].values) if type is None: continue if values: diff --git a/mypy/argmap.py b/mypy/argmap.py index 8305a371e0a6..62c312e78a83 100644 --- a/mypy/argmap.py +++ b/mypy/argmap.py @@ -2,7 +2,9 @@ from typing import List, Optional, Sequence, Callable, Set -from mypy.types import Type, Instance, TupleType, AnyType, TypeOfAny, TypedDictType +from mypy.types import ( + Type, Instance, TupleType, AnyType, TypeOfAny, TypedDictType, get_proper_type +) from mypy import nodes @@ -34,7 +36,7 @@ def map_actuals_to_formals(actual_kinds: List[int], formal_to_actual[fi].append(ai) elif actual_kind == nodes.ARG_STAR: # We need to know the actual type to map varargs. - actualt = actual_arg_type(ai) + actualt = get_proper_type(actual_arg_type(ai)) if isinstance(actualt, TupleType): # A tuple actual maps to a fixed number of formals. for _ in range(len(actualt.items)): @@ -65,7 +67,7 @@ def map_actuals_to_formals(actual_kinds: List[int], formal_to_actual[formal_kinds.index(nodes.ARG_STAR2)].append(ai) else: assert actual_kind == nodes.ARG_STAR2 - actualt = actual_arg_type(ai) + actualt = get_proper_type(actual_arg_type(ai)) if isinstance(actualt, TypedDictType): for name, value in actualt.items.items(): if name in formal_names: @@ -153,6 +155,7 @@ def expand_actual_type(self, This is supposed to be called for each formal, in order. Call multiple times per formal if multiple actuals map to a formal. """ + actual_type = get_proper_type(actual_type) if actual_kind == nodes.ARG_STAR: if isinstance(actual_type, Instance): if actual_type.type.fullname() == 'builtins.list': diff --git a/mypy/binder.py b/mypy/binder.py index e8ad27aed3e9..45855aa1b9d5 100644 --- a/mypy/binder.py +++ b/mypy/binder.py @@ -4,7 +4,9 @@ from typing import Dict, List, Set, Iterator, Union, Optional, Tuple, cast from typing_extensions import DefaultDict -from mypy.types import Type, AnyType, PartialType, UnionType, TypeOfAny, NoneType +from mypy.types import ( + Type, AnyType, PartialType, UnionType, TypeOfAny, NoneType, get_proper_type +) from mypy.subtypes import is_subtype from mypy.join import join_simple from mypy.sametypes import is_same_type @@ -191,7 +193,7 @@ def update_from_options(self, frames: List[Frame]) -> bool: type = resulting_values[0] assert type is not None - declaration_type = self.declarations.get(key) + declaration_type = get_proper_type(self.declarations.get(key)) if isinstance(declaration_type, AnyType): # At this point resulting values can't contain None, see continue above if not all(is_same_type(type, cast(Type, t)) for t in resulting_values[1:]): @@ -246,6 +248,9 @@ def assign_type(self, expr: Expression, type: Type, declared_type: Optional[Type], restrict_any: bool = False) -> None: + type = get_proper_type(type) + declared_type = get_proper_type(declared_type) + if self.type_assignments is not None: # We are in a multiassign from union, defer the actual binding, # just collect the types. @@ -270,7 +275,7 @@ def assign_type(self, expr: Expression, # times? return - enclosing_type = self.most_recent_enclosing_type(expr, type) + enclosing_type = get_proper_type(self.most_recent_enclosing_type(expr, type)) if isinstance(enclosing_type, AnyType) and not restrict_any: # If x is Any and y is int, after x = y we do not infer that x is int. # This could be changed. @@ -287,7 +292,8 @@ def assign_type(self, expr: Expression, elif (isinstance(type, AnyType) and isinstance(declared_type, UnionType) and any(isinstance(item, NoneType) for item in declared_type.items) - and isinstance(self.most_recent_enclosing_type(expr, NoneType()), NoneType)): + and isinstance(get_proper_type(self.most_recent_enclosing_type(expr, NoneType())), + NoneType)): # Replace any Nones in the union type with Any new_items = [type if isinstance(item, NoneType) else item for item in declared_type.items] @@ -320,6 +326,7 @@ def invalidate_dependencies(self, expr: BindableExpression) -> None: self._cleanse_key(dep) def most_recent_enclosing_type(self, expr: BindableExpression, type: Type) -> Optional[Type]: + type = get_proper_type(type) if isinstance(type, AnyType): return get_declaration(expr) key = literal_hash(expr) @@ -412,7 +419,7 @@ def top_frame_context(self) -> Iterator[Frame]: def get_declaration(expr: BindableExpression) -> Optional[Type]: if isinstance(expr, RefExpr) and isinstance(expr.node, Var): - type = expr.node.type + type = get_proper_type(expr.node.type) if not isinstance(type, PartialType): return type return None diff --git a/mypy/checker.py b/mypy/checker.py index 97d69645b7d1..6a984220bcdc 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -464,7 +464,7 @@ def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None: impl_type = None # type: Optional[CallableType] if defn.impl: if isinstance(defn.impl, FuncDef): - inner_type = defn.impl.type + inner_type = defn.impl.type # type: Optional[Type] elif isinstance(defn.impl, Decorator): inner_type = defn.impl.var.type else: @@ -3634,8 +3634,8 @@ def find_isinstance_check(self, node: Expression # Restrict the type of the variable to True-ish/False-ish in the if and else branches # respectively vartype = type_map[node] - if_type = true_only(vartype) - else_type = false_only(vartype) + if_type = true_only(vartype) # type: Type + else_type = false_only(vartype) # type: Type ref = node # type: Expression if_map = {ref: if_type} if not isinstance(if_type, UninhabitedType) else None else_map = {ref: else_type} if not isinstance(else_type, UninhabitedType) else None @@ -4122,7 +4122,7 @@ def or_conditional_maps(m1: TypeMap, m2: TypeMap) -> TypeMap: # expressions whose type is refined by both conditions. (We do not # learn anything about expressions whose type is refined by only # one condition.) - result = {} + result = {} # type: Dict[Expression, Type] for n1 in m1: for n2 in m2: if literal_hash(n1) == literal_hash(n2): diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index bf3ca0935fd6..6d437089b930 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -18,7 +18,7 @@ TupleType, TypedDictType, Instance, TypeVarType, ErasedType, UnionType, PartialType, DeletedType, UninhabitedType, TypeType, TypeOfAny, LiteralType, LiteralValue, true_only, false_only, is_named_instance, function_type, callable_type, FunctionLike, - StarType, is_optional, remove_optional, is_generic_instance + StarType, is_optional, remove_optional, is_generic_instance, get_proper_type ) from mypy.nodes import ( NameExpr, RefExpr, Var, FuncDef, OverloadedFuncDef, TypeInfo, CallExpr, @@ -585,6 +585,7 @@ def apply_function_plugin(self, # Apply method plugin method_callback = self.plugin.get_method_hook(fullname) assert method_callback is not None # Assume that caller ensures this + object_type = get_proper_type(object_type) return method_callback( MethodContext(object_type, formal_arg_types, formal_arg_kinds, callee.arg_names, formal_arg_names, @@ -606,6 +607,7 @@ def apply_method_signature_hook( for formal, actuals in enumerate(formal_to_actual): for actual in actuals: formal_arg_exprs[formal].append(args[actual]) + object_type = get_proper_type(object_type) return signature_hook( MethodSigContext(object_type, formal_arg_exprs, callee, context, self.chk)) else: @@ -2702,7 +2704,7 @@ def visit_typeddict_index_expr(self, td_type: TypedDictType, index: Expression) else: typ = self.accept(index) if isinstance(typ, UnionType): - key_types = typ.items + key_types = list(typ.items) # type: List[Type] else: key_types = [typ] @@ -3541,7 +3543,7 @@ def has_member(self, typ: Type, member: str) -> bool: elif isinstance(typ, TypeType): # Type[Union[X, ...]] is always normalized to Union[Type[X], ...], # so we don't need to care about unions here. - item = typ.item + item = typ.item # type: Type if isinstance(item, TypeVarType): item = item.upper_bound if isinstance(item, TupleType): @@ -3735,8 +3737,7 @@ def narrow_type_from_binder(self, expr: Expression, known_type: Type, # noqa not is_overlapping_types(known_type, restriction, prohibit_none_typevar_overlap=True)): return None - ans = narrow_declared_type(known_type, restriction) - return ans + return narrow_declared_type(known_type, restriction) return known_type diff --git a/mypy/checkmember.py b/mypy/checkmember.py index 50adb3973cb7..2a9d87564275 100644 --- a/mypy/checkmember.py +++ b/mypy/checkmember.py @@ -6,7 +6,7 @@ from mypy.types import ( Type, Instance, AnyType, TupleType, TypedDictType, CallableType, FunctionLike, TypeVarDef, Overloaded, TypeVarType, UnionType, PartialType, UninhabitedType, TypeOfAny, LiteralType, - DeletedType, NoneType, TypeType, function_type, get_type_vars, + DeletedType, NoneType, TypeType, function_type, get_type_vars, get_proper_type ) from mypy.nodes import ( TypeInfo, FuncBase, Var, FuncDef, SymbolNode, Context, MypyFile, TypeVarExpr, @@ -371,8 +371,8 @@ def analyze_member_var_access(name: str, fullname = '{}.{}'.format(method.info.fullname(), name) hook = mx.chk.plugin.get_attribute_hook(fullname) if hook: - result = hook(AttributeContext(mx.original_type, result, - mx.context, mx.chk)) + result = hook(AttributeContext(get_proper_type(mx.original_type), + result, mx.context, mx.chk)) return result else: setattr_meth = info.get_method('__setattr__') @@ -511,7 +511,7 @@ def analyze_var(name: str, mx.msg.read_only_property(name, itype.type, mx.context) if mx.is_lvalue and var.is_classvar: mx.msg.cant_assign_to_classvar(name, mx.context) - result = t + result = t # type: Type if var.is_initialized_in_class and isinstance(t, FunctionLike) and not t.is_type_obj(): if mx.is_lvalue: if var.is_property: @@ -552,7 +552,8 @@ def analyze_var(name: str, result = analyze_descriptor_access(mx.original_type, result, mx.builtin_type, mx.msg, mx.context, chk=mx.chk) if hook: - result = hook(AttributeContext(mx.original_type, result, mx.context, mx.chk)) + result = hook(AttributeContext(get_proper_type(mx.original_type), + result, mx.context, mx.chk)) return result diff --git a/mypy/checkstrformat.py b/mypy/checkstrformat.py index 215193e4b6fd..fb3553838682 100644 --- a/mypy/checkstrformat.py +++ b/mypy/checkstrformat.py @@ -6,7 +6,7 @@ from typing_extensions import Final, TYPE_CHECKING from mypy.types import ( - Type, AnyType, TupleType, Instance, UnionType, TypeOfAny + Type, AnyType, TupleType, Instance, UnionType, TypeOfAny, get_proper_type ) from mypy.nodes import ( StrExpr, BytesExpr, UnicodeExpr, TupleExpr, DictExpr, Context, Expression, StarExpr @@ -137,7 +137,7 @@ def check_simple_str_interpolation(self, specifiers: List[ConversionSpecifier], if checkers is None: return - rhs_type = self.accept(replacements) + rhs_type = get_proper_type(self.accept(replacements)) rep_types = [] # type: List[Type] if isinstance(rhs_type, TupleType): rep_types = rhs_type.items diff --git a/mypy/constraints.py b/mypy/constraints.py index 99979dc27e22..aebd1a76680b 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -7,6 +7,7 @@ CallableType, Type, TypeVisitor, UnboundType, AnyType, NoneType, TypeVarType, Instance, TupleType, TypedDictType, UnionType, Overloaded, ErasedType, PartialType, DeletedType, UninhabitedType, TypeType, TypeVarId, TypeQuery, is_named_instance, TypeOfAny, LiteralType, + ProperType, get_proper_type ) from mypy.maptype import map_instance_to_supertype import mypy.subtypes @@ -88,6 +89,8 @@ def infer_constraints(template: Type, actual: Type, The constraints are represented as Constraint objects. """ + template = get_proper_type(template) + actual = get_proper_type(actual) # If the template is simply a type variable, emit a Constraint directly. # We need to handle this case before handling Unions for two reasons: @@ -199,12 +202,12 @@ def is_same_constraint(c1: Constraint, c2: Constraint) -> bool: and mypy.sametypes.is_same_type(c1.target, c2.target)) -def simplify_away_incomplete_types(types: List[Type]) -> List[Type]: +def simplify_away_incomplete_types(types: Iterable[Type]) -> List[Type]: complete = [typ for typ in types if is_complete_type(typ)] if complete: return complete else: - return types + return list(types) def is_complete_type(typ: Type) -> bool: @@ -229,9 +232,9 @@ class ConstraintBuilderVisitor(TypeVisitor[List[Constraint]]): # The type that is compared against a template # TODO: The value may be None. Is that actually correct? - actual = None # type: Type + actual = None # type: ProperType - def __init__(self, actual: Type, direction: int) -> None: + def __init__(self, actual: ProperType, direction: int) -> None: # Direction must be SUBTYPE_OF or SUPERTYPE_OF. self.actual = actual self.direction = direction @@ -298,7 +301,7 @@ def visit_instance(self, template: Instance) -> List[Constraint]: if isinstance(actual, Instance): instance = actual erased = erase_typevars(template) - assert isinstance(erased, Instance) + assert isinstance(erased, Instance) # type: ignore # We always try nominal inference if possible, # it is much faster than the structural one. if (self.direction == SUBTYPE_OF and diff --git a/mypy/erasetype.py b/mypy/erasetype.py index 9521459632d4..ee4a2d9f5bdc 100644 --- a/mypy/erasetype.py +++ b/mypy/erasetype.py @@ -3,12 +3,13 @@ from mypy.types import ( Type, TypeVisitor, UnboundType, AnyType, NoneType, TypeVarId, Instance, TypeVarType, CallableType, TupleType, TypedDictType, UnionType, Overloaded, ErasedType, PartialType, - DeletedType, TypeTranslator, UninhabitedType, TypeType, TypeOfAny, LiteralType, + DeletedType, TypeTranslator, UninhabitedType, TypeType, TypeOfAny, LiteralType, ProperType, + get_proper_type ) from mypy.nodes import ARG_STAR, ARG_STAR2 -def erase_type(typ: Type) -> Type: +def erase_type(typ: Type) -> ProperType: """Erase any type variables from a type. Also replace tuple types with the corresponding concrete types. @@ -20,43 +21,43 @@ def erase_type(typ: Type) -> Type: Callable[[A1, A2, ...], R] -> Callable[..., Any] Type[X] -> Type[Any] """ - + typ = get_proper_type(typ) return typ.accept(EraseTypeVisitor()) -class EraseTypeVisitor(TypeVisitor[Type]): +class EraseTypeVisitor(TypeVisitor[ProperType]): - def visit_unbound_type(self, t: UnboundType) -> Type: + def visit_unbound_type(self, t: UnboundType) -> ProperType: # TODO: replace with an assert after UnboundType can't leak from semantic analysis. return AnyType(TypeOfAny.from_error) - def visit_any(self, t: AnyType) -> Type: + def visit_any(self, t: AnyType) -> ProperType: return t - def visit_none_type(self, t: NoneType) -> Type: + def visit_none_type(self, t: NoneType) -> ProperType: return t - def visit_uninhabited_type(self, t: UninhabitedType) -> Type: + def visit_uninhabited_type(self, t: UninhabitedType) -> ProperType: return t - def visit_erased_type(self, t: ErasedType) -> Type: + def visit_erased_type(self, t: ErasedType) -> ProperType: # Should not get here. raise RuntimeError() - def visit_partial_type(self, t: PartialType) -> Type: + def visit_partial_type(self, t: PartialType) -> ProperType: # Should not get here. raise RuntimeError() - def visit_deleted_type(self, t: DeletedType) -> Type: + def visit_deleted_type(self, t: DeletedType) -> ProperType: return t - def visit_instance(self, t: Instance) -> Type: + def visit_instance(self, t: Instance) -> ProperType: return Instance(t.type, [AnyType(TypeOfAny.special_form)] * len(t.args), t.line) - def visit_type_var(self, t: TypeVarType) -> Type: + def visit_type_var(self, t: TypeVarType) -> ProperType: return AnyType(TypeOfAny.special_form) - def visit_callable_type(self, t: CallableType) -> Type: + def visit_callable_type(self, t: CallableType) -> ProperType: # We must preserve the fallback type for overload resolution to work. any_type = AnyType(TypeOfAny.special_form) return CallableType( @@ -69,26 +70,26 @@ def visit_callable_type(self, t: CallableType) -> Type: implicit=True, ) - def visit_overloaded(self, t: Overloaded) -> Type: + def visit_overloaded(self, t: Overloaded) -> ProperType: return t.fallback.accept(self) - def visit_tuple_type(self, t: TupleType) -> Type: + def visit_tuple_type(self, t: TupleType) -> ProperType: return t.partial_fallback.accept(self) - def visit_typeddict_type(self, t: TypedDictType) -> Type: + def visit_typeddict_type(self, t: TypedDictType) -> ProperType: return t.fallback.accept(self) - def visit_literal_type(self, t: LiteralType) -> Type: + def visit_literal_type(self, t: LiteralType) -> ProperType: # The fallback for literal types should always be either # something like int or str, or an enum class -- types that # don't contain any TypeVars. So there's no need to visit it. return t - def visit_union_type(self, t: UnionType) -> Type: + def visit_union_type(self, t: UnionType) -> ProperType: erased_items = [erase_type(item) for item in t.items] return UnionType.make_simplified_union(erased_items) - def visit_type_type(self, t: TypeType) -> Type: + def visit_type_type(self, t: TypeType) -> ProperType: return TypeType.make_normalized(t.item.accept(self), line=t.line) diff --git a/mypy/expandtype.py b/mypy/expandtype.py index 379342b0361f..0f04fe43f1c0 100644 --- a/mypy/expandtype.py +++ b/mypy/expandtype.py @@ -4,11 +4,11 @@ Type, Instance, CallableType, TypeVisitor, UnboundType, AnyType, NoneType, TypeVarType, Overloaded, TupleType, TypedDictType, UnionType, ErasedType, PartialType, DeletedType, UninhabitedType, TypeType, TypeVarId, - FunctionLike, TypeVarDef, LiteralType, + FunctionLike, TypeVarDef, LiteralType, get_proper_type, ProperType ) -def expand_type(typ: Type, env: Mapping[TypeVarId, Type]) -> Type: +def expand_type(typ: Type, env: Mapping[TypeVarId, Type]) -> ProperType: """Substitute any type variable references in a type given by a type environment. """ @@ -16,9 +16,10 @@ def expand_type(typ: Type, env: Mapping[TypeVarId, Type]) -> Type: return typ.accept(ExpandTypeVisitor(env)) -def expand_type_by_instance(typ: Type, instance: Instance) -> Type: +def expand_type_by_instance(typ: Type, instance: Instance) -> ProperType: """Substitute type variables in type using values from an Instance. Type variables are considered to be bound by the class declaration.""" + typ = get_proper_type(typ) if instance.args == []: return typ @@ -52,7 +53,7 @@ def freshen_function_type_vars(callee: F) -> F: return cast(F, fresh_overload) -class ExpandTypeVisitor(TypeVisitor[Type]): +class ExpandTypeVisitor(TypeVisitor[ProperType]): """Visitor that substitutes type variables with values.""" variables = None # type: Mapping[TypeVarId, Type] # TypeVar id -> TypeVar value @@ -60,31 +61,31 @@ class ExpandTypeVisitor(TypeVisitor[Type]): def __init__(self, variables: Mapping[TypeVarId, Type]) -> None: self.variables = variables - def visit_unbound_type(self, t: UnboundType) -> Type: + def visit_unbound_type(self, t: UnboundType) -> ProperType: return t - def visit_any(self, t: AnyType) -> Type: + def visit_any(self, t: AnyType) -> ProperType: return t - def visit_none_type(self, t: NoneType) -> Type: + def visit_none_type(self, t: NoneType) -> ProperType: return t - def visit_uninhabited_type(self, t: UninhabitedType) -> Type: + def visit_uninhabited_type(self, t: UninhabitedType) -> ProperType: return t - def visit_deleted_type(self, t: DeletedType) -> Type: + def visit_deleted_type(self, t: DeletedType) -> ProperType: return t - def visit_erased_type(self, t: ErasedType) -> Type: + def visit_erased_type(self, t: ErasedType) -> ProperType: # Should not get here. raise RuntimeError() - def visit_instance(self, t: Instance) -> Type: + def visit_instance(self, t: Instance) -> ProperType: args = self.expand_types(t.args) return Instance(t.type, args, t.line, t.column) - def visit_type_var(self, t: TypeVarType) -> Type: - repl = self.variables.get(t.id, t) + def visit_type_var(self, t: TypeVarType) -> ProperType: + repl = get_proper_type(self.variables.get(t.id, t)) if isinstance(repl, Instance): inst = repl # Return copy of instance with type erasure flag on. @@ -93,11 +94,11 @@ def visit_type_var(self, t: TypeVarType) -> Type: else: return repl - def visit_callable_type(self, t: CallableType) -> Type: + def visit_callable_type(self, t: CallableType) -> ProperType: return t.copy_modified(arg_types=self.expand_types(t.arg_types), ret_type=t.ret_type.accept(self)) - def visit_overloaded(self, t: Overloaded) -> Type: + def visit_overloaded(self, t: Overloaded) -> ProperType: items = [] # type: List[CallableType] for item in t.items(): new_item = item.accept(self) @@ -105,25 +106,25 @@ def visit_overloaded(self, t: Overloaded) -> Type: items.append(new_item) return Overloaded(items) - def visit_tuple_type(self, t: TupleType) -> Type: + def visit_tuple_type(self, t: TupleType) -> ProperType: return t.copy_modified(items=self.expand_types(t.items)) - def visit_typeddict_type(self, t: TypedDictType) -> Type: + def visit_typeddict_type(self, t: TypedDictType) -> ProperType: return t.copy_modified(item_types=self.expand_types(t.items.values())) - def visit_literal_type(self, t: LiteralType) -> Type: + def visit_literal_type(self, t: LiteralType) -> ProperType: # TODO: Verify this implementation is correct return t - def visit_union_type(self, t: UnionType) -> Type: + def visit_union_type(self, t: UnionType) -> ProperType: # After substituting for type variables in t.items, # some of the resulting types might be subtypes of others. return UnionType.make_simplified_union(self.expand_types(t.items), t.line, t.column) - def visit_partial_type(self, t: PartialType) -> Type: + def visit_partial_type(self, t: PartialType) -> ProperType: return t - def visit_type_type(self, t: TypeType) -> Type: + def visit_type_type(self, t: TypeType) -> ProperType: # TODO: Verify that the new item type is valid (instance or # union of instances or Any). Sadly we can't report errors # here yet. diff --git a/mypy/exprtotype.py b/mypy/exprtotype.py index 54875fab82b3..dac9063eb946 100644 --- a/mypy/exprtotype.py +++ b/mypy/exprtotype.py @@ -10,7 +10,7 @@ from mypy.fastparse import parse_type_string from mypy.types import ( Type, UnboundType, TypeList, EllipsisType, AnyType, CallableArgument, TypeOfAny, - RawExpressionType, + RawExpressionType, ProperType ) @@ -29,7 +29,7 @@ def _extract_argument_name(expr: Expression) -> Optional[str]: raise TypeTranslationError() -def expr_to_unanalyzed_type(expr: Expression, _parent: Optional[Expression] = None) -> Type: +def expr_to_unanalyzed_type(expr: Expression, _parent: Optional[Expression] = None) -> ProperType: """Translate an expression to the corresponding type. The result is not semantically analyzed. It can be UnboundType or TypeList. diff --git a/mypy/fastparse.py b/mypy/fastparse.py index cf6e374c41a6..0cd4b0a3be70 100644 --- a/mypy/fastparse.py +++ b/mypy/fastparse.py @@ -30,7 +30,7 @@ ) from mypy.types import ( Type, CallableType, AnyType, UnboundType, TupleType, TypeList, EllipsisType, CallableArgument, - TypeOfAny, Instance, RawExpressionType, + TypeOfAny, Instance, RawExpressionType, ProperType ) from mypy import defaults from mypy import message_registry, errorcodes as codes @@ -197,7 +197,7 @@ def parse_type_comment(type_comment: str, column: int, errors: Optional[Errors], assume_str_is_unicode: bool = True, - ) -> Tuple[Optional[List[str]], Optional[Type]]: + ) -> Tuple[Optional[List[str]], Optional[ProperType]]: """Parse type portion of a type comment (+ optional type ignore). Return (ignore info, parsed type). @@ -229,7 +229,7 @@ def parse_type_comment(type_comment: str, def parse_type_string(expr_string: str, expr_fallback_name: str, - line: int, column: int, assume_str_is_unicode: bool = True) -> Type: + line: int, column: int, assume_str_is_unicode: bool = True) -> ProperType: """Parses a type that was originally present inside of an explicit string, byte string, or unicode string. @@ -348,7 +348,7 @@ def translate_stmt_list(self, def translate_type_comment(self, n: Union[ast3.stmt, ast3.arg], - type_comment: Optional[str]) -> Optional[Type]: + type_comment: Optional[str]) -> Optional[ProperType]: if type_comment is None: return None else: @@ -555,7 +555,8 @@ def do_func_def(self, n: Union[ast3.FunctionDef, ast3.AsyncFunctionDef], func_type = None if any(arg_types) or return_type: - if len(arg_types) != 1 and any(isinstance(t, EllipsisType) for t in arg_types): + if len(arg_types) != 1 and any(isinstance(t, EllipsisType) # type: ignore + for t in arg_types): self.fail("Ellipses cannot accompany other argument types " "in function type signature", lineno, n.col_offset) elif len(arg_types) > len(arg_kinds): @@ -1267,12 +1268,12 @@ def invalid_type(self, node: AST, note: Optional[str] = None) -> RawExpressionTy ) @overload - def visit(self, node: ast3.expr) -> Type: ... + def visit(self, node: ast3.expr) -> ProperType: ... @overload # noqa - def visit(self, node: Optional[AST]) -> Optional[Type]: ... # noqa + def visit(self, node: Optional[AST]) -> Optional[ProperType]: ... # noqa - def visit(self, node: Optional[AST]) -> Optional[Type]: # noqa + def visit(self, node: Optional[AST]) -> Optional[ProperType]: # noqa """Modified visit -- keep track of the stack of nodes""" if node is None: return None @@ -1460,7 +1461,7 @@ def visit_Str(self, n: Str) -> Type: # Do an ignore because the field doesn't exist in 3.8 (where # this method doesn't actually ever run.) - kind = n.kind # type: str # type: ignore + kind = n.kind # type: str if 'u' in kind or self.assume_str_is_unicode: return parse_type_string(n.s, 'builtins.unicode', self.line, n.col_offset, diff --git a/mypy/fastparse2.py b/mypy/fastparse2.py index 1545fbde10e6..c925303a9a2f 100644 --- a/mypy/fastparse2.py +++ b/mypy/fastparse2.py @@ -405,7 +405,8 @@ def visit_FunctionDef(self, n: ast27.FunctionDef) -> Statement: func_type = None if any(arg_types) or return_type: - if len(arg_types) != 1 and any(isinstance(t, EllipsisType) for t in arg_types): + if len(arg_types) != 1 and any(isinstance(t, EllipsisType) # type: ignore + for t in arg_types): self.fail("Ellipses cannot accompany other argument types " "in function type signature", lineno, n.col_offset) elif len(arg_types) > len(arg_kinds): diff --git a/mypy/join.py b/mypy/join.py index f651f43203c5..a64ed62cbdb4 100644 --- a/mypy/join.py +++ b/mypy/join.py @@ -6,7 +6,8 @@ from mypy.types import ( Type, AnyType, NoneType, TypeVisitor, Instance, UnboundType, TypeVarType, CallableType, TupleType, TypedDictType, ErasedType, UnionType, FunctionLike, Overloaded, LiteralType, - PartialType, DeletedType, UninhabitedType, TypeType, true_or_false, TypeOfAny, + PartialType, DeletedType, UninhabitedType, TypeType, true_or_false, TypeOfAny, get_proper_type, + ProperType, get_proper_types ) from mypy.maptype import map_instance_to_supertype from mypy.subtypes import ( @@ -18,8 +19,11 @@ from mypy import state -def join_simple(declaration: Optional[Type], s: Type, t: Type) -> Type: +def join_simple(declaration: Optional[Type], s: Type, t: Type) -> ProperType: """Return a simple least upper bound given the declared type.""" + declaration = get_proper_type(declaration) + s = get_proper_type(s) + t = get_proper_type(t) if (s.can_be_true, s.can_be_false) != (t.can_be_true, t.can_be_false): # if types are restricted in different ways, use the more general versions @@ -54,11 +58,14 @@ def join_simple(declaration: Optional[Type], s: Type, t: Type) -> Type: return declaration -def join_types(s: Type, t: Type) -> Type: +def join_types(s: Type, t: Type) -> ProperType: """Return the least upper bound of s and t. For example, the join of 'int' and 'object' is 'object'. """ + s = get_proper_type(s) + t = get_proper_type(t) + if (s.can_be_true, s.can_be_false) != (t.can_be_true, t.can_be_false): # if types are restricted in different ways, use the more general versions s = true_or_false(s) @@ -83,29 +90,29 @@ def join_types(s: Type, t: Type) -> Type: return t.accept(TypeJoinVisitor(s)) -class TypeJoinVisitor(TypeVisitor[Type]): +class TypeJoinVisitor(TypeVisitor[ProperType]): """Implementation of the least upper bound algorithm. Attributes: s: The other (left) type operand. """ - def __init__(self, s: Type) -> None: + def __init__(self, s: ProperType) -> None: self.s = s - def visit_unbound_type(self, t: UnboundType) -> Type: + def visit_unbound_type(self, t: UnboundType) -> ProperType: return AnyType(TypeOfAny.special_form) - def visit_union_type(self, t: UnionType) -> Type: + def visit_union_type(self, t: UnionType) -> ProperType: if is_subtype(self.s, t): return t else: return UnionType.make_simplified_union([self.s, t]) - def visit_any(self, t: AnyType) -> Type: + def visit_any(self, t: AnyType) -> ProperType: return t - def visit_none_type(self, t: NoneType) -> Type: + def visit_none_type(self, t: NoneType) -> ProperType: if state.strict_optional: if isinstance(self.s, (NoneType, UninhabitedType)): return t @@ -116,22 +123,22 @@ def visit_none_type(self, t: NoneType) -> Type: else: return self.s - def visit_uninhabited_type(self, t: UninhabitedType) -> Type: + def visit_uninhabited_type(self, t: UninhabitedType) -> ProperType: return self.s - def visit_deleted_type(self, t: DeletedType) -> Type: + def visit_deleted_type(self, t: DeletedType) -> ProperType: return self.s - def visit_erased_type(self, t: ErasedType) -> Type: + def visit_erased_type(self, t: ErasedType) -> ProperType: return self.s - def visit_type_var(self, t: TypeVarType) -> Type: + def visit_type_var(self, t: TypeVarType) -> ProperType: if isinstance(self.s, TypeVarType) and self.s.id == t.id: return self.s else: return self.default(self.s) - def visit_instance(self, t: Instance) -> Type: + def visit_instance(self, t: Instance) -> ProperType: if isinstance(self.s, Instance): nominal = join_instances(t, self.s) structural = None # type: Optional[Instance] @@ -160,12 +167,13 @@ def visit_instance(self, t: Instance) -> Type: else: return self.default(self.s) - def visit_callable_type(self, t: CallableType) -> Type: + def visit_callable_type(self, t: CallableType) -> ProperType: if isinstance(self.s, CallableType) and is_similar_callables(t, self.s): if is_equivalent(t, self.s): return combine_similar_callables(t, self.s) result = join_similar_callables(t, self.s) - if any(isinstance(tp, (NoneType, UninhabitedType)) for tp in result.arg_types): + if any(isinstance(tp, (NoneType, UninhabitedType)) + for tp in get_proper_types(result.arg_types)): # We don't want to return unusable Callable, attempt fallback instead. return join_types(t.fallback, self.s) return result @@ -178,7 +186,7 @@ def visit_callable_type(self, t: CallableType) -> Type: return join_types(t, call) return join_types(t.fallback, self.s) - def visit_overloaded(self, t: Overloaded) -> Type: + def visit_overloaded(self, t: Overloaded) -> ProperType: # This is more complex than most other cases. Here are some # examples that illustrate how this works. # @@ -231,7 +239,7 @@ def visit_overloaded(self, t: Overloaded) -> Type: return join_types(t, call) return join_types(t.fallback, s) - def visit_tuple_type(self, t: TupleType) -> Type: + def visit_tuple_type(self, t: TupleType) -> ProperType: if isinstance(self.s, TupleType) and self.s.length() == t.length(): items = [] # type: List[Type] for i in range(t.length()): @@ -243,7 +251,7 @@ def visit_tuple_type(self, t: TupleType) -> Type: else: return self.default(self.s) - def visit_typeddict_type(self, t: TypedDictType) -> Type: + def visit_typeddict_type(self, t: TypedDictType) -> ProperType: if isinstance(self.s, TypedDictType): items = OrderedDict([ (item_name, s_item_type) @@ -262,7 +270,7 @@ def visit_typeddict_type(self, t: TypedDictType) -> Type: else: return self.default(self.s) - def visit_literal_type(self, t: LiteralType) -> Type: + def visit_literal_type(self, t: LiteralType) -> ProperType: if isinstance(self.s, LiteralType): if t == self.s: return t @@ -271,12 +279,12 @@ def visit_literal_type(self, t: LiteralType) -> Type: else: return join_types(self.s, t.fallback) - def visit_partial_type(self, t: PartialType) -> Type: + def visit_partial_type(self, t: PartialType) -> ProperType: # We only have partial information so we can't decide the join result. We should # never get here. assert False, "Internal error" - def visit_type_type(self, t: TypeType) -> Type: + def visit_type_type(self, t: TypeType) -> ProperType: if isinstance(self.s, TypeType): return TypeType.make_normalized(self.join(t.item, self.s.item), line=t.line) elif isinstance(self.s, Instance) and self.s.type.fullname() == 'builtins.type': @@ -284,10 +292,11 @@ def visit_type_type(self, t: TypeType) -> Type: else: return self.default(self.s) - def join(self, s: Type, t: Type) -> Type: + def join(self, s: Type, t: Type) -> ProperType: return join_types(s, t) - def default(self, typ: Type) -> Type: + def default(self, typ: Type) -> ProperType: + typ = get_proper_type(typ) if isinstance(typ, Instance): return object_from_instance(typ) elif isinstance(typ, UnboundType): @@ -304,7 +313,7 @@ def default(self, typ: Type) -> Type: return AnyType(TypeOfAny.special_form) -def join_instances(t: Instance, s: Instance) -> Type: +def join_instances(t: Instance, s: Instance) -> ProperType: """Calculate the join of two instance types. """ if t.type == s.type: @@ -328,7 +337,7 @@ def join_instances(t: Instance, s: Instance) -> Type: return join_instances_via_supertype(s, t) -def join_instances_via_supertype(t: Instance, s: Instance) -> Type: +def join_instances_via_supertype(t: Instance, s: Instance) -> ProperType: # Give preference to joins via duck typing relationship, so that # join(int, float) == float, for example. if t.type._promote and is_subtype(t.type._promote, s): @@ -338,7 +347,7 @@ def join_instances_via_supertype(t: Instance, s: Instance) -> Type: # Compute the "best" supertype of t when joined with s. # The definition of "best" may evolve; for now it is the one with # the longest MRO. Ties are broken by using the earlier base. - best = None # type: Optional[Type] + best = None # type: Optional[ProperType] for base in t.type.bases: mapped = map_instance_to_supertype(t, base.type) res = join_instances(mapped, s) @@ -351,6 +360,9 @@ def join_instances_via_supertype(t: Instance, s: Instance) -> Type: def is_better(t: Type, s: Type) -> bool: # Given two possible results from join_instances_via_supertype(), # indicate whether t is the better one. + t = get_proper_type(t) + s = get_proper_type(s) + if isinstance(t, Instance): if not isinstance(s, Instance): return True diff --git a/mypy/meet.py b/mypy/meet.py index aabdfb67a24e..072dfc25988f 100644 --- a/mypy/meet.py +++ b/mypy/meet.py @@ -8,6 +8,7 @@ Type, AnyType, TypeVisitor, UnboundType, NoneType, TypeVarType, Instance, CallableType, TupleType, TypedDictType, ErasedType, UnionType, PartialType, DeletedType, UninhabitedType, TypeType, TypeOfAny, Overloaded, FunctionLike, LiteralType, + ProperType, get_proper_type, get_proper_types ) from mypy.subtypes import ( is_equivalent, is_subtype, is_protocol_implementation, is_callable_compatible, @@ -21,8 +22,11 @@ # TODO Describe this module. -def meet_types(s: Type, t: Type) -> Type: +def meet_types(s: Type, t: Type) -> ProperType: """Return the greatest lower bound of two types.""" + s = get_proper_type(s) + t = get_proper_type(t) + if isinstance(s, ErasedType): return s if isinstance(s, AnyType): @@ -34,6 +38,9 @@ def meet_types(s: Type, t: Type) -> Type: def narrow_declared_type(declared: Type, narrowed: Type) -> Type: """Return the declared type narrowed down to another type.""" + declared = get_proper_type(declared) + narrowed = get_proper_type(narrowed) + if declared == narrowed: return declared if isinstance(declared, UnionType): @@ -57,7 +64,7 @@ def narrow_declared_type(declared: Type, narrowed: Type) -> Type: elif isinstance(declared, TypedDictType) and isinstance(narrowed, Instance): # Special case useful for selecting TypedDicts from unions using isinstance(x, dict). if (narrowed.type.fullname() == 'builtins.dict' and - all(isinstance(t, AnyType) for t in narrowed.args)): + all(isinstance(t, AnyType) for t in get_proper_types(narrowed.args))): return declared return meet_types(declared, narrowed) return narrowed @@ -88,13 +95,15 @@ def get_possible_variants(typ: Type) -> List[Type]: Normalizing both kinds of types in the same way lets us reuse the same algorithm for both. """ + typ = get_proper_type(typ) + if isinstance(typ, TypeVarType): if len(typ.values) > 0: return typ.values else: return [typ.upper_bound] elif isinstance(typ, UnionType): - return typ.items + return list(typ.items) elif isinstance(typ, Overloaded): # Note: doing 'return typ.items()' makes mypy # infer a too-specific return type of List[CallableType] @@ -113,6 +122,8 @@ def is_overlapping_types(left: Type, If 'prohibit_none_typevar_overlap' is True, we disallow None from overlapping with TypeVars (in both strict-optional and non-strict-optional mode). """ + left = get_proper_type(left) + right = get_proper_type(right) def _is_overlapping_types(left: Type, right: Type) -> bool: '''Encode the kind of overlapping check to perform. @@ -175,7 +186,7 @@ def _is_overlapping_types(left: Type, right: Type) -> bool: # If both types are singleton variants (and are not TypeVars), we've hit the base case: # we skip these checks to avoid infinitely recursing. - def is_none_typevar_overlap(t1: Type, t2: Type) -> bool: + def is_none_typevar_overlap(t1: ProperType, t2: ProperType) -> bool: return isinstance(t1, NoneType) and isinstance(t2, TypeVarType) if prohibit_none_typevar_overlap: @@ -234,7 +245,7 @@ def is_none_typevar_overlap(t1: Type, t2: Type) -> bool: if isinstance(left, TypeType) and isinstance(right, TypeType): return _is_overlapping_types(left.item, right.item) - def _type_object_overlap(left: Type, right: Type) -> bool: + def _type_object_overlap(left: ProperType, right: ProperType) -> bool: """Special cases for type object types overlaps.""" # TODO: these checks are a bit in gray area, adjust if they cause problems. # 1. Type[C] vs Callable[..., C], where the latter is class object. @@ -356,7 +367,7 @@ def are_typed_dicts_overlapping(left: TypedDictType, right: TypedDictType, *, return True -def are_tuples_overlapping(left: Type, right: Type, *, +def are_tuples_overlapping(left: ProperType, right: ProperType, *, ignore_promotions: bool = False, prohibit_none_typevar_overlap: bool = False) -> bool: """Returns true if left and right are overlapping tuples.""" @@ -372,7 +383,7 @@ def are_tuples_overlapping(left: Type, right: Type, *, for l, r in zip(left.items, right.items)) -def adjust_tuple(left: Type, r: Type) -> Optional[TupleType]: +def adjust_tuple(left: ProperType, r: ProperType) -> Optional[TupleType]: """Find out if `left` is a Tuple[A, ...], and adjust its length to `right`""" if isinstance(left, Instance) and left.type.fullname() == 'builtins.tuple': n = r.length() if isinstance(r, TupleType) else 1 @@ -381,15 +392,16 @@ def adjust_tuple(left: Type, r: Type) -> Optional[TupleType]: def is_tuple(typ: Type) -> bool: + typ = get_proper_type(typ) return (isinstance(typ, TupleType) or (isinstance(typ, Instance) and typ.type.fullname() == 'builtins.tuple')) -class TypeMeetVisitor(TypeVisitor[Type]): - def __init__(self, s: Type) -> None: +class TypeMeetVisitor(TypeVisitor[ProperType]): + def __init__(self, s: ProperType) -> None: self.s = s - def visit_unbound_type(self, t: UnboundType) -> Type: + def visit_unbound_type(self, t: UnboundType) -> ProperType: if isinstance(self.s, NoneType): if state.strict_optional: return AnyType(TypeOfAny.special_form) @@ -400,10 +412,10 @@ def visit_unbound_type(self, t: UnboundType) -> Type: else: return AnyType(TypeOfAny.special_form) - def visit_any(self, t: AnyType) -> Type: + def visit_any(self, t: AnyType) -> ProperType: return self.s - def visit_union_type(self, t: UnionType) -> Type: + def visit_union_type(self, t: UnionType) -> ProperType: if isinstance(self.s, UnionType): meets = [] # type: List[Type] for x in t.items: @@ -414,7 +426,7 @@ def visit_union_type(self, t: UnionType) -> Type: for x in t.items] return UnionType.make_simplified_union(meets) - def visit_none_type(self, t: NoneType) -> Type: + def visit_none_type(self, t: NoneType) -> ProperType: if state.strict_optional: if isinstance(self.s, NoneType) or (isinstance(self.s, Instance) and self.s.type.fullname() == 'builtins.object'): @@ -424,10 +436,10 @@ def visit_none_type(self, t: NoneType) -> Type: else: return t - def visit_uninhabited_type(self, t: UninhabitedType) -> Type: + def visit_uninhabited_type(self, t: UninhabitedType) -> ProperType: return t - def visit_deleted_type(self, t: DeletedType) -> Type: + def visit_deleted_type(self, t: DeletedType) -> ProperType: if isinstance(self.s, NoneType): if state.strict_optional: return t @@ -438,16 +450,16 @@ def visit_deleted_type(self, t: DeletedType) -> Type: else: return t - def visit_erased_type(self, t: ErasedType) -> Type: + def visit_erased_type(self, t: ErasedType) -> ProperType: return self.s - def visit_type_var(self, t: TypeVarType) -> Type: + def visit_type_var(self, t: TypeVarType) -> ProperType: if isinstance(self.s, TypeVarType) and self.s.id == t.id: return self.s else: return self.default(self.s) - def visit_instance(self, t: Instance) -> Type: + def visit_instance(self, t: Instance) -> ProperType: if isinstance(self.s, Instance): si = self.s if t.type == si.type: @@ -488,12 +500,12 @@ def visit_instance(self, t: Instance) -> Type: return meet_types(t, self.s) return self.default(self.s) - def visit_callable_type(self, t: CallableType) -> Type: + def visit_callable_type(self, t: CallableType) -> ProperType: if isinstance(self.s, CallableType) and is_similar_callables(t, self.s): if is_equivalent(t, self.s): return combine_similar_callables(t, self.s) result = meet_similar_callables(t, self.s) - if isinstance(result.ret_type, UninhabitedType): + if isinstance(get_proper_type(result.ret_type), UninhabitedType): # Return a plain None or instead of a weird function. return self.default(self.s) return result @@ -509,7 +521,7 @@ def visit_callable_type(self, t: CallableType) -> Type: return meet_types(t, call) return self.default(self.s) - def visit_overloaded(self, t: Overloaded) -> Type: + def visit_overloaded(self, t: Overloaded) -> ProperType: # TODO: Implement a better algorithm that covers at least the same cases # as TypeJoinVisitor.visit_overloaded(). s = self.s @@ -528,7 +540,7 @@ def visit_overloaded(self, t: Overloaded) -> Type: return meet_types(t, call) return meet_types(t.fallback, s) - def visit_tuple_type(self, t: TupleType) -> Type: + def visit_tuple_type(self, t: TupleType) -> ProperType: if isinstance(self.s, TupleType) and self.s.length() == t.length(): items = [] # type: List[Type] for i in range(t.length()): @@ -544,7 +556,7 @@ def visit_tuple_type(self, t: TupleType) -> Type: return t return self.default(self.s) - def visit_typeddict_type(self, t: TypedDictType) -> Type: + def visit_typeddict_type(self, t: TypedDictType) -> ProperType: if isinstance(self.s, TypedDictType): for (name, l, r) in self.s.zip(t): if (not is_equivalent(l, r) or @@ -568,7 +580,7 @@ def visit_typeddict_type(self, t: TypedDictType) -> Type: else: return self.default(self.s) - def visit_literal_type(self, t: LiteralType) -> Type: + def visit_literal_type(self, t: LiteralType) -> ProperType: if isinstance(self.s, LiteralType) and self.s == t: return t elif isinstance(self.s, Instance) and is_subtype(t.fallback, self.s): @@ -576,11 +588,11 @@ def visit_literal_type(self, t: LiteralType) -> Type: else: return self.default(self.s) - def visit_partial_type(self, t: PartialType) -> Type: + def visit_partial_type(self, t: PartialType) -> ProperType: # We can't determine the meet of partial types. We should never get here. assert False, 'Internal error' - def visit_type_type(self, t: TypeType) -> Type: + def visit_type_type(self, t: TypeType) -> ProperType: if isinstance(self.s, TypeType): typ = self.meet(t.item, self.s.item) if not isinstance(typ, NoneType): @@ -593,10 +605,10 @@ def visit_type_type(self, t: TypeType) -> Type: else: return self.default(self.s) - def meet(self, s: Type, t: Type) -> Type: + def meet(self, s: Type, t: Type) -> ProperType: return meet_types(s, t) - def default(self, typ: Type) -> Type: + def default(self, typ: Type) -> ProperType: if isinstance(typ, UnboundType): return AnyType(TypeOfAny.special_form) else: @@ -624,7 +636,7 @@ def meet_similar_callables(t: CallableType, s: CallableType) -> CallableType: name=None) -def typed_dict_mapping_pair(left: Type, right: Type) -> bool: +def typed_dict_mapping_pair(left: ProperType, right: ProperType) -> bool: """Is this a pair where one type is a TypedDict and another one is an instance of Mapping? This case requires a precise/principled consideration because there are two use cases @@ -643,7 +655,7 @@ def typed_dict_mapping_pair(left: Type, right: Type) -> bool: return isinstance(other, Instance) and other.type.has_base('typing.Mapping') -def typed_dict_mapping_overlap(left: Type, right: Type, +def typed_dict_mapping_overlap(left: ProperType, right: ProperType, overlapping: Callable[[Type, Type], bool]) -> bool: """Check if a TypedDict type is overlapping with a Mapping. @@ -685,7 +697,7 @@ def typed_dict_mapping_overlap(left: Type, right: Type, mapping = next(base for base in other.type.mro if base.fullname() == 'typing.Mapping') other = map_instance_to_supertype(other, mapping) - key_type, value_type = other.args + key_type, value_type = get_proper_types(other.args) # TODO: is there a cleaner way to get str_type here? fallback = typed.as_anonymous().fallback diff --git a/mypy/messages.py b/mypy/messages.py index 5ab629e9b32b..20082f767537 100644 --- a/mypy/messages.py +++ b/mypy/messages.py @@ -443,7 +443,7 @@ def incompatible_argument(self, n: int, m: int, callee: CallableType, arg_type: arg_label, target, quote_type_string(arg_type_str), quote_type_string(expected_type_str)) if isinstance(expected_type, UnionType): - expected_types = expected_type.items + expected_types = list(expected_type.items) # type: List[Type] else: expected_types = [expected_type] for type in expected_types: diff --git a/mypy/nodes.py b/mypy/nodes.py index 0934ff520f3e..5f8c83a50631 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -449,9 +449,9 @@ def __init__(self) -> None: super().__init__() # Type signature. This is usually CallableType or Overloaded, but it can be # something else for decorated functions. - self.type = None # type: Optional[mypy.types.Type] + self.type = None # type: Optional[mypy.types.ProperType] # Original, not semantically analyzed type (used for reprocessing) - self.unanalyzed_type = None # type: Optional[mypy.types.Type] + self.unanalyzed_type = None # type: Optional[mypy.types.ProperType] # If method, reference to TypeInfo # TODO: Type should be Optional[TypeInfo] self.info = FUNC_NO_INFO @@ -528,7 +528,9 @@ def deserialize(cls, data: JsonDict) -> 'OverloadedFuncDef': if len(res.items) > 0: res.set_line(res.impl.line) if data.get('type') is not None: - res.type = mypy.types.deserialize_type(data['type']) + typ = mypy.types.deserialize_type(data['type']) + assert isinstance(typ, mypy.types.ProperType) + res.type = typ res._fullname = data['fullname'] set_flags(res, data['flags']) # NOTE: res.info will be set in the fixup phase. @@ -3079,7 +3081,7 @@ def get_member_expr_fullname(expr: MemberExpr) -> Optional[str]: deserialize_map = { - key: obj.deserialize # type: ignore + key: obj.deserialize for key, obj in globals().items() if type(obj) is not FakeInfo and isinstance(obj, type) and issubclass(obj, SymbolNode) and obj is not SymbolNode diff --git a/mypy/plugin.py b/mypy/plugin.py index aac6a41f58df..6c34d20ebafb 100644 --- a/mypy/plugin.py +++ b/mypy/plugin.py @@ -130,7 +130,7 @@ class C: pass Expression, Context, ClassDef, SymbolTableNode, MypyFile, CallExpr ) from mypy.tvar_scope import TypeVarScope -from mypy.types import Type, Instance, CallableType, TypeList, UnboundType +from mypy.types import Type, Instance, CallableType, TypeList, UnboundType, ProperType from mypy.messages import MessageBuilder from mypy.options import Options from mypy.lookup import lookup_fully_qualified @@ -382,9 +382,10 @@ def final_iteration(self) -> bool: # A context for a method signature hook that infers a better signature for a # method. Note that argument types aren't available yet. If you need them, # you have to use a method hook instead. +# TODO: document ProperType in the plugin changelog/update issue. MethodSigContext = NamedTuple( 'MethodSigContext', [ - ('type', Type), # Base object type for method call + ('type', ProperType), # Base object type for method call ('args', List[List[Expression]]), # Actual expressions for each formal argument ('default_signature', CallableType), # Original signature of the method ('context', Context), # Relevant location context (e.g. for error messages) @@ -396,7 +397,7 @@ def final_iteration(self) -> bool: # This is very similar to FunctionContext (only differences are documented). MethodContext = NamedTuple( 'MethodContext', [ - ('type', Type), # Base object type for method call + ('type', ProperType), # Base object type for method call ('arg_types', List[List[Type]]), # List of actual caller types for each formal argument # see FunctionContext for details about names and kinds ('arg_kinds', List[List[int]]), @@ -410,7 +411,7 @@ def final_iteration(self) -> bool: # A context for an attribute type hook that infers the type of an attribute. AttributeContext = NamedTuple( 'AttributeContext', [ - ('type', Type), # Type of object with attribute + ('type', ProperType), # Type of object with attribute ('default_attr_type', Type), # Original attribute type ('context', Context), # Relevant location context (e.g. for error messages) ('api', CheckerPluginInterface)]) diff --git a/mypy/plugins/attrs.py b/mypy/plugins/attrs.py index b0059226f963..796b3d74ee4b 100644 --- a/mypy/plugins/attrs.py +++ b/mypy/plugins/attrs.py @@ -20,7 +20,7 @@ ) from mypy.types import ( Type, AnyType, TypeOfAny, CallableType, NoneType, TypeVarDef, TypeVarType, - Overloaded, UnionType, FunctionLike + Overloaded, UnionType, FunctionLike, get_proper_type ) from mypy.typevars import fill_typevars from mypy.util import unmangle @@ -94,6 +94,7 @@ def argument(self, ctx: 'mypy.plugin.ClassDefContext') -> Argument: converter_type = converter.type init_type = None + converter_type = get_proper_type(converter_type) if isinstance(converter_type, CallableType) and converter_type.arg_types: init_type = ctx.api.anal_type(converter_type.arg_types[0]) elif isinstance(converter_type, Overloaded): diff --git a/mypy/plugins/common.py b/mypy/plugins/common.py index b8bfeb02a623..ce1111b19035 100644 --- a/mypy/plugins/common.py +++ b/mypy/plugins/common.py @@ -6,7 +6,10 @@ ) from mypy.plugin import ClassDefContext from mypy.semanal import set_callable_name -from mypy.types import CallableType, Overloaded, Type, TypeVarDef, LiteralType, Instance, UnionType +from mypy.types import ( + CallableType, Overloaded, Type, TypeVarDef, LiteralType, Instance, UnionType, + get_proper_type, get_proper_types +) from mypy.typevars import fill_typevars from mypy.util import get_unique_redefinition_name @@ -55,7 +58,7 @@ def _get_argument(call: CallExpr, name: str) -> Optional[Expression]: callee_node = call.callee.node if (isinstance(callee_node, (Var, SYMBOL_FUNCBASE_TYPES)) and callee_node.type): - callee_node_type = callee_node.type + callee_node_type = get_proper_type(callee_node.type) if isinstance(callee_node_type, Overloaded): # We take the last overload. callee_type = callee_node_type.items()[-1] @@ -141,18 +144,20 @@ def try_getting_str_literals(expr: Expression, typ: Type) -> Optional[List[str]] 2. 'typ' is a LiteralType containing a string 3. 'typ' is a UnionType containing only LiteralType of strings """ + typ = get_proper_type(typ) + if isinstance(expr, StrExpr): return [expr.value] if isinstance(typ, Instance) and typ.last_known_value is not None: possible_literals = [typ.last_known_value] # type: List[Type] elif isinstance(typ, UnionType): - possible_literals = typ.items + possible_literals = list(typ.items) else: possible_literals = [typ] strings = [] - for lit in possible_literals: + for lit in get_proper_types(possible_literals): if isinstance(lit, LiteralType) and lit.fallback.type.fullname() == 'builtins.str': val = lit.value assert isinstance(val, str) diff --git a/mypy/plugins/ctypes.py b/mypy/plugins/ctypes.py index 119689ae3b20..446fb2cac8ad 100644 --- a/mypy/plugins/ctypes.py +++ b/mypy/plugins/ctypes.py @@ -8,7 +8,8 @@ from mypy.maptype import map_instance_to_supertype from mypy.subtypes import is_subtype from mypy.types import ( - AnyType, CallableType, Instance, NoneType, Type, TypeOfAny, UnionType, union_items + AnyType, CallableType, Instance, NoneType, Type, TypeOfAny, UnionType, + union_items, ProperType, get_proper_type ) @@ -31,7 +32,7 @@ def _get_text_type(api: 'mypy.plugin.CheckerPluginInterface') -> Instance: def _find_simplecdata_base_arg(tp: Instance, api: 'mypy.plugin.CheckerPluginInterface' - ) -> Optional[Type]: + ) -> Optional[ProperType]: """Try to find a parametrized _SimpleCData in tp's bases and return its single type argument. None is returned if _SimpleCData appears nowhere in tp's (direct or indirect) bases. @@ -40,7 +41,7 @@ def _find_simplecdata_base_arg(tp: Instance, api: 'mypy.plugin.CheckerPluginInte simplecdata_base = map_instance_to_supertype(tp, api.named_generic_type('ctypes._SimpleCData', [AnyType(TypeOfAny.special_form)]).type) assert len(simplecdata_base.args) == 1, '_SimpleCData takes exactly one type argument' - return simplecdata_base.args[0] + return get_proper_type(simplecdata_base.args[0]) return None @@ -78,13 +79,15 @@ def _autoconvertible_to_cdata(tp: Type, api: 'mypy.plugin.CheckerPluginInterface return UnionType.make_simplified_union(allowed_types) -def _autounboxed_cdata(tp: Type) -> Type: +def _autounboxed_cdata(tp: Type) -> ProperType: """Get the auto-unboxed version of a CData type, if applicable. For *direct* _SimpleCData subclasses, the only type argument of _SimpleCData in the bases list is returned. For all other CData types, including indirect _SimpleCData subclasses, tp is returned as-is. """ + tp = get_proper_type(tp) + if isinstance(tp, UnionType): return UnionType.make_simplified_union([_autounboxed_cdata(t) for t in tp.items]) elif isinstance(tp, Instance): @@ -93,18 +96,19 @@ def _autounboxed_cdata(tp: Type) -> Type: # If tp has _SimpleCData as a direct base class, # the auto-unboxed type is the single type argument of the _SimpleCData type. assert len(base.args) == 1 - return base.args[0] + return get_proper_type(base.args[0]) # If tp is not a concrete type, or if there is no _SimpleCData in the bases, # the type is not auto-unboxed. return tp -def _get_array_element_type(tp: Type) -> Optional[Type]: +def _get_array_element_type(tp: Type) -> Optional[ProperType]: """Get the element type of the Array type tp, or None if not specified.""" + tp = get_proper_type(tp) if isinstance(tp, Instance): assert tp.type.fullname() == 'ctypes.Array' if len(tp.args) == 1: - return tp.args[0] + return get_proper_type(tp.args[0]) return None @@ -145,7 +149,7 @@ def array_getitem_callback(ctx: 'mypy.plugin.MethodContext') -> Type: 'The stub of ctypes.Array.__getitem__ should have exactly one parameter' assert len(ctx.arg_types[0]) == 1, \ "ctypes.Array.__getitem__'s parameter should not be variadic" - index_type = ctx.arg_types[0][0] + index_type = get_proper_type(ctx.arg_types[0][0]) if isinstance(index_type, Instance): if index_type.type.has_base('builtins.int'): return unboxed @@ -160,7 +164,7 @@ def array_setitem_callback(ctx: 'mypy.plugin.MethodSigContext') -> CallableType: if et is not None: allowed = _autoconvertible_to_cdata(et, ctx.api) assert len(ctx.default_signature.arg_types) == 2 - index_type = ctx.default_signature.arg_types[0] + index_type = get_proper_type(ctx.default_signature.arg_types[0]) if isinstance(index_type, Instance): arg_type = None if index_type.type.has_base('builtins.int'): diff --git a/mypy/plugins/dataclasses.py b/mypy/plugins/dataclasses.py index d3adeebf79d2..9f8aba90cf39 100644 --- a/mypy/plugins/dataclasses.py +++ b/mypy/plugins/dataclasses.py @@ -12,7 +12,7 @@ ) from mypy.plugin import ClassDefContext from mypy.plugins.common import add_method, _get_decorator_bool_argument -from mypy.types import Instance, NoneType, TypeVarDef, TypeVarType +from mypy.types import Instance, NoneType, TypeVarDef, TypeVarType, get_proper_type from mypy.server.trigger import make_wildcard_trigger # The set of decorators that generate dataclasses. @@ -234,12 +234,11 @@ def collect_attributes(self) -> List[DataclassAttribute]: # x: InitVar[int] is turned into x: int and is removed from the class. is_init_var = False - if ( - isinstance(node.type, Instance) and - node.type.type.fullname() == 'dataclasses.InitVar' - ): + node_type = get_proper_type(node.type) + if (isinstance(node_type, Instance) and + node_type.type.fullname() == 'dataclasses.InitVar'): is_init_var = True - node.type = node.type.args[0] + node.type = node_type.args[0] has_field_call, field_args = _collect_field_args(stmt.rvalue) diff --git a/mypy/plugins/default.py b/mypy/plugins/default.py index 4bfd9129188b..e67cd1fa3aa0 100644 --- a/mypy/plugins/default.py +++ b/mypy/plugins/default.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Callable, Optional +from typing import Callable, Optional, List from mypy import message_registry from mypy.nodes import StrExpr, IntExpr, DictExpr, UnaryExpr @@ -9,7 +9,7 @@ from mypy.plugins.common import try_getting_str_literals from mypy.types import ( Type, Instance, AnyType, TypeOfAny, CallableType, NoneType, UnionType, TypedDictType, - TypeVarType, TPDICT_FB_NAMES + TypeVarType, TPDICT_FB_NAMES, get_proper_type ) from mypy.subtypes import is_subtype @@ -111,7 +111,7 @@ def open_callback(ctx: FunctionContext) -> Type: elif isinstance(ctx.args[1][0], StrExpr): mode = ctx.args[1][0].value if mode is not None: - assert isinstance(ctx.default_return_type, Instance) + assert isinstance(ctx.default_return_type, Instance) # type: ignore if 'b' in mode: return ctx.api.named_generic_type('typing.BinaryIO', []) else: @@ -123,12 +123,13 @@ def contextmanager_callback(ctx: FunctionContext) -> Type: """Infer a better return type for 'contextlib.contextmanager'.""" # Be defensive, just in case. if ctx.arg_types and len(ctx.arg_types[0]) == 1: - arg_type = ctx.arg_types[0][0] + arg_type = get_proper_type(ctx.arg_types[0][0]) + default_return = get_proper_type(ctx.default_return_type) if (isinstance(arg_type, CallableType) - and isinstance(ctx.default_return_type, CallableType)): + and isinstance(default_return, CallableType)): # The stub signature doesn't preserve information about arguments so # add them back here. - return ctx.default_return_type.copy_modified( + return default_return.copy_modified( arg_types=arg_type.arg_types, arg_kinds=arg_type.arg_kinds, arg_names=arg_type.arg_names, @@ -152,7 +153,7 @@ def typed_dict_get_signature_callback(ctx: MethodSigContext) -> CallableType: and len(signature.variables) == 1 and len(ctx.args[1]) == 1): key = ctx.args[0][0].value - value_type = ctx.type.items.get(key) + value_type = get_proper_type(ctx.type.items.get(key)) ret_type = signature.ret_type if value_type: default_arg = ctx.args[1][0] @@ -181,9 +182,9 @@ def typed_dict_get_callback(ctx: MethodContext) -> Type: if keys is None: return ctx.default_return_type - output_types = [] + output_types = [] # type: List[Type] for key in keys: - value_type = ctx.type.items.get(key) + value_type = get_proper_type(ctx.type.items.get(key)) if value_type is None: ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context) return AnyType(TypeOfAny.from_error) @@ -353,7 +354,7 @@ def typed_dict_update_signature_callback(ctx: MethodSigContext) -> CallableType: signature = ctx.default_signature if (isinstance(ctx.type, TypedDictType) and len(signature.arg_types) == 1): - arg_type = signature.arg_types[0] + arg_type = get_proper_type(signature.arg_types[0]) assert isinstance(arg_type, TypedDictType) arg_type = arg_type.as_anonymous() arg_type = arg_type.copy_modified(required_keys=set()) diff --git a/mypy/plugins/enums.py b/mypy/plugins/enums.py index eeb3070c529c..e842fed1f32f 100644 --- a/mypy/plugins/enums.py +++ b/mypy/plugins/enums.py @@ -14,7 +14,7 @@ from typing_extensions import Final import mypy.plugin # To avoid circular imports. -from mypy.types import Type, Instance, LiteralType +from mypy.types import Type, Instance, LiteralType, get_proper_type # Note: 'enum.EnumMeta' is deliberately excluded from this list. Classes that directly use # enum.EnumMeta do not necessarily automatically have the 'name' and 'value' attributes. @@ -86,7 +86,7 @@ class SomeEnum: if stnode is None: return ctx.default_attr_type - underlying_type = stnode.type + underlying_type = get_proper_type(stnode.type) if underlying_type is None: # TODO: Deduce the inferred type if the user omits adding their own default types. # TODO: Consider using the return type of `Enum._generate_next_value_` here? @@ -111,7 +111,7 @@ def _extract_underlying_field_name(typ: Type) -> Optional[str]: We can examine this Literal fallback to retrieve the string. """ - + typ = get_proper_type(typ) if not isinstance(typ, Instance): return None diff --git a/mypy/sametypes.py b/mypy/sametypes.py index 0777865421e0..cc0884ac42e7 100644 --- a/mypy/sametypes.py +++ b/mypy/sametypes.py @@ -4,12 +4,15 @@ Type, UnboundType, AnyType, NoneType, TupleType, TypedDictType, UnionType, CallableType, TypeVarType, Instance, TypeVisitor, ErasedType, Overloaded, PartialType, DeletedType, UninhabitedType, TypeType, LiteralType, + ProperType, get_proper_type ) from mypy.typeops import tuple_fallback def is_same_type(left: Type, right: Type) -> bool: """Is 'left' the same type as 'right'?""" + left = get_proper_type(left) + right = get_proper_type(right) if isinstance(right, UnboundType): # Make unbound types same as anything else to reduce the number of @@ -29,7 +32,8 @@ def is_same_type(left: Type, right: Type) -> bool: return left.accept(SameTypeVisitor(right)) -def simplify_union(t: Type) -> Type: +def simplify_union(t: Type) -> ProperType: + t = get_proper_type(t) if isinstance(t, UnionType): return UnionType.make_simplified_union(t.items) return t @@ -47,7 +51,7 @@ def is_same_types(a1: Sequence[Type], a2: Sequence[Type]) -> bool: class SameTypeVisitor(TypeVisitor[bool]): """Visitor for checking whether two types are the 'same' type.""" - def __init__(self, right: Type) -> None: + def __init__(self, right: ProperType) -> None: self.right = right # visit_x(left) means: is left (which is an instance of X) the same type as diff --git a/mypy/semanal.py b/mypy/semanal.py index 48455a82051a..88203cec73bf 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -86,7 +86,7 @@ from mypy.types import ( FunctionLike, UnboundType, TypeVarDef, TupleType, UnionType, StarType, function_type, CallableType, Overloaded, Instance, Type, AnyType, LiteralType, LiteralValue, - TypeTranslator, TypeOfAny, TypeType, NoneType, PlaceholderType, TPDICT_NAMES + TypeTranslator, TypeOfAny, TypeType, NoneType, PlaceholderType, TPDICT_NAMES, ProperType ) from mypy.type_visitor import TypeQuery from mypy.nodes import implicit_module_attrs @@ -571,6 +571,7 @@ def analyze_func_def(self, defn: FuncDef) -> None: if self.found_incomplete_ref(tag) or has_placeholder(result): self.defer(defn) return + assert isinstance(result, ProperType) defn.type = result self.add_type_alias_deps(analyzer.aliases_used) self.check_function_signature(defn) diff --git a/mypy/semanal_infer.py b/mypy/semanal_infer.py index 402214dbaffe..8c7d2926cce8 100644 --- a/mypy/semanal_infer.py +++ b/mypy/semanal_infer.py @@ -3,7 +3,9 @@ from typing import Optional from mypy.nodes import Expression, Decorator, CallExpr, FuncDef, RefExpr, Var, ARG_POS -from mypy.types import Type, CallableType, AnyType, TypeOfAny, TypeVarType, function_type +from mypy.types import ( + Type, CallableType, AnyType, TypeOfAny, TypeVarType, function_type, ProperType, get_proper_type +) from mypy.typevars import has_no_typevars from mypy.semanal_shared import SemanticAnalyzerInterface @@ -62,13 +64,14 @@ def infer_decorator_signature_if_simple(dec: Decorator, def is_identity_signature(sig: Type) -> bool: """Is type a callable of form T -> T (where T is a type variable)?""" + sig = get_proper_type(sig) if isinstance(sig, CallableType) and sig.arg_kinds == [ARG_POS]: if isinstance(sig.arg_types[0], TypeVarType) and isinstance(sig.ret_type, TypeVarType): return sig.arg_types[0].id == sig.ret_type.id return False -def calculate_return_type(expr: Expression) -> Optional[Type]: +def calculate_return_type(expr: Expression) -> Optional[ProperType]: """Return the return type if we can calculate it. This only uses information available during semantic analysis so this @@ -83,10 +86,10 @@ def calculate_return_type(expr: Expression) -> Optional[Type]: return AnyType(TypeOfAny.unannotated) # Explicit Any return? if isinstance(typ, CallableType): - return typ.ret_type + return get_proper_type(typ.ret_type) return None elif isinstance(expr.node, Var): - return expr.node.type + return get_proper_type(expr.node.type) elif isinstance(expr, CallExpr): return calculate_return_type(expr.callee) return None @@ -104,11 +107,13 @@ def find_fixed_callable_return(expr: Expression) -> Optional[CallableType]: typ = expr.node.type if typ: if isinstance(typ, CallableType) and has_no_typevars(typ.ret_type): - if isinstance(typ.ret_type, CallableType): - return typ.ret_type + ret_type = get_proper_type(typ.ret_type) + if isinstance(ret_type, CallableType): + return ret_type elif isinstance(expr, CallExpr): t = find_fixed_callable_return(expr.callee) if t: - if isinstance(t.ret_type, CallableType): - return t.ret_type + ret_type = get_proper_type(t.ret_type) + if isinstance(ret_type, CallableType): + return ret_type return None diff --git a/mypy/semanal_newtype.py b/mypy/semanal_newtype.py index ff9532dcd2ae..10e8adeba8fa 100644 --- a/mypy/semanal_newtype.py +++ b/mypy/semanal_newtype.py @@ -7,7 +7,7 @@ from mypy.types import ( Type, Instance, CallableType, NoneType, TupleType, AnyType, PlaceholderType, - TypeOfAny + TypeOfAny, get_proper_type ) from mypy.nodes import ( AssignmentStmt, NewTypeExpr, CallExpr, NameExpr, RefExpr, Context, StrExpr, BytesExpr, @@ -54,6 +54,7 @@ def process_newtype_declaration(self, s: AssignmentStmt) -> bool: self.api.add_symbol(name, placeholder, s, can_defer=False) old_type, should_defer = self.check_newtype_args(name, call, s) + old_type = get_proper_type(old_type) if not call.analyzed: call.analyzed = NewTypeExpr(name, old_type, line=call.line, column=call.column) if old_type is None: @@ -159,7 +160,8 @@ def check_newtype_args(self, name: str, call: CallExpr, # We want to use our custom error message (see above), so we suppress # the default error message for invalid types here. - old_type = self.api.anal_type(unanalyzed_type, report_invalid_types=False) + old_type = get_proper_type(self.api.anal_type(unanalyzed_type, + report_invalid_types=False)) should_defer = False if old_type is None or isinstance(old_type, PlaceholderType): should_defer = True diff --git a/mypy/semanal_shared.py b/mypy/semanal_shared.py index f2a40139a7ff..3c3cfa882212 100644 --- a/mypy/semanal_shared.py +++ b/mypy/semanal_shared.py @@ -11,7 +11,9 @@ SymbolNode, SymbolTable ) from mypy.util import correct_relative_import -from mypy.types import Type, FunctionLike, Instance, TupleType, TPDICT_FB_NAMES +from mypy.types import ( + Type, FunctionLike, Instance, TupleType, TPDICT_FB_NAMES, ProperType, get_proper_type +) from mypy.tvar_scope import TypeVarScope from mypy import join @@ -179,7 +181,8 @@ def create_indirect_imported_name(file_node: MypyFile, return SymbolTableNode(GDEF, link) -def set_callable_name(sig: Type, fdef: FuncDef) -> Type: +def set_callable_name(sig: Type, fdef: FuncDef) -> ProperType: + sig = get_proper_type(sig) if isinstance(sig, FunctionLike): if fdef.info: if fdef.info.fullname() in TPDICT_FB_NAMES: @@ -192,7 +195,7 @@ def set_callable_name(sig: Type, fdef: FuncDef) -> Type: else: return sig.with_name(fdef.name()) else: - return sig + return get_proper_type(sig) def calculate_tuple_fallback(typ: TupleType) -> None: diff --git a/mypy/semanal_typeargs.py b/mypy/semanal_typeargs.py index c54c3bb6c601..b6d22538dbb3 100644 --- a/mypy/semanal_typeargs.py +++ b/mypy/semanal_typeargs.py @@ -8,7 +8,7 @@ from typing import List from mypy.nodes import TypeInfo, Context, MypyFile, FuncItem, ClassDef, Block -from mypy.types import Type, Instance, TypeVarType, AnyType +from mypy.types import Type, Instance, TypeVarType, AnyType, get_proper_types from mypy.mixedtraverser import MixedTraverserVisitor from mypy.subtypes import is_subtype from mypy.sametypes import is_same_type @@ -71,7 +71,7 @@ def visit_instance(self, t: Instance) -> None: def check_type_var_values(self, type: TypeInfo, actuals: List[Type], arg_name: str, valids: List[Type], arg_number: int, context: Context) -> None: - for actual in actuals: + for actual in get_proper_types(actuals): if (not isinstance(actual, AnyType) and not any(is_same_type(actual, value) for value in valids)): diff --git a/mypy/server/deps.py b/mypy/server/deps.py index 885fb1709711..6a35f1541af4 100644 --- a/mypy/server/deps.py +++ b/mypy/server/deps.py @@ -96,7 +96,7 @@ class 'mod.Cls'. This can also refer to an attribute inherited from a from mypy.types import ( Type, Instance, AnyType, NoneType, TypeVisitor, CallableType, DeletedType, PartialType, TupleType, TypeType, TypeVarType, TypedDictType, UnboundType, UninhabitedType, UnionType, - FunctionLike, Overloaded, TypeOfAny, LiteralType, + FunctionLike, Overloaded, TypeOfAny, LiteralType, get_proper_type, ProperType ) from mypy.server.trigger import make_trigger, make_wildcard_trigger from mypy.util import correct_relative_import @@ -387,7 +387,7 @@ def visit_assignment_stmt(self, o: AssignmentStmt) -> None: assert len(o.lvalues) == 1 lvalue = o.lvalues[0] assert isinstance(lvalue, NameExpr) - typ = self.type_map.get(lvalue) + typ = get_proper_type(self.type_map.get(lvalue)) if isinstance(typ, FunctionLike) and typ.is_type_obj(): class_name = typ.type_object().fullname() self.add_dependency(make_trigger(class_name + '.__init__')) @@ -482,10 +482,10 @@ def get_non_partial_lvalue_type(self, lvalue: RefExpr) -> Type: if lvalue not in self.type_map: # Likely a block considered unreachable during type checking. return UninhabitedType() - lvalue_type = self.type_map[lvalue] + lvalue_type = get_proper_type(self.type_map[lvalue]) if isinstance(lvalue_type, PartialType): if isinstance(lvalue.node, Var) and lvalue.node.type: - lvalue_type = lvalue.node.type + lvalue_type = get_proper_type(lvalue.node.type) else: # Probably a secondary, non-definition assignment that doesn't # result in a non-partial type. We won't be able to infer any @@ -566,7 +566,7 @@ def process_global_ref_expr(self, o: RefExpr) -> None: # constructor. # IDEA: Avoid generating spurious dependencies for except statements, # class attribute references, etc., if performance is a problem. - typ = self.type_map.get(o) + typ = get_proper_type(self.type_map.get(o)) if isinstance(typ, FunctionLike) and typ.is_type_obj(): class_name = typ.type_object().fullname() self.add_dependency(make_trigger(class_name + '.__init__')) @@ -602,7 +602,7 @@ def visit_member_expr(self, e: MemberExpr) -> None: # Special case: reference to a missing module attribute. self.add_dependency(make_trigger(e.expr.node.fullname() + '.' + e.name)) return - typ = self.type_map[e.expr] + typ = get_proper_type(self.type_map[e.expr]) self.add_attribute_dependency(typ, e.name) if self.use_logical_deps() and isinstance(typ, AnyType): name = self.get_unimported_fullname(e, typ) @@ -632,7 +632,7 @@ def get_unimported_fullname(self, e: MemberExpr, typ: AnyType) -> Optional[str]: e = e.expr if e.expr not in self.type_map: return None - obj_type = self.type_map[e.expr] + obj_type = get_proper_type(self.type_map[e.expr]) if not isinstance(obj_type, AnyType): # Can't find the base reference to the unimported module. return None @@ -723,15 +723,15 @@ def process_binary_op(self, op: str, left: Expression, right: Expression) -> Non self.add_operator_method_dependency(right, rev_method) def add_operator_method_dependency(self, e: Expression, method: str) -> None: - typ = self.type_map.get(e) + typ = get_proper_type(self.type_map.get(e)) if typ is not None: self.add_operator_method_dependency_for_type(typ, method) - def add_operator_method_dependency_for_type(self, typ: Type, method: str) -> None: + def add_operator_method_dependency_for_type(self, typ: ProperType, method: str) -> None: # Note that operator methods can't be (non-metaclass) methods of type objects # (that is, TypeType objects or Callables representing a type). if isinstance(typ, TypeVarType): - typ = typ.upper_bound + typ = get_proper_type(typ.upper_bound) if isinstance(typ, TupleType): typ = typ.partial_fallback if isinstance(typ, Instance): @@ -811,8 +811,9 @@ def add_attribute_dependency(self, typ: Type, name: str) -> None: def attribute_triggers(self, typ: Type, name: str) -> List[str]: """Return all triggers associated with the attribute of a type.""" + typ = get_proper_type(typ) if isinstance(typ, TypeVarType): - typ = typ.upper_bound + typ = get_proper_type(typ.upper_bound) if isinstance(typ, TupleType): typ = typ.partial_fallback if isinstance(typ, Instance): diff --git a/mypy/solve.py b/mypy/solve.py index d29b1b0f5020..b89c8f35f350 100644 --- a/mypy/solve.py +++ b/mypy/solve.py @@ -3,7 +3,7 @@ from typing import List, Dict, Optional from collections import defaultdict -from mypy.types import Type, AnyType, UninhabitedType, TypeVarId, TypeOfAny +from mypy.types import Type, AnyType, UninhabitedType, TypeVarId, TypeOfAny, get_proper_type from mypy.constraints import Constraint, SUPERTYPE_OF from mypy.join import join_types from mypy.meet import meet_types @@ -49,6 +49,8 @@ def solve_constraints(vars: List[TypeVarId], constraints: List[Constraint], else: top = meet_types(top, c.target) + top = get_proper_type(top) + bottom = get_proper_type(bottom) if isinstance(top, AnyType) or isinstance(bottom, AnyType): source_any = top if isinstance(top, AnyType) else bottom assert isinstance(source_any, AnyType) diff --git a/mypy/stats.py b/mypy/stats.py index ab99868c5f8c..bd2f26ced0bf 100644 --- a/mypy/stats.py +++ b/mypy/stats.py @@ -12,7 +12,7 @@ from mypy.typeanal import collect_all_inner_types from mypy.types import ( Type, AnyType, Instance, FunctionLike, TupleType, TypeVarType, TypeQuery, CallableType, - TypeOfAny + TypeOfAny, get_proper_type, get_proper_types ) from mypy import nodes from mypy.nodes import ( @@ -229,7 +229,7 @@ def record_call_target_precision(self, o: CallExpr) -> None: if not self.typemap or o.callee not in self.typemap: # Type not availabe. return - callee_type = self.typemap[o.callee] + callee_type = get_proper_type(self.typemap[o.callee]) if isinstance(callee_type, CallableType): self.record_callable_target_precision(o, callee_type) else: @@ -253,7 +253,7 @@ def record_callable_target_precision(self, o: CallExpr, callee: CallableType) -> lambda n: typemap[o.args[n]]) for formals in actual_to_formal: for n in formals: - formal = callee.arg_types[n] + formal = get_proper_type(callee.arg_types[n]) if isinstance(formal, AnyType): self.record_line(o.line, TYPE_ANY) elif is_imprecise(formal): @@ -318,6 +318,8 @@ def record_precise_if_checked_scope(self, node: Node) -> None: self.record_line(node.line, kind) def type(self, t: Optional[Type]) -> None: + t = get_proper_type(t) + if not t: # If an expression does not have a type, it is often due to dead code. # Don't count these because there can be an unanalyzed value on a line with other @@ -342,7 +344,7 @@ def type(self, t: Optional[Type]) -> None: self.num_precise_exprs += 1 self.record_line(self.line, TYPE_PRECISE) - for typ in collect_all_inner_types(t) + [t]: + for typ in get_proper_types(collect_all_inner_types(t)) + [t]: if isinstance(typ, AnyType): typ = get_original_any(typ) if is_special_form_any(typ): @@ -436,10 +438,12 @@ def visit_callable_type(self, t: CallableType) -> bool: def is_generic(t: Type) -> bool: + t = get_proper_type(t) return isinstance(t, Instance) and bool(t.args) def is_complex(t: Type) -> bool: + t = get_proper_type(t) return is_generic(t) or isinstance(t, (FunctionLike, TupleType, TypeVarType)) diff --git a/mypy/stubgen.py b/mypy/stubgen.py index 34004a472cce..c32d1719e9e4 100755 --- a/mypy/stubgen.py +++ b/mypy/stubgen.py @@ -569,7 +569,7 @@ def visit_assignment_stmt(self, o: AssignmentStmt) -> None: continue if isinstance(lvalue, TupleExpr) or isinstance(lvalue, ListExpr): items = lvalue.items - if isinstance(o.unanalyzed_type, TupleType): + if isinstance(o.unanalyzed_type, TupleType): # type: ignore annotations = o.unanalyzed_type.items # type: Iterable[Optional[Type]] else: annotations = [None] * len(items) diff --git a/mypy/subtypes.py b/mypy/subtypes.py index d03d1ec6ae76..0a40b09fff32 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -7,7 +7,7 @@ Type, AnyType, UnboundType, TypeVisitor, FormalArgument, NoneType, function_type, Instance, TypeVarType, CallableType, TupleType, TypedDictType, UnionType, Overloaded, ErasedType, PartialType, DeletedType, UninhabitedType, TypeType, is_named_instance, - FunctionLike, TypeOfAny, LiteralType, + FunctionLike, TypeOfAny, LiteralType, ProperType, get_proper_type ) import mypy.applytype import mypy.constraints @@ -64,6 +64,9 @@ def is_subtype(left: Type, right: Type, between the type arguments (e.g., A and B), taking the variance of the type var into account. """ + left = get_proper_type(left) + right = get_proper_type(right) + if (isinstance(right, AnyType) or isinstance(right, UnboundType) or isinstance(right, ErasedType)): return True @@ -113,7 +116,7 @@ def is_equivalent(a: Type, b: Type, class SubtypeVisitor(TypeVisitor[bool]): - def __init__(self, right: Type, + def __init__(self, right: ProperType, *, ignore_type_params: bool, ignore_pos_arg_names: bool = False, @@ -424,7 +427,7 @@ def visit_type_type(self, left: TypeType) -> bool: return True item = left.item if isinstance(item, TypeVarType): - item = item.upper_bound + item = get_proper_type(item.upper_bound) if isinstance(item, Instance): metaclass = item.type.metaclass_type return metaclass is not None and self._is_subtype(metaclass, right) @@ -474,9 +477,9 @@ def f(self) -> A: ... ignore_names = member != '__call__' # __call__ can be passed kwargs # The third argument below indicates to what self type is bound. # We always bind self to the subtype. (Similarly to nominal types). - supertype = find_member(member, right, left) + supertype = get_proper_type(find_member(member, right, left)) assert supertype is not None - subtype = find_member(member, left, left) + subtype = get_proper_type(find_member(member, left, left)) # Useful for debugging: # print(member, 'of', left, 'has type', subtype) # print(member, 'of', right, 'has type', supertype) @@ -554,7 +557,7 @@ def find_member(name: str, itype: Instance, subtype: Type) -> Optional[Type]: # structural subtyping. method = info.get_method(method_name) if method and method.info.fullname() != 'builtins.object': - getattr_type = find_node_type(method, itype, subtype) + getattr_type = get_proper_type(find_node_type(method, itype, subtype)) if isinstance(getattr_type, CallableType): return getattr_type.ret_type if itype.type.fallback_to_any: @@ -612,6 +615,7 @@ def find_node_type(node: Union[Var, FuncBase], itype: Instance, subtype: Type) - fallback=Instance(itype.type.mro[-1], [])) # type: Optional[Type] else: typ = node.type + typ = get_proper_type(typ) if typ is None: return AnyType(TypeOfAny.from_error) # We don't need to bind 'self' for static methods, since there is no 'self'. @@ -640,7 +644,7 @@ def non_method_protocol_members(tp: TypeInfo) -> List[str]: instance = Instance(tp, [anytype] * len(tp.defn.type_vars)) for member in tp.protocol_members: - typ = find_member(member, instance, instance) + typ = get_proper_type(find_member(member, instance, instance)) if not isinstance(typ, CallableType): result.append(member) return result @@ -1015,6 +1019,9 @@ def restrict_subtype_away(t: Type, s: Type, *, ignore_promotions: bool = False) This is used for type inference of runtime type checks such as isinstance(). Currently this just removes elements of a union type. """ + t = get_proper_type(t) + s = get_proper_type(s) + if isinstance(t, UnionType): new_items = [item for item in t.relevant_items() if (isinstance(item, AnyType) or @@ -1026,6 +1033,8 @@ def restrict_subtype_away(t: Type, s: Type, *, ignore_promotions: bool = False) def covers_at_runtime(item: Type, supertype: Type, ignore_promotions: bool) -> bool: """Will isinstance(item, supertype) always return True at runtime?""" + item = get_proper_type(item) + # Since runtime type checks will ignore type arguments, erase the types. supertype = erase_type(supertype) if is_proper_subtype(erase_type(item), supertype, ignore_promotions=ignore_promotions, @@ -1053,6 +1062,9 @@ def is_proper_subtype(left: Type, right: Type, *, ignore_promotions: bool = Fals If erase_instances is True, erase left instance *after* mapping it to supertype (this is useful for runtime isinstance() checks). """ + left = get_proper_type(left) + right = get_proper_type(right) + if isinstance(right, UnionType) and not isinstance(left, UnionType): return any([is_proper_subtype(left, item, ignore_promotions=ignore_promotions, erase_instances=erase_instances) @@ -1062,7 +1074,7 @@ def is_proper_subtype(left: Type, right: Type, *, ignore_promotions: bool = Fals class ProperSubtypeVisitor(TypeVisitor[bool]): - def __init__(self, right: Type, *, + def __init__(self, right: ProperType, *, ignore_promotions: bool = False, erase_instances: bool = False) -> None: self.right = right @@ -1183,7 +1195,7 @@ def visit_tuple_type(self, left: TupleType) -> bool: is_named_instance(right, 'typing.Reversible')): if not right.args: return False - iter_type = right.args[0] + iter_type = get_proper_type(right.args[0]) if is_named_instance(right, 'builtins.tuple') and isinstance(iter_type, AnyType): # TODO: We shouldn't need this special case. This is currently needed # for isinstance(x, tuple), though it's unclear why. @@ -1249,7 +1261,7 @@ def visit_type_type(self, left: TypeType) -> bool: return True item = left.item if isinstance(item, TypeVarType): - item = item.upper_bound + item = get_proper_type(item.upper_bound) if isinstance(item, Instance): metaclass = item.type.metaclass_type return metaclass is not None and self._is_proper_subtype(metaclass, right) @@ -1264,6 +1276,7 @@ def is_more_precise(left: Type, right: Type, *, ignore_promotions: bool = False) any left. """ # TODO Should List[int] be more precise than List[Any]? + right = get_proper_type(right) if isinstance(right, AnyType): return True return is_proper_subtype(left, right, ignore_promotions=ignore_promotions) diff --git a/mypy/suggestions.py b/mypy/suggestions.py index cd39b38d3e41..51143e8b9e98 100644 --- a/mypy/suggestions.py +++ b/mypy/suggestions.py @@ -31,7 +31,7 @@ Type, AnyType, TypeOfAny, CallableType, UnionType, NoneType, Instance, TupleType, TypeVarType, FunctionLike, TypeStrVisitor, TypeTranslator, - is_optional, + is_optional, ProperType, get_proper_type, get_proper_types ) from mypy.build import State, Graph from mypy.nodes import ( @@ -137,6 +137,7 @@ def is_explicit_any(typ: AnyType) -> bool: def is_implicit_any(typ: Type) -> bool: + typ = get_proper_type(typ) return isinstance(typ, AnyType) and not is_explicit_any(typ) @@ -419,12 +420,12 @@ def extract_from_decorator(self, node: Decorator) -> Optional[FuncDef]: typ = None if (isinstance(dec, RefExpr) and isinstance(dec.node, FuncDef)): - typ = dec.node.type + typ = get_proper_type(dec.node.type) elif (isinstance(dec, CallExpr) and isinstance(dec.callee, RefExpr) and isinstance(dec.callee.node, FuncDef) and isinstance(dec.callee.node.type, CallableType)): - typ = dec.callee.node.type.ret_type + typ = get_proper_type(dec.callee.node.type.ret_type) if not isinstance(typ, FunctionLike): return None @@ -436,7 +437,7 @@ def extract_from_decorator(self, node: Decorator) -> Optional[FuncDef]: return node.func - def try_type(self, func: FuncDef, typ: Type) -> List[str]: + def try_type(self, func: FuncDef, typ: ProperType) -> List[str]: """Recheck a function while assuming it has type typ. Return all error messages. @@ -513,6 +514,7 @@ def score_type(self, t: Type) -> int: Lower is better, prefer non-union/non-any types. Don't penalize optionals. """ + t = get_proper_type(t) if isinstance(t, AnyType): return 20 if isinstance(t, UnionType): @@ -610,7 +612,8 @@ def count_errors(msgs: List[str]) -> int: def callable_has_any(t: CallableType) -> int: # We count a bare None in argument position as Any, since # pyannotate turns it into Optional[Any] - return any(isinstance(at, NoneType) for at in t.arg_types) or has_any_type(t) + return (any(isinstance(at, NoneType) for at in get_proper_types(t.arg_types)) + or has_any_type(t)) T = TypeVar('T') diff --git a/mypy/test/testsolve.py b/mypy/test/testsolve.py index 172e4e4743c4..4daebe7811e2 100644 --- a/mypy/test/testsolve.py +++ b/mypy/test/testsolve.py @@ -117,7 +117,7 @@ def assert_solve(self, ) -> None: res = [] # type: List[Optional[Type]] for r in results: - if isinstance(r, tuple): + if isinstance(r, tuple): # type: ignore res.append(r[0]) else: res.append(r) diff --git a/mypy/treetransform.py b/mypy/treetransform.py index 2ab1e8789330..400d9761608b 100644 --- a/mypy/treetransform.py +++ b/mypy/treetransform.py @@ -22,7 +22,7 @@ YieldExpr, ExecStmt, Argument, BackquoteExpr, AwaitExpr, OverloadPart, EnumCallExpr, REVEAL_TYPE ) -from mypy.types import Type, FunctionLike +from mypy.types import Type, FunctionLike, ProperType from mypy.traverser import TraverserVisitor from mypy.visitor import NodeVisitor from mypy.util import replace_object_state @@ -154,7 +154,9 @@ def visit_overloaded_func_def(self, node: OverloadedFuncDef) -> OverloadedFuncDe newitem.line = olditem.line new = OverloadedFuncDef(items) new._fullname = node._fullname - new.type = self.optional_type(node.type) + new_type = self.optional_type(node.type) + assert isinstance(new_type, ProperType) + new.type = new_type new.info = node.info new.is_static = node.is_static new.is_class = node.is_class diff --git a/mypy/type_visitor.py b/mypy/type_visitor.py index 3c8a0545f154..17cf2b059cb6 100644 --- a/mypy/type_visitor.py +++ b/mypy/type_visitor.py @@ -23,7 +23,7 @@ RawExpressionType, Instance, NoneType, TypeType, UnionType, TypeVarType, PartialType, DeletedType, UninhabitedType, TypeVarDef, UnboundType, ErasedType, StarType, EllipsisType, TypeList, CallableArgument, - PlaceholderType, + PlaceholderType, TypeAliasType ) @@ -102,6 +102,9 @@ def visit_partial_type(self, t: PartialType) -> T: def visit_type_type(self, t: TypeType) -> T: pass + def visit_type_alias_type(self, t: TypeAliasType) -> T: + raise NotImplementedError('TODO') + def visit_placeholder_type(self, t: PlaceholderType) -> T: raise RuntimeError('Internal error: unresolved placeholder type {}'.format(t.fullname)) @@ -163,7 +166,7 @@ def visit_instance(self, t: Instance) -> Type: last_known_value = None # type: Optional[LiteralType] if t.last_known_value is not None: raw_last_known_value = t.last_known_value.accept(self) - assert isinstance(raw_last_known_value, LiteralType) + assert isinstance(raw_last_known_value, LiteralType) # type: ignore last_known_value = raw_last_known_value return Instance( typ=t.type, @@ -203,7 +206,7 @@ def visit_typeddict_type(self, t: TypedDictType) -> Type: def visit_literal_type(self, t: LiteralType) -> Type: fallback = t.fallback.accept(self) - assert isinstance(fallback, Instance) + assert isinstance(fallback, Instance) # type: ignore return LiteralType( value=t.value, fallback=fallback, @@ -214,7 +217,7 @@ def visit_literal_type(self, t: LiteralType) -> Type: def visit_union_type(self, t: UnionType) -> Type: return UnionType(self.translate_types(t.items), t.line, t.column) - def translate_types(self, types: List[Type]) -> List[Type]: + def translate_types(self, types: Iterable[Type]) -> List[Type]: return [t.accept(self) for t in types] def translate_variables(self, @@ -225,10 +228,8 @@ def visit_overloaded(self, t: Overloaded) -> Type: items = [] # type: List[CallableType] for item in t.items(): new = item.accept(self) - if isinstance(new, CallableType): - items.append(new) - else: - raise RuntimeError('CallableType expected, but got {}'.format(type(new))) + assert isinstance(new, CallableType) # type: ignore + items.append(new) return Overloaded(items=items) def visit_type_type(self, t: TypeType) -> Type: diff --git a/mypy/typeanal.py b/mypy/typeanal.py index 7e8fc3b92909..5cd5db41ebf7 100644 --- a/mypy/typeanal.py +++ b/mypy/typeanal.py @@ -13,7 +13,7 @@ from mypy.types import ( Type, UnboundType, TypeVarType, TupleType, TypedDictType, UnionType, Instance, AnyType, CallableType, NoneType, DeletedType, TypeList, TypeVarDef, SyntheticTypeVisitor, - StarType, PartialType, EllipsisType, UninhabitedType, TypeType, get_typ_args, set_typ_args, + StarType, PartialType, EllipsisType, UninhabitedType, TypeType, replace_alias_tvars, CallableArgument, get_type_vars, TypeQuery, union_items, TypeOfAny, LiteralType, RawExpressionType, PlaceholderType ) @@ -814,7 +814,7 @@ def is_defined_type_var(self, tvar: str, context: Context) -> bool: return False return self.tvar_scope.get_binding(tvar_node) is not None - def anal_array(self, a: List[Type], nested: bool = True) -> List[Type]: + def anal_array(self, a: Iterable[Type], nested: bool = True) -> List[Type]: res = [] # type: List[Type] for t in a: res.append(self.anal_type(t, nested)) @@ -968,27 +968,6 @@ def expand_type_alias(target: Type, alias_tvars: List[str], args: List[Type], return typ -def replace_alias_tvars(tp: Type, vars: List[str], subs: List[Type], - newline: int, newcolumn: int) -> Type: - """Replace type variables in a generic type alias tp with substitutions subs - resetting context. Length of subs should be already checked. - """ - typ_args = get_typ_args(tp) - new_args = typ_args[:] - for i, arg in enumerate(typ_args): - if isinstance(arg, (UnboundType, TypeVarType)): - tvar = arg.name # type: Optional[str] - else: - tvar = None - if tvar and tvar in vars: - # Perform actual substitution... - new_args[i] = subs[vars.index(tvar)] - else: - # ...recursively, if needed. - new_args[i] = replace_alias_tvars(arg, vars, subs, newline, newcolumn) - return set_typ_args(tp, new_args, newline, newcolumn) - - def set_any_tvars(tp: Type, vars: List[str], newline: int, newcolumn: int, *, from_error: bool = False, diff --git a/mypy/types.py b/mypy/types.py index 071b6decf54f..a47a45f0949a 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -6,7 +6,7 @@ from typing import ( Any, TypeVar, Dict, List, Tuple, cast, Set, Optional, Union, Iterable, NamedTuple, - Sequence, Iterator, + Sequence, Iterator, overload ) from typing_extensions import ClassVar, Final, TYPE_CHECKING @@ -20,6 +20,18 @@ from mypy.util import IdMapper, replace_object_state from mypy.bogus_type import Bogus + +# Older versions of typing don't allow using overload outside stubs, +# so provide a dummy. +# mypyc doesn't like function declarations nested in if statements +def _overload(x: Any) -> Any: + return x + + +# mypyc doesn't like unreachable code, so trick mypy into thinking the branch is reachable +if bool() or sys.version_info < (3, 6): + overload = _overload # noqa + T = TypeVar('T') JsonDict = Dict[str, Any] @@ -144,6 +156,102 @@ def deserialize(cls, data: JsonDict) -> 'Type': raise NotImplementedError('Cannot deserialize {} instance'.format(cls.__name__)) +class TypeAliasType(Type): + """A type alias to another type. + + NOTE: this is not being used yet, and the implementation is still incomplete. + + To support recursive type aliases we don't immediately expand a type alias + during semantic analysis, but create an instance of this type that records the target alias + definition node (mypy.nodes.TypeAlias) and type arguments (for generic aliases). + + This is very similar to how TypeInfo vs Instance interact. + """ + + def __init__(self, alias: Optional[mypy.nodes.TypeAlias], args: List[Type], + line: int = -1, column: int = -1) -> None: + super().__init__(line, column) + self.alias = alias + self.args = args + self.type_ref = None # type: Optional[str] + + def _expand_once(self) -> Type: + """Expand to the target type exactly once. + + This doesn't do full expansion, i.e. the result can contain another + (or even this same) type alias. Use this internal helper only when really needed, + its public wrapper mypy.types.get_proper_type() is preferred. + """ + assert self.alias is not None + return replace_alias_tvars(self.alias.target, self.alias.alias_tvars, self.args, + self.line, self.column) + + def expand_all_if_possible(self) -> Optional['ProperType']: + """Attempt a full expansion of the type alias (including nested aliases). + + If the expansion is not possible, i.e. the alias is (mutually-)recursive, + return None. + """ + raise NotImplementedError('TODO') + + # TODO: remove ignore caused by https://github.com/python/mypy/issues/6759 + @property + def can_be_true(self) -> bool: # type: ignore + assert self.alias is not None + return self.alias.target.can_be_true + + # TODO: remove ignore caused by https://github.com/python/mypy/issues/6759 + @property + def can_be_false(self) -> bool: # type: ignore + assert self.alias is not None + return self.alias.target.can_be_false + + def accept(self, visitor: 'TypeVisitor[T]') -> T: + return visitor.visit_type_alias_type(self) + + def __hash__(self) -> int: + return hash((self.alias, tuple(self.args))) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, TypeAliasType): + return NotImplemented + return (self.alias == other.alias + and self.args == other.args) + + def serialize(self) -> JsonDict: + assert self.alias is not None + data = {'.class': 'TypeAliasType', + 'type_ref': self.alias.fullname(), + 'args': [arg.serialize() for arg in self.args]} # type: JsonDict + return data + + @classmethod + def deserialize(cls, data: JsonDict) -> 'TypeAliasType': + assert data['.class'] == 'TypeAliasType' + args = [] # type: List[Type] + if 'args' in data: + args_list = data['args'] + assert isinstance(args_list, list) + args = [deserialize_type(arg) for arg in args_list] + alias = TypeAliasType(None, args) + alias.type_ref = data['type_ref'] # TODO: fix this up in fixup.py. + return alias + + def copy_modified(self, *, + args: Optional[List[Type]] = None) -> 'TypeAliasType': + return TypeAliasType( + self.alias, + args if args is not None else self.args.copy(), + self.line, self.column) + + +class ProperType(Type): + """Not a type alias. + + Every type except TypeAliasType must inherit from this type. + """ + + class TypeVarId: # A type variable is uniquely identified by its raw id and meta level. @@ -262,7 +370,7 @@ def deserialize(cls, data: JsonDict) -> 'TypeVarDef': ) -class UnboundType(Type): +class UnboundType(ProperType): """Instance type that has not been bound during semantic analysis.""" __slots__ = ('name', 'args', 'optional', 'empty_tuple_index', @@ -351,7 +459,7 @@ def deserialize(cls, data: JsonDict) -> 'UnboundType': ) -class CallableArgument(Type): +class CallableArgument(ProperType): """Represents a Arg(type, 'name') inside a Callable's type list. Note that this is a synthetic type for helping parse ASTs, not a real type. @@ -375,7 +483,7 @@ def serialize(self) -> JsonDict: assert False, "Synthetic types don't serialize" -class TypeList(Type): +class TypeList(ProperType): """Information about argument types and names [...]. This is used for the arguments of a Callable type, i.e. for @@ -398,7 +506,7 @@ def serialize(self) -> JsonDict: assert False, "Synthetic types don't serialize" -class AnyType(Type): +class AnyType(ProperType): """The type 'Any'.""" __slots__ = ('type_of_any', 'source_any', 'missing_import_name') @@ -470,7 +578,7 @@ def deserialize(cls, data: JsonDict) -> 'AnyType': data['missing_import_name']) -class UninhabitedType(Type): +class UninhabitedType(ProperType): """This type has no members. This type is the bottom type. @@ -519,7 +627,7 @@ def deserialize(cls, data: JsonDict) -> 'UninhabitedType': return UninhabitedType(is_noreturn=data['is_noreturn']) -class NoneType(Type): +class NoneType(ProperType): """The type of 'None'. This type can be written by users as 'None'. @@ -556,7 +664,7 @@ def deserialize(cls, data: JsonDict) -> 'NoneType': NoneTyp = NoneType -class ErasedType(Type): +class ErasedType(ProperType): """Placeholder for an erased type. This is used during type inference. This has the special property that @@ -567,7 +675,7 @@ def accept(self, visitor: 'TypeVisitor[T]') -> T: return visitor.visit_erased_type(self) -class DeletedType(Type): +class DeletedType(ProperType): """Type of deleted variables. These can be used as lvalues but not rvalues. @@ -596,7 +704,7 @@ def deserialize(cls, data: JsonDict) -> 'DeletedType': NOT_READY = mypy.nodes.FakeInfo('De-serialization failure: TypeInfo not fixed') # type: Final -class Instance(Type): +class Instance(ProperType): """An instance type of form C[T1, ..., Tn]. The list of type variables may be empty. @@ -723,7 +831,7 @@ def has_readable_member(self, name: str) -> bool: return self.type.has_readable_member(name) -class TypeVarType(Type): +class TypeVarType(ProperType): """A type variable type. This refers to either a class type variable (id > 0) or a function @@ -784,7 +892,7 @@ def deserialize(cls, data: JsonDict) -> 'TypeVarType': return TypeVarType(tvdef) -class FunctionLike(Type): +class FunctionLike(ProperType): """Abstract base class for function types.""" __slots__ = ('fallback',) @@ -847,7 +955,7 @@ class CallableType(FunctionLike): ) def __init__(self, - arg_types: List[Type], + arg_types: Sequence[Type], arg_kinds: List[int], arg_names: Sequence[Optional[str]], ret_type: Type, @@ -868,7 +976,7 @@ def __init__(self, assert len(arg_types) == len(arg_kinds) == len(arg_names) if variables is None: variables = [] - self.arg_types = arg_types + self.arg_types = list(arg_types) self.arg_kinds = arg_kinds self.arg_names = list(arg_names) self.min_args = arg_kinds.count(ARG_POS) @@ -899,7 +1007,7 @@ def __init__(self, self.def_extras = {} def copy_modified(self, - arg_types: Bogus[List[Type]] = _dummy, + arg_types: Bogus[Sequence[Type]] = _dummy, arg_kinds: Bogus[List[int]] = _dummy, arg_names: Bogus[List[Optional[str]]] = _dummy, ret_type: Bogus[Type] = _dummy, @@ -964,9 +1072,9 @@ def is_type_obj(self) -> bool: def type_object(self) -> mypy.nodes.TypeInfo: assert self.is_type_obj() - ret = self.ret_type + ret = get_proper_type(self.ret_type) if isinstance(ret, TypeVarType): - ret = ret.upper_bound + ret = get_proper_type(ret.upper_bound) if isinstance(ret, TupleType): ret = ret.partial_fallback assert isinstance(ret, Instance) @@ -1216,7 +1324,7 @@ def deserialize(cls, data: JsonDict) -> 'Overloaded': return Overloaded([CallableType.deserialize(t) for t in data['items']]) -class TupleType(Type): +class TupleType(ProperType): """The tuple type Tuple[T1, ..., Tn] (at least one type argument). Instance variables: @@ -1284,7 +1392,7 @@ def slice(self, begin: Optional[int], end: Optional[int], self.line, self.column, self.implicit) -class TypedDictType(Type): +class TypedDictType(ProperType): """Type of TypedDict object {'k1': v1, ..., 'kn': vn}. A TypedDict object is a dictionary with specific string (literal) keys. Each @@ -1398,7 +1506,7 @@ def zipall(self, right: 'TypedDictType') \ yield (item_name, None, right_item_type) -class RawExpressionType(Type): +class RawExpressionType(ProperType): """A synthetic type representing some arbitrary expression that does not cleanly translate into a type. @@ -1474,7 +1582,7 @@ def __eq__(self, other: object) -> bool: return NotImplemented -class LiteralType(Type): +class LiteralType(ProperType): """The type of a Literal instance. Literal[Value] A Literal always consists of: @@ -1556,7 +1664,7 @@ def deserialize(cls, data: JsonDict) -> 'LiteralType': ) -class StarType(Type): +class StarType(ProperType): """The star type *type_parameter. This is not a real type but a syntactic AST construct. @@ -1576,14 +1684,14 @@ def serialize(self) -> JsonDict: assert False, "Synthetic types don't serialize" -class UnionType(Type): +class UnionType(ProperType): """The union type Union[T1, ..., Tn] (at least one type argument).""" __slots__ = ('items',) - def __init__(self, items: List[Type], line: int = -1, column: int = -1) -> None: + def __init__(self, items: Sequence[Type], line: int = -1, column: int = -1) -> None: super().__init__(line, column) - self.items = flatten_nested_unions(items) # type: List[Type] + self.items = flatten_nested_unions(items) self.can_be_true = any(item.can_be_true for item in items) self.can_be_false = any(item.can_be_false for item in items) @@ -1595,8 +1703,15 @@ def __eq__(self, other: object) -> bool: return NotImplemented return frozenset(self.items) == frozenset(other.items) + @overload + @staticmethod + def make_union(items: List[ProperType], line: int = -1, column: int = -1) -> ProperType: ... + @overload # noqa @staticmethod - def make_union(items: List[Type], line: int = -1, column: int = -1) -> Type: + def make_union(items: List[Type], line: int = -1, column: int = -1) -> Type: ... + + @staticmethod # noqa + def make_union(items: Sequence[Type], line: int = -1, column: int = -1) -> Type: if len(items) > 1: return UnionType(items, line, column) elif len(items) == 1: @@ -1605,7 +1720,8 @@ def make_union(items: List[Type], line: int = -1, column: int = -1) -> Type: return UninhabitedType() @staticmethod - def make_simplified_union(items: Sequence[Type], line: int = -1, column: int = -1) -> Type: + def make_simplified_union(items: Sequence[Type], + line: int = -1, column: int = -1) -> ProperType: """Build union type with redundant union items removed. If only a single item remains, this may return a non-union type. @@ -1623,9 +1739,9 @@ def make_simplified_union(items: Sequence[Type], line: int = -1, column: int = - """ # TODO: Make this a function living somewhere outside mypy.types. Most other non-trivial # type operations are not static methods, so this is inconsistent. - items = list(items) + items = get_proper_types(items) while any(isinstance(typ, UnionType) for typ in items): - all_items = [] # type: List[Type] + all_items = [] # type: List[ProperType] for typ in items: if isinstance(typ, UnionType): all_items.extend(typ.items) @@ -1671,7 +1787,7 @@ def has_readable_member(self, name: str) -> bool: (isinstance(x, Instance) and x.type.has_readable_member(name)) for x in self.relevant_items()) - def relevant_items(self) -> List[Type]: + def relevant_items(self) -> List[ProperType]: """Removes NoneTypes from Unions when strict Optional checking is off.""" if state.strict_optional: return self.items @@ -1689,7 +1805,7 @@ def deserialize(cls, data: JsonDict) -> 'UnionType': return UnionType([deserialize_type(t) for t in data['items']]) -class PartialType(Type): +class PartialType(ProperType): """Type such as List[?] where type arguments are unknown, or partial None type. These are used for inferring types in multiphase initialization such as this: @@ -1722,7 +1838,7 @@ def accept(self, visitor: 'TypeVisitor[T]') -> T: return visitor.visit_partial_type(self) -class EllipsisType(Type): +class EllipsisType(ProperType): """The type ... (ellipsis). This is not a real type but a syntactic AST construct, used in Callable[..., T], for example. @@ -1738,7 +1854,7 @@ def serialize(self) -> JsonDict: assert False, "Synthetic types don't serialize" -class TypeType(Type): +class TypeType(ProperType): """For types like Type[User]. This annotates variables that are class objects, constrained by @@ -1768,7 +1884,7 @@ class TypeType(Type): # This can't be everything, but it can be a class reference, # a generic class instance, a union, Any, a type variable... - item = None # type: Type + item = None # type: ProperType def __init__(self, item: Bogus[Union[Instance, AnyType, TypeVarType, TupleType, NoneType, CallableType]], *, @@ -1780,7 +1896,8 @@ def __init__(self, item: Bogus[Union[Instance, AnyType, TypeVarType, TupleType, self.item = item @staticmethod - def make_normalized(item: Type, *, line: int = -1, column: int = -1) -> Type: + def make_normalized(item: Type, *, line: int = -1, column: int = -1) -> ProperType: + item = get_proper_type(item) if isinstance(item, UnionType): return UnionType.make_union( [TypeType.make_normalized(union_item) for union_item in item.items], @@ -1808,7 +1925,7 @@ def deserialize(cls, data: JsonDict) -> Type: return TypeType.make_normalized(deserialize_type(data['item'])) -class PlaceholderType(Type): +class PlaceholderType(ProperType): """Temporary, yet-unknown type during semantic analysis. This is needed when there's a reference to a type before the real symbol @@ -1942,7 +2059,7 @@ def visit_callable_type(self, t: CallableType) -> str: s = '({})'.format(s) - if not isinstance(t.ret_type, NoneType): + if not isinstance(get_proper_type(t.ret_type), NoneType): s += ' -> {}'.format(t.ret_type.accept(self)) if t.variables: @@ -2019,7 +2136,12 @@ def visit_type_type(self, t: TypeType) -> str: def visit_placeholder_type(self, t: PlaceholderType) -> str: return ''.format(t.fullname) - def list_str(self, a: List[Type]) -> str: + def visit_type_alias_type(self, t: TypeAliasType) -> str: + if t.alias is not None: + return ''.format(t.alias.fullname()) + return '' + + def list_str(self, a: Iterable[Type]) -> str: """Convert items of an array to strings (pretty-print types) and join the results with commas. """ @@ -2029,9 +2151,9 @@ def list_str(self, a: List[Type]) -> str: return ', '.join(res) -def strip_type(typ: Type) -> Type: +def strip_type(typ: Type) -> ProperType: """Make a copy of type without 'debugging info' (function name).""" - + typ = get_proper_type(typ) if isinstance(typ, CallableType): return typ.copy_modified(name=None) elif isinstance(typ, Overloaded): @@ -2042,10 +2164,14 @@ def strip_type(typ: Type) -> Type: def is_named_instance(t: Type, fullname: str) -> bool: + t = get_proper_type(t) return isinstance(t, Instance) and t.type.fullname() == fullname -def copy_type(t: Type) -> Type: +TP = TypeVar('TP', bound=Type) + + +def copy_type(t: TP) -> TP: """ Build a copy of the type; used to mutate the copy with truthiness information """ @@ -2058,10 +2184,12 @@ def copy_type(t: Type) -> Type: return nt -def true_only(t: Type) -> Type: +def true_only(t: Type) -> ProperType: """ Restricted version of t with only True-ish values """ + t = get_proper_type(t) + if not t.can_be_true: # All values of t are False-ish, so there are no true values in it return UninhabitedType(line=t.line, column=t.column) @@ -2078,10 +2206,12 @@ def true_only(t: Type) -> Type: return new_t -def false_only(t: Type) -> Type: +def false_only(t: Type) -> ProperType: """ Restricted version of t with only False-ish values """ + t = get_proper_type(t) + if not t.can_be_false: if state.strict_optional: # All values of t are True-ish, so there are no false values in it @@ -2103,10 +2233,12 @@ def false_only(t: Type) -> Type: return new_t -def true_or_false(t: Type) -> Type: +def true_or_false(t: Type) -> ProperType: """ Unrestricted version of t with both True-ish and False-ish values """ + t = get_proper_type(t) + if isinstance(t, UnionType): new_items = [true_or_false(item) for item in t.items] return UnionType.make_simplified_union(new_items, line=t.line, column=t.column) @@ -2155,18 +2287,44 @@ def callable_type(fdef: mypy.nodes.FuncItem, fallback: Instance, ) +def replace_alias_tvars(tp: Type, vars: List[str], subs: List[Type], + newline: int, newcolumn: int) -> Type: + """Replace type variables in a generic type alias tp with substitutions subs + resetting context. Length of subs should be already checked. + """ + typ_args = get_typ_args(tp) + new_args = typ_args[:] + for i, arg in enumerate(typ_args): + if isinstance(arg, (UnboundType, TypeVarType)): # type: ignore + tvar = arg.name # type: Optional[str] + else: + tvar = None + if tvar and tvar in vars: + # Perform actual substitution... + new_args[i] = subs[vars.index(tvar)] + else: + # ...recursively, if needed. + new_args[i] = replace_alias_tvars(arg, vars, subs, newline, newcolumn) + return set_typ_args(tp, new_args, newline, newcolumn) + + def get_typ_args(tp: Type) -> List[Type]: """Get all type arguments from a parametrizable Type.""" + # TODO: replace this and related functions with proper visitors. + tp = get_proper_type(tp) # TODO: is this really needed? + if not isinstance(tp, (Instance, UnionType, TupleType, CallableType)): return [] typ_args = (tp.args if isinstance(tp, Instance) else tp.items if not isinstance(tp, CallableType) else tp.arg_types + [tp.ret_type]) - return typ_args + return cast(List[Type], typ_args) def set_typ_args(tp: Type, new_args: List[Type], line: int = -1, column: int = -1) -> Type: """Return a copy of a parametrizable Type with arguments set to new_args.""" + tp = get_proper_type(tp) # TODO: is this really needed? + if isinstance(tp, Instance): return Instance(tp.type, new_args, line, column) if isinstance(tp, TupleType): @@ -2200,10 +2358,13 @@ def get_type_vars(typ: Type) -> List[TypeVarType]: return tvars -def flatten_nested_unions(types: Iterable[Type]) -> List[Type]: +def flatten_nested_unions(types: Iterable[Type]) -> List[ProperType]: """Flatten nested unions in a type list.""" - flat_items = [] # type: List[Type] - for tp in types: + # This and similar functions on unions can cause infinite recursion + # if passed a "pathological" alias like A = Union[int, A] or similar. + # TODO: ban such aliases in semantic analyzer. + flat_items = [] # type: List[ProperType] + for tp in get_proper_types(types): if isinstance(tp, UnionType): flat_items.extend(flatten_nested_unions(tp.items)) else: @@ -2211,11 +2372,12 @@ def flatten_nested_unions(types: Iterable[Type]) -> List[Type]: return flat_items -def union_items(typ: Type) -> List[Type]: +def union_items(typ: Type) -> List[ProperType]: """Return the flattened items of a union type. For non-union types, return a list containing just the argument. """ + typ = get_proper_type(typ) if isinstance(typ, UnionType): items = [] for item in typ.items: @@ -2226,24 +2388,52 @@ def union_items(typ: Type) -> List[Type]: def is_generic_instance(tp: Type) -> bool: + tp = get_proper_type(tp) return isinstance(tp, Instance) and bool(tp.args) def is_optional(t: Type) -> bool: + t = get_proper_type(t) return isinstance(t, UnionType) and any(isinstance(e, NoneType) for e in t.items) -def remove_optional(typ: Type) -> Type: +def remove_optional(typ: Type) -> ProperType: + typ = get_proper_type(typ) if isinstance(typ, UnionType): return UnionType.make_union([t for t in typ.items if not isinstance(t, NoneType)]) else: return typ +@overload +def get_proper_type(typ: None) -> None: ... +@overload # noqa +def get_proper_type(typ: Type) -> ProperType: ... + + +def get_proper_type(typ: Optional[Type]) -> Optional[ProperType]: # noqa + if typ is None: + return None + while isinstance(typ, TypeAliasType): + typ = typ._expand_once() + assert isinstance(typ, ProperType), typ + return typ + + +@overload +def get_proper_types(it: Iterable[Type]) -> List[ProperType]: ... +@overload # noqa +def get_proper_types(typ: Iterable[Optional[Type]]) -> List[Optional[ProperType]]: ... + + +def get_proper_types(it: Iterable[Optional[Type]]) -> List[Optional[ProperType]]: # type: ignore # noqa + return [get_proper_type(t) for t in it] + + names = globals().copy() # type: Final names.pop('NOT_READY', None) deserialize_map = { - key: obj.deserialize # type: ignore + key: obj.deserialize for key, obj in names.items() if isinstance(obj, type) and issubclass(obj, Type) and obj is not Type } # type: Final diff --git a/mypy_self_check.ini b/mypy_self_check.ini index fcc26f49a4b6..b21036da1905 100644 --- a/mypy_self_check.ini +++ b/mypy_self_check.ini @@ -13,3 +13,4 @@ warn_redundant_casts = True warn_unused_configs = True show_traceback = True always_false = MYPYC +plugins = misc/proper_plugin.py