From 2afaa93888f4b245c43d325c6d192c40301e12b4 Mon Sep 17 00:00:00 2001 From: Caio Date: Thu, 25 Mar 2021 15:59:48 -0300 Subject: [PATCH] Add support for arbitrary arrays --- build.rs | 23 ++++++++++++++ src/types/list.rs | 12 ++++++++ src/types/mod.rs | 37 +++++++++++++++++++++++ src/types/sequence.rs | 70 +++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 142 insertions(+) diff --git a/build.rs b/build.rs index 1fdfa330fd1..63d69f74f92 100644 --- a/build.rs +++ b/build.rs @@ -902,6 +902,27 @@ fn abi3_without_interpreter() -> Result<()> { Ok(()) } +fn rustc_minor_version() -> Option { + let rustc = env::var_os("RUSTC")?; + let output = Command::new(rustc).arg("--version").output().ok()?; + let version = core::str::from_utf8(&output.stdout).ok()?; + let mut pieces = version.split('.'); + if pieces.next() != Some("rustc 1") { + return None; + } + pieces.next()?.parse().ok() +} + +fn manage_min_const_generics() { + let rustc_minor_version = match rustc_minor_version() { + Some(inner) => inner, + None => return, + }; + if rustc_minor_version >= 51 { + println!("cargo:rustc-cfg=min_const_generics"); + } +} + fn main() -> Result<()> { // If PYO3_NO_PYTHON is set with abi3, we can build PyO3 without calling Python. // We only check for the abi3-py3{ABI3_MAX_MINOR} because lower versions depend on it. @@ -961,5 +982,7 @@ fn main() -> Result<()> { println!("cargo:rustc-cfg=__pyo3_ci"); } + manage_min_const_generics(); + Ok(()) } diff --git a/src/types/list.rs b/src/types/list.rs index 4a0586e4fcd..bd2385367c8 100644 --- a/src/types/list.rs +++ b/src/types/list.rs @@ -178,6 +178,7 @@ where } } +#[cfg(min_const_generics)] macro_rules! array_impls { ($($N:expr),+) => { $( @@ -193,11 +194,22 @@ macro_rules! array_impls { } } +#[cfg(min_const_generics)] array_impls!( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32 ); +#[cfg(not(min_const_generics))] +impl IntoPy for [T; N] +where + T: ToPyObject, +{ + fn into_py(self, py: Python) -> PyObject { + self.as_ref().to_object(py) + } +} + impl ToPyObject for Vec where T: ToPyObject, diff --git a/src/types/mod.rs b/src/types/mod.rs index 96459f37dc2..5d75594d00b 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -244,3 +244,40 @@ mod slice; mod string; mod tuple; mod typeobject; + +#[cfg(min_const_generics)] +struct ArrayGuard { + dst: *mut T, + initialized: usize, +} + +#[cfg(min_const_generics)] +impl Drop for ArrayGuard { + fn drop(&mut self) { + debug_assert!(self.initialized <= N); + let initialized_part = core::ptr::slice_from_raw_parts_mut(self.dst, self.initialized); + unsafe { + core::ptr::drop_in_place(initialized_part); + } + } +} + +#[cfg(min_const_generics)] +fn try_create_array(mut cb: F) -> Result<[T; N], E> +where + F: FnMut(usize) -> Result, +{ + let mut array: core::mem::MaybeUninit<[T; N]> = core::mem::MaybeUninit::uninit(); + let mut guard: ArrayGuard = ArrayGuard { + dst: array.as_mut_ptr() as _, + initialized: 0, + }; + unsafe { + for (idx, value_ptr) in (&mut *array.as_mut_ptr()).iter_mut().enumerate() { + core::ptr::write(value_ptr, cb(idx)?); + guard.initialized += 1; + } + core::mem::forget(guard); + Ok(array.assume_init()) + } +} diff --git a/src/types/sequence.rs b/src/types/sequence.rs index 423021d02a3..9bce20e94f5 100644 --- a/src/types/sequence.rs +++ b/src/types/sequence.rs @@ -257,6 +257,7 @@ impl PySequence { } } +#[cfg(not(min_const_generics))] macro_rules! array_impls { ($($N:expr),+) => { $( @@ -305,11 +306,46 @@ macro_rules! array_impls { } } +#[cfg(not(min_const_generics))] array_impls!( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32 ); +#[cfg(all(min_const_generics, not(feature = "nightly")))] +impl<'a, T, const N: usize> FromPyObject<'a> for [T; N] +where + T: FromPyObject<'a>, +{ + #[cfg(not(feature = "nightly"))] + fn extract(obj: &'a PyAny) -> PyResult { + create_array_from_obj(obj) + } + + #[cfg(feature = "nightly")] + default fn extract(obj: &'a PyAny) -> PyResult { + create_array_from_obj(obj) + } +} + +#[cfg(all(min_const_generics, feature = "nightly"))] +impl<'source, T, const N: usize> FromPyObject<'source> for [T; N] +where + for<'a> T: FromPyObject<'a> + crate::buffer::Element, +{ + fn extract(obj: &'source PyAny) -> PyResult { + let mut array = create_array_from_obj(obj)?; + if let Ok(buf) = crate::buffer::PyBuffer::get(obj) { + if buf.dimensions() == 1 && buf.copy_to_slice(obj.py(), &mut array).is_ok() { + buf.release(obj.py()); + return Ok(array); + } + buf.release(obj.py()); + } + Ok(array) + } +} + impl<'a, T> FromPyObject<'a> for Vec where T: FromPyObject<'a>, @@ -345,6 +381,21 @@ where } } +#[cfg(min_const_generics)] +fn create_array_from_obj<'s, T, const N: usize>(obj: &'s PyAny) -> PyResult<[T; N]> +where + T: FromPyObject<'s>, +{ + let seq = ::try_from(obj)?; + crate::types::try_create_array(|idx| { + seq.get_item(idx as isize) + .map_err(|_| { + exceptions::PyBufferError::new_err("Slice length does not match buffer length.") + })? + .extract::() + }) +} + fn extract_sequence<'s, T>(obj: &'s PyAny) -> PyResult> where T: FromPyObject<'s>, @@ -357,6 +408,7 @@ where Ok(v) } +#[cfg(not(min_const_generics))] fn extract_sequence_into_slice<'s, T>(obj: &'s PyAny, slice: &mut [T]) -> PyResult<()> where T: FromPyObject<'s>, @@ -706,6 +758,7 @@ mod test { assert!(v == [1, 2, 3, 4]); } + #[cfg(not(min_const_generics))] #[test] fn test_extract_bytearray_to_array() { let gil = Python::acquire_gil(); @@ -718,6 +771,23 @@ mod test { assert!(&v == b"abc"); } + #[cfg(min_const_generics)] + #[test] + fn test_extract_bytearray_to_array() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let v: [u8; 33] = py + .eval( + "bytearray(b'abcabcabcabcabcabcabcabcabcabcabc')", + None, + None, + ) + .unwrap() + .extract() + .unwrap(); + assert!(&v == b"abcabcabcabcabcabcabcabcabcabcabc"); + } + #[test] fn test_extract_bytearray_to_vec() { let gil = Python::acquire_gil();