Skip to content

Commit

Permalink
Use Composite graphs in aesara.tensor.extra_ops.broadcast_shape_iter
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Oct 16, 2022
1 parent 63ca73d commit 9b7021c
Showing 1 changed file with 34 additions and 5 deletions.
39 changes: 34 additions & 5 deletions aesara/tensor/extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from aesara.raise_op import Assert
from aesara.scalar import int32 as int_t
from aesara.scalar import upcast
from aesara.scalar.basic import Composite
from aesara.tensor import basic as at
from aesara.tensor import get_vector_length
from aesara.tensor.exceptions import NotScalarConstantError
Expand Down Expand Up @@ -1552,16 +1553,29 @@ def broadcast_shape_iter(
# be broadcastable or equal to the one non-broadcastable
# constant `const_nt_shape_var`.
assert_dim = Assert("Could not broadcast dimensions")

scalar_nonconst_nb_shapes = [
at.scalar_from_tensor(s) if isinstance(s, TensorVariable) else s
for s in nonconst_nb_shapes
]

dummy_nonconst_nb_shapes = [
v.type() for v in scalar_nonconst_nb_shapes
]
assert_cond = reduce(
aes.and_,
(
aes.or_(
aes.eq(nbv, one_at), aes.eq(nbv, const_nt_shape_var)
)
for nbv in nonconst_nb_shapes
for nbv in dummy_nonconst_nb_shapes
),
)
bcast_dim = assert_dim(const_nt_shape_var, assert_cond)
assert_cond_op = Composite(dummy_nonconst_nb_shapes, [assert_cond])

bcast_dim = assert_dim(
const_nt_shape_var, assert_cond_op(*scalar_nonconst_nb_shapes)
)
else:
bcast_dim = const_nt_shape_var
else:
Expand All @@ -1579,21 +1593,36 @@ def broadcast_shape_iter(
result_dims.append(maybe_non_bcast_shapes[0])
continue

scalar_maybe_non_bcast_shapes = [
at.scalar_from_tensor(s) if isinstance(s, TensorVariable) else s
for s in maybe_non_bcast_shapes
]
dummy_maybe_non_bcast_shapes = [
v.type() for v in scalar_maybe_non_bcast_shapes
]
non_bcast_vec = [
aes.switch(aes.eq(nbv, 1), -one_at, nbv)
for nbv in maybe_non_bcast_shapes
for nbv in dummy_maybe_non_bcast_shapes
]
dim_max = aes.abs(reduce(aes.scalar_maximum, non_bcast_vec))
dim_max_op = Composite(dummy_maybe_non_bcast_shapes, [dim_max])

dummy_dim_max = dim_max_op(*dummy_maybe_non_bcast_shapes)

assert_dim = Assert("Could not broadcast dimensions")
assert_cond = reduce(
aes.and_,
(
aes.or_(aes.eq(nbv, -one_at), aes.eq(nbv, dim_max))
aes.or_(aes.eq(nbv, -one_at), aes.eq(nbv, dummy_dim_max))
for nbv in non_bcast_vec
),
)
bcast_dim = assert_dim(dim_max, assert_cond)
assert_cond_op = Composite(dummy_maybe_non_bcast_shapes, [assert_cond])

bcast_dim = assert_dim(
dim_max_op(*scalar_maybe_non_bcast_shapes),
assert_cond_op(*scalar_maybe_non_bcast_shapes),
)

result_dims.append(bcast_dim)

Expand Down

0 comments on commit 9b7021c

Please sign in to comment.