diff --git a/python/tvm/relax/frontend/onnx_frontend.py b/python/tvm/relax/frontend/onnx_frontend.py index 23d2aca5a685..1081237c96d5 100644 --- a/python/tvm/relax/frontend/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx_frontend.py @@ -17,17 +17,18 @@ # pylint: disable=invalid-name, import-self, len-as-condition, unused-argument, too-many-lines # pylint: disable=import-outside-toplevel """ONNX: Open Neural Network Exchange frontend for Relax.""" -import copy import math import warnings -from typing import Optional +from typing import Union, Optional -import numpy as np +import numpy as _np import tvm from tvm import relax, topi from tvm.ir import IRModule from tvm.relax import testing +from tvm._ffi import base as _base +from tvm.runtime import ndarray as _nd def new_var(var_name, shape, dtype="float32"): @@ -284,7 +285,7 @@ def _impl_v13(cls, bb, inputs, attr): if -1 in new_shape: breakpoint() data_shape = [dim.value for dim in data.shape.values] - total_elements = np.prod(data_shape) + total_elements = _np.prod(data_shape) new_product = 1 for dim in new_shape: if dim > 0: @@ -452,17 +453,108 @@ class CumSum(OnnxOpConverter): @classmethod def _impl_v13(cls, bb, inputs, attr): - assert getattr(attr, "reverse", 0) == 0, "reverse is not supported yet" + data = inputs[0] if len(inputs) > 1: axis = int(inputs[1].data.numpy()) else: axis = None - return bb.emit_te( + if getattr(attr, "reverse", 0) != 0: + data = bb.emit_te(topi.flip, data, axis=axis if axis else 0) + data = bb.emit_te( topi.cumsum, - data=inputs[0], + data=data, axis=axis, exclusive=attr.get("exclusive", None), ) + if getattr(attr, "reverse", 0) != 0: + data = bb.emit_te(topi.flip, data, axis=axis if axis else 0) + return data + + +class Squeeze(OnnxOpConverter): + """Converts an onnx Squeeze node into an equivalent Relax expression.""" + + @classmethod + def _impl_v13(cls, bb, inputs, attr): + if len(inputs) > 1: + axis = [int(x) for x in inputs[1].data.numpy()] + else: + axis = None + return bb.emit_te(topi.squeeze, inputs[0], axis=axis) + + +class Constant(OnnxOpConverter): + """Converts an onnx Constant node into an equivalent Relax expression.""" + + @classmethod + def _impl_v13(cls, bb, inputs, attr): + def const( + value: Union[bool, int, float, _np.ndarray, tvm.nd.NDArray], + dtype: Optional[str] = None, + span: Optional[relax.Span] = None, + ): + """Create a constant value. + + Parameters + ---------- + value: Union[bool, int, float, numpy.ndarray, tvm.nd.NDArray] + The constant value. + + dtype: str, optional + The data type of the resulting constant. + + span: Optional[relax.Span] + Span that points to original source code. + + Note + ---- + When dtype is None, we use the following rule: + + - int maps to "int32" + - float maps to "float32" + - bool maps to "bool" + - other using the same default rule as numpy. + """ + if isinstance(value, (_base.numeric_types, (bool, list))): + value = _np.array(value, dtype=dtype) + + if not dtype: + # when dtype is None: int maps to "int32", float maps to "float32" + dtype = {_np.dtype("int64"): _np.int32, _np.dtype("float64"): _np.float32}.get( + value.dtype, None + ) + + if isinstance(value, (_np.ndarray, _np.generic)): + if dtype is not None: + value = value.astype(dtype) + value = _nd.array(value) + + if not isinstance(value, _nd.NDArray): + raise ValueError("value has to be scalar or NDArray") + + return relax.Constant(value, span) + + if "value" not in attr: + raise ValueError("no value in Constant") + value = attr.pop("value") + # Constants may rarely have string types. These are likely exported + # from other frameworks and not actually used in TVM. We'll just use + # a zero valued constant for compatibility. + if isinstance(value, bytes): + np_value = _np.asarray([0]).astype("int64") + else: + np_value = get_numpy(value) + dtype = np_value.dtype.name + value = const(np_value, dtype) + return value + + +class Sub(OnnxOpConverter): + """Converts an onnx Sub node into an equivalent Relax expression.""" + + @classmethod + def _impl_v13(cls, bb, inputs, attr): + return bb.emit_te(topi.subtract, inputs[0], inputs[1]) def _get_convert_map(opset): @@ -494,6 +586,9 @@ def _get_convert_map(opset): "Pow": Pow.get_converter(opset), "Erf": Erf.get_converter(opset), "CumSum": CumSum.get_converter(opset), + "Squeeze": Squeeze.get_converter(opset), + "Constant": Constant.get_converter(opset), + "Sub": Sub.get_converter(opset), } diff --git a/tests/python/relax/frontend/test_onnx_frontend.py b/tests/python/relax/frontend/test_onnx_frontend.py index d35355bde78a..46442e68429f 100644 --- a/tests/python/relax/frontend/test_onnx_frontend.py +++ b/tests/python/relax/frontend/test_onnx_frontend.py @@ -568,6 +568,61 @@ def test_cumsum(): check_correctness(model) +def test_squeeze(): + squeeze_node = helper.make_node("Squeeze", ["x", "axis"], ["y"]) + shape = [1, 32, 1, 32] + graph = helper.make_graph( + [squeeze_node], + "squeeze_test", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, shape), + ], + initializer=[helper.make_tensor("axis", TensorProto.INT64, [2], [0, 2])], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [32, 32])], + ) + + model = helper.make_model(graph, producer_name="squeeze_test") + check_correctness(model) + + +def test_const(): + shape = [32, 32] + const_node = helper.make_node( + "Constant", + [], + ["y"], + value=helper.make_tensor( + "value", TensorProto.FLOAT, shape, np.random.rand(*shape).astype(np.float32).flatten() + ), + ) + graph = helper.make_graph( + [const_node], + "const_test", + inputs=[], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, shape)], + ) + + model = helper.make_model(graph, producer_name="const_test") + check_correctness(model) + + +def test_sub(): + sub_node = helper.make_node("Sub", ["x", "y"], ["z"]) + shape = [32, 16] + graph = helper.make_graph( + [sub_node], + "sub_test", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, shape), + helper.make_tensor_value_info("y", TensorProto.FLOAT, shape), + ], + outputs=[helper.make_tensor_value_info("z", TensorProto.FLOAT, shape)], + ) + + model = helper.make_model(graph, producer_name="sub_test") + check_correctness(model) + + if __name__ == "__main__": test_matmul() test_concat() @@ -586,6 +641,9 @@ def test_cumsum(): test_pow() test_erf() test_cumsum() + test_squeeze() + test_const() + test_sub() # TODO, still has issues # test_reshape()