Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[serving] Detect engine when load model from URL #63

Merged
merged 1 commit into from
Feb 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 31 additions & 25 deletions serving/src/main/java/ai/djl/serving/ModelServer.java
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,9 @@ private void initModelStore() throws IOException {
} else {
modelName = ModelInfo.inferModelNameFromUrl(modelUrl);
}
if (engine == null) {
engine = inferEngineFromUrl(modelUrl);
}

for (int i = 0; i < devices.length; ++i) {
String modelVersion;
Expand Down Expand Up @@ -392,19 +395,9 @@ String mapModelUrl(Path path) {
String modelName = ModelInfo.inferModelNameFromUrl(url);
String engine;
if (Files.isDirectory(path)) {
engine = inferEngine(path);
engine = inferEngine(path, path.toFile().getName());
} else {
try {
Repository repository = Repository.newInstance("modelStore", url);
List<MRL> mrls = repository.getResources();
Artifact artifact = mrls.get(0).getDefaultArtifact();
repository.prepare(artifact);
Path modelDir = repository.getResourceDirectory(artifact);
engine = inferEngine(modelDir);
} catch (IOException e) {
logger.warn("Failed to extract model: " + path, e);
return null;
}
engine = inferEngineFromUrl(url);
}
if (engine == null) {
return null;
Expand All @@ -418,7 +411,21 @@ String mapModelUrl(Path path) {
}
}

private String inferEngine(Path modelDir) {
private String inferEngineFromUrl(String modelUrl) {
try {
Repository repository = Repository.newInstance("modelStore", modelUrl);
List<MRL> mrls = repository.getResources();
Artifact artifact = mrls.get(0).getDefaultArtifact();
repository.prepare(artifact);
Path modelDir = repository.getResourceDirectory(artifact);
return inferEngine(modelDir, artifact.getName());
} catch (IOException e) {
logger.warn("Failed to extract model: " + modelUrl, e);
return null;
}
}

private String inferEngine(Path modelDir, String modelName) {
Path file = modelDir.resolve("serving.properties");
if (Files.isRegularFile(file)) {
Properties prop = new Properties();
Expand All @@ -428,38 +435,37 @@ private String inferEngine(Path modelDir) {
if (engine != null) {
return engine;
}
modelName = prop.getProperty("modelName");
} catch (IOException e) {
logger.warn("Failed read serving.properties file", e);
}
}

String dirName = modelDir.toFile().getName();
if (Files.isDirectory(modelDir.resolve("MAR-INF"))
|| Files.isRegularFile(modelDir.resolve("model.py"))
|| Files.isRegularFile(modelDir.resolve(dirName + ".py"))) {
|| Files.isRegularFile(modelDir.resolve(modelName + ".py"))) {
// MMS/TorchServe
return "Python";
} else if (Files.isRegularFile(modelDir.resolve(dirName + ".pt"))) {
} else if (Files.isRegularFile(modelDir.resolve(modelName + ".pt"))) {
return "PyTorch";
} else if (Files.isRegularFile(modelDir.resolve("saved_model.pb"))) {
return "TensorFlow";
} else if (Files.isRegularFile(modelDir.resolve(dirName + "-symbol.json"))) {
} else if (Files.isRegularFile(modelDir.resolve(modelName + "-symbol.json"))) {
return "MXNet";
} else if (Files.isRegularFile(modelDir.resolve(dirName + ".onnx"))) {
} else if (Files.isRegularFile(modelDir.resolve(modelName + ".onnx"))) {
return "OnnxRuntime";
} else if (Files.isRegularFile(modelDir.resolve(dirName + ".trt"))
|| Files.isRegularFile(modelDir.resolve(dirName + ".uff"))) {
} else if (Files.isRegularFile(modelDir.resolve(modelName + ".trt"))
|| Files.isRegularFile(modelDir.resolve(modelName + ".uff"))) {
return "TensorRT";
} else if (Files.isRegularFile(modelDir.resolve(dirName + ".tflite"))) {
} else if (Files.isRegularFile(modelDir.resolve(modelName + ".tflite"))) {
return "TFLite";
} else if (Files.isRegularFile(modelDir.resolve("model"))
|| Files.isRegularFile(modelDir.resolve("__model__"))
|| Files.isRegularFile(modelDir.resolve("inference.pdmodel"))) {
return "PaddlePaddle";
} else if (Files.isRegularFile(modelDir.resolve(dirName + ".json"))) {
} else if (Files.isRegularFile(modelDir.resolve(modelName + ".json"))) {
return "XGBoost";
} else if (Files.isRegularFile(modelDir.resolve(dirName + ".dylib"))
|| Files.isRegularFile(modelDir.resolve(dirName + ".so"))) {
} else if (Files.isRegularFile(modelDir.resolve(modelName + ".dylib"))
|| Files.isRegularFile(modelDir.resolve(modelName + ".so"))) {
return "DLR";
}
logger.warn("Failed to detect engine of the model: " + modelDir);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import ai.djl.Device;
import ai.djl.ModelException;
import ai.djl.engine.Engine;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.repository.zoo.Criteria;
Expand Down Expand Up @@ -113,7 +114,13 @@ public CompletableFuture<Workflow> registerWorkflow(
.optModelUrls(modelUrl)
.optEngine(engineName);
if ("-1".equals(deviceName)) {
logger.info("Loading model {} on {}.", modelName, Device.cpu());
Device device;
if (engineName == null) {
device = Device.cpu();
} else {
device = Engine.getEngine(engineName).defaultDevice();
}
logger.info("Loading model {} on {}.", modelName, device);
} else if (deviceName.startsWith("nc")) {
logger.info("Loading model {} on {}.", modelName, deviceName);
String ncs = deviceName.substring(2);
Expand Down