Skip to content

Commit

Permalink
add onnx reverse sequence op (apache#7771)
Browse files Browse the repository at this point in the history
Co-authored-by: xp224797 <xp224797@alibaba-inc.com>
  • Loading branch information
2 people authored and trevor-m committed Jun 17, 2021
1 parent c5b0f08 commit 46d0cc0
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 3 deletions.
10 changes: 10 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2269,6 +2269,15 @@ def _impl_v9(cls, inputs, attr, params):
return _op.transpose(output, axes=(1, 0))


class ReverseSequence(OnnxOpConverter):
"""Operator converter for ReverseSequence"""

@classmethod
def _impl_v10(cls, inputs, attr, params):

return _op.reverse_sequence(inputs[0], inputs[1], attr["time_axis"], attr["batch_axis"])


class TopK(OnnxOpConverter):
"""Operator converter for TopK"""

Expand Down Expand Up @@ -3007,6 +3016,7 @@ def _get_convert_map(opset):
"QuantizeLinear": QuantizeLinear.get_converter(opset),
"DequantizeLinear": DequantizeLinear.get_converter(opset),
"DynamicQuantizeLinear": DynamicQuantizeLinear.get_converter(opset),
"ReverseSequence": ReverseSequence.get_converter(opset),
}


Expand Down
42 changes: 39 additions & 3 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -4225,9 +4225,6 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"):
"test_resize_upsample_sizes_nearest_ceil_half_pixel/",
"test_resize_upsample_sizes_nearest_floor_align_corners/",
"test_resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric/",
# ----
"test_reversesequence_batch/",
"test_reversesequence_time/",
"test_rnn_seq_length/",
"test_round/",
"test_scan9_sum/",
Expand Down Expand Up @@ -4350,6 +4347,44 @@ def verify_embedding_bag(num_embedding, embedding_dim, data_shape, num_bags=None
verify_embedding_bag(32, 2, [3, 3])


def verify_reverse_sequence(x, sequence_lens, batch_axis, time_axis):
node = onnx.helper.make_node(
"ReverseSequence",
inputs=["x", "sequence_lens"],
outputs=["y"],
time_axis=time_axis,
batch_axis=batch_axis,
)

graph = helper.make_graph(
[node],
"reverse_sequence_test",
inputs=[
helper.make_tensor_value_info("x", TensorProto.FLOAT, list(x.shape)),
helper.make_tensor_value_info(
"sequence_lens", TensorProto.INT64, list(sequence_lens.shape)
),
],
outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(x.shape))],
)

model = helper.make_model(graph, producer_name="reverse_sequence_test")
verify_with_ort_with_inputs(model, [x, sequence_lens], [x.shape])


@tvm.testing.uses_gpu
def test_reverse_sequence():
x = np.array(
[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]],
dtype=np.float32,
)
sequence_lens = np.array([1, 2, 3, 4], dtype=np.int64)
verify_reverse_sequence(x, sequence_lens, 0, 1)

sequence_lens = np.array([4, 3, 2, 1], dtype=np.int64)
verify_reverse_sequence(x, sequence_lens, 1, 0)


if __name__ == "__main__":
test_flatten()
test_reshape()
Expand Down Expand Up @@ -4430,3 +4465,4 @@ def verify_embedding_bag(num_embedding, embedding_dim, data_shape, num_bags=None
test_cumsum()
test_wrong_input()
test_aten()
test_reverse_sequence()

0 comments on commit 46d0cc0

Please sign in to comment.