Skip to content

Commit

Permalink
TFLite Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
maheshambule committed May 14, 2020
1 parent ca20675 commit c12571b
Showing 1 changed file with 8 additions and 11 deletions.
19 changes: 8 additions & 11 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,23 +60,22 @@ def convert_wrapper(name,
API to access the class is constructed as 'tflite.<options_class>.<options_class>'
:param quantized_check: True/False. Whether to do a quantized check or not
:param do_fuse_activation: True/False. Whether to fuse activation function to output
:return:
"""

def wrap(f):
def wrap(func):
def wrapped_f(*args):
self = args[0]
op_converter = args[0] #op_converter object
op = args[1]
new_kwargs = {}

if num_inputs is not None:
input_tensors = self.get_input_tensors(op)
input_tensors = op_converter.get_input_tensors(op)
assert len(input_tensors) == num_inputs, \
"input tensors length should be {}".format(num_inputs)
new_kwargs.update({"input_tensors" : input_tensors})

if num_outputs is not None:
output_tensors = self.get_output_tensors(op)
output_tensors = op_converter.get_output_tensors(op)
assert len(output_tensors) == num_outputs, \
"output tensors length should be {}".format(num_outputs)
new_kwargs.update({"output_tensors": output_tensors})
Expand All @@ -93,7 +92,7 @@ def wrapped_f(*args):
raise ImportError("The tflite package must be installed")


if quantized_check is not None and self.is_quantized(op):
if quantized_check is not None and op_converter.is_quantized(op):
raise tvm.error.OpNotImplemented(
'TFLite quantized {} operator is not supported yet.'.format(name))

Expand All @@ -106,14 +105,14 @@ def wrapped_f(*args):

new_kwargs.update({"options": options})

out = f(*args, **new_kwargs)
out = func(*args, **new_kwargs)

if options_class is not None and do_fuse_activation: # is this redundant
if fused_activation_fn != ActivationFunctionType.NONE:
# Assumes single output tensor
output_tensor = self.get_output_tensors(op)[0]
output_tensor = op_converter.get_output_tensors(op)[0]
if not output_tensor.qnn_params:
out = self.convert_fused_activation_function(out, fused_activation_fn)
out = op_converter.convert_fused_activation_function(out, fused_activation_fn)
else:
raise tvm.error.OpNotImplemented(
'TFLite quantized {} operator\
Expand Down Expand Up @@ -229,8 +228,6 @@ def __init__(self, model, subgraph, exp_tab):
'ZEROS_LIKE': self.convert_zeros_like,
}



def check_unsupported_ops(self):
"""Check unsupported TFLite ops in our converter."""
unsupported_ops_set = set()
Expand Down

0 comments on commit c12571b

Please sign in to comment.