Skip to content

Commit fde3e37

Browse files
Add rewrite to remove transposed kwargs
1 parent 0af2102 commit fde3e37

File tree

4 files changed

+99
-17
lines changed

4 files changed

+99
-17
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
Cholesky,
5050
Solve,
5151
SolveBase,
52+
SolveTriangular,
5253
_bilinear_solve_discrete_lyapunov,
5354
block_diag,
5455
cholesky,
@@ -957,7 +958,8 @@ def jax_bilinaer_lyapunov_to_direct(fgraph: FunctionGraph, node: Apply):
957958
@node_rewriter([det])
958959
def slogdet_specialization(fgraph, node):
959960
"""
960-
This rewrite targets specific operations related to slogdet i.e sign(det), log(det) and log(abs(det)) and rewrites them using the SLogDet operation.
961+
This rewrite targets specific operations related to slogdet i.e sign(det), log(det) and log(abs(det)) and rewrites
962+
them using the SLogDet operation.
961963
962964
Parameters
963965
----------
@@ -1013,3 +1015,45 @@ def slogdet_specialization(fgraph, node):
10131015
k: slogdet_specialization_map[v] for k, v in dummy_replacements.items()
10141016
}
10151017
return replacements
1018+
1019+
1020+
@register_specialize
1021+
@node_rewriter([Blockwise])
1022+
def rewrite_transposed_argument_to_symbolic(fgraph, node):
1023+
"""
1024+
Replace solve(A, b, transposed=True) with solve(A.T, b, transposed=False). In testing this was slightly faster,
1025+
and results is more readable graphs, because the transpose Op is explicitly included.
1026+
"""
1027+
solve_op = node.op.core_op
1028+
if not isinstance(solve_op, Solve | SolveTriangular):
1029+
return None
1030+
1031+
solve_kwargs = solve_op._props_dict()
1032+
A, b = node.inputs
1033+
1034+
if isinstance(solve_op, Solve):
1035+
if not solve_kwargs["transposed"]:
1036+
return None
1037+
1038+
solve_kwargs["transposed"] = False
1039+
solve_kwargs["lower"] = not solve_kwargs["lower"]
1040+
return [Blockwise(Solve(**solve_kwargs))(A.T, b)]
1041+
1042+
if isinstance(solve_op, SolveTriangular):
1043+
if solve_kwargs["trans"] in [False, 0]:
1044+
return None
1045+
1046+
elif solve_kwargs["trans"] in [True, 1, "T"]:
1047+
solve_kwargs["trans"] = 0
1048+
solve_kwargs["lower"] = not solve_kwargs["lower"]
1049+
return [Blockwise(SolveTriangular(**solve_kwargs))(A.T, b)]
1050+
1051+
elif solve_kwargs["trans"] in [2, "C"]:
1052+
solve_kwargs["trans"] = 0
1053+
solve_kwargs["lower"] = not solve_kwargs["lower"]
1054+
return [Blockwise(SolveTriangular(**solve_kwargs))(A.conj().T, b)]
1055+
1056+
else:
1057+
return None
1058+
1059+
return None

pytensor/tensor/slinalg.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,13 @@ def L_op(self, inputs, outputs, output_gradients):
284284
285285
Symbolic expression for updates taken from [#]_.
286286
287+
.. math::
288+
289+
\begin{align}
290+
\bar{b} &= A^{-T} \bar{c} \\
291+
\bar{A} &= -A^{-T} \bar{c} c^T &= \bar{b} \c^T \\
292+
\end{align}
293+
287294
References
288295
----------
289296
.. [#] M. B. Giles, "An extended collection of matrix derivative results
@@ -294,31 +301,24 @@ def L_op(self, inputs, outputs, output_gradients):
294301
A, b = inputs
295302

296303
c = outputs[0]
297-
# C is a scalar representing the entire graph
298-
# `output_gradients` is (dC/dc,)
299-
# We need to return (dC/d[inv(A)], dC/db)
300304
c_bar = output_gradients[0]
301305

302306
props_dict = self._props_dict()
303-
307+
# Note that we do *not* reverse the lower argument here. This is because LAPACK handles this internally when
308+
# transpose=True. If we try to "help", it will end up solving with the wrong triangle.
304309
if isinstance(self, SolveTriangular):
305310
# SolveTriangular has a special trans argument we have to handle
306-
transposed = props_dict.pop("trans") in [1, "T"]
311+
transposed = props_dict.pop("trans") in [1, "T", 2, "C"]
307312
props_dict["trans"] = not transposed
308313
else:
309314
transposed = props_dict.pop("transposed")
310315
props_dict["transposed"] = not transposed
311316

312-
# TODO: We were flipping lower before, but it doesn't appear we need to -- all tests pass without taking it into
313-
# account.
314-
# props_dict['lower'] = not self.lower
315-
316317
solve_op = type(self)(**props_dict)
317-
318318
b_bar = solve_op(A, c_bar)
319+
319320
# force outer product if vector second input
320321
A_bar = -ptm.outer(b_bar, c) if c.ndim == 1 else -b_bar.dot(c.T)
321-
322322
if transposed:
323323
A_bar = A_bar.T
324324

tests/tensor/rewriting/test_linalg.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from pytensor import tensor as pt
1111
from pytensor.compile import get_default_mode
1212
from pytensor.configdefaults import config
13+
from pytensor.graph import FunctionGraph
1314
from pytensor.graph.rewriting.utils import rewrite_graph
1415
from pytensor.tensor import swapaxes
1516
from pytensor.tensor.blockwise import Blockwise
@@ -993,3 +994,35 @@ def test_slogdet_specialization():
993994
f = function([x], [exp_det_x, sign_det_x], mode="FAST_RUN")
994995
nodes = f.maker.fgraph.apply_nodes
995996
assert not any(isinstance(node.op, SLogDet) for node in nodes)
997+
998+
999+
@pytest.mark.parametrize(
1000+
"solve_op, kwargs",
1001+
[
1002+
(pt.linalg.solve, {"transposed": True, "lower": False, "assume_a": "sym"}),
1003+
(pt.linalg.solve_triangular, {"trans": 1, "lower": False}),
1004+
],
1005+
)
1006+
def test_rewrite_away_transposed_argument(solve_op, kwargs):
1007+
A = pt.tensor("A", shape=(5, 5))
1008+
b = pt.tensor("b", shape=(5,))
1009+
x = solve_op(A, b, **kwargs)
1010+
1011+
fg = FunctionGraph([A, b], [x])
1012+
assert not any(isinstance(node.op, DimShuffle) for node in fg.toposort())
1013+
1014+
f = function([A, b], x, mode="FAST_RUN")
1015+
assert any(isinstance(node.op, DimShuffle) for node in f.maker.fgraph.toposort())
1016+
1017+
A_val = np.triu(np.random.normal(size=(5, 5)))
1018+
b_val = np.random.normal(size=(5,))
1019+
1020+
new_kwargs = kwargs.copy()
1021+
if "transposed" in kwargs:
1022+
new_kwargs["transposed"] = not kwargs["transposed"]
1023+
if "trans" in kwargs:
1024+
new_kwargs["trans"] = int(not bool(kwargs["trans"]))
1025+
new_kwargs["lower"] = not kwargs["lower"]
1026+
1027+
g = function([A, b], solve_op(A.T, b, **new_kwargs), mode="FAST_COMPILE")
1028+
np.testing.assert_allclose(f(A_val, b_val), g(A_val, b_val))

tests/tensor/test_slinalg.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -332,13 +332,18 @@ def test_solve_gradient(
332332

333333
def A_func(x):
334334
if assume_a == "pos":
335-
return x @ x.T
335+
x = x @ x.T
336336
elif assume_a == "sym":
337-
return (x + x.T) / 2
338-
else:
339-
return x
337+
x = (x + x.T) / 2
338+
return x
340339

341-
solve_op = functools.partial(solve, assume_a=assume_a, b_ndim=len(b_size))
340+
solve_op = functools.partial(
341+
solve,
342+
assume_a=assume_a,
343+
lower=lower,
344+
transposed=transposed,
345+
b_ndim=len(b_size),
346+
)
342347

343348
# To correctly check the gradients, we need to include a transformation from the space of unconstrained matrices
344349
# (A) to a valid input matrix for the given solver. This is done by the A_func function. If this isn't included,

0 commit comments

Comments
 (0)