diff --git a/ci/environment.yml b/ci/environment.yml index a040881..eb54cec 100644 --- a/ci/environment.yml +++ b/ci/environment.yml @@ -18,6 +18,6 @@ dependencies: - scipy - sparse>=0.13.0 - xarray - - xarray-datatree>=0.0.4 + - xarray-datatree>=0.0.5 - xesmf - zarr diff --git a/ndpyramid/core.py b/ndpyramid/core.py index 2a3d292..b73cc30 100644 --- a/ndpyramid/core.py +++ b/ndpyramid/core.py @@ -40,15 +40,16 @@ def pyramid_coarsen( } # set up pyramid - root = xr.Dataset(attrs=attrs) - pyramid = dt.DataTree(data=root, name='root') + levels = {} # pyramid data for key, factor in enumerate(factors): - skey = str(key) kwargs.update({d: factor for d in dims}) - pyramid[skey] = ds.coarsen(**kwargs).mean() + levels[str(key)] = ds.coarsen(**kwargs).mean() + + pyramid = dt.DataTree.from_dict(levels) + pyramid.ds = xr.Dataset(attrs=attrs) return pyramid @@ -109,8 +110,7 @@ def pyramid_reproject( resampling_dict = resampling # set up pyramid - root = xr.Dataset(attrs=attrs) - pyramid = dt.DataTree(data=root, name='root') + plevels = {} # pyramid data for level in range(levels): @@ -128,7 +128,7 @@ def reproject(da, var): transform=dst_transform, ) - pyramid[lkey] = xr.Dataset(attrs=ds.attrs) + plevels[lkey] = xr.Dataset(attrs=ds.attrs) for k, da in ds.items(): if len(da.shape) == 4: if extra_dim is None: @@ -137,9 +137,12 @@ def reproject(da, var): for index in ds[extra_dim]: da_reprojected = reproject(da.sel({extra_dim: index}), k) da_all.append(da_reprojected) - pyramid[lkey].ds[k] = xr.concat(da_all, ds[extra_dim]) + plevels[lkey][k] = xr.concat(da_all, ds[extra_dim]) else: - pyramid[lkey].ds[k] = reproject(da, k) + plevels[lkey][k] = reproject(da, k) + + pyramid = dt.DataTree.from_dict(plevels) + pyramid.ds = xr.Dataset(attrs=attrs) pyramid = add_metadata_and_zarr_encoding( pyramid, levels=levels, pixels_per_tile=pixels_per_tile, other_chunks=other_chunks diff --git a/ndpyramid/regrid.py b/ndpyramid/regrid.py index 9fc6f14..a20627a 100644 --- a/ndpyramid/regrid.py +++ b/ndpyramid/regrid.py @@ -119,9 +119,11 @@ def make_grid_pyramid(levels: int = 6) -> dt.DataTree: pyramid : dt.DataTree Multiscale grid definition """ - data = dt.DataTree() + plevels = {} for level in range(levels): - data[str(level)] = make_grid_ds(level).chunk(-1) + plevels[str(level)] = make_grid_ds(level).chunk(-1) + data = dt.DataTree.from_dict(plevels) + return data @@ -152,14 +154,15 @@ def generate_weights_pyramid( regridder_kws = {} if regridder_kws is None else regridder_kws regridder_kws = {'periodic': True, **regridder_kws} - weights_pyramid = datatree.DataTree() + plevels = {} for level in range(levels): ds_out = make_grid_ds(level=level) regridder = xe.Regridder(ds_in, ds_out, method, **regridder_kws) ds = xesmf_weights_to_xarray(regridder) - weights_pyramid[str(level)] = ds + plevels[str(level)] = ds + weights_pyramid = datatree.DataTree.from_dict(plevels) weights_pyramid.ds.attrs['levels'] = levels weights_pyramid.ds.attrs['regrid_method'] = method @@ -238,7 +241,7 @@ def pyramid_regrid( # set up pyramid root = xr.Dataset(attrs=attrs) - pyramid = dt.DataTree(data=root, name='root') + plevels = {} # pyramid data for level in range(levels): @@ -260,7 +263,10 @@ def pyramid_regrid( if regridder_apply_kws is None: regridder_apply_kws = {} regridder_apply_kws = {**{'keep_attrs': True}, **regridder_apply_kws} - pyramid[str(level)] = regridder(ds, **regridder_apply_kws) + plevels[str(level)] = regridder(ds, **regridder_apply_kws) + + pyramid = dt.DataTree.from_dict(plevels) + pyramid.ds = xr.Dataset(attrs=attrs) pyramid = add_metadata_and_zarr_encoding( pyramid, levels=levels, other_chunks=other_chunks, pixels_per_tile=pixels_per_tile diff --git a/requirements.txt b/requirements.txt index d1f2e41..8ba9292 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ cf_xarray numpy xarray -xarray-datatree >= 0.0.4 +xarray-datatree >= 0.0.5 zarr