Skip to content

Commit

Permalink
Merge pull request #3506 from davidhewitt/default-ne
Browse files Browse the repository at this point in the history
Fix bug in default implementation of `__ne__`
  • Loading branch information
davidhewitt authored Oct 11, 2023
2 parents b73c069 + e1d4173 commit b03c4cb
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 22 deletions.
29 changes: 16 additions & 13 deletions guide/src/class/object.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ impl Number {

In the `__repr__`, we used a hard-coded class name. This is sometimes not ideal,
because if the class is subclassed in Python, we would like the repr to reflect
the subclass name. This is typically done in Python code by accessing
the subclass name. This is typically done in Python code by accessing
`self.__class__.__name__`. In order to be able to access the Python type information
*and* the Rust struct, we need to use a `PyCell` as the `self` argument.

Expand Down Expand Up @@ -149,8 +149,8 @@ impl Number {
### Comparisons
Unlike in Python, PyO3 does not provide the magic comparison methods you might expect like `__eq__`,
`__lt__` and so on. Instead you have to implement all six operations at once with `__richcmp__`.
PyO3 supports the usual magic comparison methods available in Python such as `__eq__`, `__lt__`
and so on. It is also possible to support all six operations at once with `__richcmp__`.
This method will be called with a value of `CompareOp` depending on the operation.
```rust
Expand Down Expand Up @@ -198,28 +198,31 @@ impl Number {
It checks that the `std::cmp::Ordering` obtained from Rust's `Ord` matches
the given `CompareOp`.

Alternatively, if you want to leave some operations unimplemented, you can
return `py.NotImplemented()` for some of the operations:
Alternatively, you can implement just equality using `__eq__`:


```rust
use pyo3::class::basic::CompareOp;

# use pyo3::prelude::*;
#
# #[pyclass]
# struct Number(i32);
#
#[pymethods]
impl Number {
fn __richcmp__(&self, other: &Self, op: CompareOp, py: Python<'_>) -> PyObject {
match op {
CompareOp::Eq => (self.0 == other.0).into_py(py),
CompareOp::Ne => (self.0 != other.0).into_py(py),
_ => py.NotImplemented(),
}
fn __eq__(&self, other: &Self) -> bool {
self.0 == other.0
}
}

# fn main() -> PyResult<()> {
# Python::with_gil(|py| {
# let x = PyCell::new(py, Number(4))?;
# let y = PyCell::new(py, Number(4))?;
# assert!(x.eq(y)?);
# assert!(!x.ne(y)?);
# Ok(())
# })
# }
```

### Truthyness
Expand Down
30 changes: 26 additions & 4 deletions guide/src/class/protocols.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,19 +76,41 @@ given signatures should be interpreted as follows:
- `__richcmp__(<self>, object, pyo3::basic::CompareOp) -> object`

Implements Python comparison operations (`==`, `!=`, `<`, `<=`, `>`, and `>=`) in a single method.
The `CompareOp` argument indicates the comparison operation being performed.
The `CompareOp` argument indicates the comparison operation being performed. You can use
[`CompareOp::matches`] to adapt a Rust `std::cmp::Ordering` result to the requested comparison.

_This method cannot be implemented in combination with any of `__lt__`, `__le__`, `__eq__`, `__ne__`, `__gt__`, or `__ge__`._

_Note that implementing `__richcmp__` will cause Python not to generate a default `__hash__` implementation, so consider implementing `__hash__` when implementing `__richcmp__`._
<details>
<summary>Return type</summary>
The return type will normally be `PyResult<bool>`, but any Python object can be returned.

If you want to leave some operations unimplemented, you can return `py.NotImplemented()`
for some of the operations:

```rust
use pyo3::class::basic::CompareOp;

# use pyo3::prelude::*;
#
# #[pyclass]
# struct Number(i32);
#
#[pymethods]
impl Number {
fn __richcmp__(&self, other: &Self, op: CompareOp, py: Python<'_>) -> PyObject {
match op {
CompareOp::Eq => (self.0 == other.0).into_py(py),
CompareOp::Ne => (self.0 != other.0).into_py(py),
_ => py.NotImplemented(),
}
}
}
```

If the second argument `object` is not of the type specified in the
signature, the generated code will automatically `return NotImplemented`.

You can use [`CompareOp::matches`] to adapt a Rust `std::cmp::Ordering` result
to the requested comparison.
</details>

- `__getattr__(<self>, object) -> object`
Expand Down
16 changes: 15 additions & 1 deletion pytests/tests/test_comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,14 @@ def test_eq(ty: Type[Union[Eq, PyEq]]):
c = ty(1)

assert a == b
assert not (a != b)
assert a != c
assert not (a == c)

assert b == a
assert not (a != b)
assert b != c
assert not (b == c)

with pytest.raises(TypeError):
assert a <= b
Expand All @@ -49,17 +53,21 @@ def __eq__(self, other: Self) -> bool:
return self.x == other.x


@pytest.mark.parametrize("ty", (Eq, PyEq), ids=("rust", "python"))
@pytest.mark.parametrize("ty", (EqDefaultNe, PyEqDefaultNe), ids=("rust", "python"))
def test_eq_default_ne(ty: Type[Union[EqDefaultNe, PyEqDefaultNe]]):
a = ty(0)
b = ty(0)
c = ty(1)

assert a == b
assert not (a != b)
assert a != c
assert not (a == c)

assert b == a
assert not (a != b)
assert b != c
assert not (b == c)

with pytest.raises(TypeError):
assert a <= b
Expand Down Expand Up @@ -152,19 +160,25 @@ def test_ordered_default_ne(ty: Type[Union[OrderedDefaultNe, PyOrderedDefaultNe]
c = ty(1)

assert a == b
assert not (a != b)
assert a <= b
assert a >= b
assert a != c
assert not (a == c)
assert a <= c

assert b == a
assert not (b != a)
assert b <= a
assert b >= a
assert b != c
assert not (b == c)
assert b <= c

assert c != a
assert not (c == a)
assert c != b
assert not (c == b)
assert c > a
assert c >= a
assert c > b
Expand Down
12 changes: 8 additions & 4 deletions src/impl_/pyclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use crate::{
internal_tricks::extract_c_string,
pycell::PyCellLayout,
pyclass_init::PyObjectInit,
types::PyBool,
Py, PyAny, PyCell, PyClass, PyErr, PyMethodDefType, PyNativeType, PyResult, PyTypeInfo, Python,
};
use std::{
Expand Down Expand Up @@ -805,11 +806,14 @@ slot_fragment_trait! {
#[inline]
unsafe fn __ne__(
self,
_py: Python<'_>,
_slf: *mut ffi::PyObject,
_other: *mut ffi::PyObject,
py: Python<'_>,
slf: *mut ffi::PyObject,
other: *mut ffi::PyObject,
) -> PyResult<*mut ffi::PyObject> {
Ok(ffi::_Py_NewRef(ffi::Py_NotImplemented()))
// By default `__ne__` will try `__eq__` and invert the result
let slf: &PyAny = py.from_borrowed_ptr(slf);
let other: &PyAny = py.from_borrowed_ptr(other);
slf.eq(other).map(|is_eq| PyBool::new(py, !is_eq).into_ptr())
}
}

Expand Down

0 comments on commit b03c4cb

Please sign in to comment.