Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for arbitrary arrays #1128

Merged
merged 5 commits into from
May 1, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
- Support PyPy 3.7. [#1538](https://github.com/PyO3/pyo3/pull/1538)

### Added
- Add conversions for `[T; N]` for all `N` on Rust 1.51 and up. [#1128](https://github.com/PyO3/pyo3/pull/1128)
- Add conversions between `OsStr`/`OsString`/`Path`/`PathBuf` and Python strings. [#1379](https://github.com/PyO3/pyo3/pull/1379)
- Add `#[pyo3(from_py_with = "...")]` attribute for function arguments and struct fields to override the default from-Python conversion. [#1411](https://github.com/PyO3/pyo3/pull/1411)
- Add FFI definition `PyCFunction_CheckExact` for Python 3.9 and later. [#1425](https://github.com/PyO3/pyo3/pull/1425)
Expand Down
17 changes: 17 additions & 0 deletions build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,17 @@ fn ensure_python_version(interpreter_config: &InterpreterConfig) -> Result<()> {
Ok(())
}

fn rustc_minor_version() -> Option<u32> {
let rustc = env::var_os("RUSTC")?;
let output = Command::new(rustc).arg("--version").output().ok()?;
let version = core::str::from_utf8(&output.stdout).ok()?;
let mut pieces = version.split('.');
if pieces.next() != Some("rustc 1") {
return None;
}
pieces.next()?.parse().ok()
}

fn emit_cargo_configuration(interpreter_config: &InterpreterConfig) -> Result<()> {
let target_os = cargo_env_var("CARGO_CFG_TARGET_OS").unwrap();
let is_extension_module = cargo_env_var("CARGO_FEATURE_EXTENSION_MODULE").is_some();
Expand Down Expand Up @@ -850,6 +861,12 @@ fn emit_cargo_configuration(interpreter_config: &InterpreterConfig) -> Result<()
println!("cargo:rustc-cfg=py_sys_config=\"{}\"", flag)
}

// Enable use of const generics on Rust 1.51 and greater

if rustc_minor_version().unwrap_or(0) >= 51 {
println!("cargo:rustc-cfg=min_const_generics");
}

Ok(())
}

Expand Down
282 changes: 282 additions & 0 deletions src/conversions/array.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,282 @@
use crate::{
exceptions, FromPyObject, IntoPy, PyAny, PyErr, PyObject, PyResult, PyTryFrom, Python,
ToPyObject,
};

#[cfg(min_const_generics)]
impl<T, const N: usize> IntoPy<PyObject> for [T; N]
where
T: ToPyObject,
{
fn into_py(self, py: Python) -> PyObject {
self.as_ref().to_object(py)
}
}

#[cfg(min_const_generics)]
impl<'a, T, const N: usize> FromPyObject<'a> for [T; N]
where
T: FromPyObject<'a>,
{
#[cfg(not(feature = "nightly"))]
fn extract(obj: &'a PyAny) -> PyResult<Self> {
create_array_from_obj(obj)
}

#[cfg(feature = "nightly")]
default fn extract(obj: &'a PyAny) -> PyResult<Self> {
create_array_from_obj(obj)
}
}

#[cfg(all(min_const_generics, feature = "nightly"))]
impl<'source, T, const N: usize> FromPyObject<'source> for [T; N]
where
for<'a> T: Default + FromPyObject<'a> + crate::buffer::Element,
{
fn extract(obj: &'source PyAny) -> PyResult<Self> {
use crate::{AsPyPointer, PyNativeType};
// first try buffer protocol
if unsafe { crate::ffi::PyObject_CheckBuffer(obj.as_ptr()) } == 1 {
if let Ok(buf) = crate::buffer::PyBuffer::get(obj) {
let mut array = [T::default(); N];
if buf.dimensions() == 1 && buf.copy_to_slice(obj.py(), &mut array).is_ok() {
buf.release(obj.py());
return Ok(array);
}
buf.release(obj.py());
}
}
create_array_from_obj(obj)
}
}

#[cfg(min_const_generics)]
fn create_array_from_obj<'s, T, const N: usize>(obj: &'s PyAny) -> PyResult<[T; N]>
where
T: FromPyObject<'s>,
{
let seq = <crate::types::PySequence as PyTryFrom>::try_from(obj)?;
let seq_len = seq.len()? as usize;
if seq_len != N {
return Err(invalid_sequence_length(N, seq_len));
}
array_try_from_fn(|idx| seq.get_item(idx as isize).and_then(PyAny::extract))
}

// TODO use std::array::try_from_fn, if that stabilises:
// (https://github.com/rust-lang/rust/pull/75644)
#[cfg(min_const_generics)]
fn array_try_from_fn<E, F, T, const N: usize>(mut cb: F) -> Result<[T; N], E>
where
F: FnMut(usize) -> Result<T, E>,
{
// Helper to safely create arrays since the standard library doesn't
// provide one yet. Shouldn't be necessary in the future.
struct ArrayGuard<T, const N: usize> {
dst: *mut T,
initialized: usize,
}

impl<T, const N: usize> Drop for ArrayGuard<T, N> {
fn drop(&mut self) {
debug_assert!(self.initialized <= N);
let initialized_part = core::ptr::slice_from_raw_parts_mut(self.dst, self.initialized);
unsafe {
core::ptr::drop_in_place(initialized_part);
}
}
}

// [MaybeUninit<T>; N] would be "nicer" but is actually difficult to create - there are nightly
// APIs which would make this easier.
let mut array: core::mem::MaybeUninit<[T; N]> = core::mem::MaybeUninit::uninit();
let mut guard: ArrayGuard<T, N> = ArrayGuard {
dst: array.as_mut_ptr() as _,
initialized: 0,
};
unsafe {
let mut value_ptr = array.as_mut_ptr() as *mut T;
for i in 0..N {
core::ptr::write(value_ptr, cb(i)?);
value_ptr = value_ptr.offset(1);
guard.initialized += 1;
}
core::mem::forget(guard);
Ok(array.assume_init())
}
}

#[cfg(not(min_const_generics))]
macro_rules! array_impls {
($($N:expr),+) => {
$(
impl<T> IntoPy<PyObject> for [T; $N]
where
T: ToPyObject
{
fn into_py(self, py: Python) -> PyObject {
self.as_ref().to_object(py)
}
}

impl<'a, T> FromPyObject<'a> for [T; $N]
where
T: Copy + Default + FromPyObject<'a>,
{
#[cfg(not(feature = "nightly"))]
fn extract(obj: &'a PyAny) -> PyResult<Self> {
let mut array = [T::default(); $N];
extract_sequence_into_slice(obj, &mut array)?;
Ok(array)
}

#[cfg(feature = "nightly")]
default fn extract(obj: &'a PyAny) -> PyResult<Self> {
let mut array = [T::default(); $N];
extract_sequence_into_slice(obj, &mut array)?;
Ok(array)
}
}

#[cfg(feature = "nightly")]
impl<'source, T> FromPyObject<'source> for [T; $N]
where
for<'a> T: Default + FromPyObject<'a> + crate::buffer::Element,
{
fn extract(obj: &'source PyAny) -> PyResult<Self> {
let mut array = [T::default(); $N];
// first try buffer protocol
if unsafe { crate::ffi::PyObject_CheckBuffer(obj.as_ptr()) } == 1 {
if let Ok(buf) = crate::buffer::PyBuffer::get(obj) {
if buf.dimensions() == 1 && buf.copy_to_slice(obj.py(), &mut array).is_ok() {
buf.release(obj.py());
return Ok(array);
}
buf.release(obj.py());
}
}
// fall back to sequence protocol
extract_sequence_into_slice(obj, &mut array)?;
Ok(array)
}
}
)+
}
}

#[cfg(not(min_const_generics))]
array_impls!(
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,
26, 27, 28, 29, 30, 31, 32
);

#[cfg(not(min_const_generics))]
fn extract_sequence_into_slice<'s, T>(obj: &'s PyAny, slice: &mut [T]) -> PyResult<()>
where
T: FromPyObject<'s>,
{
let seq = <crate::types::PySequence as PyTryFrom>::try_from(obj)?;
let seq_len = seq.len()? as usize;
if seq_len != slice.len() {
return Err(invalid_sequence_length(slice.len(), seq_len));
}
for (value, item) in slice.iter_mut().zip(seq.iter()?) {
*value = item?.extract::<T>()?;
}
Ok(())
}

fn invalid_sequence_length(expected: usize, actual: usize) -> PyErr {
exceptions::PyValueError::new_err(format!(
"expected a sequence of length {} (got {})",
expected, actual
))
}

#[cfg(test)]
mod test {
use crate::{PyResult, Python};
#[cfg(min_const_generics)]
use std::{
panic,
sync::atomic::{AtomicUsize, Ordering},
};

#[cfg(min_const_generics)]
#[test]
fn array_try_from_fn() {
davidhewitt marked this conversation as resolved.
Show resolved Hide resolved
static DROP_COUNTER: AtomicUsize = AtomicUsize::new(0);
struct CountDrop;
impl Drop for CountDrop {
fn drop(&mut self) {
DROP_COUNTER.fetch_add(1, Ordering::SeqCst);
}
}
let _ = catch_unwind_silent(move || {
let _: Result<[CountDrop; 4], ()> = super::array_try_from_fn(|idx| {
if idx == 2 {
panic!("peek a boo");
}
Ok(CountDrop)
});
});
assert_eq!(DROP_COUNTER.load(Ordering::SeqCst), 2);
}

#[test]
fn test_extract_small_bytearray_to_array() {
let gil = Python::acquire_gil();
let py = gil.python();
davidhewitt marked this conversation as resolved.
Show resolved Hide resolved
let v: [u8; 3] = py
.eval("bytearray(b'abc')", None, None)
.unwrap()
.extract()
.unwrap();
assert!(&v == b"abc");
}

#[test]
fn test_extract_invalid_sequence_length() {
let gil = Python::acquire_gil();
let py = gil.python();
let v: PyResult<[u8; 3]> = py
.eval("bytearray(b'abcdefg')", None, None)
.unwrap()
.extract();
assert_eq!(
v.unwrap_err().to_string(),
"ValueError: expected a sequence of length 3 (got 7)"
);
}

#[cfg(min_const_generics)]
#[test]
fn test_extract_bytearray_to_array() {
let gil = Python::acquire_gil();
let py = gil.python();
let v: [u8; 33] = py
.eval(
"bytearray(b'abcabcabcabcabcabcabcabcabcabcabc')",
None,
None,
)
.unwrap()
.extract()
.unwrap();
assert!(&v == b"abcabcabcabcabcabcabcabcabcabcabc");
}

// https://stackoverflow.com/a/59211505
#[cfg(min_const_generics)]
fn catch_unwind_silent<F, R>(f: F) -> std::thread::Result<R>
where
F: FnOnce() -> R + panic::UnwindSafe,
{
let prev_hook = panic::take_hook();
panic::set_hook(Box::new(|_| {}));
let result = panic::catch_unwind(f);
panic::set_hook(prev_hook);
result
}
}
4 changes: 2 additions & 2 deletions src/conversions/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//! This module contains conversions between non-String Rust object and their string representation
//! in Python
//! This module contains conversions between various Rust object and their representation in Python.

mod array;
mod osstr;
mod path;
20 changes: 0 additions & 20 deletions src/types/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,26 +178,6 @@ where
}
}

macro_rules! array_impls {
($($N:expr),+) => {
$(
impl<T> IntoPy<PyObject> for [T; $N]
where
T: ToPyObject
{
fn into_py(self, py: Python) -> PyObject {
self.as_ref().to_object(py)
}
}
)+
}
}

array_impls!(
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,
26, 27, 28, 29, 30, 31, 32
);

impl<T> ToPyObject for Vec<T>
where
T: ToPyObject,
Expand Down
Loading