From 6e1681eedaa605e4c325b6be02899f3e849d9086 Mon Sep 17 00:00:00 2001 From: masahi Date: Thu, 29 Oct 2020 13:52:59 +0900 Subject: [PATCH] [Torch, QNN] Support dynamic quantization flow to enable importing quantized transformer models (#6782) * add stub and test * per channel quantize * calculate qparam correctly * import qbert working * support batched qdense * test batched input * fix mkl offloading of batch matmul * reduce range become True in torch 1.6 * fix for 1.6 * Revert "fix mkl offloading of batch matmul" This reverts commit cd90aa783688c68e1b12633eea4d2690d9e3a5a5. * fix merge * fix * lint fix * fix black * more black fix * fix version check for 1.5.1 * disable assert on v1.4 (strange pytorch issue) * minor fix * use dequantize Co-authored-by: masa --- python/tvm/relay/frontend/pytorch.py | 5 +- python/tvm/relay/frontend/qnn_torch.py | 81 ++++++++++++++++++++--- src/relay/qnn/op/dense.cc | 35 +++++++--- src/relay/qnn/op/quantize.cc | 23 +++---- src/relay/qnn/utils.h | 12 ++++ tests/python/frontend/pytorch/qnn_test.py | 55 ++++++++++++--- 6 files changed, 172 insertions(+), 39 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 52761647d15bb..8d164314ecc80 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2741,7 +2741,7 @@ def _run_jit_passes(graph): # pylint: disable=c-extension-no-member import torch - if is_version_greater_than("1.5.0"): + if is_version_greater_than("1.5.1"): # This is required for torchvision detection models from 1.6 above # It is the same as _jit_pass_inline, except that it has some special # case behaviors for some ops such as aten::__interpolate() @@ -3383,7 +3383,8 @@ def from_pytorch(script_module, input_infos, custom_convert_map=None, default_dt ret_name = _get_input_names(graph.return_node()) # For quantized models - if "aten::quantize_per_tensor" in op_names: + quantized_ops = set(["aten::quantize_per_tensor", "quantized::linear_dynamic"]) + if len(quantized_ops.intersection(set(op_names))) > 0: weight_quant_params = qnn_torch.get_weight_quant_params(script_module) qnn_torch.add_input_quant_params_to_op_inputs(graph) qnn_torch.add_quant_params_to_outputs(outputs, packed_param_map, weight_quant_params) diff --git a/python/tvm/relay/frontend/qnn_torch.py b/python/tvm/relay/frontend/qnn_torch.py index 3f8d495511ddc..e3431043bc864 100644 --- a/python/tvm/relay/frontend/qnn_torch.py +++ b/python/tvm/relay/frontend/qnn_torch.py @@ -173,7 +173,7 @@ def _get_quant_param_for_input(input_value): # 6th and 7th arg are output scale and zp respectively. # PyTorch 1.6 changed qconv API - if is_version_greater_than("1.5.0"): + if is_version_greater_than("1.5.1"): qconv_indices = (2, 3) else: qconv_indices = (6, 7) @@ -575,13 +575,7 @@ def _impl(inputs, _): ) return _do_bias_and_requantize( - conv_out, - bias, - input_scale, - weight_scale, - output_scale, - output_zero_point, - with_relu, + conv_out, bias, input_scale, weight_scale, output_scale, output_zero_point, with_relu ) return _impl @@ -826,6 +820,76 @@ def _impl(inputs, _): return _impl +def _linear_dynamic(): + def _calculate_qparam(inp): + # reference ATen/native/quantized/cpu/qlinear_dynamic.cpp + # ChooseQuantizationParams function + mn = _op.min(inp) + mx = _op.max(inp) + + # Ensure that the interval contains 0 + mn = _op.minimum(mn, _op.const(0.0, dtype="float32")) + mx = _op.maximum(mx, _op.const(0.0, dtype="float32")) + + qmax = 255 + + # reduce_range became True in v1.6 + if is_version_greater_than("1.5.1"): + qmax = 127 + + scale = (mx - mn) / _expr.const(qmax, dtype="float32") + + zero_point_from_min = -(mn / scale) + zero_point = _op.cast(_op.round(_op.clip(zero_point_from_min, 0.0, qmax)), "int32") + + return scale, zero_point + + def _impl(inputs, _): + weight = inputs[1][0] + weight_scale = inputs[1][1] + weight_zero_point = inputs[1][2] + + inp = inputs[0] + + input_scale, input_zero_point = _calculate_qparam(inp) + qinp = relay.qnn.op.quantize(inp, input_scale, input_zero_point, out_dtype="uint8") + + data_shape = infer_shape(inp) + + if len(data_shape) > 2: + qinp = _op.reverse_reshape(qinp, [-1, 0]) + + weight_shape = infer_shape(weight) + units = weight_shape[0] + dense = relay.qnn.op.dense( + qinp, + weight, + input_zero_point, + weight_zero_point, + input_scale, + weight_scale, + units=units, + ) + bias_var = inputs[1][3] + + dequant_scale = input_scale * weight_scale + dense_out = relay.qnn.op.dequantize( + dense, dequant_scale, input_zero_point=relay.const(0, "int32"), axis=1 + ) + + if len(data_shape) > 2: + new_shape = list(data_shape[:-1]) + new_shape.append(units) + dense_out = _op.reshape(dense_out, new_shape) + + if bias_var is not None: + return dense_out + bias_var + + return dense_out + + return _impl + + convert_map = { "aten::quantize_per_tensor": _quantize_per_tensor(), "quantized::conv2d_relu": _quantized_conv2d(with_relu=True), @@ -841,4 +905,5 @@ def _impl(inputs, _): "quantized::add_scalar": _add_scalar(), "quantized::mul_scalar": _mul_scalar(), "quantized::relu6": _relu6(), + "quantized::linear_dynamic": _linear_dynamic(), } diff --git a/src/relay/qnn/op/dense.cc b/src/relay/qnn/op/dense.cc index 62988c8cc52fd..3602995b8f164 100644 --- a/src/relay/qnn/op/dense.cc +++ b/src/relay/qnn/op/dense.cc @@ -99,6 +99,18 @@ Expr DenseFourthTerm(int input_zero_point_int, int kernel_zero_point_int, int re return MakeConstantScalar(DataType::Int(32), scalar_term); } +Expr DenseFourthTerm(const Expr& input_zero_point, const Expr& kernel_zero_point, + int reduction_dim_size) { + auto reduction_dim = MakeConstantScalar(DataType::Int(32), reduction_dim_size); + return Multiply(Multiply(input_zero_point, kernel_zero_point), reduction_dim); +} + +Expr DenseCombineTerms(const Expr& term1, const Expr& term2, const Expr& term3, const Expr& term4) { + auto data_term = Subtract(term1, term2); + // Putting constant terms together, so that constant folding can fold it. + auto const_term = Subtract(term4, term3); + return Add(data_term, const_term); +} /* * \brief Forward rewrite the qnn dense op. * \param attrs The QNN dense attrs. @@ -144,14 +156,24 @@ Expr QnnDenseCanonicalize(const Attrs& attrs, const Array& new_args, const auto* qnn_dense_attrs = attrs.as(); + auto term1 = DenseFirstTerm(quantized_data, quantized_kernel, qnn_dense_attrs); + auto term2 = DenseSecondTerm(quantized_data, kernel_zero_point); + auto term3 = DenseThirdTerm(quantized_kernel, input_zero_point); + // Extract the integer zero points. - auto input_zero_point_int = GetScalarFromConstant(input_zero_point); auto kernel_zero_point_int = GetScalarFromConstant(kernel_zero_point); + if (!IsConstScalar(input_zero_point)) { + if (kernel_zero_point_int == 0) { + return Subtract(term1, term3); + } + auto term4 = DenseFourthTerm(input_zero_point, kernel_zero_point, reduction_dim_size); + return DenseCombineTerms(term1, term2, term3, term4); + } + + auto input_zero_point_int = GetScalarFromConstant(input_zero_point); + // Get all the terms as described in the comments. - auto term1 = DenseFirstTerm(quantized_data, quantized_kernel, qnn_dense_attrs); - auto term2 = DenseSecondTerm(quantized_data, kernel_zero_point); - auto term3 = DenseThirdTerm(quantized_kernel, input_zero_point); auto term4 = DenseFourthTerm(input_zero_point_int, kernel_zero_point_int, reduction_dim_size); // Combine those 4 terms depending on the zero points to get the best lowering. @@ -165,10 +187,7 @@ Expr QnnDenseCanonicalize(const Attrs& attrs, const Array& new_args, // term 2 and term 4 become zero. return Subtract(term1, term3); } else { - auto data_term = Subtract(term1, term2); - // Putting constant terms together, so that constant folding can fold it. - auto const_term = Subtract(term4, term3); - return Add(data_term, const_term); + return DenseCombineTerms(term1, term2, term3, term4); } } diff --git a/src/relay/qnn/op/quantize.cc b/src/relay/qnn/op/quantize.cc index 0622c96f04a66..9829834f43a3b 100644 --- a/src/relay/qnn/op/quantize.cc +++ b/src/relay/qnn/op/quantize.cc @@ -83,20 +83,27 @@ Expr MakeQuantize(Expr data, Expr output_scale, Expr output_zero_point, int axis } Expr QuantizeLower(const Expr& input_tensor, const Expr& output_scale, - const Expr& output_zero_point, const Array& input_shape, + const Expr& output_zero_point, const Array& types, const QuantizeAttrs* attrs) { + ICHECK_EQ(types.size(), 4); + auto in_type = types[0]; + auto in_tensor_type = in_type.as(); + ICHECK(in_tensor_type != nullptr) << "Type information missing." + << " Please run infer_type pass."; + Array input_shape = in_tensor_type->shape; + const auto out_dtype = attrs->out_dtype; const auto axis = attrs->axis; size_t n_dim = input_shape.size(); auto expanded_output_scale = output_scale; - if (!IsConstScalar(output_scale)) { + if (!IsConstScalar(output_scale) && !IsScalarType(types[1])) { expanded_output_scale = ExpandBiasToMatchAxis(output_scale, n_dim, {axis}); } auto expanded_output_zero_point = output_zero_point; - if (!IsConstScalar(output_zero_point)) { + if (!IsConstScalar(output_zero_point) && !IsScalarType(types[2])) { expanded_output_zero_point = ExpandBiasToMatchAxis(output_zero_point, n_dim, {axis}); } @@ -120,15 +127,7 @@ Expr QuantizeQnnCanonicalize(const Attrs& attrs, const Array& new_args, const auto* quantize_attrs = attrs.as(); ICHECK(quantize_attrs != nullptr); - // Find input shape. - ICHECK_EQ(types.size(), 4); - auto in_type = types[0]; - auto in_tensor_type = in_type.as(); - ICHECK(in_tensor_type != nullptr) << "Type information missing." - << " Please run infer_type pass."; - Array input_shape = in_tensor_type->shape; - - return QuantizeLower(data, output_scale, output_zero_point, input_shape, quantize_attrs); + return QuantizeLower(data, output_scale, output_zero_point, types, quantize_attrs); } RELAY_REGISTER_OP("qnn.quantize") diff --git a/src/relay/qnn/utils.h b/src/relay/qnn/utils.h index ab5c9a4fbbe20..23759a52ec41e 100644 --- a/src/relay/qnn/utils.h +++ b/src/relay/qnn/utils.h @@ -179,6 +179,18 @@ static inline bool IsScalarType(const Type& expr_type, const DataType& dtype) { return true; } +/* + * \brief Checks whether an expr type is scalar. + * \param expr_type The type of expr to be checked. + * \return True if the type is a scalar + */ +static inline bool IsScalarType(const Type& expr_type) { + const auto* tensor_type = expr_type.as(); + CHECK(tensor_type) << "Only tensor type can be checked for scalar values. But got" + << AsText(expr_type, false); + return tensor_type->shape.size() == 0; +} + /* * \brief Checks and assigns types to scale and zero points. * \param expr_type The type of expr to be checked. diff --git a/tests/python/frontend/pytorch/qnn_test.py b/tests/python/frontend/pytorch/qnn_test.py index 706f15b9d9d92..1851e31e817fc 100644 --- a/tests/python/frontend/pytorch/qnn_test.py +++ b/tests/python/frontend/pytorch/qnn_test.py @@ -27,7 +27,9 @@ from torch.quantization import fuse_modules, QuantWrapper import tvm +import tvm.testing from tvm import relay +from tvm.relay.frontend.pytorch_utils import is_version_greater_than from tvm.contrib.download import download_testdata @@ -197,9 +199,7 @@ def fuse_model(self): # test on quantized::mul_scalar with negative scale class MulScalarNegative(nn.Module): - def __init__( - self, - ): + def __init__(self): super().__init__() self.float_op = nn.quantized.FloatFunctional() self.quant = QuantStub() @@ -337,12 +337,7 @@ def get_transform(): normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) return transforms.Compose( - [ - transforms.Resize(256), - transforms.CenterCrop(224), - transforms.ToTensor(), - normalize, - ] + [transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize] ) def get_real_image(im_height, im_width): @@ -508,3 +503,45 @@ def test_serialized_modules(): num_identical = np.sum(np.abs(tvm_result - pt_result) < 1e-2) match_ratio = num_identical / float(np.prod(tvm_result.shape)) assert match_ratio > 0.90 + + +def test_quantize_dynamic(): + # A wrapper is required for quantize_dynamic to work correctly + class LinearWrapper(nn.Module): + def __init__(self, in_dim, hidden_dim): + super().__init__() + self.linear = nn.Linear(in_dim, hidden_dim) + + def forward(self, inp): + return self.linear(inp) + + torch.manual_seed(0) + mod = LinearWrapper(16, 32) + + for qconfig in [ + torch.quantization.per_channel_dynamic_qconfig, + torch.quantization.default_dynamic_qconfig, + ]: + for ishape in [(16, 16), (10, 16, 16)]: + qspec = {nn.Linear: qconfig} + qmod = torch.quantization.quantize_dynamic(mod, qconfig_spec=qspec, dtype=torch.qint8) + + inp = torch.randn(*ishape) + script_module = torch.jit.trace(qmod, inp).eval() + + with torch.no_grad(): + pt_result = script_module(inp.clone()).numpy() + + input_name = "input" + runtime = get_tvm_runtime(script_module, "input", inp.shape) + runtime.set_input(input_name, inp.numpy().copy()) + runtime.run() + tvm_result = runtime.get_output(0).asnumpy() + + # Only compare with the PyTorch result for version v1.6 or newer + # Have seen a strange accuracy problem from PyTorch 1.4 and 1.5 + # Even with the manual random seed set, the same PyTorch + # version can outputs slightly different results depending on an environment. + # Outputs from v1.6 seem reliable. TVM's outputs are always the same + if is_version_greater_than("1.5.1"): + tvm.testing.assert_allclose(tvm_result, pt_result, rtol=1e-4, atol=1e-4)