From 555608028f3aa6fc067852c730f321b38682b1d1 Mon Sep 17 00:00:00 2001 From: Bhavana Ramaram Date: Tue, 8 Oct 2024 13:17:09 -0500 Subject: [PATCH] enhance batch job task management by adding default action types Signed-off-by: Bhavana Ramaram --- .../common/connector/AbstractConnector.java | 5 ++ .../ml/common/connector/Connector.java | 2 + .../ml/common/connector/ConnectorAction.java | 2 + .../ml/common/output/MLPredictionOutput.java | 24 +++++++ .../tasks/CancelBatchJobTransportAction.java | 66 ++++++++++++++++++ .../action/tasks/GetTaskTransportAction.java | 67 +++++++++++++++++++ .../ml/task/MLPredictTaskRunner.java | 10 +-- 7 files changed, 171 insertions(+), 5 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java b/common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java index 4849f79c93..7eeaf25670 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java @@ -125,6 +125,11 @@ public Optional findAction(String action) { return Optional.empty(); } + @Override + public void setAction(ConnectorAction action) { + actions.add(action); + } + @Override public void removeCredential() { this.credential = null; diff --git a/common/src/main/java/org/opensearch/ml/common/connector/Connector.java b/common/src/main/java/org/opensearch/ml/common/connector/Connector.java index 0a37641144..5c33ea6898 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/Connector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/Connector.java @@ -65,6 +65,8 @@ public interface Connector extends ToXContentObject, Writeable { List getActions(); + void setAction(ConnectorAction action); + ConnectorClientConfig getConnectorClientConfig(); String getActionEndpoint(String action, Map parameters); diff --git a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java index 4a7555d69b..e60f1bd33e 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java @@ -23,6 +23,7 @@ import lombok.Builder; import lombok.EqualsAndHashCode; import lombok.Getter; +import lombok.Setter; @Getter @EqualsAndHashCode @@ -36,6 +37,7 @@ public class ConnectorAction implements ToXContentObject, Writeable { public static final String ACTION_PRE_PROCESS_FUNCTION = "pre_process_function"; public static final String ACTION_POST_PROCESS_FUNCTION = "post_process_function"; + private ActionType actionType; private String method; private String url; diff --git a/common/src/main/java/org/opensearch/ml/common/output/MLPredictionOutput.java b/common/src/main/java/org/opensearch/ml/common/output/MLPredictionOutput.java index 5675dab409..1b8eb8bd6c 100644 --- a/common/src/main/java/org/opensearch/ml/common/output/MLPredictionOutput.java +++ b/common/src/main/java/org/opensearch/ml/common/output/MLPredictionOutput.java @@ -6,6 +6,7 @@ package org.opensearch.ml.common.output; import java.io.IOException; +import java.util.Map; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; @@ -29,9 +30,11 @@ public class MLPredictionOutput extends MLOutput { public static final String TASK_ID_FIELD = "task_id"; public static final String STATUS_FIELD = "status"; public static final String PREDICTION_RESULT_FIELD = "prediction_result"; + public static final String REMOTE_JOB_FIELD = "remote_job"; String taskId; String status; + Map remoteJob; @ToString.Exclude DataFrame predictionResult; @@ -44,6 +47,14 @@ public MLPredictionOutput(String taskId, String status, DataFrame predictionResu this.predictionResult = predictionResult; } + @Builder + public MLPredictionOutput(String taskId, String status, Map remoteJob) { + super(OUTPUT_TYPE); + this.taskId = taskId; + this.status = status; + this.remoteJob = remoteJob; + } + public MLPredictionOutput(StreamInput in) throws IOException { super(OUTPUT_TYPE); this.taskId = in.readOptionalString(); @@ -56,6 +67,9 @@ public MLPredictionOutput(StreamInput in) throws IOException { break; } } + if (in.readBoolean()) { + this.remoteJob = in.readMap(s -> s.readString(), s -> s.readGenericValue()); + } } @Override @@ -69,6 +83,12 @@ public void writeTo(StreamOutput out) throws IOException { } else { out.writeBoolean(false); } + if (remoteJob != null) { + out.writeBoolean(true); + out.writeMap(remoteJob, StreamOutput::writeString, StreamOutput::writeGenericValue); + } else { + out.writeBoolean(false); + } } @Override @@ -87,6 +107,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.endObject(); } + if (remoteJob != null) { + builder.field(REMOTE_JOB_FIELD, remoteJob); + } + builder.endObject(); return builder; } diff --git a/plugin/src/main/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportAction.java index 95e43ca929..6da202cc0a 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportAction.java @@ -8,15 +8,18 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX; +import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.BATCH_PREDICT; import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.CANCEL_BATCH_PREDICT; import static org.opensearch.ml.utils.MLExceptionUtils.BATCH_INFERENCE_DISABLED_ERR_MSG; import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry; +import java.util.Collections; import java.util.HashMap; import java.util.Map; import java.util.Optional; import org.apache.hc.core5.http.HttpStatus; +import org.apache.commons.text.StringSubstitutor; import org.opensearch.OpenSearchException; import org.opensearch.OpenSearchStatusException; import org.opensearch.ResourceNotFoundException; @@ -38,6 +41,7 @@ import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.MLTaskType; import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.connector.ConnectorAction; import org.opensearch.ml.common.connector.ConnectorAction.ActionType; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.exception.MLResourceNotFoundException; @@ -210,6 +214,11 @@ private void processRemoteBatchPrediction(MLTask mlTask, ActionListener actionListener) { if (connectorAccessControlHelper.validateConnectorAccess(client, connector)) { + Optional cancelBatchPredictAction = connector.findAction(CANCEL_BATCH_PREDICT.name()); + if (!cancelBatchPredictAction.isPresent() || cancelBatchPredictAction.get().getRequestBody() == null) { + ConnectorAction connectorAction = createConnectorAction(connector); + connector.setAction(connectorAction); + } connector.decrypt(CANCEL_BATCH_PREDICT.name(), (credential) -> encryptor.decrypt(credential)); RemoteConnectorExecutor connectorExecutor = MLEngineClassLoader .initInstance(connector.getProtocol(), connector, Connector.class); @@ -245,4 +254,61 @@ private void processTaskResponse(MLTaskResponse taskResponse, ActionListener batchPredictAction = connector.findAction(BATCH_PREDICT.name()); + + Map headers = batchPredictAction.get().getHeaders(); + + String predictEndpoint = batchPredictAction.get().getUrl(); + Map parameters = connector.getParameters() != null + ? new HashMap<>(connector.getParameters()) + : Collections.emptyMap(); + + if (!parameters.isEmpty()) { + StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}"); + predictEndpoint = substitutor.replace(predictEndpoint); + } + + String url = ""; + String requestBody = ""; + String method = "POST"; // Default method + + switch (getEndpointType(predictEndpoint)) { + case "sagemaker": + url = predictEndpoint.replace("CreateTransformJob", "StopTransformJob"); + requestBody = "{ \"TransformJobName\" : \"${parameters.TransformJobName}\"}"; + break; + case "openai": + case "cohere": + url = predictEndpoint + "/${parameters.id}/cancel"; + break; + case "bedrock": + url = predictEndpoint + "/${parameters.processedJobArn}/stop"; + break; + } + + return ConnectorAction + .builder() + .actionType(CANCEL_BATCH_PREDICT) + .method(method) + .url(url) + .requestBody(requestBody) + .headers(headers) + .build(); + + } + + private String getEndpointType(String url) { + if (url.contains("sagemaker")) + return "sagemaker"; + if (url.contains("openai")) + return "openai"; + if (url.contains("bedrock")) + return "bedrock"; + if (url.contains("cohere")) + return "cohere"; + return ""; + } } diff --git a/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java index e2e9109cf2..a79432f977 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java @@ -14,6 +14,7 @@ import static org.opensearch.ml.common.MLTaskState.CANCELLING; import static org.opensearch.ml.common.MLTaskState.COMPLETED; import static org.opensearch.ml.common.MLTaskState.EXPIRED; +import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.BATCH_PREDICT; import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.BATCH_PREDICT_STATUS; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_CANCELLED_REGEX; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_CANCELLING_REGEX; @@ -24,6 +25,7 @@ import static org.opensearch.ml.utils.MLExceptionUtils.logException; import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -32,6 +34,7 @@ import java.util.regex.Matcher; import java.util.regex.Pattern; +import org.apache.commons.text.StringSubstitutor; import org.opensearch.OpenSearchException; import org.opensearch.OpenSearchStatusException; import org.opensearch.ResourceNotFoundException; @@ -55,6 +58,7 @@ import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.MLTaskType; import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.connector.ConnectorAction; import org.opensearch.ml.common.connector.ConnectorAction.ActionType; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.exception.MLResourceNotFoundException; @@ -279,6 +283,11 @@ private void executeConnector( ActionListener actionListener ) { if (connectorAccessControlHelper.validateConnectorAccess(client, connector)) { + Optional batchPredictStatusAction = connector.findAction(BATCH_PREDICT_STATUS.name()); + if (!batchPredictStatusAction.isPresent() || batchPredictStatusAction.get().getRequestBody() == null) { + ConnectorAction connectorAction = createConnectorAction(connector); + connector.setAction(connectorAction); + } connector.decrypt(BATCH_PREDICT_STATUS.name(), (credential) -> encryptor.decrypt(credential)); RemoteConnectorExecutor connectorExecutor = MLEngineClassLoader .initInstance(connector.getProtocol(), connector, Connector.class); @@ -362,4 +371,62 @@ private boolean matchesPattern(Pattern pattern, String input) { Matcher matcher = pattern.matcher(input); return matcher.find(); } + + // TODO: move this method to connector utils class + private ConnectorAction createConnectorAction(Connector connector) { + Optional batchPredictAction = connector.findAction(BATCH_PREDICT.name()); + + Map headers = batchPredictAction.get().getHeaders(); + + String predictEndpoint = batchPredictAction.get().getUrl(); + Map parameters = connector.getParameters() != null + ? new HashMap<>(connector.getParameters()) + : Collections.emptyMap(); + + // Apply parameter substitution only if needed + if (!parameters.isEmpty()) { + StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}"); + predictEndpoint = substitutor.replace(predictEndpoint); + } + + String url = ""; + String requestBody = ""; + String method = "GET"; + + switch (getEndpointType(predictEndpoint)) { + case "sagemaker": + url = predictEndpoint.replace("CreateTransformJob", "DescribeTransformJob"); + requestBody = "{ \"TransformJobName\" : \"${parameters.TransformJobName}\"}"; + method = "POST"; + break; + case "openai": + case "cohere": + url = predictEndpoint + "/${parameters.id}"; + break; + case "bedrock": + url = predictEndpoint + "/${parameters.processedJobArn}"; + break; + } + return ConnectorAction + .builder() + .actionType(BATCH_PREDICT_STATUS) + .method(method) + .url(url) + .requestBody(requestBody) + .headers(headers) + .build(); + + } + + private String getEndpointType(String url) { + if (url.contains("sagemaker")) + return "sagemaker"; + if (url.contains("openai")) + return "openai"; + if (url.contains("bedrock")) + return "bedrock"; + if (url.contains("cohere")) + return "cohere"; + return ""; + } } diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java index 525ae12a88..a59d7bbe2b 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java @@ -371,11 +371,11 @@ private void runPredict( mlTaskManager.createMLTask(mlTask, ActionListener.wrap(response -> { String taskId = response.getId(); mlTask.setTaskId(taskId); - MLPredictionOutput outputBuilder = MLPredictionOutput - .builder() - .taskId(taskId) - .status(MLTaskState.CREATED.name()) - .build(); + MLPredictionOutput outputBuilder = new MLPredictionOutput( + taskId, + MLTaskState.CREATED.name(), + remoteJob + ); MLTaskResponse predictOutput = MLTaskResponse.builder().output(outputBuilder).build(); internalListener.onResponse(predictOutput);