From 7c2f5e80de3a08bacd125164178da5f739b1b379 Mon Sep 17 00:00:00 2001 From: jatoben Date: Tue, 25 Jun 2024 22:41:42 -0700 Subject: [PATCH] Don't raise `TypeError` from generated equality method (#4287) * Don't raise TypeError in derived equality method * Add newsfragment --- newsfragments/4287.changed.md | 1 + pyo3-macros-backend/src/pyclass.rs | 10 ++++++--- pytests/src/comparisons.rs | 13 ++++++++++++ pytests/tests/test_comparisons.py | 33 ++++++++++++++++++++++++------ 4 files changed, 48 insertions(+), 9 deletions(-) create mode 100644 newsfragments/4287.changed.md diff --git a/newsfragments/4287.changed.md b/newsfragments/4287.changed.md new file mode 100644 index 00000000000..440e123cebf --- /dev/null +++ b/newsfragments/4287.changed.md @@ -0,0 +1 @@ +Return `NotImplemented` from generated equality method when comparing different types. diff --git a/pyo3-macros-backend/src/pyclass.rs b/pyo3-macros-backend/src/pyclass.rs index 07d9e32e528..fd85cfa3bb6 100644 --- a/pyo3-macros-backend/src/pyclass.rs +++ b/pyo3-macros-backend/src/pyclass.rs @@ -1844,9 +1844,13 @@ fn pyclass_richcmp( op: #pyo3_path::pyclass::CompareOp ) -> #pyo3_path::PyResult<#pyo3_path::PyObject> { let self_val = self; - let other = &*#pyo3_path::types::PyAnyMethods::downcast::(other)?.borrow(); - match op { - #arms + if let Ok(other) = #pyo3_path::types::PyAnyMethods::downcast::(other) { + let other = &*other.borrow(); + match op { + #arms + } + } else { + ::std::result::Result::Ok(py.NotImplemented()) } } }; diff --git a/pytests/src/comparisons.rs b/pytests/src/comparisons.rs index fa35acf8e1a..5c7f659c9b3 100644 --- a/pytests/src/comparisons.rs +++ b/pytests/src/comparisons.rs @@ -34,6 +34,18 @@ impl EqDefaultNe { } } +#[pyclass(eq)] +#[derive(PartialEq, Eq)] +struct EqDerived(i64); + +#[pymethods] +impl EqDerived { + #[new] + fn new(value: i64) -> Self { + Self(value) + } +} + #[pyclass] struct Ordered(i64); @@ -104,6 +116,7 @@ impl OrderedDefaultNe { pub fn comparisons(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; Ok(()) diff --git a/pytests/tests/test_comparisons.py b/pytests/tests/test_comparisons.py index 508cdeb2465..50bba81cb1a 100644 --- a/pytests/tests/test_comparisons.py +++ b/pytests/tests/test_comparisons.py @@ -1,7 +1,13 @@ from typing import Type, Union import pytest -from pyo3_pytests.comparisons import Eq, EqDefaultNe, Ordered, OrderedDefaultNe +from pyo3_pytests.comparisons import ( + Eq, + EqDefaultNe, + EqDerived, + Ordered, + OrderedDefaultNe, +) from typing_extensions import Self @@ -9,15 +15,23 @@ class PyEq: def __init__(self, x: int) -> None: self.x = x - def __eq__(self, other: Self) -> bool: - return self.x == other.x + def __eq__(self, other: object) -> bool: + if isinstance(other, self.__class__): + return self.x == other.x + else: + return NotImplemented def __ne__(self, other: Self) -> bool: - return self.x != other.x + if isinstance(other, self.__class__): + return self.x != other.x + else: + return NotImplemented -@pytest.mark.parametrize("ty", (Eq, PyEq), ids=("rust", "python")) -def test_eq(ty: Type[Union[Eq, PyEq]]): +@pytest.mark.parametrize( + "ty", (Eq, EqDerived, PyEq), ids=("rust", "rust-derived", "python") +) +def test_eq(ty: Type[Union[Eq, EqDerived, PyEq]]): a = ty(0) b = ty(0) c = ty(1) @@ -32,6 +46,13 @@ def test_eq(ty: Type[Union[Eq, PyEq]]): assert b != c assert not (b == c) + assert not a == 0 + assert a != 0 + assert not b == 0 + assert b != 1 + assert not c == 1 + assert c != 1 + with pytest.raises(TypeError): assert a <= b