From 22ac6897a2e23f41479afa4e1e154d22584467bb Mon Sep 17 00:00:00 2001 From: anijain2305 Date: Sat, 2 May 2020 00:42:12 +0000 Subject: [PATCH] Flexbuffer parsing --- python/tvm/relay/frontend/tflite.py | 175 +++++------------- .../tvm/relay/frontend/tflite_flexbuffer.py | 154 +++++++++++++++ tests/python/frontend/tflite/test_forward.py | 40 ++-- 3 files changed, 225 insertions(+), 144 deletions(-) create mode 100644 python/tvm/relay/frontend/tflite_flexbuffer.py diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 517ec37744f8..34a0a4971ed0 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -31,6 +31,7 @@ from ... import nd as _nd from .common import ExprTable from .common import infer_shape as _infer_shape +from .tflite_flexbuffer import FlexBufferDecode __all__ = ['from_tflite'] @@ -330,8 +331,13 @@ def convert_qnn_fused_activation_function(self, expr, fused_activation_fn, except ImportError: raise ImportError("The tflite package must be installed") - # Quantize a float value to an integer - quantize = lambda value : (value / scale) + zero_point + # Quantize a float value to an quantized integer value + quantize = lambda x: float(int(round(x / scale)) + zero_point) + + # Get min/max of the output dtype. This will be used to ensure that clip a_min/a_max are not + # beyond the dtype range. + qmin = float(tvm.tir.op.min_value(dtype).value) + qmax = float(tvm.tir.op.max_value(dtype).value) # The input expr is a quantized tensor with its scale and zero point. We calculate the # suitable clip off points based on these scale and zero point. @@ -339,16 +345,16 @@ def convert_qnn_fused_activation_function(self, expr, fused_activation_fn, return expr elif fused_activation_fn == ActivationFunctionType.RELU6: return _op.clip(expr, - a_min=quantize(0), - a_max=quantize(6)) + a_min=max(qmin, quantize(0)), + a_max=min(qmax, quantize(6.0))) elif fused_activation_fn == ActivationFunctionType.RELU_N1_TO_1: return _op.clip(expr, - a_min=quantize(-1), - a_max=quantize(1)) + a_min=max(qmin, quantize(-1.0)), + a_max=min(qmax, quantize(1.0))) elif fused_activation_fn == ActivationFunctionType.RELU: return _op.clip(expr, - a_min=quantize(0), - a_max=float(tvm.tir.op.min_value(dtype).value)) + a_min=max(qmin, quantize(0.0)), + a_max=qmax) fused_activation_fn_str = self.activation_fn_type[fused_activation_fn] raise tvm.error.OpNotImplemented( @@ -1432,14 +1438,6 @@ def convert_fully_connected(self, op): new_input_scale = relay.const(new_input_scale_val, 'float32') new_input_zero_point = relay.const(0, 'int32') - # Call activation function - out = self.convert_qnn_fused_activation_function(\ - expr=out, - fused_activation_fn=fused_activation_fn, - scale=new_input_scale_val, - zero_point=0, - dtype='int32') - # Requantize out = _qnn.op.requantize(out, input_scale=new_input_scale, @@ -1447,6 +1445,17 @@ def convert_fully_connected(self, op): output_scale=output_tensor.qnn_params['scale'], output_zero_point=output_tensor.qnn_params['zero_point'], out_dtype=output_tensor_type_str) + + # Call activation function + output_scale_val = get_scalar_from_constant(output_tensor.qnn_params['scale']) + output_zero_point_val = get_scalar_from_constant(output_tensor.qnn_params['zero_point']) + out = self.convert_qnn_fused_activation_function(\ + expr=out, + fused_activation_fn=fused_activation_fn, + scale=output_scale_val, + zero_point=output_zero_point_val, + dtype=output_tensor_type_str) + else: out = self.convert_fused_activation_function(out, fused_activation_fn) @@ -1645,14 +1654,6 @@ def convert_conv(self, op, conv_type): new_input_scale = relay.const(new_input_scale_val, 'float32') new_input_zero_point = relay.const(0, 'int32') - # Call activation function - out = self.convert_qnn_fused_activation_function(\ - expr=out, - fused_activation_fn=fused_activation_fn, - scale=new_input_scale_val, - zero_point=0, - dtype='int32') - # Finally requantize out = _qnn.op.requantize(out, input_scale=new_input_scale, @@ -1660,6 +1661,16 @@ def convert_conv(self, op, conv_type): output_scale=output_tensor.qnn_params['scale'], output_zero_point=output_tensor.qnn_params['zero_point'], out_dtype=output_tensor_type_str) + + # Call activation function + output_scale_val = get_scalar_from_constant(output_tensor.qnn_params['scale']) + output_zero_point_val = get_scalar_from_constant(output_tensor.qnn_params['zero_point']) + out = self.convert_qnn_fused_activation_function(\ + expr=out, + fused_activation_fn=fused_activation_fn, + scale=output_scale_val, + zero_point=output_zero_point_val, + dtype=output_tensor_type_str) else: out = self.convert_fused_activation_function(out, fused_activation_fn) @@ -2302,28 +2313,15 @@ def convert_transpose_conv(self, op): def convert_detection_postprocess(self, op): """Convert TFLite_Detection_PostProcess""" - _option_names = [ - "w_scale", - "max_detections", - "_output_quantized", - "detections_per_class", - "x_scale", - "nms_score_threshold", - "num_classes", - "max_classes_per_detection", - "use_regular_nms", - "y_scale", - "h_scale", - "_support_output_type_float_in_quantized_op", - "nms_iou_threshold" - ] - - custom_options = get_custom_options(op, _option_names) - if custom_options["use_regular_nms"]: - raise tvm.error.OpAttributeUnImplemented( - "use_regular_nms=True is not yet supported for operator {}." - .format("TFLite_Detection_PostProcess") - ) + flexbuffer = op.CustomOptionsAsNumpy().tobytes() + custom_options = FlexBufferDecode(flexbuffer).decode() + + if "use_regular_nms" in custom_options: + if custom_options["use_regular_nms"]: + raise tvm.error.OpAttributeUnImplemented( + "use_regular_nms=True is not yet supported for operator {}." + .format("TFLite_Detection_PostProcess") + ) inputs = self.get_input_tensors(op) assert len(inputs) == 3, "inputs length should be 3" @@ -2494,91 +2492,6 @@ def get_tensor_name(subgraph, tensor_idx): return subgraph.Tensors(tensor_idx).Name().decode("utf-8") -def get_custom_options(op, option_names): - """Get the options of a custom operator. - - This implements partial flexbuffer deserialization to be able - to read custom options. It is not intended to be a general - purpose flexbuffer deserializer and as such only supports a - limited number of types and assumes the data is a flat map. - - Parameters - ---------- - op: - A custom TFlite operator. - option_names: list - A complete list of the custom option names. - - Returns - ------- - options: dict - A dictionary of the custom options. - - """ - import struct - from enum import IntEnum - - class _FlexBufferType(IntEnum): - """Flexbuffer type schema from flexbuffers.h""" - FBT_NULL = 0 - FBT_INT = 1 - FBT_UINT = 2 - FBT_FLOAT = 3 - # Types above stored inline, types below store an offset. - FBT_KEY = 4 - FBT_STRING = 5 - FBT_INDIRECT_INT = 6 - FBT_INDIRECT_UINT = 7 - FBT_INDIRECT_FLOAT = 8 - FBT_MAP = 9 - FBT_VECTOR = 10 # Untyped. - FBT_VECTOR_INT = 11 # Typed any size (stores no type table). - FBT_VECTOR_UINT = 12 - FBT_VECTOR_FLOAT = 13 - FBT_VECTOR_KEY = 14 - FBT_VECTOR_STRING = 15 - FBT_VECTOR_INT2 = 16 # Typed tuple (no type table, no size field). - FBT_VECTOR_UINT2 = 17 - FBT_VECTOR_FLOAT2 = 18 - FBT_VECTOR_INT3 = 19 # Typed triple (no type table, no size field). - FBT_VECTOR_UINT3 = 20 - FBT_VECTOR_FLOAT3 = 21 - FBT_VECTOR_INT4 = 22 # Typed quad (no type table, no size field). - FBT_VECTOR_UINT4 = 23 - FBT_VECTOR_FLOAT4 = 24 - FBT_BLOB = 25 - FBT_BOOL = 26 - FBT_VECTOR_BOOL = 36 # To Allow the same type of conversion of type to vector type - - buffer = op.CustomOptionsAsNumpy().tobytes() - value_vector_offset = buffer[-3] - buffer = buffer[:-3] - num_bytes = 4 # Assume all values are stored in 32 bit width - value_vector_size = struct.unpack( - "> 2) - value_offset = -value_vector_offset + i*num_bytes - value_bytes = buffer[value_offset:value_offset+num_bytes] - if flex_type == _FlexBufferType.FBT_BOOL: - value = bool(value_bytes[0]) - if flex_type == _FlexBufferType.FBT_INT: - value = struct.unpack("> 2) + value_bytes = self.buffer[end + i * byte_width: end + (i + 1) * byte_width] + if value_type == FlexBufferType.FBT_BOOL: + value = bool(value_bytes[0]) + elif value_type == FlexBufferType.FBT_INT: + value = struct.unpack("> 2); + byte_width = 1 << BitWidth(root_packed_type & 3); + + if root_type == FlexBufferType.FBT_MAP: + return self.decode_map(root_end, byte_width, root_byte_width) + raise NotImplementedError("Flexbuffer Decoding is partially imlpemented.") diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 16bc8f5fbe05..220b0664e59b 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -73,6 +73,16 @@ def get_real_image(im_height, im_width): data = np.reshape(x, (1, im_height, im_width, 3)) return data +def get_real_image_object_detection(im_height, im_width): + repo_base = 'https://github.com/dmlc/web-data/raw/master/gluoncv/detection/' + img_name = 'street_small.jpg' + image_url = os.path.join(repo_base, img_name) + img_path = download_testdata(image_url, img_name, module='data') + image = Image.open(img_path).resize((im_height, im_width)) + x = np.array(image).astype('uint8') + data = np.reshape(x, (1, im_height, im_width, 3)) + return data + def run_tvm_graph(tflite_model_buf, input_data, input_node, num_output=1, target='llvm', out_names=None): """ Generic function to compile on relay and execute on tvm """ @@ -98,6 +108,7 @@ def run_tvm_graph(tflite_model_buf, input_data, input_node, num_output=1, target mod, params = relay.frontend.from_tflite(tflite_model, shape_dict=shape_dict, dtype_dict=dtype_dict) + with relay.build_config(opt_level=3): graph, lib, params = relay.build(mod, target, params=params) @@ -1952,7 +1963,10 @@ def test_forward_qnn_mobilenet_v3_net(): def test_forward_qnn_coco_ssd_mobilenet_v1(): """Test the quantized Coco SSD Mobilenet V1 TF Lite model.""" - pytest.skip("Unsupported op - use_regular_nms") + pytest.skip("LLVM bug - getExtendedVectorNumElements - " + + "https://discuss.tvm.ai/t/segfault-in-llvm/3567. The workaround is to use a " + + "specific target, for example, llvm -mpcu=core-avx2") + tflite_model_file = tf_testing.get_workload_official( "https://storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip", "detect.tflite") @@ -1960,8 +1974,7 @@ def test_forward_qnn_coco_ssd_mobilenet_v1(): with open(tflite_model_file, "rb") as f: tflite_model_buf = f.read() - np.random.seed(0) - data = np.random.uniform(size=(1, 300, 300, 3)).astype('uint8') + data = get_real_image_object_detection(300, 300) tflite_output = run_tflite_graph(tflite_model_buf, data) tvm_output = run_tvm_graph(tflite_model_buf, data, 'normalized_input_image_tensor', num_output=4) @@ -1976,16 +1989,18 @@ def test_forward_qnn_coco_ssd_mobilenet_v1(): # For boxes that do not have any detections, TFLite puts random values. Therefore, we compare # tflite and tvm tensors for only valid boxes. for i in range(0, valid_count): - # Check bounding box co-ords + # Check bounding box co-ords. The tolerances have to be adjusted because of differences between + # for requantiize operator in TFLite and TVM. tvm.testing.assert_allclose(np.squeeze(tvm_output[0][0][i]), np.squeeze(tflite_output[0][0][i]), - rtol=1e-5, atol=1e-5) + rtol=1e-1, atol=1e-1) + # Check the class - tvm.testing.assert_allclose(np.squeeze(tvm_output[1][0][i]), np.squeeze(tflite_output[1][0][i]), - rtol=1e-5, atol=1e-5) + # Stricter check to ensure class remains same + np.testing.assert_equal(np.squeeze(tvm_output[1][0][i]), np.squeeze(tflite_output[1][0][i])) + # Check the score tvm.testing.assert_allclose(np.squeeze(tvm_output[2][0][i]), np.squeeze(tflite_output[2][0][i]), - rtol=1e-5, atol=1e-5) - + rtol=1e-2, atol=1e-2) ####################################################################### @@ -2021,13 +2036,11 @@ def test_forward_coco_ssd_mobilenet_v1(): tvm.testing.assert_allclose(np.squeeze(tvm_output[0][0][i]), np.squeeze(tflite_output[0][0][i]), rtol=1e-5, atol=1e-5) # Check the class - tvm.testing.assert_allclose(np.squeeze(tvm_output[1][0][i]), np.squeeze(tflite_output[1][0][i]), - rtol=1e-5, atol=1e-5) + np.testing.assert_equal(np.squeeze(tvm_output[1][0][i]), np.squeeze(tflite_output[1][0][i])) + # Check the score tvm.testing.assert_allclose(np.squeeze(tvm_output[2][0][i]), np.squeeze(tflite_output[2][0][i]), rtol=1e-5, atol=1e-5) ->>>>>>> Fix test - ####################################################################### # MediaPipe @@ -2135,3 +2148,4 @@ def test_forward_mediapipe_hand_landmark(): #This also fails with a segmentation fault in my run #with Tflite 1.15.2 test_forward_qnn_mobilenet_v3_net() + test_forward_qnn_coco_ssd_mobilenet_v1()