Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Torch, QNN] Support dynamic quantization flow to enable importing quantized transformer models #6782

Merged
merged 19 commits into from
Oct 29, 2020
5 changes: 3 additions & 2 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2741,7 +2741,7 @@ def _run_jit_passes(graph):
# pylint: disable=c-extension-no-member
import torch

if is_version_greater_than("1.5.0"):
if is_version_greater_than("1.5.1"):
# This is required for torchvision detection models from 1.6 above
# It is the same as _jit_pass_inline, except that it has some special
# case behaviors for some ops such as aten::__interpolate()
Expand Down Expand Up @@ -3383,7 +3383,8 @@ def from_pytorch(script_module, input_infos, custom_convert_map=None, default_dt
ret_name = _get_input_names(graph.return_node())

# For quantized models
if "aten::quantize_per_tensor" in op_names:
quantized_ops = set(["aten::quantize_per_tensor", "quantized::linear_dynamic"])
if len(quantized_ops.intersection(set(op_names))) > 0:
weight_quant_params = qnn_torch.get_weight_quant_params(script_module)
qnn_torch.add_input_quant_params_to_op_inputs(graph)
qnn_torch.add_quant_params_to_outputs(outputs, packed_param_map, weight_quant_params)
Expand Down
81 changes: 73 additions & 8 deletions python/tvm/relay/frontend/qnn_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def _get_quant_param_for_input(input_value):
# 6th and 7th arg are output scale and zp respectively.

# PyTorch 1.6 changed qconv API
if is_version_greater_than("1.5.0"):
if is_version_greater_than("1.5.1"):
qconv_indices = (2, 3)
else:
qconv_indices = (6, 7)
Expand Down Expand Up @@ -575,13 +575,7 @@ def _impl(inputs, _):
)

return _do_bias_and_requantize(
conv_out,
bias,
input_scale,
weight_scale,
output_scale,
output_zero_point,
with_relu,
conv_out, bias, input_scale, weight_scale, output_scale, output_zero_point, with_relu
)

return _impl
Expand Down Expand Up @@ -826,6 +820,76 @@ def _impl(inputs, _):
return _impl


def _linear_dynamic():
def _calculate_qparam(inp):
# reference ATen/native/quantized/cpu/qlinear_dynamic.cpp
# ChooseQuantizationParams function
mn = _op.min(inp)
mx = _op.max(inp)

# Ensure that the interval contains 0
mn = _op.minimum(mn, _op.const(0.0, dtype="float32"))
mx = _op.maximum(mx, _op.const(0.0, dtype="float32"))

qmax = 255

# reduce_range became True in v1.6
if is_version_greater_than("1.5.1"):
qmax = 127
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whats happening here? Should this be coupled with dtype - 127 should be for int8 while 255 for uint8

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comes from https://github.com/pytorch/pytorch/blob/d642992877139671466d2a96663abede9e39ad55/aten/src/ATen/native/quantized/cpu/quant_utils.h#L64-L66

Here, they intentionally reduce the possible range of quantized values by half, i.e. [qmin, qmax] to [qmin/2, qmax/2]. Since PyTorch only uses uint8, this is fine.

It's not clear to me why they do this, but the following PR has some explanation: "reduce_range option restricts the activation tensor to 7 bits instead of 8.This is necessary to enable per channel quant for RNNs and LSTMs" pytorch/pytorch#39041


scale = (mx - mn) / _expr.const(qmax, dtype="float32")

zero_point_from_min = -(mn / scale)
zero_point = _op.cast(_op.round(_op.clip(zero_point_from_min, 0.0, qmax)), "int32")

return scale, zero_point

def _impl(inputs, _):
weight = inputs[1][0]
weight_scale = inputs[1][1]
weight_zero_point = inputs[1][2]

inp = inputs[0]

input_scale, input_zero_point = _calculate_qparam(inp)
qinp = relay.qnn.op.quantize(inp, input_scale, input_zero_point, out_dtype="uint8")

data_shape = infer_shape(inp)

if len(data_shape) > 2:
qinp = _op.reverse_reshape(qinp, [-1, 0])

weight_shape = infer_shape(weight)
units = weight_shape[0]
dense = relay.qnn.op.dense(
qinp,
weight,
input_zero_point,
weight_zero_point,
input_scale,
weight_scale,
units=units,
)
bias_var = inputs[1][3]

dequant_scale = input_scale * weight_scale
dense_out = relay.qnn.op.dequantize(
dense, dequant_scale, input_zero_point=relay.const(0, "int32"), axis=1
)

if len(data_shape) > 2:
new_shape = list(data_shape[:-1])
new_shape.append(units)
dense_out = _op.reshape(dense_out, new_shape)

if bias_var is not None:
return dense_out + bias_var

return dense_out

return _impl


convert_map = {
"aten::quantize_per_tensor": _quantize_per_tensor(),
"quantized::conv2d_relu": _quantized_conv2d(with_relu=True),
Expand All @@ -841,4 +905,5 @@ def _impl(inputs, _):
"quantized::add_scalar": _add_scalar(),
"quantized::mul_scalar": _mul_scalar(),
"quantized::relu6": _relu6(),
"quantized::linear_dynamic": _linear_dynamic(),
}
35 changes: 27 additions & 8 deletions src/relay/qnn/op/dense.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,18 @@ Expr DenseFourthTerm(int input_zero_point_int, int kernel_zero_point_int, int re
return MakeConstantScalar(DataType::Int(32), scalar_term);
}

Expr DenseFourthTerm(const Expr& input_zero_point, const Expr& kernel_zero_point,
int reduction_dim_size) {
auto reduction_dim = MakeConstantScalar(DataType::Int(32), reduction_dim_size);
return Multiply(Multiply(input_zero_point, kernel_zero_point), reduction_dim);
}

Expr DenseCombineTerms(const Expr& term1, const Expr& term2, const Expr& term3, const Expr& term4) {
auto data_term = Subtract(term1, term2);
// Putting constant terms together, so that constant folding can fold it.
auto const_term = Subtract(term4, term3);
return Add(data_term, const_term);
}
/*
* \brief Forward rewrite the qnn dense op.
* \param attrs The QNN dense attrs.
Expand Down Expand Up @@ -144,14 +156,24 @@ Expr QnnDenseCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,

const auto* qnn_dense_attrs = attrs.as<DenseAttrs>();

auto term1 = DenseFirstTerm(quantized_data, quantized_kernel, qnn_dense_attrs);
auto term2 = DenseSecondTerm(quantized_data, kernel_zero_point);
auto term3 = DenseThirdTerm(quantized_kernel, input_zero_point);

// Extract the integer zero points.
auto input_zero_point_int = GetScalarFromConstant<int>(input_zero_point);
auto kernel_zero_point_int = GetScalarFromConstant<int>(kernel_zero_point);

if (!IsConstScalar(input_zero_point)) {
if (kernel_zero_point_int == 0) {
return Subtract(term1, term3);
}
auto term4 = DenseFourthTerm(input_zero_point, kernel_zero_point, reduction_dim_size);
return DenseCombineTerms(term1, term2, term3, term4);
}

auto input_zero_point_int = GetScalarFromConstant<int>(input_zero_point);

// Get all the terms as described in the comments.
auto term1 = DenseFirstTerm(quantized_data, quantized_kernel, qnn_dense_attrs);
auto term2 = DenseSecondTerm(quantized_data, kernel_zero_point);
auto term3 = DenseThirdTerm(quantized_kernel, input_zero_point);
auto term4 = DenseFourthTerm(input_zero_point_int, kernel_zero_point_int, reduction_dim_size);

// Combine those 4 terms depending on the zero points to get the best lowering.
Expand All @@ -165,10 +187,7 @@ Expr QnnDenseCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
// term 2 and term 4 become zero.
return Subtract(term1, term3);
} else {
auto data_term = Subtract(term1, term2);
// Putting constant terms together, so that constant folding can fold it.
auto const_term = Subtract(term4, term3);
return Add(data_term, const_term);
return DenseCombineTerms(term1, term2, term3, term4);
}
}

Expand Down
23 changes: 11 additions & 12 deletions src/relay/qnn/op/quantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,20 +83,27 @@ Expr MakeQuantize(Expr data, Expr output_scale, Expr output_zero_point, int axis
}

Expr QuantizeLower(const Expr& input_tensor, const Expr& output_scale,
const Expr& output_zero_point, const Array<IndexExpr>& input_shape,
const Expr& output_zero_point, const Array<tvm::relay::Type>& types,
const QuantizeAttrs* attrs) {
ICHECK_EQ(types.size(), 4);
auto in_type = types[0];
auto in_tensor_type = in_type.as<TensorTypeNode>();
ICHECK(in_tensor_type != nullptr) << "Type information missing."
<< " Please run infer_type pass.";
Array<IndexExpr> input_shape = in_tensor_type->shape;

const auto out_dtype = attrs->out_dtype;
const auto axis = attrs->axis;

size_t n_dim = input_shape.size();

auto expanded_output_scale = output_scale;
if (!IsConstScalar(output_scale)) {
if (!IsConstScalar(output_scale) && !IsScalarType(types[1])) {
expanded_output_scale = ExpandBiasToMatchAxis(output_scale, n_dim, {axis});
}

auto expanded_output_zero_point = output_zero_point;
if (!IsConstScalar(output_zero_point)) {
if (!IsConstScalar(output_zero_point) && !IsScalarType(types[2])) {
expanded_output_zero_point = ExpandBiasToMatchAxis(output_zero_point, n_dim, {axis});
}

Expand All @@ -120,15 +127,7 @@ Expr QuantizeQnnCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
const auto* quantize_attrs = attrs.as<QuantizeAttrs>();
ICHECK(quantize_attrs != nullptr);

// Find input shape.
ICHECK_EQ(types.size(), 4);
auto in_type = types[0];
auto in_tensor_type = in_type.as<TensorTypeNode>();
ICHECK(in_tensor_type != nullptr) << "Type information missing."
<< " Please run infer_type pass.";
Array<IndexExpr> input_shape = in_tensor_type->shape;

return QuantizeLower(data, output_scale, output_zero_point, input_shape, quantize_attrs);
return QuantizeLower(data, output_scale, output_zero_point, types, quantize_attrs);
}

RELAY_REGISTER_OP("qnn.quantize")
Expand Down
12 changes: 12 additions & 0 deletions src/relay/qnn/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,18 @@ static inline bool IsScalarType(const Type& expr_type, const DataType& dtype) {
return true;
}

/*
* \brief Checks whether an expr type is scalar.
* \param expr_type The type of expr to be checked.
* \return True if the type is a scalar
*/
static inline bool IsScalarType(const Type& expr_type) {
const auto* tensor_type = expr_type.as<TensorTypeNode>();
CHECK(tensor_type) << "Only tensor type can be checked for scalar values. But got"
<< AsText(expr_type, false);
return tensor_type->shape.size() == 0;
}

/*
* \brief Checks and assigns types to scale and zero points.
* \param expr_type The type of expr to be checked.
Expand Down
55 changes: 46 additions & 9 deletions tests/python/frontend/pytorch/qnn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
from torch.quantization import fuse_modules, QuantWrapper

import tvm
import tvm.testing
from tvm import relay
from tvm.relay.frontend.pytorch_utils import is_version_greater_than
from tvm.contrib.download import download_testdata


Expand Down Expand Up @@ -197,9 +199,7 @@ def fuse_model(self):

# test on quantized::mul_scalar with negative scale
class MulScalarNegative(nn.Module):
def __init__(
self,
):
def __init__(self):
super().__init__()
self.float_op = nn.quantized.FloatFunctional()
self.quant = QuantStub()
Expand Down Expand Up @@ -337,12 +337,7 @@ def get_transform():

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
return transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
]
[transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize]
)

def get_real_image(im_height, im_width):
Expand Down Expand Up @@ -508,3 +503,45 @@ def test_serialized_modules():
num_identical = np.sum(np.abs(tvm_result - pt_result) < 1e-2)
match_ratio = num_identical / float(np.prod(tvm_result.shape))
assert match_ratio > 0.90


def test_quantize_dynamic():
# A wrapper is required for quantize_dynamic to work correctly
class LinearWrapper(nn.Module):
def __init__(self, in_dim, hidden_dim):
super().__init__()
self.linear = nn.Linear(in_dim, hidden_dim)

def forward(self, inp):
return self.linear(inp)

torch.manual_seed(0)
mod = LinearWrapper(16, 32)

for qconfig in [
torch.quantization.per_channel_dynamic_qconfig,
torch.quantization.default_dynamic_qconfig,
]:
for ishape in [(16, 16), (10, 16, 16)]:
qspec = {nn.Linear: qconfig}
qmod = torch.quantization.quantize_dynamic(mod, qconfig_spec=qspec, dtype=torch.qint8)

inp = torch.randn(*ishape)
script_module = torch.jit.trace(qmod, inp).eval()

with torch.no_grad():
pt_result = script_module(inp.clone()).numpy()

input_name = "input"
runtime = get_tvm_runtime(script_module, "input", inp.shape)
runtime.set_input(input_name, inp.numpy().copy())
runtime.run()
tvm_result = runtime.get_output(0).asnumpy()

# Only compare with the PyTorch result for version v1.6 or newer
# Have seen a strange accuracy problem from PyTorch 1.4 and 1.5
# Even with the manual random seed set, the same PyTorch
# version can outputs slightly different results depending on an environment.
# Outputs from v1.6 seem reliable. TVM's outputs are always the same
if is_version_greater_than("1.5.1"):
tvm.testing.assert_allclose(tvm_result, pt_result, rtol=1e-4, atol=1e-4)