From 3adeaf670b5e2c3fb2f946696601e8e9da7d3572 Mon Sep 17 00:00:00 2001 From: zane-neo Date: Mon, 9 Dec 2024 11:58:18 +0800 Subject: [PATCH] Add cohere v2 default post process function Signed-off-by: zane-neo --- .../opensearch/ml/common/connector/MLPostProcessFunction.java | 3 +++ .../java/org/opensearch/ml/rest/RestBedRockInferenceIT.java | 2 +- .../BedRockEmbeddingTypeSupportedConnectorBodies.json | 2 +- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java index 7981f08175..06be017913 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java @@ -17,6 +17,7 @@ public class MLPostProcessFunction { public static final String COHERE_EMBEDDING = "connector.post_process.cohere.embedding"; + public static final String COHERE_V2_EMBEDDING = "connector.post_process.cohere_v2.embedding"; public static final String OPENAI_EMBEDDING = "connector.post_process.openai.embedding"; public static final String BEDROCK_EMBEDDING = "connector.post_process.bedrock.embedding"; public static final String BEDROCK_BATCH_JOB_ARN = "connector.post_process.bedrock.batch_job_arn"; @@ -35,6 +36,7 @@ public class MLPostProcessFunction { CohereRerankPostProcessFunction cohereRerankPostProcessFunction = new CohereRerankPostProcessFunction(); JSON_PATH_EXPRESSION.put(OPENAI_EMBEDDING, "$.data[*].embedding"); JSON_PATH_EXPRESSION.put(COHERE_EMBEDDING, "$.embeddings"); + JSON_PATH_EXPRESSION.put(COHERE_V2_EMBEDDING, "$.embeddings.float"); JSON_PATH_EXPRESSION.put(DEFAULT_EMBEDDING, "$[*]"); JSON_PATH_EXPRESSION.put(BEDROCK_EMBEDDING, "$.embedding"); JSON_PATH_EXPRESSION.put(BEDROCK_BATCH_JOB_ARN, "$"); @@ -42,6 +44,7 @@ public class MLPostProcessFunction { JSON_PATH_EXPRESSION.put(DEFAULT_RERANK, "$[*]"); POST_PROCESS_FUNCTIONS.put(OPENAI_EMBEDDING, embeddingPostProcessFunction); POST_PROCESS_FUNCTIONS.put(COHERE_EMBEDDING, embeddingPostProcessFunction); + POST_PROCESS_FUNCTIONS.put(COHERE_V2_EMBEDDING, embeddingPostProcessFunction); POST_PROCESS_FUNCTIONS.put(DEFAULT_EMBEDDING, embeddingPostProcessFunction); POST_PROCESS_FUNCTIONS.put(BEDROCK_EMBEDDING, bedrockEmbeddingPostProcessFunction); POST_PROCESS_FUNCTIONS.put(BEDROCK_BATCH_JOB_ARN, batchJobArnPostProcessFunction); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java index 4ddc3bf2d4..b6d1fc2a85 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java @@ -321,7 +321,7 @@ public void test_bedrockEmbeddingTypeSupportedModel_withDifferentResponseFilters if (testCaseName.equals("response_filter_to_embedding_concrete_type")) { assertEquals(errorMsg, 1024, ((List) ((Map) outputList.get(0)).get("data")).size()); } else { - assertEquals(errorMsg, 1536, ((List) ((Map) outputList.get(0)).get("data")).size()); + assertEquals(errorMsg, 1024, ((List) ((Map) outputList.get(0)).get("data")).size()); } } } diff --git a/plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockEmbeddingTypeSupportedConnectorBodies.json b/plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockEmbeddingTypeSupportedConnectorBodies.json index bee9d60135..9082077200 100644 --- a/plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockEmbeddingTypeSupportedConnectorBodies.json +++ b/plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockEmbeddingTypeSupportedConnectorBodies.json @@ -100,7 +100,7 @@ "parameters": { "region": "%s", "service_name": "bedrock", - "model_name": "amazon.titan-embed-text-v1" + "model_name": "amazon.titan-embed-text-v2:0" }, "credential": { "access_key": "%s",