From 397d15a5e2a9a766ee624923b48dca296c5cf24f Mon Sep 17 00:00:00 2001 From: Sam Levang <39069044+slevang@users.noreply.github.com> Date: Mon, 28 Oct 2024 15:58:02 +0000 Subject: [PATCH] fix: support new datatree (#238) --- pyproject.toml | 3 +-- xeofs/base_model.py | 15 +++++-------- xeofs/data_container/data_container.py | 14 +++++-------- xeofs/preprocessing/preprocessor.py | 14 +++++-------- xeofs/preprocessing/transformer.py | 29 +++++++++++--------------- xeofs/utils/io.py | 20 +++++++----------- 6 files changed, 35 insertions(+), 60 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 91d07ab..6c28e2d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,12 +13,11 @@ requires-python = ">=3.10" dependencies = [ "numpy>=1.24", "pandas>=2", - "xarray>=2023.04.0", + "xarray>=2024.10.0", "scikit-learn>=1.0.2", "tqdm>=4.64.0", "dask>=2023.0.1", "typing-extensions>=4.8.0", - "xarray-datatree>=0.0.12", ] [project.optional-dependencies] diff --git a/xeofs/base_model.py b/xeofs/base_model.py index dffb9c5..001925a 100644 --- a/xeofs/base_model.py +++ b/xeofs/base_model.py @@ -13,11 +13,6 @@ from .utils.io import insert_placeholders, open_model_tree, write_model_tree from .utils.xarray_utils import data_is_dask -try: - from xarray.core.datatree import DataTree # type: ignore -except ImportError: - from datatree import DataTree - # Ignore warnings from numpy casting with additional coordinates warnings.filterwarnings("ignore", message=r"^invalid value encountered in cast*") @@ -80,7 +75,7 @@ def compute(self, **kwargs): (data_objs,) = dask.base.compute(data_objs, **kwargs) for k, v in data_objs.items(): - dt[k] = DataTree(v) + dt[k] = xr.DataTree(v) # then rebuild the trained model from the computed results self._deserialize_attrs(dt) @@ -94,11 +89,11 @@ def get_params(self) -> dict[str, Any]: """Get the model parameters.""" return self._params - def serialize(self) -> DataTree: + def serialize(self) -> xr.DataTree: """Serialize a complete model with its preprocessor.""" # Create a root node for this object with its params as attrs ds_root = xr.Dataset(attrs=dict(params=self.get_params())) - dt = DataTree(ds_root, name=type(self).__name__) + dt = xr.DataTree(ds_root, name=type(self).__name__) # Retrieve the tree representation of each attached object, or set basic attrs for key, attr in self.get_serialization_attrs().items(): @@ -149,14 +144,14 @@ def save( write_model_tree(dt, path, overwrite=overwrite, engine=engine, **kwargs) @classmethod - def deserialize(cls, dt: DataTree) -> Self: + def deserialize(cls, dt: xr.DataTree) -> Self: """Deserialize the model and its preprocessors from a DataTree.""" # Recreate the model with parameters set by root level attrs model = cls(**dt.attrs["params"]) model._deserialize_attrs(dt) return model - def _deserialize_attrs(self, dt: DataTree): + def _deserialize_attrs(self, dt: xr.DataTree): """Set the necessary attributes of the model from a DataTree.""" for key, attr in dt.attrs.items(): if key == "params": diff --git a/xeofs/data_container/data_container.py b/xeofs/data_container/data_container.py index 2729e02..4e3d356 100644 --- a/xeofs/data_container/data_container.py +++ b/xeofs/data_container/data_container.py @@ -1,11 +1,7 @@ import dask +import xarray as xr from typing_extensions import Self -try: - from xarray.core.datatree import DataTree -except ImportError: - from datatree import DataTree - from ..utils.data_types import DataArray @@ -31,18 +27,18 @@ def __getitem__(self, __key: str) -> DataArray: f"Cannot find data '{__key}'. Please fit the model first by calling .fit()." ) - def serialize(self) -> DataTree: - dt = DataTree(name="data") + def serialize(self) -> xr.DataTree: + dt = xr.DataTree(name="data") for key, data in self.items(): if not data.name: data.name = key - dt[key] = DataTree(data.to_dataset()) + dt[key] = xr.DataTree(data.to_dataset()) dt[key].attrs = {key: "_is_node", "allow_compute": self._allow_compute[key]} return dt @classmethod - def deserialize(cls, dt: DataTree) -> Self: + def deserialize(cls, dt: xr.DataTree) -> Self: container = cls() for key, node in dt.items(): container[key] = node[key] diff --git a/xeofs/preprocessing/preprocessor.py b/xeofs/preprocessing/preprocessor.py index 9d8a9d5..a2c422d 100644 --- a/xeofs/preprocessing/preprocessor.py +++ b/xeofs/preprocessing/preprocessor.py @@ -1,4 +1,5 @@ import numpy as np +import xarray as xr from typing_extensions import Self from ..utils.data_types import ( @@ -23,11 +24,6 @@ from .stacker import Stacker from .transformer import Transformer -try: - from xarray.core.datatree import DataTree -except ImportError: - from datatree import DataTree - def extract_new_dim_names(X: list[DimensionRenamer]) -> tuple[Dims, DimsList]: """Extract the new dimension names from a list of DimensionRenamer objects. @@ -369,7 +365,7 @@ def _set_return_list(self, X): else: self.return_list = False - def serialize(self) -> DataTree: + def serialize(self) -> xr.DataTree: """Serialize the necessary attributes of the fitted pre-processor and all transformers to a Dataset.""" # Serialize the preprocessor as the root node @@ -381,9 +377,9 @@ def serialize(self) -> DataTree: transformers = self.get_transformers() for name, transformer_obj in zip(names, transformers): - dt_transformer = DataTree() + dt_transformer = xr.DataTree() if isinstance(transformer_obj, GenericListTransformer): - dt_transformer["transformers"] = DataTree() + dt_transformer["transformers"] = xr.DataTree() # Loop through list transformer objects and assign a dummy key for i, transformer in enumerate(transformer_obj.transformers): dt_transformer.transformers[str(i)] = transformer.serialize() @@ -395,7 +391,7 @@ def serialize(self) -> DataTree: return dt @classmethod - def deserialize(cls, dt: DataTree) -> Self: + def deserialize(cls, dt: xr.DataTree) -> Self: """Deserialize from a DataTree representation of the preprocessor and all attached Transformers.""" # Create the parent preprocessor diff --git a/xeofs/preprocessing/transformer.py b/xeofs/preprocessing/transformer.py index 6666428..92d2b9e 100644 --- a/xeofs/preprocessing/transformer.py +++ b/xeofs/preprocessing/transformer.py @@ -5,11 +5,6 @@ from sklearn.base import BaseEstimator, TransformerMixin from typing_extensions import Self -try: - from xarray.core.datatree import DataTree -except ImportError: - from datatree import DataTree - from ..utils.data_types import Data, DataArray, DataSet, Dims @@ -120,15 +115,15 @@ def _serialize_data(self, key: str, data: Data) -> DataSet: return ds - def serialize(self) -> DataTree: + def serialize(self) -> xr.DataTree: """Serialize a transformer to a DataTree.""" return self._serialize() - def _serialize(self) -> DataTree: + def _serialize(self) -> xr.DataTree: """Serialize a transformer to a DataTree. Use an internal method so we can override the public one in subclasesses but still use this.""" - dt = DataTree() + dt = xr.DataTree() params = self.get_params() attrs = self.get_serialization_attrs() @@ -140,16 +135,16 @@ def _serialize(self) -> DataTree: if isinstance(attr, (xr.DataArray, xr.Dataset)): # attach data to data_vars or coords ds = self._serialize_data(key, attr) - dt[key] = DataTree(ds, name=key) + dt[key] = xr.DataTree(ds, name=key) dt.attrs[key] = "_is_node" elif isinstance(attr, dict) and any( [isinstance(val, xr.DataArray) for val in attr.values()] ): # attach dict of data as branching tree - dt_attr = DataTree() + dt_attr = xr.DataTree() for k, v in attr.items(): ds = self._serialize_data(k, v) - dt_attr[k] = DataTree(ds, name=k) + dt_attr[k] = xr.DataTree(ds, name=k) dt[key] = dt_attr dt.attrs[key] = "_is_tree" else: @@ -158,24 +153,24 @@ def _serialize(self) -> DataTree: return dt - def _deserialize_data_node(self, key: str, dt: DataTree) -> DataArray: + def _deserialize_data_node(self, key: str, dt: xr.DataTree) -> DataArray: # Rebuild multiindexes - dt = dt.set_index(dt.attrs.get("multiindexes", {})) + dt.dataset = dt.dataset.set_index(dt.attrs.get("multiindexes", {})) # Extract the DataArray or coord from the Dataset data_key = dt.attrs["name_map"][key] if data_key is not None: return dt[data_key] else: - return dt.ds + return dt.dataset @classmethod - def deserialize(cls, dt: DataTree) -> Self: + def deserialize(cls, dt: xr.DataTree) -> Self: """Deserialize a saved transformer from a DataTree.""" return cls._deserialize(dt) @classmethod - def _deserialize(cls, dt: DataTree) -> Self: - """Deserialize a saved transformer from a DataTree. Use an internal + def _deserialize(cls, dt: xr.DataTree) -> Self: + """Deserialize a saved transformer from a xr.DataTree. Use an internal method so we can override the public one in subclasesses but still use this.""" # Create the object from params diff --git a/xeofs/utils/io.py b/xeofs/utils/io.py index 0cd6932..12ea2b4 100644 --- a/xeofs/utils/io.py +++ b/xeofs/utils/io.py @@ -4,15 +4,9 @@ import numpy as np import xarray as xr -try: - from xarray.core.datatree import DataTree - from xarray.backends.api import open_datatree -except ImportError: - from datatree import DataTree, open_datatree - def write_model_tree( - dt: DataTree, path: str, overwrite: bool = False, engine: str = "zarr", **kwargs + dt: xr.DataTree, path: str, overwrite: bool = False, engine: str = "zarr", **kwargs ): """Write a DataTree to a file.""" write_mode = "w" if overwrite else "w-" @@ -25,21 +19,21 @@ def write_model_tree( raise ValueError(f"Unknown engine {engine}") -def open_model_tree(path: str, engine: str = "zarr", **kwargs) -> DataTree: +def open_model_tree(path: str, engine: str = "zarr", **kwargs) -> xr.DataTree: """Open a DataTree from a file.""" if engine == "zarr" and "chunks" not in kwargs: kwargs["chunks"] = {} - dt = open_datatree(path, engine=engine, **kwargs) + dt = xr.open_datatree(path, engine=engine, **kwargs) if engine in ["netcdf4", "h5netcdf"]: dt = _desanitize_attrs_nc(dt) return dt -def insert_placeholders(dt: DataTree) -> DataTree: +def insert_placeholders(dt: xr.DataTree) -> xr.DataTree: """Insert placeholders for data that we don't want to compute.""" for node in dt.subtree: if not node.attrs.get("allow_compute", True): - dt[node.path] = DataTree( + dt[node.path] = xr.DataTree( xr.Dataset( data_vars={ node.name: xr.DataArray(np.nan, attrs={"placeholder": True}) @@ -50,7 +44,7 @@ def insert_placeholders(dt: DataTree) -> DataTree: return dt -def _sanitize_attrs_nc(dt: DataTree) -> DataTree: +def _sanitize_attrs_nc(dt: xr.DataTree) -> xr.DataTree: """Sanitize both node-level and variable-level attrs to strings for netcdf.""" sanitized_types = (dict, list, bool, type(None)) for node in dt.subtree: @@ -76,7 +70,7 @@ def _should_desanitize(attr: Any) -> bool: return False -def _desanitize_attrs_nc(dt: DataTree) -> DataTree: +def _desanitize_attrs_nc(dt: xr.DataTree) -> xr.DataTree: """Desanitize both node-level and variable-level attrs from strings for netcdf.""" for node in dt.subtree: for key, attr in node.attrs.items():