-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add dummy fct for predict for location * Basic class for the geo classifier * Create geo classifier for prediction from location * Copy asset in example app * Add function skeleton * Add WIP implementation for normalization and encoding of location The math should be fine, only the expected return is not correct. * Add imports * Add caching threshold * Update createTaxonomy.js * Replace old with new small models * Rename xcode project * Fix comment * Update GeoClassifier.java * Small geomodel needs this tensorflow version * Remove comment * Update GeoClassifier.java * Android updates to use model with lowercase letter * Run inference on normalized and encoded inputs * Add steps to add the geomodel thresholds to the taxonomy csv file * Parse the spatial thresholds from the taxonomy csv file * Use a model version that reads geomodel thresholds * Add geomodel predictions to the returned object * Add expectedNearbyFromClassification function following the Obj-C code * Update VisionCameraPluginInatVisionModule.m * Refactor function to return float array directly * Add time elapsed to results * Update VisionCameraPluginInatVisionModule.java * Destructure geomodel params * Update App.tsx * Update VisionCameraPluginInatVision.m * Rename function * Rename function in frame processor * Recycle cropped bitmap as well * Predict with geomodel in frame processor * Set geomodel scores to use in ImageClassifier * Small script to check which taxa are new or removed comparing two taxonomy files * Function to combine vision and geo scores by using simple multiplication in a loop * Revert "Small geomodel needs this tensorflow version" This reverts commit 2d49f0f. * Revert "Rename xcode project" This reverts commit 96f3564. * Revert commit, Merge conflict * Revert "Update createTaxonomy.js" This reverts commit bb8adde. * Refactor Node to assemble data from csv based on header presence * Set geomodel scores to null if switching back to not using geomodel * Refactor predict function to use float array as param instead of Map * Combine vision and geo scores * Normalize scores again after combining geo and vision * Add caching of results for similar locations * Use geomodel for predicting from file * Mnimal refactoring * Throw error is location object is not full
- Loading branch information
Showing
11 changed files
with
510 additions
and
31 deletions.
There are no files selected for viewing
139 changes: 139 additions & 0 deletions
139
android/src/main/java/com/visioncameraplugininatvision/GeoClassifier.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
package com.visioncameraplugininatvision; | ||
|
||
import static java.lang.Math.PI; | ||
import static java.lang.Math.cos; | ||
import static java.lang.Math.sin; | ||
|
||
import android.util.Log; | ||
|
||
import org.tensorflow.lite.Interpreter; | ||
|
||
import java.io.FileInputStream; | ||
import java.io.IOException; | ||
import java.nio.MappedByteBuffer; | ||
import java.nio.channels.FileChannel; | ||
import java.util.List; | ||
|
||
import timber.log.Timber; | ||
|
||
/** Classifies locations (latitude, longitude, elevation) with Tensorflow Lite. */ | ||
public class GeoClassifier { | ||
|
||
/** Tag for the {@link Log}. */ | ||
private static final String TAG = "GeoClassifier"; | ||
|
||
private final Taxonomy mTaxonomy; | ||
private final String mModelFilename; | ||
private final String mTaxonomyFilename; | ||
private final String mModelVersion; | ||
private int mModelSize; | ||
|
||
/** An instance of the driver class to run model inference with Tensorflow Lite. */ | ||
private Interpreter mTFlite; | ||
|
||
/** Instance variables to cache the geomodel results */ | ||
private final float mLocationChangeThreshold = 0.001f; | ||
private float[][] mCachedGeoResult; | ||
private double mCachedLatitude; | ||
private double mCachedLongitude; | ||
private double mCachedElevation; | ||
|
||
/** Initializes a {@code GeoClassifier}. */ | ||
public GeoClassifier(String modelPath, String taxonomyPath, String version) throws IOException { | ||
mModelFilename = modelPath; | ||
mTaxonomyFilename = taxonomyPath; | ||
mModelVersion = version; | ||
mTFlite = new Interpreter(loadModelFile()); | ||
Timber.tag(TAG).d("Created a Tensorflow Lite Geomodel Classifier."); | ||
|
||
mTaxonomy = new Taxonomy(new FileInputStream(mTaxonomyFilename), mModelVersion); | ||
mModelSize = mTaxonomy.getModelSize(); | ||
} | ||
|
||
/* | ||
* iNat geomodel input normalization documented here: | ||
* https://github.com/inaturalist/inatGeoModelTraining/tree/main#input-normalization | ||
*/ | ||
public float[] normAndEncodeLocation(double latitude, double longitude, double elevation) { | ||
double normLat = latitude / 90.0; | ||
double normLng = longitude / 180.0; | ||
double normElev; | ||
if (elevation > 0) { | ||
normElev = elevation / 5705.63; | ||
} else { | ||
normElev = elevation / 32768.0; | ||
} | ||
double a = sin(PI * normLng); | ||
double b = sin(PI * normLat); | ||
double c = cos(PI * normLng); | ||
double d = cos(PI * normLat); | ||
return new float[] { (float) a, (float) b, (float) c, (float) d, (float) normElev }; | ||
} | ||
|
||
public float[][] predictionsForLocation(double latitude, double longitude, double elevation) { | ||
if (mCachedGeoResult == null || | ||
Math.abs(latitude - mCachedLatitude) > mLocationChangeThreshold || | ||
Math.abs(longitude - mCachedLongitude) > mLocationChangeThreshold || | ||
Math.abs(elevation - mCachedElevation) > mLocationChangeThreshold) | ||
{ | ||
float[][] results = classify(latitude, longitude, elevation); | ||
if (results != null && results.length > 0) { | ||
mCachedGeoResult = results; | ||
mCachedLatitude = latitude; | ||
mCachedLongitude = longitude; | ||
mCachedElevation = elevation; | ||
} | ||
return results; | ||
} | ||
|
||
return mCachedGeoResult; | ||
} | ||
|
||
public List<Prediction> expectedNearby(double latitude, double longitude, double elevation) { | ||
float[][] scores = predictionsForLocation(latitude, longitude, elevation); | ||
return mTaxonomy.expectedNearbyFromClassification(scores); | ||
} | ||
|
||
public float[][] classify(double latitude, double longitude, double elevation) { | ||
if (mTFlite == null) { | ||
Timber.tag(TAG).e("Geomodel classifier has not been initialized; Skipped."); | ||
return null; | ||
} | ||
|
||
// Get normalized inputs | ||
float[] normalizedInputs = normAndEncodeLocation(latitude, longitude, elevation); | ||
|
||
// Create input array with shape [1][5] | ||
float[][] inputArray = new float[1][5]; | ||
inputArray[0] = normalizedInputs; | ||
|
||
// Create output array | ||
float[][] outputArray = new float[1][mModelSize]; | ||
|
||
// Run inference | ||
try { | ||
mTFlite.run(inputArray, outputArray); | ||
return outputArray; | ||
} catch (Exception exc) { | ||
exc.printStackTrace(); | ||
return new float[1][]; | ||
} catch (OutOfMemoryError exc) { | ||
exc.printStackTrace(); | ||
return new float[1][]; | ||
} | ||
} | ||
|
||
/** Closes tflite to release resources. */ | ||
public void close() { | ||
mTFlite.close(); | ||
mTFlite = null; | ||
} | ||
|
||
/** Memory-map the model file in Assets. */ | ||
private MappedByteBuffer loadModelFile() throws IOException { | ||
FileInputStream inputStream = new FileInputStream(mModelFilename); | ||
FileChannel fileChannel = inputStream.getChannel(); | ||
return fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, inputStream.available()); | ||
} | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.