-
-
Notifications
You must be signed in to change notification settings - Fork 151
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add graph optimizer that batches scalar computations #287
Comments
I would be interested in contributing to this issue. Are you accepting PRs? |
Yes, we are. A note about this proposed optimization: there are already some optimizations that address similar situations, so we need to be very clear about the what and when of this new optimization. For example, there's a "fusion" optimization that essentially creates a new scalar ufunc (in the NumPy sense) and broadcasts that across its arguments. When the arguments are combinations of vectors and scalars, this operation is applied. Here's an illustration: import aesara
import aesara.tensor as aet
from aesara.printing import debugprint
# Disable C compilation
aesara.config.cxx = ""
aesara.config.compute_test_value = "warn"
a = aet.vector("a")
a.tag.test_value = np.r_[1, 2, 3]
b = aet.scalar("b")
b.tag.test_value = 4 We'll start with a simple combination of a vector and two scalars: >>> output = a + b + 1
>>> f = aesara.function([a, b], output)
>>> debugprint(f.maker.fgraph)
Elemwise{add,no_inplace} [id A] '' 1
|TensorConstant{(1,) of 1.0} [id B]
|a [id C]
|InplaceDimShuffle{x} [id D] '' 0
|b [id E] The resulting "optimized" graph implies a computation similar to Now, if we introduce an >>> output = aet.log(a) + b + 1
>>> f = aesara.function([a, b], output)
>>> debugprint(f.maker.fgraph)
Elemwise{Composite{(i0 + log(i1) + i2)}} [id A] '' 1
|TensorConstant{(1,) of 1.0} [id B]
|a [id C]
|InplaceDimShuffle{x} [id D] '' 0
|b [id E] The At the moment, I don't see a reason why this proposed optimization couldn't work well with broadcast-based optimizations like fusion, but these are the kinds of interactions we need to consider. |
@tmcclintock Any progress or questions? |
@twiecki @brandonwillard sadly no progress. Other obligations have eaten up my time, so it may be O(month) before I get back to this. Feel free to tackle this, and I can find another issue when I have more time. PS, thank you for the tips! |
This same idea applies to multiplications which can be batched in the same way. |
if we have
vector_a + scalar_b + scalar_c
we have two vector additions. If we were to rewrite this to first sum all the scalars (scalar_b + scalar_c + vector_a
) it would be twice as fast as we only do one vector addition.It should be pretty easy to add a graph optimizer that reshuffles the terms so that scalars are grouped together.
Thanks @fehiepsi for pointing this out.
The text was updated successfully, but these errors were encountered: