Skip to content

Commit

Permalink
add SequenceErase node to Sequence test
Browse files Browse the repository at this point in the history
  • Loading branch information
Valery Chernov committed Jan 30, 2023
1 parent ddb8f69 commit f65f9f1
Showing 1 changed file with 17 additions and 2 deletions.
19 changes: 17 additions & 2 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -7747,10 +7747,17 @@ def verify_sequence_ops(tensor_shape, num_tensors, axis=0, position=0, new_axis=
outputs=["inserted_sequence"],
)

# Test sequence erase.
erase_node = helper.make_node(
"SequenceErase",
inputs=["inserted_sequence", "position"],
outputs=["erased_sequence"],
)

# Test sequence concatenation.
concat_node = helper.make_node(
"ConcatFromSequence",
inputs=["inserted_sequence"],
inputs=["erased_sequence"],
outputs=["concat_sequence"],
axis=axis,
)
Expand Down Expand Up @@ -7783,7 +7790,15 @@ def verify_sequence_ops(tensor_shape, num_tensors, axis=0, position=0, new_axis=
output_shape[axis] = (num_tensors + 1) * output_shape[axis]
graph_outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, output_shape)]

graph_nodes = [position_node, construct_node, insert_node, concat_node, split_node, at_node]
graph_nodes = [
position_node,
construct_node,
insert_node,
erase_node,
concat_node,
split_node,
at_node,
]

graph = helper.make_graph(
graph_nodes,
Expand Down

0 comments on commit f65f9f1

Please sign in to comment.