Skip to content

Commit

Permalink
Add process function for bedrock
Browse files Browse the repository at this point in the history
Signed-off-by: zane-neo <zaniu@amazon.com>
  • Loading branch information
zane-neo committed Jan 26, 2024
1 parent 6a709ac commit f90e8f3
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,38 +18,62 @@ public class MLPostProcessFunction {

public static final String COHERE_EMBEDDING = "connector.post_process.cohere.embedding";
public static final String OPENAI_EMBEDDING = "connector.post_process.openai.embedding";

public static final String BEDROCK_EMBEDDING = "connector.post_process.bedrock.embedding";
public static final String DEFAULT_EMBEDDING = "connector.post_process.default.embedding";

private static final Map<String, String> JSON_PATH_EXPRESSION = new HashMap<>();

private static final Map<String, Function<List<List<Float>>, List<ModelTensor>>> POST_PROCESS_FUNCTIONS = new HashMap<>();
private static final Map<String, Function<List<?>, List<ModelTensor>>> POST_PROCESS_FUNCTIONS = new HashMap<>();


static {
JSON_PATH_EXPRESSION.put(OPENAI_EMBEDDING, "$.data[*].embedding");
JSON_PATH_EXPRESSION.put(COHERE_EMBEDDING, "$.embeddings");
JSON_PATH_EXPRESSION.put(DEFAULT_EMBEDDING, "$[*]");
POST_PROCESS_FUNCTIONS.put(OPENAI_EMBEDDING, buildModelTensorList());
POST_PROCESS_FUNCTIONS.put(COHERE_EMBEDDING, buildModelTensorList());
POST_PROCESS_FUNCTIONS.put(DEFAULT_EMBEDDING, buildModelTensorList());
JSON_PATH_EXPRESSION.put(BEDROCK_EMBEDDING, "$.embedding");
POST_PROCESS_FUNCTIONS.put(OPENAI_EMBEDDING, buildMultipleResultModelTensor());
POST_PROCESS_FUNCTIONS.put(COHERE_EMBEDDING, buildMultipleResultModelTensor());
POST_PROCESS_FUNCTIONS.put(DEFAULT_EMBEDDING, buildMultipleResultModelTensor());
POST_PROCESS_FUNCTIONS.put(BEDROCK_EMBEDDING, buildSingleResultModelTensor());
}

public static Function<List<List<Float>>, List<ModelTensor>> buildModelTensorList() {
return embeddings -> {
public static Function<List<?>, List<ModelTensor>> buildSingleResultModelTensor() {
return embedding -> {
List<ModelTensor> modelTensors = new ArrayList<>();
if (embeddings == null) {
throw new IllegalArgumentException("The list of embeddings is null when using the built-in post-processing function.");
if (embedding == null) {
throw new IllegalArgumentException("The embedding is null when using the built-in post-processing function.");
}
embeddings.forEach(embedding -> modelTensors.add(
modelTensors.add(
ModelTensor
.builder()
.name("sentence_embedding")
.dataType(MLResultDataType.FLOAT32)
.shape(new long[]{embedding.size()})
.data(embedding.toArray(new Number[0]))
.build()
));
);
return modelTensors;
};
}

public static Function<List<?>, List<ModelTensor>> buildMultipleResultModelTensor() {
return embeddings -> {
List<ModelTensor> modelTensors = new ArrayList<>();
if (embeddings == null) {
throw new IllegalArgumentException("The list of embeddings is null when using the built-in post-processing function.");
}
embeddings.forEach(embedding -> {
List<Number> eachEmbedding = (List<Number>) embedding;
modelTensors.add(
ModelTensor
.builder()
.name("sentence_embedding")
.dataType(MLResultDataType.FLOAT32)
.shape(new long[]{eachEmbedding.size()})
.data(eachEmbedding.toArray(new Number[0]))
.build()
);
});
return modelTensors;
};
}
Expand All @@ -58,7 +82,7 @@ public static String getResponseFilter(String postProcessFunction) {
return JSON_PATH_EXPRESSION.get(postProcessFunction);
}

public static Function<List<List<Float>>, List<ModelTensor>> get(String postProcessFunction) {
public static Function<List<?>, List<ModelTensor>> get(String postProcessFunction) {
return POST_PROCESS_FUNCTIONS.get(postProcessFunction);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ public class MLPreProcessFunction {
private static final Map<String, Function<List<String>, Map<String, Object>>> PRE_PROCESS_FUNCTIONS = new HashMap<>();
public static final String TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT = "connector.pre_process.cohere.embedding";
public static final String TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT = "connector.pre_process.openai.embedding";

public static final String TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT = "connector.pre_process.bedrock.embedding";
public static final String TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT = "connector.pre_process.default.embedding";

private static Function<List<String>, Map<String, Object>> cohereTextEmbeddingPreProcess() {
Expand All @@ -26,17 +26,22 @@ private static Function<List<String>, Map<String, Object>> openAiTextEmbeddingPr
return inputs -> Map.of("parameters", Map.of("input", inputs));
}

private static Function<List<String>, Map<String, Object>> bedrockTextEmbeddingPreProcess() {
return inputs -> Map.of("parameters", Map.of("inputText", inputs.get(0)));
}

static {
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, cohereTextEmbeddingPreProcess());
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT, openAiTextEmbeddingPreProcess());
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT, openAiTextEmbeddingPreProcess());
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT, bedrockTextEmbeddingPreProcess());
}

public static boolean contains(String functionName) {
return PRE_PROCESS_FUNCTIONS.containsKey(functionName);
}

public static Function<List<String>, Map<String, Object>> get(String postProcessFunction) {
return PRE_PROCESS_FUNCTIONS.get(postProcessFunction);
public static Function<List<String>, Map<String, Object>> get(String preProcessFunction) {
return PRE_PROCESS_FUNCTIONS.get(preProcessFunction);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,28 @@ public void test_getResponseFilter() {
}

@Test
public void test_buildModelTensorList() {
Assert.assertNotNull(MLPostProcessFunction.buildModelTensorList());
public void test_buildMultipleResultModelTensorList() {
Assert.assertNotNull(MLPostProcessFunction.buildMultipleResultModelTensor());
List<List<Float>> numbersList = new ArrayList<>();
numbersList.add(Collections.singletonList(1.0f));
Assert.assertNotNull(MLPostProcessFunction.buildModelTensorList().apply(numbersList));
Assert.assertNotNull(MLPostProcessFunction.buildMultipleResultModelTensor().apply(numbersList));
}

@Test
public void test_buildModelTensorList_exception() {
public void test_buildMultipleResultModelTensorList_exception() {
exceptionRule.expect(IllegalArgumentException.class);
MLPostProcessFunction.buildModelTensorList().apply(null);
MLPostProcessFunction.buildMultipleResultModelTensor().apply(null);
}

@Test
public void test_buildSingleResultModelTensorList() {
Assert.assertNotNull(MLPostProcessFunction.buildSingleResultModelTensor());
Assert.assertNotNull(MLPostProcessFunction.buildSingleResultModelTensor().apply(Collections.singletonList(1.0f)));
}

@Test
public void test_buildSingleResultModelTensorList_exception() {
exceptionRule.expect(IllegalArgumentException.class);
MLPostProcessFunction.buildSingleResultModelTensor().apply(null);
}
}

0 comments on commit f90e8f3

Please sign in to comment.