-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rewrite interp to use apply_ufunc
#9881
Conversation
@@ -4127,18 +4119,6 @@ def interp( | |||
|
|||
coords = either_dict_or_kwargs(coords, coords_kwargs, "interp") | |||
indexers = dict(self._validate_interp_indexers(coords)) | |||
|
|||
if coords: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Handled by vectorize=True
. This is possibly a perf regression with numpy arrays, but a massive improvement with chunked arrays.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For posterity the bad thing about this approach is that it can greatly expand the number of core dimensions for the problem, limiting the potential for parallelism.
Consider the problem in #6799 (comment). In the following, dimension names are listed out in []
.
da[time, q, lat, lon].interp(q=bar[lat,lon])
gets rewritten to da[time,q,lat,lon].interp(q=bar[lat, lon], lat=lat[lat], lon=lon[lon])
which thanks to our automatic rechunking, makes dask merge chunks in lat, lon
too, for no benefit.
def _chunked_aware_interpnd(var, *coords, interp_func, interp_kwargs, localize=True): | ||
"""Wrapper for `_interpnd` through `blockwise` for chunked arrays. | ||
|
||
def _interpnd( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I merged in two functions to reduce indirection and make it easier to read.
exclude_dims=all_in_core_dims, | ||
dask="parallelized", | ||
kwargs=dict(interp_func=func, interp_kwargs=kwargs), | ||
dask_gufunc_kwargs=dict(output_sizes=output_sizes, allow_rechunk=True), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
allow_rechunk=True
matches the current behaviour where we rechunk along all core dimensions to a single chunk.
245697e
to
a5e1854
Compare
652a239
to
586f638
Compare
for more information, see https://pre-commit.ci
Merging on thursday if there are no comments. IMO this is a big win for maintainability. |
@@ -566,29 +577,30 @@ def _get_valid_fill_mask(arr, dim, limit): | |||
) <= limit | |||
|
|||
|
|||
def _localize(var, indexes_coords): | |||
def _localize(obj: T, indexes_coords: SourceDest) -> tuple[T, SourceDest]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably should use T_Xarray
instead of a plain T
to get rid of the type ignore at return.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That doesn't have Variable, so I'd have to make a new T_DatasetOrVariable
or a protocol with .isel
perhaps?
Co-authored-by: Michael Niklas <mick.niklas@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Benchmarks still looks good. Nice work!
# TODO: narrow interp_func to interpolator here | ||
return _interp1d(var, x_list, new_x_list, interp_func, interp_kwargs) # type: ignore[arg-type] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mypy is correct to error here right?
_interp1d calls interp_func(...)(....) and that should crash with a InterpCallable?
Is there a pytest with interp_func: InterpCallable
?
Is InterpCallable necessary? Would be nice to just remove it...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it depends on whether we end up using get_interpolator
or get_interpolator_nd
. I'm sure there's a test but can't remember which off the top of my head.
Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com>
for more information, see https://pre-commit.ci
Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com>
for more information, see https://pre-commit.ci
This reverts commit 1b9845d.
apply_ufunc
instead ofblockwise
directly.vectorize=True
to get sane dask graphs for vectorized interpolation to chunked arrays (interp
performance with chunked dimensions #6799 (comment))interp
performance with chunked dimensions #6799 (comment)whats-new.rst
api.rst
cc @ks905383 your vectorized interpolation example now has this graph:
instead of this quadratic monstrosity