Skip to content

Commit

Permalink
added various std traits for PyBackedStr and PyBackedBytes (#4020)
Browse files Browse the repository at this point in the history
* added various std traits for `PyBackedStr` and `PyBackedBytes`

* add newsfragment

* add tests
  • Loading branch information
Icxolu authored Apr 1, 2024
1 parent 336b1c9 commit 63ba371
Show file tree
Hide file tree
Showing 2 changed files with 261 additions and 3 deletions.
1 change: 1 addition & 0 deletions newsfragments/4020.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Adds `Clone`, `Debug`, `PartialEq`, `Eq`, `PartialOrd`, `Ord` and `Hash` implementation for `PyBackedBytes` and `PyBackedStr`, and `Display` for `PyBackedStr`.
263 changes: 260 additions & 3 deletions src/pybacked.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Contains types for working with Python objects that own the underlying data.
use std::{ops::Deref, ptr::NonNull};
use std::{ops::Deref, ptr::NonNull, sync::Arc};

use crate::{
types::{
Expand All @@ -13,6 +13,7 @@ use crate::{
/// A wrapper around `str` where the storage is owned by a Python `bytes` or `str` object.
///
/// This type gives access to the underlying data via a `Deref` implementation.
#[derive(Clone)]
pub struct PyBackedStr {
#[allow(dead_code)] // only held so that the storage is not dropped
storage: Py<PyAny>,
Expand Down Expand Up @@ -44,6 +45,14 @@ impl AsRef<[u8]> for PyBackedStr {
unsafe impl Send for PyBackedStr {}
unsafe impl Sync for PyBackedStr {}

impl std::fmt::Display for PyBackedStr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.deref().fmt(f)
}
}

impl_traits!(PyBackedStr, str);

impl TryFrom<Bound<'_, PyString>> for PyBackedStr {
type Error = PyErr;
fn try_from(py_string: Bound<'_, PyString>) -> Result<Self, Self::Error> {
Expand Down Expand Up @@ -79,16 +88,18 @@ impl FromPyObject<'_> for PyBackedStr {
/// A wrapper around `[u8]` where the storage is either owned by a Python `bytes` object, or a Rust `Box<[u8]>`.
///
/// This type gives access to the underlying data via a `Deref` implementation.
#[derive(Clone)]
pub struct PyBackedBytes {
#[allow(dead_code)] // only held so that the storage is not dropped
storage: PyBackedBytesStorage,
data: NonNull<[u8]>,
}

#[allow(dead_code)]
#[derive(Clone)]
enum PyBackedBytesStorage {
Python(Py<PyBytes>),
Rust(Box<[u8]>),
Rust(Arc<[u8]>),
}

impl Deref for PyBackedBytes {
Expand All @@ -110,6 +121,32 @@ impl AsRef<[u8]> for PyBackedBytes {
unsafe impl Send for PyBackedBytes {}
unsafe impl Sync for PyBackedBytes {}

impl<const N: usize> PartialEq<[u8; N]> for PyBackedBytes {
fn eq(&self, other: &[u8; N]) -> bool {
self.deref() == other
}
}

impl<const N: usize> PartialEq<PyBackedBytes> for [u8; N] {
fn eq(&self, other: &PyBackedBytes) -> bool {
self == other.deref()
}
}

impl<const N: usize> PartialEq<&[u8; N]> for PyBackedBytes {
fn eq(&self, other: &&[u8; N]) -> bool {
self.deref() == *other
}
}

impl<const N: usize> PartialEq<PyBackedBytes> for &[u8; N] {
fn eq(&self, other: &PyBackedBytes) -> bool {
self == &other.deref()
}
}

impl_traits!(PyBackedBytes, [u8]);

impl From<Bound<'_, PyBytes>> for PyBackedBytes {
fn from(py_bytes: Bound<'_, PyBytes>) -> Self {
let b = py_bytes.as_bytes();
Expand All @@ -123,7 +160,7 @@ impl From<Bound<'_, PyBytes>> for PyBackedBytes {

impl From<Bound<'_, PyByteArray>> for PyBackedBytes {
fn from(py_bytearray: Bound<'_, PyByteArray>) -> Self {
let s = py_bytearray.to_vec().into_boxed_slice();
let s = Arc::<[u8]>::from(py_bytearray.to_vec());
let data = NonNull::from(s.as_ref());
Self {
storage: PyBackedBytesStorage::Rust(s),
Expand All @@ -144,10 +181,85 @@ impl FromPyObject<'_> for PyBackedBytes {
}
}

macro_rules! impl_traits {
($slf:ty, $equiv:ty) => {
impl std::fmt::Debug for $slf {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.deref().fmt(f)
}
}

impl PartialEq for $slf {
fn eq(&self, other: &Self) -> bool {
self.deref() == other.deref()
}
}

impl PartialEq<$equiv> for $slf {
fn eq(&self, other: &$equiv) -> bool {
self.deref() == other
}
}

impl PartialEq<&$equiv> for $slf {
fn eq(&self, other: &&$equiv) -> bool {
self.deref() == *other
}
}

impl PartialEq<$slf> for $equiv {
fn eq(&self, other: &$slf) -> bool {
self == other.deref()
}
}

impl PartialEq<$slf> for &$equiv {
fn eq(&self, other: &$slf) -> bool {
self == &other.deref()
}
}

impl Eq for $slf {}

impl PartialOrd for $slf {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}

impl PartialOrd<$equiv> for $slf {
fn partial_cmp(&self, other: &$equiv) -> Option<std::cmp::Ordering> {
self.deref().partial_cmp(other)
}
}

impl PartialOrd<$slf> for $equiv {
fn partial_cmp(&self, other: &$slf) -> Option<std::cmp::Ordering> {
self.partial_cmp(other.deref())
}
}

impl Ord for $slf {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.deref().cmp(other.deref())
}
}

impl std::hash::Hash for $slf {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.deref().hash(state)
}
}
};
}
use impl_traits;

#[cfg(test)]
mod test {
use super::*;
use crate::Python;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};

#[test]
fn py_backed_str_empty() {
Expand Down Expand Up @@ -223,4 +335,149 @@ mod test {
is_send::<PyBackedBytes>();
is_sync::<PyBackedBytes>();
}

#[test]
fn test_backed_str_clone() {
Python::with_gil(|py| {
let s1: PyBackedStr = PyString::new_bound(py, "hello").try_into().unwrap();
let s2 = s1.clone();
assert_eq!(s1, s2);

drop(s1);
assert_eq!(s2, "hello");
});
}

#[test]
fn test_backed_str_eq() {
Python::with_gil(|py| {
let s1: PyBackedStr = PyString::new_bound(py, "hello").try_into().unwrap();
let s2: PyBackedStr = PyString::new_bound(py, "hello").try_into().unwrap();
assert_eq!(s1, "hello");
assert_eq!(s1, s2);

let s3: PyBackedStr = PyString::new_bound(py, "abcde").try_into().unwrap();
assert_eq!("abcde", s3);
assert_ne!(s1, s3);
});
}

#[test]
fn test_backed_str_hash() {
Python::with_gil(|py| {
let h = {
let mut hasher = DefaultHasher::new();
"abcde".hash(&mut hasher);
hasher.finish()
};

let s1: PyBackedStr = PyString::new_bound(py, "abcde").try_into().unwrap();
let h1 = {
let mut hasher = DefaultHasher::new();
s1.hash(&mut hasher);
hasher.finish()
};

assert_eq!(h, h1);
});
}

#[test]
fn test_backed_str_ord() {
Python::with_gil(|py| {
let mut a = vec!["a", "c", "d", "b", "f", "g", "e"];
let mut b = a
.iter()
.map(|s| PyString::new_bound(py, s).try_into().unwrap())
.collect::<Vec<PyBackedStr>>();

a.sort();
b.sort();

assert_eq!(a, b);
})
}

#[test]
fn test_backed_bytes_from_bytes_clone() {
Python::with_gil(|py| {
let b1: PyBackedBytes = PyBytes::new_bound(py, b"abcde").into();
let b2 = b1.clone();
assert_eq!(b1, b2);

drop(b1);
assert_eq!(b2, b"abcde");
});
}

#[test]
fn test_backed_bytes_from_bytearray_clone() {
Python::with_gil(|py| {
let b1: PyBackedBytes = PyByteArray::new_bound(py, b"abcde").into();
let b2 = b1.clone();
assert_eq!(b1, b2);

drop(b1);
assert_eq!(b2, b"abcde");
});
}

#[test]
fn test_backed_bytes_eq() {
Python::with_gil(|py| {
let b1: PyBackedBytes = PyBytes::new_bound(py, b"abcde").into();
let b2: PyBackedBytes = PyByteArray::new_bound(py, b"abcde").into();

assert_eq!(b1, b"abcde");
assert_eq!(b1, b2);

let b3: PyBackedBytes = PyBytes::new_bound(py, b"hello").into();
assert_eq!(b"hello", b3);
assert_ne!(b1, b3);
});
}

#[test]
fn test_backed_bytes_hash() {
Python::with_gil(|py| {
let h = {
let mut hasher = DefaultHasher::new();
b"abcde".hash(&mut hasher);
hasher.finish()
};

let b1: PyBackedBytes = PyBytes::new_bound(py, b"abcde").into();
let h1 = {
let mut hasher = DefaultHasher::new();
b1.hash(&mut hasher);
hasher.finish()
};

let b2: PyBackedBytes = PyByteArray::new_bound(py, b"abcde").into();
let h2 = {
let mut hasher = DefaultHasher::new();
b2.hash(&mut hasher);
hasher.finish()
};

assert_eq!(h, h1);
assert_eq!(h, h2);
});
}

#[test]
fn test_backed_bytes_ord() {
Python::with_gil(|py| {
let mut a = vec![b"a", b"c", b"d", b"b", b"f", b"g", b"e"];
let mut b = a
.iter()
.map(|&b| PyBytes::new_bound(py, b).into())
.collect::<Vec<PyBackedBytes>>();

a.sort();
b.sort();

assert_eq!(a, b);
})
}
}

0 comments on commit 63ba371

Please sign in to comment.