Skip to content

Refactor nlinalg.norm to match np.linalg.norm signature and functionaly #588

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
238 changes: 207 additions & 31 deletions pytensor/tensor/nlinalg.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import warnings
from functools import partial
from typing import Callable, Literal, Optional, Union

import numpy as np
from numpy.core.numeric import normalize_axis_tuple # type: ignore

from pytensor import scalar as ps
from pytensor.gradient import DisconnectedType
Expand Down Expand Up @@ -523,6 +525,7 @@ def qr(a, mode="reduced"):

class SVD(Op):
"""
Computes singular value decomposition of matrix A, into U, S, V such that A = U @ S @ V

Parameters
----------
Expand All @@ -543,13 +546,23 @@ class SVD(Op):
def __init__(self, full_matrices: bool = True, compute_uv: bool = True):
self.full_matrices = bool(full_matrices)
self.compute_uv = bool(compute_uv)
if self.compute_uv:
if self.full_matrices:
self.gufunc_signature = "(m,n)->(m,m),(k),(n,n)"
else:
self.gufunc_signature = "(m,n)->(m,k),(k),(k,n)"
else:
self.gufunc_signature = "(m,n)->(k)"

def make_node(self, x):
x = as_tensor_variable(x)
assert x.ndim == 2, "The input of svd function should be a matrix."

in_dtype = x.type.numpy_dtype
out_dtype = np.dtype(f"f{in_dtype.itemsize}")
if in_dtype.name.startswith("int"):
out_dtype = np.dtype(f"f{in_dtype.itemsize}")
else:
out_dtype = in_dtype

s = vector(dtype=out_dtype)

Expand Down Expand Up @@ -603,7 +616,7 @@ def svd(a, full_matrices: bool = True, compute_uv: bool = True):
U, V, D : matrices

"""
return SVD(full_matrices, compute_uv)(a)
return Blockwise(SVD(full_matrices, compute_uv))(a)


class Lstsq(Op):
Expand Down Expand Up @@ -677,41 +690,204 @@ def matrix_power(M, n):
return result


def norm(x, ord):
x = as_tensor_variable(x)
def _multi_svd_norm(
x: ptb.TensorVariable, row_axis: int, col_axis: int, reduce_op: Callable
):
"""Compute a function of the singular values of the 2-D matrices in `x`.

This is a private utility function used by `pytensor.tensor.nlinalg.norm()`.

Copied from `np.linalg._multi_svd_norm`.

Parameters
----------
x : TensorVariable
Input tensor.
row_axis, col_axis : int
The axes of `x` that hold the 2-D matrices.
reduce_op : callable
Reduction op. Should be one of `pt.min`, `pt.max`, or `pt.sum`

Returns
-------
result : float or ndarray
If `x` is 2-D, the return values is a float.
Otherwise, it is an array with ``x.ndim - 2`` dimensions.
The return values are either the minimum or maximum or sum of the
singular values of the matrices, depending on whether `op`
is `pt.amin` or `pt.amax` or `pt.sum`.

"""
y = ptb.moveaxis(x, (row_axis, col_axis), (-2, -1))
result = reduce_op(svd(y, compute_uv=False), axis=-1)
return result


VALID_ORD = Literal["fro", "f", "nuc", "inf", "-inf", 0, 1, -1, 2, -2]


def norm(
x: ptb.TensorVariable,
ord: Optional[Union[float, VALID_ORD]] = None,
axis: Optional[Union[int, tuple[int, ...]]] = None,
keepdims: bool = False,
):
"""
Matrix or vector norm.

Parameters
----------
x: TensorVariable
Tensor to take norm of.

ord: float, str or int, optional
Order of norm. If `ord` is a str, it must be one of the following:
- 'fro' or 'f' : Frobenius norm
- 'nuc' : nuclear norm
- 'inf' : Infinity norm
- '-inf' : Negative infinity norm
If an integer, order can be one of -2, -1, 0, 1, or 2.
Otherwise `ord` must be a float.

Default is the Frobenius (L2) norm.

axis: tuple of int, optional
Axes over which to compute the norm. If None, norm of entire matrix (or vector) is computed. Row or column
norms can be computed by passing a single integer; this will treat a matrix like a batch of vectors.

keepdims: bool
If True, dummy axes will be inserted into the output so that norm.dnim == x.dnim. Default is False.

Returns
-------
TensorVariable
Norm of `x` along axes specified by `axis`.

Notes
-----
Batched dimensions are supported to the left of the core dimensions. For example, if `x` is a 3D tensor with
shape (2, 3, 4), then `norm(x)` will compute the norm of each 3x4 matrix in the batch.

If the input is a 2D tensor and should be treated as a batch of vectors, the `axis` argument must be specified.
"""
x = ptb.as_tensor_variable(x)

ndim = x.ndim
if ndim == 0:
raise ValueError("'axis' entry is out of bounds.")
elif ndim == 1:
if ord is None:
return ptm.sum(x**2) ** 0.5
elif ord == "inf":
return ptm.max(abs(x))
elif ord == "-inf":
return ptm.min(abs(x))
core_ndim = min(2, ndim)
batch_ndim = ndim - core_ndim

if axis is None:
# Handle some common cases first. These can be computed more quickly than the default SVD way, so we always
# want to check for them.
if (
(ord is None)
or (ord in ("f", "fro") and core_ndim == 2)
or (ord == 2 and core_ndim == 1)
):
x = x.reshape(tuple(x.shape[:-2]) + (-1,) + (1,) * (core_ndim - 1))
batch_T_dim_order = tuple(range(batch_ndim)) + tuple(
range(batch_ndim + core_ndim - 1, batch_ndim - 1, -1)
)

if x.dtype.startswith("complex"):
x_real = x.real # type: ignore
x_imag = x.imag # type: ignore
sqnorm = (
ptb.transpose(x_real, batch_T_dim_order) @ x_real
+ ptb.transpose(x_imag, batch_T_dim_order) @ x_imag
)
else:
sqnorm = ptb.transpose(x, batch_T_dim_order) @ x
ret = ptm.sqrt(sqnorm).squeeze()
if keepdims:
ret = ptb.shape_padright(ret, core_ndim)
return ret

# No special computation to exploit -- set default axis before continuing
axis = tuple(range(core_ndim))

elif not isinstance(axis, tuple):
try:
axis = int(axis)
except Exception as e:
raise TypeError(
"'axis' must be None, an integer, or a tuple of integers"
) from e

axis = (axis,)

if len(axis) == 1:
# Vector norms
if ord in [None, "fro", "f"] and (core_ndim == 2):
# This is here to catch the case where X is a 2D tensor but the user wants to treat it as a batch of
# vectors. Other vector norms will work fine in this case.
ret = ptm.sqrt(ptm.sum((x.conj() * x).real, axis=axis, keepdims=keepdims))
elif (ord == "inf") or (ord == np.inf):
ret = ptm.max(ptm.abs(x), axis=axis, keepdims=keepdims)
elif (ord == "-inf") or (ord == -np.inf):
ret = ptm.min(ptm.abs(x), axis=axis, keepdims=keepdims)
elif ord == 0:
return x[x.nonzero()].shape[0]
ret = ptm.neq(x, 0).sum(axis=axis, keepdims=keepdims)
elif ord == 1:
ret = ptm.sum(ptm.abs(x), axis=axis, keepdims=keepdims)
elif isinstance(ord, str):
raise ValueError(f"Invalid norm order '{ord}' for vectors")
else:
try:
z = ptm.sum(abs(x**ord)) ** (1.0 / ord)
except TypeError:
raise ValueError("Invalid norm order for vectors.")
return z
elif ndim == 2:
if ord is None or ord == "fro":
return ptm.sum(abs(x**2)) ** (0.5)
elif ord == "inf":
return ptm.max(ptm.sum(abs(x), 1))
elif ord == "-inf":
return ptm.min(ptm.sum(abs(x), 1))
ret = ptm.sum(ptm.abs(x) ** ord, axis=axis, keepdims=keepdims)
ret **= ptm.reciprocal(ord)

return ret

elif len(axis) == 2:
# Matrix norms
row_axis, col_axis = (
batch_ndim + x for x in normalize_axis_tuple(axis, core_ndim)
)
axis = (row_axis, col_axis)

if ord in [None, "fro", "f"]:
ret = ptm.sqrt(ptm.sum((x.conj() * x).real, axis=axis))

elif (ord == "inf") or (ord == np.inf):
if row_axis > col_axis:
row_axis -= 1
ret = ptm.max(ptm.sum(ptm.abs(x), axis=col_axis), axis=row_axis)

elif (ord == "-inf") or (ord == -np.inf):
if row_axis > col_axis:
row_axis -= 1
ret = ptm.min(ptm.sum(ptm.abs(x), axis=col_axis), axis=row_axis)

elif ord == 1:
return ptm.max(ptm.sum(abs(x), 0))
if col_axis > row_axis:
col_axis -= 1
ret = ptm.max(ptm.sum(ptm.abs(x), axis=row_axis), axis=col_axis)

elif ord == -1:
return ptm.min(ptm.sum(abs(x), 0))
if col_axis > row_axis:
col_axis -= 1
ret = ptm.min(ptm.sum(ptm.abs(x), axis=row_axis), axis=col_axis)

elif ord == 2:
ret = _multi_svd_norm(x, row_axis, col_axis, ptm.max)

elif ord == -2:
ret = _multi_svd_norm(x, row_axis, col_axis, ptm.min)

elif ord == "nuc":
ret = _multi_svd_norm(x, row_axis, col_axis, ptm.sum)

else:
raise ValueError(0)
elif ndim > 2:
raise NotImplementedError("We don't support norm with ndim > 2")
raise ValueError(f"Invalid norm order for matrices: {ord}")

if keepdims:
ret = ptb.expand_dims(ret, axis)

return ret
else:
raise ValueError(
f"Cannot compute norm when core_dims < 1 or core_dims > 3, found: core_dims = {core_ndim}"
)


class TensorInv(Op):
Expand Down
Loading