diff --git a/mypy/checker.py b/mypy/checker.py index 11499d6b570e..29dd2d2c0130 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -4473,14 +4473,14 @@ def has_no_custom_eq_checks(t: Type) -> bool: if_map = {} else_map = {} - if_assignment_map, else_assignment_map = self.find_isinstance_check_helper(node.target) + if_assignment_map, else_assignment_map = self.find_isinstance_check(node.target) if if_assignment_map is not None: if_map.update(if_assignment_map) if else_assignment_map is not None: else_map.update(else_assignment_map) - if_condition_map, else_condition_map = self.find_isinstance_check_helper(node.value) + if_condition_map, else_condition_map = self.find_isinstance_check(node.value) if if_condition_map is not None: if_map.update(if_condition_map) @@ -4492,23 +4492,23 @@ def has_no_custom_eq_checks(t: Type) -> bool: (None if else_assignment_map is None or else_condition_map is None else else_map), ) elif isinstance(node, OpExpr) and node.op == 'and': - left_if_vars, left_else_vars = self.find_isinstance_check_helper(node.left) - right_if_vars, right_else_vars = self.find_isinstance_check_helper(node.right) + left_if_vars, left_else_vars = self.find_isinstance_check(node.left) + right_if_vars, right_else_vars = self.find_isinstance_check(node.right) # (e1 and e2) is true if both e1 and e2 are true, # and false if at least one of e1 and e2 is false. return (and_conditional_maps(left_if_vars, right_if_vars), or_conditional_maps(left_else_vars, right_else_vars)) elif isinstance(node, OpExpr) and node.op == 'or': - left_if_vars, left_else_vars = self.find_isinstance_check_helper(node.left) - right_if_vars, right_else_vars = self.find_isinstance_check_helper(node.right) + left_if_vars, left_else_vars = self.find_isinstance_check(node.left) + right_if_vars, right_else_vars = self.find_isinstance_check(node.right) # (e1 or e2) is true if at least one of e1 or e2 is true, # and false if both e1 and e2 are false. return (or_conditional_maps(left_if_vars, right_if_vars), and_conditional_maps(left_else_vars, right_else_vars)) elif isinstance(node, UnaryExpr) and node.op == 'not': - left, right = self.find_isinstance_check_helper(node.expr) + left, right = self.find_isinstance_check(node.expr) return right, left # Restrict the type of the variable to True-ish/False-ish in the if and else branches diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index 7c8415b75fe1..70320728bf88 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -655,6 +655,27 @@ else: reveal_type(y["model"]) # N: Revealed type is "Union[TypedDict('__main__.Model1', {'key': Literal['A']}), TypedDict('__main__.Model2', {'key': Literal['B']})]" [builtins fixtures/primitives.pyi] +[case testNarrowingExprPropagation] +from typing import Union +from typing_extensions import Literal + +class A: + tag: Literal['A'] + +class B: + tag: Literal['B'] + +abo: Union[A, B, None] + +if abo is not None and abo.tag == "A": + reveal_type(abo.tag) # N: Revealed type is "Literal['A']" + reveal_type(abo) # N: Revealed type is "__main__.A" + +if not (abo is None or abo.tag != "B"): + reveal_type(abo.tag) # N: Revealed type is "Literal['B']" + reveal_type(abo) # N: Revealed type is "__main__.B" +[builtins fixtures/primitives.pyi] + [case testNarrowingEqualityFlipFlop] # flags: --warn-unreachable --strict-equality from typing_extensions import Literal, Final