Skip to content

Make --strict-equality stricter with literals #7310

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 9, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1988,17 +1988,25 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type:
# testCustomEqCheckStrictEquality for an example.
if self.msg.errors.total_errors() == err_count and operator in ('==', '!='):
right_type = self.accept(right)
# We suppress the error if there is a custom __eq__() method on either
# side. User defined (or even standard library) classes can define this
# to return True for comparisons between non-overlapping types.
if (not custom_equality_method(left_type) and
not custom_equality_method(right_type)):
# We suppress the error if there is a custom __eq__() method on either
# side. User defined (or even standard library) classes can define this
# to return True for comparisons between non-overlapping types.
# Also flag non-overlapping literals in situations like:
# x: Literal['a', 'b']
# if x == 'c':
# ...
left_type = try_getting_literal(left_type)
right_type = try_getting_literal(right_type)
if self.dangerous_comparison(left_type, right_type):
self.msg.dangerous_comparison(left_type, right_type, 'equality', e)

elif operator == 'is' or operator == 'is not':
right_type = self.accept(right) # validate the right operand
sub_result = self.bool_type()
left_type = try_getting_literal(left_type)
right_type = try_getting_literal(right_type)
if self.dangerous_comparison(left_type, right_type):
self.msg.dangerous_comparison(left_type, right_type, 'identity', e)
method_type = None
Expand Down Expand Up @@ -4017,6 +4025,13 @@ def is_literal_type_like(t: Optional[Type]) -> bool:
return False


def try_getting_literal(typ: Type) -> Type:
"""If possible, get a more precise literal type for a given type."""
if isinstance(typ, Instance) and typ.last_known_value is not None:
return typ.last_known_value
return typ


def is_expr_literal_type(node: Expression) -> bool:
"""Returns 'true' if the given node is a Literal"""
valid = ('typing.Literal', 'typing_extensions.Literal')
Expand Down
2 changes: 1 addition & 1 deletion test-data/unit/check-columns.test
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ if int():

[case testColumnNonOverlappingEqualityCheck]
# flags: --strict-equality
if 1 == '': # E:4: Non-overlapping equality check (left operand type: "int", right operand type: "str")
if 1 == '': # E:4: Non-overlapping equality check (left operand type: "Literal[1]", right operand type: "Literal['']")
pass
[builtins fixtures/bool.pyi]

Expand Down
26 changes: 21 additions & 5 deletions test-data/unit/check-expressions.test
Original file line number Diff line number Diff line change
Expand Up @@ -2029,11 +2029,11 @@ class B: ...
a: Union[int, str]
b: Union[A, B]

a == 42
b == 42 # E: Non-overlapping equality check (left operand type: "Union[A, B]", right operand type: "int")
a == int()
b == int() # E: Non-overlapping equality check (left operand type: "Union[A, B]", right operand type: "int")

a is 42
b is 42 # E: Non-overlapping identity check (left operand type: "Union[A, B]", right operand type: "int")
a is int()
b is int() # E: Non-overlapping identity check (left operand type: "Union[A, B]", right operand type: "int")

ca: Union[Container[int], Container[str]]
cb: Union[Container[A], Container[B]]
Expand Down Expand Up @@ -2061,7 +2061,7 @@ x in b'abc'

[case testStrictEqualityNoPromotePy3]
# flags: --strict-equality
'a' == b'a' # E: Non-overlapping equality check (left operand type: "str", right operand type: "bytes")
'a' == b'a' # E: Non-overlapping equality check (left operand type: "Literal['a']", right operand type: "Literal[b'a']")
b'a' in 'abc' # E: Non-overlapping container check (element type: "bytes", container item type: "str")

x: str
Expand Down Expand Up @@ -2271,6 +2271,22 @@ def f(x: T) -> T:
return x
[builtins fixtures/bool.pyi]

[case testStrictEqualityWithALiteral]
# flags: --strict-equality
from typing_extensions import Literal, Final

def returns_a_or_b() -> Literal['a', 'b']:
...
def returns_1_or_2() -> Literal[1, 2]:
...
THREE: Final = 3

if returns_a_or_b() == 'c': # E: Non-overlapping equality check (left operand type: "Union[Literal['a'], Literal['b']]", right operand type: "Literal['c']")
...
if returns_1_or_2() is THREE: # E: Non-overlapping identity check (left operand type: "Union[Literal[1], Literal[2]]", right operand type: "Literal[3]")
...
[builtins fixtures/bool.pyi]

[case testUnimportedHintAny]
def f(x: Any) -> None: # E: Name 'Any' is not defined \
# N: Did you forget to import it from "typing"? (Suggestion: "from typing import Any")
Expand Down
2 changes: 1 addition & 1 deletion test-data/unit/check-flags.test
Original file line number Diff line number Diff line change
Expand Up @@ -1125,7 +1125,7 @@ def f(c: A) -> None: # E: Missing type parameters for generic type "A"
[case testStrictEqualityPerFile]
# flags: --config-file tmp/mypy.ini
import b
42 == 'no' # E: Non-overlapping equality check (left operand type: "int", right operand type: "str")
42 == 'no' # E: Non-overlapping equality check (left operand type: "Literal[42]", right operand type: "Literal['no']")
[file b.py]
42 == 'no'
[file mypy.ini]
Expand Down