Skip to content

Commit

Permalink
Requantize operator implementation.
Browse files Browse the repository at this point in the history
Requantize converts one quantized tensor representation to another quantized
representation. The PR has following implementation features

- Requantize operator defined in qnn namespace - relay.qnn.requantize
- Lowering of the requantize to exisiting Relay operators
- Integer fixed point implementation of requantize
    - Two rounding modes - FE_UPWARDS (round towards infinity) and
    FE_AWAY_FROM_ZERO (std::round behavior)
- Floating point implementation as well, that can act as reference or can be
used for devices when FP32 computation is not used.
- Unit test cases

Relevant Issue - apache#2351

Credit to TFLite and GemmLowp to provide reference implementations.
  • Loading branch information
anijain2305 committed Aug 8, 2019
1 parent ed11cd7 commit 91b58a5
Show file tree
Hide file tree
Showing 5 changed files with 390 additions and 128 deletions.
13 changes: 11 additions & 2 deletions include/tvm/relay/attrs/qnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ struct RequantizeAttrs : public tvm::AttrsNode<RequantizeAttrs> {
double output_scale;
int32_t output_zero_point;
bool use_int_compute;
std::string rounding_mode;
DataType out_dtype;

TVM_DECLARE_ATTRS(RequantizeAttrs, "relay.attrs.RequantizeAttrs") {
Expand All @@ -48,14 +49,22 @@ struct RequantizeAttrs : public tvm::AttrsNode<RequantizeAttrs> {
.describe("The scale of the input tensor.");
TVM_ATTR_FIELD(output_scale)
.describe("The scale of the output tensor.");
TVM_ATTR_FIELD(use_int_compute).set_default(false)
.describe("When true, the integer computation is used to handle output scale");
TVM_ATTR_FIELD(use_int_compute).set_default(true)
.describe("When true, the integer computation is used to handle output scale."
"The float compuation can be used as reference implementation or in"
"cases where FP32 computation for requantize is not expensive");
TVM_ATTR_FIELD(out_dtype)
.set_default(NullValue<DataType>())
.describe("Output data type, set to explicit type under mixed precision setting");
TVM_ATTR_FIELD(rounding_mode).set_default("FE_UPWARD")
.describe("Defines the rounding direction when the value is midway between"
"two representable values. There are two supported modes - FE_UPWARD"
"or FE_AWAY_FROM_ZERO. More context can be found at"
"https://www.gnu.org/software/libc/manual/html_node/Rounding.html");
}
};


} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_NN_QUANTIZE_H_
13 changes: 10 additions & 3 deletions python/tvm/relay/op/qnn/qnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
from __future__ import absolute_import as _abs
from . import _make


def requantize(input_data, input_zero_point, input_scale, output_zero_point,
output_scale, out_dtype="int32", use_int_compute=False):
output_scale, out_dtype="int32", use_int_compute=False,
rounding_mode="FE_UPWARD"):
r"""Requantized operator.
The requantize operator converts one quantized tensor to another quantized
Expand Down Expand Up @@ -57,11 +57,18 @@ def requantize(input_data, input_zero_point, input_scale, output_zero_point,
use_int_compute : bool, optional
Use fully integer computation for requantizing.
rounding_mode : string, optional
Defines the rounding direction when the value is midway between two
representable values.
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
assert rounding_mode in ("FE_UPWARD", "FE_AWAY_FROM_ZERO"),\
"Unsupported rounding mode"

return _make.requantize(input_data, input_zero_point, input_scale,
output_zero_point, output_scale, out_dtype,
use_int_compute)
use_int_compute, rounding_mode)
4 changes: 3 additions & 1 deletion src/relay/op/nn/requantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,16 @@ Expr MakeRequantize(Expr data,
int32_t output_zero_point,
double output_scale,
DataType out_dtype,
bool use_int_compute) {
bool use_int_compute,
std::string rounding_mode) {
auto attrs = make_node<RequantizeAttrs>();
attrs->out_dtype = std::move(out_dtype);
attrs->input_zero_point = std::move(input_zero_point);
attrs->output_zero_point = std::move(output_zero_point);
attrs->input_scale = std::move(input_scale);
attrs->output_scale = std::move(output_scale);
attrs->use_int_compute = std::move(use_int_compute);
attrs->rounding_mode = std::move(rounding_mode);
static const Op& op = Op::Get("qnn.requantize");
return CallNode::make(op, {data}, Attrs(attrs), {});
}
Expand Down
231 changes: 109 additions & 122 deletions src/relay/pass/quantize_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,109 +33,44 @@
namespace tvm {
namespace relay {


// Lowering of qnn.requantize op

/*
* Converts a floating point number so that it can be represented by integers.
* The representation is
* float_number = (fixed_point_multiplier) * 2^(shift)
*
* The fixed_point_multiplier is a number between 0.5 and 1. This is represented
* by an integer number. For example, if it is int32, then the decimal point
* exists between bit 31 and 30 from LSB (or between first and second bit from
* the left).
*
* Some examples are
* 0.25 = (0.5) * 2^(-1)
* 0.125 = (0.5) * 2^(-2)
*/
void GetFixedPointMultiplierShift(double double_multiplier,
int32_t* fixed_point_multiplier, int* shift,
const DataType& idtype) {

int acc_dtype_bits = idtype.bits();
int idtype_bits = idtype.bits();

if (double_multiplier == 0.) {
*fixed_point_multiplier = 0;
*shift = 0;
return;
}
const double q = std::frexp(double_multiplier, shift);
auto q_fixed = static_cast<int64_t>(std::round(q * (1ll << (acc_dtype_bits - 1))));
CHECK_LE(q_fixed, (1ll << (acc_dtype_bits - 1)));
if (q_fixed == (1ll << (acc_dtype_bits - 1))) {
auto q_fixed = static_cast<int64_t>(std::round(q * (1ll << (idtype_bits - 1))));
CHECK_LE(q_fixed, (1ll << (idtype_bits - 1)));
if (q_fixed == (1ll << (idtype_bits - 1))) {
q_fixed /= 2;
++*shift;
}
CHECK_LE(q_fixed, std::numeric_limits<int32_t>::max());
*fixed_point_multiplier = static_cast<int32_t>(q_fixed);
}

Expr MultiplyByIntegerMuliplier(const Expr& convolved_tensor,
const int32_t fixed_point_multiplier, const int left_shift,
const RequantizeAttrs*& param, const DataType& idtype,
const Array<IndexExpr>& out_shape) {
// TODO (janimesh) - How to add the overflow checks here. TFLite code snippet is
// bool overflow = a == b && a == std::numeric_limits<std::int32_t>::min();
// return overflow ? std::numeric_limits<std::int32_t>::max() : .....;/

// The calculations are done in upcast of idtype to retain precision.
int acc_dtype_bits = idtype.bits();
DataType up_idtype = Int(2 * acc_dtype_bits);

auto tensor = convolved_tensor;
// Typically the left_shift will be 0 if the original scale is > 0.5.
if (left_shift != 0) {
tensor = Multiply(tensor, MakeConstantScalar(idtype, 1 << left_shift));
}

// Upcast the computation to Int64 and multiply the multiplier.
Expr scalar = MakeConstantScalar(up_idtype, fixed_point_multiplier);
auto multiplied_t = Multiply(Cast(tensor, up_idtype), scalar);

// Since, we are performing fixed point computation. We are only interested in
// higher 16/32 bits. But before that, we also need to perform rounding.
// This is fixed point rounding. So, the rounder add scalar depends if the
// input is positive.
auto zero = MakeConstantScalar(up_idtype, 0);
auto pos_threshold = MakeConstantScalar(up_idtype,
1ll << (acc_dtype_bits - 2));
auto neg_threshold = MakeConstantScalar(up_idtype,
(1 - (1ll << (acc_dtype_bits - 2))));
auto pos_rounder = Full(pos_threshold, out_shape, up_idtype);
auto neg_rounder = Full(neg_threshold, out_shape, up_idtype);
auto rounding_scalar = Where(GreaterEqual(multiplied_t, zero), pos_rounder, neg_rounder);
auto rounded_tensor = Add(multiplied_t, rounding_scalar);

// Perform right shift to get the first 16/32 bits.
// The result is first doubled and the first 15/31 bits are obtained. This is
// done by just right shifting the result by 15/31 bits.
auto right_shift_scalar = MakeConstantScalar(up_idtype, (acc_dtype_bits - 1));
auto scaled_t = RightShift(rounded_tensor, right_shift_scalar);
auto q_imin = get_qmin(idtype);
auto q_imax = get_qmax(idtype);
auto integer_multiplied_t = Cast(Clip(scaled_t, q_imin, q_imax),
idtype);
return integer_multiplied_t;
}

Expr ShiftByIntegerShift(const Expr& multiplied_t,
const int& exponent, const RequantizeAttrs*& param,
const DataType& idtype, const Array<IndexExpr>& out_shape) {
CHECK_GE(exponent, 0);
int acc_dtype_bits = idtype.bits();
CHECK_LE(exponent, (acc_dtype_bits - 1));

// We need to perform rounding. The rounding here is closest to the power
// of 2. The exponent basically represents the decimal point. We need to round
// at the decimal point.
auto tensor = multiplied_t;
if (exponent != 0) {
auto pos_rounder = MakeConstantScalar(idtype, (1ll << (exponent - 1)));
auto neg_rounder = MakeConstantScalar(idtype, (1ll << (exponent - 1)) - 1);
auto pos_rounder_t = Full(pos_rounder, out_shape, idtype);
auto neg_rounder_t = Full(neg_rounder, out_shape, idtype);

auto zero = MakeConstantScalar(idtype, 0);
auto zero_t = Full(zero, out_shape, idtype);
auto round_scalar = Where(GreaterEqual(tensor, zero_t), pos_rounder_t,
neg_rounder_t);
tensor = Add(tensor, round_scalar);
}

// Right shift by exponent to approximate the division.
auto scaled_t = RightShift(tensor,
MakeConstantScalar(idtype, exponent));
return scaled_t;
}


/*
* Requantization using only integer computation. Here, the computation is
* converted to a fixed point computation by computing output multiplier and
Expand All @@ -147,59 +82,123 @@ Expr ShiftByIntegerShift(const Expr& multiplied_t,
* multiplication with an int value and then right shifting the result. This
* approximates the floating point computation with a fixed point computation.
*
* The whole computaition this can be broken down into following steps
* The whole computation this can be broken down into following steps
* 1) Calculate the integer multiplier and integer shift.
* 2) Multiply the integer multiplier with quantized tensor.
* 3) Right shift the result.
* 2) Subtract the input integer point.
* 2) Multiply the integer fixed point multiplier with quantized tensor.
* 3) Round the result.
* 4) Right shift the result.
* 5) Add the output_zero_point.
* 6) Cast to the out_dtype.
*
* The only thing complicating the above computations is the tedious approach of
* handling rounding.
*/
Expr RequantizeInt(const Expr& convolved_tensor,
Expr RequantizeInt(const Expr& input_tensor,
const RequantizeAttrs*& param, const DataType& idtype,
const Array<IndexExpr>& out_shape) {

double double_multiplier = param->input_scale/param->output_scale;

// The multiplication will be performed in higher precision. Find the dtype.
int idtype_bits = idtype.bits();
DataType up_idtype = Int(2 * idtype_bits);

// 1) Calculating the integer multiplier and integer shift
int32_t fixed_point_multiplier;
int shift;
GetFixedPointMultiplierShift(double_multiplier, &fixed_point_multiplier,
&shift, idtype);

// 2) Multiply the integer multiplier
int left_shift = shift > 0 ? shift : 0;
int right_shift = shift > 0 ? 0 : -shift;
auto multiplied_t = MultiplyByIntegerMuliplier(convolved_tensor,
fixed_point_multiplier, left_shift, param, idtype, out_shape);

// 3) Divide by the denominator or right shift the result.
auto scaled_int32_t = ShiftByIntegerShift(multiplied_t,
right_shift, param, idtype, out_shape);
// 2) Subtract the input_zero_point
auto tensor = input_tensor;
tensor = Cast(tensor, up_idtype);
if (param->input_zero_point != 0) {
auto input_zp = MakeConstantScalar(up_idtype, param->input_zero_point);
tensor = Subtract(tensor, input_zp);
}

// 4) Clip to the out_dtype min/max.


// 3) Multiply the integer multiplier
if (left_shift != 0) {
tensor = Multiply(tensor, MakeConstantScalar(up_idtype, 1 << left_shift));
}
// Perform the multiplication in higher precision.
// If idtype is Int(32), the scalar is a fixed point value of int32 where the
// decimal point is between bits 31 and 30. After multiplying with
// input_tensor, the result in int64 where the decimal point is sitting
// between bits 31 and 30 (from the right, rightmost bit is bit 0).
Expr scalar = MakeConstantScalar(up_idtype, fixed_point_multiplier);
auto multiplied_t = Multiply(tensor, scalar);


// 4) Find the rounding scalar. This depends on where the final decimal point
// sits. As we will be right shifting the multiplied_t, we need to first
// calculate the totol_right_shift.
int total_right_shift = right_shift + idtype_bits - 1;

tensor = multiplied_t;
Expr round_scalar;
if (param->rounding_mode == "FE_UPWARD") {
auto pos_rounder = MakeConstantScalar(up_idtype, (1ll << (total_right_shift - 1)));
round_scalar = pos_rounder;
} else if (param->rounding_mode == "FE_AWAY_FROM_ZERO") {
auto pos_rounder = MakeConstantScalar(up_idtype, (1ll << (total_right_shift - 1)));
auto neg_rounder = MakeConstantScalar(up_idtype, (1ll << (total_right_shift - 1)) - 1);
auto pos_rounder_t = Full(pos_rounder, out_shape, up_idtype);
auto neg_rounder_t = Full(neg_rounder, out_shape, up_idtype);

auto zero = MakeConstantScalar(up_idtype, 0);
auto zero_t = Full(zero, out_shape, up_idtype);
round_scalar = Where(GreaterEqual(tensor, zero_t), pos_rounder_t,
neg_rounder_t);
}
// Add the rounding scalar.
tensor = Add(tensor, round_scalar);

// 5) Simply right shift the result to get the final output.
auto scaled_int64_t = RightShift(tensor,
MakeConstantScalar(up_idtype, total_right_shift));

// 6) Add the output zero point.
auto output_zp = MakeConstantScalar(up_idtype, param->output_zero_point);
auto shifted_int64_t = Add(output_zp, scaled_int64_t);

// 7) Clip to the out_dtype min/max.
// Find the right clip min/maxes. While clipping, it is necessary that
// clip_min and clip_max are within the dtype range of the input tensor to the
// clip operator. For example, if the input to clip operator is int8, but the
// out_dtype is uint8, we will get incorrect results, if we set max as 255.
auto q_min = std::max(get_qmin(param->out_dtype), get_qmin(idtype));
auto q_max = std::min(get_qmax(param->out_dtype), get_qmax(idtype));
auto clipped_t = Clip(scaled_int32_t, q_min, q_max);
auto clipped_t = Clip(shifted_int64_t, q_min, q_max);
auto requantized_output = Cast(clipped_t, param->out_dtype);
return requantized_output;
}

/*

/*
* Requantization using floating computation. Here we can multiply the scale to
* the convolved_tensor, round to nearest integer and then cast back to int32.
* the input_tensor, round to nearest integer and then cast back to int32.
*/
Expr RequantizeFloat(const Expr& convolved_tensor,
Expr RequantizeFloat(const Expr& input_tensor,
const RequantizeAttrs*& param, const DataType& idtype,
const Array<IndexExpr>& out_shape) {
double double_multiplier = param->input_scale/param->output_scale;
auto scalar_multiplier = MakeConstantScalar(Float(32), double_multiplier);

// Multiply the convolved tensor with the new scale.
auto casted_t = Cast(convolved_tensor, Float(32));
auto multiplied_t = Round(Multiply(casted_t, scalar_multiplier));
auto input_zp = MakeConstantScalar(idtype, param->input_zero_point);
auto output_zp = MakeConstantScalar(Float(32), param->output_zero_point);

// Multiply the tensor with the new scale.
auto shifted_input_t = Subtract(input_tensor, input_zp);
auto casted_t = Cast(shifted_input_t, Float(32));
auto multiplied_t = Multiply(casted_t, scalar_multiplier);
auto shifted_multiplied_t = Add(output_zp, multiplied_t);
auto rounded_t = Round(shifted_multiplied_t);
auto q_imin = get_qmin(idtype);
auto q_imax = get_qmax(idtype);
auto scaled_int32_t = Cast(Clip(multiplied_t, q_imin, q_imax),
auto scaled_int32_t = Cast(Clip(rounded_t, q_imin, q_imax),
idtype);

// Clip to the out_dtype min/max.
Expand Down Expand Up @@ -243,33 +242,21 @@ Expr RequantizeForwardRewrite(const Call& ref_call,
<< " Please run infer_type pass.";
const auto input_dtype = input_tt->dtype;

// Check for current quantization support.
CHECK_EQ(param->input_zero_point, 0)
<< "Encountered non-zero zero point."
<< " Only symmetric quantization supported for now.";
CHECK_EQ(param->output_zero_point, 0)
<< "Encountered non-zero zero point."
<< " Only symmetric quantization supported for now.";

if (param->use_int_compute) {
return RequantizeInt(quantized_data, param, input_dtype, out_shape);
} else {
return RequantizeFloat(quantized_data, param, input_dtype, out_shape);
}
}


RELAY_REGISTER_OP("qnn.requantize")
.set_attr<FForwardRewrite>("FQuantizeForwardRewrite", RequantizeForwardRewrite);



TVM_REGISTER_API("relay._quantize.rewrite")
.set_body_typed<Expr(Expr)>([](const Expr& e) {
Expr ret = ForwardRewrite(e, "FQuantizeForwardRewrite", nullptr, nullptr);
return ret;
});

Expr ret = ForwardRewrite(e, "FQuantizeForwardRewrite", nullptr, nullptr);
return ret;
});

} // namespace relay
} // namespace tvm
Loading

0 comments on commit 91b58a5

Please sign in to comment.