diff --git a/pytensor/tensor/extra_ops.py b/pytensor/tensor/extra_ops.py index bffc65a58c..ca57fee85c 100644 --- a/pytensor/tensor/extra_ops.py +++ b/pytensor/tensor/extra_ops.py @@ -1434,6 +1434,13 @@ def ravel_multi_index(multi_index, dims, mode="raise", order="C"): return RavelMultiIndex(mode=mode, order=order)(*args) +_broadcast_assert = Assert( + "Could not broadcast dimensions. Broadcasting is only allowed along " + "axes that have a statically known length 1. Use `specify_shape` to " + "inform PyTensor of a known shape." +) + + def broadcast_shape(*arrays, **kwargs) -> Tuple[aes.ScalarVariable, ...]: """Compute the shape resulting from broadcasting arrays. @@ -1510,20 +1517,19 @@ def broadcast_shape_iter( for dim_shapes in zip(*array_shapes): # Get the shapes in this dimension that are not broadcastable # (i.e. not symbolically known to be broadcastable) - maybe_non_bcast_shapes = [shape for shape in dim_shapes if shape != one_at] + non_bcast_shapes = [shape for shape in dim_shapes if shape != one_at] - if len(maybe_non_bcast_shapes) == 0: + if len(non_bcast_shapes) == 0: # Every shape was broadcastable in this dimension result_dims.append(one_at) - elif len(maybe_non_bcast_shapes) == 1: + elif len(non_bcast_shapes) == 1: # Only one shape might not be broadcastable in this dimension - result_dims.extend(maybe_non_bcast_shapes) + result_dims.extend(non_bcast_shapes) else: # More than one shape might not be broadcastable in this dimension - nonconst_nb_shapes: Set[int] = set() const_nb_shapes: Set[Variable] = set() - for shape in maybe_non_bcast_shapes: + for shape in non_bcast_shapes: if isinstance(shape, Constant): const_nb_shapes.add(shape.value.item()) else: @@ -1534,7 +1540,6 @@ def broadcast_shape_iter( f"Could not broadcast dimensions. Incompatible shapes were {array_shapes}." ) - assert_op = Assert("Could not dynamically broadcast dimensions.") if len(const_nb_shapes) == 1: (first_length,) = const_nb_shapes other_lengths = nonconst_nb_shapes @@ -1547,23 +1552,8 @@ def broadcast_shape_iter( continue # Add assert that all remaining shapes are equal - use_scalars = False - if use_scalars: - condition = None - for other in other_lengths: - cond = aes.eq(first_length, other) - if condition is None: - condition = cond - else: - condition = aes.and_(condition, cond) - else: - condition = pt_all( - [pt_eq(first_length, other) for other in other_lengths] - ) - if condition is None: - result_dims.append(first_length) - else: - result_dims.append(assert_op(first_length, condition)) + condition = pt_all([pt_eq(first_length, other) for other in other_lengths]) + result_dims.append(_broadcast_assert(first_length, condition)) return tuple(result_dims)