diff --git a/thrift/lib/python/types.pyx b/thrift/lib/python/types.pyx index b728318309b..e49189a52d3 100644 --- a/thrift/lib/python/types.pyx +++ b/thrift/lib/python/types.pyx @@ -25,6 +25,7 @@ from cpython.object cimport Py_LT, Py_EQ, PyCallable_Check from cpython.tuple cimport PyTuple_New, PyTuple_SET_ITEM, PyTuple_GET_ITEM, PyTuple_Check from cpython.set cimport PyFrozenSet_New, PySet_Add from cpython.ref cimport Py_INCREF, Py_DECREF +from cpython.long cimport PyLong_AsLong from cpython.unicode cimport PyUnicode_AsUTF8String, PyUnicode_FromEncodedObject from cython.operator cimport dereference as deref @@ -2292,12 +2293,16 @@ class EnumMeta(type): for name, value in dct.items(): if not isinstance(value, int): attrs[name] = value + klass = super().__new__( metacls, classname, bases, attrs, ) + if int in bases: + type.__setattr__(klass, "__eq__", Enum.__int__eq__) + for name, value in dct.items(): if not isinstance(value, int): continue @@ -2361,6 +2366,23 @@ cdef inline _fbthrift_enum_equivalent(a, b): cdef int b_types = b_module.rfind(".") return a_module[:a_types] == b_module[:b_types] +cdef inline bint _enum_eq_(self, other): + if isinstance(other, Enum): + if self is other: + return True + # handle py3 vs python comparison + return ( + self._fbthrift_value_ == other._fbthrift_value_ and + _fbthrift_enum_equivalent(self, other) + ) + if cFollyIsDebug and isinstance(other, (bool, float)): + warnings.warn( + f"Did you really mean to compare {type(self)} and {type(other)}?", + RuntimeWarning, + stacklevel=1 + ) + return self._fbthrift_value_ == other + class Enum(metaclass=EnumMeta): def __init__(self, _): # pass on purpose to keep the __init__ interface consistent with the other base class (i.e. int) @@ -2391,21 +2413,7 @@ class Enum(metaclass=EnumMeta): return type(self), (self.value,) def __eq__(self, other): - if isinstance(other, Enum): - if self is other: - return True - # handle py3 vs python comparison - return ( - self._fbthrift_value_ == other._fbthrift_value_ and - _fbthrift_enum_equivalent(self, other) - ) - if cFollyIsDebug and isinstance(other, (bool, float)): - warnings.warn( - f"Did you really mean to compare {type(self)} and {type(other)}?", - RuntimeWarning, - stacklevel=1 - ) - return self._fbthrift_value_ == other + return _enum_eq_(self, other) # thrift-python enums have int base, so have to define # __ne__ to avoid __ne__ based on int value alone @@ -2430,6 +2438,11 @@ class Enum(metaclass=EnumMeta): def __bool__(self): return True + def __int__eq__(self, other): + if type(self) is type(other): + return PyLong_AsLong(self) == PyLong_AsLong(other) + return _enum_eq_(self, other) + class Flag(Enum): @classmethod