From f9561185a0168b5a660fc6b5ca219e7d3016a735 Mon Sep 17 00:00:00 2001
From: Animesh Jain <anijain@umich.edu>
Date: Wed, 10 Jul 2019 21:51:02 +0000
Subject: [PATCH] Requantize operator implementation.

Requantize converts one quantized tensor representation to another quantized
representation. The PR has following implementation features

- Requantize operator defined in qnn namespace - relay.qnn.requantize
- Lowering of the requantize to exisiting Relay operators
- Integer fixed point implementation of requantize
    - Two rounding modes - FE_UPWARDS (round towards infinity) and
    FE_AWAY_FROM_ZERO (std::round behavior)
- Floating point implementation as well, that can act as reference or can be
used for devices when FP32 computation is not used.
- Unit test cases

Relevant Issue - https://github.com/dmlc/tvm/issues/2351

Credit to TFLite and GemmLowp to provide reference implementations.
---
 include/tvm/relay/attrs/qnn.h               |  13 +-
 python/tvm/relay/op/qnn/qnn.py              |  13 +-
 src/relay/op/nn/requantize.cc               |   4 +-
 src/relay/pass/quantize_rewrite.cc          | 231 +++++++++---------
 tests/python/unittest/test_quantized_ops.py | 257 ++++++++++++++++++++
 5 files changed, 390 insertions(+), 128 deletions(-)
 create mode 100644 tests/python/unittest/test_quantized_ops.py

diff --git a/include/tvm/relay/attrs/qnn.h b/include/tvm/relay/attrs/qnn.h
index 12afe19d26b3e..cf69fa759c1c0 100644
--- a/include/tvm/relay/attrs/qnn.h
+++ b/include/tvm/relay/attrs/qnn.h
@@ -37,6 +37,7 @@ struct RequantizeAttrs : public tvm::AttrsNode<RequantizeAttrs> {
   double output_scale;
   int32_t output_zero_point;
   bool use_int_compute;
+  std::string rounding_mode;
   DataType out_dtype;
 
   TVM_DECLARE_ATTRS(RequantizeAttrs, "relay.attrs.RequantizeAttrs") {
@@ -48,14 +49,22 @@ struct RequantizeAttrs : public tvm::AttrsNode<RequantizeAttrs> {
         .describe("The scale of the input tensor.");
     TVM_ATTR_FIELD(output_scale)
         .describe("The scale of the output tensor.");
-    TVM_ATTR_FIELD(use_int_compute).set_default(false)
-        .describe("When true, the integer computation is used to handle output scale");
+    TVM_ATTR_FIELD(use_int_compute).set_default(true)
+      .describe("When true, the integer computation is used to handle output scale."
+                "The float compuation can be used as reference implementation or in"
+                "cases where FP32 computation for requantize is not expensive");
     TVM_ATTR_FIELD(out_dtype)
         .set_default(NullValue<DataType>())
         .describe("Output data type, set to explicit type under mixed precision setting");
+    TVM_ATTR_FIELD(rounding_mode).set_default("FE_UPWARD")
+        .describe("Defines the rounding direction when the value is midway between"
+                  "two representable values. There are two supported modes - FE_UPWARD"
+                  "or FE_AWAY_FROM_ZERO. More context can be found at"
+                  "https://www.gnu.org/software/libc/manual/html_node/Rounding.html");
   }
 };
 
+
 }  // namespace relay
 }  // namespace tvm
 #endif  // TVM_RELAY_ATTRS_NN_QUANTIZE_H_
diff --git a/python/tvm/relay/op/qnn/qnn.py b/python/tvm/relay/op/qnn/qnn.py
index 18be68cd9cfc4..484b3864f22fb 100644
--- a/python/tvm/relay/op/qnn/qnn.py
+++ b/python/tvm/relay/op/qnn/qnn.py
@@ -19,9 +19,9 @@
 from __future__ import absolute_import as _abs
 from . import _make
 
-
 def requantize(input_data, input_zero_point, input_scale, output_zero_point,
-        output_scale, out_dtype="int32", use_int_compute=False):
+        output_scale, out_dtype="int32", use_int_compute=False,
+        rounding_mode="FE_UPWARD"):
     r"""Requantized operator.
 
     The requantize operator converts one quantized tensor to another quantized
@@ -57,11 +57,18 @@ def requantize(input_data, input_zero_point, input_scale, output_zero_point,
     use_int_compute : bool, optional
         Use fully integer computation for requantizing.
 
+    rounding_mode : string, optional
+        Defines the rounding direction when the value is midway between two
+        representable values.
+
     Returns
     -------
     result : tvm.relay.Expr
         The computed result.
     """
+    assert rounding_mode in ("FE_UPWARD", "FE_AWAY_FROM_ZERO"),\
+            "Unsupported rounding mode"
+
     return _make.requantize(input_data, input_zero_point, input_scale,
                             output_zero_point, output_scale, out_dtype,
-                            use_int_compute)
\ No newline at end of file
+                            use_int_compute, rounding_mode)
diff --git a/src/relay/op/nn/requantize.cc b/src/relay/op/nn/requantize.cc
index 80f2bde4ad472..285528993f6f8 100644
--- a/src/relay/op/nn/requantize.cc
+++ b/src/relay/op/nn/requantize.cc
@@ -59,7 +59,8 @@ Expr MakeRequantize(Expr data,
                     int32_t output_zero_point,
                     double output_scale,
                     DataType out_dtype,
-                    bool use_int_compute) {
+                    bool use_int_compute,
+                    std::string rounding_mode) {
   auto attrs = make_node<RequantizeAttrs>();
   attrs->out_dtype = std::move(out_dtype);
   attrs->input_zero_point = std::move(input_zero_point);
@@ -67,6 +68,7 @@ Expr MakeRequantize(Expr data,
   attrs->input_scale = std::move(input_scale);
   attrs->output_scale = std::move(output_scale);
   attrs->use_int_compute = std::move(use_int_compute);
+  attrs->rounding_mode = std::move(rounding_mode);
   static const Op& op = Op::Get("qnn.requantize");
   return CallNode::make(op, {data}, Attrs(attrs), {});
 }
diff --git a/src/relay/pass/quantize_rewrite.cc b/src/relay/pass/quantize_rewrite.cc
index 55f8c43fd49fc..645b20c0730e1 100644
--- a/src/relay/pass/quantize_rewrite.cc
+++ b/src/relay/pass/quantize_rewrite.cc
@@ -33,13 +33,27 @@
 namespace tvm {
 namespace relay {
 
-
 // Lowering of qnn.requantize op
+
+/*
+ * Converts a floating point number so that it can be represented by integers.
+ * The representation is
+ *      float_number = (fixed_point_multiplier) * 2^(shift)
+ *
+ * The fixed_point_multiplier is a number between 0.5 and 1. This is represented
+ * by an integer number. For example, if it is int32, then the decimal point
+ * exists between bit 31 and 30 from LSB (or between first and second bit from
+ * the left).
+ *
+ * Some examples are
+ *           0.25 = (0.5) * 2^(-1)
+ *           0.125 = (0.5) * 2^(-2)
+ */
 void GetFixedPointMultiplierShift(double double_multiplier,
     int32_t* fixed_point_multiplier, int* shift,
     const DataType& idtype) {
 
-  int acc_dtype_bits = idtype.bits();
+  int idtype_bits = idtype.bits();
 
   if (double_multiplier == 0.) {
     *fixed_point_multiplier = 0;
@@ -47,9 +61,9 @@ void GetFixedPointMultiplierShift(double double_multiplier,
     return;
   }
   const double q = std::frexp(double_multiplier, shift);
-  auto q_fixed = static_cast<int64_t>(std::round(q * (1ll << (acc_dtype_bits - 1))));
-  CHECK_LE(q_fixed, (1ll << (acc_dtype_bits - 1)));
-  if (q_fixed == (1ll << (acc_dtype_bits - 1))) {
+  auto q_fixed = static_cast<int64_t>(std::round(q * (1ll << (idtype_bits - 1))));
+  CHECK_LE(q_fixed, (1ll << (idtype_bits - 1)));
+  if (q_fixed == (1ll << (idtype_bits - 1))) {
     q_fixed /= 2;
     ++*shift;
   }
@@ -57,85 +71,6 @@ void GetFixedPointMultiplierShift(double double_multiplier,
   *fixed_point_multiplier = static_cast<int32_t>(q_fixed);
 }
 
-Expr MultiplyByIntegerMuliplier(const Expr& convolved_tensor,
-    const int32_t fixed_point_multiplier, const int left_shift,
-    const RequantizeAttrs*& param, const DataType& idtype,
-    const Array<IndexExpr>& out_shape) {
-  // TODO (janimesh) - How to add the overflow checks here. TFLite code snippet is
-  // bool overflow = a == b && a == std::numeric_limits<std::int32_t>::min();
-  // return overflow ? std::numeric_limits<std::int32_t>::max() : .....;/
-
-  // The calculations are done in upcast of idtype to retain precision.
-  int acc_dtype_bits = idtype.bits();
-  DataType up_idtype = Int(2 * acc_dtype_bits);
-
-  auto tensor = convolved_tensor;
-  // Typically the left_shift will be 0 if the original scale is > 0.5.
-  if (left_shift != 0) {
-    tensor = Multiply(tensor, MakeConstantScalar(idtype, 1 << left_shift));
-  }
-
-  // Upcast the computation to Int64 and multiply the multiplier.
-  Expr scalar = MakeConstantScalar(up_idtype, fixed_point_multiplier);
-  auto multiplied_t = Multiply(Cast(tensor, up_idtype), scalar);
-
-  // Since, we are performing fixed point computation. We are only interested in
-  // higher 16/32 bits. But before that, we also need to perform rounding.
-  // This is fixed point rounding. So, the rounder add scalar depends if the
-  // input is positive.
-  auto zero = MakeConstantScalar(up_idtype, 0);
-  auto pos_threshold = MakeConstantScalar(up_idtype,
-          1ll << (acc_dtype_bits - 2));
-  auto neg_threshold = MakeConstantScalar(up_idtype,
-          (1 - (1ll << (acc_dtype_bits - 2))));
-  auto pos_rounder = Full(pos_threshold, out_shape, up_idtype);
-  auto neg_rounder = Full(neg_threshold, out_shape, up_idtype);
-  auto rounding_scalar = Where(GreaterEqual(multiplied_t, zero), pos_rounder, neg_rounder);
-  auto rounded_tensor = Add(multiplied_t, rounding_scalar);
-
-  // Perform right shift to get the first 16/32 bits.
-  // The result is first doubled and the first 15/31 bits are obtained. This is
-  // done by just right shifting the result by 15/31 bits.
-  auto right_shift_scalar = MakeConstantScalar(up_idtype, (acc_dtype_bits - 1));
-  auto scaled_t = RightShift(rounded_tensor, right_shift_scalar);
-  auto q_imin = get_qmin(idtype);
-  auto q_imax = get_qmax(idtype);
-  auto integer_multiplied_t = Cast(Clip(scaled_t, q_imin, q_imax),
-          idtype);
-  return integer_multiplied_t;
-}
-
-Expr ShiftByIntegerShift(const Expr& multiplied_t,
-    const int& exponent, const RequantizeAttrs*& param,
-    const DataType& idtype, const Array<IndexExpr>& out_shape) {
-  CHECK_GE(exponent, 0);
-  int acc_dtype_bits = idtype.bits();
-  CHECK_LE(exponent, (acc_dtype_bits - 1));
-
-  // We need to perform rounding. The rounding here is closest to the power
-  // of 2. The exponent basically represents the decimal point. We need to round
-  // at the decimal point.
-  auto tensor = multiplied_t;
-  if (exponent != 0) {
-    auto pos_rounder = MakeConstantScalar(idtype, (1ll << (exponent - 1)));
-    auto neg_rounder = MakeConstantScalar(idtype, (1ll << (exponent - 1)) - 1);
-    auto pos_rounder_t = Full(pos_rounder, out_shape, idtype);
-    auto neg_rounder_t = Full(neg_rounder, out_shape, idtype);
-
-    auto zero = MakeConstantScalar(idtype, 0);
-    auto zero_t = Full(zero, out_shape, idtype);
-    auto round_scalar = Where(GreaterEqual(tensor, zero_t), pos_rounder_t,
-            neg_rounder_t);
-    tensor = Add(tensor, round_scalar);
-  }
-
-  // Right shift by exponent to approximate the division.
-  auto scaled_t = RightShift(tensor,
-          MakeConstantScalar(idtype, exponent));
-  return scaled_t;
-}
-
-
 /*
  * Requantization using only integer computation. Here, the computation is
  * converted to a fixed point computation by computing output multiplier and
@@ -147,59 +82,123 @@ Expr ShiftByIntegerShift(const Expr& multiplied_t,
  * multiplication with an int value and then right shifting the result. This
  * approximates the floating point computation with a fixed point computation.
  *
- * The whole computaition this can be broken down into following steps 
+ * The whole computation this can be broken down into following steps
  * 1) Calculate the integer multiplier and integer shift.
- * 2) Multiply the integer multiplier with quantized tensor.
- * 3) Right shift the result.
+ * 2) Subtract the input integer point.
+ * 2) Multiply the integer fixed point multiplier with quantized tensor.
+ * 3) Round the result.
+ * 4) Right shift the result.
+ * 5) Add the output_zero_point.
+ * 6) Cast to the out_dtype.
  *
- * The only thing complicating the above computations is the tedious approach of
- * handling rounding.
  */
-Expr RequantizeInt(const Expr& convolved_tensor,
+Expr RequantizeInt(const Expr& input_tensor,
     const RequantizeAttrs*& param, const DataType& idtype,
     const Array<IndexExpr>& out_shape) {
 
   double double_multiplier = param->input_scale/param->output_scale;
+
+  // The multiplication will be performed in higher precision. Find the dtype.
+  int idtype_bits = idtype.bits();
+  DataType up_idtype = Int(2 * idtype_bits);
+
   // 1) Calculating the integer multiplier and integer shift
   int32_t fixed_point_multiplier;
   int shift;
   GetFixedPointMultiplierShift(double_multiplier, &fixed_point_multiplier,
           &shift, idtype);
-
-  // 2) Multiply the integer multiplier
   int left_shift = shift > 0 ? shift : 0;
   int right_shift = shift > 0 ? 0 : -shift;
-  auto multiplied_t = MultiplyByIntegerMuliplier(convolved_tensor,
-          fixed_point_multiplier, left_shift, param, idtype, out_shape);
 
-  // 3) Divide by the denominator or right shift the result.
-  auto scaled_int32_t = ShiftByIntegerShift(multiplied_t,
-          right_shift, param, idtype, out_shape);
+  // 2) Subtract the input_zero_point
+  auto tensor = input_tensor;
+  tensor = Cast(tensor, up_idtype);
+  if (param->input_zero_point != 0) {
+    auto input_zp = MakeConstantScalar(up_idtype, param->input_zero_point);
+    tensor = Subtract(tensor, input_zp);
+  }
 
-  // 4) Clip to the out_dtype min/max.
+
+
+  // 3) Multiply the integer multiplier
+  if (left_shift != 0) {
+    tensor = Multiply(tensor, MakeConstantScalar(up_idtype, 1 << left_shift));
+  }
+  // Perform the multiplication in higher precision.
+  // If idtype is Int(32), the scalar is a fixed point value of int32 where the
+  // decimal point is between bits 31 and 30. After multiplying with
+  // input_tensor, the result in int64 where the decimal point is sitting
+  // between bits 31 and 30 (from the right, rightmost bit is bit 0).
+  Expr scalar = MakeConstantScalar(up_idtype, fixed_point_multiplier);
+  auto multiplied_t = Multiply(tensor, scalar);
+
+
+  // 4) Find the rounding scalar. This depends on where the final decimal point
+  // sits. As we will be right shifting the multiplied_t, we need to first
+  // calculate the totol_right_shift.
+  int total_right_shift = right_shift + idtype_bits - 1;
+
+  tensor = multiplied_t;
+  Expr round_scalar;
+  if (param->rounding_mode == "FE_UPWARD") {
+    auto pos_rounder = MakeConstantScalar(up_idtype, (1ll << (total_right_shift - 1)));
+    round_scalar = pos_rounder;
+  } else if (param->rounding_mode == "FE_AWAY_FROM_ZERO") {
+    auto pos_rounder = MakeConstantScalar(up_idtype, (1ll << (total_right_shift - 1)));
+    auto neg_rounder = MakeConstantScalar(up_idtype, (1ll << (total_right_shift - 1)) - 1);
+    auto pos_rounder_t = Full(pos_rounder, out_shape, up_idtype);
+    auto neg_rounder_t = Full(neg_rounder, out_shape, up_idtype);
+
+    auto zero = MakeConstantScalar(up_idtype, 0);
+    auto zero_t = Full(zero, out_shape, up_idtype);
+    round_scalar = Where(GreaterEqual(tensor, zero_t), pos_rounder_t,
+            neg_rounder_t);
+  }
+  // Add the rounding scalar.
+  tensor = Add(tensor, round_scalar);
+
+  // 5) Simply right shift the result to get the final output.
+  auto scaled_int64_t = RightShift(tensor,
+          MakeConstantScalar(up_idtype, total_right_shift));
+
+  // 6) Add the output zero point.
+  auto output_zp = MakeConstantScalar(up_idtype, param->output_zero_point);
+  auto shifted_int64_t = Add(output_zp, scaled_int64_t);
+
+  // 7) Clip to the out_dtype min/max.
+  // Find the right clip min/maxes. While clipping, it is necessary that
+  // clip_min and clip_max are within the dtype range of the input tensor to the
+  // clip operator. For example, if the input to clip operator is int8, but the
+  // out_dtype is uint8, we will get incorrect results, if we set max as 255.
   auto q_min = std::max(get_qmin(param->out_dtype), get_qmin(idtype));
   auto q_max = std::min(get_qmax(param->out_dtype), get_qmax(idtype));
-  auto clipped_t = Clip(scaled_int32_t, q_min, q_max);
+  auto clipped_t = Clip(shifted_int64_t, q_min, q_max);
   auto requantized_output = Cast(clipped_t, param->out_dtype);
   return requantized_output;
 }
 
-/* 
+
+/*
  * Requantization using floating computation. Here we can multiply the scale to
- * the convolved_tensor, round to nearest integer and then cast back to int32.
+ * the input_tensor, round to nearest integer and then cast back to int32.
  */
-Expr RequantizeFloat(const Expr& convolved_tensor,
+Expr RequantizeFloat(const Expr& input_tensor,
     const RequantizeAttrs*& param, const DataType& idtype,
     const Array<IndexExpr>& out_shape) {
   double double_multiplier = param->input_scale/param->output_scale;
   auto scalar_multiplier = MakeConstantScalar(Float(32), double_multiplier);
-
-  // Multiply the convolved tensor with the new scale.
-  auto casted_t = Cast(convolved_tensor, Float(32));
-  auto multiplied_t = Round(Multiply(casted_t, scalar_multiplier));
+  auto input_zp = MakeConstantScalar(idtype, param->input_zero_point);
+  auto output_zp = MakeConstantScalar(Float(32), param->output_zero_point);
+
+  // Multiply the tensor with the new scale.
+  auto shifted_input_t = Subtract(input_tensor, input_zp);
+  auto casted_t = Cast(shifted_input_t, Float(32));
+  auto multiplied_t = Multiply(casted_t, scalar_multiplier);
+  auto shifted_multiplied_t = Add(output_zp, multiplied_t);
+  auto rounded_t = Round(shifted_multiplied_t);
   auto q_imin = get_qmin(idtype);
   auto q_imax = get_qmax(idtype);
-  auto scaled_int32_t = Cast(Clip(multiplied_t, q_imin, q_imax),
+  auto scaled_int32_t = Cast(Clip(rounded_t, q_imin, q_imax),
           idtype);
 
   // Clip to the out_dtype min/max.
@@ -243,14 +242,6 @@ Expr RequantizeForwardRewrite(const Call& ref_call,
       << " Please run infer_type pass.";
   const auto input_dtype = input_tt->dtype;
 
-  // Check for current quantization support.
-  CHECK_EQ(param->input_zero_point, 0)
-      << "Encountered non-zero zero point."
-      << " Only symmetric quantization supported for now.";
-  CHECK_EQ(param->output_zero_point, 0)
-      << "Encountered non-zero zero point."
-      << " Only symmetric quantization supported for now.";
-
   if (param->use_int_compute) {
     return RequantizeInt(quantized_data, param, input_dtype, out_shape);
   } else {
@@ -258,18 +249,14 @@ Expr RequantizeForwardRewrite(const Call& ref_call,
   }
 }
 
-
 RELAY_REGISTER_OP("qnn.requantize")
 .set_attr<FForwardRewrite>("FQuantizeForwardRewrite", RequantizeForwardRewrite);
 
-
-
 TVM_REGISTER_API("relay._quantize.rewrite")
 .set_body_typed<Expr(Expr)>([](const Expr& e) {
-  Expr ret = ForwardRewrite(e, "FQuantizeForwardRewrite", nullptr, nullptr);
-  return ret;
-});
-
+          Expr ret = ForwardRewrite(e, "FQuantizeForwardRewrite", nullptr, nullptr);
+            return ret;
+            });
 
 }  // namespace relay
 }  // namespace tvm
diff --git a/tests/python/unittest/test_quantized_ops.py b/tests/python/unittest/test_quantized_ops.py
new file mode 100644
index 0000000000000..e70ea09252313
--- /dev/null
+++ b/tests/python/unittest/test_quantized_ops.py
@@ -0,0 +1,257 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+import numpy as np
+from tvm import relay
+from tvm.relay.testing import create_workload
+from tvm.contrib import graph_runtime
+
+rounding_modes = ["FE_UPWARD", "FE_AWAY_FROM_ZERO"]
+
+def run_infer_type(expr):
+    mod = relay.Module.from_expr(expr)
+    mod = relay.transform.InferType()(mod)
+    entry = mod["main"]
+    return entry if isinstance(expr, relay.Function) else entry.body
+
+
+def test_requantize():
+    def verify(func, goldens):
+        with relay.build_config(opt_level=0):
+            graph, lib, params = relay.build(func, "llvm", params=None)
+            golden_data, golden_output = goldens
+            mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
+            mod.set_input("quantized_data",golden_data)
+            mod.set_input(**params)
+            mod.run()
+            res = mod.get_output(0).asnumpy()
+            np.testing.assert_equal(res, golden_output)
+
+    def get_func(data_shape, data_dtype, out_dtype, use_int_compute,
+            rounding_mode, input_scale, output_scale, input_zero_point=0,
+            output_zero_point=0):
+        quantized_data = relay.var("quantized_data", shape=data_shape,
+                dtype=data_dtype)
+        func = relay.op.qnn.requantize(
+                quantized_data,
+                input_zero_point=input_zero_point,
+                output_zero_point=output_zero_point,
+                input_scale=input_scale,
+                output_scale=output_scale,
+                rounding_mode=rounding_mode,
+                out_dtype=out_dtype,
+                use_int_compute=use_int_compute)
+
+        func = relay.Function(relay.analysis.free_vars(func),
+                func)
+        func = run_infer_type(func)
+        func = relay.quantize.rewrite(func)
+        print(func)
+        return func
+
+
+    def run_tests():
+        def same_scale_test():
+            # Have same scales, everything within range
+            golden_data = np.arange(-100, 100, 1).astype('int32')
+            golden_output = golden_data
+
+            for rounding_mode in rounding_modes:
+                for use_int_compute in [True, False]:
+                    func = get_func(data_shape=(200, ),
+                                    data_dtype='int32',
+                                    out_dtype="int8",
+                                    use_int_compute=use_int_compute,
+                                    rounding_mode=rounding_mode,
+                                    input_scale=0.5,
+                                    output_scale=0.5)
+                    verify(func, (golden_data, golden_output))
+
+        def downscale_test():
+            for rounding_mode in rounding_modes:
+                for use_int_compute in [True, False]:
+                    func = get_func(data_shape=(32, ),
+                                    data_dtype='int32',
+                                    out_dtype="int32",
+                                    use_int_compute=use_int_compute,
+                                    rounding_mode=rounding_mode,
+                                    input_scale=1,
+                                    output_scale=16)
+
+                    # Try positive values
+                    # 8 corresponds to 0.5, resulting in 1
+                    golden_data = np.arange(0, 32, 1).astype('int32')
+                    golden_output = np.repeat([0, 1, 2], [8, 16, 8])
+                    verify(func, (golden_data, golden_output))
+
+                    # Try negative values
+                    # -8 corresponds to -0.5. For FE_UPWARD, this is 0
+                    golden_data = np.arange(0, -32, -1).astype('int32')
+                    if use_int_compute == True and rounding_mode == "FE_UPWARD":
+                        golden_output = np.repeat([0, -1, -2], [9, 16, 7])
+                    else:
+                        golden_output = np.repeat([0, -1, -2], [8, 16, 8])
+                    verify(func, (golden_data, golden_output))
+
+                # Try a different scale
+                for use_int_compute in [True, False]:
+                    func = get_func(data_shape=(32, ),
+                                    data_dtype='int32',
+                                    out_dtype="int8",
+                                    use_int_compute=use_int_compute,
+                                    rounding_mode=rounding_mode,
+                                    input_scale=1,
+                                    output_scale=4)
+
+                    # Try positive values
+                    # 2I corresponds to 0.5, resulting in 1
+                    golden_data = np.arange(0, 32, 1).astype('int32')
+                    golden_output = np.repeat([0, 1, 2, 3, 4, 5, 6, 7, 8],
+                                              [2, 4, 4, 4, 4, 4, 4, 4, 2])
+                    verify(func, (golden_data, golden_output))
+
+                    # Try negative values
+                    # -8 corresponds to -0.5. For FE_UPWARD, this is 0
+                    golden_data = np.arange(0, -32, -1).astype('int32')
+                    if use_int_compute == True and rounding_mode == "FE_UPWARD":
+                        golden_output = np.repeat([0, -1, -2, -3, -4, -5, -6, -7, -8],
+                                                  [3, 4, 4, 4, 4, 4, 4, 4, 1])
+                    else:
+                        golden_output = np.repeat([0, -1, -2, -3, -4, -5, -6, -7, -8],
+                                                  [2, 4, 4, 4, 4, 4, 4, 4, 2])
+                    verify(func, (golden_data, golden_output))
+
+        def upscale_test():
+            for rounding_mode in rounding_modes:
+                for use_int_compute in [True, False]:
+                    func = get_func(data_shape=(32, ),
+                                    data_dtype='int32',
+                                    out_dtype="int8",
+                                    use_int_compute=use_int_compute,
+                                    rounding_mode=rounding_mode,
+                                    input_scale=2,
+                                    output_scale=1)
+
+                    # Try positive values
+                    # 8 corresponds to 0.5, resulting in 1
+                    golden_data = np.arange(0, 32, 1).astype('int32')
+                    golden_output = np.multiply(2, golden_data)
+                    verify(func, (golden_data, golden_output))
+
+                    # Try negative values
+                    # -8 corresponds to -0.5. For FE_UPWARD, this is 0
+                    golden_data = np.arange(0, -32, -1).astype('int32')
+                    golden_output = np.multiply(2, golden_data)
+                    verify(func, (golden_data, golden_output))
+
+        def saturation_test():
+            for rounding_mode in rounding_modes:
+                for use_int_compute in [True, False]:
+                    func = get_func(data_shape=(16, ),
+                                    data_dtype='int32',
+                                    out_dtype="int8",
+                                    use_int_compute=use_int_compute,
+                                    rounding_mode=rounding_mode,
+                                    input_scale=0.5,
+                                    output_scale=0.5)
+                    golden_data = np.arange(0, 16, 1).astype('int32')
+                    golden_data = np.add(120, golden_data)
+                    output = np.array([120, 121, 122, 123, 124, 125, 126, 127,
+                                       127, 127, 127, 127, 127, 127, 127, 127])
+                    golden_output = output
+                    verify(func, (golden_data, golden_output))
+
+                    # Try negative numbers
+                    golden_data = np.arange(0, -16, -1).astype('int32')
+                    golden_data = np.add(-120, golden_data)
+                    output = np.array([-120, -121, -122, -123, -124, -125, -126, -127,
+                                       -128, -128, -128, -128, -128, -128, -128, -128])
+                    golden_output = output
+                    verify(func, (golden_data, golden_output))
+
+        def zero_point_test():
+            # Output zero point
+            for rounding_mode in rounding_modes:
+                for use_int_compute in [True, False]:
+                    func = get_func(data_shape=(32, ),
+                                    data_dtype='int32',
+                                    out_dtype="int32",
+                                    use_int_compute=use_int_compute,
+                                    rounding_mode=rounding_mode,
+                                    input_scale=1,
+                                    output_scale=16,
+                                    output_zero_point=1)
+
+                    # Try positive values
+                    # 8 corresponds to 0.5, resulting in 1
+                    golden_data = np.arange(0, 32, 1).astype('int32')
+                    golden_output = np.repeat([0, 1, 2], [8, 16, 8])
+                    golden_output = np.add(1, golden_output)
+                    verify(func, (golden_data, golden_output))
+
+                    # Try negative values
+                    # -8 corresponds to -0.5. For FE_UPWARD, this is 0
+                    golden_data = np.arange(-32, -64, -1).astype('int32')
+                    if use_int_compute == True and rounding_mode == "FE_UPWARD":
+                        golden_output = np.repeat([-2, -3, -4], [9, 16, 7])
+                    else:
+                        golden_output = np.repeat([-2, -3, -4], [8, 16, 8])
+                    golden_output = np.add(1, golden_output)
+                    verify(func, (golden_data, golden_output))
+
+            # Input zero point
+            for rounding_mode in rounding_modes:
+                for use_int_compute in [True, False]:
+                    func = get_func(data_shape=(32, ),
+                                    data_dtype='int32',
+                                    out_dtype="int32",
+                                    use_int_compute=use_int_compute,
+                                    rounding_mode=rounding_mode,
+                                    input_scale=1,
+                                    output_scale=16,
+                                    input_zero_point=16)
+
+                    # Try positive values
+                    golden_data = np.arange(32, 64, 1).astype('int32')
+                    golden_output = np.repeat([2, 3, 4], [8, 16, 8])
+                    golden_output = np.subtract(golden_output, 1)
+                    verify(func, (golden_data, golden_output))
+
+                    # Try negative values
+                    golden_data = np.arange(-32, -64, -1).astype('int32')
+                    if use_int_compute == True and rounding_mode == "FE_UPWARD":
+                        golden_output = np.repeat([-2, -3, -4], [9, 16, 7])
+                    else:
+                        golden_output = np.repeat([-2, -3, -4], [8, 16, 8])
+                    golden_output = np.subtract(golden_output, 1)
+                    verify(func, (golden_data, golden_output))
+
+
+
+
+        if __name__ == "__main__":
+            same_scale_test()
+            downscale_test()
+            upscale_test()
+            saturation_test()
+            zero_point_test()
+
+    run_tests()
+
+if __name__ == "__main__":
+    test_requantize()