Skip to content

Commit

Permalink
Implement structured arrays as Zarr group of arrays (#603)
Browse files Browse the repository at this point in the history
* Implement structured arrays as Zarr group of arrays

* Update JAX tests to include ones that used to use structured arrays
  • Loading branch information
tomwhite authored Oct 31, 2024
1 parent 2eead17 commit c372711
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 19 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/jax-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ jobs:
- name: Run tests
run: |
# exclude tests that rely on structured types since JAX doesn't support these
pytest -k "not argmax and not argmin and not mean and not std and not var and not apply_reduction and not broadcast_trick and not groupby and not object_dtype"
# exclude a few tests that don't work on JAX
pytest -k "not broadcast_trick and not object_dtype"
env:
CUBED_BACKEND_ARRAY_API_MODULE: jax.numpy
JAX_ENABLE_X64: True
7 changes: 2 additions & 5 deletions cubed/array_api/statistical_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,9 @@ def max(x, /, *, axis=None, keepdims=False, split_every=None):
def mean(x, /, *, axis=None, keepdims=False, split_every=None):
if x.dtype not in _real_floating_dtypes:
raise TypeError("Only real floating-point dtypes are allowed in mean")
# This implementation uses NumPy and Zarr's structured arrays to store a
# This implementation uses a Zarr group of two arrays to store a
# pair of fields needed to keep per-chunk counts and totals for computing
# the mean. Structured arrays are row-based, so are less efficient than
# regular arrays, but for a function that reduces the amount of data stored,
# this is usually OK. An alternative would be to add support for multiple
# outputs.
# the mean.
dtype = x.dtype
intermediate_dtype = [("n", nxp.int64), ("total", nxp.float64)]
extra_func_kwargs = dict(dtype=intermediate_dtype)
Expand Down
2 changes: 1 addition & 1 deletion cubed/primitive/blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def apply_blockwise(out_coords: List[int], *, config: BlockwiseSpec) -> None:
out_chunk_key = key_to_slices(
out_coords_tuple, config.writes_list[i].array, config.writes_list[i].chunks
)
if isinstance(result, dict): # structured array with named fields
if isinstance(result, dict): # group of arrays with named fields
for k, v in result.items():
v = backend_array_to_numpy_array(v)
config.writes_list[i].open().set_basic_selection(
Expand Down
59 changes: 49 additions & 10 deletions cubed/storage/backends/zarr_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,28 @@
from numcodecs.registry import get_codec

from cubed.types import T_DType, T_RegularChunks, T_Shape, T_Store
from cubed.utils import join_path


class ZarrArrayGroup(dict):
def __init__(
self,
shape: Optional[T_Shape] = None,
dtype: Optional[T_DType] = None,
chunks: Optional[T_RegularChunks] = None,
):
dict.__init__(self)
self.shape = shape
self.dtype = dtype
self.chunks = chunks

def __getitem__(self, key):
if isinstance(key, str):
return super().__getitem__(key)
return {field: zarray[key] for field, zarray in self.items()}

def set_basic_selection(self, selection, value, fields=None):
self[fields][selection] = value


def open_zarr_array(
Expand All @@ -20,13 +42,30 @@ def open_zarr_array(
if isinstance(compressor, dict):
compressor = get_codec(compressor)

return zarr.open_array(
store,
mode=mode,
shape=shape,
dtype=dtype,
chunks=chunks,
path=path,
compressor=compressor,
**kwargs,
)
if dtype is None or not hasattr(dtype, "fields") or dtype.fields is None:
return zarr.open_array(
store,
mode=mode,
shape=shape,
dtype=dtype,
chunks=chunks,
path=path,
compressor=compressor,
**kwargs,
)
else:
ret = ZarrArrayGroup(shape=shape, dtype=dtype, chunks=chunks)
for field in dtype.fields:
field_dtype, _ = dtype.fields[field]
field_path = field if path is None else join_path(path, field)
ret[field] = zarr.open_array(
store,
mode=mode,
shape=shape,
dtype=field_dtype,
chunks=chunks,
path=field_path,
compressor=compressor,
**kwargs,
)
return ret
2 changes: 1 addition & 1 deletion cubed/tests/test_gufunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def stats(x):
r = np.random.normal(size=(10, 20, 30))
a = cubed.from_array(r, chunks=(5, 5, 30), spec=spec)
actual = apply_gufunc(stats, "(i)->()", a, output_dtypes="f", vectorize=vectorize)
expected = np.mean(r, axis=-1, dtype=np.float32)
expected = nxp.mean(r, axis=-1, dtype=np.float32)

assert actual.compute().shape == expected.shape
assert_allclose(actual.compute(), expected)
Expand Down

0 comments on commit c372711

Please sign in to comment.