Skip to content

Commit

Permalink
Fix issue with importing models using Tensorflow Lite 2.4.x schema (a…
Browse files Browse the repository at this point in the history
…pache#8375)

Tensorflow Lite has changed the opcode for BuiltinOperators
to be represented as 32 bit integers instead of 8 bit integers
in the schema.

This is an attempt to fix this in a way that is clean to handle
multiple versions of tensorflow lite in the frontend.
  • Loading branch information
Ramana Radhakrishnan authored and lygztq committed Jul 1, 2021
1 parent 4b9fffc commit 493682d
Showing 1 changed file with 24 additions and 1 deletion.
25 changes: 24 additions & 1 deletion python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,30 @@ def get_op_code_str(self, op):
raise ImportError("The tflite package must be installed")

op_code_list_idx = op.OpcodeIndex()
op_code_id = self.model.OperatorCodes(op_code_list_idx).BuiltinCode()

op_c = self.model.OperatorCodes(op_code_list_idx)
# In TFlite 2.4.x there was a change where the type of the field that contained
# the builtin code changed from int8 to int32 in the flat buffer representation.
# However to retain support for old flat buffers that were created, they retained
# the original 8 bit encoding for the operator but in a new field accessed by the
# DeprecatedBuiltinCode method.
# This means that the API function BuiltinCode() is used on an operator
# which was originally encoded as an 8 bit quantity it would look for the
# code in the new int32 field in the schema and this creates the need
# for the check for the magic number of 127 which is indicated by
# BuiltinOperator.PLACEHOLDER_FOR_GREATER_OP_CODES
# Remember however that this value came into existence only after Tensorflow
# lite 2.4.x and hence encase it in a try -except block.
# Phew !
try:
if op_c.BuiltinCode() < BuiltinOperator.PLACEHOLDER_FOR_GREATER_OP_CODES:
opc = op_c.DeprecatedBuiltinCode()
else:
opc = op_c.BuiltinCode()
except AttributeError:
opc = op_c.BuiltinCode()

op_code_id = opc
try:
op_code_str = self.builtin_op_code[op_code_id]
except KeyError:
Expand Down

0 comments on commit 493682d

Please sign in to comment.