Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Make tflite frontend more data driven / improve errors. #5519

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 37 additions & 16 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,20 @@ def __init__(self, model, subgraph, exp_tab):
self.activation_fn_type = build_str_map(ActivationFunctionType())
self.builtin_options = build_str_map(BuiltinOptions())

# Op-> tuple(op_parser, number_of_input_tensors,
# number_of_output_tensors)
# Adding new operators involves an entry in this table with
# the number of input and output tensors. It is simpler to
# assume that the check for the length of the input and output
# tensors needs to be equality for now as this covers the vast
# majority of cases. Thus -1 indicates that checks that
# require other relational expressions are handled in the
# op_parser method.

self.convert_map_data_driven = {
'DEPTH_TO_SPACE': (self.convert_depth_to_space, 1, 1),
'SOFTMAX': (self.convert_softmax, 1, 1),
}
# Add more operators
self.convert_map = {
'ABS': self.convert_abs,
Expand All @@ -71,7 +85,6 @@ def __init__(self, model, subgraph, exp_tab):
'CONCATENATION': self.convert_concatenation,
'CONV_2D': self.convert_conv2d,
'COS': self.convert_cos,
'DEPTH_TO_SPACE': self.convert_depth_to_space,
'DEPTHWISE_CONV_2D': self.convert_depthwise_conv2d,
'DETECTION_POSTPROCESS': self.convert_detection_postprocess,
'DIV': self.convert_div,
Expand Down Expand Up @@ -121,7 +134,6 @@ def __init__(self, model, subgraph, exp_tab):
'RSQRT': self.convert_rsqrt,
'SIN': self.convert_sin,
'SLICE': self.convert_slice,
'SOFTMAX': self.convert_softmax,
'SPACE_TO_BATCH_ND': self.convert_space_to_batch_nd,
'SPACE_TO_DEPTH': self.convert_space_to_depth,
'SPLIT': self.convert_split,
Expand Down Expand Up @@ -150,7 +162,8 @@ def check_unsupported_ops(self):
for op_idx in range(self.subgraph.OperatorsLength()):
op = self.subgraph.Operators(op_idx)
op_code_str = self.get_op_code_str(op)
if op_code_str not in self.convert_map:
if (op_code_str not in self.convert_map_data_driven)\
and (op_code_str not in self.convert_map):
unsupported_ops_set.add(op_code_str)

if unsupported_ops_set:
Expand All @@ -171,7 +184,21 @@ def convert_op_to_relay(self):
raise ImportError("The tflite package must be installed")

assert isinstance(op, Operator)
ret = self.convert_map[op_code_str](op)
input_tensors = self.get_input_tensors(op)

try:
(func, num_inputs, num_outputs) = self.convert_map_data_driven[op_code_str]
if num_inputs != -1:
print("test")
assert (len(input_tensors) == num_inputs)\
, "input tensors should be %d" % num_inputs
if num_outputs != -1:
assert (len(output_tensors) == num_outputs)\
, "output tensors should be %d" % num_outputs
ret = func(op, input_tensors, output_tensors)
except KeyError:
func = self.convert_map[op_code_str]
ret = func(op)

if len(output_tensors) == 1:
tensor_idx = output_tensors[0].tensor_idx
Expand Down Expand Up @@ -522,16 +549,11 @@ def convert_logistic(self, op):

return out

def convert_softmax(self, op):
def convert_softmax(self, op, input_tensors, output_tensors):
"""Convert TFLite softmax"""
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"

input_tensor = input_tensors[0]
input_tensor_idx = input_tensor.tensor_idx

output_tensors = self.get_output_tensors(op)
assert len(output_tensors) == 1, "output tensors length should be 1"
output_tensor = output_tensors[0]

params = {'axis': 1} # 1 is channel
Expand Down Expand Up @@ -2101,17 +2123,14 @@ def convert_space_to_batch_nd(self, op):

return reshaped_permuted_reshaped_padded

def convert_depth_to_space(self, op):
def convert_depth_to_space(self, op, input_tensors, output_tensors):
"""Convert TFLite DEPTH_TO_SPACE"""
try:
from tflite.BuiltinOptions import BuiltinOptions
from tflite.DepthToSpaceOptions import DepthToSpaceOptions
except ImportError:
raise ImportError("The tflite package must be installed")

input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"

input_tensor = input_tensors[0]
in_expr = self.get_expr(input_tensor.tensor_idx)

Expand Down Expand Up @@ -2201,8 +2220,10 @@ def convert_transpose_conv(self, op):
padding = deconv_options.Padding()
stride_h = deconv_options.StrideH()
stride_w = deconv_options.StrideW()
assert padding in (Padding.VALID, Padding.SAME), \
'Padding format {} is not supported for operator TRANSPOSE_CONV'.format(padding)

if padding not in (Padding.VALID, Padding.SAME):
raise tvm.error.OpAttributeUnImplemented('Padding format {} is not supported'\
'for operator TRANSPOSE_CONV'.format(padding))

# Data
in_expr = self.get_expr(input_tensor.tensor_idx)
Expand Down