Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#1064: Comparisons with __eq__ should not raise TypeError #1072

Merged
merged 18 commits into from
Aug 5, 2020
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
- `PyType::as_type_ptr` is no longer `unsafe`. [#1047](https://github.com/PyO3/pyo3/pull/1047)
- Change `PyIterator::from_object` to return `PyResult<PyIterator>` instead of `Result<PyIterator, PyDowncastError>`. [#1051](https://github.com/PyO3/pyo3/pull/1051)
- Implement `Send + Sync` for `PyErr`. `PyErr::new`, `PyErr::from_type`, `PyException::py_err` and `PyException::into` have had these bounds added to their arguments. [#1067](https://github.com/PyO3/pyo3/pull/1067)
- Change `#[pyproto]` to return NotImplemented for operators for which Python can try a reversed operation. [1072](https://github.com/PyO3/pyo3/pull/1072)

### Removed
- Remove `PyString::as_bytes`. [#1023](https://github.com/PyO3/pyo3/pull/1023)
Expand Down
85 changes: 85 additions & 0 deletions guide/src/class.md
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,91 @@ Each method corresponds to Python's `self.attr`, `self.attr = value` and `del se

Determines the "truthyness" of the object.

### Emulating numeric types

The [`PyNumberProtocol`] trait allows [emulate numeric types](https://docs.python.org/3/reference/datamodel.html?highlight=__ipow__#emulating-numeric-types).
mvaled marked this conversation as resolved.
Show resolved Hide resolved

* `fn __add__(lhs: impl FromPyObject, rhs: impl FromPyObject) -> PyResult<impl ToPyObject>`
* `fn __sub__(lhs: impl FromPyObject, rhs: impl FromPyObject) -> PyResult<impl ToPyObject>`
* `fn __mul__(lhs: impl FromPyObject, rhs: impl FromPyObject) -> PyResult<impl ToPyObject>`
* `fn __matmul__(lhs: impl FromPyObject, rhs: impl FromPyObject) -> PyResult<impl ToPyObject>`
* `fn __truediv__(lhs: impl FromPyObject, rhs: impl FromPyObject) -> PyResult<impl ToPyObject>`
* `fn __floordiv__(lhs: impl FromPyObject, rhs: impl FromPyObject) -> PyResult<impl ToPyObject>`
* `fn __mod__(lhs: impl FromPyObject, rhs: impl FromPyObject) -> PyResult<impl ToPyObject>`
* `fn __divmod__(lhs: impl FromPyObject, rhs: impl FromPyObject) -> PyResult<impl ToPyObject>`
* `fn __pow__(lhs: impl FromPyObject, rhs: impl FromPyObject, modulo: Option<impl FromPyObject>) -> PyResult<impl ToPyObject>`
* `fn __lshift__(lhs: impl FromPyObject, rhs: impl FromPyObject) -> PyResult<impl ToPyObject>`
* `fn __rshift__(lhs: impl FromPyObject, rhs: impl FromPyObject) -> PyResult<impl ToPyObject>`
* `fn __and__(lhs: impl FromPyObject, rhs: impl FromPyObject) -> PyResult<impl ToPyObject>`
* `fn __or__(lhs: impl FromPyObject, rhs: impl FromPyObject) -> PyResult<impl ToPyObject>`
* `fn __xor__(lhs: impl FromPyObject, rhs: impl FromPyObject) -> PyResult<impl ToPyObject>`

These methods are called to implement the binary arithmetic operations (`+`,
`-`, `*`, `@`, `/`, `//`, `%`, `divmod()`, `pow()` and `**`, `<<`, `>>`, `&`,
`^`, and `|`).
mvaled marked this conversation as resolved.
Show resolved Hide resolved

If `rhs` is not of the type specified in the signature, the generated code
will automatically `return NotImplemented`. This is not the case for `lhs`
which must match signature or else raise a TypeError.


The reflected operations are also available:

* `fn __radd__(lhs: impl FromPyObject, rhs: impl FromPyObject) -> PyResult<impl ToPyObject>`
* `fn __rsub__(lhs: impl FromPyObject, rhs: impl FromPyObject) -> PyResult<impl ToPyObject>`
* `fn __rmul__(lhs: impl FromPyObject, rhs: impl FromPyObject) -> PyResult<impl ToPyObject>`
* `fn __rmatmul__(lhs: impl FromPyObject, rhs: impl FromPyObject) -> PyResult<impl ToPyObject>`
* `fn __rtruediv__(lhs: impl FromPyObject, rhs: impl FromPyObject) -> PyResult<impl ToPyObject>`
* `fn __rfloordiv__(lhs: impl FromPyObject, rhs: impl FromPyObject) -> PyResult<impl ToPyObject>`
* `fn __rmod__(lhs: impl FromPyObject, rhs: impl FromPyObject) -> PyResult<impl ToPyObject>`
* `fn __rdivmod__(lhs: impl FromPyObject, rhs: impl FromPyObject) -> PyResult<impl ToPyObject>`
* `fn __rpow__(lhs: impl FromPyObject, rhs: impl FromPyObject, modulo: Option<impl FromPyObject>) -> PyResult<impl ToPyObject>`
* `fn __rlshift__(lhs: impl FromPyObject, rhs: impl FromPyObject) -> PyResult<impl ToPyObject>`
* `fn __rrshift__(lhs: impl FromPyObject, rhs: impl FromPyObject) -> PyResult<impl ToPyObject>`
* `fn __rand__(lhs: impl FromPyObject, rhs: impl FromPyObject) -> PyResult<impl ToPyObject>`
* `fn __ror__(lhs: impl FromPyObject, rhs: impl FromPyObject) -> PyResult<impl ToPyObject>`
* `fn __rxor__(lhs: impl FromPyObject, rhs: impl FromPyObject) -> PyResult<impl ToPyObject>`

The code generated for these methods expect that all arguments match the
signature, or raise a TypeError.


This trait also has support the augmented arithmetic assignments (`+=`, `-=`,
`*=`, `@=`, `/=`, `//=`, `%=`, `**=`, `<<=`, `>>=`, `&=`, `^=`, `|=`):

* `fn __iadd__(&'p mut self, other: impl FromPyObject) -> PyResult<()>`
* `fn __isub__(&'p mut self, other: impl FromPyObject) -> PyResult<()>`
* `fn __imul__(&'p mut self, other: impl FromPyObject) -> PyResult<()>`
* `fn __imatmul__(&'p mut self, other: impl FromPyObject) -> PyResult<()>`
* `fn __itruediv__(&'p mut self, other: impl FromPyObject) -> PyResult<()>`
* `fn __ifloordiv__(&'p mut self, other: impl FromPyObject) -> PyResult<()>`
* `fn __imod__(&'p mut self, other: impl FromPyObject) -> PyResult<()>`
* `fn __ipow__(&'p mut self, other: impl FromPyObject) -> PyResult<()>`
* `fn __ilshift__(&'p mut self, other: impl FromPyObject) -> PyResult<()>`
* `fn __irshift__(&'p mut self, other: impl FromPyObject) -> PyResult<()>`
* `fn __iand__(&'p mut self, other: impl FromPyObject) -> PyResult<()>`
* `fn __ior__(&'p mut self, other: impl FromPyObject) -> PyResult<()>`
* `fn __ixor__(&'p mut self, other: impl FromPyObject) -> PyResult<()>`

The following implement the unary arithmetic operations (`-`, `+`, `abs()` and
`~`):
mvaled marked this conversation as resolved.
Show resolved Hide resolved

* `fn __neg__(&'p self) -> PyResult<impl ToPyObject>`
* `fn __pos__(&'p self) -> PyResult<impl ToPyObject>`
* `fn __abs__(&'p self) -> PyResult<impl ToPyObject>`
* `fn __invert__(&'p self) -> PyResult<impl ToPyObject>`

Support for coercions:

* `fn __complex__(&'p self) -> PyResult<impl ToPyObject>`
* `fn __int__(&'p self) -> PyResult<impl ToPyObject>`
* `fn __float__(&'p self) -> PyResult<impl ToPyObject>`

Other:

* `fn __index__(&'p self) -> PyResult<impl ToPyObject>`
* `fn __round__(&'p self, ndigits: Option<impl FromPyObject>) -> PyResult<impl ToPyObject>`

### Garbage Collector Integration

If your type owns references to other Python objects, you will need to
Expand Down
4 changes: 1 addition & 3 deletions src/class/basic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -275,10 +275,8 @@ where
{
crate::callback_body!(py, {
let slf = py.from_borrowed_ptr::<crate::PyCell<T>>(slf);
let arg = py.from_borrowed_ptr::<PyAny>(arg);

let arg = extract_or_return_not_implemented!(py, arg);
let op = extract_op(op)?;
let arg = arg.extract()?;

slf.try_borrow()?.__richcmp__(arg, op).convert(py)
})
Expand Down
46 changes: 34 additions & 12 deletions src/class/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,8 @@ macro_rules! py_binary_num_func {
{
$crate::callback_body!(py, {
let lhs = py.from_borrowed_ptr::<$crate::PyAny>(lhs);
let rhs = py.from_borrowed_ptr::<$crate::PyAny>(rhs);

$class::$f(lhs.extract()?, rhs.extract()?).convert(py)
let rhs = extract_or_return_not_implemented!(py, rhs);
$class::$f(lhs.extract()?, rhs).convert(py)
})
}
Some(wrap::<$class>)
Expand Down Expand Up @@ -138,7 +137,7 @@ macro_rules! py_binary_self_func {
$crate::callback_body!(py, {
let slf_ = py.from_borrowed_ptr::<$crate::PyCell<T>>(slf);
let arg = py.from_borrowed_ptr::<$crate::PyAny>(arg);
call_mut!(slf_, $f, arg).convert(py)?;
call_operator_mut!(py, slf_, $f, arg).convert(py)?;
ffi::Py_INCREF(slf);
Ok(slf)
})
Expand Down Expand Up @@ -222,13 +221,8 @@ macro_rules! py_ternary_num_func {
let arg1 = py
.from_borrowed_ptr::<$crate::types::PyAny>(arg1)
.extract()?;
let arg2 = py
.from_borrowed_ptr::<$crate::types::PyAny>(arg2)
.extract()?;
let arg3 = py
.from_borrowed_ptr::<$crate::types::PyAny>(arg3)
.extract()?;

let arg2 = extract_or_return_not_implemented!(py, arg2);
let arg3 = extract_or_return_not_implemented!(py, arg3);
$class::$f(arg1, arg2, arg3).convert(py)
})
}
Expand Down Expand Up @@ -279,7 +273,7 @@ macro_rules! py_dummy_ternary_self_func {
$crate::callback_body!(py, {
let slf_cell = py.from_borrowed_ptr::<$crate::PyCell<T>>(slf);
let arg1 = py.from_borrowed_ptr::<$crate::PyAny>(arg1);
call_mut!(slf_cell, $f, arg1).convert(py)?;
call_operator_mut!(py, slf_cell, $f, arg1).convert(py)?;
ffi::Py_INCREF(slf);
Ok(slf)
})
Expand Down Expand Up @@ -375,13 +369,35 @@ macro_rules! py_func_set_del {
}};
}

macro_rules! extract_or_return_not_implemented {
($py: ident, $arg: ident) => {
match $py
.from_borrowed_ptr::<$crate::types::PyAny>($arg)
.extract()
{
Ok(value) => value,
Err(_) => return $py.NotImplemented().convert($py),
}
};
}

macro_rules! _call_impl {
($slf: expr, $fn: ident $(; $args: expr)*) => {
$slf.$fn($($args,)*)
};
($slf: expr, $fn: ident, $raw_arg: expr $(,$raw_args: expr)* $(; $args: expr)*) => {
_call_impl!($slf, $fn $(,$raw_args)* $(;$args)* ;$raw_arg.extract()?)
};
(op $py:ident; $slf: expr, $fn: ident, $raw_arg: expr $(,$raw_args: expr)* $(; $args: expr)*) => {
_call_impl!(
$slf, $fn ;
(match $raw_arg.extract() {
Ok(res) => res,
_=> return Ok($py.NotImplemented().convert($py)?)
})
$(;$args)*
)
}
}

/// Call `slf.try_borrow()?.$fn(...)`
Expand All @@ -397,3 +413,9 @@ macro_rules! call_mut {
_call_impl!($slf.try_borrow_mut()?, $fn $(,$raw_args)* $(;$args)*)
};
}

macro_rules! call_operator_mut {
($py:ident, $slf: expr, $fn: ident $(,$raw_args: expr)* $(; $args: expr)*) => {
_call_impl!(op $py; $slf.try_borrow_mut()?, $fn $(,$raw_args)* $(;$args)*)
};
}
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ pub mod buffer;
pub mod callback;
pub mod class;
pub mod conversion;
#[macro_use]
#[doc(hidden)]
pub mod derive_utils;
mod err;
Expand Down
182 changes: 182 additions & 0 deletions tests/test_arithmetics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -423,3 +423,185 @@ fn rich_comparisons_python_3_type_error() {
py_expect_exception!(py, c2, "c2 >= 1", PyTypeError);
py_expect_exception!(py, c2, "1 >= c2", PyTypeError);
}

mvaled marked this conversation as resolved.
Show resolved Hide resolved
// Checks that binary operations for which the arguments don't match the
// required type, return NotImplemented.
mod return_not_implemented {
use super::*;

#[pyclass]
struct RichComparisonToSelf {}

#[pyproto]
impl<'p> PyObjectProtocol<'p> for RichComparisonToSelf {
fn __repr__(&self) -> &'static str {
"RC_Self"
}

fn __richcmp__(&self, other: PyRef<'p, Self>, _op: CompareOp) -> PyObject {
other.py().None()
}
}

#[pyproto]
impl<'p> PyNumberProtocol<'p> for RichComparisonToSelf {
mvaled marked this conversation as resolved.
Show resolved Hide resolved
fn __add__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny {
lhs
}
fn __sub__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny {
lhs
}
fn __mul__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny {
lhs
}
fn __matmul__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny {
lhs
}
fn __truediv__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny {
lhs
}
fn __floordiv__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny {
lhs
}
fn __mod__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny {
lhs
}
fn __pow__(lhs: &'p PyAny, _other: u8, _modulo: Option<u8>) -> &'p PyAny {
lhs
}
fn __lshift__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny {
lhs
}
fn __rshift__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny {
lhs
}
fn __divmod__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny {
lhs
}
fn __and__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny {
lhs
}
fn __or__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny {
lhs
}
fn __xor__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny {
lhs
}

// Inplace assignments
fn __iadd__(&'p mut self, _other: PyRef<'p, Self>) {}
fn __isub__(&'p mut self, _other: PyRef<'p, Self>) {}
fn __imul__(&'p mut self, _other: PyRef<'p, Self>) {}
fn __imatmul__(&'p mut self, _other: PyRef<'p, Self>) {}
fn __itruediv__(&'p mut self, _other: PyRef<'p, Self>) {}
fn __ifloordiv__(&'p mut self, _other: PyRef<'p, Self>) {}
fn __imod__(&'p mut self, _other: PyRef<'p, Self>) {}
fn __ipow__(&'p mut self, _other: PyRef<'p, Self>) {}
fn __ilshift__(&'p mut self, _other: PyRef<'p, Self>) {}
fn __irshift__(&'p mut self, _other: PyRef<'p, Self>) {}
fn __iand__(&'p mut self, _other: PyRef<'p, Self>) {}
fn __ior__(&'p mut self, _other: PyRef<'p, Self>) {}
fn __ixor__(&'p mut self, _other: PyRef<'p, Self>) {}
}

fn _test_binary_dunder(dunder: &str) {
let gil = Python::acquire_gil();
let py = gil.python();
let c2 = PyCell::new(py, RichComparisonToSelf {}).unwrap();
py_run!(
py,
c2,
&format!(
"class Other: pass\nassert c2.__{}__(Other()) is NotImplemented",
dunder
)
);
}

fn _test_binary_operator(operator: &str, dunder: &str) {
_test_binary_dunder(dunder);

let gil = Python::acquire_gil();
let py = gil.python();
let c2 = PyCell::new(py, RichComparisonToSelf {}).unwrap();
py_expect_exception!(
py,
c2,
&format!("class Other: pass\nc2 {} Other()", operator),
PyTypeError
)
}

fn _test_inplace_binary_operator(operator: &str, dunder: &str) {
_test_binary_operator(operator, dunder);
}

#[test]
fn equality() {
_test_binary_dunder("eq");
_test_binary_dunder("ne");
}

#[test]
fn ordering() {
_test_binary_operator("<", "lt");
_test_binary_operator("<=", "le");
_test_binary_operator(">", "gt");
_test_binary_operator(">=", "ge");
}

#[test]
fn bitwise() {
_test_binary_operator("&", "and");
_test_binary_operator("|", "or");
_test_binary_operator("^", "xor");
_test_binary_operator("<<", "lshift");
_test_binary_operator(">>", "rshift");
}

#[test]
fn arith() {
_test_binary_operator("+", "add");
_test_binary_operator("-", "sub");
_test_binary_operator("*", "mul");
_test_binary_operator("@", "matmul");
_test_binary_operator("/", "truediv");
_test_binary_operator("//", "floordiv");
_test_binary_operator("%", "mod");
_test_binary_operator("**", "pow");
}

#[test]
#[ignore]
fn reverse_arith() {
_test_binary_dunder("radd");
_test_binary_dunder("rsub");
_test_binary_dunder("rmul");
_test_binary_dunder("rmatmul");
_test_binary_dunder("rtruediv");
_test_binary_dunder("rfloordiv");
_test_binary_dunder("rmod");
_test_binary_dunder("rpow");
}

#[test]
fn inplace_bitwise() {
_test_inplace_binary_operator("&=", "iand");
_test_inplace_binary_operator("|=", "ior");
_test_inplace_binary_operator("^=", "ixor");
_test_inplace_binary_operator("<<=", "ilshift");
_test_inplace_binary_operator(">>=", "irshift");
}

#[test]
fn inplace_arith() {
_test_inplace_binary_operator("+=", "iadd");
_test_inplace_binary_operator("-=", "isub");
_test_inplace_binary_operator("*=", "imul");
_test_inplace_binary_operator("@=", "imatmul");
_test_inplace_binary_operator("/=", "itruediv");
_test_inplace_binary_operator("//=", "ifloordiv");
_test_inplace_binary_operator("%=", "imod");
_test_inplace_binary_operator("**=", "ipow");
}
}