Skip to content

Commit

Permalink
Fix workflow thread issue pytorch#1511
Browse files Browse the repository at this point in the history
  • Loading branch information
maaquib committed Apr 6, 2022
1 parent 375f004 commit 85a8fac
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.ThreadFactory;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import org.pytorch.serve.archive.model.ModelNotFoundException;
import org.pytorch.serve.archive.model.ModelVersionNotFoundException;
import org.pytorch.serve.http.InternalServerException;
Expand All @@ -41,8 +43,11 @@ public DagExecutor(Dag dag) {
public ArrayList<NodeOutput> execute(RequestInput input, ArrayList<String> topoSortedList) {

CompletionService<NodeOutput> executorCompletionService = null;
ExecutorService executorService = null;
if (topoSortedList == null) {
ExecutorService executorService = Executors.newFixedThreadPool(4);
ThreadFactory namedThreadFactory =
new ThreadFactoryBuilder().setNameFormat("wf-debug-thread-%d").build();
executorService = Executors.newFixedThreadPool(4, namedThreadFactory);
executorCompletionService = new ExecutorCompletionService<>(executorService);
}

Expand Down Expand Up @@ -140,6 +145,9 @@ public ArrayList<NodeOutput> execute(RequestInput input, ArrayList<String> topoS
}
}
}
if (executorService != null) {
executorService.shutdown();
}

return leafOutputs;
}
Expand All @@ -150,7 +158,7 @@ private NodeOutput invokeModel(
InterruptedException {
try {

logger.error(String.format("Invoking - %s for attempt %d", nodeName, retryAttempt));
logger.info(String.format("Invoking - %s for attempt %d", nodeName, retryAttempt));
CompletableFuture<byte[]> respFuture = new CompletableFuture<>();
RestJob job = ApiUtils.addRESTInferenceJob(null, workflowModel.getName(), null, input);
job.setResponsePromise(respFuture);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ public WorkFlow getWorkflow(String workflowName) {
}

public void predict(ChannelHandlerContext ctx, String wfName, RequestInput input)
throws WorkflowNotFoundException {
throws WorkflowNotFoundException, WorkflowException {
WorkFlow wf = workflowMap.get(wfName);
if (wf != null) {
DagExecutor dagExecutor = new DagExecutor(wf.getDag());
Expand Down Expand Up @@ -420,6 +420,11 @@ public void predict(ChannelHandlerContext ctx, String wfName, RequestInput input
error[error.length - 1].strip()));
return null;
});
try {
predictionFuture.get();
} catch (ExecutionException | InterruptedException e) {
throw new WorkflowException("Workflow failed ", e);
}
} else {
throw new WorkflowNotFoundException("Workflow not found: " + wfName);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ public void handleRequest(

private void handlePredictions(
ChannelHandlerContext ctx, FullHttpRequest req, String[] segments)
throws WorkflowNotFoundException {
throws WorkflowNotFoundException, WorkflowException {
RequestInput input = parseRequest(ctx, req);
logger.info(input.toString());
String wfName = segments[2];
Expand Down

0 comments on commit 85a8fac

Please sign in to comment.