diff --git a/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java index c34cd0b42..fe201abae 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java @@ -11,6 +11,7 @@ import java.util.Objects; import java.util.function.BiConsumer; import java.util.function.Supplier; +import java.util.stream.Collectors; import java.util.stream.IntStream; import org.apache.commons.lang3.StringUtils; @@ -173,13 +174,28 @@ private void buildMapWithProcessorKeyAndOriginalValueForMapType( if (processorKey == null || sourceAndMetadataMap == null) return; if (processorKey instanceof Map) { Map next = new LinkedHashMap<>(); - for (Map.Entry nestedFieldMapEntry : ((Map) processorKey).entrySet()) { - buildMapWithProcessorKeyAndOriginalValueForMapType( - nestedFieldMapEntry.getKey(), - nestedFieldMapEntry.getValue(), - (Map) sourceAndMetadataMap.get(parentKey), - next - ); + if (sourceAndMetadataMap.get(parentKey) instanceof Map) { + for (Map.Entry nestedFieldMapEntry : ((Map) processorKey).entrySet()) { + buildMapWithProcessorKeyAndOriginalValueForMapType( + nestedFieldMapEntry.getKey(), + nestedFieldMapEntry.getValue(), + (Map) sourceAndMetadataMap.get(parentKey), + next + ); + } + } else if (sourceAndMetadataMap.get(parentKey) instanceof List) { + for (Map.Entry nestedFieldMapEntry : ((Map) processorKey).entrySet()) { + List> list = (List>) sourceAndMetadataMap.get(parentKey); + List listOfStrings = list.stream().map(x -> x.get(nestedFieldMapEntry.getKey())).collect(Collectors.toList()); + Map map = new LinkedHashMap<>(); + map.put(nestedFieldMapEntry.getKey(), listOfStrings); + buildMapWithProcessorKeyAndOriginalValueForMapType( + nestedFieldMapEntry.getKey(), + nestedFieldMapEntry.getValue(), + map, + next + ); + } } treeRes.put(parentKey, next); } else { @@ -212,7 +228,7 @@ private void validateNestedTypeValue(String sourceKey, Object sourceValue, Suppl if (maxDepth > MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(environment.settings())) { throw new IllegalArgumentException("map type field [" + sourceKey + "] reached max depth limit, cannot process it"); } else if ((List.class.isAssignableFrom(sourceValue.getClass()))) { - validateListTypeValue(sourceKey, sourceValue); + validateListTypeValue(sourceKey, sourceValue, maxDepthSupplier); } else if (Map.class.isAssignableFrom(sourceValue.getClass())) { ((Map) sourceValue).values() .stream() @@ -226,9 +242,11 @@ private void validateNestedTypeValue(String sourceKey, Object sourceValue, Suppl } @SuppressWarnings({ "rawtypes" }) - private void validateListTypeValue(String sourceKey, Object sourceValue) { + private void validateListTypeValue(String sourceKey, Object sourceValue, Supplier maxDepthSupplier) { for (Object value : (List) sourceValue) { - if (value == null) { + if (value instanceof Map) { + validateNestedTypeValue(sourceKey, value, () -> maxDepthSupplier.get() + 1); + } else if (value == null) { throw new IllegalArgumentException("list type field [" + sourceKey + "] has null, cannot process it"); } else if (!(value instanceof String)) { throw new IllegalArgumentException("list type field [" + sourceKey + "] has non string value, cannot process it"); @@ -275,13 +293,20 @@ private void putNLPResultToSourceMapForMapType( if (processorKey == null || sourceAndMetadataMap == null || sourceValue == null) return; if (sourceValue instanceof Map) { for (Map.Entry inputNestedMapEntry : ((Map) sourceValue).entrySet()) { - putNLPResultToSourceMapForMapType( - inputNestedMapEntry.getKey(), - inputNestedMapEntry.getValue(), - results, - indexWrapper, - (Map) sourceAndMetadataMap.get(processorKey) - ); + if (sourceAndMetadataMap.get(processorKey) instanceof List) { + // build nlp output for list of nested objects + for (Map nestedElement : (List>) sourceAndMetadataMap.get(processorKey)) { + nestedElement.put(inputNestedMapEntry.getKey(), results.get(indexWrapper.index++)); + } + } else { + putNLPResultToSourceMapForMapType( + inputNestedMapEntry.getKey(), + inputNestedMapEntry.getValue(), + results, + indexWrapper, + (Map) sourceAndMetadataMap.get(processorKey) + ); + } } } else if (sourceValue instanceof String) { sourceAndMetadataMap.put(processorKey, results.get(indexWrapper.index++)); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java index 63b652fae..c5ae672d3 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java @@ -70,7 +70,15 @@ private void ingestDocument() throws Exception { + " \"favorites\": {\n" + " \"game\": \"overwatch\",\n" + " \"movie\": null\n" - + " }\n" + + " },\n" + + " \"nested_passages\": [\n" + + " {\n" + + " \"text\": \"hello\"\n" + + " },\n" + + " {\n" + + " \"text\": \"world\"\n" + + " }\n" + + " ]\n" + "}\n"; Response response = makeRequest( client(), diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java index 8c2f1c1be..60408d820 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java @@ -20,6 +20,7 @@ import java.util.LinkedList; import java.util.List; import java.util.Map; +import java.util.Arrays; import java.util.function.BiConsumer; import java.util.function.Supplier; @@ -404,6 +405,35 @@ public void testBuildVectorOutput_withNestedMap_successful() { assertNotNull(actionGamesKnn); } + public void testBuildVectorOutput_withNestedList_successful() { + Map config = createNestedListConfiguration(); + IngestDocument ingestDocument = createNestedListIngestDocument(); + TextEmbeddingProcessor textEmbeddingProcessor = createInstanceWithNestedMapConfiguration(config); + Map knnMap = textEmbeddingProcessor.buildMapWithProcessorKeyAndOriginalValue(ingestDocument); + List> modelTensorList = createMockVectorResult(); + textEmbeddingProcessor.buildNLPResult(knnMap, modelTensorList, ingestDocument.getSourceAndMetadata()); + List> nestedObj = (List>) ingestDocument.getSourceAndMetadata().get("nestedField"); + assertTrue(nestedObj.get(0).containsKey("vectorField")); + assertTrue(nestedObj.get(1).containsKey("vectorField")); + assertNotNull(nestedObj.get(0).get("vectorField")); + assertNotNull(nestedObj.get(1).get("vectorField")); + } + + public void testBuildVectorOutput_withNestedList_Level2_successful() { + Map config = createNestedList2LevelConfiguration(); + IngestDocument ingestDocument = create2LevelNestedListIngestDocument(); + TextEmbeddingProcessor textEmbeddingProcessor = createInstanceWithNestedMapConfiguration(config); + Map knnMap = textEmbeddingProcessor.buildMapWithProcessorKeyAndOriginalValue(ingestDocument); + List> modelTensorList = createMockVectorResult(); + textEmbeddingProcessor.buildNLPResult(knnMap, modelTensorList, ingestDocument.getSourceAndMetadata()); + Map nestedLevel1 = (Map) ingestDocument.getSourceAndMetadata().get("nestedField"); + List> nestedObj = (List>) nestedLevel1.get("nestedField"); + assertTrue(nestedObj.get(0).containsKey("vectorField")); + assertTrue(nestedObj.get(1).containsKey("vectorField")); + assertNotNull(nestedObj.get(0).get("vectorField")); + assertNotNull(nestedObj.get(1).get("vectorField")); + } + public void test_updateDocument_appendVectorFieldsToDocument_successful() { Map config = createPlainStringConfiguration(); IngestDocument ingestDocument = createPlainIngestDocument(); @@ -520,4 +550,44 @@ private IngestDocument createNestedMapIngestDocument() { result.put("favorites", favorite); return new IngestDocument(result, new HashMap<>()); } + + private Map createNestedListConfiguration() { + Map nestedConfig = new HashMap<>(); + nestedConfig.put("textField", "vectorField"); + Map result = new HashMap<>(); + result.put("nestedField", nestedConfig); + return result; + } + + private Map createNestedList2LevelConfiguration() { + Map nestedConfig = new HashMap<>(); + nestedConfig.put("textField", "vectorField"); + Map nestConfigLevel1 = new HashMap<>(); + nestConfigLevel1.put("nestedField", nestedConfig); + Map result = new HashMap<>(); + result.put("nestedField", nestConfigLevel1); + return result; + } + + private IngestDocument createNestedListIngestDocument() { + HashMap nestedObj1 = new HashMap<>(); + nestedObj1.put("textField", "This is a text field"); + HashMap nestedObj2 = new HashMap<>(); + nestedObj2.put("textField", "This is another text field"); + HashMap nestedList = new HashMap<>(); + nestedList.put("nestedField", Arrays.asList(nestedObj1, nestedObj2)); + return new IngestDocument(nestedList, new HashMap<>()); + } + + private IngestDocument create2LevelNestedListIngestDocument() { + HashMap nestedObj1 = new HashMap<>(); + nestedObj1.put("textField", "This is a text field"); + HashMap nestedObj2 = new HashMap<>(); + nestedObj2.put("textField", "This is another text field"); + HashMap nestedList = new HashMap<>(); + nestedList.put("nestedField", Arrays.asList(nestedObj1, nestedObj2)); + HashMap nestedList1 = new HashMap<>(); + nestedList1.put("nestedField", nestedList); + return new IngestDocument(nestedList1, new HashMap<>()); + } } diff --git a/src/test/resources/processor/IndexMappings.json b/src/test/resources/processor/IndexMappings.json index 13faad6c4..ffa5cea64 100644 --- a/src/test/resources/processor/IndexMappings.json +++ b/src/test/resources/processor/IndexMappings.json @@ -83,6 +83,27 @@ }, "passage_text": { "type": "text" + }, + "nested_passages": { + "type": "nested", + "properties": { + "text": { + "type": "text" + }, + "embedding": { + "type": "knn_vector", + "dimension": 768, + "method": { + "name": "hnsw", + "space_type": "l2", + "engine": "lucene", + "parameters": { + "ef_construction": 128, + "m": 24 + } + } + } + } } } } diff --git a/src/test/resources/processor/PipelineConfiguration.json b/src/test/resources/processor/PipelineConfiguration.json index 471f6a432..d833576a0 100644 --- a/src/test/resources/processor/PipelineConfiguration.json +++ b/src/test/resources/processor/PipelineConfiguration.json @@ -10,6 +10,9 @@ "favorites": { "game": "game_knn", "movie": "movie_knn" + }, + "nested_passages": { + "text": "embedding" } } }