Skip to content

Commit

Permalink
Add Range op to ONNX, make tvm arange op support negative steps (apac…
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthew Brookhart authored and trevor-m committed Oct 19, 2020
1 parent f691355 commit ca395a7
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 1 deletion.
14 changes: 14 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1893,6 +1893,19 @@ def _impl_v1(cls, inputs, attr, params):
return _op.topk(inputs[0], inputs[1], axis=axis)


class Range(OnnxOpConverter):
"""Operator converter for Range"""

@classmethod
def _impl_v1(cls, inputs, attr, params):
if len(inputs) != 3:
raise ValueError("Expect 3 input only")

return _op.arange(
inputs[0], inputs[1], inputs[2], dtype=infer_type(inputs[0]).checked_type.dtype
)


class MaxRoiPool(OnnxOpConverter):
"""Operator converter for MaxRoiPool."""

Expand Down Expand Up @@ -2115,6 +2128,7 @@ def _get_convert_map(opset):
"Or": Or.get_converter(opset),
"Resize": Resize.get_converter(opset),
"NonZero": NonZero.get_converter(opset),
"Range": Range.get_converter(opset),
}


Expand Down
5 changes: 4 additions & 1 deletion python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,10 @@ def compute_scatter_add(attrs, inputs, output_type):
@script
def _arange_shape_func(start, stop, step):
out = output_tensor((1,), "int64")
out[0] = int64(ceil_div((int64(stop[0]) - int64(start[0])), int64(step[0])))
if step[0] < 0:
out[0] = int64(ceil_div((int64(start[0]) - int64(stop[0])), int64(-step[0])))
else:
out[0] = int64(ceil_div((int64(stop[0]) - int64(start[0])), int64(step[0])))
return out


Expand Down
36 changes: 36 additions & 0 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,42 @@ def test_power():
_test_power_iteration((2, 3), (1, 3))


def verify_range(start, limit, delta, dtype):
dtype_map = {
"float32": TensorProto.FLOAT,
"int32": TensorProto.INT32,
"int64": TensorProto.INT64,
}
dtype_onnx = dtype_map[dtype]
y = helper.make_node("Range", ["start", "limit", "delta"], ["output"])
graph = helper.make_graph(
[y],
"range_test",
inputs=[
helper.make_tensor_value_info("start", dtype_onnx, []),
helper.make_tensor_value_info("limit", dtype_onnx, []),
helper.make_tensor_value_info("delta", dtype_onnx, []),
],
outputs=[
helper.make_tensor_value_info(
"output", dtype_onnx, np.arange(start, limit, delta).shape
)
],
)
model = helper.make_model(graph, producer_name="range_test")
inputs = [np.array(x).astype(dtype) for x in [start, limit, delta]]
verify_with_ort_with_inputs(model, inputs, use_vm=True)


@tvm.testing.uses_gpu
def test_range():
for t in ["float32", "int32", "int64"]:
verify_range(0, 10, 1, t)
verify_range(2, 8, 2, t)
verify_range(-3, 6, 4, t)
verify_range(-2, -7, -1, t)


@tvm.testing.uses_gpu
def test_squeeze():
in_shape = (1, 3, 1, 3, 1, 1)
Expand Down

0 comments on commit ca395a7

Please sign in to comment.