diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 7e21739432658..a47fdf0141b51 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -251,7 +251,30 @@ def get_op_code_str(self, op): raise ImportError("The tflite package must be installed") op_code_list_idx = op.OpcodeIndex() - op_code_id = self.model.OperatorCodes(op_code_list_idx).BuiltinCode() + + op_c = self.model.OperatorCodes(op_code_list_idx) + # In TFlite 2.4.x there was a change where the type of the field that contained + # the builtin code changed from int8 to int32 in the flat buffer representation. + # However to retain support for old flat buffers that were created, they retained + # the original 8 bit encoding for the operator but in a new field accessed by the + # DeprecatedBuiltinCode method. + # This means that the API function BuiltinCode() is used on an operator + # which was originally encoded as an 8 bit quantity it would look for the + # code in the new int32 field in the schema and this creates the need + # for the check for the magic number of 127 which is indicated by + # BuiltinOperator.PLACEHOLDER_FOR_GREATER_OP_CODES + # Remember however that this value came into existence only after Tensorflow + # lite 2.4.x and hence encase it in a try -except block. + # Phew ! + try: + if op_c.BuiltinCode() < BuiltinOperator.PLACEHOLDER_FOR_GREATER_OP_CODES: + opc = op_c.DeprecatedBuiltinCode() + else: + opc = op_c.BuiltinCode() + except AttributeError: + opc = op_c.BuiltinCode() + + op_code_id = opc try: op_code_str = self.builtin_op_code[op_code_id] except KeyError: