diff --git a/include/tvm/relay/qnn/attrs.h b/include/tvm/relay/qnn/attrs.h index dbc8ded17e8e..7ef0b10541e5 100644 --- a/include/tvm/relay/qnn/attrs.h +++ b/include/tvm/relay/qnn/attrs.h @@ -33,10 +33,15 @@ namespace qnn { /*! \brief Attribute for requantize operator */ struct RequantizeAttrs : public tvm::AttrsNode { + int axis; std::string rounding; DataType out_dtype; TVM_DECLARE_ATTRS(RequantizeAttrs, "relay.attrs.RequantizeAttrs") { + TVM_ATTR_FIELD(axis) + .describe("The output channel axis for channel wise quantization. Default value is -1," + "which corresponds to the last axis.") + .set_default(-1); TVM_ATTR_FIELD(rounding).set_default("UPWARD") .describe("Defines the rounding direction when the value is midway between" "two representable values. There are two supported modes - UPWARD" @@ -56,10 +61,15 @@ struct RequantizeAttrs : public tvm::AttrsNode { /*! \brief Attribute for quantize operator */ struct QuantizeAttrs : public tvm::AttrsNode { DataType out_dtype; + int axis; TVM_DECLARE_ATTRS(QuantizeAttrs, "relay.attrs.QuantizeAttrs") { TVM_ATTR_FIELD(out_dtype) .describe("Output data type, can be one of [int8 or uint8]."); + TVM_ATTR_FIELD(axis) + .describe("The output channel axis for channel wise quantization. Default value is -1," + "which corresponds to the last axis.") + .set_default(-1); } }; diff --git a/python/tvm/relay/qnn/op/qnn.py b/python/tvm/relay/qnn/op/qnn.py index 2d950ba43bf5..739ee97dade7 100644 --- a/python/tvm/relay/qnn/op/qnn.py +++ b/python/tvm/relay/qnn/op/qnn.py @@ -26,6 +26,7 @@ def requantize(data, input_zero_point, output_scale, output_zero_point, + axis=-1, rounding="UPWARD", out_dtype="int8"): r"""Requantized operator. @@ -53,6 +54,9 @@ def requantize(data, output_zero_point: tvm.relay.Expr The zero point of the output tensor. + axis : int + The channel axis for quantization. Default value is -1 which corresponds to the last axis. + rounding : string, optional Defines the rounding direction when the value is midway between two representable values. @@ -71,6 +75,7 @@ def requantize(data, input_zero_point, output_scale, output_zero_point, + axis, rounding, out_dtype) @@ -78,6 +83,7 @@ def requantize(data, def quantize(data, output_scale, output_zero_point, + axis=-1, out_dtype='int8'): r""" Quantize op This operator takes float32 as input and produces quantized int8 or unit8 as output. @@ -95,6 +101,8 @@ def quantize(data, The output zero_point. output_scale : tvm.relay.Expr The output scale. + axis : int + The channel axis for quantization. Default value is -1 which corresponds to the last axis. out_dtype : str, optional The data type of the input tensor. Can be [int8, uint8] Returns @@ -106,6 +114,7 @@ def quantize(data, return _make.quantize(data, output_scale, output_zero_point, + axis, out_dtype) diff --git a/src/relay/pass/pattern_util.h b/src/relay/pass/pattern_util.h index 16356b625588..5b2c3ae791ad 100644 --- a/src/relay/pass/pattern_util.h +++ b/src/relay/pass/pattern_util.h @@ -35,6 +35,7 @@ #include #include #include +#include #include @@ -221,6 +222,19 @@ inline bool IsScalar(const Expr& expr) { return true; } +/*! + * \brief Check if expr is a const scalar. + * \param expr The expr. + * \return True if const scalar. + */ +inline bool IsConstScalar(const Expr& expr) { + const auto* const_expr = expr.as(); + if (const_expr) { + return const_expr->is_scalar(); + } + return false; +} + /*! * \brief Create a Constant with a scalar * @@ -228,7 +242,7 @@ inline bool IsScalar(const Expr& expr) { * \param value The value of the scalar. * \return A Constant. */ -template +template inline Constant MakeConstantScalar(DataType dtype, T value) { runtime::NDArray arr = runtime::NDArray::Empty({}, dtype, {kDLCPU, 0}); TVM_DTYPE_DISPATCH(dtype, DType, { @@ -236,7 +250,7 @@ inline Constant MakeConstantScalar(DataType dtype, T value) { // convert to float16 // storage is uint16_t *static_cast(arr->data) = - __truncXfYf2__(static_cast(value)); + __truncXfYf2__(static_cast(value)); } else { *static_cast(arr->data) = value; } @@ -244,6 +258,34 @@ inline Constant MakeConstantScalar(DataType dtype, T value) { return ConstantNode::make(arr); } +/*! + * \brief Create a Constant with a tensor. + * + * \param dtype The data type. + * \param value The vector of the tensor values. + * \return A Constant. + */ +template +static inline Constant MakeConstantTensor(DataType dtype, std::vector shape, + std::vector value) { + runtime::NDArray arr = runtime::NDArray::Empty(shape, dtype, {kDLCPU, 0}); + TVM_DTYPE_DISPATCH(dtype, DType, { + for (size_t i = 0; i < value.size(); i++) { + if (dtype == DataType::Float(16)) { + // convert to float16 + // storage is uint16_t + // Similar handling as that in MakeConstantScalar + *(static_cast(arr->data) + i) = + __truncXfYf2__( + static_cast(value[i])); + } else { + *(static_cast(arr->data) + i) = value[i]; + } + } + }) + return ConstantNode::make(arr); +} + /*! * \brief Check if two expressions are equal scalars. * \param a The expression to be checked. @@ -523,6 +565,8 @@ static inline Expr Tile(Expr data, Array reps) { return CallNode::make(op, {data}, Attrs(attrs), {}); } +Expr MakeBroadCastTo(Expr data, Array shape); + Expr MakeConcatenate(Expr data, int axis); Expr MakeRepeat(Expr data, int repeats, int axis); diff --git a/src/relay/qnn/op/quantize.cc b/src/relay/qnn/op/quantize.cc index 27b3c0f94ca7..f6133b8736de 100644 --- a/src/relay/qnn/op/quantize.cc +++ b/src/relay/qnn/op/quantize.cc @@ -45,11 +45,18 @@ bool QuantizeRel(const Array& types, CHECK(input_dtype == DataType::Float(32)) << "Input type should be one of float32 but was " << input_dtype; - // Check the types of scale and zero points. - CHECK(IsScalarType(types[1], DataType::Float(32))); // output_scale - CHECK(IsScalarType(types[2], DataType::Int(32))); // output_zero_point - const auto* quantize_attrs = attrs.as(); + int axis = quantize_attrs->axis; + axis = (axis == -1) ? data->shape.size() - 1: axis; + CHECK_LT(axis, static_cast(data->shape.size())) + << "axis " << quantize_attrs->axis << " is out of range"; + CHECK_GE(axis, 0) + << "axis " << quantize_attrs->axis << " is out of range"; + + // Check and assign types for scale and zero points. + AssignType(types[1], DataType::Float(32), data->shape[axis], reporter); // scale + AssignType(types[2], DataType::Int(32), data->shape[axis], reporter); // zero point + const Array oshape = data->shape; const DataType out_dtype = quantize_attrs->out_dtype; CHECK(out_dtype == DataType::Int(8) || out_dtype == DataType::UInt(8) || @@ -60,8 +67,10 @@ bool QuantizeRel(const Array& types, return true; } -Expr MakeQuantize(Expr data, Expr output_scale, Expr output_zero_point, DataType out_dtype) { +Expr MakeQuantize(Expr data, Expr output_scale, Expr output_zero_point, int axis, + DataType out_dtype) { auto attrs = make_object(); + attrs->axis = axis; attrs->out_dtype = std::move(out_dtype); // result_quantized_value = result_zero_point + result_real_value / result_scale. // A more detailed explanation can be found here - @@ -71,13 +80,29 @@ Expr MakeQuantize(Expr data, Expr output_scale, Expr output_zero_point, DataType } Expr QuantizeLower(const Expr& input_tensor, const Expr& output_scale, - const Expr& output_zero_point, const QuantizeAttrs* attrs) { + const Expr& output_zero_point, const Array& input_shape, + const QuantizeAttrs* attrs) { const auto out_dtype = attrs->out_dtype; + const auto axis = attrs->axis; + + size_t n_dim = input_shape.size(); + + auto expanded_output_scale = output_scale; + if (!IsConstScalar(output_scale)) { + expanded_output_scale = ExpandBiasToMatchAxis(output_scale, n_dim, {axis}); + } + + auto expanded_output_zero_point = output_zero_point; + if (!IsConstScalar(output_zero_point)) { + expanded_output_zero_point = ExpandBiasToMatchAxis(output_zero_point, n_dim, {axis}); + } + const int32_t min_val = GetQmin(out_dtype); const int32_t max_val = GetQmax(out_dtype); - auto scale_data = Divide(input_tensor, output_scale); + auto scale_data = Divide(input_tensor, expanded_output_scale); auto add_zero_point = - Cast(Round(Add(scale_data, Cast(output_zero_point, DataType::Float(32)))), DataType::Int(32)); + Cast(Round(Add(scale_data, Cast(expanded_output_zero_point, DataType::Float(32)))), + DataType::Int(32)); auto clamped_output = Clip(add_zero_point, min_val, max_val); auto clamp_out_dtype = Cast(clamped_output, out_dtype); return clamp_out_dtype; @@ -92,8 +117,15 @@ Expr QuantizeQnnCanonicalize(const Attrs& attrs, const Array& new_args, const auto* quantize_attrs = attrs.as(); CHECK(quantize_attrs != nullptr); + // Find input shape. CHECK_EQ(types.size(), 4); - return QuantizeLower(data, output_scale, output_zero_point, quantize_attrs); + auto in_type = types[0]; + auto in_tensor_type = in_type.as(); + CHECK(in_tensor_type != nullptr) << "Type information missing." + << " Please run infer_type pass."; + Array input_shape = in_tensor_type->shape; + + return QuantizeLower(data, output_scale, output_zero_point, input_shape, quantize_attrs); } RELAY_REGISTER_OP("qnn.quantize") diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc index fc6a9b442d47..a8a919589839 100644 --- a/src/relay/qnn/op/requantize.cc +++ b/src/relay/qnn/op/requantize.cc @@ -58,11 +58,6 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale, const Expr& input_zero_point, const Expr& output_scale, const Expr& output_zero_point, const RequantizeAttrs* param, const Array& input_shape, const DataType& out_dtype) { - float input_scale_float = GetScalarFromConstant(input_scale); - float output_scale_float = GetScalarFromConstant(output_scale); - double double_multiplier = - static_cast(input_scale_float) / static_cast(output_scale_float); - DataType hp_dtype = DataType::Int(64); auto tensor = Cast(input_tensor, hp_dtype); @@ -72,11 +67,34 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale, tensor = Subtract(tensor, Cast(input_zero_point, hp_dtype)); } - // 2) If the input and output scales are same, we can skip the fixed point multiplication. + // 2) If the input and output scales are same, we can skip the fixed point multiplication. Check + // if the input scale is per-tensor or per-channel. If it is per-tensor, there is single scale for + // the whole tensor. For per-channel (aka per-axis), there is a vector of scales for the input + // tensor. Depending on the quantization type, the fixed point multiplication routing is called. auto scaled_int64_t = tensor; - if (!IsEqualScalar(input_scale, output_scale)) { - scaled_int64_t = - FixedPointMultiply(scaled_int64_t, double_multiplier, input_shape, param->rounding); + float output_scale_float = GetScalarFromConstant(output_scale); + if (IsConstScalar(input_scale)) { + // This is per-tensor quantization. Single scale. + float input_scale_float = GetScalarFromConstant(input_scale); + double double_multiplier = + static_cast(input_scale_float) / static_cast(output_scale_float); + // Skip if input and output scales are same. + if (!IsEqualScalar(input_scale, output_scale)) { + scaled_int64_t = + FixedPointMultiply(scaled_int64_t, double_multiplier, input_shape, param->rounding); + } + } else { + // This is per-channel (per=axis) quantization. + std::vector double_multipliers; + auto input_axis_scales = GetFloatVectorFromConstant(input_scale); + for (auto input_axis_scale : input_axis_scales) { + double_multipliers.push_back(static_cast(input_axis_scale) / + static_cast(output_scale_float)); + } + int axis = param->axis; + axis = (axis == -1) ? input_shape.size() - 1 : axis; + scaled_int64_t = FixedPointMultiplyPerChannel(scaled_int64_t, double_multipliers, input_shape, + axis, param->rounding); } // 3) Add the output zero point. @@ -157,16 +175,24 @@ bool RequantizeRel(const Array& types, int num_inputs, const Attrs& attrs, in_dtype == DataType::Int(32)) << "Input type should be one of [int8, uint8, int32] but was " << in_dtype; - // Check the types of scale and zero points. - CHECK(IsScalarType(types[1], DataType::Float(32))); // input_scale - CHECK(IsScalarType(types[2], DataType::Int(32))); // input_zero_point + const RequantizeAttrs* requantize_attrs = attrs.as(); + int axis = requantize_attrs->axis; + axis = (axis == -1) ? data->shape.size() - 1: axis; + CHECK_LT(axis, static_cast(data->shape.size())) + << "axis " << requantize_attrs->axis << " is out of range"; + CHECK_GE(axis, 0) + << "axis " << requantize_attrs->axis << " is out of range"; + + // Check and assign types for scale and zero points. + AssignType(types[1], DataType::Float(32), data->shape[axis], reporter); // input_scale + AssignType(types[2], DataType::Int(32), data->shape[axis], reporter); // input_zero_pt + // For now, requantize output tensor is limited to full tensor uniform quantization. CHECK(IsScalarType(types[3], DataType::Float(32))); // output_scale CHECK(IsScalarType(types[4], DataType::Int(32))); // output_zero_point const Array oshape = data->shape; // assign output type - const RequantizeAttrs* param = attrs.as(); - auto out_dtype = param->out_dtype; + auto out_dtype = requantize_attrs->out_dtype; CHECK(out_dtype == DataType::Int(8) || out_dtype == DataType::UInt(8) || out_dtype == DataType::Int(32)) @@ -178,8 +204,9 @@ bool RequantizeRel(const Array& types, int num_inputs, const Attrs& attrs, // Positional relay function to create qnn requantize operator // used by frontend FFI. Expr MakeRequantize(Expr data, Expr input_scale, Expr input_zero_point, Expr output_scale, - Expr output_zero_point, std::string rounding, DataType out_dtype) { + Expr output_zero_point, int axis, std::string rounding, DataType out_dtype) { auto attrs = make_object(); + attrs->axis = axis; attrs->rounding = std::move(rounding); attrs->out_dtype = std::move(out_dtype); static const Op& op = Op::Get("qnn.requantize"); diff --git a/src/relay/qnn/util.cc b/src/relay/qnn/util.cc index 98d07eb098de..cd0e68824129 100644 --- a/src/relay/qnn/util.cc +++ b/src/relay/qnn/util.cc @@ -75,8 +75,8 @@ std::pair GetFixedPointMultiplierShift( return std::make_pair(significand, exponent); } -Expr FixedPointMultiply(Expr tensor, double multiplier, - const Array& input_shape, const std::string& rounding) { +Expr FixedPointMultiply(Expr tensor, double multiplier, const Array& input_shape, + const std::string& rounding) { // Choose high precision datatype to be int64. This is for avoiding overflow // in multiplication of two int32 values. DataType hp_dtype = DataType::Int(64); @@ -133,6 +133,90 @@ Expr FixedPointMultiply(Expr tensor, double multiplier, return tensor; } +Expr FixedPointMultiplyPerChannel(Expr tensor, std::vector multipliers, + const Array& input_shape, int channel_axis, + const std::string& rounding) { + // Get the n dim. This will be used to expand the multiplier to match the axis. + size_t n_dim = input_shape.size(); + + // Get the num of channels/axis along which the tensor was quantized. + int64_t n_channels = (int64_t)multipliers.size(); + + // Choose high precision datatype to be int64. This is for avoiding overflow + // in multiplication of two int32 values. + DataType hp_dtype = DataType::Int(64); + + // 1) Calculating the integer multiplier and integer shift. These are calculated per axis/per + // channel. + std::vector fixed_pt_multipliers, lshifts, rshifts; + for (auto multiplier : multipliers) { + int32_t fixed_pt_multiplier, shift; + std::tie(fixed_pt_multiplier, shift) = GetFixedPointMultiplierShift(multiplier); + int lshift = shift > 0 ? shift : 0; + int rshift = shift > 0 ? 0 : -shift; + fixed_pt_multipliers.push_back(fixed_pt_multiplier); + lshifts.push_back(lshift); + rshifts.push_back(rshift); + } + + // 2) Multiply the integer multiplier. Convert lefts shifts into expr and multiply. + auto lshift_expr = MakeConstantTensor(hp_dtype, {n_channels}, lshifts); + auto exp_lshift_expr = ExpandBiasToMatchAxis(lshift_expr, n_dim, {channel_axis}); + tensor = LeftShift(tensor, exp_lshift_expr); + + // 3) Perform the multiplication in higher precision. + // The scalar is a fixed point value of int32 where the decimal point is + // between bits 31 and 30. After multiplying with input_tensor, the result + // is in int64 where the decimal point is sitting between bits 31 and 30 + // (from the right, rightmost bit is bit 0). The computation is performed in + // higher precision to avoid overflow in multiplying two int32 values. + auto fixed_pt_multiplier_expr = MakeConstantTensor(hp_dtype, {n_channels}, fixed_pt_multipliers); + auto exp_fixed_pt_multiplier_expr = + ExpandBiasToMatchAxis(fixed_pt_multiplier_expr, n_dim, {channel_axis}); + tensor = Multiply(tensor, exp_fixed_pt_multiplier_expr); + + // 4) Find the rounding scalar. This depends on where the final decimal point sits. As we will be + // right shifting the multiplied_t, we need to first calculate the total_rshift. Further, we can + // calculate the pos and neg rounding offset. + std::vector pos_rounding_values, neg_rounding_values, total_rshifts; + for (auto rshift : rshifts) { + int total_rshift = rshift + 31; + total_rshifts.push_back(total_rshift); + pos_rounding_values.push_back((1ll << (total_rshift - 1))); + neg_rounding_values.push_back((1ll << (total_rshift - 1)) - 1); + } + // Make a Relay expr from positive and negative rounding offset values. + auto pos_rounding_value_expr = MakeConstantTensor(hp_dtype, {n_channels}, pos_rounding_values); + auto exp_pos_rounding_value_expr = + ExpandBiasToMatchAxis(pos_rounding_value_expr, n_dim, {channel_axis}); + auto neg_rounding_value_expr = MakeConstantTensor(hp_dtype, {n_channels}, neg_rounding_values); + auto exp_neg_rounding_value_expr = + ExpandBiasToMatchAxis(neg_rounding_value_expr, n_dim, {channel_axis}); + + Expr round_scalar; + if (rounding == "UPWARD") { + round_scalar = exp_pos_rounding_value_expr; + } else if (rounding == "TONEAREST") { + // To satisfy where op shape requirements, the rounding values are broadcasted. + auto pos_rounder = MakeBroadCastTo(exp_pos_rounding_value_expr, input_shape); + auto neg_rounder = MakeBroadCastTo(exp_neg_rounding_value_expr, input_shape); + + auto zero_t = Zeros(input_shape, hp_dtype); + round_scalar = Where(GreaterEqual(tensor, zero_t), pos_rounder, neg_rounder); + } else { + LOG(FATAL) << "Rounding mode " << rounding << " not supported."; + } + // Add the rounding scalar. + tensor = Add(tensor, round_scalar); + + // 5) Simply right shift the result to get the final output. + auto total_rshift_expr = MakeConstantTensor(hp_dtype, {n_channels}, total_rshifts); + auto exp_total_rshift_expr = ExpandBiasToMatchAxis(total_rshift_expr, n_dim, {channel_axis}); + tensor = RightShift(tensor, exp_total_rshift_expr); + + return tensor; +} + } // namespace qnn } // namespace relay } // namespace tvm diff --git a/src/relay/qnn/util.h b/src/relay/qnn/util.h index 67b934e1c4ec..e4b5cf1134d7 100644 --- a/src/relay/qnn/util.h +++ b/src/relay/qnn/util.h @@ -30,6 +30,7 @@ #include #include #include +#include #include namespace tvm { @@ -124,6 +125,33 @@ static inline int64_t get_const_int(const tvm::Expr& x) { Expr FixedPointMultiply(Expr tensor, double multiplier, const Array& input_shape, const std::string& rounding); +/* + * \brief Fixed point multiplication between integer tensor with floating point + scalar where the input tensor is per-axis/per-channel quantized.. + * \param tensor The quantized input tensor of dtype int64. + * \param multiplier The scalar multiplier. + * \param input_shape Shape of the input tensor. + * \param channel_axis The channel_axis along which the input tensor is quantized. Default value is + -1 which corresponds to the last channel_axis. + * \param rounding "UPWARD" or "TONEAREST". The rounding direction when the value + is midway between" "two representable values. + * \return The sequence of Relay ops for fixed point multiplication. + + * \note Original compuation is scale_fp32 * quantized_tensor. To convert into + * integer computation, the multiplication with fp32 vector can be + * replaced by multiplication with an int vector and then right shifting + * the result. This approximates the floating point computation with a + * fixed point computation. + * + * Computation of fixed point multiplication is consist of following + steps: + * 1) Multiply the fixed point multiplier with quantized tensor. + * 2) Round the result. + * 3) Right shift the result + */ +Expr FixedPointMultiplyPerChannel(Expr tensor, std::vector multiplier, + const Array& input_shape, int channel_axis, + const std::string& rounding); /* * \brief Checks whether an expr type is scalar of a given data type. * \param expr_type The type of expr to be checked. @@ -131,12 +159,45 @@ Expr FixedPointMultiply(Expr tensor, double multiplier, const Array& * \return True if the type is a scalar of given dtype */ static inline bool IsScalarType(const Type& expr_type, const DataType& dtype) { - const auto* scale = expr_type.as(); - CHECK_EQ(scale->shape.size(), 0); - CHECK(scale->dtype == dtype) << "Expected " << dtype << " but got " << scale->dtype; + const auto* tensor_type = expr_type.as(); + CHECK_EQ(tensor_type->shape.size(), 0); + CHECK(tensor_type->dtype == dtype) << "Expected " << dtype << " but got " << tensor_type->dtype; return true; } +/* + * \brief Checks and assigns types to scale and zero points. + * \param expr_type The type of expr to be checked. + * \param dtype The expected dtype. + * \param shape The shape at C dim of original tensor. + * \param reporter The type reported of original InferType call. + */ +static inline void AssignType(const Type& expr_type, const DataType& dtype, const IndexExpr& shape, + const TypeReporter& reporter) { + // Scale/Zero_points can be either const scalar or a vector with C axis num elems. + const auto* tensor_type = expr_type.as(); + const auto tensor_dtype = tensor_type->dtype; + CHECK(tensor_dtype == dtype) << "Expected type is " << dtype << " but received " << tensor_dtype; + if (tensor_type->shape.size() != 0) { + reporter->Assign(expr_type, TensorTypeNode::make({shape}, tensor_type->dtype)); + } +} + +static inline std::vector GetFloatVectorFromConstant(const Expr& expr) { + const auto* n = expr.as(); + std::vector vals; + CHECK(n) << "Expr must be a constant expr - " << AsText(expr, false); + int64_t num_elems = 1; + auto shape = n->data.Shape(); + for (size_t i = 0; i < shape.size(); i++) { + num_elems *= shape[i]; + } + for (int64_t i = 0; i < num_elems; i++) { + vals.push_back(static_cast(n->data->data)[i]); + } + return vals; +} + } // namespace qnn } // namespace relay } // namespace tvm diff --git a/tests/python/relay/test_op_qnn_quantize.py b/tests/python/relay/test_op_qnn_quantize.py index 68718886a0a2..45caedaf4a44 100644 --- a/tests/python/relay/test_op_qnn_quantize.py +++ b/tests/python/relay/test_op_qnn_quantize.py @@ -20,13 +20,15 @@ from tvm import relay from tvm.contrib import graph_runtime -def quantize_test_driver(in_dtype, quant_args, out_dtype, in_data, verify_output_data): +def quantize_test_driver(in_dtype, quant_args, axis, out_dtype, in_data, verify_output_data): shape = in_data.shape input_data = relay.var("input_data", shape=shape, dtype=in_dtype) - output_zero_point = relay.const(quant_args['out_zero_point'], 'int32') - output_scale = relay.const(quant_args['out_scale'], 'float32') + output_zero_point = relay.const(quant_args['out_zero_point']) + output_scale = relay.const(quant_args['out_scale']) quantized_output = relay.qnn.op.quantize(input_data, output_scale=output_scale, - output_zero_point=output_zero_point,out_dtype=out_dtype) + output_zero_point=output_zero_point, + axis=axis, + out_dtype=out_dtype) mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output) mod = relay.Module.from_expr(mod) with relay.build_config(opt_level=3): @@ -46,9 +48,9 @@ def test_float32_to_uint8(): output = np.array([0, 1, 2, 3, 4, 251, 252, 253, 254, 255]) \ .astype('uint8') \ .reshape((2,5)) - quant_args = {"out_zero_point":127, "out_scale":0.5} - quantize_test_driver(in_dtype='float32', quant_args=quant_args, out_dtype='uint8', in_data=data, - verify_output_data=output) + quant_args = {"out_zero_point":np.int32(127), "out_scale": np.float32(0.5)} + quantize_test_driver(in_dtype='float32', quant_args=quant_args, axis=-1, out_dtype='uint8', + in_data=data, verify_output_data=output) def test_float32_to_int8(): data = np.array([-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64]) \ @@ -57,10 +59,37 @@ def test_float32_to_int8(): output = np.array([-128, -127, -126, -125, -124, 123, 124, 125, 126, 127]) \ .astype('int8') \ .reshape((2,5)) - quant_args = {"out_zero_point":-1, "out_scale":0.5} - quantize_test_driver(in_dtype='float32', quant_args=quant_args, out_dtype='int8', in_data=data, - verify_output_data=output) + quant_args = {"out_zero_point":np.int32(-1), "out_scale":np.float32(0.5)} + quantize_test_driver(in_dtype='float32', quant_args=quant_args, axis=-1, out_dtype='int8', + in_data=data, verify_output_data=output) + +def test_channelwise_axis_0(): + data = np.array([-63.5, -63, -62.5, -62, -61.5, 30, 31, 31.5, 31.75, 32]) \ + .astype('float32') \ + .reshape((2,5)) + output = np.array([0, 1, 2, 3, 4, 243, 247, 249, 250, 251]) \ + .astype('uint8') \ + .reshape((2,5)) + quant_args = {"out_zero_point" : np.array([127, 123]).astype('int32'), + "out_scale" : np.array([0.5, 0.25]).astype('float32')} + + quantize_test_driver(in_dtype='float32', quant_args=quant_args, axis=0, out_dtype='uint8', + in_data=data, verify_output_data=output) + +def test_channelwise_axis_1(): + data = np.transpose(np.array([-63.5, -63, -62.5, -62, -61.5, 30, 31, 31.5, 31.75, 32]) \ + .astype('float32').reshape((2,5))) + output = np.transpose(np.array([0, 1, 2, 3, 4, 243, 247, 249, 250, 251]) \ + .astype('uint8').reshape((2,5))) + quant_args = {"out_zero_point" : np.array([127, 123]).astype('int32'), + "out_scale" : np.array([0.5, 0.25]).astype('float32')} + + quantize_test_driver(in_dtype='float32', quant_args=quant_args, axis=1, out_dtype='uint8', + in_data=data, verify_output_data=output) + if __name__ == "__main__": test_float32_to_uint8() test_float32_to_int8() + test_channelwise_axis_0() + test_channelwise_axis_1() diff --git a/tests/python/relay/test_op_qnn_requantize.py b/tests/python/relay/test_op_qnn_requantize.py index 6d6b9b448548..bd37cb989f46 100644 --- a/tests/python/relay/test_op_qnn_requantize.py +++ b/tests/python/relay/test_op_qnn_requantize.py @@ -34,15 +34,27 @@ def verify(mod, goldens): np.testing.assert_equal(res, golden_output) def get_mod(data_shape, data_dtype, out_dtype, input_scale, output_scale, - input_zero_point=0, output_zero_point=0, rounding="TONEAREST"): + input_zero_point=0, output_zero_point=0, rounding="TONEAREST", + axis=0): quantized_data = relay.var("quantized_data", shape=data_shape, dtype=data_dtype) + if isinstance(input_scale, float): + input_scale_expr = relay.const(input_scale, 'float32') + else: + input_scale_expr = relay.const(np.array(input_scale).astype('float32')) + + if isinstance(input_zero_point, float): + input_zero_point_expr = relay.const(input_zero_point, 'int32') + else: + input_zero_point_expr = relay.const(np.array(input_zero_point).astype('int32')) + mod = relay.qnn.op.requantize( quantized_data, - input_scale=relay.const(input_scale, 'float32'), - input_zero_point=relay.const(input_zero_point, 'int32'), + input_scale=input_scale_expr, + input_zero_point=input_zero_point_expr, output_scale=relay.const(output_scale, 'float32'), output_zero_point=relay.const(output_zero_point, 'int32'), + axis=axis, rounding=rounding, out_dtype=out_dtype) @@ -240,9 +252,70 @@ def test_zero_point(): golden_output = np.subtract(golden_output, 1) verify(mod, (golden_data, golden_output)) +def test_per_channel_same_scale(): + # Have same scales, everything within range + golden_data = np.arange(-5, 5, 1).astype('int32').reshape((5,2)) + golden_output = golden_data + + for rounding in roundings: + mod = get_mod(data_shape=(5, 2), + data_dtype='int32', + out_dtype="int8", + input_scale=[0.5, 0.5], + output_scale=0.5, + axis=1, + rounding=rounding) + verify(mod, (golden_data, golden_output)) + + # Change axis + golden_data = np.arange(-10, 10, 1).astype('int32').reshape((2,2,5)) + golden_output = golden_data + + for rounding in roundings: + mod = get_mod(data_shape=(2, 2, 5), + data_dtype='int32', + out_dtype="int8", + input_scale=[0.5, 0.5], + output_scale=0.5, + axis=1, + rounding=rounding) + verify(mod, (golden_data, golden_output)) + +def test_per_channel_different_scale(): + # Have same scales, everything within range + golden_data = np.arange(-5, 5, 1).astype('int32').reshape((5,2)) + golden_output = np.array([-5, -2, -3, -1, -1, 0, 1, 1, 3, 2]).reshape((5, 2)) + + for rounding in roundings: + mod = get_mod(data_shape=(5, 2), + data_dtype='int32', + out_dtype="int8", + input_scale=[0.5, 0.25], + output_scale=0.5, + axis=1, + rounding=rounding) + verify(mod, (golden_data, golden_output)) + + # Change axis + golden_data = np.arange(-20, 20, 2).astype('int32').reshape((2,2,5)) + golden_output = np.array([-20, -18, -16, -14, -12, -5, -4, -3, -2, -1, 0, 2, 4, 6, 8, 5, 6, 7, + 8, 9]).reshape((2, 2, 5)) + + for rounding in roundings: + mod = get_mod(data_shape=(2, 2, 5), + data_dtype='int32', + out_dtype="int8", + input_scale=[0.5, 0.25], + output_scale=0.5, + axis=1, + rounding=rounding) + verify(mod, (golden_data, golden_output)) + if __name__ == "__main__": test_same_scale() test_downscale() test_upscale() test_saturation() test_zero_point() + test_per_channel_same_scale() + test_per_channel_different_scale()