Skip to content

Commit

Permalink
Flexbuffer parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
anijain2305 committed May 2, 2020
1 parent 9684169 commit 22ac689
Show file tree
Hide file tree
Showing 3 changed files with 225 additions and 144 deletions.
175 changes: 44 additions & 131 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from ... import nd as _nd
from .common import ExprTable
from .common import infer_shape as _infer_shape
from .tflite_flexbuffer import FlexBufferDecode

__all__ = ['from_tflite']

Expand Down Expand Up @@ -330,25 +331,30 @@ def convert_qnn_fused_activation_function(self, expr, fused_activation_fn,
except ImportError:
raise ImportError("The tflite package must be installed")

# Quantize a float value to an integer
quantize = lambda value : (value / scale) + zero_point
# Quantize a float value to an quantized integer value
quantize = lambda x: float(int(round(x / scale)) + zero_point)

# Get min/max of the output dtype. This will be used to ensure that clip a_min/a_max are not
# beyond the dtype range.
qmin = float(tvm.tir.op.min_value(dtype).value)
qmax = float(tvm.tir.op.max_value(dtype).value)

# The input expr is a quantized tensor with its scale and zero point. We calculate the
# suitable clip off points based on these scale and zero point.
if fused_activation_fn == ActivationFunctionType.NONE:
return expr
elif fused_activation_fn == ActivationFunctionType.RELU6:
return _op.clip(expr,
a_min=quantize(0),
a_max=quantize(6))
a_min=max(qmin, quantize(0)),
a_max=min(qmax, quantize(6.0)))
elif fused_activation_fn == ActivationFunctionType.RELU_N1_TO_1:
return _op.clip(expr,
a_min=quantize(-1),
a_max=quantize(1))
a_min=max(qmin, quantize(-1.0)),
a_max=min(qmax, quantize(1.0)))
elif fused_activation_fn == ActivationFunctionType.RELU:
return _op.clip(expr,
a_min=quantize(0),
a_max=float(tvm.tir.op.min_value(dtype).value))
a_min=max(qmin, quantize(0.0)),
a_max=qmax)

fused_activation_fn_str = self.activation_fn_type[fused_activation_fn]
raise tvm.error.OpNotImplemented(
Expand Down Expand Up @@ -1432,21 +1438,24 @@ def convert_fully_connected(self, op):
new_input_scale = relay.const(new_input_scale_val, 'float32')
new_input_zero_point = relay.const(0, 'int32')

# Call activation function
out = self.convert_qnn_fused_activation_function(\
expr=out,
fused_activation_fn=fused_activation_fn,
scale=new_input_scale_val,
zero_point=0,
dtype='int32')

# Requantize
out = _qnn.op.requantize(out,
input_scale=new_input_scale,
input_zero_point=new_input_zero_point,
output_scale=output_tensor.qnn_params['scale'],
output_zero_point=output_tensor.qnn_params['zero_point'],
out_dtype=output_tensor_type_str)

# Call activation function
output_scale_val = get_scalar_from_constant(output_tensor.qnn_params['scale'])
output_zero_point_val = get_scalar_from_constant(output_tensor.qnn_params['zero_point'])
out = self.convert_qnn_fused_activation_function(\
expr=out,
fused_activation_fn=fused_activation_fn,
scale=output_scale_val,
zero_point=output_zero_point_val,
dtype=output_tensor_type_str)

else:
out = self.convert_fused_activation_function(out, fused_activation_fn)

Expand Down Expand Up @@ -1645,21 +1654,23 @@ def convert_conv(self, op, conv_type):
new_input_scale = relay.const(new_input_scale_val, 'float32')
new_input_zero_point = relay.const(0, 'int32')

# Call activation function
out = self.convert_qnn_fused_activation_function(\
expr=out,
fused_activation_fn=fused_activation_fn,
scale=new_input_scale_val,
zero_point=0,
dtype='int32')

# Finally requantize
out = _qnn.op.requantize(out,
input_scale=new_input_scale,
input_zero_point=new_input_zero_point,
output_scale=output_tensor.qnn_params['scale'],
output_zero_point=output_tensor.qnn_params['zero_point'],
out_dtype=output_tensor_type_str)

# Call activation function
output_scale_val = get_scalar_from_constant(output_tensor.qnn_params['scale'])
output_zero_point_val = get_scalar_from_constant(output_tensor.qnn_params['zero_point'])
out = self.convert_qnn_fused_activation_function(\
expr=out,
fused_activation_fn=fused_activation_fn,
scale=output_scale_val,
zero_point=output_zero_point_val,
dtype=output_tensor_type_str)
else:
out = self.convert_fused_activation_function(out, fused_activation_fn)

Expand Down Expand Up @@ -2302,28 +2313,15 @@ def convert_transpose_conv(self, op):

def convert_detection_postprocess(self, op):
"""Convert TFLite_Detection_PostProcess"""
_option_names = [
"w_scale",
"max_detections",
"_output_quantized",
"detections_per_class",
"x_scale",
"nms_score_threshold",
"num_classes",
"max_classes_per_detection",
"use_regular_nms",
"y_scale",
"h_scale",
"_support_output_type_float_in_quantized_op",
"nms_iou_threshold"
]

custom_options = get_custom_options(op, _option_names)
if custom_options["use_regular_nms"]:
raise tvm.error.OpAttributeUnImplemented(
"use_regular_nms=True is not yet supported for operator {}."
.format("TFLite_Detection_PostProcess")
)
flexbuffer = op.CustomOptionsAsNumpy().tobytes()
custom_options = FlexBufferDecode(flexbuffer).decode()

if "use_regular_nms" in custom_options:
if custom_options["use_regular_nms"]:
raise tvm.error.OpAttributeUnImplemented(
"use_regular_nms=True is not yet supported for operator {}."
.format("TFLite_Detection_PostProcess")
)

inputs = self.get_input_tensors(op)
assert len(inputs) == 3, "inputs length should be 3"
Expand Down Expand Up @@ -2494,91 +2492,6 @@ def get_tensor_name(subgraph, tensor_idx):
return subgraph.Tensors(tensor_idx).Name().decode("utf-8")


def get_custom_options(op, option_names):
"""Get the options of a custom operator.
This implements partial flexbuffer deserialization to be able
to read custom options. It is not intended to be a general
purpose flexbuffer deserializer and as such only supports a
limited number of types and assumes the data is a flat map.
Parameters
----------
op:
A custom TFlite operator.
option_names: list
A complete list of the custom option names.
Returns
-------
options: dict
A dictionary of the custom options.
"""
import struct
from enum import IntEnum

class _FlexBufferType(IntEnum):
"""Flexbuffer type schema from flexbuffers.h"""
FBT_NULL = 0
FBT_INT = 1
FBT_UINT = 2
FBT_FLOAT = 3
# Types above stored inline, types below store an offset.
FBT_KEY = 4
FBT_STRING = 5
FBT_INDIRECT_INT = 6
FBT_INDIRECT_UINT = 7
FBT_INDIRECT_FLOAT = 8
FBT_MAP = 9
FBT_VECTOR = 10 # Untyped.
FBT_VECTOR_INT = 11 # Typed any size (stores no type table).
FBT_VECTOR_UINT = 12
FBT_VECTOR_FLOAT = 13
FBT_VECTOR_KEY = 14
FBT_VECTOR_STRING = 15
FBT_VECTOR_INT2 = 16 # Typed tuple (no type table, no size field).
FBT_VECTOR_UINT2 = 17
FBT_VECTOR_FLOAT2 = 18
FBT_VECTOR_INT3 = 19 # Typed triple (no type table, no size field).
FBT_VECTOR_UINT3 = 20
FBT_VECTOR_FLOAT3 = 21
FBT_VECTOR_INT4 = 22 # Typed quad (no type table, no size field).
FBT_VECTOR_UINT4 = 23
FBT_VECTOR_FLOAT4 = 24
FBT_BLOB = 25
FBT_BOOL = 26
FBT_VECTOR_BOOL = 36 # To Allow the same type of conversion of type to vector type

buffer = op.CustomOptionsAsNumpy().tobytes()
value_vector_offset = buffer[-3]
buffer = buffer[:-3]
num_bytes = 4 # Assume all values are stored in 32 bit width
value_vector_size = struct.unpack(
"<i", buffer[-value_vector_offset - num_bytes:-value_vector_offset]
)[0]
type_offset = value_vector_size
types = buffer[-type_offset:]
values = []
for i, t in enumerate(types):
flex_type = _FlexBufferType(t >> 2)
value_offset = -value_vector_offset + i*num_bytes
value_bytes = buffer[value_offset:value_offset+num_bytes]
if flex_type == _FlexBufferType.FBT_BOOL:
value = bool(value_bytes[0])
if flex_type == _FlexBufferType.FBT_INT:
value = struct.unpack("<i", value_bytes)[0]
if flex_type == _FlexBufferType.FBT_UINT:
value = struct.unpack("<I", value_bytes)[0]
if flex_type == _FlexBufferType.FBT_FLOAT:
value = struct.unpack("<f", value_bytes)[0]

values.append(value)

custom_options = dict(zip(sorted(option_names), values))
return custom_options


def from_tflite(model, shape_dict, dtype_dict):
"""Convert from tflite model into compatible relay Function.
Expand Down
154 changes: 154 additions & 0 deletions python/tvm/relay/frontend/tflite_flexbuffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-argument, too-many-lines, import-outside-toplevel
"""Tensorflow lite frontend helper to parse custom options in Flexbuffer format."""

import struct
from enum import IntEnum

class BitWidth(IntEnum):
"""Flexbuffer bit width schema from flexbuffers.h"""
BIT_WIDTH_8 = 0
BIT_WIDTH_16 = 1
BIT_WIDTH_32 = 2
BIT_WIDTH_64 = 3

class FlexBufferType(IntEnum):
"""Flexbuffer type schema from flexbuffers.h"""
FBT_NULL = 0
FBT_INT = 1
FBT_UINT = 2
FBT_FLOAT = 3
# Types above stored inline, types below store an offset.
FBT_KEY = 4
FBT_STRING = 5
FBT_INDIRECT_INT = 6
FBT_INDIRECT_UINT = 7
FBT_INDIRECT_FLOAT = 8
FBT_MAP = 9
FBT_VECTOR = 10 # Untyped.
FBT_VECTOR_INT = 11 # Typed any size (stores no type table).
FBT_VECTOR_UINT = 12
FBT_VECTOR_FLOAT = 13
FBT_VECTOR_KEY = 14
FBT_VECTOR_STRING = 15
FBT_VECTOR_INT2 = 16 # Typed tuple (no type table, no size field).
FBT_VECTOR_UINT2 = 17
FBT_VECTOR_FLOAT2 = 18
FBT_VECTOR_INT3 = 19 # Typed triple (no type table, no size field).
FBT_VECTOR_UINT3 = 20
FBT_VECTOR_FLOAT3 = 21
FBT_VECTOR_INT4 = 22 # Typed quad (no type table, no size field).
FBT_VECTOR_UINT4 = 23
FBT_VECTOR_FLOAT4 = 24
FBT_BLOB = 25
FBT_BOOL = 26
FBT_VECTOR_BOOL = 36 # To Allow the same type of conversion of type to vector type


class FlexBufferDecode(object):
"""
This implements partial flexbuffer deserialization to be able
to read custom options. It is not intended to be a general
purpose flexbuffer deserializer and as such only supports a
limited number of types and assumes the data is a flat map.
"""

def __init__(self, buffer):
self.buffer = buffer

def indirect_jump(self, offset, byte_width):
""" Helper function to read the offset value and jump """
unpack_str = ""
if byte_width == 1:
unpack_str = "<B"
elif byte_width == 4:
unpack_str = "<i"
assert unpack_str != ""
back_jump = struct.unpack(unpack_str,
self.buffer[offset: offset + byte_width])[0]
return offset - back_jump

def decode_keys(self, end, size, byte_width):
""" Decodes the flexbuffer type vector. Map keys are stored in this form """
# Keys are strings here. The format is all strings seperated by null, followed by back
# offsets for each of the string. For example, (str1)\0(str1)\0(offset1)(offset2) The end
# pointer is pointing at the end of all strings
keys = list()
for i in range(0, size):
offset_pos = end + i * byte_width
start_index = self.indirect_jump(offset_pos, byte_width)
str_size = self.buffer[start_index:].find(b"\0")
assert str_size != -1
s = self.buffer[start_index: start_index + str_size].decode("utf-8")
keys.append(s)
return keys

def decode_vector(self, end, size, byte_width):
""" Decodes the flexbuffer vector """
# Each entry in the vector can have different datatype. Each entry is of fixed length. The
# format is a sequence of all values followed by a sequence of datatype of all values. For
# example - (4)(3.56)(int)(float) The end here points to the start of the values.
values = list()
for i in range(0, size):
value_type_pos = end + size * byte_width + i
value_type = FlexBufferType(self.buffer[value_type_pos] >> 2)
value_bytes = self.buffer[end + i * byte_width: end + (i + 1) * byte_width]
if value_type == FlexBufferType.FBT_BOOL:
value = bool(value_bytes[0])
elif value_type == FlexBufferType.FBT_INT:
value = struct.unpack("<i", value_bytes)[0]
elif value_type == FlexBufferType.FBT_UINT:
value = struct.unpack("<I", value_bytes)[0]
elif value_type == FlexBufferType.FBT_FLOAT:
value = struct.unpack("<f", value_bytes)[0]
else:
raise Exception
values.append(value)
return values

def decode_map(self, end, byte_width, parent_byte_width):
""" Decodes the flexbuffer map and returns a dict """
mid_loc = self.indirect_jump(end, parent_byte_width)
map_size = struct.unpack("<i", self.buffer[mid_loc - byte_width:mid_loc])[0]

# Find keys
keys_offset = mid_loc - byte_width * 3
keys_end = self.indirect_jump(keys_offset, byte_width)
keys_byte_width = struct.unpack(\
"<i",
self.buffer[keys_offset + byte_width:keys_offset + 2 * byte_width:])[0]
keys = self.decode_keys(keys_end, map_size, 1)

# Find values
values_end = self.indirect_jump(end, parent_byte_width)
values = self.decode_vector(values_end, map_size, byte_width)
return dict(zip(keys, values))

def decode(self):
root_end = len(self.buffer) - 1
root_byte_width = self.buffer[root_end]
root_end -= 1
root_packed_type = self.buffer[root_end]
root_end -= root_byte_width

root_type = FlexBufferType(root_packed_type >> 2);
byte_width = 1 << BitWidth(root_packed_type & 3);

if root_type == FlexBufferType.FBT_MAP:
return self.decode_map(root_end, byte_width, root_byte_width)
raise NotImplementedError("Flexbuffer Decoding is partially imlpemented.")
Loading

0 comments on commit 22ac689

Please sign in to comment.