diff --git a/serving/src/main/java/ai/djl/serving/ModelServer.java b/serving/src/main/java/ai/djl/serving/ModelServer.java index da432ba53..d7cb90091 100644 --- a/serving/src/main/java/ai/djl/serving/ModelServer.java +++ b/serving/src/main/java/ai/djl/serving/ModelServer.java @@ -344,6 +344,9 @@ private void initModelStore() throws IOException { } else { modelName = ModelInfo.inferModelNameFromUrl(modelUrl); } + if (engine == null) { + engine = inferEngineFromUrl(modelUrl); + } for (int i = 0; i < devices.length; ++i) { String modelVersion; @@ -392,19 +395,9 @@ String mapModelUrl(Path path) { String modelName = ModelInfo.inferModelNameFromUrl(url); String engine; if (Files.isDirectory(path)) { - engine = inferEngine(path); + engine = inferEngine(path, path.toFile().getName()); } else { - try { - Repository repository = Repository.newInstance("modelStore", url); - List mrls = repository.getResources(); - Artifact artifact = mrls.get(0).getDefaultArtifact(); - repository.prepare(artifact); - Path modelDir = repository.getResourceDirectory(artifact); - engine = inferEngine(modelDir); - } catch (IOException e) { - logger.warn("Failed to extract model: " + path, e); - return null; - } + engine = inferEngineFromUrl(url); } if (engine == null) { return null; @@ -418,7 +411,21 @@ String mapModelUrl(Path path) { } } - private String inferEngine(Path modelDir) { + private String inferEngineFromUrl(String modelUrl) { + try { + Repository repository = Repository.newInstance("modelStore", modelUrl); + List mrls = repository.getResources(); + Artifact artifact = mrls.get(0).getDefaultArtifact(); + repository.prepare(artifact); + Path modelDir = repository.getResourceDirectory(artifact); + return inferEngine(modelDir, artifact.getName()); + } catch (IOException e) { + logger.warn("Failed to extract model: " + modelUrl, e); + return null; + } + } + + private String inferEngine(Path modelDir, String modelName) { Path file = modelDir.resolve("serving.properties"); if (Files.isRegularFile(file)) { Properties prop = new Properties(); @@ -428,38 +435,37 @@ private String inferEngine(Path modelDir) { if (engine != null) { return engine; } + modelName = prop.getProperty("modelName"); } catch (IOException e) { logger.warn("Failed read serving.properties file", e); } } - - String dirName = modelDir.toFile().getName(); if (Files.isDirectory(modelDir.resolve("MAR-INF")) || Files.isRegularFile(modelDir.resolve("model.py")) - || Files.isRegularFile(modelDir.resolve(dirName + ".py"))) { + || Files.isRegularFile(modelDir.resolve(modelName + ".py"))) { // MMS/TorchServe return "Python"; - } else if (Files.isRegularFile(modelDir.resolve(dirName + ".pt"))) { + } else if (Files.isRegularFile(modelDir.resolve(modelName + ".pt"))) { return "PyTorch"; } else if (Files.isRegularFile(modelDir.resolve("saved_model.pb"))) { return "TensorFlow"; - } else if (Files.isRegularFile(modelDir.resolve(dirName + "-symbol.json"))) { + } else if (Files.isRegularFile(modelDir.resolve(modelName + "-symbol.json"))) { return "MXNet"; - } else if (Files.isRegularFile(modelDir.resolve(dirName + ".onnx"))) { + } else if (Files.isRegularFile(modelDir.resolve(modelName + ".onnx"))) { return "OnnxRuntime"; - } else if (Files.isRegularFile(modelDir.resolve(dirName + ".trt")) - || Files.isRegularFile(modelDir.resolve(dirName + ".uff"))) { + } else if (Files.isRegularFile(modelDir.resolve(modelName + ".trt")) + || Files.isRegularFile(modelDir.resolve(modelName + ".uff"))) { return "TensorRT"; - } else if (Files.isRegularFile(modelDir.resolve(dirName + ".tflite"))) { + } else if (Files.isRegularFile(modelDir.resolve(modelName + ".tflite"))) { return "TFLite"; } else if (Files.isRegularFile(modelDir.resolve("model")) || Files.isRegularFile(modelDir.resolve("__model__")) || Files.isRegularFile(modelDir.resolve("inference.pdmodel"))) { return "PaddlePaddle"; - } else if (Files.isRegularFile(modelDir.resolve(dirName + ".json"))) { + } else if (Files.isRegularFile(modelDir.resolve(modelName + ".json"))) { return "XGBoost"; - } else if (Files.isRegularFile(modelDir.resolve(dirName + ".dylib")) - || Files.isRegularFile(modelDir.resolve(dirName + ".so"))) { + } else if (Files.isRegularFile(modelDir.resolve(modelName + ".dylib")) + || Files.isRegularFile(modelDir.resolve(modelName + ".so"))) { return "DLR"; } logger.warn("Failed to detect engine of the model: " + modelDir); diff --git a/serving/src/main/java/ai/djl/serving/models/ModelManager.java b/serving/src/main/java/ai/djl/serving/models/ModelManager.java index 510cfabfa..35f54fbce 100644 --- a/serving/src/main/java/ai/djl/serving/models/ModelManager.java +++ b/serving/src/main/java/ai/djl/serving/models/ModelManager.java @@ -14,6 +14,7 @@ import ai.djl.Device; import ai.djl.ModelException; +import ai.djl.engine.Engine; import ai.djl.modality.Input; import ai.djl.modality.Output; import ai.djl.repository.zoo.Criteria; @@ -113,7 +114,13 @@ public CompletableFuture registerWorkflow( .optModelUrls(modelUrl) .optEngine(engineName); if ("-1".equals(deviceName)) { - logger.info("Loading model {} on {}.", modelName, Device.cpu()); + Device device; + if (engineName == null) { + device = Device.cpu(); + } else { + device = Engine.getEngine(engineName).defaultDevice(); + } + logger.info("Loading model {} on {}.", modelName, device); } else if (deviceName.startsWith("nc")) { logger.info("Loading model {} on {}.", modelName, deviceName); String ncs = deviceName.substring(2);