Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use unbound regridding method for dask wrapping #39

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ New features
`conservative_normed` to properly handle normalization (:pull:`1`).
By `Raphael Dussin <https://github.com/raphaeldussin>`_

Bug fixes
~~~~~~~~~
* Fix serialization bug when using dask's distributed scheduler (:pull:`39`).
By `Pascal Bourgault <https://github.com/aulemahal>`_.



0.3.0 (06-03-2020)
Expand Down
29 changes: 19 additions & 10 deletions xesmf/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,14 +400,20 @@ def __call__(self, indata, keep_attrs=False):
"input must be numpy array, dask array, "
"xarray DataArray or Dataset!")

def regrid_numpy(self, indata):
"""See __call__()."""

if self.locstream_in:
@staticmethod
def _regrid_array(indata, *, weights, shape_in, shape_out, locstream_in):
if locstream_in:
indata = np.expand_dims(indata, axis=-2)
return apply_weights(weights, indata, shape_in, shape_out)

outdata = apply_weights(self.weights, indata,
self.shape_in, self.shape_out)
@property
def _regrid_kwargs(self):
return {'weights': self.weights, 'locstream_in': self.locstream_in,
'shape_in': self.shape_in, 'shape_out': self.shape_out}

def regrid_numpy(self, indata):
"""See __call__()."""
outdata = self._regrid_array(indata, **self._regrid_kwargs)
return outdata

def regrid_dask(self, indata):
Expand All @@ -418,10 +424,11 @@ def regrid_dask(self, indata):
output_chunk_shape = extra_chunk_shape + self.shape_out

outdata = da.map_blocks(
self.regrid_numpy,
self._regrid_array,
indata,
dtype=float,
chunks=output_chunk_shape
chunks=output_chunk_shape,
**self._regrid_kwargs
)

return outdata
Expand All @@ -448,7 +455,8 @@ def regrid_dataarray(self, dr_in, keep_attrs=False):


dr_out = xr.apply_ufunc(
self.regrid_numpy, dr_in,
self._regrid_array, dr_in,
kwargs=self._regrid_kwargs,
input_core_dims=[input_horiz_dims],
output_core_dims=[temp_horiz_dims],
dask='parallelized',
Expand Down Expand Up @@ -513,7 +521,8 @@ def regrid_dataset(self, ds_in, keep_attrs=False):
)

ds_out = xr.apply_ufunc(
self.regrid_numpy, ds_in,
self._regrid_array, ds_in,
kwargs=self._regrid_kwargs,
input_core_dims=[input_horiz_dims],
output_core_dims=[temp_horiz_dims],
dask='parallelized',
Expand Down