1- from pytensor import scalar as aes
1+ from numpy .core .numeric import normalize_axis_index # type: ignore
2+
3+ from pytensor import Variable
24from pytensor .graph .rewriting .basic import copy_stack_trace , node_rewriter
3- from pytensor .tensor .elemwise import DimShuffle , Elemwise
4- from pytensor .tensor .math import Sum , exp
5+ from pytensor .tensor .basic import expand_dims
6+ from pytensor .tensor .elemwise import DimShuffle
7+ from pytensor .tensor .extra_ops import squeeze
8+ from pytensor .tensor .math import Sum , exp , log , logsumexp
59from pytensor .tensor .math import sum as at_sum
610from pytensor .tensor .math import true_div
7- from pytensor .tensor .rewriting .basic import register_specialize
8- from pytensor .tensor .rewriting .math import local_mul_canonizer
11+ from pytensor .tensor .rewriting .basic import register_specialize , register_stabilize
12+ from pytensor .tensor .rewriting .math import local_log_sum_exp , local_mul_canonizer
13+ from pytensor .tensor .rewriting .subtensor import is_full_slice
914from pytensor .tensor .special import LogSoftmax , Softmax , SoftmaxGrad
10- from pytensor .tensor .subtensor import AdvancedIncSubtensor
15+ from pytensor .tensor .subtensor import (
16+ AdvancedIncSubtensor ,
17+ AdvancedSubtensor ,
18+ AdvancedSubtensor1 ,
19+ Subtensor ,
20+ indices_from_subtensor ,
21+ is_basic_idx ,
22+ )
1123from pytensor .tensor .type import (
1224 values_eq_approx_remove_inf ,
1325 values_eq_approx_remove_nan ,
1426)
27+ from pytensor .tensor .type_other import NoneTypeT
28+
29+
30+ subtensor_ops = (
31+ Subtensor ,
32+ AdvancedSubtensor ,
33+ AdvancedSubtensor1 ,
34+ )
1535
1636
1737# This is not registered in stabilize, as it cause some crossentropy
1838# optimization to not be inserted.
1939@register_specialize ("stabilize" , "fast_compile" )
20- @node_rewriter ([Elemwise ])
40+ @node_rewriter ([log ])
2141def local_logsoftmax (fgraph , node ):
2242 """
2343 Detect Log(Softmax(x)) and replace it with LogSoftmax(x)
2444
2545 Note: only forward pass is affected
2646 """
27- if (
28- isinstance (node .op , Elemwise )
29- and isinstance (node .op .scalar_op , aes .Log )
30- and len (node .inputs ) == 1
31- and node .inputs [0 ].owner is not None
32- and isinstance (node .inputs [0 ].owner .op , Softmax )
47+ if node .inputs [0 ].owner is not None and isinstance (
48+ node .inputs [0 ].owner .op , Softmax
3349 ):
3450 inVars = node .inputs [0 ].owner .inputs [0 ]
3551 new_op = LogSoftmax (axis = node .inputs [0 ].owner .op .axis )
@@ -39,6 +55,92 @@ def local_logsoftmax(fgraph, node):
3955 return [ret ]
4056
4157
58+ @register_stabilize
59+ @node_rewriter ([log ])
60+ def local_log_subtensor_softmax (fgraph , node ):
61+ """Replace log(softmax(x, axis)[idx]) -> x[idx] - logsumexp(x, axis).
62+
63+ This can only be done when indexing happens over axis dims.
64+ There can be non-indexed axis dims, but not non-axis indexed dims.
65+ """
66+ [subtensor_var ] = node .inputs
67+ subtensor_node = subtensor_var .owner
68+
69+ if subtensor_node is not None and isinstance (subtensor_node .op , subtensor_ops ):
70+ softmax_var , * idxs = subtensor_node .inputs
71+ softmax_node = softmax_var .owner
72+ if softmax_node is not None and isinstance (softmax_node .op , Softmax ):
73+ if isinstance (subtensor_node .op , Subtensor ):
74+ idxs = indices_from_subtensor (idxs , subtensor_node .op .idx_list )
75+
76+ # TODO: support expand_dims
77+ if any (
78+ (isinstance (idx , Variable ) and isinstance (idx .type , NoneTypeT ))
79+ for idx in idxs
80+ ):
81+ return None
82+
83+ [x ] = softmax_node .inputs
84+ axis = softmax_node .op .axis
85+ if axis is not None :
86+ axis = normalize_axis_index (axis , ndim = x .type .ndim )
87+
88+ indexed_dims = [
89+ dim for dim , idx in enumerate (idxs ) if not is_full_slice (idx )
90+ ]
91+
92+ # We can only apply the rewrite when the softmax is applied across all indexed dims
93+ if axis is not None and {axis } != set (indexed_dims ):
94+ return None
95+
96+ dims_to_expand = ()
97+ dims_to_drop = ()
98+ if isinstance (subtensor_node .op , Subtensor ):
99+ dims_to_drop = tuple (
100+ dim for dim , idx in enumerate (idxs ) if getattr (idx , "ndim" , - 1 ) == 0
101+ )
102+ if isinstance (subtensor_node .op , (AdvancedSubtensor , AdvancedSubtensor1 )):
103+ adv_dims_idxs = tuple (
104+ (dim , idx ) for dim , idx in enumerate (idxs ) if not is_basic_idx (idx )
105+ )
106+ adv_dims = tuple (dim for dim , idx in adv_dims_idxs )
107+ adv_idxs = tuple (idx for dim , idx in adv_dims_idxs )
108+
109+ # Boolean indexing not supported
110+ if any (idx .dtype == "bool" for idx in adv_idxs ):
111+ return None
112+
113+ # Non-contiguous advanced indexing not supported
114+ if tuple (range (adv_dims [0 ], adv_dims [- 1 ] + 1 )) != adv_dims :
115+ return None
116+
117+ ndim_adv_idx = max (idx .ndim for idx in adv_idxs )
118+ n_new_dims = ndim_adv_idx - len (adv_idxs )
119+ # Advanced indexing introduces new dims
120+ if n_new_dims > 0 :
121+ dims_to_expand = tuple (range (adv_dims [0 ], adv_dims [0 ] + n_new_dims ))
122+ # It reduces number of dims
123+ elif n_new_dims < 0 :
124+ dims_to_drop = tuple (
125+ range (adv_dims [0 ], adv_dims [0 ] + abs (n_new_dims ))
126+ )
127+
128+ # Rewrite stable form of logsumexp immediately
129+ [x_logsumexp ] = local_log_sum_exp .transform (
130+ None , logsumexp (x , axis = axis , keepdims = True ).owner
131+ )
132+
133+ assert not (dims_to_drop and dims_to_expand )
134+ if dims_to_expand :
135+ x_logsumexp = expand_dims (x_logsumexp , dims_to_expand )
136+ elif dims_to_drop :
137+ x_logsumexp = squeeze (x_logsumexp , axis = dims_to_drop )
138+ ret = x [tuple (idxs )] - x_logsumexp
139+
140+ copy_stack_trace (node .outputs , ret )
141+ return [ret ]
142+
143+
42144# This is not registered in stabilize, as it cause some crossentropy
43145# optimization to not be inserted.
44146@register_specialize ("stabilize" , "fast_compile" )
0 commit comments