Skip to content

Commit

Permalink
Migrate datatree io.py and common.py into xarray/core (#9011)
Browse files Browse the repository at this point in the history
  • Loading branch information
owenlittlejohns authored Jun 12, 2024
1 parent d9e4de6 commit cb3663d
Show file tree
Hide file tree
Showing 9 changed files with 150 additions and 152 deletions.
13 changes: 8 additions & 5 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ What's New
.. _whats-new.2024.05.1:

v2024.05.1 (unreleased)
v2024.06 (unreleased)
-----------------------

New Features
Expand Down Expand Up @@ -59,6 +59,10 @@ Documentation

Internal Changes
~~~~~~~~~~~~~~~~
- Migrates remainder of ``io.py`` to ``xarray/core/datatree_io.py`` and
``TreeAttrAccessMixin`` into ``xarray/core/common.py`` (:pull: `9011`)
By `Owen Littlejohns <https://github.com/owenlittlejohns>`_ and
`Tom Nicholas <https://github.com/TomNicholas>`_.


.. _whats-new.2024.05.0:
Expand Down Expand Up @@ -141,10 +145,9 @@ Internal Changes
By `Owen Littlejohns <https://github.com/owenlittlejohns>`_, `Matt Savoie
<https://github.com/flamingbear>`_ and `Tom Nicholas <https://github.com/TomNicholas>`_.
- ``transpose``, ``set_dims``, ``stack`` & ``unstack`` now use a ``dim`` kwarg
rather than ``dims`` or ``dimensions``. This is the final change to unify
xarray functions to use ``dim``. Using the existing kwarg will raise a
warning.
By `Maximilian Roos <https://github.com/max-sixty>`_
rather than ``dims`` or ``dimensions``. This is the final change to make xarray methods
consistent with their use of ``dim``. Using the existing kwarg will raise a
warning. By `Maximilian Roos <https://github.com/max-sixty>`_

.. _whats-new.2024.03.0:

Expand Down
18 changes: 9 additions & 9 deletions xarray/backends/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset, _get_chunk, _maybe_chunk
from xarray.core.indexes import Index
from xarray.core.types import ZarrWriteModes
from xarray.core.types import NetcdfWriteModes, ZarrWriteModes
from xarray.core.utils import is_remote_uri
from xarray.namedarray.daskmanager import DaskManager
from xarray.namedarray.parallelcompat import guess_chunkmanager
Expand Down Expand Up @@ -1120,7 +1120,7 @@ def open_mfdataset(
def to_netcdf(
dataset: Dataset,
path_or_file: str | os.PathLike | None = None,
mode: Literal["w", "a"] = "w",
mode: NetcdfWriteModes = "w",
format: T_NetcdfTypes | None = None,
group: str | None = None,
engine: T_NetcdfEngine | None = None,
Expand All @@ -1138,7 +1138,7 @@ def to_netcdf(
def to_netcdf(
dataset: Dataset,
path_or_file: None = None,
mode: Literal["w", "a"] = "w",
mode: NetcdfWriteModes = "w",
format: T_NetcdfTypes | None = None,
group: str | None = None,
engine: T_NetcdfEngine | None = None,
Expand All @@ -1155,7 +1155,7 @@ def to_netcdf(
def to_netcdf(
dataset: Dataset,
path_or_file: str | os.PathLike,
mode: Literal["w", "a"] = "w",
mode: NetcdfWriteModes = "w",
format: T_NetcdfTypes | None = None,
group: str | None = None,
engine: T_NetcdfEngine | None = None,
Expand All @@ -1173,7 +1173,7 @@ def to_netcdf(
def to_netcdf(
dataset: Dataset,
path_or_file: str | os.PathLike,
mode: Literal["w", "a"] = "w",
mode: NetcdfWriteModes = "w",
format: T_NetcdfTypes | None = None,
group: str | None = None,
engine: T_NetcdfEngine | None = None,
Expand All @@ -1191,7 +1191,7 @@ def to_netcdf(
def to_netcdf(
dataset: Dataset,
path_or_file: str | os.PathLike,
mode: Literal["w", "a"] = "w",
mode: NetcdfWriteModes = "w",
format: T_NetcdfTypes | None = None,
group: str | None = None,
engine: T_NetcdfEngine | None = None,
Expand All @@ -1209,7 +1209,7 @@ def to_netcdf(
def to_netcdf(
dataset: Dataset,
path_or_file: str | os.PathLike,
mode: Literal["w", "a"] = "w",
mode: NetcdfWriteModes = "w",
format: T_NetcdfTypes | None = None,
group: str | None = None,
engine: T_NetcdfEngine | None = None,
Expand All @@ -1226,7 +1226,7 @@ def to_netcdf(
def to_netcdf(
dataset: Dataset,
path_or_file: str | os.PathLike | None,
mode: Literal["w", "a"] = "w",
mode: NetcdfWriteModes = "w",
format: T_NetcdfTypes | None = None,
group: str | None = None,
engine: T_NetcdfEngine | None = None,
Expand All @@ -1241,7 +1241,7 @@ def to_netcdf(
def to_netcdf(
dataset: Dataset,
path_or_file: str | os.PathLike | None = None,
mode: Literal["w", "a"] = "w",
mode: NetcdfWriteModes = "w",
format: T_NetcdfTypes | None = None,
group: str | None = None,
engine: T_NetcdfEngine | None = None,
Expand Down
18 changes: 18 additions & 0 deletions xarray/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,24 @@ def _ipython_key_completions_(self) -> list[str]:
return list(items)


class TreeAttrAccessMixin(AttrAccessMixin):
"""Mixin class that allows getting keys with attribute access"""

# TODO: Ensure ipython tab completion can include both child datatrees and
# variables from Dataset objects on relevant nodes.

__slots__ = ()

def __init_subclass__(cls, **kwargs):
"""This method overrides the check from ``AttrAccessMixin`` that ensures
``__dict__`` is absent in a class, with ``__slots__`` used instead.
``DataTree`` has some dynamically defined attributes in addition to those
defined in ``__slots__``. (GH9068)
"""
if not hasattr(object.__new__(cls), "__dict__"):
pass


def get_squeeze_dims(
xarray_obj,
dim: Hashable | Iterable[Hashable] | None = None,
Expand Down
11 changes: 6 additions & 5 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from xarray.core.options import OPTIONS, _get_keep_attrs
from xarray.core.types import (
DaCompatible,
NetcdfWriteModes,
T_DataArray,
T_DataArrayOrSet,
ZarrWriteModes,
Expand Down Expand Up @@ -3945,7 +3946,7 @@ def to_masked_array(self, copy: bool = True) -> np.ma.MaskedArray:
def to_netcdf(
self,
path: None = None,
mode: Literal["w", "a"] = "w",
mode: NetcdfWriteModes = "w",
format: T_NetcdfTypes | None = None,
group: str | None = None,
engine: T_NetcdfEngine | None = None,
Expand All @@ -3960,7 +3961,7 @@ def to_netcdf(
def to_netcdf(
self,
path: str | PathLike,
mode: Literal["w", "a"] = "w",
mode: NetcdfWriteModes = "w",
format: T_NetcdfTypes | None = None,
group: str | None = None,
engine: T_NetcdfEngine | None = None,
Expand All @@ -3976,7 +3977,7 @@ def to_netcdf(
def to_netcdf(
self,
path: str | PathLike,
mode: Literal["w", "a"] = "w",
mode: NetcdfWriteModes = "w",
format: T_NetcdfTypes | None = None,
group: str | None = None,
engine: T_NetcdfEngine | None = None,
Expand All @@ -3992,7 +3993,7 @@ def to_netcdf(
def to_netcdf(
self,
path: str | PathLike,
mode: Literal["w", "a"] = "w",
mode: NetcdfWriteModes = "w",
format: T_NetcdfTypes | None = None,
group: str | None = None,
engine: T_NetcdfEngine | None = None,
Expand All @@ -4005,7 +4006,7 @@ def to_netcdf(
def to_netcdf(
self,
path: str | PathLike | None = None,
mode: Literal["w", "a"] = "w",
mode: NetcdfWriteModes = "w",
format: T_NetcdfTypes | None = None,
group: str | None = None,
engine: T_NetcdfEngine | None = None,
Expand Down
11 changes: 6 additions & 5 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
from xarray.core.missing import get_clean_interp_index
from xarray.core.options import OPTIONS, _get_keep_attrs
from xarray.core.types import (
NetcdfWriteModes,
QuantileMethods,
Self,
T_ChunkDim,
Expand Down Expand Up @@ -2171,7 +2172,7 @@ def dump_to_store(self, store: AbstractDataStore, **kwargs) -> None:
def to_netcdf(
self,
path: None = None,
mode: Literal["w", "a"] = "w",
mode: NetcdfWriteModes = "w",
format: T_NetcdfTypes | None = None,
group: str | None = None,
engine: T_NetcdfEngine | None = None,
Expand All @@ -2186,7 +2187,7 @@ def to_netcdf(
def to_netcdf(
self,
path: str | PathLike,
mode: Literal["w", "a"] = "w",
mode: NetcdfWriteModes = "w",
format: T_NetcdfTypes | None = None,
group: str | None = None,
engine: T_NetcdfEngine | None = None,
Expand All @@ -2202,7 +2203,7 @@ def to_netcdf(
def to_netcdf(
self,
path: str | PathLike,
mode: Literal["w", "a"] = "w",
mode: NetcdfWriteModes = "w",
format: T_NetcdfTypes | None = None,
group: str | None = None,
engine: T_NetcdfEngine | None = None,
Expand All @@ -2218,7 +2219,7 @@ def to_netcdf(
def to_netcdf(
self,
path: str | PathLike,
mode: Literal["w", "a"] = "w",
mode: NetcdfWriteModes = "w",
format: T_NetcdfTypes | None = None,
group: str | None = None,
engine: T_NetcdfEngine | None = None,
Expand All @@ -2231,7 +2232,7 @@ def to_netcdf(
def to_netcdf(
self,
path: str | PathLike | None = None,
mode: Literal["w", "a"] = "w",
mode: NetcdfWriteModes = "w",
format: T_NetcdfTypes | None = None,
group: str | None = None,
engine: T_NetcdfEngine | None = None,
Expand Down
53 changes: 47 additions & 6 deletions xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@
Any,
Callable,
Generic,
Literal,
NoReturn,
Union,
overload,
)

from xarray.core import utils
from xarray.core.common import TreeAttrAccessMixin
from xarray.core.coordinates import DatasetCoordinates
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset, DataVariables
Expand Down Expand Up @@ -46,7 +48,6 @@
maybe_wrap_array,
)
from xarray.core.variable import Variable
from xarray.datatree_.datatree.common import TreeAttrAccessMixin

try:
from xarray.core.variable import calculate_dimensions
Expand All @@ -57,8 +58,9 @@
if TYPE_CHECKING:
import pandas as pd

from xarray.core.datatree_io import T_DataTreeNetcdfEngine, T_DataTreeNetcdfTypes
from xarray.core.merge import CoercibleValue
from xarray.core.types import ErrorOptions
from xarray.core.types import ErrorOptions, NetcdfWriteModes, ZarrWriteModes

# """
# DEVELOPERS' NOTE
Expand Down Expand Up @@ -1475,7 +1477,16 @@ def groups(self):
return tuple(node.path for node in self.subtree)

def to_netcdf(
self, filepath, mode: str = "w", encoding=None, unlimited_dims=None, **kwargs
self,
filepath,
mode: NetcdfWriteModes = "w",
encoding=None,
unlimited_dims=None,
format: T_DataTreeNetcdfTypes | None = None,
engine: T_DataTreeNetcdfEngine | None = None,
group: str | None = None,
compute: bool = True,
**kwargs,
):
"""
Write datatree contents to a netCDF file.
Expand All @@ -1499,26 +1510,47 @@ def to_netcdf(
By default, no dimensions are treated as unlimited dimensions.
Note that unlimited_dims may also be set via
``dataset.encoding["unlimited_dims"]``.
format : {"NETCDF4", }, optional
File format for the resulting netCDF file:
* NETCDF4: Data is stored in an HDF5 file, using netCDF4 API features.
engine : {"netcdf4", "h5netcdf"}, optional
Engine to use when writing netCDF files. If not provided, the
default engine is chosen based on available dependencies, with a
preference for "netcdf4" if writing to a file on disk.
group : str, optional
Path to the netCDF4 group in the given file to open as the root group
of the ``DataTree``. Currently, specifying a group is not supported.
compute : bool, default: True
If true compute immediately, otherwise return a
``dask.delayed.Delayed`` object that can be computed later.
Currently, ``compute=False`` is not supported.
kwargs :
Addional keyword arguments to be passed to ``xarray.Dataset.to_netcdf``
"""
from xarray.datatree_.datatree.io import _datatree_to_netcdf
from xarray.core.datatree_io import _datatree_to_netcdf

_datatree_to_netcdf(
self,
filepath,
mode=mode,
encoding=encoding,
unlimited_dims=unlimited_dims,
format=format,
engine=engine,
group=group,
compute=compute,
**kwargs,
)

def to_zarr(
self,
store,
mode: str = "w-",
mode: ZarrWriteModes = "w-",
encoding=None,
consolidated: bool = True,
group: str | None = None,
compute: Literal[True] = True,
**kwargs,
):
"""
Expand All @@ -1541,17 +1573,26 @@ def to_zarr(
consolidated : bool
If True, apply zarr's `consolidate_metadata` function to the store
after writing metadata for all groups.
group : str, optional
Group path. (a.k.a. `path` in zarr terminology.)
compute : bool, default: True
If true compute immediately, otherwise return a
``dask.delayed.Delayed`` object that can be computed later. Metadata
is always updated eagerly. Currently, ``compute=False`` is not
supported.
kwargs :
Additional keyword arguments to be passed to ``xarray.Dataset.to_zarr``
"""
from xarray.datatree_.datatree.io import _datatree_to_zarr
from xarray.core.datatree_io import _datatree_to_zarr

_datatree_to_zarr(
self,
store,
mode=mode,
encoding=encoding,
consolidated=consolidated,
group=group,
compute=compute,
**kwargs,
)

Expand Down
Loading

0 comments on commit cb3663d

Please sign in to comment.