Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add export and detection for TensorFlow saved_model, graph_def and TFLite #959

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions android/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
*.iml
.gradle
/local.properties
/.idea/libraries
/.idea/modules.xml
/.idea/workspace.xml
.DS_Store
/build
/captures
.externalNativeBuild

/.gradle/
/.idea/
2 changes: 2 additions & 0 deletions android/app/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
/build
/build/
55 changes: 55 additions & 0 deletions android/app/build.gradle
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
apply plugin: 'com.android.application'
apply plugin: 'de.undercouch.download'

android {
compileSdkVersion 28
buildToolsVersion '28.0.3'
defaultConfig {
applicationId "org.tensorflow.lite.examples.detection"
minSdkVersion 21
targetSdkVersion 28
versionCode 1
versionName "1.0"

ndk {
abiFilters 'armeabi-v7a', 'arm64-v8a', 'x86'
}
}
buildTypes {
release {
minifyEnabled false
proguardFiles getDefaultProguardFile('proguard-android.txt'), 'proguard-rules.pro'
}
}
aaptOptions {
noCompress "tflite"
}
compileOptions {
sourceCompatibility = '1.8'
targetCompatibility = '1.8'
}
lintOptions {
abortOnError false
}
}

// import DownloadModels task
project.ext.ASSET_DIR = projectDir.toString() + '/src/main/assets'
project.ext.TMP_DIR = project.buildDir.toString() + '/downloads'

// Download default models; if you wish to use your own models then
// place them in the "assets" directory and comment out this line.
//apply from: "download_model.gradle"

dependencies {
implementation 'androidx.appcompat:appcompat:1.1.0'
implementation 'androidx.coordinatorlayout:coordinatorlayout:1.1.0'
implementation 'com.google.android.material:material:1.1.0'
implementation 'org.tensorflow:tensorflow-lite:2.3.0'
implementation 'org.tensorflow:tensorflow-lite-gpu:2.3.0'
implementation 'androidx.constraintlayout:constraintlayout:1.1.3'
implementation 'com.google.code.gson:gson:2.8.6'
androidTestImplementation 'androidx.test.ext:junit:1.1.1'
androidTestImplementation 'com.android.support.test:rules:1.0.2'
androidTestImplementation 'com.google.truth:truth:1.0.1'
}
26 changes: 26 additions & 0 deletions android/app/download_model.gradle
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@

task downloadZipFile(type: Download) {
src 'http://storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip'
dest new File(buildDir, 'zips/')
overwrite false
}


task downloadAndUnzipFile(dependsOn: downloadZipFile, type: Copy) {
from zipTree(downloadZipFile.dest)
into project.ext.ASSET_DIR
}


task extractModels(type: Copy) {
dependsOn downloadAndUnzipFile
}

tasks.whenTaskAdded { task ->
if (task.name == 'assembleDebug') {
task.dependsOn 'extractModels'
}
if (task.name == 'assembleRelease') {
task.dependsOn 'extractModels'
}
}
21 changes: 21 additions & 0 deletions android/app/proguard-rules.pro
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Add project specific ProGuard rules here.
# You can control the set of applied configuration files using the
# proguardFiles setting in build.gradle.
#
# For more details, see
# http://developer.android.com/guide/developing/tools/proguard.html

# If your project uses WebView with JS, uncomment the following
# and specify the fully qualified class name to the JavaScript interface
# class:
#-keepclassmembers class fqcn.of.javascript.interface.for.webview {
# public *;
#}

# Uncomment this to preserve the line number information for
# debugging stack traces.
#-keepattributes SourceFile,LineNumberTable

# If you keep the line number information, uncomment this to
# hide the original source file name.
#-renamesourcefileattribute SourceFile
4 changes: 4 additions & 0 deletions android/app/src/androidTest/assets/table_results.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
dining_table 27.492085 97.94615 623.1435 444.8627 0.48828125
knife 342.53433 243.71082 583.89185 416.34595 0.4765625
cup 68.025925 197.5857 202.02031 374.2206 0.4375
book 185.43098 139.64153 244.51149 203.37737 0.3125
5 changes: 5 additions & 0 deletions android/app/src/androidTest/java/AndroidManifest.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="org.tensorflow.lite.examples.detection">
<uses-sdk />
</manifest>
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
/*
* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.tensorflow.lite.examples.detection;

import static com.google.common.truth.Truth.assertThat;
import static java.lang.Math.abs;
import static java.lang.Math.max;
import static java.lang.Math.min;

import android.content.res.AssetManager;
import android.graphics.Bitmap;
import android.graphics.Bitmap.Config;
import android.graphics.BitmapFactory;
import android.graphics.Canvas;
import android.graphics.Matrix;
import android.graphics.RectF;
import android.util.Size;
import androidx.test.ext.junit.runners.AndroidJUnit4;
import androidx.test.platform.app.InstrumentationRegistry;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.List;
import java.util.Scanner;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.tensorflow.lite.examples.detection.env.ImageUtils;
import org.tensorflow.lite.examples.detection.tflite.Classifier;
import org.tensorflow.lite.examples.detection.tflite.Classifier.Recognition;
//import org.tensorflow.lite.examples.detection.tflite.TFLiteObjectDetectionAPIModel;
import org.tensorflow.lite.examples.detection.tflite.YoloV5Classifier;

/** Golden test for Object Detection Reference app. */
@RunWith(AndroidJUnit4.class)
public class DetectorTest {

private static final int MODEL_INPUT_SIZE = 300;
private static final boolean IS_MODEL_QUANTIZED = true;
private static final String MODEL_FILE = "detect.tflite";
private static final String LABELS_FILE = "file:///android_asset/coco.txt";
private static final Size IMAGE_SIZE = new Size(640, 480);

private Classifier detector;
private Bitmap croppedBitmap;
private Matrix frameToCropTransform;
private Matrix cropToFrameTransform;

@Before
public void setUp() throws IOException {
AssetManager assetManager =
InstrumentationRegistry.getInstrumentation().getContext().getAssets();
detector =
YoloV5Classifier.create(
assetManager,
MODEL_FILE,
LABELS_FILE,
IS_MODEL_QUANTIZED,
MODEL_INPUT_SIZE);
int cropSize = MODEL_INPUT_SIZE;
int previewWidth = IMAGE_SIZE.getWidth();
int previewHeight = IMAGE_SIZE.getHeight();
int sensorOrientation = 0;
croppedBitmap = Bitmap.createBitmap(cropSize, cropSize, Config.ARGB_8888);

frameToCropTransform =
ImageUtils.getTransformationMatrix(
previewWidth, previewHeight,
cropSize, cropSize,
sensorOrientation, false);
cropToFrameTransform = new Matrix();
frameToCropTransform.invert(cropToFrameTransform);
}

@Test
public void detectionResultsShouldNotChange() throws Exception {
Canvas canvas = new Canvas(croppedBitmap);
canvas.drawBitmap(loadImage("table.jpg"), frameToCropTransform, null);
final List<Recognition> results = detector.recognizeImage(croppedBitmap);
final List<Recognition> expected = loadRecognitions("table_results.txt");

for (Recognition target : expected) {
// Find a matching result in results
boolean matched = false;
for (Recognition item : results) {
RectF bbox = new RectF();
cropToFrameTransform.mapRect(bbox, item.getLocation());
if (item.getTitle().equals(target.getTitle())
&& matchBoundingBoxes(bbox, target.getLocation())
&& matchConfidence(item.getConfidence(), target.getConfidence())) {
matched = true;
break;
}
}
assertThat(matched).isTrue();
}
}

// Confidence tolerance: absolute 1%
private static boolean matchConfidence(float a, float b) {
return abs(a - b) < 0.01;
}

// Bounding Box tolerance: overlapped area > 95% of each one
private static boolean matchBoundingBoxes(RectF a, RectF b) {
float areaA = a.width() * a.height();
float areaB = b.width() * b.height();
RectF overlapped =
new RectF(
max(a.left, b.left), max(a.top, b.top), min(a.right, b.right), min(a.bottom, b.bottom));
float overlappedArea = overlapped.width() * overlapped.height();
return overlappedArea > 0.95 * areaA && overlappedArea > 0.95 * areaB;
}

private static Bitmap loadImage(String fileName) throws Exception {
AssetManager assetManager =
InstrumentationRegistry.getInstrumentation().getContext().getAssets();
InputStream inputStream = assetManager.open(fileName);
return BitmapFactory.decodeStream(inputStream);
}

// The format of result:
// category bbox.left bbox.top bbox.right bbox.bottom confidence
// ...
// Example:
// Apple 99 25 30 75 80 0.99
// Banana 25 90 75 200 0.98
// ...
private static List<Recognition> loadRecognitions(String fileName) throws Exception {
AssetManager assetManager =
InstrumentationRegistry.getInstrumentation().getContext().getAssets();
InputStream inputStream = assetManager.open(fileName);
Scanner scanner = new Scanner(inputStream);
List<Recognition> result = new ArrayList<>();
while (scanner.hasNext()) {
String category = scanner.next();
category = category.replace('_', ' ');
if (!scanner.hasNextFloat()) {
break;
}
float left = scanner.nextFloat();
float top = scanner.nextFloat();
float right = scanner.nextFloat();
float bottom = scanner.nextFloat();
RectF boundingBox = new RectF(left, top, right, bottom);
float confidence = scanner.nextFloat();
Recognition recognition = new Recognition(null, category, confidence, boundingBox);
result.add(recognition);
}
return result;
}
}
38 changes: 38 additions & 0 deletions android/app/src/main/AndroidManifest.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="org.tensorflow.lite.examples.detection">
<!-- Tell the system this app requires OpenGL ES 3.1. -->
<uses-feature android:glEsVersion="0x00030001" android:required="true" />

<uses-sdk />

<uses-permission android:name="android.permission.CAMERA" />

<uses-feature android:name="android.hardware.camera" />
<uses-feature android:name="android.hardware.camera.autofocus" />
<uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE"/>
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE"/>
<uses-permission android:name="android.permission.INTERNET"/>

<application
android:allowBackup="false"
android:icon="@mipmap/ic_launcher"
android:label="@string/tfe_od_app_name"
android:roundIcon="@mipmap/ic_launcher_round"
android:supportsRtl="true"
android:theme="@style/AppTheme.ObjectDetection"
android:hardwareAccelerated="true"
android:debuggable="true"
android:installLocation="internalOnly">

<activity
android:name=".DetectorActivity"
android:label="@string/tfe_od_app_name"
android:screenOrientation="portrait">
<intent-filter>
<action android:name="android.intent.action.MAIN" />
<category android:name="android.intent.category.LAUNCHER" />
</intent-filter>
</activity>

</application>
</manifest>
Loading