Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[serving] Refactor ModelDefinition class #67

Merged
merged 3 commits into from
Mar 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 40 additions & 19 deletions serving/src/main/java/ai/djl/serving/ModelServer.java
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,11 @@ public List<ChannelFuture> start()
futures.add(initializeServer(managementConnector, serverGroup, workerGroup));
}

if (stopped.get()) {
// check if model load failed in wait loading model case
stop();
}

return futures;
}

Expand All @@ -190,10 +195,7 @@ public boolean isRunning() {

/** Stops the model server. */
public void stop() {
if (stopped.get()) {
return;
}

logger.info("Stopping model server.");
stopped.set(true);
for (ChannelFuture future : futures) {
future.channel().close();
Expand Down Expand Up @@ -261,7 +263,6 @@ private ChannelFuture initializeServer(
}

private void initModelStore() throws IOException {
ModelManager.init(configManager);
Set<String> startupModels = ModelManager.getInstance().getStartupModels();

String loadModels = configManager.getLoadModels();
Expand Down Expand Up @@ -311,10 +312,10 @@ private void initModelStore() throws IOException {
String version = null;
String engine = null;
String[] devices = {"-1"};
String modelName;
String workflowName;
if (endpoint != null) {
String[] tokens = endpoint.split(":", -1);
modelName = tokens[0];
workflowName = tokens[0];
if (tokens.length > 1) {
version = tokens[1].isEmpty() ? null : tokens[1];
}
Expand All @@ -336,20 +337,20 @@ private void initModelStore() throws IOException {
.mapToObj(i -> "nc" + i)
.toArray(String[]::new);
}

} else if (!tokens[3].isEmpty()) {
devices = tokens[3].split(";");
}
}
} else {
modelName = ModelInfo.inferModelNameFromUrl(modelUrl);
workflowName = ModelInfo.inferModelNameFromUrl(modelUrl);
}
if (engine == null) {
engine = inferEngineFromUrl(modelUrl);
}

for (int i = 0; i < devices.length; ++i) {
String modelVersion;
String device = devices[i];
if (devices.length > 1) {
if (version == null) {
modelVersion = "v" + i;
Expand All @@ -359,20 +360,40 @@ private void initModelStore() throws IOException {
} else {
modelVersion = version;
}
CompletableFuture<Workflow> future =
modelManager.registerWorkflow(
modelName,
modelVersion,
ModelInfo modelInfo =
new ModelInfo(
workflowName,
modelUrl,
modelVersion,
engine,
devices[i],
configManager.getBatchSize(),
configManager.getJobQueueSize(),
configManager.getMaxIdleTime(),
configManager.getMaxBatchDelay(),
configManager.getMaxIdleTime());
Workflow workflow = future.join();
modelManager.scaleWorkers(workflow, devices[i], 1, -1);
configManager.getBatchSize());
Workflow workflow = new Workflow(modelInfo);

CompletableFuture<Void> f =
modelManager
.registerWorkflow(workflow, device)
.thenAccept(v -> modelManager.scaleWorkers(workflow, device, 1, -1))
.exceptionally(
t -> {
logger.error("Failed register workflow", t);
// delay 3 seconds, allows REST API to send PING
// response (health check)
try {
Thread.sleep(3000);
} catch (InterruptedException ignore) {
// ignore
}
stop();
return null;
});
if (configManager.waitModelLoading()) {
f.join();
}
}
startupModels.add(modelName);
startupModels.add(workflowName);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
import ai.djl.serving.models.ModelManager;
import ai.djl.serving.util.ConfigManager;
import ai.djl.serving.util.NettyUtils;
import ai.djl.serving.wlm.util.WlmCapacityException;
import ai.djl.serving.wlm.util.WlmShutdownException;
import ai.djl.serving.wlm.ModelInfo;
import ai.djl.serving.wlm.util.WlmException;
import ai.djl.serving.workflow.Workflow;
import ai.djl.translate.TranslateException;
import io.netty.channel.ChannelHandlerContext;
Expand Down Expand Up @@ -73,16 +73,9 @@ protected void handleRequest(
throws ModelException {
switch (segments[1]) {
case "ping":
// TODO: Check if its OK to send other 2xx errors to ALB for "Partial Healthy"
// and "Unhealthy"
ModelManager.getInstance()
.workerStatus()
.thenAccept(
response ->
NettyUtils.sendJsonResponse(
ctx,
new StatusResponse(response),
HttpResponseStatus.OK));
.thenAccept(r -> NettyUtils.sendHttpResponse(ctx, r, true));
break;
case "invocations":
handleInvocations(ctx, req, decoder);
Expand Down Expand Up @@ -171,18 +164,21 @@ private void predict(
String deviceName = input.getProperty("device", "-1");

logger.info("Loading model {} from: {}", workflowName, modelUrl);

modelManager
.registerWorkflow(
ModelInfo modelInfo =
new ModelInfo(
workflowName,
version,
modelUrl,
version,
engineName,
deviceName,
config.getBatchSize(),
config.getJobQueueSize(),
config.getMaxIdleTime(),
config.getMaxBatchDelay(),
config.getMaxIdleTime())
.thenApply(p -> modelManager.scaleWorkers(p, deviceName, 1, -1))
config.getBatchSize());
Workflow wf = new Workflow(modelInfo);

modelManager
.registerWorkflow(wf, deviceName)
.thenApply(p -> modelManager.scaleWorkers(wf, deviceName, 1, -1))
.thenAccept(p -> runJob(modelManager, ctx, p, input));
return;
}
Expand Down Expand Up @@ -243,11 +239,8 @@ void onException(Throwable t, ChannelHandlerContext ctx) {
HttpResponseStatus status;
if (t instanceof TranslateException) {
status = HttpResponseStatus.BAD_REQUEST;
} else if (t instanceof WlmShutdownException) {
logger.info(t.getMessage());
status = HttpResponseStatus.SERVICE_UNAVAILABLE;
} else if (t instanceof WlmCapacityException) {
logger.warn(t.getMessage());
} else if (t instanceof WlmException) {
logger.warn(t.getMessage(), t);
status = HttpResponseStatus.SERVICE_UNAVAILABLE;
} else {
logger.warn("Unexpected error", t);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,10 @@ public List<ModelItem> getModels() {
* @param modelName the model name
* @param version the mode version
* @param modelUrl the model url
* @param status the model loading status
*/
public void addModel(String modelName, String version, String modelUrl) {
models.add(new ModelItem(modelName, version, modelUrl));
public void addModel(String modelName, String version, String modelUrl, String status) {
models.add(new ModelItem(modelName, version, modelUrl, status));
}

/** A class that holds model name and url. */
Expand All @@ -70,6 +71,7 @@ public static final class ModelItem {
private String modelName;
private String version;
private String modelUrl;
private String status;

/** Constructs a new {@code ModelItem} instance. */
public ModelItem() {}
Expand All @@ -80,11 +82,13 @@ public ModelItem() {}
* @param modelName the model name
* @param version the model version
* @param modelUrl the model url
* @param status the model loading status
*/
public ModelItem(String modelName, String version, String modelUrl) {
public ModelItem(String modelName, String version, String modelUrl, String status) {
this.modelName = modelName;
this.version = version;
this.modelUrl = modelUrl;
this.status = status;
}

/**
Expand Down Expand Up @@ -113,5 +117,14 @@ public String getVersion() {
public String getModelUrl() {
return modelUrl;
}

/**
* Returns the model loading status.
*
* @return the model loading status
*/
public String getStatus() {
return status;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.serving.models.Endpoint;
import ai.djl.serving.models.ModelManager;
import ai.djl.serving.util.ConfigManager;
import ai.djl.serving.util.NettyUtils;
import ai.djl.serving.wlm.ModelInfo;
import ai.djl.serving.wlm.WorkLoadManager.WorkerPool;
Expand Down Expand Up @@ -137,9 +138,19 @@ private void handleListModels(ChannelHandlerContext ctx, QueryStringDecoder deco
}

for (int i = pageToken; i < last; ++i) {
String modelName = keys.get(i);
for (Workflow m : endpoints.get(modelName).getWorkflows()) {
list.addModel(modelName, m.getVersion(), m.getUrl());
String workflowName = keys.get(i);
for (Workflow workflow : endpoints.get(workflowName).getWorkflows()) {
for (ModelInfo m : workflow.getModels()) {
String status = m.getStatus().toString();
String id = m.getModelId();
String modelName;
if (workflowName.equals(id)) {
modelName = workflowName;
} else {
modelName = workflowName + ':' + id;
}
list.addModel(modelName, workflow.getVersion(), m.getModelUrl(), status);
}
}
}

Expand Down Expand Up @@ -185,40 +196,42 @@ private void handleRegisterModel(final ChannelHandlerContext ctx, QueryStringDec
Boolean.parseBoolean(
NettyUtils.getParameter(decoder, SYNCHRONOUS_PARAMETER, "true"));

final ModelManager modelManager = ModelManager.getInstance();
CompletableFuture<Workflow> future =
modelManager.registerWorkflow(
ModelInfo modelInfo =
new ModelInfo(
modelName,
version,
modelUrl,
version,
engineName,
deviceName,
batchSize,
ConfigManager.getInstance().getJobQueueSize(),
maxIdleTime,
maxBatchDelay,
maxIdleTime);
batchSize);
Workflow workflow = new Workflow(modelInfo);
final ModelManager modelManager = ModelManager.getInstance();
CompletableFuture<Void> f =
future.thenAccept(
p -> {
for (ModelInfo m : p.getModels()) {
m.configurePool(maxIdleTime)
.configureModelBatch(batchSize, maxBatchDelay);
modelManager.scaleWorkers(m, deviceName, minWorkers, maxWorkers);
}
});

modelManager
.registerWorkflow(workflow, deviceName)
.thenAccept(
v -> {
for (ModelInfo m : workflow.getModels()) {
m.configurePool(maxIdleTime)
.configureModelBatch(batchSize, maxBatchDelay);
modelManager.scaleWorkers(
m, deviceName, minWorkers, maxWorkers);
}
})
.exceptionally(
t -> {
NettyUtils.sendError(ctx, t.getCause());
return null;
});
if (synchronous) {
final String msg = "Model \"" + modelName + "\" registered.";
f = f.thenAccept(m -> NettyUtils.sendJsonResponse(ctx, new StatusResponse(msg)));
f.thenAccept(v -> NettyUtils.sendJsonResponse(ctx, new StatusResponse(msg)));
} else {
String msg = "Model \"" + modelName + "\" registration scheduled.";
NettyUtils.sendJsonResponse(ctx, new StatusResponse(msg), HttpResponseStatus.ACCEPTED);
}

f.exceptionally(
t -> {
NettyUtils.sendError(ctx, t.getCause());
return null;
});
}

private void handleUnregisterModel(ChannelHandlerContext ctx, String modelName, String version)
Expand All @@ -240,6 +253,14 @@ private void handleScaleModel(
if (workflow == null) {
throw new ModelNotFoundException("Model not found: " + modelName);
}

// make sure all models are loaded and ready
for (ModelInfo modelInfo : workflow.getModels()) {
if (modelInfo.getStatus() != ModelInfo.Status.READY) {
throw new ServiceUnavailableException("Model is not ready: " + modelName);
}
}

List<String> msgs = new ArrayList<>();
for (ModelInfo modelInfo : workflow.getModels()) {
WorkerPool pool =
Expand Down
Loading