Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fast weighted sum #1224

Closed
crusaderky opened this issue Jan 23, 2017 · 5 comments
Closed

fast weighted sum #1224

crusaderky opened this issue Jan 23, 2017 · 5 comments

Comments

@crusaderky
Copy link
Contributor

crusaderky commented Jan 23, 2017

In my project I'm struggling with weighted sums of 2000-4000 dask-based xarrays. The time to reach the final dask-based array, the size of the final dask dict, and the time to compute the actual result are horrendous.

So I wrote the below which - as laborious as it may look - gives a performance boost nothing short of miraculous. At the bottom you'll find some benchmarks as well.

https://gist.github.com/crusaderky/62832a5ffc72ccb3e0954021b0996fdf

In my project, this deflated the size of the final dask dict from 5.2 million keys to 3.3 million and cut a 30% from the time required to define it.

I think it's generic enough to be a good addition to the core xarray module. Impressions?

@shoyer
Copy link
Member

shoyer commented Jan 23, 2017

Interesting -- thanks for sharing! I am interested in performance improvements but also a little reluctant to add in specialized optimizations directly into xarray.

You write that this is equivalent to sum(a * w for a, w in zip(arrays, weights)). How does this compare to stacking doing the sum in xarray, e.g.,(arrays * weights).sum('stacked'), where arrays and weights are now DataArray objects with a 'stacked' dimension? Or maybe arrays.dot(weights)?

Using vectorized operations feels a bit more idiomatic (though also maybe more verbose). It also may be more performant. Note that the builtin sum is not optimized well by dask because it's basically equivalent to a loop:

def sum(xs):
    result = 0
    for x in xs:
        result += x
    return result

In contrast, dask.array.sum builds up a tree so it can do the sum in parallel.

There have also been discussion in #422 about adding a dedicated method for weighted mean.

@crusaderky
Copy link
Contributor Author

(arrays * weights).sum('stacked') was my first attempt. It performed considerably worse than sum(a * w for a, w in zip(arrays, weights)) - mostly because xarray.concat() is not terribly performant (I did not look deeper into it).

I did not try dask.array.sum() - worth some playing with.

@shoyer
Copy link
Member

shoyer commented Jan 23, 2017 via email

@crusaderky
Copy link
Contributor Author

Both. One of the biggest problem is that the data of my interestest is a mix of

  • 1D arrays with dims=(scenario, ) and shape=(500000, ) (stressed financial instruments under a Monte Carlo stress set)
  • 0D arrays with dims=() (financial instruments that are impervious to the Monte Carlo stresses and never change values)
    So before you do concat(), you need to call broadcast(), which effectively means that doing the sums on your bunch of very fast 0D instruments suddendly requires repeating them on 500k points.

Even keeping the two lots separate (which is fastwsum does) performed considerably slower.

However, this was over a year ago and much before xarray.dot() and dask.einsum(), so I'll need to tinker with it again.

@crusaderky
Copy link
Contributor Author

Retiring this as it is way too specialized for the main xarray library.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants