diff --git a/pytensor/tensor/random/op.py b/pytensor/tensor/random/op.py index 005e1d55fa..f4c4403bc7 100644 --- a/pytensor/tensor/random/op.py +++ b/pytensor/tensor/random/op.py @@ -7,7 +7,7 @@ from pytensor.configdefaults import config from pytensor.graph.basic import Apply, Variable, equal_computations from pytensor.graph.op import Op -from pytensor.graph.replace import _vectorize_node, vectorize_graph +from pytensor.graph.replace import _vectorize_node from pytensor.misc.safe_asarray import _asarray from pytensor.scalar import ScalarVariable from pytensor.tensor.basic import ( @@ -20,6 +20,7 @@ ) from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType, RandomType from pytensor.tensor.random.utils import ( + compute_batch_shape, explicit_expand_dims, normalize_size_param, ) @@ -130,6 +131,9 @@ def __str__(self): props_str = ", ".join(f"{getattr(self, prop)}" for prop in self.__props__[1:]) return f"{self.name}_rv{{{props_str}}}" + def batch_ndim(self, node): + return node.default_output().type.ndim - self.ndim_supp + def _infer_shape( self, size: TensorVariable, @@ -298,8 +302,12 @@ def make_node(self, rng, size, dtype, *dist_params): dtype_idx = constant(dtype, dtype="int64") dtype = all_dtypes[dtype_idx.data] - outtype = TensorType(dtype=dtype, shape=static_shape) - out_var = outtype() + out_var = TensorType(dtype=dtype, shape=static_shape)() + + # Add expand_dims to align batch dimensions + dist_params = explicit_expand_dims( + dist_params, self.ndims_params, size_length=size.type.shape[0] + ) inputs = (rng, size, dtype_idx, *dist_params) outputs = (rng.type(), out_var) @@ -390,28 +398,22 @@ def vectorize_random_variable( # We extend it to accommodate the new input batch dimensions. # Otherwise, we assume the new size already has the right values - # Need to make parameters implicit broadcasting explicit - original_dist_params = node.inputs[3:] + old_dist_params = node.inputs[3:] old_size = node.inputs[1] len_old_size = get_vector_length(old_size) - original_expanded_dist_params = explicit_expand_dims( - original_dist_params, op.ndims_params, len_old_size - ) - # We call vectorize_graph to automatically handle any new explicit expand_dims - dist_params = vectorize_graph( - original_expanded_dist_params, dict(zip(original_dist_params, dist_params)) - ) + new_dist_params = explicit_expand_dims(dist_params, op.ndims_params) if len_old_size and equal_computations([old_size], [size]): # If the original RV had a size variable and a new one has not been provided, # we need to define a new size as the concatenation of the original size dimensions # and the novel ones implied by new broadcasted batched parameters dimensions. - # We use the first broadcasted batch dimension for reference. - bcasted_param = explicit_expand_dims(dist_params, op.ndims_params)[0] - new_param_ndim = (bcasted_param.type.ndim - op.ndims_params[0]) - len_old_size - if new_param_ndim >= 0: - new_size_dims = bcasted_param.shape[:new_param_ndim] + new_ndim = new_dist_params[0].type.ndim - old_dist_params[0].type.ndim + if new_ndim >= 0: + new_size = compute_batch_shape( + new_dist_params, ndims_params=op.ndims_params + ) + new_size_dims = new_size[:new_ndim] size = concatenate([new_size_dims, size]) - return op.make_node(rng, size, dtype, *dist_params) + return op.make_node(rng, size, dtype, *new_dist_params) diff --git a/pytensor/tensor/random/utils.py b/pytensor/tensor/random/utils.py index 5d74a16e20..dfaf2fc2e7 100644 --- a/pytensor/tensor/random/utils.py +++ b/pytensor/tensor/random/utils.py @@ -11,7 +11,7 @@ from pytensor.scalar import ScalarVariable from pytensor.tensor import get_vector_length from pytensor.tensor.basic import as_tensor_variable, cast, constant -from pytensor.tensor.extra_ops import broadcast_to +from pytensor.tensor.extra_ops import broadcast_arrays, broadcast_to from pytensor.tensor.math import maximum from pytensor.tensor.shape import shape_padleft, specify_shape from pytensor.tensor.type import int_dtypes @@ -123,7 +123,7 @@ def broadcast_params(params, ndims_params): def explicit_expand_dims( params: Sequence[TensorVariable], - ndim_params: tuple[int], + ndim_params: Sequence[int], size_length: int = 0, ) -> list[TensorVariable]: """Introduce explicit expand_dims in RV parameters that are implicitly broadcasted together and/or by size.""" @@ -149,6 +149,15 @@ def explicit_expand_dims( return new_params +def compute_batch_shape(params, ndims_params: Sequence[int]) -> TensorVariable: + params = explicit_expand_dims(params, ndims_params) + batch_params = [ + param[..., *[(0,) for _ in range(core_ndim)]] + for param, core_ndim in zip(params, ndims_params) + ] + return broadcast_arrays(*batch_params)[0].shape + + def normalize_size_param( size: int | np.ndarray | Variable | Sequence | None, ) -> Variable: diff --git a/tests/tensor/random/test_op.py b/tests/tensor/random/test_op.py index f801ab731e..43914afd55 100644 --- a/tests/tensor/random/test_op.py +++ b/tests/tensor/random/test_op.py @@ -4,7 +4,7 @@ import pytensor.tensor as pt from pytensor import config, function from pytensor.gradient import NullTypeGradError, grad -from pytensor.graph.replace import vectorize_node +from pytensor.graph.replace import vectorize_graph from pytensor.raise_op import Assert from pytensor.tensor.math import eq from pytensor.tensor.random import normal @@ -241,33 +241,27 @@ def test_multivariate_rv_infer_static_shape(): assert mv_op(param1, param2, size=(10, 2)).type.shape == (10, 2, 3) -def test_vectorize_node(): +def test_vectorize(): vec = tensor(shape=(None,)) mat = tensor(shape=(None, None)) # Test without size - node = normal(vec).owner - new_inputs = node.inputs.copy() - new_inputs[3] = mat # mu - vect_node = vectorize_node(node, *new_inputs) + out = normal(vec) + vect_node = vectorize_graph(out, {vec: mat}).owner assert vect_node.op is normal assert vect_node.inputs[3] is mat # Test with size, new size provided - node = normal(vec, size=(3,)).owner - new_inputs = node.inputs.copy() - new_inputs[1] = (2, 3) # size - new_inputs[3] = mat # mu - vect_node = vectorize_node(node, *new_inputs) + size = pt.as_tensor(np.array((3,), dtype="int64")) + out = normal(vec, size=size) + vect_node = vectorize_graph(out, {vec: mat, size: (2, 3)}).owner assert vect_node.op is normal assert tuple(vect_node.inputs[1].eval()) == (2, 3) assert vect_node.inputs[3] is mat # Test with size, new size not provided - node = normal(vec, size=(3,)).owner - new_inputs = node.inputs.copy() - new_inputs[3] = mat # mu - vect_node = vectorize_node(node, *new_inputs) + out = normal(vec, size=(3,)) + vect_node = vectorize_graph(out, {vec: mat}).owner assert vect_node.op is normal assert vect_node.inputs[3] is mat assert tuple( @@ -275,28 +269,31 @@ def test_vectorize_node(): ) == (2, 3) # Test parameter broadcasting - node = normal(vec).owner - new_inputs = node.inputs.copy() - new_inputs[3] = tensor("mu", shape=(10, 5)) # mu - new_inputs[4] = tensor("sigma", shape=(10,)) # sigma - vect_node = vectorize_node(node, *new_inputs) + mu = vec + sigma = pt.as_tensor(np.array(1.0)) + out = normal(mu, sigma) + new_mu = tensor("mu", shape=(10, 5)) + new_sigma = tensor("sigma", shape=(10,)) + vect_node = vectorize_graph(out, {mu: new_mu, sigma: new_sigma}).owner assert vect_node.op is normal assert vect_node.default_output().type.shape == (10, 5) # Test parameter broadcasting with non-expanding size - node = normal(vec, size=(5,)).owner - new_inputs = node.inputs.copy() - new_inputs[3] = tensor("mu", shape=(10, 5)) # mu - new_inputs[4] = tensor("sigma", shape=(10,)) # sigma - vect_node = vectorize_node(node, *new_inputs) + mu = vec + sigma = pt.as_tensor(np.array(1.0)) + out = normal(mu, sigma, size=(5,)) + new_mu = tensor("mu", shape=(10, 5)) + new_sigma = tensor("sigma", shape=(10,)) + vect_node = vectorize_graph(out, {mu: new_mu, sigma: new_sigma}).owner assert vect_node.op is normal assert vect_node.default_output().type.shape == (10, 5) # Test parameter broadcasting with expanding size - node = normal(vec, size=(2, 5)).owner - new_inputs = node.inputs.copy() - new_inputs[3] = tensor("mu", shape=(10, 5)) # mu - new_inputs[4] = tensor("sigma", shape=(10,)) # sigma - vect_node = vectorize_node(node, *new_inputs) + mu = vec + sigma = pt.as_tensor(np.array(1.0)) + out = normal(mu, sigma, size=(2, 5)) + new_mu = tensor("mu", shape=(1, 5)) + new_sigma = tensor("sigma", shape=(10,)) + vect_node = vectorize_graph(out, {mu: new_mu, sigma: new_sigma}).owner assert vect_node.op is normal assert vect_node.default_output().type.shape == (10, 2, 5)