From 47f06a3ccf33e9e075ce457c1e2a2b54d586effb Mon Sep 17 00:00:00 2001 From: shoubhik Date: Thu, 17 Oct 2019 16:17:48 -0700 Subject: [PATCH 1/6] [QNN] Improving Dense lowering. --- include/tvm/relay/qnn/attrs.h | 2 +- python/tvm/relay/qnn/op/op_attrs.py | 4 + src/relay/qnn/op/dense.cc | 131 ++++++++++++------ ...test_qnn_dense.py => test_op_qnn_dense.py} | 9 -- 4 files changed, 97 insertions(+), 49 deletions(-) rename tests/python/relay/{test_qnn_dense.py => test_op_qnn_dense.py} (95%) diff --git a/include/tvm/relay/qnn/attrs.h b/include/tvm/relay/qnn/attrs.h index e5f4ba94e12e..58707f4da19c 100644 --- a/include/tvm/relay/qnn/attrs.h +++ b/include/tvm/relay/qnn/attrs.h @@ -213,7 +213,7 @@ struct QnnDenseAttrs : public tvm::AttrsNode { int32_t input_zero_point; int32_t kernel_zero_point; - TVM_DECLARE_ATTRS(QnnDenseAttrs, "relay.attrs.qnn.QnnDenseAttrs") { + TVM_DECLARE_ATTRS(QnnDenseAttrs, "relay.attrs.QnnDenseAttrs") { TVM_ATTR_FIELD(units) .describe("Number of hidden units of the dense transformation."); TVM_ATTR_FIELD(out_dtype) diff --git a/python/tvm/relay/qnn/op/op_attrs.py b/python/tvm/relay/qnn/op/op_attrs.py index e3fe1c8924f6..24ca3b47cc3a 100644 --- a/python/tvm/relay/qnn/op/op_attrs.py +++ b/python/tvm/relay/qnn/op/op_attrs.py @@ -22,3 +22,7 @@ @register_relay_attr_node class QnnConv2DAttrs(Attrs): """Attributes for qnn.conv2d""" + +@register_relay_attr_node +class QnnDenseAttrs(Attrs): + """Attributes for qnn.dense""" diff --git a/src/relay/qnn/op/dense.cc b/src/relay/qnn/op/dense.cc index c708cfa3dc63..d0ffbf41bcb5 100644 --- a/src/relay/qnn/op/dense.cc +++ b/src/relay/qnn/op/dense.cc @@ -29,6 +29,7 @@ #include #include "../../op/nn/nn.h" #include "../../pass/pattern_util.h" +#include "../util.h" namespace tvm { namespace relay { @@ -37,33 +38,27 @@ namespace qnn { // relay.op.qnn.dense TVM_REGISTER_NODE_TYPE(QnnDenseAttrs); -bool QnnDenseRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool QnnDenseRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); const auto* data = types[0].as(); const auto* weight = types[1].as(); if (data == nullptr || weight == nullptr) return false; const auto* param = attrs.as(); - CHECK(param != nullptr) << "QnnConv2DAttrs cannot be nullptr."; + CHECK(param != nullptr) << "QnnDenseAttrs cannot be nullptr."; CHECK(data->dtype == Int(8) || data->dtype == UInt(8)) - << "Expected quantized dense type(int8, uint8) for input but was " << data->dtype; + << "Expected quantized dense type(int8, uint8) for input but was " << data->dtype; CHECK(weight->dtype == Int(8) || weight->dtype == UInt(8)) - << "Expected quantized dense type(int8, uint8) for weight but was " << weight->dtype; + << "Expected quantized dense type(int8, uint8) for weight but was " << weight->dtype; CHECK(param->out_dtype == Int(32)) - << "Expected quantized dense type(int32) for output but was " << param->out_dtype; + << "Expected quantized dense type(int32) for output but was " << param->out_dtype; CHECK(param->out_dtype.bits() > 0) << "Output dtype bits should be greater than 0."; return DenseRel(types, num_inputs, attrs, reporter); } // Positional relay function to create quantized dense operator used by frontend FFI. -Expr MakeQuantizedDense(Expr data, - Expr weight, - IndexExpr units, - int32_t input_zero_point, - int32_t kernel_zero_point, - DataType out_dtype) { +Expr MakeQuantizedDense(Expr data, Expr weight, IndexExpr units, int32_t input_zero_point, + int32_t kernel_zero_point, DataType out_dtype) { auto attrs = make_node(); attrs->units = std::move(units); attrs->out_dtype = out_dtype; @@ -73,40 +68,98 @@ Expr MakeQuantizedDense(Expr data, return CallNode::make(op, {data, weight}, Attrs(attrs), {}); } -/** - * \brief Lowers Qnn convolution in terms of core operators in relay. - * Mathematically it is equals to - - * Dense((quantized_input - input_zero_point;int32), (quantized_kernel - kernel_zero_point; int32)) - * - * \param attrs QnnDenseAttrs for Qnn Dense layer. +Expr DenseFirstTerm(const Expr& quantized_data, const Expr& quantized_kernel, + const QnnDenseAttrs* attrs) { + return Dense(quantized_data, quantized_kernel, attrs->units, attrs->out_dtype); +} + +Expr DenseSecondTerm(const Expr& quantized_data, const Expr& zp_kernel) { + Array axes = {1}; + return Multiply(zp_kernel, Sum(Cast(quantized_data, Int(32)), axes, true, false)); +} + +Expr DenseThirdTerm(const Expr& quantized_kernel, const Expr& zp_data) { + Array axes = {1}; + return Multiply(zp_data, Sum(Cast(quantized_kernel, Int(32)), axes, false, false)); +} + +Expr DenseFourthTerm(const QnnDenseAttrs* attrs, int common_dim) { + int64_t scalar_term = attrs->input_zero_point * attrs->kernel_zero_point * common_dim; + return MakeConstantScalar(Int(32), (int32_t)scalar_term); +} + +/* + * \brief Forward rewrite the qnn dense op. + * \param attrs The QNN dense attrs. * \param new_args The new mutated args to the call node. - * \param arg_types The data types of input and output. - * \reutrn The sequence of Relay ops for qnn cov2d op. + * \param arg_types The types of input and output. + * \return The sequence of Relay ops for qnn cov2d op. + * \node Lowering of the qnn.dense operator + * A quantized tensor is represented in following manner + * A = scale_a x (QA - zp_A) + * where QA is quantized tensor, scale_a and zp_A are quantizations + * params. + * + * Quantized dense multiplies two quantized tensors and returns a + * quantized tensor of default dtype of int32, with scale equaling to the + * product of scales of input tensors, and a zero point of zero. + * + * The lowering for asymmetric quantized dense looks as follows. More details at + * https://discuss.tvm.ai/t/tf-lite-quantized-conv2d-operator-conversion/2651/8?u=janimesh + * The computation gets unrolled into following 4 terms + * C(m, n) = Sigma(k) A(m, k) * W(n, k) + * + * RHS becomes + * Sigma(k) [QA(m, k) - zp_a] * [QW(n, k) - zp_w] + * + * Unrolling leads to following sequence + * Sigma(k) QA(m, k) * QW(n, k) // Term1 + * - Sigma(k) zp_w * QA(m, k) // Term2 + * - Sigma(k) zp_a * QW(n, k) // Term3 + * - Sigma(k) * zp_a * zp_w // Term4 + * + * Term3 and Term4 can be computed at compile time. */ -Expr QnnDenseCanonicalize(const Attrs& attrs, - const Array& new_args, +Expr QnnDenseCanonicalize(const Attrs& attrs, const Array& new_args, const Array& arg_types) { CHECK_EQ(new_args.size(), 2); Expr quantized_data = new_args[0]; Expr quantized_kernel = new_args[1]; + auto get_shape = [](const Type& type) { + auto input_tt = type.as(); + CHECK(input_tt != nullptr) << "Type information missing." + << " Please run infer_type pass."; + return input_tt->shape; + }; + const auto in_shape = get_shape(arg_types[0]); + const int common_dim = get_const_int(in_shape[1]); + const auto* qnn_dense_attrs = attrs.as(); - Expr quantized_data_int32 = Cast(quantized_data, Int(32)); - if (qnn_dense_attrs->input_zero_point != 0) { - quantized_data_int32 = Subtract(quantized_data_int32, - MakeConstantScalar(Int(32), - qnn_dense_attrs->input_zero_point)); - } - Expr quantized_kernel_int32 = Cast(quantized_kernel, Int(32)); - if (qnn_dense_attrs->kernel_zero_point != 0) { - quantized_kernel_int32 = Subtract(quantized_kernel_int32, - MakeConstantScalar(Int(32), - qnn_dense_attrs->kernel_zero_point)); + auto zp_kernel = MakeConstantScalar(Int(32), qnn_dense_attrs->kernel_zero_point); + auto zp_data = MakeConstantScalar(Int(32), qnn_dense_attrs->input_zero_point); + + // Get all the terms as described in the comments. + auto term1 = DenseFirstTerm(quantized_data, quantized_kernel, qnn_dense_attrs); + auto term2 = DenseSecondTerm(quantized_data, zp_kernel); + auto term3 = DenseThirdTerm(quantized_kernel, zp_data); + auto term4 = DenseFourthTerm(qnn_dense_attrs, common_dim); + + // Combine those 4 terms depending on the zero points to get the best lowering. + if (qnn_dense_attrs->input_zero_point == 0 && qnn_dense_attrs->kernel_zero_point == 0) { + // term 2, 3 and 4 become zero. + return term1; + } else if (qnn_dense_attrs->input_zero_point == 0 && qnn_dense_attrs->kernel_zero_point != 0) { + // term 3 and term 4 become zero. + return Subtract(term1, term2); + } else if (qnn_dense_attrs->input_zero_point != 0 && qnn_dense_attrs->kernel_zero_point == 0) { + // term 2 and term 4 become zero. + return Subtract(term1, term3); + } else { + auto data_term = Subtract(term1, term2); + // Putting constant terms together, so that constant folding can fold it. + auto const_term = Subtract(term4, term3); + return Add(data_term, const_term); } - Expr int32_dense = Dense(quantized_data_int32, - quantized_kernel_int32, - qnn_dense_attrs->units, - qnn_dense_attrs->out_dtype); - return int32_dense; } RELAY_REGISTER_OP("qnn.dense") diff --git a/tests/python/relay/test_qnn_dense.py b/tests/python/relay/test_op_qnn_dense.py similarity index 95% rename from tests/python/relay/test_qnn_dense.py rename to tests/python/relay/test_op_qnn_dense.py index f1e0767aff2e..885fe37d29bb 100644 --- a/tests/python/relay/test_qnn_dense.py +++ b/tests/python/relay/test_op_qnn_dense.py @@ -193,29 +193,20 @@ def qnn_dense_driver(test_configuration): def test_qnn_dense_without_bias(): - uint32_output_without_bias_paramas = \ - make_uint_configuration(use_bias=False) int32_output_without_bias_params = \ make_int_configuration(use_bias=False) - qnn_dense_driver(uint32_output_without_bias_paramas) qnn_dense_driver(int32_output_without_bias_params) def test_qnn_dense_with_bias(): - uint32_output_with_bias_params = \ - make_uint_configuration(use_bias=True) int32_output_with_bias_params = \ make_int_configuration(use_bias=True) - qnn_dense_driver(uint32_output_with_bias_params) qnn_dense_driver(int32_output_with_bias_params) def test_qnn_dense_with_requantized_output(): - uint8_requantized_output_with_bias_params = \ - make_uint_configuration(use_bias=True, requantize_output=True) int8_requantized_output_with_bias_params = \ make_int_configuration(use_bias=True, requantize_output=True) - qnn_dense_driver(uint8_requantized_output_with_bias_params) qnn_dense_driver(int8_requantized_output_with_bias_params) From 9d63e297fa1a44a358d5a712063bf7f0bb15e5b2 Mon Sep 17 00:00:00 2001 From: shoubhik Date: Mon, 28 Oct 2019 15:13:39 -0700 Subject: [PATCH 2/6] - Moving get_shape method to util - Finalizing the test cases and the code structure for optimized dense computation. --- src/relay/qnn/op/convolution.cc | 7 ------- src/relay/qnn/op/dense.cc | 14 +++++--------- src/relay/qnn/util.h | 7 +++++++ 3 files changed, 12 insertions(+), 16 deletions(-) diff --git a/src/relay/qnn/op/convolution.cc b/src/relay/qnn/op/convolution.cc index d17a18589d75..1103620829d1 100644 --- a/src/relay/qnn/op/convolution.cc +++ b/src/relay/qnn/op/convolution.cc @@ -70,13 +70,6 @@ using WorkloadType = std::tuple; */ WorkloadType GetWorkload(const Array& arg_types, const QnnConv2DAttrs* param) { // Get conv parameters. - auto get_shape = [](const Type& type) { - auto input_tt = type.as(); - CHECK(input_tt != nullptr) << "Type information missing." - << " Please run infer_type pass."; - return input_tt->shape; - }; - const auto in_shape = get_shape(arg_types[0]); int batch_size, in_channels; if (param->data_layout == "NCHW") { diff --git a/src/relay/qnn/op/dense.cc b/src/relay/qnn/op/dense.cc index d0ffbf41bcb5..d529dc4fc0c6 100644 --- a/src/relay/qnn/op/dense.cc +++ b/src/relay/qnn/op/dense.cc @@ -97,7 +97,7 @@ Expr DenseFourthTerm(const QnnDenseAttrs* attrs, int common_dim) { * \node Lowering of the qnn.dense operator * A quantized tensor is represented in following manner * A = scale_a x (QA - zp_A) - * where QA is quantized tensor, scale_a and zp_A are quantizations + * where QA is quantized tensor, scale_a and zp_A are quantization * params. * * Quantized dense multiplies two quantized tensors and returns a @@ -107,10 +107,10 @@ Expr DenseFourthTerm(const QnnDenseAttrs* attrs, int common_dim) { * The lowering for asymmetric quantized dense looks as follows. More details at * https://discuss.tvm.ai/t/tf-lite-quantized-conv2d-operator-conversion/2651/8?u=janimesh * The computation gets unrolled into following 4 terms - * C(m, n) = Sigma(k) A(m, k) * W(n, k) + * C(m, n) = Sigma(k) (A(m, k) * W(n, k)) * * RHS becomes - * Sigma(k) [QA(m, k) - zp_a] * [QW(n, k) - zp_w] + * Sigma(k) ([QA(m, k) - zp_a] * [QW(n, k) - zp_w]) * * Unrolling leads to following sequence * Sigma(k) QA(m, k) * QW(n, k) // Term1 @@ -119,18 +119,14 @@ Expr DenseFourthTerm(const QnnDenseAttrs* attrs, int common_dim) { * - Sigma(k) * zp_a * zp_w // Term4 * * Term3 and Term4 can be computed at compile time. + * In case of symmetric quantization terms term2, term3 and term4 become 0. */ Expr QnnDenseCanonicalize(const Attrs& attrs, const Array& new_args, const Array& arg_types) { CHECK_EQ(new_args.size(), 2); Expr quantized_data = new_args[0]; Expr quantized_kernel = new_args[1]; - auto get_shape = [](const Type& type) { - auto input_tt = type.as(); - CHECK(input_tt != nullptr) << "Type information missing." - << " Please run infer_type pass."; - return input_tt->shape; - }; + const auto in_shape = get_shape(arg_types[0]); const int common_dim = get_const_int(in_shape[1]); diff --git a/src/relay/qnn/util.h b/src/relay/qnn/util.h index f94860d28cf9..f695c5221832 100644 --- a/src/relay/qnn/util.h +++ b/src/relay/qnn/util.h @@ -36,6 +36,13 @@ namespace tvm { namespace relay { namespace qnn { +static inline Array get_shape(const Type& type) { + auto input_tt = type.as(); + CHECK(input_tt != nullptr) << "Type information missing." + << " Please run infer_type pass."; + return input_tt->shape; +}; + static inline const int32_t GetQmin(const DataType& dtype) { CHECK_LE(dtype.bits(), 32) << "QNN ops support int32 or lower precision"; From 9115416a208c63d4dc5b302e3b4545c7a3058caa Mon Sep 17 00:00:00 2001 From: shoubhik Date: Mon, 28 Oct 2019 15:59:23 -0700 Subject: [PATCH 3/6] - Fixing cpplint. --- src/relay/qnn/util.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/qnn/util.h b/src/relay/qnn/util.h index f695c5221832..482f6d8265e4 100644 --- a/src/relay/qnn/util.h +++ b/src/relay/qnn/util.h @@ -41,7 +41,7 @@ static inline Array get_shape(const Type& type) { CHECK(input_tt != nullptr) << "Type information missing." << " Please run infer_type pass."; return input_tt->shape; -}; +} static inline const int32_t GetQmin(const DataType& dtype) { CHECK_LE(dtype.bits(), 32) From 099a41fac12537babc8a5774866079b1d4d0c1fb Mon Sep 17 00:00:00 2001 From: shoubhik Date: Tue, 29 Oct 2019 11:22:59 -0700 Subject: [PATCH 4/6] - Addressing review comments. --- src/relay/qnn/op/dense.cc | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/relay/qnn/op/dense.cc b/src/relay/qnn/op/dense.cc index d529dc4fc0c6..d826271d2b01 100644 --- a/src/relay/qnn/op/dense.cc +++ b/src/relay/qnn/op/dense.cc @@ -84,8 +84,8 @@ Expr DenseThirdTerm(const Expr& quantized_kernel, const Expr& zp_data) { } Expr DenseFourthTerm(const QnnDenseAttrs* attrs, int common_dim) { - int64_t scalar_term = attrs->input_zero_point * attrs->kernel_zero_point * common_dim; - return MakeConstantScalar(Int(32), (int32_t)scalar_term); + int32_t scalar_term = attrs->input_zero_point * attrs->kernel_zero_point * common_dim; + return MakeConstantScalar(Int(32), scalar_term); } /* @@ -94,7 +94,7 @@ Expr DenseFourthTerm(const QnnDenseAttrs* attrs, int common_dim) { * \param new_args The new mutated args to the call node. * \param arg_types The types of input and output. * \return The sequence of Relay ops for qnn cov2d op. - * \node Lowering of the qnn.dense operator + * \note Lowering of the qnn.dense operator * A quantized tensor is represented in following manner * A = scale_a x (QA - zp_A) * where QA is quantized tensor, scale_a and zp_A are quantization @@ -105,7 +105,7 @@ Expr DenseFourthTerm(const QnnDenseAttrs* attrs, int common_dim) { * product of scales of input tensors, and a zero point of zero. * * The lowering for asymmetric quantized dense looks as follows. More details at - * https://discuss.tvm.ai/t/tf-lite-quantized-conv2d-operator-conversion/2651/8?u=janimesh + * https://discuss.tvm.ai/t/tf-lite-quantized-conv2d-operator-conversion/2651/8 * The computation gets unrolled into following 4 terms * C(m, n) = Sigma(k) (A(m, k) * W(n, k)) * @@ -119,7 +119,6 @@ Expr DenseFourthTerm(const QnnDenseAttrs* attrs, int common_dim) { * - Sigma(k) * zp_a * zp_w // Term4 * * Term3 and Term4 can be computed at compile time. - * In case of symmetric quantization terms term2, term3 and term4 become 0. */ Expr QnnDenseCanonicalize(const Attrs& attrs, const Array& new_args, const Array& arg_types) { @@ -127,8 +126,8 @@ Expr QnnDenseCanonicalize(const Attrs& attrs, const Array& new_args, Expr quantized_data = new_args[0]; Expr quantized_kernel = new_args[1]; - const auto in_shape = get_shape(arg_types[0]); - const int common_dim = get_const_int(in_shape[1]); + const auto reduction_dim_size = get_shape(arg_types[0]); + const int common_dim = get_const_int(reduction_dim_size[1]); const auto* qnn_dense_attrs = attrs.as(); auto zp_kernel = MakeConstantScalar(Int(32), qnn_dense_attrs->kernel_zero_point); From 387b212be1a028034616e7ae1747ac261f7ebaaa Mon Sep 17 00:00:00 2001 From: shoubhik Date: Tue, 29 Oct 2019 11:30:45 -0700 Subject: [PATCH 5/6] - Renaming the variables correctly. --- src/relay/qnn/op/dense.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/relay/qnn/op/dense.cc b/src/relay/qnn/op/dense.cc index d826271d2b01..017dbfe26a69 100644 --- a/src/relay/qnn/op/dense.cc +++ b/src/relay/qnn/op/dense.cc @@ -126,8 +126,8 @@ Expr QnnDenseCanonicalize(const Attrs& attrs, const Array& new_args, Expr quantized_data = new_args[0]; Expr quantized_kernel = new_args[1]; - const auto reduction_dim_size = get_shape(arg_types[0]); - const int common_dim = get_const_int(reduction_dim_size[1]); + const auto in_shape = get_shape(arg_types[0]); + const int reduction_dim_size = get_const_int(in_shape[1]); const auto* qnn_dense_attrs = attrs.as(); auto zp_kernel = MakeConstantScalar(Int(32), qnn_dense_attrs->kernel_zero_point); @@ -137,7 +137,7 @@ Expr QnnDenseCanonicalize(const Attrs& attrs, const Array& new_args, auto term1 = DenseFirstTerm(quantized_data, quantized_kernel, qnn_dense_attrs); auto term2 = DenseSecondTerm(quantized_data, zp_kernel); auto term3 = DenseThirdTerm(quantized_kernel, zp_data); - auto term4 = DenseFourthTerm(qnn_dense_attrs, common_dim); + auto term4 = DenseFourthTerm(qnn_dense_attrs, reduction_dim_size); // Combine those 4 terms depending on the zero points to get the best lowering. if (qnn_dense_attrs->input_zero_point == 0 && qnn_dense_attrs->kernel_zero_point == 0) { From 8ace52627f7e4813e6a3ff5b66b33ba4fb8fc442 Mon Sep 17 00:00:00 2001 From: shoubhik Date: Tue, 29 Oct 2019 12:20:05 -0700 Subject: [PATCH 6/6] - Renaming the variables correctly. --- src/relay/qnn/op/dense.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/relay/qnn/op/dense.cc b/src/relay/qnn/op/dense.cc index 017dbfe26a69..9ac01212dc37 100644 --- a/src/relay/qnn/op/dense.cc +++ b/src/relay/qnn/op/dense.cc @@ -83,8 +83,8 @@ Expr DenseThirdTerm(const Expr& quantized_kernel, const Expr& zp_data) { return Multiply(zp_data, Sum(Cast(quantized_kernel, Int(32)), axes, false, false)); } -Expr DenseFourthTerm(const QnnDenseAttrs* attrs, int common_dim) { - int32_t scalar_term = attrs->input_zero_point * attrs->kernel_zero_point * common_dim; +Expr DenseFourthTerm(const QnnDenseAttrs* attrs, int reduction_dim_size) { + int32_t scalar_term = attrs->input_zero_point * attrs->kernel_zero_point * reduction_dim_size; return MakeConstantScalar(Int(32), scalar_term); }