Skip to content

Commit

Permalink
fix: support new datatree (#238)
Browse files Browse the repository at this point in the history
  • Loading branch information
slevang authored Oct 28, 2024
1 parent af13de2 commit 397d15a
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 60 deletions.
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,11 @@ requires-python = ">=3.10"
dependencies = [
"numpy>=1.24",
"pandas>=2",
"xarray>=2023.04.0",
"xarray>=2024.10.0",
"scikit-learn>=1.0.2",
"tqdm>=4.64.0",
"dask>=2023.0.1",
"typing-extensions>=4.8.0",
"xarray-datatree>=0.0.12",
]

[project.optional-dependencies]
Expand Down
15 changes: 5 additions & 10 deletions xeofs/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,6 @@
from .utils.io import insert_placeholders, open_model_tree, write_model_tree
from .utils.xarray_utils import data_is_dask

try:
from xarray.core.datatree import DataTree # type: ignore
except ImportError:
from datatree import DataTree

# Ignore warnings from numpy casting with additional coordinates
warnings.filterwarnings("ignore", message=r"^invalid value encountered in cast*")

Expand Down Expand Up @@ -80,7 +75,7 @@ def compute(self, **kwargs):
(data_objs,) = dask.base.compute(data_objs, **kwargs)

for k, v in data_objs.items():
dt[k] = DataTree(v)
dt[k] = xr.DataTree(v)

# then rebuild the trained model from the computed results
self._deserialize_attrs(dt)
Expand All @@ -94,11 +89,11 @@ def get_params(self) -> dict[str, Any]:
"""Get the model parameters."""
return self._params

def serialize(self) -> DataTree:
def serialize(self) -> xr.DataTree:
"""Serialize a complete model with its preprocessor."""
# Create a root node for this object with its params as attrs
ds_root = xr.Dataset(attrs=dict(params=self.get_params()))
dt = DataTree(ds_root, name=type(self).__name__)
dt = xr.DataTree(ds_root, name=type(self).__name__)

# Retrieve the tree representation of each attached object, or set basic attrs
for key, attr in self.get_serialization_attrs().items():
Expand Down Expand Up @@ -149,14 +144,14 @@ def save(
write_model_tree(dt, path, overwrite=overwrite, engine=engine, **kwargs)

@classmethod
def deserialize(cls, dt: DataTree) -> Self:
def deserialize(cls, dt: xr.DataTree) -> Self:
"""Deserialize the model and its preprocessors from a DataTree."""
# Recreate the model with parameters set by root level attrs
model = cls(**dt.attrs["params"])
model._deserialize_attrs(dt)
return model

def _deserialize_attrs(self, dt: DataTree):
def _deserialize_attrs(self, dt: xr.DataTree):
"""Set the necessary attributes of the model from a DataTree."""
for key, attr in dt.attrs.items():
if key == "params":
Expand Down
14 changes: 5 additions & 9 deletions xeofs/data_container/data_container.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
import dask
import xarray as xr
from typing_extensions import Self

try:
from xarray.core.datatree import DataTree
except ImportError:
from datatree import DataTree

from ..utils.data_types import DataArray


Expand All @@ -31,18 +27,18 @@ def __getitem__(self, __key: str) -> DataArray:
f"Cannot find data '{__key}'. Please fit the model first by calling .fit()."
)

def serialize(self) -> DataTree:
dt = DataTree(name="data")
def serialize(self) -> xr.DataTree:
dt = xr.DataTree(name="data")
for key, data in self.items():
if not data.name:
data.name = key
dt[key] = DataTree(data.to_dataset())
dt[key] = xr.DataTree(data.to_dataset())
dt[key].attrs = {key: "_is_node", "allow_compute": self._allow_compute[key]}

return dt

@classmethod
def deserialize(cls, dt: DataTree) -> Self:
def deserialize(cls, dt: xr.DataTree) -> Self:
container = cls()
for key, node in dt.items():
container[key] = node[key]
Expand Down
14 changes: 5 additions & 9 deletions xeofs/preprocessing/preprocessor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import xarray as xr
from typing_extensions import Self

from ..utils.data_types import (
Expand All @@ -23,11 +24,6 @@
from .stacker import Stacker
from .transformer import Transformer

try:
from xarray.core.datatree import DataTree
except ImportError:
from datatree import DataTree


def extract_new_dim_names(X: list[DimensionRenamer]) -> tuple[Dims, DimsList]:
"""Extract the new dimension names from a list of DimensionRenamer objects.
Expand Down Expand Up @@ -369,7 +365,7 @@ def _set_return_list(self, X):
else:
self.return_list = False

def serialize(self) -> DataTree:
def serialize(self) -> xr.DataTree:
"""Serialize the necessary attributes of the fitted pre-processor
and all transformers to a Dataset."""
# Serialize the preprocessor as the root node
Expand All @@ -381,9 +377,9 @@ def serialize(self) -> DataTree:
transformers = self.get_transformers()

for name, transformer_obj in zip(names, transformers):
dt_transformer = DataTree()
dt_transformer = xr.DataTree()
if isinstance(transformer_obj, GenericListTransformer):
dt_transformer["transformers"] = DataTree()
dt_transformer["transformers"] = xr.DataTree()
# Loop through list transformer objects and assign a dummy key
for i, transformer in enumerate(transformer_obj.transformers):
dt_transformer.transformers[str(i)] = transformer.serialize()
Expand All @@ -395,7 +391,7 @@ def serialize(self) -> DataTree:
return dt

@classmethod
def deserialize(cls, dt: DataTree) -> Self:
def deserialize(cls, dt: xr.DataTree) -> Self:
"""Deserialize from a DataTree representation of the preprocessor
and all attached Transformers."""
# Create the parent preprocessor
Expand Down
29 changes: 12 additions & 17 deletions xeofs/preprocessing/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,6 @@
from sklearn.base import BaseEstimator, TransformerMixin
from typing_extensions import Self

try:
from xarray.core.datatree import DataTree
except ImportError:
from datatree import DataTree

from ..utils.data_types import Data, DataArray, DataSet, Dims


Expand Down Expand Up @@ -120,15 +115,15 @@ def _serialize_data(self, key: str, data: Data) -> DataSet:

return ds

def serialize(self) -> DataTree:
def serialize(self) -> xr.DataTree:
"""Serialize a transformer to a DataTree."""
return self._serialize()

def _serialize(self) -> DataTree:
def _serialize(self) -> xr.DataTree:
"""Serialize a transformer to a DataTree. Use an internal
method so we can override the public one in subclasesses but
still use this."""
dt = DataTree()
dt = xr.DataTree()
params = self.get_params()
attrs = self.get_serialization_attrs()

Expand All @@ -140,16 +135,16 @@ def _serialize(self) -> DataTree:
if isinstance(attr, (xr.DataArray, xr.Dataset)):
# attach data to data_vars or coords
ds = self._serialize_data(key, attr)
dt[key] = DataTree(ds, name=key)
dt[key] = xr.DataTree(ds, name=key)
dt.attrs[key] = "_is_node"
elif isinstance(attr, dict) and any(
[isinstance(val, xr.DataArray) for val in attr.values()]
):
# attach dict of data as branching tree
dt_attr = DataTree()
dt_attr = xr.DataTree()
for k, v in attr.items():
ds = self._serialize_data(k, v)
dt_attr[k] = DataTree(ds, name=k)
dt_attr[k] = xr.DataTree(ds, name=k)
dt[key] = dt_attr
dt.attrs[key] = "_is_tree"
else:
Expand All @@ -158,24 +153,24 @@ def _serialize(self) -> DataTree:

return dt

def _deserialize_data_node(self, key: str, dt: DataTree) -> DataArray:
def _deserialize_data_node(self, key: str, dt: xr.DataTree) -> DataArray:
# Rebuild multiindexes
dt = dt.set_index(dt.attrs.get("multiindexes", {}))
dt.dataset = dt.dataset.set_index(dt.attrs.get("multiindexes", {}))
# Extract the DataArray or coord from the Dataset
data_key = dt.attrs["name_map"][key]
if data_key is not None:
return dt[data_key]
else:
return dt.ds
return dt.dataset

@classmethod
def deserialize(cls, dt: DataTree) -> Self:
def deserialize(cls, dt: xr.DataTree) -> Self:
"""Deserialize a saved transformer from a DataTree."""
return cls._deserialize(dt)

@classmethod
def _deserialize(cls, dt: DataTree) -> Self:
"""Deserialize a saved transformer from a DataTree. Use an internal
def _deserialize(cls, dt: xr.DataTree) -> Self:
"""Deserialize a saved transformer from a xr.DataTree. Use an internal
method so we can override the public one in subclasesses but
still use this."""
# Create the object from params
Expand Down
20 changes: 7 additions & 13 deletions xeofs/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,9 @@
import numpy as np
import xarray as xr

try:
from xarray.core.datatree import DataTree
from xarray.backends.api import open_datatree
except ImportError:
from datatree import DataTree, open_datatree


def write_model_tree(
dt: DataTree, path: str, overwrite: bool = False, engine: str = "zarr", **kwargs
dt: xr.DataTree, path: str, overwrite: bool = False, engine: str = "zarr", **kwargs
):
"""Write a DataTree to a file."""
write_mode = "w" if overwrite else "w-"
Expand All @@ -25,21 +19,21 @@ def write_model_tree(
raise ValueError(f"Unknown engine {engine}")


def open_model_tree(path: str, engine: str = "zarr", **kwargs) -> DataTree:
def open_model_tree(path: str, engine: str = "zarr", **kwargs) -> xr.DataTree:
"""Open a DataTree from a file."""
if engine == "zarr" and "chunks" not in kwargs:
kwargs["chunks"] = {}
dt = open_datatree(path, engine=engine, **kwargs)
dt = xr.open_datatree(path, engine=engine, **kwargs)
if engine in ["netcdf4", "h5netcdf"]:
dt = _desanitize_attrs_nc(dt)
return dt


def insert_placeholders(dt: DataTree) -> DataTree:
def insert_placeholders(dt: xr.DataTree) -> xr.DataTree:
"""Insert placeholders for data that we don't want to compute."""
for node in dt.subtree:
if not node.attrs.get("allow_compute", True):
dt[node.path] = DataTree(
dt[node.path] = xr.DataTree(
xr.Dataset(
data_vars={
node.name: xr.DataArray(np.nan, attrs={"placeholder": True})
Expand All @@ -50,7 +44,7 @@ def insert_placeholders(dt: DataTree) -> DataTree:
return dt


def _sanitize_attrs_nc(dt: DataTree) -> DataTree:
def _sanitize_attrs_nc(dt: xr.DataTree) -> xr.DataTree:
"""Sanitize both node-level and variable-level attrs to strings for netcdf."""
sanitized_types = (dict, list, bool, type(None))
for node in dt.subtree:
Expand All @@ -76,7 +70,7 @@ def _should_desanitize(attr: Any) -> bool:
return False


def _desanitize_attrs_nc(dt: DataTree) -> DataTree:
def _desanitize_attrs_nc(dt: xr.DataTree) -> xr.DataTree:
"""Desanitize both node-level and variable-level attrs from strings for netcdf."""
for node in dt.subtree:
for key, attr in node.attrs.items():
Expand Down

0 comments on commit 397d15a

Please sign in to comment.