-
-
Notifications
You must be signed in to change notification settings - Fork 154
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Wrong gradients when inputs are dynamically broadcasted #1089
Comments
This comment was marked as off-topic.
This comment was marked as off-topic.
This comment was marked as off-topic.
This comment was marked as off-topic.
This comment was marked as off-topic.
This comment was marked as off-topic.
This comment was marked as off-topic.
This comment was marked as off-topic.
This comment was marked as off-topic.
This comment was marked as off-topic.
This comment was marked as off-topic.
This comment was marked as off-topic.
It seems like we might need a new Op that unbroadcasts (reduces) arrays to a given shape or leaves the input unchanged. Check https://mostafa-samir.github.io/auto-diff-pt2/#unbroadcasting-adjoints Something like: x = at.matrix("x")
# at.reduce_to is probably a better name
# maybe sum is all we will ever need
y1 = at.unbroadcast_to(x, shape=(1, 5), reduce_op="sum”)
y1.eval({x: np.ones((5, 5))}) # [[5, 5, 5, 5, 5]]
# This won't do anything, but shape may be only
# known at runtime, as in the example in this issue!
y2 = at.unbroadcast_to(x, shape=(5, 5))
y2.eval({x: np.ones((5, 5))}) # np.ones((5, 5))
# If the shape is not compatible with something that could
# have been broadcasted to the input shape, an error is raised
y3 = at.unbroadcast_to(x, shape=(2, 5))
y3.eval({x: np.ones((5, 5))}) # ValueError This was also brought up by @Sayam753 and @purna135 in Slack in relation to their work on batched solve where dynamic unbroadcasting gradients also crops up. It was that discussion that led me to suspect of this bug! Edit: This may be possible already without a specialized Op, if sum allows for symbolic axis? Does it? In that case we could cook a helper pretty quickly, and perhaps add some rewrite in case the axis are constant folded during compilation/ and a sum with constant axis is more efficient. Edit: Sum does not allow for variable axis |
Edit: mentioned other related issues in the top comment. |
Yeah, it looks like we might need to add symbolic conditions for those dimensions and let them be simplified later via shape inference and/or constant folding. This is similar to what we do when constructing broadcasted shape graphs. |
This comment was marked as off-topic.
This comment was marked as off-topic.
This comment was marked as off-topic.
This comment was marked as off-topic.
What's up with that shape of (2, 5, 0)? I think that combo is invalid as per numpy broadcasting rules due to the zero?
That sounds valid, but it seems a more convoluted answer if you ONLY want to fix this problem. In these cases we have the input and broadcasted gradient output, so the only thing we need is to reduce that gradient along the dimensions that where of size 1 in the input. Actually, in the Elemwise case we don't even have to worry about new dims, because make_node adds Dimshuffles to the inputs to align the number of dims (but we will perhaps remove that one day) So what we need is just: # perform method of new Op boils down to this
def unbroadcast_to(x, shape):
axis_to_sum = [
i
for i, (s, xs) in enumerate(zip(shape, x.shape))
if s==1 and xs !=1
]
if not axis_to_sum:
return x
return np.sum(x, axis=axis_to_sum, keepdims=True) In the grad where this issue would crop up we would do something like grad = unbroadcast_to(bcast_grad, shape=input.shape) And then we could have some rewrites to try to get rid of this Op during compilation. For instance:
|
The problem is that without an Op like the one I sketched, the only safe thing to do when you can't be sure ahead of time if something will have had a shape of 1 (or certainly not 1) is to raise in the grad method. If IfElse allowed for different shapes in the two branches we could also write a symbolic graph that applies the needed logic, but from one of the open issues it seems that both branches must have the same shape. |
Allowing sum to have symbolic axis (as long as keepdims is used, this should be fine for Aesara) would also allow for a simple solution without new Ops. But maybe that would raise a whole new set of problems |
Simply put, if we don't have the information at compile time, then it needs to be handled at run-time. |
Additionally, the description of this issue needs to clarify which result is correct and why. |
The gradients should have the same shape of the inputs, so the case where row is used is correct. This issue will arise for any Op that may or not broadcast its inputs at runtime. If broadcast occurs you need to sum the gradient across the broadcasted dimensions, otherwise you should not. However Aesara does not provide any building blocks that can achieve this branch logic AFAICT. |
Note that explicitly broadcasting all the inputs (like the explicit Dimshuffles introduced by Elemwise) wouldn't fix this either. The gradient of BroadcastTo shares the same limitations of Elemwise. |
It does. |
Via what? |
To be clear, we need something that can do the following. def foo(x, y):
...
x = at.matrix("x")
y = np.random.normal(size=(5, 5))
f = aesara.function([x], foo(x, y)
assert f(np.ones((1, 5))) == np.sum(y, axis=0, keepdims=True)
assert f(np.ones((5, 1))) == np.sum(y, axis=1, keepdims=True)
assert f(np.ones((1, 1))) == np.sum(y, axis=(0, 1), keepdims=True)
assert f(np.ones((5, 5))) == y |
This comment was marked as off-topic.
This comment was marked as off-topic.
This comment was marked as off-topic.
This comment was marked as off-topic.
This comment was marked as off-topic.
This comment was marked as off-topic.
This comment was marked as off-topic.
This comment was marked as off-topic.
This comment was marked as off-topic.
This comment was marked as off-topic.
This helper will probably also have too be updated when/if the ops that use in in the grad allow for runtime broadcasting: aesara/aesara/tensor/subtensor.py Lines 1893 to 1924 in 3500fec
|
This comment was marked as off-topic.
This comment was marked as off-topic.
This comment was marked as off-topic.
This comment was marked as off-topic.
This comment was marked as off-topic.
This comment was marked as off-topic.
To clarify the relationship between this Aesara issue and Theano's old broadcasting assumptions/ import theano
import theano.tensor as tt
import numpy as np
theano.__version__
# '1.0.5'
X_row = tt.row("X_row")
X_matrix = tt.matrix("X_matrix")
def X_grad_fn_constructor(X):
Y = tt.matrix("Y")
X_sum = tt.sum(X + Y)
X_grad = tt.grad(X_sum, wrt=X)
X_grad_fn = theano.function([X, Y], X_grad)
return X_grad_fn
X_grad_fn_row = X_grad_fn_constructor(X_row)
X_grad_fn_matrix = X_grad_fn_constructor(X_matrix)
# This input is broadcastable in the first dimension, but the `Type`-level
# representation of that fact is lacking in the `X_matrix` case. Let's see how
# Theano handles this broadcast information disparity.
# To be clear, *both* cases should theoretically return the same values for the
# same inputs.
X_val = np.ones((1, 5))
Y_val = np.ones((5, 5))
# The "row"-`Type` case (i.e. we tell Theano that the first dimension of `X` is
# broadcastable)
row_res = X_grad_fn_row(X_val, Y_val)
row_res
# array([[5., 5., 5., 5., 5.]])
# The "matrix"-`Type` case (i.e. we *don't* tell Theano that the first
# dimension of `X` is actually broadcastable)
matrix_res = X_grad_fn_matrix(X_val, Y_val)
matrix_res
# array([[1., 1., 1., 1., 1.]])
assert np.array_equal(matrix_res, row_res)
# AssertionError: In other words, Theano's assumptions were not capable of solving the issue raised here. Instead, the same shape inference problem(s) simply took different forms. (N.B. This also means that no amount "reverting" will fix this issue.) The thing we're trying to fix in this issue has always been an issue; only now (e.g. with |
This comment was marked as off-topic.
This comment was marked as off-topic.
This comment was marked as off-topic.
This comment was marked as off-topic.
This comment was marked as off-topic.
This comment was marked as off-topic.
To reiterate, this issue can be closed by
This issue concerns itself with the relevant details of the above two approaches, and any new ones we may not have considered yet. Conversations about |
This comment was marked as off-topic.
This comment was marked as off-topic.
To be fair, from an outsider perspective, the discussion here is impossible to follow. The concerns that the issue raises are valid, but the discussion has since diverged to questions that are more fundamental and/or historical in nature. To move forward I suggest
|
I just ran @ricardoV94's example both on the current HEAD and on def test_ambiguous_broadcast():
import aesara
import aesara.tensor as at
import numpy as np
x_row = at.row("x_row")
x_matrix = at.matrix("x_matrix")
y = at.matrix("y")
x_row_grad = at.grad(at.sum(x_row + y), wrt=x_row)
x_matrix_grad = at.grad(at.sum(x_matrix + y), wrt=x_matrix)
f_row = aesara.function([x_row, y], x_row_grad)
row_res = f_row(np.ones((1, 5)), np.ones((5, 5)))
f_matrix = aesara.function([x_matrix, y], x_matrix_grad)
assert np.array_equal(f_matrix(np.ones((1, 5)), np.ones((5, 5))), row_res)
test_ambiguous_broadcast() and it fails in both situations. This means that this issue is not a consequence of #928. In other words, reverting this change is not going to fix this bug. Try this for yourself (don't forget to clear the cache), so we can all agree on this point moving forward. Comment "I see it" (and only that for now) if you can reproduce it, comment something else only if you can't reproduce it. |
I see it |
@rlouf This is clearly explained in one of the "off-topic" comments as arising from an "inconsistent" rewrite: #1089 (comment) Running on import aesara
import aesara.tensor as at
from aesara.compile.mode import Mode
import numpy as np
x_matrix = at.matrix("x_matrix")
y = at.matrix("y")
x_matrix_grad = at.grad(at.sum(x_matrix + y), wrt=x_matrix)
f_matrix = aesara.function(
[x_matrix, y],
x_matrix_grad,
mode=Mode().excluding("local_fill_to_alloc"),
)
matrix_res = f_matrix(np.ones((1, 5)), np.ones((5, 5))) # ValueError It's one consequence of "rewrites assume original graphs were valid" mindset: https://theano-pymc.readthedocs.io/en/latest/tutorial/shape_info.html#shape-inference-problem (they mention the same applies to If you want an example that clear fails before that commit and passes after, make it just a bit more complex: import aesara
import aesara.tensor as at
import numpy as np
x_matrix = at.matrix("x_matrix")
y = at.matrix("y")
x_matrix_grad = at.grad((at.sum(at.exp(x_matrix + y))), wrt=x_matrix)
f_matrix = aesara.function(
[x_matrix, y],
x_matrix_grad,
)
# ValueError before `b60cf7240` and not after
matrix_res = f_matrix(np.ones((1, 5)), np.ones((5, 5))) |
The reason I am writing this is newcomers to the issue will take your first comment as face value and assume #928 is the problem, and get confused if they try to run it with the commit before that. We're not writing these just for ourselves. If the example you gave me indeed fails after #928 and not before could you please edit your original comment? |
Thanks. I updated the original message to link to that comment. |
Doing that changes the entire nature of this issue, which was never originally about that It would make more sense to describe the errant rewrites in the opening comment. |
Unless I'm mistaken you left the original code snippet that fails both after and before #928? The reason I'm asking this is that issues / PRs in the repository are not only read by those who interacted with them. We should always keep this in mind and strive to make them understandable to anyone vaguely familiar with Aesara. I spent hours trying to understand what was going on in this issue before realizing yesterday that assertions in the original comment did not hold (namely that the particular example you provided fails because of #928). Most people will not do this legwork and take your original comment at face value. Ensues general confusion. |
This bug is an unexpected consequence of #928 and rewrites that make certain assumptions: #1089 (comment)
The faulty logic is found here:
aesara/aesara/tensor/elemwise.py
Lines 552 to 570 in 7f4e0ab
This is also likely a problem in the grad of
BroadcastTo
which callsinfer_broadcastable
and which defaults to assuming something will not have broadcasted if a static shape of 1 can't be inferred.aesara/aesara/tensor/extra_ops.py
Line 1593 in 7f8af9b
aesara/aesara/tensor/basic.py
Line 1313 in 7f8af9b
And also GEMM since #986
I am not sure if there's a good solution to this problem, as we would need an expression with different output shapes depending on whether the runtime inputs are broadcasted or not.
Solution might look something like: #1089 (comment)
The text was updated successfully, but these errors were encountered: