Skip to content

Commit

Permalink
Implement ToPyObject for [T; N] (#2313)
Browse files Browse the repository at this point in the history
  • Loading branch information
PigeonF authored Apr 19, 2022
1 parent 1d71d17 commit dea9eb7
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 1 deletion.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Added

- Implement `ToPyObject` for `[T; N]`. [#2313](https://github.com/PyO3/pyo3/pull/2313)

## [0.16.4] - 2022-04-14

### Added
Expand Down
33 changes: 32 additions & 1 deletion src/conversions/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,15 @@ mod min_const_generics {
}
}

impl<T, const N: usize> ToPyObject for [T; N]
where
T: ToPyObject,
{
fn to_object(&self, py: Python<'_>) -> PyObject {
self.as_ref().to_object(py)
}
}

impl<'a, T, const N: usize> FromPyObject<'a> for [T; N]
where
T: FromPyObject<'a>,
Expand Down Expand Up @@ -154,6 +163,15 @@ mod array_impls {
}
}

impl<T> ToPyObject for [T; $N]
where
T: ToPyObject,
{
fn to_object(&self, py: Python<'_>) -> PyObject {
self.as_ref().to_object(py)
}
}

impl<'a, T> FromPyObject<'a> for [T; $N]
where
T: Copy + Default + FromPyObject<'a>,
Expand Down Expand Up @@ -200,7 +218,7 @@ fn invalid_sequence_length(expected: usize, actual: usize) -> PyErr {

#[cfg(test)]
mod tests {
use crate::{PyResult, Python};
use crate::{types::PyList, PyResult, Python};

#[test]
fn test_extract_small_bytearray_to_array() {
Expand All @@ -213,6 +231,19 @@ mod tests {
assert!(&v == b"abc");
});
}
#[test]
fn test_topyobject_array_conversion() {
use crate::ToPyObject;
Python::with_gil(|py| {
let array: [f32; 4] = [0.0, -16.0, 16.0, 42.0];
let pyobject = array.to_object(py);
let pylist: &PyList = pyobject.extract(py).unwrap();
assert_eq!(pylist[0].extract::<f32>().unwrap(), 0.0);
assert_eq!(pylist[1].extract::<f32>().unwrap(), -16.0);
assert_eq!(pylist[2].extract::<f32>().unwrap(), 16.0);
assert_eq!(pylist[3].extract::<f32>().unwrap(), 42.0);
});
}

#[test]
fn test_extract_invalid_sequence_length() {
Expand Down

0 comments on commit dea9eb7

Please sign in to comment.