Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace push implementation with map_overlap for Dask #9712

Closed
wants to merge 1 commit into from

Conversation

phofl
Copy link
Contributor

@phofl phofl commented Nov 4, 2024

  • Tests added
  • User visible changes (including notable bug fixes) are documented in whats-new.rst
  • New functions/methods are listed in api.rst

Our benchmarks here showed us that ffill alone adds 4.5 million tasks to the graph which isn't great (the dataset has 550k chunks, so a multiplication of 9).

Rewriting this with map_overlap gets this down to 1.5 million tasks, which is basically the number of chunks times 3, which is the minimum that we can get to at the moment.

We merged a few map_overlap improvements today on the dask side to make this possible, but it's now a nice improvement (also makes code on the xarray side easier).

cc @dcherian

@dcherian
Copy link
Contributor

dcherian commented Nov 4, 2024

This can't work in general.

We switched from map_overlap to cumreduction intentionally: #6118 . I fully support improving cumreduction. It's a fundamental parallel primitive.

@dcherian
Copy link
Contributor

dcherian commented Nov 4, 2024

See this test:

@requires_dask
@requires_bottleneck
def test_push_dask():
import bottleneck
import dask.array
array = np.array([np.nan, 1, 2, 3, np.nan, np.nan, np.nan, np.nan, 4, 5, np.nan, 6])
for n in [None, 1, 2, 3, 4, 5, 11]:
expected = bottleneck.push(array, axis=0, n=n)
for c in range(1, 11):
with raise_if_dask_computes():
actual = push(dask.array.from_array(array, chunks=c), axis=0, n=n)
np.testing.assert_equal(actual, expected)
# some chunks of size-1 with NaN
with raise_if_dask_computes():
actual = push(
dask.array.from_array(array, chunks=(1, 2, 3, 2, 2, 1, 1)), axis=0, n=n
)
np.testing.assert_equal(actual, expected)

@phofl
Copy link
Contributor Author

phofl commented Nov 4, 2024

Oh I didn't consider None as limit, sorry about that

@dcherian
Copy link
Contributor

dcherian commented Nov 4, 2024

You could experiment with method="blelloch" vs method="sequential" in cumreduction.

@phofl
Copy link
Contributor Author

phofl commented Nov 4, 2024

I thought more about this, I didn't consider large limit values properly. cumreduction itself works ok-ish, the issue that makes the task-graph that large is the section here:


     if n is not None and 0 < n < array.shape[axis] - 1:
         arange = da.broadcast_to(
             da.arange(
                 array.shape[axis], chunks=array.chunks[axis], dtype=array.dtype
             ).reshape(
                 tuple(size if i == axis else 1 for i, size in enumerate(array.shape))
             ),
             array.shape,
             array.chunks,
         )
         valid_arange = da.where(da.notnull(array), arange, np.nan)
         valid_limits = (arange - push(valid_arange, None, axis)) <= n
         # omit the forward fill that violate the limit
         return da.where(valid_limits, push(array, None, axis), np.nan)

Bunch of operations that can't be fused properly because of the interactions between the 2 different arrays. I'll think a bit more if we can reduce this down somehow, but there isn't anything obvious right away (at least not to me).

cumreduction itself is equivalent to map_overlap from a topology perspective if the overlapping part only reaches a single neighbouring chunk, would you be open to calling overlap in these cases? Makes it a bit uglier, but that seems to be a reasonably common use-case from what I have seen so far (I might be totally wrong here)?

@dcherian
Copy link
Contributor

dcherian commented Nov 4, 2024

Can you add push to dask.array? Then you can add whatever optimization you want :). We'd be happy to dispatch instead of vendoring all this if we can.

Also, you should be able to write this as a single cumreduction that takes a 1D array of axis-indices and the input array as inputs. I wrote it for grouped ffill in flox: https://github.com/xarray-contrib/flox/blob/672be8ceeebfa588a15ebdc9861999efa12fa44e/flox/aggregations.py#L651

@phofl
Copy link
Contributor Author

phofl commented Nov 4, 2024

Oh, I'll check that one out.

Sure, there shouldn't be any reason not to add this. I'll check out the flox implementation, getting the number of tasks down would be nice, but adding it in Dask should be a good option anyway

@dcherian
Copy link
Contributor

dcherian commented Nov 4, 2024

Ah I guess the core issue is a dependency on numbagg and/or bottleneck.

@phofl
Copy link
Contributor Author

phofl commented Nov 4, 2024

That should be fine, we can just raise if neither is installed, similar to what you are doing here

@josephnowak
Copy link
Contributor

josephnowak commented Nov 9, 2024

Hi,

When I did that code I tried to preserve as much as possible the map_overlap because it is, in theory, faster, but the main issue is that it is not efficient to use it in the general case because it is always possible that more than one chunk is empty making that the push operation does not reach the next chunks, or if you want to make it work for the general case you would have to run it N - 1 times, being N the number of chunks.

If the main issue is the code that I did to apply a limit on the forward fill generates many tasks, I think that you could replace it with a kind of cumulative reset logic (I would really love to have a bitmask parameter on the cumulative operations that allows restarting the accumulate value), this operation would accumulate a count based on the number of contiguous nan, and if it finds a nonnan then it restart the count to 0, please see the code attached, I think it illustrates the idea and it is also a functional code, the only issue is that it requires numba and I'm not sure if you want to force to the user to have numba.

def push(array, n, axis):
    """
    Dask-aware bottleneck.push
    """
    import dask.array as da
    import numpy as np

    from xarray.core.duck_array_ops import _push

    def _fill_with_last_one(a, b):
        # cumreduction apply the push func over all the blocks first so, the only missing part is filling
        # the missing values using the last data of the previous chunk
        return np.where(~np.isnan(b), b, a)

    # The method parameter makes that the tests for python 3.7 fails.
    pushed_array = da.reductions.cumreduction(
        func=_push,
        binop=_fill_with_last_one,
        ident=np.nan,
        x=array,
        axis=axis,
        dtype=array.dtype,
    )

    if n is not None and 0 < n < array.shape[axis] - 1:
        import numba as nb

        @nb.vectorize([nb.int64(nb.int64, nb.int64)])
        def reset_cumsum(x, y):
            return x + y if y else y

        def combine_reset_cumsum(x, y):
            # Sum the previous result (X) only to the positions before the first zero
            bitmask = np.cumprod(y != 0, axis=axis)
            return np.where(bitmask, y + x, y)

        valid_positions = da.reductions.cumreduction(
            func=reset_cumsum.accumulate,
            binop=combine_reset_cumsum,
            ident=0,
            x=da.isnan(array).astype(int),
            axis=axis,
            dtype=int,
        ) <= n
        pushed_array = da.where(valid_positions, pushed_array, np.nan)

    return pushed_array

@dcherian
Copy link
Contributor

dcherian commented Nov 9, 2024

(I would really love to have a bitmask parameter on the cumulative operations that allows restarting the accumulate value)

See #9229. You can reset the accumulated value at the group edges.

this operation would accumulate a count based on the number of contiguous nan, and if it finds a nonnan then it restart the count to 0,

yes rpetty sure you can accumulate both the values and the count in a single scan, and then add a blockwise task that applies the mask.

@josephnowak
Copy link
Contributor

josephnowak commented Nov 9, 2024

I have some time waiting for that segmented scan functionality, it will be quite useful, and it could be used to replace the function that I did with numba to reset the cumulative sum, but if someone need it I can implement the same function that I did with numba using only numpy and sent a PR, as it is, it should reduce the number of tasks in half.

@josephnowak
Copy link
Contributor

josephnowak commented Nov 11, 2024

Hi @phofl, I made the following improvements to the code to reduce the number of tasks produced by the limit parameter and avoid using Numba. I hope this helps to solve the performance issue that you are facing. I can send a PR to Dask if you need it (I saw that a push method was added there).

def push(array, n, axis):
    """
    Dask-aware bottleneck.push
    """
    import dask.array as da
    import numpy as np

    from xarray.core.duck_array_ops import _push

    def _fill_with_last_one(a, b):
        # cumreduction apply the push func over all the blocks first so,
        # the only missing part is filling the missing values using the
        # last data of the previous chunk
        return np.where(np.isnan(b), a, b)

    # The method parameter makes that the tests for python 3.7 fails.
    pushed_array = da.reductions.cumreduction(
        func=_push,
        binop=_fill_with_last_one,
        ident=np.nan,
        x=array,
        axis=axis,
        dtype=array.dtype,
    )

    if n is not None and 0 < n < array.shape[axis] - 1:

        def reset_cumsum(x, axis):
            cumsum = np.cumsum(x, axis=axis)
            reset_points = np.maximum.accumulate(
                np.where(x == 0, cumsum, 0), axis=axis
            )
            return cumsum - reset_points

        def combine_reset_cumsum(x, y):
            bitmask = np.cumprod(y != 0, axis=axis)
            return np.where(bitmask, y + x, y)

        valid_positions = da.reductions.cumreduction(
            func=reset_cumsum,
            binop=combine_reset_cumsum,
            ident=0,
            x=da.isnan(array).astype(int),
            axis=axis,
            dtype=int,
        ) <= n
        pushed_array = da.where(valid_positions, pushed_array, np.nan)

    return pushed_array

@dcherian
Copy link
Contributor

Hah very nice work @josephnowak 👏 👏 👏

@dcherian
Copy link
Contributor

it'd be nice to get this to work with method="blelloch" which is the standard work-efficient scan algorithm.

@josephnowak
Copy link
Contributor

josephnowak commented Nov 11, 2024

The last time that I tried to use that method the tests of Xarray failed for Python 3.7, but I think that Xarray already dropped support for that version, so probably we can use it without any issue, I'm going to send a PR soon. Also, take into consideration that this is written on the Dask documentation, not sure why no one has done a proper benchmark.
image

@phofl
Copy link
Contributor Author

phofl commented Nov 12, 2024

Nice!

I think it makes sense to add this to both repositories for now since xarray can only dispatch to Dask if the dask version is 2024.11 or newer

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants