diff --git a/aesara/tensor/elemwise.py b/aesara/tensor/elemwise.py index 34f9ea5459..bee96f4312 100644 --- a/aesara/tensor/elemwise.py +++ b/aesara/tensor/elemwise.py @@ -542,7 +542,6 @@ def connection_pattern(self, node): return [[True for output in node.outputs] for ipt in node.inputs] def L_op(self, inputs, outs, ograds): - from aesara.tensor.math import sum as at_sum # Compute grad with respect to broadcasted input rval = self._bgrad(inputs, outs, ograds) @@ -573,18 +572,9 @@ def L_op(self, inputs, outs, ograds): if isinstance(rval[i].type, (NullType, DisconnectedType)): continue - # List of all the dimensions that are broadcastable for input[i] so - # we can sum over them - # TODO: only count dimensions that were effectively broadcasted - to_sum = [ - j - for j, bcast in enumerate(ipt.type.broadcastable) - if bcast and not outs[0].broadcastable[j] - ] - - if to_sum: - sr = at_sum(rval[i], axis=to_sum, keepdims=True) - rval[i] = sr + rval[i] = aesara.tensor.extra_ops.sum_broadcastable_dims( + rval[i], ipt.shape, outs[0].shape + ) return rval diff --git a/aesara/tensor/extra_ops.py b/aesara/tensor/extra_ops.py index 28b54ba75c..a62bcff3aa 100644 --- a/aesara/tensor/extra_ops.py +++ b/aesara/tensor/extra_ops.py @@ -1,6 +1,6 @@ from collections.abc import Collection from functools import reduce -from typing import Iterable, Set, Tuple, Union +from typing import Iterable, Sequence, Set, Tuple, Union import numpy as np import numpy.core.numeric @@ -1665,19 +1665,8 @@ def grad(self, inputs, outputs_gradients): d_wrt_a = broadcast_to(dout, shape).sum(axis=new_dims) - # Determine the dimensions that were broadcast - _, static_shape = at.infer_static_shape(shape) - - # TODO: This needs to be performed at run-time when static shape - # information isn't available. - bcast_sums = [ - i - for i, (a_s, s_s) in enumerate(zip(a.type.shape, static_shape[-a.ndim :])) - if a_s == 1 and s_s != 1 - ] - - if bcast_sums: - d_wrt_a = d_wrt_a.sum(axis=bcast_sums, keepdims=True) + # Determine the dimensions that were broadcast and sum them + d_wrt_a = sum_broadcastable_dims(d_wrt_a, a.shape, shape[-a.ndim :]) return [d_wrt_a] + [ grad_undefined(self, i, shp) for i, shp in enumerate(shape, 1) @@ -1804,6 +1793,33 @@ def broadcast_arrays(*args: TensorVariable) -> Tuple[TensorVariable, ...]: return tuple(broadcast_to(a, broadcast_shape(*args)) for a in args) +def sum_broadcastable_dims( + value: TensorVariable, + shape_1: Sequence[Variable], + shape_2: Sequence[Variable], +) -> TensorVariable: + """Sum dimensions in `value` that are broadcasted between `shape_1` and `shape_2`.""" + from aesara.ifelse import ifelse + + for i, (s1, s2) in enumerate(zip(shape_1, shape_2)): + dummy_s1 = aes.get_scalar_type(dtype=s1.type.dtype)() + dummy_s2 = aes.get_scalar_type(dtype=s2.type.dtype)() + cond_op = Composite( + [dummy_s1, dummy_s2], + [ + aesara.scalar.and_( + aesara.scalar.eq(dummy_s1, 1), aesara.scalar.neq(dummy_s2, 1) + ) + ], + ) + value = ifelse( + cond_op(at.scalar_from_tensor(s1), at.scalar_from_tensor(s2)), + at_sum(value, axis=i, keepdims=True), + value, + ) + return value + + __all__ = [ "searchsorted", "cumsum", diff --git a/tests/tensor/test_elemwise.py b/tests/tensor/test_elemwise.py index 6bd514f277..48c6f6eb79 100644 --- a/tests/tensor/test_elemwise.py +++ b/tests/tensor/test_elemwise.py @@ -9,7 +9,7 @@ import aesara import aesara.scalar as aes import tests.unittest_tools as utt -from aesara.compile.mode import Mode +from aesara.compile.mode import Mode, get_default_mode from aesara.configdefaults import config from aesara.graph.basic import Apply, Variable from aesara.graph.fg import FunctionGraph @@ -889,6 +889,30 @@ def test_invalid_static_shape(self): ): x + y + def test_grad_sum_bcast_input_dims(self): + """Make sure broadcasted dimensions in the gradients are summed when static shape information isn't available.""" + Y = matrix("Y") + X = matrix("X") + X_grad = aesara.grad((X + Y).sum(), wrt=X) + + mode = get_default_mode().including("fast_run") + + X_grad_fn = aesara.function([X, Y], X_grad, mode=mode) + res = X_grad_fn(np.ones((1, 5)), np.ones((5, 5))) + assert np.array_equal(res, np.array([[5.0, 5.0, 5.0, 5.0, 5.0]])) + + # When the shapes are known at compile-time, the compiled graph should + # simplify + Y = tensor(np.float64, shape=(5, None), name="Y") + X = tensor(np.float64, shape=(1, 5), name="X") + X_grad = aesara.grad((X + Y).sum(), wrt=X) + + X_grad_fn = aesara.function([X, Y], X_grad, mode=mode) + res = X_grad_fn(np.ones((1, 5)), np.ones((5, 5))) + assert np.array_equal(res, np.array([[5.0, 5.0, 5.0, 5.0, 5.0]])) + + assert X_grad_fn.maker.fgraph.apply_nodes + def test_not_implemented_elemwise_grad(): # Regression test for unimplemented gradient in an Elemwise Op. diff --git a/tests/tensor/test_extra_ops.py b/tests/tensor/test_extra_ops.py index 22759226bf..f010e3d97c 100644 --- a/tests/tensor/test_extra_ops.py +++ b/tests/tensor/test_extra_ops.py @@ -1312,6 +1312,7 @@ def test_memory_leak(self): [ [lambda x: broadcast_to(x, (1,)), (1,)], [lambda x: broadcast_to(x, (6, 2, 5, 3)), (1,)], + [lambda x: broadcast_to(x, (6, 2, 5, 3)), (1,)], [lambda x: broadcast_to(x, (6, 2, 5, 3)), (5, 1)], [lambda x: broadcast_to(x, (6, 2, 1, 3)), (2, 1, 3)], ],