Skip to content

Commit

Permalink
Transpose optimization for Softmax for opset>=13 (fixes onnx#1716)
Browse files Browse the repository at this point in the history
In lower opsets, Softmax always coerces its inputs to a 2D tensor, making Transpose operations necessary if the permutation moves axes between the coerced batch and feature dimensions.
While one could find and optimize away Transposes that only permute axes in a way that keeps the batch/feature split, I would not consider that a common use case; optimizing only for opset >=13 seems good enough for now.
  • Loading branch information
Felix Thielke committed Sep 23, 2021
1 parent 446494e commit aafcbf4
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
18 changes: 18 additions & 0 deletions tests/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1302,6 +1302,24 @@ def test_transpose_argmax(self):
self.run_transpose_compare(["res"], {"X": np.random.randn(*input_shape).astype(np.float32)},
model_proto, remaining_transpose_num=0)

@check_opset_min_version(13, "Softmax can only be optimized since opset 13")
def test_transpose_softmax(self):
input_shape = [4, 4, 4, 4]
node0 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans_1")
node1 = helper.make_node("Softmax", ["Y"], ["Z"], axis=3, name="softmax")
node2 = helper.make_node("Transpose", ["Z"], ["res"], perm=[0, 3, 1, 2], name="trans_2")

graph = helper.make_graph(
[node0, node1, node2],
"transpose-softmax-test",
[helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape)],
[helper.make_tensor_value_info("res", TensorProto.FLOAT, input_shape)],
)

model_proto = self.make_model(graph, producer_name="onnx-tests")
self.run_transpose_compare(["res"], {"X": np.random.randn(*input_shape).astype(np.float32)},
model_proto, remaining_transpose_num=0)

def test_transpose_tile(self):
input_shape = [1, 2, 3, 4]

Expand Down
8 changes: 8 additions & 0 deletions tf2onnx/optimizer/transpose_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ def _initialize_handlers(self):
"Relu": self._simple_through_handler,
"Shape": self._shape_handler,
"Sigmoid": self._simple_through_handler,
"Softmax": self._softmax_handler,
"Sum": self._sum_handler,
"Slice": self._slice_handler,
"Split": self._split_handler,
Expand Down Expand Up @@ -822,6 +823,13 @@ def permute_pads(pads):

def _prelu_handler(self, trans, node):
return self._handle_node_having_branches(trans, node)

def _softmax_handler(self, trans, node):
# Softmax only operates on an arbitrary axis since opset 13
if self._g.opset >= 13:
return self._arg_min_max_handler(trans, node)
else:
return False

def _arg_min_max_handler(self, trans, node):
axis = node.get_attr_value("axis", 0)
Expand Down

0 comments on commit aafcbf4

Please sign in to comment.