Skip to content

Commit

Permalink
Lint
Browse files Browse the repository at this point in the history
  • Loading branch information
anijain2305 committed May 2, 2020
1 parent 22ac689 commit 7e3e04d
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 25 deletions.
31 changes: 12 additions & 19 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from ... import nd as _nd
from .common import ExprTable
from .common import infer_shape as _infer_shape
from .tflite_flexbuffer import FlexBufferDecode
from .tflite_flexbuffer import FlexBufferDecoder

__all__ = ['from_tflite']

Expand Down Expand Up @@ -343,22 +343,22 @@ def convert_qnn_fused_activation_function(self, expr, fused_activation_fn,
# suitable clip off points based on these scale and zero point.
if fused_activation_fn == ActivationFunctionType.NONE:
return expr
elif fused_activation_fn == ActivationFunctionType.RELU6:
if fused_activation_fn == ActivationFunctionType.RELU6:
return _op.clip(expr,
a_min=max(qmin, quantize(0)),
a_max=min(qmax, quantize(6.0)))
elif fused_activation_fn == ActivationFunctionType.RELU_N1_TO_1:
if fused_activation_fn == ActivationFunctionType.RELU_N1_TO_1:
return _op.clip(expr,
a_min=max(qmin, quantize(-1.0)),
a_max=min(qmax, quantize(1.0)))
elif fused_activation_fn == ActivationFunctionType.RELU:
if fused_activation_fn == ActivationFunctionType.RELU:
return _op.clip(expr,
a_min=max(qmin, quantize(0.0)),
a_max=qmax)

fused_activation_fn_str = self.activation_fn_type[fused_activation_fn]
raise tvm.error.OpNotImplemented(
'Quantized activation {} is not supported for frontend TFLite.'.format(fused_activation_fn_str))
'Quantized activation {} is not supported yet.'.format(fused_activation_fn_str))

def convert_conv2d(self, op):
"""Convert TFLite conv2d"""
Expand Down Expand Up @@ -468,7 +468,6 @@ def convert_l2_normalization(self, op):
try:
from tflite.BuiltinOptions import BuiltinOptions
from tflite.L2NormOptions import L2NormOptions
from tflite.ActivationFunctionType import ActivationFunctionType
except ImportError:
raise ImportError("The tflite package must be installed")

Expand Down Expand Up @@ -501,8 +500,7 @@ def convert_l2_normalization(self, op):
if output_tensor.qnn_params:
raise tvm.error.OpNotImplemented(
'TFLite quantized L2_NORMALIZATION operator is not supported yet.')
else:
out = self.convert_fused_activation_function(out, fused_activation_fn)
out = self.convert_fused_activation_function(out, fused_activation_fn)

return out

Expand Down Expand Up @@ -647,7 +645,6 @@ def convert_concatenation(self, op):
try:
from tflite.ConcatenationOptions import ConcatenationOptions
from tflite.BuiltinOptions import BuiltinOptions
from tflite.ActivationFunctionType import ActivationFunctionType
except ImportError:
raise ImportError("The tflite package must be installed")

Expand Down Expand Up @@ -835,7 +832,6 @@ def _convert_elemwise(self, relay_op, op):
from tflite.MulOptions import MulOptions
from tflite.DivOptions import DivOptions
from tflite.BuiltinOptions import BuiltinOptions
from tflite.ActivationFunctionType import ActivationFunctionType
except ImportError:
raise ImportError("The tflite package must be installed")

Expand Down Expand Up @@ -1361,7 +1357,6 @@ def convert_fully_connected(self, op):
from tflite.FullyConnectedOptions import FullyConnectedOptions
from tflite.BuiltinOptions import BuiltinOptions
from tflite.TensorType import TensorType
from tflite.ActivationFunctionType import ActivationFunctionType
except ImportError:
raise ImportError("The tflite package must be installed")

Expand Down Expand Up @@ -1496,23 +1491,22 @@ def convert_fused_activation_function(self, in_expr, fused_activation_fn):

if fused_activation_fn == ActivationFunctionType.NONE:
return in_expr
elif fused_activation_fn == ActivationFunctionType.RELU6:
if fused_activation_fn == ActivationFunctionType.RELU6:
return _op.clip(in_expr, a_min=0, a_max=6)
elif fused_activation_fn == ActivationFunctionType.RELU:
if fused_activation_fn == ActivationFunctionType.RELU:
return _op.nn.relu(in_expr)
elif fused_activation_fn == ActivationFunctionType.RELU_N1_TO_1:
if fused_activation_fn == ActivationFunctionType.RELU_N1_TO_1:
return _op.clip(in_expr, a_min=-1, a_max=1)
elif fused_activation_fn == ActivationFunctionType.TANH:
if fused_activation_fn == ActivationFunctionType.TANH:
return _op.tanh(in_expr)
fused_activation_fn_str = self.activation_fn_type[fused_activation_fn]
raise tvm.error.OpNotImplemented(
'Fused activation {} is not supported for frontend TFLite.'.format(fused_activation_fn_str))
'Fused activation {} is not supported yet.'.format(fused_activation_fn_str))

def convert_conv(self, op, conv_type):
"""convolution implementation."""
try:
from tflite.BuiltinOptions import BuiltinOptions
from tflite.ActivationFunctionType import ActivationFunctionType
from tflite.TensorType import TensorType
from tflite.Conv2DOptions import Conv2DOptions
from tflite.DepthwiseConv2DOptions import DepthwiseConv2DOptions
Expand Down Expand Up @@ -1837,7 +1831,6 @@ def convert_pool2d(self, op, pool_type):
"""pool2d implementation."""
try:
from tflite.BuiltinOptions import BuiltinOptions
from tflite.ActivationFunctionType import ActivationFunctionType
from tflite.Pool2DOptions import Pool2DOptions
from tflite.Padding import Padding
except ImportError:
Expand Down Expand Up @@ -2314,7 +2307,7 @@ def convert_transpose_conv(self, op):
def convert_detection_postprocess(self, op):
"""Convert TFLite_Detection_PostProcess"""
flexbuffer = op.CustomOptionsAsNumpy().tobytes()
custom_options = FlexBufferDecode(flexbuffer).decode()
custom_options = FlexBufferDecoder(flexbuffer).decode()

if "use_regular_nms" in custom_options:
if custom_options["use_regular_nms"]:
Expand Down
10 changes: 4 additions & 6 deletions python/tvm/relay/frontend/tflite_flexbuffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class FlexBufferType(IntEnum):
FBT_VECTOR_BOOL = 36 # To Allow the same type of conversion of type to vector type


class FlexBufferDecode(object):
class FlexBufferDecoder(object):
"""
This implements partial flexbuffer deserialization to be able
to read custom options. It is not intended to be a general
Expand Down Expand Up @@ -129,9 +129,6 @@ def decode_map(self, end, byte_width, parent_byte_width):
# Find keys
keys_offset = mid_loc - byte_width * 3
keys_end = self.indirect_jump(keys_offset, byte_width)
keys_byte_width = struct.unpack(\
"<i",
self.buffer[keys_offset + byte_width:keys_offset + 2 * byte_width:])[0]
keys = self.decode_keys(keys_end, map_size, 1)

# Find values
Expand All @@ -140,14 +137,15 @@ def decode_map(self, end, byte_width, parent_byte_width):
return dict(zip(keys, values))

def decode(self):
""" Decode the buffer. Decoding is paritally implemented """
root_end = len(self.buffer) - 1
root_byte_width = self.buffer[root_end]
root_end -= 1
root_packed_type = self.buffer[root_end]
root_end -= root_byte_width

root_type = FlexBufferType(root_packed_type >> 2);
byte_width = 1 << BitWidth(root_packed_type & 3);
root_type = FlexBufferType(root_packed_type >> 2)
byte_width = 1 << BitWidth(root_packed_type & 3)

if root_type == FlexBufferType.FBT_MAP:
return self.decode_map(root_end, byte_width, root_byte_width)
Expand Down

0 comments on commit 7e3e04d

Please sign in to comment.