Skip to content

[WIP]: Force Zarr coordinate reads to be on the host #10079

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 64 additions & 4 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
@@ -179,14 +179,31 @@ def encode_zarr_attr_value(value):
return encoded


def _is_coordinate_variable(zarr_array, name):
if _zarr_v3():
if zarr_array.metadata.zarr_format == 2:
is_coordinate = name in zarr_array.metadata.attributes.get(
"_ARRAY_DIMENSIONS", []
)
else:
is_coordinate = name in (zarr_array.metadata.dimension_names or [])
else:
is_coordinate = name in zarr_array.attrs.get("_ARRAY_DIMENSIONS", [])
return is_coordinate


class ZarrArrayWrapper(BackendArray):
__slots__ = ("_array", "dtype", "shape")
__slots__ = ("_array", "coords_buffer_prototype", "dtype", "is_coordinate", "shape")

def __init__(self, zarr_array):
def __init__(
self, zarr_array, is_coordinate: bool, coords_buffer_prototype: Any | None
):
# some callers attempt to evaluate an array if an `array` property exists on the object.
# we prefix with _ to avoid this inference.
self._array = zarr_array
self.shape = self._array.shape
self.is_coordinate = is_coordinate
self.coords_buffer_prototype = coords_buffer_prototype

# preserve vlen string object dtype (GH 7328)
if (
@@ -210,7 +227,14 @@ def _vindex(self, key):
return self._array.vindex[key]

def _getitem(self, key):
return self._array[key]
kwargs = {}
if _zarr_v3():
if self.is_coordinate:
prototype = self.coords_buffer_prototype
else:
prototype = None
kwargs["prototype"] = prototype
return self._array.get_basic_selection(key, **kwargs)

def __getitem__(self, key):
array = self._array
@@ -605,6 +629,7 @@ class ZarrStore(AbstractWritableDataStore):
"_cache_members",
"_close_store_on_close",
"_consolidate_on_close",
"_coords_buffer_prototype",
"_group",
"_members",
"_mode",
@@ -636,6 +661,7 @@ def open_store(
use_zarr_fill_value_as_mask=None,
write_empty: bool | None = None,
cache_members: bool = True,
coords_buffer_prototype: Any | None = None,
):
(
zarr_group,
@@ -668,6 +694,7 @@ def open_store(
close_store_on_close,
use_zarr_fill_value_as_mask,
cache_members=cache_members,
coords_buffer_prototype=coords_buffer_prototype,
)
for group in group_paths
}
@@ -691,6 +718,7 @@ def open_group(
use_zarr_fill_value_as_mask=None,
write_empty: bool | None = None,
cache_members: bool = True,
coords_buffer_prototype: Any | None = None,
):
(
zarr_group,
@@ -722,6 +750,7 @@ def open_group(
close_store_on_close,
use_zarr_fill_value_as_mask,
cache_members,
coords_buffer_prototype,
)

def __init__(
@@ -736,6 +765,7 @@ def __init__(
close_store_on_close: bool = False,
use_zarr_fill_value_as_mask=None,
cache_members: bool = True,
coords_buffer_prototype: Any | None = None,
):
self.zarr_group = zarr_group
self._read_only = self.zarr_group.read_only
@@ -751,6 +781,14 @@ def __init__(
self._use_zarr_fill_value_as_mask = use_zarr_fill_value_as_mask
self._cache_members: bool = cache_members
self._members: dict[str, ZarrArray | ZarrGroup] = {}
if _zarr_v3() and coords_buffer_prototype is None:
# Once zarr-v3 is required we can just have this as the default
# https://github.com/zarr-developers/zarr-python/issues/2871
# Use the public API once available
from zarr.core.buffer.cpu import buffer_prototype

coords_buffer_prototype = buffer_prototype
self._coords_buffer_prototype = coords_buffer_prototype

if self._cache_members:
# initialize the cache
@@ -809,7 +847,15 @@ def ds(self):

def open_store_variable(self, name):
zarr_array = self.members[name]
data = indexing.LazilyIndexedArray(ZarrArrayWrapper(zarr_array))
is_coordinate = _is_coordinate_variable(zarr_array, name)

data = indexing.LazilyIndexedArray(
ZarrArrayWrapper(
zarr_array,
is_coordinate=is_coordinate,
coords_buffer_prototype=self._coords_buffer_prototype,
)
)
try_nczarr = self._mode == "r"
dimensions, attributes = _get_zarr_dims_and_attrs(
zarr_array, DIMENSION_KEY, try_nczarr
@@ -1332,6 +1378,7 @@ def open_zarr(
use_zarr_fill_value_as_mask=None,
chunked_array_type: str | None = None,
from_array_kwargs: dict[str, Any] | None = None,
coords_buffer_prototype: Any | None = None,
**kwargs,
):
"""Load and decode a dataset from a Zarr store.
@@ -1442,6 +1489,12 @@ def open_zarr(
chunked arrays, via whichever chunk manager is specified through the ``chunked_array_type`` kwarg.
Defaults to ``{'manager': 'dask'}``, meaning additional kwargs will be passed eventually to
:py:func:`dask.array.from_array`. Experimental API that should not be relied upon.
coords_buffer_prototype : zarr.buffer.BufferPrototype, optional
The buffer prototype to use for loading coordinate arrays. Zarr offers control over
which device's memory buffers are read into. By default, xarray will always load
*coordinate* buffers into host (CPU) memory, regardless of the global zarr
configuration. To override this behavior, explicitly pass the buffer prototype
to use for coordinates here.

Returns
-------
@@ -1485,6 +1538,7 @@ def open_zarr(
"storage_options": storage_options,
"zarr_version": zarr_version,
"zarr_format": zarr_format,
"coords_buffer_prototype": coords_buffer_prototype,
}

ds = open_dataset(
@@ -1557,6 +1611,7 @@ def open_dataset(
engine=None,
use_zarr_fill_value_as_mask=None,
cache_members: bool = True,
coords_buffer_prototype: Any | None = None,
) -> Dataset:
filename_or_obj = _normalize_path(filename_or_obj)
if not store:
@@ -1573,6 +1628,7 @@ def open_dataset(
use_zarr_fill_value_as_mask=None,
zarr_format=zarr_format,
cache_members=cache_members,
coords_buffer_prototype=coords_buffer_prototype,
)

store_entrypoint = StoreBackendEntrypoint()
@@ -1608,6 +1664,7 @@ def open_datatree(
storage_options=None,
zarr_version=None,
zarr_format=None,
coords_buffer_prototype: Any | None = None,
) -> DataTree:
filename_or_obj = _normalize_path(filename_or_obj)
groups_dict = self.open_groups_as_dict(
@@ -1627,6 +1684,7 @@ def open_datatree(
storage_options=storage_options,
zarr_version=zarr_version,
zarr_format=zarr_format,
coords_buffer_prototype=coords_buffer_prototype,
)

return datatree_from_dict_with_io_cleanup(groups_dict)
@@ -1650,6 +1708,7 @@ def open_groups_as_dict(
storage_options=None,
zarr_version=None,
zarr_format=None,
coords_buffer_prototype: Any | None = None,
) -> dict[str, Dataset]:
from xarray.core.treenode import NodePath

@@ -1672,6 +1731,7 @@ def open_groups_as_dict(
storage_options=storage_options,
zarr_version=zarr_version,
zarr_format=zarr_format,
coords_buffer_prototype=coords_buffer_prototype,
)

groups_dict = {}
31 changes: 31 additions & 0 deletions xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
@@ -3766,6 +3766,37 @@
xr.open_zarr(store=store, zarr_version=2, zarr_format=3)


@requires_zarr
def test_coords_buffer_prototype() -> None:
pytest.importorskip("zarr", minversion="3")

from zarr.core.buffer import cpu
from zarr.core.buffer.core import BufferPrototype

counter = 0

class Buffer(cpu.Buffer):
def __init__(self, *args, **kwargs):
nonlocal counter
counter += 1
super().__init__(*args, **kwargs)

class NDBuffer(cpu.NDBuffer):
def __init__(self, *args, **kwargs):
nonlocal counter
counter += 1
super().__init__(*args, **kwargs)

prototype = BufferPrototype(buffer=Buffer, nd_buffer=NDBuffer)

ds = create_test_data()
store = KVStore()
# type-ignore for zarr v2/v3 compat, even though this test is skipped for v2
ds.to_zarr(store=store, zarr_format=3) # type: ignore[call-overload, unused-ignore]
xr.open_dataset(store, engine="zarr", coords_buffer_prototype=prototype) # type: ignore[arg-type, unused-ignore]
assert counter > 0


@requires_scipy
class TestScipyInMemoryData(CFEncodedBase, NetCDF3Only):
engine: T_NetcdfEngine = "scipy"
@@ -4185,7 +4216,7 @@
fx.create_dataset(k, data=v)
with pytest.warns(UserWarning, match="The 'phony_dims' kwarg"):
with xr.open_dataset(tmp_file, engine="h5netcdf", group="bar") as ds:
assert ds.dims == {

Check warning on line 4219 in xarray/tests/test_backends.py

GitHub Actions / macos-latest py3.10

The return type of `Dataset.dims` will be changed to return a set of dimension names in future, in order to be more consistent with `DataArray.dims`. To access a mapping from dimension names to lengths, please use `Dataset.sizes`.

Check warning on line 4219 in xarray/tests/test_backends.py

GitHub Actions / macos-latest py3.10

The return type of `Dataset.dims` will be changed to return a set of dimension names in future, in order to be more consistent with `DataArray.dims`. To access a mapping from dimension names to lengths, please use `Dataset.sizes`.

Check warning on line 4219 in xarray/tests/test_backends.py

GitHub Actions / ubuntu-latest py3.10

The return type of `Dataset.dims` will be changed to return a set of dimension names in future, in order to be more consistent with `DataArray.dims`. To access a mapping from dimension names to lengths, please use `Dataset.sizes`.

Check warning on line 4219 in xarray/tests/test_backends.py

GitHub Actions / ubuntu-latest py3.10

The return type of `Dataset.dims` will be changed to return a set of dimension names in future, in order to be more consistent with `DataArray.dims`. To access a mapping from dimension names to lengths, please use `Dataset.sizes`.

Check warning on line 4219 in xarray/tests/test_backends.py

GitHub Actions / windows-latest py3.10

The return type of `Dataset.dims` will be changed to return a set of dimension names in future, in order to be more consistent with `DataArray.dims`. To access a mapping from dimension names to lengths, please use `Dataset.sizes`.

Check warning on line 4219 in xarray/tests/test_backends.py

GitHub Actions / windows-latest py3.10

The return type of `Dataset.dims` will be changed to return a set of dimension names in future, in order to be more consistent with `DataArray.dims`. To access a mapping from dimension names to lengths, please use `Dataset.sizes`.

Check warning on line 4219 in xarray/tests/test_backends.py

GitHub Actions / windows-latest py3.10

The return type of `Dataset.dims` will be changed to return a set of dimension names in future, in order to be more consistent with `DataArray.dims`. To access a mapping from dimension names to lengths, please use `Dataset.sizes`.

Check warning on line 4219 in xarray/tests/test_backends.py

GitHub Actions / windows-latest py3.10

The return type of `Dataset.dims` will be changed to return a set of dimension names in future, in order to be more consistent with `DataArray.dims`. To access a mapping from dimension names to lengths, please use `Dataset.sizes`.

Check warning on line 4219 in xarray/tests/test_backends.py

GitHub Actions / windows-latest py3.10

The return type of `Dataset.dims` will be changed to return a set of dimension names in future, in order to be more consistent with `DataArray.dims`. To access a mapping from dimension names to lengths, please use `Dataset.sizes`.

Check warning on line 4219 in xarray/tests/test_backends.py

GitHub Actions / windows-latest py3.10

The return type of `Dataset.dims` will be changed to return a set of dimension names in future, in order to be more consistent with `DataArray.dims`. To access a mapping from dimension names to lengths, please use `Dataset.sizes`.
"phony_dim_0": 5,
"phony_dim_1": 5,
"phony_dim_2": 5,

Unchanged files with check annotations Beta

attrs: _AttrsLike = None,
):
self._data = data
self._dims = self._parse_dimensions(dims)

Check warning on line 264 in xarray/namedarray/core.py

GitHub Actions / ubuntu-latest py3.10 bare-minimum

Duplicate dimension names present: dimensions {'x'} appear more than once in dims=('x', 'x'). We do not yet support duplicate dimension names, but we do allow initial construction of the object. We recommend you rename the dims immediately to become distinct, as most xarray functionality is likely to fail silently if you do not. To rename the dimensions you will need to set the ``.dims`` attribute of each variable, ``e.g. var.dims=('x0', 'x1')``.

Check warning on line 264 in xarray/namedarray/core.py

GitHub Actions / ubuntu-latest py3.10 bare-minimum

Duplicate dimension names present: dimensions {'x'} appear more than once in dims=('x', 'x'). We do not yet support duplicate dimension names, but we do allow initial construction of the object. We recommend you rename the dims immediately to become distinct, as most xarray functionality is likely to fail silently if you do not. To rename the dimensions you will need to set the ``.dims`` attribute of each variable, ``e.g. var.dims=('x0', 'x1')``.

Check warning on line 264 in xarray/namedarray/core.py

GitHub Actions / macos-latest py3.10

Duplicate dimension names present: dimensions {'x'} appear more than once in dims=('x', 'x'). We do not yet support duplicate dimension names, but we do allow initial construction of the object. We recommend you rename the dims immediately to become distinct, as most xarray functionality is likely to fail silently if you do not. To rename the dimensions you will need to set the ``.dims`` attribute of each variable, ``e.g. var.dims=('x0', 'x1')``.

Check warning on line 264 in xarray/namedarray/core.py

GitHub Actions / macos-latest py3.10

Duplicate dimension names present: dimensions {'x'} appear more than once in dims=('x', 'x'). We do not yet support duplicate dimension names, but we do allow initial construction of the object. We recommend you rename the dims immediately to become distinct, as most xarray functionality is likely to fail silently if you do not. To rename the dimensions you will need to set the ``.dims`` attribute of each variable, ``e.g. var.dims=('x0', 'x1')``.

Check warning on line 264 in xarray/namedarray/core.py

GitHub Actions / ubuntu-latest py3.10

Duplicate dimension names present: dimensions {'x'} appear more than once in dims=('x', 'x'). We do not yet support duplicate dimension names, but we do allow initial construction of the object. We recommend you rename the dims immediately to become distinct, as most xarray functionality is likely to fail silently if you do not. To rename the dimensions you will need to set the ``.dims`` attribute of each variable, ``e.g. var.dims=('x0', 'x1')``.

Check warning on line 264 in xarray/namedarray/core.py

GitHub Actions / ubuntu-latest py3.10

Duplicate dimension names present: dimensions {'x'} appear more than once in dims=('x', 'x'). We do not yet support duplicate dimension names, but we do allow initial construction of the object. We recommend you rename the dims immediately to become distinct, as most xarray functionality is likely to fail silently if you do not. To rename the dimensions you will need to set the ``.dims`` attribute of each variable, ``e.g. var.dims=('x0', 'x1')``.

Check warning on line 264 in xarray/namedarray/core.py

GitHub Actions / ubuntu-latest py3.13

Duplicate dimension names present: dimensions {'x'} appear more than once in dims=('x', 'x'). We do not yet support duplicate dimension names, but we do allow initial construction of the object. We recommend you rename the dims immediately to become distinct, as most xarray functionality is likely to fail silently if you do not. To rename the dimensions you will need to set the ``.dims`` attribute of each variable, ``e.g. var.dims=('x0', 'x1')``.

Check warning on line 264 in xarray/namedarray/core.py

GitHub Actions / ubuntu-latest py3.13

Duplicate dimension names present: dimensions {'x'} appear more than once in dims=('x', 'x'). We do not yet support duplicate dimension names, but we do allow initial construction of the object. We recommend you rename the dims immediately to become distinct, as most xarray functionality is likely to fail silently if you do not. To rename the dimensions you will need to set the ``.dims`` attribute of each variable, ``e.g. var.dims=('x0', 'x1')``.
self._attrs = dict(attrs) if attrs else None
def __init_subclass__(cls, **kwargs: Any) -> None:
xp = get_array_namespace(data)
if xp == np:
# numpy currently doesn't have a astype:
return data.astype(dtype, **kwargs)

Check warning on line 232 in xarray/core/duck_array_ops.py

GitHub Actions / macos-latest py3.10

invalid value encountered in cast

Check warning on line 232 in xarray/core/duck_array_ops.py

GitHub Actions / macos-latest py3.10

invalid value encountered in cast

Check warning on line 232 in xarray/core/duck_array_ops.py

GitHub Actions / macos-latest py3.13

invalid value encountered in cast

Check warning on line 232 in xarray/core/duck_array_ops.py

GitHub Actions / macos-latest py3.13

invalid value encountered in cast

Check warning on line 232 in xarray/core/duck_array_ops.py

GitHub Actions / macos-latest py3.13

invalid value encountered in cast

Check warning on line 232 in xarray/core/duck_array_ops.py

GitHub Actions / macos-latest py3.13

invalid value encountered in cast

Check warning on line 232 in xarray/core/duck_array_ops.py

GitHub Actions / ubuntu-latest py3.13 all-but-numba

invalid value encountered in cast

Check warning on line 232 in xarray/core/duck_array_ops.py

GitHub Actions / ubuntu-latest py3.13 all-but-numba

invalid value encountered in cast

Check warning on line 232 in xarray/core/duck_array_ops.py

GitHub Actions / ubuntu-latest py3.10

invalid value encountered in cast

Check warning on line 232 in xarray/core/duck_array_ops.py

GitHub Actions / ubuntu-latest py3.10

invalid value encountered in cast

Check warning on line 232 in xarray/core/duck_array_ops.py

GitHub Actions / ubuntu-latest py3.13

invalid value encountered in cast

Check warning on line 232 in xarray/core/duck_array_ops.py

GitHub Actions / ubuntu-latest py3.13

invalid value encountered in cast

Check warning on line 232 in xarray/core/duck_array_ops.py

GitHub Actions / ubuntu-latest py3.13

invalid value encountered in cast

Check warning on line 232 in xarray/core/duck_array_ops.py

GitHub Actions / ubuntu-latest py3.13

invalid value encountered in cast

Check warning on line 232 in xarray/core/duck_array_ops.py

GitHub Actions / windows-latest py3.10

invalid value encountered in cast

Check warning on line 232 in xarray/core/duck_array_ops.py

GitHub Actions / windows-latest py3.10

invalid value encountered in cast

Check warning on line 232 in xarray/core/duck_array_ops.py

GitHub Actions / windows-latest py3.13

invalid value encountered in cast

Check warning on line 232 in xarray/core/duck_array_ops.py

GitHub Actions / windows-latest py3.13

invalid value encountered in cast
return xp.astype(data, dtype, **kwargs)
return data.astype(dtype, **kwargs)
np_arr = xr.DataArray(np.array([1, 0]), dims="x")
xp_arr = xr.DataArray(xp.asarray([1, 0]), dims="x")
expected = xr.where(np_arr, 1, 0)
actual = xr.where(xp_arr, 1, 0)

Check failure on line 144 in xarray/tests/test_array_api.py

GitHub Actions / macos-latest py3.10

test_where TypeError: `condition` must be have a boolean data type

Check failure on line 144 in xarray/tests/test_array_api.py

GitHub Actions / ubuntu-latest py3.12 all-but-dask

test_where TypeError: `condition` must be have a boolean data type

Check failure on line 144 in xarray/tests/test_array_api.py

GitHub Actions / macos-latest py3.13

test_where TypeError: `condition` must be have a boolean data type

Check failure on line 144 in xarray/tests/test_array_api.py

GitHub Actions / ubuntu-latest py3.13 all-but-numba

test_where TypeError: `condition` must be have a boolean data type

Check failure on line 144 in xarray/tests/test_array_api.py

GitHub Actions / ubuntu-latest py3.10

test_where TypeError: `condition` must be have a boolean data type

Check failure on line 144 in xarray/tests/test_array_api.py

GitHub Actions / ubuntu-latest py3.13

test_where TypeError: `condition` must be have a boolean data type
assert isinstance(actual.data, Array)
assert_equal(actual, expected)
# otherwise numpy unsigned ints will silently cast to the signed counterpart
fill_value = fill_value.item()
# passes if provided fill value fits in encoded on-disk type
new_fill = encoded_dtype.type(fill_value)

Check warning on line 348 in xarray/coding/variables.py

GitHub Actions / ubuntu-latest py3.10 min-all-deps

NumPy will stop allowing conversion of out-of-bound Python integers to integer arrays. The conversion of 255 to int8 will fail in the future. For the old behavior, usually: np.array(value).astype(dtype)` will give the desired result (the cast overflows).

Check warning on line 348 in xarray/coding/variables.py

GitHub Actions / ubuntu-latest py3.10 min-all-deps

NumPy will stop allowing conversion of out-of-bound Python integers to integer arrays. The conversion of 255 to int8 will fail in the future. For the old behavior, usually: np.array(value).astype(dtype)` will give the desired result (the cast overflows).

Check warning on line 348 in xarray/coding/variables.py

GitHub Actions / ubuntu-latest py3.10 min-all-deps

NumPy will stop allowing conversion of out-of-bound Python integers to integer arrays. The conversion of 255 to int8 will fail in the future. For the old behavior, usually: np.array(value).astype(dtype)` will give the desired result (the cast overflows).

Check warning on line 348 in xarray/coding/variables.py

GitHub Actions / ubuntu-latest py3.10 min-all-deps

NumPy will stop allowing conversion of out-of-bound Python integers to integer arrays. The conversion of 255 to int8 will fail in the future. For the old behavior, usually: np.array(value).astype(dtype)` will give the desired result (the cast overflows).

Check warning on line 348 in xarray/coding/variables.py

GitHub Actions / ubuntu-latest py3.10 min-all-deps

NumPy will stop allowing conversion of out-of-bound Python integers to integer arrays. The conversion of 255 to int8 will fail in the future. For the old behavior, usually: np.array(value).astype(dtype)` will give the desired result (the cast overflows).

Check warning on line 348 in xarray/coding/variables.py

GitHub Actions / ubuntu-latest py3.10 min-all-deps

NumPy will stop allowing conversion of out-of-bound Python integers to integer arrays. The conversion of 255 to int8 will fail in the future. For the old behavior, usually: np.array(value).astype(dtype)` will give the desired result (the cast overflows).

Check warning on line 348 in xarray/coding/variables.py

GitHub Actions / ubuntu-latest py3.10 min-all-deps

NumPy will stop allowing conversion of out-of-bound Python integers to integer arrays. The conversion of 255 to int8 will fail in the future. For the old behavior, usually: np.array(value).astype(dtype)` will give the desired result (the cast overflows).

Check warning on line 348 in xarray/coding/variables.py

GitHub Actions / ubuntu-latest py3.10 min-all-deps

NumPy will stop allowing conversion of out-of-bound Python integers to integer arrays. The conversion of 255 to int8 will fail in the future. For the old behavior, usually: np.array(value).astype(dtype)` will give the desired result (the cast overflows).

Check warning on line 348 in xarray/coding/variables.py

GitHub Actions / ubuntu-latest py3.10 min-all-deps

NumPy will stop allowing conversion of out-of-bound Python integers to integer arrays. The conversion of 255 to int8 will fail in the future. For the old behavior, usually: np.array(value).astype(dtype)` will give the desired result (the cast overflows).

Check warning on line 348 in xarray/coding/variables.py

GitHub Actions / ubuntu-latest py3.10 min-all-deps

NumPy will stop allowing conversion of out-of-bound Python integers to integer arrays. The conversion of 255 to int8 will fail in the future. For the old behavior, usually: np.array(value).astype(dtype)` will give the desired result (the cast overflows).
except OverflowError:
encoded_kind_str = "signed" if encoded_dtype.kind == "i" else "unsigned"
warnings.warn(