diff --git a/aesara/tensor/blockwise.py b/aesara/tensor/blockwise.py index 7d91116328..91521b29f3 100644 --- a/aesara/tensor/blockwise.py +++ b/aesara/tensor/blockwise.py @@ -9,8 +9,10 @@ from aesara.graph.op import Op from aesara.tensor import get_scalar_constant_value from aesara.tensor.basic import atleast_Nd +from aesara.tensor.elemwise import DimShuffle from aesara.tensor.exceptions import NotScalarConstantError from aesara.tensor.extra_ops import broadcast_shape +from aesara.tensor.math import sum as at_sum from aesara.tensor.shape import shape_tuple from aesara.tensor.type import TensorType @@ -120,19 +122,44 @@ def __init__(self, op, signature=None): self.op = op self.signature = signature or self.op.gufunc_sig - def make_node(self, *inputs): + def get_output_info(self, *inputs): + """Return the outputs dtype and broadcastable pattern and the + dimshuffled inputs. + + """ + target_length = max(input.type.ndim for input in inputs) + args = [] + for input in inputs: + length = input.type.ndim + difference = target_length - length + if not difference: + args.append(input) + else: + # TODO: use LComplete instead + args.append( + DimShuffle( + input.type.broadcastable, + ["x"] * difference + list(range(length)), + )(input) + ) + inputs = args + + # TODO: Correct this + out_dtype = inputs[0].dtype + bcast_shape, dim_sizes = _parse_input_dimensions(inputs, self.signature[0]) + output_shapes = _calculate_shapes(bcast_shape, dim_sizes, self.signature[1]) + + return out_dtype, output_shapes, inputs + + def make_node(self, *inputs): num_expected_inps = len(self.signature[0]) if len(inputs) != num_expected_inps: raise ValueError( f"Expected {int(num_expected_inps)} inputs, got {len(inputs)}" ) - # TODO: Correct this - out_dtype = inputs[0].dtype - - bcast_shape, dim_sizes = _parse_input_dimensions(inputs, self.signature[0]) - output_shapes = _calculate_shapes(bcast_shape, dim_sizes, self.signature[1]) + out_dtype, output_shapes, inputs = self.get_output_info(*inputs) def safe_const_val(x): try: @@ -151,7 +178,31 @@ def infer_shape(self, fgraph, node, shapes): output_shapes = _calculate_shapes(bcast_shape, dim_sizes, self.signature[1]) return output_shapes - def L_op( + def L_op(self, inputs, outs, ograds): + # Compute grad with respect to broadcasted input + rval = self._bgrad(inputs, outs, ograds) + + # sum out the broadcasted dimensions + for i, ipt in enumerate(inputs): + if isinstance(rval[i].type, (NullType, DisconnectedType)): + continue + + # List of all the dimensions that are broadcastable for input[i] so + # we can sum over them + # TODO: only count dimensions that were effectively broadcasted + to_sum = [ + j + for j, bcast in enumerate(ipt.type.broadcastable) + if bcast and not outs[0].broadcastable[j] + ] + + if to_sum: + sr = at_sum(rval[i], axis=to_sum, keepdims=True) + rval[i] = sr + + return rval + + def _bgrad( self, inputs: Sequence[Variable], outputs: Sequence[Variable], @@ -162,13 +213,14 @@ def L_op( core_inputs = [] for _inp, _inp_sig in zip(inputs, self.signature[0]): curr_dtype = _inp.type.dtype - curr_static_shape = _inp.type.shape[: len(_inp_sig)] + # extract the core dimensions + curr_static_shape = _inp.type.shape[-len(_inp_sig) :] core_inputs.append(TensorType(curr_dtype, curr_static_shape)()) core_out_grads = [] for _out_grad, _out_sig in zip(ograds, self.signature[1]): curr_dtype = _out_grad.type.dtype - curr_static_shape = _out_grad.type.shape[: len(_out_sig)] + curr_static_shape = _out_grad.type.shape[-len(_out_sig) :] core_out_grads.append(TensorType(curr_dtype, curr_static_shape)()) core_outputs: Sequence[Variable] = self.op.make_node(*core_inputs).outputs @@ -209,9 +261,19 @@ def transform(var: "TensorVariable", client_node: Optional[Apply]) -> Variable: grad_signature = getattr(node.op, "gufunc_sig", None) if grad_signature is None: - # TODO: Can we manually derive gufunc signatures for any `Op` - # in this situation? - grad_signature = None + if isinstance(node.op, DimShuffle): + # remove the extra dimensions that + # we have added during op creation + new_order = [i for i in node.op.new_order if i != "x"] + + # derive gufunc signature for DimShuffle + input_signature = tuple([f"a{i}" for i in range(len(new_order))]) + output_signature = tuple([f"a{i}" for i in new_order]) + grad_signature = ((input_signature,), (output_signature,)) + else: + raise ValueError( + f"'{node.op}' object has no attribute 'gufunc_sig'" + ) new_r = Blockwise(node.op, signature=grad_signature)(*blocked_inputs) assert isinstance(new_r, Variable) @@ -229,7 +291,10 @@ def py_func(*inner_inputs): # TODO:This can be avoided by making a single dummy node # But will that cover all cases? inner_node = self.op.make_node(*inner_inputs) - self.op.perform(inner_node, inner_inputs, res) + if isinstance(self.op, DimShuffle): + self.op.perform(inner_node, inner_inputs, res, params=None) + else: + self.op.perform(inner_node, inner_inputs, res) # Numpy always expects outputs to be Numpy arrays # And since we have a variable number of outputs