diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 5f2e8cf388..1de6dbb373 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -569,3 +569,45 @@ def svd_uv_merge(fgraph, node): or len(fgraph.clients[cl.outputs[2]]) > 0 ): return [cl.outputs[1]] + + +@register_canonicalize +@register_stabilize +@node_rewriter([Blockwise]) +def rewrite_inv_inv(fgraph, node): + """ + This rewrite takes advantage of the fact that if there are two consecutive inverse operations (inv(inv(input))), we get back our original input without having to compute inverse once. + + Here, we check for direct inverse operations (inv/pinv) and allows for any combination of these "inverse" nodes to be simply rewritten. + + Parameters + ---------- + fgraph: FunctionGraph + Function graph being optimized + node: Apply + Node of the function graph to be optimized + + Returns + ------- + list of Variable, optional + List of optimized variables, or None if no optimization was performed + """ + valid_inverses = (MatrixInverse, MatrixPinv) + # Check if its a valid inverse operation (either inv/pinv) + # In case the outer operation is an inverse, it directly goes to the next step of finding inner operation + # If the outer operation is not a valid inverse, we do not apply this rewrite + if not isinstance(node.op.core_op, valid_inverses): + return None + + potential_inner_inv = node.inputs[0].owner + if potential_inner_inv is None or potential_inner_inv.op is None: + return None + + # Check if inner op is blockwise and and possible inv + if not ( + potential_inner_inv + and isinstance(potential_inner_inv.op, Blockwise) + and isinstance(potential_inner_inv.op.core_op, valid_inverses) + ): + return None + return [potential_inner_inv.inputs[0]] diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 0bc064fe65..7353a82be0 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -10,6 +10,7 @@ from pytensor import tensor as pt from pytensor.compile import get_default_mode from pytensor.configdefaults import config +from pytensor.graph.rewriting.utils import rewrite_graph from pytensor.tensor import swapaxes from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import DimShuffle @@ -554,3 +555,16 @@ def test_svd_uv_merge(): assert node.op.compute_uv svd_counter += 1 assert svd_counter == 1 + + +@pytest.mark.parametrize("inv_op_1", ["inv", "pinv"]) +@pytest.mark.parametrize("inv_op_2", ["inv", "pinv"]) +def test_inv_inv_rewrite(inv_op_1, inv_op_2): + def get_pt_function(x, op_name): + return getattr(pt.linalg, op_name)(x) + + x = pt.matrix("x") + op1 = get_pt_function(x, inv_op_1) + op2 = get_pt_function(op1, inv_op_2) + rewritten_out = rewrite_graph(op2) + assert rewritten_out == x