Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transpose optimization for Softmax and LogSoftmax (fixes #1716) #1964

Merged
merged 2 commits into from
Jun 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 124 additions & 0 deletions tests/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1369,6 +1369,130 @@ def test_transpose_argmax(self):
self.run_transpose_compare(["res"], {"X": np.random.randn(*input_shape).astype(np.float32)},
model_proto, remaining_transpose_num=0)

@check_opset_max_version(
12, "Before opset 13, Softmax coerced its inputs to 2D and can thus only be optimized for certain permutations"
)
def test_transpose_softmax_valid_perm(self):
input_shape = [4, 4, 4, 4]
node0 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans_1")
node1 = helper.make_node("Softmax", ["Y"], ["Z"], axis=1, name="softmax")
node2 = helper.make_node("Transpose", ["Z"], ["res"], perm=[0, 3, 1, 2], name="trans_2")

graph = helper.make_graph(
[node0, node1, node2],
"transpose-softmax-test",
[helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape)],
[helper.make_tensor_value_info("res", TensorProto.FLOAT, input_shape)],
)

model_proto = self.make_model(graph, producer_name="onnx-tests")
self.run_transpose_compare(
["res"], {"X": np.random.randn(*input_shape).astype(np.float32)}, model_proto, remaining_transpose_num=0
)

@check_opset_max_version(
12, "Before opset 13, Softmax coerced its inputs to 2D and can thus only be optimized for certain permutations"
)
def test_transpose_softmax_invalid_perm(self):
input_shape = [4, 4, 4, 4]
node0 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans_1")
node1 = helper.make_node("Softmax", ["Y"], ["Z"], axis=3, name="softmax")
node2 = helper.make_node("Transpose", ["Z"], ["res"], perm=[0, 3, 1, 2], name="trans_2")

graph = helper.make_graph(
[node0, node1, node2],
"transpose-softmax-test",
[helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape)],
[helper.make_tensor_value_info("res", TensorProto.FLOAT, input_shape)],
)

model_proto = self.make_model(graph, producer_name="onnx-tests")
self.run_transpose_compare(
["res"], {"X": np.random.randn(*input_shape).astype(np.float32)}, model_proto, remaining_transpose_num=2
)

@check_opset_min_version(13, "Softmax can be optimized for all permutations since opset 13")
def test_transpose_softmax_13(self):
input_shape = [4, 4, 4, 4]
node0 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans_1")
node1 = helper.make_node("Softmax", ["Y"], ["Z"], axis=3, name="softmax")
node2 = helper.make_node("Transpose", ["Z"], ["res"], perm=[0, 3, 1, 2], name="trans_2")

graph = helper.make_graph(
[node0, node1, node2],
"transpose-softmax-test",
[helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape)],
[helper.make_tensor_value_info("res", TensorProto.FLOAT, input_shape)],
)

model_proto = self.make_model(graph, producer_name="onnx-tests")
self.run_transpose_compare(
["res"], {"X": np.random.randn(*input_shape).astype(np.float32)}, model_proto, remaining_transpose_num=0
)

@check_opset_max_version(
12,
"Before opset 13, LogSoftmax coerced its inputs to 2D and can thus only be optimized for certain permutations",
)
def test_transpose_logsoftmax_valid_perm(self):
input_shape = [4, 4, 4, 4]
node0 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans_1")
node1 = helper.make_node("LogSoftmax", ["Y"], ["Z"], axis=1, name="logsoftmax")
node2 = helper.make_node("Transpose", ["Z"], ["res"], perm=[0, 3, 1, 2], name="trans_2")

graph = helper.make_graph(
[node0, node1, node2],
"transpose-logsoftmax-test",
[helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape)],
[helper.make_tensor_value_info("res", TensorProto.FLOAT, input_shape)],
)

model_proto = self.make_model(graph, producer_name="onnx-tests")
self.run_transpose_compare(
["res"], {"X": np.random.randn(*input_shape).astype(np.float32)}, model_proto, remaining_transpose_num=0
)

@check_opset_max_version(
12,
"Before opset 13, LogSoftmax coerced its inputs to 2D and can thus only be optimized for certain permutations",
)
def test_transpose_logsoftmax_invalid_perm(self):
input_shape = [4, 4, 4, 4]
node0 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans_1")
node1 = helper.make_node("LogSoftmax", ["Y"], ["Z"], axis=3, name="logsoftmax")
node2 = helper.make_node("Transpose", ["Z"], ["res"], perm=[0, 3, 1, 2], name="trans_2")

graph = helper.make_graph(
[node0, node1, node2],
"transpose-logsoftmax-test",
[helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape)],
[helper.make_tensor_value_info("res", TensorProto.FLOAT, input_shape)],
)

model_proto = self.make_model(graph, producer_name="onnx-tests")
self.run_transpose_compare(
["res"], {"X": np.random.randn(*input_shape).astype(np.float32)}, model_proto, remaining_transpose_num=2
)

@check_opset_min_version(13, "LogSoftmax can be optimized for all permutations since opset 13")
def test_transpose_logsoftmax_13(self):
input_shape = [4, 4, 4, 4]
node0 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans_1")
node1 = helper.make_node("LogSoftmax", ["Y"], ["Z"], axis=3, name="logsoftmax")
node2 = helper.make_node("Transpose", ["Z"], ["res"], perm=[0, 3, 1, 2], name="trans_2")

graph = helper.make_graph(
[node0, node1, node2],
"transpose-logsoftmax-test",
[helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape)],
[helper.make_tensor_value_info("res", TensorProto.FLOAT, input_shape)],
)

model_proto = self.make_model(graph, producer_name="onnx-tests")
self.run_transpose_compare(
["res"], {"X": np.random.randn(*input_shape).astype(np.float32)}, model_proto, remaining_transpose_num=0
)

def test_transpose_tile(self):
input_shape = [1, 2, 3, 4]

Expand Down
24 changes: 24 additions & 0 deletions tf2onnx/optimizer/transpose_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ def _initialize_handlers(self):
"Identity": self._identity_handler,
"LeakyRelu": self._simple_through_handler,
"Log": self._simple_through_handler,
"LogSoftmax": self._softmax_handler,
"Max": self._maxmin_handler,
"Min": self._maxmin_handler,
"Mul": self._mul_handler,
Expand All @@ -223,6 +224,7 @@ def _initialize_handlers(self):
"Relu": self._simple_through_handler,
"Shape": self._shape_handler,
"Sigmoid": self._simple_through_handler,
"Softmax": self._softmax_handler,
"Sum": self._sum_handler,
"Slice": self._slice_handler,
"Split": self._split_handler,
Expand Down Expand Up @@ -827,6 +829,28 @@ def permute_pads(pads):
def _prelu_handler(self, trans, node):
return self._handle_node_having_branches(trans, node)

def _softmax_handler(self, trans, node):
trans_rank = get_transpose_rank(trans)
perm = trans.get_attr("perm").ints

if self._g.opset >= 13:
# Softmax operates on an arbitrary axis since opset 13
axis = node.get_attr_value("axis", -1)
new_axis = perm[axis + trans_rank if axis < 0 else axis]
if not self._switch_transpose_and_node(node, trans):
return False
node.set_attr("axis", new_axis)
return True

# For older opsets, the "axis" attribute determines the coercion point for coercing the input tensor to 2D.
# We can safely switch transpose and node if the permutation does not make any axes cross that boundary.
coercion_axis = node.get_attr_value("axis", 1)
for from_axis, to_axis in enumerate(perm):
if (from_axis < coercion_axis <= to_axis) or (from_axis >= coercion_axis > to_axis):
return False

return self._switch_transpose_and_node(node, trans)

def _arg_min_max_handler(self, trans, node):
axis = node.get_attr_value("axis", 0)
node.set_attr("axes", [axis])
Expand Down