Skip to content

Commit

Permalink
Extend log_softmax rewrite and run it in stabilize
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jun 20, 2023
1 parent 3fe07f3 commit 9aa9d06
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 52 deletions.
84 changes: 55 additions & 29 deletions pytensor/tensor/rewriting/special.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,75 @@
from pytensor import scalar as aes
from numpy.core.numeric import normalize_axis_index # type: ignore

from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.math import Sum, exp
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import Sum, exp, log
from pytensor.tensor.math import sum as at_sum
from pytensor.tensor.math import true_div
from pytensor.tensor.rewriting.basic import register_specialize
from pytensor.tensor.rewriting.basic import register_stabilize
from pytensor.tensor.rewriting.math import local_mul_canonizer
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
from pytensor.tensor.subtensor import AdvancedIncSubtensor
from pytensor.tensor.special import Softmax, SoftmaxGrad, log_softmax
from pytensor.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedSubtensor,
AdvancedSubtensor1,
Subtensor,
)
from pytensor.tensor.type import (
values_eq_approx_remove_inf,
values_eq_approx_remove_nan,
)


# This is not registered in stabilize, as it cause some crossentropy
# optimization to not be inserted.
@register_specialize("stabilize", "fast_compile")
@node_rewriter([Elemwise])
subtensor_ops = (
Subtensor,
AdvancedSubtensor,
AdvancedSubtensor1,
)


@register_stabilize
@node_rewriter([log])
def local_logsoftmax(fgraph, node):
"""
Detect Log(Softmax(x)) and replace it with LogSoftmax(x)
This also lifts Subtensor or Dimshuffle operations that could be in between log and softmax
Note: only forward pass is affected
"""
if (
isinstance(node.op, Elemwise)
and isinstance(node.op.scalar_op, aes.Log)
and len(node.inputs) == 1
and node.inputs[0].owner is not None
and isinstance(node.inputs[0].owner.op, Softmax)
):
inVars = node.inputs[0].owner.inputs[0]
new_op = LogSoftmax(axis=node.inputs[0].owner.op.axis)
ret = new_op(inVars)
ret.tag.values_eq_approx = values_eq_approx_remove_inf
copy_stack_trace([node.inputs[0], node.outputs[0]], ret)
return [ret]
def find_softmax_under_lifteable_ops(inp_node, ops_to_lift):
if inp_node is None:
return

if isinstance(inp_node.op, Softmax):
return inp_node

if isinstance(inp_node.op, subtensor_ops):
ops_to_lift.append((inp_node.op, inp_node.inputs[1:]))
return find_softmax_under_lifteable_ops(inp_node.inputs[0].owner, ops_to_lift)

if isinstance(inp_node.op, DimShuffle):
ops_to_lift.append((inp_node.op, ()))
return find_softmax_under_lifteable_ops(inp_node.inputs[0].owner, ops_to_lift)

ops_to_lift = []
softmax_node = find_softmax_under_lifteable_ops(node.inputs[0].owner, ops_to_lift)

if softmax_node is None:
return

ret = log_softmax(softmax_node.inputs[0], axis=softmax_node.op.axis)
ret.tag.values_eq_approx = values_eq_approx_remove_inf

# Lift ops that used to be between log and softmax
for op_to_lift, parameters in reversed(ops_to_lift):
ret = op_to_lift(ret, *parameters)

copy_stack_trace(node.outputs, ret)
return [ret]


# This is not registered in stabilize, as it cause some crossentropy
# optimization to not be inserted.
@register_specialize("stabilize", "fast_compile")
@register_stabilize
@node_rewriter([SoftmaxGrad])
def local_logsoftmax_grad(fgraph, node):
"""
Expand All @@ -50,9 +78,7 @@ def local_logsoftmax_grad(fgraph, node):
Note: only grad is affected
"""
if (
isinstance(node.op, SoftmaxGrad)
and len(node.inputs) == 2
and node.inputs[0].owner is not None
node.inputs[0].owner is not None
and node.inputs[0].owner.op == true_div
and len(node.inputs[0].owner.inputs) >= 2
and node.inputs[0].owner.inputs[1].owner is not None
Expand Down
53 changes: 30 additions & 23 deletions tests/tensor/rewriting/test_special.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import pytest
import scipy.special

import pytensor
from pytensor import shared
Expand All @@ -11,7 +12,8 @@
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.tensor.math import add, exp, log, true_div
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad, softmax
from pytensor.tensor.type import matrix
from pytensor.tensor.subtensor import AdvancedSubtensor, group_indices
from pytensor.tensor.type import matrix, tensor
from tests import unittest_tools as utt


Expand All @@ -35,6 +37,33 @@ def test_local_logsoftmax_rewrite(self, axis):
_fast_run_rewrites.rewrite(fgraph)
assert isinstance(fgraph.outputs[0].owner.op, LogSoftmax)
assert check_stack_trace(fgraph, ops_to_check=LogSoftmax)
assert check_stack_trace(fgraph, ops_to_check="all")

@pytest.mark.parametrize("axis", [None, 0, -1])
@pytest.mark.parametrize("idx0", [0, slice(1, None), slice(None)])
@pytest.mark.parametrize("idx1", [None, [0, 1, 1, -1]])
def test_logsoftmax_subtensor_dimshuffle(self, axis, idx0, idx1):
"""Test that stabilization is introduced even when subtensor or dimshuffle operations
are present between log and softmax.
"""

logit_p = matrix("logit_p")
p = softmax(logit_p, axis=axis)
p_indexed = p[(idx0, idx1)]
out = log(p_indexed)

# Don't waste time with C compilation
with config.change_flags(cxx=""):
fn = pytensor.function([logit_p], out)

assert not any(isinstance(node.op, Softmax) for node in fn.maker.fgraph.apply_nodes)

# This range would lead to underflow to -inf without the stabilization
test_logit_p = np.array([[-10., -10., 999.], [999., 990., -10.]])
np.testing.assert_allclose(
fn(logit_p=test_logit_p),
scipy.special.log_softmax(test_logit_p, axis=axis)[(idx0, idx1)],
)

@pytest.mark.parametrize("axis", [None, 0, -1])
def test_local_logsoftmax_grad_rewrite(self, axis):
Expand Down Expand Up @@ -91,28 +120,6 @@ def test_logsoftmax_grad_true_div_elemwise(self):
assert SoftmaxGrad(axis=-1) in [n.op for n in fgraph.toposort()]


def test_log_softmax_stabilization():
mode = pytensor.compile.mode.get_default_mode()
mode = mode.including("local_log_softmax", "specialize")

x = matrix()
y = softmax(x, axis=-1)
z = log(y)

fgraph = FunctionGraph([x], [z])
_fast_run_rewrites(fgraph)
assert check_stack_trace(fgraph, ops_to_check="all")

# Check that the softmax has been rewritten
for node in fgraph.toposort():
assert not isinstance(node.op, Softmax)

# Call the function so debug mode can verify the rewritten version matches
# the un-rewritten version
f = pytensor.function([x], z, mode=mode)
rng = np.random.default_rng(utt.fetch_seed())
f(np.cast[config.floatX](rng.random((2, 3))))


def test_softmax_graph():
"""Make sure that sotfmax expressions are turned into
Expand Down

0 comments on commit 9aa9d06

Please sign in to comment.