diff --git a/python/zarrs_python/pipeline.py b/python/zarrs_python/pipeline.py index 1175413..bc81ddd 100644 --- a/python/zarrs_python/pipeline.py +++ b/python/zarrs_python/pipeline.py @@ -24,7 +24,13 @@ from zarr.core.indexing import SelectorTuple from ._internal import CodecPipelineImpl -from .utils import get_max_threads, make_chunk_info_for_rust +from .utils import ( + CollapsedDimensionError, + DiscontiguousArrayError, + get_max_threads, + make_chunk_info_for_rust, + make_chunk_info_for_rust_with_indices, +) @dataclass(frozen=True) @@ -94,12 +100,32 @@ async def read( out = out.as_ndarray_like() # FIXME: Error if array is not in host memory if not out.dtype.isnative: raise RuntimeError("Non-native byte order not supported") - - chunks_desc = make_chunk_info_for_rust(batch_info) - await asyncio.to_thread( - self.impl.retrieve_chunks, chunks_desc, out, chunk_concurrent_limit + try: + chunks_desc = make_chunk_info_for_rust_with_indices(batch_info, drop_axes) + index_in_rust = True + except (DiscontiguousArrayError, CollapsedDimensionError): + chunks_desc = make_chunk_info_for_rust(batch_info) + index_in_rust = False + if index_in_rust: + await asyncio.to_thread( + self.impl.retrieve_chunks_and_apply_index, + chunks_desc, + out, + chunk_concurrent_limit, + ) + return None + chunks = await asyncio.to_thread( + self.impl.retrieve_chunks, chunks_desc, chunk_concurrent_limit ) - return None + for chunk, chunk_info in zip(chunks, batch_info): + out_selection = chunk_info[3] + selection = chunk_info[2] + spec = chunk_info[1] + chunk_reshaped = chunk.view(spec.dtype).reshape(spec.shape) + chunk_selected = chunk_reshaped[selection] + if drop_axes: + chunk_selected = np.squeeze(chunk_selected, axis=drop_axes) + out[out_selection] = chunk_selected async def write( self, @@ -117,8 +143,11 @@ async def write( value = np.ascontiguousarray(value, dtype=value.dtype.newbyteorder("=")) elif not value.flags.c_contiguous: value = np.ascontiguousarray(value) - chunks_desc = make_chunk_info_for_rust(batch_info) + chunks_desc = make_chunk_info_for_rust_with_indices(batch_info, drop_axes) await asyncio.to_thread( - self.impl.store_chunks, chunks_desc, value, chunk_concurrent_limit + self.impl.store_chunks_with_indices, + chunks_desc, + value, + chunk_concurrent_limit, ) return None diff --git a/python/zarrs_python/utils.py b/python/zarrs_python/utils.py index 8d267a2..43e8b73 100644 --- a/python/zarrs_python/utils.py +++ b/python/zarrs_python/utils.py @@ -1,15 +1,18 @@ from __future__ import annotations +import operator import os +from functools import reduce from typing import TYPE_CHECKING, Any import numpy as np -from zarr.core.indexing import ArrayIndexError, SelectorTuple, is_integer +from zarr.core.indexing import SelectorTuple, is_integer if TYPE_CHECKING: from collections.abc import Iterable + from types import EllipsisType - from zarr.abc.store import ByteSetter + from zarr.abc.store import ByteGetter, ByteSetter from zarr.core.array_spec import ArraySpec from zarr.core.common import ChunkCoords @@ -19,45 +22,164 @@ def get_max_threads() -> int: return (os.cpu_count() or 1) + 4 -# This is a copy of the function from zarr.core.indexing that fixes: +class DiscontiguousArrayError(Exception): + pass + + +class CollapsedDimensionError(Exception): + pass + + +# This is a (mostly) copy of the function from zarr.core.indexing that fixes: # DeprecationWarning: Conversion of an array with ndim > 0 to a scalar is deprecated # TODO: Upstream this fix -def make_slice_selection(selection: Any) -> list[slice]: +def make_slice_selection(selection: tuple[np.ndarray | float]) -> list[slice]: ls: list[slice] = [] for dim_selection in selection: if is_integer(dim_selection): ls.append(slice(int(dim_selection), int(dim_selection) + 1, 1)) elif isinstance(dim_selection, np.ndarray): + dim_selection = dim_selection.ravel() if len(dim_selection) == 1: ls.append( slice(int(dim_selection.item()), int(dim_selection.item()) + 1, 1) ) else: - raise ArrayIndexError + diff = np.diff(dim_selection) + if (diff != 1).any() and (diff != 0).any(): + raise DiscontiguousArrayError(diff) + ls.append(slice(dim_selection[0], dim_selection[-1] + 1, 1)) else: ls.append(dim_selection) return ls def selector_tuple_to_slice_selection(selector_tuple: SelectorTuple) -> list[slice]: + if isinstance(selector_tuple, slice): + return [selector_tuple] + if all(isinstance(s, slice) for s in selector_tuple): + return list(selector_tuple) + return make_slice_selection(selector_tuple) + + +def convert_chunk_to_primitive( + byte_getter: ByteGetter | ByteSetter, chunk_spec: ArraySpec +) -> tuple[str, ChunkCoords, str, Any]: return ( - [selector_tuple] - if isinstance(selector_tuple, slice) - else make_slice_selection(selector_tuple) + str(byte_getter), + chunk_spec.shape, + str(chunk_spec.dtype), + chunk_spec.fill_value.tobytes(), ) -def make_chunk_info_for_rust( - batch_info: Iterable[tuple[ByteSetter, ArraySpec, SelectorTuple, SelectorTuple]], -) -> list[tuple[str, ChunkCoords, str, Any, list[slice], list[slice]]]: - return list( - ( - str(byte_getter), +def resulting_shape_from_index( + array_shape: tuple[int, ...], + index_tuple: tuple[int | slice | EllipsisType | np.ndarray], + drop_axes: tuple[int, ...], + *, + pad: bool, +) -> tuple[int, ...]: + result_shape = [] + advanced_index_shapes = [ + idx.shape for idx in index_tuple if isinstance(idx, np.ndarray) + ] + basic_shape_index = 0 + + # Broadcast all advanced indices, if any + if advanced_index_shapes: + result_shape += np.broadcast_shapes(*advanced_index_shapes) + # Consume dimensions from array_shape + basic_shape_index += len(advanced_index_shapes) + + # Process each remaining index in index_tuple + for idx in index_tuple: + if isinstance(idx, int): + # Integer index reduces dimension, so skip this dimension in array_shape + basic_shape_index += 1 + elif isinstance(idx, slice): + if idx.step is not None and idx.step > 1: + raise DiscontiguousArrayError( + "Step size greater than 1 is not supported" + ) + # Slice keeps dimension, adjust size accordingly + start, stop, _ = idx.indices(array_shape[basic_shape_index]) + result_shape.append(stop - start) + basic_shape_index += 1 + elif idx is Ellipsis: + # Calculate number of dimensions that Ellipsis should fill + num_to_fill = len(array_shape) - len(index_tuple) + 1 + result_shape += array_shape[ + basic_shape_index : basic_shape_index + num_to_fill + ] + basic_shape_index += num_to_fill + elif not isinstance(idx, np.ndarray): + raise ValueError(f"Invalid index type: {type(idx)}") + + # Step 4: Append remaining dimensions from array_shape if fewer indices were used + if basic_shape_index < len(array_shape) and pad: + result_shape += array_shape[basic_shape_index:] + + return tuple(size for idx, size in enumerate(result_shape) if idx not in drop_axes) + + +def prod_op(x: Iterable[int]) -> int: + return reduce(operator.mul, x, 1) + + +def get_shape_for_selector( + selector_tuple: SelectorTuple, + shape: tuple[int, ...], + *, + pad: bool, + drop_axes: tuple[int, ...] = (), +) -> tuple[int, ...]: + if isinstance(selector_tuple, slice | np.ndarray): + return resulting_shape_from_index( + shape, + (selector_tuple,), + drop_axes, + pad=pad, + ) + return resulting_shape_from_index(shape, selector_tuple, drop_axes, pad=pad) + + +def make_chunk_info_for_rust_with_indices( + batch_info: Iterable[ + tuple[ByteGetter | ByteSetter, ArraySpec, SelectorTuple, SelectorTuple] + ], + drop_axes: tuple[int, ...], +) -> list[tuple[tuple[str, ChunkCoords, str, Any], list[slice], list[slice]]]: + chunk_info_with_indices = [] + for byte_getter, chunk_spec, chunk_selection, out_selection in batch_info: + chunk_info = convert_chunk_to_primitive(byte_getter, chunk_spec) + out_selection_as_slices = selector_tuple_to_slice_selection(out_selection) + chunk_selection_as_slices = selector_tuple_to_slice_selection(chunk_selection) + shape_chunk_selection_slices = get_shape_for_selector( + tuple(chunk_selection_as_slices), chunk_spec.shape, - str(chunk_spec.dtype), - chunk_spec.fill_value.tobytes(), - selector_tuple_to_slice_selection(out_selection), - selector_tuple_to_slice_selection(chunk_selection), + pad=True, + drop_axes=drop_axes, ) - for (byte_getter, chunk_spec, chunk_selection, out_selection) in batch_info + shape_chunk_selection = get_shape_for_selector( + chunk_selection, chunk_spec.shape, pad=True, drop_axes=drop_axes + ) + if prod_op(shape_chunk_selection) != prod_op(shape_chunk_selection_slices): + raise CollapsedDimensionError( + f"{shape_chunk_selection} != {shape_chunk_selection_slices}" + ) + chunk_info_with_indices.append( + (chunk_info, out_selection_as_slices, chunk_selection_as_slices) + ) + return chunk_info_with_indices + + +def make_chunk_info_for_rust( + batch_info: Iterable[ + tuple[ByteGetter | ByteSetter, ArraySpec, SelectorTuple, SelectorTuple] + ], +) -> list[tuple[str, ChunkCoords, str, Any]]: + return list( + convert_chunk_to_primitive(byte_getter, chunk_spec) + for (byte_getter, chunk_spec, _, _) in batch_info ) diff --git a/src/chunk_item.rs b/src/chunk_item.rs new file mode 100644 index 0000000..c381f9b --- /dev/null +++ b/src/chunk_item.rs @@ -0,0 +1,185 @@ +use std::{num::NonZeroU64, sync::Arc}; + +use pyo3::{ + exceptions::{PyRuntimeError, PyValueError}, + types::{PySlice, PySliceMethods}, + Bound, PyErr, PyResult, +}; +use zarrs::{ + array::{ChunkRepresentation, DataType, FillValue}, + array_subset::ArraySubset, + metadata::v3::{array::data_type::DataTypeMetadataV3, MetadataV3}, + storage::{MaybeBytes, ReadableWritableListableStorageTraits, StorageError, StoreKey}, +}; + +use crate::utils::PyErrExt; + +pub(crate) type Raw<'a> = ( + // store path + String, + // shape + Vec, + // data type + String, + // fill value bytes + Vec, +); + +pub(crate) type RawWithIndices<'a> = ( + Raw<'a>, + // out selection + Vec>, + // chunk selection + Vec>, +); + +pub(crate) trait IntoItem: std::marker::Sized { + fn store_path(&self) -> &str; + fn into_item( + self, + store: Arc, + key: StoreKey, + shape: S, + ) -> PyResult; +} + +pub(crate) trait ChunksItem { + fn store(&self) -> Arc; + fn key(&self) -> &StoreKey; + fn representation(&self) -> &ChunkRepresentation; + + fn get(&self) -> Result { + self.store().get(self.key()) + } +} + +pub(crate) struct Basic { + store: Arc, + key: StoreKey, + representation: ChunkRepresentation, +} + +pub(crate) struct WithSubset { + pub item: Basic, + pub chunk_subset: ArraySubset, + pub subset: ArraySubset, +} + +impl ChunksItem for Basic { + fn store(&self) -> Arc { + self.store.clone() + } + fn key(&self) -> &StoreKey { + &self.key + } + fn representation(&self) -> &ChunkRepresentation { + &self.representation + } +} + +impl ChunksItem for WithSubset { + fn store(&self) -> Arc { + self.item.store.clone() + } + fn key(&self) -> &StoreKey { + &self.item.key + } + fn representation(&self) -> &ChunkRepresentation { + &self.item.representation + } +} + +impl<'a> IntoItem for Raw<'a> { + fn store_path(&self) -> &str { + &self.0 + } + fn into_item( + self, + store: Arc, + key: StoreKey, + (): (), + ) -> PyResult { + let (_, chunk_shape, dtype, fill_value) = self; + let representation = get_chunk_representation(chunk_shape, &dtype, fill_value)?; + Ok(Basic { + store, + key, + representation, + }) + } +} + +impl IntoItem for RawWithIndices<'_> { + fn store_path(&self) -> &str { + &self.0 .0 + } + fn into_item( + self, + store: Arc, + key: StoreKey, + shape: &[u64], + ) -> PyResult { + let (raw, selection, chunk_selection) = self; + let chunk_shape = raw.1.clone(); + let item = raw.into_item(store.clone(), key, ())?; + Ok(WithSubset { + item, + chunk_subset: selection_to_array_subset(&chunk_selection, &chunk_shape)?, + subset: selection_to_array_subset(&selection, shape)?, + }) + } +} + +fn get_chunk_representation( + chunk_shape: Vec, + dtype: &str, + fill_value: Vec, +) -> PyResult { + // Get the chunk representation + let data_type = + DataType::from_metadata(&DataTypeMetadataV3::from_metadata(&MetadataV3::new(dtype))) + .map_py_err::()?; + let chunk_shape = chunk_shape + .into_iter() + .map(|x| NonZeroU64::new(x).expect("chunk shapes should always be non-zero")) + .collect(); + let chunk_representation = + ChunkRepresentation::new(chunk_shape, data_type, FillValue::new(fill_value)) + .map_py_err::()?; + Ok(chunk_representation) +} + +fn slice_to_range(slice: &Bound<'_, PySlice>, length: isize) -> PyResult> { + let indices = slice.indices(length)?; + if indices.start < 0 { + Err(PyErr::new::( + "slice start must be greater than or equal to 0".to_string(), + )) + } else if indices.stop < 0 { + Err(PyErr::new::( + "slice stop must be greater than or equal to 0".to_string(), + )) + } else if indices.step != 1 { + Err(PyErr::new::( + "slice step must be equal to 1".to_string(), + )) + } else { + Ok(u64::try_from(indices.start)?..u64::try_from(indices.stop)?) + } +} + +fn selection_to_array_subset( + selection: &[Bound<'_, PySlice>], + shape: &[u64], +) -> PyResult { + if selection.is_empty() { + Ok(ArraySubset::new_with_shape(vec![1; shape.len()])) + } else { + let chunk_ranges = selection + .iter() + .zip(shape) + .map(|(selection, &shape)| slice_to_range(selection, isize::try_from(shape)?)) + .collect::>>()?; + Ok(ArraySubset::new_with_ranges(&chunk_ranges)) + } +} diff --git a/src/lib.rs b/src/lib.rs index 8888a5f..3b988ce 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,28 +1,26 @@ #![warn(clippy::pedantic)] +use chunk_item::{ChunksItem, IntoItem}; use numpy::npyffi::PyArrayObject; -use numpy::{PyUntypedArray, PyUntypedArrayMethods}; +use numpy::{IntoPyArray, PyArray1, PyUntypedArray, PyUntypedArrayMethods}; use pyo3::exceptions::{PyRuntimeError, PyTypeError, PyValueError}; use pyo3::prelude::*; -use pyo3::types::PySlice; use rayon::iter::{IntoParallelIterator, ParallelIterator}; use rayon_iter_concurrent_limit::iter_concurrent_limit; use std::borrow::Cow; -use std::num::NonZeroU64; use std::sync::{Arc, Mutex}; use unsafe_cell_slice::UnsafeCellSlice; use zarrs::array::codec::{ ArrayToBytesCodecTraits, CodecOptions, CodecOptionsBuilder, StoragePartialDecoder, }; use zarrs::array::{ - copy_fill_value_into, update_array_bytes, ArrayBytes, ArraySize, ChunkRepresentation, - CodecChain, DataType, FillValue, + copy_fill_value_into, update_array_bytes, ArrayBytes, ArraySize, CodecChain, FillValue, }; use zarrs::array_subset::ArraySubset; -use zarrs::metadata::v3::array::data_type::DataTypeMetadataV3; use zarrs::metadata::v3::MetadataV3; use zarrs::storage::{ReadableWritableListableStorageTraits, StorageHandle, StoreKey}; +mod chunk_item; mod codec_pipeline_store_filesystem; #[cfg(test)] mod tests; @@ -70,190 +68,117 @@ impl CodecPipelineImpl { } } - fn collect_chunk_descriptions( + fn collect_chunk_descriptions, I, S: Copy>( &self, - chunk_descriptions: Vec, - shape: &[u64], - ) -> PyResult> { + chunk_descriptions: Vec, + shape: S, + ) -> PyResult> { chunk_descriptions .into_iter() - .map( - |(store_path, chunk_shape, dtype, fill_value, selection, chunk_selection)| { - let (store, path) = self.get_store_and_path(&store_path)?; - let key = StoreKey::new(path).map_py_err::()?; - Ok(ChunksItem { - store, - key, - chunk_subset: Self::selection_to_array_subset( - &chunk_selection, - &chunk_shape, - )?, - subset: Self::selection_to_array_subset(&selection, shape)?, - representation: Self::get_chunk_representation( - chunk_shape, - &dtype, - fill_value, - )?, - }) - }, - ) + .map(|raw| { + let (store, path) = self.get_store_and_path(raw.store_path())?; + let key = StoreKey::new(path).map_py_err::()?; + raw.into_item(store, key, shape) + }) .collect() } - fn get_chunk_representation( - chunk_shape: Vec, - dtype: &str, - fill_value: Vec, - ) -> PyResult { - // Get the chunk representation - let data_type = - DataType::from_metadata(&DataTypeMetadataV3::from_metadata(&MetadataV3::new(dtype))) - .map_py_err::()?; - let chunk_shape = chunk_shape - .into_iter() - .map(|x| NonZeroU64::new(x).expect("chunk shapes should always be non-zero")) - .collect(); - let chunk_representation = - ChunkRepresentation::new(chunk_shape, data_type, FillValue::new(fill_value)) - .map_py_err::()?; - Ok(chunk_representation) - } - - fn retrieve_chunk_bytes<'a>( - store: &dyn ReadableWritableListableStorageTraits, - key: &StoreKey, + fn retrieve_chunk_bytes<'a, I: ChunksItem>( + item: &I, codec_chain: &CodecChain, - chunk_representation: &ChunkRepresentation, codec_options: &CodecOptions, ) -> PyResult> { - let value_encoded = store.get(key).map_py_err::()?; + let value_encoded = item.get().map_py_err::()?; let value_decoded = if let Some(value_encoded) = value_encoded { let value_encoded: Vec = value_encoded.into(); // zero-copy in this case codec_chain - .decode(value_encoded.into(), chunk_representation, codec_options) + .decode(value_encoded.into(), item.representation(), codec_options) .map_py_err::()? } else { let array_size = ArraySize::new( - chunk_representation.data_type().size(), - chunk_representation.num_elements(), + item.representation().data_type().size(), + item.representation().num_elements(), ); - ArrayBytes::new_fill_value(array_size, chunk_representation.fill_value()) + ArrayBytes::new_fill_value(array_size, item.representation().fill_value()) }; Ok(value_decoded) } - fn store_chunk_bytes( - store: &dyn ReadableWritableListableStorageTraits, - key: &StoreKey, + fn store_chunk_bytes( + item: &I, codec_chain: &CodecChain, - chunk_representation: &ChunkRepresentation, value_decoded: ArrayBytes, codec_options: &CodecOptions, ) -> PyResult<()> { - if value_decoded.is_fill_value(chunk_representation.fill_value()) { - store.erase(key) + value_decoded + .validate( + item.representation().num_elements(), + item.representation().data_type().size(), + ) + .map_py_err::()?; + + if value_decoded.is_fill_value(item.representation().fill_value()) { + item.store().erase(item.key()) } else { let value_encoded = codec_chain - .encode(value_decoded, chunk_representation, codec_options) + .encode(value_decoded, item.representation(), codec_options) .map(Cow::into_owned) .map_py_err::()?; // Store the encoded chunk - store.set(key, value_encoded.into()) + item.store().set(item.key(), value_encoded.into()) } .map_py_err::() } - fn store_chunk_subset_bytes( - store: &dyn ReadableWritableListableStorageTraits, - key: &StoreKey, + fn store_chunk_subset_bytes( + item: &I, codec_chain: &CodecChain, - chunk_representation: &ChunkRepresentation, - chunk_subset_bytes: &ArrayBytes, + chunk_subset_bytes: ArrayBytes, chunk_subset: &ArraySubset, codec_options: &CodecOptions, ) -> PyResult<()> { - // Validate the inputs - chunk_subset_bytes - .validate( - chunk_subset.num_elements(), - chunk_representation.data_type().size(), - ) - .map_py_err::()?; - if !chunk_subset.inbounds(&chunk_representation.shape_u64()) { + if !chunk_subset.inbounds(&item.representation().shape_u64()) { return Err(PyErr::new::( "chunk subset is out of bounds".to_string(), )); } - // Retrieve the chunk - let chunk_bytes_old = Self::retrieve_chunk_bytes( - store, - key, - codec_chain, - chunk_representation, - codec_options, - )?; - - // Update the chunk - let chunk_bytes_new = unsafe { - // SAFETY: - // - chunk_bytes_old is compatible with the chunk shape and data type size (validated on decoding) - // - chunk_subset is compatible with chunk_subset_bytes and the data type size (validated above) - // - chunk_subset is within the bounds of the chunk shape (validated above) - // - output bytes and output subset bytes are compatible (same data type) - update_array_bytes( - chunk_bytes_old, - &chunk_representation.shape_u64(), - chunk_subset, - chunk_subset_bytes, - chunk_representation.data_type().size(), - ) - }; - - // Store the updated chunk - Self::store_chunk_bytes( - store, - key, - codec_chain, - chunk_representation, - chunk_bytes_new, - codec_options, - ) - } - - fn slice_to_range(slice: &Bound<'_, PySlice>, length: isize) -> PyResult> { - let indices = slice.indices(length)?; - if indices.start < 0 { - Err(PyErr::new::( - "slice start must be greater than or equal to 0".to_string(), - )) - } else if indices.stop < 0 { - Err(PyErr::new::( - "slice stop must be greater than or equal to 0".to_string(), - )) - } else if indices.step != 1 { - Err(PyErr::new::( - "slice step must be equal to 1".to_string(), - )) + if chunk_subset.start().iter().all(|&o| o == 0) + && chunk_subset.shape() == item.representation().shape_u64() + { + // Fast path if the chunk subset spans the entire chunk, no read required + Self::store_chunk_bytes(item, codec_chain, chunk_subset_bytes, codec_options) } else { - Ok(u64::try_from(indices.start)?..u64::try_from(indices.stop)?) - } - } + // Validate the chunk subset bytes + chunk_subset_bytes + .validate( + chunk_subset.num_elements(), + item.representation().data_type().size(), + ) + .map_py_err::()?; - fn selection_to_array_subset( - selection: &[Bound<'_, PySlice>], - shape: &[u64], - ) -> PyResult { - if selection.is_empty() { - Ok(ArraySubset::new_with_shape(vec![1; shape.len()])) - } else { - let chunk_ranges = selection - .iter() - .zip(shape) - .map(|(selection, &shape)| Self::slice_to_range(selection, isize::try_from(shape)?)) - .collect::>>()?; - Ok(ArraySubset::new_with_ranges(&chunk_ranges)) + // Retrieve the chunk + let chunk_bytes_old = Self::retrieve_chunk_bytes(item, codec_chain, codec_options)?; + + // Update the chunk + let chunk_bytes_new = unsafe { + // SAFETY: + // - chunk_bytes_old is compatible with the chunk shape and data type size (validated on decoding) + // - chunk_subset is compatible with chunk_subset_bytes and the data type size (validated above) + // - chunk_subset is within the bounds of the chunk shape (validated above) + // - output bytes and output subset bytes are compatible (same data type) + update_array_bytes( + chunk_bytes_old, + &item.representation().shape_u64(), + chunk_subset, + &chunk_subset_bytes, + item.representation().data_type().size(), + ) + }; + + // Store the updated chunk + Self::store_chunk_bytes(item, codec_chain, chunk_bytes_new, codec_options) } } @@ -304,29 +229,6 @@ impl CodecPipelineImpl { } } -type ChunksItemRaw<'a> = ( - // store path - String, - // shape - Vec, - // data type - String, - // fill value bytes - Vec, - // out selection - Vec>, - // chunk selection - Vec>, -); - -struct ChunksItem { - store: Arc, - key: StoreKey, - chunk_subset: ArraySubset, - subset: ArraySubset, - representation: ChunkRepresentation, -} - #[pymethods] impl CodecPipelineImpl { #[pyo3(signature = (metadata, validate_checksums=None, store_empty_chunks=None, concurrent_target=None))] @@ -360,10 +262,10 @@ impl CodecPipelineImpl { }) } - fn retrieve_chunks( + fn retrieve_chunks_and_apply_index( &self, py: Python, - chunk_descriptions: Vec, // FIXME: Ref / iterable? + chunk_descriptions: Vec, // FIXME: Ref / iterable? value: &Bound<'_, PyUntypedArray>, chunk_concurrent_limit: usize, ) -> PyResult<()> { @@ -381,13 +283,13 @@ impl CodecPipelineImpl { py.allow_threads(move || { let codec_options = &self.codec_options; - let update_chunk_subset = |item: ChunksItem| { + let update_chunk_subset = |item: chunk_item::WithSubset| { // See zarrs::array::Array::retrieve_chunk_subset_into if item.chunk_subset.start().iter().all(|&o| o == 0) - && item.chunk_subset.shape() == item.representation.shape_u64() + && item.chunk_subset.shape() == item.representation().shape_u64() { // See zarrs::array::Array::retrieve_chunk_into - let chunk_encoded = item.store.get(&item.key).map_py_err::()?; + let chunk_encoded = item.get().map_py_err::()?; if let Some(chunk_encoded) = chunk_encoded { // Decode the encoded data into the output buffer let chunk_encoded: Vec = chunk_encoded.into(); @@ -397,7 +299,7 @@ impl CodecPipelineImpl { // - item.subset is within the bounds of output_shape. self.codec_chain.decode_into( Cow::Owned(chunk_encoded), - &item.representation, + item.representation(), &output, &output_shape, &item.subset, @@ -412,8 +314,8 @@ impl CodecPipelineImpl { // - output is an array with output_shape elements of the item.representation data type, // - item.subset is within the bounds of output_shape. copy_fill_value_into( - item.representation.data_type(), - item.representation.fill_value(), + item.representation().data_type(), + item.representation().fill_value(), &output, &output_shape, &item.subset, @@ -422,15 +324,17 @@ impl CodecPipelineImpl { } } else { // Partially decode the chunk into the output buffer - let storage_handle = Arc::new(StorageHandle::new(item.store.clone())); + let storage_handle = Arc::new(StorageHandle::new(item.store().clone())); // NOTE: Normally a storage transformer would exist between the storage handle and the input handle // but zarr-python does not support them nor forward them to the codec pipeline - let input_handle = - Arc::new(StoragePartialDecoder::new(storage_handle, item.key)); + let input_handle = Arc::new(StoragePartialDecoder::new( + storage_handle, + item.key().clone(), + )); let partial_decoder = self .codec_chain .clone() - .partial_decoder(input_handle, &item.representation, codec_options) + .partial_decoder(input_handle, item.representation(), codec_options) .map_py_err::()?; unsafe { // SAFETY: @@ -460,10 +364,57 @@ impl CodecPipelineImpl { }) } - fn store_chunks( + fn retrieve_chunks<'py>( + &self, + py: Python<'py>, + chunk_descriptions: Vec, // FIXME: Ref / iterable? + chunk_concurrent_limit: usize, + ) -> PyResult>>> { + let chunk_descriptions = self.collect_chunk_descriptions(chunk_descriptions, ())?; + + let chunk_bytes = py.allow_threads(move || { + let codec_options = &self.codec_options; + + let get_chunk_subset = |item: chunk_item::Basic| { + let chunk_encoded = item.get().map_py_err::()?; + Ok(if let Some(chunk_encoded) = chunk_encoded { + let chunk_encoded: Vec = chunk_encoded.into(); + self.codec_chain + .decode( + Cow::Owned(chunk_encoded), + item.representation(), + codec_options, + ) + .map_py_err::()? + } else { + // The chunk is missing so we need to create one. + let num_elements = item.representation().num_elements(); + let data_type_size = item.representation().data_type().size(); + let chunk_shape = ArraySize::new(data_type_size, num_elements); + ArrayBytes::new_fill_value(chunk_shape, item.representation().fill_value()) + } + .into_fixed() + .map_py_err::()? + .into_owned()) + }; + iter_concurrent_limit!( + chunk_concurrent_limit, + chunk_descriptions, + map, + get_chunk_subset + ) + .collect::>>>() + })?; + Ok(chunk_bytes + .into_iter() + .map(|x| x.into_pyarray_bound(py)) + .collect()) + } + + fn store_chunks_with_indices( &self, py: Python, - chunk_descriptions: Vec, + chunk_descriptions: Vec, value: &Bound<'_, PyUntypedArray>, chunk_concurrent_limit: usize, ) -> PyResult<()> { @@ -493,21 +444,19 @@ impl CodecPipelineImpl { py.allow_threads(move || { let codec_options = &self.codec_options; - let store_chunk = |item: ChunksItem| match &input { + let store_chunk = |item: chunk_item::WithSubset| match &input { InputValue::Array(input) => { let chunk_subset_bytes = input .extract_array_subset( &item.subset, &input_shape, - item.representation.data_type(), + item.item.representation().data_type(), ) .map_py_err::()?; Self::store_chunk_subset_bytes( - item.store.as_ref(), - &item.key, + &item, &self.codec_chain, - &item.representation, - &chunk_subset_bytes, + chunk_subset_bytes, &item.chunk_subset, codec_options, ) @@ -515,18 +464,16 @@ impl CodecPipelineImpl { InputValue::Constant(constant_value) => { let chunk_subset_bytes = ArrayBytes::new_fill_value( ArraySize::new( - item.representation.data_type().size(), + item.representation().data_type().size(), item.chunk_subset.num_elements(), ), constant_value, ); Self::store_chunk_subset_bytes( - item.store.as_ref(), - &item.key, + &item, &self.codec_chain, - &item.representation, - &chunk_subset_bytes, + chunk_subset_bytes, &item.chunk_subset, codec_options, ) diff --git a/tests/conftest.py b/tests/conftest.py index 24e016c..1ac9d7f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,9 +11,13 @@ from zarr.storage import LocalStore, MemoryStore, ZipStore from zarr.storage.remote import RemoteStore -import zarrs_python # noqa: F401 +from zarrs_python.utils import ( # noqa: F401 + CollapsedDimensionError, + DiscontiguousArrayError, +) if TYPE_CHECKING: + from collections.abc import Iterable from typing import Any, Literal from zarr.abc.store import Store @@ -60,3 +64,92 @@ def array_fixture(request: pytest.FixtureRequest) -> npt.NDArray[Any]: .reshape(array_request.shape, order=array_request.order) .astype(array_request.dtype) ) + + +# tests that also fail with zarr-python's default codec pipeline +zarr_python_default_codec_pipeline_failures = [ + # ellipsis weirdness, need to report + "test_roundtrip[oindex-2d-contiguous_in_chunk_array-ellipsis]", + "test_roundtrip[oindex-2d-discontinuous_in_chunk_array-ellipsis]", + "test_roundtrip[vindex-2d-contiguous_in_chunk_array-ellipsis]", + "test_roundtrip[vindex-2d-discontinuous_in_chunk_array-ellipsis]", + "test_roundtrip[oindex-2d-across_chunks_indices_array-ellipsis]", + "test_roundtrip[vindex-2d-ellipsis-across_chunks_indices_array]", + "test_roundtrip[vindex-2d-across_chunks_indices_array-ellipsis]", + "test_roundtrip[vindex-2d-ellipsis-contiguous_in_chunk_array]", + "test_roundtrip[vindex-2d-ellipsis-discontinuous_in_chunk_array]", + "test_roundtrip_read_only_zarrs[oindex-2d-contiguous_in_chunk_array-ellipsis]", + "test_roundtrip_read_only_zarrs[oindex-2d-discontinuous_in_chunk_array-ellipsis]", + "test_roundtrip_read_only_zarrs[vindex-2d-contiguous_in_chunk_array-ellipsis]", + "test_roundtrip_read_only_zarrs[vindex-2d-discontinuous_in_chunk_array-ellipsis]", + "test_roundtrip_read_only_zarrs[oindex-2d-across_chunks_indices_array-ellipsis]", + "test_roundtrip_read_only_zarrs[vindex-2d-ellipsis-across_chunks_indices_array]", + "test_roundtrip_read_only_zarrs[vindex-2d-across_chunks_indices_array-ellipsis]", + "test_roundtrip_read_only_zarrs[vindex-2d-ellipsis-contiguous_in_chunk_array]", + "test_roundtrip_read_only_zarrs[vindex-2d-ellipsis-discontinuous_in_chunk_array]", + # need to investigate this one - it seems to fail with the default pipeline + # but it makes some sense that it succeeds with ours since we fall-back to numpy indexing + # in the case of a collapsed dimension + # "test_roundtrip_read_only_zarrs[vindex-2d-contiguous_in_chunk_array-contiguous_in_chunk_array]", +] + +zarrs_python_no_discontinuous_writes = [ + "test_roundtrip[oindex-2d-discontinuous_in_chunk_array-slice_in_chunk]", + "test_roundtrip[oindex-2d-discontinuous_in_chunk_array-slice_across_chunks]", + "test_roundtrip[oindex-2d-discontinuous_in_chunk_array-full_slice]", + "test_roundtrip[oindex-2d-discontinuous_in_chunk_array-int]", + "test_roundtrip[oindex-2d-slice_in_chunk-discontinuous_in_chunk_array]", + "test_roundtrip[oindex-2d-slice_across_chunks-discontinuous_in_chunk_array]", + "test_roundtrip[oindex-2d-full_slice-discontinuous_in_chunk_array]", + "test_roundtrip[oindex-2d-int-discontinuous_in_chunk_array]", + "test_roundtrip[oindex-2d-ellipsis-discontinuous_in_chunk_array]", + "test_roundtrip[vindex-2d-discontinuous_in_chunk_array-slice_in_chunk]", + "test_roundtrip[vindex-2d-discontinuous_in_chunk_array-slice_across_chunks]", + "test_roundtrip[vindex-2d-discontinuous_in_chunk_array-full_slice]", + "test_roundtrip[vindex-2d-discontinuous_in_chunk_array-int]", + "test_roundtrip[vindex-2d-slice_in_chunk-discontinuous_in_chunk_array]", + "test_roundtrip[vindex-2d-slice_across_chunks-discontinuous_in_chunk_array]", + "test_roundtrip[vindex-2d-full_slice-discontinuous_in_chunk_array]", + "test_roundtrip[vindex-2d-int-discontinuous_in_chunk_array]", + "test_roundtrip[oindex-2d-discontinuous_in_chunk_array-contiguous_in_chunk_array]", + "test_roundtrip[oindex-2d-contiguous_in_chunk_array-discontinuous_in_chunk_array]", + "test_roundtrip[oindex-2d-across_chunks_indices_array-discontinuous_in_chunk_array]", + "test_roundtrip[oindex-2d-discontinuous_in_chunk_array-discontinuous_in_chunk_array]", + "test_roundtrip[vindex-2d-contiguous_in_chunk_array-discontinuous_in_chunk_array]", + "test_roundtrip[vindex-2d-discontinuous_in_chunk_array-discontinuous_in_chunk_array]", + "test_roundtrip[oindex-2d-discontinuous_in_chunk_array-across_chunks_indices_array]", + "test_roundtrip[vindex-2d-discontinuous_in_chunk_array-contiguous_in_chunk_array]", + "test_roundtrip[oindex-1d-discontinuous_in_chunk_array]", + "test_roundtrip[vindex-1d-discontinuous_in_chunk_array]", +] + +# vindexing with two contiguous arrays would be converted to two slices but +# in numpy indexing actually requires dropping a dimension, which in turn boils +# down to integer indexing, which we can't do i.e., [np.array(1, 2), np.array(1, 2)] -> [slice(1, 3), slice(1, 3)] +# is not a correct conversion, and thus we don't support the write operation +zarrs_python_no_collapsed_dim = [ + "test_roundtrip[vindex-2d-contiguous_in_chunk_array-contiguous_in_chunk_array]" +] + + +def pytest_collection_modifyitems( + config: pytest.Config, items: Iterable[pytest.Item] +) -> None: + for item in items: + if item.name in zarr_python_default_codec_pipeline_failures: + xfail_marker = pytest.mark.xfail( + reason="This test fails with the zarr-python default codec pipeline." + ) + item.add_marker(xfail_marker) + if item.name in zarrs_python_no_discontinuous_writes: + xfail_marker = pytest.mark.xfail( + raises=DiscontiguousArrayError, + reason="zarrs discontinuous writes are not supported.", + ) + item.add_marker(xfail_marker) + if item.name in zarrs_python_no_collapsed_dim: + xfail_marker = pytest.mark.xfail( + raises=CollapsedDimensionError, + reason="zarrs vindexing with multiple contiguous arrays is not supported.", + ) + item.add_marker(xfail_marker) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 8efc5db..07b2016 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -1,6 +1,13 @@ #!/usr/bin/env python3 +import operator import tempfile +from collections.abc import Callable +from contextlib import contextmanager +from functools import reduce +from itertools import product +from pathlib import Path +from types import EllipsisType import numpy as np import pytest @@ -9,122 +16,216 @@ import zarrs_python # noqa: F401 +axis_size_ = 10 +chunk_size_ = axis_size_ // 2 +fill_value_ = 32767 +dimensionalities_ = list(range(1, 5)) + @pytest.fixture def fill_value() -> int: - return 32767 + return fill_value_ + + +non_numpy_indices = [ + pytest.param(slice(1, 3), id="slice_in_chunk"), + pytest.param(slice(1, 7), id="slice_across_chunks"), + pytest.param(2, id="int"), + pytest.param(slice(None), id="full_slice"), + pytest.param(Ellipsis, id="ellipsis"), +] + +numpy_indices = [ + pytest.param(np.array([1, 2]), id="contiguous_in_chunk_array"), + pytest.param(np.array([0, 3]), id="discontinuous_in_chunk_array"), + pytest.param(np.array([0, 6]), id="across_chunks_indices_array"), +] + +all_indices = numpy_indices + non_numpy_indices + +indexing_method_params = [ + pytest.param(lambda x: getattr(x, "oindex"), id="oindex"), + pytest.param(lambda x: x, id="vindex"), +] + + +def pytest_generate_tests(metafunc): + old_pipeline_path = zarr.config.get("codec_pipeline.path") + # need to set the codec pipeline to the zarrs pipeline because the autouse fixture doesn't apply here + zarr.config.set({"codec_pipeline.path": "zarrs_python.ZarrsCodecPipeline"}) + if "test_roundtrip" in metafunc.function.__name__: + arrs = [] + indices = [] + store_values = [] + indexing_methods = [] + ids = [] + for dimensionality in dimensionalities_: + indexers = non_numpy_indices if dimensionality > 2 else all_indices + for index_param_prod in product(indexers, repeat=dimensionality): + index = tuple(index_param.values[0] for index_param in index_param_prod) + # multi-ellipsis indexing is not supported + if sum(isinstance(i, EllipsisType) for i in index) > 1: + continue + for indexing_method_param in indexing_method_params: + arr = gen_arr(fill_value_, Path(tempfile.mktemp()), dimensionality) + indexing_method = indexing_method_param.values[0] + dimensionality_id = f"{dimensionality}d" + id = "-".join( + [indexing_method_param.id, dimensionality_id] + + [index_param.id for index_param in index_param_prod] + ) + ids.append(id) + store_values.append( + gen_store_values( + indexing_method, + index, + full_array((axis_size_,) * dimensionality), + ) + ) + indexing_methods.append(indexing_method) + indices.append(index) + arrs.append(arr) + # array is used as param name to prevent collision with arr fixture + metafunc.parametrize( + ["array", "index", "store_values", "indexing_method"], + zip(arrs, indices, store_values, indexing_methods), + ids=ids, + ) + zarr.config.set({"codec_pipeline.path": old_pipeline_path}) + + +def full_array(shape) -> np.ndarray: + return np.arange(reduce(operator.mul, shape, 1)).reshape(shape) + + +def gen_store_values( + indexing_method: Callable, + index: tuple[int | slice | np.ndarray | EllipsisType, ...], + full_array: np.ndarray, +) -> np.ndarray: + class smoke: + oindex = "oindex" + + def maybe_convert( + i: int | np.ndarray | slice | EllipsisType, axis: int + ) -> np.ndarray: + if isinstance(i, np.ndarray): + return i + if isinstance(i, slice): + return np.arange( + i.start if i.start is not None else 0, + i.stop if i.stop is not None else full_array.shape[axis], + ) + if isinstance(i, int): + return np.array([i]) + if isinstance(i, EllipsisType): + return np.arange(full_array.shape[axis]) + raise ValueError(f"Invalid index {i}") + + if not isinstance(index, EllipsisType) and indexing_method(smoke()) == "oindex": + index: tuple[np.ndarray, ...] = tuple( + maybe_convert(i, axis) for axis, i in enumerate(index) + ) + res = full_array[np.ix_(*index)] + # squeeze out extra dims from integer indexers + if all(i.shape == (1,) for i in index): + res = res.squeeze() + return res + res = res.squeeze( + axis=tuple(axis for axis, i in enumerate(index) if i.shape == (1,)) + ) + return res + return full_array[index] -@pytest.fixture -def chunks() -> tuple[int, ...]: - return (2, 2) +def gen_arr(fill_value, tmp_path, dimensionality) -> zarr.Array: + return zarr.create( + (axis_size_,) * dimensionality, + store=LocalStore(root=tmp_path / ".zarr", mode="w"), + chunks=(chunk_size_,) * dimensionality, + dtype=np.int16, + fill_value=fill_value, + codecs=[zarr.codecs.BytesCodec(), zarr.codecs.BloscCodec()], + ) -@pytest.fixture(params=[np.array([1, 2]), slice(1, 3)], ids=["array", "slice"]) -def indexer(request) -> slice | np.ndarray: +@pytest.fixture(params=dimensionalities_) +def dimensionality(request): return request.param -indexer_2 = indexer - - @pytest.fixture -def arr(fill_value, chunks) -> zarr.Array: - shape = (4, 4) - - tmp = tempfile.TemporaryDirectory() - return zarr.create( - shape, - store=LocalStore(root=tmp.name, mode="w"), - chunks=chunks, - dtype=np.int16, - fill_value=fill_value, - codecs=[zarr.codecs.BytesCodec(), zarr.codecs.BloscCodec()], - ) +def arr(dimensionality, tmp_path) -> zarr.Array: + return gen_arr(fill_value_, tmp_path, dimensionality) -def test_fill_value(arr: zarr.Array, fill_value: int): - assert np.all(arr[:] == fill_value) +def test_fill_value(arr: zarr.Array): + assert np.all(arr[:] == fill_value_) -def test_roundtrip_constant(arr: zarr.Array): +def test_constant(arr: zarr.Array): arr[:] = 42 assert np.all(arr[:] == 42) -def test_roundtrip_singleton(arr: zarr.Array): - arr[1, 1] = 42 - assert arr[1, 1] == 42 - assert arr[0, 0] != 42 +def test_singleton(arr: zarr.Array): + singleton_index = (1,) * len(arr.shape) + non_singleton_index = (0,) * len(arr.shape) + arr[singleton_index] = 42 + assert arr[singleton_index] == 42 + assert arr[non_singleton_index] != 42 -def test_roundtrip_full_array(arr: zarr.Array): - stored_values = np.arange(16).reshape(4, 4) +def test_full_array(arr: zarr.Array): + stored_values = full_array(arr.shape) arr[:] = stored_values assert np.all(arr[:] == stored_values) -def test_roundtrip_partial( - arr: zarr.Array, - indexer: slice | np.ndarray, - indexer_2: slice | np.ndarray, +def test_roundtrip( + array: zarr.Array, + store_values: np.ndarray, + index: tuple[int | slice | np.ndarray | EllipsisType, ...], + indexing_method: Callable, ): - if isinstance(indexer, np.ndarray) and isinstance(indexer_2, np.ndarray): - pytest.skip( - "indexing across two axes with arrays seems to have strange behavior even in normal zarr" - ) - stored_value = np.array([[-1, -2], [-3, -4]]) - arr[indexer, indexer_2] = stored_value - res = arr[indexer, indexer_2] + indexing_method(array)[index] = store_values + res = indexing_method(array)[index] assert np.all( - res == stored_value, + res == store_values, ), res -def test_roundtrip_1d_axis(arr: zarr.Array, indexer: slice | np.ndarray): - stored_value = np.array([-3, -4]) - arr[2, indexer] = stored_value - res = arr[2, indexer] - assert np.all(res == stored_value), res - - -def test_roundtrip_orthogonal_indexing( - arr: zarr.Array, indexer: slice | np.ndarray, indexer_2: np.ndarray | slice -): - stored_value = np.array([[-1, -2], [-3, -4]]) - arr.oindex[indexer, indexer_2] = stored_value - res = arr.oindex[indexer, indexer_2] - assert np.all(res == stored_value), res +@contextmanager +def use_zarr_default_codec_reader(): + zarr.config.set( + {"codec_pipeline.path": "zarr.codecs.pipeline.BatchedCodecPipeline"} + ) + yield + zarr.config.set({"codec_pipeline.path": "zarrs_python.ZarrsCodecPipeline"}) -def test_roundtrip_orthogonal_indexing_1d_axis( - arr: zarr.Array, indexer: slice | np.ndarray +def test_roundtrip_read_only_zarrs( + array: zarr.Array, + store_values: np.ndarray, + index: tuple[int | slice | np.ndarray | EllipsisType, ...], + indexing_method: Callable, ): - stored_value = np.array([-3, -4]) - arr.oindex[2, indexer] = stored_value - res = arr.oindex[2, indexer] - assert np.all(res == stored_value), res - - -def test_roundtrip_ellipsis_indexing_2d(arr: zarr.Array): - stored_value = np.arange(arr.size).reshape(arr.shape) - arr[...] = stored_value - res = arr[...] - assert np.all(res == stored_value), res - - -def test_roundtrip_ellipsis_indexing_1d(arr: zarr.Array): - stored_value = np.array([1, 2, 3, 4]) - arr[2, ...] = stored_value - res = arr[2, ...] - assert np.all(res == stored_value), res + with use_zarr_default_codec_reader(): + arr_default = zarr.open(array.store, mode="r+") + indexing_method(arr_default)[index] = store_values + res = indexing_method(zarr.open(array.store))[index] + assert np.all( + res == store_values, + ), res -def test_roundtrip_ellipsis_indexing_1d_invalid(arr: zarr.Array): +def test_ellipsis_indexing_invalid(arr: zarr.Array): + if len(arr.shape) <= 2: + pytest.skip( + "Ellipsis indexing works for 1D and 2D arrays in zarr-python despite a shape mismatch" + ) stored_value = np.array([1, 2, 3]) - with pytest.raises( - BaseException # TODO: ValueError, but this raises pyo3_runtime.PanicException - ): + with pytest.raises(ValueError): # noqa: PT011 # zarrs-python error: ValueError: operands could not be broadcast together with shapes (4,) (3,) # numpy error: ValueError: could not broadcast input array from shape (3,) into shape (4,) arr[2, ...] = stored_value diff --git a/zarrs_python b/zarrs_python new file mode 160000 index 0000000..66bca18 --- /dev/null +++ b/zarrs_python @@ -0,0 +1 @@ +Subproject commit 66bca1812b78ea630a24124fbca044318722487e