Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Infer shape of advanced boolean indexing #329

Merged
merged 3 commits into from
Jul 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 40 additions & 17 deletions pytensor/tensor/extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,18 @@
from pytensor.raise_op import Assert
from pytensor.scalar import int32 as int_t
from pytensor.scalar import upcast
from pytensor.tensor import as_tensor_variable
from pytensor.tensor import basic as at
from pytensor.tensor import get_vector_length
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import abs as at_abs
from pytensor.tensor.math import abs as pt_abs
from pytensor.tensor.math import all as pt_all
from pytensor.tensor.math import eq as pt_eq
from pytensor.tensor.math import ge, lt, maximum, minimum, prod
from pytensor.tensor.math import ge, lt
from pytensor.tensor.math import max as pt_max
from pytensor.tensor.math import maximum, minimum, prod
from pytensor.tensor.math import sum as at_sum
from pytensor.tensor.math import switch
from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor
from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes, vector
from pytensor.tensor.var import TensorVariable
Expand Down Expand Up @@ -1063,7 +1067,7 @@ def grad(self, inp, cost_grad):
# only valid for matrices
wr_a = fill_diagonal_offset(grad, 0, offset)

offset_abs = at_abs(offset)
offset_abs = pt_abs(offset)
pos_offset_flag = ge(offset, 0)
neg_offset_flag = lt(offset, 0)
min_wh = minimum(width, height)
Expand Down Expand Up @@ -1442,6 +1446,7 @@ def ravel_multi_index(multi_index, dims, mode="raise", order="C"):
"axes that have a statically known length 1. Use `specify_broadcastable` to "
"inform PyTensor of a known shape."
)
_runtime_broadcast_assert = Assert("Could not broadcast dimensions.")


def broadcast_shape(*arrays, **kwargs) -> Tuple[aes.ScalarVariable, ...]:
Expand All @@ -1465,6 +1470,7 @@ def broadcast_shape(*arrays, **kwargs) -> Tuple[aes.ScalarVariable, ...]:
def broadcast_shape_iter(
arrays: Iterable[Union[TensorVariable, Tuple[TensorVariable, ...]]],
arrays_are_shapes: bool = False,
allow_runtime_broadcast: bool = False,
) -> Tuple[aes.ScalarVariable, ...]:
r"""Compute the shape resulting from broadcasting arrays.

Expand All @@ -1480,22 +1486,24 @@ def broadcast_shape_iter(
arrays
An iterable of tensors, or a tuple of shapes (as tuples),
for which the broadcast shape is computed.
arrays_are_shapes
arrays_are_shapes: bool, default False
Indicates whether or not the `arrays` contains shape tuples.
If you use this approach, make sure that the broadcastable dimensions
are (scalar) constants with the value ``1``--or simply the integer
``1``.
``1``. This is not revelant if `allow_runtime_broadcast` is True.
allow_runtime_broadcast: bool, default False
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved
Whether to allow non-statically known broadcast on the shape computation.

"""
one_at = pytensor.scalar.ScalarConstant(pytensor.scalar.int64, 1)
one = pytensor.scalar.ScalarConstant(pytensor.scalar.int64, 1)

if arrays_are_shapes:
max_dims = max(len(a) for a in arrays)

array_shapes = [
(one_at,) * (max_dims - len(a))
(one,) * (max_dims - len(a))
+ tuple(
one_at
one
if sh == 1 or isinstance(sh, Constant) and sh.value == 1
else (aes.as_scalar(sh) if not isinstance(sh, Variable) else sh)
for sh in a
Expand All @@ -1508,10 +1516,8 @@ def broadcast_shape_iter(
_arrays = tuple(at.as_tensor_variable(a) for a in arrays)

array_shapes = [
(one_at,) * (max_dims - a.ndim)
+ tuple(
one_at if t_sh == 1 else sh for sh, t_sh in zip(a.shape, a.type.shape)
)
(one,) * (max_dims - a.ndim)
+ tuple(one if t_sh == 1 else sh for sh, t_sh in zip(a.shape, a.type.shape))
for a in _arrays
]

Expand All @@ -1520,11 +1526,11 @@ def broadcast_shape_iter(
for dim_shapes in zip(*array_shapes):
# Get the shapes in this dimension that are not broadcastable
# (i.e. not symbolically known to be broadcastable)
non_bcast_shapes = [shape for shape in dim_shapes if shape != one_at]
non_bcast_shapes = [shape for shape in dim_shapes if shape != one]

if len(non_bcast_shapes) == 0:
# Every shape was broadcastable in this dimension
result_dims.append(one_at)
result_dims.append(one)
elif len(non_bcast_shapes) == 1:
# Only one shape might not be broadcastable in this dimension
result_dims.extend(non_bcast_shapes)
Expand Down Expand Up @@ -1554,9 +1560,26 @@ def broadcast_shape_iter(
result_dims.append(first_length)
continue

# Add assert that all remaining shapes are equal
condition = pt_all([pt_eq(first_length, other) for other in other_lengths])
result_dims.append(_broadcast_assert(first_length, condition))
if not allow_runtime_broadcast:
# Add assert that all remaining shapes are equal
condition = pt_all(
[pt_eq(first_length, other) for other in other_lengths]
)
result_dims.append(_broadcast_assert(first_length, condition))
else:
lengths = as_tensor_variable((first_length, *other_lengths))
runtime_broadcastable = pt_eq(lengths, one)
result_dim = pt_abs(
pt_max(switch(runtime_broadcastable, -one, lengths))
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved
)
condition = pt_all(
switch(
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved
~runtime_broadcastable,
pt_eq(lengths, result_dim),
np.array(True),
)
)
result_dims.append(_runtime_broadcast_assert(result_dim, condition))

return tuple(result_dims)

Expand Down
172 changes: 111 additions & 61 deletions pytensor/tensor/random/rewriting/basic.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from itertools import zip_longest
from itertools import chain

from pytensor.compile import optdb
from pytensor.configdefaults import config
from pytensor.graph import ancestors
from pytensor.graph.op import compute_test_value
from pytensor.graph.rewriting.basic import in2out, node_rewriter
from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, node_rewriter
from pytensor.scalar import integer_types
from pytensor.tensor import NoneConst
from pytensor.tensor.basic import constant, get_vector_length
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.extra_ops import broadcast_to
from pytensor.tensor.math import sum as at_sum
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.utils import broadcast_params
from pytensor.tensor.shape import Shape, Shape_i, shape_padleft
Expand All @@ -18,7 +19,6 @@
Subtensor,
as_index_variable,
get_idx_list,
indexed_result_shape,
)
from pytensor.tensor.type_other import SliceType

Expand Down Expand Up @@ -207,98 +207,148 @@ def local_subtensor_rv_lift(fgraph, node):
``mvnormal(mu, cov, size=(2,))[0, 0]``.
"""

st_op = node.op
def is_nd_advanced_idx(idx, dtype):
if isinstance(dtype, str):
return (getattr(idx.type, "dtype", None) == dtype) and (idx.type.ndim >= 1)
else:
return (getattr(idx.type, "dtype", None) in dtype) and (idx.type.ndim >= 1)

if not isinstance(st_op, (AdvancedSubtensor, AdvancedSubtensor1, Subtensor)):
return False
subtensor_op = node.op

old_subtensor = node.outputs[0]
rv = node.inputs[0]
rv_node = rv.owner

if not (rv_node and isinstance(rv_node.op, RandomVariable)):
return False

shape_feature = getattr(fgraph, "shape_feature", None)
if not shape_feature:
return None

# Use shape_feature to facilitate inferring final shape.
# Check that neither the RV nor the old Subtensor are in the shape graph.
output_shape = fgraph.shape_feature.shape_of.get(old_subtensor, None)
if output_shape is None or {old_subtensor, rv} & set(ancestors(output_shape)):
return None

rv_op = rv_node.op
rng, size, dtype, *dist_params = rv_node.inputs

# Parse indices
idx_list = getattr(st_op, "idx_list", None)
idx_list = getattr(subtensor_op, "idx_list", None)
if idx_list:
cdata = get_idx_list(node.inputs, idx_list)
idx_vars = get_idx_list(node.inputs, idx_list)
else:
cdata = node.inputs[1:]
st_indices, st_is_bool = zip(
*tuple(
(as_index_variable(i), getattr(i, "dtype", None) == "bool") for i in cdata
)
)
idx_vars = node.inputs[1:]
indices = tuple(as_index_variable(idx) for idx in idx_vars)

# The rewrite doesn't apply if advanced indexing could broadcast the samples (leading to duplicates)
# Note: For simplicity this also excludes subtensor-related expand_dims (np.newaxis).
# If we wanted to support that we could rewrite it as subtensor + dimshuffle
# and make use of the dimshuffle lift rewrite
integer_dtypes = {type.dtype for type in integer_types}
if any(
is_nd_advanced_idx(idx, integer_dtypes) or NoneConst.equals(idx)
for idx in indices
):
return False

# Check that indexing does not act on support dims
batched_ndims = rv.ndim - rv_op.ndim_supp
if len(st_indices) > batched_ndims:
# If the last indexes are just dummy `slice(None)` we discard them
st_is_bool = st_is_bool[:batched_ndims]
st_indices, supp_indices = (
st_indices[:batched_ndims],
st_indices[batched_ndims:],
batch_ndims = rv.ndim - rv_op.ndim_supp
# We decompose the boolean indexes, which makes it clear whether they act on support dims or not
non_bool_indices = tuple(
chain.from_iterable(
idx.nonzero() if is_nd_advanced_idx(idx, "bool") else (idx,)
for idx in indices
)
)
if len(non_bool_indices) > batch_ndims:
# If the last indexes are just dummy `slice(None)` we discard them instead of quitting
non_bool_indices, supp_indices = (
non_bool_indices[:batch_ndims],
non_bool_indices[batch_ndims:],
)
for index in supp_indices:
for idx in supp_indices:
if not (
isinstance(index.type, SliceType)
and all(NoneConst.equals(i) for i in index.owner.inputs)
isinstance(idx.type, SliceType)
and all(NoneConst.equals(i) for i in idx.owner.inputs)
):
return False
n_discarded_idxs = len(supp_indices)
indices = indices[:-n_discarded_idxs]

# If no one else is using the underlying `RandomVariable`, then we can
# do this; otherwise, the graph would be internally inconsistent.
if is_rv_used_in_graph(rv, node, fgraph):
return False

# Update the size to reflect the indexed dimensions
# TODO: Could use `ShapeFeature` info. We would need to be sure that
# `node` isn't in the results, though.
# if hasattr(fgraph, "shape_feature"):
# output_shape = fgraph.shape_feature.shape_of(node.outputs[0])
# else:
output_shape_ignoring_bool = indexed_result_shape(rv.shape, st_indices)
new_size_ignoring_boolean = (
output_shape_ignoring_bool
if rv_op.ndim_supp == 0
else output_shape_ignoring_bool[: -rv_op.ndim_supp]
)
new_size = output_shape[: len(output_shape) - rv_op.ndim_supp]

# Boolean indices can actually change the `size` value (compared to just *which* dimensions of `size` are used).
# The `indexed_result_shape` helper does not consider this
if any(st_is_bool):
new_size = tuple(
at_sum(idx) if is_bool else s
for s, is_bool, idx in zip_longest(
new_size_ignoring_boolean, st_is_bool, st_indices, fillvalue=False
)
)
else:
new_size = new_size_ignoring_boolean

# Update the parameters to reflect the indexed dimensions
# Propagate indexing to the parameters' batch dims.
# We try to avoid broadcasting the parameters together (and with size), by only indexing
# non-broadcastable (non-degenerate) parameter dims. These parameters and the new size
# should still correctly broadcast any degenerate parameter dims.
new_dist_params = []
for param, param_ndim_supp in zip(dist_params, rv_op.ndims_params):
# Apply indexing on the batched dimensions of the parameter
batched_param_dims_missing = batched_ndims - (param.ndim - param_ndim_supp)
batched_param = shape_padleft(param, batched_param_dims_missing)
batched_st_indices = []
for st_index, batched_param_shape in zip(st_indices, batched_param.type.shape):
# If we have a degenerate dimension indexing it should always do the job
if batched_param_shape == 1:
batched_st_indices.append(0)
# We first expand any missing parameter dims (and later index them away or keep them with none-slicing)
batch_param_dims_missing = batch_ndims - (param.ndim - param_ndim_supp)
batch_param = (
shape_padleft(param, batch_param_dims_missing)
if batch_param_dims_missing
else param
)
# Check which dims are actually broadcasted
bcast_batch_param_dims = tuple(
dim
for dim, (param_dim, output_dim) in enumerate(
zip(batch_param.type.shape, rv.type.shape)
)
if (param_dim == 1) and (output_dim != 1)
)
batch_indices = []
curr_dim = 0
for idx in indices:
# Advanced boolean indexing
if is_nd_advanced_idx(idx, "bool"):
# Check if any broadcasted dim overlaps with advanced boolean indexing.
# If not, we use that directly, instead of the more inefficient `nonzero` form
bool_dims = range(curr_dim, curr_dim + idx.type.ndim)
# There's an overlap, we have to decompose the boolean mask as a `nonzero`
if set(bool_dims) & set(bcast_batch_param_dims):
int_indices = list(idx.nonzero())
# Indexing by 0 drops the degenerate dims
for bool_dim in bool_dims:
if bool_dim in bcast_batch_param_dims:
int_indices[bool_dim - curr_dim] = 0
batch_indices.extend(int_indices)
# No overlap, use index as is
else:
batch_indices.append(idx)
curr_dim += len(bool_dims)
# Basic-indexing (slice or integer)
else:
batched_st_indices.append(st_index)
new_dist_params.append(batched_param[tuple(batched_st_indices)])
# Broadcasted dim
if curr_dim in bcast_batch_param_dims:
# Slice indexing, keep degenerate dim by none-slicing
if isinstance(idx.type, SliceType):
batch_indices.append(slice(None))
# Integer indexing, drop degenerate dim by 0-indexing
else:
batch_indices.append(0)
# Non-broadcasted dim
else:
# Use index as is
batch_indices.append(idx)
curr_dim += 1

new_dist_params.append(batch_param[tuple(batch_indices)])

# Create new RV
new_node = rv_op.make_node(rng, new_size, dtype, *new_dist_params)
new_rv = new_node.default_output()

if config.compute_test_value != "off":
compute_test_value(new_node)
copy_stack_trace(rv, new_rv)

return [new_rv]
12 changes: 0 additions & 12 deletions pytensor/tensor/rewriting/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,18 +99,6 @@ class ShapeFeature(Feature):

Notes
-----
Right now there is only the ConvOp that could really take
advantage of this shape inference, but it is worth it even
just for the ConvOp. All that's necessary to do shape
inference is 1) to mark shared inputs as having a particular
shape, either via a .tag or some similar hacking; and 2) to
add an optional In() argument to promise that inputs will
have a certain shape (or even to have certain shapes in
certain dimensions).

We can't automatically infer the shape of shared variables as they can
change of shape during the execution by default.

To use this shape information in rewrites, use the
``shape_of`` dictionary.

Expand Down
Loading