diff --git a/ci/Dockerfile.python3.6 b/ci/Dockerfile.python3.6 index 76c45a314..f77a07e7d 100644 --- a/ci/Dockerfile.python3.6 +++ b/ci/Dockerfile.python3.6 @@ -190,6 +190,7 @@ RUN set -ex \ && pip install retrying \ && pip install mock \ && pip install pytest -U \ + && pip install pytest-mock \ && pip install pylint # Install protobuf diff --git a/frontend/server/src/main/java/com/amazonaws/ml/mms/wlm/BatchAggregator.java b/frontend/server/src/main/java/com/amazonaws/ml/mms/wlm/BatchAggregator.java index 0fac8cbd0..dba002889 100644 --- a/frontend/server/src/main/java/com/amazonaws/ml/mms/wlm/BatchAggregator.java +++ b/frontend/server/src/main/java/com/amazonaws/ml/mms/wlm/BatchAggregator.java @@ -21,55 +21,60 @@ import io.netty.handler.codec.http.HttpResponseStatus; import java.util.LinkedHashMap; import java.util.Map; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class BatchAggregator { - - private static final Logger logger = LoggerFactory.getLogger(BatchAggregator.class); - - private Model model; - private Map jobs; + static final Logger logger = LoggerFactory.getLogger(BatchAggregator.class); + Model model; + ArrayBlockingQueue reqQ; + ArrayBlockingQueue> jobQ; + String threadName; + private ExecutorService batchHandlerService; public BatchAggregator(Model model) { this.model = model; - jobs = new LinkedHashMap<>(); + // this.lastJobs = new LinkedHashMap<>(); + reqQ = new ArrayBlockingQueue<>(1); + jobQ = new ArrayBlockingQueue<>(2); } - public BaseModelRequest getRequest(String threadName, WorkerState state) - throws InterruptedException { - jobs.clear(); + public void setThreadName(String threadName) { + logger.info("set threadName=" + threadName); + this.threadName = threadName; + } - ModelInferenceRequest req = new ModelInferenceRequest(model.getModelName()); + public String getThreadName() { + return threadName; + } - model.pollBatch( - threadName, (state == WorkerState.WORKER_MODEL_LOADED) ? 0 : Long.MAX_VALUE, jobs); + public void startBatchHandlerService(String threadName) { + setThreadName(threadName); + if (batchHandlerService == null) { + batchHandlerService = Executors.newSingleThreadExecutor(); + batchHandlerService.execute(new BatchHandler()); + } + } - for (Job j : jobs.values()) { - if (j.isControlCmd()) { - if (jobs.size() > 1) { - throw new IllegalStateException( - "Received more than 1 control command. " - + "Control messages should be processed/retrieved one at a time."); - } - RequestInput input = j.getPayload(); - int gpuId = -1; - String gpu = input.getStringParameter("gpu"); - if (gpu != null) { - gpuId = Integer.parseInt(gpu); - } - return new ModelLoadModelRequest(model, gpuId, threadName); - } else { - j.setScheduled(); - req.addRequest(j.getPayload()); - } + public void stopBatchHandlerService() { + if (batchHandlerService != null) { + batchHandlerService.shutdown(); } - return req; + batchHandlerService = null; + } + + public BaseModelRequest getRequest(WorkerState state) throws InterruptedException { + return reqQ.take(); + // lastJobs = jobQ.peek(); + // return req; } public void sendResponse(ModelWorkerResponse message) { // TODO: Handle prediction level code - + Map jobs = jobQ.poll(); if (message.getCode() == 200) { if (jobs.isEmpty()) { // this is from initial load. @@ -109,6 +114,7 @@ public void sendError(BaseModelRequest message, String error, HttpResponseStatus return; } + Map jobs = jobQ.poll(); if (message != null) { ModelInferenceRequest msg = (ModelInferenceRequest) message; for (RequestInput req : msg.getRequestBatch()) { @@ -126,16 +132,66 @@ public void sendError(BaseModelRequest message, String error, HttpResponseStatus } } else { // Send the error message to all the jobs - for (Map.Entry j : jobs.entrySet()) { - String jobsId = j.getValue().getJobId(); - Job job = jobs.remove(jobsId); + if (jobs != null) { + for (Map.Entry j : jobs.entrySet()) { + String jobsId = j.getValue().getJobId(); + Job job = jobs.remove(jobsId); - if (job.isControlCmd()) { - job.sendError(status, error); - } else { - // Data message can be handled by other workers. - // If batch has gone past its batch max delay timer? - model.addFirst(job); + if (job.isControlCmd()) { + job.sendError(status, error); + } else { + // Data message can be handled by other workers. + // If batch has gone past its batch max delay timer? + model.addFirst(job); + } + } + } + } + } + + private class BatchHandler implements Runnable { + @Override + public void run() { + while (true) { + Map jobs = new LinkedHashMap<>(); + ModelInferenceRequest req = new ModelInferenceRequest(model.getModelName()); + boolean loadModelJob = false; + + try { + model.pollBatch(threadName, jobs); + if (!jobs.isEmpty()) { + jobQ.put(jobs); + + for (Job j : jobs.values()) { + if (j.isControlCmd()) { + if (jobs.size() > 1) { + throw new IllegalStateException( + "Received more than 1 control command. " + + "Control messages should be processed/retrieved one at a time."); + } + RequestInput input = j.getPayload(); + int gpuId = -1; + String gpu = input.getStringParameter("gpu"); + if (gpu != null) { + gpuId = Integer.parseInt(gpu); + } + reqQ.put(new ModelLoadModelRequest(model, gpuId, threadName)); + loadModelJob = true; + break; + } else { + j.setScheduled(); + req.addRequest(j.getPayload()); + } + } + if (!loadModelJob) { + reqQ.put(req); + } + } + } catch (InterruptedException e) { + logger.debug("Aggregator for " + threadName + " got interrupted.", e); + break; + } catch (IllegalArgumentException e) { + logger.debug("Aggregator for " + threadName + " got illegal argument.", e); } } } diff --git a/frontend/server/src/main/java/com/amazonaws/ml/mms/wlm/Model.java b/frontend/server/src/main/java/com/amazonaws/ml/mms/wlm/Model.java index bd21c3dcd..81f679453 100644 --- a/frontend/server/src/main/java/com/amazonaws/ml/mms/wlm/Model.java +++ b/frontend/server/src/main/java/com/amazonaws/ml/mms/wlm/Model.java @@ -22,7 +22,6 @@ import java.util.concurrent.LinkedBlockingDeque; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.locks.ReentrantLock; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -38,7 +37,6 @@ public class Model { private int maxBatchDelay; private String preloadModel; private AtomicInteger port; // Port on which the model server is running - private ReentrantLock lock; private int responseTimeout; private WorkerThread serverThread; // Total number of subsequent inference request failures @@ -57,7 +55,6 @@ public Model(ModelArchive modelArchive, int queueSize, String preloadModel) { jobsDb.putIfAbsent(DEFAULT_DATA_QUEUE, new LinkedBlockingDeque<>(queueSize)); failedInfReqs = new AtomicInteger(0); port = new AtomicInteger(-1); - lock = new ReentrantLock(); } public String getModelName() { @@ -131,8 +128,7 @@ public void addFirst(Job job) { jobsDb.get(DEFAULT_DATA_QUEUE).addFirst(job); } - public void pollBatch(String threadId, long waitTime, Map jobsRepo) - throws InterruptedException { + public void pollBatch(String threadId, Map jobsRepo) throws InterruptedException { if (jobsRepo == null || threadId == null || threadId.isEmpty()) { throw new IllegalArgumentException("Invalid input given provided"); } @@ -144,7 +140,7 @@ public void pollBatch(String threadId, long waitTime, Map jobsRepo) LinkedBlockingDeque jobsQueue = jobsDb.get(threadId); if (jobsQueue != null && !jobsQueue.isEmpty()) { - Job j = jobsQueue.poll(waitTime, TimeUnit.MILLISECONDS); + Job j = jobsQueue.poll(); if (j != null) { jobsRepo.put(j.getJobId(), j); return; @@ -152,7 +148,6 @@ public void pollBatch(String threadId, long waitTime, Map jobsRepo) } try { - lock.lockInterruptibly(); long maxDelay = maxBatchDelay; jobsQueue = jobsDb.get(DEFAULT_DATA_QUEUE); @@ -176,9 +171,7 @@ public void pollBatch(String threadId, long waitTime, Map jobsRepo) } logger.trace("sending jobs, size: {}", jobsRepo.size()); } finally { - if (lock.isHeldByCurrentThread()) { - lock.unlock(); - } + logger.debug("done pollBatch"); } } diff --git a/frontend/server/src/main/java/com/amazonaws/ml/mms/wlm/WorkerThread.java b/frontend/server/src/main/java/com/amazonaws/ml/mms/wlm/WorkerThread.java index 5ebb8bb70..a21f0c3fa 100644 --- a/frontend/server/src/main/java/com/amazonaws/ml/mms/wlm/WorkerThread.java +++ b/frontend/server/src/main/java/com/amazonaws/ml/mms/wlm/WorkerThread.java @@ -141,7 +141,7 @@ private void runWorker() throws WorkerInitializationException, InterruptedException, FileNotFoundException { int responseTimeout = model.getResponseTimeout(); while (isRunning()) { - req = aggregator.getRequest(backendChannel.id().asLongText(), state); + req = aggregator.getRequest(state); backendChannel.writeAndFlush(req).sync(); long begin = System.currentTimeMillis(); // TODO: Change this to configurable param @@ -208,6 +208,7 @@ public void run() { try { if (!serverThread) { connect(); + aggregator.startBatchHandlerService(backendChannel.id().asLongText()); runWorker(); } else { // TODO: Move this logic to a seperate ServerThread class @@ -406,6 +407,7 @@ public void shutdown() { aggregator.sendError( null, "Worker scaled down.", HttpResponseStatus.INTERNAL_SERVER_ERROR); } + aggregator.stopBatchHandlerService(); } public boolean isServerThread() { diff --git a/mms/tests/unit_tests/test_beckend_metric.py b/mms/tests/unit_tests/test_beckend_metric.py index 4ef2e6ad0..6732973bb 100644 --- a/mms/tests/unit_tests/test_beckend_metric.py +++ b/mms/tests/unit_tests/test_beckend_metric.py @@ -29,6 +29,7 @@ def test_metrics(caplog): Test if metric classes methods behave as expected Also checks global metric service methods """ + caplog.set_level(logging.INFO) # Create a batch of request ids request_ids = {0: 'abcd', 1: "xyz", 2: "qwerty", 3: "hjshfj"} all_req_ids = ','.join(request_ids.values()) diff --git a/mms/tests/unit_tests/test_worker_service.py b/mms/tests/unit_tests/test_worker_service.py index 575f13d73..bafba308c 100644 --- a/mms/tests/unit_tests/test_worker_service.py +++ b/mms/tests/unit_tests/test_worker_service.py @@ -50,6 +50,7 @@ def test_valid_req(self, service): class TestEmitMetrics: def test_emit_metrics(self, caplog): + caplog.set_level(logging.INFO) metrics = {'test_emit_metrics': True} emit_metrics(metrics) assert "[METRICS]" in caplog.text