Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove lock #948

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ci/Dockerfile.python3.6
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ RUN set -ex \
&& pip install retrying \
&& pip install mock \
&& pip install pytest -U \
&& pip install pytest-mock \
&& pip install pylint

# Install protobuf
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,55 +21,60 @@
import io.netty.handler.codec.http.HttpResponseStatus;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class BatchAggregator {

private static final Logger logger = LoggerFactory.getLogger(BatchAggregator.class);

private Model model;
private Map<String, Job> jobs;
static final Logger logger = LoggerFactory.getLogger(BatchAggregator.class);
Model model;
ArrayBlockingQueue<BaseModelRequest> reqQ;
ArrayBlockingQueue<Map<String, Job>> jobQ;
String threadName;
private ExecutorService batchHandlerService;

public BatchAggregator(Model model) {
this.model = model;
jobs = new LinkedHashMap<>();
// this.lastJobs = new LinkedHashMap<>();
reqQ = new ArrayBlockingQueue<>(1);
jobQ = new ArrayBlockingQueue<>(2);
}

public BaseModelRequest getRequest(String threadName, WorkerState state)
throws InterruptedException {
jobs.clear();
public void setThreadName(String threadName) {
logger.info("set threadName=" + threadName);
this.threadName = threadName;
}

ModelInferenceRequest req = new ModelInferenceRequest(model.getModelName());
public String getThreadName() {
return threadName;
}

model.pollBatch(
threadName, (state == WorkerState.WORKER_MODEL_LOADED) ? 0 : Long.MAX_VALUE, jobs);
public void startBatchHandlerService(String threadName) {
setThreadName(threadName);
if (batchHandlerService == null) {
batchHandlerService = Executors.newSingleThreadExecutor();
batchHandlerService.execute(new BatchHandler());
}
}

for (Job j : jobs.values()) {
if (j.isControlCmd()) {
if (jobs.size() > 1) {
throw new IllegalStateException(
"Received more than 1 control command. "
+ "Control messages should be processed/retrieved one at a time.");
}
RequestInput input = j.getPayload();
int gpuId = -1;
String gpu = input.getStringParameter("gpu");
if (gpu != null) {
gpuId = Integer.parseInt(gpu);
}
return new ModelLoadModelRequest(model, gpuId, threadName);
} else {
j.setScheduled();
req.addRequest(j.getPayload());
}
public void stopBatchHandlerService() {
if (batchHandlerService != null) {
batchHandlerService.shutdown();
}
return req;
batchHandlerService = null;
}

public BaseModelRequest getRequest(WorkerState state) throws InterruptedException {
return reqQ.take();
// lastJobs = jobQ.peek();
// return req;
}

public void sendResponse(ModelWorkerResponse message) {
// TODO: Handle prediction level code

Map<String, Job> jobs = jobQ.poll();
if (message.getCode() == 200) {
if (jobs.isEmpty()) {
// this is from initial load.
Expand Down Expand Up @@ -109,6 +114,7 @@ public void sendError(BaseModelRequest message, String error, HttpResponseStatus
return;
}

Map<String, Job> jobs = jobQ.poll();
if (message != null) {
ModelInferenceRequest msg = (ModelInferenceRequest) message;
for (RequestInput req : msg.getRequestBatch()) {
Expand All @@ -126,16 +132,66 @@ public void sendError(BaseModelRequest message, String error, HttpResponseStatus
}
} else {
// Send the error message to all the jobs
for (Map.Entry<String, Job> j : jobs.entrySet()) {
String jobsId = j.getValue().getJobId();
Job job = jobs.remove(jobsId);
if (jobs != null) {
for (Map.Entry<String, Job> j : jobs.entrySet()) {
String jobsId = j.getValue().getJobId();
Job job = jobs.remove(jobsId);

if (job.isControlCmd()) {
job.sendError(status, error);
} else {
// Data message can be handled by other workers.
// If batch has gone past its batch max delay timer?
model.addFirst(job);
if (job.isControlCmd()) {
job.sendError(status, error);
} else {
// Data message can be handled by other workers.
// If batch has gone past its batch max delay timer?
model.addFirst(job);
}
}
}
}
}

private class BatchHandler implements Runnable {
@Override
public void run() {
while (true) {
Map<String, Job> jobs = new LinkedHashMap<>();
ModelInferenceRequest req = new ModelInferenceRequest(model.getModelName());
boolean loadModelJob = false;

try {
model.pollBatch(threadName, jobs);
if (!jobs.isEmpty()) {
jobQ.put(jobs);

for (Job j : jobs.values()) {
if (j.isControlCmd()) {
if (jobs.size() > 1) {
throw new IllegalStateException(
"Received more than 1 control command. "
+ "Control messages should be processed/retrieved one at a time.");
}
RequestInput input = j.getPayload();
int gpuId = -1;
String gpu = input.getStringParameter("gpu");
if (gpu != null) {
gpuId = Integer.parseInt(gpu);
}
reqQ.put(new ModelLoadModelRequest(model, gpuId, threadName));
loadModelJob = true;
break;
} else {
j.setScheduled();
req.addRequest(j.getPayload());
}
}
if (!loadModelJob) {
reqQ.put(req);
}
}
} catch (InterruptedException e) {
logger.debug("Aggregator for " + threadName + " got interrupted.", e);
break;
} catch (IllegalArgumentException e) {
logger.debug("Aggregator for " + threadName + " got illegal argument.", e);
}
}
}
Expand Down
13 changes: 3 additions & 10 deletions frontend/server/src/main/java/com/amazonaws/ml/mms/wlm/Model.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.ReentrantLock;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -38,7 +37,6 @@ public class Model {
private int maxBatchDelay;
private String preloadModel;
private AtomicInteger port; // Port on which the model server is running
private ReentrantLock lock;
private int responseTimeout;
private WorkerThread serverThread;
// Total number of subsequent inference request failures
Expand All @@ -57,7 +55,6 @@ public Model(ModelArchive modelArchive, int queueSize, String preloadModel) {
jobsDb.putIfAbsent(DEFAULT_DATA_QUEUE, new LinkedBlockingDeque<>(queueSize));
failedInfReqs = new AtomicInteger(0);
port = new AtomicInteger(-1);
lock = new ReentrantLock();
}

public String getModelName() {
Expand Down Expand Up @@ -131,8 +128,7 @@ public void addFirst(Job job) {
jobsDb.get(DEFAULT_DATA_QUEUE).addFirst(job);
}

public void pollBatch(String threadId, long waitTime, Map<String, Job> jobsRepo)
throws InterruptedException {
public void pollBatch(String threadId, Map<String, Job> jobsRepo) throws InterruptedException {
if (jobsRepo == null || threadId == null || threadId.isEmpty()) {
throw new IllegalArgumentException("Invalid input given provided");
}
Expand All @@ -144,15 +140,14 @@ public void pollBatch(String threadId, long waitTime, Map<String, Job> jobsRepo)

LinkedBlockingDeque<Job> jobsQueue = jobsDb.get(threadId);
if (jobsQueue != null && !jobsQueue.isEmpty()) {
Job j = jobsQueue.poll(waitTime, TimeUnit.MILLISECONDS);
Job j = jobsQueue.poll();
if (j != null) {
jobsRepo.put(j.getJobId(), j);
return;
}
}

try {
lock.lockInterruptibly();
long maxDelay = maxBatchDelay;
jobsQueue = jobsDb.get(DEFAULT_DATA_QUEUE);

Expand All @@ -176,9 +171,7 @@ public void pollBatch(String threadId, long waitTime, Map<String, Job> jobsRepo)
}
logger.trace("sending jobs, size: {}", jobsRepo.size());
} finally {
if (lock.isHeldByCurrentThread()) {
lock.unlock();
}
logger.debug("done pollBatch");
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ private void runWorker()
throws WorkerInitializationException, InterruptedException, FileNotFoundException {
int responseTimeout = model.getResponseTimeout();
while (isRunning()) {
req = aggregator.getRequest(backendChannel.id().asLongText(), state);
req = aggregator.getRequest(state);
backendChannel.writeAndFlush(req).sync();
long begin = System.currentTimeMillis();
// TODO: Change this to configurable param
Expand Down Expand Up @@ -208,6 +208,7 @@ public void run() {
try {
if (!serverThread) {
connect();
aggregator.startBatchHandlerService(backendChannel.id().asLongText());
runWorker();
} else {
// TODO: Move this logic to a seperate ServerThread class
Expand Down Expand Up @@ -406,6 +407,7 @@ public void shutdown() {
aggregator.sendError(
null, "Worker scaled down.", HttpResponseStatus.INTERNAL_SERVER_ERROR);
}
aggregator.stopBatchHandlerService();
}

public boolean isServerThread() {
Expand Down
1 change: 1 addition & 0 deletions mms/tests/unit_tests/test_beckend_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def test_metrics(caplog):
Test if metric classes methods behave as expected
Also checks global metric service methods
"""
caplog.set_level(logging.INFO)
# Create a batch of request ids
request_ids = {0: 'abcd', 1: "xyz", 2: "qwerty", 3: "hjshfj"}
all_req_ids = ','.join(request_ids.values())
Expand Down
1 change: 1 addition & 0 deletions mms/tests/unit_tests/test_worker_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def test_valid_req(self, service):
class TestEmitMetrics:

def test_emit_metrics(self, caplog):
caplog.set_level(logging.INFO)
metrics = {'test_emit_metrics': True}
emit_metrics(metrics)
assert "[METRICS]" in caplog.text