diff --git a/plugin/src/main/java/org/opensearch/ml/action/stats/MLStatsNodeResponse.java b/plugin/src/main/java/org/opensearch/ml/action/stats/MLStatsNodeResponse.java index 1f815e086a..681e7f15fb 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/stats/MLStatsNodeResponse.java +++ b/plugin/src/main/java/org/opensearch/ml/action/stats/MLStatsNodeResponse.java @@ -17,7 +17,9 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.stats.MLAlgoStats; +import org.opensearch.ml.stats.MLModelStats; import org.opensearch.ml.stats.MLNodeLevelStat; +import org.opensearch.ml.stats.MLStatsInput; public class MLStatsNodeResponse extends BaseNodeResponse implements ToXContentFragment { /** @@ -30,6 +32,12 @@ public class MLStatsNodeResponse extends BaseNodeResponse implements ToXContentF * Example: {kmeans: { train: { request_count: 1} }} */ private Map algorithmStats; + /** + * Model stats which includes model level stats. + * + * Example: {model_id: { predict: { request_count: 1} }} + */ + private Map modelStats; /** * Constructor @@ -45,6 +53,9 @@ public MLStatsNodeResponse(StreamInput in) throws IOException { if (in.readBoolean()) { this.algorithmStats = in.readMap(stream -> stream.readEnum(FunctionName.class), MLAlgoStats::new); } + if (in.readBoolean()) { + this.modelStats = in.readMap(stream -> stream.readOptionalString(), MLModelStats::new); + } } public MLStatsNodeResponse(DiscoveryNode node, Map nodeStats) { @@ -52,14 +63,20 @@ public MLStatsNodeResponse(DiscoveryNode node, Map node this.nodeStats = nodeStats; } - public MLStatsNodeResponse(DiscoveryNode node, Map nodeStats, Map algorithmStats) { + public MLStatsNodeResponse( + DiscoveryNode node, + Map nodeStats, + Map algorithmStats, + Map modelStats + ) { super(node); this.nodeStats = nodeStats; this.algorithmStats = algorithmStats; + this.modelStats = modelStats; } public boolean isEmpty() { - return getNodeLevelStatSize() == 0 && getAlgorithmStatSize() == 0; + return getNodeLevelStatSize() == 0 && getAlgorithmStatSize() == 0 && getModelStatSize() == 0; } /** @@ -88,6 +105,12 @@ public void writeTo(StreamOutput out) throws IOException { } else { out.writeBoolean(false); } + if (modelStats != null) { + out.writeBoolean(true); + out.writeMap(modelStats, (stream, v) -> stream.writeOptionalString(v), (stream, stats) -> stats.writeTo(stream)); + } else { + out.writeBoolean(false); + } } public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { @@ -97,7 +120,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } } if (algorithmStats != null) { - builder.startObject("algorithms"); + builder.startObject(MLStatsInput.ALGORITHMS); for (Map.Entry stat : algorithmStats.entrySet()) { builder.startObject(stat.getKey().name().toLowerCase(Locale.ROOT)); stat.getValue().toXContent(builder, params); @@ -105,6 +128,15 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } builder.endObject(); } + if (modelStats != null) { + builder.startObject(MLStatsInput.MODELS); + for (Map.Entry stat : modelStats.entrySet()) { + builder.startObject(stat.getKey()); + stat.getValue().toXContent(builder, params); + builder.endObject(); + } + builder.endObject(); + } return builder; } @@ -120,17 +152,35 @@ public int getAlgorithmStatSize() { return algorithmStats == null ? 0 : algorithmStats.size(); } + public int getModelStatSize() { + return modelStats == null ? 0 : modelStats.size(); + } + public boolean hasAlgorithmStats(FunctionName algorithm) { - return algorithmStats == null ? false : algorithmStats.containsKey(algorithm); + return algorithmStats != null && algorithmStats.containsKey(algorithm); + } + + public boolean hasModelStats(String modelId) { + return modelStats != null && modelStats.containsKey(modelId); } public MLAlgoStats getAlgorithmStats(FunctionName algorithm) { return algorithmStats == null ? null : algorithmStats.get(algorithm); } + public MLModelStats getModelStats(String modelId) { + return modelStats == null ? null : modelStats.get(modelId); + } + public void removeAlgorithmStats(FunctionName algorithm) { if (algorithmStats != null) { algorithmStats.remove(algorithm); } } + + public void removeModelStats(String modelId) { + if (modelStats != null) { + modelStats.remove(modelId); + } + } } diff --git a/plugin/src/main/java/org/opensearch/ml/action/stats/MLStatsNodesResponse.java b/plugin/src/main/java/org/opensearch/ml/action/stats/MLStatsNodesResponse.java index 0ee31066b4..769964aee3 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/stats/MLStatsNodesResponse.java +++ b/plugin/src/main/java/org/opensearch/ml/action/stats/MLStatsNodesResponse.java @@ -60,7 +60,7 @@ public List readNodesFrom(StreamInput in) throws IOExceptio public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { String nodeId; DiscoveryNode node; - builder.startObject("nodes"); + builder.startObject(NODES_KEY); for (MLStatsNodeResponse mlStats : getNodes()) { node = mlStats.getNode(); nodeId = node.getId(); diff --git a/plugin/src/main/java/org/opensearch/ml/action/stats/MLStatsNodesTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/stats/MLStatsNodesTransportAction.java index f1809de9bb..f585a62edb 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/stats/MLStatsNodesTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/stats/MLStatsNodesTransportAction.java @@ -21,6 +21,7 @@ import org.opensearch.ml.stats.ActionName; import org.opensearch.ml.stats.MLActionStats; import org.opensearch.ml.stats.MLAlgoStats; +import org.opensearch.ml.stats.MLModelStats; import org.opensearch.ml.stats.MLNodeLevelStat; import org.opensearch.ml.stats.MLStatLevel; import org.opensearch.ml.stats.MLStats; @@ -125,6 +126,22 @@ MLStatsNodeResponse createMLStatsNodeResponse(MLStatsNodesRequest mlStatsNodesRe } } - return new MLStatsNodeResponse(clusterService.localNode(), statValues, algorithmStats); + Map modelStats = new HashMap<>(); + // return model level stats + if (mlStatsInput.includeModelStats()) { + for (String modelId : mlStats.getAllModels()) { + if (mlStatsInput.retrieveStatsForModel(modelId)) { + Map actionStatsMap = new HashMap<>(); + for (Map.Entry entry : mlStats.getModelStats(modelId).entrySet()) { + if (mlStatsInput.retrieveStatsForAction(entry.getKey())) { + actionStatsMap.put(entry.getKey(), entry.getValue()); + } + } + modelStats.put(modelId, new MLModelStats(actionStatsMap)); + } + } + } + + return new MLStatsNodeResponse(clusterService.localNode(), statValues, algorithmStats, modelStats); } } diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index ff002de87e..759b0cec9f 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -865,6 +865,7 @@ public void deployModel( mlStats.createCounterStatIfAbsent(functionName, ActionName.DEPLOY, ML_ACTION_REQUEST_COUNT).increment(); mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment(); mlStats.getStat(MLNodeLevelStat.ML_REQUEST_COUNT).increment(); + mlStats.createModelCounterStatIfAbsent(modelId, ActionName.DEPLOY, ML_ACTION_REQUEST_COUNT).increment(); List workerNodes = mlTask.getWorkerNodes(); if (modelCacheHelper.isModelDeployed(modelId)) { if (workerNodes != null && workerNodes.size() > 0) { @@ -1210,6 +1211,7 @@ public synchronized Map undeployModel(String[] modelIds) { mlStats .createCounterStatIfAbsent(getModelFunctionName(modelId), ActionName.UNDEPLOY, ML_ACTION_REQUEST_COUNT) .increment(); + mlStats.createModelCounterStatIfAbsent(modelId, ActionName.UNDEPLOY, ML_ACTION_REQUEST_COUNT).increment(); } else { modelUndeployStatus.put(modelId, NOT_FOUND); } @@ -1221,6 +1223,7 @@ public synchronized Map undeployModel(String[] modelIds) { modelUndeployStatus.put(modelId, UNDEPLOYED); mlStats.getStat(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT).decrement(); mlStats.createCounterStatIfAbsent(getModelFunctionName(modelId), ActionName.UNDEPLOY, ML_ACTION_REQUEST_COUNT).increment(); + mlStats.createModelCounterStatIfAbsent(modelId, ActionName.UNDEPLOY, ML_ACTION_REQUEST_COUNT).increment(); removeModel(modelId); } } diff --git a/plugin/src/main/java/org/opensearch/ml/stats/MLModelStats.java b/plugin/src/main/java/org/opensearch/ml/stats/MLModelStats.java new file mode 100644 index 0000000000..df3ceccd6a --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/stats/MLModelStats.java @@ -0,0 +1,64 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.stats; + +import java.io.IOException; +import java.util.Locale; +import java.util.Map; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentFragment; +import org.opensearch.core.xcontent.XContentBuilder; + +public class MLModelStats implements ToXContentFragment, Writeable { + + /** + * Model stats. + * Key: Model Id. + * Value: MLActionStats which contains action stat/value map. + * + * Example: {predict: { request_count: 1}} + */ + private Map modelStats; + + public MLModelStats(StreamInput in) throws IOException { + if (in.readBoolean()) { + this.modelStats = in.readMap(stream -> stream.readEnum(ActionName.class), MLActionStats::new); + } + } + + public MLModelStats(Map modelStats) { + this.modelStats = modelStats; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + if (modelStats != null && modelStats.size() > 0) { + out.writeBoolean(true); + out.writeMap(modelStats, (stream, v) -> stream.writeEnum(v), (stream, stats) -> stats.writeTo(stream)); + } else { + out.writeBoolean(false); + } + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + if (modelStats != null && modelStats.size() > 0) { + for (Map.Entry entry : modelStats.entrySet()) { + builder.startObject(entry.getKey().name().toLowerCase(Locale.ROOT)); + entry.getValue().toXContent(builder, params); + builder.endObject(); + } + } + return builder; + } + + public MLActionStats getActionStats(ActionName action) { + return modelStats == null ? null : modelStats.get(action); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/stats/MLStatLevel.java b/plugin/src/main/java/org/opensearch/ml/stats/MLStatLevel.java index 4ded531c51..b2bff0e301 100644 --- a/plugin/src/main/java/org/opensearch/ml/stats/MLStatLevel.java +++ b/plugin/src/main/java/org/opensearch/ml/stats/MLStatLevel.java @@ -9,6 +9,7 @@ public enum MLStatLevel { CLUSTER, NODE, ALGORITHM, + MODEL, ACTION; public static MLStatLevel from(String value) { diff --git a/plugin/src/main/java/org/opensearch/ml/stats/MLStats.java b/plugin/src/main/java/org/opensearch/ml/stats/MLStats.java index b743c58434..b16d7c00b0 100644 --- a/plugin/src/main/java/org/opensearch/ml/stats/MLStats.java +++ b/plugin/src/main/java/org/opensearch/ml/stats/MLStats.java @@ -22,6 +22,7 @@ public class MLStats { @Getter private Map> stats; private Map>> algoStats;// {"kmeans":{"train":{"request_count":10}}} + private Map>> modelStats;// {"model_id":{"train":{"request_count":10}}} /** * Constructor @@ -31,6 +32,7 @@ public class MLStats { public MLStats(Map> stats) { this.stats = stats; this.algoStats = new ConcurrentHashMap<>(); + this.modelStats = new ConcurrentHashMap<>(); } /** @@ -62,6 +64,12 @@ public MLStat createCounterStatIfAbsent(FunctionName algoName, ActionName act return createAlgoStatIfAbsent(algoActionStats, stat, () -> new MLStat<>(false, new CounterSupplier())); } + public MLStat createModelCounterStatIfAbsent(String modelId, ActionName action, MLActionLevelStat stat) { + Map> actionStats = modelStats.computeIfAbsent(modelId, it -> new ConcurrentHashMap<>()); + Map algoActionStats = actionStats.computeIfAbsent(action, it -> new ConcurrentHashMap<>()); + return createAlgoStatIfAbsent(algoActionStats, stat, () -> new MLStat<>(false, new CounterSupplier())); + } + public synchronized MLStat createAlgoStatIfAbsent( Map algoActionStats, MLActionLevelStat key, @@ -130,7 +138,27 @@ public Map getAlgorithmStats(FunctionName algoName) { return algoActionStats; } + public Map getModelStats(String modelId) { + if (!modelStats.containsKey(modelId)) { + return null; + } + Map modelActionStats = new HashMap<>(); + + for (Map.Entry> entry : modelStats.get(modelId).entrySet()) { + Map statsMap = new HashMap<>(); + for (Map.Entry state : entry.getValue().entrySet()) { + statsMap.put(state.getKey(), state.getValue().getValue()); + } + modelActionStats.put(entry.getKey(), new MLActionStats(statsMap)); + } + return modelActionStats; + } + public FunctionName[] getAllAlgorithms() { return algoStats.keySet().toArray(new FunctionName[0]); } + + public String[] getAllModels() { + return modelStats.keySet().toArray(new String[0]); + } } diff --git a/plugin/src/main/java/org/opensearch/ml/stats/MLStatsInput.java b/plugin/src/main/java/org/opensearch/ml/stats/MLStatsInput.java index 01ee7f61cc..151f53658a 100644 --- a/plugin/src/main/java/org/opensearch/ml/stats/MLStatsInput.java +++ b/plugin/src/main/java/org/opensearch/ml/stats/MLStatsInput.java @@ -35,6 +35,7 @@ public class MLStatsInput implements ToXContentObject, Writeable { public static final String ACTION_LEVEL_STATS = "action_level_stats"; public static final String NODE_IDS = "node_ids"; public static final String ALGORITHMS = "algorithms"; + public static final String MODELS = "models"; public static final String ACTIONS = "actions"; /** @@ -62,6 +63,11 @@ public class MLStatsInput implements ToXContentObject, Writeable { * Which algorithm's stats will be retrieved. */ private EnumSet algorithms; + /** + * Which model's stats will be retrieved. + */ + private Set models; + /** * Which action's stats will be retrieved. */ @@ -75,6 +81,7 @@ public class MLStatsInput implements ToXContentObject, Writeable { * @param actionLevelStats action level stats which will be retrieved * @param nodeIds retrieve stats on these nodes * @param algorithms retrieve stats for which algorithms + * @param models retrieve stats for which models * @param actions retrieve stats for which actions */ @Builder @@ -85,6 +92,7 @@ public MLStatsInput( EnumSet actionLevelStats, Set nodeIds, EnumSet algorithms, + Set models, EnumSet actions ) { this.targetStatLevels = targetStatLevels; @@ -93,6 +101,7 @@ public MLStatsInput( this.actionLevelStats = actionLevelStats; this.nodeIds = nodeIds; this.algorithms = algorithms; + this.models = models; this.actions = actions; } @@ -103,6 +112,7 @@ public MLStatsInput() { this.actionLevelStats = EnumSet.noneOf(MLActionLevelStat.class); this.nodeIds = new HashSet<>(); this.algorithms = EnumSet.noneOf(FunctionName.class); + this.models = new HashSet<>(); this.actions = EnumSet.noneOf(ActionName.class); } @@ -112,6 +122,7 @@ public MLStatsInput(StreamInput input) throws IOException { nodeLevelStats = input.readBoolean() ? input.readEnumSet(MLNodeLevelStat.class) : EnumSet.noneOf(MLNodeLevelStat.class); actionLevelStats = input.readBoolean() ? input.readEnumSet(MLActionLevelStat.class) : EnumSet.noneOf(MLActionLevelStat.class); nodeIds = input.readBoolean() ? new HashSet<>(input.readStringList()) : new HashSet<>(); + models = input.readBoolean() ? new HashSet<>(input.readStringList()) : new HashSet<>(); algorithms = input.readBoolean() ? input.readEnumSet(FunctionName.class) : EnumSet.noneOf(FunctionName.class); actions = input.readBoolean() ? input.readEnumSet(ActionName.class) : EnumSet.noneOf(ActionName.class); } @@ -123,6 +134,7 @@ public void writeTo(StreamOutput out) throws IOException { writeEnumSet(out, nodeLevelStats); writeEnumSet(out, actionLevelStats); out.writeOptionalStringCollection(nodeIds); + out.writeOptionalStringCollection(models); writeEnumSet(out, algorithms); writeEnumSet(out, actions); } @@ -142,6 +154,7 @@ public static MLStatsInput parse(XContentParser parser) throws IOException { EnumSet nodeLevelStats = EnumSet.noneOf(MLNodeLevelStat.class); EnumSet actionLevelStats = EnumSet.noneOf(MLActionLevelStat.class); Set nodeIds = new HashSet<>(); + Set models = new HashSet<>(); EnumSet algorithms = EnumSet.noneOf(FunctionName.class); EnumSet actions = EnumSet.noneOf(ActionName.class); @@ -184,6 +197,9 @@ public static MLStatsInput parse(XContentParser parser) throws IOException { case ALGORITHMS: parseField(parser, algorithms, input -> FunctionName.from(input.toUpperCase(Locale.ROOT)), FunctionName.class); break; + case MODELS: + parseArrayField(parser, models); + break; case ACTIONS: parseField(parser, actions, input -> ActionName.from(input.toUpperCase(Locale.ROOT)), ActionName.class); break; @@ -200,6 +216,7 @@ public static MLStatsInput parse(XContentParser parser) throws IOException { .actionLevelStats(actionLevelStats) .nodeIds(nodeIds) .algorithms(algorithms) + .models(models) .actions(actions) .build(); } @@ -225,6 +242,9 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par if (algorithms != null) { builder.field(ALGORITHMS, algorithms); } + if (models != null) { + builder.field(MODELS, models); + } if (actions != null) { builder.field(ACTIONS, actions); } @@ -252,10 +272,18 @@ public boolean retrieveStatsForAllAlgos() { return algorithms == null || algorithms.size() == 0; } + public boolean retrieveStatsForAllModels() { + return models == null || models.size() == 0; + } + public boolean retrieveStatsForAlgo(FunctionName algoName) { return retrieveStatsForAllAlgos() || algorithms.contains(algoName); } + public boolean retrieveStatsForModel(String modelId) { + return retrieveStatsForAllModels() || models.contains(modelId); + } + public boolean retrieveStatsForAction(ActionName actionName) { return retrieveStatsForAllActions() || actions.contains(actionName); } @@ -283,10 +311,15 @@ public boolean onlyRetrieveClusterLevelStats() { } return !targetStatLevels.contains(MLStatLevel.NODE) && !targetStatLevels.contains(MLStatLevel.ALGORITHM) + && !targetStatLevels.contains(MLStatLevel.MODEL) && !targetStatLevels.contains(MLStatLevel.ACTION); } public boolean includeAlgoStats() { return targetStatLevels.contains(MLStatLevel.ALGORITHM) || targetStatLevels.contains(MLStatLevel.ACTION); } + + public boolean includeModelStats() { + return targetStatLevels.contains(MLStatLevel.MODEL) || targetStatLevels.contains(MLStatLevel.ACTION); + } } diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java index 22fed1ff34..e6b6be2c62 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java @@ -206,6 +206,9 @@ private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListe mlStats .createCounterStatIfAbsent(mlTask.getFunctionName(), ActionName.PREDICT, MLActionLevelStat.ML_ACTION_REQUEST_COUNT) .increment(); + if (modelId != null) { + mlStats.createModelCounterStatIfAbsent(modelId, ActionName.PREDICT, MLActionLevelStat.ML_ACTION_REQUEST_COUNT).increment(); + } mlTask.setState(MLTaskState.RUNNING); mlTaskManager.add(mlTask); @@ -232,7 +235,7 @@ private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListe throw new IllegalArgumentException("Model not ready to be used: " + modelId); } } catch (Exception e) { - handlePredictFailure(mlTask, internalListener, e, false); + handlePredictFailure(mlTask, internalListener, e, false, modelId); } // search model by model id. @@ -258,7 +261,7 @@ private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListe OpenSearchException e = new OpenSearchException( "User: " + requestUser.getName() + " does not have permissions to run predict by model: " + modelId ); - handlePredictFailure(mlTask, internalListener, e, false); + handlePredictFailure(mlTask, internalListener, e, false, modelId); return; } // run predict @@ -279,18 +282,18 @@ private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListe }, e -> { log.error("Failed to predict " + mlInput.getAlgorithm() + ", modelId: " + mlTask.getModelId(), e); - handlePredictFailure(mlTask, internalListener, e, true); + handlePredictFailure(mlTask, internalListener, e, true, modelId); }); GetRequest getRequest = new GetRequest(ML_MODEL_INDEX, mlTask.getModelId()); client.get(getRequest, threadedActionListener(ActionListener.runBefore(getModelListener, () -> context.restore()))); } catch (Exception e) { log.error("Failed to get model " + mlTask.getModelId(), e); - handlePredictFailure(mlTask, internalListener, e, true); + handlePredictFailure(mlTask, internalListener, e, true, modelId); } } else { IllegalArgumentException e = new IllegalArgumentException("ModelId is invalid"); log.error("ModelId is invalid", e); - handlePredictFailure(mlTask, internalListener, e, false); + handlePredictFailure(mlTask, internalListener, e, false, modelId); } } @@ -298,11 +301,18 @@ private ThreadedActionListener threadedActionListener(ActionListener l return new ThreadedActionListener<>(log, threadPool, PREDICT_THREAD_POOL, listener, false); } - private void handlePredictFailure(MLTask mlTask, ActionListener listener, Exception e, boolean trackFailure) { + private void handlePredictFailure( + MLTask mlTask, + ActionListener listener, + Exception e, + boolean trackFailure, + String modelId + ) { if (trackFailure) { mlStats .createCounterStatIfAbsent(mlTask.getFunctionName(), ActionName.PREDICT, MLActionLevelStat.ML_ACTION_FAILURE_COUNT) .increment(); + mlStats.createModelCounterStatIfAbsent(modelId, ActionName.PREDICT, MLActionLevelStat.ML_ACTION_FAILURE_COUNT); mlStats.getStat(MLNodeLevelStat.ML_FAILURE_COUNT).increment(); } handleAsyncMLTaskFailure(mlTask, e); diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLTrainingTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLTrainingTaskRunner.java index ca436e72f7..711a94171f 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLTrainingTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLTrainingTaskRunner.java @@ -119,6 +119,9 @@ protected void executeTask(MLTrainingTaskRequest request, ActionListener internalListener = ActionListener.wrap(res -> { String modelId = ((MLTrainingOutput) res.getOutput()).getModelId(); + mlStats + .createModelCounterStatIfAbsent(modelId, ActionName.TRAIN, MLActionLevelStat.ML_ACTION_REQUEST_COUNT) + .increment(); log.info("ML model trained successfully, task id: {}, model id: {}", taskId, modelId); mlTask.setModelId(modelId); handleAsyncMLTaskComplete(mlTask); diff --git a/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodeResponseTests.java b/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodeResponseTests.java index b475b033ff..6397198f0d 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodeResponseTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodeResponseTests.java @@ -23,6 +23,7 @@ import org.opensearch.ml.stats.MLActionLevelStat; import org.opensearch.ml.stats.MLActionStats; import org.opensearch.ml.stats.MLAlgoStats; +import org.opensearch.ml.stats.MLModelStats; import org.opensearch.ml.stats.MLNodeLevelStat; import org.opensearch.ml.utils.TestHelper; import org.opensearch.test.OpenSearchTestCase; @@ -33,7 +34,9 @@ public class MLStatsNodeResponseTests extends OpenSearchTestCase { private MLStatsNodeResponse response; private DiscoveryNode node; - private long totalRequestCount = 100l; + private final long totalRequestCount = 100l; + + private final String modelId = "model_id"; @Before public void setup() { @@ -63,39 +66,48 @@ public void testToXContent_NodeLevelStats() throws IOException { assertEquals("{\"ml_request_count\":100}", taskContent); } - public void testToXContent_AlgorithmStats() throws IOException { + public void testToXContent_AlgorithmAndModelStats() throws IOException { XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); builder.startObject(); - MLStatsNodeResponse response = createResponseWithDefaultAlgoStats(null); + MLStatsNodeResponse response = createResponseWithDefaultAlgoAndModelStats(null); response.toXContent(builder, ToXContent.EMPTY_PARAMS); builder.endObject(); String taskContent = TestHelper.xContentBuilderToString(builder); - assertEquals("{\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_request_count\":100}}}}", taskContent); + assertEquals( + "{\"algorithms\":{\"kmeans\":{\"predict\":{\"ml_action_request_count\":100}}},\"models\":{\"model_id\":{\"predict\":{\"ml_action_request_count\":100}}}}", + taskContent + ); } public void testWriteTo_AlgoStats() throws IOException { - MLStatsNodeResponse response = createResponseWithDefaultAlgoStats(null); + MLStatsNodeResponse response = createResponseWithDefaultAlgoAndModelStats(null); BytesStreamOutput output = new BytesStreamOutput(); response.writeTo(output); MLStatsNodeResponse newResponse = new MLStatsNodeResponse(output.bytes().streamInput()); assertEquals(0, newResponse.getNodeLevelStatSize()); assertEquals(1, newResponse.getAlgorithmStatSize()); assertTrue(newResponse.hasAlgorithmStats(FunctionName.KMEANS)); - MLActionStats stats = newResponse.getAlgorithmStats(FunctionName.KMEANS).getActionStats(ActionName.TRAIN); + assertTrue(newResponse.hasModelStats(modelId)); + MLActionStats stats = newResponse.getAlgorithmStats(FunctionName.KMEANS).getActionStats(ActionName.PREDICT); + MLActionStats mlStats = newResponse.getModelStats(modelId).getActionStats(ActionName.PREDICT); assertEquals(totalRequestCount, stats.getActionStat(MLActionLevelStat.ML_ACTION_REQUEST_COUNT)); + assertEquals(totalRequestCount, mlStats.getActionStat(MLActionLevelStat.ML_ACTION_REQUEST_COUNT)); } - private MLStatsNodeResponse createResponseWithDefaultAlgoStats(Map nodeStats) { + private MLStatsNodeResponse createResponseWithDefaultAlgoAndModelStats(Map nodeStats) { Map algoStats = new HashMap<>(); Map actionStats = ImmutableMap.of(MLActionLevelStat.ML_ACTION_REQUEST_COUNT, totalRequestCount); - Map stats = ImmutableMap.of(ActionName.TRAIN, new MLActionStats(actionStats)); + Map stats = ImmutableMap.of(ActionName.PREDICT, new MLActionStats(actionStats)); algoStats.put(FunctionName.KMEANS, new MLAlgoStats(stats)); - MLStatsNodeResponse response = new MLStatsNodeResponse(node, nodeStats, algoStats); + Map modelStats = new HashMap<>(); + modelStats.put(modelId, new MLModelStats(stats)); + + MLStatsNodeResponse response = new MLStatsNodeResponse(node, nodeStats, algoStats, modelStats); return response; } - public void testToXContent_WithAlgoStats() throws IOException { + public void testToXContent_WithAlgoAndModelStats() throws IOException { XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); builder.startObject(); DiscoveryNode node = new DiscoveryNode("node0", buildNewFakeTransportAddress(), Version.CURRENT); @@ -108,14 +120,23 @@ public void testToXContent_WithAlgoStats() throws IOException { algoActionStatMap.put(MLActionLevelStat.ML_ACTION_FAILURE_COUNT, 22); algoActionStats.put(ActionName.TRAIN, new MLActionStats(algoActionStatMap)); algoStats.put(FunctionName.KMEANS, new MLAlgoStats(algoActionStats)); - response = new MLStatsNodeResponse(node, statsToValues, algoStats); + + Map modelStats = new HashMap<>(); + Map modelActionStats = new HashMap<>(); + Map modelActionStatMap = new HashMap<>(); + modelActionStatMap.put(MLActionLevelStat.ML_ACTION_REQUEST_COUNT, 111); + modelActionStatMap.put(MLActionLevelStat.ML_ACTION_FAILURE_COUNT, 22); + modelActionStats.put(ActionName.PREDICT, new MLActionStats(modelActionStatMap)); + modelStats.put(modelId, new MLModelStats(modelActionStats)); + + response = new MLStatsNodeResponse(node, statsToValues, algoStats, modelStats); response.toXContent(builder, ToXContent.EMPTY_PARAMS); builder.endObject(); String taskContent = TestHelper.xContentBuilderToString(builder); Set validResult = ImmutableSet .of( - "{\"ml_request_count\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_failure_count\":22,\"ml_action_request_count\":111}}}}", - "{\"ml_request_count\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_request_count\":111,\"ml_action_failure_count\":22}}}}" + "{\"ml_request_count\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_failure_count\":22,\"ml_action_request_count\":111}}},\"models\":{\"model_id\":{\"predict\":{\"ml_action_failure_count\":22,\"ml_action_request_count\":111}}}}", + "{\"ml_request_count\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_request_count\":111,\"ml_action_failure_count\":22}}},\"models\":{\"model_id\":{\"predict\":{\"ml_action_request_count\":111,\"ml_action_failure_count\":22}}}}" ); assertTrue(validResult.contains(taskContent)); } @@ -129,12 +150,12 @@ public void testReadStats() throws IOException { } public void testIsEmpty_NullNodeStats() { - MLStatsNodeResponse response = createResponseWithDefaultAlgoStats(null); + MLStatsNodeResponse response = createResponseWithDefaultAlgoAndModelStats(null); assertFalse(response.isEmpty()); } public void testIsEmpty_EmptyNodeStats() { - MLStatsNodeResponse response = createResponseWithDefaultAlgoStats(ImmutableMap.of()); + MLStatsNodeResponse response = createResponseWithDefaultAlgoAndModelStats(ImmutableMap.of()); assertFalse(response.isEmpty()); } @@ -142,14 +163,18 @@ public void testIsEmpty_NullAlgoStats() { assertFalse(response.isEmpty()); } - public void testIsEmpty_EmptyAlgoStats() { - MLStatsNodeResponse response = createResponseWithDefaultAlgoStats(ImmutableMap.of()); + public void testIsEmpty_EmptyAlgoAndModelStats() { + MLStatsNodeResponse response = createResponseWithDefaultAlgoAndModelStats(ImmutableMap.of()); + assertEquals(1, response.getAlgorithmStatSize()); + assertEquals(1, response.getModelStatSize()); response.removeAlgorithmStats(FunctionName.KMEANS); - assertTrue(response.isEmpty()); + response.removeModelStats(modelId); + assertEquals(0, response.getAlgorithmStatSize()); + assertEquals(0, response.getModelStatSize()); } public void testIsEmpty_NonEmptyNodeAndAlgoStats() { - MLStatsNodeResponse response = createResponseWithDefaultAlgoStats( + MLStatsNodeResponse response = createResponseWithDefaultAlgoAndModelStats( ImmutableMap.of(MLNodeLevelStat.ML_REQUEST_COUNT, totalRequestCount) ); assertFalse(response.isEmpty()); @@ -176,13 +201,13 @@ public void testGetAlgorithmLevelStat_NullAlgoStats() { } public void testGetAlgorithmLevelStat_EmptyAlgoStats() { - MLStatsNodeResponse response = new MLStatsNodeResponse(node, null, ImmutableMap.of()); + MLStatsNodeResponse response = new MLStatsNodeResponse(node, null, ImmutableMap.of(), ImmutableMap.of()); assertNull(response.getAlgorithmStats(FunctionName.BATCH_RCF)); assertEquals(0, response.getNodeLevelStatSize()); } public void testGetAlgorithmLevelStat_NonExistingAlgo() { - MLStatsNodeResponse response = createResponseWithDefaultAlgoStats(null); + MLStatsNodeResponse response = createResponseWithDefaultAlgoAndModelStats(null); assertEquals(0, response.getNodeLevelStatSize()); assertEquals(1, response.getAlgorithmStatSize()); assertNotNull(response.getAlgorithmStats(FunctionName.KMEANS)); diff --git a/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodesTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodesTransportActionTests.java index b54c9ca83a..35249f8ecf 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodesTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodesTransportActionTests.java @@ -30,6 +30,7 @@ import org.opensearch.ml.stats.MLActionStats; import org.opensearch.ml.stats.MLAlgoStats; import org.opensearch.ml.stats.MLClusterLevelStat; +import org.opensearch.ml.stats.MLModelStats; import org.opensearch.ml.stats.MLNodeLevelStat; import org.opensearch.ml.stats.MLStat; import org.opensearch.ml.stats.MLStatLevel; @@ -50,6 +51,8 @@ public class MLStatsNodesTransportActionTests extends OpenSearchIntegTestCase { private MLNodeLevelStat nodeStatName1; private Environment environment; + private final String modelId = "model_id"; + @Override @Before public void setUp() throws Exception { @@ -100,6 +103,7 @@ public void testNewNodeResponse() throws IOException { StreamInput in = out.bytes().streamInput(); MLStatsNodeResponse newStatsNodeResponse = action.newNodeResponse(in); Assert.assertEquals(statsNodeResponse.getNodeLevelStatSize(), newStatsNodeResponse.getAlgorithmStatSize()); + Assert.assertEquals(statsNodeResponse.getNodeLevelStatSize(), newStatsNodeResponse.getModelStatSize()); } public void testNodeOperation() { @@ -131,7 +135,7 @@ public void testNodeOperationWithJvmHeapUsage() { public void testNodeOperation_NoNodeLevelStat() { String nodeId = clusterService().localNode().getId(); - MLStatsInput mlStatsInput = MLStatsInput.builder().targetStatLevels(EnumSet.of(MLStatLevel.ALGORITHM)).build(); + MLStatsInput mlStatsInput = MLStatsInput.builder().targetStatLevels(EnumSet.of(MLStatLevel.ALGORITHM, MLStatLevel.MODEL)).build(); MLStatsNodesRequest mlStatsNodesRequest = new MLStatsNodesRequest(new String[] { nodeId }, mlStatsInput); MLStatsNodeResponse response = action.nodeOperation(new MLStatsNodeRequest(mlStatsNodesRequest)); @@ -142,6 +146,7 @@ public void testNodeOperation_NoNodeLevelStat() { public void testNodeOperation_NoNodeLevelStat_AlgoStat() { MLStats mlStats = new MLStats(statsMap); mlStats.createCounterStatIfAbsent(FunctionName.KMEANS, ActionName.TRAIN, MLActionLevelStat.ML_ACTION_REQUEST_COUNT).increment(); + mlStats.createModelCounterStatIfAbsent(modelId, ActionName.PREDICT, MLActionLevelStat.ML_ACTION_REQUEST_COUNT).increment(); MLStatsNodesTransportAction action = new MLStatsNodesTransportAction( client().threadPool(), @@ -153,18 +158,25 @@ public void testNodeOperation_NoNodeLevelStat_AlgoStat() { ); String nodeId = clusterService().localNode().getId(); - MLStatsInput mlStatsInput = MLStatsInput.builder().targetStatLevels(EnumSet.of(MLStatLevel.ALGORITHM)).build(); + MLStatsInput mlStatsInput = MLStatsInput.builder().targetStatLevels(EnumSet.of(MLStatLevel.ALGORITHM, MLStatLevel.MODEL)).build(); MLStatsNodesRequest mlStatsNodesRequest = new MLStatsNodesRequest(new String[] { nodeId }, mlStatsInput); MLStatsNodeResponse response = action.nodeOperation(new MLStatsNodeRequest(mlStatsNodesRequest)); assertEquals(0, response.getNodeLevelStatSize()); assertEquals(1, response.getAlgorithmStatSize()); + assertEquals(1, response.getModelStatSize()); MLAlgoStats algorithmStats = response.getAlgorithmStats(FunctionName.KMEANS); assertNotNull(algorithmStats); MLActionStats actionStats = algorithmStats.getActionStats(ActionName.TRAIN); assertNotNull(actionStats); assertEquals(1l, actionStats.getActionStat(MLActionLevelStat.ML_ACTION_REQUEST_COUNT)); + + MLModelStats modelStats = response.getModelStats(modelId); + assertNotNull(modelStats); + actionStats = modelStats.getActionStats(ActionName.PREDICT); + assertNotNull(actionStats); + assertEquals(1l, actionStats.getActionStat(MLActionLevelStat.ML_ACTION_REQUEST_COUNT)); } } diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java index d67f84de25..525bc95eda 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java @@ -837,6 +837,8 @@ private String[] createTempChunkFiles() throws IOException { public void testRegisterModelMeta() { setupForModelMeta(); + mock_MLIndicesHandler_initModelIndex(mlIndicesHandler, true); + mock_client_index(client, modelId); MLRegisterModelMetaInput registerModelMetaInput = prepareRequest(); modelManager.registerModelMeta(registerModelMetaInput, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(String.class); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLStatsActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLStatsActionTests.java index d0fae76be9..78c628cb19 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLStatsActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLStatsActionTests.java @@ -61,6 +61,7 @@ import org.opensearch.ml.stats.MLActionStats; import org.opensearch.ml.stats.MLAlgoStats; import org.opensearch.ml.stats.MLClusterLevelStat; +import org.opensearch.ml.stats.MLModelStats; import org.opensearch.ml.stats.MLNodeLevelStat; import org.opensearch.ml.stats.MLStat; import org.opensearch.ml.stats.MLStatLevel; @@ -108,6 +109,8 @@ public class RestMLStatsActionTests extends OpenSearchTestCase { long nodeTotalRequestCount = 100; long kmeansTrainRequestCount = 20; + String modelId = "model_id"; + @Before public void setup() throws IOException { MockitoAnnotations.openMocks(this); @@ -209,7 +212,7 @@ public void testPrepareRequest_ClusterAndNodeLevelStates() throws Exception { content .utf8ToString() .contains( - "\"nodes\":{\"node\":{\"ml_request_count\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_request_count\":20}}}}}}" + "\"nodes\":{\"node\":{\"ml_request_count\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_request_count\":20}}},\"models\":{\"model_id\":{\"train\":{\"ml_action_request_count\":20}}}}}" ) ); } @@ -226,7 +229,12 @@ private void prepareResponse() { new MLActionStats(ImmutableMap.of(MLActionLevelStat.ML_ACTION_REQUEST_COUNT, kmeansTrainRequestCount)) ); algoStats.put(FunctionName.KMEANS, new MLAlgoStats(actionStats)); - MLStatsNodeResponse nodeResponse = new MLStatsNodeResponse(node, nodeStats, algoStats); + + Map modelStats = new HashMap<>(); + + modelStats.put(modelId, new MLModelStats(actionStats)); + + MLStatsNodeResponse nodeResponse = new MLStatsNodeResponse(node, nodeStats, algoStats, modelStats); nodes.add(nodeResponse); MLStatsNodesResponse statsResponse = new MLStatsNodesResponse(clusterName, nodes, ImmutableList.of()); actionListener.onResponse(statsResponse); @@ -299,7 +307,7 @@ public void testPrepareRequest_ClusterAndNodeLevelStates_NoRequestContent() thro content .utf8ToString() .contains( - "\"nodes\":{\"node\":{\"ml_request_count\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_request_count\":20}}}}}}" + "\"nodes\":{\"node\":{\"ml_request_count\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_request_count\":20}}},\"models\":{\"model_id\":{\"train\":{\"ml_action_request_count\":20}}}}}" ) ); } @@ -334,7 +342,7 @@ public void testPrepareRequest_ClusterAndNodeLevelStates_RequestParams() throws content .utf8ToString() .contains( - "\"nodes\":{\"node\":{\"ml_request_count\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_request_count\":20}}}}}}" + "\"nodes\":{\"node\":{\"ml_request_count\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_request_count\":20}}},\"models\":{\"model_id\":{\"train\":{\"ml_action_request_count\":20}}}}}" ) ); } @@ -360,7 +368,7 @@ public void testPrepareRequest_ClusterAndNodeLevelStates_RequestParams_NodeLevel assertEquals(RestStatus.OK, restResponse.status()); BytesReference content = restResponse.content(); assertEquals( - "{\"nodes\":{\"node\":{\"ml_request_count\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_request_count\":20}}}}}}", + "{\"nodes\":{\"node\":{\"ml_request_count\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_request_count\":20}}},\"models\":{\"model_id\":{\"train\":{\"ml_action_request_count\":20}}}}}}", content.utf8ToString() ); } diff --git a/plugin/src/test/java/org/opensearch/ml/stats/MLModelStatsTests.java b/plugin/src/test/java/org/opensearch/ml/stats/MLModelStatsTests.java new file mode 100644 index 0000000000..8be4b56f05 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/stats/MLModelStatsTests.java @@ -0,0 +1,114 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.stats; + +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; +import static org.opensearch.ml.stats.MLActionLevelStat.ML_ACTION_FAILURE_COUNT; +import static org.opensearch.ml.stats.MLActionLevelStat.ML_ACTION_REQUEST_COUNT; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; + +import org.junit.Before; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.utils.TestHelper; +import org.opensearch.test.OpenSearchTestCase; + +import com.google.common.collect.ImmutableSet; + +public class MLModelStatsTests extends OpenSearchTestCase { + private MLModelStats mlModelStats; + private MLActionStats mlActionStats; + private long requestCount = 200; + private long failureCount = 100; + + @Before + public void setup() { + Map modelActionStats = new HashMap<>(); + modelActionStats.put(ML_ACTION_REQUEST_COUNT, requestCount); + modelActionStats.put(ML_ACTION_FAILURE_COUNT, failureCount); + mlActionStats = new MLActionStats(modelActionStats); + + Map modelStats = new HashMap<>(); + modelStats.put(ActionName.PREDICT, mlActionStats); + mlModelStats = new MLModelStats(modelStats); + } + + public void testSerializationDeserialization() throws IOException { + BytesStreamOutput output = new BytesStreamOutput(); + mlModelStats.writeTo(output); + MLModelStats parsedMLModelStats = new MLModelStats(output.bytes().streamInput()); + MLActionStats parsedMLActionStats = parsedMLModelStats.getActionStats(ActionName.PREDICT); + assertEquals(2, parsedMLActionStats.getActionStatSize()); + assertEquals(requestCount, parsedMLActionStats.getActionStat(ML_ACTION_REQUEST_COUNT)); + assertEquals(failureCount, parsedMLActionStats.getActionStat(ML_ACTION_FAILURE_COUNT)); + } + + public void testEmptySerializationDeserialization() throws IOException { + + Map modelStats = new HashMap<>(); + MLModelStats mlModelEmptyStats = new MLModelStats(modelStats); + BytesStreamOutput output = new BytesStreamOutput(); + mlModelEmptyStats.writeTo(output); + MLModelStats parsedMLModelStats = new MLModelStats(output.bytes().streamInput()); + MLActionStats parsedMLActionStats = parsedMLModelStats.getActionStats(ActionName.PREDICT); + assertNull(parsedMLActionStats); + // assertEquals(0, output.bytes().length()); + } + + public void testToXContent() throws IOException { + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + builder.startObject(); + mlModelStats.toXContent(builder, EMPTY_PARAMS); + builder.endObject(); + String content = TestHelper.xContentBuilderToString(builder); + Set validContents = ImmutableSet + .of( + "{\"predict\":{\"ml_action_request_count\":200,\"ml_action_failure_count\":100}}", + "{\"predict\":{\"ml_action_failure_count\":100,\"ml_action_request_count\":200}}" + ); + assertTrue(validContents.contains(content)); + } + + public void testToXContent_EmptyStats() throws IOException { + Map statMap = new HashMap<>(); + MLAlgoStats stats = new MLAlgoStats(statMap); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + builder.startObject(); + stats.toXContent(builder, EMPTY_PARAMS); + builder.endObject(); + String content = TestHelper.xContentBuilderToString(builder); + assertEquals("{}", content); + } + + public void testToXContent_NullStats() throws IOException { + Map statMap = null; + MLAlgoStats stats = new MLAlgoStats(statMap); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + builder.startObject(); + stats.toXContent(builder, EMPTY_PARAMS); + builder.endObject(); + String content = TestHelper.xContentBuilderToString(builder); + assertEquals("{}", content); + } + + public void testGetActionStats() { + assertNotNull(mlModelStats.getActionStats(ActionName.PREDICT)); + + // null stats + Map statMap = null; + MLModelStats stats = new MLModelStats(statMap); + assertNull(stats.getActionStats(ActionName.PREDICT)); + + // empty stats + stats = new MLModelStats(new HashMap<>()); + assertNull(stats.getActionStats(ActionName.PREDICT)); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/stats/MLStatsInputTests.java b/plugin/src/test/java/org/opensearch/ml/stats/MLStatsInputTests.java index ddad31f58d..8ac5494bc1 100644 --- a/plugin/src/test/java/org/opensearch/ml/stats/MLStatsInputTests.java +++ b/plugin/src/test/java/org/opensearch/ml/stats/MLStatsInputTests.java @@ -27,6 +27,8 @@ public class MLStatsInputTests extends OpenSearchTestCase { private String node1 = "node1"; private String node2 = "node2"; + private String modelId = "model_id"; + @Before public void setup() { mlStatsInput = MLStatsInput @@ -37,6 +39,7 @@ public void setup() { .actionLevelStats(EnumSet.allOf(MLActionLevelStat.class)) .nodeIds(ImmutableSet.of(node1, node2)) .algorithms(EnumSet.allOf(FunctionName.class)) + .models(ImmutableSet.of(modelId)) .actions(EnumSet.allOf(ActionName.class)) .build(); } @@ -59,6 +62,7 @@ public void testParseMLStatsInput() throws IOException { public void testRetrieveAll() { assertFalse(mlStatsInput.retrieveStatsForAllAlgos()); + assertFalse(mlStatsInput.retrieveStatsForAllModels()); assertFalse(mlStatsInput.retrieveAllNodeLevelStats()); assertFalse(mlStatsInput.retrieveStatsForAllActions()); assertFalse(mlStatsInput.retrieveAllClusterLevelStats()); @@ -66,6 +70,7 @@ public void testRetrieveAll() { assertFalse(mlStatsInput.retrieveAllActionLevelStats()); MLStatsInput mlStatsInput = MLStatsInput.builder().build(); + assertTrue(mlStatsInput.retrieveStatsForAllModels()); assertTrue(mlStatsInput.retrieveStatsForAllAlgos()); assertTrue(mlStatsInput.retrieveAllNodeLevelStats()); assertTrue(mlStatsInput.retrieveStatsForAllActions()); @@ -75,6 +80,7 @@ public void testRetrieveAll() { mlStatsInput = new MLStatsInput(); assertTrue(mlStatsInput.retrieveStatsForAllAlgos()); + assertTrue(mlStatsInput.retrieveStatsForAllModels()); assertTrue(mlStatsInput.retrieveAllNodeLevelStats()); assertTrue(mlStatsInput.retrieveStatsForAllActions()); assertTrue(mlStatsInput.retrieveAllClusterLevelStats()); @@ -123,6 +129,9 @@ public void testOnlyRetrieveClusterLevelStats() { mlStatsInput = MLStatsInput.builder().targetStatLevels(EnumSet.of(MLStatLevel.ALGORITHM)).build(); assertFalse(mlStatsInput.onlyRetrieveClusterLevelStats()); + mlStatsInput = MLStatsInput.builder().targetStatLevels(EnumSet.of(MLStatLevel.MODEL)).build(); + assertFalse(mlStatsInput.onlyRetrieveClusterLevelStats()); + mlStatsInput = MLStatsInput.builder().targetStatLevels(EnumSet.of(MLStatLevel.ACTION)).build(); assertFalse(mlStatsInput.onlyRetrieveClusterLevelStats()); } @@ -148,6 +157,7 @@ private void verifyParsedMLStatsInput(MLStatsInput parsedMLStatsInput) { mlStatsInput.getAlgorithms().toArray(new FunctionName[0]), parsedMLStatsInput.getAlgorithms().toArray(new FunctionName[0]) ); + assertArrayEquals(mlStatsInput.getModels().toArray(new String[0]), parsedMLStatsInput.getModels().toArray(new String[0])); assertArrayEquals(mlStatsInput.getActions().toArray(new ActionName[0]), parsedMLStatsInput.getActions().toArray(new ActionName[0])); assertEquals(2, parsedMLStatsInput.getNodeIds().size()); assertTrue(parsedMLStatsInput.getNodeIds().contains(node1)); diff --git a/plugin/src/test/java/org/opensearch/ml/stats/MLStatsTests.java b/plugin/src/test/java/org/opensearch/ml/stats/MLStatsTests.java index 51bde6bd3a..10e05639dd 100644 --- a/plugin/src/test/java/org/opensearch/ml/stats/MLStatsTests.java +++ b/plugin/src/test/java/org/opensearch/ml/stats/MLStatsTests.java @@ -26,6 +26,8 @@ public class MLStatsTests extends OpenSearchTestCase { private MLClusterLevelStat clusterStatName1; private MLNodeLevelStat nodeStatName1; + private String modelID = "model_id"; + @Rule public ExpectedException expectedEx = ExpectedException.none(); @@ -113,6 +115,11 @@ public void testGetAlgorithmStats_Empty() { assertNull(algorithmStats); } + public void testGetModelStats_Empty() { + Map modelStats = mlStats.getModelStats(modelID); + assertNull(modelStats); + } + public void testGetAlgorithmStats() { MLStats stats = new MLStats(statsMap); MLStat statCounter = stats.createCounterStatIfAbsent(FunctionName.KMEANS, ActionName.TRAIN, ML_ACTION_REQUEST_COUNT); @@ -122,11 +129,25 @@ public void testGetAlgorithmStats() { assertEquals(1l, algorithmStats.get(ActionName.TRAIN).getActionStat(ML_ACTION_REQUEST_COUNT)); } + public void testGetModelStats() { + MLStats stats = new MLStats(statsMap); + MLStat statCounter = stats.createModelCounterStatIfAbsent(modelID, ActionName.TRAIN, ML_ACTION_REQUEST_COUNT); + statCounter.increment(); + Map modelStats = stats.getModelStats(modelID); + assertNotNull(modelStats); + assertEquals(1l, modelStats.get(ActionName.TRAIN).getActionStat(ML_ACTION_REQUEST_COUNT)); + } + public void testGetAllAlgorithms_Empty() { FunctionName[] allAlgorithms = mlStats.getAllAlgorithms(); assertEquals(0, allAlgorithms.length); } + public void testGetAllModels_Empty() { + String[] allModels = mlStats.getAllModels(); + assertEquals(0, allModels.length); + } + public void testGetAllAlgorithms() { MLStats stats = new MLStats(statsMap); MLStat statCounter = stats.createCounterStatIfAbsent(FunctionName.KMEANS, ActionName.TRAIN, ML_ACTION_REQUEST_COUNT); @@ -134,4 +155,12 @@ public void testGetAllAlgorithms() { FunctionName[] allAlgorithms = stats.getAllAlgorithms(); assertArrayEquals(new FunctionName[] { FunctionName.KMEANS }, allAlgorithms); } + + public void testGetAllModels() { + MLStats stats = new MLStats(statsMap); + MLStat statCounter = stats.createModelCounterStatIfAbsent(modelID, ActionName.PREDICT, ML_ACTION_REQUEST_COUNT); + statCounter.increment(); + String[] allModels = stats.getAllModels(); + assertArrayEquals(new String[] { modelID }, allModels); + } }