From f4f64aba3daf42b132a876817c62d67672da840f Mon Sep 17 00:00:00 2001 From: Johannes Klein Date: Fri, 20 Dec 2024 15:42:15 +0100 Subject: [PATCH] Geomodel android (#48) * Add dummy fct for predict for location * Basic class for the geo classifier * Create geo classifier for prediction from location * Copy asset in example app * Add function skeleton * Add WIP implementation for normalization and encoding of location The math should be fine, only the expected return is not correct. * Add imports * Add caching threshold * Update createTaxonomy.js * Replace old with new small models * Rename xcode project * Fix comment * Update GeoClassifier.java * Small geomodel needs this tensorflow version * Remove comment * Update GeoClassifier.java * Android updates to use model with lowercase letter * Run inference on normalized and encoded inputs * Add steps to add the geomodel thresholds to the taxonomy csv file * Parse the spatial thresholds from the taxonomy csv file * Use a model version that reads geomodel thresholds * Add geomodel predictions to the returned object * Add expectedNearbyFromClassification function following the Obj-C code * Update VisionCameraPluginInatVisionModule.m * Refactor function to return float array directly * Add time elapsed to results * Update VisionCameraPluginInatVisionModule.java * Destructure geomodel params * Update App.tsx * Update VisionCameraPluginInatVision.m * Rename function * Rename function in frame processor * Recycle cropped bitmap as well * Predict with geomodel in frame processor * Set geomodel scores to use in ImageClassifier * Small script to check which taxa are new or removed comparing two taxonomy files * Function to combine vision and geo scores by using simple multiplication in a loop * Revert "Small geomodel needs this tensorflow version" This reverts commit 2d49f0fe629d088e244028209f663dc5cce526f0. * Revert "Rename xcode project" This reverts commit 96f3564320afee16fa2ac24e8a251d890bb4e00e. * Revert commit, Merge conflict * Revert "Update createTaxonomy.js" This reverts commit bb8adde620c6da6669c636a754c88b8f8833b080. * Refactor Node to assemble data from csv based on header presence * Set geomodel scores to null if switching back to not using geomodel * Refactor predict function to use float array as param instead of Map * Combine vision and geo scores * Normalize scores again after combining geo and vision * Add caching of results for similar locations * Use geomodel for predicting from file * Mnimal refactoring * Throw error is location object is not full --- .../GeoClassifier.java | 139 ++++++++++++++++++ .../ImageClassifier.java | 42 +++++- .../visioncameraplugininatvision/Node.java | 58 ++++++-- .../Taxonomy.java | 47 ++++-- .../VisionCameraPluginInatVisionModule.java | 106 ++++++++++++- .../VisionCameraPluginInatVisionPlugin.java | 46 +++++- example/src/App.tsx | 4 + .../VisionCameraPluginInatVision.m | 2 +- .../VisionCameraPluginInatVisionModule.m | 1 - scripts/createTaxonomy.js | 14 ++ scripts/whichTaxaAreNew.js | 82 +++++++++++ 11 files changed, 510 insertions(+), 31 deletions(-) create mode 100644 android/src/main/java/com/visioncameraplugininatvision/GeoClassifier.java create mode 100644 scripts/whichTaxaAreNew.js diff --git a/android/src/main/java/com/visioncameraplugininatvision/GeoClassifier.java b/android/src/main/java/com/visioncameraplugininatvision/GeoClassifier.java new file mode 100644 index 0000000..41eeb39 --- /dev/null +++ b/android/src/main/java/com/visioncameraplugininatvision/GeoClassifier.java @@ -0,0 +1,139 @@ +package com.visioncameraplugininatvision; + +import static java.lang.Math.PI; +import static java.lang.Math.cos; +import static java.lang.Math.sin; + +import android.util.Log; + +import org.tensorflow.lite.Interpreter; + +import java.io.FileInputStream; +import java.io.IOException; +import java.nio.MappedByteBuffer; +import java.nio.channels.FileChannel; +import java.util.List; + +import timber.log.Timber; + +/** Classifies locations (latitude, longitude, elevation) with Tensorflow Lite. */ +public class GeoClassifier { + + /** Tag for the {@link Log}. */ + private static final String TAG = "GeoClassifier"; + + private final Taxonomy mTaxonomy; + private final String mModelFilename; + private final String mTaxonomyFilename; + private final String mModelVersion; + private int mModelSize; + + /** An instance of the driver class to run model inference with Tensorflow Lite. */ + private Interpreter mTFlite; + + /** Instance variables to cache the geomodel results */ + private final float mLocationChangeThreshold = 0.001f; + private float[][] mCachedGeoResult; + private double mCachedLatitude; + private double mCachedLongitude; + private double mCachedElevation; + + /** Initializes a {@code GeoClassifier}. */ + public GeoClassifier(String modelPath, String taxonomyPath, String version) throws IOException { + mModelFilename = modelPath; + mTaxonomyFilename = taxonomyPath; + mModelVersion = version; + mTFlite = new Interpreter(loadModelFile()); + Timber.tag(TAG).d("Created a Tensorflow Lite Geomodel Classifier."); + + mTaxonomy = new Taxonomy(new FileInputStream(mTaxonomyFilename), mModelVersion); + mModelSize = mTaxonomy.getModelSize(); + } + + /* + * iNat geomodel input normalization documented here: + * https://github.com/inaturalist/inatGeoModelTraining/tree/main#input-normalization + */ + public float[] normAndEncodeLocation(double latitude, double longitude, double elevation) { + double normLat = latitude / 90.0; + double normLng = longitude / 180.0; + double normElev; + if (elevation > 0) { + normElev = elevation / 5705.63; + } else { + normElev = elevation / 32768.0; + } + double a = sin(PI * normLng); + double b = sin(PI * normLat); + double c = cos(PI * normLng); + double d = cos(PI * normLat); + return new float[] { (float) a, (float) b, (float) c, (float) d, (float) normElev }; + } + + public float[][] predictionsForLocation(double latitude, double longitude, double elevation) { + if (mCachedGeoResult == null || + Math.abs(latitude - mCachedLatitude) > mLocationChangeThreshold || + Math.abs(longitude - mCachedLongitude) > mLocationChangeThreshold || + Math.abs(elevation - mCachedElevation) > mLocationChangeThreshold) + { + float[][] results = classify(latitude, longitude, elevation); + if (results != null && results.length > 0) { + mCachedGeoResult = results; + mCachedLatitude = latitude; + mCachedLongitude = longitude; + mCachedElevation = elevation; + } + return results; + } + + return mCachedGeoResult; + } + + public List expectedNearby(double latitude, double longitude, double elevation) { + float[][] scores = predictionsForLocation(latitude, longitude, elevation); + return mTaxonomy.expectedNearbyFromClassification(scores); + } + + public float[][] classify(double latitude, double longitude, double elevation) { + if (mTFlite == null) { + Timber.tag(TAG).e("Geomodel classifier has not been initialized; Skipped."); + return null; + } + + // Get normalized inputs + float[] normalizedInputs = normAndEncodeLocation(latitude, longitude, elevation); + + // Create input array with shape [1][5] + float[][] inputArray = new float[1][5]; + inputArray[0] = normalizedInputs; + + // Create output array + float[][] outputArray = new float[1][mModelSize]; + + // Run inference + try { + mTFlite.run(inputArray, outputArray); + return outputArray; + } catch (Exception exc) { + exc.printStackTrace(); + return new float[1][]; + } catch (OutOfMemoryError exc) { + exc.printStackTrace(); + return new float[1][]; + } + } + + /** Closes tflite to release resources. */ + public void close() { + mTFlite.close(); + mTFlite = null; + } + + /** Memory-map the model file in Assets. */ + private MappedByteBuffer loadModelFile() throws IOException { + FileInputStream inputStream = new FileInputStream(mModelFilename); + FileChannel fileChannel = inputStream.getChannel(); + return fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, inputStream.available()); + } +} + diff --git a/android/src/main/java/com/visioncameraplugininatvision/ImageClassifier.java b/android/src/main/java/com/visioncameraplugininatvision/ImageClassifier.java index d5116cb..95e2afb 100644 --- a/android/src/main/java/com/visioncameraplugininatvision/ImageClassifier.java +++ b/android/src/main/java/com/visioncameraplugininatvision/ImageClassifier.java @@ -52,6 +52,8 @@ public class ImageClassifier { /** A ByteBuffer to hold image data, to be feed into Tensorflow Lite as inputs. */ private ByteBuffer imgData; + private float[][] mGeomodelScores; + public void setFilterByTaxonId(Integer taxonId) { mTaxonomy.setFilterByTaxonId(taxonId); } @@ -68,6 +70,10 @@ public boolean getNegativeFilter() { return mTaxonomy.getNegativeFilter(); } + public void setGeomodelScores(float[][] scores) { + mGeomodelScores = scores; + } + /** Initializes an {@code ImageClassifier}. */ public ImageClassifier(String modelPath, String taxonomyPath, String version) throws IOException { mModelFilename = modelPath; @@ -85,7 +91,7 @@ public ImageClassifier(String modelPath, String taxonomyPath, String version) th } /** Classifies a frame from the preview stream. */ - public List classifyFrame(Bitmap bitmap, Double taxonomyRollupCutoff) { + public List classifyBitmap(Bitmap bitmap, Double taxonomyRollupCutoff) { if (mTFlite == null) { Timber.tag(TAG).e("Image classifier has not been initialized; Skipped."); return null; @@ -108,7 +114,16 @@ public List classifyFrame(Bitmap bitmap, Double taxonomyRollupCutoff List predictions = null; try { mTFlite.runForMultipleInputsOutputs(input, expectedOutputs); - predictions = mTaxonomy.predict(expectedOutputs, taxonomyRollupCutoff); + // Get raw vision scores + float[] visionScores = ((float[][]) expectedOutputs.get(0))[0]; + float[] combinedScores = new float[visionScores.length]; + if (mGeomodelScores != null) { + // Combine vision and geo scores + combinedScores = combineVisionScores(visionScores, mGeomodelScores[0]); + } else { + combinedScores = visionScores; + } + predictions = mTaxonomy.predict(combinedScores, taxonomyRollupCutoff); } catch (Exception exc) { exc.printStackTrace(); return new ArrayList(); @@ -177,5 +192,28 @@ private void convertBitmapToByteBuffer(Bitmap bitmap) { } } + /** Combines vision and geo model scores */ + private float[] combineVisionScores(float[] visionScores, float[] geoScores) { + float[] combinedScores = new float[visionScores.length]; + float sum = 0.0f; + + // First multiply the scores + for (int i = 0; i < visionScores.length; i++) { + float visionScore = visionScores[i]; + float geoScore = geoScores[i]; + combinedScores[i] = visionScore * geoScore; + sum += combinedScores[i]; + } + + // Then normalize so they sum to 1.0 + if (sum > 0) { + for (int i = 0; i < combinedScores.length; i++) { + combinedScores[i] = combinedScores[i] / sum; + } + } + + return combinedScores; + } + } diff --git a/android/src/main/java/com/visioncameraplugininatvision/Node.java b/android/src/main/java/com/visioncameraplugininatvision/Node.java index 4cecedc..27403a6 100644 --- a/android/src/main/java/com/visioncameraplugininatvision/Node.java +++ b/android/src/main/java/com/visioncameraplugininatvision/Node.java @@ -1,6 +1,7 @@ package com.visioncameraplugininatvision; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; public class Node { @@ -18,6 +19,8 @@ public class Node { public String spatialId; + public String spatialThreshold; + public transient Node parent; public transient List children = new ArrayList<>(); @@ -31,19 +34,48 @@ public String toString() { // parent_taxon_id,taxon_id,rank_level,leaf_class_id,iconic_class_id,spatial_class_id,name // Seek model 1.0: // parent_taxon_id,taxon_id,rank_level,leaf_class_id,name - public Node(String line, String version) { - String[] parts = line.trim().split(",", 7); - - this.parentKey = parts[0]; - this.key = parts[1]; - this.rank = Float.parseFloat(parts[2]); - this.leafId = parts[3]; - if (version.equals("1.0")) { - this.name = parts[4]; - } else { - this.iconicId = parts[4]; - this.spatialId = parts[5]; - this.name = parts[6]; + public Node(String[] headers, String line) { + String[] parts = line.trim().split(",", headers.length); + List headerList = new ArrayList<>(Arrays.asList(headers)); + + if (headerList.contains("parent_taxon_id")) { + int parentTaxonIdIndex = headerList.indexOf("parent_taxon_id"); + this.parentKey = parts[parentTaxonIdIndex]; + } + + if (headerList.contains("taxon_id")) { + int taxonIdIndex = headerList.indexOf("taxon_id"); + this.key = parts[taxonIdIndex]; + } + + if (headerList.contains("rank_level")) { + int rankLevelIndex = headerList.indexOf("rank_level"); + this.rank = Float.parseFloat(parts[rankLevelIndex]); + } + + if (headerList.contains("leaf_class_id")) { + int leafClassIdIndex = headerList.indexOf("leaf_class_id"); + this.leafId = parts[leafClassIdIndex]; + } + + if (headerList.contains("iconic_class_id")) { + int iconicClassIdIndex = headerList.indexOf("iconic_class_id"); + this.iconicId = parts[iconicClassIdIndex]; + } + + if (headerList.contains("spatial_class_id")) { + int spatialClassIdIndex = headerList.indexOf("spatial_class_id"); + this.spatialId = parts[spatialClassIdIndex]; + } + + if (headerList.contains("spatial_threshold")) { + int spatialThresholdIndex = headerList.indexOf("spatial_threshold"); + this.spatialThreshold = parts[spatialThresholdIndex]; + } + + if (headerList.contains("name")) { + int nameIndex = headerList.indexOf("name"); + this.name = parts[nameIndex]; } } diff --git a/android/src/main/java/com/visioncameraplugininatvision/Taxonomy.java b/android/src/main/java/com/visioncameraplugininatvision/Taxonomy.java index c7aa944..079633d 100644 --- a/android/src/main/java/com/visioncameraplugininatvision/Taxonomy.java +++ b/android/src/main/java/com/visioncameraplugininatvision/Taxonomy.java @@ -101,12 +101,15 @@ public void setTaxonomyRollupCutoff(float taxonomyRollupCutoff) { // Read the taxonomy CSV file into a list of nodes BufferedReader reader = new BufferedReader(new InputStreamReader(is)); try { - reader.readLine(); // Skip the first line (header line) + String headerLine = reader.readLine(); + // Transform header line to array + String[] headers = headerLine.split(","); + mNodes = new ArrayList<>(); mLeaves = new ArrayList<>(); for (String line; (line = reader.readLine()) != null; ) { - Node node = new Node(line, mModelVersion); + Node node = new Node(headers, line); mNodes.add(node); if ((node.leafId != null) && (node.leafId.length() > 0)) { mLeaves.add(node); @@ -152,11 +155,9 @@ public int getModelSize() { return mLeaves.size(); } - public List predict(Map outputs, Double taxonomyRollupCutoff) { - // Get raw predictions - float[] results = ((float[][]) outputs.get(0))[0]; + public List predict(float[] scores, Double taxonomyRollupCutoff) { // Make a copy of results - float[] resultsCopy = results.clone(); + float[] resultsCopy = scores.clone(); // Make sure results is sorted by score Arrays.sort(resultsCopy); // Get result with the highest score @@ -170,13 +171,41 @@ public List predict(Map outputs, Double taxonomyRol } resultsCopy = null; - Map scores = aggregateAndNormalizeScores(results); - Timber.tag(TAG).d("Number of nodes in scores: " + scores.size()); - List bestBranch = buildBestBranchFromScores(scores); + Map aggregateScores = aggregateAndNormalizeScores(scores); + Timber.tag(TAG).d("Number of nodes in scores: " + aggregateScores.size()); + List bestBranch = buildBestBranchFromScores(aggregateScores); return bestBranch; } + public List expectedNearbyFromClassification(float[][] results) { + List scores = new ArrayList<>(); + List filteredOutScores = new ArrayList<>(); + + for (Node leaf : mLeaves) { + // We did not implement batch processing here, so we only have one result + float score = results[0][Integer.valueOf(leaf.leafId)]; + Prediction prediction = new Prediction(leaf, score); + + // If score is higher than spatialThreshold it means the taxon is "expected nearby" + if (leaf.spatialThreshold != null && !leaf.spatialThreshold.isEmpty()) { + float threshold = Float.parseFloat(leaf.spatialThreshold); + if (score >= threshold) { + scores.add(prediction); + } else { + filteredOutScores.add(prediction); + } + } else { + scores.add(prediction); + } + } + + // Log length of scores + Timber.tag(TAG).d("Length of scores: " + scores.size()); + Timber.tag(TAG).d("Length of filteredOutScores: " + filteredOutScores.size()); + + return scores; + } /** Aggregates scores for nodes, including non-leaf nodes (so each non-leaf node has a score of the sum of all its dependents) */ private Map aggregateAndNormalizeScores(float[] results) { diff --git a/android/src/main/java/com/visioncameraplugininatvision/VisionCameraPluginInatVisionModule.java b/android/src/main/java/com/visioncameraplugininatvision/VisionCameraPluginInatVisionModule.java index 69c2d47..daa9a66 100644 --- a/android/src/main/java/com/visioncameraplugininatvision/VisionCameraPluginInatVisionModule.java +++ b/android/src/main/java/com/visioncameraplugininatvision/VisionCameraPluginInatVisionModule.java @@ -71,6 +71,12 @@ public void removeListeners(Integer count) { public static final String OPTION_TAXONOMY_PATH = "taxonomyPath"; public static final String OPTION_CONFIDENCE_THRESHOLD = "confidenceThreshold"; public static final String OPTION_CROP_RATIO = "cropRatio"; + public static final String OPTION_USE_GEOMODEL = "useGeomodel"; + public static final String OPTION_GEOMODEL_PATH = "geomodelPath"; + public static final String OPTION_LOCATION = "location"; + public static final String LATITUDE = "latitude"; + public static final String LONGITUDE = "longitude"; + public static final String ELEVATION = "elevation"; public static final float DEFAULT_CONFIDENCE_THRESHOLD = 0.7f; private float mConfidenceThreshold = DEFAULT_CONFIDENCE_THRESHOLD; @@ -100,8 +106,49 @@ public void getPredictionsForImage(ReadableMap options, Promise promise) { } double cropRatio = options.hasKey(OPTION_CROP_RATIO) ? options.getDouble(OPTION_CROP_RATIO) : DEFAULT_CROP_RATIO; - ImageClassifier classifier = null; + // Destructure geomodel parameters. Those can be null + Boolean useGeomodel = options.hasKey(OPTION_USE_GEOMODEL) ? options.getBoolean(OPTION_USE_GEOMODEL) : null; + String geomodelPath = options.hasKey(OPTION_GEOMODEL_PATH) ? options.getString(OPTION_GEOMODEL_PATH) : null; + ReadableMap location = options.hasKey(OPTION_LOCATION) ? options.getMap(OPTION_LOCATION) : null; + + // Initialize and use geomodel if requested + GeoClassifier geoClassifier = null; + float[][] geomodelScores = null; + if (useGeomodel != null && useGeomodel) { + if (geomodelPath == null) { + throw new RuntimeException("Geomodel scoring requested but path is null"); + } + if (location == null) { + throw new RuntimeException("Geomodel scoring requested but location is null"); + } + Double latitude = location.hasKey(LATITUDE) ? location.getDouble(LATITUDE) : null; + Double longitude = location.hasKey(LONGITUDE) ? location.getDouble(LONGITUDE) : null; + Double elevation = location.hasKey(ELEVATION) ? location.getDouble(ELEVATION) : null; + if (latitude == null || longitude == null || elevation == null) { + throw new RuntimeException("Geomodel scoring requested but latitude, longitude, or elevation is null"); + } + // Geomodel classifier initialization with model and taxonomy files + Timber.tag(TAG).d("Initializing geo classifier: " + geomodelPath + " / " + taxonomyFilename); + try { + geoClassifier = new GeoClassifier(geomodelPath, taxonomyFilename, version); + } catch (IOException e) { + e.printStackTrace(); + throw new RuntimeException("Failed to initialize a geomodel classifier: " + e.getMessage()); + } catch (OutOfMemoryError e) { + e.printStackTrace(); + throw new RuntimeException("Out of memory"); + } catch (Exception e) { + e.printStackTrace(); + Timber.tag(TAG).w("Other type of exception - Device not supported - classifier failed to load - " + e); + throw new RuntimeException("Android version is too old - needs to be at least 6.0"); + } + geomodelScores = geoClassifier.predictionsForLocation(latitude, longitude, elevation); + } else { + Timber.tag(TAG).d("Not using geomodel."); + } + + ImageClassifier classifier = null; try { classifier = new ImageClassifier(modelFilename, taxonomyFilename, version); } catch (IOException e) { @@ -154,12 +201,12 @@ public void getPredictionsForImage(ReadableMap options, Promise promise) { return; } + classifier.setGeomodelScores(geomodelScores); // Override the built-in taxonomy cutoff for predictions from file Double taxonomyRollupCutoff = 0.0; - List predictions = classifier.classifyFrame(bitmap, taxonomyRollupCutoff); + List predictions = classifier.classifyBitmap(bitmap, taxonomyRollupCutoff); bitmap.recycle(); - WritableArray cleanedPredictions = Arguments.createArray(); for (Prediction prediction : predictions) { // only KPCOFGS ranks qualify as "top" predictions @@ -176,7 +223,7 @@ public void getPredictionsForImage(ReadableMap options, Promise promise) { } } - + long endTime = SystemClock.uptimeMillis(); WritableMap resultMap = Arguments.createMap(); resultMap.putArray("predictions", cleanedPredictions); @@ -185,4 +232,55 @@ public void getPredictionsForImage(ReadableMap options, Promise promise) { resultMap.putDouble("timeElapsed", (endTime - startTime) / 1000.0); promise.resolve(resultMap); } + + @ReactMethod + public void getPredictionsForLocation(ReadableMap options, Promise promise) { + long startTime = SystemClock.uptimeMillis(); + // Destructure the model path from the options map + String geomodelPath = options.getString(OPTION_GEOMODEL_PATH); + String taxonomyPath = options.getString(OPTION_TAXONOMY_PATH); + ReadableMap location = options.getMap(OPTION_LOCATION); + + double latitude = location.getDouble(LATITUDE); + double longitude = location.getDouble(LONGITUDE); + double elevation = location.getDouble(ELEVATION); + + GeoClassifier classifier = null; + try { + classifier = new GeoClassifier(geomodelPath, taxonomyPath, "2.13"); + } catch (IOException e) { + e.printStackTrace(); + promise.reject("E_CLASSIFIER", "Failed to initialize a geomodel mClassifier: " + e.getMessage()); + return; + } catch (OutOfMemoryError e) { + e.printStackTrace(); + Timber.tag(TAG).w("Out of memory - Device not supported - classifier failed to load - " + e); + promise.reject("E_OUT_OF_MEMORY", "Out of memory"); + return; + } catch (Exception e) { + e.printStackTrace(); + Timber.tag(TAG).w("Other type of exception - Device not supported - classifier failed to load - " + e); + promise.reject("E_UNSUPPORTED_DEVICE", "Android version is too old - needs to be at least 6.0"); + return; + } + + List predictions = classifier.expectedNearby(latitude, longitude, elevation); + + WritableArray cleanedPredictions = Arguments.createArray(); + for (Prediction prediction : predictions) { + Map map = Taxonomy.nodeToMap(prediction); + if (map == null) continue; + // Transform the Map to a ReadableMap + ReadableMap readableMap = Arguments.makeNativeMap(map); + cleanedPredictions.pushMap(readableMap); + } + + long endTime = SystemClock.uptimeMillis(); + WritableMap resultMap = Arguments.createMap(); + resultMap.putArray("predictions", cleanedPredictions); + resultMap.putMap("options", options); + // Time elapsed on the native side; in seconds + resultMap.putDouble("timeElapsed", (endTime - startTime) / 1000.0); + promise.resolve(resultMap); + } } diff --git a/android/src/main/java/com/visioncameraplugininatvision/VisionCameraPluginInatVisionPlugin.java b/android/src/main/java/com/visioncameraplugininatvision/VisionCameraPluginInatVisionPlugin.java index cfc20e8..f9fd90d 100644 --- a/android/src/main/java/com/visioncameraplugininatvision/VisionCameraPluginInatVisionPlugin.java +++ b/android/src/main/java/com/visioncameraplugininatvision/VisionCameraPluginInatVisionPlugin.java @@ -30,6 +30,7 @@ public class VisionCameraPluginInatVisionPlugin extends FrameProcessorPlugin { private final static String TAG = "VisionCameraPluginInatVisionPlugin"; private ImageClassifier mImageClassifier = null; + private GeoClassifier mGeoClassifier = null; private Integer mFilterByTaxonId = null; // If null -> no filter by taxon ID defined public void setFilterByTaxonId(Integer taxonId) { @@ -106,6 +107,46 @@ public Object callback(@NonNull Frame frame, @Nullable Map argum setCropRatio(cropRatio); } + // Destructure geomodel parameters. Those can be null + Boolean useGeomodel = (Boolean)arguments.get("useGeomodel"); + String geomodelPath = (String)arguments.get("geomodelPath"); + Map location = (Map)arguments.get("location"); + + // Initialize and use geomodel if requested + float[][] geomodelScores = null; + if (useGeomodel != null && useGeomodel) { + if (geomodelPath == null) { + throw new RuntimeException("Geomodel scoring requested but path is null"); + } + if (location == null) { + throw new RuntimeException("Geomodel scoring requested but location is null"); + } + Double latitude = location.get("latitude"); + Double longitude = location.get("longitude"); + Double elevation = location.get("elevation"); + + // Geomodel classifier initialization with model and taxonomy files + if (mGeoClassifier == null) { + Timber.tag(TAG).d("Initializing geo classifier: " + geomodelPath + " / " + taxonomyPath); + try { + mGeoClassifier = new GeoClassifier(geomodelPath, taxonomyPath, version); + } catch (IOException e) { + e.printStackTrace(); + throw new RuntimeException("Failed to initialize a geomodel classifier: " + e.getMessage()); + } catch (OutOfMemoryError e) { + e.printStackTrace(); + throw new RuntimeException("Out of memory"); + } catch (Exception e) { + e.printStackTrace(); + Timber.tag(TAG).w("Other type of exception - Device not supported - classifier failed to load - " + e); + throw new RuntimeException("Android version is too old - needs to be at least 6.0"); + } + } + geomodelScores = mGeoClassifier.predictionsForLocation(latitude, longitude, elevation); + } else { + Timber.tag(TAG).d("Not using geomodel for this frame."); + } + // Image classifier initialization with model and taxonomy files if (mImageClassifier == null) { Timber.tag(TAG).d("Initializing classifier: " + modelPath + " / " + taxonomyPath); @@ -130,6 +171,7 @@ public Object callback(@NonNull Frame frame, @Nullable Map argum List cleanedPredictions = new ArrayList<>(); if (mImageClassifier != null) { + mImageClassifier.setGeomodelScores(geomodelScores); Bitmap bmp = BitmapUtils.getBitmap(image, patchedOrientationAndroid); Log.d(TAG, "originalBitmap: " + bmp + ": " + bmp.getWidth() + " x " + bmp.getHeight()); // Crop the center square of the frame @@ -148,8 +190,10 @@ public Object callback(@NonNull Frame frame, @Nullable Map argum bmp.recycle(); bmp = rescaledBitmap; Log.d(TAG, "rescaledBitmap: " + bmp + ": " + bmp.getWidth() + " x " + bmp.getHeight()); - List predictions = mImageClassifier.classifyFrame(bmp, taxonomyRollupCutoff); + List predictions = mImageClassifier.classifyBitmap(bmp, taxonomyRollupCutoff); bmp.recycle(); + croppedBitmap.recycle(); + Log.d(TAG, "Predictions: " + predictions.size()); for (Prediction prediction : predictions) { diff --git a/example/src/App.tsx b/example/src/App.tsx index 4381d0e..6401143 100644 --- a/example/src/App.tsx +++ b/example/src/App.tsx @@ -169,6 +169,10 @@ export default function App(): React.JSX.Element { taxonomyFilenameAndroid, `${RNFS.DocumentDirectoryPath}/${taxonomyFilenameAndroid}` ); + await RNFS.copyFileAssets( + geomodelFilenameAndroid, + `${RNFS.DocumentDirectoryPath}/${geomodelFilenameAndroid}` + ); })(); } }, []); diff --git a/ios/VisionCameraPluginInatVision/VisionCameraPluginInatVision.m b/ios/VisionCameraPluginInatVision/VisionCameraPluginInatVision.m index a918dab..cd8c765 100644 --- a/ios/VisionCameraPluginInatVision/VisionCameraPluginInatVision.m +++ b/ios/VisionCameraPluginInatVision/VisionCameraPluginInatVision.m @@ -154,7 +154,7 @@ - (id)callback:(Frame*)frame withArguments:(NSDictionary*)arguments { lng:longitude.floatValue elevation:elevation.floatValue]; } else { - NSLog(@"not doing anything geo related."); + NSLog(@"Not using geomodel for this frame."); } CMSampleBufferRef buffer = frame.buffer; diff --git a/ios/VisionCameraPluginInatVision/VisionCameraPluginInatVisionModule.m b/ios/VisionCameraPluginInatVision/VisionCameraPluginInatVisionModule.m index dc99ce0..e33c5e8 100644 --- a/ios/VisionCameraPluginInatVision/VisionCameraPluginInatVisionModule.m +++ b/ios/VisionCameraPluginInatVision/VisionCameraPluginInatVisionModule.m @@ -323,7 +323,6 @@ - (MLMultiArray *)normalizeMultiArray:(MLMultiArray *)mlArray error:(NSError **) NSArray *leafScores = [taxonomy expectedNearbyFromClassification:geomodelPreds]; - // convert the VCPPredictions in the bestRecentBranch into dicts NSMutableArray *predictions = [NSMutableArray array]; for (VCPPrediction *prediction in leafScores) { [predictions addObject:[prediction asDict]]; diff --git a/scripts/createTaxonomy.js b/scripts/createTaxonomy.js index 0233c20..6904b87 100644 --- a/scripts/createTaxonomy.js +++ b/scripts/createTaxonomy.js @@ -51,10 +51,24 @@ fs.createReadStream(filePathTaxonomy) entry.spatial_threshold = thresholdDict[entry.taxon_id] ? parseFloat(thresholdDict[entry.taxon_id]) : null; + // Delete and add the name so that it is last when written to file + const name = entry.name; + delete entry.name; + entry.name = name; return entry; }); // Write json to file const json = JSON.stringify(combinedEntries, null, 2); fs.writeFileSync('taxonomy.json', json); + + // Also write a new .csv with the threshold data appended to the original rows + const csvHeader = Object.keys(combinedEntries[0]).join(','); + const csvRows = combinedEntries.map((entry) => + Object.values(entry).join(',') + ); + let csvData = [csvHeader, ...csvRows].join('\n'); + // Replace all NaN with empty string + csvData = csvData.replace(/NaN/g, ''); + fs.writeFileSync('taxonomy_with_thresholds.csv', csvData); }); }); diff --git a/scripts/whichTaxaAreNew.js b/scripts/whichTaxaAreNew.js new file mode 100644 index 0000000..eef6291 --- /dev/null +++ b/scripts/whichTaxaAreNew.js @@ -0,0 +1,82 @@ +const fs = require('fs'); +const path = require('path'); +const csv = require('csv-parser'); + +// Step 1: Define the file path +const filePathTaxonomy = path.join(__dirname, 'taxonomy_v1.csv'); +const filePathTaxonomy2 = path.join(__dirname, 'taxonomy_v2_13.csv'); + +// Step 2: Read the .csv files and Step 3: Parse the CSV data +let entriesTaxonomy1 = []; +fs.createReadStream(filePathTaxonomy) + .pipe(csv()) + .on('data', (row) => { + // Extract the row + entriesTaxonomy1.push(row); + }) + .on('end', () => { + // Read taxonomy data + console.log('entriesTaxonomy1.length', entriesTaxonomy1.length); + const overridenEntries = entriesTaxonomy1.map((entry) => { + entry.parent_taxon_id = parseInt(entry.parent_taxon_id, 10); + entry.taxon_id = parseInt(entry.taxon_id, 10); + entry.rank_level = parseInt(entry.rank_level, 10); + entry.leaf_class_id = parseInt(entry.leaf_class_id, 10); + entry.iconic_class_id = parseInt(entry.iconic_class_id, 10); + entry.spatial_class_id = parseInt(entry.spatial_class_id, 10); + return entry; + }); + entriesTaxonomy1 = overridenEntries; + // Read second taxonomy data + let entriesTaxonomy2 = []; + fs.createReadStream(filePathTaxonomy2) + .pipe(csv()) + .on('data', (row) => { + // Extract the row + entriesTaxonomy2.push(row); + }) + .on('end', () => { + // Read taxonomy data + console.log('entriesTaxonomy2.length', entriesTaxonomy2.length); + const overridenEntries2 = entriesTaxonomy2.map((entry) => { + entry.parent_taxon_id = parseInt(entry.parent_taxon_id, 10); + entry.taxon_id = parseInt(entry.taxon_id, 10); + entry.rank_level = parseInt(entry.rank_level, 10); + entry.leaf_class_id = parseInt(entry.leaf_class_id, 10); + entry.iconic_class_id = parseInt(entry.iconic_class_id, 10); + entry.spatial_class_id = parseInt(entry.spatial_class_id, 10); + return entry; + }); + entriesTaxonomy2 = overridenEntries2; + // Find new taxa + const newTaxa = entriesTaxonomy2.filter( + (entry2) => + !entriesTaxonomy1.some( + (entry1) => entry1.taxon_id === entry2.taxon_id + ) + ); + console.log('newTaxa', newTaxa.length); + // Write json to file + const json = JSON.stringify(newTaxa, null, 2); + fs.writeFileSync('newTaxa.json', json); + // Find taxa no longer present + const removedTaxa = entriesTaxonomy1.filter( + (entry1) => + !entriesTaxonomy2.some( + (entry2) => entry1.taxon_id === entry2.taxon_id + ) + ); + console.log('removedTaxa', removedTaxa.length); + // Write json to file + const json2 = JSON.stringify(removedTaxa, null, 2); + fs.writeFileSync('removedTaxa.json', json2); + // Combine all taxon_ids of removed taxa into a comma sperated long string + const removedSpecies = removedTaxa.filter( + (entry) => entry.rank_level === 10 + ); + console.log('removedSpecies.length', removedSpecies.length); + const removedTaxaIds = removedSpecies.map((entry) => entry.taxon_id); + const removedTaxaIdsString = removedTaxaIds.join(','); + console.log('removedTaxaIdsString', removedTaxaIdsString); + }); + });