From a084d5a3124eb33c594c6c4ccef91d2877a04a20 Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Mon, 17 Apr 2023 10:41:19 -0700 Subject: [PATCH] Add memory consumption estimation for models in profile API. (#853) Signed-off-by: Jing Zhang --- .../org/opensearch/ml/model/MLModelCache.java | 4 ++ .../ml/model/MLModelCacheHelper.java | 56 +++++++++++++++++++ .../opensearch/ml/model/MLModelManager.java | 1 + .../opensearch/ml/profile/MLModelProfile.java | 18 +++++- .../ml/rest/RestMLProfileAction.java | 4 +- .../ml/model/MLModelCacheHelperTests.java | 29 ++++++++++ 6 files changed, 110 insertions(+), 2 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelCache.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelCache.java index 531ee45427..c496d8d428 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelCache.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelCache.java @@ -36,6 +36,8 @@ public class MLModelCache { private final Set workerNodes; private final Queue modelInferenceDurationQueue; private final Queue predictRequestDurationQueue; + private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) Long memSizeEstimationCPU; + private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) Long memSizeEstimationGPU; public MLModelCache() { targetWorkerNodes = ConcurrentHashMap.newKeySet(); @@ -90,6 +92,8 @@ public void clear() { if (predictor != null) { predictor.close(); } + memSizeEstimationCPU = 0L; + memSizeEstimationGPU = 0L; if (executor != null) { executor.close(); } diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java index 7d40350649..e360be2e71 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java @@ -19,6 +19,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.exception.MLLimitExceededException; +import org.opensearch.ml.common.model.MLModelFormat; import org.opensearch.ml.common.model.MLModelState; import org.opensearch.ml.engine.MLExecutable; import org.opensearch.ml.engine.Predictable; @@ -66,6 +67,59 @@ public synchronized void setModelState(String modelId, MLModelState state) { getExistingModelCache(modelId).setModelState(state); } + /** + * Set memory size estimation CPU/GPU + * @param modelId model id + * @param format model format like onnx + * @param size memory size + */ + public synchronized void setMemSizeEstimation(String modelId, MLModelFormat format, Long size) { + Long memSize = getMemSizeEstimation(format, size); + log.debug("Updating memSizeEstimation of Model {} to {}", modelId, memSize); + getExistingModelCache(modelId).setMemSizeEstimationCPU(memSize); + getExistingModelCache(modelId).setMemSizeEstimationGPU(memSize); + } + + private Long getMemSizeEstimation(MLModelFormat format, Long size) { + Double scale = 1.0; + switch (format) { + case ONNX: + scale = 1.5; + break; + case TORCH_SCRIPT: + scale = 1.2; + break; + } + Long memSize = Double.valueOf(scale * size).longValue(); + return memSize; + } + + /** + * Get CPU memory estimation. + * @param modelId model id + * @return Long + */ + public Long getMemEstCPU(String modelId) { + MLModelCache modelCache = modelCaches.get(modelId); + if (modelCache == null) { + return null; + } + return modelCache.getMemSizeEstimationCPU(); + } + + /** + * Get GPU memory estimation. + * @param modelId model id + * @return Long + */ + public Long getMemEstGPU(String modelId) { + MLModelCache modelCache = modelCaches.get(modelId); + if (modelCache == null) { + return null; + } + return modelCache.getMemSizeEstimationGPU(); + } + /** * Check if model deployed on node. * @param modelId model id @@ -293,6 +347,8 @@ public MLModelProfile getModelProfile(String modelId) { } builder.modelInferenceStats(modelCache.getInferenceStats(true)); builder.predictRequestStats(modelCache.getInferenceStats(false)); + builder.memSizeEstimationCPU(modelCache.getMemSizeEstimationCPU()); + builder.memSizeEstimationGPU(modelCache.getMemSizeEstimationGPU()); return builder.build(); } diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index 841104cddc..e8a55ecd30 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -565,6 +565,7 @@ public void deployModel( modelCacheHelper.setPredictor(modelId, predictable); mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT).increment(); modelCacheHelper.setModelState(modelId, MLModelState.DEPLOYED); + modelCacheHelper.setMemSizeEstimation(modelId, mlModel.getModelFormat(), mlModel.getModelContentSizeInBytes()); listener.onResponse("successful"); } catch (Exception e) { log.error("Failed to add predictor to cache", e); diff --git a/plugin/src/main/java/org/opensearch/ml/profile/MLModelProfile.java b/plugin/src/main/java/org/opensearch/ml/profile/MLModelProfile.java index 686f1dee8a..4f18142bbd 100644 --- a/plugin/src/main/java/org/opensearch/ml/profile/MLModelProfile.java +++ b/plugin/src/main/java/org/opensearch/ml/profile/MLModelProfile.java @@ -28,6 +28,8 @@ public class MLModelProfile implements ToXContentFragment, Writeable { private final String[] workerNodes; private final MLPredictRequestStats modelInferenceStats; private final MLPredictRequestStats predictRequestStats; + private final Long memSizeEstimationCPU; + private final Long memSizeEstimationGPU; @Builder public MLModelProfile( @@ -36,7 +38,9 @@ public MLModelProfile( String[] targetWorkerNodes, String[] workerNodes, MLPredictRequestStats modelInferenceStats, - MLPredictRequestStats predictRequestStats + MLPredictRequestStats predictRequestStats, + Long memSizeEstimationCPU, + Long memSizeEstimationGPU ) { this.modelState = modelState; this.predictor = predictor; @@ -44,6 +48,8 @@ public MLModelProfile( this.workerNodes = workerNodes; this.modelInferenceStats = modelInferenceStats; this.predictRequestStats = predictRequestStats; + this.memSizeEstimationCPU = memSizeEstimationCPU; + this.memSizeEstimationGPU = memSizeEstimationGPU; } @Override @@ -67,6 +73,12 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (predictRequestStats != null) { builder.field("predict_request_stats", predictRequestStats); } + if (memSizeEstimationCPU != null) { + builder.field("mem_size_estimation_cpu", memSizeEstimationCPU); + } + if (memSizeEstimationGPU != null) { + builder.field("mem_size_estimation_gpu", memSizeEstimationGPU); + } builder.endObject(); return builder; } @@ -90,6 +102,8 @@ public MLModelProfile(StreamInput in) throws IOException { } else { this.predictRequestStats = null; } + this.memSizeEstimationCPU = in.readOptionalLong(); + this.memSizeEstimationGPU = in.readOptionalLong(); } @Override @@ -115,5 +129,7 @@ public void writeTo(StreamOutput out) throws IOException { } else { out.writeBoolean(false); } + out.writeOptionalLong(memSizeEstimationCPU); + out.writeOptionalLong(memSizeEstimationGPU); } } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLProfileAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLProfileAction.java index 908ecf9968..fb19b0a141 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLProfileAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLProfileAction.java @@ -161,7 +161,9 @@ private Map buildModelCentricResult(List