From 8e909e438e39b3bb4dc488dd8d78f0e721e81ae7 Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Fri, 21 May 2021 06:42:48 -0700 Subject: [PATCH] Fix crash on TypeGuard plus "and" (#10496) In python/typeshed#5473, I tried to switch a number of `inspect` functions to use the new `TypeGuard` functionality. Unfortunately, mypy-primer found a number of crashes in third-party libraries in places where a TypeGuard function was ANDed together with some other check. Examples: - https://github.com/sphinx-doc/sphinx/blob/4.x/sphinx/util/inspect.py#L252 - https://github.com/sphinx-doc/sphinx/blob/4.x/sphinx/ext/coverage.py#L212 - https://github.com/streamlit/streamlit/blob/develop/lib/streamlit/elements/doc_string.py#L105 The problems trace back to the decision in #9865 to make TypeGuardType not inherit from ProperType: in various conditions that are more complicated than a simple `if` check, mypy wants everything to become a ProperType. Therefore, to fix the crashes I had to make TypeGuardType a ProperType and support it in various visitors. --- mypy/constraints.py | 5 ++++- mypy/erasetype.py | 5 ++++- mypy/expandtype.py | 5 ++++- mypy/fixup.py | 5 ++++- mypy/indirection.py | 3 +++ mypy/join.py | 5 ++++- mypy/meet.py | 5 ++++- mypy/sametypes.py | 9 ++++++++- mypy/server/astdiff.py | 5 ++++- mypy/server/astmerge.py | 5 ++++- mypy/server/deps.py | 6 +++++- mypy/subtypes.py | 13 ++++++++++++- mypy/type_visitor.py | 12 +++++++++++- mypy/typeanal.py | 5 ++++- mypy/types.py | 20 +++++++++++++------- mypy/typetraverser.py | 5 ++++- test-data/unit/check-typeguard.test | 21 +++++++++++++++++++++ 17 files changed, 113 insertions(+), 21 deletions(-) diff --git a/mypy/constraints.py b/mypy/constraints.py index 7e1de292abec..074f038a30bc 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -7,7 +7,7 @@ CallableType, Type, TypeVisitor, UnboundType, AnyType, NoneType, TypeVarType, Instance, TupleType, TypedDictType, UnionType, Overloaded, ErasedType, PartialType, DeletedType, UninhabitedType, TypeType, TypeVarId, TypeQuery, is_named_instance, TypeOfAny, LiteralType, - ProperType, get_proper_type, TypeAliasType + ProperType, get_proper_type, TypeAliasType, TypeGuardType ) from mypy.maptype import map_instance_to_supertype import mypy.subtypes @@ -534,6 +534,9 @@ def visit_union_type(self, template: UnionType) -> List[Constraint]: def visit_type_alias_type(self, template: TypeAliasType) -> List[Constraint]: assert False, "This should be never called, got {}".format(template) + def visit_type_guard_type(self, template: TypeGuardType) -> List[Constraint]: + assert False, "This should be never called, got {}".format(template) + def infer_against_any(self, types: Iterable[Type], any_type: AnyType) -> List[Constraint]: res = [] # type: List[Constraint] for t in types: diff --git a/mypy/erasetype.py b/mypy/erasetype.py index 7a56eceacf5f..70b7c3b6de32 100644 --- a/mypy/erasetype.py +++ b/mypy/erasetype.py @@ -4,7 +4,7 @@ Type, TypeVisitor, UnboundType, AnyType, NoneType, TypeVarId, Instance, TypeVarType, CallableType, TupleType, TypedDictType, UnionType, Overloaded, ErasedType, PartialType, DeletedType, TypeTranslator, UninhabitedType, TypeType, TypeOfAny, LiteralType, ProperType, - get_proper_type, TypeAliasType + get_proper_type, TypeAliasType, TypeGuardType ) from mypy.nodes import ARG_STAR, ARG_STAR2 @@ -90,6 +90,9 @@ def visit_union_type(self, t: UnionType) -> ProperType: from mypy.typeops import make_simplified_union return make_simplified_union(erased_items) + def visit_type_guard_type(self, t: TypeGuardType) -> ProperType: + return TypeGuardType(t.type_guard.accept(self)) + def visit_type_type(self, t: TypeType) -> ProperType: return TypeType.make_normalized(t.item.accept(self), line=t.line) diff --git a/mypy/expandtype.py b/mypy/expandtype.py index f98e0750743b..c9a1a2430afb 100644 --- a/mypy/expandtype.py +++ b/mypy/expandtype.py @@ -1,7 +1,7 @@ from typing import Dict, Iterable, List, TypeVar, Mapping, cast from mypy.types import ( - Type, Instance, CallableType, TypeVisitor, UnboundType, AnyType, + Type, Instance, CallableType, TypeGuardType, TypeVisitor, UnboundType, AnyType, NoneType, TypeVarType, Overloaded, TupleType, TypedDictType, UnionType, ErasedType, PartialType, DeletedType, UninhabitedType, TypeType, TypeVarId, FunctionLike, TypeVarDef, LiteralType, get_proper_type, ProperType, @@ -126,6 +126,9 @@ def visit_union_type(self, t: UnionType) -> Type: from mypy.typeops import make_simplified_union # asdf return make_simplified_union(self.expand_types(t.items), t.line, t.column) + def visit_type_guard_type(self, t: TypeGuardType) -> ProperType: + return TypeGuardType(t.type_guard.accept(self)) + def visit_partial_type(self, t: PartialType) -> Type: return t diff --git a/mypy/fixup.py b/mypy/fixup.py index b90dba971e4f..f995ad36f0f6 100644 --- a/mypy/fixup.py +++ b/mypy/fixup.py @@ -9,7 +9,7 @@ TypeVarExpr, ClassDef, Block, TypeAlias, ) from mypy.types import ( - CallableType, Instance, Overloaded, TupleType, TypedDictType, + CallableType, Instance, Overloaded, TupleType, TypeGuardType, TypedDictType, TypeVarType, UnboundType, UnionType, TypeVisitor, LiteralType, TypeType, NOT_READY, TypeAliasType, AnyType, TypeOfAny, TypeVarDef ) @@ -254,6 +254,9 @@ def visit_union_type(self, ut: UnionType) -> None: for it in ut.items: it.accept(self) + def visit_type_guard_type(self, t: TypeGuardType) -> None: + t.type_guard.accept(self) + def visit_void(self, o: Any) -> None: pass # Nothing to descend into. diff --git a/mypy/indirection.py b/mypy/indirection.py index 307628c2abc5..aff942ce9393 100644 --- a/mypy/indirection.py +++ b/mypy/indirection.py @@ -97,6 +97,9 @@ def visit_literal_type(self, t: types.LiteralType) -> Set[str]: def visit_union_type(self, t: types.UnionType) -> Set[str]: return self._visit(t.items) + def visit_type_guard_type(self, t: types.TypeGuardType) -> Set[str]: + return self._visit(t.type_guard) + def visit_partial_type(self, t: types.PartialType) -> Set[str]: return set() diff --git a/mypy/join.py b/mypy/join.py index d4e6051b55af..53a1fce973dc 100644 --- a/mypy/join.py +++ b/mypy/join.py @@ -7,7 +7,7 @@ Type, AnyType, NoneType, TypeVisitor, Instance, UnboundType, TypeVarType, CallableType, TupleType, TypedDictType, ErasedType, UnionType, FunctionLike, Overloaded, LiteralType, PartialType, DeletedType, UninhabitedType, TypeType, TypeOfAny, get_proper_type, - ProperType, get_proper_types, TypeAliasType, PlaceholderType + ProperType, get_proper_types, TypeAliasType, PlaceholderType, TypeGuardType ) from mypy.maptype import map_instance_to_supertype from mypy.subtypes import ( @@ -340,6 +340,9 @@ def visit_type_type(self, t: TypeType) -> ProperType: def visit_type_alias_type(self, t: TypeAliasType) -> ProperType: assert False, "This should be never called, got {}".format(t) + def visit_type_guard_type(self, t: TypeGuardType) -> ProperType: + assert False, "This should be never called, got {}".format(t) + def join(self, s: Type, t: Type) -> ProperType: return join_types(s, t) diff --git a/mypy/meet.py b/mypy/meet.py index 6170396517b9..558de6ec92c9 100644 --- a/mypy/meet.py +++ b/mypy/meet.py @@ -8,7 +8,7 @@ Type, AnyType, TypeVisitor, UnboundType, NoneType, TypeVarType, Instance, CallableType, TupleType, TypedDictType, ErasedType, UnionType, PartialType, DeletedType, UninhabitedType, TypeType, TypeOfAny, Overloaded, FunctionLike, LiteralType, - ProperType, get_proper_type, get_proper_types, TypeAliasType + ProperType, get_proper_type, get_proper_types, TypeAliasType, TypeGuardType ) from mypy.subtypes import is_equivalent, is_subtype, is_callable_compatible, is_proper_subtype from mypy.erasetype import erase_type @@ -648,6 +648,9 @@ def visit_type_type(self, t: TypeType) -> ProperType: def visit_type_alias_type(self, t: TypeAliasType) -> ProperType: assert False, "This should be never called, got {}".format(t) + def visit_type_guard_type(self, t: TypeGuardType) -> ProperType: + assert False, "This should be never called, got {}".format(t) + def meet(self, s: Type, t: Type) -> ProperType: return meet_types(s, t) diff --git a/mypy/sametypes.py b/mypy/sametypes.py index 024333a13ec8..f599cc2f7b14 100644 --- a/mypy/sametypes.py +++ b/mypy/sametypes.py @@ -1,7 +1,7 @@ from typing import Sequence from mypy.types import ( - Type, UnboundType, AnyType, NoneType, TupleType, TypedDictType, + Type, TypeGuardType, UnboundType, AnyType, NoneType, TupleType, TypedDictType, UnionType, CallableType, TypeVarType, Instance, TypeVisitor, ErasedType, Overloaded, PartialType, DeletedType, UninhabitedType, TypeType, LiteralType, ProperType, get_proper_type, TypeAliasType) @@ -10,6 +10,7 @@ def is_same_type(left: Type, right: Type) -> bool: """Is 'left' the same type as 'right'?""" + left = get_proper_type(left) right = get_proper_type(right) @@ -150,6 +151,12 @@ def visit_union_type(self, left: UnionType) -> bool: else: return False + def visit_type_guard_type(self, left: TypeGuardType) -> bool: + if isinstance(self.right, TypeGuardType): + return is_same_type(left.type_guard, self.right.type_guard) + else: + return False + def visit_overloaded(self, left: Overloaded) -> bool: if isinstance(self.right, Overloaded): return is_same_types(left.items(), self.right.items()) diff --git a/mypy/server/astdiff.py b/mypy/server/astdiff.py index 9893092882b5..f74f3f35c7e1 100644 --- a/mypy/server/astdiff.py +++ b/mypy/server/astdiff.py @@ -57,7 +57,7 @@ class level -- these are handled at attribute level (say, 'mod.Cls.method' FuncBase, OverloadedFuncDef, FuncItem, MypyFile, UNBOUND_IMPORTED ) from mypy.types import ( - Type, TypeVisitor, UnboundType, AnyType, NoneType, UninhabitedType, + Type, TypeGuardType, TypeVisitor, UnboundType, AnyType, NoneType, UninhabitedType, ErasedType, DeletedType, Instance, TypeVarType, CallableType, TupleType, TypedDictType, UnionType, Overloaded, PartialType, TypeType, LiteralType, TypeAliasType ) @@ -335,6 +335,9 @@ def visit_union_type(self, typ: UnionType) -> SnapshotItem: normalized = tuple(sorted(items)) return ('UnionType', normalized) + def visit_type_guard_type(self, typ: TypeGuardType) -> SnapshotItem: + return ('TypeGuardType', snapshot_type(typ.type_guard)) + def visit_overloaded(self, typ: Overloaded) -> SnapshotItem: return ('Overloaded', snapshot_types(typ.items())) diff --git a/mypy/server/astmerge.py b/mypy/server/astmerge.py index 1c411886ac7d..8b9726019224 100644 --- a/mypy/server/astmerge.py +++ b/mypy/server/astmerge.py @@ -59,7 +59,7 @@ Type, SyntheticTypeVisitor, Instance, AnyType, NoneType, CallableType, ErasedType, DeletedType, TupleType, TypeType, TypeVarType, TypedDictType, UnboundType, UninhabitedType, UnionType, Overloaded, TypeVarDef, TypeList, CallableArgument, EllipsisType, StarType, LiteralType, - RawExpressionType, PartialType, PlaceholderType, TypeAliasType + RawExpressionType, PartialType, PlaceholderType, TypeAliasType, TypeGuardType ) from mypy.util import get_prefix, replace_object_state from mypy.typestate import TypeState @@ -389,6 +389,9 @@ def visit_erased_type(self, t: ErasedType) -> None: def visit_deleted_type(self, typ: DeletedType) -> None: pass + def visit_type_guard_type(self, typ: TypeGuardType) -> None: + raise RuntimeError + def visit_partial_type(self, typ: PartialType) -> None: raise RuntimeError diff --git a/mypy/server/deps.py b/mypy/server/deps.py index 78acc1d9e376..9aee82664bd2 100644 --- a/mypy/server/deps.py +++ b/mypy/server/deps.py @@ -97,7 +97,8 @@ class 'mod.Cls'. This can also refer to an attribute inherited from a Type, Instance, AnyType, NoneType, TypeVisitor, CallableType, DeletedType, PartialType, TupleType, TypeType, TypeVarType, TypedDictType, UnboundType, UninhabitedType, UnionType, FunctionLike, Overloaded, TypeOfAny, LiteralType, ErasedType, get_proper_type, ProperType, - TypeAliasType) + TypeAliasType, TypeGuardType +) from mypy.server.trigger import make_trigger, make_wildcard_trigger from mypy.util import correct_relative_import from mypy.scope import Scope @@ -970,6 +971,9 @@ def visit_unbound_type(self, typ: UnboundType) -> List[str]: def visit_uninhabited_type(self, typ: UninhabitedType) -> List[str]: return [] + def visit_type_guard_type(self, typ: TypeGuardType) -> List[str]: + return typ.type_guard.accept(self) + def visit_union_type(self, typ: UnionType) -> List[str]: triggers = [] for item in typ.items: diff --git a/mypy/subtypes.py b/mypy/subtypes.py index c3b8b82a3c2c..ffcaf8f2bc92 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -4,7 +4,7 @@ from typing_extensions import Final from mypy.types import ( - Type, AnyType, UnboundType, TypeVisitor, FormalArgument, NoneType, + Type, AnyType, TypeGuardType, UnboundType, TypeVisitor, FormalArgument, NoneType, Instance, TypeVarType, CallableType, TupleType, TypedDictType, UnionType, Overloaded, ErasedType, PartialType, DeletedType, UninhabitedType, TypeType, is_named_instance, FunctionLike, TypeOfAny, LiteralType, get_proper_type, TypeAliasType @@ -475,6 +475,9 @@ def visit_overloaded(self, left: Overloaded) -> bool: def visit_union_type(self, left: UnionType) -> bool: return all(self._is_subtype(item, self.orig_right) for item in left.items) + def visit_type_guard_type(self, left: TypeGuardType) -> bool: + raise RuntimeError("TypeGuard should not appear here") + def visit_partial_type(self, left: PartialType) -> bool: # This is indeterminate as we don't really know the complete type yet. raise RuntimeError @@ -1374,6 +1377,14 @@ def visit_overloaded(self, left: Overloaded) -> bool: def visit_union_type(self, left: UnionType) -> bool: return all([self._is_proper_subtype(item, self.orig_right) for item in left.items]) + def visit_type_guard_type(self, left: TypeGuardType) -> bool: + if isinstance(self.right, TypeGuardType): + # TypeGuard[bool] is a subtype of TypeGuard[int] + return self._is_proper_subtype(left.type_guard, self.right.type_guard) + else: + # TypeGuards aren't a subtype of anything else for now (but see #10489) + return False + def visit_partial_type(self, left: PartialType) -> bool: # TODO: What's the right thing to do here? return False diff --git a/mypy/type_visitor.py b/mypy/type_visitor.py index 8a95ceb049af..a0e6299a5a8a 100644 --- a/mypy/type_visitor.py +++ b/mypy/type_visitor.py @@ -19,7 +19,7 @@ T = TypeVar('T') from mypy.types import ( - Type, AnyType, CallableType, Overloaded, TupleType, TypedDictType, LiteralType, + Type, AnyType, CallableType, Overloaded, TupleType, TypeGuardType, TypedDictType, LiteralType, RawExpressionType, Instance, NoneType, TypeType, UnionType, TypeVarType, PartialType, DeletedType, UninhabitedType, TypeVarLikeDef, UnboundType, ErasedType, StarType, EllipsisType, TypeList, CallableArgument, @@ -103,6 +103,10 @@ def visit_type_type(self, t: TypeType) -> T: def visit_type_alias_type(self, t: TypeAliasType) -> T: pass + @abstractmethod + def visit_type_guard_type(self, t: TypeGuardType) -> T: + pass + @trait @mypyc_attr(allow_interpreted_subclasses=True) @@ -220,6 +224,9 @@ def visit_union_type(self, t: UnionType) -> Type: def translate_types(self, types: Iterable[Type]) -> List[Type]: return [t.accept(self) for t in types] + def visit_type_guard_type(self, t: TypeGuardType) -> Type: + return TypeGuardType(t.type_guard.accept(self)) + def translate_variables(self, variables: Sequence[TypeVarLikeDef]) -> Sequence[TypeVarLikeDef]: return variables @@ -319,6 +326,9 @@ def visit_star_type(self, t: StarType) -> T: def visit_union_type(self, t: UnionType) -> T: return self.query_types(t.items) + def visit_type_guard_type(self, t: TypeGuardType) -> T: + return t.type_guard.accept(self) + def visit_overloaded(self, t: Overloaded) -> T: return self.query_types(t.items()) diff --git a/mypy/typeanal.py b/mypy/typeanal.py index 42d4dbf61115..d9e7764ba3f8 100644 --- a/mypy/typeanal.py +++ b/mypy/typeanal.py @@ -14,7 +14,7 @@ from mypy.types import ( Type, UnboundType, TypeVarType, TupleType, TypedDictType, UnionType, Instance, AnyType, CallableType, NoneType, ErasedType, DeletedType, TypeList, TypeVarDef, SyntheticTypeVisitor, - StarType, PartialType, EllipsisType, UninhabitedType, TypeType, + StarType, PartialType, EllipsisType, UninhabitedType, TypeType, TypeGuardType, CallableArgument, TypeQuery, union_items, TypeOfAny, LiteralType, RawExpressionType, PlaceholderType, Overloaded, get_proper_type, TypeAliasType, TypeVarLikeDef, ParamSpecDef ) @@ -542,6 +542,9 @@ def visit_callable_type(self, t: CallableType, nested: bool = True) -> Type: ) return ret + def visit_type_guard_type(self, t: TypeGuardType) -> Type: + return t + def anal_type_guard(self, t: Type) -> Optional[Type]: if isinstance(t, UnboundType): sym = self.lookup_qualified(t.name, t) diff --git a/mypy/types.py b/mypy/types.py index 4587dadfd885..d9c71fbcf7ee 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -270,7 +270,14 @@ def copy_modified(self, *, self.line, self.column) -class TypeGuardType(Type): +class ProperType(Type): + """Not a type alias. + + Every type except TypeAliasType must inherit from this type. + """ + + +class TypeGuardType(ProperType): """Only used by find_instance_check() etc.""" def __init__(self, type_guard: Type): super().__init__(line=type_guard.line, column=type_guard.column) @@ -279,12 +286,8 @@ def __init__(self, type_guard: Type): def __repr__(self) -> str: return "TypeGuard({})".format(self.type_guard) - -class ProperType(Type): - """Not a type alias. - - Every type except TypeAliasType must inherit from this type. - """ + def accept(self, visitor: 'TypeVisitor[T]') -> T: + return visitor.visit_type_guard_type(self) class TypeVarId: @@ -2183,6 +2186,9 @@ def visit_union_type(self, t: UnionType) -> str: s = self.list_str(t.items) return 'Union[{}]'.format(s) + def visit_type_guard_type(self, t: TypeGuardType) -> str: + return 'TypeGuard[{}]'.format(t.type_guard.accept(self)) + def visit_partial_type(self, t: PartialType) -> str: if t.type is None: return '' diff --git a/mypy/typetraverser.py b/mypy/typetraverser.py index 8d7459f7a551..e8f22a62e7c4 100644 --- a/mypy/typetraverser.py +++ b/mypy/typetraverser.py @@ -6,7 +6,7 @@ Type, SyntheticTypeVisitor, AnyType, UninhabitedType, NoneType, ErasedType, DeletedType, TypeVarType, LiteralType, Instance, CallableType, TupleType, TypedDictType, UnionType, Overloaded, TypeType, CallableArgument, UnboundType, TypeList, StarType, EllipsisType, - PlaceholderType, PartialType, RawExpressionType, TypeAliasType + PlaceholderType, PartialType, RawExpressionType, TypeAliasType, TypeGuardType ) @@ -62,6 +62,9 @@ def visit_typeddict_type(self, t: TypedDictType) -> None: def visit_union_type(self, t: UnionType) -> None: self.traverse_types(t.items) + def visit_type_guard_type(self, t: TypeGuardType) -> None: + t.type_guard.accept(self) + def visit_overloaded(self, t: Overloaded) -> None: self.traverse_types(t.items()) diff --git a/test-data/unit/check-typeguard.test b/test-data/unit/check-typeguard.test index 52beb2836485..fa340cb04044 100644 --- a/test-data/unit/check-typeguard.test +++ b/test-data/unit/check-typeguard.test @@ -294,3 +294,24 @@ class C: class D(C): def is_float(self, a: object) -> bool: pass # E: Signature of "is_float" incompatible with supertype "C" [builtins fixtures/tuple.pyi] + +[case testTypeGuardInAnd] +from typing import Any +from typing_extensions import TypeGuard +import types +def isclass(a: object) -> bool: + pass +def ismethod(a: object) -> TypeGuard[float]: + pass +def isfunction(a: object) -> TypeGuard[str]: + pass +def isclassmethod(obj: Any) -> bool: + if ismethod(obj) and obj.__self__ is not None and isclass(obj.__self__): # E: "float" has no attribute "__self__" + return True + + return False +def coverage(obj: Any) -> bool: + if not (ismethod(obj) or isfunction(obj)): + return True + return False +[builtins fixtures/classmethod.pyi]