Skip to content

Commit

Permalink
Align batch_ndims of RandomVariable inputs in make_node
Browse files Browse the repository at this point in the history
Also fixes bug in vectorize of RandomVariable, where it was wrongly using the first parameter to infer the new size dims, even though that was not broadcasted, only expanded with new dims.
  • Loading branch information
ricardoV94 committed Apr 23, 2024
1 parent 16a4f3b commit bae694d
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 50 deletions.
38 changes: 20 additions & 18 deletions pytensor/tensor/random/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pytensor.configdefaults import config
from pytensor.graph.basic import Apply, Variable, equal_computations
from pytensor.graph.op import Op
from pytensor.graph.replace import _vectorize_node, vectorize_graph
from pytensor.graph.replace import _vectorize_node
from pytensor.misc.safe_asarray import _asarray
from pytensor.scalar import ScalarVariable
from pytensor.tensor.basic import (
Expand All @@ -20,6 +20,7 @@
)
from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType, RandomType
from pytensor.tensor.random.utils import (
compute_batch_shape,
explicit_expand_dims,
normalize_size_param,
)
Expand Down Expand Up @@ -130,6 +131,9 @@ def __str__(self):
props_str = ", ".join(f"{getattr(self, prop)}" for prop in self.__props__[1:])
return f"{self.name}_rv{{{props_str}}}"

def batch_ndim(self, node):
return node.default_output().type.ndim - self.ndim_supp

def _infer_shape(
self,
size: TensorVariable,
Expand Down Expand Up @@ -298,8 +302,12 @@ def make_node(self, rng, size, dtype, *dist_params):
dtype_idx = constant(dtype, dtype="int64")
dtype = all_dtypes[dtype_idx.data]

outtype = TensorType(dtype=dtype, shape=static_shape)
out_var = outtype()
out_var = TensorType(dtype=dtype, shape=static_shape)()

# Add expand_dims to align batch dimensions
dist_params = explicit_expand_dims(
dist_params, self.ndims_params, size_length=size.type.shape[0]
)
inputs = (rng, size, dtype_idx, *dist_params)
outputs = (rng.type(), out_var)

Expand Down Expand Up @@ -390,28 +398,22 @@ def vectorize_random_variable(
# We extend it to accommodate the new input batch dimensions.
# Otherwise, we assume the new size already has the right values

# Need to make parameters implicit broadcasting explicit
original_dist_params = node.inputs[3:]
old_dist_params = node.inputs[3:]
old_size = node.inputs[1]
len_old_size = get_vector_length(old_size)

original_expanded_dist_params = explicit_expand_dims(
original_dist_params, op.ndims_params, len_old_size
)
# We call vectorize_graph to automatically handle any new explicit expand_dims
dist_params = vectorize_graph(
original_expanded_dist_params, dict(zip(original_dist_params, dist_params))
)
new_dist_params = explicit_expand_dims(dist_params, op.ndims_params)

if len_old_size and equal_computations([old_size], [size]):
# If the original RV had a size variable and a new one has not been provided,
# we need to define a new size as the concatenation of the original size dimensions
# and the novel ones implied by new broadcasted batched parameters dimensions.
# We use the first broadcasted batch dimension for reference.
bcasted_param = explicit_expand_dims(dist_params, op.ndims_params)[0]
new_param_ndim = (bcasted_param.type.ndim - op.ndims_params[0]) - len_old_size
if new_param_ndim >= 0:
new_size_dims = bcasted_param.shape[:new_param_ndim]
new_ndim = new_dist_params[0].type.ndim - old_dist_params[0].type.ndim
if new_ndim >= 0:
new_size = compute_batch_shape(
new_dist_params, ndims_params=op.ndims_params
)
new_size_dims = new_size[:new_ndim]
size = concatenate([new_size_dims, size])

return op.make_node(rng, size, dtype, *dist_params)
return op.make_node(rng, size, dtype, *new_dist_params)
13 changes: 11 additions & 2 deletions pytensor/tensor/random/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from pytensor.scalar import ScalarVariable
from pytensor.tensor import get_vector_length
from pytensor.tensor.basic import as_tensor_variable, cast, constant
from pytensor.tensor.extra_ops import broadcast_to
from pytensor.tensor.extra_ops import broadcast_arrays, broadcast_to
from pytensor.tensor.math import maximum
from pytensor.tensor.shape import shape_padleft, specify_shape
from pytensor.tensor.type import int_dtypes
Expand Down Expand Up @@ -123,7 +123,7 @@ def broadcast_params(params, ndims_params):

def explicit_expand_dims(
params: Sequence[TensorVariable],
ndim_params: tuple[int],
ndim_params: Sequence[int],
size_length: int = 0,
) -> list[TensorVariable]:
"""Introduce explicit expand_dims in RV parameters that are implicitly broadcasted together and/or by size."""
Expand All @@ -149,6 +149,15 @@ def explicit_expand_dims(
return new_params


def compute_batch_shape(params, ndims_params: Sequence[int]) -> TensorVariable:
params = explicit_expand_dims(params, ndims_params)
batch_params = [
param[..., *[(0,) for _ in range(core_ndim)]]
for param, core_ndim in zip(params, ndims_params)
]
return broadcast_arrays(*batch_params)[0].shape


def normalize_size_param(
size: int | np.ndarray | Variable | Sequence | None,
) -> Variable:
Expand Down
57 changes: 27 additions & 30 deletions tests/tensor/random/test_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytensor.tensor as pt
from pytensor import config, function
from pytensor.gradient import NullTypeGradError, grad
from pytensor.graph.replace import vectorize_node
from pytensor.graph.replace import vectorize_graph
from pytensor.raise_op import Assert
from pytensor.tensor.math import eq
from pytensor.tensor.random import normal
Expand Down Expand Up @@ -241,62 +241,59 @@ def test_multivariate_rv_infer_static_shape():
assert mv_op(param1, param2, size=(10, 2)).type.shape == (10, 2, 3)


def test_vectorize_node():
def test_vectorize():
vec = tensor(shape=(None,))
mat = tensor(shape=(None, None))

# Test without size
node = normal(vec).owner
new_inputs = node.inputs.copy()
new_inputs[3] = mat # mu
vect_node = vectorize_node(node, *new_inputs)
out = normal(vec)
vect_node = vectorize_graph(out, {vec: mat}).owner
assert vect_node.op is normal
assert vect_node.inputs[3] is mat

# Test with size, new size provided
node = normal(vec, size=(3,)).owner
new_inputs = node.inputs.copy()
new_inputs[1] = (2, 3) # size
new_inputs[3] = mat # mu
vect_node = vectorize_node(node, *new_inputs)
size = pt.as_tensor(np.array((3,), dtype="int64"))
out = normal(vec, size=size)
vect_node = vectorize_graph(out, {vec: mat, size: (2, 3)}).owner
assert vect_node.op is normal
assert tuple(vect_node.inputs[1].eval()) == (2, 3)
assert vect_node.inputs[3] is mat

# Test with size, new size not provided
node = normal(vec, size=(3,)).owner
new_inputs = node.inputs.copy()
new_inputs[3] = mat # mu
vect_node = vectorize_node(node, *new_inputs)
out = normal(vec, size=(3,))
vect_node = vectorize_graph(out, {vec: mat}).owner
assert vect_node.op is normal
assert vect_node.inputs[3] is mat
assert tuple(
vect_node.inputs[1].eval({mat: np.zeros((2, 3), dtype=config.floatX)})
) == (2, 3)

# Test parameter broadcasting
node = normal(vec).owner
new_inputs = node.inputs.copy()
new_inputs[3] = tensor("mu", shape=(10, 5)) # mu
new_inputs[4] = tensor("sigma", shape=(10,)) # sigma
vect_node = vectorize_node(node, *new_inputs)
mu = vec
sigma = pt.as_tensor(np.array(1.0))
out = normal(mu, sigma)
new_mu = tensor("mu", shape=(10, 5))
new_sigma = tensor("sigma", shape=(10,))
vect_node = vectorize_graph(out, {mu: new_mu, sigma: new_sigma}).owner
assert vect_node.op is normal
assert vect_node.default_output().type.shape == (10, 5)

# Test parameter broadcasting with non-expanding size
node = normal(vec, size=(5,)).owner
new_inputs = node.inputs.copy()
new_inputs[3] = tensor("mu", shape=(10, 5)) # mu
new_inputs[4] = tensor("sigma", shape=(10,)) # sigma
vect_node = vectorize_node(node, *new_inputs)
mu = vec
sigma = pt.as_tensor(np.array(1.0))
out = normal(mu, sigma, size=(5,))
new_mu = tensor("mu", shape=(10, 5))
new_sigma = tensor("sigma", shape=(10,))
vect_node = vectorize_graph(out, {mu: new_mu, sigma: new_sigma}).owner
assert vect_node.op is normal
assert vect_node.default_output().type.shape == (10, 5)

# Test parameter broadcasting with expanding size
node = normal(vec, size=(2, 5)).owner
new_inputs = node.inputs.copy()
new_inputs[3] = tensor("mu", shape=(10, 5)) # mu
new_inputs[4] = tensor("sigma", shape=(10,)) # sigma
vect_node = vectorize_node(node, *new_inputs)
mu = vec
sigma = pt.as_tensor(np.array(1.0))
out = normal(mu, sigma, size=(2, 5))
new_mu = tensor("mu", shape=(1, 5))
new_sigma = tensor("sigma", shape=(10,))
vect_node = vectorize_graph(out, {mu: new_mu, sigma: new_sigma}).owner
assert vect_node.op is normal
assert vect_node.default_output().type.shape == (10, 2, 5)

0 comments on commit bae694d

Please sign in to comment.