Skip to content

Commit

Permalink
Py/PyAny: remove PartialEq impl and add is() (#2183)
Browse files Browse the repository at this point in the history
  • Loading branch information
birkenfeld authored Feb 25, 2022
1 parent 4873459 commit 03dc96b
Show file tree
Hide file tree
Showing 16 changed files with 85 additions and 55 deletions.
6 changes: 4 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add `PyMapping::contains` method (`in` operator for `PyMapping`). [#2133](https://github.com/PyO3/pyo3/pull/2133)
- Add garbage collection magic methods `__traverse__` and `__clear__` to `#[pymethods]`. [#2159](https://github.com/PyO3/pyo3/pull/2159)
- Add support for `from_py_with` on struct tuples and enums to override the default from-Python conversion. [#2181](https://github.com/PyO3/pyo3/pull/2181)
- Add `eq`, `ne`, `lt`, `le`, `gt`, `ge` methods to `PyAny` that wrap `rich_compare`.
- Add `eq`, `ne`, `lt`, `le`, `gt`, `ge` methods to `PyAny` that wrap `rich_compare`. [#2175](https://github.com/PyO3/pyo3/pull/2175)
- Add `Py::is` and `PyAny::is` methods to check for object identity. [#2183](https://github.com/PyO3/pyo3/pull/2183)

### Changed

Expand Down Expand Up @@ -81,7 +82,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Removed

- Remove all functionality deprecated in PyO3 0.14. [#2007](https://github.com/PyO3/pyo3/pull/2007)
- Remove `Default` impl for `PyMethodDef` [2166](https://github.com/PyO3/pyo3/pull/2166)
- Remove `Default` impl for `PyMethodDef`. [#2166](https://github.com/PyO3/pyo3/pull/2166)
- Remove `PartialEq` impl for `Py` and `PyAny` (use the new `is()` instead). [#2183](https://github.com/PyO3/pyo3/pull/2183)

### Fixed

Expand Down
13 changes: 13 additions & 0 deletions guide/src/migration.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,19 @@ impl MyClass {
}
```

### Removed `PartialEq` for object wrappers

The Python object wrappers `Py` and `PyAny` had implementations of `PartialEq`
so that `object_a == object_b` would compare the Python objects for pointer
equality, which corresponds to the `is` operator, not the `==` operator in
Python. This has been removed in favor of a new method: use
`object_a.is(object_b)`. This also has the advantage of not requiring the same
wrapper type for `object_a` and `object_b`; you can now directly compare a
`Py<T>` with a `&PyAny` without having to convert.

To check for Python object equality (the Python `==` operator), use the new
method `eq()`.

### Container magic methods now match Python behavior

In PyO3 0.15, `__getitem__`, `__setitem__` and `__delitem__` in `#[pymethods]` would generate only the _mapping_ implementation for a `#[pyclass]`. To match the Python behavior, these methods now generate both the _mapping_ **and** _sequence_ implementations.
Expand Down
2 changes: 1 addition & 1 deletion src/conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,7 @@ mod tests {
Python::with_gil(|py| {
let list = PyList::new(py, &[1, 2, 3]);
let val = unsafe { <PyList as PyTryFrom>::try_from_unchecked(list.as_ref()) };
assert_eq!(list, val);
assert!(list.is(val));
});
}

Expand Down
35 changes: 22 additions & 13 deletions src/err/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ impl PyErr {
///
/// Python::with_gil(|py| {
/// let err: PyErr = PyTypeError::new_err(("some type error",));
/// assert_eq!(err.get_type(py), PyType::new::<PyTypeError>(py));
/// assert!(err.get_type(py).is(PyType::new::<PyTypeError>(py)));
/// });
/// ```
pub fn get_type<'py>(&'py self, py: Python<'py>) -> &'py PyType {
Expand Down Expand Up @@ -231,7 +231,7 @@ impl PyErr {
///
/// Python::with_gil(|py| {
/// let err = PyTypeError::new_err(("some type error",));
/// assert_eq!(err.traceback(py), None);
/// assert!(err.traceback(py).is_none());
/// });
/// ```
pub fn traceback<'py>(&'py self, py: Python<'py>) -> Option<&'py PyTraceback> {
Expand Down Expand Up @@ -469,9 +469,12 @@ impl PyErr {
/// Python::with_gil(|py| {
/// let err: PyErr = PyTypeError::new_err(("some type error",));
/// let err_clone = err.clone_ref(py);
/// assert_eq!(err.get_type(py), err_clone.get_type(py));
/// assert_eq!(err.value(py), err_clone.value(py));
/// assert_eq!(err.traceback(py), err_clone.traceback(py));
/// assert!(err.get_type(py).is(err_clone.get_type(py)));
/// assert!(err.value(py).is(err_clone.value(py)));
/// match err.traceback(py) {
/// None => assert!(err_clone.traceback(py).is_none()),
/// Some(tb) => assert!(err_clone.traceback(py).unwrap().is(tb)),
/// }
/// });
/// ```
#[inline]
Expand Down Expand Up @@ -706,7 +709,7 @@ fn exceptions_must_derive_from_base_exception(py: Python) -> PyErr {
mod tests {
use super::PyErrState;
use crate::exceptions;
use crate::{PyErr, Python};
use crate::{AsPyPointer, PyErr, Python};

#[test]
fn no_error() {
Expand Down Expand Up @@ -857,16 +860,22 @@ mod tests {
fn deprecations() {
let err = exceptions::PyValueError::new_err("an error");
Python::with_gil(|py| {
assert_eq!(err.ptype(py), err.get_type(py));
assert_eq!(err.pvalue(py), err.value(py));
assert_eq!(err.instance(py), err.value(py));
assert_eq!(err.ptraceback(py), err.traceback(py));
assert_eq!(err.ptype(py).as_ptr(), err.get_type(py).as_ptr());
assert_eq!(err.pvalue(py).as_ptr(), err.value(py).as_ptr());
assert_eq!(err.instance(py).as_ptr(), err.value(py).as_ptr());
assert_eq!(
err.ptraceback(py).map(|t| t.as_ptr()),
err.traceback(py).map(|t| t.as_ptr())
);

assert_eq!(
err.clone_ref(py).into_instance(py).as_ref(py),
err.value(py)
err.clone_ref(py).into_instance(py).as_ref(py).as_ptr(),
err.value(py).as_ptr()
);
assert_eq!(
PyErr::from_instance(err.value(py)).value(py).as_ptr(),
err.value(py).as_ptr()
);
assert_eq!(PyErr::from_instance(err.value(py)).value(py), err.value(py));
});
}
}
2 changes: 1 addition & 1 deletion src/impl_/extract_argument.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ pub fn from_py_with_with_default<'py, T>(
#[doc(hidden)]
#[cold]
pub fn argument_extraction_error(py: Python, arg_name: &str, error: PyErr) -> PyErr {
if error.get_type(py) == PyTypeError::type_object(py) {
if error.get_type(py).is(PyTypeError::type_object(py)) {
let remapped_error =
PyTypeError::new_err(format!("argument '{}': {}", arg_name, error.value(py)));
remapped_error.set_cause(py, error.cause(py));
Expand Down
16 changes: 9 additions & 7 deletions src/instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,15 @@ where
}

impl<T> Py<T> {
/// Returns whether `self` and `other` point to the same object. To compare
/// the equality of two objects (the `==` operator), use [`eq`](PyAny::eq).
///
/// This is equivalent to the Python expression `self is other`.
#[inline]
pub fn is<U: AsPyPointer>(&self, o: &U) -> bool {
self.as_ptr() == o.as_ptr()
}

/// Gets the reference count of the `ffi::PyObject` pointer.
#[inline]
pub fn get_refcnt(&self, _py: Python) -> isize {
Expand Down Expand Up @@ -829,13 +838,6 @@ where
}
}

impl<T> PartialEq for Py<T> {
#[inline]
fn eq(&self, o: &Py<T>) -> bool {
self.0 == o.0
}
}

/// If the GIL is held this increments `self`'s reference count.
/// Otherwise this registers the [`Py`]`<T>` instance to have its reference count
/// incremented the next time PyO3 acquires the GIL.
Expand Down
9 changes: 9 additions & 0 deletions src/types/any.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,15 @@ impl PyAny {
<T as PyTryFrom>::try_from(self)
}

/// Returns whether `self` and `other` point to the same object. To compare
/// the equality of two objects (the `==` operator), use [`eq`](PyAny::eq).
///
/// This is equivalent to the Python expression `self is other`.
#[inline]
pub fn is<T: AsPyPointer>(&self, other: &T) -> bool {
self.as_ptr() == other.as_ptr()
}

/// Determines whether this object has the given attribute.
///
/// This is equivalent to the Python expression `hasattr(self, attr_name)`.
Expand Down
4 changes: 2 additions & 2 deletions src/types/boolobject.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ mod tests {
assert!(PyBool::new(py, true).is_true());
let t: &PyAny = PyBool::new(py, true).into();
assert!(t.extract::<bool>().unwrap());
assert_eq!(true.to_object(py), PyBool::new(py, true).into());
assert!(true.to_object(py).is(PyBool::new(py, true)));
});
}

Expand All @@ -79,7 +79,7 @@ mod tests {
assert!(!PyBool::new(py, false).is_true());
let t: &PyAny = PyBool::new(py, false).into();
assert!(!t.extract::<bool>().unwrap());
assert_eq!(false.to_object(py), PyBool::new(py, false).into());
assert!(false.to_object(py).is(PyBool::new(py, false)));
});
}
}
8 changes: 4 additions & 4 deletions src/types/dict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ mod tests {
Python::with_gil(|py| {
let dict = [(7, 32)].into_py_dict(py);
assert_eq!(32, dict.get_item(7i32).unwrap().extract::<i32>().unwrap());
assert_eq!(None, dict.get_item(8i32));
assert!(dict.get_item(8i32).is_none());
let map: HashMap<i32, i32> = [(7, 32)].iter().cloned().collect();
assert_eq!(map, dict.extract().unwrap());
let map: BTreeMap<i32, i32> = [(7, 32)].iter().cloned().collect();
Expand Down Expand Up @@ -426,7 +426,7 @@ mod tests {

let ndict = dict.copy().unwrap();
assert_eq!(32, ndict.get_item(7i32).unwrap().extract::<i32>().unwrap());
assert_eq!(None, ndict.get_item(8i32));
assert!(ndict.get_item(8i32).is_none());
});
}

Expand Down Expand Up @@ -464,7 +464,7 @@ mod tests {
let ob = v.to_object(py);
let dict = <PyDict as PyTryFrom>::try_from(ob.as_ref(py)).unwrap();
assert_eq!(32, dict.get_item(7i32).unwrap().extract::<i32>().unwrap());
assert_eq!(None, dict.get_item(8i32));
assert!(dict.get_item(8i32).is_none());
});
}

Expand Down Expand Up @@ -527,7 +527,7 @@ mod tests {
let dict = <PyDict as PyTryFrom>::try_from(ob.as_ref(py)).unwrap();
assert!(dict.del_item(7i32).is_ok());
assert_eq!(0, dict.len());
assert_eq!(None, dict.get_item(7i32));
assert!(dict.get_item(7i32).is_none());
});
}

Expand Down
2 changes: 1 addition & 1 deletion src/types/iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def fibonacci(target):
Python::with_gil(|py| {
let obj: Py<PyAny> = vec![10, 20].to_object(py).as_ref(py).iter().unwrap().into();
let iter: &PyIterator = PyIterator::try_from(obj.as_ref(py)).unwrap();
assert_eq!(obj, iter.into());
assert!(obj.is(iter));
});
}

Expand Down
9 changes: 0 additions & 9 deletions src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,6 @@ macro_rules! pyobject_native_type_base(
unsafe { $crate::PyObject::from_borrowed_ptr(py, self.as_ptr()) }
}
}

impl<$($generics,)*> ::std::cmp::PartialEq for $name {
#[inline]
fn eq(&self, o: &$name) -> bool {
use $crate::AsPyPointer;

self.as_ptr() == o.as_ptr()
}
}
};
);

Expand Down
4 changes: 2 additions & 2 deletions src/types/sequence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -725,11 +725,11 @@ mod tests {
let seq = ob.cast_as::<PySequence>(py).unwrap();
let rep_seq = seq.in_place_repeat(3).unwrap();
assert_eq!(6, seq.len().unwrap());
assert_eq!(seq, rep_seq);
assert!(seq.is(rep_seq));

let conc_seq = seq.in_place_concat(seq).unwrap();
assert_eq!(12, seq.len().unwrap());
assert_eq!(seq, conc_seq);
assert!(seq.is(conc_seq));
});
}

Expand Down
6 changes: 3 additions & 3 deletions src/types/string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,7 @@ mod tests {
let data = unsafe { s.data().unwrap() };
assert_eq!(data, PyStringData::Ucs1(b"f\xfe"));
let err = data.to_string(py).unwrap_err();
assert_eq!(err.get_type(py), PyUnicodeDecodeError::type_object(py));
assert!(err.get_type(py).is(PyUnicodeDecodeError::type_object(py)));
assert!(err
.to_string()
.contains("'utf-8' codec can't decode byte 0xfe in position 1"));
Expand Down Expand Up @@ -546,7 +546,7 @@ mod tests {
let data = unsafe { s.data().unwrap() };
assert_eq!(data, PyStringData::Ucs2(&[0xff22, 0xd800]));
let err = data.to_string(py).unwrap_err();
assert_eq!(err.get_type(py), PyUnicodeDecodeError::type_object(py));
assert!(err.get_type(py).is(PyUnicodeDecodeError::type_object(py)));
assert!(err
.to_string()
.contains("'utf-16' codec can't decode bytes in position 0-3"));
Expand Down Expand Up @@ -585,7 +585,7 @@ mod tests {
let data = unsafe { s.data().unwrap() };
assert_eq!(data, PyStringData::Ucs4(&[0x20000, 0xd800]));
let err = data.to_string(py).unwrap_err();
assert_eq!(err.get_type(py), PyUnicodeDecodeError::type_object(py));
assert!(err.get_type(py).is(PyUnicodeDecodeError::type_object(py)));
assert!(err
.to_string()
.contains("'utf-32' codec can't decode bytes in position 0-7"));
Expand Down
10 changes: 6 additions & 4 deletions tests/test_sequence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -279,10 +279,12 @@ fn test_generic_list_set() {
let list = PyCell::new(py, GenericList { items: vec![] }).unwrap();

py_run!(py, list, "list.items = [1, 2, 3]");
assert_eq!(
list.borrow().items,
vec![1.to_object(py), 2.to_object(py), 3.to_object(py)]
);
assert!(list
.borrow()
.items
.iter()
.zip(&[1u32, 2, 3])
.all(|(a, b)| a.as_ref(py).eq(&b.into_py(py)).unwrap()));
}

#[pyclass]
Expand Down
10 changes: 6 additions & 4 deletions tests/test_sequence_pyproto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,10 +263,12 @@ fn test_generic_list_set() {
let list = PyCell::new(py, GenericList { items: vec![] }).unwrap();

py_run!(py, list, "list.items = [1, 2, 3]");
assert_eq!(
list.borrow().items,
vec![1.to_object(py), 2.to_object(py), 3.to_object(py)]
);
assert!(list
.borrow()
.items
.iter()
.zip(&[1u32, 2, 3])
.all(|(a, b)| a.as_ref(py).eq(&b.into_py(py)).unwrap()));
}

#[pyclass]
Expand Down
4 changes: 2 additions & 2 deletions tests/test_serde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,12 @@ mod test_serde {

#[test]
fn test_deserialize() {
let serialized = r#"{"username": "danya", "friends":
let serialized = r#"{"username": "danya", "friends":
[{"username": "friend", "group": {"name": "danya's friends"}, "friends": []}]}"#;
let user: User = serde_json::from_str(serialized).expect("failed to deserialize");

assert_eq!(user.username, "danya");
assert_eq!(user.group, None);
assert!(user.group.is_none());
assert_eq!(user.friends.len(), 1usize);
let friend = user.friends.get(0).unwrap();

Expand Down

0 comments on commit 03dc96b

Please sign in to comment.