Skip to content

Commit

Permalink
[Torch, QNN] Support dynamic quantization flow to enable importing qu…
Browse files Browse the repository at this point in the history
…antized transformer models (apache#6782)

* add stub and test

* per channel quantize

* calculate qparam correctly

* import qbert working

* support batched qdense

* test batched input

* fix mkl offloading of batch matmul

* reduce range become True in torch 1.6

* fix for 1.6

* Revert "fix mkl offloading of batch matmul"

This reverts commit cd90aa7.

* fix merge

* fix

* lint fix

* fix black

* more black fix

* fix version check for 1.5.1

* disable assert on v1.4 (strange pytorch issue)

* minor fix

* use dequantize

Co-authored-by: masa <masa@pop-os.localdomain>
  • Loading branch information
2 people authored and trevor-m committed Dec 4, 2020
1 parent 25a508d commit f373d78
Show file tree
Hide file tree
Showing 6 changed files with 172 additions and 39 deletions.
5 changes: 3 additions & 2 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2741,7 +2741,7 @@ def _run_jit_passes(graph):
# pylint: disable=c-extension-no-member
import torch

if is_version_greater_than("1.5.0"):
if is_version_greater_than("1.5.1"):
# This is required for torchvision detection models from 1.6 above
# It is the same as _jit_pass_inline, except that it has some special
# case behaviors for some ops such as aten::__interpolate()
Expand Down Expand Up @@ -3383,7 +3383,8 @@ def from_pytorch(script_module, input_infos, custom_convert_map=None, default_dt
ret_name = _get_input_names(graph.return_node())

# For quantized models
if "aten::quantize_per_tensor" in op_names:
quantized_ops = set(["aten::quantize_per_tensor", "quantized::linear_dynamic"])
if len(quantized_ops.intersection(set(op_names))) > 0:
weight_quant_params = qnn_torch.get_weight_quant_params(script_module)
qnn_torch.add_input_quant_params_to_op_inputs(graph)
qnn_torch.add_quant_params_to_outputs(outputs, packed_param_map, weight_quant_params)
Expand Down
81 changes: 73 additions & 8 deletions python/tvm/relay/frontend/qnn_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def _get_quant_param_for_input(input_value):
# 6th and 7th arg are output scale and zp respectively.

# PyTorch 1.6 changed qconv API
if is_version_greater_than("1.5.0"):
if is_version_greater_than("1.5.1"):
qconv_indices = (2, 3)
else:
qconv_indices = (6, 7)
Expand Down Expand Up @@ -575,13 +575,7 @@ def _impl(inputs, _):
)

return _do_bias_and_requantize(
conv_out,
bias,
input_scale,
weight_scale,
output_scale,
output_zero_point,
with_relu,
conv_out, bias, input_scale, weight_scale, output_scale, output_zero_point, with_relu
)

return _impl
Expand Down Expand Up @@ -826,6 +820,76 @@ def _impl(inputs, _):
return _impl


def _linear_dynamic():
def _calculate_qparam(inp):
# reference ATen/native/quantized/cpu/qlinear_dynamic.cpp
# ChooseQuantizationParams function
mn = _op.min(inp)
mx = _op.max(inp)

# Ensure that the interval contains 0
mn = _op.minimum(mn, _op.const(0.0, dtype="float32"))
mx = _op.maximum(mx, _op.const(0.0, dtype="float32"))

qmax = 255

# reduce_range became True in v1.6
if is_version_greater_than("1.5.1"):
qmax = 127

scale = (mx - mn) / _expr.const(qmax, dtype="float32")

zero_point_from_min = -(mn / scale)
zero_point = _op.cast(_op.round(_op.clip(zero_point_from_min, 0.0, qmax)), "int32")

return scale, zero_point

def _impl(inputs, _):
weight = inputs[1][0]
weight_scale = inputs[1][1]
weight_zero_point = inputs[1][2]

inp = inputs[0]

input_scale, input_zero_point = _calculate_qparam(inp)
qinp = relay.qnn.op.quantize(inp, input_scale, input_zero_point, out_dtype="uint8")

data_shape = infer_shape(inp)

if len(data_shape) > 2:
qinp = _op.reverse_reshape(qinp, [-1, 0])

weight_shape = infer_shape(weight)
units = weight_shape[0]
dense = relay.qnn.op.dense(
qinp,
weight,
input_zero_point,
weight_zero_point,
input_scale,
weight_scale,
units=units,
)
bias_var = inputs[1][3]

dequant_scale = input_scale * weight_scale
dense_out = relay.qnn.op.dequantize(
dense, dequant_scale, input_zero_point=relay.const(0, "int32"), axis=1
)

if len(data_shape) > 2:
new_shape = list(data_shape[:-1])
new_shape.append(units)
dense_out = _op.reshape(dense_out, new_shape)

if bias_var is not None:
return dense_out + bias_var

return dense_out

return _impl


convert_map = {
"aten::quantize_per_tensor": _quantize_per_tensor(),
"quantized::conv2d_relu": _quantized_conv2d(with_relu=True),
Expand All @@ -841,4 +905,5 @@ def _impl(inputs, _):
"quantized::add_scalar": _add_scalar(),
"quantized::mul_scalar": _mul_scalar(),
"quantized::relu6": _relu6(),
"quantized::linear_dynamic": _linear_dynamic(),
}
35 changes: 27 additions & 8 deletions src/relay/qnn/op/dense.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,18 @@ Expr DenseFourthTerm(int input_zero_point_int, int kernel_zero_point_int, int re
return MakeConstantScalar(DataType::Int(32), scalar_term);
}

Expr DenseFourthTerm(const Expr& input_zero_point, const Expr& kernel_zero_point,
int reduction_dim_size) {
auto reduction_dim = MakeConstantScalar(DataType::Int(32), reduction_dim_size);
return Multiply(Multiply(input_zero_point, kernel_zero_point), reduction_dim);
}

Expr DenseCombineTerms(const Expr& term1, const Expr& term2, const Expr& term3, const Expr& term4) {
auto data_term = Subtract(term1, term2);
// Putting constant terms together, so that constant folding can fold it.
auto const_term = Subtract(term4, term3);
return Add(data_term, const_term);
}
/*
* \brief Forward rewrite the qnn dense op.
* \param attrs The QNN dense attrs.
Expand Down Expand Up @@ -144,14 +156,24 @@ Expr QnnDenseCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,

const auto* qnn_dense_attrs = attrs.as<DenseAttrs>();

auto term1 = DenseFirstTerm(quantized_data, quantized_kernel, qnn_dense_attrs);
auto term2 = DenseSecondTerm(quantized_data, kernel_zero_point);
auto term3 = DenseThirdTerm(quantized_kernel, input_zero_point);

// Extract the integer zero points.
auto input_zero_point_int = GetScalarFromConstant<int>(input_zero_point);
auto kernel_zero_point_int = GetScalarFromConstant<int>(kernel_zero_point);

if (!IsConstScalar(input_zero_point)) {
if (kernel_zero_point_int == 0) {
return Subtract(term1, term3);
}
auto term4 = DenseFourthTerm(input_zero_point, kernel_zero_point, reduction_dim_size);
return DenseCombineTerms(term1, term2, term3, term4);
}

auto input_zero_point_int = GetScalarFromConstant<int>(input_zero_point);

// Get all the terms as described in the comments.
auto term1 = DenseFirstTerm(quantized_data, quantized_kernel, qnn_dense_attrs);
auto term2 = DenseSecondTerm(quantized_data, kernel_zero_point);
auto term3 = DenseThirdTerm(quantized_kernel, input_zero_point);
auto term4 = DenseFourthTerm(input_zero_point_int, kernel_zero_point_int, reduction_dim_size);

// Combine those 4 terms depending on the zero points to get the best lowering.
Expand All @@ -165,10 +187,7 @@ Expr QnnDenseCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
// term 2 and term 4 become zero.
return Subtract(term1, term3);
} else {
auto data_term = Subtract(term1, term2);
// Putting constant terms together, so that constant folding can fold it.
auto const_term = Subtract(term4, term3);
return Add(data_term, const_term);
return DenseCombineTerms(term1, term2, term3, term4);
}
}

Expand Down
23 changes: 11 additions & 12 deletions src/relay/qnn/op/quantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,20 +83,27 @@ Expr MakeQuantize(Expr data, Expr output_scale, Expr output_zero_point, int axis
}

Expr QuantizeLower(const Expr& input_tensor, const Expr& output_scale,
const Expr& output_zero_point, const Array<IndexExpr>& input_shape,
const Expr& output_zero_point, const Array<tvm::relay::Type>& types,
const QuantizeAttrs* attrs) {
ICHECK_EQ(types.size(), 4);
auto in_type = types[0];
auto in_tensor_type = in_type.as<TensorTypeNode>();
ICHECK(in_tensor_type != nullptr) << "Type information missing."
<< " Please run infer_type pass.";
Array<IndexExpr> input_shape = in_tensor_type->shape;

const auto out_dtype = attrs->out_dtype;
const auto axis = attrs->axis;

size_t n_dim = input_shape.size();

auto expanded_output_scale = output_scale;
if (!IsConstScalar(output_scale)) {
if (!IsConstScalar(output_scale) && !IsScalarType(types[1])) {
expanded_output_scale = ExpandBiasToMatchAxis(output_scale, n_dim, {axis});
}

auto expanded_output_zero_point = output_zero_point;
if (!IsConstScalar(output_zero_point)) {
if (!IsConstScalar(output_zero_point) && !IsScalarType(types[2])) {
expanded_output_zero_point = ExpandBiasToMatchAxis(output_zero_point, n_dim, {axis});
}

Expand All @@ -120,15 +127,7 @@ Expr QuantizeQnnCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
const auto* quantize_attrs = attrs.as<QuantizeAttrs>();
ICHECK(quantize_attrs != nullptr);

// Find input shape.
ICHECK_EQ(types.size(), 4);
auto in_type = types[0];
auto in_tensor_type = in_type.as<TensorTypeNode>();
ICHECK(in_tensor_type != nullptr) << "Type information missing."
<< " Please run infer_type pass.";
Array<IndexExpr> input_shape = in_tensor_type->shape;

return QuantizeLower(data, output_scale, output_zero_point, input_shape, quantize_attrs);
return QuantizeLower(data, output_scale, output_zero_point, types, quantize_attrs);
}

RELAY_REGISTER_OP("qnn.quantize")
Expand Down
12 changes: 12 additions & 0 deletions src/relay/qnn/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,18 @@ static inline bool IsScalarType(const Type& expr_type, const DataType& dtype) {
return true;
}

/*
* \brief Checks whether an expr type is scalar.
* \param expr_type The type of expr to be checked.
* \return True if the type is a scalar
*/
static inline bool IsScalarType(const Type& expr_type) {
const auto* tensor_type = expr_type.as<TensorTypeNode>();
CHECK(tensor_type) << "Only tensor type can be checked for scalar values. But got"
<< AsText(expr_type, false);
return tensor_type->shape.size() == 0;
}

/*
* \brief Checks and assigns types to scale and zero points.
* \param expr_type The type of expr to be checked.
Expand Down
55 changes: 46 additions & 9 deletions tests/python/frontend/pytorch/qnn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
from torch.quantization import fuse_modules, QuantWrapper

import tvm
import tvm.testing
from tvm import relay
from tvm.relay.frontend.pytorch_utils import is_version_greater_than
from tvm.contrib.download import download_testdata


Expand Down Expand Up @@ -197,9 +199,7 @@ def fuse_model(self):

# test on quantized::mul_scalar with negative scale
class MulScalarNegative(nn.Module):
def __init__(
self,
):
def __init__(self):
super().__init__()
self.float_op = nn.quantized.FloatFunctional()
self.quant = QuantStub()
Expand Down Expand Up @@ -337,12 +337,7 @@ def get_transform():

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
return transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
]
[transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize]
)

def get_real_image(im_height, im_width):
Expand Down Expand Up @@ -508,3 +503,45 @@ def test_serialized_modules():
num_identical = np.sum(np.abs(tvm_result - pt_result) < 1e-2)
match_ratio = num_identical / float(np.prod(tvm_result.shape))
assert match_ratio > 0.90


def test_quantize_dynamic():
# A wrapper is required for quantize_dynamic to work correctly
class LinearWrapper(nn.Module):
def __init__(self, in_dim, hidden_dim):
super().__init__()
self.linear = nn.Linear(in_dim, hidden_dim)

def forward(self, inp):
return self.linear(inp)

torch.manual_seed(0)
mod = LinearWrapper(16, 32)

for qconfig in [
torch.quantization.per_channel_dynamic_qconfig,
torch.quantization.default_dynamic_qconfig,
]:
for ishape in [(16, 16), (10, 16, 16)]:
qspec = {nn.Linear: qconfig}
qmod = torch.quantization.quantize_dynamic(mod, qconfig_spec=qspec, dtype=torch.qint8)

inp = torch.randn(*ishape)
script_module = torch.jit.trace(qmod, inp).eval()

with torch.no_grad():
pt_result = script_module(inp.clone()).numpy()

input_name = "input"
runtime = get_tvm_runtime(script_module, "input", inp.shape)
runtime.set_input(input_name, inp.numpy().copy())
runtime.run()
tvm_result = runtime.get_output(0).asnumpy()

# Only compare with the PyTorch result for version v1.6 or newer
# Have seen a strange accuracy problem from PyTorch 1.4 and 1.5
# Even with the manual random seed set, the same PyTorch
# version can outputs slightly different results depending on an environment.
# Outputs from v1.6 seem reliable. TVM's outputs are always the same
if is_version_greater_than("1.5.1"):
tvm.testing.assert_allclose(tvm_result, pt_result, rtol=1e-4, atol=1e-4)

0 comments on commit f373d78

Please sign in to comment.