Skip to content

Commit

Permalink
[CONTRIB] TFLite Runtime (#4439)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZihengJiang authored and tqchen committed Dec 4, 2019
1 parent f214364 commit 24713bd
Show file tree
Hide file tree
Showing 7 changed files with 567 additions and 0 deletions.
3 changes: 3 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ tvm_option(USE_NNPACK "Build with nnpack support" OFF)
tvm_option(USE_RANDOM "Build with random support" OFF)
tvm_option(USE_MICRO_STANDALONE_RUNTIME "Build with micro.standalone_runtime support" OFF)
tvm_option(USE_ANTLR "Build with ANTLR for Relay parsing" OFF)
tvm_option(USE_TFLITE "Build with tflite support" OFF)
tvm_option(USE_TENSORFLOW_PATH "TensorFlow root path when use TFLite" none)

# include directories
include_directories(${CMAKE_INCLUDE_PATH})
Expand Down Expand Up @@ -257,6 +259,7 @@ include(cmake/modules/contrib/MicroStandaloneRuntime.cmake)
include(cmake/modules/contrib/Sort.cmake)
include(cmake/modules/contrib/NNPack.cmake)
include(cmake/modules/contrib/HybridDump.cmake)
include(cmake/modules/contrib/TFLite.cmake)

if(NOT MSVC)
include(CheckCXXCompilerFlag)
Expand Down
9 changes: 9 additions & 0 deletions cmake/config.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,15 @@ set(USE_RANDOM OFF)
# Whether use NNPack
set(USE_NNPACK OFF)

# Possible values:
# - ON: enable tflite with cmake's find search
# - OFF: disable tflite
# - /path/to/libtensorflow-lite.a: use specific path to tensorflow lite library
set(USE_TFLITE OFF)

# /path/to/tensorflow: tensorflow root path when use tflite library
set(USE_TENSORFLOW_PATH none)

# Whether use CuDNN
set(USE_CUDNN OFF)

Expand Down
35 changes: 35 additions & 0 deletions cmake/modules/contrib/TFLite.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

if(NOT USE_TFLITE STREQUAL "OFF")
message(STATUS "Build with contrib.tflite")
if (USE_TENSORFLOW_PATH STREQUAL "none")
set(USE_TENSORFLOW_PATH ${CMAKE_CURRENT_SOURCE_DIR}/tensorflow)
endif()

file(GLOB TFLITE_CONTRIB_SRC src/runtime/contrib/tflite/*.cc)
list(APPEND RUNTIME_SRCS ${TFLITE_CONTRIB_SRC})
include_directories(${USE_TENSORFLOW_PATH})

if (USE_TFLITE STREQUAL "ON")
set(USE_TFLITE ${USE_TENSORFLOW_PATH}/tensorflow/lite/tools/make/gen/*/lib)
endif()
find_library(TFLITE_CONTRIB_LIB libtensorflow-lite.a ${USE_TFLITE})

list(APPEND TVM_RUNTIME_LINKER_LIBS ${TFLITE_CONTRIB_LIB})
list(APPEND TVM_RUNTIME_LINKER_LIBS rt dl flatbuffers)
endif()
108 changes: 108 additions & 0 deletions python/tvm/contrib/tflite_runtime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""TFLite runtime that load and run tflite models."""
from .._ffi.function import get_global_func
from ..rpc import base as rpc_base

def create(tflite_model_bytes, ctx):
"""Create a runtime executor module given a tflite model and context.
Parameters
----------
tflite_model_byte : bytes
The tflite model to be deployed in bytes string format.
ctx : TVMContext
The context to deploy the module. It can be local or remote when there
is only one TVMContext.
Returns
-------
tflite_runtime : TFLiteModule
Runtime tflite module that can be used to execute the tflite model.
"""
device_type = ctx.device_type
if device_type >= rpc_base.RPC_SESS_MASK:
fcreate = ctx._rpc_sess.get_function("tvm.tflite_runtime.create")
return TFLiteModule(fcreate(bytearray(tflite_model_bytes), ctx))
fcreate = get_global_func("tvm.tflite_runtime.create")
return TFLiteModule(fcreate(bytearray(tflite_model_bytes), ctx))


class TFLiteModule(object):
"""Wrapper runtime module.
This is a thin wrapper of the underlying TVM module.
you can also directly call set_input, run, and get_output
of underlying module functions
Parameters
----------
module : Module
The interal tvm module that holds the actual tflite functions.
Attributes
----------
module : Module
The interal tvm module that holds the actual tflite functions.
"""

def __init__(self, module):
self.module = module
self._set_input = module["set_input"]
self._invoke = module["invoke"]
self._get_output = module["get_output"]
self._allocate_tensors = module["allocate_tensors"]

def set_input(self, index, value):
"""Set inputs to the module via kwargs
Parameters
----------
key : int or str
The input key
value : the input value.
The input key
params : dict of str to NDArray
Additonal arguments
"""
self._set_input(index, value)

def invoke(self):
"""Invoke forward execution of the model
Parameters
----------
input_dict: dict of str to NDArray
List of input values to be feed to
"""
self._invoke()

def allocate_tensors(self):
"""Allocate space for all tensors.
"""
self._allocate_tensors()


def get_output(self, index):
"""Get index-th output to out
Parameters
----------
index : int
The output index
"""
return self._get_output(index)
191 changes: 191 additions & 0 deletions src/runtime/contrib/tflite/tflite_runtime.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tflite_runtime.cc
*/
#include <tvm/runtime/registry.h>
#include <tvm/dtype.h>
#include <tensorflow/lite/interpreter.h>
#include <tensorflow/lite/kernels/register.h>
#include <tensorflow/lite/model.h>


#include "tflite_runtime.h"

namespace tvm {
namespace runtime {

#define TVM_DTYPE_DISPATCH(type, DType, ...) \
if (type == Float(64)) { \
typedef double DType; \
{__VA_ARGS__} \
} else if (type == Float(32)) { \
typedef float DType; \
{__VA_ARGS__} \
} else if (type == Float(16)) { \
typedef uint16_t DType; \
{__VA_ARGS__} \
} else if (type == Int(64)) { \
typedef int64_t DType; \
{__VA_ARGS__} \
} else if (type == Int(32)) { \
typedef int32_t DType; \
{__VA_ARGS__} \
} else if (type == Int(16)) { \
typedef int16_t DType; \
{__VA_ARGS__} \
} else if (type == Int(8)) { \
typedef int8_t DType; \
{__VA_ARGS__} \
} else if (type == UInt(64)) { \
typedef uint64_t DType; \
{__VA_ARGS__} \
} else if (type == UInt(32)) { \
typedef uint32_t DType; \
{__VA_ARGS__} \
} else if (type == UInt(16)) { \
typedef uint16_t DType; \
{__VA_ARGS__} \
} else if (type == UInt(8)) { \
typedef uint8_t DType; \
{__VA_ARGS__} \
} else { \
LOG(FATAL) << "unknown data type " << type; \
}

DataType TfLiteDType2TVMDType(TfLiteType dtype) {
switch (dtype) {
case kTfLiteFloat32:
return Float(32);
case kTfLiteInt32:
return Int(32);
case kTfLiteInt64:
return Int(64);
case kTfLiteInt16:
return Int(16);
case kTfLiteInt8:
return Int(8);
case kTfLiteUInt8:
return UInt(8);
case kTfLiteFloat16:
return Float(16);
default:
LOG(FATAL) << "tflite data type not support yet: " << dtype;
return Float(32);
}
}


void TFLiteRuntime::Init(const std::string& tflite_model_bytes,
TVMContext ctx) {
const char* buffer = tflite_model_bytes.c_str();
size_t buffer_size = tflite_model_bytes.size();
std::unique_ptr<tflite::FlatBufferModel> model =
tflite::FlatBufferModel::BuildFromBuffer(buffer, buffer_size);
tflite::ops::builtin::BuiltinOpResolver resolver;
tflite::InterpreterBuilder(*model, resolver)(&interpreter_);
ctx_ = ctx;
}

void TFLiteRuntime::AllocateTensors() {
interpreter_->AllocateTensors();
}

void TFLiteRuntime::Invoke() {
interpreter_->Invoke();
}

void TFLiteRuntime::SetInput(int index, DLTensor* data_in) {
DataType dtype(data_in->dtype);
TVM_DTYPE_DISPATCH(dtype, DType, {
DType* dest = interpreter_->typed_input_tensor<DType>(index);
DType* src = static_cast<DType*>(data_in->data);
CHECK(data_in->strides == NULL);
int64_t size = 1;
for (int64_t i = 0; i < data_in->ndim; ++i) {
size *= data_in->shape[i];
}
for (int64_t i = 0; i < size; ++i) {
dest[i] = src[i];
}
});
}

NDArray TFLiteRuntime::GetOutput(int index) const {
TfLiteTensor* output = interpreter_->output_tensor(index);
DataType dtype = TfLiteDType2TVMDType(output->type);
TfLiteIntArray* dims = output->dims;
int64_t size = 1;
std::vector<int64_t> shape;
for (int i = 0; i < dims->size; ++i) {
shape.push_back(dims->data[i]);
size *= dims->data[i];
}
NDArray ret = NDArray::Empty(shape, dtype, ctx_);
TVM_DTYPE_DISPATCH(dtype, DType, {
DType* dest = static_cast<DType*>(ret->data);
DType* src = interpreter_->typed_output_tensor<DType>(index);
for (int64_t i = 0; i < size; ++i) {
dest[i] = src[i];
}
});
return ret;
}

PackedFunc TFLiteRuntime::GetFunction(
const std::string& name,
const ObjectPtr<Object>& sptr_to_self) {
// Return member functions during query.
if (name == "set_input") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
int in_idx = args[0];
CHECK_GE(in_idx, 0);
this->SetInput(in_idx, args[1]);
});
} else if (name == "get_output") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = this->GetOutput(args[0]);
});
} else if (name == "invoke") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
this->Invoke();
});
} else if (name == "allocate_tensors") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
this->AllocateTensors();
});
} else {
return PackedFunc();
}
}

Module TFLiteRuntimeCreate(const std::string& tflite_model_bytes,
TVMContext ctx) {
auto exec = make_object<TFLiteRuntime>();
exec->Init(tflite_model_bytes, ctx);
return Module(exec);
}

TVM_REGISTER_GLOBAL("tvm.tflite_runtime.create")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = TFLiteRuntimeCreate(args[0], args[1]);
});
} // namespace runtime
} // namespace tvm
Loading

0 comments on commit 24713bd

Please sign in to comment.