Skip to content

Commit

Permalink
add onnx reverse sequence op
Browse files Browse the repository at this point in the history
  • Loading branch information
xp224797 authored and alter-xp committed Apr 19, 2021
1 parent 1fb32b0 commit a24d77b
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 0 deletions.
11 changes: 11 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2114,6 +2114,16 @@ def _impl_v9(cls, inputs, attr, params):
return _op.transpose(output, axes=(1, 0))



class ReverseSequence(OnnxOpConverter):
"""Operator converter for ReverseSequence"""

@classmethod
def _impl_v10(cls, inputs, attr, params):

return _op.reverse_sequence(inputs[0], inputs[1], attr["time_axis"], attr["batch_axis"])


class TopK(OnnxOpConverter):
"""Operator converter for TopK"""

Expand Down Expand Up @@ -2801,6 +2811,7 @@ def _get_convert_map(opset):
"QuantizeLinear": QuantizeLinear.get_converter(opset),
"DequantizeLinear": DequantizeLinear.get_converter(opset),
"DynamicQuantizeLinear": DynamicQuantizeLinear.get_converter(opset),
"ReverseSequence": ReverseSequence.get_converter(opset),
}


Expand Down
49 changes: 49 additions & 0 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -4328,6 +4328,51 @@ def verify_embedding_bag(num_embedding, embedding_dim, data_shape, num_bags=None
verify_embedding_bag(32, 2, [3, 3])


def verify_reverse_sequence(x, sequence_lens, batch_axis, time_axis):
node = onnx.helper.make_node(
"ReverseSequence",
inputs=["x", "sequence_lens"],
outputs=["y"],
time_axis=time_axis,
batch_axis=batch_axis,
)

graph = helper.make_graph(
[node],
"reverse_sequence_test",
inputs=[
helper.make_tensor_value_info("x", TensorProto.FLOAT, list(x.shape)),
helper.make_tensor_value_info(
"sequence_lens", TensorProto.INT64, list(sequence_lens.shape)
),
],
outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(x.shape))],
)

model = helper.make_model(graph, producer_name="reverse_sequence_test")
verify_with_ort_with_inputs(model, [x, sequence_lens], list(x.shape))


@tvm.testing.uses_gpu
def test_reverse_sequence():
x = np.array(
[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]],
dtype=np.float32,
)
sequence_lens = np.array([1, 2, 3, 4], dtype=np.int64)
y = np.array(
[[0, 5, 10, 15], [4, 1, 6, 11], [8, 9, 2, 7], [12, 13, 14, 3]],
dtype=np.float32,
)
verify_reverse_sequence(x, sequence_lens, 0, 1)

y = np.array(
[[0, 1, 2, 3], [5, 4, 6, 7], [10, 9, 8, 11], [15, 14, 13, 12]],
dtype=np.float32,
)
verify_reverse_sequence(x, sequence_lens, 1, 0)


if __name__ == "__main__":
test_flatten()
test_reshape()
Expand Down Expand Up @@ -4407,4 +4452,8 @@ def verify_embedding_bag(num_embedding, embedding_dim, data_shape, num_bags=None
test_softplus()
test_cumsum()
test_wrong_input()
<<<<<<< HEAD
test_aten()
=======
test_reverse_sequence()
>>>>>>> 726b946b9... add onnx reverse sequence op

0 comments on commit a24d77b

Please sign in to comment.