Skip to content

Commit

Permalink
Cosmetic improvements to dynamic broadcast checks
Browse files Browse the repository at this point in the history
  • Loading branch information
aseyboldt committed Jun 16, 2023
1 parent 0fbec99 commit 53c26d1
Showing 1 changed file with 14 additions and 24 deletions.
38 changes: 14 additions & 24 deletions pytensor/tensor/extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1434,6 +1434,13 @@ def ravel_multi_index(multi_index, dims, mode="raise", order="C"):
return RavelMultiIndex(mode=mode, order=order)(*args)


_broadcast_assert = Assert(
"Could not broadcast dimensions. Broadcasting is only allowed along "
"axes that have a statically known length 1. Use `specify_shape` to "
"inform PyTensor of a known shape."
)


def broadcast_shape(*arrays, **kwargs) -> Tuple[aes.ScalarVariable, ...]:
"""Compute the shape resulting from broadcasting arrays.
Expand Down Expand Up @@ -1510,20 +1517,19 @@ def broadcast_shape_iter(
for dim_shapes in zip(*array_shapes):
# Get the shapes in this dimension that are not broadcastable
# (i.e. not symbolically known to be broadcastable)
maybe_non_bcast_shapes = [shape for shape in dim_shapes if shape != one_at]
non_bcast_shapes = [shape for shape in dim_shapes if shape != one_at]

if len(maybe_non_bcast_shapes) == 0:
if len(non_bcast_shapes) == 0:
# Every shape was broadcastable in this dimension
result_dims.append(one_at)
elif len(maybe_non_bcast_shapes) == 1:
elif len(non_bcast_shapes) == 1:
# Only one shape might not be broadcastable in this dimension
result_dims.extend(maybe_non_bcast_shapes)
result_dims.extend(non_bcast_shapes)
else:
# More than one shape might not be broadcastable in this dimension

nonconst_nb_shapes: Set[int] = set()
const_nb_shapes: Set[Variable] = set()
for shape in maybe_non_bcast_shapes:
for shape in non_bcast_shapes:
if isinstance(shape, Constant):
const_nb_shapes.add(shape.value.item())
else:
Expand All @@ -1534,7 +1540,6 @@ def broadcast_shape_iter(
f"Could not broadcast dimensions. Incompatible shapes were {array_shapes}."
)

assert_op = Assert("Could not dynamically broadcast dimensions.")
if len(const_nb_shapes) == 1:
(first_length,) = const_nb_shapes
other_lengths = nonconst_nb_shapes
Expand All @@ -1547,23 +1552,8 @@ def broadcast_shape_iter(
continue

# Add assert that all remaining shapes are equal
use_scalars = False
if use_scalars:
condition = None
for other in other_lengths:
cond = aes.eq(first_length, other)
if condition is None:
condition = cond
else:
condition = aes.and_(condition, cond)
else:
condition = pt_all(
[pt_eq(first_length, other) for other in other_lengths]
)
if condition is None:
result_dims.append(first_length)
else:
result_dims.append(assert_op(first_length, condition))
condition = pt_all([pt_eq(first_length, other) for other in other_lengths])
result_dims.append(_broadcast_assert(first_length, condition))

return tuple(result_dims)

Expand Down

0 comments on commit 53c26d1

Please sign in to comment.