Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Support SequenceErase op #13865

Merged
merged 4 commits into from
Jan 31, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 34 additions & 5 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -6148,13 +6148,32 @@ def _impl_v11(cls, inputs, attr, params):
return _expr.Tuple(inputs)


class SequenceLength(OnnxOpConverter):
"""Operator converter for sequence length op."""
class SequenceErase(OnnxOpConverter):
"""Operator converter for sequence erase op."""

@classmethod
def _impl_v11(cls, inputs, attr, params):
# Get length of input sequence
return _expr.const(len(inputs[0]), dtype="int64")
# Erase tensor from sequence on specified position
input_sequence = inputs[0]

if len(inputs) == 2:
position = inputs[1]
# Non constant position is not supported.
if isinstance(position, _expr.Constant):
position = position.data.numpy()
elif position.name_hint in params:
position = params[position.name_hint].numpy()
else:
raise NotImplementedError("Position must be a constant.")
else:
position = -1

if position < 0:
position = len(input_sequence) + position
# Convert sequence to a list, insert tensors before erased, and repackage as Tuple.
tensor_list = [input_sequence[i] for i in range(len(input_sequence)) if i != position]
# Create new tuple and return.
return _expr.Tuple(tensor_list)


class SequenceInsert(OnnxOpConverter):
Expand Down Expand Up @@ -6188,6 +6207,15 @@ def _impl_v11(cls, inputs, attr, params):
return _expr.Tuple(tensor_list)


class SequenceLength(OnnxOpConverter):
"""Operator converter for sequence length op."""

@classmethod
def _impl_v11(cls, inputs, attr, params):
# Get length of input sequence
return _expr.const(len(inputs[0]), dtype="int64")


class ConcatFromSequence(OnnxOpConverter):
"""Operator converter for sequence concatenation op."""

Expand Down Expand Up @@ -6492,8 +6520,9 @@ def _get_convert_map(opset):
"LinearRegressor": LinearRegressor.get_converter(opset),
# Sequence operators
"SequenceConstruct": SequenceConstruct.get_converter(opset),
"SequenceLength": SequenceLength.get_converter(opset),
"SequenceErase": SequenceErase.get_converter(opset),
"SequenceInsert": SequenceInsert.get_converter(opset),
"SequenceLength": SequenceLength.get_converter(opset),
"ConcatFromSequence": ConcatFromSequence.get_converter(opset),
"SplitToSequence": SplitToSequence.get_converter(opset),
"SequenceAt": SequenceAt.get_converter(opset),
Expand Down
10 changes: 9 additions & 1 deletion tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -7747,10 +7747,17 @@ def verify_sequence_ops(tensor_shape, num_tensors, axis=0, position=0, new_axis=
outputs=["inserted_sequence"],
)

# Test sequence erase.
erase_node = helper.make_node(
"SequenceErase",
inputs=["inserted_sequence", "position"],
outputs=["erased_sequence"],
)

# Test sequence concatenation.
concat_node = helper.make_node(
"ConcatFromSequence",
inputs=["inserted_sequence"],
inputs=["erased_sequence"],
outputs=["concat_sequence"],
axis=axis,
)
Expand Down Expand Up @@ -7796,6 +7803,7 @@ def verify_sequence_ops(tensor_shape, num_tensors, axis=0, position=0, new_axis=
position_node,
construct_node,
insert_node,
erase_node,
concat_node,
split_node,
at_node,
Expand Down