Skip to content

Commit

Permalink
Geomodel android (#48)
Browse files Browse the repository at this point in the history
* Add dummy fct for predict for location

* Basic class for the geo classifier

* Create geo classifier for prediction from location

* Copy asset in example app

* Add function skeleton

* Add WIP implementation for normalization and encoding of location

The math should be fine, only the expected return is not correct.

* Add imports

* Add caching threshold

* Update createTaxonomy.js

* Replace old with new small models

* Rename xcode project

* Fix comment

* Update GeoClassifier.java

* Small geomodel needs this tensorflow version

* Remove comment

* Update GeoClassifier.java

* Android updates to use model with lowercase letter

* Run inference on normalized and encoded inputs

* Add steps to add the geomodel thresholds to the taxonomy csv file

* Parse the spatial thresholds from the taxonomy csv file

* Use a model version that reads geomodel thresholds

* Add geomodel predictions to the returned object

* Add expectedNearbyFromClassification function following the Obj-C code

* Update VisionCameraPluginInatVisionModule.m

* Refactor function to return float array directly

* Add time elapsed to results

* Update VisionCameraPluginInatVisionModule.java

* Destructure geomodel params

* Update App.tsx

* Update VisionCameraPluginInatVision.m

* Rename function

* Rename function in frame processor

* Recycle cropped bitmap as well

* Predict with geomodel in frame processor

* Set geomodel scores to use in ImageClassifier

* Small script to check which taxa are new or removed comparing two taxonomy files

* Function to combine vision and geo scores by using simple multiplication in a loop

* Revert "Small geomodel needs this tensorflow version"

This reverts commit 2d49f0f.

* Revert "Rename xcode project"

This reverts commit 96f3564.

* Revert commit, Merge conflict

* Revert "Update createTaxonomy.js"

This reverts commit bb8adde.

* Refactor Node to assemble data from csv based on header presence

* Set geomodel scores to null if switching back to not using geomodel

* Refactor predict function to use float array as param instead of Map

* Combine vision and geo scores

* Normalize scores again after combining geo and vision

* Add caching of results for similar locations

* Use geomodel for predicting from file

* Mnimal refactoring

* Throw error is location object is not full
  • Loading branch information
jtklein authored Dec 20, 2024
1 parent 55b5bbd commit f4f64ab
Show file tree
Hide file tree
Showing 11 changed files with 510 additions and 31 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
package com.visioncameraplugininatvision;

import static java.lang.Math.PI;
import static java.lang.Math.cos;
import static java.lang.Math.sin;

import android.util.Log;

import org.tensorflow.lite.Interpreter;

import java.io.FileInputStream;
import java.io.IOException;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.util.List;

import timber.log.Timber;

/** Classifies locations (latitude, longitude, elevation) with Tensorflow Lite. */
public class GeoClassifier {

/** Tag for the {@link Log}. */
private static final String TAG = "GeoClassifier";

private final Taxonomy mTaxonomy;
private final String mModelFilename;
private final String mTaxonomyFilename;
private final String mModelVersion;
private int mModelSize;

/** An instance of the driver class to run model inference with Tensorflow Lite. */
private Interpreter mTFlite;

/** Instance variables to cache the geomodel results */
private final float mLocationChangeThreshold = 0.001f;
private float[][] mCachedGeoResult;
private double mCachedLatitude;
private double mCachedLongitude;
private double mCachedElevation;

/** Initializes a {@code GeoClassifier}. */
public GeoClassifier(String modelPath, String taxonomyPath, String version) throws IOException {
mModelFilename = modelPath;
mTaxonomyFilename = taxonomyPath;
mModelVersion = version;
mTFlite = new Interpreter(loadModelFile());
Timber.tag(TAG).d("Created a Tensorflow Lite Geomodel Classifier.");

mTaxonomy = new Taxonomy(new FileInputStream(mTaxonomyFilename), mModelVersion);
mModelSize = mTaxonomy.getModelSize();
}

/*
* iNat geomodel input normalization documented here:
* https://github.com/inaturalist/inatGeoModelTraining/tree/main#input-normalization
*/
public float[] normAndEncodeLocation(double latitude, double longitude, double elevation) {
double normLat = latitude / 90.0;
double normLng = longitude / 180.0;
double normElev;
if (elevation > 0) {
normElev = elevation / 5705.63;
} else {
normElev = elevation / 32768.0;
}
double a = sin(PI * normLng);
double b = sin(PI * normLat);
double c = cos(PI * normLng);
double d = cos(PI * normLat);
return new float[] { (float) a, (float) b, (float) c, (float) d, (float) normElev };
}

public float[][] predictionsForLocation(double latitude, double longitude, double elevation) {
if (mCachedGeoResult == null ||
Math.abs(latitude - mCachedLatitude) > mLocationChangeThreshold ||
Math.abs(longitude - mCachedLongitude) > mLocationChangeThreshold ||
Math.abs(elevation - mCachedElevation) > mLocationChangeThreshold)
{
float[][] results = classify(latitude, longitude, elevation);
if (results != null && results.length > 0) {
mCachedGeoResult = results;
mCachedLatitude = latitude;
mCachedLongitude = longitude;
mCachedElevation = elevation;
}
return results;
}

return mCachedGeoResult;
}

public List<Prediction> expectedNearby(double latitude, double longitude, double elevation) {
float[][] scores = predictionsForLocation(latitude, longitude, elevation);
return mTaxonomy.expectedNearbyFromClassification(scores);
}

public float[][] classify(double latitude, double longitude, double elevation) {
if (mTFlite == null) {
Timber.tag(TAG).e("Geomodel classifier has not been initialized; Skipped.");
return null;
}

// Get normalized inputs
float[] normalizedInputs = normAndEncodeLocation(latitude, longitude, elevation);

// Create input array with shape [1][5]
float[][] inputArray = new float[1][5];
inputArray[0] = normalizedInputs;

// Create output array
float[][] outputArray = new float[1][mModelSize];

// Run inference
try {
mTFlite.run(inputArray, outputArray);
return outputArray;
} catch (Exception exc) {
exc.printStackTrace();
return new float[1][];
} catch (OutOfMemoryError exc) {
exc.printStackTrace();
return new float[1][];
}
}

/** Closes tflite to release resources. */
public void close() {
mTFlite.close();
mTFlite = null;
}

/** Memory-map the model file in Assets. */
private MappedByteBuffer loadModelFile() throws IOException {
FileInputStream inputStream = new FileInputStream(mModelFilename);
FileChannel fileChannel = inputStream.getChannel();
return fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, inputStream.available());
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ public class ImageClassifier {
/** A ByteBuffer to hold image data, to be feed into Tensorflow Lite as inputs. */
private ByteBuffer imgData;

private float[][] mGeomodelScores;

public void setFilterByTaxonId(Integer taxonId) {
mTaxonomy.setFilterByTaxonId(taxonId);
}
Expand All @@ -68,6 +70,10 @@ public boolean getNegativeFilter() {
return mTaxonomy.getNegativeFilter();
}

public void setGeomodelScores(float[][] scores) {
mGeomodelScores = scores;
}

/** Initializes an {@code ImageClassifier}. */
public ImageClassifier(String modelPath, String taxonomyPath, String version) throws IOException {
mModelFilename = modelPath;
Expand All @@ -85,7 +91,7 @@ public ImageClassifier(String modelPath, String taxonomyPath, String version) th
}

/** Classifies a frame from the preview stream. */
public List<Prediction> classifyFrame(Bitmap bitmap, Double taxonomyRollupCutoff) {
public List<Prediction> classifyBitmap(Bitmap bitmap, Double taxonomyRollupCutoff) {
if (mTFlite == null) {
Timber.tag(TAG).e("Image classifier has not been initialized; Skipped.");
return null;
Expand All @@ -108,7 +114,16 @@ public List<Prediction> classifyFrame(Bitmap bitmap, Double taxonomyRollupCutoff
List<Prediction> predictions = null;
try {
mTFlite.runForMultipleInputsOutputs(input, expectedOutputs);
predictions = mTaxonomy.predict(expectedOutputs, taxonomyRollupCutoff);
// Get raw vision scores
float[] visionScores = ((float[][]) expectedOutputs.get(0))[0];
float[] combinedScores = new float[visionScores.length];
if (mGeomodelScores != null) {
// Combine vision and geo scores
combinedScores = combineVisionScores(visionScores, mGeomodelScores[0]);
} else {
combinedScores = visionScores;
}
predictions = mTaxonomy.predict(combinedScores, taxonomyRollupCutoff);
} catch (Exception exc) {
exc.printStackTrace();
return new ArrayList<Prediction>();
Expand Down Expand Up @@ -177,5 +192,28 @@ private void convertBitmapToByteBuffer(Bitmap bitmap) {
}
}

/** Combines vision and geo model scores */
private float[] combineVisionScores(float[] visionScores, float[] geoScores) {
float[] combinedScores = new float[visionScores.length];
float sum = 0.0f;

// First multiply the scores
for (int i = 0; i < visionScores.length; i++) {
float visionScore = visionScores[i];
float geoScore = geoScores[i];
combinedScores[i] = visionScore * geoScore;
sum += combinedScores[i];
}

// Then normalize so they sum to 1.0
if (sum > 0) {
for (int i = 0; i < combinedScores.length; i++) {
combinedScores[i] = combinedScores[i] / sum;
}
}

return combinedScores;
}

}

58 changes: 45 additions & 13 deletions android/src/main/java/com/visioncameraplugininatvision/Node.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.visioncameraplugininatvision;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public class Node {
Expand All @@ -18,6 +19,8 @@ public class Node {

public String spatialId;

public String spatialThreshold;

public transient Node parent;

public transient List<Node> children = new ArrayList<>();
Expand All @@ -31,19 +34,48 @@ public String toString() {
// parent_taxon_id,taxon_id,rank_level,leaf_class_id,iconic_class_id,spatial_class_id,name
// Seek model 1.0:
// parent_taxon_id,taxon_id,rank_level,leaf_class_id,name
public Node(String line, String version) {
String[] parts = line.trim().split(",", 7);

this.parentKey = parts[0];
this.key = parts[1];
this.rank = Float.parseFloat(parts[2]);
this.leafId = parts[3];
if (version.equals("1.0")) {
this.name = parts[4];
} else {
this.iconicId = parts[4];
this.spatialId = parts[5];
this.name = parts[6];
public Node(String[] headers, String line) {
String[] parts = line.trim().split(",", headers.length);
List<String> headerList = new ArrayList<>(Arrays.asList(headers));

if (headerList.contains("parent_taxon_id")) {
int parentTaxonIdIndex = headerList.indexOf("parent_taxon_id");
this.parentKey = parts[parentTaxonIdIndex];
}

if (headerList.contains("taxon_id")) {
int taxonIdIndex = headerList.indexOf("taxon_id");
this.key = parts[taxonIdIndex];
}

if (headerList.contains("rank_level")) {
int rankLevelIndex = headerList.indexOf("rank_level");
this.rank = Float.parseFloat(parts[rankLevelIndex]);
}

if (headerList.contains("leaf_class_id")) {
int leafClassIdIndex = headerList.indexOf("leaf_class_id");
this.leafId = parts[leafClassIdIndex];
}

if (headerList.contains("iconic_class_id")) {
int iconicClassIdIndex = headerList.indexOf("iconic_class_id");
this.iconicId = parts[iconicClassIdIndex];
}

if (headerList.contains("spatial_class_id")) {
int spatialClassIdIndex = headerList.indexOf("spatial_class_id");
this.spatialId = parts[spatialClassIdIndex];
}

if (headerList.contains("spatial_threshold")) {
int spatialThresholdIndex = headerList.indexOf("spatial_threshold");
this.spatialThreshold = parts[spatialThresholdIndex];
}

if (headerList.contains("name")) {
int nameIndex = headerList.indexOf("name");
this.name = parts[nameIndex];
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,15 @@ public void setTaxonomyRollupCutoff(float taxonomyRollupCutoff) {
// Read the taxonomy CSV file into a list of nodes
BufferedReader reader = new BufferedReader(new InputStreamReader(is));
try {
reader.readLine(); // Skip the first line (header line)
String headerLine = reader.readLine();
// Transform header line to array
String[] headers = headerLine.split(",");


mNodes = new ArrayList<>();
mLeaves = new ArrayList<>();
for (String line; (line = reader.readLine()) != null; ) {
Node node = new Node(line, mModelVersion);
Node node = new Node(headers, line);
mNodes.add(node);
if ((node.leafId != null) && (node.leafId.length() > 0)) {
mLeaves.add(node);
Expand Down Expand Up @@ -152,11 +155,9 @@ public int getModelSize() {
return mLeaves.size();
}

public List<Prediction> predict(Map<Integer, Object> outputs, Double taxonomyRollupCutoff) {
// Get raw predictions
float[] results = ((float[][]) outputs.get(0))[0];
public List<Prediction> predict(float[] scores, Double taxonomyRollupCutoff) {
// Make a copy of results
float[] resultsCopy = results.clone();
float[] resultsCopy = scores.clone();
// Make sure results is sorted by score
Arrays.sort(resultsCopy);
// Get result with the highest score
Expand All @@ -170,13 +171,41 @@ public List<Prediction> predict(Map<Integer, Object> outputs, Double taxonomyRol
}
resultsCopy = null;

Map<String, Float> scores = aggregateAndNormalizeScores(results);
Timber.tag(TAG).d("Number of nodes in scores: " + scores.size());
List<Prediction> bestBranch = buildBestBranchFromScores(scores);
Map<String, Float> aggregateScores = aggregateAndNormalizeScores(scores);
Timber.tag(TAG).d("Number of nodes in scores: " + aggregateScores.size());
List<Prediction> bestBranch = buildBestBranchFromScores(aggregateScores);

return bestBranch;
}

public List<Prediction> expectedNearbyFromClassification(float[][] results) {
List<Prediction> scores = new ArrayList<>();
List<Prediction> filteredOutScores = new ArrayList<>();

for (Node leaf : mLeaves) {
// We did not implement batch processing here, so we only have one result
float score = results[0][Integer.valueOf(leaf.leafId)];
Prediction prediction = new Prediction(leaf, score);

// If score is higher than spatialThreshold it means the taxon is "expected nearby"
if (leaf.spatialThreshold != null && !leaf.spatialThreshold.isEmpty()) {
float threshold = Float.parseFloat(leaf.spatialThreshold);
if (score >= threshold) {
scores.add(prediction);
} else {
filteredOutScores.add(prediction);
}
} else {
scores.add(prediction);
}
}

// Log length of scores
Timber.tag(TAG).d("Length of scores: " + scores.size());
Timber.tag(TAG).d("Length of filteredOutScores: " + filteredOutScores.size());

return scores;
}

/** Aggregates scores for nodes, including non-leaf nodes (so each non-leaf node has a score of the sum of all its dependents) */
private Map<String, Float> aggregateAndNormalizeScores(float[] results) {
Expand Down
Loading

0 comments on commit f4f64ab

Please sign in to comment.