Skip to content

Commit

Permalink
Implement shape inference for boolean advanced indexing
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jul 12, 2023
1 parent d108ebb commit c1fbd08
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 22 deletions.
55 changes: 36 additions & 19 deletions pytensor/tensor/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,11 @@
from pytensor.printing import Printer, pprint, set_precedence
from pytensor.scalar.basic import ScalarConstant
from pytensor.tensor import _get_vector_length, as_tensor_variable, get_vector_length
from pytensor.tensor.basic import alloc, get_underlying_scalar_constant_value
from pytensor.tensor.basic import alloc, get_underlying_scalar_constant_value, nonzero
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.exceptions import (
AdvancedIndexingError,
NotScalarConstantError,
ShapeError,
)
from pytensor.tensor.exceptions import AdvancedIndexingError, NotScalarConstantError
from pytensor.tensor.math import clip
from pytensor.tensor.shape import Reshape, specify_broadcastable
from pytensor.tensor.shape import Reshape, shape_i, specify_broadcastable
from pytensor.tensor.type import (
TensorType,
bscalar,
Expand Down Expand Up @@ -2584,26 +2580,47 @@ def R_op(self, inputs, eval_points):
return self.make_node(eval_points[0], *inputs[1:]).outputs

def infer_shape(self, fgraph, node, ishapes):
indices = node.inputs[1:]
index_shapes = list(ishapes[1:])
for i, idx in enumerate(indices):
if (
def is_bool_index(idx):
return (
isinstance(idx, (np.bool_, bool))
or getattr(idx, "dtype", None) == "bool"
):
raise ShapeError(
"Shape inference for boolean indices is not implemented"
)

indices = node.inputs[1:]
index_shapes = []
for idx, ishape in zip(indices, ishapes[1:]):
# Mixed bool indexes are converted to nonzero entries
if is_bool_index(idx):
index_shapes.extend(
(shape_i(nz_dim, 0, fgraph=fgraph),) for nz_dim in nonzero(idx)
)
# The `ishapes` entries for `SliceType`s will be None, and
# we need to give `indexed_result_shape` the actual slices.
if isinstance(getattr(idx, "type", None), SliceType):
index_shapes[i] = idx
elif isinstance(getattr(idx, "type", None), SliceType):
index_shapes.append(idx)
else:
index_shapes.append(ishape)

res_shape = indexed_result_shape(
ishapes[0], index_shapes, indices_are_shapes=True
res_shape = list(
indexed_result_shape(ishapes[0], index_shapes, indices_are_shapes=True)
)

adv_indices = [idx for idx in indices if not is_basic_idx(idx)]
bool_indices = [idx for idx in adv_indices if is_bool_index(idx)]

# Special logic when the only advanced index group is of bool type.
# We can replace the nonzeros by a sum of the whole bool variable.
if len(bool_indices) == 1 and len(adv_indices) == 1:
[bool_index] = bool_indices
# Find the output dim associated with the bool index group
# Because there are no more advanced index groups, there is exactly
# one output dim per index variable up to the bool group.
# Note: Scalar integer indexing counts as advanced indexing.
start_dim = indices.index(bool_index)
res_shape[start_dim] = bool_index.sum()

assert node.outputs[0].ndim == len(res_shape)
return [list(res_shape)]
return [res_shape]

def perform(self, node, inputs, out_):
(out,) = out_
Expand Down
72 changes: 69 additions & 3 deletions tests/tensor/test_subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
tensor,
tensor3,
tensor4,
tensor5,
vector,
)
from pytensor.tensor.type_other import NoneConst, SliceConstant, make_slice, slicetype
Expand Down Expand Up @@ -2150,6 +2151,12 @@ def fun(x, y):


class TestInferShape(utt.InferShapeTester):
@staticmethod
def random_bool_mask(shape, rng=None):
if rng is None:
rng = np.random.default_rng()
return rng.binomial(n=1, p=0.5, size=shape).astype(bool)

def test_IncSubtensor(self):
admat = dmatrix()
bdmat = dmatrix()
Expand Down Expand Up @@ -2439,25 +2446,84 @@ def test_AdvancedSubtensor_bool(self):
n = dmatrix()
n_val = np.arange(6).reshape((2, 3))

# infer_shape is not implemented, but it should not crash
self._compile_and_check(
[n],
[n[n[:, 0] > 2, n[0, :] > 2]],
[n_val],
AdvancedSubtensor,
check_topo=False,
)
self._compile_and_check(
[n],
[n[n[:, 0] > 2]],
[n_val],
AdvancedSubtensor,
check_topo=False,
)
self._compile_and_check(
[n],
[n[:, np.array([True, False, True])]],
[n_val],
AdvancedSubtensor,
)
self._compile_and_check(
[n],
[n[np.array([False, False]), 1:]],
[n_val],
AdvancedSubtensor,
)
self._compile_and_check(
[n],
[n[np.array([True, True]), 0]],
[n_val],
AdvancedSubtensor,
)
self._compile_and_check(
[n],
[n[self.random_bool_mask(n_val.shape)]],
[n_val],
AdvancedSubtensor,
)
self._compile_and_check(
[n],
[n[None, self.random_bool_mask(n_val.shape), None]],
[n_val],
AdvancedSubtensor,
)
self._compile_and_check(
[n],
[n[slice(5, None), self.random_bool_mask(n_val.shape[1])]],
[n_val],
AdvancedSubtensor,
)

abs_res = n[~isinf(n)]
assert abs_res.type.shape == (None,)

def test_AdvancedSubtensor_bool_mixed(self):
n = tensor5("x", dtype="float64")
shape = (18, 3, 4, 5, 6)
n_val = np.arange(np.prod(shape)).reshape(shape)
self._compile_and_check(
[n],
# Consecutive advanced index
[n[1:, self.random_bool_mask((3, 4)), 0, 1:]],
[n_val],
AdvancedSubtensor,
)
self._compile_and_check(
[n],
# Non-consecutive advanced index
[n[1:, self.random_bool_mask((3, 4)), 1:, 0]],
[n_val],
AdvancedSubtensor,
)
self._compile_and_check(
[n],
# Non-consecutive advanced index
[n[1:, self.random_bool_mask((3,)), 1:, None, np.zeros((6, 1), dtype=int)]],
[n_val],
AdvancedSubtensor,
)


@config.change_flags(compute_test_value="raise")
def test_basic_shape():
Expand Down

0 comments on commit c1fbd08

Please sign in to comment.