Skip to content

Commit

Permalink
enhance batch job task management by adding default action types
Browse files Browse the repository at this point in the history
Signed-off-by: Bhavana Ramaram <rbhavna@amazon.com>
  • Loading branch information
rbhavna committed Oct 8, 2024
1 parent 74c211e commit 5556080
Show file tree
Hide file tree
Showing 7 changed files with 171 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,11 @@ public Optional<ConnectorAction> findAction(String action) {
return Optional.empty();
}

@Override
public void setAction(ConnectorAction action) {
actions.add(action);
}

@Override
public void removeCredential() {
this.credential = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ public interface Connector extends ToXContentObject, Writeable {

List<ConnectorAction> getActions();

void setAction(ConnectorAction action);

ConnectorClientConfig getConnectorClientConfig();

String getActionEndpoint(String action, Map<String, String> parameters);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.Setter;

@Getter
@EqualsAndHashCode
Expand All @@ -36,6 +37,7 @@ public class ConnectorAction implements ToXContentObject, Writeable {
public static final String ACTION_PRE_PROCESS_FUNCTION = "pre_process_function";
public static final String ACTION_POST_PROCESS_FUNCTION = "post_process_function";


private ActionType actionType;
private String method;
private String url;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.common.output;

import java.io.IOException;
import java.util.Map;

import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
Expand All @@ -29,9 +30,11 @@ public class MLPredictionOutput extends MLOutput {
public static final String TASK_ID_FIELD = "task_id";
public static final String STATUS_FIELD = "status";
public static final String PREDICTION_RESULT_FIELD = "prediction_result";
public static final String REMOTE_JOB_FIELD = "remote_job";

String taskId;
String status;
Map<String, Object> remoteJob;

@ToString.Exclude
DataFrame predictionResult;
Expand All @@ -44,6 +47,14 @@ public MLPredictionOutput(String taskId, String status, DataFrame predictionResu
this.predictionResult = predictionResult;
}

@Builder
public MLPredictionOutput(String taskId, String status, Map<String, Object> remoteJob) {
super(OUTPUT_TYPE);
this.taskId = taskId;
this.status = status;
this.remoteJob = remoteJob;
}

public MLPredictionOutput(StreamInput in) throws IOException {
super(OUTPUT_TYPE);
this.taskId = in.readOptionalString();
Expand All @@ -56,6 +67,9 @@ public MLPredictionOutput(StreamInput in) throws IOException {
break;
}
}
if (in.readBoolean()) {
this.remoteJob = in.readMap(s -> s.readString(), s -> s.readGenericValue());
}
}

@Override
Expand All @@ -69,6 +83,12 @@ public void writeTo(StreamOutput out) throws IOException {
} else {
out.writeBoolean(false);
}
if (remoteJob != null) {
out.writeBoolean(true);
out.writeMap(remoteJob, StreamOutput::writeString, StreamOutput::writeGenericValue);
} else {
out.writeBoolean(false);
}
}

@Override
Expand All @@ -87,6 +107,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.endObject();
}

if (remoteJob != null) {
builder.field(REMOTE_JOB_FIELD, remoteJob);
}

builder.endObject();
return builder;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,18 @@
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX;
import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX;
import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.BATCH_PREDICT;
import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.CANCEL_BATCH_PREDICT;
import static org.opensearch.ml.utils.MLExceptionUtils.BATCH_INFERENCE_DISABLED_ERR_MSG;
import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry;

import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;

import org.apache.hc.core5.http.HttpStatus;
import org.apache.commons.text.StringSubstitutor;
import org.opensearch.OpenSearchException;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.ResourceNotFoundException;
Expand All @@ -38,6 +41,7 @@
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskType;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.connector.ConnectorAction;
import org.opensearch.ml.common.connector.ConnectorAction.ActionType;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.exception.MLResourceNotFoundException;
Expand Down Expand Up @@ -210,6 +214,11 @@ private void processRemoteBatchPrediction(MLTask mlTask, ActionListener<MLCancel

private void executeConnector(Connector connector, MLInput mlInput, ActionListener<MLCancelBatchJobResponse> actionListener) {
if (connectorAccessControlHelper.validateConnectorAccess(client, connector)) {
Optional<ConnectorAction> cancelBatchPredictAction = connector.findAction(CANCEL_BATCH_PREDICT.name());
if (!cancelBatchPredictAction.isPresent() || cancelBatchPredictAction.get().getRequestBody() == null) {
ConnectorAction connectorAction = createConnectorAction(connector);
connector.setAction(connectorAction);
}
connector.decrypt(CANCEL_BATCH_PREDICT.name(), (credential) -> encryptor.decrypt(credential));
RemoteConnectorExecutor connectorExecutor = MLEngineClassLoader
.initInstance(connector.getProtocol(), connector, Connector.class);
Expand Down Expand Up @@ -245,4 +254,61 @@ private void processTaskResponse(MLTaskResponse taskResponse, ActionListener<MLC
log.error("Unable to fetch status for ml task ", e);
}
}

// TODO: move this method to connector utils class
private ConnectorAction createConnectorAction(Connector connector) {
Optional<ConnectorAction> batchPredictAction = connector.findAction(BATCH_PREDICT.name());

Map<String, String> headers = batchPredictAction.get().getHeaders();

String predictEndpoint = batchPredictAction.get().getUrl();
Map<String, String> parameters = connector.getParameters() != null
? new HashMap<>(connector.getParameters())
: Collections.emptyMap();

if (!parameters.isEmpty()) {
StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
predictEndpoint = substitutor.replace(predictEndpoint);
}

String url = "";
String requestBody = "";
String method = "POST"; // Default method

switch (getEndpointType(predictEndpoint)) {
case "sagemaker":
url = predictEndpoint.replace("CreateTransformJob", "StopTransformJob");
requestBody = "{ \"TransformJobName\" : \"${parameters.TransformJobName}\"}";
break;
case "openai":
case "cohere":
url = predictEndpoint + "/${parameters.id}/cancel";
break;
case "bedrock":
url = predictEndpoint + "/${parameters.processedJobArn}/stop";
break;
}

return ConnectorAction
.builder()
.actionType(CANCEL_BATCH_PREDICT)
.method(method)
.url(url)
.requestBody(requestBody)
.headers(headers)
.build();

}

private String getEndpointType(String url) {
if (url.contains("sagemaker"))
return "sagemaker";
if (url.contains("openai"))
return "openai";
if (url.contains("bedrock"))
return "bedrock";
if (url.contains("cohere"))
return "cohere";
return "";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import static org.opensearch.ml.common.MLTaskState.CANCELLING;
import static org.opensearch.ml.common.MLTaskState.COMPLETED;
import static org.opensearch.ml.common.MLTaskState.EXPIRED;
import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.BATCH_PREDICT;
import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.BATCH_PREDICT_STATUS;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_CANCELLED_REGEX;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_CANCELLING_REGEX;
Expand All @@ -24,6 +25,7 @@
import static org.opensearch.ml.utils.MLExceptionUtils.logException;
import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry;

import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -32,6 +34,7 @@
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import org.apache.commons.text.StringSubstitutor;
import org.opensearch.OpenSearchException;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.ResourceNotFoundException;
Expand All @@ -55,6 +58,7 @@
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskType;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.connector.ConnectorAction;
import org.opensearch.ml.common.connector.ConnectorAction.ActionType;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.exception.MLResourceNotFoundException;
Expand Down Expand Up @@ -279,6 +283,11 @@ private void executeConnector(
ActionListener<MLTaskGetResponse> actionListener
) {
if (connectorAccessControlHelper.validateConnectorAccess(client, connector)) {
Optional<ConnectorAction> batchPredictStatusAction = connector.findAction(BATCH_PREDICT_STATUS.name());
if (!batchPredictStatusAction.isPresent() || batchPredictStatusAction.get().getRequestBody() == null) {
ConnectorAction connectorAction = createConnectorAction(connector);
connector.setAction(connectorAction);
}
connector.decrypt(BATCH_PREDICT_STATUS.name(), (credential) -> encryptor.decrypt(credential));
RemoteConnectorExecutor connectorExecutor = MLEngineClassLoader
.initInstance(connector.getProtocol(), connector, Connector.class);
Expand Down Expand Up @@ -362,4 +371,62 @@ private boolean matchesPattern(Pattern pattern, String input) {
Matcher matcher = pattern.matcher(input);
return matcher.find();
}

// TODO: move this method to connector utils class
private ConnectorAction createConnectorAction(Connector connector) {
Optional<ConnectorAction> batchPredictAction = connector.findAction(BATCH_PREDICT.name());

Map<String, String> headers = batchPredictAction.get().getHeaders();

String predictEndpoint = batchPredictAction.get().getUrl();
Map<String, String> parameters = connector.getParameters() != null
? new HashMap<>(connector.getParameters())
: Collections.emptyMap();

// Apply parameter substitution only if needed
if (!parameters.isEmpty()) {
StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
predictEndpoint = substitutor.replace(predictEndpoint);
}

String url = "";
String requestBody = "";
String method = "GET";

switch (getEndpointType(predictEndpoint)) {
case "sagemaker":
url = predictEndpoint.replace("CreateTransformJob", "DescribeTransformJob");
requestBody = "{ \"TransformJobName\" : \"${parameters.TransformJobName}\"}";
method = "POST";
break;
case "openai":
case "cohere":
url = predictEndpoint + "/${parameters.id}";
break;
case "bedrock":
url = predictEndpoint + "/${parameters.processedJobArn}";
break;
}
return ConnectorAction
.builder()
.actionType(BATCH_PREDICT_STATUS)
.method(method)
.url(url)
.requestBody(requestBody)
.headers(headers)
.build();

}

private String getEndpointType(String url) {
if (url.contains("sagemaker"))
return "sagemaker";
if (url.contains("openai"))
return "openai";
if (url.contains("bedrock"))
return "bedrock";
if (url.contains("cohere"))
return "cohere";
return "";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -371,11 +371,11 @@ private void runPredict(
mlTaskManager.createMLTask(mlTask, ActionListener.wrap(response -> {
String taskId = response.getId();
mlTask.setTaskId(taskId);
MLPredictionOutput outputBuilder = MLPredictionOutput
.builder()
.taskId(taskId)
.status(MLTaskState.CREATED.name())
.build();
MLPredictionOutput outputBuilder = new MLPredictionOutput(
taskId,
MLTaskState.CREATED.name(),
remoteJob
);

MLTaskResponse predictOutput = MLTaskResponse.builder().output(outputBuilder).build();
internalListener.onResponse(predictOutput);
Expand Down

0 comments on commit 5556080

Please sign in to comment.