From a31e8279786644e0626c672f07de90440c32b92e Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 20 Jun 2023 14:12:47 +0200 Subject: [PATCH] Extend log_softmax rewrite and run it in `stabilize` --- pytensor/tensor/rewriting/special.py | 87 +++++++++++++++++--------- tests/tensor/rewriting/test_special.py | 57 ++++++++++------- 2 files changed, 91 insertions(+), 53 deletions(-) diff --git a/pytensor/tensor/rewriting/special.py b/pytensor/tensor/rewriting/special.py index 78dc5390ab..c893439e4d 100644 --- a/pytensor/tensor/rewriting/special.py +++ b/pytensor/tensor/rewriting/special.py @@ -1,47 +1,78 @@ -from pytensor import scalar as aes from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter -from pytensor.tensor.elemwise import DimShuffle, Elemwise -from pytensor.tensor.math import Sum, exp +from pytensor.tensor.elemwise import DimShuffle +from pytensor.tensor.math import Sum, exp, log from pytensor.tensor.math import sum as at_sum from pytensor.tensor.math import true_div -from pytensor.tensor.rewriting.basic import register_specialize +from pytensor.tensor.rewriting.basic import register_stabilize from pytensor.tensor.rewriting.math import local_mul_canonizer -from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad -from pytensor.tensor.subtensor import AdvancedIncSubtensor +from pytensor.tensor.special import Softmax, SoftmaxGrad, log_softmax +from pytensor.tensor.subtensor import ( + AdvancedIncSubtensor, + AdvancedSubtensor, + AdvancedSubtensor1, + Subtensor, +) from pytensor.tensor.type import ( values_eq_approx_remove_inf, values_eq_approx_remove_nan, ) -# This is not registered in stabilize, as it cause some crossentropy -# optimization to not be inserted. -@register_specialize("stabilize", "fast_compile") -@node_rewriter([Elemwise]) +subtensor_ops = ( + Subtensor, + AdvancedSubtensor, + AdvancedSubtensor1, +) + + +@register_stabilize +@node_rewriter([log]) def local_logsoftmax(fgraph, node): """ Detect Log(Softmax(x)) and replace it with LogSoftmax(x) + This also lifts Subtensor or Dimshuffle operations that could be in between log and softmax + Note: only forward pass is affected """ - if ( - isinstance(node.op, Elemwise) - and isinstance(node.op.scalar_op, aes.Log) - and len(node.inputs) == 1 - and node.inputs[0].owner is not None - and isinstance(node.inputs[0].owner.op, Softmax) - ): - inVars = node.inputs[0].owner.inputs[0] - new_op = LogSoftmax(axis=node.inputs[0].owner.op.axis) - ret = new_op(inVars) - ret.tag.values_eq_approx = values_eq_approx_remove_inf - copy_stack_trace([node.inputs[0], node.outputs[0]], ret) - return [ret] + + def find_softmax_under_lifteable_ops(inp_node, ops_to_lift): + if inp_node is None: + return + + if isinstance(inp_node.op, Softmax): + return inp_node + + if isinstance(inp_node.op, subtensor_ops): + ops_to_lift.append((inp_node.op, inp_node.inputs[1:])) + return find_softmax_under_lifteable_ops( + inp_node.inputs[0].owner, ops_to_lift + ) + + if isinstance(inp_node.op, DimShuffle): + ops_to_lift.append((inp_node.op, ())) + return find_softmax_under_lifteable_ops( + inp_node.inputs[0].owner, ops_to_lift + ) + + ops_to_lift = [] + softmax_node = find_softmax_under_lifteable_ops(node.inputs[0].owner, ops_to_lift) + + if softmax_node is None: + return + + ret = log_softmax(softmax_node.inputs[0], axis=softmax_node.op.axis) + ret.tag.values_eq_approx = values_eq_approx_remove_inf + + # Lift ops that used to be between log and softmax + for op_to_lift, parameters in reversed(ops_to_lift): + ret = op_to_lift(ret, *parameters) + + copy_stack_trace(node.outputs, ret) + return [ret] -# This is not registered in stabilize, as it cause some crossentropy -# optimization to not be inserted. -@register_specialize("stabilize", "fast_compile") +@register_stabilize @node_rewriter([SoftmaxGrad]) def local_logsoftmax_grad(fgraph, node): """ @@ -50,9 +81,7 @@ def local_logsoftmax_grad(fgraph, node): Note: only grad is affected """ if ( - isinstance(node.op, SoftmaxGrad) - and len(node.inputs) == 2 - and node.inputs[0].owner is not None + node.inputs[0].owner is not None and node.inputs[0].owner.op == true_div and len(node.inputs[0].owner.inputs) >= 2 and node.inputs[0].owner.inputs[1].owner is not None diff --git a/tests/tensor/rewriting/test_special.py b/tests/tensor/rewriting/test_special.py index 089a6d0a6b..799b805aa4 100644 --- a/tests/tensor/rewriting/test_special.py +++ b/tests/tensor/rewriting/test_special.py @@ -1,5 +1,6 @@ import numpy as np import pytest +import scipy.special import pytensor from pytensor import shared @@ -35,6 +36,37 @@ def test_local_logsoftmax_rewrite(self, axis): _fast_run_rewrites.rewrite(fgraph) assert isinstance(fgraph.outputs[0].owner.op, LogSoftmax) assert check_stack_trace(fgraph, ops_to_check=LogSoftmax) + assert check_stack_trace(fgraph, ops_to_check="all") + + @pytest.mark.parametrize("axis", [None, 0, -1]) + @pytest.mark.parametrize("idx0", [0, slice(1, None), slice(None)]) + @pytest.mark.parametrize("idx1", [None, [0, 1, 1, -1]]) + def test_logsoftmax_subtensor_dimshuffle(self, axis, idx0, idx1): + """Test that stabilization is introduced even when subtensor or dimshuffle operations + are present between log and softmax. + """ + logit_p = matrix("logit_p") + p = softmax(logit_p, axis=axis) + p_indexed = p[(idx0, idx1)] + out = log(p_indexed) + + # Don't waste time with C compilation + with config.change_flags(cxx=""): + mode = get_mode(None).including("stabilize") + fn = pytensor.function([logit_p], out, mode=mode) + + assert not any( + isinstance(node.op, Softmax) for node in fn.maker.fgraph.apply_nodes + ) + + # This range would lead to underflow to -inf without the stabilization + test_logit_p = np.array( + [[-10.0, -10.0, 999.0], [999.0, 990.0, -10.0]], dtype=config.floatX + ) + np.testing.assert_allclose( + fn(logit_p=test_logit_p), + scipy.special.log_softmax(test_logit_p, axis=axis)[(idx0, idx1)], + ) @pytest.mark.parametrize("axis", [None, 0, -1]) def test_local_logsoftmax_grad_rewrite(self, axis): @@ -46,7 +78,7 @@ def test_local_logsoftmax_grad_rewrite(self, axis): """ m = config.mode - m = get_mode(m) + m = get_mode(m).including("stabilize") m.check_isfinite = False # some inputs that are large to make the gradient explode in the non # rewritten case @@ -91,29 +123,6 @@ def test_logsoftmax_grad_true_div_elemwise(self): assert SoftmaxGrad(axis=-1) in [n.op for n in fgraph.toposort()] -def test_log_softmax_stabilization(): - mode = pytensor.compile.mode.get_default_mode() - mode = mode.including("local_log_softmax", "specialize") - - x = matrix() - y = softmax(x, axis=-1) - z = log(y) - - fgraph = FunctionGraph([x], [z]) - _fast_run_rewrites(fgraph) - assert check_stack_trace(fgraph, ops_to_check="all") - - # Check that the softmax has been rewritten - for node in fgraph.toposort(): - assert not isinstance(node.op, Softmax) - - # Call the function so debug mode can verify the rewritten version matches - # the un-rewritten version - f = pytensor.function([x], z, mode=mode) - rng = np.random.default_rng(utt.fetch_seed()) - f(np.cast[config.floatX](rng.random((2, 3)))) - - def test_softmax_graph(): """Make sure that sotfmax expressions are turned into a softmax Op.