From 577221db8479fe6175206648e3b0a0c9f22d0cfe Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 4 Nov 2024 13:43:03 -0800 Subject: [PATCH] Fix writing of DataTree subgroups to zarr or netCDF (#9677) * Fix writing of DataTree subgroups to zarr or netCDF Consider a DataTree with a group, e.g., `tree = DataTree.from_dict({'/': ... '/child': ...})` If we write `tree['/child']` to disk, the result should have groups relative to `'/child'`, so writing and reading from the same path restores the same object. In addition, coordinates defined at the root should be written to disk instead of being omitted. * Add write_inherited_coords for additional control in DataTree.to_zarr As discussed in the last xarray meeting, this defaults to write_inherited_coords=True, which has a little more overhead but means you always get coordinates when opening a sub-group. * Switch write_inherited_coords default to false * add whats new * remove unused import --- doc/whats-new.rst | 9 ++- xarray/core/datatree.py | 14 ++++ xarray/core/datatree_io.py | 106 +++++++------------------ xarray/tests/test_backends_datatree.py | 74 +++++++++++++++++ 4 files changed, 124 insertions(+), 79 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index b32822d5129..4659978df8a 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -23,6 +23,9 @@ New Features ~~~~~~~~~~~~ - Added :py:meth:`DataTree.persist` method (:issue:`9675`, :pull:`9682`). By `Sam Levang `_. +- Added ``write_inherited_coords`` option to :py:meth:`DataTree.to_netcdf` + and :py:meth:`DataTree.to_zarr` (:pull:`9677`). + By `Stephan Hoyer `_. - Support lazy grouping by dask arrays, and allow specifying ordered groups with ``UniqueGrouper(labels=["a", "b", "c"])`` (:issue:`2852`, :issue:`757`). By `Deepak Cherian `_. @@ -42,7 +45,11 @@ Deprecations Bug fixes ~~~~~~~~~ -- Fix inadvertent deep-copying of child data in DataTree. +- Fix inadvertent deep-copying of child data in DataTree (:issue:`9683`, + :pull:`9684`). + By `Stephan Hoyer `_. +- Avoid including parent groups when writing DataTree subgroups to Zarr or + netCDF (:pull:`9682`). By `Stephan Hoyer `_. - Fix regression in the interoperability of :py:meth:`DataArray.polyfit` and :py:meth:`xr.polyval` for date-time coordinates. (:pull:`9691`). By `Pascal Bourgault `_. diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index eab5d30c7dc..efbdd6bc8eb 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -1573,6 +1573,7 @@ def to_netcdf( format: T_DataTreeNetcdfTypes | None = None, engine: T_DataTreeNetcdfEngine | None = None, group: str | None = None, + write_inherited_coords: bool = False, compute: bool = True, **kwargs, ): @@ -1609,6 +1610,11 @@ def to_netcdf( group : str, optional Path to the netCDF4 group in the given file to open as the root group of the ``DataTree``. Currently, specifying a group is not supported. + write_inherited_coords : bool, default: False + If true, replicate inherited coordinates on all descendant nodes. + Otherwise, only write coordinates at the level at which they are + originally defined. This saves disk space, but requires opening the + full tree to load inherited coordinates. compute : bool, default: True If true compute immediately, otherwise return a ``dask.delayed.Delayed`` object that can be computed later. @@ -1632,6 +1638,7 @@ def to_netcdf( format=format, engine=engine, group=group, + write_inherited_coords=write_inherited_coords, compute=compute, **kwargs, ) @@ -1643,6 +1650,7 @@ def to_zarr( encoding=None, consolidated: bool = True, group: str | None = None, + write_inherited_coords: bool = False, compute: Literal[True] = True, **kwargs, ): @@ -1668,6 +1676,11 @@ def to_zarr( after writing metadata for all groups. group : str, optional Group path. (a.k.a. `path` in zarr terminology.) + write_inherited_coords : bool, default: False + If true, replicate inherited coordinates on all descendant nodes. + Otherwise, only write coordinates at the level at which they are + originally defined. This saves disk space, but requires opening the + full tree to load inherited coordinates. compute : bool, default: True If true compute immediately, otherwise return a ``dask.delayed.Delayed`` object that can be computed later. Metadata @@ -1690,6 +1703,7 @@ def to_zarr( encoding=encoding, consolidated=consolidated, group=group, + write_inherited_coords=write_inherited_coords, compute=compute, **kwargs, ) diff --git a/xarray/core/datatree_io.py b/xarray/core/datatree_io.py index da1cc12c92a..3d0daa26b90 100644 --- a/xarray/core/datatree_io.py +++ b/xarray/core/datatree_io.py @@ -2,54 +2,15 @@ from collections.abc import Mapping, MutableMapping from os import PathLike -from typing import TYPE_CHECKING, Any, Literal, get_args +from typing import Any, Literal, get_args from xarray.core.datatree import DataTree from xarray.core.types import NetcdfWriteModes, ZarrWriteModes -if TYPE_CHECKING: - from h5netcdf.legacyapi import Dataset as h5Dataset - from netCDF4 import Dataset as ncDataset - T_DataTreeNetcdfEngine = Literal["netcdf4", "h5netcdf"] T_DataTreeNetcdfTypes = Literal["NETCDF4"] -def _get_nc_dataset_class( - engine: T_DataTreeNetcdfEngine | None, -) -> type[ncDataset] | type[h5Dataset]: - if engine == "netcdf4": - from netCDF4 import Dataset as ncDataset - - return ncDataset - if engine == "h5netcdf": - from h5netcdf.legacyapi import Dataset as h5Dataset - - return h5Dataset - if engine is None: - try: - from netCDF4 import Dataset as ncDataset - - return ncDataset - except ImportError: - from h5netcdf.legacyapi import Dataset as h5Dataset - - return h5Dataset - raise ValueError(f"unsupported engine: {engine}") - - -def _create_empty_netcdf_group( - filename: str | PathLike, - group: str, - mode: NetcdfWriteModes, - engine: T_DataTreeNetcdfEngine | None, -) -> None: - ncDataset = _get_nc_dataset_class(engine) - - with ncDataset(filename, mode=mode) as rootgrp: - rootgrp.createGroup(group) - - def _datatree_to_netcdf( dt: DataTree, filepath: str | PathLike, @@ -59,6 +20,7 @@ def _datatree_to_netcdf( format: T_DataTreeNetcdfTypes | None = None, engine: T_DataTreeNetcdfEngine | None = None, group: str | None = None, + write_inherited_coords: bool = False, compute: bool = True, **kwargs, ) -> None: @@ -97,34 +59,23 @@ def _datatree_to_netcdf( unlimited_dims = {} for node in dt.subtree: - ds = node.to_dataset(inherit=False) - group_path = node.path - if ds is None: - _create_empty_netcdf_group(filepath, group_path, mode, engine) - else: - ds.to_netcdf( - filepath, - group=group_path, - mode=mode, - encoding=encoding.get(node.path), - unlimited_dims=unlimited_dims.get(node.path), - engine=engine, - format=format, - compute=compute, - **kwargs, - ) + at_root = node is dt + ds = node.to_dataset(inherit=write_inherited_coords or at_root) + group_path = None if at_root else "/" + node.relative_to(dt) + ds.to_netcdf( + filepath, + group=group_path, + mode=mode, + encoding=encoding.get(node.path), + unlimited_dims=unlimited_dims.get(node.path), + engine=engine, + format=format, + compute=compute, + **kwargs, + ) mode = "a" -def _create_empty_zarr_group( - store: MutableMapping | str | PathLike[str], group: str, mode: ZarrWriteModes -): - import zarr - - root = zarr.open_group(store, mode=mode) - root.create_group(group, overwrite=True) - - def _datatree_to_zarr( dt: DataTree, store: MutableMapping | str | PathLike[str], @@ -132,6 +83,7 @@ def _datatree_to_zarr( encoding: Mapping[str, Any] | None = None, consolidated: bool = True, group: str | None = None, + write_inherited_coords: bool = False, compute: Literal[True] = True, **kwargs, ): @@ -163,19 +115,17 @@ def _datatree_to_zarr( ) for node in dt.subtree: - ds = node.to_dataset(inherit=False) - group_path = node.path - if ds is None: - _create_empty_zarr_group(store, group_path, mode) - else: - ds.to_zarr( - store, - group=group_path, - mode=mode, - encoding=encoding.get(node.path), - consolidated=False, - **kwargs, - ) + at_root = node is dt + ds = node.to_dataset(inherit=write_inherited_coords or at_root) + group_path = None if at_root else "/" + node.relative_to(dt) + ds.to_zarr( + store, + group=group_path, + mode=mode, + encoding=encoding.get(node.path), + consolidated=False, + **kwargs, + ) if "w" in mode: mode = "a" diff --git a/xarray/tests/test_backends_datatree.py b/xarray/tests/test_backends_datatree.py index 6e2d25249fb..608f9645ce8 100644 --- a/xarray/tests/test_backends_datatree.py +++ b/xarray/tests/test_backends_datatree.py @@ -196,6 +196,24 @@ def test_netcdf_encoding(self, tmpdir, simple_datatree): with pytest.raises(ValueError, match="unexpected encoding group.*"): original_dt.to_netcdf(filepath, encoding=enc, engine=self.engine) + def test_write_subgroup(self, tmpdir): + original_dt = DataTree.from_dict( + { + "/": xr.Dataset(coords={"x": [1, 2, 3]}), + "/child": xr.Dataset({"foo": ("x", [4, 5, 6])}), + } + ).children["child"] + + expected_dt = original_dt.copy() + expected_dt.name = None + + filepath = tmpdir / "test.zarr" + original_dt.to_netcdf(filepath, engine=self.engine) + + with open_datatree(filepath, engine=self.engine) as roundtrip_dt: + assert_equal(original_dt, roundtrip_dt) + assert_identical(expected_dt, roundtrip_dt) + @requires_netCDF4 class TestNetCDF4DatatreeIO(DatatreeIOBase): @@ -556,3 +574,59 @@ def test_open_groups_chunks(self, tmpdir) -> None: for ds in dict_of_datasets.values(): ds.close() + + def test_write_subgroup(self, tmpdir): + original_dt = DataTree.from_dict( + { + "/": xr.Dataset(coords={"x": [1, 2, 3]}), + "/child": xr.Dataset({"foo": ("x", [4, 5, 6])}), + } + ).children["child"] + + expected_dt = original_dt.copy() + expected_dt.name = None + + filepath = tmpdir / "test.zarr" + original_dt.to_zarr(filepath) + + with open_datatree(filepath, engine="zarr") as roundtrip_dt: + assert_equal(original_dt, roundtrip_dt) + assert_identical(expected_dt, roundtrip_dt) + + def test_write_inherited_coords_false(self, tmpdir): + original_dt = DataTree.from_dict( + { + "/": xr.Dataset(coords={"x": [1, 2, 3]}), + "/child": xr.Dataset({"foo": ("x", [4, 5, 6])}), + } + ) + + filepath = tmpdir / "test.zarr" + original_dt.to_zarr(filepath, write_inherited_coords=False) + + with open_datatree(filepath, engine="zarr") as roundtrip_dt: + assert_identical(original_dt, roundtrip_dt) + + expected_child = original_dt.children["child"].copy(inherit=False) + expected_child.name = None + with open_datatree(filepath, group="child", engine="zarr") as roundtrip_child: + assert_identical(expected_child, roundtrip_child) + + def test_write_inherited_coords_true(self, tmpdir): + original_dt = DataTree.from_dict( + { + "/": xr.Dataset(coords={"x": [1, 2, 3]}), + "/child": xr.Dataset({"foo": ("x", [4, 5, 6])}), + } + ) + + filepath = tmpdir / "test.zarr" + original_dt.to_zarr(filepath, write_inherited_coords=True) + + with open_datatree(filepath, engine="zarr") as roundtrip_dt: + assert_identical(original_dt, roundtrip_dt) + + expected_child = original_dt.children["child"].copy(inherit=True) + expected_child.name = None + with open_datatree(filepath, group="child", engine="zarr") as roundtrip_child: + assert_identical(expected_child, roundtrip_child)