Skip to content
39 changes: 17 additions & 22 deletions python/zarrs/_internal.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ from enum import Enum, auto
import numpy
import numpy.typing

class Basic:
def __new__(cls, byte_interface: typing.Any, chunk_spec: typing.Any): ...
...

class CodecPipelineImpl:
def __new__(
cls,
Expand All @@ -20,34 +24,15 @@ class CodecPipelineImpl:
): ...
def retrieve_chunks_and_apply_index(
self,
chunk_descriptions: typing.Sequence[
tuple[
tuple[
StoreConfig, str, typing.Sequence[int], str, typing.Sequence[int]
],
typing.Sequence[slice],
typing.Sequence[slice],
]
],
chunk_descriptions: typing.Sequence[WithSubset],
value: numpy.NDArray[typing.Any],
) -> None: ...
def retrieve_chunks(
self,
chunk_descriptions: typing.Sequence[
tuple[StoreConfig, str, typing.Sequence[int], str, typing.Sequence[int]]
],
self, chunk_descriptions: typing.Sequence[Basic]
) -> list[numpy.typing.NDArray[numpy.uint8]]: ...
def store_chunks_with_indices(
self,
chunk_descriptions: typing.Sequence[
tuple[
tuple[
StoreConfig, str, typing.Sequence[int], str, typing.Sequence[int]
],
typing.Sequence[slice],
typing.Sequence[slice],
]
],
chunk_descriptions: typing.Sequence[WithSubset],
value: numpy.NDArray[typing.Any],
) -> None: ...

Expand All @@ -57,6 +42,16 @@ class FilesystemStoreConfig:
class HttpStoreConfig:
endpoint: str

class WithSubset:
def __new__(
cls,
item: Basic,
chunk_subset: typing.Sequence[slice],
subset: typing.Sequence[slice],
shape: typing.Sequence[int],
): ...
...

class StoreConfig(Enum):
Filesystem = auto()
Http = auto()
38 changes: 16 additions & 22 deletions python/zarrs/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,7 @@
from typing import TYPE_CHECKING, TypedDict

import numpy as np
from zarr.abc.codec import (
Codec,
CodecPipeline,
)
from zarr.abc.codec import Codec, CodecPipeline
from zarr.core.config import config

if TYPE_CHECKING:
Expand All @@ -18,7 +15,7 @@

from zarr.abc.store import ByteGetter, ByteSetter
from zarr.core.array_spec import ArraySpec
from zarr.core.buffer import Buffer, NDBuffer
from zarr.core.buffer import Buffer, NDArrayLike, NDBuffer
from zarr.core.chunk_grids import ChunkGrid
from zarr.core.common import ChunkCoords
from zarr.core.indexing import SelectorTuple
Expand Down Expand Up @@ -120,30 +117,28 @@ async def read(
batch_info: Iterable[
tuple[ByteGetter, ArraySpec, SelectorTuple, SelectorTuple]
],
out: NDBuffer,
out: NDBuffer, # type: ignore
drop_axes: tuple[int, ...] = (), # FIXME: unused
) -> None:
out = out.as_ndarray_like() # FIXME: Error if array is not in host memory
# FIXME: Error if array is not in host memory
out: NDArrayLike = out.as_ndarray_like()
if not out.dtype.isnative:
raise RuntimeError("Non-native byte order not supported")
try:
chunks_desc = make_chunk_info_for_rust_with_indices(batch_info, drop_axes)
index_in_rust = True
chunks_desc = make_chunk_info_for_rust_with_indices(
batch_info, drop_axes, out.shape
)
except (DiscontiguousArrayError, CollapsedDimensionError):
chunks_desc = make_chunk_info_for_rust(batch_info)
index_in_rust = False
if index_in_rust:
else:
await asyncio.to_thread(
self.impl.retrieve_chunks_and_apply_index,
chunks_desc,
out,
)
return None
chunks = await asyncio.to_thread(self.impl.retrieve_chunks, chunks_desc)
for chunk, chunk_info in zip(chunks, batch_info):
out_selection = chunk_info[3]
selection = chunk_info[2]
spec = chunk_info[1]
for chunk, (_, spec, selection, out_selection) in zip(chunks, batch_info):
chunk_reshaped = chunk.view(spec.dtype).reshape(spec.shape)
chunk_selected = chunk_reshaped[selection]
if drop_axes:
Expand All @@ -155,18 +150,17 @@ async def write(
batch_info: Iterable[
tuple[ByteSetter, ArraySpec, SelectorTuple, SelectorTuple]
],
value: NDBuffer,
value: NDBuffer, # type: ignore
drop_axes: tuple[int, ...] = (),
) -> None:
value = value.as_ndarray_like() # FIXME: Error if array is not in host memory
# FIXME: Error if array is not in host memory
value: NDArrayLike | np.ndarray = value.as_ndarray_like()
if not value.dtype.isnative:
value = np.ascontiguousarray(value, dtype=value.dtype.newbyteorder("="))
elif not value.flags.c_contiguous:
value = np.ascontiguousarray(value)
chunks_desc = make_chunk_info_for_rust_with_indices(batch_info, drop_axes)
await asyncio.to_thread(
self.impl.store_chunks_with_indices,
chunks_desc,
value,
chunks_desc = make_chunk_info_for_rust_with_indices(
batch_info, drop_axes, value.shape
)
await asyncio.to_thread(self.impl.store_chunks_with_indices, chunks_desc, value)
return None
44 changes: 20 additions & 24 deletions python/zarrs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,19 @@
import operator
import os
from functools import reduce
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING

import numpy as np
from zarr.core.indexing import SelectorTuple, is_integer

from zarrs._internal import Basic, WithSubset

if TYPE_CHECKING:
from collections.abc import Iterable
from types import EllipsisType

from zarr.abc.store import ByteGetter, ByteSetter, Store
from zarr.abc.store import ByteGetter, ByteSetter
from zarr.core.array_spec import ArraySpec
from zarr.core.common import ChunkCoords


# adapted from https://docs.python.org/3/library/concurrent.futures.html#concurrent.futures.ThreadPoolExecutor
Expand Down Expand Up @@ -62,18 +63,6 @@ def selector_tuple_to_slice_selection(selector_tuple: SelectorTuple) -> list[sli
return make_slice_selection(selector_tuple)


def convert_chunk_to_primitive(
byte_interface: ByteGetter | ByteSetter, chunk_spec: ArraySpec
) -> tuple[Store, str, ChunkCoords, str, Any]:
return (
byte_interface.store,
byte_interface.path,
chunk_spec.shape,
str(chunk_spec.dtype),
chunk_spec.fill_value.tobytes(),
)


def resulting_shape_from_index(
array_shape: tuple[int, ...],
index_tuple: tuple[int | slice | EllipsisType | np.ndarray],
Expand Down Expand Up @@ -150,10 +139,12 @@ def make_chunk_info_for_rust_with_indices(
tuple[ByteGetter | ByteSetter, ArraySpec, SelectorTuple, SelectorTuple]
],
drop_axes: tuple[int, ...],
) -> list[tuple[tuple[Store, str, ChunkCoords, str, Any], list[slice], list[slice]]]:
chunk_info_with_indices = []
shape: tuple[int, ...],
) -> list[WithSubset]:
shape = shape if shape else (1,) # constant array
chunk_info_with_indices: list[WithSubset] = []
for byte_getter, chunk_spec, chunk_selection, out_selection in batch_info:
chunk_info = convert_chunk_to_primitive(byte_getter, chunk_spec)
chunk_info = Basic(byte_getter, chunk_spec)
out_selection_as_slices = selector_tuple_to_slice_selection(out_selection)
chunk_selection_as_slices = selector_tuple_to_slice_selection(chunk_selection)
shape_chunk_selection_slices = get_shape_for_selector(
Expand All @@ -170,7 +161,12 @@ def make_chunk_info_for_rust_with_indices(
f"{shape_chunk_selection} != {shape_chunk_selection_slices}"
)
chunk_info_with_indices.append(
(chunk_info, out_selection_as_slices, chunk_selection_as_slices)
WithSubset(
chunk_info,
chunk_subset=chunk_selection_as_slices,
subset=out_selection_as_slices,
shape=shape,
)
)
return chunk_info_with_indices

Expand All @@ -179,8 +175,8 @@ def make_chunk_info_for_rust(
batch_info: Iterable[
tuple[ByteGetter | ByteSetter, ArraySpec, SelectorTuple, SelectorTuple]
],
) -> list[tuple[Store, str, ChunkCoords, str, Any]]:
return list(
convert_chunk_to_primitive(byte_getter, chunk_spec)
for (byte_getter, chunk_spec, _, _) in batch_info
)
) -> list[Basic]:
return [
Basic(byte_interface, chunk_spec)
for (byte_interface, chunk_spec, _, _) in batch_info
]
Loading
Loading