diff --git a/pytensor/tensor/blockwise.py b/pytensor/tensor/blockwise.py new file mode 100644 index 0000000000..98d893a28b --- /dev/null +++ b/pytensor/tensor/blockwise.py @@ -0,0 +1,467 @@ +import re +from collections.abc import Sequence +from functools import singledispatch +from typing import Any, Dict, List, Optional, Tuple, Union, cast + +import numpy as np + +from pytensor import config +from pytensor.gradient import DisconnectedType +from pytensor.graph.basic import Apply, Constant +from pytensor.graph.op import Op +from pytensor.graph.null_type import NullType +from pytensor.tensor.var import TensorVariable, as_tensor_variable +from pytensor.tensor.shape import shape_padleft +from pytensor.tensor.type import continuous_dtypes, discrete_dtypes, tensor + +from pytensor.tensor.elemwise import Elemwise, DimShuffle + +# TODO: Implement vectorize helper to batch whole graphs (similar to what Blockwise does for the grad) + +# TODO: Add github link +# Copied verbatim from numpy.lib.function_base +_DIMENSION_NAME = r"\w+" +_CORE_DIMENSION_LIST = "(?:{0:}(?:,{0:})*)?".format(_DIMENSION_NAME) +_ARGUMENT = rf"\({_CORE_DIMENSION_LIST}\)" +_ARGUMENT_LIST = "{0:}(?:,{0:})*".format(_ARGUMENT) +_SIGNATURE = "^{0:}->{0:}$".format(_ARGUMENT_LIST) + + +def safe_signature( + core_inputs: Sequence[TensorVariable], + core_outputs: Sequence[TensorVariable], +) -> str: + def operand_sig(operand: TensorVariable, prefix: str) -> str: + operands = ",".join(f"{prefix}{i}" for i in range(operand.type.ndim)) + return f"({operands})" + + inputs_sig = ",".join( + operand_sig(i, prefix=f"i{n}") for n, i in enumerate(core_inputs) + ) + outputs_sig = ",".join( + operand_sig(o, prefix=f"o{n}") for n, o in enumerate(core_outputs) + ) + return f"{inputs_sig}->{outputs_sig}" + + +@singledispatch +def _vectorize_node(op: Op, node: Apply, *bached_inputs) -> Apply: + if hasattr(op, "gufunc_signature"): + signature = op.gufunc_signature + else: + # TODO: This is pretty bad for shape inference and merge optimization! + # Should get better as we add signatures to our Ops + signature = safe_signature(node.inputs, node.outputs) + return Blockwise(op, signature=signature).make_node(*bached_inputs) + + +def vectorize_node(node: Apply, *batched_inputs) -> Apply: + """Returns vectorized version of node with new batched inputs.""" + + # Special cases for most common `Op`s that don't really need to be "vectorized" + # TODO: Other simple cases include Reshape, Alloc, ? + + op = node.op + + if isinstance(op, Elemwise): + return op.make_node(*batched_inputs) + + if isinstance(op, Blockwise): + return op.make_node(*batched_inputs) + + if isinstance(op, DimShuffle): + [x] = batched_inputs + batched_ndims = x.type.ndim - node.inputs[0].type.ndim + if not batched_ndims: + return node + input_broadcastable = ( + x.type.broadcastable[:batched_ndims] + op.input_broadcastable + ) + new_order = list(range(batched_ndims)) + [ + "x" if o == "x" else o + batched_ndims for o in op.new_order + ] + return DimShuffle(input_broadcastable, new_order).make_node(x) + + from pytensor.tensor.random.op import RandomVariable + + if isinstance(op, RandomVariable): + return op.make_node(*batched_inputs) + + # Fallback to dispatch implementation so users can override behavior + return _vectorize_node(op, node, *batched_inputs) + + +def _parse_gufunc_signature(signature): + """ + Parse string signatures for a generalized universal function. + + Arguments + --------- + signature : string + Generalized universal function signature, e.g., ``(m,n),(n,p)->(m,p)`` + for ``np.matmul``. + + Returns + ------- + Tuple of input and output core dimensions parsed from the signature, each + of the form List[Tuple[str, ...]]. + """ + signature = re.sub(r"\s+", "", signature) + + if not re.match(_SIGNATURE, signature): + raise ValueError(f"not a valid gufunc signature: {signature}") + return tuple( + [ + tuple(re.findall(_DIMENSION_NAME, arg)) + for arg in re.findall(_ARGUMENT, arg_list) + ] + for arg_list in signature.split("->") + ) + + +class Blockwise(Op): + """Generalizes a core `Op` to work with batched dimensions. + + TODO: Add rewrites for Blockwise of Dimshuffle and CAReduce + TODO: Dispatch JAX (should be easy with the vectorize macro) + TODO: Dispatch Numba + TODO: C implementation? + TODO: Fuse Blockwise? + """ + + __props__ = ("core_op", "signature") + + def __init__( + self, + core_op: Op, + signature: Optional[str] = None, + name: Optional[str] = None, + **kwargs, + ): + """ + + Parameters + ---------- + core_op + An instance of a subclass of `Op` which works on the core case. + signature + Generalized universal function signature, + e.g., (m,n),(n)->(m) for vectorized matrix-vector multiplication + + """ + # Some Ops are implemented in a way that they already batch natively + # TODO: Consider refactoring them into a shared class + from pytensor.tensor.random.op import RandomVariable + + natively_batched_ops = (Elemwise, RandomVariable) + + if isinstance(core_op, type(self)): + raise TypeError("core_op cannot be a Blockwise") + if isinstance(core_op, natively_batched_ops): + raise TypeError(f"{core_op} already works as a Blockwise") + + if signature is None: + signature = getattr(core_op, "gufunc_signature", None) + if signature is None: + raise ValueError( + f"Signature not provided nor found in core_op {core_op}" + ) + + self.core_op = core_op + self.signature = signature + self.name = name + self.inputs_sig, self.outputs_sig = _parse_gufunc_signature(signature) + self._gufunc = None + super().__init__(**kwargs) + + def _create_dummy_core_node(self, inputs: Sequence[TensorVariable]) -> Apply: + core_input_types = [] + for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig)): + if inp.type.ndim < len(sig): + raise ValueError( + f"Input {i} {inp} has insufficient core dimensions for signature {self.signature}" + ) + # ndim_supp = 0 case + if not sig: + core_shape = () + else: + core_shape = inp.type.shape[-len(sig) :] + core_input_types.append(tensor(dtype=inp.type.dtype, shape=core_shape)) + + core_node = self.core_op.make_node(*core_input_types) + + if len(core_node.outputs) != len(self.outputs_sig): + raise ValueError( + f"Insufficient number of outputs for signature {self.signature}: {len(core_node.outputs)}" + ) + for i, (core_out, sig) in enumerate(zip(core_node.outputs, self.outputs_sig)): + if core_out.type.ndim != len(sig): + raise ValueError( + f"Output {i} of {self.core_op} has wrong number of core dimensions for signature {self.signature}: {core_out.type.ndim}" + ) + + return core_node + + def make_node(self, *inputs): + inputs = [as_tensor_variable(i) for i in inputs] + + core_node = self._create_dummy_core_node(inputs) + + batch_ndims = max( + inp.type.ndim - len(sig) for inp, sig in zip(inputs, self.inputs_sig) + ) + + # Don't pollute the graph with useless BlockWise + # TODO: Do we want to do this? Or leave it as a Blockwise and later have a rewrite that removes useless casse + # A reason to not eagerly avoid Blockwise is that we could make all rewrites track the Blockwise version, + # instead of having to track both or only the more restricted core case. + if not batch_ndims: + return self.core_op.make_node(*inputs) + + batched_inputs = [] + batch_shapes = [] + for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig)): + # Append missing dims to the left + missing_batch_ndims = batch_ndims - (inp.type.ndim - len(sig)) + if missing_batch_ndims: + inp = shape_padleft(inp, missing_batch_ndims) + batched_inputs.append(inp) + + if not sig: + batch_shapes.append(inp.type.shape) + else: + batch_shapes.append(inp.type.shape[: -len(sig)]) + + def get_most_specialized_batch_shape( + dims: Sequence[Union[None, int]] + ) -> Union[None, int]: + dims_set = set(dims) + # All dims are the same + if len(dims_set) == 1: + return tuple(dims_set)[0] + + # Only valid indeterminate case + if dims_set == {None, 1}: + return None + + dims_set.discard(1) + dims_set.discard(None) + if len(dims_set) > 1: + raise ValueError + return tuple(dims_set)[0] + + try: + batch_shape = tuple( + [ + get_most_specialized_batch_shape(batch_dims) + for batch_dims in zip(*batch_shapes) + ] + ) + except ValueError: + raise ValueError( + f"Incompatible Blockwise batch input shapes {[inp.type.shape for inp in inputs]}" + ) + + batched_outputs = [ + tensor(dtype=core_out.type.dtype, shape=batch_shape + core_out.type.shape) + for core_out in core_node.outputs + ] + + return Apply(self, batched_inputs, batched_outputs) + + def _batch_ndim_from_outputs(self, outputs: Sequence[TensorVariable]) -> int: + return cast(int, outputs[0].type.ndim - len(self.outputs_sig[0])) + + def infer_shape( + self, fgraph, node, input_shapes + ) -> List[Tuple[TensorVariable, ...]]: + from pytensor.tensor import broadcast_shape + from pytensor.tensor.shape import Shape_i + + batch_ndims = self._batch_ndim_from_outputs(node.outputs) + core_dims: Dict[str, Any] = {} + batch_shapes = [] + for input_shape, sig in zip(input_shapes, self.inputs_sig): + batch_shapes.append(input_shape[:batch_ndims]) + core_shape = input_shape[batch_ndims:] + + for core_dim, dim_name in zip(core_shape, sig): + prev_core_dim = core_dims.get(core_dim) + if prev_core_dim is None: + core_dims[dim_name] = core_dim + # Prefer constants + elif not isinstance(prev_core_dim, Constant): + core_dims[dim_name] = core_dim + + batch_shape = broadcast_shape(*batch_shapes, arrays_are_shapes=True) + + out_shapes = [] + for output, sig in zip(node.outputs, self.outputs_sig): + core_out_shape = [] + for i, dim_name in enumerate(sig): + # The output dim is the same as another input dim + if dim_name in core_dims: + core_out_shape.append(core_dims[dim_name]) + else: + # TODO: We could try to make use of infer_shape of core_op + core_out_shape.append(Shape_i(batch_ndims + i)(output)) + out_shapes.append((*batch_shape, *core_out_shape)) + + return out_shapes + + def connection_pattern(self, node): + if hasattr(self.core_op, "connection_pattern"): + return self.core_op.connection_pattern(node) + + return [[True for _ in node.outputs] for _ in node.inputs] + + def _bgrad(self, inputs, outputs, ograds): + # Grad, with respect to broadcasted versions of inputs + + def as_core(t, core_t): + # Inputs could be NullType or DisconnectedType + if isinstance(t.type, (NullType, DisconnectedType)): + return t + return core_t.type() + + with config.change_flags(compute_test_value="off"): + safe_inputs = [ + tensor(dtype=inp.type.dtype, shape=(None,) * len(sig)) + for inp, sig in zip(inputs, self.inputs_sig) + ] + core_node = self._create_dummy_core_node(safe_inputs) + + core_inputs = [ + as_core(inp, core_inp) + for inp, core_inp in zip(inputs, core_node.inputs) + ] + core_ograds = [ + as_core(ograd, core_ograd) + for ograd, core_ograd in zip(ograds, core_node.outputs) + ] + core_outputs = core_node.outputs + + core_igrads = self.core_op.L_op(core_inputs, core_outputs, core_ograds) + + batch_ndims = self._batch_ndim_from_outputs(outputs) + + def transform(var): + # From a graph of ScalarOps, make a graph of Broadcast ops. + if isinstance(var.type, (NullType, DisconnectedType)): + return var + if var in core_inputs: + return inputs[core_inputs.index(var)] + if var in core_outputs: + return outputs[core_outputs.index(var)] + if var in core_ograds: + return ograds[core_ograds.index(var)] + + node = var.owner + + # The gradient contains a constant, which may be responsible for broadcasting + if node is None: + if batch_ndims: + var = shape_padleft(var, batch_ndims) + return var + + batched_inputs = [transform(inp) for inp in node.inputs] + batched_node = vectorize_node(node, *batched_inputs) + batched_var = batched_node.outputs[var.owner.outputs.index(var)] + + return batched_var + + ret = [] + for core_igrad, ipt in zip(core_igrads, inputs): + # Undefined gradient + if core_igrad is None: + ret.append(None) + else: + ret.append(transform(core_igrad)) + + return ret + + def L_op(self, inputs, outs, ograds): + from pytensor.tensor.math import sum as pt_sum + + # Compute grad with respect to broadcasted input + rval = self._bgrad(inputs, outs, ograds) + + # TODO: (Borrowed from Elemwise) make sure that zeros are clearly identifiable + # to the gradient.grad method when the outputs have + # some integer and some floating point outputs + if any(out.type.dtype not in continuous_dtypes for out in outs): + # For integer output, return value may only be zero or undefined + # We don't bother with trying to check that the scalar ops + # correctly returned something that evaluates to 0, we just make + # the return value obviously zero so that gradient.grad can tell + # this op did the right thing. + new_rval = [] + for elem, inp in zip(rval, inputs): + if isinstance(elem.type, (NullType, DisconnectedType)): + new_rval.append(elem) + else: + elem = inp.zeros_like() + if str(elem.type.dtype) not in continuous_dtypes: + elem = elem.astype(config.floatX) + assert str(elem.type.dtype) not in discrete_dtypes + new_rval.append(elem) + return new_rval + + # Sum out the broadcasted dimensions + batch_ndims = self._batch_ndim_from_outputs(outs) + batch_shape = outs[0].type.shape[:batch_ndims] + for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig)): + if isinstance(rval[i].type, (NullType, DisconnectedType)): + continue + + assert inp.type.ndim == batch_ndims + len(sig) + + to_sum = [ + j + for j, (inp_s, out_s) in enumerate(zip(inp.type.shape, batch_shape)) + if inp_s == 1 and out_s != 1 + ] + if to_sum: + rval[i] = pt_sum(rval[i], axis=to_sum, keepdims=True) + + return rval + + def _create_gufunc(self, node): + # TODO: Use `impl` numpy versions just like Elemwise and ScalarOps do + + n_outs = len(self.outputs_sig) + core_node = self._create_dummy_core_node(node.inputs) + + def core_func(*inner_inputs): + inner_outputs = [[None] for _ in range(n_outs)] + + inner_inputs = [np.asarray(inp) for inp in inner_inputs] + self.core_op.perform(core_node, inner_inputs, inner_outputs) + + if len(inner_outputs) == 1: + return inner_outputs[0][0] + else: + return tuple(r[0] for r in inner_outputs) + + self._gufunc = np.vectorize(core_func, signature=self.signature) + return self._gufunc + + def perform(self, node, inputs, outputs): + gufunc = self._gufunc + + if gufunc is None: + gufunc = self._create_gufunc(node) + + res = gufunc(*inputs) + + if isinstance(res, tuple): + for i, out in enumerate(outputs): + outputs[i][0] = res[i] + else: + outputs[0][0] = res + + def __str__(self): + if self.name is None: + return f"{type(self).__name__}{{{self.core_op}, {self.signature}}}" + else: + return self.name diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index aa77326638..b4d1469ef8 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -220,7 +220,7 @@ def __str__(self): else: return "DimShuffle{%s}" % ",".join(str(x) for x in self.new_order) - def perform(self, node, inp, out, params): + def perform(self, node, inp, out, params=None): (res,) = inp (storage,) = out diff --git a/pytensor/tensor/nlinalg.py b/pytensor/tensor/nlinalg.py index 5514cd6d1c..ff680a6e1f 100644 --- a/pytensor/tensor/nlinalg.py +++ b/pytensor/tensor/nlinalg.py @@ -15,6 +15,7 @@ class MatrixPinv(Op): __props__ = ("hermitian",) + gufunc_signature = "(m,n)->(n,m)" def __init__(self, hermitian): self.hermitian = hermitian @@ -81,6 +82,9 @@ def pinv(x, hermitian=False): class Inv(Op): """Computes the inverse of one or more matrices.""" + # TODO: This Op is already natively vectorized, dispatch on `vectorized_node` to avoid useless Blockwise + gufunc_signature = "(m,m)->(m,m)" + def make_node(self, x): x = as_tensor_variable(x) return Apply(self, [x], [x.type()]) @@ -112,6 +116,7 @@ class MatrixInverse(Op): """ __props__ = () + gufunc_signature = "(m,m)->(m,m)" def __init__(self): pass @@ -200,6 +205,7 @@ class Det(Op): """ __props__ = () + gufunc_signature = "(m,m)->()" def make_node(self, x): x = as_tensor_variable(x) @@ -237,6 +243,7 @@ class SLogDet(Op): """ __props__ = () + gufunc_signature = "(m, m)->(),()" def make_node(self, x): x = as_tensor_variable(x) @@ -272,6 +279,7 @@ class Eig(Op): _numop = staticmethod(np.linalg.eig) __props__: Tuple[str, ...] = () + gufunc_signature = "(m,m)->(m),(m,m)" def make_node(self, x): x = as_tensor_variable(x) diff --git a/tests/tensor/test_blockwise.py b/tests/tensor/test_blockwise.py new file mode 100644 index 0000000000..ffc5fe476a --- /dev/null +++ b/tests/tensor/test_blockwise.py @@ -0,0 +1,284 @@ +from itertools import product +from typing import Tuple, Union + +import numpy as np +import pytest + +import pytensor +from pytensor import config +from pytensor.gradient import grad +from pytensor.graph import Op, Apply +from pytensor.tensor import tensor, exp +from pytensor.tensor.blockwise import Blockwise, _parse_gufunc_signature, vectorize_node +from pytensor.tensor.elemwise import DimShuffle +from pytensor.tensor.nlinalg import MatrixInverse +from pytensor.tensor.random import normal +from pytensor.tensor.slinalg import Cholesky + + +def test_vectorize_node(): + vec = tensor(shape=(None,)) + mat = tensor(shape=(None, None)) + tns = tensor(shape=(None, None, None)) + + node = exp(vec).owner + vect_node = vectorize_node(node, mat) + assert vect_node.op == exp + assert vect_node.inputs[0] is mat + + col_mat = tensor(shape=(None, 1)) + tcol_mat = tensor(shape=(None, None, 1)) + node = col_mat.dimshuffle(0).owner # drop column + vect_node = vectorize_node(node, tcol_mat) + assert isinstance(vect_node.op, DimShuffle) + assert vect_node.inputs[0] is tcol_mat + assert vect_node.outputs[0].type.shape == (None, None) + + node = normal(vec).owner + new_inputs = node.inputs[:3] + [mat] + node.inputs[4:] + vect_node = vectorize_node(node, *new_inputs) + assert vect_node.op is normal + assert vect_node.inputs[3] is mat + + # Something that falls back to blockwise + node = MatrixInverse()(mat).owner + vect_node = vectorize_node(node, tns) + assert isinstance(vect_node.op, Blockwise) and isinstance( + vect_node.op.core_op, MatrixInverse + ) + assert vect_node.op.signature == ("(m,m)->(m,m)") + assert vect_node.inputs[0] is tns + + # Useless blockwise + tns4 = tensor(shape=(5, None, None, None)) + new_vect_node = vectorize_node(vect_node, tns4) + assert new_vect_node.op is vect_node.op + assert isinstance(new_vect_node.op, Blockwise) and isinstance( + new_vect_node.op.core_op, MatrixInverse + ) + assert new_vect_node.inputs[0] is tns4 + + +def test_useless_blockwise(): + cop = MatrixInverse() + bop = Blockwise(cop, signature=("(m, m) -> (m, m)")) + + inp = tensor(shape=(None, None, None)) + out = bop(inp) + assert out.owner.op is bop + assert out.owner.inputs[0] is inp + + inp = tensor(shape=(None, None)) + out = bop(inp) + assert out.owner.op is cop + assert out.owner.inputs[0] is inp + + +class TestOp(Op): + def make_node(self, *inputs): + return Apply(self, inputs, [i.type() for i in inputs]) + + def perform(self, *args, **kwargs): + raise NotImplementedError("Test Op should not be present in final graph") + + +test_op = TestOp() + + +def test_vectorize_node_default_signature(): + vec = tensor(shape=(None,)) + mat = tensor(shape=(5, None)) + node = test_op.make_node(vec, mat) + + vect_node = vectorize_node(node, mat, mat) + assert isinstance(vect_node.op, Blockwise) and isinstance( + vect_node.op.core_op, TestOp + ) + assert vect_node.op.signature == ("(i00),(i10,i11)->(o00),(o10,o11)") + + with pytest.raises( + ValueError, match="Signature not provided nor found in core_op TestOp" + ): + Blockwise(test_op) + + vect_node = Blockwise(test_op, signature="(m),(n)->(m),(n)").make_node(vec, mat) + assert vect_node.outputs[0].type.shape == ( + 5, + None, + ) + assert vect_node.outputs[0].type.shape == ( + 5, + None, + ) + + +def test_blockwise_shape(): + # Single output + inp = tensor(shape=(5, None, None)) + inp_test = np.zeros((5, 4, 3), dtype=config.floatX) + + # Shape can be inferred from inputs + op = Blockwise(test_op, signature="(m, n) -> (n, m)") + out = op(inp) + assert out.type.shape == (5, None, None) + + shape_fn = pytensor.function([inp], out.shape) + assert not any( + isinstance(getattr(n.op, "core_op", n.op), TestOp) + for n in shape_fn.maker.fgraph.apply_nodes + ) + assert tuple(shape_fn(inp_test)) == (5, 3, 4) + + # Shape can only be partially inferred from inputs + op = Blockwise(test_op, signature="(m, n) -> (m, k)") + out = op(inp) + assert out.type.shape == (5, None, None) + + shape_fn = pytensor.function([inp], out.shape) + assert any( + isinstance(getattr(n.op, "core_op", n.op), TestOp) + for n in shape_fn.maker.fgraph.apply_nodes + ) + + shape_fn = pytensor.function([inp], out.shape[:-1]) + assert not any( + isinstance(getattr(n.op, "core_op", n.op), TestOp) + for n in shape_fn.maker.fgraph.apply_nodes + ) + assert tuple(shape_fn(inp_test)) == (5, 4) + + # Mutiple outputs + inp1 = tensor(shape=(7, 1, None, None)) + inp2 = tensor(shape=(1, 5, None, None)) + inp1_test = np.zeros((7, 1, 4, 3), dtype=config.floatX) + inp2_test = np.zeros((1, 5, 4, 3), dtype=config.floatX) + + op = Blockwise(test_op, signature="(m, n), (m, n) -> (n, m), (m, k)") + outs = op(inp1, inp2) + assert outs[0].type.shape == (7, 5, None, None) + assert outs[1].type.shape == (7, 5, None, None) + + shape_fn = pytensor.function([inp1, inp2], [out.shape for out in outs]) + assert any( + isinstance(getattr(n.op, "core_op", n.op), TestOp) + for n in shape_fn.maker.fgraph.apply_nodes + ) + + shape_fn = pytensor.function([inp1, inp2], outs[0].shape) + assert not any( + isinstance(getattr(n.op, "core_op", n.op), TestOp) + for n in shape_fn.maker.fgraph.apply_nodes + ) + assert tuple(shape_fn(inp1_test, inp2_test)) == (7, 5, 3, 4) + + shape_fn = pytensor.function([inp1, inp2], [outs[0].shape, outs[1].shape[:-1]]) + assert not any( + isinstance(getattr(n.op, "core_op", n.op), TestOp) + for n in shape_fn.maker.fgraph.apply_nodes + ) + assert tuple(shape_fn(inp1_test, inp2_test)[0]) == (7, 5, 3, 4) + assert tuple(shape_fn(inp1_test, inp2_test)[1]) == (7, 5, 4) + + +class BlockwiseOpTester: + """Base class to test Blockwise works for specific Ops""" + + core_op = None + signature = None + batcheable_axes = None + + @classmethod + def setup_class(cls): + seed = sum(map(ord, cls.__class__.__name__)) + cls.rng = np.random.default_rng(seed) + cls.params_sig, cls.outputs_sig = _parse_gufunc_signature(cls.signature) + if cls.batcheable_axes is None: + cls.batcheable_axes = list(range(len(cls.outputs_sig))) + batch_shapes = [(), (1,), (5,), (1, 1), (1, 3), (3, 1), (3, 5)] + cls.test_batch_shapes = list( + product(batch_shapes, repeat=len(cls.batcheable_axes)) + ) + cls.block_op = Blockwise(core_op=cls.core_op, signature=cls.signature) + + @staticmethod + def parse_shape(shape: Tuple[Union[str, int], ...]) -> Tuple[int, ...]: + """ + Convert (5, "m", "n") -> (5, 7, 11) + """ + mapping = {"m": 7, "n": 11, "k": 19} + return tuple(mapping.get(p, p) for p in shape) + + def create_testvals(self, shape): + return self.rng.normal(size=self.parse_shape(shape)).astype(config.floatX) + + def create_batched_inputs(self): + for batch_shapes in self.test_batch_shapes: + vec_inputs = [] + vec_inputs_testvals = [] + for batch_shape, param_sig in zip(batch_shapes, self.params_sig): + vec_inputs.append(tensor(shape=batch_shape + (None,) * len(param_sig))) + vec_inputs_testvals.append( + self.create_testvals(shape=batch_shape + param_sig) + ) + yield vec_inputs, vec_inputs_testvals + + def test_perform(self): + base_inputs = [ + tensor(shape=(None,) * len(param_sig)) for param_sig in self.params_sig + ] + core_func = pytensor.function(base_inputs, self.core_op(*base_inputs)) + np_func = np.vectorize(core_func, signature=self.signature) + + for vec_inputs, vec_inputs_testvals in self.create_batched_inputs(): + pt_func = pytensor.function(vec_inputs, self.block_op(*vec_inputs)) + if len(self.outputs_sig) != 1: + raise NotImplementedError("Did not implement test for multi-output Ops") + np.testing.assert_allclose( + pt_func(*vec_inputs_testvals), + np_func(*vec_inputs_testvals), + ) + + def test_grad(self): + base_inputs = [ + tensor(shape=(None,) * len(param_sig)) for param_sig in self.params_sig + ] + out = self.core_op(*base_inputs).sum() + if len(base_inputs) == 1: + core_grad_func = pytensor.function( + base_inputs, grad(out, wrt=base_inputs[0]) + ) + else: + core_grad_func = pytensor.function(base_inputs, grad(out, wrt=base_inputs)) + + [param_sig, _] = self.signature.split("->") + grad_sig = f"{param_sig}->{param_sig}" + np_func_raw = np.vectorize(core_grad_func, signature=grad_sig) + if len(base_inputs): + np_func = lambda *args: [np_func_raw(*args)] # noqa: E731 + else: + np_func = np_func_raw + + for vec_inputs, vec_inputs_testvals in self.create_batched_inputs(): + out = self.block_op(*vec_inputs).sum() + pt_func = pytensor.function(vec_inputs, grad(out, wrt=vec_inputs)) + pt_outs = pt_func(*vec_inputs_testvals) + np_outs = np_func(*vec_inputs_testvals) + for pt_out, np_out in zip(pt_outs, np_outs): + np.testing.assert_allclose(pt_out, np_out) + + +class MatrixOpBlockwiseTester(BlockwiseOpTester): + def create_testvals(self, shape): + # Return a posdef matrix + X = super().create_testvals(shape) + return np.einsum("...ij,...kj->...ik", X, X) + + +class TestCholesky(MatrixOpBlockwiseTester): + core_op = Cholesky(lower=True) + signature = "(m, m) -> (m, m)" + + +class TestMatrixInverse(MatrixOpBlockwiseTester): + core_op = MatrixInverse() + signature = "(m, m) -> (m, m)"