Skip to content

Commit

Permalink
[TFLite] Add option to overwrite OperatorConverter class in relay.fro…
Browse files Browse the repository at this point in the history
…ntend.from_tflite (#9256)

* [TFLite] Relay Frontend: Add option to overwrite OperatorConverter class

This allows to overwrite the mapping from TFLite Operators to TVM Relay Operators from external python scripts. This has the following advantages:
- Adding support for unsupported builtin or even custom operators by adding a hand-written convert function
- Enables overwriting of existing convert functions for supported operators by alternative implementations (useful for currently unsupported edge cases)

Example Usage:

```
class CustomOperatorConverter(relay.frontend.tflite.OperatorConverter):

    def __init__(self, model, subgraph, exp_tab):
        super(CustomOperatorConverter, self).__init__(model, subgraph, exp_tab)
        convert_map_overwrite = {"SUB": self.convert_sub_custom}
        self.convert_map.update(convert_map_overwrite)

    def convert_sub_custom(self, op):
        ...
...
relay_mod = relay.frontend.from_tflite(
    tflite_model, shape_dict=shape_dict, dtype_dict=dtype_dict, op_converter=CustomOperatorConverter
)
```

[TFLite] Make sure that even DETECTION_POSTPROCESS op can be overwritten

This is desirable, because the current implementation of this CUSTOM op is incompatible with MicroTVM targets

* Tests: added test case for overwriting op_converter in TFLite relay frontend

Kept the test as simple as possible by only comparing 2 different
implementations of a SUB TFLite operator:

1. Original: c = a - b
2. Dummy: c = a + (-b)

Comparison with TFLite reference output is not necessary because tis is
already covered by other test cases. Instead comparisons of the two TVM
models are used.
  • Loading branch information
PhilippvK authored Oct 14, 2021
1 parent 08018ea commit f4db899
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 3 deletions.
9 changes: 7 additions & 2 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __init__(self, model, subgraph, exp_tab):
self.activation_fn_type = build_str_map(ActivationFunctionType())
self.builtin_options = build_str_map(BuiltinOptions())
self.prefetched_nodes = {}
self.allow_custom_ops = False

# Add more operators
self.convert_map = {
Expand Down Expand Up @@ -287,6 +288,10 @@ def get_op_code_str(self, op):
if op_code_id == BuiltinOperator.CUSTOM:
# Custom operator
custom_op_code_str = self.model.OperatorCodes(op_code_list_idx).CustomCode()

if self.allow_custom_ops:
return "CUSTOM"

if custom_op_code_str == b"TFLite_Detection_PostProcess":
return "DETECTION_POSTPROCESS"

Expand Down Expand Up @@ -3695,7 +3700,7 @@ def _input_type(model):
return shape_dict, dtype_dict


def from_tflite(model, shape_dict=None, dtype_dict=None):
def from_tflite(model, shape_dict=None, dtype_dict=None, op_converter=OperatorConverter):
"""Convert from tflite model into compatible relay Function.
Parameters
Expand Down Expand Up @@ -3755,7 +3760,7 @@ def from_tflite(model, shape_dict=None, dtype_dict=None):
exp_tab.set_expr(model_input_name, _expr.var(model_input_name, shape=shape, dtype=dtype))

# op code in model
op_converter = OperatorConverter(model, subgraph, exp_tab)
op_converter = op_converter(model, subgraph, exp_tab)
op_converter.check_unsupported_ops()
op_converter.convert_op_to_relay()

Expand Down
72 changes: 71 additions & 1 deletion tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def run_tvm_graph(
target="llvm",
out_names=None,
mode="graph_executor",
op_converter=relay.frontend.tflite.OperatorConverter,
):
"""Generic function to compile on relay and execute on tvm"""
# TFLite.Model.Model has changed to TFLite.Model from 1.14 to 2.1
Expand All @@ -185,7 +186,7 @@ def run_tvm_graph(
dtype_dict[e] = input_data[i].dtype.name

mod, params = relay.frontend.from_tflite(
tflite_model, shape_dict=shape_dict, dtype_dict=dtype_dict
tflite_model, shape_dict=shape_dict, dtype_dict=dtype_dict, op_converter=op_converter
)

if mode in ["debug", "vm"]:
Expand Down Expand Up @@ -3996,6 +3997,72 @@ def test_detection_postprocess():
)


#######################################################################
# Custom Converter
# ----------------


def test_custom_op_converter():
"""Test case for user-defined operator converter in TFLite frontend"""

class DummyOperatorConverter(relay.frontend.tflite.OperatorConverter):
"""Operator Converter for converting TFLite ops to relay ops"""

def __init__(self, model, subgraph, exp_tab):
super(DummyOperatorConverter, self).__init__(model, subgraph, exp_tab)
self.allow_custom_ops = True

convert_map_overwrite = {"SUB": self.convert_sub_dummy}

self.convert_map.update(convert_map_overwrite)

def convert_sub_dummy(self, op):
"""Convert TFLite SUB"""
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 2, "input tensors length should be 2"

lhs_tensor = input_tensors[0]
rhs_tensor = input_tensors[1]

lhs_expr = self.get_expr(lhs_tensor.tensor_idx)
rhs_expr = self.get_expr(rhs_tensor.tensor_idx)

temp_expr = relay.op.negative(rhs_expr)
out = relay.op.add(lhs_expr, temp_expr)

return out

with tf.Graph().as_default():
# Generate TFLite model for single addition
data = [
np.arange(6.0, dtype=np.float32).reshape((2, 1, 1, 3)),
np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 1, 3)),
]
in_data = [
array_ops.placeholder(shape=data[0].shape, dtype="float32", name="in_0"),
array_ops.placeholder(shape=data[1].shape, dtype="float32", name="in_1"),
]
out = math_ops.subtract(in_data[0], in_data[1])
in_name = [x[1] for x in zip(in_data, ("in_0:0", "in_1:0"))]
input_tensors = [x for x in in_data]
output_tensors = [out]
in_node = [0] * len(in_name)
for i in range(len(in_name)):
in_node[i] = in_name[i].split(":")[0] if ":" in in_name[i] else in_name[i]

with tf.Session() as sess:
converter = tf.lite.TFLiteConverter.from_session(sess, input_tensors, output_tensors)
tflite_model_buf = converter.convert()
in_data = [x[1] for x in zip(in_data, data)]
tvm_output_orig = run_tvm_graph(tflite_model_buf, in_data, in_node)
tvm_output_dummy = run_tvm_graph(
tflite_model_buf, in_data, in_node, op_converter=DummyOperatorConverter
)
tvm.testing.assert_allclose(
np.squeeze(tvm_output_orig[0]), np.squeeze(tvm_output_dummy[0]), rtol=1e-5, atol=1e-5
)


#######################################################################
# Mobilenet
# ---------
Expand Down Expand Up @@ -4621,6 +4688,9 @@ def test_prevent_tensorflow_dynamic_range():
# Detection_PostProcess
test_detection_postprocess()

# Overwrite Converter
test_custom_op_converter()

# End to End
test_forward_mobilenet_v1()
test_forward_mobilenet_v2()
Expand Down

0 comments on commit f4db899

Please sign in to comment.