diff --git a/pytensor/tensor/extra_ops.py b/pytensor/tensor/extra_ops.py index bd3cf71fb4..ca57fee85c 100644 --- a/pytensor/tensor/extra_ops.py +++ b/pytensor/tensor/extra_ops.py @@ -1,9 +1,7 @@ from collections.abc import Collection -from functools import reduce from typing import Iterable, Set, Tuple, Union import numpy as np -import numpy.core.numeric from numpy.core.multiarray import normalize_axis_index import pytensor @@ -14,7 +12,7 @@ disconnected_type, grad_undefined, ) -from pytensor.graph.basic import Apply, Constant, Variable, equal_computations +from pytensor.graph.basic import Apply, Constant, Variable from pytensor.graph.op import Op from pytensor.link.c.op import COp from pytensor.link.c.params_type import ParamsType @@ -23,12 +21,12 @@ from pytensor.raise_op import Assert from pytensor.scalar import int32 as int_t from pytensor.scalar import upcast -from pytensor.scalar.basic import Composite from pytensor.tensor import basic as at from pytensor.tensor import get_vector_length from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.math import abs as at_abs -from pytensor.tensor.math import all as at_all +from pytensor.tensor.math import all as pt_all +from pytensor.tensor.math import eq as pt_eq from pytensor.tensor.math import ge, lt, maximum, minimum, prod from pytensor.tensor.math import sum as at_sum from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor @@ -536,7 +534,7 @@ def bincount(x, weights=None, minlength=None, assert_nonneg=False): if assert_nonneg: assert_op = Assert("Input to bincount has negative values!") - x = assert_op(x, at_all(x >= 0)) + x = assert_op(x, pt_all(x >= 0)) max_value = at.cast(x.max() + 1, "int64") @@ -1436,6 +1434,13 @@ def ravel_multi_index(multi_index, dims, mode="raise", order="C"): return RavelMultiIndex(mode=mode, order=order)(*args) +_broadcast_assert = Assert( + "Could not broadcast dimensions. Broadcasting is only allowed along " + "axes that have a statically known length 1. Use `specify_shape` to " + "inform PyTensor of a known shape." +) + + def broadcast_shape(*arrays, **kwargs) -> Tuple[aes.ScalarVariable, ...]: """Compute the shape resulting from broadcasting arrays. @@ -1510,119 +1515,45 @@ def broadcast_shape_iter( result_dims = [] for dim_shapes in zip(*array_shapes): - # Get the shapes in this dimension that are not definitively - # broadcastable (i.e. not symbolically known to be broadcastable) - maybe_non_bcast_shapes = [shape for shape in dim_shapes if shape != one_at] + # Get the shapes in this dimension that are not broadcastable + # (i.e. not symbolically known to be broadcastable) + non_bcast_shapes = [shape for shape in dim_shapes if shape != one_at] - if len(maybe_non_bcast_shapes) == 0: + if len(non_bcast_shapes) == 0: # Every shape was broadcastable in this dimension result_dims.append(one_at) - elif len(maybe_non_bcast_shapes) == 1: + elif len(non_bcast_shapes) == 1: # Only one shape might not be broadcastable in this dimension - result_dims.extend(maybe_non_bcast_shapes) + result_dims.extend(non_bcast_shapes) else: # More than one shape might not be broadcastable in this dimension - nonconst_nb_shapes: Set[int] = set() const_nb_shapes: Set[Variable] = set() - for shape in maybe_non_bcast_shapes: + for shape in non_bcast_shapes: if isinstance(shape, Constant): const_nb_shapes.add(shape.value.item()) else: nonconst_nb_shapes.add(shape) if len(const_nb_shapes) > 1: - raise ValueError("Could not broadcast dimensions") - elif len(const_nb_shapes) == 1: - (const_nb_shape,) = const_nb_shapes - - assert const_nb_shape != 1 - - const_nt_shape_var = pytensor.scalar.ScalarConstant( - pytensor.scalar.int64, const_nb_shape + raise ValueError( + f"Could not broadcast dimensions. Incompatible shapes were {array_shapes}." ) - if len(nonconst_nb_shapes) > 0: - # All the potential non-broadcast shapes need to either - # be broadcastable or equal to the one non-broadcastable - # constant `const_nt_shape_var`. - assert_dim = Assert("Could not broadcast dimensions") - - scalar_nonconst_nb_shapes = [ - at.scalar_from_tensor(s) - if isinstance(s.type, TensorType) - else s - for s in nonconst_nb_shapes - ] - - dummy_nonconst_nb_shapes = [ - aes.get_scalar_type(dtype=v.dtype)() - for v in scalar_nonconst_nb_shapes - ] - assert_cond = reduce( - aes.and_, - ( - aes.or_( - aes.eq(nbv, one_at), aes.eq(nbv, const_nt_shape_var) - ) - for nbv in dummy_nonconst_nb_shapes - ), - ) - assert_cond_op = Composite(dummy_nonconst_nb_shapes, [assert_cond]) - - bcast_dim = assert_dim( - const_nt_shape_var, assert_cond_op(*scalar_nonconst_nb_shapes) - ) - else: - bcast_dim = const_nt_shape_var + if len(const_nb_shapes) == 1: + (first_length,) = const_nb_shapes + other_lengths = nonconst_nb_shapes + first_length = aes.as_scalar(first_length) else: - # There are no constant, non-broadcastable shapes in this - # dimension. - - all_dims_equal = all( - # TODO FIXME: This is a largely deficient, and expensive, means - # of comparing graphs (and especially shapes) - equal_computations([maybe_non_bcast_shapes[0]], [dim]) - for dim in maybe_non_bcast_shapes[1:] - ) + first_length, *other_lengths = nonconst_nb_shapes - if all_dims_equal: - result_dims.append(maybe_non_bcast_shapes[0]) - continue - - scalar_maybe_non_bcast_shapes = [ - at.scalar_from_tensor(s) if isinstance(s.type, TensorType) else s - for s in maybe_non_bcast_shapes - ] - dummy_maybe_non_bcast_shapes = [ - aes.get_scalar_type(dtype=v.dtype)() - for v in scalar_maybe_non_bcast_shapes - ] - non_bcast_vec = [ - aes.switch(aes.eq(nbv, 1), -one_at, nbv) - for nbv in dummy_maybe_non_bcast_shapes - ] - dim_max = aes.abs(reduce(aes.scalar_maximum, non_bcast_vec)) - dim_max_op = Composite(dummy_maybe_non_bcast_shapes, [dim_max]) - - dummy_dim_max = dim_max_op(*dummy_maybe_non_bcast_shapes) - - assert_dim = Assert("Could not broadcast dimensions") - assert_cond = reduce( - aes.and_, - ( - aes.or_(aes.eq(nbv, -one_at), aes.eq(nbv, dummy_dim_max)) - for nbv in non_bcast_vec - ), - ) - assert_cond_op = Composite(dummy_maybe_non_bcast_shapes, [assert_cond]) - - bcast_dim = assert_dim( - dim_max_op(*scalar_maybe_non_bcast_shapes), - assert_cond_op(*scalar_maybe_non_bcast_shapes), - ) + if len(other_lengths) == 0: + result_dims.append(first_length) + continue - result_dims.append(bcast_dim) + # Add assert that all remaining shapes are equal + condition = pt_all([pt_eq(first_length, other) for other in other_lengths]) + result_dims.append(_broadcast_assert(first_length, condition)) return tuple(result_dims) diff --git a/tests/tensor/rewriting/test_basic.py b/tests/tensor/rewriting/test_basic.py index dd7c184073..fe2b795907 100644 --- a/tests/tensor/rewriting/test_basic.py +++ b/tests/tensor/rewriting/test_basic.py @@ -1703,8 +1703,12 @@ def verify_op_count(f, count, cls): ], ) def test_basic(self, expr, x_shape, y_shape): - x = at.tensor(dtype="int64", shape=(None,) * len(x_shape), name="x") - y = at.tensor(dtype="int64", shape=(None,) * len(y_shape), name="y") + x = at.tensor( + dtype="int64", shape=(1 if val == 1 else None for val in x_shape), name="x" + ) + y = at.tensor( + dtype="int64", shape=(1 if val == 1 else None for val in y_shape), name="y" + ) z = expr(x, y) z_opt = pytensor.function( @@ -1878,7 +1882,8 @@ def test_multi_input_single_alloc(self): mode=self.fast_run_mode, ) self.verify_op_count(func, 0, Alloc) - self.verify_op_count(func, 1, Assert) + # The second assert is from the shape check... + self.verify_op_count(func, 2, Assert) def test_misc(self): x = row(dtype=self.dtype) diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index f191e51357..f69879a51d 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -608,9 +608,10 @@ def test_mul_div_cases(self): ((dv / dy) / dv, [dv, dy], [dvv, dyv], 1, "float64"), ((fv / fy) / fv, [fv, fy], [fvv, fyv], 1, "float32"), # must broadcast as their is a dimshuffle in the computation - ((dx / dv) / dx, [dx, dv], [dxv, dvv], 1, "float64"), + # The broadcast leads to an extra elemwise to check compatibility + ((dx / dv) / dx, [dx, dv], [dxv, dvv], 2, "float64"), # topo: [Shape_i, Shape_i, Elemwise{reciprocal,no_inplace}(), Alloc] - ((fx / fv) / fx, [fx, fv], [fxv, fvv], 1, "float32"), + ((fx / fv) / fx, [fx, fv], [fxv, fvv], 2, "float32"), # topo: [Shape_i, Shape_i, Elemwise{reciprocal,no_inplace}(), Alloc] ] ): @@ -621,9 +622,12 @@ def test_mul_div_cases(self): elem = [t for t in topo if isinstance(t.op, Elemwise)] assert len(elem) == nb_elemwise assert isinstance(elem[0].op, (Elemwise,)) - assert isinstance( - elem[0].op.scalar_op, - (aes.basic.Reciprocal, aes.basic.TrueDiv), + assert any( + isinstance( + el.op.scalar_op, + (aes.basic.Reciprocal, aes.basic.TrueDiv), + ) + for el in elem ) assert out_dtype == out.dtype diff --git a/tests/tensor/test_extra_ops.py b/tests/tensor/test_extra_ops.py index 286e173cad..de419aec0a 100644 --- a/tests/tensor/test_extra_ops.py +++ b/tests/tensor/test_extra_ops.py @@ -1086,7 +1086,9 @@ def shape_tuple(x, use_bcast=True): assert any( isinstance(node.op, Assert) for node in applys_between([x_at, y_at], b_at) ) - assert np.array_equal([z.eval() for z in b_at], b.shape) + # This should fail because it would need dynamic broadcasting + with pytest.raises(AssertionError): + assert np.array_equal([z.eval() for z in b_at], b.shape) b_at = broadcast_shape(shape_tuple(x_at), shape_tuple(y_at), arrays_are_shapes=True) assert np.array_equal([z.eval() for z in b_at], b.shape) @@ -1183,8 +1185,8 @@ def test_broadcast_shape_constants(): @pytest.mark.parametrize( ("s1_vals", "s2_vals", "exp_res"), [ - ((2, 2), (1, 2), (2, 2)), - ((0, 2), (1, 2), (0, 2)), + ((2, 2), (1, 2), AssertionError), + ((0, 2), (1, 2), AssertionError), ((1, 2, 1), (2, 1, 2, 1), (2, 1, 2, 1)), ], ) @@ -1203,7 +1205,11 @@ def test_broadcast_shape_symbolic(s1_vals, s2_vals, exp_res): res = broadcast_shape(s1s, s2s, arrays_are_shapes=True) res = at.as_tensor(res) - assert tuple(res.eval(eval_point)) == exp_res + if exp_res is AssertionError: + with pytest.raises(AssertionError): + res.eval(eval_point) + else: + assert tuple(res.eval(eval_point)) == exp_res def test_broadcast_shape_symbolic_one_symbolic(): @@ -1395,7 +1401,7 @@ def test_inplace(self): def test_broadcast_arrays(): - x, y = at.dvector(), at.dmatrix() + x, y = at.tensor(shape=(1,), dtype="float64"), at.dmatrix() x_bcast, y_bcast = broadcast_arrays(x, y) py_mode = Mode("py", None) diff --git a/tests/unittest_tools.py b/tests/unittest_tools.py index e27e30003a..d8c1bd0876 100644 --- a/tests/unittest_tools.py +++ b/tests/unittest_tools.py @@ -255,7 +255,7 @@ def _compile_and_check( # Check that the Op is removed from the compiled function. if check_topo: topo_shape = shapes_function.maker.fgraph.toposort() - assert not any(isinstance(t.op, cls) for t in topo_shape) + assert not any(t in outputs for t in topo_shape) topo_out = outputs_function.maker.fgraph.toposort() assert any(isinstance(t.op, cls) for t in topo_out) # Check that the shape produced agrees with the actual shape.