diff --git a/python/tvm/contrib/tflite_runtime.py b/python/tvm/contrib/tflite_runtime.py index 92501f950c56..3b0e268e2a44 100644 --- a/python/tvm/contrib/tflite_runtime.py +++ b/python/tvm/contrib/tflite_runtime.py @@ -73,6 +73,7 @@ def __init__(self, module): self._set_input = module["set_input"] self._invoke = module["invoke"] self._get_output = module["get_output"] + self._set_num_threads = module["set_num_threads"] def set_input(self, index, value): """Set inputs to the module via kwargs @@ -109,3 +110,12 @@ def get_output(self, index): The output index """ return self._get_output(index) + + def set_num_threads(self, num_threads): + """Set the number of threads via kwargs + Parameters + ---------- + num_threads : int + The number of threads + """ + self._set_num_threads(num_threads) diff --git a/src/runtime/contrib/tflite/tflite_runtime.cc b/src/runtime/contrib/tflite/tflite_runtime.cc index f56e62ec1a40..9a434fde2955 100644 --- a/src/runtime/contrib/tflite/tflite_runtime.cc +++ b/src/runtime/contrib/tflite/tflite_runtime.cc @@ -128,6 +128,8 @@ void TFLiteRuntime::SetInput(int index, DLTensor* data_in) { }); } +void TFLiteRuntime::SetNumThreads(int num_threads) { interpreter_->SetNumThreads(num_threads); } + NDArray TFLiteRuntime::GetOutput(int index) const { TfLiteTensor* output = interpreter_->tensor(interpreter_->outputs()[index]); DataType dtype = TfLiteDType2TVMDType(output->type); @@ -163,6 +165,12 @@ PackedFunc TFLiteRuntime::GetFunction(const std::string& name, [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetOutput(args[0]); }); } else if (name == "invoke") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->Invoke(); }); + } else if (name == "set_num_threads") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + int num_threads = args[0]; + CHECK_GE(num_threads, 1); + this->SetNumThreads(num_threads); + }); } else { return PackedFunc(); } diff --git a/src/runtime/contrib/tflite/tflite_runtime.h b/src/runtime/contrib/tflite/tflite_runtime.h index ff0e6ab0db56..3311f10975be 100644 --- a/src/runtime/contrib/tflite/tflite_runtime.h +++ b/src/runtime/contrib/tflite/tflite_runtime.h @@ -93,6 +93,11 @@ class TFLiteRuntime : public ModuleNode { * \return NDArray corresponding to given output node index. */ NDArray GetOutput(int index) const; + /*! + * \brief Set the number of threads available to the interpreter. + * \param num_threads The number of threads to be set. + */ + void SetNumThreads(int num_threads); // Buffer backing the interpreter's model std::unique_ptr flatBuffersBuffer_;