Skip to content

Commit

Permalink
[serving] Refactor ModelDefinition class (#67)
Browse files Browse the repository at this point in the history
* [serving] Refactor ModelDefinition class

* [serving] Add async model loading on startup

* [serving] Start HTTP listener while model is loading
  • Loading branch information
frankfliu authored Mar 11, 2022
1 parent 028aff4 commit 5c27dba
Show file tree
Hide file tree
Showing 18 changed files with 534 additions and 459 deletions.
59 changes: 40 additions & 19 deletions serving/src/main/java/ai/djl/serving/ModelServer.java
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,11 @@ public List<ChannelFuture> start()
futures.add(initializeServer(managementConnector, serverGroup, workerGroup));
}

if (stopped.get()) {
// check if model load failed in wait loading model case
stop();
}

return futures;
}

Expand All @@ -190,10 +195,7 @@ public boolean isRunning() {

/** Stops the model server. */
public void stop() {
if (stopped.get()) {
return;
}

logger.info("Stopping model server.");
stopped.set(true);
for (ChannelFuture future : futures) {
future.channel().close();
Expand Down Expand Up @@ -261,7 +263,6 @@ private ChannelFuture initializeServer(
}

private void initModelStore() throws IOException {
ModelManager.init(configManager);
Set<String> startupModels = ModelManager.getInstance().getStartupModels();

String loadModels = configManager.getLoadModels();
Expand Down Expand Up @@ -311,10 +312,10 @@ private void initModelStore() throws IOException {
String version = null;
String engine = null;
String[] devices = {"-1"};
String modelName;
String workflowName;
if (endpoint != null) {
String[] tokens = endpoint.split(":", -1);
modelName = tokens[0];
workflowName = tokens[0];
if (tokens.length > 1) {
version = tokens[1].isEmpty() ? null : tokens[1];
}
Expand All @@ -336,20 +337,20 @@ private void initModelStore() throws IOException {
.mapToObj(i -> "nc" + i)
.toArray(String[]::new);
}

} else if (!tokens[3].isEmpty()) {
devices = tokens[3].split(";");
}
}
} else {
modelName = ModelInfo.inferModelNameFromUrl(modelUrl);
workflowName = ModelInfo.inferModelNameFromUrl(modelUrl);
}
if (engine == null) {
engine = inferEngineFromUrl(modelUrl);
}

for (int i = 0; i < devices.length; ++i) {
String modelVersion;
String device = devices[i];
if (devices.length > 1) {
if (version == null) {
modelVersion = "v" + i;
Expand All @@ -359,20 +360,40 @@ private void initModelStore() throws IOException {
} else {
modelVersion = version;
}
CompletableFuture<Workflow> future =
modelManager.registerWorkflow(
modelName,
modelVersion,
ModelInfo modelInfo =
new ModelInfo(
workflowName,
modelUrl,
modelVersion,
engine,
devices[i],
configManager.getBatchSize(),
configManager.getJobQueueSize(),
configManager.getMaxIdleTime(),
configManager.getMaxBatchDelay(),
configManager.getMaxIdleTime());
Workflow workflow = future.join();
modelManager.scaleWorkers(workflow, devices[i], 1, -1);
configManager.getBatchSize());
Workflow workflow = new Workflow(modelInfo);

CompletableFuture<Void> f =
modelManager
.registerWorkflow(workflow, device)
.thenAccept(v -> modelManager.scaleWorkers(workflow, device, 1, -1))
.exceptionally(
t -> {
logger.error("Failed register workflow", t);
// delay 3 seconds, allows REST API to send PING
// response (health check)
try {
Thread.sleep(3000);
} catch (InterruptedException ignore) {
// ignore
}
stop();
return null;
});
if (configManager.waitModelLoading()) {
f.join();
}
}
startupModels.add(modelName);
startupModels.add(workflowName);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
import ai.djl.serving.models.ModelManager;
import ai.djl.serving.util.ConfigManager;
import ai.djl.serving.util.NettyUtils;
import ai.djl.serving.wlm.util.WlmCapacityException;
import ai.djl.serving.wlm.util.WlmShutdownException;
import ai.djl.serving.wlm.ModelInfo;
import ai.djl.serving.wlm.util.WlmException;
import ai.djl.serving.workflow.Workflow;
import ai.djl.translate.TranslateException;
import io.netty.channel.ChannelHandlerContext;
Expand Down Expand Up @@ -73,16 +73,9 @@ protected void handleRequest(
throws ModelException {
switch (segments[1]) {
case "ping":
// TODO: Check if its OK to send other 2xx errors to ALB for "Partial Healthy"
// and "Unhealthy"
ModelManager.getInstance()
.workerStatus()
.thenAccept(
response ->
NettyUtils.sendJsonResponse(
ctx,
new StatusResponse(response),
HttpResponseStatus.OK));
.thenAccept(r -> NettyUtils.sendHttpResponse(ctx, r, true));
break;
case "invocations":
handleInvocations(ctx, req, decoder);
Expand Down Expand Up @@ -171,18 +164,21 @@ private void predict(
String deviceName = input.getProperty("device", "-1");

logger.info("Loading model {} from: {}", workflowName, modelUrl);

modelManager
.registerWorkflow(
ModelInfo modelInfo =
new ModelInfo(
workflowName,
version,
modelUrl,
version,
engineName,
deviceName,
config.getBatchSize(),
config.getJobQueueSize(),
config.getMaxIdleTime(),
config.getMaxBatchDelay(),
config.getMaxIdleTime())
.thenApply(p -> modelManager.scaleWorkers(p, deviceName, 1, -1))
config.getBatchSize());
Workflow wf = new Workflow(modelInfo);

modelManager
.registerWorkflow(wf, deviceName)
.thenApply(p -> modelManager.scaleWorkers(wf, deviceName, 1, -1))
.thenAccept(p -> runJob(modelManager, ctx, p, input));
return;
}
Expand Down Expand Up @@ -243,11 +239,8 @@ void onException(Throwable t, ChannelHandlerContext ctx) {
HttpResponseStatus status;
if (t instanceof TranslateException) {
status = HttpResponseStatus.BAD_REQUEST;
} else if (t instanceof WlmShutdownException) {
logger.info(t.getMessage());
status = HttpResponseStatus.SERVICE_UNAVAILABLE;
} else if (t instanceof WlmCapacityException) {
logger.warn(t.getMessage());
} else if (t instanceof WlmException) {
logger.warn(t.getMessage(), t);
status = HttpResponseStatus.SERVICE_UNAVAILABLE;
} else {
logger.warn("Unexpected error", t);
Expand Down
19 changes: 16 additions & 3 deletions serving/src/main/java/ai/djl/serving/http/ListModelsResponse.java
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,10 @@ public List<ModelItem> getModels() {
* @param modelName the model name
* @param version the mode version
* @param modelUrl the model url
* @param status the model loading status
*/
public void addModel(String modelName, String version, String modelUrl) {
models.add(new ModelItem(modelName, version, modelUrl));
public void addModel(String modelName, String version, String modelUrl, String status) {
models.add(new ModelItem(modelName, version, modelUrl, status));
}

/** A class that holds model name and url. */
Expand All @@ -70,6 +71,7 @@ public static final class ModelItem {
private String modelName;
private String version;
private String modelUrl;
private String status;

/** Constructs a new {@code ModelItem} instance. */
public ModelItem() {}
Expand All @@ -80,11 +82,13 @@ public ModelItem() {}
* @param modelName the model name
* @param version the model version
* @param modelUrl the model url
* @param status the model loading status
*/
public ModelItem(String modelName, String version, String modelUrl) {
public ModelItem(String modelName, String version, String modelUrl, String status) {
this.modelName = modelName;
this.version = version;
this.modelUrl = modelUrl;
this.status = status;
}

/**
Expand Down Expand Up @@ -113,5 +117,14 @@ public String getVersion() {
public String getModelUrl() {
return modelUrl;
}

/**
* Returns the model loading status.
*
* @return the model loading status
*/
public String getStatus() {
return status;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.serving.models.Endpoint;
import ai.djl.serving.models.ModelManager;
import ai.djl.serving.util.ConfigManager;
import ai.djl.serving.util.NettyUtils;
import ai.djl.serving.wlm.ModelInfo;
import ai.djl.serving.wlm.WorkLoadManager.WorkerPool;
Expand Down Expand Up @@ -137,9 +138,19 @@ private void handleListModels(ChannelHandlerContext ctx, QueryStringDecoder deco
}

for (int i = pageToken; i < last; ++i) {
String modelName = keys.get(i);
for (Workflow m : endpoints.get(modelName).getWorkflows()) {
list.addModel(modelName, m.getVersion(), m.getUrl());
String workflowName = keys.get(i);
for (Workflow workflow : endpoints.get(workflowName).getWorkflows()) {
for (ModelInfo m : workflow.getModels()) {
String status = m.getStatus().toString();
String id = m.getModelId();
String modelName;
if (workflowName.equals(id)) {
modelName = workflowName;
} else {
modelName = workflowName + ':' + id;
}
list.addModel(modelName, workflow.getVersion(), m.getModelUrl(), status);
}
}
}

Expand Down Expand Up @@ -185,40 +196,42 @@ private void handleRegisterModel(final ChannelHandlerContext ctx, QueryStringDec
Boolean.parseBoolean(
NettyUtils.getParameter(decoder, SYNCHRONOUS_PARAMETER, "true"));

final ModelManager modelManager = ModelManager.getInstance();
CompletableFuture<Workflow> future =
modelManager.registerWorkflow(
ModelInfo modelInfo =
new ModelInfo(
modelName,
version,
modelUrl,
version,
engineName,
deviceName,
batchSize,
ConfigManager.getInstance().getJobQueueSize(),
maxIdleTime,
maxBatchDelay,
maxIdleTime);
batchSize);
Workflow workflow = new Workflow(modelInfo);
final ModelManager modelManager = ModelManager.getInstance();
CompletableFuture<Void> f =
future.thenAccept(
p -> {
for (ModelInfo m : p.getModels()) {
m.configurePool(maxIdleTime)
.configureModelBatch(batchSize, maxBatchDelay);
modelManager.scaleWorkers(m, deviceName, minWorkers, maxWorkers);
}
});

modelManager
.registerWorkflow(workflow, deviceName)
.thenAccept(
v -> {
for (ModelInfo m : workflow.getModels()) {
m.configurePool(maxIdleTime)
.configureModelBatch(batchSize, maxBatchDelay);
modelManager.scaleWorkers(
m, deviceName, minWorkers, maxWorkers);
}
})
.exceptionally(
t -> {
NettyUtils.sendError(ctx, t.getCause());
return null;
});
if (synchronous) {
final String msg = "Model \"" + modelName + "\" registered.";
f = f.thenAccept(m -> NettyUtils.sendJsonResponse(ctx, new StatusResponse(msg)));
f.thenAccept(v -> NettyUtils.sendJsonResponse(ctx, new StatusResponse(msg)));
} else {
String msg = "Model \"" + modelName + "\" registration scheduled.";
NettyUtils.sendJsonResponse(ctx, new StatusResponse(msg), HttpResponseStatus.ACCEPTED);
}

f.exceptionally(
t -> {
NettyUtils.sendError(ctx, t.getCause());
return null;
});
}

private void handleUnregisterModel(ChannelHandlerContext ctx, String modelName, String version)
Expand All @@ -240,6 +253,14 @@ private void handleScaleModel(
if (workflow == null) {
throw new ModelNotFoundException("Model not found: " + modelName);
}

// make sure all models are loaded and ready
for (ModelInfo modelInfo : workflow.getModels()) {
if (modelInfo.getStatus() != ModelInfo.Status.READY) {
throw new ServiceUnavailableException("Model is not ready: " + modelName);
}
}

List<String> msgs = new ArrayList<>();
for (ModelInfo modelInfo : workflow.getModels()) {
WorkerPool pool =
Expand Down
Loading

0 comments on commit 5c27dba

Please sign in to comment.