diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 2f052522e5c1..2757d8d60a60 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -1988,17 +1988,25 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type: # testCustomEqCheckStrictEquality for an example. if self.msg.errors.total_errors() == err_count and operator in ('==', '!='): right_type = self.accept(right) + # We suppress the error if there is a custom __eq__() method on either + # side. User defined (or even standard library) classes can define this + # to return True for comparisons between non-overlapping types. if (not custom_equality_method(left_type) and not custom_equality_method(right_type)): - # We suppress the error if there is a custom __eq__() method on either - # side. User defined (or even standard library) classes can define this - # to return True for comparisons between non-overlapping types. + # Also flag non-overlapping literals in situations like: + # x: Literal['a', 'b'] + # if x == 'c': + # ... + left_type = try_getting_literal(left_type) + right_type = try_getting_literal(right_type) if self.dangerous_comparison(left_type, right_type): self.msg.dangerous_comparison(left_type, right_type, 'equality', e) elif operator == 'is' or operator == 'is not': right_type = self.accept(right) # validate the right operand sub_result = self.bool_type() + left_type = try_getting_literal(left_type) + right_type = try_getting_literal(right_type) if self.dangerous_comparison(left_type, right_type): self.msg.dangerous_comparison(left_type, right_type, 'identity', e) method_type = None @@ -4017,6 +4025,13 @@ def is_literal_type_like(t: Optional[Type]) -> bool: return False +def try_getting_literal(typ: Type) -> Type: + """If possible, get a more precise literal type for a given type.""" + if isinstance(typ, Instance) and typ.last_known_value is not None: + return typ.last_known_value + return typ + + def is_expr_literal_type(node: Expression) -> bool: """Returns 'true' if the given node is a Literal""" valid = ('typing.Literal', 'typing_extensions.Literal') diff --git a/test-data/unit/check-columns.test b/test-data/unit/check-columns.test index 1ed2b051f745..877e6114d2eb 100644 --- a/test-data/unit/check-columns.test +++ b/test-data/unit/check-columns.test @@ -226,7 +226,7 @@ if int(): [case testColumnNonOverlappingEqualityCheck] # flags: --strict-equality -if 1 == '': # E:4: Non-overlapping equality check (left operand type: "int", right operand type: "str") +if 1 == '': # E:4: Non-overlapping equality check (left operand type: "Literal[1]", right operand type: "Literal['']") pass [builtins fixtures/bool.pyi] diff --git a/test-data/unit/check-expressions.test b/test-data/unit/check-expressions.test index 69614287c1fc..65cc863f4b97 100644 --- a/test-data/unit/check-expressions.test +++ b/test-data/unit/check-expressions.test @@ -2029,11 +2029,11 @@ class B: ... a: Union[int, str] b: Union[A, B] -a == 42 -b == 42 # E: Non-overlapping equality check (left operand type: "Union[A, B]", right operand type: "int") +a == int() +b == int() # E: Non-overlapping equality check (left operand type: "Union[A, B]", right operand type: "int") -a is 42 -b is 42 # E: Non-overlapping identity check (left operand type: "Union[A, B]", right operand type: "int") +a is int() +b is int() # E: Non-overlapping identity check (left operand type: "Union[A, B]", right operand type: "int") ca: Union[Container[int], Container[str]] cb: Union[Container[A], Container[B]] @@ -2061,7 +2061,7 @@ x in b'abc' [case testStrictEqualityNoPromotePy3] # flags: --strict-equality -'a' == b'a' # E: Non-overlapping equality check (left operand type: "str", right operand type: "bytes") +'a' == b'a' # E: Non-overlapping equality check (left operand type: "Literal['a']", right operand type: "Literal[b'a']") b'a' in 'abc' # E: Non-overlapping container check (element type: "bytes", container item type: "str") x: str @@ -2271,6 +2271,22 @@ def f(x: T) -> T: return x [builtins fixtures/bool.pyi] +[case testStrictEqualityWithALiteral] +# flags: --strict-equality +from typing_extensions import Literal, Final + +def returns_a_or_b() -> Literal['a', 'b']: + ... +def returns_1_or_2() -> Literal[1, 2]: + ... +THREE: Final = 3 + +if returns_a_or_b() == 'c': # E: Non-overlapping equality check (left operand type: "Union[Literal['a'], Literal['b']]", right operand type: "Literal['c']") + ... +if returns_1_or_2() is THREE: # E: Non-overlapping identity check (left operand type: "Union[Literal[1], Literal[2]]", right operand type: "Literal[3]") + ... +[builtins fixtures/bool.pyi] + [case testUnimportedHintAny] def f(x: Any) -> None: # E: Name 'Any' is not defined \ # N: Did you forget to import it from "typing"? (Suggestion: "from typing import Any") diff --git a/test-data/unit/check-flags.test b/test-data/unit/check-flags.test index d49c428503cd..7e8cf53c8c89 100644 --- a/test-data/unit/check-flags.test +++ b/test-data/unit/check-flags.test @@ -1125,7 +1125,7 @@ def f(c: A) -> None: # E: Missing type parameters for generic type "A" [case testStrictEqualityPerFile] # flags: --config-file tmp/mypy.ini import b -42 == 'no' # E: Non-overlapping equality check (left operand type: "int", right operand type: "str") +42 == 'no' # E: Non-overlapping equality check (left operand type: "Literal[42]", right operand type: "Literal['no']") [file b.py] 42 == 'no' [file mypy.ini]