diff --git a/python/tvm/relay/frontend/__init__.py b/python/tvm/relay/frontend/__init__.py index 30d544694fba5..7b4fcd34af890 100644 --- a/python/tvm/relay/frontend/__init__.py +++ b/python/tvm/relay/frontend/__init__.py @@ -10,3 +10,4 @@ from .mxnet import from_mxnet from .keras import from_keras from .onnx import from_onnx +from .tflite import from_tflite diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py new file mode 100644 index 0000000000000..3c926d915fdd4 --- /dev/null +++ b/python/tvm/relay/frontend/tflite.py @@ -0,0 +1,629 @@ +# pylint: disable=invalid-name, unused-argument +"""Tensorflow lite frontend.""" +from __future__ import absolute_import as _abs +import math +import numpy as np +from .. import ir_pass +from .. import expr as _expr +from .. import op as _op +from ... import nd as _nd +from .common import ExprTable + +__all__ = ['from_tflite'] + +class TensorWrapper(object): + """Tensor wrapper for TFLite Tensor""" + def __init__(self, tensor_idx, tensor, buffer): + self.tensor_idx = tensor_idx + self.tensor = tensor + self.buffer = buffer + +class OperatorConverter(object): + """Operator Converted for converting TFLite ops to Relay ops""" + def __init__(self, model, subgraph, exp_tab): + + try: + from tflite.BuiltinOperator import BuiltinOperator + from tflite.BuiltinOptions import BuiltinOptions + from tflite.ActivationFunctionType import ActivationFunctionType + except ImportError: + raise ImportError("The tflite package must be installed") + + self.model = model + self.subgraph = subgraph + self.exp_tab = exp_tab + self.builtin_op_code = build_str_map(BuiltinOperator()) + self.activation_fn_type = build_str_map(ActivationFunctionType()) + self.builtin_options = build_str_map(BuiltinOptions()) + self.convert_map = { + 'CONV_2D': self.convert_conv2d, + 'DEPTHWISE_CONV_2D': self.convert_depthwise_conv2d, + 'AVERAGE_POOL_2D': self.convert_average_pool2d, + 'RESHAPE': self.convert_reshape, + 'SOFTMAX': self.convert_softmax, + 'SQUEEZE': self.convert_squeeze, + 'MAX_POOL_2D': self.convert_max_pool2d, + # Add more operators + } + + def check_unsupported_ops(self): + """Check unsupported TFLite ops in our converter.""" + unsupported_ops_set = set() + + for op_idx in range(self.subgraph.OperatorsLength()): + op = self.subgraph.Operators(op_idx) + op_code_str = self.get_op_code_str(op) + if op_code_str not in self.convert_map: + unsupported_ops_set.add(op_code_str) + + if unsupported_ops_set: + raise NotImplementedError("Unsupported Ops: %s" % ( + ','.join(unsupported_ops_set))) + + def convert_op_to_relay(self): + """Convert TFLite ops to relay ops""" + for op_idx in range(self.subgraph.OperatorsLength()): + op = self.subgraph.Operators(op_idx) + op_code_str = self.get_op_code_str(op) + output_tensors = self.get_output_tensors(op) + + ret = self.convert_map[op_code_str](op) + + if len(output_tensors) == 1: + tensor_idx = output_tensors[0].tensor_idx + self.exp_tab.set_expr(get_tensor_name(self.subgraph, tensor_idx), ret) + else: + for idx, output_tensor in enumerate(output_tensors): + self.exp_tab.set_expr(get_tensor_name(self.subgraph, output_tensor.tensor_idx), + ret[idx]) + + def get_op_code_str(self, op): + """Get TFLite ops string representation""" + try: + from tflite.BuiltinOperator import BuiltinOperator + except ImportError: + raise ImportError("The tflite package must be installed") + + op_code_list_idx = op.OpcodeIndex() + op_code_id = self.model.OperatorCodes(op_code_list_idx).BuiltinCode() + op_code_str = self.builtin_op_code[op_code_id] + if op_code_id == BuiltinOperator.CUSTOM: + # Custom operator + raise NotImplementedError("Not Support Custom Operator Now") + return op_code_str + + def get_input_tensors(self, op): + operator_inputs = op.InputsAsNumpy() + return self.get_tensors(operator_inputs) + + def get_output_tensors(self, op): + operator_outputs = op.OutputsAsNumpy() + return self.get_tensors(operator_outputs) + + def get_tensors(self, tensors_idx_list): + """Get tensor wrapper list from given TFLite tensor index list""" + return_list = list() + for tensor_idx in tensors_idx_list: + if tensor_idx < 0: + return_list.append(TensorWrapper(tensor_idx, 0, 0)) + continue + + tensor = self.subgraph.Tensors(tensor_idx) + buffer_idx = tensor.Buffer() + buffer = self.model.Buffers(buffer_idx) + return_list.append(TensorWrapper(tensor_idx, tensor, buffer)) + return return_list + + def get_tensor_value(self, tensor_wrapper): + """Get tensor buffer value from given tensor wrapper""" + assert isinstance(tensor_wrapper, TensorWrapper) + + try: + from tflite.TensorType import TensorType + except ImportError: + raise ImportError("The tflite package must be installed") + + if tensor_wrapper.tensor.Type() == TensorType.UINT8: + return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.uint8).reshape( + tensor_wrapper.tensor.ShapeAsNumpy()) + elif tensor_wrapper.tensor.Type() == TensorType.FLOAT32: + return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.float32).reshape( + tensor_wrapper.tensor.ShapeAsNumpy()) + elif tensor_wrapper.tensor.Type() == TensorType.INT32: + return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.int32).reshape( + tensor_wrapper.tensor.ShapeAsNumpy()) + else: + raise NotImplementedError("Not support tensor type {}" + .format(str(tensor_wrapper.tensor.Type()))) + + def get_tensor_type_str(self, tensor_type): + """Get tensor type string representation when given TFLite tensor type""" + try: + from tflite.TensorType import TensorType + except ImportError: + raise ImportError("The tflite package must be installed") + + if tensor_type == TensorType.UINT8: + return "uint8" + elif tensor_type == TensorType.FLOAT32: + return "float32" + elif tensor_type == TensorType.INT32: + return "int32" + else: + raise NotImplementedError("Not support tensor type {}".format(str(tensor_type))) + + def convert_conv2d(self, op): + """Convert TFLite conv2d""" + return self.convert_conv(op, "conv2d") + + def convert_depthwise_conv2d(self, op): + """Convert TFLite depthwise conv2d""" + return self.convert_conv(op, "depthwise") + + def convert_average_pool2d(self, op): + """Convert TFLite average pool2d""" + return self.convert_pool2d(op, "average") + + def convert_max_pool2d(self, op): + """Convert TFLite max pool2d""" + return self.convert_pool2d(op, "max") + + def convert_reshape(self, op): + """Convert TFLite reshape""" + try: + from tflite.BuiltinOptions import BuiltinOptions + from tflite.Operator import Operator + from tflite.ReshapeOptions import ReshapeOptions + except ImportError: + raise ImportError("The tflite package must be installed") + + assert isinstance(op, Operator) + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 2, "input tensors length should be 2" + input_tensor = input_tensors[0] + input_tensor_idx = input_tensor.tensor_idx + + assert op.BuiltinOptionsType() == BuiltinOptions.ReshapeOptions + op_options = op.BuiltinOptions() + reshape_options = ReshapeOptions() + reshape_options.Init(op_options.Bytes, op_options.Pos) + target_shape = reshape_options.NewShapeAsNumpy() + input_shape_length = len(input_tensor.tensor.ShapeAsNumpy()) + + in_expr = self.get_expr(input_tensor_idx) + + if input_shape_length == 1 or input_shape_length == 2: + # The rule is channel first (after N but before H, W). + # length of 1 means N*H*W*C, do nothing. + # length of 2 means N*H*W, C, do nothing. + pass + elif input_shape_length == 3: + # convert N C H*W to N H*W C + in_expr = _op.transpose(in_expr, axes=(0, 2, 1)) + elif input_shape_length == 4: + # convert input to N H W C, then reshape to target shape, + # finally convert back if necessary + in_expr = _op.transpose(in_expr, axes=(0, 2, 3, 1)) + else: + raise NotImplementedError("Not support input shape length {} of reshape : " + .format(str(input_shape_length))) + + out = _op.reshape(in_expr, newshape=tuple(target_shape)) + + # The rule is channel first. + # 1: N*H*W*C + # 2: N*H*W, C + # 3: N H W C, reshape to N H*W C, transpose to N C H*W + # 4: N H W C, transpose to N C H W + # add more if we need target shapes in future + if len(target_shape) == 1 or len(target_shape) == 2: + pass + elif len(target_shape) == 3: + out = _op.transpose(out, axes=(0, 2, 1)) + elif len(target_shape) == 4: + out = _op.transpose(out, axes=(0, 3, 1, 2)) + else: + raise NotImplementedError("Not support to reshape to shape length {}: " + .format(str(len(target_shape)))) + + return out + + def convert_softmax(self, op): + """Convert TFLite softmax""" + try: + from tflite.Operator import Operator + except ImportError: + raise ImportError("The tflite package must be installed") + + assert isinstance(op, Operator) + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 1, "input tensors length should be 1" + + input_tensor = input_tensors[0] + input_tensor_idx = input_tensor.tensor_idx + params = {'axis': 1} # 1 is channel + in_expr = self.get_expr(input_tensor_idx) + out = _op.nn.softmax(in_expr, **params) + + return out + + def convert_squeeze(self, op): + """Convert TFLite squeeze""" + try: + from tflite.BuiltinOptions import BuiltinOptions + from tflite.Operator import Operator + from tflite.SqueezeOptions import SqueezeOptions + except ImportError: + raise ImportError("The tflite package must be installed") + + assert isinstance(op, Operator) + input_tensors = self.get_input_tensors(op) + output_tensors = self.get_output_tensors(op) + assert len(input_tensors) == 1, "input tensors length should be 1" + assert len(output_tensors) == 1, "output tensors length should be 1" + input_tensor = input_tensors[0] + input_tensor_idx = input_tensor.tensor_idx + + assert op.BuiltinOptionsType() == BuiltinOptions.SqueezeOptions + op_options = op.BuiltinOptions() + squeeze_options = SqueezeOptions() + squeeze_options.Init(op_options.Bytes, op_options.Pos) + squeeze_axis = squeeze_options.SqueezeDimsAsNumpy() + input_shape_length = len(input_tensor.tensor.ShapeAsNumpy()) + output_shape_length = len(output_tensors[0].tensor.ShapeAsNumpy()) + + in_expr = self.get_expr(input_tensor_idx) + + # TFLite is N H W C, our layout is N C H W + if input_shape_length == 1 or input_shape_length == 2: + # The rule is channel first (after N but before H, W). + # length of 1 means N*H*W*C, do nothing. + # length of 2 means N*H*W, C, do nothing. + pass + elif input_shape_length == 3: + # convert N C H*W to N H*W C + in_expr = _op.transpose(in_expr, axes=(0, 2, 1)) + elif input_shape_length == 4: + # convert input to N H W C, then reshape to target shape, + # finally convert back if necessary + in_expr = _op.transpose(in_expr, axes=(0, 2, 3, 1)) + else: + raise NotImplementedError("Not support input shape length {} of squeeze : " + .format(str(input_shape_length))) + + out = _op.squeeze(in_expr, axis=tuple(squeeze_axis)) + + # The rule is channel first. + # 1: N*H*W*C + # 2: N*H*W, C + # 3: N H W C, reshape to N H*W C, transpose to N C H*W + # 4: N H W C, transpose to N C H W + # add more if we need target shapes in future + if output_shape_length == 1 or output_shape_length == 2: + pass + elif output_shape_length == 3: + out = _op.transpose(out, axes=(0, 2, 1)) + elif output_shape_length == 4: + out = _op.transpose(out, axes=(0, 3, 1, 2)) + else: + raise NotImplementedError("Not support to squeeze to length {} : " + .format(str(output_shape_length))) + + return out + + def convert_fused_activation_function(self, in_expr, fused_activation_fn): + """Convert TFLite fused activation function""" + try: + from tflite.ActivationFunctionType import ActivationFunctionType + except ImportError: + raise ImportError("The tflite package must be installed") + assert fused_activation_fn != ActivationFunctionType.NONE + if fused_activation_fn == ActivationFunctionType.RELU6: + return _op.clip(in_expr, a_min=0, a_max=6) + elif fused_activation_fn == ActivationFunctionType.RELU: + return _op.nn.relu(in_expr) + elif fused_activation_fn == ActivationFunctionType.RELU_N1_TO_1: + return _op.clip(in_expr, a_min=-1, a_max=1) + elif fused_activation_fn == ActivationFunctionType.TANH: + return _op.tanh(in_expr) + else: + fused_activation_fn_str = self.activation_fn_type[fused_activation_fn] + raise NotImplementedError("Unsupported fused activation fn {}" + .format(fused_activation_fn_str)) + + def convert_conv(self, op, conv_type): + """convolution implementation.""" + try: + from tflite.BuiltinOptions import BuiltinOptions + from tflite.ActivationFunctionType import ActivationFunctionType + from tflite.TensorType import TensorType + from tflite.Operator import Operator + from tflite.Conv2DOptions import Conv2DOptions + from tflite.DepthwiseConv2DOptions import DepthwiseConv2DOptions + from tflite.Padding import Padding + except ImportError: + raise ImportError("The tflite package must be installed") + + assert isinstance(op, Operator) + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) >= 2, "input tensors length should be >= 2" + + input_tensor = input_tensors[0] + input_tensor_idx = input_tensor.tensor_idx + weight_tensor = input_tensors[1] + + is_depthwise_conv = False + if conv_type == 'conv2d': + assert op.BuiltinOptionsType() == BuiltinOptions.Conv2DOptions + op_options = op.BuiltinOptions() + conv_options = Conv2DOptions() + conv_options.Init(op_options.Bytes, op_options.Pos) + elif conv_type == 'depthwise': + is_depthwise_conv = True + assert op.BuiltinOptionsType() == BuiltinOptions.DepthwiseConv2DOptions + op_options = op.BuiltinOptions() + conv_options = DepthwiseConv2DOptions() + conv_options.Init(op_options.Bytes, op_options.Pos) + depth_multiplier = conv_options.DepthMultiplier() + assert depth_multiplier == 1, "TF frontend have transformed it be 1 " \ + "no matter original value be set by 0.25, 0.5 or any else" + else: + raise ValueError("Not support conv type: {}".format(conv_type)) + + stride_h = conv_options.StrideH() + stride_w = conv_options.StrideW() + dilation_h = conv_options.DilationHFactor() + dilation_w = conv_options.DilationWFactor() + padding = conv_options.Padding() + fused_activation_fn = conv_options.FusedActivationFunction() + + _, input_h, input_w, _ = input_tensor.tensor.ShapeAsNumpy() + + if is_depthwise_conv: + multiplier, kernel_h, kernel_w, in_channels = weight_tensor.tensor.ShapeAsNumpy() + assert multiplier == depth_multiplier + else: + output_channels, kernel_h, kernel_w, _ = weight_tensor.tensor.ShapeAsNumpy() + + dilated_kernel_h = dilation_h * (kernel_h - 1) + 1 + dilated_kernel_w = dilation_w * (kernel_w - 1) + 1 + + params = {'kernel_size': [kernel_h, kernel_w], + 'strides': [stride_h, stride_w], + 'dilation': [dilation_h, dilation_w], + 'padding': [0, 0]} + + if is_depthwise_conv: + params['channels'] = int(in_channels * multiplier) + params['groups'] = int(in_channels) + else: + params['channels'] = int(output_channels) + + # weight tensor type should be UINT8 (quantization) or FLOAT32 + weight_tensor_type = weight_tensor.tensor.Type() + assert weight_tensor_type == TensorType.UINT8 or weight_tensor_type == TensorType.FLOAT32 + weight_tensor_type_str = self.get_tensor_type_str(weight_tensor_type) + + in_expr = self.get_expr(input_tensor_idx) + weight_value = self.get_tensor_value(weight_tensor) + + if is_depthwise_conv: + # TFLite is M KH KW IC, we require IC M KH KW + weight_value = weight_value.transpose((3, 0, 1, 2)) + else: + # TFLite is OC KH KW IC, we require OC IC KH kW + weight_value = weight_value.transpose((0, 3, 1, 2)) + + weight_expr = self.exp_tab.new_const(weight_value, dtype=weight_tensor_type_str) + + if padding == Padding.VALID: + pass + elif padding == Padding.SAME: + pad_top, pad_bottom = get_pad_value(input_h, dilated_kernel_h, stride_h) + pad_left, pad_right = get_pad_value(input_w, dilated_kernel_w, stride_w) + in_expr = _op.nn.pad(data=in_expr, pad_width=((0, 0), (0, 0), + (pad_top, pad_bottom), + (pad_left, pad_right))) + else: + raise NotImplementedError("Not support padding format: {}".format(padding)) + + out = _op.nn.conv2d(data=in_expr, weight=weight_expr, **params) + + # if we have bias + if len(input_tensors) == 3: + bias_tensor = input_tensors[2] + bias_tensor_type = bias_tensor.tensor.Type() + # bias tensor type should be INT32 (quantization) or FLOAT32 + assert bias_tensor_type == TensorType.INT32 or bias_tensor_type == TensorType.FLOAT32 + bias_tensor_type_str = self.get_tensor_type_str(bias_tensor_type) + bias_expr = self.exp_tab.new_const(self.get_tensor_value(bias_tensor), + dtype=bias_tensor_type_str) + out = _op.nn.bias_add(out, bias_expr) + + # If we have fused activations + if fused_activation_fn != ActivationFunctionType.NONE: + out = self.convert_fused_activation_function(out, fused_activation_fn) + + return out + + def convert_pool2d(self, op, pool_type): + """pool2d implementation.""" + try: + from tflite.BuiltinOptions import BuiltinOptions + from tflite.ActivationFunctionType import ActivationFunctionType + from tflite.Operator import Operator + from tflite.Pool2DOptions import Pool2DOptions + from tflite.Padding import Padding + except ImportError: + raise ImportError("The tflite package must be installed") + + assert isinstance(op, Operator) + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 1, "input tensors length should be 1" + input_tensor = input_tensors[0] + input_tensor_idx = input_tensor.tensor_idx + + assert op.BuiltinOptionsType() == BuiltinOptions.Pool2DOptions + op_options = op.BuiltinOptions() + pool2d_options = Pool2DOptions() + pool2d_options.Init(op_options.Bytes, op_options.Pos) + stride_h = pool2d_options.StrideH() + stride_w = pool2d_options.StrideW() + padding = pool2d_options.Padding() + filter_h = pool2d_options.FilterHeight() + filter_w = pool2d_options.FilterWidth() + fused_activation_fn = pool2d_options.FusedActivationFunction() + + params = {'pool_size': (filter_h, filter_w), + 'strides': (stride_h, stride_w), + 'padding': [0, 0]} + + in_expr = self.get_expr(input_tensor_idx) + + _, input_h, input_w, _ = input_tensor.tensor.ShapeAsNumpy() + if padding == Padding.VALID: + pass + elif padding == Padding.SAME: + pad_top, pad_bottom = get_pad_value(input_h, filter_h, stride_h) + pad_left, pad_right = get_pad_value(input_w, filter_w, stride_w) + params['padding'] = [pad_top, pad_left, pad_bottom, pad_right] + else: + raise NotImplementedError("Not support padding format: {}".format(padding)) + + if pool_type == "average": + out = _op.nn.avg_pool2d(in_expr, **params) + elif pool_type == "max": + out = _op.nn.max_pool2d(in_expr, **params) + else: + raise ValueError("Not support pool type: {}".format(pool_type)) + + # If we have fused activations + if fused_activation_fn != ActivationFunctionType.NONE: + out = self.convert_fused_activation_function(out, fused_activation_fn) + + return out + + def get_expr(self, input_tensor_idx): + return self.exp_tab.get_expr(get_tensor_name(self.subgraph, input_tensor_idx)) + +def build_str_map(obj): + """Build string map of TFLite enum int value + + Parameters + ---------- + obj: + TFLite class which contains enum int value, such as BuiltInOptions + + Returns + ------- + String representation map of TFLite class enum int value + """ + ret = {} + for field_name in dir(obj): + if not field_name.startswith('_'): + field_value = getattr(obj, field_name) + if isinstance(field_value, int): + ret[field_value] = field_name + return ret + +# SAME padding: https://www.tensorflow.org/api_guides/python/nn +def get_pad_value(data, kernel, stride): + """Get the pad tuple of value for SAME padding + + Parameters + ---------- + data: + 1D input data + + kernel: + 1D input kernel + + stride: + 1D input stride + + Returns + ------- + pad tuple of value + """ + + out = math.ceil(float(data) / float(stride)) + pad = max(0, (out - 1) * stride + kernel - data) + pad_before = pad // 2 + pad_after = pad - pad_before + return pad_before, pad_after + + +def get_tensor_name(subgraph, tensor_idx): + """Get the tensor name. + + Parameters + ---------- + subgraph: + tflite.Subgraph.Subgraph + + tensor: + tensor index in subgraph + + Returns + ------- + tensor name in UTF-8 encoding + """ + return subgraph.Tensors(tensor_idx).Name().decode("utf-8") + + +def from_tflite(model, shape_dict, dtype_dict): + """Convert from tflite model into compatible relay Function. + + Parameters + ---------- + model: + tflite.Model.Model + + shape_dict : dict of str to int list/tuple + Input shapes of the model. + + dtype_dict : dict of str to str + Input types of the model. + + Returns + ------- + func : tvm.relay.Function + Compatible relay Function + + params : dict of str to tvm.NDArray + The parameter dict to be used by relay + """ + try: + import tflite.Model + import tflite.SubGraph + import tflite.BuiltinOperator + except ImportError: + raise ImportError("The tflite package must be installed") + assert isinstance(model, tflite.Model.Model) + + # keep the same as tflite + assert model.SubgraphsLength() == 1, "only support one subgraph (main subgraph)" + subgraph = model.Subgraphs(0) + + # model inputs / outputs + model_inputs = subgraph.InputsAsNumpy() + model_outputs = subgraph.OutputsAsNumpy() + + exp_tab = ExprTable() + for model_input in model_inputs: + model_input_name = get_tensor_name(subgraph, model_input) + shape = shape_dict[model_input_name] if model_input_name in shape_dict else None + dtype = dtype_dict[model_input_name] if model_input_name in dtype_dict else "float32" + exp_tab.set_expr(model_input_name, _expr.var(model_input_name, shape=shape, dtype=dtype)) + + # op code in model + op_converter = OperatorConverter(model, subgraph, exp_tab) + op_converter.check_unsupported_ops() + op_converter.convert_op_to_relay() + + # params and outputs + params = {k:_nd.array(np.array(v)) for k, v in exp_tab.params.items()} + outputs = [exp_tab.get_expr(get_tensor_name(subgraph, i)) for i in model_outputs] + outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) + func = _expr.Function(ir_pass.free_vars(outputs), outputs) + return func, params diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py new file mode 100644 index 0000000000000..a929d4e339053 --- /dev/null +++ b/tests/python/frontend/tflite/test_forward.py @@ -0,0 +1,372 @@ +# pylint: disable=import-self, invalid-name, unused-argument +""" +TFLite testcases +================ +This article is a test script to test TFLite operator with Relay. +""" +from __future__ import print_function +import numpy as np +import tvm +from tvm import relay +from tvm.contrib import util +import tensorflow as tf +from tensorflow.python.framework import constant_op +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variables +from tensorflow.contrib.lite.python import interpreter as interpreter_wrapper + +import nnvm.testing.tf + +####################################################################### +# Generic run functions for TVM & TFLite +# -------------------------------------- +def convert_to_list(x): + if not isinstance(x, list): + x = [x] + return x + +def run_tvm_graph(tflite_model_buf, input_data, input_node, num_output=1, target='llvm', + out_names=None): + """ Generic function to compile on relay and execute on tvm """ + try: + import tflite.Model + except ImportError: + raise ImportError("The tflite package must be installed") + + # get TFLite model from buffer + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0) + + input_data = convert_to_list(input_data) + input_node = convert_to_list(input_node) + + shape_dict = {} + dtype_dict = {} + for i, e in enumerate(input_node): + shape_dict[e] = input_data[i].shape + dtype_dict[e] = input_data[i].dtype.name + + func, params = relay.frontend.from_tflite(tflite_model, + shape_dict=shape_dict, + dtype_dict=dtype_dict) + with relay.build_config(opt_level=3): + graph, lib, params = relay.build(func, target, params=params) + + ctx = tvm.context(target, 0) + from tvm.contrib import graph_runtime + m = graph_runtime.create(graph, lib, ctx) + # set inputs + for i, e in enumerate(input_node): + m.set_input(e, tvm.nd.array(input_data[i].astype(input_data[i].dtype))) + + m.set_input(**params) + # execute + m.run() + # get outputs + assert out_names is None or num_output == len(out_names), "out_names: {} num_output: {}".format( + out_names, num_output) + tvm_output_list = [] + for i in range(0, num_output): + tvm_output = m.get_output(i) + tvm_output_list.append(tvm_output.asnumpy()) + return tvm_output_list + + +def run_tflite_graph(tflite_model_buf, input_data): + """ Generic function to execute TFLite """ + input_data = convert_to_list(input_data) + + interpreter = interpreter_wrapper.Interpreter(model_content=tflite_model_buf) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + output_details = interpreter.get_output_details() + + # set input + assert len(input_data) == len(input_details) + for i in range(len(input_details)): + interpreter.set_tensor(input_details[i]['index'], input_data[i]) + + # Run + interpreter.invoke() + + # get output + tflite_output = list() + for i in range(len(output_details)): + tflite_output.append(interpreter.get_tensor(output_details[i]['index'])) + + return tflite_output + + +def compare_tflite_with_tvm(tflite_in_data, tvm_in_data, in_name, input_tensors, + output_tensors, output_need_transpose_nchw=False, + init_global_variables=False): + """Generic function to generate and compare TFLite and TVM output""" + tflite_in_data = convert_to_list(tflite_in_data) + tvm_in_data = convert_to_list(tvm_in_data) + in_name = convert_to_list(in_name) + in_node = [0] * len(in_name) + for i in range(len(in_name)): + in_node[i] = in_name[i].split(':')[0] if ":" in in_name[i] else in_name[i] + + with tf.Session() as sess: + if init_global_variables: + sess.run(variables.global_variables_initializer()) + # convert to tflite model + converter = tf.contrib.lite.TFLiteConverter.from_session( + sess, input_tensors, output_tensors) + tflite_model_buffer = converter.convert() + tflite_output = run_tflite_graph(tflite_model_buffer, tflite_in_data) + + for device in ["llvm"]: + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + continue + + tvm_output = run_tvm_graph(tflite_model_buffer, tvm_in_data, in_node, target=device) + for i in range(len(tflite_output)): + if output_need_transpose_nchw: + tvm.testing.assert_allclose(tflite_output[i], + np.transpose(tvm_output[i], axes=(0, 2, 3, 1)), + atol=1e-5, rtol=1e-5) + else: + tvm.testing.assert_allclose(tflite_output[i], tvm_output[i], + atol=1e-5, rtol=1e-5) + + sess.close() + + +####################################################################### +# Pooling +# ------- +def _test_pooling_iteration(input_shape, **kwargs): + """ One iteration of pool operation with given shapes and attributes """ + + x = -np.arange( + np.prod(input_shape), dtype=np.float32).reshape(input_shape) - 1 + tvm_data = np.transpose(x, axes=(0, 3, 1, 2)) + + with tf.Graph().as_default(): + in_data = array_ops.placeholder(shape=input_shape, dtype='float32') + out = nn_ops.pool(in_data, **kwargs) + + compare_tflite_with_tvm(x, tvm_data, 'Placeholder:0', [in_data], [out], + output_need_transpose_nchw=True) + + +def _test_pooling(input_shape, **kwargs): + _test_pooling_iteration(input_shape, **kwargs) + + +def test_forward_pooling(): + """ Pooling """ + + for pool_type in ['AVG', 'MAX']: + _test_pooling(input_shape=[2, 9, 10, 2], + window_shape=[1, 1], + padding='SAME', + pooling_type=pool_type, + dilation_rate=[1, 1], + strides=[1, 1]) + + _test_pooling(input_shape=[2, 10, 9, 2], + window_shape=[1, 1], + padding='SAME', + pooling_type=pool_type, + dilation_rate=[1, 1], + strides=[1, 1]) + + _test_pooling(input_shape=[2, 9, 10, 2], + window_shape=[2, 1], + padding='SAME', + pooling_type=pool_type, + dilation_rate=[1, 1], + strides=[1, 1]) + + _test_pooling(input_shape=[2, 10, 9, 2], + window_shape=[2, 3], + padding='SAME', + pooling_type=pool_type, + dilation_rate=[1, 1], + strides=[2, 1]) + + +####################################################################### +# Convolution +# ----------- + +def _test_convolution(tensor_in_sizes, filter_in_sizes, + dilations, strides, padding, data_format, + is_depthwise=False): + """ One iteration of convolution with given shapes and attributes """ + + total_size_1 = 1 + total_size_2 = 1 + for s in tensor_in_sizes: + total_size_1 *= s + for s in filter_in_sizes: + total_size_2 *= s + # Initializes the input tensor with array containing incrementing + # numbers from 1. + data_array = [f * 1.0 for f in range(1, total_size_1 + 1)] + filter_array = [f * 1.0 for f in range(1, total_size_2 + 1)] + + with tf.Graph().as_default(): + in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype='float32') + in_filter = constant_op.constant(filter_array, shape=filter_in_sizes, dtype='float32') + strides = [1] + strides + [1] + dilations = [1] + dilations + [1] + + if is_depthwise: + out = nn_ops.depthwise_conv2d_native(in_data, + in_filter, + strides=strides, + padding=padding, + data_format=data_format) + else: + out = nn_ops.conv2d(in_data, + in_filter, + strides=strides, + padding=padding, + data_format=data_format) + # TFLite is NHWC, TVM is NCHW + tflite_data_array = np.reshape(data_array, tensor_in_sizes).astype('float32') + tvm_data_array = np.transpose(tflite_data_array, axes=(0, 3, 1, 2)) + # TFLite output is NHWC, TVM is NCHW, we need transpose + compare_tflite_with_tvm(tflite_data_array, tvm_data_array, + 'Placeholder:0', [in_data], [out], + output_need_transpose_nchw=True) + + +def test_forward_convolution(): + _test_convolution([4, 8, 8, 176], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME', 'NHWC') + _test_convolution([4, 17, 17, 19], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID', 'NHWC') + _test_convolution([4, 17, 17, 124], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME', 'NHWC') + _test_convolution([4, 17, 17, 12], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID', 'NHWC') + + # depthwise convolution + _test_convolution([4, 8, 8, 176], [1, 1, 176, 1], [1, 1], [1, 1], 'SAME', 'NHWC', True) + _test_convolution([4, 17, 17, 19], [3, 3, 19, 1], [1, 1], [2, 2], 'VALID', 'NHWC', True) + _test_convolution([4, 17, 17, 124], [1, 1, 124, 1], [1, 1], [1, 1], 'SAME', 'NHWC', True) + _test_convolution([4, 17, 17, 12], [3, 3, 12, 1], [1, 1], [2, 2], 'VALID', 'NHWC', True) + + +####################################################################### +# Reshape +# ------- + +def _test_reshape(data, out_shape): + """ One iteration of reshape operation with given data and out shape """ + # see relay/frontend/tflite.py convert_reshape more detail of channel first rule + if len(data.shape) == 1 or len(data.shape) == 2: + tvm_data = data + elif len(data.shape) == 3: + tvm_data = np.transpose(data, axes=(0, 2, 1)) + elif len(data.shape) == 4: + tvm_data = np.transpose(data, axes=(0, 3, 1, 2)) + else: + raise NotImplementedError("Not support input shape {} of reshape : ". + format(str(len(data)))) + + with tf.Graph().as_default(): + in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) + out = array_ops.reshape(in_data, out_shape) + + compare_tflite_with_tvm(data, tvm_data, 'Placeholder:0', [in_data], [out]) + + +def test_forward_reshape(): + _test_reshape(np.arange(6.0, dtype=np.float32), [2, 3]) + _test_reshape(np.arange(6), [-1, 2]) + _test_reshape(np.arange(6), [3, -1]) + _test_reshape(np.arange(6), [-1]) + + +####################################################################### +# Squeeze +# ------- + +def _test_squeeze(data, squeeze_dims=None): + """ One iteration of squeeze """ + + if squeeze_dims is None: + squeeze_dims = [] + + # see relay/frontend/tflite.py convert_squeeze more detail of channel first rule + if len(data.shape) == 1 or len(data.shape) == 2: + tvm_data = data + elif len(data.shape) == 3: + tvm_data = np.transpose(data, axes=(0, 2, 1)) + elif len(data.shape) == 4: + tvm_data = np.transpose(data, axes=(0, 3, 1, 2)) + else: + raise NotImplementedError("Not support input shape {} of reshape : ". + format(str(len(data.shape)))) + + tvm_data = np.transpose(data, axes=(0, 3, 1, 2)) + + with tf.Graph().as_default(): + in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) + + if squeeze_dims: + out = array_ops.squeeze(in_data, squeeze_dims) + else: + out = array_ops.squeeze(in_data) + + compare_tflite_with_tvm(data, tvm_data, 'Placeholder:0', [in_data], [out]) + + +def test_forward_squeeze(): + """ Squeeze """ + _test_squeeze(np.arange(6).reshape((1, 2, 1, 3)), [0, 2]) + _test_squeeze(np.arange(6).reshape((2, 1, 3, 1)), [1, 3]) + +####################################################################### +# Softmax +# ------- + +def _test_softmax(data): + """ One iteration of softmax """ + with tf.Graph().as_default(): + in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) + out = nn_ops.softmax(in_data) + compare_tflite_with_tvm(data, data, 'Placeholder:0', [in_data], [out]) + +def test_forward_softmax(): + """ Softmax """ + _test_softmax(np.arange(6.0, dtype=np.float32).reshape((1, 6))) + +####################################################################### +# Mobilenet +# --------- +def test_forward_mobilenet(): + '''test mobilenet v1 tflite model''' + # MobilenetV1 + temp = util.tempdir() + tflite_model_file = nnvm.testing.tf.get_workload_official( + "http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz", + "mobilenet_v1_1.0_224.tflite", temp) + tflite_model_buf = open(tflite_model_file, "rb").read() + data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32') + tvm_data = np.transpose(data, axes=(0, 3, 1, 2)) + tflite_output = run_tflite_graph(tflite_model_buf, data) + tvm_output = run_tvm_graph(tflite_model_buf, tvm_data, 'input') + tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), + rtol=1e-5, atol=1e-5) + +####################################################################### +# Main +# ---- +if __name__ == '__main__': + # Transforms + test_forward_reshape() + test_forward_squeeze() + + # NN + test_forward_convolution() + test_forward_pooling() + test_forward_softmax() + + # End to End + test_forward_mobilenet() diff --git a/tests/scripts/task_python_frontend.sh b/tests/scripts/task_python_frontend.sh index 1f207760fb723..beff2d47f4643 100755 --- a/tests/scripts/task_python_frontend.sh +++ b/tests/scripts/task_python_frontend.sh @@ -38,3 +38,6 @@ python3 -m nose -v tests/python/frontend/onnx || exit -1 echo "Running nnvm to relay frontend test..." python3 -m nose -v tests/python/frontend/nnvm_to_relay || exit -1 + +echo "Running relay TFLite frontend test..." +python3 -m nose -v tests/python/frontend/tflite || exit -1