diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 5faaa32295d9..aafc301be555 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -441,7 +441,10 @@ def autopad( # pad N and C with zeros pad = _op.concatenate([_op.const(np.zeros([2, 2], dtype="int64"), dtype="int64"), pad], axis=0) - return _op.nn.pad(data, fold_constant(pad), _op.const(pad_value), pad_type) + if isinstance(pad_value, (float, int)): + pad_value = _op.const(pad_value) + + return _op.nn.pad(data, fold_constant(pad), pad_value, pad_type) class Conv(OnnxOpConverter): @@ -3202,6 +3205,79 @@ def get_scalar(x, dtype="float32"): return _qnn.op.quantize(out, c_scale, c_zero_point, out_dtype=dtype) +class ConvInteger(OnnxOpConverter): + """Operator converter for ConvInteger.""" + + @classmethod + def _impl_v10(cls, inputs, attr, params): + data = inputs[0] + weight = inputs[1] + data_zp = inputs[2] + weight_zp = inputs[3] + if data_zp is None: + data_zp = _expr.const(0, "int32") + if weight_zp is None: + weight_zp = _expr.const(0, "int32") + + input_type = infer_type(data) + input_shape = get_const_tuple(input_type.checked_type.shape) + + ndim = len(input_shape) + kernel_type = infer_type(weight) + kernel_shape = get_const_tuple(kernel_type.checked_type.shape) + if "kernel_shape" not in attr: + attr["kernel_shape"] = kernel_shape[2:] + + if "auto_pad" in attr: + attr["auto_pad"] = attr["auto_pad"].decode("utf-8") + if attr["auto_pad"] in ("SAME_UPPER", "SAME_LOWER"): + # Warning: Convolution does not yet support dynamic shapes, + # one will need to run dynamic_to_static on this model after import + data = autopad( + data, + attr.get("strides", [1] * (ndim - 2)), + attr["kernel_shape"], + attr.get("dilations", [1] * (ndim - 2)), + ndim, + pad_value=data_zp, + mode=attr["auto_pad"], + ) + elif attr["auto_pad"] == "VALID": + attr["pads"] = tuple([0 for i in range(ndim - 2)]) + elif attr["auto_pad"] == "NOTSET": + pass + else: + msg = 'Value {} in attribute "auto_pad" of operator Conv is invalid.' + raise tvm.error.OpAttributeInvalid(msg.format(attr["auto_pad"])) + attr.pop("auto_pad") + + out_channels = kernel_shape[0] + dilation = attr.get("dilations", [1] * (ndim - 2)) + strides = attr.get("strides", [1] * (ndim - 2)) + padding = attr["pads"] if "pads" in attr else 0 + groups = attr["group"] if "group" in attr else 1 + + if ndim != 4: + raise tvm.error.OpAttributeInvalid( + "Only 2D kernels are supported for operator ConvInteger." + ) + + return _qnn.op.conv2d( + data, + weight, + _op.cast(data_zp, "int32"), + _op.cast(weight_zp, "int32"), + _expr.const(1.0, "float32"), + _expr.const(1.0, "float32"), + kernel_size=attr["kernel_shape"], + channels=out_channels, + strides=strides, + padding=padding, + dilation=dilation, + groups=groups, + ) + + class BitShift(OnnxOpConverter): """Operator converter for NonZero""" @@ -3430,6 +3506,7 @@ def _get_convert_map(opset): "ReverseSequence": ReverseSequence.get_converter(opset), "QLinearConv": QLinearConv.get_converter(opset), "QLinearAdd": QLinearAdd.get_converter(opset), + "ConvInteger": ConvInteger.get_converter(opset), } diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 2d9666706a5d..049ca1e0cfe0 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -4420,7 +4420,6 @@ def verify_eyelike(indata): onnx_test_folders = sorted(glob.glob("/".join(f.split("/")[0:-1]) + "/backend/test/data/node/*/")) unsupported_onnx_tests = [ - "test_basic_convinteger/", "test_cast_DOUBLE_to_FLOAT16/", "test_cast_FLOAT_to_STRING/", "test_cast_STRING_to_FLOAT/", @@ -4428,7 +4427,6 @@ def verify_eyelike(indata): "test_compress_1/", "test_compress_default_axis/", "test_compress_negative_axis/", - "test_convinteger_with_padding/", "test_convtranspose_dilations/", "test_convtranspose_output_shape/", "test_cumsum_1d/", @@ -4872,6 +4870,161 @@ def test_qlinearadd(): verify_qlinearadd([5, 1, 7], [2, 7], [5, 2, 7]) +def verify_convinteger( + x_shape, + w_shape, + y_shape, + padding, + kernel_shape, + strides, + dilations, + auto_pad="NOTSET", + dtype="uint8", +): + + x_array = np.random.randint(low=0, high=255, size=x_shape).astype(dtype) + w_array = np.random.uniform(low=0, high=255, size=w_shape).astype(dtype) + x_zero_point_array = np.random.randint(0, 255, size=[]).astype(dtype) + w_zero_point_array = np.random.randint(0, 255, size=[]).astype(dtype) + + ONNX_DTYPE = mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)] + input_nodes = [ + helper.make_tensor_value_info("x", ONNX_DTYPE, list(x_shape)), + helper.make_tensor_value_info("w", ONNX_DTYPE, list(w_shape)), + helper.make_tensor_value_info("x_zero_point", ONNX_DTYPE, []), + helper.make_tensor_value_info("w_zero_point", ONNX_DTYPE, []), + ] + input_names = [ + "x", + "w", + "x_zero_point", + "w_zero_point", + ] + input_values = [x_array, w_array, x_zero_point_array, w_zero_point_array] + + if padding is None: + ## autopadding with unset default attributes + kwargs = {} + if not all([s == 1 for s in strides]): + kwargs["strides"] = strides + if not all([d == 1 for d in dilations]): + kwargs["dilations"] = dilations + + node = helper.make_node( + "ConvInteger", + inputs=input_names, + outputs=["y"], + # Default values for other attributes: + auto_pad=auto_pad, + **kwargs, + ) + else: + node = helper.make_node( + "ConvInteger", + inputs=input_names, + outputs=["y"], + kernel_shape=kernel_shape, + # Default values for other attributes: + strides=strides, + dilations=dilations, + # groups=1 + pads=padding, + ) + + graph = helper.make_graph( + [node], + "convinteger_test", + inputs=input_nodes, + outputs=[helper.make_tensor_value_info("y", TensorProto.INT32, list(y_shape))], + ) + model = helper.make_model(graph, producer_name="convinteger_test") + # opt_level=1 will cause error + verify_with_ort_with_inputs(model, input_values, opt_level=2) + + +def test_convinteger(): + def repeat(N, D): + return tuple([N for _ in range(D)]) + + # only support 2D ConvInteger because we only support qnn.conv2d for now. + D = 2 + + # Convolution with padding + verify_convinteger( + (1, 1) + repeat(5, D), + (1, 1) + repeat(3, D), + (1, 1) + repeat(5, D), + 2 * repeat(1, D), + repeat(3, D), + repeat(1, D), + repeat(1, D), + ) + + # Convolution with asymmetric padding + verify_convinteger( + (1, 1) + repeat(5, D), + (1, 1) + repeat(3, D), + (1, 1) + repeat(4, D), + repeat(0, D) + repeat(1, D), + repeat(3, D), + repeat(1, D), + repeat(1, D), + ) + # Convolution without padding + verify_convinteger( + (1, 1) + repeat(5, D), + (1, 1) + repeat(3, D), + (1, 1) + repeat(3, D), + 2 * repeat(0, D), + repeat(3, D), + repeat(1, D), + repeat(1, D), + ) + # Convolution with autopadding + verify_convinteger( + (1, 1) + repeat(5, D), + (1, 1) + repeat(3, D), + (1, 1) + repeat(5, D), + None, + repeat(3, D), + repeat(1, D), + repeat(1, D), + auto_pad="SAME_UPPER", + ) + # Convolution with valid autopadding + verify_convinteger( + (1, 1) + repeat(5, D), + (1, 1) + repeat(3, D), + (1, 1) + repeat(3, D), + None, + repeat(3, D), + repeat(1, D), + repeat(1, D), + auto_pad="VALID", + ) + # Convolution with non uniform stride + verify_convinteger( + (1, 1) + repeat(5, D), + (1, 1) + repeat(3, D), + (1, 1) + repeat(3, D), + None, + repeat(3, D), + repeat(2, D), + repeat(1, D), + auto_pad="SAME_UPPER", + ) + # Convolution with dilation + verify_convinteger( + (1, 1) + repeat(5, D), + (1, 1) + repeat(3, D), + (1, 1) + repeat(5, D), + 2 * repeat(2, D), + repeat(3, D), + repeat(1, D), + repeat(2, D), + ) + + if __name__ == "__main__": test_flatten() test_reshape() @@ -4955,4 +5108,5 @@ def test_qlinearadd(): test_reverse_sequence() test_eyelike() test_qlinearconv() + test_convinteger() test_batch_matmul()