Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
Signed-off-by: zane-neo <zaniu@amazon.com>
  • Loading branch information
zane-neo committed Jun 6, 2024
1 parent e6427bf commit aee2a0b
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import org.opensearch.ml.common.connector.functions.preprocess.BedrockEmbeddingPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.CohereEmbeddingPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.CohereRerankPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.MultiModalEmbeddingPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.MultiModalConnectorPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.OpenAIEmbeddingPreProcessFunction;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
Expand Down Expand Up @@ -36,7 +36,7 @@ public class MLPreProcessFunction {
OpenAIEmbeddingPreProcessFunction openAIEmbeddingPreProcessFunction = new OpenAIEmbeddingPreProcessFunction();
BedrockEmbeddingPreProcessFunction bedrockEmbeddingPreProcessFunction = new BedrockEmbeddingPreProcessFunction();
CohereRerankPreProcessFunction cohereRerankPreProcessFunction = new CohereRerankPreProcessFunction();
MultiModalEmbeddingPreProcessFunction multiModalEmbeddingPreProcessFunction = new MultiModalEmbeddingPreProcessFunction();
MultiModalConnectorPreProcessFunction multiModalEmbeddingPreProcessFunction = new MultiModalConnectorPreProcessFunction();
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, cohereEmbeddingPreProcessFunction);
PRE_PROCESS_FUNCTIONS.put(TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT, multiModalEmbeddingPreProcessFunction);
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT, openAIEmbeddingPreProcessFunction);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,18 @@

import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.function.Function;

import static org.opensearch.ml.common.utils.StringUtils.addDefaultMethod;

/**
* This abstract class represents a pre-processing function for a connector.
* It takes an instance of {@link MLInput} as input and returns an instance of {@link RemoteInferenceInputDataSet}.
* The input data is expected to be of type {@link MLInput}, and the pre-processing function can be customized by implementing the {@link #validate(MLInput)} and {@link #process(MLInput)} methods.
* If the input data is already of type {@link RemoteInferenceInputDataSet}, it can be returned directly by setting the {@link #returnDirectlyForRemoteInferenceInput} flag to true.
*/
@Log4j2
public abstract class ConnectorPreProcessFunction implements Function<MLInput, RemoteInferenceInputDataSet> {

Expand All @@ -45,10 +52,11 @@ public RemoteInferenceInputDataSet apply(MLInput mlInput) {

public void validateTextDocsInput(MLInput mlInput) {
if (!(mlInput.getInputDataset() instanceof TextDocsInputDataSet)) {
log.error(String.format(Locale.ROOT, "This pre_process_function can only support TextDocsInputDataSet, actual input type is: %s", mlInput.getInputDataset().getClass().getName()));
throw new IllegalArgumentException("This pre_process_function can only support TextDocsInputDataSet");
}
List<String> docs = ((TextDocsInputDataSet) mlInput.getInputDataset()).getDocs();
if (docs.size() == 1 && docs.get(0) == null) {
if (docs.size() == 0 || (docs.size() == 1 && docs.get(0) == null)) {
throw new IllegalArgumentException("No input text or image provided");
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,16 @@

import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString;

public class MultiModalEmbeddingPreProcessFunction extends ConnectorPreProcessFunction {
/**
* This class provides a pre-processing function for multi-modal input data.
* It takes an instance of {@link MLInput} as input and returns an instance of {@link RemoteInferenceInputDataSet}.
* The input data is expected to be of type {@link TextDocsInputDataSet}, with the first document representing text input and the second document representing an image input.
* The function validates the input data and then processes it to create a {@link RemoteInferenceInputDataSet} object.
* If the input data is already of type {@link RemoteInferenceInputDataSet}, it is returned directly.
*/
public class MultiModalConnectorPreProcessFunction extends ConnectorPreProcessFunction {

public MultiModalEmbeddingPreProcessFunction() {
public MultiModalConnectorPreProcessFunction() {
this.returnDirectlyForRemoteInferenceInput = true;
}

Expand All @@ -26,7 +33,12 @@ public void validate(MLInput mlInput) {
validateTextDocsInput(mlInput);
}

// The input will must have inputText even it's null, input image is optional.
/**
* @param mlInput The input data to be processed.
* This method validates the input data and then processes it to create a {@link RemoteInferenceInputDataSet} object.
* If the input data is already of type {@link RemoteInferenceInputDataSet}, it is returned directly.
* The inputText will always show up in the first document, even it's null.
*/
@Override
public RemoteInferenceInputDataSet process(MLInput mlInput) {
TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,19 @@
import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.w3c.dom.Text;

import java.rmi.Remote;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;

import static org.junit.Assert.assertEquals;

public class MultiModalEmbeddingPreProcessFunctionTest {
public class MultiModalConnectorPreProcessFunctionTest {
@Rule
public ExpectedException exceptionRule = ExpectedException.none();

MultiModalEmbeddingPreProcessFunction function;
MultiModalConnectorPreProcessFunction function;

TextSimilarityInputDataSet textSimilarityInputDataSet;
TextDocsInputDataSet textDocsInputDataSet;
Expand All @@ -40,7 +38,7 @@ public class MultiModalEmbeddingPreProcessFunctionTest {

@Before
public void setUp() {
function = new MultiModalEmbeddingPreProcessFunction();
function = new MultiModalConnectorPreProcessFunction();
textSimilarityInputDataSet = TextSimilarityInputDataSet.builder().queryText("test").textDocs(Arrays.asList("hello")).build();
textDocsInputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("hello", "world")).build();
remoteInferenceInputDataSet = RemoteInferenceInputDataSet.builder().parameters(Map.of("inputText", "value1", "inputImage", "value2")).build();
Expand All @@ -51,21 +49,21 @@ public void setUp() {
}

@Test
public void process_NullInput() {
public void testProcess_whenNullInput_expectIllegalArgumentException() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Preprocess function input can't be null");
function.apply(null);
}

@Test
public void process_WrongInput() {
public void testProcess_whenWrongInput_expectIllegalArgumentException() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("This pre_process_function can only support TextDocsInputDataSet");
function.apply(textSimilarityInput);
}

@Test
public void process_input_text_image() {
public void testProcess_whenCorrectInput_expectCorrectOutput() {
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build();
RemoteInferenceInputDataSet dataSet = function.apply(mlInput);
assertEquals(2, dataSet.getParameters().size());
Expand All @@ -74,7 +72,7 @@ public void process_input_text_image() {
}

@Test
public void process_input_text_only() {
public void testProcess_whenInputTextOnly_expectInputTextShowUp() {
TextDocsInputDataSet textDocsInputDataSet1 = TextDocsInputDataSet.builder().docs(Arrays.asList("hello")).build();
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet1).build();
RemoteInferenceInputDataSet dataSet = function.apply(mlInput);
Expand All @@ -83,7 +81,7 @@ public void process_input_text_only() {
}

@Test
public void process_input_text_null() {
public void testProcess_whenInputTextIsnull_expectIllegalArgumentException() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("No input text or image provided");
List<String> docs = new ArrayList<>();
Expand All @@ -94,7 +92,7 @@ public void process_input_text_null() {
}

@Test
public void process_RemoteInferenceInput() {
public void testProcess_whenRemoteInferenceInput_expectRemoteInferenceInputDataSet() {
RemoteInferenceInputDataSet dataSet = function.apply(remoteInferenceInput);
assertEquals(remoteInferenceInputDataSet, dataSet);
}
Expand Down

0 comments on commit aee2a0b

Please sign in to comment.