diff --git a/newsfragments/4317.added.md b/newsfragments/4317.added.md new file mode 100644 index 00000000000..99849236101 --- /dev/null +++ b/newsfragments/4317.added.md @@ -0,0 +1 @@ +Implement `PartialEq` for `Bound<'py, PyInt>` with `u8`, `u16`, `u32`, `u64`, `u128`, `usize`, `i8`, `i16`, `i32`, `i64`, `i128` and `isize`. diff --git a/src/types/num.rs b/src/types/num.rs index 517e769742b..f33d85f4a1c 100644 --- a/src/types/num.rs +++ b/src/types/num.rs @@ -1,4 +1,6 @@ -use crate::{ffi, PyAny}; +use super::any::PyAnyMethods; + +use crate::{ffi, instance::Bound, PyAny}; /// Represents a Python `int` object. /// @@ -17,3 +19,113 @@ pyobject_native_type_core!(PyInt, pyobject_native_static_type_object!(ffi::PyLon /// Deprecated alias for [`PyInt`]. #[deprecated(since = "0.23.0", note = "use `PyInt` instead")] pub type PyLong = PyInt; + +macro_rules! int_compare { + ($rust_type: ty) => { + impl PartialEq<$rust_type> for Bound<'_, PyInt> { + #[inline] + fn eq(&self, other: &$rust_type) -> bool { + if let Ok(value) = self.extract::<$rust_type>() { + value == *other + } else { + false + } + } + } + impl PartialEq> for $rust_type { + #[inline] + fn eq(&self, other: &Bound<'_, PyInt>) -> bool { + if let Ok(value) = other.extract::<$rust_type>() { + value == *self + } else { + false + } + } + } + }; +} + +int_compare!(i8); +int_compare!(u8); +int_compare!(i16); +int_compare!(u16); +int_compare!(i32); +int_compare!(u32); +int_compare!(i64); +int_compare!(u64); +int_compare!(i128); +int_compare!(u128); +int_compare!(isize); +int_compare!(usize); + +#[cfg(test)] +mod tests { + use super::PyInt; + use crate::{types::PyAnyMethods, IntoPy, Python}; + + #[test] + fn test_partial_eq() { + Python::with_gil(|py| { + let v_i8 = 123i8; + let v_u8 = 123i8; + let v_i16 = 123i16; + let v_u16 = 123u16; + let v_i32 = 123i32; + let v_u32 = 123u32; + let v_i64 = 123i64; + let v_u64 = 123u64; + let v_i128 = 123i128; + let v_u128 = 123u128; + let v_isize = 123isize; + let v_usize = 123usize; + let obj = 123_i64.into_py(py).downcast_bound(py).unwrap().clone(); + assert_eq!(v_i8, obj); + assert_eq!(obj, v_i8); + + assert_eq!(v_u8, obj); + assert_eq!(obj, v_u8); + + assert_eq!(v_i16, obj); + assert_eq!(obj, v_i16); + + assert_eq!(v_u16, obj); + assert_eq!(obj, v_u16); + + assert_eq!(v_i32, obj); + assert_eq!(obj, v_i32); + + assert_eq!(v_u32, obj); + assert_eq!(obj, v_u32); + + assert_eq!(v_i64, obj); + assert_eq!(obj, v_i64); + + assert_eq!(v_u64, obj); + assert_eq!(obj, v_u64); + + assert_eq!(v_i128, obj); + assert_eq!(obj, v_i128); + + assert_eq!(v_u128, obj); + assert_eq!(obj, v_u128); + + assert_eq!(v_isize, obj); + assert_eq!(obj, v_isize); + + assert_eq!(v_usize, obj); + assert_eq!(obj, v_usize); + + let big_num = (u8::MAX as u16) + 1; + let big_obj = big_num + .into_py(py) + .into_bound(py) + .downcast_into::() + .unwrap(); + + for x in 0u8..=u8::MAX { + assert_ne!(x, big_obj); + assert_ne!(big_obj, x); + } + }); + } +} diff --git a/tests/test_class_attributes.rs b/tests/test_class_attributes.rs index 9e544211c3c..a2e099549c0 100644 --- a/tests/test_class_attributes.rs +++ b/tests/test_class_attributes.rs @@ -224,7 +224,7 @@ macro_rules! test_case { let struct_obj = struct_class.call0().unwrap(); assert!(struct_obj.setattr($renamed_field_name, 2).is_ok()); let attr = struct_obj.getattr($renamed_field_name).unwrap(); - assert_eq!(2, attr.extract().unwrap()); + assert_eq!(2, attr.extract::().unwrap()); }); } };