diff --git a/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java index 7981f08175..06be017913 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java @@ -17,6 +17,7 @@ public class MLPostProcessFunction { public static final String COHERE_EMBEDDING = "connector.post_process.cohere.embedding"; + public static final String COHERE_V2_EMBEDDING = "connector.post_process.cohere_v2.embedding"; public static final String OPENAI_EMBEDDING = "connector.post_process.openai.embedding"; public static final String BEDROCK_EMBEDDING = "connector.post_process.bedrock.embedding"; public static final String BEDROCK_BATCH_JOB_ARN = "connector.post_process.bedrock.batch_job_arn"; @@ -35,6 +36,7 @@ public class MLPostProcessFunction { CohereRerankPostProcessFunction cohereRerankPostProcessFunction = new CohereRerankPostProcessFunction(); JSON_PATH_EXPRESSION.put(OPENAI_EMBEDDING, "$.data[*].embedding"); JSON_PATH_EXPRESSION.put(COHERE_EMBEDDING, "$.embeddings"); + JSON_PATH_EXPRESSION.put(COHERE_V2_EMBEDDING, "$.embeddings.float"); JSON_PATH_EXPRESSION.put(DEFAULT_EMBEDDING, "$[*]"); JSON_PATH_EXPRESSION.put(BEDROCK_EMBEDDING, "$.embedding"); JSON_PATH_EXPRESSION.put(BEDROCK_BATCH_JOB_ARN, "$"); @@ -42,6 +44,7 @@ public class MLPostProcessFunction { JSON_PATH_EXPRESSION.put(DEFAULT_RERANK, "$[*]"); POST_PROCESS_FUNCTIONS.put(OPENAI_EMBEDDING, embeddingPostProcessFunction); POST_PROCESS_FUNCTIONS.put(COHERE_EMBEDDING, embeddingPostProcessFunction); + POST_PROCESS_FUNCTIONS.put(COHERE_V2_EMBEDDING, embeddingPostProcessFunction); POST_PROCESS_FUNCTIONS.put(DEFAULT_EMBEDDING, embeddingPostProcessFunction); POST_PROCESS_FUNCTIONS.put(BEDROCK_EMBEDDING, bedrockEmbeddingPostProcessFunction); POST_PROCESS_FUNCTIONS.put(BEDROCK_BATCH_JOB_ARN, batchJobArnPostProcessFunction); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java index 4ddc3bf2d4..b6d1fc2a85 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java @@ -321,7 +321,7 @@ public void test_bedrockEmbeddingTypeSupportedModel_withDifferentResponseFilters if (testCaseName.equals("response_filter_to_embedding_concrete_type")) { assertEquals(errorMsg, 1024, ((List) ((Map) outputList.get(0)).get("data")).size()); } else { - assertEquals(errorMsg, 1536, ((List) ((Map) outputList.get(0)).get("data")).size()); + assertEquals(errorMsg, 1024, ((List) ((Map) outputList.get(0)).get("data")).size()); } } } diff --git a/plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockEmbeddingTypeSupportedConnectorBodies.json b/plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockEmbeddingTypeSupportedConnectorBodies.json index bee9d60135..9082077200 100644 --- a/plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockEmbeddingTypeSupportedConnectorBodies.json +++ b/plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockEmbeddingTypeSupportedConnectorBodies.json @@ -100,7 +100,7 @@ "parameters": { "region": "%s", "service_name": "bedrock", - "model_name": "amazon.titan-embed-text-v1" + "model_name": "amazon.titan-embed-text-v2:0" }, "credential": { "access_key": "%s",