diff --git a/photon-core/src/main/java/org/photonvision/common/configuration/NeuralNetworkModelManager.java b/photon-core/src/main/java/org/photonvision/common/configuration/NeuralNetworkModelManager.java
index 75a1ded750..dd3772f245 100644
--- a/photon-core/src/main/java/org/photonvision/common/configuration/NeuralNetworkModelManager.java
+++ b/photon-core/src/main/java/org/photonvision/common/configuration/NeuralNetworkModelManager.java
@@ -18,24 +18,49 @@
package org.photonvision.common.configuration;
import java.io.File;
-import java.io.FileOutputStream;
import java.io.IOException;
+import java.net.URL;
import java.nio.file.Files;
+import java.nio.file.Path;
import java.nio.file.Paths;
+import java.nio.file.StandardCopyOption;
+import java.util.ArrayList;
import java.util.List;
import org.photonvision.common.logging.LogGroup;
import org.photonvision.common.logging.Logger;
import org.photonvision.rknn.RknnJNI;
+/**
+ * Manages the loading of neural network models.
+ *
+ *
Models are loaded from the filesystem at the modelsFolder
location. PhotonVision
+ * also supports shipping pre-trained models as resources in the JAR. If the model is not found on
+ * the filesystem, it will be extracted from the JAR to the filesystem.
+ *
+ *
Each model must have a corresponding labels
file. The labels file format is
+ * simply a list of string names per label, one label per line. The labels file must have the same
+ * name as the model file, but with the suffix -labels.txt
instead of .rknn
+ *
.
+ *
+ *
Note: PhotonVision currently only supports YOLOv5 and YOLOv8 models in the .rknn
+ * format.
+ */
public class NeuralNetworkModelManager {
+ /** Singleton instance of the NeuralNetworkModelManager */
private static NeuralNetworkModelManager INSTANCE;
- private static final Logger logger = new Logger(NeuralNetworkModelManager.class, LogGroup.Config);
- private final String MODEL_NAME = "note-640-640-yolov5s.rknn";
- private final RknnJNI.ModelVersion modelVersion = RknnJNI.ModelVersion.YOLO_V5;
- private File defaultModelFile;
- private List labels;
+ /**
+ * Private constructor to prevent instantiation
+ *
+ * @return The NeuralNetworkModelManager instance
+ */
+ private NeuralNetworkModelManager() {}
+ /**
+ * Returns the singleton instance of the NeuralNetworkModelManager
+ *
+ * @return The singleton instance
+ */
public static NeuralNetworkModelManager getInstance() {
if (INSTANCE == null) {
INSTANCE = new NeuralNetworkModelManager();
@@ -43,62 +68,200 @@ public static NeuralNetworkModelManager getInstance() {
return INSTANCE;
}
+ /** Logger for the NeuralNetworkModelManager */
+ private static final Logger logger = new Logger(NeuralNetworkModelManager.class, LogGroup.Config);
+
+ /**
+ * Determines the model version based on the model's filename.
+ *
+ * "yolov5" -> "YOLO_V5"
+ *
+ *
"yolov8" -> "YOLO_V8"
+ *
+ * @param modelName The model's filename
+ * @return The model version
+ */
+ private static RknnJNI.ModelVersion getModelVersion(String modelName)
+ throws IllegalArgumentException {
+ if (modelName.contains("yolov5")) {
+ return RknnJNI.ModelVersion.YOLO_V5;
+ } else if (modelName.contains("yolov8")) {
+ return RknnJNI.ModelVersion.YOLO_V8;
+ } else {
+ throw new IllegalArgumentException("Unknown model version for model " + modelName);
+ }
+ }
+
+ /** This class represents a model that can be loaded by the RknnJNI. */
+ public class Model {
+ public final File modelFile;
+ public final RknnJNI.ModelVersion version;
+ public final List labels;
+
+ public Model(String model, String labels) throws IllegalArgumentException {
+ this.version = getModelVersion(model);
+ this.modelFile = new File(model);
+ try {
+ this.labels = Files.readAllLines(Paths.get(labels));
+ } catch (IOException e) {
+ throw new IllegalArgumentException("Error reading labels file " + labels, e);
+ }
+ }
+
+ public String getPath() {
+ return modelFile.getAbsolutePath();
+ }
+ }
+
+ /**
+ * Stores model information, such as the model file, labels, and version.
+ *
+ * The first model in the list is the default model.
+ */
+ private List models;
+
/**
- * Perform initial setup and extract default model from JAR to the filesystem
+ * Returns the default rknn model. This is simply the first model in the list.
*
- * @param modelsFolder Where models live
+ * @return The default model
*/
- public void initialize(File modelsFolder) {
- var modelResourcePath = "/models/" + MODEL_NAME;
- this.defaultModelFile = new File(modelsFolder, MODEL_NAME);
- extractResource(modelResourcePath, defaultModelFile);
+ public Model getDefaultRknnModel() {
+ return models.get(0);
+ }
- File labelsFile = new File(modelsFolder, "labels_v5.txt");
- var labelResourcePath = "/models/" + labelsFile.getName();
- extractResource(labelResourcePath, labelsFile);
+ /**
+ * Enumerates the names of all models.
+ *
+ * @return A list of model names
+ */
+ public List getModels() {
+ return models.stream().map(model -> model.modelFile.getName()).toList();
+ }
+
+ /**
+ * Returns the model with the given name.
+ *
+ * TODO: Java 17 This should return an Optional instead of null.
+ *
+ * @param modelName The model name
+ * @return The model
+ */
+ public Model getModel(String modelName) {
+ Model m =
+ models.stream()
+ .filter(model -> model.modelFile.getName().equals(modelName))
+ .findFirst()
+ .orElse(null);
+
+ if (m == null) {
+ logger.error("Model " + modelName + " not found.");
+ }
+
+ return m;
+ }
+
+ /**
+ * Loads models from the specified folder.
+ *
+ * @param modelsFolder The folder where the models are stored
+ */
+ public void loadModels(File modelsFolder) {
+ if (!modelsFolder.exists()) {
+ logger.error("Models folder " + modelsFolder.getAbsolutePath() + " does not exist.");
+ return;
+ }
+
+ if (models == null) {
+ models = new ArrayList<>();
+ }
try {
- labels = Files.readAllLines(Paths.get(labelsFile.getPath()));
+ Files.walk(modelsFolder.toPath())
+ .filter(Files::isRegularFile)
+ .filter(path -> path.toString().endsWith(".rknn"))
+ .forEach(
+ modelPath -> {
+ String model = modelPath.toString();
+ String labels = model.replace(".rknn", "-labels.txt");
+
+ try {
+ models.add(new Model(model, labels));
+ } catch (IllegalArgumentException e) {
+ logger.error("Failed to load model " + model, e);
+ }
+ });
} catch (IOException e) {
- logger.error("Error reading labels.txt", e);
+ logger.error("Failed to load models from " + modelsFolder.getAbsolutePath(), e);
+ }
+
+ // Log the loaded models
+ StringBuilder sb = new StringBuilder();
+ sb.append("Loaded models: ");
+ for (Model model : models) {
+ sb.append(model.modelFile.getName()).append(", ");
}
+ sb.setLength(sb.length() - 2);
+ logger.info(sb.toString());
}
- private void extractResource(String resourcePath, File outputFile) {
- try (var in = NeuralNetworkModelManager.class.getResourceAsStream(resourcePath)) {
- if (in == null) {
+ /**
+ * Extracts models from a JAR resource and copies them to the specified folder.
+ *
+ * @param modelsFolder the folder where the models will be copied to
+ */
+ public void extractModels(File modelsFolder) {
+ if (!modelsFolder.exists()) {
+ modelsFolder.mkdirs();
+ }
+
+ String resourcePath = "models"; // Adjust path if necessary
+ try {
+ URL resourceURL = NeuralNetworkModelManager.class.getClassLoader().getResource(resourcePath);
+ if (resourceURL == null) {
logger.error("Failed to find jar resource at " + resourcePath);
return;
}
- if (!outputFile.exists()) {
- try (FileOutputStream fos = new FileOutputStream(outputFile)) {
- int read = -1;
- byte[] buffer = new byte[1024];
- while ((read = in.read(buffer)) != -1) {
- fos.write(buffer, 0, read);
+ Path resourcePathResolved = Paths.get(resourceURL.toURI());
+ Files.walk(resourcePathResolved)
+ .forEach(sourcePath -> copyResource(sourcePath, resourcePathResolved, modelsFolder));
+ } catch (Exception e) {
+ logger.error("Failed to extract models from JAR", e);
+ }
+ }
+
+ /**
+ * Copies a resource from the source path to the target path.
+ *
+ * @param sourcePath The path of the resource to be copied.
+ * @param resourcePathResolved The resolved path of the resource.
+ * @param modelsFolder The folder where the resource will be copied to.
+ */
+ private void copyResource(Path sourcePath, Path resourcePathResolved, File modelsFolder) {
+ Path targetPath =
+ Paths.get(
+ modelsFolder.getAbsolutePath(), resourcePathResolved.relativize(sourcePath).toString());
+ try {
+ if (Files.isDirectory(sourcePath)) {
+ Files.createDirectories(targetPath);
+ } else {
+ Path parentDir = targetPath.getParent();
+ if (parentDir != null && !Files.exists(parentDir)) {
+ Files.createDirectories(parentDir);
+ }
+
+ if (!Files.exists(targetPath)) {
+ Files.copy(sourcePath, targetPath);
+ } else {
+ long sourceSize = Files.size(sourcePath);
+ long targetSize = Files.size(targetPath);
+ if (sourceSize != targetSize) {
+ Files.copy(sourcePath, targetPath, StandardCopyOption.REPLACE_EXISTING);
}
- } catch (IOException e) {
- logger.error("Error extracting resource to " + outputFile.toPath().toString(), e);
}
- } else {
- logger.info(
- "File " + outputFile.toPath().toString() + " already exists. Skipping extraction.");
}
} catch (IOException e) {
- logger.error("Error finding jar resource " + resourcePath, e);
+ logger.error("Failed to copy " + sourcePath + " to " + targetPath, e);
}
}
-
- public File getDefaultRknnModel() {
- return defaultModelFile;
- }
-
- public List getLabels() {
- return labels;
- }
-
- public RknnJNI.ModelVersion getModelVersion() {
- return modelVersion;
- }
}
diff --git a/photon-core/src/main/java/org/photonvision/jni/RknnDetectorJNI.java b/photon-core/src/main/java/org/photonvision/jni/RknnDetectorJNI.java
index c282ec7a79..a247818052 100644
--- a/photon-core/src/main/java/org/photonvision/jni/RknnDetectorJNI.java
+++ b/photon-core/src/main/java/org/photonvision/jni/RknnDetectorJNI.java
@@ -23,6 +23,7 @@
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.stream.Collectors;
import org.opencv.core.Mat;
+import org.photonvision.common.configuration.NeuralNetworkModelManager;
import org.photonvision.common.logging.LogGroup;
import org.photonvision.common.logging.Logger;
import org.photonvision.common.util.TestUtils;
@@ -70,6 +71,10 @@ public static class RknnObjectDetector {
static volatile boolean hook = false;
+ public RknnObjectDetector(NeuralNetworkModelManager.Model model) {
+ this(model.getPath(), model.labels, model.version);
+ }
+
public RknnObjectDetector(String modelPath, List labels, RknnJNI.ModelVersion version) {
synchronized (lock) {
objPointer = RknnJNI.create(modelPath, labels.size(), version.ordinal(), -1);
diff --git a/photon-core/src/main/java/org/photonvision/vision/pipe/impl/RknnDetectionPipe.java b/photon-core/src/main/java/org/photonvision/vision/pipe/impl/RknnDetectionPipe.java
index 9ce2f348f4..70d02e649a 100644
--- a/photon-core/src/main/java/org/photonvision/vision/pipe/impl/RknnDetectionPipe.java
+++ b/photon-core/src/main/java/org/photonvision/vision/pipe/impl/RknnDetectionPipe.java
@@ -40,14 +40,10 @@ public class RknnDetectionPipe
public RknnDetectionPipe() {
// For now this is hard-coded to defaults. Should be refactored into set pipe
- // params, though.
- // And ideally a little wrapper helper for only changing native stuff on content
+ // params, though. And ideally a little wrapper helper for only changing native stuff on content
// change created.
this.detector =
- new RknnObjectDetector(
- NeuralNetworkModelManager.getInstance().getDefaultRknnModel().getAbsolutePath(),
- NeuralNetworkModelManager.getInstance().getLabels(),
- NeuralNetworkModelManager.getInstance().getModelVersion());
+ new RknnObjectDetector(NeuralNetworkModelManager.getInstance().getDefaultRknnModel());
}
private static class Letterbox {
diff --git a/photon-server/src/main/java/org/photonvision/Main.java b/photon-server/src/main/java/org/photonvision/Main.java
index 9dd5e7ad70..ab3b15aadc 100644
--- a/photon-server/src/main/java/org/photonvision/Main.java
+++ b/photon-server/src/main/java/org/photonvision/Main.java
@@ -435,8 +435,9 @@ public static void main(String[] args) {
.setConfig(ConfigManager.getInstance().getConfig().getNetworkConfig());
logger.info("Loading ML models");
- NeuralNetworkModelManager.getInstance()
- .initialize(ConfigManager.getInstance().getModelsDirectory());
+ var modelManager = NeuralNetworkModelManager.getInstance();
+ modelManager.extractModels(ConfigManager.getInstance().getModelsDirectory());
+ modelManager.loadModels(ConfigManager.getInstance().getModelsDirectory());
if (isSmoketest) {
logger.info("PhotonVision base functionality loaded -- smoketest complete");
diff --git a/photon-server/src/main/resources/models/labels_v5.txt b/photon-server/src/main/resources/models/note-640-640-yolov5s-labels.txt
similarity index 100%
rename from photon-server/src/main/resources/models/labels_v5.txt
rename to photon-server/src/main/resources/models/note-640-640-yolov5s-labels.txt