diff --git a/aesara/tensor/extra_ops.py b/aesara/tensor/extra_ops.py index 03e6fffbf3..28b54ba75c 100644 --- a/aesara/tensor/extra_ops.py +++ b/aesara/tensor/extra_ops.py @@ -23,6 +23,7 @@ from aesara.raise_op import Assert from aesara.scalar import int32 as int_t from aesara.scalar import upcast +from aesara.scalar.basic import Composite from aesara.tensor import basic as at from aesara.tensor import get_vector_length from aesara.tensor.exceptions import NotScalarConstantError @@ -1552,16 +1553,29 @@ def broadcast_shape_iter( # be broadcastable or equal to the one non-broadcastable # constant `const_nt_shape_var`. assert_dim = Assert("Could not broadcast dimensions") + + scalar_nonconst_nb_shapes = [ + at.scalar_from_tensor(s) if isinstance(s, TensorVariable) else s + for s in nonconst_nb_shapes + ] + + dummy_nonconst_nb_shapes = [ + v.type() for v in scalar_nonconst_nb_shapes + ] assert_cond = reduce( aes.and_, ( aes.or_( aes.eq(nbv, one_at), aes.eq(nbv, const_nt_shape_var) ) - for nbv in nonconst_nb_shapes + for nbv in dummy_nonconst_nb_shapes ), ) - bcast_dim = assert_dim(const_nt_shape_var, assert_cond) + assert_cond_op = Composite(dummy_nonconst_nb_shapes, [assert_cond]) + + bcast_dim = assert_dim( + const_nt_shape_var, assert_cond_op(*scalar_nonconst_nb_shapes) + ) else: bcast_dim = const_nt_shape_var else: @@ -1579,21 +1593,36 @@ def broadcast_shape_iter( result_dims.append(maybe_non_bcast_shapes[0]) continue + scalar_maybe_non_bcast_shapes = [ + at.scalar_from_tensor(s) if isinstance(s, TensorVariable) else s + for s in maybe_non_bcast_shapes + ] + dummy_maybe_non_bcast_shapes = [ + v.type() for v in scalar_maybe_non_bcast_shapes + ] non_bcast_vec = [ aes.switch(aes.eq(nbv, 1), -one_at, nbv) - for nbv in maybe_non_bcast_shapes + for nbv in dummy_maybe_non_bcast_shapes ] dim_max = aes.abs(reduce(aes.scalar_maximum, non_bcast_vec)) + dim_max_op = Composite(dummy_maybe_non_bcast_shapes, [dim_max]) + + dummy_dim_max = dim_max_op(*dummy_maybe_non_bcast_shapes) assert_dim = Assert("Could not broadcast dimensions") assert_cond = reduce( aes.and_, ( - aes.or_(aes.eq(nbv, -one_at), aes.eq(nbv, dim_max)) + aes.or_(aes.eq(nbv, -one_at), aes.eq(nbv, dummy_dim_max)) for nbv in non_bcast_vec ), ) - bcast_dim = assert_dim(dim_max, assert_cond) + assert_cond_op = Composite(dummy_maybe_non_bcast_shapes, [assert_cond]) + + bcast_dim = assert_dim( + dim_max_op(*scalar_maybe_non_bcast_shapes), + assert_cond_op(*scalar_maybe_non_bcast_shapes), + ) result_dims.append(bcast_dim)