Skip to content

Commit

Permalink
Implement rewrite to stabilize log(softmax[idx])
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jun 20, 2023
1 parent df4183d commit 7c54695
Show file tree
Hide file tree
Showing 2 changed files with 172 additions and 14 deletions.
128 changes: 115 additions & 13 deletions pytensor/tensor/rewriting/special.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,51 @@
from pytensor import scalar as aes
from numpy.core.numeric import normalize_axis_index # type: ignore

from pytensor import Variable
from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.math import Sum, exp
from pytensor.tensor.basic import expand_dims
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.extra_ops import squeeze
from pytensor.tensor.math import Sum, exp, log, logsumexp
from pytensor.tensor.math import sum as at_sum
from pytensor.tensor.math import true_div
from pytensor.tensor.rewriting.basic import register_specialize
from pytensor.tensor.rewriting.math import local_mul_canonizer
from pytensor.tensor.rewriting.basic import register_specialize, register_stabilize
from pytensor.tensor.rewriting.math import local_log_sum_exp, local_mul_canonizer
from pytensor.tensor.rewriting.subtensor import is_full_slice
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
from pytensor.tensor.subtensor import AdvancedIncSubtensor
from pytensor.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedSubtensor,
AdvancedSubtensor1,
Subtensor,
indices_from_subtensor,
is_basic_idx,
)
from pytensor.tensor.type import (
values_eq_approx_remove_inf,
values_eq_approx_remove_nan,
)
from pytensor.tensor.type_other import NoneTypeT


subtensor_ops = (
Subtensor,
AdvancedSubtensor,
AdvancedSubtensor1,
)


# This is not registered in stabilize, as it cause some crossentropy
# optimization to not be inserted.
@register_specialize("stabilize", "fast_compile")
@node_rewriter([Elemwise])
@node_rewriter([log])
def local_logsoftmax(fgraph, node):
"""
Detect Log(Softmax(x)) and replace it with LogSoftmax(x)
Note: only forward pass is affected
"""
if (
isinstance(node.op, Elemwise)
and isinstance(node.op.scalar_op, aes.Log)
and len(node.inputs) == 1
and node.inputs[0].owner is not None
and isinstance(node.inputs[0].owner.op, Softmax)
if node.inputs[0].owner is not None and isinstance(
node.inputs[0].owner.op, Softmax
):
inVars = node.inputs[0].owner.inputs[0]
new_op = LogSoftmax(axis=node.inputs[0].owner.op.axis)
Expand All @@ -39,6 +55,92 @@ def local_logsoftmax(fgraph, node):
return [ret]


@register_stabilize
@node_rewriter([log])
def local_log_subtensor_softmax(fgraph, node):
"""Replace log(softmax(x, axis)[idx]) -> x[idx] - logsumexp(x, axis).
This can only be done when indexing happens over axis dims.
There can be non-indexed axis dims, but not non-axis indexed dims.
"""
[subtensor_var] = node.inputs
subtensor_node = subtensor_var.owner

if subtensor_node is not None and isinstance(subtensor_node.op, subtensor_ops):
softmax_var, *idxs = subtensor_node.inputs
softmax_node = softmax_var.owner
if softmax_node is not None and isinstance(softmax_node.op, Softmax):
if isinstance(subtensor_node.op, Subtensor):
idxs = indices_from_subtensor(idxs, subtensor_node.op.idx_list)

# TODO: support expand_dims
if any(
(isinstance(idx, Variable) and isinstance(idx.type, NoneTypeT))
for idx in idxs
):
return None

[x] = softmax_node.inputs
axis = softmax_node.op.axis
if axis is not None:
axis = normalize_axis_index(axis, ndim=x.type.ndim)

indexed_dims = [
dim for dim, idx in enumerate(idxs) if not is_full_slice(idx)
]

# We can only apply the rewrite when the softmax is applied across all indexed dims
if axis is not None and {axis} != set(indexed_dims):
return None

dims_to_expand = ()
dims_to_drop = ()
if isinstance(subtensor_node.op, Subtensor):
dims_to_drop = tuple(
dim for dim, idx in enumerate(idxs) if getattr(idx, "ndim", -1) == 0
)
if isinstance(subtensor_node.op, (AdvancedSubtensor, AdvancedSubtensor1)):
adv_dims_idxs = tuple(
(dim, idx) for dim, idx in enumerate(idxs) if not is_basic_idx(idx)
)
adv_dims = tuple(dim for dim, idx in adv_dims_idxs)
adv_idxs = tuple(idx for dim, idx in adv_dims_idxs)

# Boolean indexing not supported
if any(idx.dtype == "bool" for idx in adv_idxs):
return None

# Non-contiguous advanced indexing not supported
if tuple(range(adv_dims[0], adv_dims[-1] + 1)) != adv_dims:
return None

ndim_adv_idx = max(idx.ndim for idx in adv_idxs)
n_new_dims = ndim_adv_idx - len(adv_idxs)
# Advanced indexing introduces new dims
if n_new_dims > 0:
dims_to_expand = tuple(range(adv_dims[0], adv_dims[0] + n_new_dims))
# It reduces number of dims
elif n_new_dims < 0:
dims_to_drop = tuple(
range(adv_dims[0], adv_dims[0] + abs(n_new_dims))
)

# Rewrite stable form of logsumexp immediately
[x_logsumexp] = local_log_sum_exp.transform(
None, logsumexp(x, axis=axis, keepdims=True).owner
)

assert not (dims_to_drop and dims_to_expand)
if dims_to_expand:
x_logsumexp = expand_dims(x_logsumexp, dims_to_expand)
elif dims_to_drop:
x_logsumexp = squeeze(x_logsumexp, axis=dims_to_drop)
ret = x[tuple(idxs)] - x_logsumexp

copy_stack_trace(node.outputs, ret)
return [ret]


# This is not registered in stabilize, as it cause some crossentropy
# optimization to not be inserted.
@register_specialize("stabilize", "fast_compile")
Expand Down
58 changes: 57 additions & 1 deletion tests/tensor/rewriting/test_special.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import pytest
import scipy.special

import pytensor
from pytensor import shared
Expand All @@ -11,7 +12,8 @@
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.tensor.math import add, exp, log, true_div
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad, softmax
from pytensor.tensor.type import matrix
from pytensor.tensor.subtensor import AdvancedSubtensor, group_indices
from pytensor.tensor.type import matrix, tensor
from tests import unittest_tools as utt


Expand Down Expand Up @@ -130,3 +132,57 @@ def f(inputs):
return pytensor.grad(None, x, known_grads={y: inputs})

utt.verify_grad(f, [rng.random((3, 4))])


def _is_non_contiguous_adv_indexing(index_var):
if not isinstance(index_var.owner.op, AdvancedSubtensor):
return False
idx_groups = group_indices(index_var.owner.inputs[1:])
return len(idx_groups) > 3 or (len(idx_groups) == 3 and not idx_groups[0][0])


@pytest.mark.parametrize("axis", [None, 0, 1, 2])
@pytest.mark.parametrize("idx0", [0, slice(1, None), slice(None), None, [0, 1, 1, -1]])
@pytest.mark.parametrize("idx1", [0, slice(1, None), slice(None), None, [0, 1, 1, -1]])
@pytest.mark.parametrize(
"idx2", [0, slice(1, None), slice(None), None, [[0, 1, 1, -1], [-1, 1, 1, 0]]]
)
def test_log_subtensor_softmax(axis, idx0, idx1, idx2):
logit_p = tensor("logit_p", shape=(4, 3, 5))
p = softmax(logit_p, axis=axis)
p_indexed = p[(idx0, idx1, idx2)]
out = log(p_indexed)

# Don't waste time with C compilation
with config.change_flags(cxx=""):
fn = pytensor.function([logit_p], out)

rewrite_applies = True
if _is_non_contiguous_adv_indexing(p_indexed):
rewrite_applies = False
else:
if idx0 is None or idx1 is None or idx2 is None:
# Not yet implemented!
rewrite_applies = False
elif axis is not None:
indexed_dims = {
dim for dim, idx in enumerate((idx0, idx1, idx2)) if idx != slice(None)
}
# If no indexed dims, the rewrite doesn't actually apply
# but the log_softmax stabilization kicks-in and the output is also stable
if indexed_dims:
rewrite_applies = {axis} == indexed_dims

assert any(isinstance(node.op, Softmax) for node in fn.maker.fgraph.apply_nodes) != rewrite_applies

if not rewrite_applies:
return

# This range would lead to underflow to -inf without the stabilization
logit_ps = np.array([0.0, 1.0, 2.0, 3.0, 999.0])
rng = np.random.default_rng(156)
test_logit_p = rng.choice(logit_ps, size=(4, 3, 5))
np.testing.assert_allclose(
fn(logit_p=test_logit_p),
scipy.special.log_softmax(test_logit_p, axis=axis)[(idx0, idx1, idx2)],
)

0 comments on commit 7c54695

Please sign in to comment.