Skip to content

Commit

Permalink
Fix missing broadcast dimension sums in Elemwise, BroadcastTo gradients
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Oct 13, 2022
1 parent c4a9ff8 commit d2a61ba
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 28 deletions.
16 changes: 3 additions & 13 deletions aesara/tensor/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,6 @@ def connection_pattern(self, node):
return [[True for output in node.outputs] for ipt in node.inputs]

def L_op(self, inputs, outs, ograds):
from aesara.tensor.math import sum as at_sum

# Compute grad with respect to broadcasted input
rval = self._bgrad(inputs, outs, ograds)
Expand Down Expand Up @@ -573,18 +572,9 @@ def L_op(self, inputs, outs, ograds):
if isinstance(rval[i].type, (NullType, DisconnectedType)):
continue

# List of all the dimensions that are broadcastable for input[i] so
# we can sum over them
# TODO: only count dimensions that were effectively broadcasted
to_sum = [
j
for j, bcast in enumerate(ipt.type.broadcastable)
if bcast and not outs[0].broadcastable[j]
]

if to_sum:
sr = at_sum(rval[i], axis=to_sum, keepdims=True)
rval[i] = sr
rval[i] = aesara.tensor.extra_ops.sum_broadcastable_dims(
rval[i], ipt.shape, outs[0].shape
)

return rval

Expand Down
44 changes: 30 additions & 14 deletions aesara/tensor/extra_ops.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections.abc import Collection
from functools import reduce
from typing import Iterable, Set, Tuple, Union
from typing import Iterable, Sequence, Set, Tuple, Union

import numpy as np
import numpy.core.numeric
Expand Down Expand Up @@ -1665,19 +1665,8 @@ def grad(self, inputs, outputs_gradients):

d_wrt_a = broadcast_to(dout, shape).sum(axis=new_dims)

# Determine the dimensions that were broadcast
_, static_shape = at.infer_static_shape(shape)

# TODO: This needs to be performed at run-time when static shape
# information isn't available.
bcast_sums = [
i
for i, (a_s, s_s) in enumerate(zip(a.type.shape, static_shape[-a.ndim :]))
if a_s == 1 and s_s != 1
]

if bcast_sums:
d_wrt_a = d_wrt_a.sum(axis=bcast_sums, keepdims=True)
# Determine the dimensions that were broadcast and sum them
d_wrt_a = sum_broadcastable_dims(d_wrt_a, a.shape, shape[-a.ndim :])

return [d_wrt_a] + [
grad_undefined(self, i, shp) for i, shp in enumerate(shape, 1)
Expand Down Expand Up @@ -1804,6 +1793,33 @@ def broadcast_arrays(*args: TensorVariable) -> Tuple[TensorVariable, ...]:
return tuple(broadcast_to(a, broadcast_shape(*args)) for a in args)


def sum_broadcastable_dims(
value: TensorVariable,
shape_1: Sequence[Variable],
shape_2: Sequence[Variable],
) -> TensorVariable:
"""Sum dimensions in `value` that are broadcasted between `shape_1` and `shape_2`."""
from aesara.ifelse import ifelse

for i, (s1, s2) in enumerate(zip(shape_1, shape_2)):
dummy_s1 = aes.get_scalar_type(dtype=s1.type.dtype)()
dummy_s2 = aes.get_scalar_type(dtype=s2.type.dtype)()
cond_op = Composite(
[dummy_s1, dummy_s2],
[
aesara.scalar.and_(
aesara.scalar.eq(dummy_s1, 1), aesara.scalar.neq(dummy_s2, 1)
)
],
)
value = ifelse(
cond_op(at.scalar_from_tensor(s1), at.scalar_from_tensor(s2)),
at_sum(value, axis=i, keepdims=True),
value,
)
return value


__all__ = [
"searchsorted",
"cumsum",
Expand Down
26 changes: 25 additions & 1 deletion tests/tensor/test_elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import aesara
import aesara.scalar as aes
import tests.unittest_tools as utt
from aesara.compile.mode import Mode
from aesara.compile.mode import Mode, get_default_mode
from aesara.configdefaults import config
from aesara.graph.basic import Apply, Variable
from aesara.graph.fg import FunctionGraph
Expand Down Expand Up @@ -889,6 +889,30 @@ def test_invalid_static_shape(self):
):
x + y

def test_grad_sum_bcast_input_dims(self):
"""Make sure broadcasted dimensions in the gradients are summed when static shape information isn't available."""
Y = matrix("Y")
X = matrix("X")
X_grad = aesara.grad((X + Y).sum(), wrt=X)

mode = get_default_mode().including("fast_run")

X_grad_fn = aesara.function([X, Y], X_grad, mode=mode)
res = X_grad_fn(np.ones((1, 5)), np.ones((5, 5)))
assert np.array_equal(res, np.array([[5.0, 5.0, 5.0, 5.0, 5.0]]))

# When the shapes are known at compile-time, the compiled graph should
# simplify
Y = tensor(np.float64, shape=(5, None), name="Y")
X = tensor(np.float64, shape=(1, 5), name="X")
X_grad = aesara.grad((X + Y).sum(), wrt=X)

X_grad_fn = aesara.function([X, Y], X_grad, mode=mode)
res = X_grad_fn(np.ones((1, 5)), np.ones((5, 5)))
assert np.array_equal(res, np.array([[5.0, 5.0, 5.0, 5.0, 5.0]]))

assert X_grad_fn.maker.fgraph.apply_nodes


def test_not_implemented_elemwise_grad():
# Regression test for unimplemented gradient in an Elemwise Op.
Expand Down
1 change: 1 addition & 0 deletions tests/tensor/test_extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1312,6 +1312,7 @@ def test_memory_leak(self):
[
[lambda x: broadcast_to(x, (1,)), (1,)],
[lambda x: broadcast_to(x, (6, 2, 5, 3)), (1,)],
[lambda x: broadcast_to(x, (6, 2, 5, 3)), (1,)],
[lambda x: broadcast_to(x, (6, 2, 5, 3)), (5, 1)],
[lambda x: broadcast_to(x, (6, 2, 1, 3)), (2, 1, 3)],
],
Expand Down

0 comments on commit d2a61ba

Please sign in to comment.