Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Infer shape of advanced boolean indexing #329

Merged
merged 3 commits into from
Jul 13, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 40 additions & 17 deletions pytensor/tensor/extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,18 @@
from pytensor.raise_op import Assert
from pytensor.scalar import int32 as int_t
from pytensor.scalar import upcast
from pytensor.tensor import as_tensor_variable
from pytensor.tensor import basic as at
from pytensor.tensor import get_vector_length
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import abs as at_abs
from pytensor.tensor.math import abs as pt_abs
from pytensor.tensor.math import all as pt_all
from pytensor.tensor.math import eq as pt_eq
from pytensor.tensor.math import ge, lt, maximum, minimum, prod
from pytensor.tensor.math import ge, lt
from pytensor.tensor.math import max as pt_max
from pytensor.tensor.math import maximum, minimum, prod
from pytensor.tensor.math import sum as at_sum
from pytensor.tensor.math import switch
from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor
from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes, vector
from pytensor.tensor.var import TensorVariable
Expand Down Expand Up @@ -1063,7 +1067,7 @@ def grad(self, inp, cost_grad):
# only valid for matrices
wr_a = fill_diagonal_offset(grad, 0, offset)

offset_abs = at_abs(offset)
offset_abs = pt_abs(offset)
pos_offset_flag = ge(offset, 0)
neg_offset_flag = lt(offset, 0)
min_wh = minimum(width, height)
Expand Down Expand Up @@ -1442,6 +1446,7 @@ def ravel_multi_index(multi_index, dims, mode="raise", order="C"):
"axes that have a statically known length 1. Use `specify_broadcastable` to "
"inform PyTensor of a known shape."
)
_runtime_broadcast_assert = Assert("Could not broadcast dimensions.")


def broadcast_shape(*arrays, **kwargs) -> Tuple[aes.ScalarVariable, ...]:
Expand All @@ -1465,6 +1470,7 @@ def broadcast_shape(*arrays, **kwargs) -> Tuple[aes.ScalarVariable, ...]:
def broadcast_shape_iter(
arrays: Iterable[Union[TensorVariable, Tuple[TensorVariable, ...]]],
arrays_are_shapes: bool = False,
allow_runtime_broadcast: bool = False,
) -> Tuple[aes.ScalarVariable, ...]:
r"""Compute the shape resulting from broadcasting arrays.

Expand All @@ -1480,22 +1486,24 @@ def broadcast_shape_iter(
arrays
An iterable of tensors, or a tuple of shapes (as tuples),
for which the broadcast shape is computed.
arrays_are_shapes
arrays_are_shapes: bool, default False
Indicates whether or not the `arrays` contains shape tuples.
If you use this approach, make sure that the broadcastable dimensions
are (scalar) constants with the value ``1``--or simply the integer
``1``.
``1``. This is not revelant if `allow_runtime_broadcast` is True.
allow_runtime_broadcast: bool, default False
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved
Whether to allow non-statically known broadcast on the shape computation.

"""
one_at = pytensor.scalar.ScalarConstant(pytensor.scalar.int64, 1)
one = pytensor.scalar.ScalarConstant(pytensor.scalar.int64, 1)

if arrays_are_shapes:
max_dims = max(len(a) for a in arrays)

array_shapes = [
(one_at,) * (max_dims - len(a))
(one,) * (max_dims - len(a))
+ tuple(
one_at
one
if sh == 1 or isinstance(sh, Constant) and sh.value == 1
else (aes.as_scalar(sh) if not isinstance(sh, Variable) else sh)
for sh in a
Expand All @@ -1508,10 +1516,8 @@ def broadcast_shape_iter(
_arrays = tuple(at.as_tensor_variable(a) for a in arrays)

array_shapes = [
(one_at,) * (max_dims - a.ndim)
+ tuple(
one_at if t_sh == 1 else sh for sh, t_sh in zip(a.shape, a.type.shape)
)
(one,) * (max_dims - a.ndim)
+ tuple(one if t_sh == 1 else sh for sh, t_sh in zip(a.shape, a.type.shape))
for a in _arrays
]

Expand All @@ -1520,11 +1526,11 @@ def broadcast_shape_iter(
for dim_shapes in zip(*array_shapes):
# Get the shapes in this dimension that are not broadcastable
# (i.e. not symbolically known to be broadcastable)
non_bcast_shapes = [shape for shape in dim_shapes if shape != one_at]
non_bcast_shapes = [shape for shape in dim_shapes if shape != one]

if len(non_bcast_shapes) == 0:
# Every shape was broadcastable in this dimension
result_dims.append(one_at)
result_dims.append(one)
elif len(non_bcast_shapes) == 1:
# Only one shape might not be broadcastable in this dimension
result_dims.extend(non_bcast_shapes)
Expand Down Expand Up @@ -1554,9 +1560,26 @@ def broadcast_shape_iter(
result_dims.append(first_length)
continue

# Add assert that all remaining shapes are equal
condition = pt_all([pt_eq(first_length, other) for other in other_lengths])
result_dims.append(_broadcast_assert(first_length, condition))
if not allow_runtime_broadcast:
# Add assert that all remaining shapes are equal
condition = pt_all(
[pt_eq(first_length, other) for other in other_lengths]
)
result_dims.append(_broadcast_assert(first_length, condition))
else:
lengths = as_tensor_variable((first_length, *other_lengths))
runtime_broadcastable = pt_eq(lengths, one)
result_dim = pt_abs(
pt_max(switch(runtime_broadcastable, -one, lengths))
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved
)
condition = pt_all(
switch(
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved
~runtime_broadcastable,
pt_eq(lengths, result_dim),
np.array(True),
)
)
result_dims.append(_runtime_broadcast_assert(result_dim, condition))

return tuple(result_dims)

Expand Down
61 changes: 41 additions & 20 deletions pytensor/tensor/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,11 @@
from pytensor.printing import Printer, pprint, set_precedence
from pytensor.scalar.basic import ScalarConstant
from pytensor.tensor import _get_vector_length, as_tensor_variable, get_vector_length
from pytensor.tensor.basic import alloc, get_underlying_scalar_constant_value
from pytensor.tensor.basic import alloc, get_underlying_scalar_constant_value, nonzero
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.exceptions import (
AdvancedIndexingError,
NotScalarConstantError,
ShapeError,
)
from pytensor.tensor.exceptions import AdvancedIndexingError, NotScalarConstantError
from pytensor.tensor.math import clip
from pytensor.tensor.shape import Reshape, specify_broadcastable
from pytensor.tensor.shape import Reshape, shape_i, specify_broadcastable
from pytensor.tensor.type import (
TensorType,
bscalar,
Expand Down Expand Up @@ -510,7 +506,11 @@ def indexed_result_shape(array_shape, indices, indices_are_shapes=False):
from pytensor.tensor.extra_ops import broadcast_shape

res_shape += broadcast_shape(
*grp_indices, arrays_are_shapes=indices_are_shapes
*grp_indices,
arrays_are_shapes=indices_are_shapes,
# The AdvancedIndexing Op relies on the Numpy implementation which allows runtime broadcasting.
# As long as that is true, the shape inference has to respect that this is not an error.
allow_runtime_broadcast=True,
)

res_shape += tuple(array_shape[dim] for dim in remaining_dims)
Expand Down Expand Up @@ -2584,26 +2584,47 @@ def R_op(self, inputs, eval_points):
return self.make_node(eval_points[0], *inputs[1:]).outputs

def infer_shape(self, fgraph, node, ishapes):
indices = node.inputs[1:]
index_shapes = list(ishapes[1:])
for i, idx in enumerate(indices):
if (
def is_bool_index(idx):
return (
isinstance(idx, (np.bool_, bool))
or getattr(idx, "dtype", None) == "bool"
):
raise ShapeError(
"Shape inference for boolean indices is not implemented"
)

indices = node.inputs[1:]
index_shapes = []
for idx, ishape in zip(indices, ishapes[1:]):
# Mixed bool indexes are converted to nonzero entries
if is_bool_index(idx):
index_shapes.extend(
(shape_i(nz_dim, 0, fgraph=fgraph),) for nz_dim in nonzero(idx)
)
# The `ishapes` entries for `SliceType`s will be None, and
# we need to give `indexed_result_shape` the actual slices.
if isinstance(getattr(idx, "type", None), SliceType):
index_shapes[i] = idx
elif isinstance(getattr(idx, "type", None), SliceType):
index_shapes.append(idx)
else:
index_shapes.append(ishape)

res_shape = indexed_result_shape(
ishapes[0], index_shapes, indices_are_shapes=True
res_shape = list(
indexed_result_shape(ishapes[0], index_shapes, indices_are_shapes=True)
)

adv_indices = [idx for idx in indices if not is_basic_idx(idx)]
bool_indices = [idx for idx in adv_indices if is_bool_index(idx)]

# Special logic when the only advanced index group is of bool type.
# We can replace the nonzeros by a sum of the whole bool variable.
if len(bool_indices) == 1 and len(adv_indices) == 1:
[bool_index] = bool_indices
# Find the output dim associated with the bool index group
# Because there are no more advanced index groups, there is exactly
# one output dim per index variable up to the bool group.
# Note: Scalar integer indexing counts as advanced indexing.
start_dim = indices.index(bool_index)
res_shape[start_dim] = bool_index.sum()

assert node.outputs[0].ndim == len(res_shape)
return [list(res_shape)]
return [res_shape]

def perform(self, node, inputs, out_):
(out,) = out_
Expand Down
10 changes: 9 additions & 1 deletion tests/tensor/test_extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1087,9 +1087,17 @@ def shape_tuple(x, use_bcast=True):
assert any(
isinstance(node.op, Assert) for node in applys_between([x_at, y_at], b_at)
)
# This should fail because it would need dynamic broadcasting
with pytest.raises(AssertionError):
assert np.array_equal([z.eval() for z in b_at], b.shape)
# But fine if we allow_runtime_broadcast
b_at = broadcast_shape(
shape_tuple(x_at, use_bcast=False),
shape_tuple(y_at, use_bcast=False),
arrays_are_shapes=True,
allow_runtime_broadcast=True,
)
assert np.array_equal([z.eval() for z in b_at], b.shape)
# Or if static bcast is known
b_at = broadcast_shape(shape_tuple(x_at), shape_tuple(y_at), arrays_are_shapes=True)
assert np.array_equal([z.eval() for z in b_at], b.shape)

Expand Down
73 changes: 70 additions & 3 deletions tests/tensor/test_subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
tensor,
tensor3,
tensor4,
tensor5,
vector,
)
from pytensor.tensor.type_other import NoneConst, SliceConstant, make_slice, slicetype
Expand Down Expand Up @@ -2150,6 +2151,12 @@ def fun(x, y):


class TestInferShape(utt.InferShapeTester):
@staticmethod
def random_bool_mask(shape, rng=None):
if rng is None:
rng = np.random.default_rng()
return rng.binomial(n=1, p=0.5, size=shape).astype(bool)

def test_IncSubtensor(self):
admat = dmatrix()
bdmat = dmatrix()
Expand Down Expand Up @@ -2439,25 +2446,85 @@ def test_AdvancedSubtensor_bool(self):
n = dmatrix()
n_val = np.arange(6).reshape((2, 3))

# infer_shape is not implemented, but it should not crash
# Shape inference requires runtime broadcasting between the nonzero() shapes
self._compile_and_check(
[n],
[n[n[:, 0] > 2, n[0, :] > 2]],
[n_val],
AdvancedSubtensor,
check_topo=False,
)
self._compile_and_check(
[n],
[n[n[:, 0] > 2]],
[n_val],
AdvancedSubtensor,
check_topo=False,
)
self._compile_and_check(
[n],
[n[:, np.array([True, False, True])]],
[n_val],
AdvancedSubtensor,
)
self._compile_and_check(
[n],
[n[np.array([False, False]), 1:]],
[n_val],
AdvancedSubtensor,
)
self._compile_and_check(
[n],
[n[np.array([True, True]), 0]],
[n_val],
AdvancedSubtensor,
)
self._compile_and_check(
[n],
[n[self.random_bool_mask(n_val.shape)]],
[n_val],
AdvancedSubtensor,
)
self._compile_and_check(
[n],
[n[None, self.random_bool_mask(n_val.shape), None]],
[n_val],
AdvancedSubtensor,
)
self._compile_and_check(
[n],
[n[slice(5, None), self.random_bool_mask(n_val.shape[1])]],
[n_val],
AdvancedSubtensor,
)

abs_res = n[~isinf(n)]
assert abs_res.type.shape == (None,)

def test_AdvancedSubtensor_bool_mixed(self):
n = tensor5("x", dtype="float64")
shape = (18, 3, 4, 5, 6)
n_val = np.arange(np.prod(shape)).reshape(shape)
self._compile_and_check(
[n],
# Consecutive advanced index
[n[1:, self.random_bool_mask((3, 4)), 0, 1:]],
[n_val],
AdvancedSubtensor,
)
self._compile_and_check(
[n],
# Non-consecutive advanced index
[n[1:, self.random_bool_mask((3, 4)), 1:, 0]],
[n_val],
AdvancedSubtensor,
)
self._compile_and_check(
[n],
# Non-consecutive advanced index
[n[1:, self.random_bool_mask((3,)), 1:, None, np.zeros((6, 1), dtype=int)]],
[n_val],
AdvancedSubtensor,
)


@config.change_flags(compute_test_value="raise")
def test_basic_shape():
Expand Down