Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

POC refactor tflite frontend #5528

Closed
wants to merge 5 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 87 additions & 32 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""Tensorflow lite frontend."""
import math
import itertools
import importlib
import numpy as np
import tvm
from tvm.ir import IRModule
Expand All @@ -43,6 +44,87 @@ def __init__(self, tensor_idx, tensor, buffer, qnn_params=None):
self.buffer = buffer
self.qnn_params = qnn_params


def convert_wrapper(name,
num_inputs=None,
num_outputs=None,
options_class_str=None,
quantized_check=None,
do_fuse_activation=None):
"""

:param name: Name of the the operator
:param num_inputs: Number of inputs expected. If None is provided checks will not be done
:param num_outputs: Number of outputs expected. If None is provided checks will not be done
:param options_class_str: Name of operators options class to load it dynamically
API to access the class is constructed as 'tflite.<options_class>.<options_class>'
:param quantized_check: True/False. Whether to do a quantized check or not
:param do_fuse_activation: True/False. Whether to fuse activation function to output
"""

def wrap(func):
def wrapped_f(*args):
op_converter = args[0] #op_converter object
op = args[1]
new_kwargs = {}

if num_inputs is not None:
input_tensors = op_converter.get_input_tensors(op)
assert len(input_tensors) == num_inputs, \
"input tensors length should be {}".format(num_inputs)
new_kwargs.update({"input_tensors" : input_tensors})

if num_outputs is not None:
output_tensors = op_converter.get_output_tensors(op)
assert len(output_tensors) == num_outputs, \
"output tensors length should be {}".format(num_outputs)
new_kwargs.update({"output_tensors": output_tensors})

try:
from tflite.BuiltinOptions import BuiltinOptions
from tflite.ActivationFunctionType import ActivationFunctionType

if options_class_str is not None:
options_module = importlib.import_module('tflite.{}'.format(options_class_str))
options_class = getattr(options_module, options_class_str)

except ImportError:
raise ImportError("The tflite package must be installed")


if quantized_check is not None and op_converter.is_quantized(op):
raise tvm.error.OpNotImplemented(
'TFLite quantized {} operator is not supported yet.'.format(name))

if options_class_str is not None:
assert op.BuiltinOptionsType() == getattr(BuiltinOptions, options_class_str)
op_options = op.BuiltinOptions()
options = options_class()
options.Init(op_options.Bytes, op_options.Pos)
fused_activation_fn = options.FusedActivationFunction()

new_kwargs.update({"options": options})

out = func(*args, **new_kwargs)

if options_class is not None and do_fuse_activation: # is this redundant
if fused_activation_fn != ActivationFunctionType.NONE:
# Assumes single output tensor
output_tensor = op_converter.get_output_tensors(op)[0]
if not output_tensor.qnn_params:
out = op_converter.\
convert_fused_activation_function(out, fused_activation_fn)
else:
raise tvm.error.OpNotImplemented(
'TFLite quantized {} operator\
with fused activation function is not supported yet.'.format(name))

return out
return wrapped_f

return wrap


class OperatorConverter(object):
"""Operator Converted for converting TFLite ops to Relay ops"""
def __init__(self, model, subgraph, exp_tab):
Expand Down Expand Up @@ -466,45 +548,18 @@ def convert_resize_nearest_neighbor(self, op):
"""Convert TFLite RESIZE_NEAREST_NEIGHBOR"""
return self._convert_resize("nearest_neighbor", op)

def convert_l2_normalization(self, op):
@convert_wrapper("L2_NORMALIZATION", num_inputs=1, num_outputs=1,
options_class_str='L2NormOptions', quantized_check=True,
do_fuse_activation=True)
def convert_l2_normalization(self, op, input_tensors=None, output_tensors=None, options=None):
"""Convert TFLite L2_NORMALIZATION """
try:
from tflite.BuiltinOptions import BuiltinOptions
from tflite.L2NormOptions import L2NormOptions
except ImportError:
raise ImportError("The tflite package must be installed")

input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"
input_tensor = input_tensors[0]
in_expr = self.get_expr(input_tensor.tensor_idx)

output_tensors = self.get_output_tensors(op)
assert len(output_tensors) == 1, "output tensors length should be 1"
output_tensor = output_tensors[0]

assert op.BuiltinOptionsType() == BuiltinOptions.L2NormOptions
op_options = op.BuiltinOptions()
l2_norm_options = L2NormOptions()
l2_norm_options.Init(op_options.Bytes, op_options.Pos)
fused_activation_fn = l2_norm_options.FusedActivationFunction()

in_expr = self.get_tensor_expr(input_tensor)
# TFLite supports normalization only over the last dim
input_tensor_rank = len(input_tensor.tensor.ShapeAsNumpy())

if self.is_quantized(op):
raise tvm.error.OpNotImplemented(
'TFLite quantized L2_NORMALIZATION operator is not supported yet.')

# TFL uses only the default epsilon value
out = _op.nn.l2_normalize(in_expr, eps=1e-12, axis=[input_tensor_rank - 1])

# if we have fused activation fn
if output_tensor.qnn_params:
raise tvm.error.OpNotImplemented(
'TFLite quantized L2_NORMALIZATION operator is not supported yet.')
out = self.convert_fused_activation_function(out, fused_activation_fn)

return out

def convert_lrn(self, op):
Expand Down