diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 6b01271c46b7..dc90bfa477cb 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -42,7 +42,11 @@ API Changes New Features --------------------- -(No changes) + +* GITHUB#14084, GITHUB#13635, GITHUB#13634: Adds new `SeededKnnByteVectorQuery` and `SeededKnnFloatVectorQuery` + queries. These queries allow for the vector search entry points to be initialized via a `seed` query. This follows + the research provided via https://arxiv.org/abs/2307.16779. (Sean MacAvaney, Ben Trent). + Improvements --------------------- diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java index 35144055830c..05157ab65cb5 100644 --- a/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java @@ -46,7 +46,7 @@ public class KnnByteVectorQuery extends AbstractKnnVectorQuery { private static final TopDocs NO_RESULTS = TopDocsCollector.EMPTY_TOPDOCS; - private final byte[] target; + protected final byte[] target; /** * Find the k nearest documents to the target vector according to the vectors in the diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnCollector.java b/lucene/core/src/java/org/apache/lucene/search/KnnCollector.java index 43bac9fbc309..f694d8f7085c 100644 --- a/lucene/core/src/java/org/apache/lucene/search/KnnCollector.java +++ b/lucene/core/src/java/org/apache/lucene/search/KnnCollector.java @@ -85,4 +85,58 @@ public interface KnnCollector { * @return The collected top documents */ TopDocs topDocs(); + + /** + * KnnCollector.Decorator is the base class for decorators of KnnCollector objects, which extend + * the object with new behaviors. + * + * @lucene.experimental + */ + abstract class Decorator implements KnnCollector { + private final KnnCollector collector; + + public Decorator(KnnCollector collector) { + this.collector = collector; + } + + @Override + public boolean earlyTerminated() { + return collector.earlyTerminated(); + } + + @Override + public void incVisitedCount(int count) { + collector.incVisitedCount(count); + } + + @Override + public long visitedCount() { + return collector.visitedCount(); + } + + @Override + public long visitLimit() { + return collector.visitLimit(); + } + + @Override + public int k() { + return collector.k(); + } + + @Override + public boolean collect(int docId, float similarity) { + return collector.collect(docId, similarity); + } + + @Override + public float minCompetitiveSimilarity() { + return collector.minCompetitiveSimilarity(); + } + + @Override + public TopDocs topDocs() { + return collector.topDocs(); + } + } } diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java index d2aaf4296eda..c7d6fdb3608d 100644 --- a/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java @@ -47,7 +47,7 @@ public class KnnFloatVectorQuery extends AbstractKnnVectorQuery { private static final TopDocs NO_RESULTS = TopDocsCollector.EMPTY_TOPDOCS; - private final float[] target; + protected final float[] target; /** * Find the k nearest documents to the target vector according to the vectors in the diff --git a/lucene/core/src/java/org/apache/lucene/search/SeededKnnByteVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/SeededKnnByteVectorQuery.java new file mode 100644 index 000000000000..980b6869c34f --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/SeededKnnByteVectorQuery.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search; + +import java.io.IOException; +import java.util.Objects; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.search.knn.KnnCollectorManager; +import org.apache.lucene.search.knn.SeededKnnCollectorManager; + +/** + * This is a version of knn byte vector query that provides a query seed to initiate the vector + * search. NOTE: The underlying format is free to ignore the provided seed + * + *

See "Lexically-Accelerated Dense + * Retrieval" (Kulkarni, Hrishikesh and MacAvaney, Sean and Goharian, Nazli and Frieder, Ophir). + * In SIGIR '23: Proceedings of the 46th International ACM SIGIR Conference on Research and + * Development in Information Retrieval Pages 152 - 162 + * + * @lucene.experimental + */ +public class SeededKnnByteVectorQuery extends KnnByteVectorQuery { + final Query seed; + final Weight seedWeight; + + /** + * Construct a new SeededKnnByteVectorQuery instance + * + * @param field knn byte vector field to query + * @param target the query vector + * @param k number of neighbors to return + * @param filter a filter on the neighbors to return + * @param seed a query seed to initiate the vector format search + */ + public SeededKnnByteVectorQuery(String field, byte[] target, int k, Query filter, Query seed) { + super(field, target, k, filter); + this.seed = Objects.requireNonNull(seed); + this.seedWeight = null; + } + + SeededKnnByteVectorQuery(String field, byte[] target, int k, Query filter, Weight seedWeight) { + super(field, target, k, filter); + this.seed = null; + this.seedWeight = Objects.requireNonNull(seedWeight); + } + + @Override + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + if (seedWeight != null) { + return super.rewrite(indexSearcher); + } + BooleanQuery.Builder booleanSeedQueryBuilder = + new BooleanQuery.Builder() + .add(seed, BooleanClause.Occur.MUST) + .add(new FieldExistsQuery(field), BooleanClause.Occur.FILTER); + if (filter != null) { + booleanSeedQueryBuilder.add(filter, BooleanClause.Occur.FILTER); + } + Query seedRewritten = indexSearcher.rewrite(booleanSeedQueryBuilder.build()); + Weight seedWeight = indexSearcher.createWeight(seedRewritten, ScoreMode.TOP_SCORES, 1f); + SeededKnnByteVectorQuery rewritten = + new SeededKnnByteVectorQuery(field, target, k, filter, seedWeight); + return rewritten.rewrite(indexSearcher); + } + + @Override + protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) { + if (seedWeight == null) { + throw new UnsupportedOperationException("must be rewritten before constructing manager"); + } + return new SeededKnnCollectorManager( + super.getKnnCollectorManager(k, searcher), + seedWeight, + k, + leaf -> { + ByteVectorValues vv = leaf.getByteVectorValues(field); + if (vv == null) { + ByteVectorValues.checkField(leaf.getContext().reader(), field); + } + return vv; + }); + } +} diff --git a/lucene/core/src/java/org/apache/lucene/search/SeededKnnFloatVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/SeededKnnFloatVectorQuery.java new file mode 100644 index 000000000000..02a33bdcdef7 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/SeededKnnFloatVectorQuery.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search; + +import java.io.IOException; +import java.util.Objects; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.search.knn.KnnCollectorManager; +import org.apache.lucene.search.knn.SeededKnnCollectorManager; + +/** + * This is a version of knn float vector query that provides a query seed to initiate the vector + * search. NOTE: The underlying format is free to ignore the provided seed. + * + *

See "Lexically-Accelerated Dense + * Retrieval" (Kulkarni, Hrishikesh and MacAvaney, Sean and Goharian, Nazli and Frieder, Ophir). + * In SIGIR '23: Proceedings of the 46th International ACM SIGIR Conference on Research and + * Development in Information Retrieval Pages 152 - 162 + * + * @lucene.experimental + */ +public class SeededKnnFloatVectorQuery extends KnnFloatVectorQuery { + final Query seed; + final Weight seedWeight; + + /** + * Construct a new SeededKnnFloatVectorQuery instance + * + * @param field knn float vector field to query + * @param target the query vector + * @param k number of neighbors to return + * @param filter a filter on the neighbors to return + * @param seed a query seed to initiate the vector format search + */ + public SeededKnnFloatVectorQuery(String field, float[] target, int k, Query filter, Query seed) { + super(field, target, k, filter); + this.seed = Objects.requireNonNull(seed); + this.seedWeight = null; + } + + SeededKnnFloatVectorQuery(String field, float[] target, int k, Query filter, Weight seedWeight) { + super(field, target, k, filter); + this.seed = null; + this.seedWeight = Objects.requireNonNull(seedWeight); + } + + @Override + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + if (seedWeight != null) { + return super.rewrite(indexSearcher); + } + BooleanQuery.Builder booleanSeedQueryBuilder = + new BooleanQuery.Builder() + .add(seed, BooleanClause.Occur.MUST) + .add(new FieldExistsQuery(field), BooleanClause.Occur.FILTER); + if (filter != null) { + booleanSeedQueryBuilder.add(filter, BooleanClause.Occur.FILTER); + } + Query seedRewritten = indexSearcher.rewrite(booleanSeedQueryBuilder.build()); + Weight seedWeight = indexSearcher.createWeight(seedRewritten, ScoreMode.TOP_SCORES, 1f); + SeededKnnFloatVectorQuery rewritten = + new SeededKnnFloatVectorQuery(field, target, k, filter, seedWeight); + return rewritten.rewrite(indexSearcher); + } + + @Override + protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) { + if (seedWeight == null) { + throw new UnsupportedOperationException("must be rewritten before constructing manager"); + } + return new SeededKnnCollectorManager( + super.getKnnCollectorManager(k, searcher), + seedWeight, + k, + leaf -> { + FloatVectorValues vv = leaf.getFloatVectorValues(field); + if (vv == null) { + FloatVectorValues.checkField(leaf.getContext().reader(), field); + } + return vv; + }); + } +} diff --git a/lucene/core/src/java/org/apache/lucene/search/TimeLimitingKnnCollectorManager.java b/lucene/core/src/java/org/apache/lucene/search/TimeLimitingKnnCollectorManager.java index 2a1f312fbc58..2dc2f035b90f 100644 --- a/lucene/core/src/java/org/apache/lucene/search/TimeLimitingKnnCollectorManager.java +++ b/lucene/core/src/java/org/apache/lucene/search/TimeLimitingKnnCollectorManager.java @@ -45,51 +45,19 @@ public KnnCollector newCollector(int visitedLimit, LeafReaderContext context) th return new TimeLimitingKnnCollector(collector); } - class TimeLimitingKnnCollector implements KnnCollector { - private final KnnCollector collector; - - TimeLimitingKnnCollector(KnnCollector collector) { - this.collector = collector; + class TimeLimitingKnnCollector extends KnnCollector.Decorator { + public TimeLimitingKnnCollector(KnnCollector collector) { + super(collector); } @Override public boolean earlyTerminated() { - return queryTimeout.shouldExit() || collector.earlyTerminated(); - } - - @Override - public void incVisitedCount(int count) { - collector.incVisitedCount(count); - } - - @Override - public long visitedCount() { - return collector.visitedCount(); - } - - @Override - public long visitLimit() { - return collector.visitLimit(); - } - - @Override - public int k() { - return collector.k(); - } - - @Override - public boolean collect(int docId, float similarity) { - return collector.collect(docId, similarity); - } - - @Override - public float minCompetitiveSimilarity() { - return collector.minCompetitiveSimilarity(); + return queryTimeout.shouldExit() || super.earlyTerminated(); } @Override public TopDocs topDocs() { - TopDocs docs = collector.topDocs(); + TopDocs docs = super.topDocs(); // Mark results as partial if timeout is met TotalHits.Relation relation = diff --git a/lucene/core/src/java/org/apache/lucene/search/knn/EntryPointProvider.java b/lucene/core/src/java/org/apache/lucene/search/knn/EntryPointProvider.java new file mode 100644 index 000000000000..9e7b44b571df --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/knn/EntryPointProvider.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search.knn; + +import org.apache.lucene.search.DocIdSetIterator; + +/** Provides entry points for the kNN search */ +public interface EntryPointProvider { + /** Iterator of valid entry points for the kNN search */ + DocIdSetIterator entryPoints(); + + /** Number of valid entry points for the kNN search */ + int numberOfEntryPoints(); +} diff --git a/lucene/core/src/java/org/apache/lucene/search/knn/SeededKnnCollector.java b/lucene/core/src/java/org/apache/lucene/search/knn/SeededKnnCollector.java new file mode 100644 index 000000000000..c3c4f62901ee --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/knn/SeededKnnCollector.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search.knn; + +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.KnnCollector; + +/** + * A {@link KnnCollector} that provides seeded knn collection. See usage in {@link + * SeededKnnCollectorManager}. + * + * @lucene.experimental + */ +class SeededKnnCollector extends KnnCollector.Decorator implements EntryPointProvider { + private final DocIdSetIterator entryPoints; + private final int numberOfEntryPoints; + + SeededKnnCollector( + KnnCollector collector, DocIdSetIterator entryPoints, int numberOfEntryPoints) { + super(collector); + this.entryPoints = entryPoints; + this.numberOfEntryPoints = numberOfEntryPoints; + } + + @Override + public DocIdSetIterator entryPoints() { + return entryPoints; + } + + @Override + public int numberOfEntryPoints() { + return numberOfEntryPoints; + } +} diff --git a/lucene/core/src/java/org/apache/lucene/search/knn/SeededKnnCollectorManager.java b/lucene/core/src/java/org/apache/lucene/search/knn/SeededKnnCollectorManager.java new file mode 100644 index 000000000000..7631db6e3022 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/knn/SeededKnnCollectorManager.java @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search.knn; + +import java.io.IOException; +import java.util.Arrays; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.BulkScorer; +import org.apache.lucene.search.CollectionTerminatedException; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.search.LeafCollector; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TopScoreDocCollector; +import org.apache.lucene.search.TopScoreDocCollectorManager; +import org.apache.lucene.search.Weight; +import org.apache.lucene.util.IOFunction; + +/** + * A {@link KnnCollectorManager} that provides seeded knn collection. See usage in {@link + * org.apache.lucene.search.SeededKnnFloatVectorQuery} and {@link + * org.apache.lucene.search.SeededKnnByteVectorQuery}. + */ +public class SeededKnnCollectorManager implements KnnCollectorManager { + private final KnnCollectorManager delegate; + private final Weight seedWeight; + private final int k; + private final IOFunction vectorValuesSupplier; + + public SeededKnnCollectorManager( + KnnCollectorManager delegate, + Weight seedWeight, + int k, + IOFunction vectorValuesSupplier) { + this.delegate = delegate; + this.seedWeight = seedWeight; + this.k = k; + this.vectorValuesSupplier = vectorValuesSupplier; + } + + @Override + public KnnCollector newCollector(int visitedLimit, LeafReaderContext ctx) throws IOException { + // Execute the seed query + TopScoreDocCollector seedCollector = + new TopScoreDocCollectorManager(k, null, Integer.MAX_VALUE).newCollector(); + final LeafReader leafReader = ctx.reader(); + final LeafCollector leafCollector = seedCollector.getLeafCollector(ctx); + if (leafCollector != null) { + try { + BulkScorer scorer = seedWeight.bulkScorer(ctx); + if (scorer != null) { + scorer.score( + leafCollector, + leafReader.getLiveDocs(), + 0 /* min */, + DocIdSetIterator.NO_MORE_DOCS /* max */); + } + } catch ( + @SuppressWarnings("unused") + CollectionTerminatedException e) { + } + leafCollector.finish(); + } + + TopDocs seedTopDocs = seedCollector.topDocs(); + KnnVectorValues vectorValues = vectorValuesSupplier.apply(leafReader); + final KnnCollector delegateCollector = delegate.newCollector(visitedLimit, ctx); + if (seedTopDocs.totalHits.value() == 0 || vectorValues == null) { + return delegateCollector; + } + KnnVectorValues.DocIndexIterator indexIterator = vectorValues.iterator(); + DocIdSetIterator seedDocs = new MappedDISI(indexIterator, new TopDocsDISI(seedTopDocs)); + return new SeededKnnCollector(delegateCollector, seedDocs, seedTopDocs.scoreDocs.length); + } + + private static class MappedDISI extends DocIdSetIterator { + KnnVectorValues.DocIndexIterator indexedDISI; + DocIdSetIterator sourceDISI; + + private MappedDISI(KnnVectorValues.DocIndexIterator indexedDISI, DocIdSetIterator sourceDISI) { + this.indexedDISI = indexedDISI; + this.sourceDISI = sourceDISI; + } + + /** + * Advances the source iterator to the first document number that is greater than or equal to + * the provided target and returns the corresponding index. + */ + @Override + public int advance(int target) throws IOException { + int newTarget = sourceDISI.advance(target); + if (newTarget != NO_MORE_DOCS) { + indexedDISI.advance(newTarget); + } + return docID(); + } + + @Override + public long cost() { + return sourceDISI.cost(); + } + + @Override + public int docID() { + if (indexedDISI.docID() == NO_MORE_DOCS || sourceDISI.docID() == NO_MORE_DOCS) { + return NO_MORE_DOCS; + } + return indexedDISI.index(); + } + + /** Advances to the next document in the source iterator and returns the corresponding index. */ + @Override + public int nextDoc() throws IOException { + int newTarget = sourceDISI.nextDoc(); + if (newTarget != NO_MORE_DOCS) { + indexedDISI.advance(newTarget); + } + return docID(); + } + } + + private static class TopDocsDISI extends DocIdSetIterator { + private final int[] sortedDocIds; + private int idx = -1; + + private TopDocsDISI(TopDocs topDocs) { + sortedDocIds = new int[topDocs.scoreDocs.length]; + for (int i = 0; i < topDocs.scoreDocs.length; i++) { + sortedDocIds[i] = topDocs.scoreDocs[i].doc; + } + Arrays.sort(sortedDocIds); + } + + @Override + public int advance(int target) throws IOException { + return slowAdvance(target); + } + + @Override + public long cost() { + return sortedDocIds.length; + } + + @Override + public int docID() { + if (idx == -1) { + return -1; + } else if (idx >= sortedDocIds.length) { + return DocIdSetIterator.NO_MORE_DOCS; + } else { + return sortedDocIds[idx]; + } + } + + @Override + public int nextDoc() { + idx += 1; + return docID(); + } + } +} diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java index 46d6c93d52c3..e8f0d316fd81 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java @@ -20,8 +20,10 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import java.io.IOException; +import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.TopKnnCollector; +import org.apache.lucene.search.knn.EntryPointProvider; import org.apache.lucene.util.BitSet; import org.apache.lucene.util.Bits; import org.apache.lucene.util.FixedBitSet; @@ -52,7 +54,9 @@ public HnswGraphSearcher(NeighborQueue candidates, BitSet visited) { } /** - * Searches HNSW graph for the nearest neighbors of a query vector. + * Searches the HNSW graph for the nearest neighbors of a query vector. If entry points are + * directly provided via the knnCollector, then the search will be initialized at those points. + * Otherwise, the search will discover the best entry point per the normal HNSW search algorithm. * * @param scorer the scorer to compare the query with the nodes * @param knnCollector a collector of top knn results to be returned @@ -67,7 +71,30 @@ public static void search( HnswGraphSearcher graphSearcher = new HnswGraphSearcher( new NeighborQueue(knnCollector.k(), true), new SparseFixedBitSet(getGraphSize(graph))); - search(scorer, knnCollector, graph, graphSearcher, acceptOrds); + final int[] entryPoints; + if (knnCollector instanceof EntryPointProvider epp) { + if (epp.numberOfEntryPoints() <= 0) { + throw new IllegalArgumentException("The number of entry points must be > 0"); + } + DocIdSetIterator eps = epp.entryPoints(); + entryPoints = new int[epp.numberOfEntryPoints()]; + int idx = 0; + while (idx < entryPoints.length) { + int entryPointOrdInt = eps.nextDoc(); + if (entryPointOrdInt == NO_MORE_DOCS) { + throw new IllegalArgumentException( + "The number of entry points provided is less than the number of entry points requested"); + } + assert entryPointOrdInt < getGraphSize(graph); + entryPoints[idx++] = entryPointOrdInt; + } + // This is an invalid case, but we should check it + assert entryPoints.length > 0; + // We use provided entry point ordinals to search the complete graph (level 0) + graphSearcher.searchLevel(knnCollector, scorer, 0, entryPoints, graph, acceptOrds); + } else { + search(scorer, knnCollector, graph, graphSearcher, acceptOrds); + } } /** diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/OrdinalTranslatedKnnCollector.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/OrdinalTranslatedKnnCollector.java index ed1a5ffb59fa..5225fe700ab9 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/OrdinalTranslatedKnnCollector.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/OrdinalTranslatedKnnCollector.java @@ -24,54 +24,24 @@ /** * Wraps a provided KnnCollector object, translating the provided vectorId ordinal to a documentId */ -public final class OrdinalTranslatedKnnCollector implements KnnCollector { +public final class OrdinalTranslatedKnnCollector extends KnnCollector.Decorator { - private final KnnCollector in; private final IntToIntFunction vectorOrdinalToDocId; - public OrdinalTranslatedKnnCollector(KnnCollector in, IntToIntFunction vectorOrdinalToDocId) { - this.in = in; + public OrdinalTranslatedKnnCollector( + KnnCollector collector, IntToIntFunction vectorOrdinalToDocId) { + super(collector); this.vectorOrdinalToDocId = vectorOrdinalToDocId; } - @Override - public boolean earlyTerminated() { - return in.earlyTerminated(); - } - - @Override - public void incVisitedCount(int count) { - in.incVisitedCount(count); - } - - @Override - public long visitedCount() { - return in.visitedCount(); - } - - @Override - public long visitLimit() { - return in.visitLimit(); - } - - @Override - public int k() { - return in.k(); - } - @Override public boolean collect(int vectorId, float similarity) { - return in.collect(vectorOrdinalToDocId.apply(vectorId), similarity); - } - - @Override - public float minCompetitiveSimilarity() { - return in.minCompetitiveSimilarity(); + return super.collect(vectorOrdinalToDocId.apply(vectorId), similarity); } @Override public TopDocs topDocs() { - TopDocs td = in.topDocs(); + TopDocs td = super.topDocs(); return new TopDocs( new TotalHits( visitedCount(), diff --git a/lucene/core/src/test/org/apache/lucene/document/TestManyKnnDocs.java b/lucene/core/src/test/org/apache/lucene/document/TestManyKnnDocs.java index 2023ee73391d..1e485515a62b 100644 --- a/lucene/core/src/test/org/apache/lucene/document/TestManyKnnDocs.java +++ b/lucene/core/src/test/org/apache/lucene/document/TestManyKnnDocs.java @@ -17,6 +17,7 @@ package org.apache.lucene.document; import com.carrotsearch.randomizedtesting.annotations.TimeoutSuite; +import java.nio.file.Path; import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; @@ -24,19 +25,27 @@ import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.KnnFloatVectorQuery; +import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.search.MatchNoDocsQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.SeededKnnFloatVectorQuery; import org.apache.lucene.search.TopDocs; import org.apache.lucene.store.Directory; import org.apache.lucene.store.FSDirectory; import org.apache.lucene.tests.codecs.vector.ConfigurableMCodec; import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.tests.util.LuceneTestCase.Monster; +import org.junit.BeforeClass; @TimeoutSuite(millis = 86_400_000) // 24 hour timeout @Monster("takes ~10 minutes and needs extra heap, disk space, file handles") public class TestManyKnnDocs extends LuceneTestCase { // gradlew -p lucene/core test --tests TestManyKnnDocs -Ptests.heapsize=16g -Dtests.monster=true - public void testLargeSegment() throws Exception { + private static Path testDir; + + @BeforeClass + public static void init_index() throws Exception { IndexWriterConfig iwc = new IndexWriterConfig(); iwc.setCodec( new ConfigurableMCodec( @@ -46,27 +55,138 @@ public void testLargeSegment() throws Exception { mp.setMaxMergeAtOnce(256); // avoid intermediate merges (waste of time with HNSW?) mp.setSegmentsPerTier(256); // only merge once at the end when we ask iwc.setMergePolicy(mp); - String fieldName = "field"; VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.DOT_PRODUCT; - try (Directory dir = FSDirectory.open(createTempDir("ManyKnnVectorDocs")); + try (Directory dir = FSDirectory.open(testDir = createTempDir("ManyKnnVectorDocs")); IndexWriter iw = new IndexWriter(dir, iwc)) { int numVectors = 2088992; - float[] vector = new float[1]; - Document doc = new Document(); - doc.add(new KnnFloatVectorField(fieldName, vector, similarityFunction)); for (int i = 0; i < numVectors; i++) { + float[] vector = new float[1]; + Document doc = new Document(); vector[0] = (i % 256); + doc.add(new KnnFloatVectorField("field", vector, similarityFunction)); + doc.add(new KeywordField("int", "" + i, org.apache.lucene.document.Field.Store.YES)); + doc.add(new StoredField("intValue", i)); iw.addDocument(doc); } // merge to single segment and then verify iw.forceMerge(1); iw.commit(); + } + } + + public void testLargeSegmentKnn() throws Exception { + try (Directory dir = FSDirectory.open(testDir)) { IndexSearcher searcher = new IndexSearcher(DirectoryReader.open(dir)); - TopDocs docs = searcher.search(new KnnFloatVectorQuery("field", new float[] {120}, 10), 5); - assertEquals(5, docs.scoreDocs.length); + for (int i = 0; i < 256; i++) { + Query filterQuery = new MatchAllDocsQuery(); + float[] vector = new float[128]; + vector[0] = i; + vector[1] = 1; + TopDocs docs = + searcher.search(new KnnFloatVectorQuery("field", vector, 10, filterQuery), 5); + assertEquals(5, docs.scoreDocs.length); + Document d = searcher.storedFields().document(docs.scoreDocs[0].doc); + String s = ""; + for (int j = 0; j < docs.scoreDocs.length - 1; j++) { + s += docs.scoreDocs[j].doc + " " + docs.scoreDocs[j].score + "\n"; + } + assertEquals(s, i + 256, d.getField("intValue").numericValue()); + } + } + } + + public void testLargeSegmentSeededExact() throws Exception { + try (Directory dir = FSDirectory.open(testDir)) { + IndexSearcher searcher = new IndexSearcher(DirectoryReader.open(dir)); + for (int i = 0; i < 256; i++) { + Query seedQuery = KeywordField.newExactQuery("int", "" + (i + 256)); + Query filterQuery = new MatchAllDocsQuery(); + float[] vector = new float[128]; + vector[0] = i; + vector[1] = 1; + TopDocs docs = + searcher.search( + new SeededKnnFloatVectorQuery("field", vector, 10, filterQuery, seedQuery), 5); + assertEquals(5, docs.scoreDocs.length); + String s = ""; + for (int j = 0; j < docs.scoreDocs.length - 1; j++) { + s += docs.scoreDocs[j].doc + " " + docs.scoreDocs[j].score + "\n"; + } + Document d = searcher.storedFields().document(docs.scoreDocs[0].doc); + assertEquals(s, i + 256, d.getField("intValue").numericValue()); + } + } + } + + public void testLargeSegmentSeededNearby() throws Exception { + try (Directory dir = FSDirectory.open(testDir)) { + IndexSearcher searcher = new IndexSearcher(DirectoryReader.open(dir)); + for (int i = 0; i < 256; i++) { + Query seedQuery = KeywordField.newExactQuery("int", "" + i); + Query filterQuery = new MatchAllDocsQuery(); + float[] vector = new float[128]; + vector[0] = i; + vector[1] = 1; + TopDocs docs = + searcher.search( + new SeededKnnFloatVectorQuery("field", vector, 10, filterQuery, seedQuery), 5); + assertEquals(5, docs.scoreDocs.length); + String s = ""; + for (int j = 0; j < docs.scoreDocs.length - 1; j++) { + s += docs.scoreDocs[j].doc + " " + docs.scoreDocs[j].score + "\n"; + } + Document d = searcher.storedFields().document(docs.scoreDocs[0].doc); + assertEquals(s, i + 256, d.getField("intValue").numericValue()); + } + } + } + + public void testLargeSegmentSeededDistant() throws Exception { + try (Directory dir = FSDirectory.open(testDir)) { + IndexSearcher searcher = new IndexSearcher(DirectoryReader.open(dir)); + for (int i = 0; i < 256; i++) { + Query seedQuery = KeywordField.newExactQuery("int", "" + (i + 128)); + Query filterQuery = new MatchAllDocsQuery(); + float[] vector = new float[128]; + vector[0] = i; + vector[1] = 1; + TopDocs docs = + searcher.search( + new SeededKnnFloatVectorQuery("field", vector, 10, filterQuery, seedQuery), 5); + assertEquals(5, docs.scoreDocs.length); + Document d = searcher.storedFields().document(docs.scoreDocs[0].doc); + String s = ""; + for (int j = 0; j < docs.scoreDocs.length - 1; j++) { + s += docs.scoreDocs[j].doc + " " + docs.scoreDocs[j].score + "\n"; + } + assertEquals(s, i + 256, d.getField("intValue").numericValue()); + } + } + } + + public void testLargeSegmentSeededNone() throws Exception { + try (Directory dir = FSDirectory.open(testDir)) { + IndexSearcher searcher = new IndexSearcher(DirectoryReader.open(dir)); + for (int i = 0; i < 256; i++) { + Query seedQuery = new MatchNoDocsQuery(); + Query filterQuery = new MatchAllDocsQuery(); + float[] vector = new float[128]; + vector[0] = i; + vector[1] = 1; + TopDocs docs = + searcher.search( + new SeededKnnFloatVectorQuery("field", vector, 10, filterQuery, seedQuery), 5); + assertEquals(5, docs.scoreDocs.length); + Document d = searcher.storedFields().document(docs.scoreDocs[0].doc); + String s = ""; + for (int j = 0; j < docs.scoreDocs.length - 1; j++) { + s += docs.scoreDocs[j].doc + " " + docs.scoreDocs[j].score + "\n"; + } + assertEquals(s, i + 256, d.getField("intValue").numericValue()); + } } } } diff --git a/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java index b45d6e8fb641..21219e0e1d99 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java @@ -61,7 +61,7 @@ Field getKnnVectorField(String name, float[] vector) { return new KnnByteVectorField(name, floatToBytes(vector), VectorSimilarityFunction.EUCLIDEAN); } - private static byte[] floatToBytes(float[] query) { + static byte[] floatToBytes(float[] query) { byte[] bytes = new byte[query.length]; for (int i = 0; i < query.length; i++) { assert query[i] <= Byte.MAX_VALUE && query[i] >= Byte.MIN_VALUE && (query[i] % 1) == 0 @@ -109,7 +109,7 @@ public void testVectorEncodingMismatch() throws IOException { } } - private static class ThrowingKnnVectorQuery extends KnnByteVectorQuery { + static class ThrowingKnnVectorQuery extends KnnByteVectorQuery { public ThrowingKnnVectorQuery(String field, byte[] target, int k, Query filter) { super(field, target, k, filter); diff --git a/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java index 5dcb6f97df93..ece2b385654e 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java @@ -259,7 +259,7 @@ public void testDocAndScoreQueryBasics() throws IOException { } } - private static class ThrowingKnnVectorQuery extends KnnFloatVectorQuery { + static class ThrowingKnnVectorQuery extends KnnFloatVectorQuery { public ThrowingKnnVectorQuery(String field, float[] target, int k, Query filter) { super(field, target, k, filter); diff --git a/lucene/core/src/test/org/apache/lucene/search/TestSeededKnnByteVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestSeededKnnByteVectorQuery.java new file mode 100644 index 000000000000..d0fb8c95e035 --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/search/TestSeededKnnByteVectorQuery.java @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search; + +import static org.apache.lucene.search.TestKnnByteVectorQuery.floatToBytes; + +import java.io.IOException; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.IntPoint; +import org.apache.lucene.document.KnnByteVectorField; +import org.apache.lucene.document.NumericDocValuesField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.QueryTimeout; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.RandomIndexWriter; +import org.apache.lucene.tests.util.TestUtil; +import org.apache.lucene.util.TestVectorUtil; + +public class TestSeededKnnByteVectorQuery extends BaseKnnVectorQueryTestCase { + + private static final Query MATCH_NONE = new MatchNoDocsQuery(); + + @Override + AbstractKnnVectorQuery getKnnVectorQuery(String field, float[] query, int k, Query queryFilter) { + return new SeededKnnByteVectorQuery(field, floatToBytes(query), k, queryFilter, MATCH_NONE); + } + + @Override + AbstractKnnVectorQuery getThrowingKnnVectorQuery(String field, float[] vec, int k, Query query) { + return new ThrowingKnnVectorQuery(field, floatToBytes(vec), k, query, MATCH_NONE); + } + + @Override + float[] randomVector(int dim) { + byte[] b = TestVectorUtil.randomVectorBytes(dim); + float[] v = new float[b.length]; + int vi = 0; + for (int i = 0; i < v.length; i++) { + v[vi++] = b[i]; + } + return v; + } + + @Override + Field getKnnVectorField( + String name, float[] vector, VectorSimilarityFunction similarityFunction) { + return new KnnByteVectorField(name, floatToBytes(vector), similarityFunction); + } + + @Override + Field getKnnVectorField(String name, float[] vector) { + return new KnnByteVectorField(name, floatToBytes(vector), VectorSimilarityFunction.EUCLIDEAN); + } + + /** Tests with random vectors and a random seed. Uses RandomIndexWriter. */ + public void testRandomWithSeed() throws IOException { + int numDocs = 1000; + int dimension = atLeast(5); + int numIters = atLeast(10); + int numDocsWithVector = 0; + try (Directory d = newDirectoryForTest()) { + // Always use the default kNN format to have predictable behavior around when it hits + // visitedLimit. This is fine since the test targets AbstractKnnVectorQuery logic, not the kNN + // format + // implementation. + IndexWriterConfig iwc = new IndexWriterConfig().setCodec(TestUtil.getDefaultCodec()); + RandomIndexWriter w = new RandomIndexWriter(random(), d, iwc); + for (int i = 0; i < numDocs; i++) { + Document doc = new Document(); + if (random().nextBoolean()) { + // Randomly skip some vectors to test the mapping from docid to ordinals + doc.add(getKnnVectorField("field", randomVector(dimension))); + numDocsWithVector += 1; + } + doc.add(new NumericDocValuesField("tag", i)); + doc.add(new IntPoint("tag", i)); + w.addDocument(doc); + } + w.forceMerge(1); + w.close(); + + try (IndexReader reader = DirectoryReader.open(d)) { + IndexSearcher searcher = newSearcher(reader); + for (int i = 0; i < numIters; i++) { + int k = random().nextInt(80) + 1; + int n = random().nextInt(100) + 1; + // we may get fewer results than requested if there are deletions, but this test doesn't + // check that + assert reader.hasDeletions() == false; + + // All documents as seeds + Query seed1 = new MatchAllDocsQuery(); + Query filter = random().nextBoolean() ? null : new MatchAllDocsQuery(); + SeededKnnByteVectorQuery query = + new SeededKnnByteVectorQuery( + "field", floatToBytes(randomVector(dimension)), k, filter, seed1); + TopDocs results = searcher.search(query, n); + int expected = Math.min(Math.min(n, k), numDocsWithVector); + + assertEquals(expected, results.scoreDocs.length); + assertTrue(results.totalHits.value() >= results.scoreDocs.length); + // verify the results are in descending score order + float last = Float.MAX_VALUE; + for (ScoreDoc scoreDoc : results.scoreDocs) { + assertTrue(scoreDoc.score <= last); + last = scoreDoc.score; + } + + // Restrictive seed query -- 6 documents + Query seed2 = IntPoint.newRangeQuery("tag", 1, 6); + query = + new SeededKnnByteVectorQuery( + "field", floatToBytes(randomVector(dimension)), k, null, seed2); + results = searcher.search(query, n); + expected = Math.min(Math.min(n, k), reader.numDocs()); + assertEquals(expected, results.scoreDocs.length); + assertTrue(results.totalHits.value() >= results.scoreDocs.length); + // verify the results are in descending score order + last = Float.MAX_VALUE; + for (ScoreDoc scoreDoc : results.scoreDocs) { + assertTrue(scoreDoc.score <= last); + last = scoreDoc.score; + } + + // No seed documents -- falls back on full approx search + Query seed3 = new MatchNoDocsQuery(); + query = + new SeededKnnByteVectorQuery( + "field", floatToBytes(randomVector(dimension)), k, null, seed3); + results = searcher.search(query, n); + expected = Math.min(Math.min(n, k), reader.numDocs()); + assertEquals(expected, results.scoreDocs.length); + assertTrue(results.totalHits.value() >= results.scoreDocs.length); + // verify the results are in descending score order + last = Float.MAX_VALUE; + for (ScoreDoc scoreDoc : results.scoreDocs) { + assertTrue(scoreDoc.score <= last); + last = scoreDoc.score; + } + } + } + } + } + + private static class ThrowingKnnVectorQuery extends SeededKnnByteVectorQuery { + + public ThrowingKnnVectorQuery(String field, byte[] target, int k, Query filter, Query seed) { + super(field, target, k, filter, seed); + } + + private ThrowingKnnVectorQuery( + String field, byte[] target, int k, Query filter, Weight seedWeight) { + super(field, target, k, filter, seedWeight); + } + + @Override + // This is test only and we need to overwrite the inner rewrite to throw + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + if (seedWeight != null) { + return super.rewrite(indexSearcher); + } + BooleanQuery.Builder booleanSeedQueryBuilder = + new BooleanQuery.Builder() + .add(seed, BooleanClause.Occur.MUST) + .add(new FieldExistsQuery(field), BooleanClause.Occur.FILTER); + if (filter != null) { + booleanSeedQueryBuilder.add(filter, BooleanClause.Occur.FILTER); + } + Query seedRewritten = indexSearcher.rewrite(booleanSeedQueryBuilder.build()); + Weight seedWeight = indexSearcher.createWeight(seedRewritten, ScoreMode.TOP_SCORES, 1f); + return new ThrowingKnnVectorQuery(field, target, k, filter, seedWeight) + .rewrite(indexSearcher); + } + + @Override + protected TopDocs exactSearch( + LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout) { + throw new UnsupportedOperationException("exact search is not supported"); + } + + @Override + public String toString(String field) { + return null; + } + } +} diff --git a/lucene/core/src/test/org/apache/lucene/search/TestSeededKnnFloatVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestSeededKnnFloatVectorQuery.java new file mode 100644 index 000000000000..d5630037ef74 --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/search/TestSeededKnnFloatVectorQuery.java @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search; + +import java.io.IOException; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.IntPoint; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.document.NumericDocValuesField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.QueryTimeout; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.RandomIndexWriter; +import org.apache.lucene.tests.util.TestUtil; +import org.apache.lucene.util.TestVectorUtil; + +public class TestSeededKnnFloatVectorQuery extends BaseKnnVectorQueryTestCase { + private static final Query MATCH_NONE = new MatchNoDocsQuery(); + + @Override + KnnFloatVectorQuery getKnnVectorQuery(String field, float[] query, int k, Query queryFilter) { + return new SeededKnnFloatVectorQuery(field, query, k, queryFilter, MATCH_NONE); + } + + @Override + AbstractKnnVectorQuery getThrowingKnnVectorQuery(String field, float[] vec, int k, Query query) { + return new ThrowingKnnVectorQuery(field, vec, k, query, MATCH_NONE); + } + + @Override + float[] randomVector(int dim) { + return TestVectorUtil.randomVector(dim); + } + + @Override + Field getKnnVectorField( + String name, float[] vector, VectorSimilarityFunction similarityFunction) { + return new KnnFloatVectorField(name, vector, similarityFunction); + } + + @Override + Field getKnnVectorField(String name, float[] vector) { + return new KnnFloatVectorField(name, vector); + } + + /** Tests with random vectors and a random seed. Uses RandomIndexWriter. */ + public void testRandomWithSeed() throws IOException { + int numDocs = 1000; + int dimension = atLeast(5); + int numIters = atLeast(10); + int numDocsWithVector = 0; + try (Directory d = newDirectoryForTest()) { + // Always use the default kNN format to have predictable behavior around when it hits + // visitedLimit. This is fine since the test targets AbstractKnnVectorQuery logic, not the kNN + // format + // implementation. + IndexWriterConfig iwc = new IndexWriterConfig().setCodec(TestUtil.getDefaultCodec()); + RandomIndexWriter w = new RandomIndexWriter(random(), d, iwc); + for (int i = 0; i < numDocs; i++) { + Document doc = new Document(); + if (random().nextBoolean()) { + // Randomly skip some vectors to test the mapping from docid to ordinals + doc.add(getKnnVectorField("field", randomVector(dimension))); + numDocsWithVector += 1; + } + doc.add(new NumericDocValuesField("tag", i)); + doc.add(new IntPoint("tag", i)); + w.addDocument(doc); + } + w.forceMerge(1); + w.close(); + + try (IndexReader reader = DirectoryReader.open(d)) { + IndexSearcher searcher = newSearcher(reader); + for (int i = 0; i < numIters; i++) { + int k = random().nextInt(80) + 1; + int n = random().nextInt(100) + 1; + // we may get fewer results than requested if there are deletions, but this test doesn't + // check that + assert reader.hasDeletions() == false; + + // All documents as seeds + Query seed1 = new MatchAllDocsQuery(); + Query filter = random().nextBoolean() ? null : new MatchAllDocsQuery(); + AbstractKnnVectorQuery query = + new SeededKnnFloatVectorQuery("field", randomVector(dimension), k, filter, seed1); + TopDocs results = searcher.search(query, n); + int expected = Math.min(Math.min(n, k), numDocsWithVector); + + assertEquals(expected, results.scoreDocs.length); + assertTrue(results.totalHits.value() >= results.scoreDocs.length); + // verify the results are in descending score order + float last = Float.MAX_VALUE; + for (ScoreDoc scoreDoc : results.scoreDocs) { + assertTrue(scoreDoc.score <= last); + last = scoreDoc.score; + } + + // Restrictive seed query -- 6 documents + Query seed2 = IntPoint.newRangeQuery("tag", 1, 6); + query = new SeededKnnFloatVectorQuery("field", randomVector(dimension), k, null, seed2); + results = searcher.search(query, n); + expected = Math.min(Math.min(n, k), reader.numDocs()); + assertEquals(expected, results.scoreDocs.length); + assertTrue(results.totalHits.value() >= results.scoreDocs.length); + // verify the results are in descending score order + last = Float.MAX_VALUE; + for (ScoreDoc scoreDoc : results.scoreDocs) { + assertTrue(scoreDoc.score <= last); + last = scoreDoc.score; + } + + // No seed documents -- falls back on full approx search + Query seed3 = new MatchNoDocsQuery(); + query = new SeededKnnFloatVectorQuery("field", randomVector(dimension), k, null, seed3); + results = searcher.search(query, n); + expected = Math.min(Math.min(n, k), reader.numDocs()); + assertEquals(expected, results.scoreDocs.length); + assertTrue(results.totalHits.value() >= results.scoreDocs.length); + // verify the results are in descending score order + last = Float.MAX_VALUE; + for (ScoreDoc scoreDoc : results.scoreDocs) { + assertTrue(scoreDoc.score <= last); + last = scoreDoc.score; + } + } + } + } + } + + private static class ThrowingKnnVectorQuery extends SeededKnnFloatVectorQuery { + + private ThrowingKnnVectorQuery(String field, float[] target, int k, Query filter, Query seed) { + super(field, target, k, filter, seed); + } + + private ThrowingKnnVectorQuery( + String field, float[] target, int k, Query filter, Weight seedWeight) { + super(field, target, k, filter, seedWeight); + } + + @Override + // This is test only and we need to overwrite the inner rewrite to throw + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + if (seedWeight != null) { + return super.rewrite(indexSearcher); + } + BooleanQuery.Builder booleanSeedQueryBuilder = + new BooleanQuery.Builder() + .add(seed, BooleanClause.Occur.MUST) + .add(new FieldExistsQuery(field), BooleanClause.Occur.FILTER); + if (filter != null) { + booleanSeedQueryBuilder.add(filter, BooleanClause.Occur.FILTER); + } + Query seedRewritten = indexSearcher.rewrite(booleanSeedQueryBuilder.build()); + Weight seedWeight = indexSearcher.createWeight(seedRewritten, ScoreMode.TOP_SCORES, 1f); + return new ThrowingKnnVectorQuery(field, target, k, filter, seedWeight) + .rewrite(indexSearcher); + } + + @Override + protected TopDocs exactSearch( + LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout) { + throw new UnsupportedOperationException("exact search is not supported"); + } + + @Override + public String toString(String field) { + return null; + } + } +}