Skip to content

Commit

Permalink
[Relay][Frontend][ONNX] Add ConvInteger support. (apache#8456)
Browse files Browse the repository at this point in the history
* Add ConvInteger support and fix some ConvTranspose padding bugs.

* Simplify pads check.

* Fix style.

* Remove changes to conv_transpose.
  • Loading branch information
Josh Fromm authored and ylc committed Sep 29, 2021
1 parent e3d3bcc commit 7c92860
Show file tree
Hide file tree
Showing 2 changed files with 234 additions and 3 deletions.
79 changes: 78 additions & 1 deletion python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,10 @@ def autopad(
# pad N and C with zeros
pad = _op.concatenate([_op.const(np.zeros([2, 2], dtype="int64"), dtype="int64"), pad], axis=0)

return _op.nn.pad(data, fold_constant(pad), _op.const(pad_value), pad_type)
if isinstance(pad_value, (float, int)):
pad_value = _op.const(pad_value)

return _op.nn.pad(data, fold_constant(pad), pad_value, pad_type)


class Conv(OnnxOpConverter):
Expand Down Expand Up @@ -3202,6 +3205,79 @@ def get_scalar(x, dtype="float32"):
return _qnn.op.quantize(out, c_scale, c_zero_point, out_dtype=dtype)


class ConvInteger(OnnxOpConverter):
"""Operator converter for ConvInteger."""

@classmethod
def _impl_v10(cls, inputs, attr, params):
data = inputs[0]
weight = inputs[1]
data_zp = inputs[2]
weight_zp = inputs[3]
if data_zp is None:
data_zp = _expr.const(0, "int32")
if weight_zp is None:
weight_zp = _expr.const(0, "int32")

input_type = infer_type(data)
input_shape = get_const_tuple(input_type.checked_type.shape)

ndim = len(input_shape)
kernel_type = infer_type(weight)
kernel_shape = get_const_tuple(kernel_type.checked_type.shape)
if "kernel_shape" not in attr:
attr["kernel_shape"] = kernel_shape[2:]

if "auto_pad" in attr:
attr["auto_pad"] = attr["auto_pad"].decode("utf-8")
if attr["auto_pad"] in ("SAME_UPPER", "SAME_LOWER"):
# Warning: Convolution does not yet support dynamic shapes,
# one will need to run dynamic_to_static on this model after import
data = autopad(
data,
attr.get("strides", [1] * (ndim - 2)),
attr["kernel_shape"],
attr.get("dilations", [1] * (ndim - 2)),
ndim,
pad_value=data_zp,
mode=attr["auto_pad"],
)
elif attr["auto_pad"] == "VALID":
attr["pads"] = tuple([0 for i in range(ndim - 2)])
elif attr["auto_pad"] == "NOTSET":
pass
else:
msg = 'Value {} in attribute "auto_pad" of operator Conv is invalid.'
raise tvm.error.OpAttributeInvalid(msg.format(attr["auto_pad"]))
attr.pop("auto_pad")

out_channels = kernel_shape[0]
dilation = attr.get("dilations", [1] * (ndim - 2))
strides = attr.get("strides", [1] * (ndim - 2))
padding = attr["pads"] if "pads" in attr else 0
groups = attr["group"] if "group" in attr else 1

if ndim != 4:
raise tvm.error.OpAttributeInvalid(
"Only 2D kernels are supported for operator ConvInteger."
)

return _qnn.op.conv2d(
data,
weight,
_op.cast(data_zp, "int32"),
_op.cast(weight_zp, "int32"),
_expr.const(1.0, "float32"),
_expr.const(1.0, "float32"),
kernel_size=attr["kernel_shape"],
channels=out_channels,
strides=strides,
padding=padding,
dilation=dilation,
groups=groups,
)


class BitShift(OnnxOpConverter):
"""Operator converter for NonZero"""

Expand Down Expand Up @@ -3430,6 +3506,7 @@ def _get_convert_map(opset):
"ReverseSequence": ReverseSequence.get_converter(opset),
"QLinearConv": QLinearConv.get_converter(opset),
"QLinearAdd": QLinearAdd.get_converter(opset),
"ConvInteger": ConvInteger.get_converter(opset),
}


Expand Down
158 changes: 156 additions & 2 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -4420,15 +4420,13 @@ def verify_eyelike(indata):
onnx_test_folders = sorted(glob.glob("/".join(f.split("/")[0:-1]) + "/backend/test/data/node/*/"))

unsupported_onnx_tests = [
"test_basic_convinteger/",
"test_cast_DOUBLE_to_FLOAT16/",
"test_cast_FLOAT_to_STRING/",
"test_cast_STRING_to_FLOAT/",
"test_compress_0/",
"test_compress_1/",
"test_compress_default_axis/",
"test_compress_negative_axis/",
"test_convinteger_with_padding/",
"test_convtranspose_dilations/",
"test_convtranspose_output_shape/",
"test_cumsum_1d/",
Expand Down Expand Up @@ -4872,6 +4870,161 @@ def test_qlinearadd():
verify_qlinearadd([5, 1, 7], [2, 7], [5, 2, 7])


def verify_convinteger(
x_shape,
w_shape,
y_shape,
padding,
kernel_shape,
strides,
dilations,
auto_pad="NOTSET",
dtype="uint8",
):

x_array = np.random.randint(low=0, high=255, size=x_shape).astype(dtype)
w_array = np.random.uniform(low=0, high=255, size=w_shape).astype(dtype)
x_zero_point_array = np.random.randint(0, 255, size=[]).astype(dtype)
w_zero_point_array = np.random.randint(0, 255, size=[]).astype(dtype)

ONNX_DTYPE = mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)]
input_nodes = [
helper.make_tensor_value_info("x", ONNX_DTYPE, list(x_shape)),
helper.make_tensor_value_info("w", ONNX_DTYPE, list(w_shape)),
helper.make_tensor_value_info("x_zero_point", ONNX_DTYPE, []),
helper.make_tensor_value_info("w_zero_point", ONNX_DTYPE, []),
]
input_names = [
"x",
"w",
"x_zero_point",
"w_zero_point",
]
input_values = [x_array, w_array, x_zero_point_array, w_zero_point_array]

if padding is None:
## autopadding with unset default attributes
kwargs = {}
if not all([s == 1 for s in strides]):
kwargs["strides"] = strides
if not all([d == 1 for d in dilations]):
kwargs["dilations"] = dilations

node = helper.make_node(
"ConvInteger",
inputs=input_names,
outputs=["y"],
# Default values for other attributes:
auto_pad=auto_pad,
**kwargs,
)
else:
node = helper.make_node(
"ConvInteger",
inputs=input_names,
outputs=["y"],
kernel_shape=kernel_shape,
# Default values for other attributes:
strides=strides,
dilations=dilations,
# groups=1
pads=padding,
)

graph = helper.make_graph(
[node],
"convinteger_test",
inputs=input_nodes,
outputs=[helper.make_tensor_value_info("y", TensorProto.INT32, list(y_shape))],
)
model = helper.make_model(graph, producer_name="convinteger_test")
# opt_level=1 will cause error
verify_with_ort_with_inputs(model, input_values, opt_level=2)


def test_convinteger():
def repeat(N, D):
return tuple([N for _ in range(D)])

# only support 2D ConvInteger because we only support qnn.conv2d for now.
D = 2

# Convolution with padding
verify_convinteger(
(1, 1) + repeat(5, D),
(1, 1) + repeat(3, D),
(1, 1) + repeat(5, D),
2 * repeat(1, D),
repeat(3, D),
repeat(1, D),
repeat(1, D),
)

# Convolution with asymmetric padding
verify_convinteger(
(1, 1) + repeat(5, D),
(1, 1) + repeat(3, D),
(1, 1) + repeat(4, D),
repeat(0, D) + repeat(1, D),
repeat(3, D),
repeat(1, D),
repeat(1, D),
)
# Convolution without padding
verify_convinteger(
(1, 1) + repeat(5, D),
(1, 1) + repeat(3, D),
(1, 1) + repeat(3, D),
2 * repeat(0, D),
repeat(3, D),
repeat(1, D),
repeat(1, D),
)
# Convolution with autopadding
verify_convinteger(
(1, 1) + repeat(5, D),
(1, 1) + repeat(3, D),
(1, 1) + repeat(5, D),
None,
repeat(3, D),
repeat(1, D),
repeat(1, D),
auto_pad="SAME_UPPER",
)
# Convolution with valid autopadding
verify_convinteger(
(1, 1) + repeat(5, D),
(1, 1) + repeat(3, D),
(1, 1) + repeat(3, D),
None,
repeat(3, D),
repeat(1, D),
repeat(1, D),
auto_pad="VALID",
)
# Convolution with non uniform stride
verify_convinteger(
(1, 1) + repeat(5, D),
(1, 1) + repeat(3, D),
(1, 1) + repeat(3, D),
None,
repeat(3, D),
repeat(2, D),
repeat(1, D),
auto_pad="SAME_UPPER",
)
# Convolution with dilation
verify_convinteger(
(1, 1) + repeat(5, D),
(1, 1) + repeat(3, D),
(1, 1) + repeat(5, D),
2 * repeat(2, D),
repeat(3, D),
repeat(1, D),
repeat(2, D),
)


if __name__ == "__main__":
test_flatten()
test_reshape()
Expand Down Expand Up @@ -4955,4 +5108,5 @@ def test_qlinearadd():
test_reverse_sequence()
test_eyelike()
test_qlinearconv()
test_convinteger()
test_batch_matmul()

0 comments on commit 7c92860

Please sign in to comment.