From 5b7b0c5a2e568840e0f7f21bf3e93755a36db95a Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Wed, 27 Nov 2019 12:23:17 -0800 Subject: [PATCH 01/21] TFLite Runtime. --- CMakeLists.txt | 2 + cmake/config.cmake | 2 + cmake/modules/contrib/TfLite.cmake | 33 ++++ python/tvm/contrib/tflite_runtime.py | 144 ++++++++++++++ src/runtime/contrib/tflite/tflite_runtime.cc | 189 +++++++++++++++++++ src/runtime/contrib/tflite/tflite_runtime.h | 110 +++++++++++ 6 files changed, 480 insertions(+) create mode 100644 cmake/modules/contrib/TfLite.cmake create mode 100644 python/tvm/contrib/tflite_runtime.py create mode 100644 src/runtime/contrib/tflite/tflite_runtime.cc create mode 100644 src/runtime/contrib/tflite/tflite_runtime.h diff --git a/CMakeLists.txt b/CMakeLists.txt index bf18ffc9e856..1cf79f85ac56 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -63,6 +63,7 @@ tvm_option(USE_NNPACK "Build with nnpack support" OFF) tvm_option(USE_RANDOM "Build with random support" OFF) tvm_option(USE_MICRO_STANDALONE_RUNTIME "Build with micro.standalone_runtime support" OFF) tvm_option(USE_ANTLR "Build with ANTLR for Relay parsing" OFF) +tvm_option(USE_TFLITE "Build with nnpack support" ON) # include directories include_directories(${CMAKE_INCLUDE_PATH}) @@ -257,6 +258,7 @@ include(cmake/modules/contrib/MicroStandaloneRuntime.cmake) include(cmake/modules/contrib/Sort.cmake) include(cmake/modules/contrib/NNPack.cmake) include(cmake/modules/contrib/HybridDump.cmake) +include(cmake/modules/contrib/TfLite.cmake) if(NOT MSVC) include(CheckCXXCompilerFlag) diff --git a/cmake/config.cmake b/cmake/config.cmake index 1ef956c7ee18..e3f0b6225f1a 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -145,6 +145,8 @@ set(USE_RANDOM OFF) # Whether use NNPack set(USE_NNPACK OFF) +set(USE_TFLITE ON) + # Whether use CuDNN set(USE_CUDNN OFF) diff --git a/cmake/modules/contrib/TfLite.cmake b/cmake/modules/contrib/TfLite.cmake new file mode 100644 index 000000000000..c198d711eaac --- /dev/null +++ b/cmake/modules/contrib/TfLite.cmake @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +if(USE_TFLITE) + message(STATUS "Build with contrib.tflite") + message("current path: ${CMAKE_CURRENT_SOURCE_DIR}") + # if (TENSORFLOW_PATH STREQUAL "") + set(TENSORFLOW_PATH ${CMAKE_CURRENT_SOURCE_DIR}/tensorflow) + # endif() + message("tfpath: ${TENSORFLOW_PATH}") + file(GLOB TFLITE_CONTRIB_SRC src/runtime/contrib/tflite/*.cc) + list(APPEND RUNTIME_SRCS ${TFLITE_CONTRIB_SRC}) + include_directories(${TENSORFLOW_PATH}) + find_library(TFLITE_CONTRIB_LIB libtensorflow-lite.a ${TENSORFLOW_PATH}/tensorflow/lite/tools/make/gen/linux_x86_64/lib) + message("tflite lib: ${TFLITE_CONTRIB_LIB}") + + list(APPEND TVM_LINKER_LIBS ${TFLITE_CONTRIB_LIB}) + list(APPEND TVM_LINKER_LIBS rt dl flatbuffers) +endif(USE_TFLITE) diff --git a/python/tvm/contrib/tflite_runtime.py b/python/tvm/contrib/tflite_runtime.py new file mode 100644 index 000000000000..37b18c5c88e4 --- /dev/null +++ b/python/tvm/contrib/tflite_runtime.py @@ -0,0 +1,144 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Minimum graph runtime that executes graph containing TVM PackedFunc.""" +import numpy as np + +from .._ffi.base import string_types +from .._ffi.function import get_global_func +from .._ffi.runtime_ctypes import TVMContext +from ..rpc import base as rpc_base + +def create(tflite_fname, ctx): + """Create a runtime executor module given a graph and module. + Parameters + ---------- + graph_json_str : str or graph class + The graph to be deployed in json format output by nnvm graph. + The graph can only contain one operator(tvm_op) that + points to the name of PackedFunc in the libmod. + ctx : TVMContext or list of TVMContext + The context to deploy the module. It can be local or remote when there + is only one TVMContext. Otherwise, the first context in the list will + be used as this purpose. All context should be given for heterogeneous + execution. + Returns + ------- + graph_module : GraphModule + Runtime graph module that can be used to execute the graph. + """ + if not isinstance(tflite_fname, string_types): + except AttributeError: + raise ValueError("Type %s is not supported" % type(tflite_fname)) + + fcreate = get_global_func("tvm.tflite_runtime.create") + return TfliteModule(fcreate(tflite_fname, ctx)) + + +class TfliteModule(object): + """Wrapper runtime module. + + This is a thin wrapper of the underlying TVM module. + you can also directly call set_input, run, and get_output + of underlying module functions + + Parameters + ---------- + module : Module + The interal tvm module that holds the actual graph functions. + + Attributes + ---------- + module : Module + The interal tvm module that holds the actual graph functions. + """ + + def __init__(self, module): + self.module = module + self._set_input = module["set_input"] + self._invoke = module["invoke"] + self._get_output = module["get_output"] + self._get_input = module["get_input"] + + def set_input(self, key=None, value=None, **params): + """Set inputs to the module via kwargs + + Parameters + ---------- + key : int or str + The input key + + value : the input value. + The input key + + params : dict of str to NDArray + Additonal arguments + """ + if key is not None: + self._get_input(key).copyfrom(value) + + if params: + # upload big arrays first to avoid memory issue in rpc mode + keys = list(params.keys()) + keys.sort(key=lambda x: -np.prod(params[x].shape)) + for k in keys: + self._get_input(k).copyfrom(params[k]) + + def run(self, **input_dict): + """Run forward execution of the graph + + Parameters + ---------- + input_dict: dict of str to NDArray + List of input values to be feed to + """ + if input_dict: + self.set_input(**input_dict) + self._run() + + def get_input(self, index, out=None): + """Get index-th input to out + + Parameters + ---------- + index : int + The input index + + out : NDArray + The output array container + """ + if out: + self._get_input(index).copyto(out) + return out + + return self._get_input(index) + + def get_output(self, index, out=None): + """Get index-th output to out + + Parameters + ---------- + index : int + The output index + + out : NDArray + The output array container + """ + if out: + self._get_output(index, out) + return out + + return self._get_output(index) diff --git a/src/runtime/contrib/tflite/tflite_runtime.cc b/src/runtime/contrib/tflite/tflite_runtime.cc new file mode 100644 index 000000000000..7b66feb800cf --- /dev/null +++ b/src/runtime/contrib/tflite/tflite_runtime.cc @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tflite_runtime.cc + */ +#include +#include +#include +#include +#include + + +#include "tflite_runtime.h" + +namespace tvm { +namespace runtime { + +#define TVM_DTYPE_DISPATCH(type, DType, ...) \ + if (type == Float(64)) { \ + typedef double DType; \ + {__VA_ARGS__} \ + } else if (type == Float(32)) { \ + typedef float DType; \ + {__VA_ARGS__} \ + } else if (type == Float(16)) { \ + typedef uint16_t DType; \ + {__VA_ARGS__} \ + } else if (type == Int(64)) { \ + typedef int64_t DType; \ + {__VA_ARGS__} \ + } else if (type == Int(32)) { \ + typedef int32_t DType; \ + {__VA_ARGS__} \ + } else if (type == Int(16)) { \ + typedef int16_t DType; \ + {__VA_ARGS__} \ + } else if (type == Int(8)) { \ + typedef int8_t DType; \ + {__VA_ARGS__} \ + } else if (type == UInt(64)) { \ + typedef uint64_t DType; \ + {__VA_ARGS__} \ + } else if (type == UInt(32)) { \ + typedef uint32_t DType; \ + {__VA_ARGS__} \ + } else if (type == UInt(16)) { \ + typedef uint16_t DType; \ + {__VA_ARGS__} \ + } else if (type == UInt(8)) { \ + typedef uint8_t DType; \ + {__VA_ARGS__} \ + } else { \ + LOG(FATAL) << "unknown data type " << type; \ + } + +DataType TfLiteDType2TVMDType(TfLiteType dtype) { + switch (dtype) { + case kTfLiteFloat32: + return Float(32); + case kTfLiteInt32: + return Int(32); + case kTfLiteInt64: + return Int(64); + case kTfLiteInt16: + return Int(16); + case kTfLiteInt8: + return Int(8); + case kTfLiteUInt8: + return UInt(8); + case kTfLiteFloat16: + return Float(16); + default: + LOG(FATAL) << "tflite data type not support yet: " << dtype; + } +} + + +void TfliteRuntime::Init(const std::string& tflite_fname, + TVMContext ctx) { + std::unique_ptr model = tflite::FlatBufferModel::BuildFromFile(tflite_fname.c_str()); + tflite::ops::builtin::BuiltinOpResolver resolver; + tflite::InterpreterBuilder(*model, resolver)(&interpreter_); + interpreter_->AllocateTensors(); + + ctx_ = ctx; +} + +void TfliteRuntime::Invoke() { + interpreter_->Invoke(); +} + +void TfliteRuntime::SetInput(int index, DLTensor* data_in) { + DataType dtype(data_in->dtype); + TVM_DTYPE_DISPATCH(dtype, DType, { + DType* dest = interpreter_->typed_input_tensor(index); + DType* src = static_cast(data_in->data); + CHECK(data_in->strides == NULL); + int64_t size = 1; + for (int64_t i = 0; i < data_in->ndim; ++i) { + size *= data_in->shape[i]; + } + for (int64_t i = 0; i < size; ++i) { + dest[i] = src[i]; + } + }); +} + +NDArray TfliteRuntime::GetOutput(int index) const { + TfLiteTensor* output = interpreter_->output_tensor(index); + DataType dtype = TfLiteDType2TVMDType(output->type); + TfLiteIntArray* dims = output->dims; + int64_t size = 1; + std::vector shape; + for (int i = 0; i < dims->size; ++i) { + shape.push_back(dims->data[i]); + size *= dims->data[i]; + } + + NDArray ret = NDArray::Empty(shape, dtype, ctx_); + TVM_DTYPE_DISPATCH(dtype, DType, { + DType* dest = static_cast(ret->data); + DType* src = interpreter_->typed_input_tensor(index); + for (int64_t i = 0; i < size; ++i) { + dest[i] = src[i]; + } + }); + return ret; +} + +PackedFunc TfliteRuntime::GetFunction( + const std::string& name, + const ObjectPtr& sptr_to_self) { + // Return member functions during query. + if (name == "set_input") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + int in_idx = args[0]; + CHECK_GE(in_idx, 0); + this->SetInput(in_idx, args[1]); + }); + } else if (name == "get_output") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + *rv = this->GetOutput(args[0]); + }); + } else if (name == "get_input") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + int in_idx = args[0]; + CHECK_GE(in_idx, 0); + *rv = this->GetInput(in_idx); + }); + } else if (name == "invoke") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + this->Invoke(); + }); + } else { + return PackedFunc(); + } +} + +Module TfliteRuntimeCreate(const std::string& tflite_fname, + TVMContext ctx) { + auto exec = make_object(); + exec->Init(tflite_fname, ctx); + return Module(exec); +} + +TVM_REGISTER_GLOBAL("tvm.tflite_runtime.create") + .set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = TfliteRuntimeCreate(args[0], contexts); + }); + +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/contrib/tflite/tflite_runtime.h b/src/runtime/contrib/tflite/tflite_runtime.h new file mode 100644 index 000000000000..45e861eede09 --- /dev/null +++ b/src/runtime/contrib/tflite/tflite_runtime.h @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \brief Tiny graph runtime that can run graph + * containing only tvm PackedFunc. + * \file graph_runtime.h + */ +#ifndef TVM_RUNTIME_TFLITE_TFLITE_RUNTIME_H_ +#define TVM_RUNTIME_TFLITE_TFLITE_RUNTIME_H_ + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace tvm { +namespace runtime { + + +/*! + * \brief Tiny graph runtime. + * + * This runtime can be acccesibly in various language via + * TVM runtime PackedFunc API. + */ +class TfliteRuntime : public ModuleNode { + public: + /*! + * \brief Get member function to front-end + * \param name The name of the function. + * \param sptr_to_self The pointer to the module node. + * \return The corresponding member function. + */ + virtual PackedFunc GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self); + + /*! + * \return The type key of the executor. + */ + const char* type_key() const final { + return "TfliteRuntime"; + } + void Invoke(); + + /*! + * \brief Initialize the graph executor with graph and context. + * \param graph_json The execution graph. + * \param module The module containing the compiled functions for the host + * processor. + * \param ctxs The context of the host and devices where graph nodes will be + * executed on. + */ + + void Init(const std::string& tflite_fname, + TVMContext ctx); + /*! + * \brief set index-th input to the graph. + * \param index The input index. + * \param data_in The input data. + */ + void SetInput(int index, DLTensor* data_in); + /*! + * \brief Return NDArray for given input index. + * \param index The input index. + * + * \return NDArray corresponding to given input node index. + */ + NDArray GetInput(int index) const; + /*! + * \brief Return NDArray for given output index. + * \param index The output index. + * + * \return NDArray corresponding to given output node index. + */ + NDArray GetOutput(int index) const; + + private: + std::unique_ptr interpreter_; + TVMContext ctx_; + +}; + +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_TFLITE_TFLITE_RUNTIME_H_ From e42379717de6a66ddbbf541b1d5442758bf562ed Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Wed, 27 Nov 2019 15:27:47 -0800 Subject: [PATCH 02/21] Add test. --- python/tvm/contrib/tflite_runtime.py | 40 ++------- src/runtime/contrib/tflite/tflite_runtime.cc | 29 ++++-- src/runtime/contrib/tflite/tflite_runtime.h | 2 + tests/python/unittest/test_runtime_tflite.py | 95 ++++++++++++++++++++ 4 files changed, 125 insertions(+), 41 deletions(-) create mode 100644 tests/python/unittest/test_runtime_tflite.py diff --git a/python/tvm/contrib/tflite_runtime.py b/python/tvm/contrib/tflite_runtime.py index 37b18c5c88e4..3c2b6ff78636 100644 --- a/python/tvm/contrib/tflite_runtime.py +++ b/python/tvm/contrib/tflite_runtime.py @@ -41,8 +41,7 @@ def create(tflite_fname, ctx): Runtime graph module that can be used to execute the graph. """ if not isinstance(tflite_fname, string_types): - except AttributeError: - raise ValueError("Type %s is not supported" % type(tflite_fname)) + raise ValueError("Type %s is not supported" % type(tflite_fname)) fcreate = get_global_func("tvm.tflite_runtime.create") return TfliteModule(fcreate(tflite_fname, ctx)) @@ -71,9 +70,9 @@ def __init__(self, module): self._set_input = module["set_input"] self._invoke = module["invoke"] self._get_output = module["get_output"] - self._get_input = module["get_input"] + self._allocate_tensors = module["allocate_tensors"] - def set_input(self, key=None, value=None, **params): + def set_input(self, index, value): """Set inputs to the module via kwargs Parameters @@ -87,17 +86,9 @@ def set_input(self, key=None, value=None, **params): params : dict of str to NDArray Additonal arguments """ - if key is not None: - self._get_input(key).copyfrom(value) + self._set_input(index, value) - if params: - # upload big arrays first to avoid memory issue in rpc mode - keys = list(params.keys()) - keys.sort(key=lambda x: -np.prod(params[x].shape)) - for k in keys: - self._get_input(k).copyfrom(params[k]) - - def run(self, **input_dict): + def invoke(self): """Run forward execution of the graph Parameters @@ -105,26 +96,11 @@ def run(self, **input_dict): input_dict: dict of str to NDArray List of input values to be feed to """ - if input_dict: - self.set_input(**input_dict) - self._run() - - def get_input(self, index, out=None): - """Get index-th input to out - - Parameters - ---------- - index : int - The input index + self._invoke() - out : NDArray - The output array container - """ - if out: - self._get_input(index).copyto(out) - return out + def allocate_tensors(self): + self._allocate_tensors() - return self._get_input(index) def get_output(self, index, out=None): """Get index-th output to out diff --git a/src/runtime/contrib/tflite/tflite_runtime.cc b/src/runtime/contrib/tflite/tflite_runtime.cc index 7b66feb800cf..403b1476d67a 100644 --- a/src/runtime/contrib/tflite/tflite_runtime.cc +++ b/src/runtime/contrib/tflite/tflite_runtime.cc @@ -97,17 +97,29 @@ void TfliteRuntime::Init(const std::string& tflite_fname, std::unique_ptr model = tflite::FlatBufferModel::BuildFromFile(tflite_fname.c_str()); tflite::ops::builtin::BuiltinOpResolver resolver; tflite::InterpreterBuilder(*model, resolver)(&interpreter_); - interpreter_->AllocateTensors(); + LOG(INFO) << "Init TFLite Interpreter..."; + LOG(INFO) << "Number of inputs: " << interpreter_->inputs().size(); + LOG(INFO) << interpreter_->GetInputName(0); + LOG(INFO) << "Number of outputs: " << interpreter_->outputs().size(); + LOG(INFO) << interpreter_->GetOutputName(0); ctx_ = ctx; } +void TfliteRuntime::AllocateTensors() { + LOG(INFO) << "AllocateTensors"; + interpreter_->AllocateTensors(); +} + void TfliteRuntime::Invoke() { + LOG(INFO) << "Invoke"; interpreter_->Invoke(); } void TfliteRuntime::SetInput(int index, DLTensor* data_in) { + LOG(INFO) << "SetInput"; DataType dtype(data_in->dtype); + LOG(INFO) << "data type: " << dtype; TVM_DTYPE_DISPATCH(dtype, DType, { DType* dest = interpreter_->typed_input_tensor(index); DType* src = static_cast(data_in->data); @@ -123,6 +135,7 @@ void TfliteRuntime::SetInput(int index, DLTensor* data_in) { } NDArray TfliteRuntime::GetOutput(int index) const { + LOG(INFO) << "GetOutput"; TfLiteTensor* output = interpreter_->output_tensor(index); DataType dtype = TfLiteDType2TVMDType(output->type); TfLiteIntArray* dims = output->dims; @@ -136,7 +149,7 @@ NDArray TfliteRuntime::GetOutput(int index) const { NDArray ret = NDArray::Empty(shape, dtype, ctx_); TVM_DTYPE_DISPATCH(dtype, DType, { DType* dest = static_cast(ret->data); - DType* src = interpreter_->typed_input_tensor(index); + DType* src = interpreter_->typed_output_tensor(index); for (int64_t i = 0; i < size; ++i) { dest[i] = src[i]; } @@ -158,16 +171,14 @@ PackedFunc TfliteRuntime::GetFunction( return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetOutput(args[0]); }); - } else if (name == "get_input") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - int in_idx = args[0]; - CHECK_GE(in_idx, 0); - *rv = this->GetInput(in_idx); - }); } else if (name == "invoke") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->Invoke(); }); + } else if (name == "allocate_tensors") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + this->AllocateTensors(); + }); } else { return PackedFunc(); } @@ -182,7 +193,7 @@ Module TfliteRuntimeCreate(const std::string& tflite_fname, TVM_REGISTER_GLOBAL("tvm.tflite_runtime.create") .set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = TfliteRuntimeCreate(args[0], contexts); + *rv = TfliteRuntimeCreate(args[0], args[1]); }); } // namespace runtime diff --git a/src/runtime/contrib/tflite/tflite_runtime.h b/src/runtime/contrib/tflite/tflite_runtime.h index 45e861eede09..124e9e565a57 100644 --- a/src/runtime/contrib/tflite/tflite_runtime.h +++ b/src/runtime/contrib/tflite/tflite_runtime.h @@ -64,6 +64,8 @@ class TfliteRuntime : public ModuleNode { const char* type_key() const final { return "TfliteRuntime"; } + + void AllocateTensors(); void Invoke(); /*! diff --git a/tests/python/unittest/test_runtime_tflite.py b/tests/python/unittest/test_runtime_tflite.py new file mode 100644 index 000000000000..855eb11e8695 --- /dev/null +++ b/tests/python/unittest/test_runtime_tflite.py @@ -0,0 +1,95 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +import numpy as np +from tvm import rpc +from tvm.contrib import util, tflite_runtime +import tensorflow as tf +import tflite_runtime.interpreter as tflite + + + + +def test_tflite_runtime(): + + def create_tflite_model(): + root = tf.Module() + root.const = tf.constant([1., 2.], tf.float32) + root.f = tf.function(lambda x: root.const * x) + + input_signature = tf.TensorSpec(shape=[2, ], dtype=tf.float32) + concrete_func = root.f.get_concrete_function(input_signature) + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + tflite_model = converter.convert() + return tflite_model + + + def check_verify(): + tflite_fname = "model.tflite" + tflite_model = create_tflite_model() + open('/tmp/model.tflite', 'wb').write(tflite_model) + + # inference via tflite interpreter python apis + interpreter = tflite.Interpreter(model_path="/tmp/model.tflite") + interpreter.allocate_tensors() + input_details = interpreter.get_input_details() + output_details = interpreter.get_output_details() + + input_shape = input_details[0]['shape'] + tflite_input = np.array(np.random.random_sample(input_shape), dtype=np.float32) + print(tflite_input) + interpreter.set_tensor(input_details[0]['index'], tflite_input) + interpreter.invoke() + + tflite_output = interpreter.get_tensor(output_details[0]['index']) + print(tflite_output) + + + # inference via tvm tflite runtime + runtime = tflite_runtime.create(tflite_fname, tvm.cpu(0)) + runtime.allocate_tensors() + runtime.set_input(0, tvm.nd.array(tflite_input)) + runtime.invoke() + out = runtime.get_output(0) + print(out) + np.testing.assert_equal(out.asnumpy(), tflite_output) + + + # def check_remote(): + # if not tvm.module.enabled("llvm"): + # print("Skip because llvm is not enabled") + # return + # server = rpc.Server("localhost") + # remote = rpc.connect(server.host, server.port) + # temp = util.tempdir() + # ctx = remote.cpu(0) + # path_dso = temp.relpath("dev_lib.so") + # mlib.export_library(path_dso) + # remote.upload(path_dso) + # mlib = remote.load_module("dev_lib.so") + # mod = graph_runtime.create(graph, mlib, remote.cpu(0)) + # a = np.random.uniform(size=(n,)).astype(A.dtype) + # mod.run(x=tvm.nd.array(a, ctx)) + # out = tvm.nd.empty((n,), ctx=ctx) + # out = mod.get_output(0, out) + # np.testing.assert_equal(out.asnumpy(), a + 1) + + check_verify() + # check_remote() + +if __name__ == "__main__": + test_tflite_runtime() From 0a63523d9aeddf91905ad78a5a172539cfbeee12 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Wed, 27 Nov 2019 16:08:40 -0800 Subject: [PATCH 03/21] Support for remote rpc. --- CMakeLists.txt | 2 +- cmake/modules/contrib/TfLite.cmake | 7 +-- python/tvm/contrib/tflite_runtime.py | 10 +++- src/runtime/contrib/tflite/tflite_runtime.cc | 13 +--- tests/python/unittest/test_runtime_tflite.py | 63 +++++++++++--------- 5 files changed, 49 insertions(+), 46 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 1cf79f85ac56..54b122293749 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -63,7 +63,7 @@ tvm_option(USE_NNPACK "Build with nnpack support" OFF) tvm_option(USE_RANDOM "Build with random support" OFF) tvm_option(USE_MICRO_STANDALONE_RUNTIME "Build with micro.standalone_runtime support" OFF) tvm_option(USE_ANTLR "Build with ANTLR for Relay parsing" OFF) -tvm_option(USE_TFLITE "Build with nnpack support" ON) +tvm_option(USE_TFLITE "Build with tflite support" OFF) # include directories include_directories(${CMAKE_INCLUDE_PATH}) diff --git a/cmake/modules/contrib/TfLite.cmake b/cmake/modules/contrib/TfLite.cmake index c198d711eaac..43cb8e8e3695 100644 --- a/cmake/modules/contrib/TfLite.cmake +++ b/cmake/modules/contrib/TfLite.cmake @@ -17,16 +17,13 @@ if(USE_TFLITE) message(STATUS "Build with contrib.tflite") - message("current path: ${CMAKE_CURRENT_SOURCE_DIR}") - # if (TENSORFLOW_PATH STREQUAL "") + if (NOT DEFINED TENSORFLOW_PATH) set(TENSORFLOW_PATH ${CMAKE_CURRENT_SOURCE_DIR}/tensorflow) - # endif() - message("tfpath: ${TENSORFLOW_PATH}") + endif() file(GLOB TFLITE_CONTRIB_SRC src/runtime/contrib/tflite/*.cc) list(APPEND RUNTIME_SRCS ${TFLITE_CONTRIB_SRC}) include_directories(${TENSORFLOW_PATH}) find_library(TFLITE_CONTRIB_LIB libtensorflow-lite.a ${TENSORFLOW_PATH}/tensorflow/lite/tools/make/gen/linux_x86_64/lib) - message("tflite lib: ${TFLITE_CONTRIB_LIB}") list(APPEND TVM_LINKER_LIBS ${TFLITE_CONTRIB_LIB}) list(APPEND TVM_LINKER_LIBS rt dl flatbuffers) diff --git a/python/tvm/contrib/tflite_runtime.py b/python/tvm/contrib/tflite_runtime.py index 3c2b6ff78636..8080b6e1b34d 100644 --- a/python/tvm/contrib/tflite_runtime.py +++ b/python/tvm/contrib/tflite_runtime.py @@ -18,6 +18,7 @@ import numpy as np from .._ffi.base import string_types +from .._ffi.ndarray import context from .._ffi.function import get_global_func from .._ffi.runtime_ctypes import TVMContext from ..rpc import base as rpc_base @@ -26,7 +27,7 @@ def create(tflite_fname, ctx): """Create a runtime executor module given a graph and module. Parameters ---------- - graph_json_str : str or graph class + tflite_fname : str The graph to be deployed in json format output by nnvm graph. The graph can only contain one operator(tvm_op) that points to the name of PackedFunc in the libmod. @@ -43,6 +44,13 @@ def create(tflite_fname, ctx): if not isinstance(tflite_fname, string_types): raise ValueError("Type %s is not supported" % type(tflite_fname)) + device_type = ctx.device_type + if device_type >= rpc_base.RPC_SESS_MASK: + device_type = ctx.device_type % rpc_base.RPC_SESS_MASK + device_id = ctx.device_id + remote_ctx = context(device_type, device_id) + fcreate = ctx._rpc_sess.get_function("tvm.tflite_runtime.create") + return TfliteModule(fcreate(tflite_fname, ctx)) fcreate = get_global_func("tvm.tflite_runtime.create") return TfliteModule(fcreate(tflite_fname, ctx)) diff --git a/src/runtime/contrib/tflite/tflite_runtime.cc b/src/runtime/contrib/tflite/tflite_runtime.cc index 403b1476d67a..9eb0a969b3b5 100644 --- a/src/runtime/contrib/tflite/tflite_runtime.cc +++ b/src/runtime/contrib/tflite/tflite_runtime.cc @@ -88,6 +88,7 @@ DataType TfLiteDType2TVMDType(TfLiteType dtype) { return Float(16); default: LOG(FATAL) << "tflite data type not support yet: " << dtype; + return Float(32); } } @@ -97,29 +98,19 @@ void TfliteRuntime::Init(const std::string& tflite_fname, std::unique_ptr model = tflite::FlatBufferModel::BuildFromFile(tflite_fname.c_str()); tflite::ops::builtin::BuiltinOpResolver resolver; tflite::InterpreterBuilder(*model, resolver)(&interpreter_); - LOG(INFO) << "Init TFLite Interpreter..."; - - LOG(INFO) << "Number of inputs: " << interpreter_->inputs().size(); - LOG(INFO) << interpreter_->GetInputName(0); - LOG(INFO) << "Number of outputs: " << interpreter_->outputs().size(); - LOG(INFO) << interpreter_->GetOutputName(0); ctx_ = ctx; } void TfliteRuntime::AllocateTensors() { - LOG(INFO) << "AllocateTensors"; interpreter_->AllocateTensors(); } void TfliteRuntime::Invoke() { - LOG(INFO) << "Invoke"; interpreter_->Invoke(); } void TfliteRuntime::SetInput(int index, DLTensor* data_in) { - LOG(INFO) << "SetInput"; DataType dtype(data_in->dtype); - LOG(INFO) << "data type: " << dtype; TVM_DTYPE_DISPATCH(dtype, DType, { DType* dest = interpreter_->typed_input_tensor(index); DType* src = static_cast(data_in->data); @@ -135,7 +126,6 @@ void TfliteRuntime::SetInput(int index, DLTensor* data_in) { } NDArray TfliteRuntime::GetOutput(int index) const { - LOG(INFO) << "GetOutput"; TfLiteTensor* output = interpreter_->output_tensor(index); DataType dtype = TfLiteDType2TVMDType(output->type); TfLiteIntArray* dims = output->dims; @@ -195,6 +185,5 @@ TVM_REGISTER_GLOBAL("tvm.tflite_runtime.create") .set_body([](TVMArgs args, TVMRetValue* rv) { *rv = TfliteRuntimeCreate(args[0], args[1]); }); - } // namespace runtime } // namespace tvm diff --git a/tests/python/unittest/test_runtime_tflite.py b/tests/python/unittest/test_runtime_tflite.py index 855eb11e8695..729048cf0335 100644 --- a/tests/python/unittest/test_runtime_tflite.py +++ b/tests/python/unittest/test_runtime_tflite.py @@ -22,8 +22,6 @@ import tflite_runtime.interpreter as tflite - - def test_tflite_runtime(): def create_tflite_model(): @@ -41,7 +39,9 @@ def create_tflite_model(): def check_verify(): tflite_fname = "model.tflite" tflite_model = create_tflite_model() - open('/tmp/model.tflite', 'wb').write(tflite_model) + temp = util.tempdir() + tflite_model_path = temp.relpath(tflite_fname) + open(tflite_model_path, 'wb').write(tflite_model) # inference via tflite interpreter python apis interpreter = tflite.Interpreter(model_path="/tmp/model.tflite") @@ -51,13 +51,9 @@ def check_verify(): input_shape = input_details[0]['shape'] tflite_input = np.array(np.random.random_sample(input_shape), dtype=np.float32) - print(tflite_input) interpreter.set_tensor(input_details[0]['index'], tflite_input) interpreter.invoke() - tflite_output = interpreter.get_tensor(output_details[0]['index']) - print(tflite_output) - # inference via tvm tflite runtime runtime = tflite_runtime.create(tflite_fname, tvm.cpu(0)) @@ -65,31 +61,44 @@ def check_verify(): runtime.set_input(0, tvm.nd.array(tflite_input)) runtime.invoke() out = runtime.get_output(0) - print(out) np.testing.assert_equal(out.asnumpy(), tflite_output) - # def check_remote(): - # if not tvm.module.enabled("llvm"): - # print("Skip because llvm is not enabled") - # return - # server = rpc.Server("localhost") - # remote = rpc.connect(server.host, server.port) - # temp = util.tempdir() - # ctx = remote.cpu(0) - # path_dso = temp.relpath("dev_lib.so") - # mlib.export_library(path_dso) - # remote.upload(path_dso) - # mlib = remote.load_module("dev_lib.so") - # mod = graph_runtime.create(graph, mlib, remote.cpu(0)) - # a = np.random.uniform(size=(n,)).astype(A.dtype) - # mod.run(x=tvm.nd.array(a, ctx)) - # out = tvm.nd.empty((n,), ctx=ctx) - # out = mod.get_output(0, out) - # np.testing.assert_equal(out.asnumpy(), a + 1) + def check_remote(): + tflite_fname = "model.tflite" + tflite_model = create_tflite_model() + temp = util.tempdir() + tflite_model_path = temp.relpath(tflite_fname) + open(tflite_model_path, 'wb').write(tflite_model) + + # inference via tflite interpreter python apis + interpreter = tflite.Interpreter(model_path="/tmp/model.tflite") + interpreter.allocate_tensors() + input_details = interpreter.get_input_details() + output_details = interpreter.get_output_details() + + input_shape = input_details[0]['shape'] + tflite_input = np.array(np.random.random_sample(input_shape), dtype=np.float32) + interpreter.set_tensor(input_details[0]['index'], tflite_input) + interpreter.invoke() + tflite_output = interpreter.get_tensor(output_details[0]['index']) + + # inference via remote tvm tflite runtime + server = rpc.Server("localhost") + remote = rpc.connect(server.host, server.port) + ctx = remote.cpu(0) + a = remote.upload(tflite_model_path) + + runtime = tflite_runtime.create(tflite_model_path, remote.cpu(0)) + runtime.allocate_tensors() + runtime.set_input(0, tvm.nd.array(tflite_input, remote.cpu(0))) + runtime.invoke() + out = runtime.get_output(0) + np.testing.assert_equal(out.asnumpy(), tflite_output) + check_verify() - # check_remote() + check_remote() if __name__ == "__main__": test_tflite_runtime() From 223e8959e11c5f2b158057e87d7f7398ee240f39 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Wed, 27 Nov 2019 16:21:37 -0800 Subject: [PATCH 04/21] Update. --- python/tvm/contrib/tflite_runtime.py | 8 ++++---- .../test_tflite_runtime.py} | 9 ++++++--- 2 files changed, 10 insertions(+), 7 deletions(-) rename tests/python/{unittest/test_runtime_tflite.py => contrib/test_tflite_runtime.py} (92%) diff --git a/python/tvm/contrib/tflite_runtime.py b/python/tvm/contrib/tflite_runtime.py index 8080b6e1b34d..d7a91b23ef04 100644 --- a/python/tvm/contrib/tflite_runtime.py +++ b/python/tvm/contrib/tflite_runtime.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Minimum graph runtime that executes graph containing TVM PackedFunc.""" +"""TFLite runtime that load and run tflite models.""" import numpy as np from .._ffi.base import string_types @@ -50,12 +50,12 @@ def create(tflite_fname, ctx): device_id = ctx.device_id remote_ctx = context(device_type, device_id) fcreate = ctx._rpc_sess.get_function("tvm.tflite_runtime.create") - return TfliteModule(fcreate(tflite_fname, ctx)) + return TFLiteModule(fcreate(tflite_fname, ctx)) fcreate = get_global_func("tvm.tflite_runtime.create") - return TfliteModule(fcreate(tflite_fname, ctx)) + return TFLiteModule(fcreate(tflite_fname, ctx)) -class TfliteModule(object): +class TFLiteModule(object): """Wrapper runtime module. This is a thin wrapper of the underlying TVM module. diff --git a/tests/python/unittest/test_runtime_tflite.py b/tests/python/contrib/test_tflite_runtime.py similarity index 92% rename from tests/python/unittest/test_runtime_tflite.py rename to tests/python/contrib/test_tflite_runtime.py index 729048cf0335..c79245af7af9 100644 --- a/tests/python/unittest/test_runtime_tflite.py +++ b/tests/python/contrib/test_tflite_runtime.py @@ -41,10 +41,12 @@ def check_verify(): tflite_model = create_tflite_model() temp = util.tempdir() tflite_model_path = temp.relpath(tflite_fname) + print(tflite_model_path) open(tflite_model_path, 'wb').write(tflite_model) # inference via tflite interpreter python apis - interpreter = tflite.Interpreter(model_path="/tmp/model.tflite") + print('interpreter') + interpreter = tflite.Interpreter(model_path=tflite_model_path) interpreter.allocate_tensors() input_details = interpreter.get_input_details() output_details = interpreter.get_output_details() @@ -55,8 +57,9 @@ def check_verify(): interpreter.invoke() tflite_output = interpreter.get_tensor(output_details[0]['index']) + print('tvm tflite runtime') # inference via tvm tflite runtime - runtime = tflite_runtime.create(tflite_fname, tvm.cpu(0)) + runtime = tflite_runtime.create(tflite_model_path, tvm.cpu(0)) runtime.allocate_tensors() runtime.set_input(0, tvm.nd.array(tflite_input)) runtime.invoke() @@ -72,7 +75,7 @@ def check_remote(): open(tflite_model_path, 'wb').write(tflite_model) # inference via tflite interpreter python apis - interpreter = tflite.Interpreter(model_path="/tmp/model.tflite") + interpreter = tflite.Interpreter(model_path=tflite_model_path) interpreter.allocate_tensors() input_details = interpreter.get_input_details() output_details = interpreter.get_output_details() From d190f5f99b9e18be125b902df2c86061c997f696 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Wed, 27 Nov 2019 16:31:30 -0800 Subject: [PATCH 05/21] Update. --- cmake/config.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/config.cmake b/cmake/config.cmake index e3f0b6225f1a..42a860c8073e 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -145,7 +145,7 @@ set(USE_RANDOM OFF) # Whether use NNPack set(USE_NNPACK OFF) -set(USE_TFLITE ON) +set(USE_TFLITE OFF) # Whether use CuDNN set(USE_CUDNN OFF) From 031bbd84e75384640b9e688c70029466a0eccdd8 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Wed, 27 Nov 2019 16:38:32 -0800 Subject: [PATCH 06/21] Update. --- src/runtime/contrib/tflite/tflite_runtime.cc | 18 +++++++++--------- src/runtime/contrib/tflite/tflite_runtime.h | 9 ++------- 2 files changed, 11 insertions(+), 16 deletions(-) diff --git a/src/runtime/contrib/tflite/tflite_runtime.cc b/src/runtime/contrib/tflite/tflite_runtime.cc index 9eb0a969b3b5..6baf5cdc5746 100644 --- a/src/runtime/contrib/tflite/tflite_runtime.cc +++ b/src/runtime/contrib/tflite/tflite_runtime.cc @@ -93,7 +93,7 @@ DataType TfLiteDType2TVMDType(TfLiteType dtype) { } -void TfliteRuntime::Init(const std::string& tflite_fname, +void TFLiteRuntime::Init(const std::string& tflite_fname, TVMContext ctx) { std::unique_ptr model = tflite::FlatBufferModel::BuildFromFile(tflite_fname.c_str()); tflite::ops::builtin::BuiltinOpResolver resolver; @@ -101,15 +101,15 @@ void TfliteRuntime::Init(const std::string& tflite_fname, ctx_ = ctx; } -void TfliteRuntime::AllocateTensors() { +void TFLiteRuntime::AllocateTensors() { interpreter_->AllocateTensors(); } -void TfliteRuntime::Invoke() { +void TFLiteRuntime::Invoke() { interpreter_->Invoke(); } -void TfliteRuntime::SetInput(int index, DLTensor* data_in) { +void TFLiteRuntime::SetInput(int index, DLTensor* data_in) { DataType dtype(data_in->dtype); TVM_DTYPE_DISPATCH(dtype, DType, { DType* dest = interpreter_->typed_input_tensor(index); @@ -125,7 +125,7 @@ void TfliteRuntime::SetInput(int index, DLTensor* data_in) { }); } -NDArray TfliteRuntime::GetOutput(int index) const { +NDArray TFLiteRuntime::GetOutput(int index) const { TfLiteTensor* output = interpreter_->output_tensor(index); DataType dtype = TfLiteDType2TVMDType(output->type); TfLiteIntArray* dims = output->dims; @@ -147,7 +147,7 @@ NDArray TfliteRuntime::GetOutput(int index) const { return ret; } -PackedFunc TfliteRuntime::GetFunction( +PackedFunc TFLiteRuntime::GetFunction( const std::string& name, const ObjectPtr& sptr_to_self) { // Return member functions during query. @@ -174,16 +174,16 @@ PackedFunc TfliteRuntime::GetFunction( } } -Module TfliteRuntimeCreate(const std::string& tflite_fname, +Module TFLiteRuntimeCreate(const std::string& tflite_fname, TVMContext ctx) { - auto exec = make_object(); + auto exec = make_object(); exec->Init(tflite_fname, ctx); return Module(exec); } TVM_REGISTER_GLOBAL("tvm.tflite_runtime.create") .set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = TfliteRuntimeCreate(args[0], args[1]); + *rv = TFLiteRuntimeCreate(args[0], args[1]); }); } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/tflite/tflite_runtime.h b/src/runtime/contrib/tflite/tflite_runtime.h index 124e9e565a57..4eaa5e498681 100644 --- a/src/runtime/contrib/tflite/tflite_runtime.h +++ b/src/runtime/contrib/tflite/tflite_runtime.h @@ -26,14 +26,9 @@ #define TVM_RUNTIME_TFLITE_TFLITE_RUNTIME_H_ #include -#include -#include #include #include -#include -#include -#include #include #include @@ -47,7 +42,7 @@ namespace runtime { * This runtime can be acccesibly in various language via * TVM runtime PackedFunc API. */ -class TfliteRuntime : public ModuleNode { +class TFLiteRuntime : public ModuleNode { public: /*! * \brief Get member function to front-end @@ -62,7 +57,7 @@ class TfliteRuntime : public ModuleNode { * \return The type key of the executor. */ const char* type_key() const final { - return "TfliteRuntime"; + return "TFLiteRuntime"; } void AllocateTensors(); From 299930d696d4ca08d0369e86f7502de4de60f1b6 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Tue, 3 Dec 2019 11:50:21 -0800 Subject: [PATCH 07/21] change model loading to byte string --- python/tvm/contrib/tflite_runtime.py | 9 +++---- src/runtime/contrib/tflite/tflite_runtime.cc | 8 +++--- src/runtime/contrib/tflite/tflite_runtime.h | 2 +- tests/python/contrib/test_tflite_runtime.py | 26 +++++++++++--------- 4 files changed, 22 insertions(+), 23 deletions(-) diff --git a/python/tvm/contrib/tflite_runtime.py b/python/tvm/contrib/tflite_runtime.py index d7a91b23ef04..ebe2e9118686 100644 --- a/python/tvm/contrib/tflite_runtime.py +++ b/python/tvm/contrib/tflite_runtime.py @@ -23,7 +23,7 @@ from .._ffi.runtime_ctypes import TVMContext from ..rpc import base as rpc_base -def create(tflite_fname, ctx): +def create(tflite_model_bytes, ctx): """Create a runtime executor module given a graph and module. Parameters ---------- @@ -41,18 +41,15 @@ def create(tflite_fname, ctx): graph_module : GraphModule Runtime graph module that can be used to execute the graph. """ - if not isinstance(tflite_fname, string_types): - raise ValueError("Type %s is not supported" % type(tflite_fname)) - device_type = ctx.device_type if device_type >= rpc_base.RPC_SESS_MASK: device_type = ctx.device_type % rpc_base.RPC_SESS_MASK device_id = ctx.device_id remote_ctx = context(device_type, device_id) fcreate = ctx._rpc_sess.get_function("tvm.tflite_runtime.create") - return TFLiteModule(fcreate(tflite_fname, ctx)) + return TFLiteModule(fcreate(bytearray(tflite_model_bytes), ctx)) fcreate = get_global_func("tvm.tflite_runtime.create") - return TFLiteModule(fcreate(tflite_fname, ctx)) + return TFLiteModule(fcreate(bytearray(tflite_model_bytes), ctx)) class TFLiteModule(object): diff --git a/src/runtime/contrib/tflite/tflite_runtime.cc b/src/runtime/contrib/tflite/tflite_runtime.cc index 6baf5cdc5746..d88e06ae021f 100644 --- a/src/runtime/contrib/tflite/tflite_runtime.cc +++ b/src/runtime/contrib/tflite/tflite_runtime.cc @@ -93,9 +93,9 @@ DataType TfLiteDType2TVMDType(TfLiteType dtype) { } -void TFLiteRuntime::Init(const std::string& tflite_fname, +void TFLiteRuntime::Init(const std::string& tflite_model_bytes, TVMContext ctx) { - std::unique_ptr model = tflite::FlatBufferModel::BuildFromFile(tflite_fname.c_str()); + std::unique_ptr model = tflite::FlatBufferModel::BuildFromBuffer(tflite_model_bytes.c_str(), tflite_model_bytes.size()); tflite::ops::builtin::BuiltinOpResolver resolver; tflite::InterpreterBuilder(*model, resolver)(&interpreter_); ctx_ = ctx; @@ -174,10 +174,10 @@ PackedFunc TFLiteRuntime::GetFunction( } } -Module TFLiteRuntimeCreate(const std::string& tflite_fname, +Module TFLiteRuntimeCreate(const std::string& tflite_model_bytes, TVMContext ctx) { auto exec = make_object(); - exec->Init(tflite_fname, ctx); + exec->Init(tflite_model_bytes, ctx); return Module(exec); } diff --git a/src/runtime/contrib/tflite/tflite_runtime.h b/src/runtime/contrib/tflite/tflite_runtime.h index 4eaa5e498681..b8a223a5e96c 100644 --- a/src/runtime/contrib/tflite/tflite_runtime.h +++ b/src/runtime/contrib/tflite/tflite_runtime.h @@ -72,7 +72,7 @@ class TFLiteRuntime : public ModuleNode { * executed on. */ - void Init(const std::string& tflite_fname, + void Init(const std::string& tflite_model_bytes, TVMContext ctx); /*! * \brief set index-th input to the graph. diff --git a/tests/python/contrib/test_tflite_runtime.py b/tests/python/contrib/test_tflite_runtime.py index c79245af7af9..861693ec138c 100644 --- a/tests/python/contrib/test_tflite_runtime.py +++ b/tests/python/contrib/test_tflite_runtime.py @@ -59,12 +59,13 @@ def check_verify(): print('tvm tflite runtime') # inference via tvm tflite runtime - runtime = tflite_runtime.create(tflite_model_path, tvm.cpu(0)) - runtime.allocate_tensors() - runtime.set_input(0, tvm.nd.array(tflite_input)) - runtime.invoke() - out = runtime.get_output(0) - np.testing.assert_equal(out.asnumpy(), tflite_output) + with open(tflite_model_path, 'rb') as model_fin: + runtime = tflite_runtime.create(model_fin.read(), tvm.cpu(0)) + runtime.allocate_tensors() + runtime.set_input(0, tvm.nd.array(tflite_input)) + runtime.invoke() + out = runtime.get_output(0) + np.testing.assert_equal(out.asnumpy(), tflite_output) def check_remote(): @@ -92,12 +93,13 @@ def check_remote(): ctx = remote.cpu(0) a = remote.upload(tflite_model_path) - runtime = tflite_runtime.create(tflite_model_path, remote.cpu(0)) - runtime.allocate_tensors() - runtime.set_input(0, tvm.nd.array(tflite_input, remote.cpu(0))) - runtime.invoke() - out = runtime.get_output(0) - np.testing.assert_equal(out.asnumpy(), tflite_output) + with open(tflite_model_path, 'rb') as model_fin: + runtime = tflite_runtime.create(model_fin.read(), remote.cpu(0)) + runtime.allocate_tensors() + runtime.set_input(0, tvm.nd.array(tflite_input, remote.cpu(0))) + runtime.invoke() + out = runtime.get_output(0) + np.testing.assert_equal(out.asnumpy(), tflite_output) check_verify() From 25c1333a684d9d5f676e095a3125f015370d3cf1 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Tue, 3 Dec 2019 11:56:42 -0800 Subject: [PATCH 08/21] Introduce TFLITE_LIB_PATH --- cmake/modules/contrib/TfLite.cmake | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/cmake/modules/contrib/TfLite.cmake b/cmake/modules/contrib/TfLite.cmake index 43cb8e8e3695..ec601815732c 100644 --- a/cmake/modules/contrib/TfLite.cmake +++ b/cmake/modules/contrib/TfLite.cmake @@ -20,10 +20,15 @@ if(USE_TFLITE) if (NOT DEFINED TENSORFLOW_PATH) set(TENSORFLOW_PATH ${CMAKE_CURRENT_SOURCE_DIR}/tensorflow) endif() + file(GLOB TFLITE_CONTRIB_SRC src/runtime/contrib/tflite/*.cc) list(APPEND RUNTIME_SRCS ${TFLITE_CONTRIB_SRC}) include_directories(${TENSORFLOW_PATH}) - find_library(TFLITE_CONTRIB_LIB libtensorflow-lite.a ${TENSORFLOW_PATH}/tensorflow/lite/tools/make/gen/linux_x86_64/lib) + + if (NOT DEFINED TFLITE_LIB_PATH) + set(TFLITE_LIB_PATH ${TENSORFLOW_PATH}/tensorflow/lite/tools/make/gen/*/lib) + endif() + find_library(TFLITE_CONTRIB_LIB libtensorflow-lite.a ${TFLITE_LIB_PATH}) list(APPEND TVM_LINKER_LIBS ${TFLITE_CONTRIB_LIB}) list(APPEND TVM_LINKER_LIBS rt dl flatbuffers) From bf3706ed608f1cebd58469d0b5b6bf4ba144f217 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Tue, 3 Dec 2019 22:07:32 +0000 Subject: [PATCH 09/21] Update cmake --- CMakeLists.txt | 3 ++- cmake/config.cmake | 9 ++++++++- .../contrib/{TfLite.cmake => TFLite.cmake} | 19 +++++++++++-------- 3 files changed, 21 insertions(+), 10 deletions(-) rename cmake/modules/contrib/{TfLite.cmake => TFLite.cmake} (66%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 54b122293749..bb21db26bad6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -64,6 +64,7 @@ tvm_option(USE_RANDOM "Build with random support" OFF) tvm_option(USE_MICRO_STANDALONE_RUNTIME "Build with micro.standalone_runtime support" OFF) tvm_option(USE_ANTLR "Build with ANTLR for Relay parsing" OFF) tvm_option(USE_TFLITE "Build with tflite support" OFF) +tvm_option(USE_TENSORFLOW_PATH "TensorFlow root path when use TFLite" none) # include directories include_directories(${CMAKE_INCLUDE_PATH}) @@ -258,7 +259,7 @@ include(cmake/modules/contrib/MicroStandaloneRuntime.cmake) include(cmake/modules/contrib/Sort.cmake) include(cmake/modules/contrib/NNPack.cmake) include(cmake/modules/contrib/HybridDump.cmake) -include(cmake/modules/contrib/TfLite.cmake) +include(cmake/modules/contrib/TFLite.cmake) if(NOT MSVC) include(CheckCXXCompilerFlag) diff --git a/cmake/config.cmake b/cmake/config.cmake index 42a860c8073e..df7cbefa5b91 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -145,7 +145,14 @@ set(USE_RANDOM OFF) # Whether use NNPack set(USE_NNPACK OFF) -set(USE_TFLITE OFF) +# Possible values: +# - ON: enable tflite with cmake's find search +# - OFF: disable tflite +# - /path/to/libtensorflow-lite.a: use specific path to tensorflow lite library +set(USE_TFLITE_PATH OFF) + +# /path/to/tensorflow: tensorflow root path when use tflite library +set(USE_TENSORFLOW_PATH none) # Whether use CuDNN set(USE_CUDNN OFF) diff --git a/cmake/modules/contrib/TfLite.cmake b/cmake/modules/contrib/TFLite.cmake similarity index 66% rename from cmake/modules/contrib/TfLite.cmake rename to cmake/modules/contrib/TFLite.cmake index ec601815732c..a62ff722ebbf 100644 --- a/cmake/modules/contrib/TfLite.cmake +++ b/cmake/modules/contrib/TFLite.cmake @@ -15,21 +15,24 @@ # specific language governing permissions and limitations # under the License. -if(USE_TFLITE) +if(NOT USE_TFLITE STREQUAL "OFF") message(STATUS "Build with contrib.tflite") - if (NOT DEFINED TENSORFLOW_PATH) - set(TENSORFLOW_PATH ${CMAKE_CURRENT_SOURCE_DIR}/tensorflow) + message("tensorflow path: ${USE_TENSORFLOW_PATH}") + if (NOT USE_TENSORFLOW_PATH) + set(USE_TENSORFLOW_PATH ${CMAKE_CURRENT_SOURCE_DIR}/tensorflow) endif() + message("tensorflow path: ${USE_TENSORFLOW_PATH}") file(GLOB TFLITE_CONTRIB_SRC src/runtime/contrib/tflite/*.cc) list(APPEND RUNTIME_SRCS ${TFLITE_CONTRIB_SRC}) - include_directories(${TENSORFLOW_PATH}) + include_directories(${USE_TENSORFLOW_PATH}) - if (NOT DEFINED TFLITE_LIB_PATH) - set(TFLITE_LIB_PATH ${TENSORFLOW_PATH}/tensorflow/lite/tools/make/gen/*/lib) + if (USE_TFLITE STREQUAL "ON") + set(USE_TFLITE ${USE_TENSORFLOW_PATH}/tensorflow/lite/tools/make/gen/*/lib) endif() - find_library(TFLITE_CONTRIB_LIB libtensorflow-lite.a ${TFLITE_LIB_PATH}) + find_library(TFLITE_CONTRIB_LIB libtensorflow-lite.a ${USE_TFLITE}) + message("tflite lib path: ${TFLITE_CONTRIB_LIB}") list(APPEND TVM_LINKER_LIBS ${TFLITE_CONTRIB_LIB}) list(APPEND TVM_LINKER_LIBS rt dl flatbuffers) -endif(USE_TFLITE) +endif() From fed10d723d34e3f87857461e8e4dc666af7a8b1b Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Tue, 3 Dec 2019 22:16:32 +0000 Subject: [PATCH 10/21] fix for runtime lib dependency --- cmake/modules/contrib/TFLite.cmake | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cmake/modules/contrib/TFLite.cmake b/cmake/modules/contrib/TFLite.cmake index a62ff722ebbf..4684cd3d1146 100644 --- a/cmake/modules/contrib/TFLite.cmake +++ b/cmake/modules/contrib/TFLite.cmake @@ -33,6 +33,6 @@ if(NOT USE_TFLITE STREQUAL "OFF") find_library(TFLITE_CONTRIB_LIB libtensorflow-lite.a ${USE_TFLITE}) message("tflite lib path: ${TFLITE_CONTRIB_LIB}") - list(APPEND TVM_LINKER_LIBS ${TFLITE_CONTRIB_LIB}) - list(APPEND TVM_LINKER_LIBS rt dl flatbuffers) + list(APPEND TVM_RUNTIME_LINKER_LIBS ${TFLITE_CONTRIB_LIB}) + list(APPEND TVM_RUNTIME_LINKER_LIBS rt dl flatbuffers) endif() From 33f6c55d4e3e5ca733458d5236872d943eaf20e6 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Tue, 3 Dec 2019 22:18:54 +0000 Subject: [PATCH 11/21] skip test --- tests/python/contrib/test_tflite_runtime.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/contrib/test_tflite_runtime.py b/tests/python/contrib/test_tflite_runtime.py index 861693ec138c..6cf277abf788 100644 --- a/tests/python/contrib/test_tflite_runtime.py +++ b/tests/python/contrib/test_tflite_runtime.py @@ -22,7 +22,7 @@ import tflite_runtime.interpreter as tflite -def test_tflite_runtime(): +def skip_test_tflite_runtime(): def create_tflite_model(): root = tf.Module() @@ -106,4 +106,4 @@ def check_remote(): check_remote() if __name__ == "__main__": - test_tflite_runtime() + pass From dc266605974619ab1808222edcb577954de9904a Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Tue, 3 Dec 2019 14:29:23 -0800 Subject: [PATCH 12/21] Fix --- cmake/config.cmake | 2 +- cmake/modules/contrib/TFLite.cmake | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cmake/config.cmake b/cmake/config.cmake index df7cbefa5b91..25bf5516291b 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -149,7 +149,7 @@ set(USE_NNPACK OFF) # - ON: enable tflite with cmake's find search # - OFF: disable tflite # - /path/to/libtensorflow-lite.a: use specific path to tensorflow lite library -set(USE_TFLITE_PATH OFF) +set(USE_TFLITE OFF) # /path/to/tensorflow: tensorflow root path when use tflite library set(USE_TENSORFLOW_PATH none) diff --git a/cmake/modules/contrib/TFLite.cmake b/cmake/modules/contrib/TFLite.cmake index 4684cd3d1146..71f82acfd4bd 100644 --- a/cmake/modules/contrib/TFLite.cmake +++ b/cmake/modules/contrib/TFLite.cmake @@ -18,7 +18,7 @@ if(NOT USE_TFLITE STREQUAL "OFF") message(STATUS "Build with contrib.tflite") message("tensorflow path: ${USE_TENSORFLOW_PATH}") - if (NOT USE_TENSORFLOW_PATH) + if (USE_TENSORFLOW_PATH STREQUAL "none") set(USE_TENSORFLOW_PATH ${CMAKE_CURRENT_SOURCE_DIR}/tensorflow) endif() message("tensorflow path: ${USE_TENSORFLOW_PATH}") From 28052474c13fbff302f0249f30144a68e4084a46 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Tue, 3 Dec 2019 14:33:38 -0800 Subject: [PATCH 13/21] Fix lint --- src/runtime/contrib/tflite/tflite_runtime.cc | 1 + src/runtime/contrib/tflite/tflite_runtime.h | 7 +++---- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/runtime/contrib/tflite/tflite_runtime.cc b/src/runtime/contrib/tflite/tflite_runtime.cc index d88e06ae021f..eaaeb848c6a7 100644 --- a/src/runtime/contrib/tflite/tflite_runtime.cc +++ b/src/runtime/contrib/tflite/tflite_runtime.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include "tflite_runtime.h" diff --git a/src/runtime/contrib/tflite/tflite_runtime.h b/src/runtime/contrib/tflite/tflite_runtime.h index b8a223a5e96c..4180897e1073 100644 --- a/src/runtime/contrib/tflite/tflite_runtime.h +++ b/src/runtime/contrib/tflite/tflite_runtime.h @@ -22,8 +22,8 @@ * containing only tvm PackedFunc. * \file graph_runtime.h */ -#ifndef TVM_RUNTIME_TFLITE_TFLITE_RUNTIME_H_ -#define TVM_RUNTIME_TFLITE_TFLITE_RUNTIME_H_ +#ifndef TVM_RUNTIME_CONTRIB_TFLITE_TFLITE_RUNTIME_H_ +#define TVM_RUNTIME_CONTRIB_TFLITE_TFLITE_RUNTIME_H_ #include #include @@ -98,10 +98,9 @@ class TFLiteRuntime : public ModuleNode { private: std::unique_ptr interpreter_; TVMContext ctx_; - }; } // namespace runtime } // namespace tvm -#endif // TVM_RUNTIME_TFLITE_TFLITE_RUNTIME_H_ +#endif // TVM_RUNTIME_CONTRIB_TFLITE_TFLITE_RUNTIME_H_ From b309175da62d7f3742f8597636b12febd8a0721e Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Tue, 3 Dec 2019 14:39:19 -0800 Subject: [PATCH 14/21] Fix lint --- src/runtime/contrib/tflite/tflite_runtime.cc | 7 ++++--- src/runtime/contrib/tflite/tflite_runtime.h | 1 + 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/runtime/contrib/tflite/tflite_runtime.cc b/src/runtime/contrib/tflite/tflite_runtime.cc index eaaeb848c6a7..a32669d5f635 100644 --- a/src/runtime/contrib/tflite/tflite_runtime.cc +++ b/src/runtime/contrib/tflite/tflite_runtime.cc @@ -25,7 +25,6 @@ #include #include #include -#include #include "tflite_runtime.h" @@ -96,7 +95,10 @@ DataType TfLiteDType2TVMDType(TfLiteType dtype) { void TFLiteRuntime::Init(const std::string& tflite_model_bytes, TVMContext ctx) { - std::unique_ptr model = tflite::FlatBufferModel::BuildFromBuffer(tflite_model_bytes.c_str(), tflite_model_bytes.size()); + const char* buffer = tflite_model_bytes.c_str(); + size_t buffer_size = tflite_model_bytes.size(); + std::unique_ptr model = + tflite::FlatBufferModel::BuildFromBuffer(buffer, buffer_size); tflite::ops::builtin::BuiltinOpResolver resolver; tflite::InterpreterBuilder(*model, resolver)(&interpreter_); ctx_ = ctx; @@ -136,7 +138,6 @@ NDArray TFLiteRuntime::GetOutput(int index) const { shape.push_back(dims->data[i]); size *= dims->data[i]; } - NDArray ret = NDArray::Empty(shape, dtype, ctx_); TVM_DTYPE_DISPATCH(dtype, DType, { DType* dest = static_cast(ret->data); diff --git a/src/runtime/contrib/tflite/tflite_runtime.h b/src/runtime/contrib/tflite/tflite_runtime.h index 4180897e1073..c30b65cc6a96 100644 --- a/src/runtime/contrib/tflite/tflite_runtime.h +++ b/src/runtime/contrib/tflite/tflite_runtime.h @@ -31,6 +31,7 @@ #include #include +#include namespace tvm { namespace runtime { From e7c0abcbe33a2bd1abf1a000eed4bec4d1c8e10a Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Tue, 3 Dec 2019 14:50:29 -0800 Subject: [PATCH 15/21] Fix lint --- python/tvm/contrib/tflite_runtime.py | 7 ------ src/runtime/contrib/tflite/tflite_runtime.h | 26 ++++++++++++--------- tests/python/contrib/test_tflite_runtime.py | 4 ++-- 3 files changed, 17 insertions(+), 20 deletions(-) diff --git a/python/tvm/contrib/tflite_runtime.py b/python/tvm/contrib/tflite_runtime.py index ebe2e9118686..0b72610a11df 100644 --- a/python/tvm/contrib/tflite_runtime.py +++ b/python/tvm/contrib/tflite_runtime.py @@ -15,12 +15,8 @@ # specific language governing permissions and limitations # under the License. """TFLite runtime that load and run tflite models.""" -import numpy as np - -from .._ffi.base import string_types from .._ffi.ndarray import context from .._ffi.function import get_global_func -from .._ffi.runtime_ctypes import TVMContext from ..rpc import base as rpc_base def create(tflite_model_bytes, ctx): @@ -43,9 +39,6 @@ def create(tflite_model_bytes, ctx): """ device_type = ctx.device_type if device_type >= rpc_base.RPC_SESS_MASK: - device_type = ctx.device_type % rpc_base.RPC_SESS_MASK - device_id = ctx.device_id - remote_ctx = context(device_type, device_id) fcreate = ctx._rpc_sess.get_function("tvm.tflite_runtime.create") return TFLiteModule(fcreate(bytearray(tflite_model_bytes), ctx)) fcreate = get_global_func("tvm.tflite_runtime.create") diff --git a/src/runtime/contrib/tflite/tflite_runtime.h b/src/runtime/contrib/tflite/tflite_runtime.h index c30b65cc6a96..4b08b97b6865 100644 --- a/src/runtime/contrib/tflite/tflite_runtime.h +++ b/src/runtime/contrib/tflite/tflite_runtime.h @@ -18,9 +18,9 @@ */ /*! - * \brief Tiny graph runtime that can run graph + * \brief Tflite runtime that can run tflite model * containing only tvm PackedFunc. - * \file graph_runtime.h + * \file tflite_runtime.h */ #ifndef TVM_RUNTIME_CONTRIB_TFLITE_TFLITE_RUNTIME_H_ #define TVM_RUNTIME_CONTRIB_TFLITE_TFLITE_RUNTIME_H_ @@ -38,7 +38,7 @@ namespace runtime { /*! - * \brief Tiny graph runtime. + * \brief Tflite runtime. * * This runtime can be acccesibly in various language via * TVM runtime PackedFunc API. @@ -61,22 +61,26 @@ class TFLiteRuntime : public ModuleNode { return "TFLiteRuntime"; } + /*! + * \brief Update allocations for all tenssors. This is relatively expensive. + */ void AllocateTensors(); + /*! + * \brief Invoke the internal tflite interpreter and run the whole model in + * dependency order. + */ void Invoke(); /*! - * \brief Initialize the graph executor with graph and context. - * \param graph_json The execution graph. - * \param module The module containing the compiled functions for the host - * processor. - * \param ctxs The context of the host and devices where graph nodes will be - * executed on. + * \brief Initialize the tflite runtime with tflite model and context. + * \param tflite_model_bytes The tflite model. + * \param ctx The context where the tflite model will be executed on. */ - void Init(const std::string& tflite_model_bytes, TVMContext ctx); + /*! - * \brief set index-th input to the graph. + * \brief set index-th input to the model. * \param index The input index. * \param data_in The input data. */ diff --git a/tests/python/contrib/test_tflite_runtime.py b/tests/python/contrib/test_tflite_runtime.py index 6cf277abf788..c22dfc2be8ba 100644 --- a/tests/python/contrib/test_tflite_runtime.py +++ b/tests/python/contrib/test_tflite_runtime.py @@ -22,7 +22,7 @@ import tflite_runtime.interpreter as tflite -def skip_test_tflite_runtime(): +def skipped_test_tflite_runtime(): def create_tflite_model(): root = tf.Module() @@ -106,4 +106,4 @@ def check_remote(): check_remote() if __name__ == "__main__": - pass + skipped_test_tflite_runtime() From 4030b7b5de6970e6819858a9efbddae78fe9107f Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Tue, 3 Dec 2019 14:53:37 -0800 Subject: [PATCH 16/21] Update comments --- python/tvm/contrib/tflite_runtime.py | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/python/tvm/contrib/tflite_runtime.py b/python/tvm/contrib/tflite_runtime.py index 0b72610a11df..15e83ac14f8b 100644 --- a/python/tvm/contrib/tflite_runtime.py +++ b/python/tvm/contrib/tflite_runtime.py @@ -20,22 +20,18 @@ from ..rpc import base as rpc_base def create(tflite_model_bytes, ctx): - """Create a runtime executor module given a graph and module. + """Create a runtime executor module given a tflite model and context. Parameters ---------- - tflite_fname : str - The graph to be deployed in json format output by nnvm graph. - The graph can only contain one operator(tvm_op) that - points to the name of PackedFunc in the libmod. - ctx : TVMContext or list of TVMContext + tflite_model_byte : bytes + The tflite model to be deployed in bytes string format. + ctx : TVMContext The context to deploy the module. It can be local or remote when there - is only one TVMContext. Otherwise, the first context in the list will - be used as this purpose. All context should be given for heterogeneous - execution. + is only one TVMContext. Returns ------- - graph_module : GraphModule - Runtime graph module that can be used to execute the graph. + tflite_runtime : TFLiteModule + Runtime tflite module that can be used to execute the tflite model. """ device_type = ctx.device_type if device_type >= rpc_base.RPC_SESS_MASK: @@ -55,12 +51,12 @@ class TFLiteModule(object): Parameters ---------- module : Module - The interal tvm module that holds the actual graph functions. + The interal tvm module that holds the actual tflite functions. Attributes ---------- module : Module - The interal tvm module that holds the actual graph functions. + The interal tvm module that holds the actual tflite functions. """ def __init__(self, module): @@ -87,7 +83,7 @@ def set_input(self, index, value): self._set_input(index, value) def invoke(self): - """Run forward execution of the graph + """Invoke forward execution of the model Parameters ---------- @@ -97,6 +93,8 @@ def invoke(self): self._invoke() def allocate_tensors(self): + """Allocate space for all tensors. + """ self._allocate_tensors() From dcc05150ed809a79502579369c4bd9a6fa3d97f3 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Tue, 3 Dec 2019 14:56:01 -0800 Subject: [PATCH 17/21] Update comments --- python/tvm/contrib/tflite_runtime.py | 1 - tests/python/contrib/test_tflite_runtime.py | 3 ++- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/contrib/tflite_runtime.py b/python/tvm/contrib/tflite_runtime.py index 15e83ac14f8b..1440af0aa609 100644 --- a/python/tvm/contrib/tflite_runtime.py +++ b/python/tvm/contrib/tflite_runtime.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. """TFLite runtime that load and run tflite models.""" -from .._ffi.ndarray import context from .._ffi.function import get_global_func from ..rpc import base as rpc_base diff --git a/tests/python/contrib/test_tflite_runtime.py b/tests/python/contrib/test_tflite_runtime.py index c22dfc2be8ba..a55d601baa46 100644 --- a/tests/python/contrib/test_tflite_runtime.py +++ b/tests/python/contrib/test_tflite_runtime.py @@ -106,4 +106,5 @@ def check_remote(): check_remote() if __name__ == "__main__": - skipped_test_tflite_runtime() + # skipped_test_tflite_runtime() + pass From 28df0fc292b400e26de471fc0634a3703edad0b4 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Tue, 3 Dec 2019 14:59:24 -0800 Subject: [PATCH 18/21] Update comments --- python/tvm/contrib/tflite_runtime.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/python/tvm/contrib/tflite_runtime.py b/python/tvm/contrib/tflite_runtime.py index 1440af0aa609..89a547f48f96 100644 --- a/python/tvm/contrib/tflite_runtime.py +++ b/python/tvm/contrib/tflite_runtime.py @@ -97,19 +97,12 @@ def allocate_tensors(self): self._allocate_tensors() - def get_output(self, index, out=None): + def get_output(self, index): """Get index-th output to out Parameters ---------- index : int The output index - - out : NDArray - The output array container """ - if out: - self._get_output(index, out) - return out - return self._get_output(index) From 176b0d63b002a8df8563b85acdb41265cba7dc36 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Tue, 3 Dec 2019 15:05:46 -0800 Subject: [PATCH 19/21] Remove message --- cmake/modules/contrib/TFLite.cmake | 3 --- 1 file changed, 3 deletions(-) diff --git a/cmake/modules/contrib/TFLite.cmake b/cmake/modules/contrib/TFLite.cmake index 71f82acfd4bd..9074def9dc8e 100644 --- a/cmake/modules/contrib/TFLite.cmake +++ b/cmake/modules/contrib/TFLite.cmake @@ -17,11 +17,9 @@ if(NOT USE_TFLITE STREQUAL "OFF") message(STATUS "Build with contrib.tflite") - message("tensorflow path: ${USE_TENSORFLOW_PATH}") if (USE_TENSORFLOW_PATH STREQUAL "none") set(USE_TENSORFLOW_PATH ${CMAKE_CURRENT_SOURCE_DIR}/tensorflow) endif() - message("tensorflow path: ${USE_TENSORFLOW_PATH}") file(GLOB TFLITE_CONTRIB_SRC src/runtime/contrib/tflite/*.cc) list(APPEND RUNTIME_SRCS ${TFLITE_CONTRIB_SRC}) @@ -31,7 +29,6 @@ if(NOT USE_TFLITE STREQUAL "OFF") set(USE_TFLITE ${USE_TENSORFLOW_PATH}/tensorflow/lite/tools/make/gen/*/lib) endif() find_library(TFLITE_CONTRIB_LIB libtensorflow-lite.a ${USE_TFLITE}) - message("tflite lib path: ${TFLITE_CONTRIB_LIB}") list(APPEND TVM_RUNTIME_LINKER_LIBS ${TFLITE_CONTRIB_LIB}) list(APPEND TVM_RUNTIME_LINKER_LIBS rt dl flatbuffers) From 9b4af76c2ecf267ee3a1571ccd37201915102963 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Tue, 3 Dec 2019 15:16:39 -0800 Subject: [PATCH 20/21] Update test --- tests/python/contrib/test_tflite_runtime.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/contrib/test_tflite_runtime.py b/tests/python/contrib/test_tflite_runtime.py index a55d601baa46..e8bc66300e1a 100644 --- a/tests/python/contrib/test_tflite_runtime.py +++ b/tests/python/contrib/test_tflite_runtime.py @@ -18,8 +18,8 @@ import numpy as np from tvm import rpc from tvm.contrib import util, tflite_runtime -import tensorflow as tf -import tflite_runtime.interpreter as tflite +# import tensorflow as tf +# import tflite_runtime.interpreter as tflite def skipped_test_tflite_runtime(): From 841cbe12a536cc2de15bdb0ef01591112e788104 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Tue, 3 Dec 2019 16:44:58 -0800 Subject: [PATCH 21/21] retrigeger