diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index a62e505b287a..e4a6885efeb7 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2269,6 +2269,15 @@ def _impl_v9(cls, inputs, attr, params): return _op.transpose(output, axes=(1, 0)) +class ReverseSequence(OnnxOpConverter): + """Operator converter for ReverseSequence""" + + @classmethod + def _impl_v10(cls, inputs, attr, params): + + return _op.reverse_sequence(inputs[0], inputs[1], attr["time_axis"], attr["batch_axis"]) + + class TopK(OnnxOpConverter): """Operator converter for TopK""" @@ -3007,6 +3016,7 @@ def _get_convert_map(opset): "QuantizeLinear": QuantizeLinear.get_converter(opset), "DequantizeLinear": DequantizeLinear.get_converter(opset), "DynamicQuantizeLinear": DynamicQuantizeLinear.get_converter(opset), + "ReverseSequence": ReverseSequence.get_converter(opset), } diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index f878fa939fe2..89655840da2a 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -4225,9 +4225,6 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"): "test_resize_upsample_sizes_nearest_ceil_half_pixel/", "test_resize_upsample_sizes_nearest_floor_align_corners/", "test_resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric/", - # ---- - "test_reversesequence_batch/", - "test_reversesequence_time/", "test_rnn_seq_length/", "test_round/", "test_scan9_sum/", @@ -4350,6 +4347,44 @@ def verify_embedding_bag(num_embedding, embedding_dim, data_shape, num_bags=None verify_embedding_bag(32, 2, [3, 3]) +def verify_reverse_sequence(x, sequence_lens, batch_axis, time_axis): + node = onnx.helper.make_node( + "ReverseSequence", + inputs=["x", "sequence_lens"], + outputs=["y"], + time_axis=time_axis, + batch_axis=batch_axis, + ) + + graph = helper.make_graph( + [node], + "reverse_sequence_test", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, list(x.shape)), + helper.make_tensor_value_info( + "sequence_lens", TensorProto.INT64, list(sequence_lens.shape) + ), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(x.shape))], + ) + + model = helper.make_model(graph, producer_name="reverse_sequence_test") + verify_with_ort_with_inputs(model, [x, sequence_lens], [x.shape]) + + +@tvm.testing.uses_gpu +def test_reverse_sequence(): + x = np.array( + [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]], + dtype=np.float32, + ) + sequence_lens = np.array([1, 2, 3, 4], dtype=np.int64) + verify_reverse_sequence(x, sequence_lens, 0, 1) + + sequence_lens = np.array([4, 3, 2, 1], dtype=np.int64) + verify_reverse_sequence(x, sequence_lens, 1, 0) + + if __name__ == "__main__": test_flatten() test_reshape() @@ -4430,3 +4465,4 @@ def verify_embedding_bag(num_embedding, embedding_dim, data_shape, num_bags=None test_cumsum() test_wrong_input() test_aten() + test_reverse_sequence()