Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove strict TensorType.broadcastable usage from local_elemwise_alloc #1102

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 60 additions & 90 deletions aesara/tensor/basic_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,13 @@
)
from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.exceptions import NotScalarConstantError, ShapeError
from aesara.tensor.extra_ops import BroadcastTo, Repeat, Unique, broadcast_shape
from aesara.tensor.extra_ops import (
BroadcastTo,
Repeat,
Unique,
broadcast_shape,
broadcast_to,
)
from aesara.tensor.math import all as at_all
from aesara.tensor.math import eq
from aesara.tensor.shape import (
Expand Down Expand Up @@ -1491,26 +1497,11 @@ def local_elemwise_alloc(fgraph, node):
introduces them as a canonicalization of `Alloc`'s with leading
broadcastable dimensions.
"""
if not isinstance(node.op, Elemwise):
return False

# Rewrite is only applicable when there are at least two inputs
if len(node.inputs) == 1:
return None
return False

if len(node.outputs) > 1:
# Ensure all outputs have the same broadcast pattern
# This is a supposition that I'm not sure is always true.
assert all(
o.type.broadcastable == node.outputs[0].type.broadcastable
for o in node.outputs[1:]
)

# The broadcast pattern of the output must match the broadcast
# pattern of at least one of the inputs.
if not any(
i.type.broadcastable == node.outputs[0].type.broadcastable for i in node.inputs
):
return False

def dimshuffled_alloc(i):
Expand All @@ -1523,103 +1514,82 @@ def dimshuffled_alloc(i):
# At least one input must have an owner that is either a `Alloc` or a
# `DimShuffle` with an owner that is a `Alloc` -- otherwise there is
# nothing to optimize.
if not any(
i.owner and (isinstance(i.owner.op, Alloc) or dimshuffled_alloc(i))
for i in node.inputs
):
alloc_idxs = [
idx
for idx, i in enumerate(node.inputs)
if i.owner and (isinstance(i.owner.op, Alloc) or dimshuffled_alloc(i))
]
if len(alloc_idxs) == 0:
return False

# Search for a non `Alloc` or `DimShuffle` of `Alloc` input that we can use as a
# baseline for the dimensions.
assert_op_idx = None
ref_var_idx = None
for idx, i in enumerate(node.inputs):
if i.type.broadcastable == node.outputs[0].type.broadcastable:
# Prefer an input that is not a `Alloc` nor a `DimShuffle` of a
# `Alloc` so that all `Alloc`s can be optimized.
if not (
i.owner and (isinstance(i.owner.op, Alloc) or dimshuffled_alloc(i))
):
assert_op_idx = idx
# Prefer an input that is not an `Alloc` nor a `DimShuffle` of an
# `Alloc`, so that all `Alloc`s can be optimized.
if idx not in alloc_idxs:
ref_var_idx = idx
break

# If only `Alloc` and `DimShuffle` of `Alloc` exist, we pick the first suitable one
if assert_op_idx is None:
if ref_var_idx is None:
for idx, i in enumerate(node.inputs):
if (i.type.broadcastable == node.outputs[0].type.broadcastable) and (
i.owner and (isinstance(i.owner.op, Alloc) or dimshuffled_alloc(i))
):
assert_op_idx = idx
# XXX: This broadcastable comparison doesn't work
if (
i.type.broadcastable == node.outputs[0].type.broadcastable
) and idx in alloc_idxs:
ref_var_idx = idx
break

assert_op_in = node.inputs[assert_op_idx]
cmp_op = assert_op_in
new_i = []
same_shape = fgraph.shape_feature.same_shape
for i in node.inputs:
if not hasattr(fgraph, "shape_feature"):
return False

input_shapes = [
tuple(fgraph.shape_feature.get_shape(i, j) for j in range(i.type.ndim))
for i in node.inputs
]
bcasted_shape = broadcast_shape(
*input_shapes,
arrays_are_shapes=True,
)

new_inputs = list(node.inputs)
for idx in alloc_idxs:
i = node.inputs[idx]

# Remove `Alloc`
if i.owner and isinstance(i.owner.op, Alloc):
assert i.type.ndim == cmp_op.ndim
if config.experimental__local_alloc_elemwise_assert:
get_shape = fgraph.shape_feature.get_shape
cond = []
for idx in range(i.type.ndim):
if not i.type.broadcastable[idx] and not same_shape(
i, cmp_op, idx, idx
):
i_shp = get_shape(i, idx)
cmp_shp = get_shape(cmp_op, idx)
cond.append(eq(i_shp, cmp_shp))
if cond:
assert_op_in = assert_op(assert_op_in, *cond)
alloc_input = i.owner.inputs[0]
if alloc_input.ndim != i.ndim:
# The `Alloc` can add dimensions to the value.
# We replace those cases with a `DimShuffle` here.
nb_dim_to_add = i.ndim - alloc_input.ndim
alloc_input = alloc_input.dimshuffle(
["x"] * nb_dim_to_add + list(range(alloc_input.ndim))
)
copy_stack_trace(i, alloc_input)
new_i.append(alloc_input)
if isinstance(i.owner.op, Alloc):
new_alloc = broadcast_to(i.owner.inputs[0], bcasted_shape)

# TODO FIXME: This shouldn't be handled here.
# `DimShuffle`s should be lifted through `Alloc`s
# by other, more general rewrites.
# Remove `Alloc` in `DimShuffle`
elif i.owner and dimshuffled_alloc(i):
assert i.type.ndim == cmp_op.type.ndim
if config.experimental__local_alloc_elemwise_assert:
assert_cond = [
eq(i.shape[idx], cmp_op.shape[idx])
for idx in range(i.type.ndim)
if not i.type.broadcastable[idx]
and not same_shape(i, cmp_op, idx, idx)
]
if assert_cond:
assert_op_in = assert_op(assert_op_in, *assert_cond)
alloc_input = i.owner.inputs[0].owner.inputs[0]
if alloc_input.ndim != i.owner.inputs[0].ndim:
elif isinstance(i.owner.op, DimShuffle):
old_alloc = i.owner.inputs[0]
new_alloc = old_alloc.owner.inputs[0]
# We need to keep the old `DimShuffle`. It could swap axes or
# add dimensions anywhere.
if new_alloc.ndim != old_alloc.ndim:
# The `Alloc` can add dimensions to the value.
# We replace those cases with a `DimShuffle` here.
# We let later optimizations merge the nested `DimShuffle`s
nb_dim_to_add = i.owner.inputs[0].ndim - alloc_input.ndim
alloc_input = alloc_input.dimshuffle(
["x"] * nb_dim_to_add + list(range(alloc_input.ndim))
nb_dim_to_add = old_alloc.ndim - new_alloc.ndim
new_alloc = new_alloc.dimshuffle(
["x"] * nb_dim_to_add + list(range(new_alloc.ndim))
)
new_alloc = broadcast_to(i.owner.op(new_alloc), bcasted_shape)

# We need to keep the old `DimShuffle`. It could swap axes or
# add dimensions anywhere.
r_i = i.owner.op(alloc_input)
copy_stack_trace(i, r_i)
new_i.append(r_i)

else:
new_i.append(i)
new_i[assert_op_idx] = assert_op_in
copy_stack_trace(i, new_alloc)
new_inputs[idx] = new_alloc

# If this assert is triggered, it means we are recreating an equivalent graph
# which would result in a cyclical merge optimization.
if all(new is old for new, old in zip(new_i, node.inputs)):
if all(new is old for new, old in zip(new_inputs, node.inputs)):
return

ret = node.op(*new_i, return_list=True)
ret = node.op(*new_inputs, return_list=True)
copy_stack_trace(node.outputs, ret)
return ret

Expand Down
86 changes: 60 additions & 26 deletions aesara/tensor/extra_ops.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections.abc import Collection
from functools import reduce
from typing import Iterable, Tuple, Union
from typing import Iterable, Set, Tuple, Union

import numpy as np
import numpy.core.numeric
Expand All @@ -14,7 +14,7 @@
disconnected_type,
grad_undefined,
)
from aesara.graph.basic import Apply, Variable, equal_computations
from aesara.graph.basic import Apply, Constant, Variable, equal_computations
from aesara.graph.op import Op
from aesara.link.c.op import COp
from aesara.link.c.params_type import ParamsType
Expand Down Expand Up @@ -1491,7 +1491,12 @@ def broadcast_shape_iter(

array_shapes = [
(one_at,) * (max_dims - len(a))
+ tuple(one_at if getattr(sh, "value", sh) == 1 else sh for sh in a)
+ tuple(
one_at
if getattr(sh, "value", sh) == 1
else (aes.as_scalar(sh) if not isinstance(sh, Variable) else sh)
for sh in a
)
for a in arrays
]
else:
Expand Down Expand Up @@ -1523,32 +1528,61 @@ def broadcast_shape_iter(
else:
# More than one shape might not be broadcastable in this dimension

all_dims_equal = all(
# TODO FIXME: This is a largely deficient means of comparing graphs
# (and especially shapes)
equal_computations([maybe_non_bcast_shapes[0]], [dim])
for dim in maybe_non_bcast_shapes[1:]
)
nonconst_nb_shapes: Set[int] = set()
const_nb_shapes: Set[Variable] = set()
for shape in maybe_non_bcast_shapes:
if isinstance(shape, Constant):
const_nb_shapes.add(shape.value.item())
else:
nonconst_nb_shapes.add(shape)

if all_dims_equal:
result_dims.append(maybe_non_bcast_shapes[0])
continue
if len(const_nb_shapes) > 1:
raise ValueError("Could not broadcast dimensions")
elif len(const_nb_shapes) == 1:
(const_nb_shape,) = const_nb_shapes

non_bcast_vec = [
aes.switch(aes.eq(nbv, 1), -one_at, nbv)
for nbv in maybe_non_bcast_shapes
]
dim_max = aes.abs(reduce(aes.scalar_maximum, non_bcast_vec))
assert const_nb_shape != 1

assert_dim = Assert("Could not broadcast dimensions")
assert_cond = reduce(
aes.and_,
(
aes.or_(aes.eq(nbv, -one_at), aes.eq(nbv, dim_max))
for nbv in non_bcast_vec
),
)
bcast_dim = assert_dim(dim_max, assert_cond)
const_nt_shape_var = aesara.scalar.ScalarConstant(
aesara.scalar.int64, const_nb_shape
)

if len(nonconst_nb_shapes) > 0:
assert_dim = Assert("Could not broadcast dimensions")
assert_cond = reduce(
aes.and_,
(aes.eq(nbv, const_nt_shape_var) for nbv in nonconst_nb_shapes),
)
bcast_dim = assert_dim(const_nt_shape_var, assert_cond)
else:
bcast_dim = const_nt_shape_var
else:
all_dims_equal = all(
# TODO FIXME: This is a largely deficient, and expensive, means
# of comparing graphs (and especially shapes)
equal_computations([maybe_non_bcast_shapes[0]], [dim])
for dim in maybe_non_bcast_shapes[1:]
)

if all_dims_equal:
result_dims.append(maybe_non_bcast_shapes[0])
continue

non_bcast_vec = [
aes.switch(aes.eq(nbv, 1), -one_at, nbv)
for nbv in maybe_non_bcast_shapes
]
dim_max = aes.abs(reduce(aes.scalar_maximum, non_bcast_vec))

assert_dim = Assert("Could not broadcast dimensions")
assert_cond = reduce(
aes.and_,
(
aes.or_(aes.eq(nbv, -one_at), aes.eq(nbv, dim_max))
for nbv in non_bcast_vec
),
)
bcast_dim = assert_dim(dim_max, assert_cond)

result_dims.append(bcast_dim)

Expand Down
Loading