From 1c46eeb989789b10c25b92bc1277fcc06c696981 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Mon, 30 Jan 2023 08:39:16 +0300 Subject: [PATCH 1/4] SequenceErase was implemented in ONNX front-end --- python/tvm/relay/frontend/onnx.py | 41 +++++++++++++++++++++++++++---- 1 file changed, 36 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 6e0c7cc2dd3f..b8d8ac0441eb 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -6148,13 +6148,35 @@ def _impl_v11(cls, inputs, attr, params): return _expr.Tuple(inputs) -class SequenceLength(OnnxOpConverter): - """Operator converter for sequence length op.""" +class SequenceErase(OnnxOpConverter): + """Operator converter for sequence erase op.""" @classmethod def _impl_v11(cls, inputs, attr, params): - # Get length of input sequence - return _expr.const(len(inputs[0]), dtype="int64") + # Erase tensor from sequence on specified position + input_sequence = inputs[0] + + if len(inputs) == 2: + position = inputs[1] + # Non constant position is not supported. + if isinstance(position, _expr.Constant): + position = position.data.numpy() + elif position.name_hint in params: + position = params[position.name_hint].numpy() + else: + raise NotImplementedError("Position must be a constant.") + else: + position = -1 + + if position < 0: + position = len(input_sequence) + position + 1 + # Convert sequence to a list, insert tensors before erased, and repackage as Tuple. + tensor_list = [input_sequence[i] for i in range(position)] + # Insert tensors tail after erased one. + for i in range(position + 1, len(input_sequence)): + tensor_list.append(input_sequence[i]) + # Create new tuple and return. + return _expr.Tuple(tensor_list) class SequenceInsert(OnnxOpConverter): @@ -6188,6 +6210,14 @@ def _impl_v11(cls, inputs, attr, params): return _expr.Tuple(tensor_list) +class SequenceLength(OnnxOpConverter): + """Operator converter for sequence length op.""" + + @classmethod + def _impl_v11(cls, inputs, attr, params): + # Get length of input sequence + return _expr.const(len(inputs[0]), dtype="int64") + class ConcatFromSequence(OnnxOpConverter): """Operator converter for sequence concatenation op.""" @@ -6492,8 +6522,9 @@ def _get_convert_map(opset): "LinearRegressor": LinearRegressor.get_converter(opset), # Sequence operators "SequenceConstruct": SequenceConstruct.get_converter(opset), - "SequenceLength": SequenceLength.get_converter(opset), + "SequenceErase": SequenceErase.get_converter(opset), "SequenceInsert": SequenceInsert.get_converter(opset), + "SequenceLength": SequenceLength.get_converter(opset), "ConcatFromSequence": ConcatFromSequence.get_converter(opset), "SplitToSequence": SplitToSequence.get_converter(opset), "SequenceAt": SequenceAt.get_converter(opset), From 86e4aa62ea259bf7cfc8a32eb3bde845bd10ac3c Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Mon, 30 Jan 2023 08:54:06 +0300 Subject: [PATCH 2/4] add SequenceErase node to Sequence test --- tests/python/frontend/onnx/test_forward.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 6a780a632fb7..3e1af4086784 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -7747,10 +7747,17 @@ def verify_sequence_ops(tensor_shape, num_tensors, axis=0, position=0, new_axis= outputs=["inserted_sequence"], ) + # Test sequence erase. + erase_node = helper.make_node( + "SequenceErase", + inputs=["inserted_sequence", "position"], + outputs=["erased_sequence"], + ) + # Test sequence concatenation. concat_node = helper.make_node( "ConcatFromSequence", - inputs=["inserted_sequence"], + inputs=["erased_sequence"], outputs=["concat_sequence"], axis=axis, ) @@ -7796,6 +7803,7 @@ def verify_sequence_ops(tensor_shape, num_tensors, axis=0, position=0, new_axis= position_node, construct_node, insert_node, + erase_node, concat_node, split_node, at_node, From 7d5eb286910dc309a3502d422df73075b3be5406 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Mon, 30 Jan 2023 13:31:45 +0300 Subject: [PATCH 3/4] remark from reviewer. fix negative position recalculation --- python/tvm/relay/frontend/onnx.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index b8d8ac0441eb..9eaa9ff58cc0 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -6169,12 +6169,9 @@ def _impl_v11(cls, inputs, attr, params): position = -1 if position < 0: - position = len(input_sequence) + position + 1 + position = len(input_sequence) + position # Convert sequence to a list, insert tensors before erased, and repackage as Tuple. - tensor_list = [input_sequence[i] for i in range(position)] - # Insert tensors tail after erased one. - for i in range(position + 1, len(input_sequence)): - tensor_list.append(input_sequence[i]) + tensor_list = [input_sequence[i] for i in range(len(input_sequence)) if i != position] # Create new tuple and return. return _expr.Tuple(tensor_list) @@ -6218,6 +6215,7 @@ def _impl_v11(cls, inputs, attr, params): # Get length of input sequence return _expr.const(len(inputs[0]), dtype="int64") + class ConcatFromSequence(OnnxOpConverter): """Operator converter for sequence concatenation op.""" From 21030947a064b70cce4324cf3787f24b792eb481 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Mon, 30 Jan 2023 13:52:13 +0300 Subject: [PATCH 4/4] add assert --- python/tvm/relay/frontend/onnx.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 9eaa9ff58cc0..93429a863889 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -6168,10 +6168,13 @@ def _impl_v11(cls, inputs, attr, params): else: position = -1 + seq_len = len(input_sequence) + assert -seq_len <= position < seq_len, "Position is out of bounds" + if position < 0: - position = len(input_sequence) + position + position = seq_len + position # Convert sequence to a list, insert tensors before erased, and repackage as Tuple. - tensor_list = [input_sequence[i] for i in range(len(input_sequence)) if i != position] + tensor_list = [input_sequence[i] for i in range(seq_len) if i != position] # Create new tuple and return. return _expr.Tuple(tensor_list)