From 2b753f8b674f186cd31d4ffe43cfd6720e87a265 Mon Sep 17 00:00:00 2001 From: masa Date: Sat, 24 Oct 2020 18:09:35 +0900 Subject: [PATCH 01/19] add stub and test --- python/tvm/relay/frontend/pytorch.py | 4 ++- python/tvm/relay/frontend/qnn_torch.py | 34 ++++++++++++++++++++ tests/python/frontend/pytorch/qnn_test.py | 38 +++++++++++++++++++++++ 3 files changed, 75 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 52761647d15b..b08e4a540f19 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -3373,6 +3373,7 @@ def from_pytorch(script_module, input_infos, custom_convert_map=None, default_dt is_module = isinstance(script_module, torch.jit.ScriptModule) params = script_module.state_dict() if is_module else {} + outputs = _get_relay_input_vars( graph, input_infos, prelude, default_dtype=default_dtype, is_module=is_module ) @@ -3383,7 +3384,8 @@ def from_pytorch(script_module, input_infos, custom_convert_map=None, default_dt ret_name = _get_input_names(graph.return_node()) # For quantized models - if "aten::quantize_per_tensor" in op_names: + quantized_ops = set(["aten::quantize_per_tensor", "quantized::linear_dynamic"]) + if len(quantized_ops.intersection(set(op_names))) > 0: weight_quant_params = qnn_torch.get_weight_quant_params(script_module) qnn_torch.add_input_quant_params_to_op_inputs(graph) qnn_torch.add_quant_params_to_outputs(outputs, packed_param_map, weight_quant_params) diff --git a/python/tvm/relay/frontend/qnn_torch.py b/python/tvm/relay/frontend/qnn_torch.py index 3f8d495511dd..802f0aed1e73 100644 --- a/python/tvm/relay/frontend/qnn_torch.py +++ b/python/tvm/relay/frontend/qnn_torch.py @@ -826,6 +826,39 @@ def _impl(inputs, _): return _impl +def _linear_dynamic(): + def _impl(inputs, _): + weight = inputs[1][0] + weight_scale = inputs[1][1] + weight_zero_point = inputs[1][2] + + input_scale = _expr.const(1.0) + input_zero_point = _expr.const(0) + + qinp = relay.qnn.op.quantize(inputs[0], input_scale, input_zero_point) + + weight_shape = infer_shape(weight) + dense = relay.qnn.op.dense( + qinp, + weight, + input_zero_point, + weight_zero_point, + input_scale, + weight_scale, + units=weight_shape[0], + ) + bias_var = inputs[1][3] + + dense_out = _op.cast(dense, "float32") + + if bias_var is not None: + return _op.nn.bias_add(dense_out, bias_var) + + return dense_out + + return _impl + + convert_map = { "aten::quantize_per_tensor": _quantize_per_tensor(), "quantized::conv2d_relu": _quantized_conv2d(with_relu=True), @@ -841,4 +874,5 @@ def _impl(inputs, _): "quantized::add_scalar": _add_scalar(), "quantized::mul_scalar": _mul_scalar(), "quantized::relu6": _relu6(), + "quantized::linear_dynamic": _linear_dynamic(), } diff --git a/tests/python/frontend/pytorch/qnn_test.py b/tests/python/frontend/pytorch/qnn_test.py index 706f15b9d9d9..5ac4e1c7b174 100644 --- a/tests/python/frontend/pytorch/qnn_test.py +++ b/tests/python/frontend/pytorch/qnn_test.py @@ -41,6 +41,7 @@ def get_tvm_runtime(script_module, input_name, ishape): input_shapes = [(input_name, ishape)] mod, params = relay.frontend.from_pytorch(script_module, input_shapes) + print(mod["main"]) with tvm.transform.PassContext(opt_level=3): # test on only cpu for now, torch cannot run quant models on cuda @@ -508,3 +509,40 @@ def test_serialized_modules(): num_identical = np.sum(np.abs(tvm_result - pt_result) < 1e-2) match_ratio = num_identical / float(np.prod(tvm_result.shape)) assert match_ratio > 0.90 + + +def test_quantize_dynamic(): + # A wrapper is required for quantize_dynamic to work correctly + class LinearWrapper(nn.Module): + def __init__(self, in_dim, hidden_dim): + super().__init__() + self.linear = nn.Linear(in_dim, hidden_dim) + + def forward(self, inp): + return self.linear(inp) + + mod = LinearWrapper(16, 32) + + qmod = torch.quantization.quantize_dynamic(mod, {nn.Linear}, dtype=torch.qint8) + + inp = torch.randn(16, 16) + script_module = torch.jit.trace(qmod, inp).eval() + + with torch.no_grad(): + pt_result = script_module(inp.clone()).numpy() + + input_name = "input" + runtime = get_tvm_runtime(script_module, "input", inp.shape) + runtime.set_input(input_name, inp.numpy().copy()) + runtime.run() + tvm_result = runtime.get_output(0).asnumpy() + + max_abs_diff = np.max(np.abs(tvm_result - pt_result)) + mean_abs_diff = np.mean(np.abs(tvm_result - pt_result)) + num_identical = np.sum(tvm_result == pt_result) + match_ratio = num_identical / float(np.prod(tvm_result.shape)) + + print(max_abs_diff, mean_abs_diff, match_ratio) + + +test_quantize_dynamic() From e13037bc97cc02402ce4b2a36787a58ebad10507 Mon Sep 17 00:00:00 2001 From: masa Date: Sat, 24 Oct 2020 18:17:16 +0900 Subject: [PATCH 02/19] per channel quantize --- tests/python/frontend/pytorch/qnn_test.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/python/frontend/pytorch/qnn_test.py b/tests/python/frontend/pytorch/qnn_test.py index 5ac4e1c7b174..c1aba6ef643c 100644 --- a/tests/python/frontend/pytorch/qnn_test.py +++ b/tests/python/frontend/pytorch/qnn_test.py @@ -523,7 +523,10 @@ def forward(self, inp): mod = LinearWrapper(16, 32) - qmod = torch.quantization.quantize_dynamic(mod, {nn.Linear}, dtype=torch.qint8) + qspec = {nn.Linear: torch.quantization.per_channel_dynamic_qconfig} + qmod = torch.quantization.quantize_dynamic( + mod, qconfig_spec=qspec, dtype=torch.qint8 + ) inp = torch.randn(16, 16) script_module = torch.jit.trace(qmod, inp).eval() From e3bd310c496e15f93edc305cbb3742e04bed2c94 Mon Sep 17 00:00:00 2001 From: masa Date: Sat, 24 Oct 2020 19:31:43 +0900 Subject: [PATCH 03/19] calculate qparam correctly --- python/tvm/relay/frontend/qnn_torch.py | 20 ++++++++++++++++---- src/relay/qnn/op/quantize.cc | 23 +++++++++++------------ src/relay/qnn/utils.h | 12 ++++++++++++ tests/python/frontend/pytorch/qnn_test.py | 15 +++------------ 4 files changed, 42 insertions(+), 28 deletions(-) diff --git a/python/tvm/relay/frontend/qnn_torch.py b/python/tvm/relay/frontend/qnn_torch.py index 802f0aed1e73..4a81f1f98479 100644 --- a/python/tvm/relay/frontend/qnn_torch.py +++ b/python/tvm/relay/frontend/qnn_torch.py @@ -827,15 +827,26 @@ def _impl(inputs, _): def _linear_dynamic(): + def _calculate_qparam(inp): + mx = _op.max(inp) + mn = _op.min(inp) + + scale = (mx - mn) / _expr.const(255.0) + + zero_point_from_min = -(mn / scale) + zero_point = _op.cast(_op.round(_op.clip(zero_point_from_min, 0.0, 255.0)), "int32") + + return scale, zero_point + def _impl(inputs, _): weight = inputs[1][0] weight_scale = inputs[1][1] weight_zero_point = inputs[1][2] - input_scale = _expr.const(1.0) - input_zero_point = _expr.const(0) + inp = inputs[0] - qinp = relay.qnn.op.quantize(inputs[0], input_scale, input_zero_point) + input_scale, input_zero_point = _calculate_qparam(inp) + qinp = relay.qnn.op.quantize(inp, input_scale, input_zero_point, out_dtype="uint8") weight_shape = infer_shape(weight) dense = relay.qnn.op.dense( @@ -849,7 +860,8 @@ def _impl(inputs, _): ) bias_var = inputs[1][3] - dense_out = _op.cast(dense, "float32") + dequant_scale = input_scale * weight_scale + dense_out = _op.cast(dense, "float32") * dequant_scale if bias_var is not None: return _op.nn.bias_add(dense_out, bias_var) diff --git a/src/relay/qnn/op/quantize.cc b/src/relay/qnn/op/quantize.cc index 0622c96f04a6..29d59b311994 100644 --- a/src/relay/qnn/op/quantize.cc +++ b/src/relay/qnn/op/quantize.cc @@ -83,20 +83,27 @@ Expr MakeQuantize(Expr data, Expr output_scale, Expr output_zero_point, int axis } Expr QuantizeLower(const Expr& input_tensor, const Expr& output_scale, - const Expr& output_zero_point, const Array& input_shape, + const Expr& output_zero_point, const Array& types, const QuantizeAttrs* attrs) { + ICHECK_EQ(types.size(), 4); + auto in_type = types[0]; + auto in_tensor_type = in_type.as(); + ICHECK(in_tensor_type != nullptr) << "Type information missing." + << " Please run infer_type pass."; + Array input_shape = in_tensor_type->shape; + const auto out_dtype = attrs->out_dtype; const auto axis = attrs->axis; size_t n_dim = input_shape.size(); auto expanded_output_scale = output_scale; - if (!IsConstScalar(output_scale)) { + if (!IsConstScalar(output_scale) && !IsScalarType(types[1])) { expanded_output_scale = ExpandBiasToMatchAxis(output_scale, n_dim, {axis}); } auto expanded_output_zero_point = output_zero_point; - if (!IsConstScalar(output_zero_point)) { + if (!IsConstScalar(output_zero_point) && !IsScalarType(types[2])) { expanded_output_zero_point = ExpandBiasToMatchAxis(output_zero_point, n_dim, {axis}); } @@ -120,15 +127,7 @@ Expr QuantizeQnnCanonicalize(const Attrs& attrs, const Array& new_args, const auto* quantize_attrs = attrs.as(); ICHECK(quantize_attrs != nullptr); - // Find input shape. - ICHECK_EQ(types.size(), 4); - auto in_type = types[0]; - auto in_tensor_type = in_type.as(); - ICHECK(in_tensor_type != nullptr) << "Type information missing." - << " Please run infer_type pass."; - Array input_shape = in_tensor_type->shape; - - return QuantizeLower(data, output_scale, output_zero_point, input_shape, quantize_attrs); + return QuantizeLower(data, output_scale, output_zero_point, types, quantize_attrs); } RELAY_REGISTER_OP("qnn.quantize") diff --git a/src/relay/qnn/utils.h b/src/relay/qnn/utils.h index ab5c9a4fbbe2..23759a52ec41 100644 --- a/src/relay/qnn/utils.h +++ b/src/relay/qnn/utils.h @@ -179,6 +179,18 @@ static inline bool IsScalarType(const Type& expr_type, const DataType& dtype) { return true; } +/* + * \brief Checks whether an expr type is scalar. + * \param expr_type The type of expr to be checked. + * \return True if the type is a scalar + */ +static inline bool IsScalarType(const Type& expr_type) { + const auto* tensor_type = expr_type.as(); + CHECK(tensor_type) << "Only tensor type can be checked for scalar values. But got" + << AsText(expr_type, false); + return tensor_type->shape.size() == 0; +} + /* * \brief Checks and assigns types to scale and zero points. * \param expr_type The type of expr to be checked. diff --git a/tests/python/frontend/pytorch/qnn_test.py b/tests/python/frontend/pytorch/qnn_test.py index c1aba6ef643c..962b24f81cbd 100644 --- a/tests/python/frontend/pytorch/qnn_test.py +++ b/tests/python/frontend/pytorch/qnn_test.py @@ -27,6 +27,7 @@ from torch.quantization import fuse_modules, QuantWrapper import tvm +import tvm.testing from tvm import relay from tvm.contrib.download import download_testdata @@ -524,9 +525,7 @@ def forward(self, inp): mod = LinearWrapper(16, 32) qspec = {nn.Linear: torch.quantization.per_channel_dynamic_qconfig} - qmod = torch.quantization.quantize_dynamic( - mod, qconfig_spec=qspec, dtype=torch.qint8 - ) + qmod = torch.quantization.quantize_dynamic(mod, qconfig_spec=qspec, dtype=torch.qint8) inp = torch.randn(16, 16) script_module = torch.jit.trace(qmod, inp).eval() @@ -540,12 +539,4 @@ def forward(self, inp): runtime.run() tvm_result = runtime.get_output(0).asnumpy() - max_abs_diff = np.max(np.abs(tvm_result - pt_result)) - mean_abs_diff = np.mean(np.abs(tvm_result - pt_result)) - num_identical = np.sum(tvm_result == pt_result) - match_ratio = num_identical / float(np.prod(tvm_result.shape)) - - print(max_abs_diff, mean_abs_diff, match_ratio) - - -test_quantize_dynamic() + tvm.testing.assert_allclose(tvm_result, pt_result, rtol=1e-5, atol=1e-5) From 9b27ea51dbb89bc191d6a1bf2b72d41911f08b98 Mon Sep 17 00:00:00 2001 From: masa Date: Sun, 25 Oct 2020 05:42:52 +0900 Subject: [PATCH 04/19] import qbert working --- python/tvm/relay/frontend/qnn_torch.py | 2 +- src/relay/qnn/op/dense.cc | 37 ++++++++++++++++++++------ 2 files changed, 30 insertions(+), 9 deletions(-) diff --git a/python/tvm/relay/frontend/qnn_torch.py b/python/tvm/relay/frontend/qnn_torch.py index 4a81f1f98479..fc8967344a8f 100644 --- a/python/tvm/relay/frontend/qnn_torch.py +++ b/python/tvm/relay/frontend/qnn_torch.py @@ -864,7 +864,7 @@ def _impl(inputs, _): dense_out = _op.cast(dense, "float32") * dequant_scale if bias_var is not None: - return _op.nn.bias_add(dense_out, bias_var) + return dense_out + bias_var return dense_out diff --git a/src/relay/qnn/op/dense.cc b/src/relay/qnn/op/dense.cc index 62988c8cc52f..be951afd0da6 100644 --- a/src/relay/qnn/op/dense.cc +++ b/src/relay/qnn/op/dense.cc @@ -99,6 +99,18 @@ Expr DenseFourthTerm(int input_zero_point_int, int kernel_zero_point_int, int re return MakeConstantScalar(DataType::Int(32), scalar_term); } +Expr DenseFourthTerm(const Expr& input_zero_point, const Expr& kernel_zero_point, + int reduction_dim_size) { + auto reduction_dim = MakeConstantScalar(DataType::Int(32), reduction_dim_size); + return Multiply(Multiply(input_zero_point, kernel_zero_point), reduction_dim); +} + +Expr DenseCombineTerms(const Expr& term1, const Expr& term2, const Expr& term3, const Expr& term4) { + auto data_term = Subtract(term1, term2); + // Putting constant terms together, so that constant folding can fold it. + auto const_term = Subtract(term4, term3); + return Add(data_term, const_term); +} /* * \brief Forward rewrite the qnn dense op. * \param attrs The QNN dense attrs. @@ -144,14 +156,26 @@ Expr QnnDenseCanonicalize(const Attrs& attrs, const Array& new_args, const auto* qnn_dense_attrs = attrs.as(); + auto term1 = DenseFirstTerm(quantized_data, quantized_kernel, qnn_dense_attrs); + auto term2 = DenseSecondTerm(quantized_data, kernel_zero_point); + auto term3 = DenseThirdTerm(quantized_kernel, input_zero_point); + // Extract the integer zero points. - auto input_zero_point_int = GetScalarFromConstant(input_zero_point); auto kernel_zero_point_int = GetScalarFromConstant(kernel_zero_point); + if (!IsConstScalar(input_zero_point)) { + if (kernel_zero_point_int == 0) { + LOG(INFO) << "kernel zp is zero"; + return Subtract(term1, term3); + } + LOG(INFO) << "kernel zp is non zero"; + auto term4 = DenseFourthTerm(input_zero_point, kernel_zero_point, reduction_dim_size); + return DenseCombineTerms(term1, term2, term3, term4); + } + + auto input_zero_point_int = GetScalarFromConstant(input_zero_point); + // Get all the terms as described in the comments. - auto term1 = DenseFirstTerm(quantized_data, quantized_kernel, qnn_dense_attrs); - auto term2 = DenseSecondTerm(quantized_data, kernel_zero_point); - auto term3 = DenseThirdTerm(quantized_kernel, input_zero_point); auto term4 = DenseFourthTerm(input_zero_point_int, kernel_zero_point_int, reduction_dim_size); // Combine those 4 terms depending on the zero points to get the best lowering. @@ -165,10 +189,7 @@ Expr QnnDenseCanonicalize(const Attrs& attrs, const Array& new_args, // term 2 and term 4 become zero. return Subtract(term1, term3); } else { - auto data_term = Subtract(term1, term2); - // Putting constant terms together, so that constant folding can fold it. - auto const_term = Subtract(term4, term3); - return Add(data_term, const_term); + return DenseCombineTerms(term1, term2, term3, term4); } } From bada149a906bbe513b1f561e35248c7fb780d0f6 Mon Sep 17 00:00:00 2001 From: masa Date: Sun, 25 Oct 2020 06:26:23 +0900 Subject: [PATCH 05/19] support batched qdense --- python/tvm/relay/frontend/qnn_torch.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/qnn_torch.py b/python/tvm/relay/frontend/qnn_torch.py index fc8967344a8f..711f0a74f504 100644 --- a/python/tvm/relay/frontend/qnn_torch.py +++ b/python/tvm/relay/frontend/qnn_torch.py @@ -848,7 +848,13 @@ def _impl(inputs, _): input_scale, input_zero_point = _calculate_qparam(inp) qinp = relay.qnn.op.quantize(inp, input_scale, input_zero_point, out_dtype="uint8") + data_shape = infer_shape(inp) + + if len(data_shape) > 2: + qinp = _op.reverse_reshape(qinp, [-1, 0]) + weight_shape = infer_shape(weight) + units = weight_shape[0] dense = relay.qnn.op.dense( qinp, weight, @@ -856,13 +862,18 @@ def _impl(inputs, _): weight_zero_point, input_scale, weight_scale, - units=weight_shape[0], + units=units, ) bias_var = inputs[1][3] dequant_scale = input_scale * weight_scale dense_out = _op.cast(dense, "float32") * dequant_scale + if len(data_shape) > 2: + new_shape = list(data_shape[:-1]) + new_shape.append(units) + dense_out = _op.reshape(dense_out, new_shape) + if bias_var is not None: return dense_out + bias_var From c47e5f693c0a7f2d12047b262226720c68b776d7 Mon Sep 17 00:00:00 2001 From: masa Date: Sun, 25 Oct 2020 08:53:14 +0900 Subject: [PATCH 06/19] test batched input --- src/relay/qnn/op/dense.cc | 2 -- tests/python/frontend/pytorch/qnn_test.py | 43 +++++++++++++++-------- 2 files changed, 29 insertions(+), 16 deletions(-) diff --git a/src/relay/qnn/op/dense.cc b/src/relay/qnn/op/dense.cc index be951afd0da6..3602995b8f16 100644 --- a/src/relay/qnn/op/dense.cc +++ b/src/relay/qnn/op/dense.cc @@ -165,10 +165,8 @@ Expr QnnDenseCanonicalize(const Attrs& attrs, const Array& new_args, if (!IsConstScalar(input_zero_point)) { if (kernel_zero_point_int == 0) { - LOG(INFO) << "kernel zp is zero"; return Subtract(term1, term3); } - LOG(INFO) << "kernel zp is non zero"; auto term4 = DenseFourthTerm(input_zero_point, kernel_zero_point, reduction_dim_size); return DenseCombineTerms(term1, term2, term3, term4); } diff --git a/tests/python/frontend/pytorch/qnn_test.py b/tests/python/frontend/pytorch/qnn_test.py index 962b24f81cbd..6e12cb4a8963 100644 --- a/tests/python/frontend/pytorch/qnn_test.py +++ b/tests/python/frontend/pytorch/qnn_test.py @@ -42,12 +42,14 @@ def get_tvm_runtime(script_module, input_name, ishape): input_shapes = [(input_name, ishape)] mod, params = relay.frontend.from_pytorch(script_module, input_shapes) - print(mod["main"]) with tvm.transform.PassContext(opt_level=3): # test on only cpu for now, torch cannot run quant models on cuda # also not to make CI too slow - lib = relay.build(mod, target="llvm", params=params) + # opt_mod, opt_params = relay.optimize(mod, target="llvm -mcpu=cascadelake -libs=mkl", params=params) + # print(opt_mod["main"]) + # lib = relay.build(mod, target="llvm -mcpu=cascadelake -libs=mkl", params=params) + lib = relay.build(mod, target="llvm -mcpu=cascadelake", params=params) runtime = tvm.contrib.graph_runtime.GraphModule(lib["default"](tvm.cpu(0))) return runtime @@ -524,19 +526,32 @@ def forward(self, inp): mod = LinearWrapper(16, 32) - qspec = {nn.Linear: torch.quantization.per_channel_dynamic_qconfig} - qmod = torch.quantization.quantize_dynamic(mod, qconfig_spec=qspec, dtype=torch.qint8) + for qconfig in [torch.quantization.per_channel_dynamic_qconfig, + torch.quantization.default_dynamic_qconfig]: + for ishape in [(16, 16), (10, 16, 16)]: + qspec = {nn.Linear: qconfig} + qmod = torch.quantization.quantize_dynamic(mod, qconfig_spec=qspec, dtype=torch.qint8) - inp = torch.randn(16, 16) - script_module = torch.jit.trace(qmod, inp).eval() + inp = torch.randn(*ishape) + script_module = torch.jit.trace(qmod, inp).eval() - with torch.no_grad(): - pt_result = script_module(inp.clone()).numpy() + with torch.no_grad(): + pt_result = script_module(inp.clone()).numpy() + + input_name = "input" + runtime = get_tvm_runtime(script_module, "input", inp.shape) + runtime.set_input(input_name, inp.numpy().copy()) + runtime.run() + tvm_result = runtime.get_output(0).asnumpy() + + max_abs_diff = np.max(np.abs(tvm_result - pt_result)) + mean_abs_diff = np.mean(np.abs(tvm_result - pt_result)) + num_identical = np.sum(tvm_result == pt_result) + match_ratio = num_identical / float(np.prod(tvm_result.shape)) + + print(max_abs_diff, mean_abs_diff, match_ratio) + + tvm.testing.assert_allclose(tvm_result, pt_result, rtol=1e-4, atol=1e-4) - input_name = "input" - runtime = get_tvm_runtime(script_module, "input", inp.shape) - runtime.set_input(input_name, inp.numpy().copy()) - runtime.run() - tvm_result = runtime.get_output(0).asnumpy() - tvm.testing.assert_allclose(tvm_result, pt_result, rtol=1e-5, atol=1e-5) +test_quantize_dynamic() From 2eda1d098c84c316a4343d60143ffcefd64c4f99 Mon Sep 17 00:00:00 2001 From: masa Date: Sun, 25 Oct 2020 09:44:16 +0900 Subject: [PATCH 07/19] fix mkl offloading of batch matmul --- python/tvm/topi/x86/batch_matmul.py | 16 ++++++---------- src/relay/op/type_relations.h | 2 ++ 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/python/tvm/topi/x86/batch_matmul.py b/python/tvm/topi/x86/batch_matmul.py index 100bdf205165..4abd955a6741 100644 --- a/python/tvm/topi/x86/batch_matmul.py +++ b/python/tvm/topi/x86/batch_matmul.py @@ -137,9 +137,9 @@ def _default_batch_matmul_config(cfg, M, N, K): cfg["tile_y"] = SplitEntity([M // y_bn, y_bn]) -def batch_matmul_blas_common(cfg, x, y, out_shape, lib): +def batch_matmul_common(cfg, x, y, out_shape, lib): """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are - data in batch, using one of BLAS libraries. + data in batch. Parameters ---------- @@ -151,8 +151,8 @@ def batch_matmul_blas_common(cfg, x, y, out_shape, lib): 3-D with shape [batch, N, K] out_shape : tuple or None Shape of the output - lib : A contrib module which implements batch_matmul funtion - cblas and mkl are supported + lib : A contrib module + cblas or mkl are supported Returns ------- @@ -174,23 +174,19 @@ def batch_matmul_blas_common(cfg, x, y, out_shape, lib): @autotvm.register_topi_compute("batch_matmul_cblas.x86") def batch_matmul_cblas(cfg, x, y, out_shape=None): - """Compute batch_matmul using cblas""" - return batch_matmul_blas_common(cfg, x, y, out_shape, cblas) + return batch_matmul_common(cfg, x, y, out_shape, cblas) @autotvm.register_topi_schedule("batch_matmul_cblas.x86") def schedule_batch_matmul_cblas(_, outs): - """Create schedule for batch_matmul_cblas""" return generic.schedule_extern(outs) @autotvm.register_topi_compute("batch_matmul_mkl.x86") def batch_matmul_mkl(cfg, x, y, out_shape=None): - """Compute batch_matmul using mkl""" - return batch_matmul_blas_common(cfg, x, y, out_shape, mkl) + return batch_matmul_common(cfg, x, y, out_shape, mkl) @autotvm.register_topi_schedule("batch_matmul_mkl.x86") def schedule_batch_matmul_mkl(_, outs): - """Create schedule for batch_matmul_mul""" return generic.schedule_extern(outs) diff --git a/src/relay/op/type_relations.h b/src/relay/op/type_relations.h index 6d6d5f70c0c2..8726e05f2f05 100644 --- a/src/relay/op/type_relations.h +++ b/src/relay/op/type_relations.h @@ -57,6 +57,8 @@ bool IdentityRel(const Array& types, int num_inputs, const Attrs& attrs, bool BroadcastRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter); +TensorType ConcreteBroadcast(const TensorType& t1, const TensorType& t2, DataType output_dtype); + /*! * \brief Determine the broadcasted shape from two input shapes * \param t1 One of two Tensortype whose shapes are broadcasted From 789507d2845bcce65b47bbbe3ce51c21517e2f5d Mon Sep 17 00:00:00 2001 From: masa Date: Mon, 26 Oct 2020 11:31:42 +0900 Subject: [PATCH 08/19] reduce range become True in torch 1.6 --- python/tvm/relay/frontend/qnn_torch.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/qnn_torch.py b/python/tvm/relay/frontend/qnn_torch.py index 711f0a74f504..e1cafba2ff7b 100644 --- a/python/tvm/relay/frontend/qnn_torch.py +++ b/python/tvm/relay/frontend/qnn_torch.py @@ -831,10 +831,16 @@ def _calculate_qparam(inp): mx = _op.max(inp) mn = _op.min(inp) - scale = (mx - mn) / _expr.const(255.0) + qmax = 255 + + # reduce_range became True in v1.6 + if is_version_greater_than("1.5.0"): + qmax /= 2 + + scale = (mx - mn) / _expr.const(qmax, dtype="float32") zero_point_from_min = -(mn / scale) - zero_point = _op.cast(_op.round(_op.clip(zero_point_from_min, 0.0, 255.0)), "int32") + zero_point = _op.cast(_op.round(_op.clip(zero_point_from_min, 0.0, qmax)), "int32") return scale, zero_point From e92e02abb1244cfc5cc5fc021eb43bd6f35c50ff Mon Sep 17 00:00:00 2001 From: masa Date: Mon, 26 Oct 2020 20:25:48 +0900 Subject: [PATCH 09/19] fix for 1.6 --- python/tvm/relay/frontend/qnn_torch.py | 10 ++++++++-- tests/python/frontend/pytorch/qnn_test.py | 8 +------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/python/tvm/relay/frontend/qnn_torch.py b/python/tvm/relay/frontend/qnn_torch.py index e1cafba2ff7b..e4204ae2d7e1 100644 --- a/python/tvm/relay/frontend/qnn_torch.py +++ b/python/tvm/relay/frontend/qnn_torch.py @@ -828,14 +828,20 @@ def _impl(inputs, _): def _linear_dynamic(): def _calculate_qparam(inp): - mx = _op.max(inp) + # reference ATen/native/quantized/cpu/qlinear_dynamic.cpp + # ChooseQuantizationParams function mn = _op.min(inp) + mx = _op.max(inp) + + # Ensure that the interval contains 0 + mn = _op.minimum(mn, _op.const(0., dtype="float32")) + mx = _op.maximum(mx, _op.const(0., dtype="float32")) qmax = 255 # reduce_range became True in v1.6 if is_version_greater_than("1.5.0"): - qmax /= 2 + qmax = 127 scale = (mx - mn) / _expr.const(qmax, dtype="float32") diff --git a/tests/python/frontend/pytorch/qnn_test.py b/tests/python/frontend/pytorch/qnn_test.py index 6e12cb4a8963..7056ebfb4a90 100644 --- a/tests/python/frontend/pytorch/qnn_test.py +++ b/tests/python/frontend/pytorch/qnn_test.py @@ -524,6 +524,7 @@ def __init__(self, in_dim, hidden_dim): def forward(self, inp): return self.linear(inp) + torch.manual_seed(0) mod = LinearWrapper(16, 32) for qconfig in [torch.quantization.per_channel_dynamic_qconfig, @@ -544,13 +545,6 @@ def forward(self, inp): runtime.run() tvm_result = runtime.get_output(0).asnumpy() - max_abs_diff = np.max(np.abs(tvm_result - pt_result)) - mean_abs_diff = np.mean(np.abs(tvm_result - pt_result)) - num_identical = np.sum(tvm_result == pt_result) - match_ratio = num_identical / float(np.prod(tvm_result.shape)) - - print(max_abs_diff, mean_abs_diff, match_ratio) - tvm.testing.assert_allclose(tvm_result, pt_result, rtol=1e-4, atol=1e-4) From 1c408897294cea55e4609bb2e9f2a41939e406af Mon Sep 17 00:00:00 2001 From: masa Date: Wed, 28 Oct 2020 10:44:53 +0900 Subject: [PATCH 10/19] Revert "fix mkl offloading of batch matmul" This reverts commit cd90aa783688c68e1b12633eea4d2690d9e3a5a5. --- python/tvm/relay/op/strategy/x86.py | 7 ------- python/tvm/topi/x86/batch_matmul.py | 24 ++++-------------------- src/relay/op/type_relations.cc | 2 +- src/relay/op/type_relations.h | 2 -- 4 files changed, 5 insertions(+), 30 deletions(-) diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index 3c5735b17aa5..e2a82d396b22 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -377,13 +377,6 @@ def batch_matmul_strategy_cpu(attrs, inputs, out_type, target): name="batch_matmul_cblas.x86", plevel=15, ) - if "mkl" in target.libs: - strategy.add_implementation( - wrap_compute_batch_matmul(topi.x86.batch_matmul_mkl), - wrap_topi_schedule(topi.x86.schedule_batch_matmul_mkl), - name="batch_matmul_mkl.x86", - plevel=15, - ) return strategy diff --git a/python/tvm/topi/x86/batch_matmul.py b/python/tvm/topi/x86/batch_matmul.py index 4abd955a6741..4e5f6efc815a 100644 --- a/python/tvm/topi/x86/batch_matmul.py +++ b/python/tvm/topi/x86/batch_matmul.py @@ -19,7 +19,7 @@ from tvm import te from tvm import autotvm from tvm.autotvm.task.space import SplitEntity -from tvm.contrib import cblas, mkl +from tvm.contrib import cblas from .. import generic from ..util import traverse_inline, get_const_tuple, get_max_power2_factor @@ -137,7 +137,8 @@ def _default_batch_matmul_config(cfg, M, N, K): cfg["tile_y"] = SplitEntity([M // y_bn, y_bn]) -def batch_matmul_common(cfg, x, y, out_shape, lib): +@autotvm.register_topi_compute("batch_matmul_cblas.x86") +def batch_matmul_cblas(cfg, x, y, out_shape=None): """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are data in batch. @@ -151,8 +152,6 @@ def batch_matmul_common(cfg, x, y, out_shape, lib): 3-D with shape [batch, N, K] out_shape : tuple or None Shape of the output - lib : A contrib module - cblas or mkl are supported Returns ------- @@ -169,24 +168,9 @@ def batch_matmul_common(cfg, x, y, out_shape, lib): assert out_shape[1] == M, "got invalid output shape" assert out_shape[2] == N, "got invalid output shape" cfg.add_flop(XB * M * N * XK * 2) - return lib.batch_matmul(x, y, False, True) - - -@autotvm.register_topi_compute("batch_matmul_cblas.x86") -def batch_matmul_cblas(cfg, x, y, out_shape=None): - return batch_matmul_common(cfg, x, y, out_shape, cblas) + return cblas.batch_matmul(x, y, False, True) @autotvm.register_topi_schedule("batch_matmul_cblas.x86") def schedule_batch_matmul_cblas(_, outs): return generic.schedule_extern(outs) - - -@autotvm.register_topi_compute("batch_matmul_mkl.x86") -def batch_matmul_mkl(cfg, x, y, out_shape=None): - return batch_matmul_common(cfg, x, y, out_shape, mkl) - - -@autotvm.register_topi_schedule("batch_matmul_mkl.x86") -def schedule_batch_matmul_mkl(_, outs): - return generic.schedule_extern(outs) diff --git a/src/relay/op/type_relations.cc b/src/relay/op/type_relations.cc index 7a3bfcb21ce6..3dc33c5022e0 100644 --- a/src/relay/op/type_relations.cc +++ b/src/relay/op/type_relations.cc @@ -64,7 +64,7 @@ bool EqualConstInt(const IndexExpr& lhs, int64_t value) { return false; } -TensorType ConcreteBroadcast(const TensorType& t1, const TensorType& t2, DataType output_dtype) { +Type ConcreteBroadcast(const TensorType& t1, const TensorType& t2, DataType output_dtype) { std::vector oshape; size_t ndim1 = t1->shape.size(); size_t ndim2 = t2->shape.size(); diff --git a/src/relay/op/type_relations.h b/src/relay/op/type_relations.h index 8726e05f2f05..6d6d5f70c0c2 100644 --- a/src/relay/op/type_relations.h +++ b/src/relay/op/type_relations.h @@ -57,8 +57,6 @@ bool IdentityRel(const Array& types, int num_inputs, const Attrs& attrs, bool BroadcastRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter); -TensorType ConcreteBroadcast(const TensorType& t1, const TensorType& t2, DataType output_dtype); - /*! * \brief Determine the broadcasted shape from two input shapes * \param t1 One of two Tensortype whose shapes are broadcasted From 1fd5b42616040cb772cbc30cee380f394311165c Mon Sep 17 00:00:00 2001 From: masa Date: Wed, 28 Oct 2020 10:56:07 +0900 Subject: [PATCH 11/19] fix merge --- python/tvm/relay/op/strategy/x86.py | 7 +++++++ python/tvm/topi/x86/batch_matmul.py | 30 ++++++++++++++++++++++++----- 2 files changed, 32 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index e2a82d396b22..3c5735b17aa5 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -377,6 +377,13 @@ def batch_matmul_strategy_cpu(attrs, inputs, out_type, target): name="batch_matmul_cblas.x86", plevel=15, ) + if "mkl" in target.libs: + strategy.add_implementation( + wrap_compute_batch_matmul(topi.x86.batch_matmul_mkl), + wrap_topi_schedule(topi.x86.schedule_batch_matmul_mkl), + name="batch_matmul_mkl.x86", + plevel=15, + ) return strategy diff --git a/python/tvm/topi/x86/batch_matmul.py b/python/tvm/topi/x86/batch_matmul.py index 4e5f6efc815a..100bdf205165 100644 --- a/python/tvm/topi/x86/batch_matmul.py +++ b/python/tvm/topi/x86/batch_matmul.py @@ -19,7 +19,7 @@ from tvm import te from tvm import autotvm from tvm.autotvm.task.space import SplitEntity -from tvm.contrib import cblas +from tvm.contrib import cblas, mkl from .. import generic from ..util import traverse_inline, get_const_tuple, get_max_power2_factor @@ -137,10 +137,9 @@ def _default_batch_matmul_config(cfg, M, N, K): cfg["tile_y"] = SplitEntity([M // y_bn, y_bn]) -@autotvm.register_topi_compute("batch_matmul_cblas.x86") -def batch_matmul_cblas(cfg, x, y, out_shape=None): +def batch_matmul_blas_common(cfg, x, y, out_shape, lib): """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are - data in batch. + data in batch, using one of BLAS libraries. Parameters ---------- @@ -152,6 +151,8 @@ def batch_matmul_cblas(cfg, x, y, out_shape=None): 3-D with shape [batch, N, K] out_shape : tuple or None Shape of the output + lib : A contrib module which implements batch_matmul funtion + cblas and mkl are supported Returns ------- @@ -168,9 +169,28 @@ def batch_matmul_cblas(cfg, x, y, out_shape=None): assert out_shape[1] == M, "got invalid output shape" assert out_shape[2] == N, "got invalid output shape" cfg.add_flop(XB * M * N * XK * 2) - return cblas.batch_matmul(x, y, False, True) + return lib.batch_matmul(x, y, False, True) + + +@autotvm.register_topi_compute("batch_matmul_cblas.x86") +def batch_matmul_cblas(cfg, x, y, out_shape=None): + """Compute batch_matmul using cblas""" + return batch_matmul_blas_common(cfg, x, y, out_shape, cblas) @autotvm.register_topi_schedule("batch_matmul_cblas.x86") def schedule_batch_matmul_cblas(_, outs): + """Create schedule for batch_matmul_cblas""" + return generic.schedule_extern(outs) + + +@autotvm.register_topi_compute("batch_matmul_mkl.x86") +def batch_matmul_mkl(cfg, x, y, out_shape=None): + """Compute batch_matmul using mkl""" + return batch_matmul_blas_common(cfg, x, y, out_shape, mkl) + + +@autotvm.register_topi_schedule("batch_matmul_mkl.x86") +def schedule_batch_matmul_mkl(_, outs): + """Create schedule for batch_matmul_mul""" return generic.schedule_extern(outs) From 44a30cbd00a7764508d8247b3179b376bab0ae53 Mon Sep 17 00:00:00 2001 From: masa Date: Wed, 28 Oct 2020 12:54:26 +0900 Subject: [PATCH 12/19] fix --- python/tvm/relay/frontend/pytorch.py | 1 - src/relay/op/type_relations.cc | 2 +- tests/python/frontend/pytorch/qnn_test.py | 8 +------- 3 files changed, 2 insertions(+), 9 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index b08e4a540f19..07ef71a7b9a5 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -3373,7 +3373,6 @@ def from_pytorch(script_module, input_infos, custom_convert_map=None, default_dt is_module = isinstance(script_module, torch.jit.ScriptModule) params = script_module.state_dict() if is_module else {} - outputs = _get_relay_input_vars( graph, input_infos, prelude, default_dtype=default_dtype, is_module=is_module ) diff --git a/src/relay/op/type_relations.cc b/src/relay/op/type_relations.cc index 3dc33c5022e0..7a3bfcb21ce6 100644 --- a/src/relay/op/type_relations.cc +++ b/src/relay/op/type_relations.cc @@ -64,7 +64,7 @@ bool EqualConstInt(const IndexExpr& lhs, int64_t value) { return false; } -Type ConcreteBroadcast(const TensorType& t1, const TensorType& t2, DataType output_dtype) { +TensorType ConcreteBroadcast(const TensorType& t1, const TensorType& t2, DataType output_dtype) { std::vector oshape; size_t ndim1 = t1->shape.size(); size_t ndim2 = t2->shape.size(); diff --git a/tests/python/frontend/pytorch/qnn_test.py b/tests/python/frontend/pytorch/qnn_test.py index 7056ebfb4a90..723219f91da0 100644 --- a/tests/python/frontend/pytorch/qnn_test.py +++ b/tests/python/frontend/pytorch/qnn_test.py @@ -46,10 +46,7 @@ def get_tvm_runtime(script_module, input_name, ishape): with tvm.transform.PassContext(opt_level=3): # test on only cpu for now, torch cannot run quant models on cuda # also not to make CI too slow - # opt_mod, opt_params = relay.optimize(mod, target="llvm -mcpu=cascadelake -libs=mkl", params=params) - # print(opt_mod["main"]) - # lib = relay.build(mod, target="llvm -mcpu=cascadelake -libs=mkl", params=params) - lib = relay.build(mod, target="llvm -mcpu=cascadelake", params=params) + lib = relay.build(mod, target="llvm", params=params) runtime = tvm.contrib.graph_runtime.GraphModule(lib["default"](tvm.cpu(0))) return runtime @@ -546,6 +543,3 @@ def forward(self, inp): tvm_result = runtime.get_output(0).asnumpy() tvm.testing.assert_allclose(tvm_result, pt_result, rtol=1e-4, atol=1e-4) - - -test_quantize_dynamic() From 6de0c4a8a679de1d01b1181c157a2f0a9212911d Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 28 Oct 2020 13:59:40 +0900 Subject: [PATCH 13/19] lint fix --- src/relay/qnn/op/quantize.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/qnn/op/quantize.cc b/src/relay/qnn/op/quantize.cc index 29d59b311994..9829834f43a3 100644 --- a/src/relay/qnn/op/quantize.cc +++ b/src/relay/qnn/op/quantize.cc @@ -89,7 +89,7 @@ Expr QuantizeLower(const Expr& input_tensor, const Expr& output_scale, auto in_type = types[0]; auto in_tensor_type = in_type.as(); ICHECK(in_tensor_type != nullptr) << "Type information missing." - << " Please run infer_type pass."; + << " Please run infer_type pass."; Array input_shape = in_tensor_type->shape; const auto out_dtype = attrs->out_dtype; From 12447d7270d8f8e90c490aee692bf142242bfd1f Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 28 Oct 2020 17:04:43 +0900 Subject: [PATCH 14/19] fix black --- python/tvm/relay/frontend/qnn_torch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/qnn_torch.py b/python/tvm/relay/frontend/qnn_torch.py index e4204ae2d7e1..c01e34d9037b 100644 --- a/python/tvm/relay/frontend/qnn_torch.py +++ b/python/tvm/relay/frontend/qnn_torch.py @@ -834,8 +834,8 @@ def _calculate_qparam(inp): mx = _op.max(inp) # Ensure that the interval contains 0 - mn = _op.minimum(mn, _op.const(0., dtype="float32")) - mx = _op.maximum(mx, _op.const(0., dtype="float32")) + mn = _op.minimum(mn, _op.const(0.0, dtype="float32")) + mx = _op.maximum(mx, _op.const(0.0, dtype="float32")) qmax = 255 From 7858c805c55c7df5a59eda1b9d77cf85c0fa4c7c Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 28 Oct 2020 17:08:25 +0900 Subject: [PATCH 15/19] more black fix --- tests/python/frontend/pytorch/qnn_test.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/python/frontend/pytorch/qnn_test.py b/tests/python/frontend/pytorch/qnn_test.py index 723219f91da0..756e75ffc099 100644 --- a/tests/python/frontend/pytorch/qnn_test.py +++ b/tests/python/frontend/pytorch/qnn_test.py @@ -524,8 +524,10 @@ def forward(self, inp): torch.manual_seed(0) mod = LinearWrapper(16, 32) - for qconfig in [torch.quantization.per_channel_dynamic_qconfig, - torch.quantization.default_dynamic_qconfig]: + for qconfig in [ + torch.quantization.per_channel_dynamic_qconfig, + torch.quantization.default_dynamic_qconfig, + ]: for ishape in [(16, 16), (10, 16, 16)]: qspec = {nn.Linear: qconfig} qmod = torch.quantization.quantize_dynamic(mod, qconfig_spec=qspec, dtype=torch.qint8) From bc78179eecdac5aef0b5d3e26cf5312ed69dcffb Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 28 Oct 2020 22:27:04 +0900 Subject: [PATCH 16/19] fix version check for 1.5.1 --- python/tvm/relay/frontend/pytorch.py | 2 +- python/tvm/relay/frontend/qnn_torch.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 07ef71a7b9a5..8d164314ecc8 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2741,7 +2741,7 @@ def _run_jit_passes(graph): # pylint: disable=c-extension-no-member import torch - if is_version_greater_than("1.5.0"): + if is_version_greater_than("1.5.1"): # This is required for torchvision detection models from 1.6 above # It is the same as _jit_pass_inline, except that it has some special # case behaviors for some ops such as aten::__interpolate() diff --git a/python/tvm/relay/frontend/qnn_torch.py b/python/tvm/relay/frontend/qnn_torch.py index c01e34d9037b..6af6f768c62d 100644 --- a/python/tvm/relay/frontend/qnn_torch.py +++ b/python/tvm/relay/frontend/qnn_torch.py @@ -173,7 +173,7 @@ def _get_quant_param_for_input(input_value): # 6th and 7th arg are output scale and zp respectively. # PyTorch 1.6 changed qconv API - if is_version_greater_than("1.5.0"): + if is_version_greater_than("1.5.1"): qconv_indices = (2, 3) else: qconv_indices = (6, 7) @@ -840,7 +840,7 @@ def _calculate_qparam(inp): qmax = 255 # reduce_range became True in v1.6 - if is_version_greater_than("1.5.0"): + if is_version_greater_than("1.5.1"): qmax = 127 scale = (mx - mn) / _expr.const(qmax, dtype="float32") From f05fac69d55a179f334a5f2bac05a705556f1763 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 28 Oct 2020 23:14:42 +0900 Subject: [PATCH 17/19] disable assert on v1.4 (strange pytorch issue) --- tests/python/frontend/pytorch/qnn_test.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/tests/python/frontend/pytorch/qnn_test.py b/tests/python/frontend/pytorch/qnn_test.py index 756e75ffc099..67fa9a284e5f 100644 --- a/tests/python/frontend/pytorch/qnn_test.py +++ b/tests/python/frontend/pytorch/qnn_test.py @@ -29,6 +29,7 @@ import tvm import tvm.testing from tvm import relay +from tvm.relay.frontend.pytorch_utils import is_version_greater_than from tvm.contrib.download import download_testdata @@ -198,9 +199,7 @@ def fuse_model(self): # test on quantized::mul_scalar with negative scale class MulScalarNegative(nn.Module): - def __init__( - self, - ): + def __init__(self,): super().__init__() self.float_op = nn.quantized.FloatFunctional() self.quant = QuantStub() @@ -338,12 +337,7 @@ def get_transform(): normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) return transforms.Compose( - [ - transforms.Resize(256), - transforms.CenterCrop(224), - transforms.ToTensor(), - normalize, - ] + [transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize,] ) def get_real_image(im_height, im_width): @@ -544,4 +538,13 @@ def forward(self, inp): runtime.run() tvm_result = runtime.get_output(0).asnumpy() - tvm.testing.assert_allclose(tvm_result, pt_result, rtol=1e-4, atol=1e-4) + # Only compare with the PyTorch result for version v1.6 or newer + # Have seen a strange accuracy problem from PyTorch 1.4 and 1.5 + # Even with the manual random seed set, the same PyTorch + # version can outputs slightly different results depending on an environment. + # Outputs from v1.6 seem reliable. TVM's outputs are always the same + if is_version_greater_than("1.5.1"): + tvm.testing.assert_allclose(tvm_result, pt_result, rtol=1e-4, atol=1e-4) + + +test_quantize_dynamic() From b9f1eb4317f03b2b08969c5d1ebc32863284f15e Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 28 Oct 2020 23:22:56 +0900 Subject: [PATCH 18/19] minor fix --- tests/python/frontend/pytorch/qnn_test.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/tests/python/frontend/pytorch/qnn_test.py b/tests/python/frontend/pytorch/qnn_test.py index 67fa9a284e5f..1851e31e817f 100644 --- a/tests/python/frontend/pytorch/qnn_test.py +++ b/tests/python/frontend/pytorch/qnn_test.py @@ -199,7 +199,7 @@ def fuse_model(self): # test on quantized::mul_scalar with negative scale class MulScalarNegative(nn.Module): - def __init__(self,): + def __init__(self): super().__init__() self.float_op = nn.quantized.FloatFunctional() self.quant = QuantStub() @@ -337,7 +337,7 @@ def get_transform(): normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) return transforms.Compose( - [transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize,] + [transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize] ) def get_real_image(im_height, im_width): @@ -545,6 +545,3 @@ def forward(self, inp): # Outputs from v1.6 seem reliable. TVM's outputs are always the same if is_version_greater_than("1.5.1"): tvm.testing.assert_allclose(tvm_result, pt_result, rtol=1e-4, atol=1e-4) - - -test_quantize_dynamic() From 246b11fa03add187f0daff1f2b93a0cbbdd75c36 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 29 Oct 2020 07:06:20 +0900 Subject: [PATCH 19/19] use dequantize --- python/tvm/relay/frontend/qnn_torch.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/python/tvm/relay/frontend/qnn_torch.py b/python/tvm/relay/frontend/qnn_torch.py index 6af6f768c62d..e3431043bc86 100644 --- a/python/tvm/relay/frontend/qnn_torch.py +++ b/python/tvm/relay/frontend/qnn_torch.py @@ -575,13 +575,7 @@ def _impl(inputs, _): ) return _do_bias_and_requantize( - conv_out, - bias, - input_scale, - weight_scale, - output_scale, - output_zero_point, - with_relu, + conv_out, bias, input_scale, weight_scale, output_scale, output_zero_point, with_relu ) return _impl @@ -879,7 +873,9 @@ def _impl(inputs, _): bias_var = inputs[1][3] dequant_scale = input_scale * weight_scale - dense_out = _op.cast(dense, "float32") * dequant_scale + dense_out = relay.qnn.op.dequantize( + dense, dequant_scale, input_zero_point=relay.const(0, "int32"), axis=1 + ) if len(data_shape) > 2: new_shape = list(data_shape[:-1])