Skip to content

Commit

Permalink
feat(tflite): support bias use int64 dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
dingshaohua960303 committed Aug 25, 2022
1 parent 6e19d20 commit 9c1feac
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 1 deletion.
10 changes: 9 additions & 1 deletion bin/convert
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ complete_str = """_mgeconvert(){
return
;;
tracedmodule_to_tflite)
words="-i --input -o --output --input_data_type --input_scales --input_zero_points --require_quantize --param_fake_quant --quantize_file_path --graph_name --mtk --end_point --outspec --remove_relu --prefer_same_pad_mode"
words="-i --input -o --output --input_data_type --input_scales --input_zero_points --require_quantize --param_fake_quant --quantize_file_path --graph_name --mtk --end_point --outspec --remove_relu --prefer_same_pad_mode --use_int64_bias"
COMPREPLY=( $(compgen -W "$words" -- $word) )
return
;;
Expand Down Expand Up @@ -397,6 +397,7 @@ def init(subparsers):
outspec=args.outspec,
remove_relu=args.remove_relu,
prefer_same_pad_mode=args.prefer_same_pad_mode,
use_int64_bias=args.use_int64_bias,
)
else:
mgeconvert.mge_to_tflite(
Expand All @@ -406,6 +407,7 @@ def init(subparsers):
mtk=args.mtk,
outspec=args.outspec,
prefer_same_pad_mode=args.prefer_same_pad_mode,
use_int64_bias=args.use_int64_bias,
)

def tflite_parser(subparsers):
Expand Down Expand Up @@ -488,6 +490,12 @@ def init(subparsers):
help="whether prefer to use SAME pad mode for conv op",
)

p.add_argument(
"--use_int64_bias",
action="store_true",
help="whether use int64 as dtype of bias",
)

tflite_parser(subparsers)


Expand Down
1 change: 1 addition & 0 deletions mgeconvert/backend/ir_to_tflite/tflite_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def get_shape_param(
"int8": TensorType.INT8,
"int16": TensorType.INT16,
"int32": TensorType.INT32,
"int64": TensorType.INT64,
"qint8_narrow": TensorType.INT8,
}

Expand Down
28 changes: 28 additions & 0 deletions mgeconvert/converter_ir/ir_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
TanHOpr,
TransposeOpr,
TrueDivOpr,
_ConvOpr,
_PoolOpr,
)
from .ir_tensor import AxisOrder, IRTensor
Expand Down Expand Up @@ -121,6 +122,7 @@ class TransformerRule(Enum):
TRANSPOSE_LINEAR_WEIGHT_TO_NHWC = 133
# force fc with no trans for megengine
FC_NO_TRANS = 134
BIAS_ASTYPE_INT64 = 135


def cmp_rules(a, b):
Expand Down Expand Up @@ -1537,3 +1539,29 @@ def trans_tensor(tensor):
if opr.transpose_b and tensor_b.owner_opr == None:
opr.transpose_b = False
trans_tensor(tensor_b)


@_register_tranformation_rule(TransformerRule.BIAS_ASTYPE_INT64)
def _bias_astype_int64(net: IRGraph):
for opr in net.all_oprs:
if not isinstance(opr, (MatMulOpr, _ConvOpr)):
continue
bias = None
if isinstance(opr, MatMulOpr) and len(opr.inp_tensors) == 3:
bias = opr.inp_tensors[2]
elif isinstance(opr, Deconv2dOpr) and len(opr.inp_tensors) > 3:
if (
opr.inp_tensors[0].shape == [4] and len(opr.inp_tensors) == 4
): # shape as input
bias = opr.inp_tensors[-1]
if len(opr.inp_tensors) == 4:
bias = opr.inp_tensors[-1]
elif isinstance(opr, (Conv2dOpr, ConvRelu2dOpr)) and len(opr.inp_tensors) == 3:
bias = opr.inp_tensors[-1]
if bias is not None and bias.scale is not None:
bias.set_qparams(
scale=bias.scale,
zero_point=bias.zero_point,
q_dtype="int64",
np_dtype="int64",
)
4 changes: 4 additions & 0 deletions mgeconvert/converters/mge_to_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def mge_to_tflite(
disable_nhwc=False,
outspec=None,
prefer_same_pad_mode=False,
use_int64_bias=False,
):
"""
Convert megengine model to TFLite,
Expand Down Expand Up @@ -70,6 +71,9 @@ def mge_to_tflite(
transformer_options.append(TransformerRule.DECONV_ADD_ZERO_BIAS,)
transformer_options.append(TransformerRule.FUSE_FOR_DECONV_BIAS,)

if use_int64_bias:
transformer_options.append(TransformerRule.BIAS_ASTYPE_INT64)

transformer = IRTransform(transformer_options)
transformed_irgraph = transformer.transform(irgraph)

Expand Down
3 changes: 3 additions & 0 deletions mgeconvert/converters/tm_to_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def tracedmodule_to_tflite(
remove_relu=False,
prefer_same_pad_mode=False,
disable_nhwc=False,
use_int64_bias=False,
):
"""
Convert traced model to TFLite,
Expand Down Expand Up @@ -89,6 +90,8 @@ def tracedmodule_to_tflite(
transformer_options.append(TransformerRule.DECONV_ADD_ZERO_BIAS,)
if remove_relu:
transformer_options.append(TransformerRule.REMOVE_TFLITE_RELU,)
if use_int64_bias:
transformer_options.append(TransformerRule.BIAS_ASTYPE_INT64)

transformer = IRTransform(transformer_options)
transformed_irgraph = transformer.transform(irgraph)
Expand Down

0 comments on commit 9c1feac

Please sign in to comment.