Skip to content

Commit

Permalink
Add memory consumption estimation for models in profile API. (opensea…
Browse files Browse the repository at this point in the history
…rch-project#853)

Signed-off-by: Jing Zhang <jngz@amazon.com>
  • Loading branch information
jngz-es authored and rbhavna committed Jun 16, 2023
1 parent bdd1388 commit a084d5a
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ public class MLModelCache {
private final Set<String> workerNodes;
private final Queue<Double> modelInferenceDurationQueue;
private final Queue<Double> predictRequestDurationQueue;
private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) Long memSizeEstimationCPU;
private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) Long memSizeEstimationGPU;

public MLModelCache() {
targetWorkerNodes = ConcurrentHashMap.newKeySet();
Expand Down Expand Up @@ -90,6 +92,8 @@ public void clear() {
if (predictor != null) {
predictor.close();
}
memSizeEstimationCPU = 0L;
memSizeEstimationGPU = 0L;
if (executor != null) {
executor.close();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.opensearch.common.settings.Settings;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.exception.MLLimitExceededException;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.engine.MLExecutable;
import org.opensearch.ml.engine.Predictable;
Expand Down Expand Up @@ -66,6 +67,59 @@ public synchronized void setModelState(String modelId, MLModelState state) {
getExistingModelCache(modelId).setModelState(state);
}

/**
* Set memory size estimation CPU/GPU
* @param modelId model id
* @param format model format like onnx
* @param size memory size
*/
public synchronized void setMemSizeEstimation(String modelId, MLModelFormat format, Long size) {
Long memSize = getMemSizeEstimation(format, size);
log.debug("Updating memSizeEstimation of Model {} to {}", modelId, memSize);
getExistingModelCache(modelId).setMemSizeEstimationCPU(memSize);
getExistingModelCache(modelId).setMemSizeEstimationGPU(memSize);
}

private Long getMemSizeEstimation(MLModelFormat format, Long size) {
Double scale = 1.0;
switch (format) {
case ONNX:
scale = 1.5;
break;
case TORCH_SCRIPT:
scale = 1.2;
break;
}
Long memSize = Double.valueOf(scale * size).longValue();
return memSize;
}

/**
* Get CPU memory estimation.
* @param modelId model id
* @return Long
*/
public Long getMemEstCPU(String modelId) {
MLModelCache modelCache = modelCaches.get(modelId);
if (modelCache == null) {
return null;
}
return modelCache.getMemSizeEstimationCPU();
}

/**
* Get GPU memory estimation.
* @param modelId model id
* @return Long
*/
public Long getMemEstGPU(String modelId) {
MLModelCache modelCache = modelCaches.get(modelId);
if (modelCache == null) {
return null;
}
return modelCache.getMemSizeEstimationGPU();
}

/**
* Check if model deployed on node.
* @param modelId model id
Expand Down Expand Up @@ -293,6 +347,8 @@ public MLModelProfile getModelProfile(String modelId) {
}
builder.modelInferenceStats(modelCache.getInferenceStats(true));
builder.predictRequestStats(modelCache.getInferenceStats(false));
builder.memSizeEstimationCPU(modelCache.getMemSizeEstimationCPU());
builder.memSizeEstimationGPU(modelCache.getMemSizeEstimationGPU());
return builder.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,7 @@ public void deployModel(
modelCacheHelper.setPredictor(modelId, predictable);
mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT).increment();
modelCacheHelper.setModelState(modelId, MLModelState.DEPLOYED);
modelCacheHelper.setMemSizeEstimation(modelId, mlModel.getModelFormat(), mlModel.getModelContentSizeInBytes());
listener.onResponse("successful");
} catch (Exception e) {
log.error("Failed to add predictor to cache", e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ public class MLModelProfile implements ToXContentFragment, Writeable {
private final String[] workerNodes;
private final MLPredictRequestStats modelInferenceStats;
private final MLPredictRequestStats predictRequestStats;
private final Long memSizeEstimationCPU;
private final Long memSizeEstimationGPU;

@Builder
public MLModelProfile(
Expand All @@ -36,14 +38,18 @@ public MLModelProfile(
String[] targetWorkerNodes,
String[] workerNodes,
MLPredictRequestStats modelInferenceStats,
MLPredictRequestStats predictRequestStats
MLPredictRequestStats predictRequestStats,
Long memSizeEstimationCPU,
Long memSizeEstimationGPU
) {
this.modelState = modelState;
this.predictor = predictor;
this.targetWorkerNodes = targetWorkerNodes;
this.workerNodes = workerNodes;
this.modelInferenceStats = modelInferenceStats;
this.predictRequestStats = predictRequestStats;
this.memSizeEstimationCPU = memSizeEstimationCPU;
this.memSizeEstimationGPU = memSizeEstimationGPU;
}

@Override
Expand All @@ -67,6 +73,12 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (predictRequestStats != null) {
builder.field("predict_request_stats", predictRequestStats);
}
if (memSizeEstimationCPU != null) {
builder.field("mem_size_estimation_cpu", memSizeEstimationCPU);
}
if (memSizeEstimationGPU != null) {
builder.field("mem_size_estimation_gpu", memSizeEstimationGPU);
}
builder.endObject();
return builder;
}
Expand All @@ -90,6 +102,8 @@ public MLModelProfile(StreamInput in) throws IOException {
} else {
this.predictRequestStats = null;
}
this.memSizeEstimationCPU = in.readOptionalLong();
this.memSizeEstimationGPU = in.readOptionalLong();
}

@Override
Expand All @@ -115,5 +129,7 @@ public void writeTo(StreamOutput out) throws IOException {
} else {
out.writeBoolean(false);
}
out.writeOptionalLong(memSizeEstimationCPU);
out.writeOptionalLong(memSizeEstimationGPU);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,9 @@ private Map<String, MLProfileModelResponse> buildModelCentricResult(List<MLProfi
null,
null,
entry.getValue().getModelInferenceStats(),
entry.getValue().getPredictRequestStats()
entry.getValue().getPredictRequestStats(),
entry.getValue().getMemSizeEstimationCPU(),
entry.getValue().getMemSizeEstimationGPU()
);
mlProfileModelResponse.getMlModelProfileMap().putAll(ImmutableMap.of(nodeId, modelProfile));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.opensearch.common.settings.Settings;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.exception.MLLimitExceededException;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.engine.algorithms.text_embedding.TextEmbeddingModel;
import org.opensearch.ml.profile.MLModelProfile;
Expand Down Expand Up @@ -78,6 +79,34 @@ public void testModelState() {
assertEquals(FunctionName.TEXT_EMBEDDING, cacheHelper.getFunctionName(modelId));
}

public void testMemSizeEstimationCPU() {
cacheHelper.initModelState(modelId, MLModelState.DEPLOYING, FunctionName.TEXT_EMBEDDING, targetWorkerNodes);
assertTrue(cacheHelper.getMemEstCPU(modelId) == null);
cacheHelper.setMemSizeEstimation(modelId, MLModelFormat.TORCH_SCRIPT, 1000L);
assertTrue(cacheHelper.getMemEstCPU(modelId) == 1200L);
}

public void testMemSizeEstimationCPUONNX() {
cacheHelper.initModelState(modelId, MLModelState.DEPLOYING, FunctionName.TEXT_EMBEDDING, targetWorkerNodes);
assertTrue(cacheHelper.getMemEstCPU(modelId) == null);
cacheHelper.setMemSizeEstimation(modelId, MLModelFormat.ONNX, 1000L);
assertTrue(cacheHelper.getMemEstCPU(modelId) == 1500L);
}

public void testMemSizeEstimationGPU() {
cacheHelper.initModelState(modelId, MLModelState.DEPLOYING, FunctionName.TEXT_EMBEDDING, targetWorkerNodes);
assertTrue(cacheHelper.getMemEstGPU(modelId) == null);
cacheHelper.setMemSizeEstimation(modelId, MLModelFormat.TORCH_SCRIPT, 1000L);
assertTrue(cacheHelper.getMemEstGPU(modelId) == 1200L);
}

public void testMemSizeEstimationGPUONNX() {
cacheHelper.initModelState(modelId, MLModelState.DEPLOYING, FunctionName.TEXT_EMBEDDING, targetWorkerNodes);
assertTrue(cacheHelper.getMemEstGPU(modelId) == null);
cacheHelper.setMemSizeEstimation(modelId, MLModelFormat.ONNX, 1000L);
assertTrue(cacheHelper.getMemEstGPU(modelId) == 1500L);
}

public void testModelState_DuplicateError() {
expectedEx.expect(MLLimitExceededException.class);
expectedEx.expectMessage("Duplicate deploy model task");
Expand Down

0 comments on commit a084d5a

Please sign in to comment.