Skip to content

Commit

Permalink
[ONNX]Support Opset 13 split IFF the split is a constant (apache#9643)
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthew Brookhart authored and baoxinqi committed Dec 27, 2021
1 parent e175ea4 commit dfa2fad
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 3 deletions.
25 changes: 25 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1493,6 +1493,31 @@ def _impl_v1(cls, inputs, attr, params):
output = output[0]
return output

@classmethod
def _impl_v13(cls, inputs, attr, params):
splits = inputs[1]
splits_rank = None
if splits is not None:
splits_rank = len(infer_shape(splits))
if splits is not None and splits_rank > 0:
if isinstance(splits, _expr.Constant):
splits = splits.data.asnumpy()
indices = []
index = 0
for i in splits[:-1]:
index += i
indices.append(index)
else:
raise ValueError("Dynamic Split not yet supported")
# When splits isnt specified divide evenly over axis.
else:
indices = attr["tvm_custom"]["num_outputs"]
output = _op.split(inputs[0], indices, attr.get("axis", 0))
# If the output of split is a single value, unpack if from the TupleWrapper
if len(output) == 1:
output = output[0]
return output


class Slice(OnnxOpConverter):
"""Operator converter for Slice."""
Expand Down
18 changes: 15 additions & 3 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def get_tvm_output_with_vm(
if not isinstance(input_data, list):
input_data = [input_data]
_, shape_dict = get_input_data_shape_dict(graph_def, input_data)

mod, params = relay.frontend.from_onnx(
graph_def,
shape_dict,
Expand Down Expand Up @@ -167,7 +166,6 @@ def verify_with_ort_with_inputs(
model.opset_import[0].version = opset

ort_out = get_onnxruntime_output(model, inputs)

if use_vm:
tvm_out = get_tvm_output_with_vm(
model,
Expand Down Expand Up @@ -1954,7 +1952,9 @@ def verify_split(indata, outdatas, split, axis=0, pass_split=True, opset=11):
inputs.append(
helper.make_tensor_value_info("split", TensorProto.INT64, list(np_split.shape))
)
indata = [indata, np_split]
# TODO(mbrookhart): Support dynamic split, edit this test case to remove split from
# the initializer and add it back to the input data
indata = [indata] # , np_split]
initializer.append(
helper.make_tensor("split", TensorProto.INT64, list(np_split.shape), np_split)
)
Expand Down Expand Up @@ -1989,6 +1989,8 @@ def verify_split(indata, outdatas, split, axis=0, pass_split=True, opset=11):
opset=opset,
target=target,
dev=dev,
use_vm=True,
freeze_params=(opset >= 13),
)

# 1D
Expand All @@ -1997,13 +1999,23 @@ def verify_split(indata, outdatas, split, axis=0, pass_split=True, opset=11):
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], [2, 2, 2], 0, False
)
verify_split([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [[1.0, 2.0], [3.0], [4.0, 5.0, 6.0]], [2, 1, 3], 0)
verify_split(
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [[1.0, 2.0], [3.0], [4.0, 5.0, 6.0]], [2, 1, 3], 0, opset=13
)
# 2D
verify_split(
[[1.0, 2.0, 3.0, 4.0], [7.0, 8.0, 9.0, 10.0]],
[[[1.0, 2.0], [7.0, 8.0]], [[3.0, 4.0], [9.0, 10.0]]],
[2, 2],
1,
)
verify_split(
[[1.0, 2.0, 3.0, 4.0], [7.0, 8.0, 9.0, 10.0]],
[[[1.0, 2.0], [7.0, 8.0]], [[3.0, 4.0], [9.0, 10.0]]],
[2, 2],
1,
opset=13,
)
# Split evenly (unstack)
verify_split([1, 2, 3], [[1], [2], [3]], False, 0, False)
# Split a single value to a single value
Expand Down

0 comments on commit dfa2fad

Please sign in to comment.