Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/chunk_item.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use zarrs::{
storage::StoreKey,
};

use crate::utils::PyErrExt;
use crate::map_py_err::PyErrStrExt as _;

pub(crate) trait ChunksItem {
fn key(&self) -> &StoreKey;
Expand Down Expand Up @@ -76,7 +76,7 @@ impl Basic {
let fill_value: Bound<'_, PyAny> = chunk_spec.getattr("fill_value")?;
let fill_value_bytes = fill_value_to_bytes(&dtype, &fill_value)?;
Ok(Self {
key: StoreKey::new(path).map_py_err::<PyValueError>()?,
key: StoreKey::new(path).map_py_err_from_str::<PyValueError>()?,
representation: get_chunk_representation(chunk_shape, &dtype, fill_value_bytes)?,
})
}
Expand Down Expand Up @@ -148,14 +148,14 @@ fn get_chunk_representation(
&MetadataV3::new(dtype),
zarrs::config::global_config().data_type_aliases_v3(),
)
.map_py_err::<PyRuntimeError>()?;
.map_py_err_from_str::<PyRuntimeError>()?;
let chunk_shape = chunk_shape
.into_iter()
.map(|x| NonZeroU64::new(x).expect("chunk shapes should always be non-zero"))
.collect();
let chunk_representation =
ChunkRepresentation::new(chunk_shape, data_type, FillValue::new(fill_value))
.map_py_err::<PyValueError>()?;
.map_py_err_from_str::<PyValueError>()?;
Ok(chunk_representation)
}

Expand Down
5 changes: 3 additions & 2 deletions src/concurrency.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ use zarrs::array::{
RecommendedConcurrency,
};

use crate::{chunk_item::ChunksItem, utils::PyCodecErrExt as _, CodecPipelineImpl};
use crate::map_py_err::PyErrExt as _;
use crate::{chunk_item::ChunksItem, CodecPipelineImpl};

pub trait ChunkConcurrentLimitAndCodecOptions {
fn get_chunk_concurrent_limit_and_codec_options(
Expand All @@ -30,7 +31,7 @@ where
let codec_concurrency = codec_pipeline_impl
.codec_chain
.recommended_concurrency(chunk_representation)
.map_codec_err()?;
.map_py_err()?;

let min_concurrent_chunks =
std::cmp::min(codec_pipeline_impl.chunk_concurrent_minimum, num_chunks);
Expand Down
138 changes: 83 additions & 55 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ use zarrs::storage::{ReadableWritableListableStorage, StorageHandle, StoreKey};

mod chunk_item;
mod concurrency;
mod map_py_err;
mod runtime;
mod store;
#[cfg(test)]
Expand All @@ -45,8 +46,9 @@ mod utils;

use crate::chunk_item::ChunksItem;
use crate::concurrency::ChunkConcurrentLimitAndCodecOptions;
use crate::map_py_err::{PyErrExt as _, PyErrStrExt as _};
use crate::store::StoreConfig;
use crate::utils::{PyCodecErrExt, PyErrExt as _, PyUntypedArrayExt as _};
use crate::utils::PyUntypedArrayExt as _;

// TODO: Use a OnceLock for store with get_or_try_init when stabilised?
#[gen_stub_pyclass]
Expand All @@ -67,12 +69,15 @@ impl CodecPipelineImpl {
codec_chain: &CodecChain,
codec_options: &CodecOptions,
) -> PyResult<ArrayBytes<'a>> {
let value_encoded = self.store.get(item.key()).map_py_err::<PyRuntimeError>()?;
let value_encoded = self
.store
.get(item.key())
.map_py_err_from_str::<PyRuntimeError>()?;
let value_decoded = if let Some(value_encoded) = value_encoded {
let value_encoded: Vec<u8> = value_encoded.into(); // zero-copy in this case
codec_chain
.decode(value_encoded.into(), item.representation(), codec_options)
.map_codec_err()?
.map_py_err()?
} else {
let array_size = ArraySize::new(
item.representation().data_type().size(),
Expand All @@ -95,20 +100,22 @@ impl CodecPipelineImpl {
item.representation().num_elements(),
item.representation().data_type().size(),
)
.map_codec_err()?;
.map_py_err()?;

if value_decoded.is_fill_value(item.representation().fill_value()) {
self.store.erase(item.key()).map_py_err::<PyRuntimeError>()
self.store
.erase(item.key())
.map_py_err_from_str::<PyRuntimeError>()
} else {
let value_encoded = codec_chain
.encode(value_decoded, item.representation(), codec_options)
.map(Cow::into_owned)
.map_codec_err()?;
.map_py_err()?;

// Store the encoded chunk
self.store
.set(item.key(), value_encoded.into())
.map_py_err::<PyRuntimeError>()
.map_py_err_from_str::<PyRuntimeError>()
}
}

Expand All @@ -135,7 +142,7 @@ impl CodecPipelineImpl {
// Validate the chunk subset bytes
chunk_subset_bytes
.validate(chunk_subset.num_elements(), data_type_size)
.map_codec_err()?;
.map_py_err()?;

// Retrieve the chunk
let chunk_bytes_old = self.retrieve_chunk_bytes(item, codec_chain, codec_options)?;
Expand All @@ -148,7 +155,7 @@ impl CodecPipelineImpl {
&chunk_subset_bytes,
data_type_size,
)
.map_codec_err()?;
.map_py_err()?;

// Store the updated chunk
self.store_chunk_bytes(item, codec_chain, chunk_bytes_new, codec_options)
Expand All @@ -169,7 +176,7 @@ impl CodecPipelineImpl {
array_object
}

fn nparray_to_slice<'a>(value: &'a Bound<'_, PyUntypedArray>) -> Result<&'a [u8], PyErr> {
fn nparray_to_slice<'a>(value: &'a Bound<'_, PyUntypedArray>) -> PyResult<&'a [u8]> {
if !value.is_c_contiguous() {
return Err(PyErr::new::<PyValueError, _>(
"input array must be a C contiguous array".to_string(),
Expand All @@ -188,7 +195,7 @@ impl CodecPipelineImpl {

fn nparray_to_unsafe_cell_slice<'a>(
value: &'a Bound<'_, PyUntypedArray>,
) -> Result<UnsafeCellSlice<'a, u8>, PyErr> {
) -> PyResult<UnsafeCellSlice<'a, u8>> {
if !value.is_c_contiguous() {
return Err(PyErr::new::<PyValueError, _>(
"input array must be a C contiguous array".to_string(),
Expand All @@ -204,6 +211,48 @@ impl CodecPipelineImpl {
};
Ok(UnsafeCellSlice::new(output))
}

/// Assemble partial decoders in parallel
fn assemble_partial_decoders(
&self,
chunk_descriptions: &[chunk_item::WithSubset],
chunk_concurrent_limit: usize,
codec_options: &CodecOptions,
) -> PyResult<HashMap<StoreKey, Arc<dyn ArrayPartialDecoderTraits>>> {
let partial_chunk_descriptions = chunk_descriptions
.iter()
.filter(|item| !(is_whole_chunk(item)))
.unique_by(|item| item.key())
.collect::<Vec<_>>();
let mut partial_decoder_cache: HashMap<StoreKey, Arc<dyn ArrayPartialDecoderTraits>> =
HashMap::new();
if !partial_chunk_descriptions.is_empty() {
let key_decoder_pairs = iter_concurrent_limit!(
chunk_concurrent_limit,
partial_chunk_descriptions,
map,
|item| {
let storage_handle = Arc::new(StorageHandle::new(self.store.clone()));
let input_handle =
StoragePartialDecoder::new(storage_handle, item.key().clone());
let partial_decoder = self
.codec_chain
.clone()
.partial_decoder(
Arc::new(input_handle),
item.representation(),
codec_options,
)
.map_py_err()?;
Ok((item.key().clone(), partial_decoder))
}
)
.collect::<PyResult<Vec<_>>>()?;
partial_decoder_cache.extend(key_decoder_pairs);
}

Ok(partial_decoder_cache)
}
}

fn array_metadata_to_codec_metadata_v3(
Expand Down Expand Up @@ -238,6 +287,7 @@ fn array_metadata_to_codec_metadata_v3(
#[gen_stub_pymethods]
#[pymethods]
impl CodecPipelineImpl {
#[allow(clippy::needless_pass_by_value)]
#[pyo3(signature = (
array_metadata,
store_config,
Expand All @@ -257,11 +307,12 @@ impl CodecPipelineImpl {
num_threads: Option<usize>,
) -> PyResult<Self> {
let metadata: ArrayMetadata =
serde_json::from_str(array_metadata).map_py_err::<PyTypeError>()?;
serde_json::from_str(array_metadata).map_py_err_from_str::<PyTypeError>()?;
let codec_metadata =
array_metadata_to_codec_metadata_v3(metadata).map_py_err::<PyTypeError>()?;
let codec_chain =
Arc::new(CodecChain::from_metadata(&codec_metadata).map_py_err::<PyTypeError>()?);
array_metadata_to_codec_metadata_v3(metadata).map_py_err_from_str::<PyTypeError>()?;
let codec_chain = Arc::new(
CodecChain::from_metadata(&codec_metadata).map_py_err_from_str::<PyTypeError>()?,
);

let mut codec_options = CodecOptionsBuilder::new();
if let Some(validate_checksums) = validate_checksums {
Expand All @@ -275,8 +326,9 @@ impl CodecPipelineImpl {
chunk_concurrent_maximum.unwrap_or(rayon::current_num_threads());
let num_threads = num_threads.unwrap_or(rayon::current_num_threads());

let store: ReadableWritableListableStorage =
(&store_config).try_into().map_py_err::<PyTypeError>()?;
let store: ReadableWritableListableStorage = (&store_config)
.try_into()
.map_py_err_from_str::<PyTypeError>()?;

Ok(Self {
store,
Expand Down Expand Up @@ -305,38 +357,12 @@ impl CodecPipelineImpl {
return Ok(());
};

// Assemble partial decoders ahead of time and in parallel
let partial_chunk_descriptions = chunk_descriptions
.iter()
.filter(|item| !(is_whole_chunk(item)))
.unique_by(|item| item.key())
.collect::<Vec<_>>();
let mut partial_decoder_cache: HashMap<StoreKey, Arc<dyn ArrayPartialDecoderTraits>> =
HashMap::new();
if !partial_chunk_descriptions.is_empty() {
let key_decoder_pairs = iter_concurrent_limit!(
chunk_concurrent_limit,
partial_chunk_descriptions,
map,
|item| {
let storage_handle = Arc::new(StorageHandle::new(self.store.clone()));
let input_handle =
StoragePartialDecoder::new(storage_handle, item.key().clone());
let partial_decoder = self
.codec_chain
.clone()
.partial_decoder(
Arc::new(input_handle),
item.representation(),
&codec_options,
)
.map_codec_err()?;
Ok((item.key().clone(), partial_decoder))
}
)
.collect::<PyResult<Vec<_>>>()?;
partial_decoder_cache.extend(key_decoder_pairs);
}
// Assemble partial decoders ahead of time
let partial_decoder_cache = self.assemble_partial_decoders(
chunk_descriptions.as_ref(),
chunk_concurrent_limit,
&codec_options,
)?;

py.allow_threads(move || {
// FIXME: the `decode_into` methods only support fixed length data types.
Expand All @@ -359,20 +385,22 @@ impl CodecPipelineImpl {
.data_type()
.fixed_size()
.ok_or("variable length data type not supported")
.map_py_err::<PyTypeError>()?,
.map_py_err_from_str::<PyTypeError>()?,
&output_shape,
subset,
)
.map_py_err::<PyRuntimeError>()?
.map_py_err_from_str::<PyRuntimeError>()?
};

// See zarrs::array::Array::retrieve_chunk_subset_into
if chunk_subset.start().iter().all(|&o| o == 0)
&& chunk_subset.shape() == item.representation().shape_u64()
{
// See zarrs::array::Array::retrieve_chunk_into
if let Some(chunk_encoded) =
self.store.get(item.key()).map_py_err::<PyRuntimeError>()?
if let Some(chunk_encoded) = self
.store
.get(item.key())
.map_py_err_from_str::<PyRuntimeError>()?
{
// Decode the encoded data into the output buffer
let chunk_encoded: Vec<u8> = chunk_encoded.into();
Expand Down Expand Up @@ -401,7 +429,7 @@ impl CodecPipelineImpl {
&codec_options,
)
}
.map_codec_err()
.map_py_err()
};

iter_concurrent_limit!(
Expand Down Expand Up @@ -454,7 +482,7 @@ impl CodecPipelineImpl {
&input_shape,
item.item.representation().data_type(),
)
.map_codec_err()?;
.map_py_err()?;
self.store_chunk_subset_bytes(
&item,
&self.codec_chain,
Expand Down
36 changes: 36 additions & 0 deletions src/map_py_err.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
use std::fmt::Display;

use pyo3::{PyErr, PyResult, PyTypeInfo};
use zarrs::array::codec::CodecError;

pub(crate) trait PyErrStrExt<T> {
fn map_py_err_from_str<PE: PyTypeInfo>(self) -> PyResult<T>;
}

impl<T, E: Display> PyErrStrExt<T> for Result<T, E> {
fn map_py_err_from_str<PE: PyTypeInfo>(self) -> PyResult<T> {
self.map_err(|e| PyErr::new::<PE, _>(format!("{e}")))
}
}

pub(crate) trait PyErrExt<T> {
fn map_py_err(self) -> PyResult<T>;
}

impl<T> PyErrExt<T> for Result<T, CodecError> {
fn map_py_err(self) -> PyResult<T> {
// see https://docs.python.org/3/library/exceptions.html#exception-hierarchy
self.map_err(|e| match e {
// requested indexing operation doesn’t match shape
CodecError::IncompatibleIndexer(_)
| CodecError::IncompatibleDimensionalityError(_)
| CodecError::InvalidByteRangeError(_) => {
PyErr::new::<pyo3::exceptions::PyIndexError, _>(format!("{e}"))
}
// some pipe, file, or subprocess failed
CodecError::IOError(_) => PyErr::new::<pyo3::exceptions::PyOSError, _>(format!("{e}")),
// all the rest: some unknown runtime problem
e => PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!("{e}")),
})
}
}
5 changes: 3 additions & 2 deletions src/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ use zarrs::storage::{
storage_adapter::async_to_sync::AsyncToSyncStorageAdapter, ReadableWritableListableStorage,
};

use crate::{runtime::tokio_block_on, utils::PyErrExt};
use crate::map_py_err::PyErrStrExt as _;
use crate::runtime::tokio_block_on;

mod filesystem;
mod http;
Expand Down Expand Up @@ -79,7 +80,7 @@ fn opendal_builder_to_sync_store<B: Builder>(
builder: B,
) -> PyResult<ReadableWritableListableStorage> {
let operator = opendal::Operator::new(builder)
.map_py_err::<PyValueError>()?
.map_py_err_from_str::<PyValueError>()?
.finish();
let store = Arc::new(zarrs_opendal::AsyncOpendalStore::new(operator));
let store = Arc::new(AsyncToSyncStorageAdapter::new(store, tokio_block_on()));
Expand Down
Loading
Loading