diff --git a/pytensor/tensor/nlinalg.py b/pytensor/tensor/nlinalg.py index c8805c38e1..cf510fb065 100644 --- a/pytensor/tensor/nlinalg.py +++ b/pytensor/tensor/nlinalg.py @@ -1,7 +1,9 @@ import warnings from functools import partial +from typing import Callable, Literal, Optional, Union import numpy as np +from numpy.core.numeric import normalize_axis_tuple # type: ignore from pytensor import scalar as ps from pytensor.gradient import DisconnectedType @@ -523,6 +525,7 @@ def qr(a, mode="reduced"): class SVD(Op): """ + Computes singular value decomposition of matrix A, into U, S, V such that A = U @ S @ V Parameters ---------- @@ -543,13 +546,23 @@ class SVD(Op): def __init__(self, full_matrices: bool = True, compute_uv: bool = True): self.full_matrices = bool(full_matrices) self.compute_uv = bool(compute_uv) + if self.compute_uv: + if self.full_matrices: + self.gufunc_signature = "(m,n)->(m,m),(k),(n,n)" + else: + self.gufunc_signature = "(m,n)->(m,k),(k),(k,n)" + else: + self.gufunc_signature = "(m,n)->(k)" def make_node(self, x): x = as_tensor_variable(x) assert x.ndim == 2, "The input of svd function should be a matrix." in_dtype = x.type.numpy_dtype - out_dtype = np.dtype(f"f{in_dtype.itemsize}") + if in_dtype.name.startswith("int"): + out_dtype = np.dtype(f"f{in_dtype.itemsize}") + else: + out_dtype = in_dtype s = vector(dtype=out_dtype) @@ -603,7 +616,7 @@ def svd(a, full_matrices: bool = True, compute_uv: bool = True): U, V, D : matrices """ - return SVD(full_matrices, compute_uv)(a) + return Blockwise(SVD(full_matrices, compute_uv))(a) class Lstsq(Op): @@ -677,41 +690,204 @@ def matrix_power(M, n): return result -def norm(x, ord): - x = as_tensor_variable(x) +def _multi_svd_norm( + x: ptb.TensorVariable, row_axis: int, col_axis: int, reduce_op: Callable +): + """Compute a function of the singular values of the 2-D matrices in `x`. + + This is a private utility function used by `pytensor.tensor.nlinalg.norm()`. + + Copied from `np.linalg._multi_svd_norm`. + + Parameters + ---------- + x : TensorVariable + Input tensor. + row_axis, col_axis : int + The axes of `x` that hold the 2-D matrices. + reduce_op : callable + Reduction op. Should be one of `pt.min`, `pt.max`, or `pt.sum` + + Returns + ------- + result : float or ndarray + If `x` is 2-D, the return values is a float. + Otherwise, it is an array with ``x.ndim - 2`` dimensions. + The return values are either the minimum or maximum or sum of the + singular values of the matrices, depending on whether `op` + is `pt.amin` or `pt.amax` or `pt.sum`. + + """ + y = ptb.moveaxis(x, (row_axis, col_axis), (-2, -1)) + result = reduce_op(svd(y, compute_uv=False), axis=-1) + return result + + +VALID_ORD = Literal["fro", "f", "nuc", "inf", "-inf", 0, 1, -1, 2, -2] + + +def norm( + x: ptb.TensorVariable, + ord: Optional[Union[float, VALID_ORD]] = None, + axis: Optional[Union[int, tuple[int, ...]]] = None, + keepdims: bool = False, +): + """ + Matrix or vector norm. + + Parameters + ---------- + x: TensorVariable + Tensor to take norm of. + + ord: float, str or int, optional + Order of norm. If `ord` is a str, it must be one of the following: + - 'fro' or 'f' : Frobenius norm + - 'nuc' : nuclear norm + - 'inf' : Infinity norm + - '-inf' : Negative infinity norm + If an integer, order can be one of -2, -1, 0, 1, or 2. + Otherwise `ord` must be a float. + + Default is the Frobenius (L2) norm. + + axis: tuple of int, optional + Axes over which to compute the norm. If None, norm of entire matrix (or vector) is computed. Row or column + norms can be computed by passing a single integer; this will treat a matrix like a batch of vectors. + + keepdims: bool + If True, dummy axes will be inserted into the output so that norm.dnim == x.dnim. Default is False. + + Returns + ------- + TensorVariable + Norm of `x` along axes specified by `axis`. + + Notes + ----- + Batched dimensions are supported to the left of the core dimensions. For example, if `x` is a 3D tensor with + shape (2, 3, 4), then `norm(x)` will compute the norm of each 3x4 matrix in the batch. + + If the input is a 2D tensor and should be treated as a batch of vectors, the `axis` argument must be specified. + """ + x = ptb.as_tensor_variable(x) + ndim = x.ndim - if ndim == 0: - raise ValueError("'axis' entry is out of bounds.") - elif ndim == 1: - if ord is None: - return ptm.sum(x**2) ** 0.5 - elif ord == "inf": - return ptm.max(abs(x)) - elif ord == "-inf": - return ptm.min(abs(x)) + core_ndim = min(2, ndim) + batch_ndim = ndim - core_ndim + + if axis is None: + # Handle some common cases first. These can be computed more quickly than the default SVD way, so we always + # want to check for them. + if ( + (ord is None) + or (ord in ("f", "fro") and core_ndim == 2) + or (ord == 2 and core_ndim == 1) + ): + x = x.reshape(tuple(x.shape[:-2]) + (-1,) + (1,) * (core_ndim - 1)) + batch_T_dim_order = tuple(range(batch_ndim)) + tuple( + range(batch_ndim + core_ndim - 1, batch_ndim - 1, -1) + ) + + if x.dtype.startswith("complex"): + x_real = x.real # type: ignore + x_imag = x.imag # type: ignore + sqnorm = ( + ptb.transpose(x_real, batch_T_dim_order) @ x_real + + ptb.transpose(x_imag, batch_T_dim_order) @ x_imag + ) + else: + sqnorm = ptb.transpose(x, batch_T_dim_order) @ x + ret = ptm.sqrt(sqnorm).squeeze() + if keepdims: + ret = ptb.shape_padright(ret, core_ndim) + return ret + + # No special computation to exploit -- set default axis before continuing + axis = tuple(range(core_ndim)) + + elif not isinstance(axis, tuple): + try: + axis = int(axis) + except Exception as e: + raise TypeError( + "'axis' must be None, an integer, or a tuple of integers" + ) from e + + axis = (axis,) + + if len(axis) == 1: + # Vector norms + if ord in [None, "fro", "f"] and (core_ndim == 2): + # This is here to catch the case where X is a 2D tensor but the user wants to treat it as a batch of + # vectors. Other vector norms will work fine in this case. + ret = ptm.sqrt(ptm.sum((x.conj() * x).real, axis=axis, keepdims=keepdims)) + elif (ord == "inf") or (ord == np.inf): + ret = ptm.max(ptm.abs(x), axis=axis, keepdims=keepdims) + elif (ord == "-inf") or (ord == -np.inf): + ret = ptm.min(ptm.abs(x), axis=axis, keepdims=keepdims) elif ord == 0: - return x[x.nonzero()].shape[0] + ret = ptm.neq(x, 0).sum(axis=axis, keepdims=keepdims) + elif ord == 1: + ret = ptm.sum(ptm.abs(x), axis=axis, keepdims=keepdims) + elif isinstance(ord, str): + raise ValueError(f"Invalid norm order '{ord}' for vectors") else: - try: - z = ptm.sum(abs(x**ord)) ** (1.0 / ord) - except TypeError: - raise ValueError("Invalid norm order for vectors.") - return z - elif ndim == 2: - if ord is None or ord == "fro": - return ptm.sum(abs(x**2)) ** (0.5) - elif ord == "inf": - return ptm.max(ptm.sum(abs(x), 1)) - elif ord == "-inf": - return ptm.min(ptm.sum(abs(x), 1)) + ret = ptm.sum(ptm.abs(x) ** ord, axis=axis, keepdims=keepdims) + ret **= ptm.reciprocal(ord) + + return ret + + elif len(axis) == 2: + # Matrix norms + row_axis, col_axis = ( + batch_ndim + x for x in normalize_axis_tuple(axis, core_ndim) + ) + axis = (row_axis, col_axis) + + if ord in [None, "fro", "f"]: + ret = ptm.sqrt(ptm.sum((x.conj() * x).real, axis=axis)) + + elif (ord == "inf") or (ord == np.inf): + if row_axis > col_axis: + row_axis -= 1 + ret = ptm.max(ptm.sum(ptm.abs(x), axis=col_axis), axis=row_axis) + + elif (ord == "-inf") or (ord == -np.inf): + if row_axis > col_axis: + row_axis -= 1 + ret = ptm.min(ptm.sum(ptm.abs(x), axis=col_axis), axis=row_axis) + elif ord == 1: - return ptm.max(ptm.sum(abs(x), 0)) + if col_axis > row_axis: + col_axis -= 1 + ret = ptm.max(ptm.sum(ptm.abs(x), axis=row_axis), axis=col_axis) + elif ord == -1: - return ptm.min(ptm.sum(abs(x), 0)) + if col_axis > row_axis: + col_axis -= 1 + ret = ptm.min(ptm.sum(ptm.abs(x), axis=row_axis), axis=col_axis) + + elif ord == 2: + ret = _multi_svd_norm(x, row_axis, col_axis, ptm.max) + + elif ord == -2: + ret = _multi_svd_norm(x, row_axis, col_axis, ptm.min) + + elif ord == "nuc": + ret = _multi_svd_norm(x, row_axis, col_axis, ptm.sum) + else: - raise ValueError(0) - elif ndim > 2: - raise NotImplementedError("We don't support norm with ndim > 2") + raise ValueError(f"Invalid norm order for matrices: {ord}") + + if keepdims: + ret = ptb.expand_dims(ret, axis) + + return ret + else: + raise ValueError( + f"Cannot compute norm when core_dims < 1 or core_dims > 3, found: core_dims = {core_ndim}" + ) class TensorInv(Op): diff --git a/tests/tensor/test_nlinalg.py b/tests/tensor/test_nlinalg.py index 4ce69d577c..d39ab0b777 100644 --- a/tests/tensor/test_nlinalg.py +++ b/tests/tensor/test_nlinalg.py @@ -1,7 +1,8 @@ +from functools import partial + import numpy as np import numpy.linalg import pytest -from numpy import inf from numpy.testing import assert_array_almost_equal import pytensor @@ -34,6 +35,7 @@ lscalar, matrix, scalar, + tensor, tensor3, tensor4, vector, @@ -150,29 +152,52 @@ def test_qr_modes(): class TestSvd(utt.InferShapeTester): op_class = SVD - dtype = "float32" def setup_method(self): super().setup_method() self.rng = np.random.default_rng(utt.fetch_seed()) - self.A = matrix(dtype=self.dtype) + self.A = matrix(dtype=config.floatX) self.op = svd - def test_svd(self): - A = matrix("A", dtype=self.dtype) - U, S, VT = svd(A) - fn = function([A], [U, S, VT]) - a = self.rng.random((4, 4)).astype(self.dtype) - n_u, n_s, n_vt = np.linalg.svd(a) - t_u, t_s, t_vt = fn(a) + @pytest.mark.parametrize( + "core_shape", [(3, 3), (4, 3), (3, 4)], ids=["square", "tall", "wide"] + ) + @pytest.mark.parametrize( + "full_matrix", [True, False], ids=["full=True", "full=False"] + ) + @pytest.mark.parametrize( + "compute_uv", [True, False], ids=["compute_uv=True", "compute_uv=False"] + ) + @pytest.mark.parametrize( + "batched", [True, False], ids=["batched=True", "batched=False"] + ) + @pytest.mark.parametrize( + "test_imag", [True, False], ids=["test_imag=True", "test_imag=False"] + ) + def test_svd(self, core_shape, full_matrix, compute_uv, batched, test_imag): + dtype = config.floatX + if test_imag: + dtype = "complex128" if dtype.endswith("64") else "complex64" + shape = core_shape if not batched else (10, *core_shape) + A = tensor("A", shape=shape, dtype=dtype) + a = self.rng.random(shape).astype(dtype) + + outputs = svd(A, compute_uv=compute_uv, full_matrices=full_matrix) + outputs = outputs if isinstance(outputs, list) else [outputs] + fn = function(inputs=[A], outputs=outputs) + + np_fn = np.vectorize( + partial(np.linalg.svd, compute_uv=compute_uv, full_matrices=full_matrix), + signature=outputs[0].owner.op.core_op.gufunc_signature, + ) + + np_outputs = np_fn(a) + pt_outputs = fn(a) - assert _allclose(n_u, t_u) - assert _allclose(n_s, t_s) - assert _allclose(n_vt, t_vt) + np_outputs = np_outputs if isinstance(np_outputs, tuple) else [np_outputs] - fn = function([A], svd(A, compute_uv=False)) - t_s = fn(a) - assert _allclose(n_s, t_s) + for np_val, pt_val in zip(np_outputs, pt_outputs): + assert _allclose(np_val, pt_val) def test_svd_infer_shape(self): self.validate_shape((4, 4), full_matrices=True, compute_uv=True) @@ -183,7 +208,7 @@ def test_svd_infer_shape(self): def validate_shape(self, shape, compute_uv=True, full_matrices=True): A = self.A - A_v = self.rng.random(shape).astype(self.dtype) + A_v = self.rng.random(shape).astype(config.floatX) outputs = self.op(A, full_matrices=full_matrices, compute_uv=compute_uv) if not compute_uv: outputs = [outputs] @@ -437,44 +462,82 @@ def test_non_square_matrix(self): f(a) -class TestNormTests: +class TestNorm: def test_wrong_type_of_ord_for_vector(self): - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Invalid norm order 'fro' for vectors"): norm([2, 1], "fro") def test_wrong_type_of_ord_for_matrix(self): - with pytest.raises(ValueError): - norm([[2, 1], [3, 4]], 0) + ord = 0 + with pytest.raises(ValueError, match=f"Invalid norm order for matrices: {ord}"): + norm([[2, 1], [3, 4]], ord) def test_non_tensorial_input(self): - with pytest.raises(ValueError): - norm(3, None) - - def test_tensor_input(self): - with pytest.raises(NotImplementedError): - norm(np.random.random((3, 4, 5)), None) + with pytest.raises( + ValueError, + match="Cannot compute norm when core_dims < 1 or core_dims > 3, found: core_dims = 0", + ): + norm(3, ord=2) + + def test_invalid_axis_input(self): + axis = scalar("i", dtype="int") + with pytest.raises( + TypeError, match="'axis' must be None, an integer, or a tuple of integers" + ): + norm([[1, 2], [3, 4]], axis=axis) + + @pytest.mark.parametrize( + "ord", + [None, np.inf, -np.inf, 1, -1, 2, -2], + ids=["None", "inf", "-inf", "1", "-1", "2", "-2"], + ) + @pytest.mark.parametrize("core_dims", [(4,), (4, 3)], ids=["vector", "matrix"]) + @pytest.mark.parametrize("batch_dims", [(), (2,)], ids=["no_batch", "batch"]) + @pytest.mark.parametrize("test_imag", [True, False], ids=["complex", "real"]) + @pytest.mark.parametrize( + "keepdims", [True, False], ids=["keep_dims=True", "keep_dims=False"] + ) + def test_numpy_compare( + self, + ord: float, + core_dims: tuple[int, ...], + batch_dims: tuple[int, ...], + test_imag: bool, + keepdims: bool, + axis=None, + ): + is_matrix = len(core_dims) == 2 + has_batch = len(batch_dims) > 0 + if ord in [np.inf, -np.inf] and not is_matrix: + pytest.skip("Infinity norm not defined for vectors") + if test_imag and is_matrix and ord == -2: + pytest.skip("Complex matrices not supported") + if has_batch and not is_matrix: + # Handle batched vectors by row-normalizing a matrix + axis = (-1,) - def test_numpy_compare(self): rng = np.random.default_rng(utt.fetch_seed()) - M = matrix("A", dtype=config.floatX) - V = vector("V", dtype=config.floatX) + if test_imag: + x_real, x_imag = rng.standard_normal((2, *batch_dims, *core_dims)).astype( + config.floatX + ) + dtype = "complex128" if config.floatX.endswith("64") else "complex64" + X = (x_real + 1j * x_imag).astype(dtype) + else: + X = rng.standard_normal(batch_dims + core_dims).astype(config.floatX) - a = rng.random((4, 4)).astype(config.floatX) - b = rng.random(4).astype(config.floatX) + if batch_dims == (): + np_norm = np.linalg.norm(X, ord=ord, axis=axis, keepdims=keepdims) + else: + np_norm = np.stack( + [np.linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims) for x in X] + ) - A = ( - [None, "fro", "inf", "-inf", 1, -1, None, "inf", "-inf", 0, 1, -1, 2, -2], - [M, M, M, M, M, M, V, V, V, V, V, V, V, V], - [a, a, a, a, a, a, b, b, b, b, b, b, b, b], - [None, "fro", inf, -inf, 1, -1, None, inf, -inf, 0, 1, -1, 2, -2], - ) + pt_norm = norm(X, ord=ord, axis=axis, keepdims=keepdims) + f = function([], pt_norm, mode="FAST_COMPILE") - for i in range(0, 14): - f = function([A[1][i]], norm(A[1][i], A[0][i])) - t_n = f(A[2][i]) - n_n = np.linalg.norm(A[2][i], A[3][i]) - assert _allclose(n_n, t_n) + utt.assert_allclose(np_norm, f()) class TestTensorInv(utt.InferShapeTester):