From de551e2e801471064e590e33653c1b9577357f2b Mon Sep 17 00:00:00 2001 From: John Mazanec Date: Wed, 21 Dec 2022 14:51:01 -0800 Subject: [PATCH] Add filter option for query type (#88) Adds filter option for query type. Filtering support was introduced in the k-NN plugin in 2.4. Breaks backwards compatibility with OpenSearch 2.4, however, given that 2.4 is experimental, this is okay. Backwards compatibility issues will only arise during mixed cluster upgrade. Signed-off-by: John Mazanec --- .../query/NeuralQueryBuilder.java | 19 +- .../common/BaseNeuralSearchIT.java | 13 ++ .../query/NeuralQueryBuilderTests.java | 181 +++++++++++++++++- .../neuralsearch/query/NeuralQueryIT.java | 74 +++++++ 4 files changed, 274 insertions(+), 13 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java index cdd248195..609a15ebf 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java @@ -5,6 +5,7 @@ package org.opensearch.neuralsearch.query; +import static org.opensearch.knn.index.query.KNNQueryBuilder.FILTER_FIELD; import static org.opensearch.neuralsearch.common.VectorUtil.vectorAsListToArray; import java.io.IOException; @@ -80,6 +81,7 @@ public static void initialize(MLCommonsClientAccessor mlClient) { @Getter(AccessLevel.PACKAGE) @Setter(AccessLevel.PACKAGE) private Supplier vectorSupplier; + private QueryBuilder filter; /** * Constructor from stream input @@ -93,6 +95,7 @@ public NeuralQueryBuilder(StreamInput in) throws IOException { this.queryText = in.readString(); this.modelId = in.readString(); this.k = in.readVInt(); + this.filter = in.readOptionalNamedWriteable(QueryBuilder.class); } @Override @@ -101,6 +104,7 @@ protected void doWriteTo(StreamOutput out) throws IOException { out.writeString(this.queryText); out.writeString(this.modelId); out.writeVInt(this.k); + out.writeOptionalNamedWriteable(this.filter); } @Override @@ -110,6 +114,9 @@ protected void doXContent(XContentBuilder xContentBuilder, Params params) throws xContentBuilder.field(QUERY_TEXT_FIELD.getPreferredName(), queryText); xContentBuilder.field(MODEL_ID_FIELD.getPreferredName(), modelId); xContentBuilder.field(K_FIELD.getPreferredName(), k); + if (filter != null) { + xContentBuilder.field(FILTER_FIELD.getPreferredName(), filter); + } printBoostAndQueryName(xContentBuilder); xContentBuilder.endObject(); xContentBuilder.endObject(); @@ -125,7 +132,8 @@ protected void doXContent(XContentBuilder xContentBuilder, Params params) throws * "model_id": "string", * "k": int, * "name": "string", (optional) - * "boost": float (optional) + * "boost": float (optional), + * "filter": map (optional) * } * } * @@ -184,6 +192,10 @@ private static void parseQueryParams(XContentParser parser, NeuralQueryBuilder n "[" + NAME + "] query does not support [" + currentFieldName + "]" ); } + } else if (token == XContentParser.Token.START_OBJECT) { + if (FILTER_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + neuralQueryBuilder.filter(parseInnerQueryBuilder(parser)); + } } else { throw new ParsingException( parser.getTokenLocation(), @@ -205,7 +217,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) { // create a new builder. Once the supplier's value gets set, we return a KNNQueryBuilder. Otherwise, we just // return the current unmodified query builder. if (vectorSupplier() != null) { - return vectorSupplier().get() == null ? this : new KNNQueryBuilder(fieldName(), vectorSupplier.get(), k()); + return vectorSupplier().get() == null ? this : new KNNQueryBuilder(fieldName(), vectorSupplier.get(), k(), filter()); } SetOnce vectorSetOnce = new SetOnce<>(); @@ -215,7 +227,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) { actionListener.onResponse(null); }, actionListener::onFailure))) ); - return new NeuralQueryBuilder(fieldName(), queryText(), modelId(), k(), vectorSetOnce::get); + return new NeuralQueryBuilder(fieldName(), queryText(), modelId(), k(), vectorSetOnce::get, filter()); } @Override @@ -233,6 +245,7 @@ protected boolean doEquals(NeuralQueryBuilder obj) { equalsBuilder.append(queryText, obj.queryText); equalsBuilder.append(modelId, obj.modelId); equalsBuilder.append(k, obj.k); + equalsBuilder.append(filter, obj.filter); return equalsBuilder.isEquals(); } diff --git a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java index 914acb345..054124904 100644 --- a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java +++ b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java @@ -330,6 +330,19 @@ protected Map getFirstInnerHit(Map searchRespons return (Map) hits2List.get(0); } + /** + * Parse the total number of hits from the search + * + * @param searchResponseAsMap Complete search response as a map + * @return number of hits from the search + */ + @SuppressWarnings("unchecked") + protected int getHitCount(Map searchResponseAsMap) { + Map hits1map = (Map) searchResponseAsMap.get("hits"); + List hits1List = (List) hits1map.get("hits"); + return hits1List.size(); + } + /** * Create a k-NN index from a list of KNNFieldConfigs * diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java index 64bab186a..88aaedc0a 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java @@ -8,8 +8,10 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; +import static org.opensearch.common.xcontent.ToXContent.EMPTY_PARAMS; import static org.opensearch.index.query.AbstractQueryBuilder.BOOST_FIELD; import static org.opensearch.index.query.AbstractQueryBuilder.NAME_FIELD; +import static org.opensearch.knn.index.query.KNNQueryBuilder.FILTER_FIELD; import static org.opensearch.neuralsearch.TestUtils.xContentBuilderToMap; import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.K_FIELD; import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.MODEL_ID_FIELD; @@ -29,12 +31,20 @@ import org.opensearch.action.ActionListener; import org.opensearch.client.Client; +import org.opensearch.common.ParseField; import org.opensearch.common.ParsingException; +import org.opensearch.common.bytes.BytesReference; import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.io.stream.FilterStreamInput; +import org.opensearch.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.common.io.stream.NamedWriteableRegistry; +import org.opensearch.common.xcontent.NamedXContentRegistry; import org.opensearch.common.xcontent.ToXContent; import org.opensearch.common.xcontent.XContentBuilder; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.XContentParser; +import org.opensearch.index.query.MatchAllQueryBuilder; +import org.opensearch.index.query.MatchNoneQueryBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryRewriteContext; import org.opensearch.knn.index.query.KNNQueryBuilder; @@ -52,6 +62,8 @@ public class NeuralQueryBuilderTests extends OpenSearchTestCase { private static final String QUERY_NAME = "queryName"; private static final Supplier TEST_VECTOR_SUPPLIER = () -> new float[10]; + private static final QueryBuilder TEST_FILTER = new MatchAllQueryBuilder(); + @SneakyThrows public void testFromXContent_whenBuiltWithDefaults_thenBuildSuccessfully() { /* @@ -118,6 +130,60 @@ public void testFromXContent_whenBuiltWithOptionals_thenBuildSuccessfully() { assertEquals(QUERY_NAME, neuralQueryBuilder.queryName()); } + @SneakyThrows + public void testFromXContent_whenBuiltWithFilter_thenBuildSuccessfully() { + /* + { + "VECTOR_FIELD": { + "query_text": "string", + "model_id": "string", + "k": int, + "boost": 10.0, + "_name": "something", + "filter": { + "match_all": {} + } + } + } + */ + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .startObject(FIELD_NAME) + .field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT) + .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID) + .field(K_FIELD.getPreferredName(), K) + .field(BOOST_FIELD.getPreferredName(), BOOST) + .field(NAME_FIELD.getPreferredName(), QUERY_NAME) + .field(FILTER_FIELD.getPreferredName(), TEST_FILTER) + .endObject() + .endObject(); + + NamedXContentRegistry namedXContentRegistry = new NamedXContentRegistry( + List.of( + new NamedXContentRegistry.Entry( + QueryBuilder.class, + new ParseField(MatchAllQueryBuilder.NAME), + MatchAllQueryBuilder::fromXContent + ) + ) + ); + XContentParser contentParser = createParser( + namedXContentRegistry, + xContentBuilder.contentType().xContent(), + BytesReference.bytes(xContentBuilder) + ); + contentParser.nextToken(); + NeuralQueryBuilder neuralQueryBuilder = NeuralQueryBuilder.fromXContent(contentParser); + + assertEquals(FIELD_NAME, neuralQueryBuilder.fieldName()); + assertEquals(QUERY_TEXT, neuralQueryBuilder.queryText()); + assertEquals(MODEL_ID, neuralQueryBuilder.modelId()); + assertEquals(K, neuralQueryBuilder.k()); + assertEquals(BOOST, neuralQueryBuilder.boost(), 0.0); + assertEquals(QUERY_NAME, neuralQueryBuilder.queryName()); + assertEquals(TEST_FILTER, neuralQueryBuilder.filter()); + } + @SneakyThrows public void testFromXContent_whenBuildWithMultipleRootFields_thenFail() { /* @@ -196,10 +262,44 @@ public void testFromXContent_whenBuildWithDuplicateParameters_thenFail() { expectThrows(IOException.class, () -> NeuralQueryBuilder.fromXContent(contentParser)); } + @SneakyThrows + public void testFromXContent_whenBuiltWithInvalidFilter_thenFail() { + /* + { + "VECTOR_FIELD": { + "query_text": "string", + "model_id": "string", + "k": int, + "boost": 10.0, + "filter": 12 + } + } + */ + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .startObject(FIELD_NAME) + .field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT) + .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID) + .field(K_FIELD.getPreferredName(), K) + .field(BOOST_FIELD.getPreferredName(), BOOST) + .field(NAME_FIELD.getPreferredName(), QUERY_NAME) + .field(FILTER_FIELD.getPreferredName(), 12) + .endObject() + .endObject(); + + XContentParser contentParser = createParser(xContentBuilder); + contentParser.nextToken(); + expectThrows(ParsingException.class, () -> NeuralQueryBuilder.fromXContent(contentParser)); + } + @SuppressWarnings("unchecked") @SneakyThrows public void testToXContent() { - NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(FIELD_NAME).modelId(MODEL_ID).queryText(QUERY_TEXT).k(K); + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(FIELD_NAME) + .modelId(MODEL_ID) + .queryText(QUERY_TEXT) + .k(K) + .filter(TEST_FILTER); XContentBuilder builder = XContentFactory.jsonBuilder(); builder = neuralQueryBuilder.toXContent(builder, ToXContent.EMPTY_PARAMS); @@ -225,6 +325,11 @@ public void testToXContent() { assertEquals(MODEL_ID, secondInnerMap.get(MODEL_ID_FIELD.getPreferredName())); assertEquals(QUERY_TEXT, secondInnerMap.get(QUERY_TEXT_FIELD.getPreferredName())); assertEquals(K, secondInnerMap.get(K_FIELD.getPreferredName())); + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder(); + assertEquals( + xContentBuilderToMap(TEST_FILTER.toXContent(xContentBuilder, EMPTY_PARAMS)), + secondInnerMap.get(FILTER_FIELD.getPreferredName()) + ); } @SneakyThrows @@ -236,11 +341,19 @@ public void testStreams() { original.k(K); original.boost(BOOST); original.queryName(QUERY_NAME); + original.filter(TEST_FILTER); BytesStreamOutput streamOutput = new BytesStreamOutput(); original.writeTo(streamOutput); - NeuralQueryBuilder copy = new NeuralQueryBuilder(streamOutput.bytes().streamInput()); + FilterStreamInput filterStreamInput = new NamedWriteableAwareStreamInput( + streamOutput.bytes().streamInput(), + new NamedWriteableRegistry( + List.of(new NamedWriteableRegistry.Entry(QueryBuilder.class, MatchAllQueryBuilder.NAME, MatchAllQueryBuilder::new)) + ) + ); + + NeuralQueryBuilder copy = new NeuralQueryBuilder(filterStreamInput); assertEquals(original, copy); } @@ -258,12 +371,16 @@ public void testHashAndEquals() { int k1 = 1; int k2 = 2; + QueryBuilder filter1 = new MatchAllQueryBuilder(); + QueryBuilder filter2 = new MatchNoneQueryBuilder(); + NeuralQueryBuilder neuralQueryBuilder_baseline = new NeuralQueryBuilder().fieldName(fieldName1) .queryText(queryText1) .modelId(modelId1) .k(k1) .boost(boost1) - .queryName(queryName1); + .queryName(queryName1) + .filter(filter1); // Identical to neuralQueryBuilder_baseline NeuralQueryBuilder neuralQueryBuilder_baselineCopy = new NeuralQueryBuilder().fieldName(fieldName1) @@ -271,13 +388,15 @@ public void testHashAndEquals() { .modelId(modelId1) .k(k1) .boost(boost1) - .queryName(queryName1); + .queryName(queryName1) + .filter(filter1); // Identical to neuralQueryBuilder_baseline except default boost and query name NeuralQueryBuilder neuralQueryBuilder_defaultBoostAndQueryName = new NeuralQueryBuilder().fieldName(fieldName1) .queryText(queryText1) .modelId(modelId1) - .k(k1); + .k(k1) + .filter(filter1); // Identical to neuralQueryBuilder_baseline except diff field name NeuralQueryBuilder neuralQueryBuilder_diffFieldName = new NeuralQueryBuilder().fieldName(fieldName2) @@ -285,7 +404,8 @@ public void testHashAndEquals() { .modelId(modelId1) .k(k1) .boost(boost1) - .queryName(queryName1); + .queryName(queryName1) + .filter(filter1); // Identical to neuralQueryBuilder_baseline except diff query text NeuralQueryBuilder neuralQueryBuilder_diffQueryText = new NeuralQueryBuilder().fieldName(fieldName1) @@ -293,7 +413,8 @@ public void testHashAndEquals() { .modelId(modelId1) .k(k1) .boost(boost1) - .queryName(queryName1); + .queryName(queryName1) + .filter(filter1); // Identical to neuralQueryBuilder_baseline except diff model ID NeuralQueryBuilder neuralQueryBuilder_diffModelId = new NeuralQueryBuilder().fieldName(fieldName1) @@ -301,7 +422,8 @@ public void testHashAndEquals() { .modelId(modelId2) .k(k1) .boost(boost1) - .queryName(queryName1); + .queryName(queryName1) + .filter(filter1); // Identical to neuralQueryBuilder_baseline except diff k NeuralQueryBuilder neuralQueryBuilder_diffK = new NeuralQueryBuilder().fieldName(fieldName1) @@ -309,7 +431,8 @@ public void testHashAndEquals() { .modelId(modelId1) .k(k2) .boost(boost1) - .queryName(queryName1); + .queryName(queryName1) + .filter(filter1); // Identical to neuralQueryBuilder_baseline except diff boost NeuralQueryBuilder neuralQueryBuilder_diffBoost = new NeuralQueryBuilder().fieldName(fieldName1) @@ -317,16 +440,35 @@ public void testHashAndEquals() { .modelId(modelId1) .k(k1) .boost(boost2) - .queryName(queryName1); + .queryName(queryName1) + .filter(filter1); // Identical to neuralQueryBuilder_baseline except diff query name NeuralQueryBuilder neuralQueryBuilder_diffQueryName = new NeuralQueryBuilder().fieldName(fieldName1) + .queryText(queryText1) + .modelId(modelId1) + .k(k1) + .boost(boost1) + .queryName(queryName2) + .filter(filter1); + + // Identical to neuralQueryBuilder_baseline except no filter + NeuralQueryBuilder neuralQueryBuilder_noFilter = new NeuralQueryBuilder().fieldName(fieldName1) .queryText(queryText1) .modelId(modelId1) .k(k1) .boost(boost1) .queryName(queryName2); + // Identical to neuralQueryBuilder_baseline except no filter + NeuralQueryBuilder neuralQueryBuilder_diffFilter = new NeuralQueryBuilder().fieldName(fieldName1) + .queryText(queryText1) + .modelId(modelId1) + .k(k1) + .boost(boost1) + .queryName(queryName2) + .filter(filter2); + assertEquals(neuralQueryBuilder_baseline, neuralQueryBuilder_baseline); assertEquals(neuralQueryBuilder_baseline.hashCode(), neuralQueryBuilder_baseline.hashCode()); @@ -353,6 +495,12 @@ public void testHashAndEquals() { assertNotEquals(neuralQueryBuilder_baseline, neuralQueryBuilder_diffQueryName); assertNotEquals(neuralQueryBuilder_baseline.hashCode(), neuralQueryBuilder_diffQueryName.hashCode()); + + assertNotEquals(neuralQueryBuilder_baseline, neuralQueryBuilder_noFilter); + assertNotEquals(neuralQueryBuilder_baseline.hashCode(), neuralQueryBuilder_noFilter.hashCode()); + + assertNotEquals(neuralQueryBuilder_baseline, neuralQueryBuilder_diffFilter); + assertNotEquals(neuralQueryBuilder_baseline.hashCode(), neuralQueryBuilder_diffFilter.hashCode()); } @SneakyThrows @@ -411,4 +559,17 @@ public void testRewrite_whenVectorSupplierAndVectorSet_thenReturnKNNQueryBuilder assertEquals(neuralQueryBuilder.k(), knnQueryBuilder.getK()); assertArrayEquals(TEST_VECTOR_SUPPLIER.get(), (float[]) knnQueryBuilder.vector(), 0.0f); } + + public void testRewrite_whenFilterSet_thenKNNQueryBuilderFilterSet() { + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(FIELD_NAME) + .queryText(QUERY_TEXT) + .modelId(MODEL_ID) + .k(K) + .vectorSupplier(TEST_VECTOR_SUPPLIER) + .filter(TEST_FILTER); + QueryBuilder queryBuilder = neuralQueryBuilder.doRewrite(null); + assertTrue(queryBuilder instanceof KNNQueryBuilder); + KNNQueryBuilder knnQueryBuilder = (KNNQueryBuilder) queryBuilder; + assertEquals(neuralQueryBuilder.filter(), knnQueryBuilder.getFilter()); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryIT.java index ae62f308f..0ea24f168 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryIT.java @@ -30,6 +30,7 @@ public class NeuralQueryIT extends BaseNeuralSearchIT { private static final String TEST_MULTI_VECTOR_FIELD_INDEX_NAME = "test-neural-multi-vector-field-index"; private static final String TEST_TEXT_AND_VECTOR_FIELD_INDEX_NAME = "test-neural-text-and-vector-field-index"; private static final String TEST_NESTED_INDEX_NAME = "test-neural-nested-index"; + private static final String TEST_MULTI_DOC_INDEX_NAME = "test-neural-multi-doc-index"; private static final String TEST_QUERY_TEXT = "Hello world"; private static final String TEST_KNN_VECTOR_FIELD_NAME_1 = "test-knn-vector-1"; private static final String TEST_KNN_VECTOR_FIELD_NAME_2 = "test-knn-vector-2"; @@ -71,6 +72,7 @@ public void testBasicQuery() { TEST_QUERY_TEXT, modelId.get(), 1, + null, null ); Map searchResponseAsMap = search(TEST_BASIC_INDEX_NAME, neuralQueryBuilder, 1); @@ -104,6 +106,7 @@ public void testBoostQuery() { TEST_QUERY_TEXT, modelId.get(), 1, + null, null ); @@ -146,6 +149,7 @@ public void testRescoreQuery() { TEST_QUERY_TEXT, modelId.get(), 1, + null, null ); @@ -192,6 +196,7 @@ public void testBooleanQuery_withMultipleNeuralQueries() { TEST_QUERY_TEXT, modelId.get(), 1, + null, null ); NeuralQueryBuilder neuralQueryBuilder2 = new NeuralQueryBuilder( @@ -199,6 +204,7 @@ public void testBooleanQuery_withMultipleNeuralQueries() { TEST_QUERY_TEXT, modelId.get(), 1, + null, null ); @@ -245,6 +251,7 @@ public void testBooleanQuery_withNeuralAndBM25Queries() { TEST_QUERY_TEXT, modelId.get(), 1, + null, null ); @@ -287,6 +294,7 @@ public void testNestedQuery() { TEST_QUERY_TEXT, modelId.get(), 1, + null, null ); @@ -298,6 +306,46 @@ public void testNestedQuery() { assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), 0.0); } + /** + * Tests filter query: + * { + * "query": { + * "neural": { + * "text_knn": { + * "query_text": "Hello world", + * "model_id": "dcsdcasd", + * "k": 1, + * "filter": { + * "match": { + * "_id": { + * "query": "3" + * } + * } + * } + * } + * } + * } + * } + */ + @SneakyThrows + public void testFilterQuery() { + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME); + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder( + TEST_KNN_VECTOR_FIELD_NAME_1, + TEST_QUERY_TEXT, + modelId.get(), + 1, + null, + new MatchQueryBuilder("_id", "3") + ); + Map searchResponseAsMap = search(TEST_MULTI_DOC_INDEX_NAME, neuralQueryBuilder, 3); + assertEquals(1, getHitCount(searchResponseAsMap)); + Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); + assertEquals("3", firstInnerHit.get("_id")); + float expectedScore = computeExpectedScore(modelId.get(), testVector, TEST_SPACE_TYPE, TEST_QUERY_TEXT); + assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), 0.0); + } + private void initializeIndexIfNotExist(String indexName) throws IOException { if (TEST_BASIC_INDEX_NAME.equals(indexName) && !indexExists(TEST_BASIC_INDEX_NAME)) { prepareKnnIndex( @@ -359,5 +407,31 @@ private void initializeIndexIfNotExist(String indexName) throws IOException { ); assertEquals(1, getDocCount(TEST_TEXT_AND_VECTOR_FIELD_INDEX_NAME)); } + + if (TEST_MULTI_DOC_INDEX_NAME.equals(indexName) && !indexExists(TEST_MULTI_DOC_INDEX_NAME)) { + prepareKnnIndex( + TEST_MULTI_DOC_INDEX_NAME, + Collections.singletonList(new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DIMENSION, TEST_SPACE_TYPE)) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_NAME, + "1", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector).toArray()) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_NAME, + "2", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector).toArray()) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_NAME, + "3", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector).toArray()) + ); + assertEquals(3, getDocCount(TEST_MULTI_DOC_INDEX_NAME)); + } } }