Skip to content

Commit

Permalink
make L_op work
Browse files Browse the repository at this point in the history
  • Loading branch information
purna135 authored and brandonwillard committed Oct 6, 2022
1 parent bd04d54 commit 73bffe8
Showing 1 changed file with 78 additions and 13 deletions.
91 changes: 78 additions & 13 deletions aesara/tensor/blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
from aesara.graph.op import Op
from aesara.tensor import get_scalar_constant_value
from aesara.tensor.basic import atleast_Nd
from aesara.tensor.elemwise import DimShuffle
from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.extra_ops import broadcast_shape
from aesara.tensor.math import sum as at_sum
from aesara.tensor.shape import shape_tuple
from aesara.tensor.type import TensorType

Expand Down Expand Up @@ -120,19 +122,44 @@ def __init__(self, op, signature=None):
self.op = op
self.signature = signature or self.op.gufunc_sig

def make_node(self, *inputs):
def get_output_info(self, *inputs):
"""Return the outputs dtype and broadcastable pattern and the
dimshuffled inputs.
"""
target_length = max(input.type.ndim for input in inputs)
args = []
for input in inputs:
length = input.type.ndim
difference = target_length - length
if not difference:
args.append(input)
else:
# TODO: use LComplete instead
args.append(
DimShuffle(
input.type.broadcastable,
["x"] * difference + list(range(length)),
)(input)
)
inputs = args

# TODO: Correct this
out_dtype = inputs[0].dtype

bcast_shape, dim_sizes = _parse_input_dimensions(inputs, self.signature[0])
output_shapes = _calculate_shapes(bcast_shape, dim_sizes, self.signature[1])

return out_dtype, output_shapes, inputs

def make_node(self, *inputs):
num_expected_inps = len(self.signature[0])
if len(inputs) != num_expected_inps:
raise ValueError(
f"Expected {int(num_expected_inps)} inputs, got {len(inputs)}"
)

# TODO: Correct this
out_dtype = inputs[0].dtype

bcast_shape, dim_sizes = _parse_input_dimensions(inputs, self.signature[0])
output_shapes = _calculate_shapes(bcast_shape, dim_sizes, self.signature[1])
out_dtype, output_shapes, inputs = self.get_output_info(*inputs)

def safe_const_val(x):
try:
Expand All @@ -151,7 +178,31 @@ def infer_shape(self, fgraph, node, shapes):
output_shapes = _calculate_shapes(bcast_shape, dim_sizes, self.signature[1])
return output_shapes

def L_op(
def L_op(self, inputs, outs, ograds):
# Compute grad with respect to broadcasted input
rval = self._bgrad(inputs, outs, ograds)

# sum out the broadcasted dimensions
for i, ipt in enumerate(inputs):
if isinstance(rval[i].type, (NullType, DisconnectedType)):
continue

# List of all the dimensions that are broadcastable for input[i] so
# we can sum over them
# TODO: only count dimensions that were effectively broadcasted
to_sum = [
j
for j, bcast in enumerate(ipt.type.broadcastable)
if bcast and not outs[0].broadcastable[j]
]

if to_sum:
sr = at_sum(rval[i], axis=to_sum, keepdims=True)
rval[i] = sr

return rval

def _bgrad(
self,
inputs: Sequence[Variable],
outputs: Sequence[Variable],
Expand All @@ -162,13 +213,14 @@ def L_op(
core_inputs = []
for _inp, _inp_sig in zip(inputs, self.signature[0]):
curr_dtype = _inp.type.dtype
curr_static_shape = _inp.type.shape[: len(_inp_sig)]
# extract the core dimensions
curr_static_shape = _inp.type.shape[-len(_inp_sig) :]
core_inputs.append(TensorType(curr_dtype, curr_static_shape)())

core_out_grads = []
for _out_grad, _out_sig in zip(ograds, self.signature[1]):
curr_dtype = _out_grad.type.dtype
curr_static_shape = _out_grad.type.shape[: len(_out_sig)]
curr_static_shape = _out_grad.type.shape[-len(_out_sig) :]
core_out_grads.append(TensorType(curr_dtype, curr_static_shape)())

core_outputs: Sequence[Variable] = self.op.make_node(*core_inputs).outputs
Expand Down Expand Up @@ -209,9 +261,19 @@ def transform(var: "TensorVariable", client_node: Optional[Apply]) -> Variable:
grad_signature = getattr(node.op, "gufunc_sig", None)

if grad_signature is None:
# TODO: Can we manually derive gufunc signatures for any `Op`
# in this situation?
grad_signature = None
if isinstance(node.op, DimShuffle):
# remove the extra dimensions that
# we have added during op creation
new_order = [i for i in node.op.new_order if i != "x"]

# derive gufunc signature for DimShuffle
input_signature = tuple([f"a{i}" for i in range(len(new_order))])
output_signature = tuple([f"a{i}" for i in new_order])
grad_signature = ((input_signature,), (output_signature,))
else:
raise ValueError(
f"'{node.op}' object has no attribute 'gufunc_sig'"
)

new_r = Blockwise(node.op, signature=grad_signature)(*blocked_inputs)
assert isinstance(new_r, Variable)
Expand All @@ -229,7 +291,10 @@ def py_func(*inner_inputs):
# TODO:This can be avoided by making a single dummy node
# But will that cover all cases?
inner_node = self.op.make_node(*inner_inputs)
self.op.perform(inner_node, inner_inputs, res)
if isinstance(self.op, DimShuffle):
self.op.perform(inner_node, inner_inputs, res, params=None)
else:
self.op.perform(inner_node, inner_inputs, res)

# Numpy always expects outputs to be Numpy arrays
# And since we have a variable number of outputs
Expand Down

0 comments on commit 73bffe8

Please sign in to comment.