Skip to content

Commit

Permalink
Removing extra clip.
Browse files Browse the repository at this point in the history
  • Loading branch information
anijain2305 committed Apr 9, 2020
1 parent 1dce423 commit 8a0f8dc
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 5 deletions.
6 changes: 3 additions & 3 deletions python/tvm/_ffi/_pyversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,6 @@
#----------------------------
# Python3 version.
#----------------------------
if not (sys.version_info[0] >= 3 and sys.version_info[1] >= 6):
PY3STATEMENT = "The minimal Python requirement is Python 3.6"
raise Exception(PY3STATEMENT)
# if not (sys.version_info[0] >= 3 and sys.version_info[1] >= 6):
# PY3STATEMENT = "The minimal Python requirement is Python 3.6"
# raise Exception(PY3STATEMENT)
7 changes: 6 additions & 1 deletion src/relay/qnn/op/requantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,6 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale,
const Expr& input_zero_point, const Expr& output_scale,
const Expr& output_zero_point, const RequantizeAttrs* param,
const Array<IndexExpr>& input_shape, const DataType& out_dtype) {

auto tensor = Cast(input_tensor, DataType::Int(32));
// 1) Subtract the input_zero_point
auto zero_scalar = MakeConstantScalar(DataType::Int(32), 0);
Expand Down Expand Up @@ -177,6 +176,12 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale,
shifted_int32_t = Add(Cast(output_zero_point, DataType::Int(32)), scaled_int32_t);
}

// 4) Clip to the out_dtype min/max. Skip clipping if out_dtype is Int32. The fixed point
// multiplication keeps the value in int32 range.
if (out_dtype == DataType::Int(32)) {
return shifted_int32_t;
}

auto q_min = GetQmin(out_dtype);
auto q_max = GetQmax(out_dtype);
auto clipped_t = Clip(shifted_int32_t, q_min, q_max);
Expand Down
2 changes: 1 addition & 1 deletion src/relay/qnn/util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ Expr FixedPointMultiply(Expr tensor, double multiplier, const Array<IndexExpr>&
tensor =
RightShift(tensor, MakeConstantScalar(hp_dtype, total_right_shift));

// The fixed point multiplication keeps the value in int32 range. Casting back to int32.
// 6) The fixed point multiplication keeps the value in int32 range. Casting back to int32.
return Cast(tensor, DataType::Int(32));
}

Expand Down

0 comments on commit 8a0f8dc

Please sign in to comment.