Skip to content

Commit

Permalink
Fix user defined preprocess function missing prediction issue
Browse files Browse the repository at this point in the history
Signed-off-by: zane-neo <zaniu@amazon.com>
  • Loading branch information
zane-neo committed May 8, 2024
1 parent 9478fd4 commit b6fa75f
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import org.opensearch.ml.common.connector.functions.preprocess.BedrockEmbeddingPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.CohereEmbeddingPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.CohereRerankPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.MultiModalEmbeddingPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.OpenAIEmbeddingPreProcessFunction;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
Expand All @@ -22,6 +23,7 @@ public class MLPreProcessFunction {
public static final String TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT = "connector.pre_process.cohere.embedding";
public static final String TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT = "connector.pre_process.openai.embedding";
public static final String TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT = "connector.pre_process.bedrock.embedding";
public static final String TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT = "connector.pre_process.multimodal.embedding";
public static final String TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT = "connector.pre_process.default.embedding";
public static final String TEXT_SIMILARITY_TO_COHERE_RERANK_INPUT = "connector.pre_process.cohere.rerank";
public static final String TEXT_SIMILARITY_TO_DEFAULT_INPUT = "connector.pre_process.default.rerank";
Expand All @@ -33,11 +35,13 @@ public class MLPreProcessFunction {
CohereEmbeddingPreProcessFunction cohereEmbeddingPreProcessFunction = new CohereEmbeddingPreProcessFunction();
OpenAIEmbeddingPreProcessFunction openAIEmbeddingPreProcessFunction = new OpenAIEmbeddingPreProcessFunction();
BedrockEmbeddingPreProcessFunction bedrockEmbeddingPreProcessFunction = new BedrockEmbeddingPreProcessFunction();
MultiModalEmbeddingPreProcessFunction multiModalEmbeddingPreProcessFunction = new MultiModalEmbeddingPreProcessFunction();
CohereRerankPreProcessFunction cohereRerankPreProcessFunction = new CohereRerankPreProcessFunction();
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, cohereEmbeddingPreProcessFunction);
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT, openAIEmbeddingPreProcessFunction);
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT, openAIEmbeddingPreProcessFunction);
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT, bedrockEmbeddingPreProcessFunction);
PRE_PROCESS_FUNCTIONS.put(TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT, multiModalEmbeddingPreProcessFunction);
PRE_PROCESS_FUNCTIONS.put(TEXT_SIMILARITY_TO_DEFAULT_INPUT, cohereRerankPreProcessFunction);
PRE_PROCESS_FUNCTIONS.put(TEXT_SIMILARITY_TO_COHERE_RERANK_INPUT, cohereRerankPreProcessFunction);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.connector.functions.preprocess;

import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;

import java.util.Map;

import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString;

public class MultiModalEmbeddingPreProcessFunction extends ConnectorPreProcessFunction {

public MultiModalEmbeddingPreProcessFunction() {
this.returnDirectlyForRemoteInferenceInput = true;
}

@Override
public void validate(MLInput mlInput) {
validateTextDocsInput(mlInput);
}

@Override
public RemoteInferenceInputDataSet process(MLInput mlInput) {
TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset();
Map<String, Object> processedResult = Map.of("parameters", Map.of("inputText", inputData.getDocs().get(0), "inputImage", inputData.getDocs().get(1)));
return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(processedResult)).build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicReference;
Expand All @@ -27,6 +28,8 @@
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.connector.ConnectorAction;
import org.opensearch.ml.common.connector.MLPreProcessFunction;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
Expand Down Expand Up @@ -99,6 +102,12 @@ private Tuple<Integer, Integer> calculateChunkSize(TextDocsInputDataSet textDocs
return Tuple.tuple(textDocsLength / stepSize + 1, stepSize);
}
} else {
Optional<ConnectorAction> predictAction = getConnector().findPredictAction();
String preProcessFunction = predictAction.get().getPreProcessFunction();
if (preProcessFunction != null && !MLPreProcessFunction.contains(preProcessFunction)) {
// user defined preprocess script, this case, the chunk size is always equals to text docs length.
return Tuple.tuple(textDocsLength, 1);
}
// consider as batch.
return Tuple.tuple(1, textDocsLength);
}
Expand Down

0 comments on commit b6fa75f

Please sign in to comment.