-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Replace push implementation with map_overlap for Dask #9712
Conversation
This can't work in general. We switched from map_overlap to cumreduction intentionally: #6118 . I fully support improving |
See this test: xarray/xarray/tests/test_duck_array_ops.py Lines 1008 to 1028 in a00bc91
|
Oh I didn't consider None as limit, sorry about that |
You could experiment with |
I thought more about this, I didn't consider large limit values properly.
Bunch of operations that can't be fused properly because of the interactions between the 2 different arrays. I'll think a bit more if we can reduce this down somehow, but there isn't anything obvious right away (at least not to me).
|
Can you add Also, you should be able to write this as a single |
Oh, I'll check that one out. Sure, there shouldn't be any reason not to add this. I'll check out the flox implementation, getting the number of tasks down would be nice, but adding it in Dask should be a good option anyway |
Ah I guess the core issue is a dependency on |
That should be fine, we can just raise if neither is installed, similar to what you are doing here |
Hi, When I did that code I tried to preserve as much as possible the map_overlap because it is, in theory, faster, but the main issue is that it is not efficient to use it in the general case because it is always possible that more than one chunk is empty making that the push operation does not reach the next chunks, or if you want to make it work for the general case you would have to run it N - 1 times, being N the number of chunks. If the main issue is the code that I did to apply a limit on the forward fill generates many tasks, I think that you could replace it with a kind of cumulative reset logic (I would really love to have a bitmask parameter on the cumulative operations that allows restarting the accumulate value), this operation would accumulate a count based on the number of contiguous nan, and if it finds a nonnan then it restart the count to 0, please see the code attached, I think it illustrates the idea and it is also a functional code, the only issue is that it requires numba and I'm not sure if you want to force to the user to have numba. def push(array, n, axis):
"""
Dask-aware bottleneck.push
"""
import dask.array as da
import numpy as np
from xarray.core.duck_array_ops import _push
def _fill_with_last_one(a, b):
# cumreduction apply the push func over all the blocks first so, the only missing part is filling
# the missing values using the last data of the previous chunk
return np.where(~np.isnan(b), b, a)
# The method parameter makes that the tests for python 3.7 fails.
pushed_array = da.reductions.cumreduction(
func=_push,
binop=_fill_with_last_one,
ident=np.nan,
x=array,
axis=axis,
dtype=array.dtype,
)
if n is not None and 0 < n < array.shape[axis] - 1:
import numba as nb
@nb.vectorize([nb.int64(nb.int64, nb.int64)])
def reset_cumsum(x, y):
return x + y if y else y
def combine_reset_cumsum(x, y):
# Sum the previous result (X) only to the positions before the first zero
bitmask = np.cumprod(y != 0, axis=axis)
return np.where(bitmask, y + x, y)
valid_positions = da.reductions.cumreduction(
func=reset_cumsum.accumulate,
binop=combine_reset_cumsum,
ident=0,
x=da.isnan(array).astype(int),
axis=axis,
dtype=int,
) <= n
pushed_array = da.where(valid_positions, pushed_array, np.nan)
return pushed_array |
See #9229. You can reset the accumulated value at the group edges.
yes rpetty sure you can accumulate both the values and the count in a single scan, and then add a blockwise task that applies the mask. |
I have some time waiting for that segmented scan functionality, it will be quite useful, and it could be used to replace the function that I did with numba to reset the cumulative sum, but if someone need it I can implement the same function that I did with numba using only numpy and sent a PR, as it is, it should reduce the number of tasks in half. |
Hi @phofl, I made the following improvements to the code to reduce the number of tasks produced by the limit parameter and avoid using Numba. I hope this helps to solve the performance issue that you are facing. I can send a PR to Dask if you need it (I saw that a push method was added there). def push(array, n, axis):
"""
Dask-aware bottleneck.push
"""
import dask.array as da
import numpy as np
from xarray.core.duck_array_ops import _push
def _fill_with_last_one(a, b):
# cumreduction apply the push func over all the blocks first so,
# the only missing part is filling the missing values using the
# last data of the previous chunk
return np.where(np.isnan(b), a, b)
# The method parameter makes that the tests for python 3.7 fails.
pushed_array = da.reductions.cumreduction(
func=_push,
binop=_fill_with_last_one,
ident=np.nan,
x=array,
axis=axis,
dtype=array.dtype,
)
if n is not None and 0 < n < array.shape[axis] - 1:
def reset_cumsum(x, axis):
cumsum = np.cumsum(x, axis=axis)
reset_points = np.maximum.accumulate(
np.where(x == 0, cumsum, 0), axis=axis
)
return cumsum - reset_points
def combine_reset_cumsum(x, y):
bitmask = np.cumprod(y != 0, axis=axis)
return np.where(bitmask, y + x, y)
valid_positions = da.reductions.cumreduction(
func=reset_cumsum,
binop=combine_reset_cumsum,
ident=0,
x=da.isnan(array).astype(int),
axis=axis,
dtype=int,
) <= n
pushed_array = da.where(valid_positions, pushed_array, np.nan)
return pushed_array |
Hah very nice work @josephnowak 👏 👏 👏 |
it'd be nice to get this to work with |
Nice! I think it makes sense to add this to both repositories for now since xarray can only dispatch to Dask if the dask version is 2024.11 or newer |
whats-new.rst
api.rst
Our benchmarks here showed us that ffill alone adds 4.5 million tasks to the graph which isn't great (the dataset has 550k chunks, so a multiplication of 9).
Rewriting this with map_overlap gets this down to 1.5 million tasks, which is basically the number of chunks times 3, which is the minimum that we can get to at the moment.
We merged a few map_overlap improvements today on the dask side to make this possible, but it's now a nice improvement (also makes code on the xarray side easier).
cc @dcherian