Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend log_softmax stabilization rewrite to graphs with indexing and expand_dims #352

Merged
merged 2 commits into from
Jun 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 58 additions & 29 deletions pytensor/tensor/rewriting/special.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,78 @@
from pytensor import scalar as aes
from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.math import Sum, exp
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import Sum, exp, log
from pytensor.tensor.math import sum as at_sum
from pytensor.tensor.math import true_div
from pytensor.tensor.rewriting.basic import register_specialize
from pytensor.tensor.rewriting.basic import register_stabilize
from pytensor.tensor.rewriting.math import local_mul_canonizer
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
from pytensor.tensor.subtensor import AdvancedIncSubtensor
from pytensor.tensor.special import Softmax, SoftmaxGrad, log_softmax
from pytensor.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedSubtensor,
AdvancedSubtensor1,
Subtensor,
)
from pytensor.tensor.type import (
values_eq_approx_remove_inf,
values_eq_approx_remove_nan,
)


# This is not registered in stabilize, as it cause some crossentropy
# optimization to not be inserted.
@register_specialize("stabilize", "fast_compile")
@node_rewriter([Elemwise])
subtensor_ops = (
Subtensor,
AdvancedSubtensor,
AdvancedSubtensor1,
)


@register_stabilize
@node_rewriter([log])
def local_logsoftmax(fgraph, node):
"""
Detect Log(Softmax(x)) and replace it with LogSoftmax(x)
This also lifts Subtensor or Dimshuffle operations that could be in between log and softmax
Note: only forward pass is affected
"""
if (
isinstance(node.op, Elemwise)
and isinstance(node.op.scalar_op, aes.Log)
and len(node.inputs) == 1
and node.inputs[0].owner is not None
and isinstance(node.inputs[0].owner.op, Softmax)
):
inVars = node.inputs[0].owner.inputs[0]
new_op = LogSoftmax(axis=node.inputs[0].owner.op.axis)
ret = new_op(inVars)
ret.tag.values_eq_approx = values_eq_approx_remove_inf
copy_stack_trace([node.inputs[0], node.outputs[0]], ret)
return [ret]

def find_softmax_under_lifteable_ops(inp_node, ops_to_lift):
if inp_node is None:
return

if isinstance(inp_node.op, Softmax):
return inp_node

if isinstance(inp_node.op, subtensor_ops):
ops_to_lift.append((inp_node.op, inp_node.inputs[1:]))
return find_softmax_under_lifteable_ops(
inp_node.inputs[0].owner, ops_to_lift
)

if isinstance(inp_node.op, DimShuffle):
ops_to_lift.append((inp_node.op, ()))
return find_softmax_under_lifteable_ops(
inp_node.inputs[0].owner, ops_to_lift
)

ops_to_lift = []
softmax_node = find_softmax_under_lifteable_ops(node.inputs[0].owner, ops_to_lift)

if softmax_node is None:
return

ret = log_softmax(softmax_node.inputs[0], axis=softmax_node.op.axis)
ret.tag.values_eq_approx = values_eq_approx_remove_inf

# Lift ops that used to be between log and softmax
for op_to_lift, parameters in reversed(ops_to_lift):
ret = op_to_lift(ret, *parameters)

copy_stack_trace(node.outputs, ret)
return [ret]


# This is not registered in stabilize, as it cause some crossentropy
# optimization to not be inserted.
@register_specialize("stabilize", "fast_compile")
@register_stabilize
@node_rewriter([SoftmaxGrad])
def local_logsoftmax_grad(fgraph, node):
"""
Expand All @@ -50,9 +81,7 @@ def local_logsoftmax_grad(fgraph, node):
Note: only grad is affected
"""
if (
isinstance(node.op, SoftmaxGrad)
and len(node.inputs) == 2
and node.inputs[0].owner is not None
node.inputs[0].owner is not None
and node.inputs[0].owner.op == true_div
and len(node.inputs[0].owner.inputs) >= 2
and node.inputs[0].owner.inputs[1].owner is not None
Expand Down
36 changes: 2 additions & 34 deletions pytensor/tensor/special.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import warnings
from textwrap import dedent

import numpy as np
Expand Down Expand Up @@ -483,25 +482,8 @@ def c_code_cache_version():
return (4,)


UNSET_AXIS = object()


def softmax(c, axis=UNSET_AXIS):
if axis is UNSET_AXIS:
warnings.warn(
"Softmax now accepts an axis argument. For backwards-compatibility it defaults to -1 when not specified, "
"but in the future the default will be `None`.\nTo suppress this warning specify axis explicitly.",
FutureWarning,
)
axis = -1

def softmax(c, axis=None):
c = as_tensor_variable(c)
if c.ndim == 1:
# TODO: Create Specific warning type that can be suppressed?
warnings.warn(
"Softmax no longer converts a vector to a row matrix.",
UserWarning,
)
return Softmax(axis=axis)(c)


Expand Down Expand Up @@ -749,22 +731,8 @@ def c_code_cache_version():
return (1,)


def log_softmax(c, axis=UNSET_AXIS):
if axis is UNSET_AXIS:
warnings.warn(
"logsoftmax now accepts an axis argument. For backwards-compatibility it defaults to -1 when not specified, "
"but in the future the default will be `None`.\nTo suppress this warning specify axis explicitly.",
FutureWarning,
)
axis = -1

def log_softmax(c, axis=None):
c = as_tensor_variable(c)
if c.ndim == 1:
# TODO: Create Specific warning type that can be suppressed?
warnings.warn(
"Softmax no longer converts a vector to a row matrix.",
UserWarning,
)
return LogSoftmax(axis=axis)(c)


Expand Down
2 changes: 1 addition & 1 deletion tests/d3viz/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(self, nfeatures=100, noutputs=10, nhiddens=50, rng=None):

wy = shared(self.rng.normal(0, 1, (nhiddens, noutputs)))
by = shared(np.zeros(noutputs), borrow=True)
y = softmax(at.dot(h, wy) + by)
y = softmax(at.dot(h, wy) + by, axis=-1)
self.inputs = [x]
self.outputs = [y]

Expand Down
59 changes: 34 additions & 25 deletions tests/tensor/rewriting/test_special.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import pytest
import scipy.special

import pytensor
from pytensor import shared
Expand Down Expand Up @@ -35,6 +36,37 @@ def test_local_logsoftmax_rewrite(self, axis):
_fast_run_rewrites.rewrite(fgraph)
assert isinstance(fgraph.outputs[0].owner.op, LogSoftmax)
assert check_stack_trace(fgraph, ops_to_check=LogSoftmax)
assert check_stack_trace(fgraph, ops_to_check="all")

@pytest.mark.parametrize("axis", [None, 0, -1])
@pytest.mark.parametrize("idx0", [0, slice(1, None), slice(None)])
@pytest.mark.parametrize("idx1", [None, [0, 1, 1, -1]])
def test_logsoftmax_subtensor_dimshuffle(self, axis, idx0, idx1):
"""Test that stabilization is introduced even when subtensor or dimshuffle operations
are present between log and softmax.
"""
logit_p = matrix("logit_p")
p = softmax(logit_p, axis=axis)
p_indexed = p[(idx0, idx1)]
out = log(p_indexed)

# Don't waste time with C compilation
with config.change_flags(cxx=""):
mode = get_mode(None).including("stabilize")
fn = pytensor.function([logit_p], out, mode=mode)

assert not any(
isinstance(node.op, Softmax) for node in fn.maker.fgraph.apply_nodes
)

# This range would lead to underflow to -inf without the stabilization
test_logit_p = np.array(
[[-10.0, -10.0, 999.0], [999.0, 990.0, -10.0]], dtype=config.floatX
)
np.testing.assert_allclose(
fn(logit_p=test_logit_p),
scipy.special.log_softmax(test_logit_p, axis=axis)[(idx0, idx1)],
)

@pytest.mark.parametrize("axis", [None, 0, -1])
def test_local_logsoftmax_grad_rewrite(self, axis):
Expand All @@ -46,7 +78,7 @@ def test_local_logsoftmax_grad_rewrite(self, axis):
"""

m = config.mode
m = get_mode(m)
m = get_mode(m).including("stabilize")
m.check_isfinite = False
# some inputs that are large to make the gradient explode in the non
# rewritten case
Expand All @@ -72,7 +104,7 @@ def test_logsoftmax_grad_true_div_elemwise(self):
"""

x = matrix("x")
y = log(softmax(x))
y = log(softmax(x, axis=-1))
g = pytensor.tensor.grad(y.sum(), x)

softmax_grad_node = g.owner
Expand All @@ -91,29 +123,6 @@ def test_logsoftmax_grad_true_div_elemwise(self):
assert SoftmaxGrad(axis=-1) in [n.op for n in fgraph.toposort()]


def test_log_softmax_stabilization():
mode = pytensor.compile.mode.get_default_mode()
mode = mode.including("local_log_softmax", "specialize")

x = matrix()
y = softmax(x)
z = log(y)

fgraph = FunctionGraph([x], [z])
_fast_run_rewrites(fgraph)
assert check_stack_trace(fgraph, ops_to_check="all")

# Check that the softmax has been rewritten
for node in fgraph.toposort():
assert not isinstance(node.op, Softmax)

# Call the function so debug mode can verify the rewritten version matches
# the un-rewritten version
f = pytensor.function([x], z, mode=mode)
rng = np.random.default_rng(utt.fetch_seed())
f(np.cast[config.floatX](rng.random((2, 3))))


def test_softmax_graph():
"""Make sure that sotfmax expressions are turned into
a softmax Op.
Expand Down
4 changes: 3 additions & 1 deletion tests/test_rop.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,9 @@ def test_sum(self):
self.check_mat_rop_lop(self.mx.sum(axis=1), (self.mat_in_shape[0],))

def test_softmax(self):
self.check_rop_lop(pytensor.tensor.special.softmax(self.x), self.in_shape)
self.check_rop_lop(
pytensor.tensor.special.softmax(self.x, axis=-1), self.in_shape
)

def test_alloc(self):
# Alloc of the sum of x into a vector
Expand Down