Skip to content

Commit

Permalink
Fix advanced indexing in subtensor_rv_lift
Browse files Browse the repository at this point in the history
Also excludes the following cases:
1. expand_dims via broadcasting
2. multi-dimensional integer indexing (could lead to duplicates which is inconsitent with the lifted RV graph)
  • Loading branch information
ricardoV94 committed Jun 7, 2023
1 parent a85ecce commit d34dbbc
Show file tree
Hide file tree
Showing 2 changed files with 268 additions and 82 deletions.
170 changes: 109 additions & 61 deletions pytensor/tensor/random/rewriting/basic.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from itertools import zip_longest
from itertools import chain

from pytensor.compile import optdb
from pytensor.configdefaults import config
from pytensor.graph import ancestors
from pytensor.graph.op import compute_test_value
from pytensor.graph.rewriting.basic import in2out, node_rewriter
from pytensor.scalar import integer_types
from pytensor.tensor import NoneConst
from pytensor.tensor.basic import constant, get_vector_length
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.extra_ops import broadcast_to
from pytensor.tensor.math import sum as at_sum
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.utils import broadcast_params
from pytensor.tensor.shape import Shape, Shape_i, shape_padleft
Expand All @@ -18,7 +19,6 @@
Subtensor,
as_index_variable,
get_idx_list,
indexed_result_shape,
)
from pytensor.tensor.type_other import SliceType

Expand Down Expand Up @@ -207,98 +207,146 @@ def local_subtensor_rv_lift(fgraph, node):
``mvnormal(mu, cov, size=(2,))[0, 0]``.
"""

st_op = node.op
def is_nd_advanced_idx(idx, dtype):
if isinstance(dtype, str):
return (getattr(idx.type, "dtype", None) == dtype) and (idx.type.ndim >= 1)
else:
return (getattr(idx.type, "dtype", None) in dtype) and (idx.type.ndim >= 1)

if not isinstance(st_op, (AdvancedSubtensor, AdvancedSubtensor1, Subtensor)):
return False
subtensor_op = node.op

old_subtensor = node.outputs[0]
rv = node.inputs[0]
rv_node = rv.owner

if not (rv_node and isinstance(rv_node.op, RandomVariable)):
return False

shape_feature = getattr(fgraph, "shape_feature", None)
if not shape_feature:
return None

# Use shape_feature to facilitate inferring final shape.
# Check that neither the RV nor the old Subtensor are in the shape graph.
output_shape = fgraph.shape_feature.shape_of.get(old_subtensor, None)
if output_shape is None or {old_subtensor, rv} & set(ancestors(output_shape)):
return None

rv_op = rv_node.op
rng, size, dtype, *dist_params = rv_node.inputs

# Parse indices
idx_list = getattr(st_op, "idx_list", None)
idx_list = getattr(subtensor_op, "idx_list", None)
if idx_list:
cdata = get_idx_list(node.inputs, idx_list)
idx_vars = get_idx_list(node.inputs, idx_list)
else:
cdata = node.inputs[1:]
st_indices, st_is_bool = zip(
*tuple(
(as_index_variable(i), getattr(i, "dtype", None) == "bool") for i in cdata
)
)
idx_vars = node.inputs[1:]
indices = tuple(as_index_variable(idx) for idx in idx_vars)

# The rewrite doesn't apply if advanced indexing could broadcast the samples (leading to duplicates)
# Note: For simplicity this also excludes subtensor-related expand_dims (np.newaxis).
# If we wanted to support that we could rewrite it as subtensor + dimshuffle
# and make use of the dimshuffle lift rewrite
integer_dtypes = {type.dtype for type in integer_types}
if any(
is_nd_advanced_idx(idx, integer_dtypes) or NoneConst.equals(idx)
for idx in indices
):
return False

# Check that indexing does not act on support dims
batched_ndims = rv.ndim - rv_op.ndim_supp
if len(st_indices) > batched_ndims:
# If the last indexes are just dummy `slice(None)` we discard them
st_is_bool = st_is_bool[:batched_ndims]
st_indices, supp_indices = (
st_indices[:batched_ndims],
st_indices[batched_ndims:],
batch_ndims = rv.ndim - rv_op.ndim_supp
# We decompose the boolean indexes, which makes it clear whether they act on support dims or not
non_bool_indices = tuple(
chain.from_iterable(
idx.nonzero() if is_nd_advanced_idx(idx, "bool") else (idx,)
for idx in indices
)
)
if len(non_bool_indices) > batch_ndims:
# If the last indexes are just dummy `slice(None)` we discard them instead of quitting
non_bool_indices, supp_indices = (
non_bool_indices[:batch_ndims],
non_bool_indices[batch_ndims:],
)
for index in supp_indices:
for idx in supp_indices:
if not (
isinstance(index.type, SliceType)
and all(NoneConst.equals(i) for i in index.owner.inputs)
isinstance(idx.type, SliceType)
and all(NoneConst.equals(i) for i in idx.owner.inputs)
):
return False
n_discarded_idxs = len(supp_indices)
indices = indices[:-n_discarded_idxs]

# If no one else is using the underlying `RandomVariable`, then we can
# do this; otherwise, the graph would be internally inconsistent.
if is_rv_used_in_graph(rv, node, fgraph):
return False

# Update the size to reflect the indexed dimensions
# TODO: Could use `ShapeFeature` info. We would need to be sure that
# `node` isn't in the results, though.
# if hasattr(fgraph, "shape_feature"):
# output_shape = fgraph.shape_feature.shape_of(node.outputs[0])
# else:
output_shape_ignoring_bool = indexed_result_shape(rv.shape, st_indices)
new_size_ignoring_boolean = (
output_shape_ignoring_bool
if rv_op.ndim_supp == 0
else output_shape_ignoring_bool[: -rv_op.ndim_supp]
)
new_size = output_shape[: len(output_shape) - rv_op.ndim_supp]

# Boolean indices can actually change the `size` value (compared to just *which* dimensions of `size` are used).
# The `indexed_result_shape` helper does not consider this
if any(st_is_bool):
new_size = tuple(
at_sum(idx) if is_bool else s
for s, is_bool, idx in zip_longest(
new_size_ignoring_boolean, st_is_bool, st_indices, fillvalue=False
)
)
else:
new_size = new_size_ignoring_boolean

# Update the parameters to reflect the indexed dimensions
# Propagate indexing to the parameters' batch dims.
# We try to avoid broadcasting the parameters together (and with size), by only indexing
# non-broadcastable (non-degenerate) parameter dims. These parameters and the new size
# should still correctly broadcast any degenerate parameter dims.
new_dist_params = []
for param, param_ndim_supp in zip(dist_params, rv_op.ndims_params):
# Apply indexing on the batched dimensions of the parameter
batched_param_dims_missing = batched_ndims - (param.ndim - param_ndim_supp)
batched_param = shape_padleft(param, batched_param_dims_missing)
batched_st_indices = []
for st_index, batched_param_shape in zip(st_indices, batched_param.type.shape):
# If we have a degenerate dimension indexing it should always do the job
if batched_param_shape == 1:
batched_st_indices.append(0)
# We first expand any missing parameter dims (and later index them away or keep them with none-slicing)
batch_param_dims_missing = batch_ndims - (param.ndim - param_ndim_supp)
batch_param = (
shape_padleft(param, batch_param_dims_missing)
if batch_param_dims_missing
else param
)
# Check which dims are actually broadcasted
bcast_batch_param_dims = tuple(
dim
for dim, (param_dim, output_dim) in enumerate(
zip(batch_param.type.shape, rv.type.shape)
)
if (param_dim == 1) and (output_dim != 1)
)
batch_indices = []
curr_dim = 0
for idx in indices:
# Advanced boolean indexing
if is_nd_advanced_idx(idx, "bool"):
# Check if any broadcasted dim overlaps with advanced boolean indexing.
# If not, we use that directly, instead of the more inefficient `nonzero` form
bool_dims = range(curr_dim, curr_dim + idx.type.ndim)
# There's an overlap, we have to decompose the boolean mask as a `nonzero`
if set(bool_dims) & set(bcast_batch_param_dims):
int_indices = list(idx.nonzero())
# Indexing by 0 drops the degenerate dims
for bool_dim in bool_dims:
if bool_dim in bcast_batch_param_dims:
int_indices[bool_dim - curr_dim] = 0
batch_indices.extend(int_indices)
# No overlap, use index as is
else:
batch_indices.append(idx)
curr_dim += len(bool_dims)
# Basic-indexing (slice or integer)
else:
batched_st_indices.append(st_index)
new_dist_params.append(batched_param[tuple(batched_st_indices)])
# Broadcasted dim
if curr_dim in bcast_batch_param_dims:
# Slice indexing, keep degenerate dim by none-slicing
if isinstance(idx.type, SliceType):
batch_indices.append(slice(None))
# Integer indexing, drop degenerate dim by 0-indexing
else:
batch_indices.append(0)
# Non-broadcasted dim
else:
# Use index as is
batch_indices.append(idx)
curr_dim += 1

new_dist_params.append(batch_param[tuple(batch_indices)])

# Create new RV
new_node = rv_op.make_node(rng, new_size, dtype, *new_dist_params)
new_rv = new_node.default_output()

if config.compute_test_value != "off":
compute_test_value(new_node)

return [new_rv]
Loading

0 comments on commit d34dbbc

Please sign in to comment.