Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Requantize] Cleanup and Optimize Lowering #5286

Merged
merged 5 commits into from
Apr 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 13 additions & 23 deletions src/relay/qnn/op/requantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,36 +132,28 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale,
const Expr& input_zero_point, const Expr& output_scale,
const Expr& output_zero_point, const RequantizeAttrs* param,
const Array<IndexExpr>& input_shape, const DataType& out_dtype) {
DataType hp_dtype = DataType::Int(64);

auto tensor = Cast(input_tensor, hp_dtype);
auto tensor = Cast(input_tensor, DataType::Int(32));
anijain2305 marked this conversation as resolved.
Show resolved Hide resolved
// 1) Subtract the input_zero_point
auto zero_scalar = MakeConstantScalar(DataType::Int(32), 0);
if (!IsEqualScalar(input_zero_point, zero_scalar)) {
tensor = Subtract(tensor, Cast(input_zero_point, hp_dtype));
tensor = Subtract(tensor, Cast(input_zero_point, DataType::Int(32)));
}

// Check if multiplier is greater than 1.
bool is_multiplier_gt_one = false;

// 2) If the input and output scales are same, we can skip the fixed point multiplication. Check
// if the input scale is per-tensor or per-channel. If it is per-tensor, there is single scale for
// the whole tensor. For per-channel (aka per-axis), there is a vector of scales for the input
// tensor. Depending on the quantization type, the fixed point multiplication routing is called.
auto scaled_int64_t = tensor;
auto scaled_int32_t = tensor;
float output_scale_float = GetScalarFromConstant<float>(output_scale);
if (IsConstScalar(input_scale)) {
// This is per-tensor quantization. Single scale.
float input_scale_float = GetScalarFromConstant<float>(input_scale);
double double_multiplier =
static_cast<double>(input_scale_float) / static_cast<double>(output_scale_float);
if (double_multiplier > 1) {
is_multiplier_gt_one = true;
}
// Skip if input and output scales are same.
if (!IsEqualScalar(input_scale, output_scale)) {
scaled_int64_t =
FixedPointMultiply(scaled_int64_t, double_multiplier, input_shape, param->rounding);
scaled_int32_t =
FixedPointMultiply(scaled_int32_t, double_multiplier, input_shape, param->rounding);
}
} else {
// This is per-channel (per=axis) quantization.
Expand All @@ -171,30 +163,28 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale,
double multiplier =
static_cast<double>(input_axis_scale) / static_cast<double>(output_scale_float);
double_multipliers.push_back(multiplier);
if (multiplier > 1) {
is_multiplier_gt_one = true;
}
}
int axis = param->axis;
axis = (axis == -1) ? input_shape.size() - 1 : axis;
scaled_int64_t = FixedPointMultiplyPerChannel(scaled_int64_t, double_multipliers, input_shape,
scaled_int32_t = FixedPointMultiplyPerChannel(scaled_int32_t, double_multipliers, input_shape,
axis, param->rounding);
}

// 3) Add the output zero point.
auto shifted_int64_t = scaled_int64_t;
auto shifted_int32_t = scaled_int32_t;
if (!IsEqualScalar(output_zero_point, zero_scalar)) {
shifted_int64_t = Add(Cast(output_zero_point, hp_dtype), scaled_int64_t);
shifted_int32_t = Add(Cast(output_zero_point, DataType::Int(32)), scaled_int32_t);
}

// 4) Clip to the out_dtype min/max. Skip clipping if out_dtype is Int32. The fixed point
// multiplication keeps the value in int32 range if the requantize scale is less than 1.
if (out_dtype == DataType::Int(32) && !is_multiplier_gt_one) {
return Cast(shifted_int64_t, out_dtype);
// multiplication keeps the value in int32 range.
if (out_dtype == DataType::Int(32)) {
return shifted_int32_t;
}

auto q_min = GetQmin(out_dtype);
auto q_max = GetQmax(out_dtype);
auto clipped_t = Clip(shifted_int64_t, q_min, q_max);
auto clipped_t = Clip(shifted_int32_t, q_min, q_max);
return Cast(clipped_t, out_dtype);
}

Expand Down
8 changes: 6 additions & 2 deletions src/relay/qnn/util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ Expr FixedPointMultiply(Expr tensor, double multiplier, const Array<IndexExpr>&
// Choose high precision datatype to be int64. This is for avoiding overflow
// in multiplication of two int32 values.
DataType hp_dtype = DataType::Int(64);
tensor = Cast(tensor, hp_dtype);

// 1) Calculating the integer multiplier and integer shift
int32_t fixed_point_multiplier, shift;
Expand Down Expand Up @@ -130,7 +131,8 @@ Expr FixedPointMultiply(Expr tensor, double multiplier, const Array<IndexExpr>&
tensor =
RightShift(tensor, MakeConstantScalar(hp_dtype, total_right_shift));

return tensor;
// 6) The fixed point multiplication keeps the value in int32 range. Casting back to int32.
return Cast(tensor, DataType::Int(32));
}

Expr FixedPointMultiplyPerChannel(Expr tensor, std::vector<double> multipliers,
Expand All @@ -145,6 +147,7 @@ Expr FixedPointMultiplyPerChannel(Expr tensor, std::vector<double> multipliers,
// Choose high precision datatype to be int64. This is for avoiding overflow
// in multiplication of two int32 values.
DataType hp_dtype = DataType::Int(64);
tensor = Cast(tensor, hp_dtype);

// 1) Calculating the integer multiplier and integer shift. These are calculated per axis/per
// channel.
Expand Down Expand Up @@ -218,7 +221,8 @@ Expr FixedPointMultiplyPerChannel(Expr tensor, std::vector<double> multipliers,
auto exp_total_rshift_expr = ExpandBiasToMatchAxis(total_rshift_expr, n_dim, {channel_axis});
tensor = RightShift(tensor, exp_total_rshift_expr);

return tensor;
// 6) The fixed point multiplication keeps the value in int32 range. Casting back to int32.
return Cast(tensor, DataType::Int(32));
}

} // namespace qnn
Expand Down
3 changes: 1 addition & 2 deletions src/relay/quantize/realize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,7 @@ inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype,
} else if (static_cast<int>(factor) == factor) {
return Multiply(data, MakeConstantScalar(dtype, factor));
} else {
data = qnn::FixedPointMultiply(
Cast(data, DataType::Int(64)), factor, data_shape, cfg->rounding);
data = qnn::FixedPointMultiply(data, factor, data_shape, cfg->rounding);
return Cast(data, dtype);
}
}
Expand Down