Compute dimension sums in Elemwise.grad
at run-time
#1260
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR provides initial fixes for #1089 by supporting ambiguous broadcasting cases. In its current state, this PR is an investigation into the requirements and unforeseen issues behind adding support for those cases.
The approach used here moves the compile/construction-time conditional
sum
logic inElemwise.grad
into the graph in order to handle dimensions lacking complete broadcast information (see here). In this context, having complete broadcast information means that we know definitively whether or not the shape of a dimension is 1 (i.e. we knowvar.type.shape[d] != 1
for a dimensiond
).Since we want to interpret
var.type.shape[d] == None
asvar.type.shape[d] != 1
at some points, we need to address the old logic inElemwise.grad
when it encountersVariable
s withTensorType
s that cannot be used to determine which axes need to be summed when the gradient graph is constructed. By moving the axes summing conditions into the graph, we can always perform the correct computations, especially when the requisite broadcasting information is only available at run-time.More importantly, the extra logic can be removed whenever the requisite information becomes available at compile-time. By using existing
Op
s, we can leverage existing rewrites and our shape inference to automatically remove the extra logic. How we choose to represent that logic (i.e. whichOp
s to use) will be important, so expect this PR to iterate on that.One important point that needs to be addressed—for at least an
ifelse
approach—is the forking of conditional branches introduced by nested conditionals and iterated gradients. There's ultimately no avoiding this issue when the requisite shape information isn't available, but there are definitely better ways to handle it.Regardless, we can always make different default assumptions (e.g. that
var.type.shape != 1
in certain cases) and/or represent constraints that provide complete broadcast information and, as a result, produce graphs without the extra logic in the vast majority of use cases. #1122 and #1170 cover this topic in different ways. (N.B. The topic is effectively independent of the new support being added in this PR, although the issue this PR addresses could—and is—largely mitigated by such changes.)It looks like second order gradients generate graphs that make use ofDimShuffle.input_broadcastable
andDimShuffle
's broadcastable dimension dropping feature, and that currently requires complete broadcasting information.DimShuffle
will probably need to be changed so that it can at least attempt to drop dimensions that aren't known to be broadcastable.This issue has to do with gradients of
IfElse
when one branch is broadcastable and the other isn't. More specifically, when theIfElse
gradients are evaluated at a point that conforms to/is intended for only one branch and its distinct shape. As a work-around,specify_shape
is being used to make sure that the gradients in each branch match their original shapes, although something significantly better seems possible, but we need to investigate that separately.shape=(..., 1, ...)
when broadcastable input values are used). This will allow us to produce the original graphs expected by the tests.