diff --git a/src/relay/pass/pattern_util.h b/src/relay/pass/pattern_util.h index 525008be38d88..0921e13095474 100644 --- a/src/relay/pass/pattern_util.h +++ b/src/relay/pass/pattern_util.h @@ -503,6 +503,8 @@ static inline Expr Tile(Expr data, Array reps) { Expr MakeConcatenate(Expr data, int axis); +Expr MakeRepeat(Expr data, int repeats, int axis); + Expr MakeStridedSlice(Expr data, Array begin, Array end, Array strides); Expr MakeStack(Expr data, int axis); diff --git a/src/relay/qnn/op/convolution.cc b/src/relay/qnn/op/convolution.cc index 9cbb415e61313..88a369c9731bb 100644 --- a/src/relay/qnn/op/convolution.cc +++ b/src/relay/qnn/op/convolution.cc @@ -39,9 +39,7 @@ namespace qnn { // relay.op.qnn.conv2d TVM_REGISTER_NODE_TYPE(QnnConv2DAttrs); -bool QnnConv2DRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool QnnConv2DRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); const auto* data = types[0].as(); @@ -50,17 +48,22 @@ bool QnnConv2DRel(const Array& types, const auto* param = attrs.as(); CHECK(param != nullptr) << "QnnConv2DAttrs cannot be nullptr."; CHECK(data->dtype == Int(8) || data->dtype == UInt(8)) - << "Expected qnn conv2d type(int8, uint8) for input but was " << data->dtype; + << "Expected qnn conv2d type(int8, uint8) for input but was " << data->dtype; CHECK(weight->dtype == Int(8) || weight->dtype == UInt(8)) - << "Expected qnn conv2d type(int8, uint8) for weight but was " << weight->dtype; + << "Expected qnn conv2d type(int8, uint8) for weight but was " << weight->dtype; CHECK(param->out_dtype == Int(16) || param->out_dtype == Int(32)) - << "Expected qnn conv2d type(int32, int16) for output but was " << param->out_dtype; + << "Expected qnn conv2d type(int32, int16) for output but was " << param->out_dtype; CHECK(param->out_dtype.bits() > 0) << "Output dtype bits should be greater than 0."; return Conv2DRel(types, num_inputs, attrs, reporter); } -// Workload - batch_size, in_channels, out_channels, kernel_h, kernel_w -using WorkloadType = std::tuple; +bool is_depthwise(const QnnConv2DAttrs* param) { + return param->channels.defined() && tvm::ir::Equal(param->channels, param->groups) && + param->groups != 1; +} + +// Workload - batch_size, in_channels, out_channels, kernel_h, kernel_w, channel_multiplier +using WorkloadType = std::tuple; /* * \brief Get the conv parameters like batch_size, kernel_height etc. @@ -84,26 +87,38 @@ WorkloadType GetWorkload(const Array& arg_types, const QnnConv const auto kernel_shape = get_shape(arg_types[1]); int out_channels, kernel_h, kernel_w; + int channel_multiplier = -1; if (param->kernel_layout == "OIHW") { out_channels = get_const_int(kernel_shape[0]); kernel_h = get_const_int(kernel_shape[2]); kernel_w = get_const_int(kernel_shape[3]); + if (is_depthwise(param)) { + channel_multiplier = get_const_int(kernel_shape[1]); + } } else if (param->kernel_layout == "HWIO") { kernel_h = get_const_int(kernel_shape[0]); kernel_w = get_const_int(kernel_shape[1]); out_channels = get_const_int(kernel_shape[3]); + if (is_depthwise(param)) { + channel_multiplier = get_const_int(kernel_shape[2]); + } } else if (param->kernel_layout == "HWOI") { kernel_h = get_const_int(kernel_shape[0]); kernel_w = get_const_int(kernel_shape[1]); out_channels = get_const_int(kernel_shape[2]); + if (is_depthwise(param)) { + channel_multiplier = get_const_int(kernel_shape[3]); + } } else { LOG(FATAL) << "qnn.conv2d does not support " << param->kernel_layout << " layout"; } - return std::make_tuple(batch_size, in_channels, out_channels, kernel_h, kernel_w); + + return std::make_tuple(batch_size, in_channels, out_channels, kernel_h, kernel_w, + channel_multiplier); } /* - * \brief Fallback to simpler lowering for dilation or depthwise conv. + * \brief Fallback to simpler lowering for dilation or grouped conv. * \param data The input expr. * \param weight The weight expr. * \param param The qnn conv2d attributes. @@ -166,6 +181,130 @@ Expr Conv2DPadInput(const Expr& data, const QnnConv2DAttrs* param) { return padded_data; } +/* + * \brief Calculates the second term in the qnn.conv2d depthwise lowering sequence. + * \param padded_data The padded data expr. + * \param param The qnn conv2d attributes. + * \param kernel_h The height of kernel. + * \param kernel_w The width of kernel. + * \param channel_multiplier The channel/depth multiplier. + * \return The sequence of Relay operators for term2. + * \note The term2 looks like this + * + * Sigma(r, s) zp_w * Qa(n, oc/cm, oh + r, ow + s) + * + * Second term is not directly representable by one Relay operator. + * However, deeper analysis shows that we can reduce r,s using avg_pool2d, + * followed by repeat on the C axis by cm times. + */ +Expr DepthwiseConv2DSecondTerm(const Expr& padded_data, const QnnConv2DAttrs* param, int kernel_h, + int kernel_w, int channel_multiplier) { + // Constant Expr for the kernel zero point. + auto zp_kernel = MakeConstantScalar(Int(32), param->kernel_zero_point); + + auto casted_t2 = Cast(padded_data, Int(32)); + + // We can reduce the H and W axis by using avg_pool2d. However, avg_pool2d averages the sum. + // Since, this is integer division (floor), we can first multiply the data by the pool_size and + // then perform avg_pool2d. Reversing this causes inaccuracy due to floor division. + auto scaled_hw_t2 = Multiply(casted_t2, MakeConstantScalar(Int(32), kernel_h * kernel_w)); + Array padding({0, 0}); + + // If the pool_size is 1x1, we don't need avg_pool2d. + auto reduced_t2 = scaled_hw_t2; + if (kernel_h * kernel_w != 1) { + reduced_t2 = + AvgPool2D(scaled_hw_t2, param->kernel_size, param->strides, padding, param->data_layout, + false, // ceil_mode + false); // count_include_pad + } + + auto multiplied_t2 = reduced_t2; + if (param->kernel_zero_point != 1) { + multiplied_t2 = Multiply(zp_kernel, reduced_t2); + } + + // Reduce the C dimension. Find the dimension. + int axis_t2 = 0; + if (param->data_layout == "NCHW") { + axis_t2 = 1; + } else if (param->data_layout == "NHWC") { + axis_t2 = 3; + } else { + LOG(FATAL) << "qnn.conv2d does not support " << param->data_layout << " layout"; + } + auto repeated_t2 = multiplied_t2; + if (channel_multiplier != 1) { + repeated_t2 = MakeRepeat(multiplied_t2, channel_multiplier, axis_t2); + } + return repeated_t2; +} + +/* + * \brief Calculates the third term in the qnn.conv2d depthwise lowering sequence. + * \param weight The weight expr. + * \param param The qnn conv2d attributes. + * \param out_channels The number of output channels. + * \param channel_multiplier The channel/depth multiplier. + * \return The sequence of Relay operatos for term3. + * \note The term3 looks like this + * + * Sigma(r, s) zp_a * Qw(oc/m, oc%m, r, s) + * + * This can be achieved by calling reduce on r and s axis. The tensor can be then reshaped to + * (1, oc, 1, 1) as (oc/m, oc%m) are just contiguous memory locations. + */ +Expr DepthwiseConv2DThirdTerm(const Expr& weight, const QnnConv2DAttrs* param, int out_channels, + int channel_multiplier) { + // Constant expr for input zero point. + auto zp_data = MakeConstantScalar(Int(32), param->input_zero_point); + + // Find which dimensions are R, S. + Array axes_t3; + if (param->kernel_layout == "OIHW") { + // For OIHW kernel layout, HW are reduce axis + axes_t3 = {2, 3}; + } else if (param->kernel_layout == "HWIO") { + axes_t3 = {0, 1}; + } else if (param->kernel_layout == "HWOI") { + axes_t3 = {0, 1}; + } else { + LOG(FATAL) << "qnn.conv2d does not support " << param->kernel_layout << " layout"; + } + auto reduced_t3 = Sum(Cast(weight, Int(32)), axes_t3, false, false); + + // Find the newshape depending on NCHW/NHWC layout. + Array newshape; + if (param->data_layout == "NCHW") { + newshape = {1, out_channels * channel_multiplier, 1, 1}; + } else if (param->data_layout == "NHWC") { + newshape = {1, 1, 1, out_channels * channel_multiplier}; + } else { + LOG(FATAL) << "qnn.conv2d does not support " << param->data_layout << " layout"; + } + auto reshaped_t3 = Reshape(reduced_t3, newshape); + + if (param->input_zero_point == 1) { + return reshaped_t3; + } + return Multiply(zp_data, reshaped_t3); +} + +/* + * \brief Calculates the fourth term in the qnn.conv2d depthwise lowering sequence. + * \param param The qnn conv2d attributes. + * \param kernel_h The height of kernel. + * \param kernel_w The width of kernel. + * \return The sequence of Relay operators for term4. + * \note The term4 looks like this + * + * Sigma(r, s) zp_a * zp_w + */ +Expr DepthwiseConv2DFourthTerm(const QnnConv2DAttrs* param, int kernel_h, int kernel_w) { + int scalar_term4 = param->input_zero_point * param->kernel_zero_point * kernel_h * kernel_w; + return MakeConstantScalar(Int(32), scalar_term4); +} + /* * \brief Calculates the first term in the qnn.conv2d lowering sequence. * \param data The input expr. @@ -245,7 +384,6 @@ Expr Conv2DSecondTerm(const Expr& padded_data, const QnnConv2DAttrs* param, int * \brief Calculates the third term in the qnn.conv2d lowering sequence. * \param weight The weight expr. * \param param The qnn conv2d attributes. - * \param batch_size The batch size. * \param out_channels The number of output channels. * \return The sequence of Relay operatos for term3. * \note The term3 looks like this @@ -256,8 +394,7 @@ Expr Conv2DSecondTerm(const Expr& padded_data, const QnnConv2DAttrs* param, int * a 1D tensor. The tensor is then reshaped to conform to NHWC/NCHW * format. */ -Expr Conv2DThirdTerm(const Expr& weight, const QnnConv2DAttrs* param, int batch_size, - int out_channels) { +Expr Conv2DThirdTerm(const Expr& weight, const QnnConv2DAttrs* param, int out_channels) { // Constant expr for input zero point. auto zp_data = MakeConstantScalar(Int(32), param->input_zero_point); @@ -278,9 +415,9 @@ Expr Conv2DThirdTerm(const Expr& weight, const QnnConv2DAttrs* param, int batch_ // Find the newshape depending on NCHW/NHWC layout. Array newshape; if (param->data_layout == "NCHW") { - newshape = {batch_size, out_channels, 1, 1}; + newshape = {1, out_channels, 1, 1}; } else if (param->data_layout == "NHWC") { - newshape = {batch_size, 1, 1, out_channels}; + newshape = {1, 1, 1, out_channels}; } else { LOG(FATAL) << "qnn.conv2d does not support " << param->data_layout << " layout"; } @@ -295,7 +432,6 @@ Expr Conv2DThirdTerm(const Expr& weight, const QnnConv2DAttrs* param, int batch_ /* * \brief Calculates the fourth term in the qnn.conv2d lowering sequence. * \param param The qnn conv2d attributes. - * \param batch_size The batch size. * \param in_channels The number of input channels. * \param kernel_h The height of kernel. * \param kernel_w The width of kernel. @@ -305,8 +441,7 @@ Expr Conv2DThirdTerm(const Expr& weight, const QnnConv2DAttrs* param, int batch_ * Sigma(c,r,s) zp_a * zp_w * */ -Expr Conv2DFourthTerm(const QnnConv2DAttrs* param, int batch_size, int in_channels, int kernel_h, - int kernel_w) { +Expr Conv2DFourthTerm(const QnnConv2DAttrs* param, int in_channels, int kernel_h, int kernel_w) { int scalar_term4 = param->input_zero_point * param->kernel_zero_point * in_channels * kernel_h * kernel_w; return MakeConstantScalar(Int(32), scalar_term4); @@ -391,7 +526,20 @@ Expr Conv2DCombineTerms(const Expr& term1, const Expr& term2, const Expr& term3, * gives an opportunity to reuse alter_op_layout infrastructure. * 3) For dilated conv, in current lowering, we need dilated pool. So as * a workaround, we fall back to simpler lowering using int32 conv if - * the conv is dilated. We fallback also in case of depthwise conv. + * the conv is dilated. We fallback also in case of grouped conv. + * + * For depthwise, we can similarly unroll the computation. The intial compute is as follows + * wehere cm = channel_multiplier + * + * Qc(n, oc, oh, ow) = Sigma(r, s) (Qw(oc/m, oc%/m, r, s) - zp_w) + * * (Qa(n, oc/cm, oh + r, ow + s) - zp_a) + * + * This can be written as + * + * Sigma(r, s) Qw(oc/m, oc%/m, r, s) * Qa(n, oc/cm, oh + r, ow + s) + * - Sigma(r, s) zp_w * Qa(n, oc/cm, oh + r, ow + s) + * - Sigma(r, s) zp_a * Qw(oc/m, oc%m, r, s) + * - Sigma(r, s) zp_a * zp_w * * The whole process can be broken down into following steps * * Assertion checks for existing support, fallback if necessary @@ -417,23 +565,33 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array& new_args, param->kernel_layout == "HWOI") << "qnn.conv2d supports only OIHW/HWIO/HWOI kernel data layout."; - int batch_size, in_channels, out_channels, kernel_h, kernel_w; - std::tie(batch_size, in_channels, out_channels, kernel_h, kernel_w) = + int batch_size, in_channels, out_channels, kernel_h, kernel_w, channel_multiplier; + std::tie(batch_size, in_channels, out_channels, kernel_h, kernel_w, channel_multiplier) = GetWorkload(arg_types, param); - // Fallback to int32 conv if there is dilation or depthwise conv2d + // Fallback to int32 conv if there is dilation or grouped conv2d + CHECK_EQ(param->dilation.size(), 2) << "qnn.conv2d only supports 2D dilation"; auto dilation_h = get_const_int(param->dilation[0]); auto dilation_w = get_const_int(param->dilation[1]); - if (dilation_h != 1 || dilation_w != 1 || param->groups != 1) { + if (dilation_h != 1 || dilation_w != 1 || (param->groups != 1 && !is_depthwise(param))) { return Conv2DFallBack(data, weight, param); + } else if (is_depthwise(param)) { + CHECK_NE(channel_multiplier, -1); + auto padded_data = Conv2DPadInput(data, param); + auto term1 = Conv2DFirstTerm(padded_data, weight, param); + auto term2 = + DepthwiseConv2DSecondTerm(padded_data, param, kernel_h, kernel_w, channel_multiplier); + auto term3 = DepthwiseConv2DThirdTerm(weight, param, out_channels, channel_multiplier); + auto term4 = DepthwiseConv2DFourthTerm(param, kernel_h, kernel_w); + return Conv2DCombineTerms(term1, term2, term3, term4, param); } auto padded_data = Conv2DPadInput(data, param); auto term1 = Conv2DFirstTerm(padded_data, weight, param); auto term2 = Conv2DSecondTerm(padded_data, param, kernel_h, kernel_w, out_channels); - auto term3 = Conv2DThirdTerm(weight, param, batch_size, out_channels); - auto term4 = Conv2DFourthTerm(param, batch_size, in_channels, kernel_h, kernel_w); + auto term3 = Conv2DThirdTerm(weight, param, out_channels); + auto term4 = Conv2DFourthTerm(param, in_channels, kernel_h, kernel_w); return Conv2DCombineTerms(term1, term2, term3, term4, param); } diff --git a/tests/python/relay/test_op_qnn_conv2d.py b/tests/python/relay/test_op_qnn_conv2d.py index 40fc993635ef2..4c03b7424c12d 100644 --- a/tests/python/relay/test_op_qnn_conv2d.py +++ b/tests/python/relay/test_op_qnn_conv2d.py @@ -42,7 +42,9 @@ def get_ref_func(data, dilation, data_layout, kernel_layout, - out_dtype): + out_dtype, + groups, + channels=None): casted_data = relay.op.cast(data, "int32") casted_kernel = relay.op.cast(kernel, "int32") shifted_data = relay.op.subtract(casted_data, @@ -54,6 +56,8 @@ def get_ref_func(data, padding=padding, strides=strides, dilation=dilation, + groups=groups, + channels=channels, kernel_size=kernel_size, out_dtype=out_dtype, data_layout=data_layout, @@ -74,7 +78,9 @@ def get_qnn_func(data, dilation, data_layout, kernel_layout, - out_dtype): + out_dtype, + groups, + channels=None): func = relay.qnn.op.conv2d( data, kernel, input_zero_point=input_zero_point, @@ -86,6 +92,8 @@ def get_qnn_func(data, dilation=dilation, padding=padding, out_dtype=out_dtype, + groups=groups, + channels=channels, data_layout=data_layout, kernel_layout=kernel_layout) @@ -107,7 +115,9 @@ def get_funcs(data_shape, dilation, data_layout, kernel_layout, - out_dtype): + out_dtype, + groups=1, + channels=None): data = relay.var("data", shape=data_shape, dtype=data_dtype) kernel = relay.var("kernel", shape=kernel_shape, @@ -124,7 +134,9 @@ def get_funcs(data_shape, dilation, data_layout, kernel_layout, - out_dtype) + out_dtype, + groups, + channels) ref_func = run_infer_type(ref_func) qnn_func = get_qnn_func(data, kernel, @@ -138,7 +150,9 @@ def get_funcs(data_shape, dilation, data_layout, kernel_layout, - out_dtype) + out_dtype, + groups, + channels) return (ref_func, qnn_func) def verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, @@ -529,7 +543,8 @@ def test_const_folding(): dilation=(1, 1), data_layout="NCHW", kernel_layout="OIHW", - out_dtype="int32") + out_dtype="int32", + groups=1) folded_mod = transform.FoldConstant()(qnn_func) folded_func = folded_mod["main"] assert "reshape" not in folded_func.astext() @@ -724,6 +739,112 @@ def test_broadcast_layout(): with relay.build_config(opt_level=3): graph, lib, params = relay.build(mod, "llvm -mcpu=skylake-avx512") +def test_depthwise_depth_multiplier(): + with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d): + + # uint8 input, NCHW and OIHW + # Depthwise multiplier = 1 + data_shape = (2, 4, 16, 16) + data_dtype = 'uint8' + kernel_shape = (4, 1, 3, 3) + kernel_dtype = 'uint8' + ref_func, qnn_func = get_funcs(data_shape=data_shape, + data_dtype=data_dtype, + kernel_shape=kernel_shape, + kernel_dtype=kernel_dtype, + input_zero_point=5, + kernel_zero_point=3, + input_scale=1.0, + kernel_scale=1.0, + kernel_size=(3, 3), + padding=(0, 0), + strides=(1, 1), + dilation=(1, 1), + data_layout="NCHW", + kernel_layout="OIHW", + out_dtype="int32", + groups=4, + channels=4) + verify(ref_func, qnn_func, data_shape, data_dtype, + kernel_shape, kernel_dtype) + + + # Depthwise multiplier = 2 + data_shape = (10, 4, 16, 16) + data_dtype = 'uint8' + kernel_shape = (4, 2, 3, 3) + kernel_dtype = 'uint8' + ref_func, qnn_func = get_funcs(data_shape=data_shape, + data_dtype=data_dtype, + kernel_shape=kernel_shape, + kernel_dtype=kernel_dtype, + input_zero_point=5, + kernel_zero_point=3, + input_scale=1.0, + kernel_scale=1.0, + kernel_size=(3, 3), + padding=(0, 0), + strides=(1, 1), + dilation=(1, 1), + data_layout="NCHW", + kernel_layout="OIHW", + out_dtype="int32", + groups=8, + channels=8) + verify(ref_func, qnn_func, data_shape, data_dtype, + kernel_shape, kernel_dtype) + + # uint8 input, NHWC and HWOI + # Depthwise multiplier = 1 + data_shape = (2, 16, 16, 4) + data_dtype = 'uint8' + kernel_shape = (3, 3, 4, 1) + kernel_dtype = 'uint8' + ref_func, qnn_func = get_funcs(data_shape=data_shape, + data_dtype=data_dtype, + kernel_shape=kernel_shape, + kernel_dtype=kernel_dtype, + input_zero_point=5, + kernel_zero_point=3, + input_scale=1.0, + kernel_scale=1.0, + kernel_size=(3, 3), + padding=(0, 0), + strides=(1, 1), + dilation=(1, 1), + data_layout="NHWC", + kernel_layout="HWOI", + out_dtype="int32", + groups=4, + channels=4) + verify(ref_func, qnn_func, data_shape, data_dtype, + kernel_shape, kernel_dtype) + + # Depthwise multiplier = 2 + data_shape = (2, 16, 16, 4) + data_dtype = 'uint8' + kernel_shape = (3, 3, 4, 2) + kernel_dtype = 'uint8' + ref_func, qnn_func = get_funcs(data_shape=data_shape, + data_dtype=data_dtype, + kernel_shape=kernel_shape, + kernel_dtype=kernel_dtype, + input_zero_point=5, + kernel_zero_point=3, + input_scale=1.0, + kernel_scale=1.0, + kernel_size=(3, 3), + padding=(0, 0), + strides=(1, 1), + dilation=(1, 1), + data_layout="NHWC", + kernel_layout="HWOI", + out_dtype="int32", + groups=8, + channels=8) + verify(ref_func, qnn_func, data_shape, data_dtype, + kernel_shape, kernel_dtype) + if __name__ == "__main__": test_no_zero_point() test_input_zero_point() @@ -738,3 +859,4 @@ def test_broadcast_layout(): test_broadcast_layout() test_tflite_output_multiplier_greater_than_one() test_tflite_anistropic_strides() + test_depthwise_depth_multiplier() diff --git a/topi/python/topi/x86/conv2d_alter_op.py b/topi/python/topi/x86/conv2d_alter_op.py index f596bc0eb503d..a02f919cba0eb 100644 --- a/topi/python/topi/x86/conv2d_alter_op.py +++ b/topi/python/topi/x86/conv2d_alter_op.py @@ -197,6 +197,11 @@ def _conv2d_legalize(attrs, inputs, arg_types): if not (dilation[0] == 1 and dilation[1] == 1): return None + # No legalization for depthwise convolutions yet. + groups = attrs.get_int("groups") + if groups != 1: + return None + # Collect the input tensors. data_tensor, kernel_tensor = arg_types[0], arg_types[1] data_dtype = data_tensor.dtype