Skip to content

Added rewrite for matrix inv(inv(x)) -> x #893

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,3 +569,45 @@ def svd_uv_merge(fgraph, node):
or len(fgraph.clients[cl.outputs[2]]) > 0
):
return [cl.outputs[1]]


@register_canonicalize
@register_stabilize
@node_rewriter([Blockwise])
def rewrite_inv_inv(fgraph, node):
"""
This rewrite takes advantage of the fact that if there are two consecutive inverse operations (inv(inv(input))), we get back our original input without having to compute inverse once.

Here, we check for direct inverse operations (inv/pinv) and allows for any combination of these "inverse" nodes to be simply rewritten.

Parameters
----------
fgraph: FunctionGraph
Function graph being optimized
node: Apply
Node of the function graph to be optimized

Returns
-------
list of Variable, optional
List of optimized variables, or None if no optimization was performed
"""
valid_inverses = (MatrixInverse, MatrixPinv)
# Check if its a valid inverse operation (either inv/pinv)
# In case the outer operation is an inverse, it directly goes to the next step of finding inner operation
# If the outer operation is not a valid inverse, we do not apply this rewrite
if not isinstance(node.op.core_op, valid_inverses):
return None
Comment on lines +576 to +600
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we predefine the two possible pinv as blockwise (like we do for matrix_inverse) the rewrite can track more specifically and avoids being called for any Blockwise it sees:

Suggested change
@node_rewriter([Blockwise])
def rewrite_inv_inv(fgraph, node):
"""
This rewrite takes advantage of the fact that if there are two consecutive inverse operations (inv(inv(input))), we get back our original input without having to compute inverse once.
Here, we check for direct inverse operations (inv/pinv) and allows for any combination of these "inverse" nodes to be simply rewritten.
Parameters
----------
fgraph: FunctionGraph
Function graph being optimized
node: Apply
Node of the function graph to be optimized
Returns
-------
list of Variable, optional
List of optimized variables, or None if no optimization was performed
"""
valid_inverses = (MatrixInverse, MatrixPinv)
# Check if its a valid inverse operation (either inv/pinv)
# In case the outer operation is an inverse, it directly goes to the next step of finding inner operation
# If the outer operation is not a valid inverse, we do not apply this rewrite
if not isinstance(node.op.core_op, valid_inverses):
return None
@node_rewriter([matrix_inverse, matrix_pinv_hermitian, matrix_pinv_non_hermitian])
def rewrite_inv_inv(fgraph, node):
"""
This rewrite takes advantage of the fact that if there are two consecutive inverse operations (inv(inv(input))), we get back our original input without having to compute inverse once.
Here, we check for direct inverse operations (inv/pinv) and allows for any combination of these "inverse" nodes to be simply rewritten.
Parameters
----------
fgraph: FunctionGraph
Function graph being optimized
node: Apply
Node of the function graph to be optimized
Returns
-------
list of Variable, optional
List of optimized variables, or None if no optimization was performed
"""

Need to predefine those matrix_pinv*.. The helper "pinv" should return the predefined Ops instead of creating new ones to avoid Op duplication

Copy link
Contributor Author

@tanish1729 tanish1729 Jul 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i dont quite understand the last line u wrote about

Need to predefine those matrix_pinv*.. The helper "pinv" should return the predefined Ops instead of creating new ones to avoid Op duplication

also, as a general rule for rewrites, is it better that they should be tracking more specific Ops instead of just Blockwise

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, as a general rule for rewrites, is it better that they should be tracking more specific Ops instead of just Blockwise

Is that a question? The answer is yes. It avoids useless calls to the rewrite function when the Op is not in the graph

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but for the check of the ops inside, i will have to use the method that i am already doing

Copy link
Member

@ricardoV94 ricardoV94 Jul 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding pinv, the helper function is here:

def pinv(x, hermitian=False):

The idea is instead:

pinv_hermitian = Blockwise(Pinv(hermitian=True))
pinv_non_hermitian = Blockwise(Pinv(hermitian=False))

def pinv(x, hermitian=False):
  ...
  return pinv_hermitian if hermitian else pinv_non_hermitian

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But actually let's not to that now, have to think. This makes initialization a bit slower because we have to create more instances...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh alright. but this is essentially the same thing as first checking the op as blockwise and then the core op as one of the valid inverses. whats the difference in both these ideas

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The difference is the rewrite is not even considered if the node is not a Blockwise(Pinv). This way the rewrite is attempted on every single Blockwise in the graph.

It's just an optimization, not a question of correctness


potential_inner_inv = node.inputs[0].owner
if potential_inner_inv is None or potential_inner_inv.op is None:
return None

# Check if inner op is blockwise and and possible inv
if not (
potential_inner_inv
and isinstance(potential_inner_inv.op, Blockwise)
and isinstance(potential_inner_inv.op.core_op, valid_inverses)
):
return None
return [potential_inner_inv.inputs[0]]
14 changes: 14 additions & 0 deletions tests/tensor/rewriting/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pytensor import tensor as pt
from pytensor.compile import get_default_mode
from pytensor.configdefaults import config
from pytensor.graph.rewriting.utils import rewrite_graph
from pytensor.tensor import swapaxes
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle
Expand Down Expand Up @@ -554,3 +555,16 @@ def test_svd_uv_merge():
assert node.op.compute_uv
svd_counter += 1
assert svd_counter == 1


@pytest.mark.parametrize("inv_op_1", ["inv", "pinv"])
@pytest.mark.parametrize("inv_op_2", ["inv", "pinv"])
def test_inv_inv_rewrite(inv_op_1, inv_op_2):
def get_pt_function(x, op_name):
return getattr(pt.linalg, op_name)(x)

x = pt.matrix("x")
op1 = get_pt_function(x, inv_op_1)
op2 = get_pt_function(op1, inv_op_2)
rewritten_out = rewrite_graph(op2)
assert rewritten_out == x