Skip to content

Commit

Permalink
Add cohere v2 default post process function
Browse files Browse the repository at this point in the history
Signed-off-by: zane-neo <zaniu@amazon.com>
  • Loading branch information
zane-neo committed Dec 9, 2024
1 parent 58dd965 commit 3adeaf6
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
public class MLPostProcessFunction {

public static final String COHERE_EMBEDDING = "connector.post_process.cohere.embedding";
public static final String COHERE_V2_EMBEDDING = "connector.post_process.cohere_v2.embedding";
public static final String OPENAI_EMBEDDING = "connector.post_process.openai.embedding";
public static final String BEDROCK_EMBEDDING = "connector.post_process.bedrock.embedding";
public static final String BEDROCK_BATCH_JOB_ARN = "connector.post_process.bedrock.batch_job_arn";
Expand All @@ -35,13 +36,15 @@ public class MLPostProcessFunction {
CohereRerankPostProcessFunction cohereRerankPostProcessFunction = new CohereRerankPostProcessFunction();
JSON_PATH_EXPRESSION.put(OPENAI_EMBEDDING, "$.data[*].embedding");
JSON_PATH_EXPRESSION.put(COHERE_EMBEDDING, "$.embeddings");
JSON_PATH_EXPRESSION.put(COHERE_V2_EMBEDDING, "$.embeddings.float");
JSON_PATH_EXPRESSION.put(DEFAULT_EMBEDDING, "$[*]");
JSON_PATH_EXPRESSION.put(BEDROCK_EMBEDDING, "$.embedding");
JSON_PATH_EXPRESSION.put(BEDROCK_BATCH_JOB_ARN, "$");
JSON_PATH_EXPRESSION.put(COHERE_RERANK, "$.results");
JSON_PATH_EXPRESSION.put(DEFAULT_RERANK, "$[*]");
POST_PROCESS_FUNCTIONS.put(OPENAI_EMBEDDING, embeddingPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(COHERE_EMBEDDING, embeddingPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(COHERE_V2_EMBEDDING, embeddingPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(DEFAULT_EMBEDDING, embeddingPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(BEDROCK_EMBEDDING, bedrockEmbeddingPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(BEDROCK_BATCH_JOB_ARN, batchJobArnPostProcessFunction);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ public void test_bedrockEmbeddingTypeSupportedModel_withDifferentResponseFilters
if (testCaseName.equals("response_filter_to_embedding_concrete_type")) {
assertEquals(errorMsg, 1024, ((List) ((Map<?, ?>) outputList.get(0)).get("data")).size());
} else {
assertEquals(errorMsg, 1536, ((List) ((Map<?, ?>) outputList.get(0)).get("data")).size());
assertEquals(errorMsg, 1024, ((List) ((Map<?, ?>) outputList.get(0)).get("data")).size());
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@
"parameters": {
"region": "%s",
"service_name": "bedrock",
"model_name": "amazon.titan-embed-text-v1"
"model_name": "amazon.titan-embed-text-v2:0"
},
"credential": {
"access_key": "%s",
Expand Down

0 comments on commit 3adeaf6

Please sign in to comment.