From 5d9324c8d0e684ab486fc1f530d44a4fffafb1f6 Mon Sep 17 00:00:00 2001 From: Jackie Han Date: Mon, 23 Oct 2023 15:35:07 -0700 Subject: [PATCH] Expose execute api for MLClient (#1540) * Expose execute api for MLClient Signed-off-by: Jackie Han * unit test change Signed-off-by: Jackie Han --------- Signed-off-by: Jackie Han --- .../ml/client/MachineLearningClient.java | 22 +++++ .../ml/client/MachineLearningNodeClient.java | 17 ++++ .../ml/client/MachineLearningClientTest.java | 66 +++++++++++++ .../client/MachineLearningNodeClientTest.java | 93 +++++++++++++++++++ 4 files changed, 198 insertions(+) diff --git a/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java b/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java index a14db534c7..55c3178ad8 100644 --- a/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java +++ b/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java @@ -13,13 +13,16 @@ import org.opensearch.action.support.PlainActionFuture; import org.opensearch.common.action.ActionFuture; import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.MLTask; +import org.opensearch.ml.common.input.Input; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse; import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse; +import org.opensearch.ml.common.transport.execute.MLExecuteTaskResponse; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupResponse; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; @@ -299,4 +302,23 @@ default ActionFuture registerModelGroup(MLRegister * @param listener a listener to be notified of the result */ void registerModelGroup(MLRegisterModelGroupInput mlRegisterModelGroupInput, ActionListener listener); + + /** + * Execute an algorithm + * @param name algorithm function name + * @param input input + * @return the result future + */ + default ActionFuture execute(FunctionName name, Input input) { + PlainActionFuture actionFuture = PlainActionFuture.newFuture(); + execute(name, input, actionFuture); + return actionFuture; + } + + /** + * Execute an algorithm + * @param input an algorithm input + * @param listener a listener to be notified of the result + */ + void execute(FunctionName name, Input input, ActionListener listener); } diff --git a/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java b/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java index 953c9350c5..e1fd6445a2 100644 --- a/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java +++ b/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java @@ -26,6 +26,7 @@ import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.MLTask; +import org.opensearch.ml.common.input.Input; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.input.parameter.MLAlgoParams; import org.opensearch.ml.common.output.MLOutput; @@ -37,6 +38,9 @@ import org.opensearch.ml.common.transport.deploy.MLDeployModelAction; import org.opensearch.ml.common.transport.deploy.MLDeployModelRequest; import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse; +import org.opensearch.ml.common.transport.execute.MLExecuteTaskAction; +import org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest; +import org.opensearch.ml.common.transport.execute.MLExecuteTaskResponse; import org.opensearch.ml.common.transport.model.MLModelDeleteAction; import org.opensearch.ml.common.transport.model.MLModelDeleteRequest; import org.opensearch.ml.common.transport.model.MLModelGetAction; @@ -182,6 +186,19 @@ public void registerModelGroup( client.execute(MLRegisterModelGroupAction.INSTANCE, mlRegisterModelGroupRequest, listener); } + /** + * Execute an algorithm + * + * @param name function name + * @param input an algorithm input + * @param listener a listener to be notified of the result + */ + @Override + public void execute(FunctionName name, Input input, ActionListener listener) { + MLExecuteTaskRequest mlExecuteTaskRequest = new MLExecuteTaskRequest(name, input); + client.execute(MLExecuteTaskAction.INSTANCE, mlExecuteTaskRequest, listener); + } + @Override public void getTask(String taskId, ActionListener listener) { MLTaskGetRequest mlTaskGetRequest = MLTaskGetRequest.builder().taskId(taskId).build(); diff --git a/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java b/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java index 0032f4bf60..2dda397016 100644 --- a/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java +++ b/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java @@ -12,6 +12,7 @@ import static org.opensearch.ml.common.input.Constants.KMEANS; import static org.opensearch.ml.common.input.Constants.TRAIN; +import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.List; @@ -32,7 +33,9 @@ import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.dataframe.DataFrame; import org.opensearch.ml.common.dataset.DataFrameInputDataset; +import org.opensearch.ml.common.input.Input; import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.input.execute.metricscorrelation.MetricsCorrelationInput; import org.opensearch.ml.common.input.parameter.MLAlgoParams; import org.opensearch.ml.common.model.MLModelConfig; import org.opensearch.ml.common.model.MLModelFormat; @@ -42,6 +45,7 @@ import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse; import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse; +import org.opensearch.ml.common.transport.execute.MLExecuteTaskResponse; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupResponse; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; @@ -81,6 +85,9 @@ public class MachineLearningClientTest { @Mock MLRegisterModelGroupResponse registerModelGroupResponse; + @Mock + MLExecuteTaskResponse mlExecuteTaskResponse; + private String modekId = "test_model_id"; private MLModel mlModel; private MLTask mlTask; @@ -161,6 +168,11 @@ public void createConnector(MLCreateConnectorInput mlCreateConnectorInput, Actio listener.onResponse(createConnectorResponse); } + @Override + public void execute(FunctionName name, Input input, ActionListener listener) { + listener.onResponse(mlExecuteTaskResponse); + } + public void registerModelGroup( MLRegisterModelGroupInput mlRegisterModelGroupInput, ActionListener listener @@ -354,4 +366,58 @@ public void createConnector() { assertEquals(createConnectorResponse, machineLearningClient.createConnector(mlCreateConnectorInput).actionGet()); } + + @Test + public void executeMetricsCorrelation() { + List inputData = new ArrayList<>( + Arrays + .asList( + new float[] { + 0.89451003f, + 4.2006273f, + 0.3697659f, + 2.2458954f, + -4.671612f, + -1.5076426f, + 1.635445f, + -1.1394824f, + -0.7503817f, + 0.98424894f, + -0.38896716f, + 1.0328646f, + 1.9543738f, + -0.5236269f, + 0.14298044f, + 3.2963762f, + 8.1641035f, + 5.717064f, + 7.4869685f, + 2.5987444f, + 11.018798f, + 9.151356f, + 5.7354255f, + 6.862203f, + 3.0524514f, + 4.431755f, + 5.1481285f, + 7.9548607f, + 7.4519925f, + 6.09533f, + 7.634116f, + 8.898271f, + 3.898491f, + 9.447067f, + 8.197385f, + 5.8284273f, + 5.804283f, + 7.089733f, + 9.140584f } + ) + ); + Input metricsCorrelationInput = MetricsCorrelationInput.builder().inputData(inputData).build(); + assertEquals( + mlExecuteTaskResponse, + machineLearningClient.execute(FunctionName.METRICS_CORRELATION, metricsCorrelationInput).actionGet() + ); + } } diff --git a/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java b/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java index ccdf812195..591797b758 100644 --- a/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java +++ b/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java @@ -6,6 +6,7 @@ package org.opensearch.ml.client; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; import static org.mockito.Answers.RETURNS_DEEP_STUBS; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; @@ -22,6 +23,7 @@ import static org.opensearch.ml.common.input.Constants.TRAINANDPREDICT; import java.io.IOException; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; @@ -58,13 +60,19 @@ import org.opensearch.ml.common.MLTaskType; import org.opensearch.ml.common.dataframe.DataFrame; import org.opensearch.ml.common.dataset.MLInputDataset; +import org.opensearch.ml.common.input.Input; import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.input.execute.metricscorrelation.MetricsCorrelationInput; import org.opensearch.ml.common.model.MLModelConfig; import org.opensearch.ml.common.model.MLModelFormat; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.output.MLPredictionOutput; import org.opensearch.ml.common.output.MLTrainingOutput; +import org.opensearch.ml.common.output.Output; +import org.opensearch.ml.common.output.execute.metrics_correlation.MCorrModelTensor; +import org.opensearch.ml.common.output.execute.metrics_correlation.MCorrModelTensors; +import org.opensearch.ml.common.output.execute.metrics_correlation.MetricsCorrelationOutput; import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.common.transport.connector.MLCreateConnectorAction; import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; @@ -73,6 +81,9 @@ import org.opensearch.ml.common.transport.deploy.MLDeployModelAction; import org.opensearch.ml.common.transport.deploy.MLDeployModelRequest; import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse; +import org.opensearch.ml.common.transport.execute.MLExecuteTaskAction; +import org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest; +import org.opensearch.ml.common.transport.execute.MLExecuteTaskResponse; import org.opensearch.ml.common.transport.model.MLModelDeleteAction; import org.opensearch.ml.common.transport.model.MLModelDeleteRequest; import org.opensearch.ml.common.transport.model.MLModelGetAction; @@ -152,6 +163,9 @@ public class MachineLearningNodeClientTest { @Mock ActionListener registerModelGroupResponseActionListener; + @Mock + ActionListener executeTaskResponseActionListener; + @InjectMocks MachineLearningNodeClient machineLearningNodeClient; @@ -676,6 +690,85 @@ public void createConnector() { } + @Test + public void executeMetricsCorrelation() { + Output metricsCorrelationOutput; + List outputs = new ArrayList<>(); + MCorrModelTensor mCorrModelTensor = MCorrModelTensor + .builder() + .event_pattern(new float[] { 1.0f, 2.0f, 3.0f }) + .event_window(new float[] { 4.0f, 5.0f, 6.0f }) + .suspected_metrics(new long[] { 1, 2 }) + .build(); + List mlModelTensors = Arrays.asList(mCorrModelTensor); + MCorrModelTensors modelTensors = MCorrModelTensors.builder().mCorrModelTensors(mlModelTensors).build(); + outputs.add(modelTensors); + metricsCorrelationOutput = MetricsCorrelationOutput.builder().modelOutput(outputs).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + MLExecuteTaskResponse output = new MLExecuteTaskResponse(FunctionName.METRICS_CORRELATION, metricsCorrelationOutput); + actionListener.onResponse(output); + return null; + }).when(client).execute(eq(MLExecuteTaskAction.INSTANCE), any(), any()); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLExecuteTaskResponse.class); + + List inputData = new ArrayList<>( + Arrays + .asList( + new float[] { + 0.89451003f, + 4.2006273f, + 0.3697659f, + 2.2458954f, + -4.671612f, + -1.5076426f, + 1.635445f, + -1.1394824f, + -0.7503817f, + 0.98424894f, + -0.38896716f, + 1.0328646f, + 1.9543738f, + -0.5236269f, + 0.14298044f, + 3.2963762f, + 8.1641035f, + 5.717064f, + 7.4869685f, + 2.5987444f, + 11.018798f, + 9.151356f, + 5.7354255f, + 6.862203f, + 3.0524514f, + 4.431755f, + 5.1481285f, + 7.9548607f, + 7.4519925f, + 6.09533f, + 7.634116f, + 8.898271f, + 3.898491f, + 9.447067f, + 8.197385f, + 5.8284273f, + 5.804283f, + 7.089733f, + 9.140584f } + ) + ); + Input metricsCorrelationInput = MetricsCorrelationInput.builder().inputData(inputData).build(); + + machineLearningNodeClient.execute(FunctionName.METRICS_CORRELATION, metricsCorrelationInput, executeTaskResponseActionListener); + + verify(client).execute(eq(MLExecuteTaskAction.INSTANCE), isA(MLExecuteTaskRequest.class), any()); + verify(executeTaskResponseActionListener).onResponse(argumentCaptor.capture()); + assertEquals(FunctionName.METRICS_CORRELATION, argumentCaptor.getValue().getFunctionName()); + assertTrue(argumentCaptor.getValue().getOutput() instanceof MetricsCorrelationOutput); + } + private SearchResponse createSearchResponse(ToXContentObject o) throws IOException { XContentBuilder content = o.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS);