Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
Signed-off-by: zane-neo <zaniu@amazon.com>
  • Loading branch information
zane-neo committed Sep 25, 2023
1 parent 09ce432 commit 027b4ee
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public static Function<List<List<Float>>, List<ModelTensor>> buildModelTensorLis
return embeddings -> {
List<ModelTensor> modelTensors = new ArrayList<>();
if (embeddings == null) {
throw new IllegalArgumentException("NumbersList is null when applying build-in post process function!");
throw new IllegalArgumentException("The list of embeddings is null when using the built-in post-processing function.");
}
embeddings.forEach(embedding -> modelTensors.add(
ModelTensor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ public class MLPreProcessFunction {
public static final String TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT = "connector.pre_process.cohere.embedding";
public static final String TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT = "connector.pre_process.openai.embedding";

public static final String DEFAULT_EMBEDDING_INPUT = "connector.pre_process.default.embedding";
public static final String TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT = "connector.pre_process.default.embedding";

private static Function<List<String>, Map<String, Object>> cohereTextEmbeddingPreProcess() {
return inputs -> Map.of("parameters", Map.of("texts", inputs));
Expand All @@ -29,7 +29,7 @@ private static Function<List<String>, Map<String, Object>> openAiTextEmbeddingPr
static {
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, cohereTextEmbeddingPreProcess());
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT, openAiTextEmbeddingPreProcess());
PRE_PROCESS_FUNCTIONS.put(DEFAULT_EMBEDDING_INPUT, openAiTextEmbeddingPreProcess());
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT, openAiTextEmbeddingPreProcess());
}

public static boolean contains(String functionName) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ private static RemoteInferenceInputDataSet processTextDocsInput(TextDocsInputDat
throw new IllegalArgumentException("no predict action found");
}
String preProcessFunction = predictAction.get().getPreProcessFunction();
preProcessFunction = preProcessFunction == null ? MLPreProcessFunction.DEFAULT_EMBEDDING_INPUT : preProcessFunction;
preProcessFunction = preProcessFunction == null ? MLPreProcessFunction.TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT : preProcessFunction;
if (MLPreProcessFunction.contains(preProcessFunction)) {
Map<String, Object> buildInFunctionResult = MLPreProcessFunction.get(preProcessFunction).apply(docs);
return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(buildInFunctionResult)).build();
Expand Down

0 comments on commit 027b4ee

Please sign in to comment.