Skip to content

Commit

Permalink
update for datatree 0.5 (#40)
Browse files Browse the repository at this point in the history
  • Loading branch information
Joe Hamman authored May 31, 2022
1 parent 3cf38b8 commit 41f2bed
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 17 deletions.
2 changes: 1 addition & 1 deletion ci/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@ dependencies:
- scipy
- sparse>=0.13.0
- xarray
- xarray-datatree>=0.0.4
- xarray-datatree>=0.0.5
- xesmf
- zarr
21 changes: 12 additions & 9 deletions ndpyramid/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,16 @@ def pyramid_coarsen(
}

# set up pyramid
root = xr.Dataset(attrs=attrs)
pyramid = dt.DataTree(data=root, name='root')
levels = {}

# pyramid data
for key, factor in enumerate(factors):

skey = str(key)
kwargs.update({d: factor for d in dims})
pyramid[skey] = ds.coarsen(**kwargs).mean()
levels[str(key)] = ds.coarsen(**kwargs).mean()

pyramid = dt.DataTree.from_dict(levels)
pyramid.ds = xr.Dataset(attrs=attrs)

return pyramid

Expand Down Expand Up @@ -109,8 +110,7 @@ def pyramid_reproject(
resampling_dict = resampling

# set up pyramid
root = xr.Dataset(attrs=attrs)
pyramid = dt.DataTree(data=root, name='root')
plevels = {}

# pyramid data
for level in range(levels):
Expand All @@ -128,7 +128,7 @@ def reproject(da, var):
transform=dst_transform,
)

pyramid[lkey] = xr.Dataset(attrs=ds.attrs)
plevels[lkey] = xr.Dataset(attrs=ds.attrs)
for k, da in ds.items():
if len(da.shape) == 4:
if extra_dim is None:
Expand All @@ -137,9 +137,12 @@ def reproject(da, var):
for index in ds[extra_dim]:
da_reprojected = reproject(da.sel({extra_dim: index}), k)
da_all.append(da_reprojected)
pyramid[lkey].ds[k] = xr.concat(da_all, ds[extra_dim])
plevels[lkey][k] = xr.concat(da_all, ds[extra_dim])
else:
pyramid[lkey].ds[k] = reproject(da, k)
plevels[lkey][k] = reproject(da, k)

pyramid = dt.DataTree.from_dict(plevels)
pyramid.ds = xr.Dataset(attrs=attrs)

pyramid = add_metadata_and_zarr_encoding(
pyramid, levels=levels, pixels_per_tile=pixels_per_tile, other_chunks=other_chunks
Expand Down
18 changes: 12 additions & 6 deletions ndpyramid/regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,11 @@ def make_grid_pyramid(levels: int = 6) -> dt.DataTree:
pyramid : dt.DataTree
Multiscale grid definition
"""
data = dt.DataTree()
plevels = {}
for level in range(levels):
data[str(level)] = make_grid_ds(level).chunk(-1)
plevels[str(level)] = make_grid_ds(level).chunk(-1)
data = dt.DataTree.from_dict(plevels)

return data


Expand Down Expand Up @@ -152,14 +154,15 @@ def generate_weights_pyramid(
regridder_kws = {} if regridder_kws is None else regridder_kws
regridder_kws = {'periodic': True, **regridder_kws}

weights_pyramid = datatree.DataTree()
plevels = {}
for level in range(levels):
ds_out = make_grid_ds(level=level)
regridder = xe.Regridder(ds_in, ds_out, method, **regridder_kws)
ds = xesmf_weights_to_xarray(regridder)

weights_pyramid[str(level)] = ds
plevels[str(level)] = ds

weights_pyramid = datatree.DataTree.from_dict(plevels)
weights_pyramid.ds.attrs['levels'] = levels
weights_pyramid.ds.attrs['regrid_method'] = method

Expand Down Expand Up @@ -238,7 +241,7 @@ def pyramid_regrid(

# set up pyramid
root = xr.Dataset(attrs=attrs)
pyramid = dt.DataTree(data=root, name='root')
plevels = {}

# pyramid data
for level in range(levels):
Expand All @@ -260,7 +263,10 @@ def pyramid_regrid(
if regridder_apply_kws is None:
regridder_apply_kws = {}
regridder_apply_kws = {**{'keep_attrs': True}, **regridder_apply_kws}
pyramid[str(level)] = regridder(ds, **regridder_apply_kws)
plevels[str(level)] = regridder(ds, **regridder_apply_kws)

pyramid = dt.DataTree.from_dict(plevels)
pyramid.ds = xr.Dataset(attrs=attrs)

pyramid = add_metadata_and_zarr_encoding(
pyramid, levels=levels, other_chunks=other_chunks, pixels_per_tile=pixels_per_tile
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
cf_xarray
numpy
xarray
xarray-datatree >= 0.0.4
xarray-datatree >= 0.0.5
zarr

0 comments on commit 41f2bed

Please sign in to comment.