diff --git a/photon-core/src/main/java/org/photonvision/common/configuration/NeuralNetworkModelManager.java b/photon-core/src/main/java/org/photonvision/common/configuration/NeuralNetworkModelManager.java index 75a1ded750..dd3772f245 100644 --- a/photon-core/src/main/java/org/photonvision/common/configuration/NeuralNetworkModelManager.java +++ b/photon-core/src/main/java/org/photonvision/common/configuration/NeuralNetworkModelManager.java @@ -18,24 +18,49 @@ package org.photonvision.common.configuration; import java.io.File; -import java.io.FileOutputStream; import java.io.IOException; +import java.net.URL; import java.nio.file.Files; +import java.nio.file.Path; import java.nio.file.Paths; +import java.nio.file.StandardCopyOption; +import java.util.ArrayList; import java.util.List; import org.photonvision.common.logging.LogGroup; import org.photonvision.common.logging.Logger; import org.photonvision.rknn.RknnJNI; +/** + * Manages the loading of neural network models. + * + *

Models are loaded from the filesystem at the modelsFolder location. PhotonVision + * also supports shipping pre-trained models as resources in the JAR. If the model is not found on + * the filesystem, it will be extracted from the JAR to the filesystem. + * + *

Each model must have a corresponding labels file. The labels file format is + * simply a list of string names per label, one label per line. The labels file must have the same + * name as the model file, but with the suffix -labels.txt instead of .rknn + * . + * + *

Note: PhotonVision currently only supports YOLOv5 and YOLOv8 models in the .rknn + * format. + */ public class NeuralNetworkModelManager { + /** Singleton instance of the NeuralNetworkModelManager */ private static NeuralNetworkModelManager INSTANCE; - private static final Logger logger = new Logger(NeuralNetworkModelManager.class, LogGroup.Config); - private final String MODEL_NAME = "note-640-640-yolov5s.rknn"; - private final RknnJNI.ModelVersion modelVersion = RknnJNI.ModelVersion.YOLO_V5; - private File defaultModelFile; - private List labels; + /** + * Private constructor to prevent instantiation + * + * @return The NeuralNetworkModelManager instance + */ + private NeuralNetworkModelManager() {} + /** + * Returns the singleton instance of the NeuralNetworkModelManager + * + * @return The singleton instance + */ public static NeuralNetworkModelManager getInstance() { if (INSTANCE == null) { INSTANCE = new NeuralNetworkModelManager(); @@ -43,62 +68,200 @@ public static NeuralNetworkModelManager getInstance() { return INSTANCE; } + /** Logger for the NeuralNetworkModelManager */ + private static final Logger logger = new Logger(NeuralNetworkModelManager.class, LogGroup.Config); + + /** + * Determines the model version based on the model's filename. + * + *

"yolov5" -> "YOLO_V5" + * + *

"yolov8" -> "YOLO_V8" + * + * @param modelName The model's filename + * @return The model version + */ + private static RknnJNI.ModelVersion getModelVersion(String modelName) + throws IllegalArgumentException { + if (modelName.contains("yolov5")) { + return RknnJNI.ModelVersion.YOLO_V5; + } else if (modelName.contains("yolov8")) { + return RknnJNI.ModelVersion.YOLO_V8; + } else { + throw new IllegalArgumentException("Unknown model version for model " + modelName); + } + } + + /** This class represents a model that can be loaded by the RknnJNI. */ + public class Model { + public final File modelFile; + public final RknnJNI.ModelVersion version; + public final List labels; + + public Model(String model, String labels) throws IllegalArgumentException { + this.version = getModelVersion(model); + this.modelFile = new File(model); + try { + this.labels = Files.readAllLines(Paths.get(labels)); + } catch (IOException e) { + throw new IllegalArgumentException("Error reading labels file " + labels, e); + } + } + + public String getPath() { + return modelFile.getAbsolutePath(); + } + } + + /** + * Stores model information, such as the model file, labels, and version. + * + *

The first model in the list is the default model. + */ + private List models; + /** - * Perform initial setup and extract default model from JAR to the filesystem + * Returns the default rknn model. This is simply the first model in the list. * - * @param modelsFolder Where models live + * @return The default model */ - public void initialize(File modelsFolder) { - var modelResourcePath = "/models/" + MODEL_NAME; - this.defaultModelFile = new File(modelsFolder, MODEL_NAME); - extractResource(modelResourcePath, defaultModelFile); + public Model getDefaultRknnModel() { + return models.get(0); + } - File labelsFile = new File(modelsFolder, "labels_v5.txt"); - var labelResourcePath = "/models/" + labelsFile.getName(); - extractResource(labelResourcePath, labelsFile); + /** + * Enumerates the names of all models. + * + * @return A list of model names + */ + public List getModels() { + return models.stream().map(model -> model.modelFile.getName()).toList(); + } + + /** + * Returns the model with the given name. + * + *

TODO: Java 17 This should return an Optional instead of null. + * + * @param modelName The model name + * @return The model + */ + public Model getModel(String modelName) { + Model m = + models.stream() + .filter(model -> model.modelFile.getName().equals(modelName)) + .findFirst() + .orElse(null); + + if (m == null) { + logger.error("Model " + modelName + " not found."); + } + + return m; + } + + /** + * Loads models from the specified folder. + * + * @param modelsFolder The folder where the models are stored + */ + public void loadModels(File modelsFolder) { + if (!modelsFolder.exists()) { + logger.error("Models folder " + modelsFolder.getAbsolutePath() + " does not exist."); + return; + } + + if (models == null) { + models = new ArrayList<>(); + } try { - labels = Files.readAllLines(Paths.get(labelsFile.getPath())); + Files.walk(modelsFolder.toPath()) + .filter(Files::isRegularFile) + .filter(path -> path.toString().endsWith(".rknn")) + .forEach( + modelPath -> { + String model = modelPath.toString(); + String labels = model.replace(".rknn", "-labels.txt"); + + try { + models.add(new Model(model, labels)); + } catch (IllegalArgumentException e) { + logger.error("Failed to load model " + model, e); + } + }); } catch (IOException e) { - logger.error("Error reading labels.txt", e); + logger.error("Failed to load models from " + modelsFolder.getAbsolutePath(), e); + } + + // Log the loaded models + StringBuilder sb = new StringBuilder(); + sb.append("Loaded models: "); + for (Model model : models) { + sb.append(model.modelFile.getName()).append(", "); } + sb.setLength(sb.length() - 2); + logger.info(sb.toString()); } - private void extractResource(String resourcePath, File outputFile) { - try (var in = NeuralNetworkModelManager.class.getResourceAsStream(resourcePath)) { - if (in == null) { + /** + * Extracts models from a JAR resource and copies them to the specified folder. + * + * @param modelsFolder the folder where the models will be copied to + */ + public void extractModels(File modelsFolder) { + if (!modelsFolder.exists()) { + modelsFolder.mkdirs(); + } + + String resourcePath = "models"; // Adjust path if necessary + try { + URL resourceURL = NeuralNetworkModelManager.class.getClassLoader().getResource(resourcePath); + if (resourceURL == null) { logger.error("Failed to find jar resource at " + resourcePath); return; } - if (!outputFile.exists()) { - try (FileOutputStream fos = new FileOutputStream(outputFile)) { - int read = -1; - byte[] buffer = new byte[1024]; - while ((read = in.read(buffer)) != -1) { - fos.write(buffer, 0, read); + Path resourcePathResolved = Paths.get(resourceURL.toURI()); + Files.walk(resourcePathResolved) + .forEach(sourcePath -> copyResource(sourcePath, resourcePathResolved, modelsFolder)); + } catch (Exception e) { + logger.error("Failed to extract models from JAR", e); + } + } + + /** + * Copies a resource from the source path to the target path. + * + * @param sourcePath The path of the resource to be copied. + * @param resourcePathResolved The resolved path of the resource. + * @param modelsFolder The folder where the resource will be copied to. + */ + private void copyResource(Path sourcePath, Path resourcePathResolved, File modelsFolder) { + Path targetPath = + Paths.get( + modelsFolder.getAbsolutePath(), resourcePathResolved.relativize(sourcePath).toString()); + try { + if (Files.isDirectory(sourcePath)) { + Files.createDirectories(targetPath); + } else { + Path parentDir = targetPath.getParent(); + if (parentDir != null && !Files.exists(parentDir)) { + Files.createDirectories(parentDir); + } + + if (!Files.exists(targetPath)) { + Files.copy(sourcePath, targetPath); + } else { + long sourceSize = Files.size(sourcePath); + long targetSize = Files.size(targetPath); + if (sourceSize != targetSize) { + Files.copy(sourcePath, targetPath, StandardCopyOption.REPLACE_EXISTING); } - } catch (IOException e) { - logger.error("Error extracting resource to " + outputFile.toPath().toString(), e); } - } else { - logger.info( - "File " + outputFile.toPath().toString() + " already exists. Skipping extraction."); } } catch (IOException e) { - logger.error("Error finding jar resource " + resourcePath, e); + logger.error("Failed to copy " + sourcePath + " to " + targetPath, e); } } - - public File getDefaultRknnModel() { - return defaultModelFile; - } - - public List getLabels() { - return labels; - } - - public RknnJNI.ModelVersion getModelVersion() { - return modelVersion; - } } diff --git a/photon-core/src/main/java/org/photonvision/jni/RknnDetectorJNI.java b/photon-core/src/main/java/org/photonvision/jni/RknnDetectorJNI.java index c282ec7a79..a247818052 100644 --- a/photon-core/src/main/java/org/photonvision/jni/RknnDetectorJNI.java +++ b/photon-core/src/main/java/org/photonvision/jni/RknnDetectorJNI.java @@ -23,6 +23,7 @@ import java.util.concurrent.CopyOnWriteArrayList; import java.util.stream.Collectors; import org.opencv.core.Mat; +import org.photonvision.common.configuration.NeuralNetworkModelManager; import org.photonvision.common.logging.LogGroup; import org.photonvision.common.logging.Logger; import org.photonvision.common.util.TestUtils; @@ -70,6 +71,10 @@ public static class RknnObjectDetector { static volatile boolean hook = false; + public RknnObjectDetector(NeuralNetworkModelManager.Model model) { + this(model.getPath(), model.labels, model.version); + } + public RknnObjectDetector(String modelPath, List labels, RknnJNI.ModelVersion version) { synchronized (lock) { objPointer = RknnJNI.create(modelPath, labels.size(), version.ordinal(), -1); diff --git a/photon-core/src/main/java/org/photonvision/vision/pipe/impl/RknnDetectionPipe.java b/photon-core/src/main/java/org/photonvision/vision/pipe/impl/RknnDetectionPipe.java index 9ce2f348f4..70d02e649a 100644 --- a/photon-core/src/main/java/org/photonvision/vision/pipe/impl/RknnDetectionPipe.java +++ b/photon-core/src/main/java/org/photonvision/vision/pipe/impl/RknnDetectionPipe.java @@ -40,14 +40,10 @@ public class RknnDetectionPipe public RknnDetectionPipe() { // For now this is hard-coded to defaults. Should be refactored into set pipe - // params, though. - // And ideally a little wrapper helper for only changing native stuff on content + // params, though. And ideally a little wrapper helper for only changing native stuff on content // change created. this.detector = - new RknnObjectDetector( - NeuralNetworkModelManager.getInstance().getDefaultRknnModel().getAbsolutePath(), - NeuralNetworkModelManager.getInstance().getLabels(), - NeuralNetworkModelManager.getInstance().getModelVersion()); + new RknnObjectDetector(NeuralNetworkModelManager.getInstance().getDefaultRknnModel()); } private static class Letterbox { diff --git a/photon-server/src/main/java/org/photonvision/Main.java b/photon-server/src/main/java/org/photonvision/Main.java index 9dd5e7ad70..ab3b15aadc 100644 --- a/photon-server/src/main/java/org/photonvision/Main.java +++ b/photon-server/src/main/java/org/photonvision/Main.java @@ -435,8 +435,9 @@ public static void main(String[] args) { .setConfig(ConfigManager.getInstance().getConfig().getNetworkConfig()); logger.info("Loading ML models"); - NeuralNetworkModelManager.getInstance() - .initialize(ConfigManager.getInstance().getModelsDirectory()); + var modelManager = NeuralNetworkModelManager.getInstance(); + modelManager.extractModels(ConfigManager.getInstance().getModelsDirectory()); + modelManager.loadModels(ConfigManager.getInstance().getModelsDirectory()); if (isSmoketest) { logger.info("PhotonVision base functionality loaded -- smoketest complete"); diff --git a/photon-server/src/main/resources/models/labels_v5.txt b/photon-server/src/main/resources/models/note-640-640-yolov5s-labels.txt similarity index 100% rename from photon-server/src/main/resources/models/labels_v5.txt rename to photon-server/src/main/resources/models/note-640-640-yolov5s-labels.txt