Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use static-only broadcasting rules to compute shape of broadcasting #345

Merged
merged 2 commits into from
Jun 17, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 36 additions & 95 deletions pytensor/tensor/extra_ops.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from collections.abc import Collection
from functools import reduce
from typing import Iterable, Set, Tuple, Union

import numpy as np
import numpy.core.numeric
from numpy.core.multiarray import normalize_axis_index

import pytensor
Expand All @@ -14,7 +12,7 @@
disconnected_type,
grad_undefined,
)
from pytensor.graph.basic import Apply, Constant, Variable, equal_computations
from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.graph.op import Op
from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType
Expand All @@ -23,12 +21,12 @@
from pytensor.raise_op import Assert
from pytensor.scalar import int32 as int_t
from pytensor.scalar import upcast
from pytensor.scalar.basic import Composite
from pytensor.tensor import basic as at
from pytensor.tensor import get_vector_length
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import abs as at_abs
from pytensor.tensor.math import all as at_all
from pytensor.tensor.math import all as pt_all
from pytensor.tensor.math import eq as pt_eq
from pytensor.tensor.math import ge, lt, maximum, minimum, prod
from pytensor.tensor.math import sum as at_sum
from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor
Expand Down Expand Up @@ -536,7 +534,7 @@ def bincount(x, weights=None, minlength=None, assert_nonneg=False):

if assert_nonneg:
assert_op = Assert("Input to bincount has negative values!")
x = assert_op(x, at_all(x >= 0))
x = assert_op(x, pt_all(x >= 0))

max_value = at.cast(x.max() + 1, "int64")

Expand Down Expand Up @@ -1510,8 +1508,8 @@ def broadcast_shape_iter(
result_dims = []

for dim_shapes in zip(*array_shapes):
# Get the shapes in this dimension that are not definitively
# broadcastable (i.e. not symbolically known to be broadcastable)
# Get the shapes in this dimension that are not broadcastable
# (i.e. not symbolically known to be broadcastable)
maybe_non_bcast_shapes = [shape for shape in dim_shapes if shape != one_at]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the "maybe"?

Suggested change
maybe_non_bcast_shapes = [shape for shape in dim_shapes if shape != one_at]
non_bcast_shapes = [shape for shape in dim_shapes if shape != one_at]


if len(maybe_non_bcast_shapes) == 0:
Expand All @@ -1532,97 +1530,40 @@ def broadcast_shape_iter(
nonconst_nb_shapes.add(shape)

if len(const_nb_shapes) > 1:
raise ValueError("Could not broadcast dimensions")
elif len(const_nb_shapes) == 1:
(const_nb_shape,) = const_nb_shapes

assert const_nb_shape != 1

const_nt_shape_var = pytensor.scalar.ScalarConstant(
pytensor.scalar.int64, const_nb_shape
raise ValueError(
f"Could not broadcast dimensions. Incompatible shapes were {array_shapes}."
)

if len(nonconst_nb_shapes) > 0:
# All the potential non-broadcast shapes need to either
# be broadcastable or equal to the one non-broadcastable
# constant `const_nt_shape_var`.
assert_dim = Assert("Could not broadcast dimensions")

scalar_nonconst_nb_shapes = [
at.scalar_from_tensor(s)
if isinstance(s.type, TensorType)
else s
for s in nonconst_nb_shapes
]

dummy_nonconst_nb_shapes = [
aes.get_scalar_type(dtype=v.dtype)()
for v in scalar_nonconst_nb_shapes
]
assert_cond = reduce(
aes.and_,
(
aes.or_(
aes.eq(nbv, one_at), aes.eq(nbv, const_nt_shape_var)
)
for nbv in dummy_nonconst_nb_shapes
),
)
assert_cond_op = Composite(dummy_nonconst_nb_shapes, [assert_cond])

bcast_dim = assert_dim(
const_nt_shape_var, assert_cond_op(*scalar_nonconst_nb_shapes)
)
else:
bcast_dim = const_nt_shape_var
assert_op = Assert("Could not dynamically broadcast dimensions.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe create this Op once at the module level?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert_op = Assert("Could not dynamically broadcast dimensions.")
assert_op = Assert("Could not broadcast dimensions. If a variable should broadcast use `specify_shape` to inform PyTensor.")

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made it even more specific:
"Could not broadcast dimensions. Broadcasting is only allowed along "
"axes that have a statically known length 1. Use specify_shape to "
"inform PyTensor of a known shape."

if len(const_nb_shapes) == 1:
(first_length,) = const_nb_shapes
other_lengths = nonconst_nb_shapes
first_length = aes.as_scalar(first_length)
else:
# There are no constant, non-broadcastable shapes in this
# dimension.

all_dims_equal = all(
# TODO FIXME: This is a largely deficient, and expensive, means
# of comparing graphs (and especially shapes)
equal_computations([maybe_non_bcast_shapes[0]], [dim])
for dim in maybe_non_bcast_shapes[1:]
)

if all_dims_equal:
result_dims.append(maybe_non_bcast_shapes[0])
continue

scalar_maybe_non_bcast_shapes = [
at.scalar_from_tensor(s) if isinstance(s.type, TensorType) else s
for s in maybe_non_bcast_shapes
]
dummy_maybe_non_bcast_shapes = [
aes.get_scalar_type(dtype=v.dtype)()
for v in scalar_maybe_non_bcast_shapes
]
non_bcast_vec = [
aes.switch(aes.eq(nbv, 1), -one_at, nbv)
for nbv in dummy_maybe_non_bcast_shapes
]
dim_max = aes.abs(reduce(aes.scalar_maximum, non_bcast_vec))
dim_max_op = Composite(dummy_maybe_non_bcast_shapes, [dim_max])

dummy_dim_max = dim_max_op(*dummy_maybe_non_bcast_shapes)

assert_dim = Assert("Could not broadcast dimensions")
assert_cond = reduce(
aes.and_,
(
aes.or_(aes.eq(nbv, -one_at), aes.eq(nbv, dummy_dim_max))
for nbv in non_bcast_vec
),
)
assert_cond_op = Composite(dummy_maybe_non_bcast_shapes, [assert_cond])

bcast_dim = assert_dim(
dim_max_op(*scalar_maybe_non_bcast_shapes),
assert_cond_op(*scalar_maybe_non_bcast_shapes),
first_length, *other_lengths = nonconst_nb_shapes

if len(other_lengths) == 0:
result_dims.append(first_length)
continue

# Add assert that all remaining shapes are equal
use_scalars = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove the use_scalars block?

if use_scalars:
condition = None
for other in other_lengths:
cond = aes.eq(first_length, other)
if condition is None:
condition = cond
else:
condition = aes.and_(condition, cond)
else:
condition = pt_all(
[pt_eq(first_length, other) for other in other_lengths]
)

result_dims.append(bcast_dim)
if condition is None:
result_dims.append(first_length)
else:
result_dims.append(assert_op(first_length, condition))

return tuple(result_dims)

Expand Down
11 changes: 8 additions & 3 deletions tests/tensor/rewriting/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1703,8 +1703,12 @@ def verify_op_count(f, count, cls):
],
)
def test_basic(self, expr, x_shape, y_shape):
x = at.tensor(dtype="int64", shape=(None,) * len(x_shape), name="x")
y = at.tensor(dtype="int64", shape=(None,) * len(y_shape), name="y")
x = at.tensor(
dtype="int64", shape=(1 if val == 1 else None for val in x_shape), name="x"
)
y = at.tensor(
dtype="int64", shape=(1 if val == 1 else None for val in y_shape), name="y"
)
z = expr(x, y)

z_opt = pytensor.function(
Expand Down Expand Up @@ -1878,7 +1882,8 @@ def test_multi_input_single_alloc(self):
mode=self.fast_run_mode,
)
self.verify_op_count(func, 0, Alloc)
self.verify_op_count(func, 1, Assert)
# The second assert is from the shape check...
self.verify_op_count(func, 2, Assert)

def test_misc(self):
x = row(dtype=self.dtype)
Expand Down
14 changes: 9 additions & 5 deletions tests/tensor/rewriting/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,9 +608,10 @@ def test_mul_div_cases(self):
((dv / dy) / dv, [dv, dy], [dvv, dyv], 1, "float64"),
((fv / fy) / fv, [fv, fy], [fvv, fyv], 1, "float32"),
# must broadcast as their is a dimshuffle in the computation
((dx / dv) / dx, [dx, dv], [dxv, dvv], 1, "float64"),
# The broadcast leads to an extra elemwise to check compatibility
((dx / dv) / dx, [dx, dv], [dxv, dvv], 2, "float64"),
# topo: [Shape_i, Shape_i, Elemwise{reciprocal,no_inplace}(<TensorType(float64, row)>), Alloc]
((fx / fv) / fx, [fx, fv], [fxv, fvv], 1, "float32"),
((fx / fv) / fx, [fx, fv], [fxv, fvv], 2, "float32"),
# topo: [Shape_i, Shape_i, Elemwise{reciprocal,no_inplace}(<TensorType(float32, row)>), Alloc]
]
):
Expand All @@ -621,9 +622,12 @@ def test_mul_div_cases(self):
elem = [t for t in topo if isinstance(t.op, Elemwise)]
assert len(elem) == nb_elemwise
assert isinstance(elem[0].op, (Elemwise,))
assert isinstance(
elem[0].op.scalar_op,
(aes.basic.Reciprocal, aes.basic.TrueDiv),
assert any(
isinstance(
el.op.scalar_op,
(aes.basic.Reciprocal, aes.basic.TrueDiv),
)
for el in elem
)
assert out_dtype == out.dtype

Expand Down
16 changes: 11 additions & 5 deletions tests/tensor/test_extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1086,7 +1086,9 @@ def shape_tuple(x, use_bcast=True):
assert any(
isinstance(node.op, Assert) for node in applys_between([x_at, y_at], b_at)
)
assert np.array_equal([z.eval() for z in b_at], b.shape)
# This should fail because it would need dynamic broadcasting
with pytest.raises(AssertionError):
assert np.array_equal([z.eval() for z in b_at], b.shape)
b_at = broadcast_shape(shape_tuple(x_at), shape_tuple(y_at), arrays_are_shapes=True)
assert np.array_equal([z.eval() for z in b_at], b.shape)

Expand Down Expand Up @@ -1183,8 +1185,8 @@ def test_broadcast_shape_constants():
@pytest.mark.parametrize(
("s1_vals", "s2_vals", "exp_res"),
[
((2, 2), (1, 2), (2, 2)),
((0, 2), (1, 2), (0, 2)),
((2, 2), (1, 2), AssertionError),
((0, 2), (1, 2), AssertionError),
((1, 2, 1), (2, 1, 2, 1), (2, 1, 2, 1)),
],
)
Expand All @@ -1203,7 +1205,11 @@ def test_broadcast_shape_symbolic(s1_vals, s2_vals, exp_res):
res = broadcast_shape(s1s, s2s, arrays_are_shapes=True)
res = at.as_tensor(res)

assert tuple(res.eval(eval_point)) == exp_res
if exp_res is AssertionError:
with pytest.raises(AssertionError):
res.eval(eval_point)
else:
assert tuple(res.eval(eval_point)) == exp_res


def test_broadcast_shape_symbolic_one_symbolic():
Expand Down Expand Up @@ -1395,7 +1401,7 @@ def test_inplace(self):


def test_broadcast_arrays():
x, y = at.dvector(), at.dmatrix()
x, y = at.tensor(shape=(1,), dtype="float64"), at.dmatrix()
x_bcast, y_bcast = broadcast_arrays(x, y)

py_mode = Mode("py", None)
Expand Down
2 changes: 1 addition & 1 deletion tests/unittest_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def _compile_and_check(
# Check that the Op is removed from the compiled function.
if check_topo:
topo_shape = shapes_function.maker.fgraph.toposort()
assert not any(isinstance(t.op, cls) for t in topo_shape)
assert not any(t in outputs for t in topo_shape)
topo_out = outputs_function.maker.fgraph.toposort()
assert any(isinstance(t.op, cls) for t in topo_out)
# Check that the shape produced agrees with the actual shape.
Expand Down