diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 34a0a4971ed0..4021060bab0c 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -31,7 +31,7 @@ from ... import nd as _nd from .common import ExprTable from .common import infer_shape as _infer_shape -from .tflite_flexbuffer import FlexBufferDecode +from .tflite_flexbuffer import FlexBufferDecoder __all__ = ['from_tflite'] @@ -343,22 +343,22 @@ def convert_qnn_fused_activation_function(self, expr, fused_activation_fn, # suitable clip off points based on these scale and zero point. if fused_activation_fn == ActivationFunctionType.NONE: return expr - elif fused_activation_fn == ActivationFunctionType.RELU6: + if fused_activation_fn == ActivationFunctionType.RELU6: return _op.clip(expr, a_min=max(qmin, quantize(0)), a_max=min(qmax, quantize(6.0))) - elif fused_activation_fn == ActivationFunctionType.RELU_N1_TO_1: + if fused_activation_fn == ActivationFunctionType.RELU_N1_TO_1: return _op.clip(expr, a_min=max(qmin, quantize(-1.0)), a_max=min(qmax, quantize(1.0))) - elif fused_activation_fn == ActivationFunctionType.RELU: + if fused_activation_fn == ActivationFunctionType.RELU: return _op.clip(expr, a_min=max(qmin, quantize(0.0)), a_max=qmax) fused_activation_fn_str = self.activation_fn_type[fused_activation_fn] raise tvm.error.OpNotImplemented( - 'Quantized activation {} is not supported for frontend TFLite.'.format(fused_activation_fn_str)) + 'Quantized activation {} is not supported yet.'.format(fused_activation_fn_str)) def convert_conv2d(self, op): """Convert TFLite conv2d""" @@ -468,7 +468,6 @@ def convert_l2_normalization(self, op): try: from tflite.BuiltinOptions import BuiltinOptions from tflite.L2NormOptions import L2NormOptions - from tflite.ActivationFunctionType import ActivationFunctionType except ImportError: raise ImportError("The tflite package must be installed") @@ -501,8 +500,7 @@ def convert_l2_normalization(self, op): if output_tensor.qnn_params: raise tvm.error.OpNotImplemented( 'TFLite quantized L2_NORMALIZATION operator is not supported yet.') - else: - out = self.convert_fused_activation_function(out, fused_activation_fn) + out = self.convert_fused_activation_function(out, fused_activation_fn) return out @@ -647,7 +645,6 @@ def convert_concatenation(self, op): try: from tflite.ConcatenationOptions import ConcatenationOptions from tflite.BuiltinOptions import BuiltinOptions - from tflite.ActivationFunctionType import ActivationFunctionType except ImportError: raise ImportError("The tflite package must be installed") @@ -835,7 +832,6 @@ def _convert_elemwise(self, relay_op, op): from tflite.MulOptions import MulOptions from tflite.DivOptions import DivOptions from tflite.BuiltinOptions import BuiltinOptions - from tflite.ActivationFunctionType import ActivationFunctionType except ImportError: raise ImportError("The tflite package must be installed") @@ -1361,7 +1357,6 @@ def convert_fully_connected(self, op): from tflite.FullyConnectedOptions import FullyConnectedOptions from tflite.BuiltinOptions import BuiltinOptions from tflite.TensorType import TensorType - from tflite.ActivationFunctionType import ActivationFunctionType except ImportError: raise ImportError("The tflite package must be installed") @@ -1496,23 +1491,22 @@ def convert_fused_activation_function(self, in_expr, fused_activation_fn): if fused_activation_fn == ActivationFunctionType.NONE: return in_expr - elif fused_activation_fn == ActivationFunctionType.RELU6: + if fused_activation_fn == ActivationFunctionType.RELU6: return _op.clip(in_expr, a_min=0, a_max=6) - elif fused_activation_fn == ActivationFunctionType.RELU: + if fused_activation_fn == ActivationFunctionType.RELU: return _op.nn.relu(in_expr) - elif fused_activation_fn == ActivationFunctionType.RELU_N1_TO_1: + if fused_activation_fn == ActivationFunctionType.RELU_N1_TO_1: return _op.clip(in_expr, a_min=-1, a_max=1) - elif fused_activation_fn == ActivationFunctionType.TANH: + if fused_activation_fn == ActivationFunctionType.TANH: return _op.tanh(in_expr) fused_activation_fn_str = self.activation_fn_type[fused_activation_fn] raise tvm.error.OpNotImplemented( - 'Fused activation {} is not supported for frontend TFLite.'.format(fused_activation_fn_str)) + 'Fused activation {} is not supported yet.'.format(fused_activation_fn_str)) def convert_conv(self, op, conv_type): """convolution implementation.""" try: from tflite.BuiltinOptions import BuiltinOptions - from tflite.ActivationFunctionType import ActivationFunctionType from tflite.TensorType import TensorType from tflite.Conv2DOptions import Conv2DOptions from tflite.DepthwiseConv2DOptions import DepthwiseConv2DOptions @@ -1837,7 +1831,6 @@ def convert_pool2d(self, op, pool_type): """pool2d implementation.""" try: from tflite.BuiltinOptions import BuiltinOptions - from tflite.ActivationFunctionType import ActivationFunctionType from tflite.Pool2DOptions import Pool2DOptions from tflite.Padding import Padding except ImportError: @@ -2314,7 +2307,7 @@ def convert_transpose_conv(self, op): def convert_detection_postprocess(self, op): """Convert TFLite_Detection_PostProcess""" flexbuffer = op.CustomOptionsAsNumpy().tobytes() - custom_options = FlexBufferDecode(flexbuffer).decode() + custom_options = FlexBufferDecoder(flexbuffer).decode() if "use_regular_nms" in custom_options: if custom_options["use_regular_nms"]: diff --git a/python/tvm/relay/frontend/tflite_flexbuffer.py b/python/tvm/relay/frontend/tflite_flexbuffer.py index 6b8606af8889..e3427ab76e51 100644 --- a/python/tvm/relay/frontend/tflite_flexbuffer.py +++ b/python/tvm/relay/frontend/tflite_flexbuffer.py @@ -60,7 +60,7 @@ class FlexBufferType(IntEnum): FBT_VECTOR_BOOL = 36 # To Allow the same type of conversion of type to vector type -class FlexBufferDecode(object): +class FlexBufferDecoder(object): """ This implements partial flexbuffer deserialization to be able to read custom options. It is not intended to be a general @@ -129,9 +129,6 @@ def decode_map(self, end, byte_width, parent_byte_width): # Find keys keys_offset = mid_loc - byte_width * 3 keys_end = self.indirect_jump(keys_offset, byte_width) - keys_byte_width = struct.unpack(\ - "> 2); - byte_width = 1 << BitWidth(root_packed_type & 3); + root_type = FlexBufferType(root_packed_type >> 2) + byte_width = 1 << BitWidth(root_packed_type & 3) if root_type == FlexBufferType.FBT_MAP: return self.decode_map(root_end, byte_width, root_byte_width)