diff --git a/src/relay/qnn/op/convolution.cc b/src/relay/qnn/op/convolution.cc index 5ebd9b951855..91739e63e3a6 100644 --- a/src/relay/qnn/op/convolution.cc +++ b/src/relay/qnn/op/convolution.cc @@ -130,17 +130,17 @@ WorkloadType GetWorkload(const Array& arg_types, const Conv2DA } /* - * \brief Fallback to simpler lowering for dilation or grouped conv. + * \brief Fallback to simpler lowering for dilation (when non-zero kernel point) or grouped conv. * \param data The input expr. * \param weight The weight expr. * \param input_zero_point The input zero point expr. * \param kernel_zero_point The kernel zero point expr. * \param param The qnn conv2d attributes. * \return The fallback lowered sequence of Relay expr. - * \note In case of dilation, normal lowering would require a dilated pool. - * Since, we don't have dilated pool, we fallback to a simpler sequence of - * Relay operations. This will potentially lead to performance degradation - * as the convolution is called on int32 tensors instead of int8 tensors. + * \note In case of dilation with non-zero kernel zero point, normal lowering would require a + * dilated pool. Since, we don't have dilated pool, we fallback to a simpler sequence of Relay + * operations. This will potentially lead to performance degradation as the convolution is called on + * int32 tensors instead of int8 tensors. */ Expr Conv2DFallBack(const Expr& data, const Expr& weight, const Expr& input_zero_point, const Expr& kernel_zero_point, const Conv2DAttrs* param) { @@ -598,12 +598,16 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array& new_args, auto input_zero_point_int = GetScalarFromConstant(input_zero_point); auto kernel_zero_point_int = GetScalarFromConstant(kernel_zero_point); - // Fallback to int32 conv if there is dilation or grouped conv2d + // Fallback to int32 conv if there is dilation with non-zero kernel point or grouped conv2d + // For dilated conv, if the kernel zero point is non-zero, the pooling operator also has to + // traverse the elements in dilated manner. Currently, we do not have strided pool. So, in case of + // dilated conv with non-zero kernel point, we fall back to simpler but slow lowering. CHECK_EQ(param->dilation.size(), 2) << "qnn.conv2d only supports 2D dilation"; auto dilation_h = get_const_int(param->dilation[0]); auto dilation_w = get_const_int(param->dilation[1]); - if (dilation_h != 1 || dilation_w != 1 || (param->groups != 1 && !is_depthwise(param))) { + if ((kernel_zero_point_int != 0 && (dilation_h != 1 || dilation_w != 1)) || + (param->groups != 1 && !is_depthwise(param))) { return Conv2DFallBack(data, weight, input_zero_point, kernel_zero_point, param); } else if (is_depthwise(param)) { CHECK_NE(channel_multiplier, -1); diff --git a/tests/python/relay/test_op_qnn_conv2d.py b/tests/python/relay/test_op_qnn_conv2d.py index 9631ffc2faf3..ced12c843563 100644 --- a/tests/python/relay/test_op_qnn_conv2d.py +++ b/tests/python/relay/test_op_qnn_conv2d.py @@ -495,7 +495,7 @@ def test_padding(): def test_dilation(): with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d): - # uint8 input + # Non-zero kernel point - fall back to simpler lowering. data_shape = (2, 4, 4, 4) data_dtype = 'uint8' kernel_shape = (3, 4, 2, 2) @@ -518,6 +518,29 @@ def test_dilation(): verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) + # Zero kernel point + data_shape = (2, 4, 4, 4) + data_dtype = 'uint8' + kernel_shape = (3, 4, 2, 2) + kernel_dtype = 'uint8' + ref_func, qnn_func = get_funcs(data_shape=data_shape, + data_dtype=data_dtype, + kernel_shape=kernel_shape, + kernel_dtype=kernel_dtype, + input_zero_point=0, + kernel_zero_point=0, + input_scale=1.0, + kernel_scale=1.0, + kernel_size=(2, 2), + padding=(0, 0), + strides=(1, 1), + dilation=(2, 2), + data_layout="NCHW", + kernel_layout="OIHW", + out_dtype="int32") + verify(ref_func, qnn_func, data_shape, data_dtype, + kernel_shape, kernel_dtype) + def test_const_folding(): with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):