Skip to content

Commit

Permalink
Comparing a proto message with an object of unknown returns NotImplem…
Browse files Browse the repository at this point in the history
…ented

The Python comparison protocol requires that if an object doesn't know how to
compare itself to an object of a different type, it returns NotImplemented
rather than False. The interpreter will then try performing the comparison using
the other operand. This translates, for protos, to:
If a proto message doesn't know how to compare itself to an object of
non-message type, it returns NotImplemented. This way, the interpreter will then
try performing the comparison using the comparison methods of the other object,
which may know how to compare itself to a message. If not, then Python will
return the combined result (e.g., if both objects don't know how to perform
__eq__, then the equality operator `==` return false).
This change allows one to compare a proto with custom matchers such as mock.ANY
that the message doesn't know how to compare to, regardless of whether
mock.ANY is on the right-hand side or left-hand side of the equality (prior to
this change, it only worked with mock.ANY on the left-hand side).

Fixes #9173

PiperOrigin-RevId: 561728156
  • Loading branch information
protobuf-github-bot authored and copybara-github committed Aug 31, 2023
1 parent 14222b3 commit 12d4f41
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 18 deletions.
37 changes: 37 additions & 0 deletions python/google/protobuf/internal/message_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import pydoc
import sys
import unittest
from unittest import mock
import warnings

cmp = lambda x, y: (x > y) - (x < y)
Expand Down Expand Up @@ -1268,6 +1269,42 @@ def testReturningType(self, message_module):
self.assertEqual(bool, type(m.repeated_bool[0]))
self.assertEqual(True, m.repeated_bool[0])

def testEquality(self, message_module):
m = message_module.TestAllTypes()
m2 = message_module.TestAllTypes()
self.assertEqual(m, m)
self.assertEqual(m, m2)
self.assertEqual(m2, m)

different_m = message_module.TestAllTypes()
different_m.repeated_float.append(1)
self.assertNotEqual(m, different_m)
self.assertNotEqual(different_m, m)

self.assertIsNotNone(m)
self.assertIsNotNone(m)
self.assertNotEqual(42, m)
self.assertNotEqual(m, 42)
self.assertNotEqual('foo', m)
self.assertNotEqual(m, 'foo')

self.assertEqual(mock.ANY, m)
self.assertEqual(m, mock.ANY)

class ComparesWithFoo(object):

def __eq__(self, other):
if getattr(other, 'optional_string', 'not_foo') == 'foo':
return True
return NotImplemented

m.optional_string = 'foo'
self.assertEqual(m, ComparesWithFoo())
self.assertEqual(ComparesWithFoo(), m)
m.optional_string = 'bar'
self.assertNotEqual(m, ComparesWithFoo())
self.assertNotEqual(ComparesWithFoo(), m)


# Class to test proto2-only features (required, extensions, etc.)
@testing_refleaks.TestCase
Expand Down
2 changes: 1 addition & 1 deletion python/google/protobuf/internal/python_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -987,7 +987,7 @@ def _AddEqualsMethod(message_descriptor, cls):
def __eq__(self, other):
if (not isinstance(other, message_mod.Message) or
other.DESCRIPTOR != self.DESCRIPTOR):
return False
return NotImplemented

if self is other:
return True
Expand Down
35 changes: 18 additions & 17 deletions python/google/protobuf/pyext/message.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2045,25 +2045,26 @@ static PyObject* RichCompare(CMessage* self, PyObject* other, int opid) {
Py_INCREF(Py_NotImplemented);
return Py_NotImplemented;
}
bool equals = true;
// If other is not a message, it cannot be equal.
// If other is not a message, this implementation doesn't know how to perform
// comparisons.
if (!PyObject_TypeCheck(other, CMessage_Type)) {
Py_INCREF(Py_NotImplemented);
return Py_NotImplemented;
}
// Otherwise, we have a CMessage whose message we can inspect.
bool equals = true;
const google::protobuf::Message* other_message =
reinterpret_cast<CMessage*>(other)->message;
// If messages don't have the same descriptors, they are not equal.
if (equals &&
self->message->GetDescriptor() != other_message->GetDescriptor()) {
equals = false;
}
// Check the message contents.
if (equals &&
!google::protobuf::util::MessageDifferencer::Equals(
*self->message, *reinterpret_cast<CMessage*>(other)->message)) {
equals = false;
} else {
// Otherwise, we have a CMessage whose message we can inspect.
const google::protobuf::Message* other_message =
reinterpret_cast<CMessage*>(other)->message;
// If messages don't have the same descriptors, they are not equal.
if (equals &&
self->message->GetDescriptor() != other_message->GetDescriptor()) {
equals = false;
}
// Check the message contents.
if (equals &&
!google::protobuf::util::MessageDifferencer::Equals(
*self->message, *reinterpret_cast<CMessage*>(other)->message)) {
equals = false;
}
}

if (equals ^ (opid == Py_EQ)) {
Expand Down

0 comments on commit 12d4f41

Please sign in to comment.