Skip to content

Commit

Permalink
feat: Add GPU support for Android (#77)
Browse files Browse the repository at this point in the history
* feat: add gpu support for android

* docs: add basic docs on android gpu usage

* fix: Update Podfile.lock

* chore: Format C++

---------

Co-authored-by: Marc Rousavy <me@mrousavy.com>
  • Loading branch information
TkTioNG and mrousavy authored Jul 13, 2024
1 parent 0c524af commit b8cd552
Show file tree
Hide file tree
Showing 11 changed files with 382 additions and 307 deletions.
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ We follow the [conventional commits specification](https://www.conventionalcommi
- `fix`: bug fixes, e.g. fix crash due to deprecated method.
- `feat`: new features, e.g. add new method to the module.
- `refactor`: code refactor, e.g. migrate from class components to hooks.
- `docs`: changes into documentation, e.g. add usage example for the module..
- `docs`: changes into documentation, e.g. add usage example for the module.
- `test`: adding or updating tests, e.g. add integration tests using detox.
- `chore`: tooling changes, e.g. change CI config.

Expand Down
28 changes: 28 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,34 @@ If you are on bare React Native, you need to include the CoreML/Metal code in yo
> [!NOTE]
> Since some operations aren't supported on the CoreML delegate, make sure your Model is able to use the CoreML GPU delegate.
#### Android GPU/NNAPI (Android)
To enable GPU or NNAPI delegate in Android, you **may** need to include `OpenCL` library with `uses-native-library` on `application` scope in AndroidManifest.xml, starting from Android 12.
```xml
<!-- Like this -->
<uses-native-library android:name="libOpenCL.so" android:required="false" />
<!-- You may need one/all of the followings depends on your targeting devices -->
<uses-native-library android:name="libOpenCL-pixel.so" android:required="false" />
<uses-native-library android:name="libGLES_mali.so" android:required="false" />
<uses-native-library android:name="libPVROCL.so" android:required="false" />
```
Then, you can just use it:
```ts
const model = await loadTensorflowModel(require('assets/my-model.tflite'), 'android-gpu')
// or
const model = await loadTensorflowModel(require('assets/my-model.tflite'), 'nnapi')
```
> [!WARNING]
> NNAPI is deprecated on Android 15. Hence, it is not recommended in future projects.
> Both has similiar performance, but GPU delegate has better initial loading time.
> [!NOTE]
> Android does not provide support for OpenCL officially, however, most gpu vendors do provide support for it.
## Community Discord
[Join the Margelo Community Discord](https://discord.gg/6CSHz2qAvA) to chat about react-native-fast-tflite or other Margelo libraries.
Expand Down
19 changes: 14 additions & 5 deletions android/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,19 @@ find_package(ReactAndroid REQUIRED CONFIG)
find_package(fbjni REQUIRED CONFIG)

find_library(
TFLITE
tensorflowlite_jni
PATHS "./src/main/cpp/lib/tensorflow/jni/${ANDROID_ABI}"
NO_DEFAULT_PATH
NO_CMAKE_FIND_ROOT_PATH
TFLITE
tensorflowlite_jni
PATHS "./src/main/cpp/lib/tensorflow/jni/${ANDROID_ABI}"
NO_DEFAULT_PATH
NO_CMAKE_FIND_ROOT_PATH
)

find_library(
TFLITE_GPU
tensorflowlite_gpu_jni
PATHS "./src/main/cpp/lib/tensorflow/jni/${ANDROID_ABI}"
NO_DEFAULT_PATH
NO_CMAKE_FIND_ROOT_PATH
)

string(APPEND CMAKE_CXX_FLAGS " -DANDROID")
Expand Down Expand Up @@ -49,4 +57,5 @@ target_link_libraries(
ReactAndroid::reactnativejni # <-- CallInvokerImpl
fbjni::fbjni # <-- fbjni.h
${TFLITE}
${TFLITE_GPU}
)
5 changes: 5 additions & 0 deletions android/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,11 @@ dependencies {
implementation "org.tensorflow:tensorflow-lite:2.12.0"
extractHeaders("org.tensorflow:tensorflow-lite:2.12.0")
extractSO("org.tensorflow:tensorflow-lite:2.12.0")

// Tensorflow Lite GPU delegate
implementation "org.tensorflow:tensorflow-lite-gpu:2.12.0"
extractHeaders("org.tensorflow:tensorflow-lite-gpu:2.12.0")
extractSO("org.tensorflow:tensorflow-lite-gpu:2.12.0")
}

task extractAARHeaders {
Expand Down
152 changes: 93 additions & 59 deletions cpp/TensorflowPlugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

#ifdef ANDROID
#include <tensorflow/lite/c/c_api.h>
#include <tensorflow/lite/delegates/gpu/delegate.h>
#include <tensorflow/lite/delegates/nnapi/nnapi_delegate_c_api.h>
#else
#include <TensorFlowLiteC/TensorFlowLiteC.h>

Expand Down Expand Up @@ -56,83 +58,111 @@ void TensorflowPlugin::installToRuntime(jsi::Runtime& runtime,
delegateType = Delegate::CoreML;
} else if (delegate == "metal") {
delegateType = Delegate::Metal;
} else if (delegate == "nnapi") {
delegateType = Delegate::NnApi;
} else if (delegate == "android-gpu") {
delegateType = Delegate::AndroidGPU;
} else {
delegateType = Delegate::Default;
}
}

auto promise =
Promise::createPromise(runtime, [=, &runtime](std::shared_ptr<Promise> promise) {
// Launch async thread
std::async(std::launch::async, [=, &runtime]() {
try {
// Fetch model from URL (JS bundle)
Buffer buffer = fetchURL(modelPath);

// Load Model into Tensorflow
auto model = TfLiteModelCreate(buffer.data, buffer.size);
if (model == nullptr) {
callInvoker->invokeAsync([=]() {
promise->reject("Failed to load model from \"" + modelPath + "\"!");
});
return;
}

// Create TensorFlow Interpreter
auto options = TfLiteInterpreterOptionsCreate();

switch (delegateType) {
case Delegate::CoreML: {
auto promise = Promise::createPromise(runtime, [=, &runtime](
std::shared_ptr<Promise> promise) {
// Launch async thread
std::async(std::launch::async, [=, &runtime]() {
try {
// Fetch model from URL (JS bundle)
Buffer buffer = fetchURL(modelPath);

// Load Model into Tensorflow
auto model = TfLiteModelCreate(buffer.data, buffer.size);
if (model == nullptr) {
callInvoker->invokeAsync(
[=]() { promise->reject("Failed to load model from \"" + modelPath + "\"!"); });
return;
}

// Create TensorFlow Interpreter
auto options = TfLiteInterpreterOptionsCreate();

switch (delegateType) {
case Delegate::CoreML: {
#if FAST_TFLITE_ENABLE_CORE_ML
TfLiteCoreMlDelegateOptions delegateOptions;
auto delegate = TfLiteCoreMlDelegateCreate(&delegateOptions);
TfLiteInterpreterOptionsAddDelegate(options, delegate);
break;
TfLiteCoreMlDelegateOptions delegateOptions;
auto delegate = TfLiteCoreMlDelegateCreate(&delegateOptions);
TfLiteInterpreterOptionsAddDelegate(options, delegate);
break;
#else
callInvoker->invokeAsync([=]() {
promise->reject("CoreML Delegate is not enabled! Set $EnableCoreMLDelegate to true in Podfile and rebuild.");
});
return;
#endif
}
case Delegate::Metal: {
callInvoker->invokeAsync(
[=]() { promise->reject("Metal Delegate is not supported!"); });
return;
}
#ifdef ANDROID
case Delegate::NnApi: {
TfLiteNnapiDelegateOptions delegateOptions = TfLiteNnapiDelegateOptionsDefault();
auto delegate = TfLiteNnapiDelegateCreate(&delegateOptions);
TfLiteInterpreterOptionsAddDelegate(options, delegate);
break;
}
case Delegate::AndroidGPU: {
TfLiteGpuDelegateOptionsV2 delegateOptions = TfLiteGpuDelegateOptionsV2Default();
auto delegate = TfLiteGpuDelegateV2Create(&delegateOptions);
TfLiteInterpreterOptionsAddDelegate(options, delegate);
break;
}
#else
case Delegate::NnApi: {
callInvoker->invokeAsync([=]() {
promise->reject("Nnapi Delegate is only supported on Android!");
});
}
case Delegate::Metal: {
callInvoker->invokeAsync(
[=]() { promise->reject("Metal Delegate is not supported!"); });
return;
}
default: {
// use default CPU delegate.
case Delegate::AndroidGPU: {
callInvoker->invokeAsync([=]() {
promise->reject("Android-Gpu Delegate is not supported on Android!");
});
}
}
#endif
default: {
// use default CPU delegate.
}
}

auto interpreter = TfLiteInterpreterCreate(model, options);
auto interpreter = TfLiteInterpreterCreate(model, options);

if (interpreter == nullptr) {
callInvoker->invokeAsync([=]() {
promise->reject("Failed to create TFLite interpreter from model \"" +
modelPath + "\"!");
});
return;
}
if (interpreter == nullptr) {
callInvoker->invokeAsync([=]() {
promise->reject("Failed to create TFLite interpreter from model \"" + modelPath +
"\"!");
});
return;
}

// Initialize Model and allocate memory buffers
auto plugin = std::make_shared<TensorflowPlugin>(interpreter, buffer,
delegateType, callInvoker);

callInvoker->invokeAsync([=, &runtime]() {
auto result = jsi::Object::createFromHostObject(runtime, plugin);
promise->resolve(std::move(result));
});

auto end = std::chrono::steady_clock::now();
log("Successfully loaded Tensorflow Model in %i ms!",
std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count());
} catch (std::exception& error) {
std::string message = error.what();
callInvoker->invokeAsync([=]() { promise->reject(message); });
}
// Initialize Model and allocate memory buffers
auto plugin = std::make_shared<TensorflowPlugin>(interpreter, buffer, delegateType,
callInvoker);

callInvoker->invokeAsync([=, &runtime]() {
auto result = jsi::Object::createFromHostObject(runtime, plugin);
promise->resolve(std::move(result));
});
});

auto end = std::chrono::steady_clock::now();
log("Successfully loaded Tensorflow Model in %i ms!",
std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count());
} catch (std::exception& error) {
std::string message = error.what();
callInvoker->invokeAsync([=]() { promise->reject(message); });
}
});
});
return promise;
});

Expand Down Expand Up @@ -339,6 +369,10 @@ jsi::Value TensorflowPlugin::get(jsi::Runtime& runtime, const jsi::PropNameID& p
return jsi::String::createFromUtf8(runtime, "core-ml");
case Delegate::Metal:
return jsi::String::createFromUtf8(runtime, "metal");
case Delegate::NnApi:
return jsi::String::createFromUtf8(runtime, "nnapi");
case Delegate::AndroidGPU:
return jsi::String::createFromUtf8(runtime, "android-gpu");
}
}

Expand Down
2 changes: 1 addition & 1 deletion cpp/TensorflowPlugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ typedef std::function<Buffer(std::string)> FetchURLFunc;
class TensorflowPlugin : public jsi::HostObject {
public:
// TFL Delegate Type
enum Delegate { Default, Metal, CoreML };
enum Delegate { Default, Metal, CoreML, NnApi, AndroidGPU };

public:
explicit TensorflowPlugin(TfLiteInterpreter* interpreter, Buffer model, Delegate delegate,
Expand Down
1 change: 1 addition & 0 deletions example/android/app/src/main/AndroidManifest.xml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
android:roundIcon="@mipmap/ic_launcher_round"
android:allowBackup="false"
android:theme="@style/AppTheme">
<uses-native-library android:name="libOpenCL.so" android:required="false" />
<activity
android:name=".MainActivity"
android:label="@string/app_name"
Expand Down
Loading

0 comments on commit b8cd552

Please sign in to comment.