1- from pytensor import scalar as aes
1+ from numpy .core .numeric import normalize_axis_index # type: ignore
2+
23from pytensor .graph .rewriting .basic import copy_stack_trace , node_rewriter
3- from pytensor .tensor .elemwise import DimShuffle , Elemwise
4- from pytensor .tensor .math import Sum , exp
4+ from pytensor .tensor .elemwise import DimShuffle
5+ from pytensor .tensor .math import Sum , exp , log
56from pytensor .tensor .math import sum as at_sum
67from pytensor .tensor .math import true_div
7- from pytensor .tensor .rewriting .basic import register_specialize
8+ from pytensor .tensor .rewriting .basic import register_stabilize
89from pytensor .tensor .rewriting .math import local_mul_canonizer
9- from pytensor .tensor .special import LogSoftmax , Softmax , SoftmaxGrad
10- from pytensor .tensor .subtensor import AdvancedIncSubtensor
10+ from pytensor .tensor .special import Softmax , SoftmaxGrad , log_softmax
11+ from pytensor .tensor .subtensor import (
12+ AdvancedIncSubtensor ,
13+ AdvancedSubtensor ,
14+ AdvancedSubtensor1 ,
15+ Subtensor ,
16+ )
1117from pytensor .tensor .type import (
1218 values_eq_approx_remove_inf ,
1319 values_eq_approx_remove_nan ,
1420)
1521
1622
17- # This is not registered in stabilize, as it cause some crossentropy
18- # optimization to not be inserted.
19- @register_specialize ("stabilize" , "fast_compile" )
20- @node_rewriter ([Elemwise ])
23+ subtensor_ops = (
24+ Subtensor ,
25+ AdvancedSubtensor ,
26+ AdvancedSubtensor1 ,
27+ )
28+
29+
30+ @register_stabilize
31+ @node_rewriter ([log ])
2132def local_logsoftmax (fgraph , node ):
2233 """
2334 Detect Log(Softmax(x)) and replace it with LogSoftmax(x)
2435
36+ This also lifts Subtensor or Dimshuffle operations that could be in between log and softmax
37+
2538 Note: only forward pass is affected
2639 """
27- if (
28- isinstance (node .op , Elemwise )
29- and isinstance (node .op .scalar_op , aes .Log )
30- and len (node .inputs ) == 1
31- and node .inputs [0 ].owner is not None
32- and isinstance (node .inputs [0 ].owner .op , Softmax )
33- ):
34- inVars = node .inputs [0 ].owner .inputs [0 ]
35- new_op = LogSoftmax (axis = node .inputs [0 ].owner .op .axis )
36- ret = new_op (inVars )
37- ret .tag .values_eq_approx = values_eq_approx_remove_inf
38- copy_stack_trace ([node .inputs [0 ], node .outputs [0 ]], ret )
39- return [ret ]
40+ def find_softmax_under_lifteable_ops (inp_node , ops_to_lift ):
41+ if inp_node is None :
42+ return
43+
44+ if isinstance (inp_node .op , Softmax ):
45+ return inp_node
46+
47+ if isinstance (inp_node .op , subtensor_ops ):
48+ ops_to_lift .append ((inp_node .op , inp_node .inputs [1 :]))
49+ return find_softmax_under_lifteable_ops (inp_node .inputs [0 ].owner , ops_to_lift )
50+
51+ if isinstance (inp_node .op , DimShuffle ):
52+ ops_to_lift .append ((inp_node .op , ()))
53+ return find_softmax_under_lifteable_ops (inp_node .inputs [0 ].owner , ops_to_lift )
54+
55+ ops_to_lift = []
56+ softmax_node = find_softmax_under_lifteable_ops (node .inputs [0 ].owner , ops_to_lift )
57+
58+ if softmax_node is None :
59+ return
60+
61+ ret = log_softmax (softmax_node .inputs [0 ], axis = softmax_node .op .axis )
62+ ret .tag .values_eq_approx = values_eq_approx_remove_inf
63+
64+ # Lift ops that used to be between log and softmax
65+ for op_to_lift , parameters in reversed (ops_to_lift ):
66+ ret = op_to_lift (ret , * parameters )
67+
68+ copy_stack_trace (node .outputs , ret )
69+ return [ret ]
4070
4171
42- # This is not registered in stabilize, as it cause some crossentropy
43- # optimization to not be inserted.
44- @register_specialize ("stabilize" , "fast_compile" )
72+ @register_stabilize
4573@node_rewriter ([SoftmaxGrad ])
4674def local_logsoftmax_grad (fgraph , node ):
4775 """
@@ -50,9 +78,7 @@ def local_logsoftmax_grad(fgraph, node):
5078 Note: only grad is affected
5179 """
5280 if (
53- isinstance (node .op , SoftmaxGrad )
54- and len (node .inputs ) == 2
55- and node .inputs [0 ].owner is not None
81+ node .inputs [0 ].owner is not None
5682 and node .inputs [0 ].owner .op == true_div
5783 and len (node .inputs [0 ].owner .inputs ) >= 2
5884 and node .inputs [0 ].owner .inputs [1 ].owner is not None
0 commit comments