diff --git a/include/tvm/relay/qnn/attrs.h b/include/tvm/relay/qnn/attrs.h index 58707f4da19c..66029c8d80f9 100644 --- a/include/tvm/relay/qnn/attrs.h +++ b/include/tvm/relay/qnn/attrs.h @@ -135,6 +135,10 @@ struct QnnConv2DAttrs : public tvm::AttrsNode { // Quantization related attributes. int32_t input_zero_point; int32_t kernel_zero_point; + // The input tensor scale and kernel tensor scales are stored + // for easy access to this information. + double input_scale; + double kernel_scale; TVM_DECLARE_ATTRS(QnnConv2DAttrs, "relay.attrs.QnnConv2DAttrs") { TVM_ATTR_FIELD(strides).set_default(Array({1, 1})) @@ -177,6 +181,10 @@ struct QnnConv2DAttrs : public tvm::AttrsNode { .describe("The zero point of the input tensor."); TVM_ATTR_FIELD(kernel_zero_point) .describe("The zero point of the kernel tensor."); + TVM_ATTR_FIELD(input_scale) + .describe("The quantization scale for the input tensor."); + TVM_ATTR_FIELD(kernel_scale) + .describe("The quantization scale for the weight tensor."); } }; @@ -212,6 +220,8 @@ struct QnnDenseAttrs : public tvm::AttrsNode { // Quantization related attributes. int32_t input_zero_point; int32_t kernel_zero_point; + double input_scale; + double kernel_scale; TVM_DECLARE_ATTRS(QnnDenseAttrs, "relay.attrs.QnnDenseAttrs") { TVM_ATTR_FIELD(units) @@ -222,6 +232,10 @@ struct QnnDenseAttrs : public tvm::AttrsNode { .describe("The zero point of the input tensor."); TVM_ATTR_FIELD(kernel_zero_point) .describe("The zero point of the kernel tensor."); + TVM_ATTR_FIELD(input_scale) + .describe("The input tensor scale."); + TVM_ATTR_FIELD(kernel_scale) + .describe("The kernel tensor scale."); } }; diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 8966aa6b389e..3fb39e3d387c 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -730,9 +730,13 @@ def convert_fully_connected(self, op): weight_expr = self.exp_tab.new_const(weight_value, dtype=weight_tensor_type_str) if input_tensor.qnn_params: + input_scale = input_tensor.qnn_params['scale'] + kernel_scale = weight_tensor.qnn_params['scale'] out = _qnn.op.dense(in_expr, weight_expr, input_zero_point=input_tensor.qnn_params['zero_point'], kernel_zero_point=weight_tensor.qnn_params['zero_point'], + input_scale=input_scale, + kernel_scale=kernel_scale, out_dtype='int32') else: out = _op.nn.dense(in_expr, weight_expr) @@ -936,6 +940,8 @@ def convert_conv(self, op, conv_type): qnn_conv2d_params['input_zero_point'] = input_tensor.qnn_params['zero_point'] qnn_conv2d_params['kernel_zero_point'] = weight_tensor.qnn_params['zero_point'] qnn_conv2d_params['out_dtype'] = 'int32' + qnn_conv2d_params['input_scale'] = input_tensor.qnn_params['scale'] + qnn_conv2d_params['kernel_scale'] = weight_tensor.qnn_params['scale'] out = _qnn.op.conv2d(in_expr, weight_expr, **qnn_conv2d_params) else: out = _op.nn.conv2d(in_expr, weight_expr, **params) diff --git a/python/tvm/relay/qnn/op/legalizations.py b/python/tvm/relay/qnn/op/legalizations.py index 3f94f98a03a5..1bb28d8a48c8 100644 --- a/python/tvm/relay/qnn/op/legalizations.py +++ b/python/tvm/relay/qnn/op/legalizations.py @@ -88,6 +88,8 @@ def helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay_op): new_attrs = {k : attrs[k] for k in attrs.keys()} del new_attrs['kernel_zero_point'] del new_attrs['input_zero_point'] + del new_attrs['input_scale'] + del new_attrs['kernel_scale'] return relay_op(shift_data, shift_kernel, **new_attrs) # Helper function to change dtypes to uint8 x int8. Intel VNNI instructions prefer this setting. diff --git a/python/tvm/relay/qnn/op/qnn.py b/python/tvm/relay/qnn/op/qnn.py index 7faf62b4be14..c79c99c80d44 100644 --- a/python/tvm/relay/qnn/op/qnn.py +++ b/python/tvm/relay/qnn/op/qnn.py @@ -189,6 +189,8 @@ def conv2d(data, kernel, input_zero_point, kernel_zero_point, + input_scale, + kernel_scale, strides=(1, 1), padding=(0, 0), dilation=(1, 1), @@ -219,6 +221,16 @@ def conv2d(data, input_zero_point: int The zero point of the data distribution. + input_scale: float + The scale for the input tensor. The scale for the input tensor is + stored purely for convenience here. See more commentary below. + + kernel_scale: float + The scale for the weight tensor. The scale for the weight tensor is + stored for access to this during relay. This information is not + needed in the pass pipeline after qnn.conv2d is lowered to the + sequence of steps as in nn.conv2d. See also input_scale in Requantize. + kernel_zero_point: int The zero point of the quantized_kernel distribution. @@ -260,6 +272,7 @@ def conv2d(data, return _make.conv2d(data, kernel, input_zero_point, kernel_zero_point, + input_scale, kernel_scale, strides, padding, dilation, groups, channels, kernel_size, data_layout, kernel_layout, out_layout, out_dtype) @@ -317,6 +330,8 @@ def dense(data, weight, input_zero_point, kernel_zero_point, + input_scale, + kernel_scale, units=None, out_dtype="int32"): """Qnn Dense operator. @@ -332,6 +347,17 @@ def dense(data, The quantized input data to the operator. weight : tvm.relay.Expr The quantized weight expressions. + input_zero_point: int + The input zero point. + kernel_zero_point: int + The kernel zero point. + input_scale: float + The scale for the input tensor. + kernel_scale: float + The scale for the weight tensor. The scale for the weight tensor is + stored for access to this during relay. This information is not + needed in the pass pipeline after qnn.conv2d is lowered to the + sequence of steps as in nn.conv2d. See also input_scale in Requantize. units : int, optional Number of hidden units of the dense transformation. out_dtype : str, optional @@ -345,9 +371,11 @@ def dense(data, return _make.dense(data, weight, - units, input_zero_point, kernel_zero_point, + input_scale, + kernel_scale, + units, out_dtype) diff --git a/src/relay/qnn/op/convolution.cc b/src/relay/qnn/op/convolution.cc index 75cac5d0f120..9cbb415e6131 100644 --- a/src/relay/qnn/op/convolution.cc +++ b/src/relay/qnn/op/convolution.cc @@ -440,7 +440,8 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array& new_args, // Positional relay function to create quantized conv2d operator // used by frontend FFI. Expr MakeQnnConv2D(Expr data, Expr weight, int32_t input_zero_point, int32_t kernel_zero_point, - Array strides, Array padding, Array dilation, + double input_scale, double kernel_scale, Array strides, + Array padding, Array dilation, int groups, IndexExpr channels, Array kernel_size, std::string data_layout, std::string kernel_layout, std::string out_layout, DataType out_dtype) { @@ -457,6 +458,8 @@ Expr MakeQnnConv2D(Expr data, Expr weight, int32_t input_zero_point, int32_t ker attrs->out_dtype = std::move(out_dtype); attrs->input_zero_point = std::move(input_zero_point); attrs->kernel_zero_point = std::move(kernel_zero_point); + attrs->input_scale = std::move(input_scale); + attrs->kernel_scale = std::move(kernel_scale); static const Op& op = Op::Get("qnn.conv2d"); return CallNode::make(op, {data, weight}, Attrs(attrs), {}); } diff --git a/src/relay/qnn/op/dense.cc b/src/relay/qnn/op/dense.cc index 9ac01212dc37..fdf50d2b9141 100644 --- a/src/relay/qnn/op/dense.cc +++ b/src/relay/qnn/op/dense.cc @@ -57,13 +57,17 @@ bool QnnDenseRel(const Array& types, int num_inputs, const Attrs& attrs, } // Positional relay function to create quantized dense operator used by frontend FFI. -Expr MakeQuantizedDense(Expr data, Expr weight, IndexExpr units, int32_t input_zero_point, - int32_t kernel_zero_point, DataType out_dtype) { +Expr MakeQuantizedDense(Expr data, Expr weight, int32_t input_zero_point, + int32_t kernel_zero_point, double input_scale, + double kernel_scale, IndexExpr units, + DataType out_dtype) { auto attrs = make_node(); attrs->units = std::move(units); attrs->out_dtype = out_dtype; attrs->input_zero_point = input_zero_point; attrs->kernel_zero_point = kernel_zero_point; + attrs->input_scale = input_scale; + attrs->kernel_scale = kernel_scale; static const Op& op = Op::Get("qnn.dense"); return CallNode::make(op, {data, weight}, Attrs(attrs), {}); } diff --git a/tests/python/relay/test_op_qnn_conv2d.py b/tests/python/relay/test_op_qnn_conv2d.py index 71368f84d023..165e0e720f6a 100644 --- a/tests/python/relay/test_op_qnn_conv2d.py +++ b/tests/python/relay/test_op_qnn_conv2d.py @@ -26,6 +26,8 @@ def get_ref_func(data, kernel, input_zero_point, kernel_zero_point, + input_scale, + kernel_scale, kernel_size, padding, strides, @@ -56,6 +58,8 @@ def get_qnn_func(data, kernel, input_zero_point, kernel_zero_point, + input_scale, + kernel_scale, kernel_size, padding, strides, @@ -67,6 +71,8 @@ def get_qnn_func(data, data, kernel, input_zero_point=input_zero_point, kernel_zero_point=kernel_zero_point, + input_scale=input_scale, + kernel_scale=kernel_scale, kernel_size=kernel_size, strides=strides, dilation=dilation, @@ -85,6 +91,8 @@ def get_funcs(data_shape, kernel_dtype, input_zero_point, kernel_zero_point, + input_scale, + kernel_scale, kernel_size, padding, strides, @@ -100,6 +108,8 @@ def get_funcs(data_shape, kernel, input_zero_point, kernel_zero_point, + input_scale, + kernel_scale, kernel_size, padding, strides, @@ -112,6 +122,8 @@ def get_funcs(data_shape, kernel, input_zero_point, kernel_zero_point, + input_scale, + kernel_scale, kernel_size, padding, strides, @@ -172,6 +184,8 @@ def test_no_zero_point(): kernel_dtype=kernel_dtype, input_zero_point=0, kernel_zero_point=0, + input_scale=1.0, + kernel_scale=1.0, kernel_size=(2, 2), padding=(0, 0), strides=(1, 1), @@ -193,6 +207,8 @@ def test_no_zero_point(): kernel_dtype=kernel_dtype, input_zero_point=0, kernel_zero_point=0, + input_scale=1.0, + kernel_scale=1.0, kernel_size=(2, 2), padding=(0, 0), strides=(1, 1), @@ -215,6 +231,8 @@ def test_kernel_zero_point(): kernel_dtype=kernel_dtype, input_zero_point=0, kernel_zero_point=1, + input_scale=1.0, + kernel_scale=1.0, kernel_size=(2, 2), padding=(0, 0), strides=(1, 1), @@ -236,6 +254,8 @@ def test_kernel_zero_point(): kernel_dtype=kernel_dtype, input_zero_point=0, kernel_zero_point=5, + input_scale=1.0, + kernel_scale=1.0, kernel_size=(2, 2), padding=(0, 0), strides=(1, 1), @@ -259,6 +279,8 @@ def test_input_zero_point(): kernel_dtype=kernel_dtype, input_zero_point=5, kernel_zero_point=0, + input_scale=1.0, + kernel_scale=1.0, kernel_size=(2, 2), padding=(0, 0), strides=(1, 1), @@ -280,6 +302,8 @@ def test_input_zero_point(): kernel_dtype=kernel_dtype, input_zero_point=5, kernel_zero_point=0, + input_scale=1.0, + kernel_scale=1.0, kernel_size=(2, 2), padding=(0, 0), strides=(1, 1), @@ -302,6 +326,8 @@ def test_both_zero_point(): kernel_dtype=kernel_dtype, input_zero_point=5, kernel_zero_point=3, + input_scale=1.0, + kernel_scale=1.0, kernel_size=(2, 2), padding=(0, 0), strides=(1, 1), @@ -323,6 +349,8 @@ def test_both_zero_point(): kernel_dtype=kernel_dtype, input_zero_point=5, kernel_zero_point=3, + input_scale=1.0, + kernel_scale=1.0, kernel_size=(2, 2), padding=(0, 0), strides=(1, 1), @@ -345,6 +373,8 @@ def test_layout(): kernel_dtype=kernel_dtype, input_zero_point=5, kernel_zero_point=3, + input_scale=1.0, + kernel_scale=1.0, kernel_size=(2, 2), padding=(0, 0), strides=(1, 1), @@ -366,6 +396,8 @@ def test_layout(): kernel_dtype=kernel_dtype, input_zero_point=5, kernel_zero_point=3, + input_scale=1.0, + kernel_scale=1.0, kernel_size=(2, 2), padding=(0, 0), strides=(1, 1), @@ -390,6 +422,8 @@ def test_padding(): kernel_dtype=kernel_dtype, input_zero_point=8, kernel_zero_point=5, + input_scale=1.0, + kernel_scale=1.0, kernel_size=(2, 2), padding=(1, 1), strides=(1, 1), @@ -411,6 +445,8 @@ def test_padding(): kernel_dtype=kernel_dtype, input_zero_point=8, kernel_zero_point=3, + input_scale=1.0, + kernel_scale=1.0, kernel_size=(2, 2), padding=(1, 1), strides=(1, 1), @@ -433,6 +469,8 @@ def test_dilation(): kernel_dtype=kernel_dtype, input_zero_point=5, kernel_zero_point=3, + input_scale=1.0, + kernel_scale=1.0, kernel_size=(2, 2), padding=(0, 0), strides=(1, 1), @@ -460,6 +498,8 @@ def test_const_folding(): input_zero_point=8, kernel_zero_point=3, kernel_size=(2, 2), + input_scale=1.0, + kernel_scale=1.0, padding=(0, 0), strides=(1, 1), dilation=(1, 1), @@ -482,6 +522,8 @@ def test_kernel_size_1x1(): kernel_dtype=kernel_dtype, input_zero_point=5, kernel_zero_point=3, + input_scale=1.0, + kernel_scale=1.0, kernel_size=(1, 1), padding=(0, 0), strides=(1, 1), @@ -505,6 +547,8 @@ def test_tflite_large_irregular(): kernel_dtype=kernel_dtype, input_zero_point=127, kernel_zero_point=127, + input_scale=1.0, + kernel_scale=1.0, kernel_size=(1, 1), padding=(0, 0), strides=(1, 1), @@ -536,6 +580,8 @@ def test_tflite_output_multiplier_greater_than_one(): data_dtype=data_dtype, kernel_shape=kernel_shape, kernel_dtype=kernel_dtype, + input_scale=1.0, + kernel_scale=1.0, input_zero_point=128, kernel_zero_point=128, kernel_size=(2, 2), @@ -582,6 +628,8 @@ def test_tflite_anistropic_strides(): kernel_dtype=kernel_dtype, input_zero_point=127, kernel_zero_point=127, + input_scale=1.0, + kernel_scale=1.0, kernel_size=(2, 2), padding=(0, 0), strides=(1, 3), @@ -619,6 +667,8 @@ def test_broadcast_layout(): kernel_dtype=kernel_dtype, input_zero_point=8, kernel_zero_point=3, + input_scale=1.0, + kernel_scale=1.0, kernel_size=(7, 7), padding=(1, 1), strides=(1, 1), diff --git a/tests/python/relay/test_op_qnn_dense.py b/tests/python/relay/test_op_qnn_dense.py index 885fe37d29bb..103399d39f9d 100644 --- a/tests/python/relay/test_op_qnn_dense.py +++ b/tests/python/relay/test_op_qnn_dense.py @@ -31,13 +31,15 @@ def make_requantize_params(input_scale, output_scale, output_zero_point, out_dty return config -def make_configuration(quantized_data, +def make_configuration(quantized_data, quantized_kernel, dtype, input_shape, kernel_shape, input_zero_point, kernel_zero_point, + input_scale, + kernel_scale, units, output, out_dtype='int32', @@ -53,6 +55,8 @@ def make_configuration(quantized_data, 'kernel_shape': kernel_shape, 'input_zero_point': input_zero_point, 'kernel_zero_point': kernel_zero_point, + 'input_scale': input_scale, + 'kernel_scale': kernel_scale, 'units': units, 'output': output, 'out_dtype': out_dtype, @@ -65,6 +69,9 @@ def make_configuration(quantized_data, def make_uint_configuration(use_bias=False, requantize_output=False): input_shape, kernel_shape, output_shape = (2, 10), (3,10), (2, 3) input_zero_point, kernel_zero_point = 127, 127 + input_scale = 0.5 + kernel_scale = 0.5 + output_scale = 1.0 in_dtype = 'uint8' out_dtype = 'int32' if not requantize_output else 'uint8' units = 3 @@ -78,7 +85,7 @@ def make_uint_configuration(use_bias=False, requantize_output=False): .astype(in_dtype) \ .reshape(kernel_shape) bias = np.array([4, 8, 12]).astype(out_dtype).reshape((units, )) if use_bias else None - requant_params = make_requantize_params(0.25, 1.0, 127, 'uint8') if requantize_output else None + requant_params = make_requantize_params(input_scale * kernel_scale, output_scale, 127, 'uint8') if requantize_output else None if requantize_output: assert use_bias @@ -95,6 +102,8 @@ def make_uint_configuration(use_bias=False, requantize_output=False): kernel_shape=kernel_shape, input_zero_point=input_zero_point, kernel_zero_point=kernel_zero_point, + input_scale=input_scale, + kernel_scale= kernel_scale, units=units, output=output, bias=bias, @@ -116,8 +125,11 @@ def make_int_configuration(use_bias=False, requantize_output=False): 1, 3, 5, 7, 9, 11, 13, 15, 17, 19]) \ .astype(in_dtype) \ .reshape(kernel_shape) + input_scale = 0.5 + kernel_scale = 0.5 + output_scale = 1.0 bias = np.array([4, 8, 12]).astype(out_dtype).reshape((units, )) if use_bias else None - requant_params = make_requantize_params(0.25, 1.0, -1, 'int8') if requantize_output else None + requant_params = make_requantize_params(input_scale * kernel_scale, output_scale, -1, 'int8') if requantize_output else None if requantize_output: assert use_bias @@ -134,6 +146,8 @@ def make_int_configuration(use_bias=False, requantize_output=False): kernel_shape=kernel_shape, input_zero_point=input_zero_point, kernel_zero_point=kernel_zero_point, + input_scale=input_scale, + kernel_scale=kernel_scale, units=units, output=output, bias=bias, @@ -158,6 +172,8 @@ def qnn_dense_driver(test_configuration): quantized_kernel, test_configuration['input_zero_point'], test_configuration['kernel_zero_point'], + test_configuration['input_scale'], + test_configuration['kernel_scale'], test_configuration['units']) if test_configuration[bias_name] is not None: bias = relay.var(bias_name, diff --git a/tests/python/relay/test_pass_qnn_legalize.py b/tests/python/relay/test_pass_qnn_legalize.py index 8ace7bc745ea..6e092fad6fc4 100644 --- a/tests/python/relay/test_pass_qnn_legalize.py +++ b/tests/python/relay/test_pass_qnn_legalize.py @@ -103,6 +103,8 @@ def _get_mod(data_dtype, kernel_dtype): data, kernel, input_zero_point=1, kernel_zero_point=1, + input_scale=1.0, + kernel_scale=1.0, kernel_size=(3, 3), strides=(1, 1), dilation=(1, 1), @@ -185,6 +187,8 @@ def _get_mod(data_dtype, kernel_dtype): data, kernel, input_zero_point=1, kernel_zero_point=1, + input_scale=1, + kernel_scale=1, out_dtype='int32') mod = relay.Function(relay.analysis.free_vars(func), func)