diff --git a/pytensor/graph/op.py b/pytensor/graph/op.py index 690bb44df5..79f0508588 100644 --- a/pytensor/graph/op.py +++ b/pytensor/graph/op.py @@ -596,6 +596,16 @@ def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op": # By default, do nothing return self + def infer_shape(self, fgraph, node, input_shapes): + if hasattr(self, "gufunc_signature"): + from pytensor.tensor.utils import _gufunc_to_out_shape + + return _gufunc_to_out_shape(self.gufunc_signature, input_shapes) + else: + from pytensor.tensor.exceptions import ShapeError + + raise ShapeError(f"Op {self} does not implement infer_shape") + def __str__(self): return getattr(type(self), "__name__", super().__str__()) diff --git a/pytensor/tensor/nlinalg.py b/pytensor/tensor/nlinalg.py index ee33f6533c..446cf5ab44 100644 --- a/pytensor/tensor/nlinalg.py +++ b/pytensor/tensor/nlinalg.py @@ -62,9 +62,6 @@ def L_op(self, inputs, outputs, g_outputs): ).T return [grad] - def infer_shape(self, fgraph, node, shapes): - return [list(reversed(shapes[0]))] - def pinv(x, hermitian=False): """Computes the pseudo-inverse of a matrix :math:`A`. @@ -155,9 +152,6 @@ def R_op(self, inputs, eval_points): return [None] return [-matrix_dot(xi, ev, xi)] - def infer_shape(self, fgraph, node, shapes): - return shapes - inv = matrix_inverse = Blockwise(MatrixInverse()) @@ -224,9 +218,6 @@ def grad(self, inputs, g_outputs): (x,) = inputs return [gz * self(x) * matrix_inverse(x).T] - def infer_shape(self, fgraph, node, shapes): - return [()] - def __str__(self): return "Det" @@ -258,9 +249,6 @@ def perform(self, node, inputs, outputs): except Exception as e: raise ValueError("Failed to compute determinant", x) from e - def infer_shape(self, fgraph, node, shapes): - return [(), ()] - def __str__(self): return "SLogDet" @@ -316,10 +304,6 @@ def perform(self, node, inputs, outputs): (w, v) = outputs w[0], v[0] = (z.astype(x.dtype) for z in np.linalg.eig(x)) - def infer_shape(self, fgraph, node, shapes): - n = shapes[0][0] - return [(n,), (n, n)] - eig = Blockwise(Eig()) diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index a8f9377170..ddbd74a9d4 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -50,9 +50,6 @@ def __init__( if self.overwrite_a: self.destroy_map = {0: [0]} - def infer_shape(self, fgraph, node, shapes): - return [shapes[0]] - def make_node(self, x): x = as_tensor_variable(x) if x.type.ndim != 2: @@ -268,15 +265,6 @@ def make_node(self, A, b): x = tensor(dtype=o_dtype, shape=b.type.shape) return Apply(self, [A, b], [x]) - def infer_shape(self, fgraph, node, shapes): - Ashape, Bshape = shapes - rows = Ashape[1] - if len(Bshape) == 1: - return [(rows,)] - else: - cols = Bshape[1] - return [(rows, cols)] - def L_op(self, inputs, outputs, output_gradients): r"""Reverse-mode gradient updates for matrix solve operation :math:`c = A^{-1} b`. @@ -890,9 +878,6 @@ def perform(self, node, inputs, output_storage): out_dtype = node.outputs[0].type.dtype X[0] = scipy_linalg.solve_continuous_lyapunov(A, B).astype(out_dtype) - def infer_shape(self, fgraph, node, shapes): - return [shapes[0]] - def grad(self, inputs, output_grads): # Gradient computations come from Kao and Hennequin (2020), https://arxiv.org/pdf/2011.11430.pdf # Note that they write the equation as AX + XA.H + Q = 0, while scipy uses AX + XA^H = Q, @@ -962,9 +947,6 @@ def perform(self, node, inputs, output_storage): out_dtype ) - def infer_shape(self, fgraph, node, shapes): - return [shapes[0]] - def grad(self, inputs, output_grads): # Gradient computations come from Kao and Hennequin (2020), https://arxiv.org/pdf/2011.11430.pdf A, Q = inputs @@ -1082,9 +1064,6 @@ def perform(self, node, inputs, output_storage): out_dtype = node.outputs[0].type.dtype X[0] = scipy_linalg.solve_discrete_are(A, B, Q, R).astype(out_dtype) - def infer_shape(self, fgraph, node, shapes): - return [shapes[0]] - def grad(self, inputs, output_grads): # Gradient computations come from Kao and Hennequin (2020), https://arxiv.org/pdf/2011.11430.pdf A, B, Q, R = inputs diff --git a/pytensor/tensor/utils.py b/pytensor/tensor/utils.py index 0ebb2e5434..2dbfa9b8ea 100644 --- a/pytensor/tensor/utils.py +++ b/pytensor/tensor/utils.py @@ -7,6 +7,7 @@ import pytensor from pytensor.graph import FunctionGraph, Variable from pytensor.npy_2_compat import normalize_axis_tuple +from pytensor.tensor import Any, Constant from pytensor.utils import hash_from_code @@ -202,6 +203,58 @@ def _parse_gufunc_signature( ) +def _gufunc_to_out_shape( + signature: str, shapes: list[tuple[Any, ...]] +) -> list[tuple[Any, ...]]: + """ + Compute the shape of the output of an Op given its gufunc signature and the + shapes of its inputs. + + Parameters + ---------- + signature : str + The gufunc signature of the Op. + eg: "(m,n),(n,p)->(m,p)". + + shapes : list of tuple of Any + The list of shapes of the inputs. + + Returns + ------- + out_shape : list of tuple of Any + The list of shapes of the outputs. + + Raises + ------ + ValueError + If the signature is invalid for the shapes of the inputs. + """ + input_sig, output_sig = _parse_gufunc_signature(signature) + dim_to_size: dict[str, Any] = {} + for input_shape, sig in zip(shapes, input_sig, strict=True): + for size, dim_name in zip(input_shape, sig, strict=True): + prev_size = dim_to_size.get(dim_name) + if prev_size is None: + dim_to_size[dim_name] = size + # Prefer constants + elif not isinstance(prev_size, Constant): + dim_to_size[dim_name] = size + + out_shapes = [] + for output_shape in output_sig: + temp_list = [] + for dim in output_shape: + if dim not in dim_to_size: + raise ValueError( + f"Invalid signature {signature} for shapes {shapes}. " + f"Dimension {dim} not in input dimensions." + ) + else: + temp_list.append(dim_to_size[dim]) + out_shapes.append((*temp_list,)) + return out_shapes + + def safe_signature( core_inputs_ndim: Sequence[int], core_outputs_ndim: Sequence[int],