Skip to content

Commit

Permalink
Add a SolveTriangular Op
Browse files Browse the repository at this point in the history
`Solve` has also been changed to match SciPy.
  • Loading branch information
fshart authored and brandonwillard committed Dec 13, 2021
1 parent 6fce270 commit 79961a6
Show file tree
Hide file tree
Showing 3 changed files with 315 additions and 169 deletions.
220 changes: 149 additions & 71 deletions aesara/tensor/slinalg.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import warnings
from typing import Union

import numpy as np
import scipy.linalg
Expand All @@ -11,6 +12,7 @@
from aesara.tensor import basic as aet
from aesara.tensor import math as atm
from aesara.tensor.type import matrix, tensor, vector
from aesara.tensor.var import TensorVariable


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -259,93 +261,52 @@ def cho_solve(c_and_lower, b, check_finite=True):
return CholeskySolve(lower=lower, check_finite=check_finite)(A, b)


class Solve(Op):
"""
Solve a system of linear equations.
For on CPU and GPU.
"""
class SolveBase(Op):
"""Base class for `scipy.linalg` matrix equation solvers."""

__props__ = (
"assume_a",
"lower",
"check_finite", # "transposed"
"check_finite",
)

def __init__(
self,
assume_a="gen",
lower=False,
check_finite=True, # transposed=False
check_finite=True,
):
if assume_a not in ("gen", "sym", "her", "pos"):
raise ValueError(f"{assume_a} is not a recognized matrix structure")
self.assume_a = assume_a
self.lower = lower
self.check_finite = check_finite
# self.transposed = transposed

def __repr__(self):
return "Solve{%s}" % str(self._props())
def perform(self, node, inputs, outputs):
pass

def make_node(self, A, b):
A = as_tensor_variable(A)
b = as_tensor_variable(b)
assert A.ndim == 2
assert b.ndim in [1, 2]

# infer dtype by solving the most simple
# case with (1, 1) matrices
if A.ndim != 2:
raise ValueError(f"`A` must be a matrix; got {A.type} instead.")
if b.ndim not in [1, 2]:
raise ValueError(f"`b` must be a matrix or a vector; got {b.type} instead.")

# Infer dtype by solving the most simple case with 1x1 matrices
o_dtype = scipy.linalg.solve(
np.eye(1).astype(A.dtype), np.eye(1).astype(b.dtype)
).dtype
x = tensor(broadcastable=b.broadcastable, dtype=o_dtype)
return Apply(self, [A, b], [x])

def perform(self, node, inputs, output_storage):
A, b = inputs

if self.assume_a != "gen":
# if self.transposed:
# if self.assume_a == "her":
# trans = "C"
# else:
# trans = "T"
# else:
# trans = "N"

rval = scipy.linalg.solve_triangular(
A,
b,
lower=self.lower,
check_finite=self.check_finite,
# trans=trans
)
else:
rval = scipy.linalg.solve(
A,
b,
assume_a=self.assume_a,
lower=self.lower,
check_finite=self.check_finite,
# transposed=self.transposed,
)

output_storage[0][0] = rval

# computes shape of x where x = inv(A) * b
def infer_shape(self, fgraph, node, shapes):
Ashape, Bshape = shapes
rows = Ashape[1]
if len(Bshape) == 1: # b is a Vector
if len(Bshape) == 1:
return [(rows,)]
else:
cols = Bshape[1] # b is a Matrix
cols = Bshape[1]
return [(rows, cols)]

def L_op(self, inputs, outputs, output_gradients):
r"""
Reverse-mode gradient updates for matrix solve operation :math:`c = A^{-1} b`.
r"""Reverse-mode gradient updates for matrix solve operation :math:`c = A^{-1} b`.
Symbolic expression for updates taken from [#]_.
Expand All @@ -364,31 +325,148 @@ def L_op(self, inputs, outputs, output_gradients):
# We need to return (dC/d[inv(A)], dC/db)
c_bar = output_gradients[0]

trans_solve_op = Solve(
assume_a=self.assume_a,
check_finite=self.check_finite,
lower=not self.lower,
trans_solve_op = type(self)(
**{
k: (not getattr(self, k) if k == "lower" else getattr(self, k))
for k in self.__props__
}
)
b_bar = trans_solve_op(A.T, c_bar)
# force outer product if vector second input
A_bar = -atm.outer(b_bar, c) if c.ndim == 1 else -b_bar.dot(c.T)

if self.assume_a != "gen":
if self.lower:
A_bar = aet.tril(A_bar)
else:
A_bar = aet.triu(A_bar)

return [A_bar, b_bar]

def __repr__(self):
return f"{type(self).__name__}{self._props()}"


class SolveTriangular(SolveBase):
"""Solve a system of linear equations."""

__props__ = (
"lower",
"trans",
"unit_diagonal",
"check_finite",
)

def __init__(
self,
trans=0,
lower=False,
unit_diagonal=False,
check_finite=True,
):
super().__init__(lower=lower, check_finite=check_finite)
self.trans = trans
self.unit_diagonal = unit_diagonal

def perform(self, node, inputs, outputs):
A, b = inputs
outputs[0][0] = scipy.linalg.solve_triangular(
A,
b,
lower=self.lower,
trans=self.trans,
unit_diagonal=self.unit_diagonal,
check_finite=self.check_finite,
)

def L_op(self, inputs, outputs, output_gradients):
res = super().L_op(inputs, outputs, output_gradients)

if self.lower:
res[0] = aet.tril(res[0])
else:
res[0] = aet.triu(res[0])

return res


solvetriangular = SolveTriangular()


def solve_triangular(
a: TensorVariable,
b: TensorVariable,
trans: Union[int, str] = 0,
lower: bool = False,
unit_diagonal: bool = False,
check_finite: bool = True,
) -> TensorVariable:
"""Solve the equation `a x = b` for `x`, assuming `a` is a triangular matrix.
Parameters
----------
a
Square input data
b
Input data for the right hand side.
lower : bool, optional
Use only data contained in the lower triangle of `a`. Default is to use upper triangle.
trans: {0, 1, 2, ‘N’, ‘T’, ‘C’}, optional
Type of system to solve:
trans system
0 or 'N' a x = b
1 or 'T' a^T x = b
2 or 'C' a^H x = b
unit_diagonal: bool, optional
If True, diagonal elements of `a` are assumed to be 1 and will not be referenced.
check_finite : bool, optional
Whether to check that the input matrices contain only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.
"""
return SolveTriangular(
lower=lower,
trans=trans,
unit_diagonal=unit_diagonal,
check_finite=check_finite,
)(a, b)


class Solve(SolveBase):
"""
Solve a system of linear equations.
For on CPU and GPU.
"""

__props__ = (
"assume_a",
"lower",
"check_finite",
)

def __init__(
self,
assume_a="gen",
lower=False,
check_finite=True,
):
if assume_a not in ("gen", "sym", "her", "pos"):
raise ValueError(f"{assume_a} is not a recognized matrix structure")

super().__init__(lower=lower, check_finite=check_finite)
self.assume_a = assume_a

def perform(self, node, inputs, outputs):
a, b = inputs
outputs[0][0] = scipy.linalg.solve(
a=a,
b=b,
lower=self.lower,
check_finite=self.check_finite,
assume_a=self.assume_a,
)


solve = Solve()


def solve(a, b, assume_a="gen", lower=False, check_finite=True):
"""
Solves the linear equation set ``a * x = b`` for the unknown ``x``
for square ``a`` matrix.
"""Solves the linear equation set ``a * x = b`` for the unknown ``x`` for square ``a`` matrix.
If the data matrix is known to be a particular type then supplying the
corresponding string to ``assume_a`` key chooses the dedicated solver.
Expand Down Expand Up @@ -432,8 +510,8 @@ def solve(a, b, assume_a="gen", lower=False, check_finite=True):


# TODO: These are deprecated; emit a warning
solve_lower_triangular = Solve(assume_a="sym", lower=True)
solve_upper_triangular = Solve(assume_a="sym", lower=False)
solve_lower_triangular = SolveTriangular(lower=True)
solve_upper_triangular = SolveTriangular(lower=False)
solve_symmetric = Solve(assume_a="sym")

# TODO: Optimizations to replace multiplication by matrix inverse
Expand Down
29 changes: 27 additions & 2 deletions tests/link/test_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -2174,6 +2174,31 @@ def test_Cholesky(x, lower, exc):
"gen",
None,
),
],
)
def test_Solve(A, x, lower, exc):
g = slinalg.Solve(lower)(A, x)

if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
else:
g_fg = FunctionGraph(outputs=[g])

cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)


@pytest.mark.parametrize(
"A, x, lower, exc",
[
(
set_test_value(
aet.dmatrix(),
Expand All @@ -2185,8 +2210,8 @@ def test_Cholesky(x, lower, exc):
),
],
)
def test_Solve(A, x, lower, exc):
g = slinalg.Solve(lower)(A, x)
def test_SolveTriangular(A, x, lower, exc):
g = slinalg.SolveTriangular(lower)(A, x)

if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
Expand Down
Loading

0 comments on commit 79961a6

Please sign in to comment.