Skip to content

Commit

Permalink
working specialised rewrite + test
Browse files Browse the repository at this point in the history
  • Loading branch information
tanish1729 committed Nov 1, 2024
1 parent 6767600 commit 816cc77
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 10 deletions.
26 changes: 16 additions & 10 deletions pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from collections.abc import Callable
from typing import cast

import numpy as np

from pytensor import Variable
from pytensor import tensor as pt
from pytensor.graph import Apply, FunctionGraph
Expand Down Expand Up @@ -967,23 +969,24 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
return [eye_input * (non_eye_input**0.5)]


# SLogDet Rewrites
def check_log_abs_det(fgraph, client):
def _check_log_abs_det(fgraph, client):
# First, we find abs
if not (isinstance(client.op, Elemwise) and isinstance(client.op.scalar_op, Abs)):
return False

# Check whether log is a client of abs
for client_2 in fgraph.clients[client.outputs[0]]:
if not (
isinstance(client_2.op, Elemwise) and isinstance(client_2.op.scalar_op, Log)
isinstance(client_2[0].op, Elemwise)
and isinstance(client_2[0].op.scalar_op, Log)
):
return False

return True


@node_rewriter(tracks=[det])
@register_specialize
@node_rewriter([det])
def slogdet_specialization(fgraph, node):
replacements = {}
for client in fgraph.clients[node.outputs[0]]:
Expand All @@ -996,19 +999,22 @@ def slogdet_specialization(fgraph, node):
replacements[client[0].outputs[0]] = sign_det_x

# Check for log(abs(det))
elif check_log_abs_det(fgraph, client[0]):
elif _check_log_abs_det(fgraph, client[0]):
x = node.inputs[0]
sign_det_x, slog_det_x = SLogDet()(x)
replacements[fgraph.clients[client[0].outputs[0]][0][0].outputs[0]] = (
slog_det_x
)

# Check for log(det)
# elif isinstance(client[0].op, Elemwise) and isinstance(
# client[0].op.scalar_op, Log
# ):
# pass
# replacements[client[0].owner.outputs[0]] = pt.where(pt.eq(sign_det_x, -1), np.nan, slog_det_x)
elif isinstance(client[0].op, Elemwise) and isinstance(
client[0].op.scalar_op, Log
):
x = node.inputs[0]
sign_det_x, slog_det_x = SLogDet()(x)
replacements[client[0].outputs[0]] = pt.where(
pt.eq(sign_det_x, -1), np.nan, slog_det_x
)

# Det is used directly for something else, don't rewrite to avoid computing two dets
else:
Expand Down
21 changes: 21 additions & 0 deletions tests/tensor/rewriting/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
KroneckerProduct,
MatrixInverse,
MatrixPinv,
SLogDet,
matrix_inverse,
svd,
)
Expand Down Expand Up @@ -900,3 +901,23 @@ def test_rewrite_cholesky_diag_to_sqrt_diag_not_applied():
f_rewritten = function([x], z_cholesky, mode="FAST_RUN")
nodes = f_rewritten.maker.fgraph.apply_nodes
assert any(isinstance(node.op, Cholesky) for node in nodes)


def test_slogdet_specialisation():
x = pt.dmatrix("x")
det_x = pt.linalg.det(x)
log_abs_det_x = pt.log(pt.abs(det_x))
sign_det_x = pt.sign(det_x)
exp_det_x = pt.exp(det_x)
# sign(det(x))
f = function([x], [sign_det_x], mode="FAST_RUN")
nodes = f.maker.fgraph.apply_nodes
assert any(isinstance(node.op, SLogDet) for node in nodes)
# log(abs(det(x)))
f = function([x], [log_abs_det_x], mode="FAST_RUN")
nodes = f.maker.fgraph.apply_nodes
assert any(isinstance(node.op, SLogDet) for node in nodes)
# other functions (rewrite shouldnt be applied to these)
f = function([x], [exp_det_x], mode="FAST_RUN")
nodes = f.maker.fgraph.apply_nodes
assert not any(isinstance(node.op, SLogDet) for node in nodes)

0 comments on commit 816cc77

Please sign in to comment.