Skip to content

Commit

Permalink
[TFLite runtime] Allow to set number of threads to TFLite interpreter (
Browse files Browse the repository at this point in the history
…apache#6901)

* Support for setting thread count in TFLite runtime,

Co-authored-by: FrozenGene <zhaowu@apache.org>

* fix lint

Co-authored-by: FrozenGene <zhaowu@apache.org>
  • Loading branch information
2 people authored and Trevor Morris committed Dec 4, 2020
1 parent 9674d49 commit e09edaa
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 0 deletions.
10 changes: 10 additions & 0 deletions python/tvm/contrib/tflite_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __init__(self, module):
self._set_input = module["set_input"]
self._invoke = module["invoke"]
self._get_output = module["get_output"]
self._set_num_threads = module["set_num_threads"]

def set_input(self, index, value):
"""Set inputs to the module via kwargs
Expand Down Expand Up @@ -109,3 +110,12 @@ def get_output(self, index):
The output index
"""
return self._get_output(index)

def set_num_threads(self, num_threads):
"""Set the number of threads via kwargs
Parameters
----------
num_threads : int
The number of threads
"""
self._set_num_threads(num_threads)
8 changes: 8 additions & 0 deletions src/runtime/contrib/tflite/tflite_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ void TFLiteRuntime::SetInput(int index, DLTensor* data_in) {
});
}

void TFLiteRuntime::SetNumThreads(int num_threads) { interpreter_->SetNumThreads(num_threads); }

NDArray TFLiteRuntime::GetOutput(int index) const {
TfLiteTensor* output = interpreter_->tensor(interpreter_->outputs()[index]);
DataType dtype = TfLiteDType2TVMDType(output->type);
Expand Down Expand Up @@ -163,6 +165,12 @@ PackedFunc TFLiteRuntime::GetFunction(const std::string& name,
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetOutput(args[0]); });
} else if (name == "invoke") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->Invoke(); });
} else if (name == "set_num_threads") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
int num_threads = args[0];
CHECK_GE(num_threads, 1);
this->SetNumThreads(num_threads);
});
} else {
return PackedFunc();
}
Expand Down
5 changes: 5 additions & 0 deletions src/runtime/contrib/tflite/tflite_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ class TFLiteRuntime : public ModuleNode {
* \return NDArray corresponding to given output node index.
*/
NDArray GetOutput(int index) const;
/*!
* \brief Set the number of threads available to the interpreter.
* \param num_threads The number of threads to be set.
*/
void SetNumThreads(int num_threads);

// Buffer backing the interpreter's model
std::unique_ptr<char[]> flatBuffersBuffer_;
Expand Down

0 comments on commit e09edaa

Please sign in to comment.