diff --git a/pytensor/tensor/extra_ops.py b/pytensor/tensor/extra_ops.py index 26f1faa3e7..c13e514d53 100644 --- a/pytensor/tensor/extra_ops.py +++ b/pytensor/tensor/extra_ops.py @@ -21,14 +21,18 @@ from pytensor.raise_op import Assert from pytensor.scalar import int32 as int_t from pytensor.scalar import upcast +from pytensor.tensor import as_tensor_variable from pytensor.tensor import basic as at from pytensor.tensor import get_vector_length from pytensor.tensor.exceptions import NotScalarConstantError -from pytensor.tensor.math import abs as at_abs +from pytensor.tensor.math import abs as pt_abs from pytensor.tensor.math import all as pt_all from pytensor.tensor.math import eq as pt_eq -from pytensor.tensor.math import ge, lt, maximum, minimum, prod +from pytensor.tensor.math import ge, lt +from pytensor.tensor.math import max as pt_max +from pytensor.tensor.math import maximum, minimum, prod from pytensor.tensor.math import sum as at_sum +from pytensor.tensor.math import switch from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes, vector from pytensor.tensor.var import TensorVariable @@ -1063,7 +1067,7 @@ def grad(self, inp, cost_grad): # only valid for matrices wr_a = fill_diagonal_offset(grad, 0, offset) - offset_abs = at_abs(offset) + offset_abs = pt_abs(offset) pos_offset_flag = ge(offset, 0) neg_offset_flag = lt(offset, 0) min_wh = minimum(width, height) @@ -1442,6 +1446,7 @@ def ravel_multi_index(multi_index, dims, mode="raise", order="C"): "axes that have a statically known length 1. Use `specify_broadcastable` to " "inform PyTensor of a known shape." ) +_runtime_broadcast_assert = Assert("Could not broadcast dimensions.") def broadcast_shape(*arrays, **kwargs) -> Tuple[aes.ScalarVariable, ...]: @@ -1465,6 +1470,7 @@ def broadcast_shape(*arrays, **kwargs) -> Tuple[aes.ScalarVariable, ...]: def broadcast_shape_iter( arrays: Iterable[Union[TensorVariable, Tuple[TensorVariable, ...]]], arrays_are_shapes: bool = False, + allow_runtime_broadcast: bool = False, ) -> Tuple[aes.ScalarVariable, ...]: r"""Compute the shape resulting from broadcasting arrays. @@ -1480,22 +1486,24 @@ def broadcast_shape_iter( arrays An iterable of tensors, or a tuple of shapes (as tuples), for which the broadcast shape is computed. - arrays_are_shapes + arrays_are_shapes: bool, default False Indicates whether or not the `arrays` contains shape tuples. If you use this approach, make sure that the broadcastable dimensions are (scalar) constants with the value ``1``--or simply the integer - ``1``. + ``1``. This is not revelant if `allow_runtime_broadcast` is True. + allow_runtime_broadcast: bool, default False + Whether to allow non-statically known broadcast on the shape computation. """ - one_at = pytensor.scalar.ScalarConstant(pytensor.scalar.int64, 1) + one = pytensor.scalar.ScalarConstant(pytensor.scalar.int64, 1) if arrays_are_shapes: max_dims = max(len(a) for a in arrays) array_shapes = [ - (one_at,) * (max_dims - len(a)) + (one,) * (max_dims - len(a)) + tuple( - one_at + one if sh == 1 or isinstance(sh, Constant) and sh.value == 1 else (aes.as_scalar(sh) if not isinstance(sh, Variable) else sh) for sh in a @@ -1508,10 +1516,8 @@ def broadcast_shape_iter( _arrays = tuple(at.as_tensor_variable(a) for a in arrays) array_shapes = [ - (one_at,) * (max_dims - a.ndim) - + tuple( - one_at if t_sh == 1 else sh for sh, t_sh in zip(a.shape, a.type.shape) - ) + (one,) * (max_dims - a.ndim) + + tuple(one if t_sh == 1 else sh for sh, t_sh in zip(a.shape, a.type.shape)) for a in _arrays ] @@ -1520,11 +1526,11 @@ def broadcast_shape_iter( for dim_shapes in zip(*array_shapes): # Get the shapes in this dimension that are not broadcastable # (i.e. not symbolically known to be broadcastable) - non_bcast_shapes = [shape for shape in dim_shapes if shape != one_at] + non_bcast_shapes = [shape for shape in dim_shapes if shape != one] if len(non_bcast_shapes) == 0: # Every shape was broadcastable in this dimension - result_dims.append(one_at) + result_dims.append(one) elif len(non_bcast_shapes) == 1: # Only one shape might not be broadcastable in this dimension result_dims.extend(non_bcast_shapes) @@ -1554,9 +1560,26 @@ def broadcast_shape_iter( result_dims.append(first_length) continue - # Add assert that all remaining shapes are equal - condition = pt_all([pt_eq(first_length, other) for other in other_lengths]) - result_dims.append(_broadcast_assert(first_length, condition)) + if not allow_runtime_broadcast: + # Add assert that all remaining shapes are equal + condition = pt_all( + [pt_eq(first_length, other) for other in other_lengths] + ) + result_dims.append(_broadcast_assert(first_length, condition)) + else: + lengths = as_tensor_variable((first_length, *other_lengths)) + runtime_broadcastable = pt_eq(lengths, one) + result_dim = pt_abs( + pt_max(switch(runtime_broadcastable, -one, lengths)) + ) + condition = pt_all( + switch( + ~runtime_broadcastable, + pt_eq(lengths, result_dim), + np.array(True), + ) + ) + result_dims.append(_runtime_broadcast_assert(result_dim, condition)) return tuple(result_dims) diff --git a/pytensor/tensor/random/rewriting/basic.py b/pytensor/tensor/random/rewriting/basic.py index bdef569905..1cd93a29e0 100644 --- a/pytensor/tensor/random/rewriting/basic.py +++ b/pytensor/tensor/random/rewriting/basic.py @@ -1,14 +1,15 @@ -from itertools import zip_longest +from itertools import chain from pytensor.compile import optdb from pytensor.configdefaults import config +from pytensor.graph import ancestors from pytensor.graph.op import compute_test_value -from pytensor.graph.rewriting.basic import in2out, node_rewriter +from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, node_rewriter +from pytensor.scalar import integer_types from pytensor.tensor import NoneConst from pytensor.tensor.basic import constant, get_vector_length from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.extra_ops import broadcast_to -from pytensor.tensor.math import sum as at_sum from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.random.utils import broadcast_params from pytensor.tensor.shape import Shape, Shape_i, shape_padleft @@ -18,7 +19,6 @@ Subtensor, as_index_variable, get_idx_list, - indexed_result_shape, ) from pytensor.tensor.type_other import SliceType @@ -207,47 +207,76 @@ def local_subtensor_rv_lift(fgraph, node): ``mvnormal(mu, cov, size=(2,))[0, 0]``. """ - st_op = node.op + def is_nd_advanced_idx(idx, dtype): + if isinstance(dtype, str): + return (getattr(idx.type, "dtype", None) == dtype) and (idx.type.ndim >= 1) + else: + return (getattr(idx.type, "dtype", None) in dtype) and (idx.type.ndim >= 1) - if not isinstance(st_op, (AdvancedSubtensor, AdvancedSubtensor1, Subtensor)): - return False + subtensor_op = node.op + old_subtensor = node.outputs[0] rv = node.inputs[0] rv_node = rv.owner if not (rv_node and isinstance(rv_node.op, RandomVariable)): return False + shape_feature = getattr(fgraph, "shape_feature", None) + if not shape_feature: + return None + + # Use shape_feature to facilitate inferring final shape. + # Check that neither the RV nor the old Subtensor are in the shape graph. + output_shape = fgraph.shape_feature.shape_of.get(old_subtensor, None) + if output_shape is None or {old_subtensor, rv} & set(ancestors(output_shape)): + return None + rv_op = rv_node.op rng, size, dtype, *dist_params = rv_node.inputs # Parse indices - idx_list = getattr(st_op, "idx_list", None) + idx_list = getattr(subtensor_op, "idx_list", None) if idx_list: - cdata = get_idx_list(node.inputs, idx_list) + idx_vars = get_idx_list(node.inputs, idx_list) else: - cdata = node.inputs[1:] - st_indices, st_is_bool = zip( - *tuple( - (as_index_variable(i), getattr(i, "dtype", None) == "bool") for i in cdata - ) - ) + idx_vars = node.inputs[1:] + indices = tuple(as_index_variable(idx) for idx in idx_vars) + + # The rewrite doesn't apply if advanced indexing could broadcast the samples (leading to duplicates) + # Note: For simplicity this also excludes subtensor-related expand_dims (np.newaxis). + # If we wanted to support that we could rewrite it as subtensor + dimshuffle + # and make use of the dimshuffle lift rewrite + integer_dtypes = {type.dtype for type in integer_types} + if any( + is_nd_advanced_idx(idx, integer_dtypes) or NoneConst.equals(idx) + for idx in indices + ): + return False # Check that indexing does not act on support dims - batched_ndims = rv.ndim - rv_op.ndim_supp - if len(st_indices) > batched_ndims: - # If the last indexes are just dummy `slice(None)` we discard them - st_is_bool = st_is_bool[:batched_ndims] - st_indices, supp_indices = ( - st_indices[:batched_ndims], - st_indices[batched_ndims:], + batch_ndims = rv.ndim - rv_op.ndim_supp + # We decompose the boolean indexes, which makes it clear whether they act on support dims or not + non_bool_indices = tuple( + chain.from_iterable( + idx.nonzero() if is_nd_advanced_idx(idx, "bool") else (idx,) + for idx in indices + ) + ) + if len(non_bool_indices) > batch_ndims: + # If the last indexes are just dummy `slice(None)` we discard them instead of quitting + non_bool_indices, supp_indices = ( + non_bool_indices[:batch_ndims], + non_bool_indices[batch_ndims:], ) - for index in supp_indices: + for idx in supp_indices: if not ( - isinstance(index.type, SliceType) - and all(NoneConst.equals(i) for i in index.owner.inputs) + isinstance(idx.type, SliceType) + and all(NoneConst.equals(i) for i in idx.owner.inputs) ): return False + n_discarded_idxs = len(supp_indices) + indices = indices[:-n_discarded_idxs] # If no one else is using the underlying `RandomVariable`, then we can # do this; otherwise, the graph would be internally inconsistent. @@ -255,50 +284,71 @@ def local_subtensor_rv_lift(fgraph, node): return False # Update the size to reflect the indexed dimensions - # TODO: Could use `ShapeFeature` info. We would need to be sure that - # `node` isn't in the results, though. - # if hasattr(fgraph, "shape_feature"): - # output_shape = fgraph.shape_feature.shape_of(node.outputs[0]) - # else: - output_shape_ignoring_bool = indexed_result_shape(rv.shape, st_indices) - new_size_ignoring_boolean = ( - output_shape_ignoring_bool - if rv_op.ndim_supp == 0 - else output_shape_ignoring_bool[: -rv_op.ndim_supp] - ) + new_size = output_shape[: len(output_shape) - rv_op.ndim_supp] - # Boolean indices can actually change the `size` value (compared to just *which* dimensions of `size` are used). - # The `indexed_result_shape` helper does not consider this - if any(st_is_bool): - new_size = tuple( - at_sum(idx) if is_bool else s - for s, is_bool, idx in zip_longest( - new_size_ignoring_boolean, st_is_bool, st_indices, fillvalue=False - ) - ) - else: - new_size = new_size_ignoring_boolean - - # Update the parameters to reflect the indexed dimensions + # Propagate indexing to the parameters' batch dims. + # We try to avoid broadcasting the parameters together (and with size), by only indexing + # non-broadcastable (non-degenerate) parameter dims. These parameters and the new size + # should still correctly broadcast any degenerate parameter dims. new_dist_params = [] for param, param_ndim_supp in zip(dist_params, rv_op.ndims_params): - # Apply indexing on the batched dimensions of the parameter - batched_param_dims_missing = batched_ndims - (param.ndim - param_ndim_supp) - batched_param = shape_padleft(param, batched_param_dims_missing) - batched_st_indices = [] - for st_index, batched_param_shape in zip(st_indices, batched_param.type.shape): - # If we have a degenerate dimension indexing it should always do the job - if batched_param_shape == 1: - batched_st_indices.append(0) + # We first expand any missing parameter dims (and later index them away or keep them with none-slicing) + batch_param_dims_missing = batch_ndims - (param.ndim - param_ndim_supp) + batch_param = ( + shape_padleft(param, batch_param_dims_missing) + if batch_param_dims_missing + else param + ) + # Check which dims are actually broadcasted + bcast_batch_param_dims = tuple( + dim + for dim, (param_dim, output_dim) in enumerate( + zip(batch_param.type.shape, rv.type.shape) + ) + if (param_dim == 1) and (output_dim != 1) + ) + batch_indices = [] + curr_dim = 0 + for idx in indices: + # Advanced boolean indexing + if is_nd_advanced_idx(idx, "bool"): + # Check if any broadcasted dim overlaps with advanced boolean indexing. + # If not, we use that directly, instead of the more inefficient `nonzero` form + bool_dims = range(curr_dim, curr_dim + idx.type.ndim) + # There's an overlap, we have to decompose the boolean mask as a `nonzero` + if set(bool_dims) & set(bcast_batch_param_dims): + int_indices = list(idx.nonzero()) + # Indexing by 0 drops the degenerate dims + for bool_dim in bool_dims: + if bool_dim in bcast_batch_param_dims: + int_indices[bool_dim - curr_dim] = 0 + batch_indices.extend(int_indices) + # No overlap, use index as is + else: + batch_indices.append(idx) + curr_dim += len(bool_dims) + # Basic-indexing (slice or integer) else: - batched_st_indices.append(st_index) - new_dist_params.append(batched_param[tuple(batched_st_indices)]) + # Broadcasted dim + if curr_dim in bcast_batch_param_dims: + # Slice indexing, keep degenerate dim by none-slicing + if isinstance(idx.type, SliceType): + batch_indices.append(slice(None)) + # Integer indexing, drop degenerate dim by 0-indexing + else: + batch_indices.append(0) + # Non-broadcasted dim + else: + # Use index as is + batch_indices.append(idx) + curr_dim += 1 + + new_dist_params.append(batch_param[tuple(batch_indices)]) # Create new RV new_node = rv_op.make_node(rng, new_size, dtype, *new_dist_params) new_rv = new_node.default_output() - if config.compute_test_value != "off": - compute_test_value(new_node) + copy_stack_trace(rv, new_rv) return [new_rv] diff --git a/pytensor/tensor/rewriting/shape.py b/pytensor/tensor/rewriting/shape.py index 817c6d0d63..e0d4963388 100644 --- a/pytensor/tensor/rewriting/shape.py +++ b/pytensor/tensor/rewriting/shape.py @@ -99,18 +99,6 @@ class ShapeFeature(Feature): Notes ----- - Right now there is only the ConvOp that could really take - advantage of this shape inference, but it is worth it even - just for the ConvOp. All that's necessary to do shape - inference is 1) to mark shared inputs as having a particular - shape, either via a .tag or some similar hacking; and 2) to - add an optional In() argument to promise that inputs will - have a certain shape (or even to have certain shapes in - certain dimensions). - - We can't automatically infer the shape of shared variables as they can - change of shape during the execution by default. - To use this shape information in rewrites, use the ``shape_of`` dictionary. diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index e0c0060058..3779b1411a 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -20,15 +20,11 @@ from pytensor.printing import Printer, pprint, set_precedence from pytensor.scalar.basic import ScalarConstant from pytensor.tensor import _get_vector_length, as_tensor_variable, get_vector_length -from pytensor.tensor.basic import alloc, get_underlying_scalar_constant_value +from pytensor.tensor.basic import alloc, get_underlying_scalar_constant_value, nonzero from pytensor.tensor.elemwise import DimShuffle -from pytensor.tensor.exceptions import ( - AdvancedIndexingError, - NotScalarConstantError, - ShapeError, -) +from pytensor.tensor.exceptions import AdvancedIndexingError, NotScalarConstantError from pytensor.tensor.math import clip -from pytensor.tensor.shape import Reshape, specify_broadcastable +from pytensor.tensor.shape import Reshape, shape_i, specify_broadcastable from pytensor.tensor.type import ( TensorType, bscalar, @@ -510,7 +506,11 @@ def indexed_result_shape(array_shape, indices, indices_are_shapes=False): from pytensor.tensor.extra_ops import broadcast_shape res_shape += broadcast_shape( - *grp_indices, arrays_are_shapes=indices_are_shapes + *grp_indices, + arrays_are_shapes=indices_are_shapes, + # The AdvancedIndexing Op relies on the Numpy implementation which allows runtime broadcasting. + # As long as that is true, the shape inference has to respect that this is not an error. + allow_runtime_broadcast=True, ) res_shape += tuple(array_shape[dim] for dim in remaining_dims) @@ -2584,26 +2584,47 @@ def R_op(self, inputs, eval_points): return self.make_node(eval_points[0], *inputs[1:]).outputs def infer_shape(self, fgraph, node, ishapes): - indices = node.inputs[1:] - index_shapes = list(ishapes[1:]) - for i, idx in enumerate(indices): - if ( + def is_bool_index(idx): + return ( isinstance(idx, (np.bool_, bool)) or getattr(idx, "dtype", None) == "bool" - ): - raise ShapeError( - "Shape inference for boolean indices is not implemented" + ) + + indices = node.inputs[1:] + index_shapes = [] + for idx, ishape in zip(indices, ishapes[1:]): + # Mixed bool indexes are converted to nonzero entries + if is_bool_index(idx): + index_shapes.extend( + (shape_i(nz_dim, 0, fgraph=fgraph),) for nz_dim in nonzero(idx) ) # The `ishapes` entries for `SliceType`s will be None, and # we need to give `indexed_result_shape` the actual slices. - if isinstance(getattr(idx, "type", None), SliceType): - index_shapes[i] = idx + elif isinstance(getattr(idx, "type", None), SliceType): + index_shapes.append(idx) + else: + index_shapes.append(ishape) - res_shape = indexed_result_shape( - ishapes[0], index_shapes, indices_are_shapes=True + res_shape = list( + indexed_result_shape(ishapes[0], index_shapes, indices_are_shapes=True) ) + + adv_indices = [idx for idx in indices if not is_basic_idx(idx)] + bool_indices = [idx for idx in adv_indices if is_bool_index(idx)] + + # Special logic when the only advanced index group is of bool type. + # We can replace the nonzeros by a sum of the whole bool variable. + if len(bool_indices) == 1 and len(adv_indices) == 1: + [bool_index] = bool_indices + # Find the output dim associated with the bool index group + # Because there are no more advanced index groups, there is exactly + # one output dim per index variable up to the bool group. + # Note: Scalar integer indexing counts as advanced indexing. + start_dim = indices.index(bool_index) + res_shape[start_dim] = bool_index.sum() + assert node.outputs[0].ndim == len(res_shape) - return [list(res_shape)] + return [res_shape] def perform(self, node, inputs, out_): (out,) = out_ diff --git a/tests/tensor/random/rewriting/test_basic.py b/tests/tensor/random/rewriting/test_basic.py index ef9cf8b3b3..337cd6cf91 100644 --- a/tests/tensor/random/rewriting/test_basic.py +++ b/tests/tensor/random/rewriting/test_basic.py @@ -26,6 +26,7 @@ local_rv_size_lift, local_subtensor_rv_lift, ) +from pytensor.tensor.rewriting.shape import ShapeFeature, ShapeOptimizer from pytensor.tensor.subtensor import AdvancedSubtensor, AdvancedSubtensor1, Subtensor from pytensor.tensor.type import iscalar, vector @@ -58,7 +59,9 @@ def apply_local_rewrite_to_rv( p for p in dist_params_at + size_at if not isinstance(p, (slice, Constant)) ] - mode = Mode("py", EquilibriumGraphRewriter([rewrite], max_use_ratio=100)) + mode = Mode( + "py", EquilibriumGraphRewriter([ShapeOptimizer(), rewrite], max_use_ratio=100) + ) f_rewritten = function( f_inputs, @@ -440,30 +443,48 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol): np.testing.assert_allclose(res_base, res_rewritten, rtol=rtol) +def rand_bool_mask(shape, rng=None): + if rng is None: + rng = np.random.default_rng() + return rng.binomial(n=1, p=0.5, size=shape).astype(bool) + + @pytest.mark.parametrize( "indices, lifted, dist_op, dist_params, size", [ + # 0 ( - # `size`-less advanced boolean indexing - (np.r_[True, False, False, True],), + # `size`-less simple integer indexing + (slice(None), 2), True, - uniform, + normal, ( - (0.1 - 1e-5) * np.arange(4).astype(dtype=config.floatX), - 0.1 * np.arange(4).astype(dtype=config.floatX), + np.arange(30, dtype=config.floatX).reshape(3, 5, 2), + np.full((1, 5, 1), 1e-6), ), (), ), ( - # `size`-only advanced boolean indexing - (np.r_[True, False, False, True],), + # `size`-only slice + (2, -1), True, uniform, ( np.array(0.9 - 1e-5, dtype=config.floatX), np.array(0.9, dtype=config.floatX), ), - (4,), + (5, 2), + ), + ( + # `size`-less slice + (slice(None), slice(4, -6, -1), slice(1, None)), + True, + normal, + ( + np.arange(30, dtype=config.floatX).reshape(3, 5, 2), + np.full((1, 5, 1), 1e-6), + ), + (), ), ( # `size`-only slice @@ -477,8 +498,32 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol): (5, 2), ), ( - (slice(1, None), [0, 2]), + # `size`-less advanced boolean indexing + (np.r_[True, False, False, True],), + True, + uniform, + ( + (0.1 - 1e-5) * np.arange(4).astype(dtype=config.floatX), + 0.1 * np.arange(4).astype(dtype=config.floatX), + ), + (), + ), + # 5 + ( + # `size`-only advanced boolean indexing + (np.r_[True, False, False, True],), True, + uniform, + ( + np.array(0.9 - 1e-5, dtype=config.floatX), + np.array(0.9, dtype=config.floatX), + ), + (4,), + ), + ( + # Advanced integer indexing + (slice(1, None), [0, 2]), + False, # Could have duplicates normal, ( np.array([1, 10, 100], dtype=config.floatX), @@ -487,8 +532,9 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol): (4, 3), ), ( + # Advanced integer indexing (np.array([1]), 0), - True, + False, # We don't support expand_dims normal, ( np.array([[-1, 20], [300, -4000]], dtype=config.floatX), @@ -496,23 +542,39 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol): ), (3, 2, 2), ), - # Only one distribution parameter ( - (0,), + # Advanced integer-boolean indexing + (0, np.r_[True, False]), True, - poisson, - (np.array([[1, 2], [3, 4]], dtype=config.floatX),), + normal, + ( + np.array([[1, 2], [3, 4]], dtype=config.floatX), + np.array([1e-6], dtype=config.floatX), + ), (3, 2, 2), ), - # Univariate distribution with vector parameters ( - (np.array([0, 2]),), + # Advanced non-consecutive integer-boolean indexing + (slice(None), 0, slice(None), np.r_[True, False]), + True, + normal, + ( + np.array([[1, 2], [3, 4]], dtype=config.floatX), + np.array([[1e-6]], dtype=config.floatX), + ), + (7, 3, 2, 2), + ), + # 10 + ( + # Univariate distribution with core-vector parameters + (1,), True, categorical, (np.array([0.0, 0.0, 1.0], dtype=config.floatX),), (4,), ), ( + # Univariate distribution with core-vector parameters (np.array([True, False, True, True]),), True, categorical, @@ -520,6 +582,7 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol): (4,), ), ( + # Univariate distribution with core-vector parameters (np.array([True, False, True]),), True, categorical, @@ -532,10 +595,8 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol): (), ), ( - ( - slice(None), - np.array([True, False, True]), - ), + # Univariate distribution with core-vector parameters + (slice(None), np.array([True, False, True])), True, categorical, ( @@ -546,16 +607,18 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol): ), (4, 3), ), - # Boolean indexing where output is empty ( + # Boolean indexing where output is empty (np.array([False, False]),), True, normal, (np.array([[1.0, 0.0, 0.0]], dtype=config.floatX),), (2, 3), ), + # 15 ( - (np.array([False, False]),), + # Boolean indexing where output is empty + (np.array([False, False]), slice(1, None)), True, categorical, ( @@ -566,10 +629,107 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol): ), (2, 3), ), - # Multivariate cases, indexing only supported if it does not affect core dimensions ( - # Indexing dips into core dimension - (np.array([1]), 0), + # Empty-slice + (slice(None), slice(10, None), slice(1, None)), + True, + normal, + ( + np.arange(30).reshape(2, 3, 5), + np.full((1, 5), 1e-6), + ), + (2, 3, 5), + ), + ( + # Multidimensional boolean indexing + (rand_bool_mask((5, 3, 2)),), + True, + normal, + ( + np.arange(30).reshape(5, 3, 2), + 1e-6, + ), + (), + ), + ( + # Multidimensional boolean indexing + (rand_bool_mask((5, 3)),), + True, + normal, + ( + np.arange(30).reshape(5, 3, 2), + 1e-6, + ), + (), + ), + ( + # Multidimensional boolean indexing + (rand_bool_mask((5, 3)), slice(None)), + True, + normal, + ( + np.arange(30).reshape(5, 3, 2), + 1e-6, + ), + (), + ), + # 20 + ( + # Multidimensional boolean indexing + (slice(None), rand_bool_mask((3, 2))), + True, + normal, + ( + np.arange(30).reshape(5, 3, 2), + 1e-6, + ), + (), + ), + ( + # Multidimensional boolean indexing + (rand_bool_mask((5, 3)),), + True, + normal, + ( + np.arange(3).reshape(1, 3, 1), + np.full((5, 1, 2), 1e-6), + ), + (5, 3, 2), + ), + ( + # Multidimensional boolean indexing + ( + np.array([True, False, True, False, False]), + slice(None), + (np.array([True, True])), + ), + True, + normal, + ( + np.arange(30).reshape(5, 3, 2), + 1e-6, + ), + (), + ), + ( + # Multidimensional boolean indexing, + # requires runtime broadcasting of the zeros arrays + ( + np.array([True, False, True, False, False]), # nonzero().shape == (2,) + slice(None), + (np.array([True, False])), # nonzero().shape == (1,) + ), + True, + normal, + ( + np.arange(30).reshape(5, 3, 2), + 1e-6, + ), + (), + ), + ( + # Multivariate distribution: indexing dips into core dimension + (1, 0), False, multivariate_normal, ( @@ -578,9 +738,22 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol): ), (), ), + # 25 ( - (np.array([0, 2]),), - True, + # Multivariate distribution: indexing dips into core dimension + (rand_bool_mask((2, 2)),), + False, + multivariate_normal, + ( + np.array([[-1, 20], [300, -4000]], dtype=config.floatX), + np.eye(2).astype(config.floatX) * 1e-6, + ), + (), + ), + ( + # Multivariate distribution: advanced integer indexing + (np.array([0, 0]),), + False, # Could have duplicates (it has in this case)! multivariate_normal, ( np.array( @@ -592,6 +765,7 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol): (), ), ( + # Multivariate distribution: dummy slice "dips" into core dimension (np.array([True, False, True]), slice(None)), True, multivariate_normal, @@ -603,6 +777,17 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol): ), (3,), ), + ( + # Multivariate distribution + (0, slice(1, None), rand_bool_mask((4, 3))), + True, + multivariate_normal, + ( + np.arange(4 * 3 * 2).reshape(4, 3, 2).astype(dtype=config.floatX), + np.eye(2) * 1e-6, + ), + (5, 3, 4, 3), + ), ], ) @config.change_flags(compute_test_value_opt="raise", compute_test_value="raise") @@ -650,7 +835,7 @@ def test_Subtensor_lift(indices, lifted, dist_op, dist_params, size): res_base = f_base(*arg_values) res_rewritten = f_rewritten(*arg_values) - np.testing.assert_allclose(res_base, res_rewritten, rtol=1e-3) + np.testing.assert_allclose(res_base, res_rewritten, rtol=1e-3, atol=1e-2) def test_Subtensor_lift_restrictions(): @@ -664,7 +849,7 @@ def test_Subtensor_lift_restrictions(): # the lift z = x - y - fg = FunctionGraph([rng], [z], clone=False) + fg = FunctionGraph([rng], [z], clone=False, features=[ShapeFeature()]) _ = EquilibriumGraphRewriter([local_subtensor_rv_lift], max_use_ratio=100).apply(fg) subtensor_node = fg.outputs[0].owner.inputs[1].owner.inputs[0].owner @@ -676,7 +861,7 @@ def test_Subtensor_lift_restrictions(): # We add `x` as an output to make sure that `is_rv_used_in_graph` handles # `"output"` "nodes" correctly. - fg = FunctionGraph([rng], [z, x], clone=False) + fg = FunctionGraph([rng], [z, x], clone=False, features=[ShapeFeature()]) EquilibriumGraphRewriter([local_subtensor_rv_lift], max_use_ratio=100).apply(fg) assert fg.outputs[0] == z @@ -684,7 +869,7 @@ def test_Subtensor_lift_restrictions(): # The non-`Subtensor` client doesn't depend on the RNG state, so we can # perform the lift - fg = FunctionGraph([rng], [z], clone=False) + fg = FunctionGraph([rng], [z], clone=False, features=[ShapeFeature()]) EquilibriumGraphRewriter([local_subtensor_rv_lift], max_use_ratio=100).apply(fg) rv_node = fg.outputs[0].owner.inputs[1].owner.inputs[0].owner diff --git a/tests/tensor/test_extra_ops.py b/tests/tensor/test_extra_ops.py index 0723498486..4d2c3fec9e 100644 --- a/tests/tensor/test_extra_ops.py +++ b/tests/tensor/test_extra_ops.py @@ -1087,9 +1087,17 @@ def shape_tuple(x, use_bcast=True): assert any( isinstance(node.op, Assert) for node in applys_between([x_at, y_at], b_at) ) - # This should fail because it would need dynamic broadcasting with pytest.raises(AssertionError): assert np.array_equal([z.eval() for z in b_at], b.shape) + # But fine if we allow_runtime_broadcast + b_at = broadcast_shape( + shape_tuple(x_at, use_bcast=False), + shape_tuple(y_at, use_bcast=False), + arrays_are_shapes=True, + allow_runtime_broadcast=True, + ) + assert np.array_equal([z.eval() for z in b_at], b.shape) + # Or if static bcast is known b_at = broadcast_shape(shape_tuple(x_at), shape_tuple(y_at), arrays_are_shapes=True) assert np.array_equal([z.eval() for z in b_at], b.shape) diff --git a/tests/tensor/test_subtensor.py b/tests/tensor/test_subtensor.py index 0f60ac4ed5..652be812e9 100644 --- a/tests/tensor/test_subtensor.py +++ b/tests/tensor/test_subtensor.py @@ -63,6 +63,7 @@ tensor, tensor3, tensor4, + tensor5, vector, ) from pytensor.tensor.type_other import NoneConst, SliceConstant, make_slice, slicetype @@ -2150,6 +2151,12 @@ def fun(x, y): class TestInferShape(utt.InferShapeTester): + @staticmethod + def random_bool_mask(shape, rng=None): + if rng is None: + rng = np.random.default_rng() + return rng.binomial(n=1, p=0.5, size=shape).astype(bool) + def test_IncSubtensor(self): admat = dmatrix() bdmat = dmatrix() @@ -2439,25 +2446,85 @@ def test_AdvancedSubtensor_bool(self): n = dmatrix() n_val = np.arange(6).reshape((2, 3)) - # infer_shape is not implemented, but it should not crash + # Shape inference requires runtime broadcasting between the nonzero() shapes self._compile_and_check( [n], [n[n[:, 0] > 2, n[0, :] > 2]], [n_val], AdvancedSubtensor, - check_topo=False, ) self._compile_and_check( [n], [n[n[:, 0] > 2]], [n_val], AdvancedSubtensor, - check_topo=False, + ) + self._compile_and_check( + [n], + [n[:, np.array([True, False, True])]], + [n_val], + AdvancedSubtensor, + ) + self._compile_and_check( + [n], + [n[np.array([False, False]), 1:]], + [n_val], + AdvancedSubtensor, + ) + self._compile_and_check( + [n], + [n[np.array([True, True]), 0]], + [n_val], + AdvancedSubtensor, + ) + self._compile_and_check( + [n], + [n[self.random_bool_mask(n_val.shape)]], + [n_val], + AdvancedSubtensor, + ) + self._compile_and_check( + [n], + [n[None, self.random_bool_mask(n_val.shape), None]], + [n_val], + AdvancedSubtensor, + ) + self._compile_and_check( + [n], + [n[slice(5, None), self.random_bool_mask(n_val.shape[1])]], + [n_val], + AdvancedSubtensor, ) abs_res = n[~isinf(n)] assert abs_res.type.shape == (None,) + def test_AdvancedSubtensor_bool_mixed(self): + n = tensor5("x", dtype="float64") + shape = (18, 3, 4, 5, 6) + n_val = np.arange(np.prod(shape)).reshape(shape) + self._compile_and_check( + [n], + # Consecutive advanced index + [n[1:, self.random_bool_mask((3, 4)), 0, 1:]], + [n_val], + AdvancedSubtensor, + ) + self._compile_and_check( + [n], + # Non-consecutive advanced index + [n[1:, self.random_bool_mask((3, 4)), 1:, 0]], + [n_val], + AdvancedSubtensor, + ) + self._compile_and_check( + [n], + # Non-consecutive advanced index + [n[1:, self.random_bool_mask((3,)), 1:, None, np.zeros((6, 1), dtype=int)]], + [n_val], + AdvancedSubtensor, + ) + @config.change_flags(compute_test_value="raise") def test_basic_shape():