diff --git a/aesara/tensor/slinalg.py b/aesara/tensor/slinalg.py index ed373497af..4ee711ddd9 100644 --- a/aesara/tensor/slinalg.py +++ b/aesara/tensor/slinalg.py @@ -1,5 +1,6 @@ import logging import warnings +from typing import Union import numpy as np import scipy.linalg @@ -11,6 +12,7 @@ from aesara.tensor import basic as aet from aesara.tensor import math as atm from aesara.tensor.type import matrix, tensor, vector +from aesara.tensor.var import TensorVariable logger = logging.getLogger(__name__) @@ -259,93 +261,52 @@ def cho_solve(c_and_lower, b, check_finite=True): return CholeskySolve(lower=lower, check_finite=check_finite)(A, b) -class Solve(Op): - """ - Solve a system of linear equations. - - For on CPU and GPU. - """ +class SolveBase(Op): + """Base class for `scipy.linalg` matrix equation solvers.""" __props__ = ( - "assume_a", "lower", - "check_finite", # "transposed" + "check_finite", ) def __init__( self, - assume_a="gen", lower=False, - check_finite=True, # transposed=False + check_finite=True, ): - if assume_a not in ("gen", "sym", "her", "pos"): - raise ValueError(f"{assume_a} is not a recognized matrix structure") - self.assume_a = assume_a self.lower = lower self.check_finite = check_finite - # self.transposed = transposed - def __repr__(self): - return "Solve{%s}" % str(self._props()) + def perform(self, node, inputs, outputs): + pass def make_node(self, A, b): A = as_tensor_variable(A) b = as_tensor_variable(b) - assert A.ndim == 2 - assert b.ndim in [1, 2] - # infer dtype by solving the most simple - # case with (1, 1) matrices + if A.ndim != 2: + raise ValueError(f"`A` must be a matrix; got {A.type} instead.") + if b.ndim not in [1, 2]: + raise ValueError(f"`b` must be a matrix or a vector; got {b.type} instead.") + + # Infer dtype by solving the most simple case with 1x1 matrices o_dtype = scipy.linalg.solve( np.eye(1).astype(A.dtype), np.eye(1).astype(b.dtype) ).dtype x = tensor(broadcastable=b.broadcastable, dtype=o_dtype) return Apply(self, [A, b], [x]) - def perform(self, node, inputs, output_storage): - A, b = inputs - - if self.assume_a != "gen": - # if self.transposed: - # if self.assume_a == "her": - # trans = "C" - # else: - # trans = "T" - # else: - # trans = "N" - - rval = scipy.linalg.solve_triangular( - A, - b, - lower=self.lower, - check_finite=self.check_finite, - # trans=trans - ) - else: - rval = scipy.linalg.solve( - A, - b, - assume_a=self.assume_a, - lower=self.lower, - check_finite=self.check_finite, - # transposed=self.transposed, - ) - - output_storage[0][0] = rval - - # computes shape of x where x = inv(A) * b def infer_shape(self, fgraph, node, shapes): Ashape, Bshape = shapes rows = Ashape[1] - if len(Bshape) == 1: # b is a Vector + if len(Bshape) == 1: return [(rows,)] else: - cols = Bshape[1] # b is a Matrix + cols = Bshape[1] return [(rows, cols)] def L_op(self, inputs, outputs, output_gradients): - r""" - Reverse-mode gradient updates for matrix solve operation :math:`c = A^{-1} b`. + r"""Reverse-mode gradient updates for matrix solve operation :math:`c = A^{-1} b`. Symbolic expression for updates taken from [#]_. @@ -364,31 +325,148 @@ def L_op(self, inputs, outputs, output_gradients): # We need to return (dC/d[inv(A)], dC/db) c_bar = output_gradients[0] - trans_solve_op = Solve( - assume_a=self.assume_a, - check_finite=self.check_finite, - lower=not self.lower, + trans_solve_op = type(self)( + **{ + k: (not getattr(self, k) if k == "lower" else getattr(self, k)) + for k in self.__props__ + } ) b_bar = trans_solve_op(A.T, c_bar) # force outer product if vector second input A_bar = -atm.outer(b_bar, c) if c.ndim == 1 else -b_bar.dot(c.T) - if self.assume_a != "gen": - if self.lower: - A_bar = aet.tril(A_bar) - else: - A_bar = aet.triu(A_bar) - return [A_bar, b_bar] + def __repr__(self): + return f"{type(self).__name__}{self._props()}" + + +class SolveTriangular(SolveBase): + """Solve a system of linear equations.""" + + __props__ = ( + "lower", + "trans", + "unit_diagonal", + "check_finite", + ) + + def __init__( + self, + trans=0, + lower=False, + unit_diagonal=False, + check_finite=True, + ): + super().__init__(lower=lower, check_finite=check_finite) + self.trans = trans + self.unit_diagonal = unit_diagonal + + def perform(self, node, inputs, outputs): + A, b = inputs + outputs[0][0] = scipy.linalg.solve_triangular( + A, + b, + lower=self.lower, + trans=self.trans, + unit_diagonal=self.unit_diagonal, + check_finite=self.check_finite, + ) + + def L_op(self, inputs, outputs, output_gradients): + res = super().L_op(inputs, outputs, output_gradients) + + if self.lower: + res[0] = aet.tril(res[0]) + else: + res[0] = aet.triu(res[0]) + + return res + + +solvetriangular = SolveTriangular() + + +def solve_triangular( + a: TensorVariable, + b: TensorVariable, + trans: Union[int, str] = 0, + lower: bool = False, + unit_diagonal: bool = False, + check_finite: bool = True, +) -> TensorVariable: + """Solve the equation `a x = b` for `x`, assuming `a` is a triangular matrix. + + Parameters + ---------- + a + Square input data + b + Input data for the right hand side. + lower : bool, optional + Use only data contained in the lower triangle of `a`. Default is to use upper triangle. + trans: {0, 1, 2, ‘N’, ‘T’, ‘C’}, optional + Type of system to solve: + trans system + 0 or 'N' a x = b + 1 or 'T' a^T x = b + 2 or 'C' a^H x = b + unit_diagonal: bool, optional + If True, diagonal elements of `a` are assumed to be 1 and will not be referenced. + check_finite : bool, optional + Whether to check that the input matrices contain only finite numbers. + Disabling may give a performance gain, but may result in problems + (crashes, non-termination) if the inputs do contain infinities or NaNs. + """ + return SolveTriangular( + lower=lower, + trans=trans, + unit_diagonal=unit_diagonal, + check_finite=check_finite, + )(a, b) + + +class Solve(SolveBase): + """ + Solve a system of linear equations. + + For on CPU and GPU. + """ + + __props__ = ( + "assume_a", + "lower", + "check_finite", + ) + + def __init__( + self, + assume_a="gen", + lower=False, + check_finite=True, + ): + if assume_a not in ("gen", "sym", "her", "pos"): + raise ValueError(f"{assume_a} is not a recognized matrix structure") + + super().__init__(lower=lower, check_finite=check_finite) + self.assume_a = assume_a + + def perform(self, node, inputs, outputs): + a, b = inputs + outputs[0][0] = scipy.linalg.solve( + a=a, + b=b, + lower=self.lower, + check_finite=self.check_finite, + assume_a=self.assume_a, + ) + solve = Solve() def solve(a, b, assume_a="gen", lower=False, check_finite=True): - """ - Solves the linear equation set ``a * x = b`` for the unknown ``x`` - for square ``a`` matrix. + """Solves the linear equation set ``a * x = b`` for the unknown ``x`` for square ``a`` matrix. If the data matrix is known to be a particular type then supplying the corresponding string to ``assume_a`` key chooses the dedicated solver. @@ -432,8 +510,8 @@ def solve(a, b, assume_a="gen", lower=False, check_finite=True): # TODO: These are deprecated; emit a warning -solve_lower_triangular = Solve(assume_a="sym", lower=True) -solve_upper_triangular = Solve(assume_a="sym", lower=False) +solve_lower_triangular = SolveTriangular(lower=True) +solve_upper_triangular = SolveTriangular(lower=False) solve_symmetric = Solve(assume_a="sym") # TODO: Optimizations to replace multiplication by matrix inverse diff --git a/tests/link/test_numba.py b/tests/link/test_numba.py index 6845333a6b..0be2caab24 100644 --- a/tests/link/test_numba.py +++ b/tests/link/test_numba.py @@ -2174,6 +2174,31 @@ def test_Cholesky(x, lower, exc): "gen", None, ), + ], +) +def test_Solve(A, x, lower, exc): + g = slinalg.Solve(lower)(A, x) + + if isinstance(g, list): + g_fg = FunctionGraph(outputs=g) + else: + g_fg = FunctionGraph(outputs=[g]) + + cm = contextlib.suppress() if exc is None else pytest.warns(exc) + with cm: + compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, (SharedVariable, Constant)) + ], + ) + + +@pytest.mark.parametrize( + "A, x, lower, exc", + [ ( set_test_value( aet.dmatrix(), @@ -2185,8 +2210,8 @@ def test_Cholesky(x, lower, exc): ), ], ) -def test_Solve(A, x, lower, exc): - g = slinalg.Solve(lower)(A, x) +def test_SolveTriangular(A, x, lower, exc): + g = slinalg.SolveTriangular(lower)(A, x) if isinstance(g, list): g_fg = FunctionGraph(outputs=g) diff --git a/tests/tensor/test_slinalg.py b/tests/tensor/test_slinalg.py index ab3b9aea66..a32d953785 100644 --- a/tests/tensor/test_slinalg.py +++ b/tests/tensor/test_slinalg.py @@ -1,3 +1,4 @@ +import functools import itertools import numpy as np @@ -14,12 +15,15 @@ CholeskyGrad, CholeskySolve, Solve, + SolveBase, + SolveTriangular, cho_solve, cholesky, eigvalsh, expm, kron, solve, + solve_triangular, ) from aesara.tensor.type import dmatrix, matrix, tensor, vector from tests import unittest_tools as utt @@ -170,122 +174,107 @@ def test_eigvalsh_grad(): ) -class TestSolve(utt.InferShapeTester): - def setup_method(self): - self.op_class = Solve - self.op = Solve() - super().setup_method() - - def test_infer_shape(self): - rng = np.random.default_rng(utt.fetch_seed()) +class TestSolveBase(utt.InferShapeTester): + @pytest.mark.parametrize( + "A_func, b_func, error_message", + [ + (vector, matrix, "`A` must be a matrix.*"), + ( + functools.partial(tensor, dtype="floatX", broadcastable=(False,) * 3), + matrix, + "`A` must be a matrix.*", + ), + ( + matrix, + functools.partial(tensor, dtype="floatX", broadcastable=(False,) * 3), + "`b` must be a matrix or a vector.*", + ), + ], + ) + def test_make_node(self, A_func, b_func, error_message): + np.random.default_rng(utt.fetch_seed()) + with pytest.raises(ValueError, match=error_message): + A = A_func() + b = b_func() + SolveBase()(A, b) + + def test__repr__(self): + np.random.default_rng(utt.fetch_seed()) A = matrix() b = matrix() - self._compile_and_check( - [A, b], # aesara.function inputs - [self.op(A, b)], # aesara.function outputs - # A must be square - [ - np.asarray(rng.random((5, 5)), dtype=config.floatX), - np.asarray(rng.random((5, 1)), dtype=config.floatX), - ], - self.op_class, - warn=False, - ) + y = SolveBase()(A, b) + assert y.__repr__() == "SolveBase{lower=False, check_finite=True}.0" + + +class TestSolve(utt.InferShapeTester): + def test__init__(self): + with pytest.raises(ValueError) as excinfo: + Solve(assume_a="test") + assert "is not a recognized matrix structure" in str(excinfo.value) + + @pytest.mark.parametrize("b_shape", [(5, 1), (5,)]) + def test_infer_shape(self, b_shape): rng = np.random.default_rng(utt.fetch_seed()) A = matrix() - b = vector() + b_val = np.asarray(rng.random(b_shape), dtype=config.floatX) + b = aet.as_tensor_variable(b_val).type() self._compile_and_check( - [A, b], # aesara.function inputs - [self.op(A, b)], # aesara.function outputs - # A must be square + [A, b], + [solve(A, b)], [ np.asarray(rng.random((5, 5)), dtype=config.floatX), - np.asarray(rng.random((5)), dtype=config.floatX), + b_val, ], - self.op_class, + Solve, warn=False, ) - def test_solve_correctness(self): + def test_correctness(self): rng = np.random.default_rng(utt.fetch_seed()) A = matrix() b = matrix() - y = self.op(A, b) + y = solve(A, b) gen_solve_func = aesara.function([A, b], y) - cholesky_lower = Cholesky(lower=True) - L = cholesky_lower(A) - y_lower = self.op(L, b) - lower_solve_func = aesara.function([L, b], y_lower) - - cholesky_upper = Cholesky(lower=False) - U = cholesky_upper(A) - y_upper = self.op(U, b) - upper_solve_func = aesara.function([U, b], y_upper) - b_val = np.asarray(rng.random((5, 1)), dtype=config.floatX) - # 1-test general case A_val = np.asarray(rng.random((5, 5)), dtype=config.floatX) - # positive definite matrix: A_val = np.dot(A_val.transpose(), A_val) + assert np.allclose( scipy.linalg.solve(A_val, b_val), gen_solve_func(A_val, b_val) ) - # 2-test lower traingular case - L_val = scipy.linalg.cholesky(A_val, lower=True) - assert np.allclose( - scipy.linalg.solve_triangular(L_val, b_val, lower=True), - lower_solve_func(L_val, b_val), + A_undef = np.array( + [ + [1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 1, 1], + [0, 0, 0, 1, 0], + ], + dtype=config.floatX, ) - - # 3-test upper traingular case - U_val = scipy.linalg.cholesky(A_val, lower=False) assert np.allclose( - scipy.linalg.solve_triangular(U_val, b_val, lower=False), - upper_solve_func(U_val, b_val), + scipy.linalg.solve(A_undef, b_val), gen_solve_func(A_undef, b_val) ) - def test_solve_dtype(self): - dtypes = [ - "uint8", - "uint16", - "uint32", - "uint64", - "int8", - "int16", - "int32", - "int64", - "float16", - "float32", - "float64", - ] - - A_val = np.eye(2) - b_val = np.ones((2, 1)) - - # try all dtype combinations - for A_dtype, b_dtype in itertools.product(dtypes, dtypes): - A = matrix(dtype=A_dtype) - b = matrix(dtype=b_dtype) - x = solve(A, b) - fn = function([A, b], x) - x_result = fn(A_val.astype(A_dtype), b_val.astype(b_dtype)) - - assert x.dtype == x_result.dtype + @pytest.mark.parametrize( + "m, n, assume_a, lower", + [ + (5, None, "gen", False), + (5, None, "gen", True), + (4, 2, "gen", False), + (4, 2, "gen", True), + ], + ) + def test_solve_grad(self, m, n, assume_a, lower): + rng = np.random.default_rng(utt.fetch_seed()) - def verify_solve_grad(self, m, n, assume_a, lower, rng): - # ensure diagonal elements of A relatively large to avoid numerical - # precision issues + # Ensure diagonal elements of `A` are relatively large to avoid + # numerical precision issues A_val = (rng.normal(size=(m, m)) * 0.5 + np.eye(m)).astype(config.floatX) - if assume_a != "gen": - if lower: - A_val = np.tril(A_val) - else: - A_val = np.triu(A_val) - if n is None: b_val = rng.normal(size=m).astype(config.floatX) else: @@ -298,22 +287,76 @@ def verify_solve_grad(self, m, n, assume_a, lower, rng): solve_op = Solve(assume_a=assume_a, lower=lower) utt.verify_grad(solve_op, [A_val, b_val], 3, rng, eps=eps) + +class TestSolveTriangular(utt.InferShapeTester): + @pytest.mark.parametrize("b_shape", [(5, 1), (5,)]) + def test_infer_shape(self, b_shape): + rng = np.random.default_rng(utt.fetch_seed()) + A = matrix() + b_val = np.asarray(rng.random(b_shape), dtype=config.floatX) + b = aet.as_tensor_variable(b_val).type() + self._compile_and_check( + [A, b], + [solve_triangular(A, b)], + [ + np.asarray(rng.random((5, 5)), dtype=config.floatX), + b_val, + ], + SolveTriangular, + warn=False, + ) + + @pytest.mark.parametrize("lower", [True, False]) + def test_correctness(self, lower): + rng = np.random.default_rng(utt.fetch_seed()) + + b_val = np.asarray(rng.random((5, 1)), dtype=config.floatX) + + A_val = np.asarray(rng.random((5, 5)), dtype=config.floatX) + A_val = np.dot(A_val.transpose(), A_val) + + C_val = scipy.linalg.cholesky(A_val, lower=lower) + + A = matrix() + b = matrix() + + cholesky = Cholesky(lower=lower) + C = cholesky(A) + y_lower = solve_triangular(C, b, lower=lower) + lower_solve_func = aesara.function([C, b], y_lower) + + assert np.allclose( + scipy.linalg.solve_triangular(C_val, b_val, lower=lower), + lower_solve_func(C_val, b_val), + ) + @pytest.mark.parametrize( - "m, n, assume_a, lower", + "m, n, lower", [ - (5, None, "gen", False), - (5, None, "gen", True), - (4, 2, "gen", False), - (4, 2, "gen", True), - (5, None, "sym", False), - (5, None, "sym", True), - (4, 2, "sym", False), - (4, 2, "sym", True), + (5, None, False), + (5, None, True), + (4, 2, False), + (4, 2, True), ], ) - def test_solve_grad(self, m, n, assume_a, lower): + def test_solve_grad(self, m, n, lower): rng = np.random.default_rng(utt.fetch_seed()) - self.verify_solve_grad(m, n, assume_a, lower, rng) + + # Ensure diagonal elements of `A` are relatively large to avoid + # numerical precision issues + A_val = (rng.normal(size=(m, m)) * 0.5 + np.eye(m)).astype(config.floatX) + + if n is None: + b_val = rng.normal(size=m).astype(config.floatX) + else: + b_val = rng.normal(size=(m, n)).astype(config.floatX) + + eps = None + if config.floatX == "float64": + eps = 2e-8 + + solve_op = SolveTriangular(lower=lower) + utt.verify_grad(solve_op, [A_val, b_val], 3, rng, eps=eps) class TestCholeskySolve(utt.InferShapeTester):