diff --git a/mypy/binder.py b/mypy/binder.py index 1c711ce9c631..44533cc860a1 100644 --- a/mypy/binder.py +++ b/mypy/binder.py @@ -203,7 +203,7 @@ def update_from_options(self, frames: List[Frame]) -> bool: for other in resulting_values[1:]: assert other is not None # Ignore the error about using get_proper_type(). - if not isinstance(other, TypeGuardType): # type: ignore[misc] + if not contains_type_guard(other): type = join_simple(self.declarations[key], type, other) if current_value is None or not is_same_type(type, current_value): self._put(key, type) @@ -431,3 +431,13 @@ def get_declaration(expr: BindableExpression) -> Optional[Type]: if not isinstance(type, PartialType): return type return None + + +def contains_type_guard(other: Type) -> bool: + # Ignore the error about using get_proper_type(). + if isinstance(other, TypeGuardType): # type: ignore[misc] + return True + other = get_proper_type(other) + if isinstance(other, UnionType): + return any(contains_type_guard(item) for item in other.relevant_items()) + return False diff --git a/mypy/checker.py b/mypy/checker.py index ba020f5d97d5..3c506d081dbf 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -4167,6 +4167,11 @@ def find_isinstance_check_helper(self, node: Expression) -> Tuple[TypeMap, TypeM self.fail("Type guard requires positional argument", node) return {}, {} if literal(expr) == LITERAL_TYPE: + # Note: we wrap the target type, so that we can special case later. + # Namely, for isinstance() we use a normal meet, while TypeGuard is + # considered "always right" (i.e. even if the types are not overlapping). + # Also note that a care must be taken to unwrap this back at read places + # where we use this to narrow down declared type. return {expr: TypeGuardType(node.callee.type_guard)}, {} elif isinstance(node, ComparisonExpr): # Step 1: Obtain the types of each operand and whether or not we can diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 5aef881aa2ff..f1e31dd070ba 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -4179,10 +4179,6 @@ def narrow_type_from_binder(self, expr: Expression, known_type: Type, """ if literal(expr) >= LITERAL_TYPE: restriction = self.chk.binder.get(expr) - # Ignore the error about using get_proper_type(). - if isinstance(restriction, TypeGuardType): # type: ignore[misc] - # A type guard forces the new type even if it doesn't overlap the old. - return restriction.type_guard # If the current node is deferred, some variables may get Any types that they # otherwise wouldn't have. We don't want to narrow down these since it may # produce invalid inferred Optional[Any] types, at least. diff --git a/mypy/meet.py b/mypy/meet.py index 70d75a2570bf..465b25dd7c0d 100644 --- a/mypy/meet.py +++ b/mypy/meet.py @@ -56,7 +56,11 @@ def narrow_declared_type(declared: Type, narrowed: Type) -> Type: if declared == narrowed: return declared - if isinstance(declared, UnionType): + # Ignore the error about using get_proper_type(). + if isinstance(narrowed, TypeGuardType): # type: ignore[misc] + # A type guard forces the new type even if it doesn't overlap the old. + return narrowed.type_guard + elif isinstance(declared, UnionType): return make_simplified_union([narrow_declared_type(x, narrowed) for x in declared.relevant_items()]) elif not is_overlapping_types(declared, narrowed, @@ -157,6 +161,11 @@ def _is_overlapping_types(left: Type, right: Type) -> bool: if isinstance(left, PartialType) or isinstance(right, PartialType): assert False, "Unexpectedly encountered partial type" + # Ignore the error about using get_proper_type(). + if isinstance(left, TypeGuardType) or isinstance(right, TypeGuardType): # type: ignore[misc] + # A type guard forces the new type even if it doesn't overlap the old. + return True + # We should also never encounter these types, but it's possible a few # have snuck through due to unrelated bugs. For now, we handle these # in the same way we handle 'Any'. diff --git a/test-data/unit/check-typeguard.test b/test-data/unit/check-typeguard.test index 0c3456184794..c4f88ca3f018 100644 --- a/test-data/unit/check-typeguard.test +++ b/test-data/unit/check-typeguard.test @@ -370,3 +370,47 @@ if guard(a): reveal_type(a) # N: Revealed type is "__main__.A" reveal_type(a) # N: Revealed type is "__main__.A" [builtins fixtures/tuple.pyi] + +[case testTypeGuardNestedRestrictionAny] +from typing_extensions import TypeGuard +from typing import Any + +class A: ... +def f(x: object) -> TypeGuard[A]: ... +def g(x: object) -> None: ... + +def test(x: Any) -> None: + if not(f(x) or x): + return + g(reveal_type(x)) # N: Revealed type is "Union[__main__.A, Any]" +[builtins fixtures/tuple.pyi] + +[case testTypeGuardNestedRestrictionUnionOther] +from typing_extensions import TypeGuard +from typing import Any + +class A: ... +class B: ... +def f(x: object) -> TypeGuard[A]: ... +def f2(x: object) -> TypeGuard[B]: ... +def g(x: object) -> None: ... + +def test(x: object) -> None: + if not(f(x) or f2(x)): + return + g(reveal_type(x)) # N: Revealed type is "Union[__main__.A, __main__.B]" +[builtins fixtures/tuple.pyi] + +[case testTypeGuardNestedRestrictionUnionIsInstance] +from typing_extensions import TypeGuard +from typing import Any, List + +class A: ... +def f(x: List[object]) -> TypeGuard[List[str]]: ... +def g(x: object) -> None: ... + +def test(x: List[object]) -> None: + if not(f(x) or isinstance(x, A)): + return + g(reveal_type(x)) # N: Revealed type is "Union[builtins.list[builtins.str], __main__.]" +[builtins fixtures/tuple.pyi]