From 3327be2c7da186d425e06fa13b91611338b54a92 Mon Sep 17 00:00:00 2001 From: Aaqib Ansari Date: Wed, 23 Mar 2022 15:49:34 -0700 Subject: [PATCH] Fix workflow thread issue #1511 --- .../main/java/org/pytorch/serve/ensemble/DagExecutor.java | 7 +++++-- .../java/org/pytorch/serve/workflow/WorkflowManager.java | 7 ++++++- .../workflow/api/http/WorkflowInferenceRequestHandler.java | 2 +- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/frontend/server/src/main/java/org/pytorch/serve/ensemble/DagExecutor.java b/frontend/server/src/main/java/org/pytorch/serve/ensemble/DagExecutor.java index 5dcb72c52ed..140e4d88651 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/ensemble/DagExecutor.java +++ b/frontend/server/src/main/java/org/pytorch/serve/ensemble/DagExecutor.java @@ -41,8 +41,9 @@ public DagExecutor(Dag dag) { public ArrayList execute(RequestInput input, ArrayList topoSortedList) { CompletionService executorCompletionService = null; + ExecutorService executorService = null; if (topoSortedList == null) { - ExecutorService executorService = Executors.newFixedThreadPool(4); + executorService = Executors.newFixedThreadPool(4); executorCompletionService = new ExecutorCompletionService<>(executorService); } @@ -90,6 +91,8 @@ public ArrayList execute(RequestInput input, ArrayList topoS logger.error(e.getMessage()); String[] error = e.getMessage().split(":"); throw new InternalServerException(error[error.length - 1]); // NOPMD + } finally { + executorService.shutdownNow(); } } else { for (String name : readyToExecute) { @@ -150,7 +153,7 @@ private NodeOutput invokeModel( InterruptedException { try { - logger.error(String.format("Invoking - %s for attempt %d", nodeName, retryAttempt)); + logger.info(String.format("Invoking - %s for attempt %d", nodeName, retryAttempt)); CompletableFuture respFuture = new CompletableFuture<>(); RestJob job = ApiUtils.addRESTInferenceJob(null, workflowModel.getName(), null, input); job.setResponsePromise(respFuture); diff --git a/frontend/server/src/main/java/org/pytorch/serve/workflow/WorkflowManager.java b/frontend/server/src/main/java/org/pytorch/serve/workflow/WorkflowManager.java index a391d54b332..708eaca9923 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/workflow/WorkflowManager.java +++ b/frontend/server/src/main/java/org/pytorch/serve/workflow/WorkflowManager.java @@ -366,7 +366,7 @@ public WorkFlow getWorkflow(String workflowName) { } public void predict(ChannelHandlerContext ctx, String wfName, RequestInput input) - throws WorkflowNotFoundException { + throws WorkflowNotFoundException, WorkflowException { WorkFlow wf = workflowMap.get(wfName); if (wf != null) { DagExecutor dagExecutor = new DagExecutor(wf.getDag()); @@ -420,6 +420,11 @@ public void predict(ChannelHandlerContext ctx, String wfName, RequestInput input error[error.length - 1].strip())); return null; }); + try { + predictionFuture.get(); + } catch (ExecutionException | InterruptedException e) { + throw new WorkflowException("Workflow failed ", e); + } } else { throw new WorkflowNotFoundException("Workflow not found: " + wfName); } diff --git a/frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowInferenceRequestHandler.java b/frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowInferenceRequestHandler.java index 1c67f4d2b82..59effb0b4be 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowInferenceRequestHandler.java +++ b/frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowInferenceRequestHandler.java @@ -58,7 +58,7 @@ public void handleRequest( private void handlePredictions( ChannelHandlerContext ctx, FullHttpRequest req, String[] segments) - throws WorkflowNotFoundException { + throws WorkflowNotFoundException, WorkflowException { RequestInput input = parseRequest(ctx, req); logger.info(input.toString()); String wfName = segments[2];