Skip to content

Commit

Permalink
[TFLite] add support for float16 (#7093)
Browse files Browse the repository at this point in the history
* [TFLite] add support for float16

* add testi case

* add test case

* add comments
  • Loading branch information
euntaik authored Dec 21, 2020
1 parent 38273ee commit 9914685
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 17 deletions.
61 changes: 46 additions & 15 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ def get_tensor_type_as_numpy(self, tensor_wrapper):
return {
TensorType.UINT8: np.uint8,
TensorType.INT8: np.int8,
TensorType.FLOAT16: np.float16,
TensorType.FLOAT32: np.float32,
TensorType.INT32: np.int32,
TensorType.INT64: np.int64,
Expand Down Expand Up @@ -362,6 +363,8 @@ def get_tensor_type_str(self, tensor_type):
return "int8"
if tensor_type == TensorType.UINT8:
return "uint8"
if tensor_type == TensorType.FLOAT16:
return "float16"
if tensor_type == TensorType.FLOAT32:
return "float32"
if tensor_type == TensorType.INT32:
Expand Down Expand Up @@ -1991,20 +1994,33 @@ def convert_conv(self, op, conv_type):
weight_tensor_type_str = self.get_tensor_type_str(weight_tensor_type)

in_expr = self.get_expr(input_tensor_idx)
weight_value = self.get_tensor_value(weight_tensor)

# TFLite kernel layout:
# convolution:
# OC KH KW IC, we require KH KW IC OC (HWIO)
# depthwise convolution:
# 1 KH KW C(input_c * depth_multiplier), we require
# KH KW IC M (depth_multiplier) (HWOI)
if is_depthwise_conv:
weight_value = weight_value.reshape(kernel_h, kernel_w, input_c, depth_multiplier)

# TFLite converts float32 models to float16 models by introducing
# a Dequantize op in every op that contains a float32 values.
# (weights, biases, and constants etc. )
# So conv op may have weight and bias as tensors instead of values.
if self.has_expr(weight_tensor.tensor_idx):
weight_expr = self.get_expr(weight_tensor.tensor_idx)
if is_depthwise_conv:
weight_expr = _op.reshape(
weight_expr, (kernel_h, kernel_w, input_c, depth_multiplier)
)
else:
weight_expr = _op.transpose(weight_expr, axes=(1, 2, 3, 0))
else:
weight_value = weight_value.transpose((1, 2, 3, 0))
weight_value = self.get_tensor_value(weight_tensor)
# TFLite kernel layout:
# convolution:
# OC KH KW IC, we require KH KW IC OC (HWIO)
# depthwise convolution:
# 1 KH KW C(input_c * depth_multiplier), we require
# KH KW IC M (depth_multiplier) (HWOI)
if is_depthwise_conv:
weight_value = weight_value.reshape(kernel_h, kernel_w, input_c, depth_multiplier)
else:
weight_value = weight_value.transpose((1, 2, 3, 0))

weight_expr = self.exp_tab.new_const(weight_value, dtype=weight_tensor_type_str)
weight_expr = self.exp_tab.new_const(weight_value, dtype=weight_tensor_type_str)

if padding == Padding.VALID:
pass
Expand Down Expand Up @@ -2039,9 +2055,12 @@ def convert_conv(self, op, conv_type):
# bias tensor type should be INT32 (quantization) or FLOAT32
assert bias_tensor_type in (TensorType.INT32, TensorType.FLOAT32)
bias_tensor_type_str = self.get_tensor_type_str(bias_tensor_type)
bias_expr = self.exp_tab.new_const(
self.get_tensor_value(bias_tensor), dtype=bias_tensor_type_str
)
if self.has_expr(bias_tensor.tensor_idx):
bias_expr = self.get_expr(bias_tensor.tensor_idx)
else:
bias_expr = self.exp_tab.new_const(
self.get_tensor_value(bias_tensor), dtype=bias_tensor_type_str
)
channel_axis = 3
out = _op.nn.bias_add(out, bias_expr, axis=channel_axis)

Expand Down Expand Up @@ -2870,10 +2889,22 @@ def convert_quantize(self, op):

def convert_dequantize(self, op):
"""Convert TFLite Dequantize"""
try:
from tflite.TensorType import TensorType
except ImportError:
raise ImportError("The tflite package must be installed")

input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"
input_tensor = input_tensors[0]

if input_tensor.tensor.Type() == TensorType.FLOAT16:
dtype = self.get_tensor_type_str(input_tensor.tensor.Type())
input_value = self.get_tensor_value(input_tensor)
in_expr = self.exp_tab.new_const(input_value, dtype=dtype)
out = relay.cast(in_expr, dtype="float32")
return out

in_expr = self.get_expr(input_tensor.tensor_idx)

# The input must be quantized
Expand Down
35 changes: 33 additions & 2 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,13 @@ def convert_to_list(x):
#######################################################################
# Get a real image for e2e testing
# --------------------------------
def get_real_image(im_height, im_width):
def get_real_image(im_height, im_width, quantized=True):
repo_base = "https://github.com/dmlc/web-data/raw/main/tensorflow/models/InceptionV1/"
img_name = "elephant-299.jpg"
image_url = os.path.join(repo_base, img_name)
img_path = download_testdata(image_url, img_name, module="data")
image = Image.open(img_path).resize((im_height, im_width))
x = np.array(image).astype("uint8")
x = np.array(image).astype("uint8") if quantized else np.array(image).astype("float32")
data = np.reshape(x, (1, im_height, im_width, 3))
return data

Expand Down Expand Up @@ -3792,6 +3792,35 @@ def test_forward_tflite2_qnn_mobilenet_v2():
tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)


def test_forward_tflite_float16():
"""Test float16 quantized model"""
# MobilenetV2
tflite_model_file = tf_testing.get_workload_official(
"https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_128.tgz",
"mobilenet_v1_0.25_128_frozen.pb",
)

converter = tf.lite.TFLiteConverter.from_frozen_graph(
tflite_model_file, ["input"], ["MobilenetV1/Predictions/Reshape_1"]
)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.float16]
tflite_model_buf = converter.convert()

# Test image. Checking the labels because the requantize implementation is different between
# TFLite and Relay. This cause final output numbers to mismatch. So, testing accuracy via
# labels. Also, giving a real image, instead of random inputs.
data = get_real_image(128, 128, quantized=False)

tflite_output = run_tflite_graph(tflite_model_buf, data)
tflite_predictions = np.squeeze(tflite_output)
tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1]
tvm_output = run_tvm_graph(tflite_model_buf, data, "input")
tvm_predictions = np.squeeze(tvm_output)
tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1]
tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)


#######################################################################
# Quantized SSD Mobilenet
# -----------------------
Expand Down Expand Up @@ -4057,3 +4086,5 @@ def test_forward_mediapipe_hand_landmark():
test_forward_tflite2_qnn_resnet50()
test_forward_tflite2_qnn_inception_v1()
test_forward_tflite2_qnn_mobilenet_v2()

test_forward_tflite_float16()

0 comments on commit 9914685

Please sign in to comment.