Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix writing of DataTree subgroups to zarr or netCDF #9677

Merged
merged 9 commits into from
Nov 4, 2024
90 changes: 25 additions & 65 deletions xarray/core/datatree_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,33 +11,6 @@
T_DataTreeNetcdfTypes = Literal["NETCDF4"]


def _get_nc_dataset_class(engine: T_DataTreeNetcdfEngine | None):
if engine == "netcdf4":
from netCDF4 import Dataset
elif engine == "h5netcdf":
from h5netcdf.legacyapi import Dataset
elif engine is None:
try:
from netCDF4 import Dataset
except ImportError:
from h5netcdf.legacyapi import Dataset
else:
raise ValueError(f"unsupported engine: {engine}")
return Dataset


def _create_empty_netcdf_group(
filename: str | PathLike,
group: str,
mode: NetcdfWriteModes,
engine: T_DataTreeNetcdfEngine | None,
):
ncDataset = _get_nc_dataset_class(engine)

with ncDataset(filename, mode=mode) as rootgrp:
rootgrp.createGroup(group)


def _datatree_to_netcdf(
dt: DataTree,
filepath: str | PathLike,
Expand Down Expand Up @@ -85,34 +58,23 @@ def _datatree_to_netcdf(
unlimited_dims = {}

for node in dt.subtree:
ds = node.to_dataset(inherit=False)
group_path = node.path
if ds is None:
_create_empty_netcdf_group(filepath, group_path, mode, engine)
else:
ds.to_netcdf(
filepath,
group=group_path,
mode=mode,
encoding=encoding.get(node.path),
unlimited_dims=unlimited_dims.get(node.path),
engine=engine,
format=format,
compute=compute,
**kwargs,
)
at_root = node is dt
ds = node.to_dataset(inherit=at_root)
group_path = None if at_root else "/" + node.relative_to(dt)
ds.to_netcdf(
filepath,
group=group_path,
mode=mode,
encoding=encoding.get(node.path),
unlimited_dims=unlimited_dims.get(node.path),
engine=engine,
format=format,
compute=compute,
**kwargs,
)
mode = "a"


def _create_empty_zarr_group(
store: MutableMapping | str | PathLike[str], group: str, mode: ZarrWriteModes
):
import zarr

root = zarr.open_group(store, mode=mode)
root.create_group(group, overwrite=True)


def _datatree_to_zarr(
dt: DataTree,
store: MutableMapping | str | PathLike[str],
Expand Down Expand Up @@ -151,19 +113,17 @@ def _datatree_to_zarr(
)

for node in dt.subtree:
ds = node.to_dataset(inherit=False)
group_path = node.path
if ds is None:
_create_empty_zarr_group(store, group_path, mode)
else:
ds.to_zarr(
store,
group=group_path,
mode=mode,
encoding=encoding.get(node.path),
consolidated=False,
**kwargs,
)
at_root = node is dt
ds = node.to_dataset(inherit=at_root)
group_path = None if at_root else "/" + node.relative_to(dt)
ds.to_zarr(
store,
group=group_path,
mode=mode,
encoding=encoding.get(node.path),
consolidated=False,
**kwargs,
)
if "w" in mode:
mode = "a"

Expand Down
36 changes: 36 additions & 0 deletions xarray/tests/test_backends_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,24 @@ def test_netcdf_encoding(self, tmpdir, simple_datatree):
with pytest.raises(ValueError, match="unexpected encoding group.*"):
original_dt.to_netcdf(filepath, encoding=enc, engine=self.engine)

def test_write_subgroup(self, tmpdir):
original_dt = DataTree.from_dict(
{
"/": xr.Dataset(coords={"x": [1, 2, 3]}),
"/child": xr.Dataset({"foo": ("x", [4, 5, 6])}),
}
).children["child"]

expected_dt = original_dt.copy()
expected_dt.name = None

filepath = tmpdir / "test.zarr"
original_dt.to_netcdf(filepath, engine=self.engine)

with open_datatree(filepath, engine=self.engine) as roundtrip_dt:
assert_equal(original_dt, roundtrip_dt)
assert_identical(expected_dt, roundtrip_dt)


@requires_netCDF4
class TestNetCDF4DatatreeIO(DatatreeIOBase):
Expand Down Expand Up @@ -532,3 +550,21 @@ def test_open_groups_chunks(self, tmpdir) -> None:

for ds in dict_of_datasets.values():
ds.close()

def test_write_subgroup(self, tmpdir):
original_dt = DataTree.from_dict(
{
"/": xr.Dataset(coords={"x": [1, 2, 3]}),
"/child": xr.Dataset({"foo": ("x", [4, 5, 6])}),
}
).children["child"]

expected_dt = original_dt.copy()
expected_dt.name = None

filepath = tmpdir / "test.zarr"
original_dt.to_zarr(filepath)

with open_datatree(filepath, engine="zarr") as roundtrip_dt:
assert_equal(original_dt, roundtrip_dt)
assert_identical(expected_dt, roundtrip_dt)
Loading