diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 6ee2e2f6117d..50bc29a26ed0 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -20,7 +20,7 @@ from mypy.maptype import map_instance_to_supertype from mypy.expandtype import expand_type_by_instance from mypy.sametypes import is_same_type -from mypy.typestate import TypeState +from mypy.typestate import TypeState, SubtypeKind from mypy import experiments @@ -46,7 +46,8 @@ def check_type_parameter(lefta: Type, righta: Type, variance: int) -> bool: def is_subtype(left: Type, right: Type, type_parameter_checker: TypeParameterChecker = check_type_parameter, *, ignore_pos_arg_names: bool = False, - ignore_declared_variance: bool = False) -> bool: + ignore_declared_variance: bool = False, + ignore_promotions: bool = False) -> bool: """Is 'left' subtype of 'right'? Also consider Any to be a subtype of any type, and vice versa. This @@ -66,7 +67,9 @@ def is_subtype(left: Type, right: Type, # 'left' can be a subtype of the union 'right' is if it is a # subtype of one of the items making up the union. is_subtype_of_item = any(is_subtype(left, item, type_parameter_checker, - ignore_pos_arg_names=ignore_pos_arg_names) + ignore_pos_arg_names=ignore_pos_arg_names, + ignore_declared_variance=ignore_declared_variance, + ignore_promotions=ignore_promotions) for item in right.items) # However, if 'left' is a type variable T, T might also have # an upper bound which is itself a union. This case will be @@ -81,7 +84,8 @@ def is_subtype(left: Type, right: Type, # otherwise, fall through return left.accept(SubtypeVisitor(right, type_parameter_checker, ignore_pos_arg_names=ignore_pos_arg_names, - ignore_declared_variance=ignore_declared_variance)) + ignore_declared_variance=ignore_declared_variance, + ignore_promotions=ignore_promotions)) def is_subtype_ignoring_tvars(left: Type, right: Type) -> bool: @@ -106,11 +110,43 @@ class SubtypeVisitor(TypeVisitor[bool]): def __init__(self, right: Type, type_parameter_checker: TypeParameterChecker, *, ignore_pos_arg_names: bool = False, - ignore_declared_variance: bool = False) -> None: + ignore_declared_variance: bool = False, + ignore_promotions: bool = False) -> None: self.right = right self.check_type_parameter = type_parameter_checker self.ignore_pos_arg_names = ignore_pos_arg_names self.ignore_declared_variance = ignore_declared_variance + self.ignore_promotions = ignore_promotions + self._subtype_kind = SubtypeVisitor.build_subtype_kind( + type_parameter_checker=type_parameter_checker, + ignore_pos_arg_names=ignore_pos_arg_names, + ignore_declared_variance=ignore_declared_variance, + ignore_promotions=ignore_promotions) + + @staticmethod + def build_subtype_kind(*, + type_parameter_checker: TypeParameterChecker = check_type_parameter, + ignore_pos_arg_names: bool = False, + ignore_declared_variance: bool = False, + ignore_promotions: bool = False) -> SubtypeKind: + return ('subtype', + type_parameter_checker, + ignore_pos_arg_names, + ignore_declared_variance, + ignore_promotions) + + def _lookup_cache(self, left: Instance, right: Instance) -> bool: + return TypeState.is_cached_subtype_check(self._subtype_kind, left, right) + + def _record_cache(self, left: Instance, right: Instance) -> None: + TypeState.record_subtype_cache_entry(self._subtype_kind, left, right) + + def _is_subtype(self, left: Type, right: Type) -> bool: + return is_subtype(left, right, + type_parameter_checker=self.check_type_parameter, + ignore_pos_arg_names=self.ignore_pos_arg_names, + ignore_declared_variance=self.ignore_declared_variance, + ignore_promotions=self.ignore_promotions) # visit_x(left) means: is left (which is an instance of X) a subtype of # right? @@ -150,17 +186,15 @@ def visit_instance(self, left: Instance) -> bool: return True right = self.right if isinstance(right, TupleType) and right.fallback.type.is_enum: - return is_subtype(left, right.fallback) + return self._is_subtype(left, right.fallback) if isinstance(right, Instance): - if TypeState.is_cached_subtype_check(left, right): + if self._lookup_cache(left, right): return True - for base in left.type.mro: - # TODO: Also pass recursively ignore_declared_variance - if base._promote and is_subtype( - base._promote, self.right, self.check_type_parameter, - ignore_pos_arg_names=self.ignore_pos_arg_names): - TypeState.record_subtype_cache_entry(left, right) - return True + if not self.ignore_promotions: + for base in left.type.mro: + if base._promote and self._is_subtype(base._promote, self.right): + self._record_cache(left, right) + return True rname = right.type.fullname() # Always try a nominal check if possible, # there might be errors that a user wants to silence *once*. @@ -172,7 +206,7 @@ def visit_instance(self, left: Instance) -> bool: for lefta, righta, tvar in zip(t.args, right.args, right.type.defn.type_vars)) if nominal: - TypeState.record_subtype_cache_entry(left, right) + self._record_cache(left, right) return nominal if right.type.is_protocol and is_protocol_implementation(left, right): return True @@ -182,7 +216,7 @@ def visit_instance(self, left: Instance) -> bool: if isinstance(item, TupleType): item = item.fallback if is_named_instance(left, 'builtins.type'): - return is_subtype(TypeType(AnyType(TypeOfAny.special_form)), right) + return self._is_subtype(TypeType(AnyType(TypeOfAny.special_form)), right) if left.type.is_metaclass(): if isinstance(item, AnyType): return True @@ -192,7 +226,7 @@ def visit_instance(self, left: Instance) -> bool: # Special case: Instance can be a subtype of Callable. call = find_member('__call__', left, left) if call: - return is_subtype(call, right) + return self._is_subtype(call, right) return False else: return False @@ -201,27 +235,24 @@ def visit_type_var(self, left: TypeVarType) -> bool: right = self.right if isinstance(right, TypeVarType) and left.id == right.id: return True - if left.values and is_subtype(UnionType.make_simplified_union(left.values), right): + if left.values and self._is_subtype(UnionType.make_simplified_union(left.values), right): return True - return is_subtype(left.upper_bound, self.right) + return self._is_subtype(left.upper_bound, self.right) def visit_callable_type(self, left: CallableType) -> bool: right = self.right if isinstance(right, CallableType): return is_callable_compatible( left, right, - is_compat=is_subtype, + is_compat=self._is_subtype, ignore_pos_arg_names=self.ignore_pos_arg_names) elif isinstance(right, Overloaded): - return all(is_subtype(left, item, self.check_type_parameter, - ignore_pos_arg_names=self.ignore_pos_arg_names) - for item in right.items()) + return all(self._is_subtype(left, item) for item in right.items()) elif isinstance(right, Instance): - return is_subtype(left.fallback, right, - ignore_pos_arg_names=self.ignore_pos_arg_names) + return self._is_subtype(left.fallback, right) elif isinstance(right, TypeType): # This is unsound, we don't check the __init__ signature. - return left.is_type_obj() and is_subtype(left.ret_type, right.item) + return left.is_type_obj() and self._is_subtype(left.ret_type, right.item) else: return False @@ -239,17 +270,17 @@ def visit_tuple_type(self, left: TupleType) -> bool: iter_type = right.args[0] else: iter_type = AnyType(TypeOfAny.special_form) - return all(is_subtype(li, iter_type) for li in left.items) - elif is_subtype(left.fallback, right, self.check_type_parameter): + return all(self._is_subtype(li, iter_type) for li in left.items) + elif self._is_subtype(left.fallback, right): return True return False elif isinstance(right, TupleType): if len(left.items) != len(right.items): return False for l, r in zip(left.items, right.items): - if not is_subtype(l, r, self.check_type_parameter): + if not self._is_subtype(l, r): return False - if not is_subtype(left.fallback, right.fallback, self.check_type_parameter): + if not self._is_subtype(left.fallback, right.fallback): return False return True else: @@ -258,7 +289,7 @@ def visit_tuple_type(self, left: TupleType) -> bool: def visit_typeddict_type(self, left: TypedDictType) -> bool: right = self.right if isinstance(right, Instance): - return is_subtype(left.fallback, right, self.check_type_parameter) + return self._is_subtype(left.fallback, right) elif isinstance(right, TypedDictType): if not left.names_are_wider_than(right): return False @@ -284,11 +315,10 @@ def visit_typeddict_type(self, left: TypedDictType) -> bool: def visit_overloaded(self, left: Overloaded) -> bool: right = self.right if isinstance(right, Instance): - return is_subtype(left.fallback, right) + return self._is_subtype(left.fallback, right) elif isinstance(right, CallableType): for item in left.items(): - if is_subtype(item, right, self.check_type_parameter, - ignore_pos_arg_names=self.ignore_pos_arg_names): + if self._is_subtype(item, right): return True return False elif isinstance(right, Overloaded): @@ -301,8 +331,7 @@ def visit_overloaded(self, left: Overloaded) -> bool: found_match = False for left_index, left_item in enumerate(left.items()): - subtype_match = is_subtype(left_item, right_item, self.check_type_parameter, - ignore_pos_arg_names=self.ignore_pos_arg_names) + subtype_match = self._is_subtype(left_item, right_item)\ # Order matters: we need to make sure that the index of # this item is at least the index of the previous one. @@ -317,10 +346,10 @@ def visit_overloaded(self, left: Overloaded) -> bool: # If this one overlaps with the supertype in any way, but it wasn't # an exact match, then it's a potential error. if (is_callable_compatible(left_item, right_item, - is_compat=is_subtype, ignore_return=True, + is_compat=self._is_subtype, ignore_return=True, ignore_pos_arg_names=self.ignore_pos_arg_names) or is_callable_compatible(right_item, left_item, - is_compat=is_subtype, ignore_return=True, + is_compat=self._is_subtype, ignore_return=True, ignore_pos_arg_names=self.ignore_pos_arg_names)): # If this is an overload that's already been matched, there's no # problem. @@ -341,13 +370,12 @@ def visit_overloaded(self, left: Overloaded) -> bool: # All the items must have the same type object status, so # it's sufficient to query only (any) one of them. # This is unsound, we don't check all the __init__ signatures. - return left.is_type_obj() and is_subtype(left.items()[0], right) + return left.is_type_obj() and self._is_subtype(left.items()[0], right) else: return False def visit_union_type(self, left: UnionType) -> bool: - return all(is_subtype(item, self.right, self.check_type_parameter) - for item in left.items) + return all(self._is_subtype(item, self.right) for item in left.items) def visit_partial_type(self, left: PartialType) -> bool: # This is indeterminate as we don't really know the complete type yet. @@ -356,10 +384,10 @@ def visit_partial_type(self, left: PartialType) -> bool: def visit_type_type(self, left: TypeType) -> bool: right = self.right if isinstance(right, TypeType): - return is_subtype(left.item, right.item) + return self._is_subtype(left.item, right.item) if isinstance(right, CallableType): # This is unsound, we don't check the __init__ signature. - return is_subtype(left.item, right.ret_type) + return self._is_subtype(left.item, right.ret_type) if isinstance(right, Instance): if right.type.fullname() in ['builtins.object', 'builtins.type']: return True @@ -368,7 +396,7 @@ def visit_type_type(self, left: TypeType) -> bool: item = item.upper_bound if isinstance(item, Instance): metaclass = item.type.metaclass_type - return metaclass is not None and is_subtype(metaclass, right) + return metaclass is not None and self._is_subtype(metaclass, right) return False @@ -423,6 +451,8 @@ def f(self) -> A: ... return False if not proper_subtype: # Nominal check currently ignores arg names + # NOTE: If we ever change this, be sure to also change the call to + # SubtypeVisitor.build_subtype_kind(...) down below. is_compat = is_subtype(subtype, supertype, ignore_pos_arg_names=True) else: is_compat = is_proper_subtype(subtype, supertype) @@ -444,10 +474,13 @@ def f(self) -> A: ... # This rule is copied from nominal check in checker.py if IS_CLASS_OR_STATIC in superflags and IS_CLASS_OR_STATIC not in subflags: return False - if proper_subtype: - TypeState.record_proper_subtype_cache_entry(left, right) + + if not proper_subtype: + # Nominal check currently ignores arg names + subtype_kind = SubtypeVisitor.build_subtype_kind(ignore_pos_arg_names=True) else: - TypeState.record_subtype_cache_entry(left, right) + subtype_kind = ProperSubtypeVisitor.build_subtype_kind() + TypeState.record_subtype_cache_entry(subtype_kind, left, right) return True @@ -961,21 +994,38 @@ def restrict_subtype_away(t: Type, s: Type) -> Type: return t -def is_proper_subtype(left: Type, right: Type) -> bool: +def is_proper_subtype(left: Type, right: Type, *, ignore_promotions: bool = False) -> bool: """Is left a proper subtype of right? For proper subtypes, there's no need to rely on compatibility due to Any types. Every usable type is a proper subtype of itself. """ if isinstance(right, UnionType) and not isinstance(left, UnionType): - return any([is_proper_subtype(left, item) + return any([is_proper_subtype(left, item, ignore_promotions=ignore_promotions) for item in right.items]) - return left.accept(ProperSubtypeVisitor(right)) + return left.accept(ProperSubtypeVisitor(right, ignore_promotions=ignore_promotions)) class ProperSubtypeVisitor(TypeVisitor[bool]): - def __init__(self, right: Type) -> None: + def __init__(self, right: Type, *, ignore_promotions: bool = False) -> None: self.right = right + self.ignore_promotions = ignore_promotions + self._subtype_kind = ProperSubtypeVisitor.build_subtype_kind( + ignore_promotions=ignore_promotions, + ) + + @staticmethod + def build_subtype_kind(*, ignore_promotions: bool = False) -> SubtypeKind: + return ('subtype_proper', ignore_promotions) + + def _lookup_cache(self, left: Instance, right: Instance) -> bool: + return TypeState.is_cached_subtype_check(self._subtype_kind, left, right) + + def _record_cache(self, left: Instance, right: Instance) -> None: + TypeState.record_subtype_cache_entry(self._subtype_kind, left, right) + + def _is_proper_subtype(self, left: Type, right: Type) -> bool: + return is_proper_subtype(left, right, ignore_promotions=self.ignore_promotions) def visit_unbound_type(self, left: UnboundType) -> bool: # This can be called if there is a bad type annotation. The result probably @@ -1006,19 +1056,20 @@ def visit_deleted_type(self, left: DeletedType) -> bool: def visit_instance(self, left: Instance) -> bool: right = self.right if isinstance(right, Instance): - if TypeState.is_cached_proper_subtype_check(left, right): + if self._lookup_cache(left, right): return True - for base in left.type.mro: - if base._promote and is_proper_subtype(base._promote, right): - TypeState.record_proper_subtype_cache_entry(left, right) - return True + if not self.ignore_promotions: + for base in left.type.mro: + if base._promote and self._is_proper_subtype(base._promote, right): + self._record_cache(left, right) + return True if left.type.has_base(right.type.fullname()): def check_argument(leftarg: Type, rightarg: Type, variance: int) -> bool: if variance == COVARIANT: - return is_proper_subtype(leftarg, rightarg) + return self._is_proper_subtype(leftarg, rightarg) elif variance == CONTRAVARIANT: - return is_proper_subtype(rightarg, leftarg) + return self._is_proper_subtype(rightarg, leftarg) else: return sametypes.is_same_type(leftarg, rightarg) # Map left type to corresponding right instances. @@ -1027,7 +1078,7 @@ def check_argument(leftarg: Type, rightarg: Type, variance: int) -> bool: nominal = all(check_argument(ta, ra, tvar.variance) for ta, ra, tvar in zip(left.args, right.args, right.type.defn.type_vars)) if nominal: - TypeState.record_proper_subtype_cache_entry(left, right) + self._record_cache(left, right) return nominal if (right.type.is_protocol and is_protocol_implementation(left, right, proper_subtype=True)): @@ -1036,29 +1087,30 @@ def check_argument(leftarg: Type, rightarg: Type, variance: int) -> bool: if isinstance(right, CallableType): call = find_member('__call__', left, left) if call: - return is_proper_subtype(call, right) + return self._is_proper_subtype(call, right) return False return False def visit_type_var(self, left: TypeVarType) -> bool: if isinstance(self.right, TypeVarType) and left.id == self.right.id: return True - if left.values and is_subtype(UnionType.make_simplified_union(left.values), self.right): + if left.values and is_subtype(UnionType.make_simplified_union(left.values), self.right, + ignore_promotions=self.ignore_promotions): return True - return is_proper_subtype(left.upper_bound, self.right) + return self._is_proper_subtype(left.upper_bound, self.right) def visit_callable_type(self, left: CallableType) -> bool: right = self.right if isinstance(right, CallableType): - return is_callable_compatible(left, right, is_compat=is_proper_subtype) + return is_callable_compatible(left, right, is_compat=self._is_proper_subtype) elif isinstance(right, Overloaded): - return all(is_proper_subtype(left, item) + return all(self._is_proper_subtype(left, item) for item in right.items()) elif isinstance(right, Instance): - return is_proper_subtype(left.fallback, right) + return self._is_proper_subtype(left.fallback, right) elif isinstance(right, TypeType): # This is unsound, we don't check the __init__ signature. - return left.is_type_obj() and is_proper_subtype(left.ret_type, right.item) + return left.is_type_obj() and self._is_proper_subtype(left.ret_type, right.item) return False def visit_tuple_type(self, left: TupleType) -> bool: @@ -1076,15 +1128,15 @@ def visit_tuple_type(self, left: TupleType) -> bool: # TODO: We shouldn't need this special case. This is currently needed # for isinstance(x, tuple), though it's unclear why. return True - return all(is_proper_subtype(li, iter_type) for li in left.items) - return is_proper_subtype(left.fallback, right) + return all(self._is_proper_subtype(li, iter_type) for li in left.items) + return self._is_proper_subtype(left.fallback, right) elif isinstance(right, TupleType): if len(left.items) != len(right.items): return False for l, r in zip(left.items, right.items): - if not is_proper_subtype(l, r): + if not self._is_proper_subtype(l, r): return False - return is_proper_subtype(left.fallback, right.fallback) + return self._is_proper_subtype(left.fallback, right.fallback) return False def visit_typeddict_type(self, left: TypedDictType) -> bool: @@ -1097,14 +1149,14 @@ def visit_typeddict_type(self, left: TypedDictType) -> bool: if name not in left.items: return False return True - return is_proper_subtype(left.fallback, right) + return self._is_proper_subtype(left.fallback, right) def visit_overloaded(self, left: Overloaded) -> bool: # TODO: What's the right thing to do here? return False def visit_union_type(self, left: UnionType) -> bool: - return all([is_proper_subtype(item, self.right) for item in left.items]) + return all([self._is_proper_subtype(item, self.right) for item in left.items]) def visit_partial_type(self, left: PartialType) -> bool: # TODO: What's the right thing to do here? @@ -1115,10 +1167,10 @@ def visit_type_type(self, left: TypeType) -> bool: right = self.right if isinstance(right, TypeType): # This is unsound, we don't check the __init__ signature. - return is_proper_subtype(left.item, right.item) + return self._is_proper_subtype(left.item, right.item) if isinstance(right, CallableType): # This is also unsound because of __init__. - return right.is_type_obj() and is_proper_subtype(left.item, right.ret_type) + return right.is_type_obj() and self._is_proper_subtype(left.item, right.ret_type) if isinstance(right, Instance): if right.type.fullname() == 'builtins.type': # TODO: Strictly speaking, the type builtins.type is considered equivalent to @@ -1131,7 +1183,7 @@ def visit_type_type(self, left: TypeType) -> bool: return False -def is_more_precise(left: Type, right: Type) -> bool: +def is_more_precise(left: Type, right: Type, *, ignore_promotions: bool = False) -> bool: """Check if left is a more precise type than right. A left is a proper subtype of right, left is also more precise than @@ -1141,4 +1193,4 @@ def is_more_precise(left: Type, right: Type) -> bool: # TODO Should List[int] be more precise than List[Any]? if isinstance(right, AnyType): return True - return is_proper_subtype(left, right) + return is_proper_subtype(left, right, ignore_promotions=ignore_promotions) diff --git a/mypy/typestate.py b/mypy/typestate.py index 337aac21d714..3c3515c7a7e5 100644 --- a/mypy/typestate.py +++ b/mypy/typestate.py @@ -3,7 +3,7 @@ and potentially other mutable TypeInfo state. This module contains mutable global state. """ -from typing import Dict, Set, Tuple, Optional +from typing import Any, Dict, Set, Tuple, Optional MYPY = False if MYPY: @@ -12,6 +12,17 @@ from mypy.types import Instance from mypy.server.trigger import make_trigger +# Represents that the 'left' instance is a subtype of the 'right' instance +SubtypeRelationship = Tuple[Instance, Instance] + +# A tuple encoding the specific conditions under which we performed the subtype check. +# (e.g. did we want a proper subtype? A regular subtype while ignoring variance?) +SubtypeKind = Tuple[Any, ...] + +# A cache that keeps track of whether the given TypeInfo is a part of a particular +# subtype relationship +SubtypeCache = Dict[TypeInfo, Dict[SubtypeKind, Set[SubtypeRelationship]]] + class TypeState: """This class provides subtype caching to improve performance of subtype checks. @@ -23,13 +34,11 @@ class TypeState: The protocol dependencies however are only stored here, and shouldn't be deleted unless not needed any more (e.g. during daemon shutdown). """ - # 'caches' and 'caches_proper' are subtype caches, implemented as sets of pairs - # of (subtype, supertype), where supertypes are instances of given TypeInfo. + # '_subtype_caches' keeps track of (subtype, supertype) pairs where supertypes are + # instances of the given TypeInfo. The cache also keeps track of the specific + # *kind* of subtyping relationship, which we represent as an arbitrary hashable tuple. # We need the caches, since subtype checks for structural types are very slow. - # _subtype_caches_proper is for caching proper subtype checks (i.e. not assuming that - # Any is consistent with every type). - _subtype_caches = {} # type: ClassVar[Dict[TypeInfo, Set[Tuple[Instance, Instance]]]] - _subtype_caches_proper = {} # type: ClassVar[Dict[TypeInfo, Set[Tuple[Instance, Instance]]]] + _subtype_caches = {} # type: ClassVar[SubtypeCache] # This contains protocol dependencies generated after running a full build, # or after an update. These dependencies are special because: @@ -70,13 +79,11 @@ class TypeState: def reset_all_subtype_caches(cls) -> None: """Completely reset all known subtype caches.""" cls._subtype_caches = {} - cls._subtype_caches_proper = {} @classmethod def reset_subtype_caches_for(cls, info: TypeInfo) -> None: """Reset subtype caches (if any) for a given supertype TypeInfo.""" - cls._subtype_caches.setdefault(info, set()).clear() - cls._subtype_caches_proper.setdefault(info, set()).clear() + cls._subtype_caches.setdefault(info, dict()).clear() @classmethod def reset_all_subtype_caches_for(cls, info: TypeInfo) -> None: @@ -85,20 +92,15 @@ def reset_all_subtype_caches_for(cls, info: TypeInfo) -> None: cls.reset_subtype_caches_for(item) @classmethod - def is_cached_subtype_check(cls, left: Instance, right: Instance) -> bool: - return (left, right) in cls._subtype_caches.setdefault(right.type, set()) - - @classmethod - def is_cached_proper_subtype_check(cls, left: Instance, right: Instance) -> bool: - return (left, right) in cls._subtype_caches_proper.setdefault(right.type, set()) - - @classmethod - def record_subtype_cache_entry(cls, left: Instance, right: Instance) -> None: - cls._subtype_caches.setdefault(right.type, set()).add((left, right)) + def is_cached_subtype_check(cls, kind: SubtypeKind, left: Instance, right: Instance) -> bool: + subtype_kinds = cls._subtype_caches.setdefault(right.type, dict()) + return (left, right) in subtype_kinds.setdefault(kind, set()) @classmethod - def record_proper_subtype_cache_entry(cls, left: Instance, right: Instance) -> None: - cls._subtype_caches_proper.setdefault(right.type, set()).add((left, right)) + def record_subtype_cache_entry(cls, kind: SubtypeKind, + left: Instance, right: Instance) -> None: + subtype_kinds = cls._subtype_caches.setdefault(right.type, dict()) + subtype_kinds.setdefault(kind, set()).add((left, right)) @classmethod def reset_protocol_deps(cls) -> None: