Skip to content

Commit

Permalink
[serving] Fix unregister model regression (#1101)
Browse files Browse the repository at this point in the history
1. Update register model API
2. Improve error message for invalid input data
3. Fix bug in unregsiter model
4. Update default number of workers

Change-Id: I5c293d26f830f5fe37708ee039ad9918c4fa5c2e
  • Loading branch information
frankfliu authored Jul 15, 2021
1 parent 308e149 commit 1763a2d
Show file tree
Hide file tree
Showing 9 changed files with 51 additions and 35 deletions.
2 changes: 1 addition & 1 deletion api/src/main/java/ai/djl/inference/Predictor.java
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ public List<O> batchPredict(List<I> inputs) throws TranslateException {
return ret;
} catch (EngineException e) {
throw new TranslateException(e);
} catch (RuntimeException e) {
} catch (TranslateException | RuntimeException e) {
throw e;
} catch (Exception e) {
throw new TranslateException(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ public Batchifier getBatchifier() {

/** {@inheritDoc} */
@Override
public NDList processInput(TranslatorContext ctx, Input input) {
public NDList processInput(TranslatorContext ctx, Input input) throws TranslateException {
ctx.setAttachment("input", input);
PairList<String, byte[]> inputs = input.getContent();
byte[] data = inputs.get("data");
Expand All @@ -294,7 +294,11 @@ public NDList processInput(TranslatorContext ctx, Input input) {
data = input.getContent().valueAt(0);
}
NDManager manager = ctx.getNDManager();
return NDList.decode(manager, data);
try {
return NDList.decode(manager, data);
} catch (IllegalArgumentException e) {
throw new TranslateException("Input is not a NDList data type", e);
}
}

/** {@inheritDoc} */
Expand Down
4 changes: 2 additions & 2 deletions serving/serving/src/main/java/ai/djl/serving/ModelServer.java
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ public static void main(String[] args) {
CommandLine cmd = parser.parse(options, args, null, false);
Arguments arguments = new Arguments(cmd);
if (arguments.hasHelp()) {
printHelp("model-server [OPTIONS]", options);
printHelp("djl-serving [OPTIONS]", options);
return;
}

Expand Down Expand Up @@ -340,7 +340,7 @@ private void initModelStore() throws IOException {
configManager.getMaxBatchDelay(),
configManager.getMaxIdleTime());
ModelInfo modelInfo = future.join();
modelManager.triggerModelUpdated(modelInfo.scaleWorkers(workers, workers));
modelManager.triggerModelUpdated(modelInfo.scaleWorkers(1, workers));
startupModels.add(modelName);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@

import ai.djl.ModelException;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.serving.util.ConfigManager;
import ai.djl.serving.util.NettyUtils;
import ai.djl.serving.wlm.Endpoint;
import ai.djl.serving.wlm.ModelInfo;
import ai.djl.serving.wlm.ModelManager;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.QueryStringDecoder;
import java.util.ArrayList;
import java.util.Collections;
Expand All @@ -38,8 +40,6 @@ public class ManagementRequestHandler extends HttpRequestHandler {

/** HTTP Parameter "synchronous". */
private static final String SYNCHRONOUS_PARAMETER = "synchronous";
/** HTTP Parameter "initial_workers". */
private static final String INITIAL_WORKERS_PARAMETER = "initial_workers";
/** HTTP Parameter "url". */
private static final String URL_PARAMETER = "url";
/** HTTP Parameter "batch_size". */
Expand All @@ -48,8 +48,8 @@ public class ManagementRequestHandler extends HttpRequestHandler {
private static final String MODEL_NAME_PARAMETER = "model_name";
/** HTTP Parameter "model_version". */
private static final String MODEL_VERSION_PARAMETER = "model_version";
/** HTTP Parameter "engine_name". */
private static final String ENGINE_NAME_PARAMETER = "engine_name";
/** HTTP Parameter "engine". */
private static final String ENGINE_NAME_PARAMETER = "engine";
/** HTTP Parameter "gpu_id". */
private static final String GPU_ID_PARAMETER = "gpu_id";
/** HTTP Parameter "max_batch_delay". */
Expand Down Expand Up @@ -167,8 +167,9 @@ private void handleRegisterModel(final ChannelHandlerContext ctx, QueryStringDec
int batchSize = NettyUtils.getIntParameter(decoder, BATCH_SIZE_PARAMETER, 1);
int maxBatchDelay = NettyUtils.getIntParameter(decoder, MAX_BATCH_DELAY_PARAMETER, 100);
int maxIdleTime = NettyUtils.getIntParameter(decoder, MAX_IDLE_TIME__PARAMETER, 60);
final int initialWorkers =
NettyUtils.getIntParameter(decoder, INITIAL_WORKERS_PARAMETER, 1);
int minWorkers = NettyUtils.getIntParameter(decoder, MIN_WORKER_PARAMETER, 1);
int defaultWorkers = ConfigManager.getInstance().getDefaultWorkers();
int maxWorkers = NettyUtils.getIntParameter(decoder, MAX_WORKER_PARAMETER, defaultWorkers);
boolean synchronous =
Boolean.parseBoolean(
NettyUtils.getParameter(decoder, SYNCHRONOUS_PARAMETER, "true"));
Expand All @@ -188,16 +189,16 @@ private void handleRegisterModel(final ChannelHandlerContext ctx, QueryStringDec
future.thenAccept(
m ->
modelManager.triggerModelUpdated(
m.scaleWorkers(initialWorkers, initialWorkers)
.configurePool(maxIdleTime, maxBatchDelay)
.configureModelBatch(batchSize)));
m.scaleWorkers(minWorkers, maxWorkers)
.configurePool(maxIdleTime)
.configureModelBatch(batchSize, maxBatchDelay)));

if (synchronous) {
final String msg = "Model \"" + modelName + "\" registered.";
f = f.thenAccept(m -> NettyUtils.sendJsonResponse(ctx, new StatusResponse(msg)));
} else {
String msg = "Model \"" + modelName + "\" registration scheduled.";
NettyUtils.sendJsonResponse(ctx, new StatusResponse(msg));
NettyUtils.sendJsonResponse(ctx, new StatusResponse(msg), HttpResponseStatus.ACCEPTED);
}

f.exceptionally(
Expand Down Expand Up @@ -239,14 +240,17 @@ private void handleScaleModel(
int maxIdleTime =
NettyUtils.getIntParameter(
decoder, MAX_IDLE_TIME__PARAMETER, modelInfo.getMaxIdleTime());
int batchSize =
NettyUtils.getIntParameter(
decoder, BATCH_SIZE_PARAMETER, modelInfo.getBatchSize());
int maxBatchDelay =
NettyUtils.getIntParameter(
decoder, MAX_BATCH_DELAY_PARAMETER, modelInfo.getMaxBatchDelay());

modelInfo =
modelInfo
.scaleWorkers(minWorkers, maxWorkers)
.configurePool(maxIdleTime, maxBatchDelay);
modelInfo
.scaleWorkers(minWorkers, maxWorkers)
.configurePool(maxIdleTime)
.configureModelBatch(batchSize, maxBatchDelay);
modelManager.triggerModelUpdated(modelInfo);

String msg =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,15 @@ public void sendResponse(List<Output> outputs) {
jobs.clear();
}

/** Sends an internal server error. */
public void sendError() {
/**
* Sends an error response to client.
*
* @param status the HTTP status
* @param error the exception
*/
public void sendError(HttpResponseStatus status, Throwable error) {
for (Job job : jobs) {
job.sendError(HttpResponseStatus.INTERNAL_SERVER_ERROR, "Internal server error");
job.sendError(status, error);
}
jobs.clear();
}
Expand Down
7 changes: 3 additions & 4 deletions serving/serving/src/main/java/ai/djl/serving/wlm/Job.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.serving.http.InternalServerException;
import ai.djl.serving.util.NettyUtils;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
Expand Down Expand Up @@ -118,17 +117,17 @@ public void sendOutput(Output output) {
* Sends error to the client.
*
* @param status the HTTP status
* @param error the error message
* @param error the exception
*/
public void sendError(HttpResponseStatus status, String error) {
public void sendError(HttpResponseStatus status, Throwable error) {
/*
* We can load the models based on the configuration file.Since this Job is
* not driven by the external connections, we could have a empty context for
* this job. We shouldn't try to send a response to ctx if this is not triggered
* by external clients.
*/
if (ctx != null) {
NettyUtils.sendError(ctx, status, new InternalServerException(error));
NettyUtils.sendError(ctx, status, error);
}

logger.debug(
Expand Down
10 changes: 5 additions & 5 deletions serving/serving/src/main/java/ai/djl/serving/wlm/ModelInfo.java
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,13 @@ public ModelInfo(
* triggerUpdates in the {@code ModelManager} using this new model.
*
* @param batchSize the batchSize to set
* @param maxBatchDelay maximum time to wait for a free space in worker queue after scaling up
* workers before giving up to offer the job to the queue.
* @return new configured ModelInfo.
*/
public ModelInfo configureModelBatch(int batchSize) {
public ModelInfo configureModelBatch(int batchSize, int maxBatchDelay) {
this.batchSize = batchSize;
this.maxBatchDelay = maxBatchDelay;
return this;
}

Expand All @@ -103,13 +106,10 @@ public ModelInfo scaleWorkers(int minWorkers, int maxWorkers) {
* model.
*
* @param maxIdleTime time a WorkerThread can be idle before scaling down this worker.
* @param maxBatchDelay maximum time to wait for a free space in worker queue after scaling up
* workers before giving up to offer the job to the queue.
* @return new configured ModelInfo.
*/
public ModelInfo configurePool(int maxIdleTime, int maxBatchDelay) {
public ModelInfo configurePool(int maxIdleTime) {
this.maxIdleTime = maxIdleTime;
this.maxBatchDelay = maxBatchDelay;
return this;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ public boolean unregisterModel(String modelName, String version) {
startupModels.remove(modelName);
m.close();
}
endpoint.getModels().clear();
logger.info("Model {} unregistered.", modelName);
} else {
ModelInfo model = endpoint.remove(version);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.serving.http.InternalServerException;
import ai.djl.translate.TranslateException;
import io.netty.handler.codec.http.HttpResponseStatus;
import java.util.List;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.atomic.AtomicBoolean;
Expand Down Expand Up @@ -74,12 +76,11 @@ public void run() {
aggregator.sendResponse(reply);
} catch (TranslateException e) {
logger.warn("Failed to predict", e);
aggregator.sendError();
aggregator.sendError(HttpResponseStatus.BAD_REQUEST, e);
}
}
req = null;
}

} catch (InterruptedException e) {
logger.debug("Shutting down the thread .. Scaling down.");
} catch (Throwable t) {
Expand All @@ -89,7 +90,8 @@ public void run() {
currentThread.set(null);
shutdown(WorkerState.WORKER_STOPPED);
if (req != null) {
aggregator.sendError();
Exception e = new InternalServerException("Server shutting down");
aggregator.sendError(HttpResponseStatus.INTERNAL_SERVER_ERROR, e);
}
}
}
Expand Down Expand Up @@ -120,7 +122,8 @@ public void shutdown(WorkerState state) {
Thread thread = currentThread.getAndSet(null);
if (thread != null) {
thread.interrupt();
aggregator.sendError();
Exception e = new InternalServerException("Server shutting down");
aggregator.sendError(HttpResponseStatus.INTERNAL_SERVER_ERROR, e);
}
predictor.close();
}
Expand Down

0 comments on commit 1763a2d

Please sign in to comment.