From f25c3d043f5b2aa55064a3825182e61f72b0c530 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Sun, 14 Apr 2019 22:48:39 +0100 Subject: [PATCH 1/7] Allow type promotions with strict equality --- mypy/checkexpr.py | 2 +- test-data/unit/check-expressions.test | 12 +++++++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 422a0f5c100e..6e047818a88a 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -1996,7 +1996,7 @@ def dangerous_comparison(self, left: Type, right: Type) -> bool: if isinstance(left, UnionType) and isinstance(right, UnionType): left = remove_optional(left) right = remove_optional(right) - return not is_overlapping_types(left, right, ignore_promotions=True) + return not is_overlapping_types(left, right, ignore_promotions=False) def get_operator_method(self, op: str) -> str: if op == '/' and self.chk.options.python_version[0] == 2: diff --git a/test-data/unit/check-expressions.test b/test-data/unit/check-expressions.test index ee3ee71e1e88..e4c75a76e30f 100644 --- a/test-data/unit/check-expressions.test +++ b/test-data/unit/check-expressions.test @@ -2024,7 +2024,7 @@ cb: Union[Container[A], Container[B]] [builtins fixtures/bool.pyi] [typing fixtures/typing-full.pyi] -[case testStrictEqualityNoPromote] +[case testStrictEqualityNoPromotePy3] # flags: --strict-equality 'a' == b'a' # E: Non-overlapping equality check (left operand type: "str", right operand type: "bytes") b'a' in 'abc' # E: Non-overlapping container check (element type: "bytes", container item type: "str") @@ -2035,6 +2035,16 @@ x != y # E: Non-overlapping equality check (left operand type: "str", right ope [builtins fixtures/primitives.pyi] [typing fixtures/typing-full.pyi] +[case testStrictEqualityOkPromote] +# flags: --strict-equality +from typing import Container +c: Container[int] + +1 == 1.0 # OK +1.0 in c # OK +[builtins fixtures/primitives.pyi] +[typing fixtures/typing-full.pyi] + [case testStrictEqualityAny] # flags: --strict-equality from typing import Any, Container From ee4147d09bf693211255c5537575747d95231028 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Sun, 14 Apr 2019 23:15:43 +0100 Subject: [PATCH 2/7] Special case bytes in bytes --- mypy/checkexpr.py | 17 ++++++++++++----- test-data/unit/check-expressions.test | 6 ++++++ test-data/unit/fixtures/primitives.pyi | 5 ++++- 3 files changed, 22 insertions(+), 6 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 6e047818a88a..7e762c85c363 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -1938,7 +1938,8 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type: self.msg.unsupported_operand_types('in', left_type, right_type, e) # Only show dangerous overlap if there are no other errors. elif (not local_errors.is_errors() and cont_type and - self.dangerous_comparison(left_type, cont_type)): + self.dangerous_comparison(left_type, cont_type, + original_cont_type=right_type)): self.msg.dangerous_comparison(left_type, cont_type, 'container', e) else: self.msg.add_errors(local_errors) @@ -1974,9 +1975,13 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type: assert result is not None return result - def dangerous_comparison(self, left: Type, right: Type) -> bool: + def dangerous_comparison(self, left: Type, right: Type, + original_cont_type: Optional[Type] = None) -> bool: """Check for dangerous non-overlapping comparisons like 42 == 'no'. + The original_cont_type is the original container type for 'in' checks + (and None for equality checks). + Rules: * X and None are overlapping even in strict-optional mode. This is to allow 'assert x is not None' for x defined as 'x = None # type: str' in class body @@ -1985,9 +1990,7 @@ def dangerous_comparison(self, left: Type, right: Type) -> bool: non-overlapping, although technically None is overlap, it is most likely an error. * Any overlaps with everything, i.e. always safe. - * Promotions are ignored, so both 'abc' == b'abc' and 1 == 1.0 - are errors. This is mostly needed for bytes vs unicode, and - int vs float are added just for consistency. + * Special case: b'abc' in b'cde' is safe. """ if not self.chk.options.strict_equality: return False @@ -1996,6 +1999,10 @@ def dangerous_comparison(self, left: Type, right: Type) -> bool: if isinstance(left, UnionType) and isinstance(right, UnionType): left = remove_optional(left) right = remove_optional(right) + if (isinstance(original_cont_type, Instance) and + original_cont_type.type.fullname() == 'builtins.bytes' and + isinstance(left, Instance) and left.type.fullname() == 'builtins.bytes'): + return False return not is_overlapping_types(left, right, ignore_promotions=False) def get_operator_method(self, op: str) -> str: diff --git a/test-data/unit/check-expressions.test b/test-data/unit/check-expressions.test index e4c75a76e30f..b169624a7a05 100644 --- a/test-data/unit/check-expressions.test +++ b/test-data/unit/check-expressions.test @@ -2024,6 +2024,12 @@ cb: Union[Container[A], Container[B]] [builtins fixtures/bool.pyi] [typing fixtures/typing-full.pyi] +[case testStrictEqualityBytesSpecial] +# flags: --strict-equality +b'abc' in b'abcde' +[builtins fixtures/primitives.pyi] +[typing fixtures/typing-full.pyi] + [case testStrictEqualityNoPromotePy3] # flags: --strict-equality 'a' == b'a' # E: Non-overlapping equality check (left operand type: "str", right operand type: "bytes") diff --git a/test-data/unit/fixtures/primitives.pyi b/test-data/unit/fixtures/primitives.pyi index 796196fa08c6..f2c0cd03acfc 100644 --- a/test-data/unit/fixtures/primitives.pyi +++ b/test-data/unit/fixtures/primitives.pyi @@ -23,7 +23,10 @@ class str(Sequence[str]): def __contains__(self, other: object) -> bool: pass def __getitem__(self, item: int) -> str: pass def format(self, *args) -> str: pass -class bytes: pass +class bytes(Sequence[int]): + def __iter__(self) -> Iterator[int]: pass + def __contains__(self, other: object) -> bool: pass + def __getitem__(self, item: int) -> int: pass class bytearray: pass class tuple(Generic[T]): pass class function: pass From 07b689b9229ffcc8854b5627b924cb3ce6940a38 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Sun, 14 Apr 2019 23:43:24 +0100 Subject: [PATCH 3/7] Allow types with custom __eq__; add tests and refactor unions --- mypy/checkexpr.py | 27 +++++++++++++++++---- test-data/unit/check-expressions.test | 34 +++++++++++++++++++++++++++ 2 files changed, 56 insertions(+), 5 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 7e762c85c363..09045c2eaeda 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -1952,8 +1952,10 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type: # testCustomEqCheckStrictEquality for an example. if self.msg.errors.total_errors() == err_count and operator in ('==', '!='): right_type = self.accept(right) - if self.dangerous_comparison(left_type, right_type): - self.msg.dangerous_comparison(left_type, right_type, 'equality', e) + if (not self.custom_equality_method(left_type) and + not self.custom_equality_method(right_type)): + if self.dangerous_comparison(left_type, right_type): + self.msg.dangerous_comparison(left_type, right_type, 'equality', e) elif operator == 'is' or operator == 'is not': right_type = self.accept(right) # validate the right operand @@ -1975,6 +1977,23 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type: assert result is not None return result + def custom_equality_method(self, typ: Type) -> bool: + if isinstance(typ, UnionType): + return any(self.custom_equality_method(t) for t in typ.items) + if isinstance(typ, Instance): + method = typ.type.get_method('__eq__') + if method and method.info: + return not method.info.fullname().startswith('builtins.') + return False + return False + + def has_bytes_component(self, typ: Type) -> bool: + if isinstance(typ, UnionType): + return any(self.has_bytes_component(t) for t in typ.items) + if isinstance(typ, Instance) and typ.type.fullname() == 'builtins.bytes': + return True + return False + def dangerous_comparison(self, left: Type, right: Type, original_cont_type: Optional[Type] = None) -> bool: """Check for dangerous non-overlapping comparisons like 42 == 'no'. @@ -1999,9 +2018,7 @@ def dangerous_comparison(self, left: Type, right: Type, if isinstance(left, UnionType) and isinstance(right, UnionType): left = remove_optional(left) right = remove_optional(right) - if (isinstance(original_cont_type, Instance) and - original_cont_type.type.fullname() == 'builtins.bytes' and - isinstance(left, Instance) and left.type.fullname() == 'builtins.bytes'): + if self.has_bytes_component(original_cont_type) and self.has_bytes_component(left): return False return not is_overlapping_types(left, right, ignore_promotions=False) diff --git a/test-data/unit/check-expressions.test b/test-data/unit/check-expressions.test index b169624a7a05..72a9876006b0 100644 --- a/test-data/unit/check-expressions.test +++ b/test-data/unit/check-expressions.test @@ -2030,6 +2030,16 @@ b'abc' in b'abcde' [builtins fixtures/primitives.pyi] [typing fixtures/typing-full.pyi] +[case testStrictEqualityBytesSpecialUnion] +# flags: --strict-equality +from typing import Union +x: Union[bytes, str] + +b'abc' in x +x in b'abc' +[builtins fixtures/primitives.pyi] +[typing fixtures/typing-full.pyi] + [case testStrictEqualityNoPromotePy3] # flags: --strict-equality 'a' == b'a' # E: Non-overlapping equality check (left operand type: "str", right operand type: "bytes") @@ -2102,6 +2112,30 @@ class B: A() == B() # E: Unsupported operand types for == ("A" and "B") [builtins fixtures/bool.pyi] +[case testCustomEqCheckStrictEqualityOKInstance] +# flags: --strict-equality +class A: + def __eq__(self, other: object) -> bool: + ... +class B: + def __eq__(self, other: object) -> bool: + ... + +A() == int() # OK +int() != B() # OK +[builtins fixtures/bool.pyi] + +[case testCustomEqCheckStrictEqualityOKUnion] +# flags: --strict-equality +from typing import Union +class A: + def __eq__(self, other: object) -> bool: + ... + +x: Union[A, str] +x == int() +[builtins fixtures/bool.pyi] + [case testCustomContainsCheckStrictEquality] # flags: --strict-equality class A: From c9b285685a08177e0f94716195cfee1098b52b38 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Sun, 14 Apr 2019 23:53:35 +0100 Subject: [PATCH 4/7] Add commnets and docstrings --- mypy/checkexpr.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 09045c2eaeda..d1cbbe39b4fb 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -1954,6 +1954,9 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type: right_type = self.accept(right) if (not self.custom_equality_method(left_type) and not self.custom_equality_method(right_type)): + # We suppress the error if there is a custom __eq__() method on either + # side. User defined (or even standard library) classes can define this + # to return True for comparisons between non-overlapping types. if self.dangerous_comparison(left_type, right_type): self.msg.dangerous_comparison(left_type, right_type, 'equality', e) @@ -1978,6 +1981,7 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type: return result def custom_equality_method(self, typ: Type) -> bool: + """Does this type have a custom __eq__() method?""" if isinstance(typ, UnionType): return any(self.custom_equality_method(t) for t in typ.items) if isinstance(typ, Instance): @@ -1985,9 +1989,11 @@ def custom_equality_method(self, typ: Type) -> bool: if method and method.info: return not method.info.fullname().startswith('builtins.') return False + # TODO: support other types (see has_member())? return False def has_bytes_component(self, typ: Type) -> bool: + """Is this the builtin bytes type, or a union that contains it?""" if isinstance(typ, UnionType): return any(self.has_bytes_component(t) for t in typ.items) if isinstance(typ, Instance) and typ.type.fullname() == 'builtins.bytes': @@ -2019,6 +2025,8 @@ def dangerous_comparison(self, left: Type, right: Type, left = remove_optional(left) right = remove_optional(right) if self.has_bytes_component(original_cont_type) and self.has_bytes_component(left): + # We need to special case bytes, because both 97 in b'abc' and b'a' in b'abc' + # return True (and we want to show the error only if the check can _never_ be True). return False return not is_overlapping_types(left, right, ignore_promotions=False) From f858c12a97c43ad7b1020a9637a03a5436291938 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Mon, 15 Apr 2019 00:12:31 +0100 Subject: [PATCH 5/7] Fix self-check --- mypy/checkexpr.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index d1cbbe39b4fb..084d8fd168b1 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -2024,7 +2024,8 @@ def dangerous_comparison(self, left: Type, right: Type, if isinstance(left, UnionType) and isinstance(right, UnionType): left = remove_optional(left) right = remove_optional(right) - if self.has_bytes_component(original_cont_type) and self.has_bytes_component(left): + if (original_cont_type and self.has_bytes_component(original_cont_type) and + self.has_bytes_component(left)): # We need to special case bytes, because both 97 in b'abc' and b'a' in b'abc' # return True (and we want to show the error only if the check can _never_ be True). return False From a2ee47144e9d9d00704e1272fb525e64580ffe08 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Sat, 27 Apr 2019 10:22:54 -0700 Subject: [PATCH 6/7] Address CR first part --- mypy/checkexpr.py | 64 ++++++++++++++++----------- test-data/unit/check-expressions.test | 28 ++++++++++++ 2 files changed, 65 insertions(+), 27 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 0951ebcd6100..adf1680f2945 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -1939,7 +1939,7 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type: # Only show dangerous overlap if there are no other errors. elif (not local_errors.is_errors() and cont_type and self.dangerous_comparison(left_type, cont_type, - original_cont_type=right_type)): + original_container=right_type)): self.msg.dangerous_comparison(left_type, cont_type, 'container', e) else: self.msg.add_errors(local_errors) @@ -1952,8 +1952,8 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type: # testCustomEqCheckStrictEquality for an example. if self.msg.errors.total_errors() == err_count and operator in ('==', '!='): right_type = self.accept(right) - if (not self.custom_equality_method(left_type) and - not self.custom_equality_method(right_type)): + if (not custom_equality_method(left_type) and + not custom_equality_method(right_type)): # We suppress the error if there is a custom __eq__() method on either # side. User defined (or even standard library) classes can define this # to return True for comparisons between non-overlapping types. @@ -1980,31 +1980,11 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type: assert result is not None return result - def custom_equality_method(self, typ: Type) -> bool: - """Does this type have a custom __eq__() method?""" - if isinstance(typ, UnionType): - return any(self.custom_equality_method(t) for t in typ.items) - if isinstance(typ, Instance): - method = typ.type.get_method('__eq__') - if method and method.info: - return not method.info.fullname().startswith('builtins.') - return False - # TODO: support other types (see has_member())? - return False - - def has_bytes_component(self, typ: Type) -> bool: - """Is this the builtin bytes type, or a union that contains it?""" - if isinstance(typ, UnionType): - return any(self.has_bytes_component(t) for t in typ.items) - if isinstance(typ, Instance) and typ.type.fullname() == 'builtins.bytes': - return True - return False - def dangerous_comparison(self, left: Type, right: Type, - original_cont_type: Optional[Type] = None) -> bool: + original_container: Optional[Type] = None) -> bool: """Check for dangerous non-overlapping comparisons like 42 == 'no'. - The original_cont_type is the original container type for 'in' checks + The original_container is the original container type for 'in' checks (and None for equality checks). Rules: @@ -2024,8 +2004,8 @@ def dangerous_comparison(self, left: Type, right: Type, if isinstance(left, UnionType) and isinstance(right, UnionType): left = remove_optional(left) right = remove_optional(right) - if (original_cont_type and self.has_bytes_component(original_cont_type) and - self.has_bytes_component(left)): + if (original_container and has_bytes_component(original_container) and + has_bytes_component(left)): # We need to special case bytes, because both 97 in b'abc' and b'a' in b'abc' # return True (and we want to show the error only if the check can _never_ be True). return False @@ -3842,3 +3822,33 @@ def is_expr_literal_type(node: Expression) -> bool: underlying = node.node return isinstance(underlying, TypeAlias) and isinstance(underlying.target, LiteralType) return False + + +def custom_equality_method(typ: Type) -> bool: + """Does this type have a custom __eq__() method?""" + if isinstance(typ, Instance): + method = typ.type.get_method('__eq__') + if method and method.info: + return not method.info.fullname().startswith('builtins.') + return False + if isinstance(typ, UnionType): + return any(custom_equality_method(t) for t in typ.items) + if isinstance(typ, TupleType): + return custom_equality_method(tuple_fallback(typ)) + if isinstance(typ, CallableType) and typ.is_type_obj(): + # Look up __eq__ on the metaclass for class objects. + return custom_equality_method(typ.fallback) + if isinstance(typ, AnyType): + # Avoid false positives in uncertain cases. + return True + # TODO: support other types (see ExpressionChecker.has_member())? + return False + + +def has_bytes_component(typ: Type) -> bool: + """Is this the builtin bytes type, or a union that contains it?""" + if isinstance(typ, UnionType): + return any(has_bytes_component(t) for t in typ.items) + if isinstance(typ, Instance) and typ.type.fullname() == 'builtins.bytes': + return True + return False diff --git a/test-data/unit/check-expressions.test b/test-data/unit/check-expressions.test index 72a9876006b0..001abb546d58 100644 --- a/test-data/unit/check-expressions.test +++ b/test-data/unit/check-expressions.test @@ -2136,6 +2136,34 @@ x: Union[A, str] x == int() [builtins fixtures/bool.pyi] +[case testCustomEqCheckStrictEqualityTuple] +# flags: --strict-equality +from typing import NamedTuple + +class Base(NamedTuple): + attr: int + +class Custom(Base): + def __eq__(self, other: object) -> bool: ... + +Base(int()) == int() # E: Non-overlapping equality check (left operand type: "Base", right operand type: "int") +Base(int()) == tuple() +Custom(int()) == int() +[builtins fixtures/bool.pyi] + +[case testCustomEqCheckStrictEqualityMeta] +# flags: --strict-equality +class CustomMeta(type): + def __eq__(self, other: object) -> bool: ... + +class Normal: ... +class Custom(metaclass=CustomMeta): ... + +Normal == int() # E: Non-overlapping equality check (left operand type: "Type[Normal]", right operand type: "int") +Normal == Normal +Custom == int() +[builtins fixtures/bool.pyi] + [case testCustomContainsCheckStrictEquality] # flags: --strict-equality class A: From b5ed8b8956921bf16ec6a67321a20aacc92575e6 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Sat, 27 Apr 2019 10:47:58 -0700 Subject: [PATCH 7/7] Tweak the docs --- docs/source/command_line.rst | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/docs/source/command_line.rst b/docs/source/command_line.rst index 1aa1924eea9f..fc523b959bcf 100644 --- a/docs/source/command_line.rst +++ b/docs/source/command_line.rst @@ -396,15 +396,17 @@ of the above sections. .. code-block:: python - from typing import Text + from typing import List, Text - text: Text - if b'some bytes' in text: # Error: non-overlapping check! + items: List[int] + if 'some string' in items: # Error: non-overlapping container check! ... - if text != b'other bytes': # Error: non-overlapping check! + + text: Text + if text != b'other bytes': # Error: non-overlapping equality check! ... - assert text is not None # OK, this special case is allowed. + assert text is not None # OK, check against None is allowed as a special case. ``--strict`` This flag mode enables all optional error checking flags. You can see the