Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TFLite] QNN support for TFLite 2.1.0 quantized models #5848

Merged
merged 9 commits into from
Jul 3, 2020
Merged
195 changes: 150 additions & 45 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,42 +244,92 @@ def get_tensors(self, tensors_idx_list):
qnn_params = None
tflite_qnn_params = tensor.Quantization()
if tflite_qnn_params is not None:
scale = float(tflite_qnn_params.ScaleAsNumpy())
zero_point = int(tflite_qnn_params.ZeroPointAsNumpy())
# TFLite supports both per-tensor and per-axis (aka channel) quantization. For
# per-tensor quantization, scale and zero points are scalar values. For per-axis
# quantization, scale and zero points for the weights are tensors (activations are
# per-tensor quantized). However, the TFLite quantization spec puts restrictions on
# zero points for per-axis quantization. Specifically, the zero point is a tensor
# but all values are 0. More information can be found here -
# https://www.tensorflow.org/lite/performance/quantization_spec

tflite_scale = tflite_qnn_params.ScaleAsNumpy()
tflite_zero_point = tflite_qnn_params.ZeroPointAsNumpy()
is_qnn_params_valid = True

# Handle Per-axis and per-tensor cases
if isinstance(tflite_scale, np.ndarray):
assert isinstance(tflite_zero_point, np.ndarray)

# Tensor - Per-axis quantization
if tflite_scale.size != 1 and tflite_zero_point.size != 1:
scale = tflite_scale
# Ensure that all zero points are zeros
zero_point = tflite_zero_point
if not np.all(zero_point == 0):
raise tvm.error.OpAttributeInvalid(\
"TFLite per-axis quantization restricts all zero points to be"
+ " 0, but a non-zero value is observed")
zero_point = int(zero_point[0])

# Scalar - Per-tensor quantization
elif tflite_scale.size == 1 and tflite_zero_point.size == 1:
scale = float(tflite_scale[0])
zero_point = int(tflite_zero_point[0])

anijain2305 marked this conversation as resolved.
Show resolved Hide resolved
else:
raise NotImplementedError(\
"Quantized type {} (scale) and {} (zero point) not supported"
.format(type(tflite_scale), type(tflite_zero_point)))
elif tflite_scale == 0 and tflite_zero_point == 0:
# Handle corner case for ops like quantized reshape whose second operand (shape)
# has zero scale and zero zero point. This is not used.
is_qnn_params_valid = False
else:
raise NotImplementedError("Quantized type {} not supported"
.format(type(tflite_scale)))

# Check that the scale and zero points are valid.
if scale != 0 or zero_point != 0:
if is_qnn_params_valid:
qnn_params = dict()
qnn_params['scale'] = relay.const(scale, 'float32')
qnn_params['zero_point'] = relay.const(zero_point, 'int32')
return_list.append(TensorWrapper(tensor_idx, tensor, buffer, qnn_params))
return return_list

def get_tensor_value(self, tensor_wrapper):
"""Get tensor buffer value from given tensor wrapper"""

def get_tensor_type_as_numpy(self, tensor_wrapper):
"""Returns np.dtype out of TensorType"""
assert isinstance(tensor_wrapper, TensorWrapper)

try:
from tflite.TensorType import TensorType
return {TensorType.UINT8: np.uint8,
TensorType.INT8: np.int8,
TensorType.FLOAT32: np.float32,
TensorType.INT32: np.int32,
TensorType.INT64: np.int64,
TensorType.BOOL: np.bool_}[tensor_wrapper.tensor.Type()]
except ImportError:
raise ImportError("The tflite package must be installed")
except KeyError:
raise NotImplementedError("Tensor type '{}' currently not supported"
.format(tensor_wrapper.tensor.Type()))


def get_tensor_value(self, tensor_wrapper):
"""Get tensor buffer value from given tensor wrapper"""
assert isinstance(tensor_wrapper, TensorWrapper)

dtype = self.get_tensor_type_as_numpy(tensor_wrapper)
data = tensor_wrapper.buffer.DataAsNumpy()

if tensor_wrapper.tensor.ShapeLength() != 0:
shape = tensor_wrapper.tensor.ShapeAsNumpy()
else:
shape = []

return np.frombuffer(data, dtype=dtype).reshape(shape)

if tensor_wrapper.tensor.Type() == TensorType.UINT8:
return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.uint8).reshape(
tensor_wrapper.tensor.ShapeAsNumpy())
if tensor_wrapper.tensor.Type() == TensorType.FLOAT32:
return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.float32).reshape(
tensor_wrapper.tensor.ShapeAsNumpy())
if tensor_wrapper.tensor.Type() == TensorType.INT32:
return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.int32).reshape(
tensor_wrapper.tensor.ShapeAsNumpy())
if tensor_wrapper.tensor.Type() == TensorType.INT64:
return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.int64).reshape(
tensor_wrapper.tensor.ShapeAsNumpy())
if tensor_wrapper.tensor.Type() == TensorType.BOOL:
return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.bool_).reshape(
tensor_wrapper.tensor.ShapeAsNumpy())
raise NotImplementedError("Tensor type {} is currently not supported"
.format(str(tensor_wrapper.tensor.Type())))

def get_tensor_type_str(self, tensor_type):
"""Get tensor type string representation when given TFLite tensor type"""
Expand Down Expand Up @@ -651,12 +701,43 @@ def convert_shape(self, op):

def convert_relu(self, op):
"""Convert TFLite ReLU"""
try:
from tflite.ActivationFunctionType import ActivationFunctionType
except ImportError:
raise ImportError("The tflite package must be installed")

Comment on lines +704 to +708
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is unnecessary given the import of ActivationFunctionType in the constructor here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tried this but it failed, the scope of imports is limited to the functions in which they are imported.

input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"

input_tensor = input_tensors[0]
in_expr = self.get_expr(input_tensor.tensor_idx)
out = _op.nn.relu(in_expr)

output_tensors = self.get_output_tensors(op)
assert len(output_tensors) == 1, "output tensors length should be 1"
output_tensor = output_tensors[0]

if input_tensor.qnn_params:
# Quantize a float value to an quantized integer value
scale_val = get_scalar_from_constant(input_tensor.qnn_params['scale'])
zero_point_val = get_scalar_from_constant(input_tensor.qnn_params['zero_point'])

output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type())
out = self.convert_qnn_fused_activation_function(\
expr=in_expr,
fused_activation_fn=ActivationFunctionType.RELU,
scale=scale_val,
zero_point=zero_point_val,
dtype=output_tensor_type_str)
else:
out = _op.nn.relu(in_expr)

if output_tensor.qnn_params:
output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type())
out = _qnn.op.requantize(out,
input_scale=input_tensor.qnn_params['scale'],
input_zero_point=input_tensor.qnn_params['zero_point'],
output_scale=output_tensor.qnn_params['scale'],
output_zero_point=output_tensor.qnn_params['zero_point'],
out_dtype=output_tensor_type_str)

return out

Expand Down Expand Up @@ -692,6 +773,11 @@ def _hard_swish(data):

def convert_relu6(self, op):
"""Convert TFLite ReLU6"""
try:
from tflite.ActivationFunctionType import ActivationFunctionType
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as relu, I think this is unnecessary given the import of ActivationFunctionType in the constructor here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as before

except ImportError:
raise ImportError("The tflite package must be installed")

input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"
input_tensor = input_tensors[0]
Expand All @@ -705,17 +791,14 @@ def convert_relu6(self, op):
# Quantize a float value to an quantized integer value
scale_val = get_scalar_from_constant(input_tensor.qnn_params['scale'])
zero_point_val = get_scalar_from_constant(input_tensor.qnn_params['zero_point'])
quantize = lambda x: float(int(round(x / scale_val)) + zero_point_val)

# Get min/max of the input dtype. This will be used to ensure that
# clip a_min/a_max are not beyond the dtype range.
input_tensor_type_str = self.get_tensor_type_str(input_tensor.tensor.Type())
qmin = float(tvm.tir.op.min_value(input_tensor_type_str).value)
qmax = float(tvm.tir.op.max_value(input_tensor_type_str).value)

out = _op.clip(in_expr,
a_min=max(qmin, quantize(0)),
a_max=min(qmax, quantize(6.0)))
output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type())
out = self.convert_qnn_fused_activation_function(\
expr=in_expr,
fused_activation_fn=ActivationFunctionType.RELU6,
scale=scale_val,
zero_point=zero_point_val,
dtype=output_tensor_type_str)
else:
out = _op.clip(in_expr, a_min=0, a_max=6)

Expand Down Expand Up @@ -1604,9 +1687,9 @@ def convert_fully_connected(self, op):
fully_connected_options.Init(op_options.Bytes, op_options.Pos)
fused_activation_fn = fully_connected_options.FusedActivationFunction()

# weight tensor type should be UINT8 (quantization) or FLOAT32
# weight tensor type should be INT8/UINT8 (quantization) or FLOAT32
weight_tensor_type = weight_tensor.tensor.Type()
assert weight_tensor_type in (TensorType.UINT8, TensorType.FLOAT32)
assert weight_tensor_type in (TensorType.INT8, TensorType.UINT8, TensorType.FLOAT32)
weight_tensor_type_str = self.get_tensor_type_str(weight_tensor_type)

if self.has_expr(weight_tensor.tensor_idx):
Expand Down Expand Up @@ -1795,9 +1878,9 @@ def convert_conv(self, op, conv_type):
params['channels'] = int(output_channels)
params['kernel_layout'] = 'HWIO'

# weight tensor type should be UINT8 (quantization) or FLOAT32
# weight tensor type should be INT8/UINT8 (quantization) or FLOAT32
weight_tensor_type = weight_tensor.tensor.Type()
assert weight_tensor_type in (TensorType.UINT8, TensorType.FLOAT32)
assert weight_tensor_type in (TensorType.INT8, TensorType.UINT8, TensorType.FLOAT32)
weight_tensor_type_str = self.get_tensor_type_str(weight_tensor_type)

in_expr = self.get_expr(input_tensor_idx)
Expand Down Expand Up @@ -1856,9 +1939,15 @@ def convert_conv(self, op, conv_type):
if output_tensor.qnn_params:
# Calculate the intermediate scale and zero point of the int32 output.
data_scale = input_tensor.qnn_params['scale']
weight_scale = weight_tensor.qnn_params['scale']
data_scale_val = get_scalar_from_constant(data_scale)
weight_scale_val = get_scalar_from_constant(weight_scale)

weight_scale = weight_tensor.qnn_params['scale']
# If weight scale is scalar, it is per-tensor quantization
if isinstance(weight_scale, float):
weight_scale_val = get_scalar_from_constant(weight_scale)
else:
weight_scale_val = get_tensor_from_constant(weight_scale)

new_input_scale_val = data_scale_val * weight_scale_val
new_input_scale = relay.const(new_input_scale_val, 'float32')
new_input_zero_point = relay.const(0, 'int32')
Expand All @@ -1869,7 +1958,8 @@ def convert_conv(self, op, conv_type):
input_zero_point=new_input_zero_point,
output_scale=output_tensor.qnn_params['scale'],
output_zero_point=output_tensor.qnn_params['zero_point'],
out_dtype=output_tensor_type_str)
out_dtype=output_tensor_type_str,
axis=3)

# Call activation function
output_scale_val = get_scalar_from_constant(output_tensor.qnn_params['scale'])
Expand All @@ -1882,7 +1972,6 @@ def convert_conv(self, op, conv_type):
dtype=output_tensor_type_str)
else:
out = self.convert_fused_activation_function(out, fused_activation_fn)

return out

def convert_split(self, op):
Expand Down Expand Up @@ -2594,17 +2683,27 @@ def convert_quantize(self, op):
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"
input_tensor = input_tensors[0]
input_tensor_type_str = self.get_tensor_type_str(input_tensor.tensor.Type())
in_expr = self.get_expr(input_tensor.tensor_idx)

output_tensors = self.get_output_tensors(op)
assert len(output_tensors) == 1, "output tensors length should be 1"
output_tensor = output_tensors[0]
output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type())

# The output must be quantized
assert output_tensor.qnn_params
# Quantize the input
out = self.quantize(in_expr, output_tensor)

# TFLite Quantize op can also act as Requantize op
if input_tensor_type_str == "float32":
out = self.quantize(in_expr, output_tensor)
else:
out = _qnn.op.requantize(in_expr,
input_scale=input_tensor.qnn_params['scale'],
input_zero_point=input_tensor.qnn_params['zero_point'],
output_scale=output_tensor.qnn_params['scale'],
output_zero_point=output_tensor.qnn_params['zero_point'],
out_dtype=output_tensor_type_str)
return out
Comment on lines 2683 to 2707
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This to me looks like it can go in by it's own right as a separate PR but this needs a unit test change in tflite/test_forward.py .

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. I will add a test case in this PR. This will enable us to keep those 5 end to end tests as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

Above test case added to force both types of quantize nodes


def convert_dequantize(self, op):
Expand Down Expand Up @@ -2734,7 +2833,6 @@ def get_tensor_expr(self, tensor):
else:
type_str = self.get_tensor_type_str(tensor.tensor.Type())
expr = self.exp_tab.new_const(self.get_tensor_value(tensor), dtype=type_str)

return expr


Expand All @@ -2747,6 +2845,13 @@ def get_scalar_from_constant(expr):
"value must be float32/int32"
return np.asscalar(value)

def get_tensor_from_constant(expr):
""" Returns tensor of values from Relay constant node. """
assert isinstance(expr, _expr.Constant)
value = expr.data.asnumpy()
assert value.dtype == np.dtype(np.int32) or value.dtype == np.dtype(np.float32), \
"value must be float32/int32"
return value

def build_str_map(obj):
"""Build string map of TFLite enum int value
Expand Down
16 changes: 13 additions & 3 deletions src/relay/qnn/op/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,19 @@ bool QnnConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
CHECK(IsScalarType(types[3], DataType::Int(32))); // kernel_zero_point
CHECK(IsScalarType(types[4], DataType::Float(32))); // input_scale
// Kernel scale can be a vector of length output_channels or a scalar.
size_t axis = param->kernel_layout.find('O');
CHECK(axis != std::string::npos) << "Kernel layout attribute is not defined";
AssignType(types[5], DataType::Float(32), weight->shape[axis], reporter); // kernel scale
if (param->groups == 1) {
size_t axis = param->kernel_layout.find('O');
CHECK(axis != std::string::npos) << "Kernel layout attribute is not defined";
AssignType(types[5], DataType::Float(32), weight->shape[axis], reporter); // kernel scale
} else {
// Here, total number of output channels depend on depth multiplier.
size_t o_axis = param->kernel_layout.find('O');
size_t i_axis = param->kernel_layout.find('I');
CHECK(o_axis != std::string::npos || i_axis != std::string::npos)
<< "Kernel layout attribute is not defined";
AssignType(types[5], DataType::Float(32), weight->shape[i_axis] * weight->shape[o_axis],
reporter); // kernel scale
}

// Collect the input tensor and output tensor devoid of scale and zero points to reuse Relay
// Conv2D infer type function.
Expand Down
Loading