From c1fbd0851d03256096707aa1a0e25a1dea38ef34 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 6 Jun 2023 11:14:23 +0200 Subject: [PATCH] Implement shape inference for boolean advanced indexing --- pytensor/tensor/subtensor.py | 55 +++++++++++++++++--------- tests/tensor/test_subtensor.py | 72 ++++++++++++++++++++++++++++++++-- 2 files changed, 105 insertions(+), 22 deletions(-) diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index e0c0060058..f82c800844 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -20,15 +20,11 @@ from pytensor.printing import Printer, pprint, set_precedence from pytensor.scalar.basic import ScalarConstant from pytensor.tensor import _get_vector_length, as_tensor_variable, get_vector_length -from pytensor.tensor.basic import alloc, get_underlying_scalar_constant_value +from pytensor.tensor.basic import alloc, get_underlying_scalar_constant_value, nonzero from pytensor.tensor.elemwise import DimShuffle -from pytensor.tensor.exceptions import ( - AdvancedIndexingError, - NotScalarConstantError, - ShapeError, -) +from pytensor.tensor.exceptions import AdvancedIndexingError, NotScalarConstantError from pytensor.tensor.math import clip -from pytensor.tensor.shape import Reshape, specify_broadcastable +from pytensor.tensor.shape import Reshape, shape_i, specify_broadcastable from pytensor.tensor.type import ( TensorType, bscalar, @@ -2584,26 +2580,47 @@ def R_op(self, inputs, eval_points): return self.make_node(eval_points[0], *inputs[1:]).outputs def infer_shape(self, fgraph, node, ishapes): - indices = node.inputs[1:] - index_shapes = list(ishapes[1:]) - for i, idx in enumerate(indices): - if ( + def is_bool_index(idx): + return ( isinstance(idx, (np.bool_, bool)) or getattr(idx, "dtype", None) == "bool" - ): - raise ShapeError( - "Shape inference for boolean indices is not implemented" + ) + + indices = node.inputs[1:] + index_shapes = [] + for idx, ishape in zip(indices, ishapes[1:]): + # Mixed bool indexes are converted to nonzero entries + if is_bool_index(idx): + index_shapes.extend( + (shape_i(nz_dim, 0, fgraph=fgraph),) for nz_dim in nonzero(idx) ) # The `ishapes` entries for `SliceType`s will be None, and # we need to give `indexed_result_shape` the actual slices. - if isinstance(getattr(idx, "type", None), SliceType): - index_shapes[i] = idx + elif isinstance(getattr(idx, "type", None), SliceType): + index_shapes.append(idx) + else: + index_shapes.append(ishape) - res_shape = indexed_result_shape( - ishapes[0], index_shapes, indices_are_shapes=True + res_shape = list( + indexed_result_shape(ishapes[0], index_shapes, indices_are_shapes=True) ) + + adv_indices = [idx for idx in indices if not is_basic_idx(idx)] + bool_indices = [idx for idx in adv_indices if is_bool_index(idx)] + + # Special logic when the only advanced index group is of bool type. + # We can replace the nonzeros by a sum of the whole bool variable. + if len(bool_indices) == 1 and len(adv_indices) == 1: + [bool_index] = bool_indices + # Find the output dim associated with the bool index group + # Because there are no more advanced index groups, there is exactly + # one output dim per index variable up to the bool group. + # Note: Scalar integer indexing counts as advanced indexing. + start_dim = indices.index(bool_index) + res_shape[start_dim] = bool_index.sum() + assert node.outputs[0].ndim == len(res_shape) - return [list(res_shape)] + return [res_shape] def perform(self, node, inputs, out_): (out,) = out_ diff --git a/tests/tensor/test_subtensor.py b/tests/tensor/test_subtensor.py index 0f60ac4ed5..6bd66d11aa 100644 --- a/tests/tensor/test_subtensor.py +++ b/tests/tensor/test_subtensor.py @@ -63,6 +63,7 @@ tensor, tensor3, tensor4, + tensor5, vector, ) from pytensor.tensor.type_other import NoneConst, SliceConstant, make_slice, slicetype @@ -2150,6 +2151,12 @@ def fun(x, y): class TestInferShape(utt.InferShapeTester): + @staticmethod + def random_bool_mask(shape, rng=None): + if rng is None: + rng = np.random.default_rng() + return rng.binomial(n=1, p=0.5, size=shape).astype(bool) + def test_IncSubtensor(self): admat = dmatrix() bdmat = dmatrix() @@ -2439,25 +2446,84 @@ def test_AdvancedSubtensor_bool(self): n = dmatrix() n_val = np.arange(6).reshape((2, 3)) - # infer_shape is not implemented, but it should not crash self._compile_and_check( [n], [n[n[:, 0] > 2, n[0, :] > 2]], [n_val], AdvancedSubtensor, - check_topo=False, ) self._compile_and_check( [n], [n[n[:, 0] > 2]], [n_val], AdvancedSubtensor, - check_topo=False, + ) + self._compile_and_check( + [n], + [n[:, np.array([True, False, True])]], + [n_val], + AdvancedSubtensor, + ) + self._compile_and_check( + [n], + [n[np.array([False, False]), 1:]], + [n_val], + AdvancedSubtensor, + ) + self._compile_and_check( + [n], + [n[np.array([True, True]), 0]], + [n_val], + AdvancedSubtensor, + ) + self._compile_and_check( + [n], + [n[self.random_bool_mask(n_val.shape)]], + [n_val], + AdvancedSubtensor, + ) + self._compile_and_check( + [n], + [n[None, self.random_bool_mask(n_val.shape), None]], + [n_val], + AdvancedSubtensor, + ) + self._compile_and_check( + [n], + [n[slice(5, None), self.random_bool_mask(n_val.shape[1])]], + [n_val], + AdvancedSubtensor, ) abs_res = n[~isinf(n)] assert abs_res.type.shape == (None,) + def test_AdvancedSubtensor_bool_mixed(self): + n = tensor5("x", dtype="float64") + shape = (18, 3, 4, 5, 6) + n_val = np.arange(np.prod(shape)).reshape(shape) + self._compile_and_check( + [n], + # Consecutive advanced index + [n[1:, self.random_bool_mask((3, 4)), 0, 1:]], + [n_val], + AdvancedSubtensor, + ) + self._compile_and_check( + [n], + # Non-consecutive advanced index + [n[1:, self.random_bool_mask((3, 4)), 1:, 0]], + [n_val], + AdvancedSubtensor, + ) + self._compile_and_check( + [n], + # Non-consecutive advanced index + [n[1:, self.random_bool_mask((3,)), 1:, None, np.zeros((6, 1), dtype=int)]], + [n_val], + AdvancedSubtensor, + ) + @config.change_flags(compute_test_value="raise") def test_basic_shape():