Skip to content

Commit

Permalink
Remove strict TensorType.broadcastable usage from local_elemwise_alloc
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Aug 13, 2022
1 parent 9ef9c8c commit f604e1f
Show file tree
Hide file tree
Showing 3 changed files with 345 additions and 343 deletions.
150 changes: 60 additions & 90 deletions aesara/tensor/basic_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,13 @@
)
from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.exceptions import NotScalarConstantError, ShapeError
from aesara.tensor.extra_ops import BroadcastTo, Repeat, Unique, broadcast_shape
from aesara.tensor.extra_ops import (
BroadcastTo,
Repeat,
Unique,
broadcast_shape,
broadcast_to,
)
from aesara.tensor.math import all as at_all
from aesara.tensor.math import eq
from aesara.tensor.shape import (
Expand Down Expand Up @@ -1491,26 +1497,11 @@ def local_elemwise_alloc(fgraph, node):
introduces them as a canonicalization of `Alloc`'s with leading
broadcastable dimensions.
"""
if not isinstance(node.op, Elemwise):
return False

# Rewrite is only applicable when there are at least two inputs
if len(node.inputs) == 1:
return None
return False

if len(node.outputs) > 1:
# Ensure all outputs have the same broadcast pattern
# This is a supposition that I'm not sure is always true.
assert all(
o.type.broadcastable == node.outputs[0].type.broadcastable
for o in node.outputs[1:]
)

# The broadcast pattern of the output must match the broadcast
# pattern of at least one of the inputs.
if not any(
i.type.broadcastable == node.outputs[0].type.broadcastable for i in node.inputs
):
return False

def dimshuffled_alloc(i):
Expand All @@ -1523,103 +1514,82 @@ def dimshuffled_alloc(i):
# At least one input must have an owner that is either a `Alloc` or a
# `DimShuffle` with an owner that is a `Alloc` -- otherwise there is
# nothing to optimize.
if not any(
i.owner and (isinstance(i.owner.op, Alloc) or dimshuffled_alloc(i))
for i in node.inputs
):
alloc_idxs = [
idx
for idx, i in enumerate(node.inputs)
if i.owner and (isinstance(i.owner.op, Alloc) or dimshuffled_alloc(i))
]
if len(alloc_idxs) == 0:
return False

# Search for a non `Alloc` or `DimShuffle` of `Alloc` input that we can use as a
# baseline for the dimensions.
assert_op_idx = None
ref_var_idx = None
for idx, i in enumerate(node.inputs):
if i.type.broadcastable == node.outputs[0].type.broadcastable:
# Prefer an input that is not a `Alloc` nor a `DimShuffle` of a
# `Alloc` so that all `Alloc`s can be optimized.
if not (
i.owner and (isinstance(i.owner.op, Alloc) or dimshuffled_alloc(i))
):
assert_op_idx = idx
# Prefer an input that is not an `Alloc` nor a `DimShuffle` of an
# `Alloc`, so that all `Alloc`s can be optimized.
if idx not in alloc_idxs:
ref_var_idx = idx
break

# If only `Alloc` and `DimShuffle` of `Alloc` exist, we pick the first suitable one
if assert_op_idx is None:
if ref_var_idx is None:
for idx, i in enumerate(node.inputs):
if (i.type.broadcastable == node.outputs[0].type.broadcastable) and (
i.owner and (isinstance(i.owner.op, Alloc) or dimshuffled_alloc(i))
):
assert_op_idx = idx
# XXX: This broadcastable comparison doesn't work
if (
i.type.broadcastable == node.outputs[0].type.broadcastable
) and idx in alloc_idxs:
ref_var_idx = idx
break

assert_op_in = node.inputs[assert_op_idx]
cmp_op = assert_op_in
new_i = []
same_shape = fgraph.shape_feature.same_shape
for i in node.inputs:
if not hasattr(fgraph, "shape_feature"):
return False

input_shapes = [
tuple(fgraph.shape_feature.get_shape(i, j) for j in range(i.type.ndim))
for i in node.inputs
]
bcasted_shape = broadcast_shape(
*input_shapes,
arrays_are_shapes=True,
)

new_inputs = list(node.inputs)
for idx in alloc_idxs:
i = node.inputs[idx]

# Remove `Alloc`
if i.owner and isinstance(i.owner.op, Alloc):
assert i.type.ndim == cmp_op.ndim
if config.experimental__local_alloc_elemwise_assert:
get_shape = fgraph.shape_feature.get_shape
cond = []
for idx in range(i.type.ndim):
if not i.type.broadcastable[idx] and not same_shape(
i, cmp_op, idx, idx
):
i_shp = get_shape(i, idx)
cmp_shp = get_shape(cmp_op, idx)
cond.append(eq(i_shp, cmp_shp))
if cond:
assert_op_in = assert_op(assert_op_in, *cond)
alloc_input = i.owner.inputs[0]
if alloc_input.ndim != i.ndim:
# The `Alloc` can add dimensions to the value.
# We replace those cases with a `DimShuffle` here.
nb_dim_to_add = i.ndim - alloc_input.ndim
alloc_input = alloc_input.dimshuffle(
["x"] * nb_dim_to_add + list(range(alloc_input.ndim))
)
copy_stack_trace(i, alloc_input)
new_i.append(alloc_input)
if isinstance(i.owner.op, Alloc):
new_alloc = broadcast_to(i.owner.inputs[0], bcasted_shape)

# TODO FIXME: This shouldn't be handled here.
# `DimShuffle`s should be lifted through `Alloc`s
# by other, more general rewrites.
# Remove `Alloc` in `DimShuffle`
elif i.owner and dimshuffled_alloc(i):
assert i.type.ndim == cmp_op.type.ndim
if config.experimental__local_alloc_elemwise_assert:
assert_cond = [
eq(i.shape[idx], cmp_op.shape[idx])
for idx in range(i.type.ndim)
if not i.type.broadcastable[idx]
and not same_shape(i, cmp_op, idx, idx)
]
if assert_cond:
assert_op_in = assert_op(assert_op_in, *assert_cond)
alloc_input = i.owner.inputs[0].owner.inputs[0]
if alloc_input.ndim != i.owner.inputs[0].ndim:
elif isinstance(i.owner.op, DimShuffle):
old_alloc = i.owner.inputs[0]
new_alloc = old_alloc.owner.inputs[0]
# We need to keep the old `DimShuffle`. It could swap axes or
# add dimensions anywhere.
if new_alloc.ndim != old_alloc.ndim:
# The `Alloc` can add dimensions to the value.
# We replace those cases with a `DimShuffle` here.
# We let later optimizations merge the nested `DimShuffle`s
nb_dim_to_add = i.owner.inputs[0].ndim - alloc_input.ndim
alloc_input = alloc_input.dimshuffle(
["x"] * nb_dim_to_add + list(range(alloc_input.ndim))
nb_dim_to_add = old_alloc.ndim - new_alloc.ndim
new_alloc = new_alloc.dimshuffle(
["x"] * nb_dim_to_add + list(range(new_alloc.ndim))
)
new_alloc = broadcast_to(i.owner.op(new_alloc), bcasted_shape)

# We need to keep the old `DimShuffle`. It could swap axes or
# add dimensions anywhere.
r_i = i.owner.op(alloc_input)
copy_stack_trace(i, r_i)
new_i.append(r_i)

else:
new_i.append(i)
new_i[assert_op_idx] = assert_op_in
copy_stack_trace(i, new_alloc)
new_inputs[idx] = new_alloc

# If this assert is triggered, it means we are recreating an equivalent graph
# which would result in a cyclical merge optimization.
if all(new is old for new, old in zip(new_i, node.inputs)):
if all(new is old for new, old in zip(new_inputs, node.inputs)):
return

ret = node.op(*new_i, return_list=True)
ret = node.op(*new_inputs, return_list=True)
copy_stack_trace(node.outputs, ret)
return ret

Expand Down
Loading

0 comments on commit f604e1f

Please sign in to comment.