Skip to content

Commit

Permalink
(chore): allow list of slices
Browse files Browse the repository at this point in the history
  • Loading branch information
ilan-gold committed Jun 11, 2024
1 parent a875eaf commit 736a24b
Showing 1 changed file with 38 additions and 20 deletions.
58 changes: 38 additions & 20 deletions src/array.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use pyo3::exceptions::{PyIndexError, PyTypeError, PyValueError};
use pyo3::prelude::*;
use zarrs::array::{Array as RustArray};
use zarrs::array_subset::ArraySubset;
use zarrs::storage::ReadableStorageTraits;
use pyo3::types::PySlice;
use pyo3::types::{PyInt, PyList, PySlice};
use std::ops::Range;

#[pyclass]
pub struct Array {
Expand All @@ -11,38 +13,54 @@ pub struct Array {

impl Array {

fn bound_slice(&self, slice: &Bound<PySlice>) -> PyResult<Vec<u64>> {
let start: i32 = slice.getattr("start")?.extract().map_or(0, |x| x);
let mut start_u64: u64 = start as u64;
if start < 0 {
if self.arr.shape()[0] as i32 + start < 0 {
return Err(PyIndexError::new_err(format!("{0} out of bounds", start)))
}
start_u64 = u64::try_from(start).map_err(|_| PyErr::new::<PyIndexError, _>("Failed to extract start"))?;
}
let stop: i32 = slice.getattr("stop")?.extract().map_or(self.arr.shape()[0] as i32, |x| x);
let mut stop_u64: u64 = stop as u64;
if stop < 0 {
if self.arr.shape()[0] as i32 + stop < 0 {
return Err(PyIndexError::new_err(format!("{0} out of bounds", stop)))
fn maybe_convert_u64(&self, ind: i32, axis: usize) -> PyResult<u64> {
let mut ind_u64: u64 = ind as u64;
if ind < 0 {
if self.arr.shape()[axis] as i32 + ind < 0 {
return Err(PyIndexError::new_err(format!("{0} out of bounds", ind)))
}
stop_u64 = u64::try_from(stop).map_err(|_| PyErr::new::<PyIndexError, _>("Failed to extract stop"))?;
ind_u64 = u64::try_from(ind).map_err(|_| PyIndexError::new_err("Failed to extract start"))?;
}
let _step: u64 = slice.getattr("step")?.extract().map_or(1, |x| x);
let selection: Vec<u64> = (start_u64..stop_u64).step_by(_step.try_into().unwrap()).collect();
return Ok(ind_u64);
}

fn bound_slice(&self, slice: &Bound<PySlice>) -> PyResult<Range<u64>> {
let start: i32 = slice.getattr("start")?.extract().map_or(0, |x| x);
let stop: i32 = slice.getattr("stop")?.extract().map_or(0, |x| x);
let start_u64 = self.maybe_convert_u64(start, 0)?;
let stop_u64 = self.maybe_convert_u64(stop, 0)?;
// let _step: u64 = slice.getattr("step")?.extract().map_or(1, |x| x); // there is no way to use step it seems with zarrs?
let selection = start_u64..stop_u64;
return Ok(selection)
}

fn fill_from_slices(&self, slices: Vec<Range<u64>>) -> PyResult<Vec<Range<u64>>> {
Ok(self.arr.shape().iter().enumerate().map(|(index, &value)| { if index < slices.len() { slices[index].clone() } else { 0..value } }).collect())
}
}

#[pymethods]
impl Array {

pub fn __getitem__(&self, key: &Bound<'_, PyAny>) -> PyResult<Vec<u8>> {
let selection: ArraySubset;
if let Ok(slice) = key.downcast::<PySlice>() {
let selection = self.bound_slice(slice)?;
return self.arr.retrieve_chunk(&selection[..]).map_err(|x| PyErr::new::<PyTypeError, _>(x.to_string()));
selection = ArraySubset::new_with_ranges(&self.fill_from_slices(vec![self.bound_slice(slice)?])?);
} else if let Ok(list) = key.downcast::<PyList>(){
let ranges: Vec<Range<u64>> = list.into_iter().enumerate().map(|(index, val)| {
if let Ok(int) = val.downcast::<PyInt>() {
let end = self.maybe_convert_u64(int.extract()?, index)?;
Ok(end..(end + 1))
} else if let Ok(slice) = val.downcast::<PySlice>() {
Ok(self.bound_slice(slice)?)
} else {
return Err(PyValueError::new_err(format!("Cannot take {0}, must be int or slice", val.to_string())));
}
}).collect::<Result<Vec<Range<u64>>, _>>()?;
selection = ArraySubset::new_with_ranges(&self.fill_from_slices(ranges)?);
} else {
return Err(PyTypeError::new_err("Unsupported type"));
}
return self.arr.retrieve_chunks(&selection).map_err(|x| PyErr::new::<PyTypeError, _>(x.to_string()));
}
}

0 comments on commit 736a24b

Please sign in to comment.