diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index a66fc4736a98..3688ff5ff4e5 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -66,6 +66,7 @@ def __init__(self, model, subgraph, exp_tab): self.activation_fn_type = build_str_map(ActivationFunctionType()) self.builtin_options = build_str_map(BuiltinOptions()) self.prefetched_nodes = {} + self.allow_custom_ops = False # Add more operators self.convert_map = { @@ -287,6 +288,10 @@ def get_op_code_str(self, op): if op_code_id == BuiltinOperator.CUSTOM: # Custom operator custom_op_code_str = self.model.OperatorCodes(op_code_list_idx).CustomCode() + + if self.allow_custom_ops: + return "CUSTOM" + if custom_op_code_str == b"TFLite_Detection_PostProcess": return "DETECTION_POSTPROCESS" @@ -3695,7 +3700,7 @@ def _input_type(model): return shape_dict, dtype_dict -def from_tflite(model, shape_dict=None, dtype_dict=None): +def from_tflite(model, shape_dict=None, dtype_dict=None, op_converter=OperatorConverter): """Convert from tflite model into compatible relay Function. Parameters @@ -3755,7 +3760,7 @@ def from_tflite(model, shape_dict=None, dtype_dict=None): exp_tab.set_expr(model_input_name, _expr.var(model_input_name, shape=shape, dtype=dtype)) # op code in model - op_converter = OperatorConverter(model, subgraph, exp_tab) + op_converter = op_converter(model, subgraph, exp_tab) op_converter.check_unsupported_ops() op_converter.convert_op_to_relay() diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 4a6f88417b9c..754976ca8c13 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -161,6 +161,7 @@ def run_tvm_graph( target="llvm", out_names=None, mode="graph_executor", + op_converter=relay.frontend.tflite.OperatorConverter, ): """Generic function to compile on relay and execute on tvm""" # TFLite.Model.Model has changed to TFLite.Model from 1.14 to 2.1 @@ -185,7 +186,7 @@ def run_tvm_graph( dtype_dict[e] = input_data[i].dtype.name mod, params = relay.frontend.from_tflite( - tflite_model, shape_dict=shape_dict, dtype_dict=dtype_dict + tflite_model, shape_dict=shape_dict, dtype_dict=dtype_dict, op_converter=op_converter ) if mode in ["debug", "vm"]: @@ -3996,6 +3997,72 @@ def test_detection_postprocess(): ) +####################################################################### +# Custom Converter +# ---------------- + + +def test_custom_op_converter(): + """Test case for user-defined operator converter in TFLite frontend""" + + class DummyOperatorConverter(relay.frontend.tflite.OperatorConverter): + """Operator Converter for converting TFLite ops to relay ops""" + + def __init__(self, model, subgraph, exp_tab): + super(DummyOperatorConverter, self).__init__(model, subgraph, exp_tab) + self.allow_custom_ops = True + + convert_map_overwrite = {"SUB": self.convert_sub_dummy} + + self.convert_map.update(convert_map_overwrite) + + def convert_sub_dummy(self, op): + """Convert TFLite SUB""" + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 2, "input tensors length should be 2" + + lhs_tensor = input_tensors[0] + rhs_tensor = input_tensors[1] + + lhs_expr = self.get_expr(lhs_tensor.tensor_idx) + rhs_expr = self.get_expr(rhs_tensor.tensor_idx) + + temp_expr = relay.op.negative(rhs_expr) + out = relay.op.add(lhs_expr, temp_expr) + + return out + + with tf.Graph().as_default(): + # Generate TFLite model for single addition + data = [ + np.arange(6.0, dtype=np.float32).reshape((2, 1, 1, 3)), + np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 1, 3)), + ] + in_data = [ + array_ops.placeholder(shape=data[0].shape, dtype="float32", name="in_0"), + array_ops.placeholder(shape=data[1].shape, dtype="float32", name="in_1"), + ] + out = math_ops.subtract(in_data[0], in_data[1]) + in_name = [x[1] for x in zip(in_data, ("in_0:0", "in_1:0"))] + input_tensors = [x for x in in_data] + output_tensors = [out] + in_node = [0] * len(in_name) + for i in range(len(in_name)): + in_node[i] = in_name[i].split(":")[0] if ":" in in_name[i] else in_name[i] + + with tf.Session() as sess: + converter = tf.lite.TFLiteConverter.from_session(sess, input_tensors, output_tensors) + tflite_model_buf = converter.convert() + in_data = [x[1] for x in zip(in_data, data)] + tvm_output_orig = run_tvm_graph(tflite_model_buf, in_data, in_node) + tvm_output_dummy = run_tvm_graph( + tflite_model_buf, in_data, in_node, op_converter=DummyOperatorConverter + ) + tvm.testing.assert_allclose( + np.squeeze(tvm_output_orig[0]), np.squeeze(tvm_output_dummy[0]), rtol=1e-5, atol=1e-5 + ) + + ####################################################################### # Mobilenet # --------- @@ -4621,6 +4688,9 @@ def test_prevent_tensorflow_dynamic_range(): # Detection_PostProcess test_detection_postprocess() + # Overwrite Converter + test_custom_op_converter() + # End to End test_forward_mobilenet_v1() test_forward_mobilenet_v2()