Skip to content

Commit

Permalink
Disallow dynamic broadcasting
Browse files Browse the repository at this point in the history
  • Loading branch information
aseyboldt committed Jun 15, 2023
1 parent f4536c3 commit 0fbec99
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 109 deletions.
131 changes: 36 additions & 95 deletions pytensor/tensor/extra_ops.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from collections.abc import Collection
from functools import reduce
from typing import Iterable, Set, Tuple, Union

import numpy as np
import numpy.core.numeric
from numpy.core.multiarray import normalize_axis_index

import pytensor
Expand All @@ -14,7 +12,7 @@
disconnected_type,
grad_undefined,
)
from pytensor.graph.basic import Apply, Constant, Variable, equal_computations
from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.graph.op import Op
from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType
Expand All @@ -23,12 +21,12 @@
from pytensor.raise_op import Assert
from pytensor.scalar import int32 as int_t
from pytensor.scalar import upcast
from pytensor.scalar.basic import Composite
from pytensor.tensor import basic as at
from pytensor.tensor import get_vector_length
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import abs as at_abs
from pytensor.tensor.math import all as at_all
from pytensor.tensor.math import all as pt_all
from pytensor.tensor.math import eq as pt_eq
from pytensor.tensor.math import ge, lt, maximum, minimum, prod
from pytensor.tensor.math import sum as at_sum
from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor
Expand Down Expand Up @@ -536,7 +534,7 @@ def bincount(x, weights=None, minlength=None, assert_nonneg=False):

if assert_nonneg:
assert_op = Assert("Input to bincount has negative values!")
x = assert_op(x, at_all(x >= 0))
x = assert_op(x, pt_all(x >= 0))

max_value = at.cast(x.max() + 1, "int64")

Expand Down Expand Up @@ -1510,8 +1508,8 @@ def broadcast_shape_iter(
result_dims = []

for dim_shapes in zip(*array_shapes):
# Get the shapes in this dimension that are not definitively
# broadcastable (i.e. not symbolically known to be broadcastable)
# Get the shapes in this dimension that are not broadcastable
# (i.e. not symbolically known to be broadcastable)
maybe_non_bcast_shapes = [shape for shape in dim_shapes if shape != one_at]

if len(maybe_non_bcast_shapes) == 0:
Expand All @@ -1532,97 +1530,40 @@ def broadcast_shape_iter(
nonconst_nb_shapes.add(shape)

if len(const_nb_shapes) > 1:
raise ValueError("Could not broadcast dimensions")
elif len(const_nb_shapes) == 1:
(const_nb_shape,) = const_nb_shapes

assert const_nb_shape != 1

const_nt_shape_var = pytensor.scalar.ScalarConstant(
pytensor.scalar.int64, const_nb_shape
raise ValueError(
f"Could not broadcast dimensions. Incompatible shapes were {array_shapes}."
)

if len(nonconst_nb_shapes) > 0:
# All the potential non-broadcast shapes need to either
# be broadcastable or equal to the one non-broadcastable
# constant `const_nt_shape_var`.
assert_dim = Assert("Could not broadcast dimensions")

scalar_nonconst_nb_shapes = [
at.scalar_from_tensor(s)
if isinstance(s.type, TensorType)
else s
for s in nonconst_nb_shapes
]

dummy_nonconst_nb_shapes = [
aes.get_scalar_type(dtype=v.dtype)()
for v in scalar_nonconst_nb_shapes
]
assert_cond = reduce(
aes.and_,
(
aes.or_(
aes.eq(nbv, one_at), aes.eq(nbv, const_nt_shape_var)
)
for nbv in dummy_nonconst_nb_shapes
),
)
assert_cond_op = Composite(dummy_nonconst_nb_shapes, [assert_cond])

bcast_dim = assert_dim(
const_nt_shape_var, assert_cond_op(*scalar_nonconst_nb_shapes)
)
else:
bcast_dim = const_nt_shape_var
assert_op = Assert("Could not dynamically broadcast dimensions.")
if len(const_nb_shapes) == 1:
(first_length,) = const_nb_shapes
other_lengths = nonconst_nb_shapes
first_length = aes.as_scalar(first_length)
else:
# There are no constant, non-broadcastable shapes in this
# dimension.

all_dims_equal = all(
# TODO FIXME: This is a largely deficient, and expensive, means
# of comparing graphs (and especially shapes)
equal_computations([maybe_non_bcast_shapes[0]], [dim])
for dim in maybe_non_bcast_shapes[1:]
)

if all_dims_equal:
result_dims.append(maybe_non_bcast_shapes[0])
continue

scalar_maybe_non_bcast_shapes = [
at.scalar_from_tensor(s) if isinstance(s.type, TensorType) else s
for s in maybe_non_bcast_shapes
]
dummy_maybe_non_bcast_shapes = [
aes.get_scalar_type(dtype=v.dtype)()
for v in scalar_maybe_non_bcast_shapes
]
non_bcast_vec = [
aes.switch(aes.eq(nbv, 1), -one_at, nbv)
for nbv in dummy_maybe_non_bcast_shapes
]
dim_max = aes.abs(reduce(aes.scalar_maximum, non_bcast_vec))
dim_max_op = Composite(dummy_maybe_non_bcast_shapes, [dim_max])

dummy_dim_max = dim_max_op(*dummy_maybe_non_bcast_shapes)

assert_dim = Assert("Could not broadcast dimensions")
assert_cond = reduce(
aes.and_,
(
aes.or_(aes.eq(nbv, -one_at), aes.eq(nbv, dummy_dim_max))
for nbv in non_bcast_vec
),
)
assert_cond_op = Composite(dummy_maybe_non_bcast_shapes, [assert_cond])

bcast_dim = assert_dim(
dim_max_op(*scalar_maybe_non_bcast_shapes),
assert_cond_op(*scalar_maybe_non_bcast_shapes),
first_length, *other_lengths = nonconst_nb_shapes

if len(other_lengths) == 0:
result_dims.append(first_length)
continue

# Add assert that all remaining shapes are equal
use_scalars = False
if use_scalars:
condition = None
for other in other_lengths:
cond = aes.eq(first_length, other)
if condition is None:
condition = cond
else:
condition = aes.and_(condition, cond)
else:
condition = pt_all(
[pt_eq(first_length, other) for other in other_lengths]
)

result_dims.append(bcast_dim)
if condition is None:
result_dims.append(first_length)
else:
result_dims.append(assert_op(first_length, condition))

return tuple(result_dims)

Expand Down
11 changes: 8 additions & 3 deletions tests/tensor/rewriting/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1703,8 +1703,12 @@ def verify_op_count(f, count, cls):
],
)
def test_basic(self, expr, x_shape, y_shape):
x = at.tensor(dtype="int64", shape=(None,) * len(x_shape), name="x")
y = at.tensor(dtype="int64", shape=(None,) * len(y_shape), name="y")
x = at.tensor(
dtype="int64", shape=(1 if val == 1 else None for val in x_shape), name="x"
)
y = at.tensor(
dtype="int64", shape=(1 if val == 1 else None for val in y_shape), name="y"
)
z = expr(x, y)

z_opt = pytensor.function(
Expand Down Expand Up @@ -1878,7 +1882,8 @@ def test_multi_input_single_alloc(self):
mode=self.fast_run_mode,
)
self.verify_op_count(func, 0, Alloc)
self.verify_op_count(func, 1, Assert)
# The second assert is from the shape check...
self.verify_op_count(func, 2, Assert)

def test_misc(self):
x = row(dtype=self.dtype)
Expand Down
14 changes: 9 additions & 5 deletions tests/tensor/rewriting/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,9 +608,10 @@ def test_mul_div_cases(self):
((dv / dy) / dv, [dv, dy], [dvv, dyv], 1, "float64"),
((fv / fy) / fv, [fv, fy], [fvv, fyv], 1, "float32"),
# must broadcast as their is a dimshuffle in the computation
((dx / dv) / dx, [dx, dv], [dxv, dvv], 1, "float64"),
# The broadcast leads to an extra elemwise to check compatibility
((dx / dv) / dx, [dx, dv], [dxv, dvv], 2, "float64"),
# topo: [Shape_i, Shape_i, Elemwise{reciprocal,no_inplace}(<TensorType(float64, row)>), Alloc]
((fx / fv) / fx, [fx, fv], [fxv, fvv], 1, "float32"),
((fx / fv) / fx, [fx, fv], [fxv, fvv], 2, "float32"),
# topo: [Shape_i, Shape_i, Elemwise{reciprocal,no_inplace}(<TensorType(float32, row)>), Alloc]
]
):
Expand All @@ -621,9 +622,12 @@ def test_mul_div_cases(self):
elem = [t for t in topo if isinstance(t.op, Elemwise)]
assert len(elem) == nb_elemwise
assert isinstance(elem[0].op, (Elemwise,))
assert isinstance(
elem[0].op.scalar_op,
(aes.basic.Reciprocal, aes.basic.TrueDiv),
assert any(
isinstance(
el.op.scalar_op,
(aes.basic.Reciprocal, aes.basic.TrueDiv),
)
for el in elem
)
assert out_dtype == out.dtype

Expand Down
16 changes: 11 additions & 5 deletions tests/tensor/test_extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1086,7 +1086,9 @@ def shape_tuple(x, use_bcast=True):
assert any(
isinstance(node.op, Assert) for node in applys_between([x_at, y_at], b_at)
)
assert np.array_equal([z.eval() for z in b_at], b.shape)
# This should fail because it would need dynamic broadcasting
with pytest.raises(AssertionError):
assert np.array_equal([z.eval() for z in b_at], b.shape)
b_at = broadcast_shape(shape_tuple(x_at), shape_tuple(y_at), arrays_are_shapes=True)
assert np.array_equal([z.eval() for z in b_at], b.shape)

Expand Down Expand Up @@ -1183,8 +1185,8 @@ def test_broadcast_shape_constants():
@pytest.mark.parametrize(
("s1_vals", "s2_vals", "exp_res"),
[
((2, 2), (1, 2), (2, 2)),
((0, 2), (1, 2), (0, 2)),
((2, 2), (1, 2), AssertionError),
((0, 2), (1, 2), AssertionError),
((1, 2, 1), (2, 1, 2, 1), (2, 1, 2, 1)),
],
)
Expand All @@ -1203,7 +1205,11 @@ def test_broadcast_shape_symbolic(s1_vals, s2_vals, exp_res):
res = broadcast_shape(s1s, s2s, arrays_are_shapes=True)
res = at.as_tensor(res)

assert tuple(res.eval(eval_point)) == exp_res
if exp_res is AssertionError:
with pytest.raises(AssertionError):
res.eval(eval_point)
else:
assert tuple(res.eval(eval_point)) == exp_res


def test_broadcast_shape_symbolic_one_symbolic():
Expand Down Expand Up @@ -1395,7 +1401,7 @@ def test_inplace(self):


def test_broadcast_arrays():
x, y = at.dvector(), at.dmatrix()
x, y = at.tensor(shape=(1,), dtype="float64"), at.dmatrix()
x_bcast, y_bcast = broadcast_arrays(x, y)

py_mode = Mode("py", None)
Expand Down
2 changes: 1 addition & 1 deletion tests/unittest_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def _compile_and_check(
# Check that the Op is removed from the compiled function.
if check_topo:
topo_shape = shapes_function.maker.fgraph.toposort()
assert not any(isinstance(t.op, cls) for t in topo_shape)
assert not any(t in outputs for t in topo_shape)
topo_out = outputs_function.maker.fgraph.toposort()
assert any(isinstance(t.op, cls) for t in topo_out)
# Check that the shape produced agrees with the actual shape.
Expand Down

0 comments on commit 0fbec99

Please sign in to comment.