Skip to content

Commit

Permalink
Fix workflow thread issue pytorch#1511
Browse files Browse the repository at this point in the history
  • Loading branch information
maaquib committed Apr 6, 2022
1 parent 375f004 commit 3327be2
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ public DagExecutor(Dag dag) {
public ArrayList<NodeOutput> execute(RequestInput input, ArrayList<String> topoSortedList) {

CompletionService<NodeOutput> executorCompletionService = null;
ExecutorService executorService = null;
if (topoSortedList == null) {
ExecutorService executorService = Executors.newFixedThreadPool(4);
executorService = Executors.newFixedThreadPool(4);
executorCompletionService = new ExecutorCompletionService<>(executorService);
}

Expand Down Expand Up @@ -90,6 +91,8 @@ public ArrayList<NodeOutput> execute(RequestInput input, ArrayList<String> topoS
logger.error(e.getMessage());
String[] error = e.getMessage().split(":");
throw new InternalServerException(error[error.length - 1]); // NOPMD
} finally {
executorService.shutdownNow();
}
} else {
for (String name : readyToExecute) {
Expand Down Expand Up @@ -150,7 +153,7 @@ private NodeOutput invokeModel(
InterruptedException {
try {

logger.error(String.format("Invoking - %s for attempt %d", nodeName, retryAttempt));
logger.info(String.format("Invoking - %s for attempt %d", nodeName, retryAttempt));
CompletableFuture<byte[]> respFuture = new CompletableFuture<>();
RestJob job = ApiUtils.addRESTInferenceJob(null, workflowModel.getName(), null, input);
job.setResponsePromise(respFuture);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ public WorkFlow getWorkflow(String workflowName) {
}

public void predict(ChannelHandlerContext ctx, String wfName, RequestInput input)
throws WorkflowNotFoundException {
throws WorkflowNotFoundException, WorkflowException {
WorkFlow wf = workflowMap.get(wfName);
if (wf != null) {
DagExecutor dagExecutor = new DagExecutor(wf.getDag());
Expand Down Expand Up @@ -420,6 +420,11 @@ public void predict(ChannelHandlerContext ctx, String wfName, RequestInput input
error[error.length - 1].strip()));
return null;
});
try {
predictionFuture.get();
} catch (ExecutionException | InterruptedException e) {
throw new WorkflowException("Workflow failed ", e);
}
} else {
throw new WorkflowNotFoundException("Workflow not found: " + wfName);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ public void handleRequest(

private void handlePredictions(
ChannelHandlerContext ctx, FullHttpRequest req, String[] segments)
throws WorkflowNotFoundException {
throws WorkflowNotFoundException, WorkflowException {
RequestInput input = parseRequest(ctx, req);
logger.info(input.toString());
String wfName = segments[2];
Expand Down

0 comments on commit 3327be2

Please sign in to comment.