Skip to content

Commit

Permalink
[TFLite] QNN support for TFLite 2.1.0 quantized models (apache#5848)
Browse files Browse the repository at this point in the history
* [TFLite] TFLite 2.x parser quantization support.

* Address comments. Fix a bug for depthwise conv

* Added tests for relu, conv, quantize. Address comments.

* Using web-data. Minor refactoring.

* Removing TF hub package

* Trigger CI.

* Handle TFLite input layer naming.

* Addressing reviews.

* Retrigger CI.
  • Loading branch information
anijain2305 authored and Trevor Morris committed Jul 14, 2020
1 parent 37ca901 commit 715402e
Show file tree
Hide file tree
Showing 3 changed files with 419 additions and 97 deletions.
195 changes: 150 additions & 45 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,42 +244,92 @@ def get_tensors(self, tensors_idx_list):
qnn_params = None
tflite_qnn_params = tensor.Quantization()
if tflite_qnn_params is not None:
scale = float(tflite_qnn_params.ScaleAsNumpy())
zero_point = int(tflite_qnn_params.ZeroPointAsNumpy())
# TFLite supports both per-tensor and per-axis (aka channel) quantization. For
# per-tensor quantization, scale and zero points are scalar values. For per-axis
# quantization, scale and zero points for the weights are tensors (activations are
# per-tensor quantized). However, the TFLite quantization spec puts restrictions on
# zero points for per-axis quantization. Specifically, the zero point is a tensor
# but all values are 0. More information can be found here -
# https://www.tensorflow.org/lite/performance/quantization_spec

tflite_scale = tflite_qnn_params.ScaleAsNumpy()
tflite_zero_point = tflite_qnn_params.ZeroPointAsNumpy()
is_qnn_params_valid = True

# Handle Per-axis and per-tensor cases
if isinstance(tflite_scale, np.ndarray):
assert isinstance(tflite_zero_point, np.ndarray)

# Tensor - Per-axis quantization
if tflite_scale.size != 1 and tflite_zero_point.size != 1:
scale = tflite_scale
# Ensure that all zero points are zeros
zero_point = tflite_zero_point
if not np.all(zero_point == 0):
raise tvm.error.OpAttributeInvalid(\
"TFLite per-axis quantization restricts all zero points to be"
+ " 0, but a non-zero value is observed")
zero_point = int(zero_point[0])

# Scalar - Per-tensor quantization
elif tflite_scale.size == 1 and tflite_zero_point.size == 1:
scale = float(tflite_scale[0])
zero_point = int(tflite_zero_point[0])

else:
raise NotImplementedError(\
"Quantized type {} (scale) and {} (zero point) not supported"
.format(type(tflite_scale), type(tflite_zero_point)))
elif tflite_scale == 0 and tflite_zero_point == 0:
# Handle corner case for ops like quantized reshape whose second operand (shape)
# has zero scale and zero zero point. This is not used.
is_qnn_params_valid = False
else:
raise NotImplementedError("Quantized type {} not supported"
.format(type(tflite_scale)))

# Check that the scale and zero points are valid.
if scale != 0 or zero_point != 0:
if is_qnn_params_valid:
qnn_params = dict()
qnn_params['scale'] = relay.const(scale, 'float32')
qnn_params['zero_point'] = relay.const(zero_point, 'int32')
return_list.append(TensorWrapper(tensor_idx, tensor, buffer, qnn_params))
return return_list

def get_tensor_value(self, tensor_wrapper):
"""Get tensor buffer value from given tensor wrapper"""

def get_tensor_type_as_numpy(self, tensor_wrapper):
"""Returns np.dtype out of TensorType"""
assert isinstance(tensor_wrapper, TensorWrapper)

try:
from tflite.TensorType import TensorType
return {TensorType.UINT8: np.uint8,
TensorType.INT8: np.int8,
TensorType.FLOAT32: np.float32,
TensorType.INT32: np.int32,
TensorType.INT64: np.int64,
TensorType.BOOL: np.bool_}[tensor_wrapper.tensor.Type()]
except ImportError:
raise ImportError("The tflite package must be installed")
except KeyError:
raise NotImplementedError("Tensor type '{}' currently not supported"
.format(tensor_wrapper.tensor.Type()))


def get_tensor_value(self, tensor_wrapper):
"""Get tensor buffer value from given tensor wrapper"""
assert isinstance(tensor_wrapper, TensorWrapper)

dtype = self.get_tensor_type_as_numpy(tensor_wrapper)
data = tensor_wrapper.buffer.DataAsNumpy()

if tensor_wrapper.tensor.ShapeLength() != 0:
shape = tensor_wrapper.tensor.ShapeAsNumpy()
else:
shape = []

return np.frombuffer(data, dtype=dtype).reshape(shape)

if tensor_wrapper.tensor.Type() == TensorType.UINT8:
return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.uint8).reshape(
tensor_wrapper.tensor.ShapeAsNumpy())
if tensor_wrapper.tensor.Type() == TensorType.FLOAT32:
return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.float32).reshape(
tensor_wrapper.tensor.ShapeAsNumpy())
if tensor_wrapper.tensor.Type() == TensorType.INT32:
return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.int32).reshape(
tensor_wrapper.tensor.ShapeAsNumpy())
if tensor_wrapper.tensor.Type() == TensorType.INT64:
return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.int64).reshape(
tensor_wrapper.tensor.ShapeAsNumpy())
if tensor_wrapper.tensor.Type() == TensorType.BOOL:
return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.bool_).reshape(
tensor_wrapper.tensor.ShapeAsNumpy())
raise NotImplementedError("Tensor type {} is currently not supported"
.format(str(tensor_wrapper.tensor.Type())))

def get_tensor_type_str(self, tensor_type):
"""Get tensor type string representation when given TFLite tensor type"""
Expand Down Expand Up @@ -651,12 +701,43 @@ def convert_shape(self, op):

def convert_relu(self, op):
"""Convert TFLite ReLU"""
try:
from tflite.ActivationFunctionType import ActivationFunctionType
except ImportError:
raise ImportError("The tflite package must be installed")

input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"

input_tensor = input_tensors[0]
in_expr = self.get_expr(input_tensor.tensor_idx)
out = _op.nn.relu(in_expr)

output_tensors = self.get_output_tensors(op)
assert len(output_tensors) == 1, "output tensors length should be 1"
output_tensor = output_tensors[0]

if input_tensor.qnn_params:
# Quantize a float value to an quantized integer value
scale_val = get_scalar_from_constant(input_tensor.qnn_params['scale'])
zero_point_val = get_scalar_from_constant(input_tensor.qnn_params['zero_point'])

output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type())
out = self.convert_qnn_fused_activation_function(\
expr=in_expr,
fused_activation_fn=ActivationFunctionType.RELU,
scale=scale_val,
zero_point=zero_point_val,
dtype=output_tensor_type_str)
else:
out = _op.nn.relu(in_expr)

if output_tensor.qnn_params:
output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type())
out = _qnn.op.requantize(out,
input_scale=input_tensor.qnn_params['scale'],
input_zero_point=input_tensor.qnn_params['zero_point'],
output_scale=output_tensor.qnn_params['scale'],
output_zero_point=output_tensor.qnn_params['zero_point'],
out_dtype=output_tensor_type_str)

return out

Expand Down Expand Up @@ -692,6 +773,11 @@ def _hard_swish(data):

def convert_relu6(self, op):
"""Convert TFLite ReLU6"""
try:
from tflite.ActivationFunctionType import ActivationFunctionType
except ImportError:
raise ImportError("The tflite package must be installed")

input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"
input_tensor = input_tensors[0]
Expand All @@ -705,17 +791,14 @@ def convert_relu6(self, op):
# Quantize a float value to an quantized integer value
scale_val = get_scalar_from_constant(input_tensor.qnn_params['scale'])
zero_point_val = get_scalar_from_constant(input_tensor.qnn_params['zero_point'])
quantize = lambda x: float(int(round(x / scale_val)) + zero_point_val)

# Get min/max of the input dtype. This will be used to ensure that
# clip a_min/a_max are not beyond the dtype range.
input_tensor_type_str = self.get_tensor_type_str(input_tensor.tensor.Type())
qmin = float(tvm.tir.op.min_value(input_tensor_type_str).value)
qmax = float(tvm.tir.op.max_value(input_tensor_type_str).value)

out = _op.clip(in_expr,
a_min=max(qmin, quantize(0)),
a_max=min(qmax, quantize(6.0)))
output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type())
out = self.convert_qnn_fused_activation_function(\
expr=in_expr,
fused_activation_fn=ActivationFunctionType.RELU6,
scale=scale_val,
zero_point=zero_point_val,
dtype=output_tensor_type_str)
else:
out = _op.clip(in_expr, a_min=0, a_max=6)

Expand Down Expand Up @@ -1604,9 +1687,9 @@ def convert_fully_connected(self, op):
fully_connected_options.Init(op_options.Bytes, op_options.Pos)
fused_activation_fn = fully_connected_options.FusedActivationFunction()

# weight tensor type should be UINT8 (quantization) or FLOAT32
# weight tensor type should be INT8/UINT8 (quantization) or FLOAT32
weight_tensor_type = weight_tensor.tensor.Type()
assert weight_tensor_type in (TensorType.UINT8, TensorType.FLOAT32)
assert weight_tensor_type in (TensorType.INT8, TensorType.UINT8, TensorType.FLOAT32)
weight_tensor_type_str = self.get_tensor_type_str(weight_tensor_type)

if self.has_expr(weight_tensor.tensor_idx):
Expand Down Expand Up @@ -1795,9 +1878,9 @@ def convert_conv(self, op, conv_type):
params['channels'] = int(output_channels)
params['kernel_layout'] = 'HWIO'

# weight tensor type should be UINT8 (quantization) or FLOAT32
# weight tensor type should be INT8/UINT8 (quantization) or FLOAT32
weight_tensor_type = weight_tensor.tensor.Type()
assert weight_tensor_type in (TensorType.UINT8, TensorType.FLOAT32)
assert weight_tensor_type in (TensorType.INT8, TensorType.UINT8, TensorType.FLOAT32)
weight_tensor_type_str = self.get_tensor_type_str(weight_tensor_type)

in_expr = self.get_expr(input_tensor_idx)
Expand Down Expand Up @@ -1856,9 +1939,15 @@ def convert_conv(self, op, conv_type):
if output_tensor.qnn_params:
# Calculate the intermediate scale and zero point of the int32 output.
data_scale = input_tensor.qnn_params['scale']
weight_scale = weight_tensor.qnn_params['scale']
data_scale_val = get_scalar_from_constant(data_scale)
weight_scale_val = get_scalar_from_constant(weight_scale)

weight_scale = weight_tensor.qnn_params['scale']
# If weight scale is scalar, it is per-tensor quantization
if isinstance(weight_scale, float):
weight_scale_val = get_scalar_from_constant(weight_scale)
else:
weight_scale_val = get_tensor_from_constant(weight_scale)

new_input_scale_val = data_scale_val * weight_scale_val
new_input_scale = relay.const(new_input_scale_val, 'float32')
new_input_zero_point = relay.const(0, 'int32')
Expand All @@ -1869,7 +1958,8 @@ def convert_conv(self, op, conv_type):
input_zero_point=new_input_zero_point,
output_scale=output_tensor.qnn_params['scale'],
output_zero_point=output_tensor.qnn_params['zero_point'],
out_dtype=output_tensor_type_str)
out_dtype=output_tensor_type_str,
axis=3)

# Call activation function
output_scale_val = get_scalar_from_constant(output_tensor.qnn_params['scale'])
Expand All @@ -1882,7 +1972,6 @@ def convert_conv(self, op, conv_type):
dtype=output_tensor_type_str)
else:
out = self.convert_fused_activation_function(out, fused_activation_fn)

return out

def convert_split(self, op):
Expand Down Expand Up @@ -2594,17 +2683,27 @@ def convert_quantize(self, op):
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"
input_tensor = input_tensors[0]
input_tensor_type_str = self.get_tensor_type_str(input_tensor.tensor.Type())
in_expr = self.get_expr(input_tensor.tensor_idx)

output_tensors = self.get_output_tensors(op)
assert len(output_tensors) == 1, "output tensors length should be 1"
output_tensor = output_tensors[0]
output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type())

# The output must be quantized
assert output_tensor.qnn_params
# Quantize the input
out = self.quantize(in_expr, output_tensor)

# TFLite Quantize op can also act as Requantize op
if input_tensor_type_str == "float32":
out = self.quantize(in_expr, output_tensor)
else:
out = _qnn.op.requantize(in_expr,
input_scale=input_tensor.qnn_params['scale'],
input_zero_point=input_tensor.qnn_params['zero_point'],
output_scale=output_tensor.qnn_params['scale'],
output_zero_point=output_tensor.qnn_params['zero_point'],
out_dtype=output_tensor_type_str)
return out

def convert_dequantize(self, op):
Expand Down Expand Up @@ -2734,7 +2833,6 @@ def get_tensor_expr(self, tensor):
else:
type_str = self.get_tensor_type_str(tensor.tensor.Type())
expr = self.exp_tab.new_const(self.get_tensor_value(tensor), dtype=type_str)

return expr


Expand All @@ -2747,6 +2845,13 @@ def get_scalar_from_constant(expr):
"value must be float32/int32"
return np.asscalar(value)

def get_tensor_from_constant(expr):
""" Returns tensor of values from Relay constant node. """
assert isinstance(expr, _expr.Constant)
value = expr.data.asnumpy()
assert value.dtype == np.dtype(np.int32) or value.dtype == np.dtype(np.float32), \
"value must be float32/int32"
return value

def build_str_map(obj):
"""Build string map of TFLite enum int value
Expand Down
16 changes: 13 additions & 3 deletions src/relay/qnn/op/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,19 @@ bool QnnConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
CHECK(IsScalarType(types[3], DataType::Int(32))); // kernel_zero_point
CHECK(IsScalarType(types[4], DataType::Float(32))); // input_scale
// Kernel scale can be a vector of length output_channels or a scalar.
size_t axis = param->kernel_layout.find('O');
CHECK(axis != std::string::npos) << "Kernel layout attribute is not defined";
AssignType(types[5], DataType::Float(32), weight->shape[axis], reporter); // kernel scale
if (param->groups == 1) {
size_t axis = param->kernel_layout.find('O');
CHECK(axis != std::string::npos) << "Kernel layout attribute is not defined";
AssignType(types[5], DataType::Float(32), weight->shape[axis], reporter); // kernel scale
} else {
// Here, total number of output channels depend on depth multiplier.
size_t o_axis = param->kernel_layout.find('O');
size_t i_axis = param->kernel_layout.find('I');
CHECK(o_axis != std::string::npos || i_axis != std::string::npos)
<< "Kernel layout attribute is not defined";
AssignType(types[5], DataType::Float(32), weight->shape[i_axis] * weight->shape[o_axis],
reporter); // kernel scale
}

// Collect the input tensor and output tensor devoid of scale and zero points to reuse Relay
// Conv2D infer type function.
Expand Down
Loading

0 comments on commit 715402e

Please sign in to comment.