diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt
index 6b01271c46b7..dc90bfa477cb 100644
--- a/lucene/CHANGES.txt
+++ b/lucene/CHANGES.txt
@@ -42,7 +42,11 @@ API Changes
New Features
---------------------
-(No changes)
+
+* GITHUB#14084, GITHUB#13635, GITHUB#13634: Adds new `SeededKnnByteVectorQuery` and `SeededKnnFloatVectorQuery`
+ queries. These queries allow for the vector search entry points to be initialized via a `seed` query. This follows
+ the research provided via https://arxiv.org/abs/2307.16779. (Sean MacAvaney, Ben Trent).
+
Improvements
---------------------
diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java
index 35144055830c..05157ab65cb5 100644
--- a/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java
@@ -46,7 +46,7 @@ public class KnnByteVectorQuery extends AbstractKnnVectorQuery {
private static final TopDocs NO_RESULTS = TopDocsCollector.EMPTY_TOPDOCS;
- private final byte[] target;
+ protected final byte[] target;
/**
* Find the k
nearest documents to the target vector according to the vectors in the
diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnCollector.java b/lucene/core/src/java/org/apache/lucene/search/KnnCollector.java
index 43bac9fbc309..f694d8f7085c 100644
--- a/lucene/core/src/java/org/apache/lucene/search/KnnCollector.java
+++ b/lucene/core/src/java/org/apache/lucene/search/KnnCollector.java
@@ -85,4 +85,58 @@ public interface KnnCollector {
* @return The collected top documents
*/
TopDocs topDocs();
+
+ /**
+ * KnnCollector.Decorator is the base class for decorators of KnnCollector objects, which extend
+ * the object with new behaviors.
+ *
+ * @lucene.experimental
+ */
+ abstract class Decorator implements KnnCollector {
+ private final KnnCollector collector;
+
+ public Decorator(KnnCollector collector) {
+ this.collector = collector;
+ }
+
+ @Override
+ public boolean earlyTerminated() {
+ return collector.earlyTerminated();
+ }
+
+ @Override
+ public void incVisitedCount(int count) {
+ collector.incVisitedCount(count);
+ }
+
+ @Override
+ public long visitedCount() {
+ return collector.visitedCount();
+ }
+
+ @Override
+ public long visitLimit() {
+ return collector.visitLimit();
+ }
+
+ @Override
+ public int k() {
+ return collector.k();
+ }
+
+ @Override
+ public boolean collect(int docId, float similarity) {
+ return collector.collect(docId, similarity);
+ }
+
+ @Override
+ public float minCompetitiveSimilarity() {
+ return collector.minCompetitiveSimilarity();
+ }
+
+ @Override
+ public TopDocs topDocs() {
+ return collector.topDocs();
+ }
+ }
}
diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java
index d2aaf4296eda..c7d6fdb3608d 100644
--- a/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java
@@ -47,7 +47,7 @@ public class KnnFloatVectorQuery extends AbstractKnnVectorQuery {
private static final TopDocs NO_RESULTS = TopDocsCollector.EMPTY_TOPDOCS;
- private final float[] target;
+ protected final float[] target;
/**
* Find the k
nearest documents to the target vector according to the vectors in the
diff --git a/lucene/core/src/java/org/apache/lucene/search/SeededKnnByteVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/SeededKnnByteVectorQuery.java
new file mode 100644
index 000000000000..980b6869c34f
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/search/SeededKnnByteVectorQuery.java
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.search;
+
+import java.io.IOException;
+import java.util.Objects;
+import org.apache.lucene.index.ByteVectorValues;
+import org.apache.lucene.search.knn.KnnCollectorManager;
+import org.apache.lucene.search.knn.SeededKnnCollectorManager;
+
+/**
+ * This is a version of knn byte vector query that provides a query seed to initiate the vector
+ * search. NOTE: The underlying format is free to ignore the provided seed
+ *
+ *
See "Lexically-Accelerated Dense
+ * Retrieval" (Kulkarni, Hrishikesh and MacAvaney, Sean and Goharian, Nazli and Frieder, Ophir).
+ * In SIGIR '23: Proceedings of the 46th International ACM SIGIR Conference on Research and
+ * Development in Information Retrieval Pages 152 - 162
+ *
+ * @lucene.experimental
+ */
+public class SeededKnnByteVectorQuery extends KnnByteVectorQuery {
+ final Query seed;
+ final Weight seedWeight;
+
+ /**
+ * Construct a new SeededKnnByteVectorQuery instance
+ *
+ * @param field knn byte vector field to query
+ * @param target the query vector
+ * @param k number of neighbors to return
+ * @param filter a filter on the neighbors to return
+ * @param seed a query seed to initiate the vector format search
+ */
+ public SeededKnnByteVectorQuery(String field, byte[] target, int k, Query filter, Query seed) {
+ super(field, target, k, filter);
+ this.seed = Objects.requireNonNull(seed);
+ this.seedWeight = null;
+ }
+
+ SeededKnnByteVectorQuery(String field, byte[] target, int k, Query filter, Weight seedWeight) {
+ super(field, target, k, filter);
+ this.seed = null;
+ this.seedWeight = Objects.requireNonNull(seedWeight);
+ }
+
+ @Override
+ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
+ if (seedWeight != null) {
+ return super.rewrite(indexSearcher);
+ }
+ BooleanQuery.Builder booleanSeedQueryBuilder =
+ new BooleanQuery.Builder()
+ .add(seed, BooleanClause.Occur.MUST)
+ .add(new FieldExistsQuery(field), BooleanClause.Occur.FILTER);
+ if (filter != null) {
+ booleanSeedQueryBuilder.add(filter, BooleanClause.Occur.FILTER);
+ }
+ Query seedRewritten = indexSearcher.rewrite(booleanSeedQueryBuilder.build());
+ Weight seedWeight = indexSearcher.createWeight(seedRewritten, ScoreMode.TOP_SCORES, 1f);
+ SeededKnnByteVectorQuery rewritten =
+ new SeededKnnByteVectorQuery(field, target, k, filter, seedWeight);
+ return rewritten.rewrite(indexSearcher);
+ }
+
+ @Override
+ protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) {
+ if (seedWeight == null) {
+ throw new UnsupportedOperationException("must be rewritten before constructing manager");
+ }
+ return new SeededKnnCollectorManager(
+ super.getKnnCollectorManager(k, searcher),
+ seedWeight,
+ k,
+ leaf -> {
+ ByteVectorValues vv = leaf.getByteVectorValues(field);
+ if (vv == null) {
+ ByteVectorValues.checkField(leaf.getContext().reader(), field);
+ }
+ return vv;
+ });
+ }
+}
diff --git a/lucene/core/src/java/org/apache/lucene/search/SeededKnnFloatVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/SeededKnnFloatVectorQuery.java
new file mode 100644
index 000000000000..02a33bdcdef7
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/search/SeededKnnFloatVectorQuery.java
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.search;
+
+import java.io.IOException;
+import java.util.Objects;
+import org.apache.lucene.index.FloatVectorValues;
+import org.apache.lucene.search.knn.KnnCollectorManager;
+import org.apache.lucene.search.knn.SeededKnnCollectorManager;
+
+/**
+ * This is a version of knn float vector query that provides a query seed to initiate the vector
+ * search. NOTE: The underlying format is free to ignore the provided seed.
+ *
+ *
See "Lexically-Accelerated Dense
+ * Retrieval" (Kulkarni, Hrishikesh and MacAvaney, Sean and Goharian, Nazli and Frieder, Ophir).
+ * In SIGIR '23: Proceedings of the 46th International ACM SIGIR Conference on Research and
+ * Development in Information Retrieval Pages 152 - 162
+ *
+ * @lucene.experimental
+ */
+public class SeededKnnFloatVectorQuery extends KnnFloatVectorQuery {
+ final Query seed;
+ final Weight seedWeight;
+
+ /**
+ * Construct a new SeededKnnFloatVectorQuery instance
+ *
+ * @param field knn float vector field to query
+ * @param target the query vector
+ * @param k number of neighbors to return
+ * @param filter a filter on the neighbors to return
+ * @param seed a query seed to initiate the vector format search
+ */
+ public SeededKnnFloatVectorQuery(String field, float[] target, int k, Query filter, Query seed) {
+ super(field, target, k, filter);
+ this.seed = Objects.requireNonNull(seed);
+ this.seedWeight = null;
+ }
+
+ SeededKnnFloatVectorQuery(String field, float[] target, int k, Query filter, Weight seedWeight) {
+ super(field, target, k, filter);
+ this.seed = null;
+ this.seedWeight = Objects.requireNonNull(seedWeight);
+ }
+
+ @Override
+ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
+ if (seedWeight != null) {
+ return super.rewrite(indexSearcher);
+ }
+ BooleanQuery.Builder booleanSeedQueryBuilder =
+ new BooleanQuery.Builder()
+ .add(seed, BooleanClause.Occur.MUST)
+ .add(new FieldExistsQuery(field), BooleanClause.Occur.FILTER);
+ if (filter != null) {
+ booleanSeedQueryBuilder.add(filter, BooleanClause.Occur.FILTER);
+ }
+ Query seedRewritten = indexSearcher.rewrite(booleanSeedQueryBuilder.build());
+ Weight seedWeight = indexSearcher.createWeight(seedRewritten, ScoreMode.TOP_SCORES, 1f);
+ SeededKnnFloatVectorQuery rewritten =
+ new SeededKnnFloatVectorQuery(field, target, k, filter, seedWeight);
+ return rewritten.rewrite(indexSearcher);
+ }
+
+ @Override
+ protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) {
+ if (seedWeight == null) {
+ throw new UnsupportedOperationException("must be rewritten before constructing manager");
+ }
+ return new SeededKnnCollectorManager(
+ super.getKnnCollectorManager(k, searcher),
+ seedWeight,
+ k,
+ leaf -> {
+ FloatVectorValues vv = leaf.getFloatVectorValues(field);
+ if (vv == null) {
+ FloatVectorValues.checkField(leaf.getContext().reader(), field);
+ }
+ return vv;
+ });
+ }
+}
diff --git a/lucene/core/src/java/org/apache/lucene/search/TimeLimitingKnnCollectorManager.java b/lucene/core/src/java/org/apache/lucene/search/TimeLimitingKnnCollectorManager.java
index 2a1f312fbc58..2dc2f035b90f 100644
--- a/lucene/core/src/java/org/apache/lucene/search/TimeLimitingKnnCollectorManager.java
+++ b/lucene/core/src/java/org/apache/lucene/search/TimeLimitingKnnCollectorManager.java
@@ -45,51 +45,19 @@ public KnnCollector newCollector(int visitedLimit, LeafReaderContext context) th
return new TimeLimitingKnnCollector(collector);
}
- class TimeLimitingKnnCollector implements KnnCollector {
- private final KnnCollector collector;
-
- TimeLimitingKnnCollector(KnnCollector collector) {
- this.collector = collector;
+ class TimeLimitingKnnCollector extends KnnCollector.Decorator {
+ public TimeLimitingKnnCollector(KnnCollector collector) {
+ super(collector);
}
@Override
public boolean earlyTerminated() {
- return queryTimeout.shouldExit() || collector.earlyTerminated();
- }
-
- @Override
- public void incVisitedCount(int count) {
- collector.incVisitedCount(count);
- }
-
- @Override
- public long visitedCount() {
- return collector.visitedCount();
- }
-
- @Override
- public long visitLimit() {
- return collector.visitLimit();
- }
-
- @Override
- public int k() {
- return collector.k();
- }
-
- @Override
- public boolean collect(int docId, float similarity) {
- return collector.collect(docId, similarity);
- }
-
- @Override
- public float minCompetitiveSimilarity() {
- return collector.minCompetitiveSimilarity();
+ return queryTimeout.shouldExit() || super.earlyTerminated();
}
@Override
public TopDocs topDocs() {
- TopDocs docs = collector.topDocs();
+ TopDocs docs = super.topDocs();
// Mark results as partial if timeout is met
TotalHits.Relation relation =
diff --git a/lucene/core/src/java/org/apache/lucene/search/knn/EntryPointProvider.java b/lucene/core/src/java/org/apache/lucene/search/knn/EntryPointProvider.java
new file mode 100644
index 000000000000..9e7b44b571df
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/search/knn/EntryPointProvider.java
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.search.knn;
+
+import org.apache.lucene.search.DocIdSetIterator;
+
+/** Provides entry points for the kNN search */
+public interface EntryPointProvider {
+ /** Iterator of valid entry points for the kNN search */
+ DocIdSetIterator entryPoints();
+
+ /** Number of valid entry points for the kNN search */
+ int numberOfEntryPoints();
+}
diff --git a/lucene/core/src/java/org/apache/lucene/search/knn/SeededKnnCollector.java b/lucene/core/src/java/org/apache/lucene/search/knn/SeededKnnCollector.java
new file mode 100644
index 000000000000..c3c4f62901ee
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/search/knn/SeededKnnCollector.java
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.search.knn;
+
+import org.apache.lucene.search.DocIdSetIterator;
+import org.apache.lucene.search.KnnCollector;
+
+/**
+ * A {@link KnnCollector} that provides seeded knn collection. See usage in {@link
+ * SeededKnnCollectorManager}.
+ *
+ * @lucene.experimental
+ */
+class SeededKnnCollector extends KnnCollector.Decorator implements EntryPointProvider {
+ private final DocIdSetIterator entryPoints;
+ private final int numberOfEntryPoints;
+
+ SeededKnnCollector(
+ KnnCollector collector, DocIdSetIterator entryPoints, int numberOfEntryPoints) {
+ super(collector);
+ this.entryPoints = entryPoints;
+ this.numberOfEntryPoints = numberOfEntryPoints;
+ }
+
+ @Override
+ public DocIdSetIterator entryPoints() {
+ return entryPoints;
+ }
+
+ @Override
+ public int numberOfEntryPoints() {
+ return numberOfEntryPoints;
+ }
+}
diff --git a/lucene/core/src/java/org/apache/lucene/search/knn/SeededKnnCollectorManager.java b/lucene/core/src/java/org/apache/lucene/search/knn/SeededKnnCollectorManager.java
new file mode 100644
index 000000000000..7631db6e3022
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/search/knn/SeededKnnCollectorManager.java
@@ -0,0 +1,177 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.search.knn;
+
+import java.io.IOException;
+import java.util.Arrays;
+import org.apache.lucene.index.KnnVectorValues;
+import org.apache.lucene.index.LeafReader;
+import org.apache.lucene.index.LeafReaderContext;
+import org.apache.lucene.search.BulkScorer;
+import org.apache.lucene.search.CollectionTerminatedException;
+import org.apache.lucene.search.DocIdSetIterator;
+import org.apache.lucene.search.KnnCollector;
+import org.apache.lucene.search.LeafCollector;
+import org.apache.lucene.search.TopDocs;
+import org.apache.lucene.search.TopScoreDocCollector;
+import org.apache.lucene.search.TopScoreDocCollectorManager;
+import org.apache.lucene.search.Weight;
+import org.apache.lucene.util.IOFunction;
+
+/**
+ * A {@link KnnCollectorManager} that provides seeded knn collection. See usage in {@link
+ * org.apache.lucene.search.SeededKnnFloatVectorQuery} and {@link
+ * org.apache.lucene.search.SeededKnnByteVectorQuery}.
+ */
+public class SeededKnnCollectorManager implements KnnCollectorManager {
+ private final KnnCollectorManager delegate;
+ private final Weight seedWeight;
+ private final int k;
+ private final IOFunction vectorValuesSupplier;
+
+ public SeededKnnCollectorManager(
+ KnnCollectorManager delegate,
+ Weight seedWeight,
+ int k,
+ IOFunction vectorValuesSupplier) {
+ this.delegate = delegate;
+ this.seedWeight = seedWeight;
+ this.k = k;
+ this.vectorValuesSupplier = vectorValuesSupplier;
+ }
+
+ @Override
+ public KnnCollector newCollector(int visitedLimit, LeafReaderContext ctx) throws IOException {
+ // Execute the seed query
+ TopScoreDocCollector seedCollector =
+ new TopScoreDocCollectorManager(k, null, Integer.MAX_VALUE).newCollector();
+ final LeafReader leafReader = ctx.reader();
+ final LeafCollector leafCollector = seedCollector.getLeafCollector(ctx);
+ if (leafCollector != null) {
+ try {
+ BulkScorer scorer = seedWeight.bulkScorer(ctx);
+ if (scorer != null) {
+ scorer.score(
+ leafCollector,
+ leafReader.getLiveDocs(),
+ 0 /* min */,
+ DocIdSetIterator.NO_MORE_DOCS /* max */);
+ }
+ } catch (
+ @SuppressWarnings("unused")
+ CollectionTerminatedException e) {
+ }
+ leafCollector.finish();
+ }
+
+ TopDocs seedTopDocs = seedCollector.topDocs();
+ KnnVectorValues vectorValues = vectorValuesSupplier.apply(leafReader);
+ final KnnCollector delegateCollector = delegate.newCollector(visitedLimit, ctx);
+ if (seedTopDocs.totalHits.value() == 0 || vectorValues == null) {
+ return delegateCollector;
+ }
+ KnnVectorValues.DocIndexIterator indexIterator = vectorValues.iterator();
+ DocIdSetIterator seedDocs = new MappedDISI(indexIterator, new TopDocsDISI(seedTopDocs));
+ return new SeededKnnCollector(delegateCollector, seedDocs, seedTopDocs.scoreDocs.length);
+ }
+
+ private static class MappedDISI extends DocIdSetIterator {
+ KnnVectorValues.DocIndexIterator indexedDISI;
+ DocIdSetIterator sourceDISI;
+
+ private MappedDISI(KnnVectorValues.DocIndexIterator indexedDISI, DocIdSetIterator sourceDISI) {
+ this.indexedDISI = indexedDISI;
+ this.sourceDISI = sourceDISI;
+ }
+
+ /**
+ * Advances the source iterator to the first document number that is greater than or equal to
+ * the provided target and returns the corresponding index.
+ */
+ @Override
+ public int advance(int target) throws IOException {
+ int newTarget = sourceDISI.advance(target);
+ if (newTarget != NO_MORE_DOCS) {
+ indexedDISI.advance(newTarget);
+ }
+ return docID();
+ }
+
+ @Override
+ public long cost() {
+ return sourceDISI.cost();
+ }
+
+ @Override
+ public int docID() {
+ if (indexedDISI.docID() == NO_MORE_DOCS || sourceDISI.docID() == NO_MORE_DOCS) {
+ return NO_MORE_DOCS;
+ }
+ return indexedDISI.index();
+ }
+
+ /** Advances to the next document in the source iterator and returns the corresponding index. */
+ @Override
+ public int nextDoc() throws IOException {
+ int newTarget = sourceDISI.nextDoc();
+ if (newTarget != NO_MORE_DOCS) {
+ indexedDISI.advance(newTarget);
+ }
+ return docID();
+ }
+ }
+
+ private static class TopDocsDISI extends DocIdSetIterator {
+ private final int[] sortedDocIds;
+ private int idx = -1;
+
+ private TopDocsDISI(TopDocs topDocs) {
+ sortedDocIds = new int[topDocs.scoreDocs.length];
+ for (int i = 0; i < topDocs.scoreDocs.length; i++) {
+ sortedDocIds[i] = topDocs.scoreDocs[i].doc;
+ }
+ Arrays.sort(sortedDocIds);
+ }
+
+ @Override
+ public int advance(int target) throws IOException {
+ return slowAdvance(target);
+ }
+
+ @Override
+ public long cost() {
+ return sortedDocIds.length;
+ }
+
+ @Override
+ public int docID() {
+ if (idx == -1) {
+ return -1;
+ } else if (idx >= sortedDocIds.length) {
+ return DocIdSetIterator.NO_MORE_DOCS;
+ } else {
+ return sortedDocIds[idx];
+ }
+ }
+
+ @Override
+ public int nextDoc() {
+ idx += 1;
+ return docID();
+ }
+ }
+}
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java
index 46d6c93d52c3..e8f0d316fd81 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java
@@ -20,8 +20,10 @@
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException;
+import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.TopKnnCollector;
+import org.apache.lucene.search.knn.EntryPointProvider;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.FixedBitSet;
@@ -52,7 +54,9 @@ public HnswGraphSearcher(NeighborQueue candidates, BitSet visited) {
}
/**
- * Searches HNSW graph for the nearest neighbors of a query vector.
+ * Searches the HNSW graph for the nearest neighbors of a query vector. If entry points are
+ * directly provided via the knnCollector, then the search will be initialized at those points.
+ * Otherwise, the search will discover the best entry point per the normal HNSW search algorithm.
*
* @param scorer the scorer to compare the query with the nodes
* @param knnCollector a collector of top knn results to be returned
@@ -67,7 +71,30 @@ public static void search(
HnswGraphSearcher graphSearcher =
new HnswGraphSearcher(
new NeighborQueue(knnCollector.k(), true), new SparseFixedBitSet(getGraphSize(graph)));
- search(scorer, knnCollector, graph, graphSearcher, acceptOrds);
+ final int[] entryPoints;
+ if (knnCollector instanceof EntryPointProvider epp) {
+ if (epp.numberOfEntryPoints() <= 0) {
+ throw new IllegalArgumentException("The number of entry points must be > 0");
+ }
+ DocIdSetIterator eps = epp.entryPoints();
+ entryPoints = new int[epp.numberOfEntryPoints()];
+ int idx = 0;
+ while (idx < entryPoints.length) {
+ int entryPointOrdInt = eps.nextDoc();
+ if (entryPointOrdInt == NO_MORE_DOCS) {
+ throw new IllegalArgumentException(
+ "The number of entry points provided is less than the number of entry points requested");
+ }
+ assert entryPointOrdInt < getGraphSize(graph);
+ entryPoints[idx++] = entryPointOrdInt;
+ }
+ // This is an invalid case, but we should check it
+ assert entryPoints.length > 0;
+ // We use provided entry point ordinals to search the complete graph (level 0)
+ graphSearcher.searchLevel(knnCollector, scorer, 0, entryPoints, graph, acceptOrds);
+ } else {
+ search(scorer, knnCollector, graph, graphSearcher, acceptOrds);
+ }
}
/**
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/OrdinalTranslatedKnnCollector.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/OrdinalTranslatedKnnCollector.java
index ed1a5ffb59fa..5225fe700ab9 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/OrdinalTranslatedKnnCollector.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/OrdinalTranslatedKnnCollector.java
@@ -24,54 +24,24 @@
/**
* Wraps a provided KnnCollector object, translating the provided vectorId ordinal to a documentId
*/
-public final class OrdinalTranslatedKnnCollector implements KnnCollector {
+public final class OrdinalTranslatedKnnCollector extends KnnCollector.Decorator {
- private final KnnCollector in;
private final IntToIntFunction vectorOrdinalToDocId;
- public OrdinalTranslatedKnnCollector(KnnCollector in, IntToIntFunction vectorOrdinalToDocId) {
- this.in = in;
+ public OrdinalTranslatedKnnCollector(
+ KnnCollector collector, IntToIntFunction vectorOrdinalToDocId) {
+ super(collector);
this.vectorOrdinalToDocId = vectorOrdinalToDocId;
}
- @Override
- public boolean earlyTerminated() {
- return in.earlyTerminated();
- }
-
- @Override
- public void incVisitedCount(int count) {
- in.incVisitedCount(count);
- }
-
- @Override
- public long visitedCount() {
- return in.visitedCount();
- }
-
- @Override
- public long visitLimit() {
- return in.visitLimit();
- }
-
- @Override
- public int k() {
- return in.k();
- }
-
@Override
public boolean collect(int vectorId, float similarity) {
- return in.collect(vectorOrdinalToDocId.apply(vectorId), similarity);
- }
-
- @Override
- public float minCompetitiveSimilarity() {
- return in.minCompetitiveSimilarity();
+ return super.collect(vectorOrdinalToDocId.apply(vectorId), similarity);
}
@Override
public TopDocs topDocs() {
- TopDocs td = in.topDocs();
+ TopDocs td = super.topDocs();
return new TopDocs(
new TotalHits(
visitedCount(),
diff --git a/lucene/core/src/test/org/apache/lucene/document/TestManyKnnDocs.java b/lucene/core/src/test/org/apache/lucene/document/TestManyKnnDocs.java
index 2023ee73391d..1e485515a62b 100644
--- a/lucene/core/src/test/org/apache/lucene/document/TestManyKnnDocs.java
+++ b/lucene/core/src/test/org/apache/lucene/document/TestManyKnnDocs.java
@@ -17,6 +17,7 @@
package org.apache.lucene.document;
import com.carrotsearch.randomizedtesting.annotations.TimeoutSuite;
+import java.nio.file.Path;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
@@ -24,19 +25,27 @@
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.KnnFloatVectorQuery;
+import org.apache.lucene.search.MatchAllDocsQuery;
+import org.apache.lucene.search.MatchNoDocsQuery;
+import org.apache.lucene.search.Query;
+import org.apache.lucene.search.SeededKnnFloatVectorQuery;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.FSDirectory;
import org.apache.lucene.tests.codecs.vector.ConfigurableMCodec;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.tests.util.LuceneTestCase.Monster;
+import org.junit.BeforeClass;
@TimeoutSuite(millis = 86_400_000) // 24 hour timeout
@Monster("takes ~10 minutes and needs extra heap, disk space, file handles")
public class TestManyKnnDocs extends LuceneTestCase {
// gradlew -p lucene/core test --tests TestManyKnnDocs -Ptests.heapsize=16g -Dtests.monster=true
- public void testLargeSegment() throws Exception {
+ private static Path testDir;
+
+ @BeforeClass
+ public static void init_index() throws Exception {
IndexWriterConfig iwc = new IndexWriterConfig();
iwc.setCodec(
new ConfigurableMCodec(
@@ -46,27 +55,138 @@ public void testLargeSegment() throws Exception {
mp.setMaxMergeAtOnce(256); // avoid intermediate merges (waste of time with HNSW?)
mp.setSegmentsPerTier(256); // only merge once at the end when we ask
iwc.setMergePolicy(mp);
- String fieldName = "field";
VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
- try (Directory dir = FSDirectory.open(createTempDir("ManyKnnVectorDocs"));
+ try (Directory dir = FSDirectory.open(testDir = createTempDir("ManyKnnVectorDocs"));
IndexWriter iw = new IndexWriter(dir, iwc)) {
int numVectors = 2088992;
- float[] vector = new float[1];
- Document doc = new Document();
- doc.add(new KnnFloatVectorField(fieldName, vector, similarityFunction));
for (int i = 0; i < numVectors; i++) {
+ float[] vector = new float[1];
+ Document doc = new Document();
vector[0] = (i % 256);
+ doc.add(new KnnFloatVectorField("field", vector, similarityFunction));
+ doc.add(new KeywordField("int", "" + i, org.apache.lucene.document.Field.Store.YES));
+ doc.add(new StoredField("intValue", i));
iw.addDocument(doc);
}
// merge to single segment and then verify
iw.forceMerge(1);
iw.commit();
+ }
+ }
+
+ public void testLargeSegmentKnn() throws Exception {
+ try (Directory dir = FSDirectory.open(testDir)) {
IndexSearcher searcher = new IndexSearcher(DirectoryReader.open(dir));
- TopDocs docs = searcher.search(new KnnFloatVectorQuery("field", new float[] {120}, 10), 5);
- assertEquals(5, docs.scoreDocs.length);
+ for (int i = 0; i < 256; i++) {
+ Query filterQuery = new MatchAllDocsQuery();
+ float[] vector = new float[128];
+ vector[0] = i;
+ vector[1] = 1;
+ TopDocs docs =
+ searcher.search(new KnnFloatVectorQuery("field", vector, 10, filterQuery), 5);
+ assertEquals(5, docs.scoreDocs.length);
+ Document d = searcher.storedFields().document(docs.scoreDocs[0].doc);
+ String s = "";
+ for (int j = 0; j < docs.scoreDocs.length - 1; j++) {
+ s += docs.scoreDocs[j].doc + " " + docs.scoreDocs[j].score + "\n";
+ }
+ assertEquals(s, i + 256, d.getField("intValue").numericValue());
+ }
+ }
+ }
+
+ public void testLargeSegmentSeededExact() throws Exception {
+ try (Directory dir = FSDirectory.open(testDir)) {
+ IndexSearcher searcher = new IndexSearcher(DirectoryReader.open(dir));
+ for (int i = 0; i < 256; i++) {
+ Query seedQuery = KeywordField.newExactQuery("int", "" + (i + 256));
+ Query filterQuery = new MatchAllDocsQuery();
+ float[] vector = new float[128];
+ vector[0] = i;
+ vector[1] = 1;
+ TopDocs docs =
+ searcher.search(
+ new SeededKnnFloatVectorQuery("field", vector, 10, filterQuery, seedQuery), 5);
+ assertEquals(5, docs.scoreDocs.length);
+ String s = "";
+ for (int j = 0; j < docs.scoreDocs.length - 1; j++) {
+ s += docs.scoreDocs[j].doc + " " + docs.scoreDocs[j].score + "\n";
+ }
+ Document d = searcher.storedFields().document(docs.scoreDocs[0].doc);
+ assertEquals(s, i + 256, d.getField("intValue").numericValue());
+ }
+ }
+ }
+
+ public void testLargeSegmentSeededNearby() throws Exception {
+ try (Directory dir = FSDirectory.open(testDir)) {
+ IndexSearcher searcher = new IndexSearcher(DirectoryReader.open(dir));
+ for (int i = 0; i < 256; i++) {
+ Query seedQuery = KeywordField.newExactQuery("int", "" + i);
+ Query filterQuery = new MatchAllDocsQuery();
+ float[] vector = new float[128];
+ vector[0] = i;
+ vector[1] = 1;
+ TopDocs docs =
+ searcher.search(
+ new SeededKnnFloatVectorQuery("field", vector, 10, filterQuery, seedQuery), 5);
+ assertEquals(5, docs.scoreDocs.length);
+ String s = "";
+ for (int j = 0; j < docs.scoreDocs.length - 1; j++) {
+ s += docs.scoreDocs[j].doc + " " + docs.scoreDocs[j].score + "\n";
+ }
+ Document d = searcher.storedFields().document(docs.scoreDocs[0].doc);
+ assertEquals(s, i + 256, d.getField("intValue").numericValue());
+ }
+ }
+ }
+
+ public void testLargeSegmentSeededDistant() throws Exception {
+ try (Directory dir = FSDirectory.open(testDir)) {
+ IndexSearcher searcher = new IndexSearcher(DirectoryReader.open(dir));
+ for (int i = 0; i < 256; i++) {
+ Query seedQuery = KeywordField.newExactQuery("int", "" + (i + 128));
+ Query filterQuery = new MatchAllDocsQuery();
+ float[] vector = new float[128];
+ vector[0] = i;
+ vector[1] = 1;
+ TopDocs docs =
+ searcher.search(
+ new SeededKnnFloatVectorQuery("field", vector, 10, filterQuery, seedQuery), 5);
+ assertEquals(5, docs.scoreDocs.length);
+ Document d = searcher.storedFields().document(docs.scoreDocs[0].doc);
+ String s = "";
+ for (int j = 0; j < docs.scoreDocs.length - 1; j++) {
+ s += docs.scoreDocs[j].doc + " " + docs.scoreDocs[j].score + "\n";
+ }
+ assertEquals(s, i + 256, d.getField("intValue").numericValue());
+ }
+ }
+ }
+
+ public void testLargeSegmentSeededNone() throws Exception {
+ try (Directory dir = FSDirectory.open(testDir)) {
+ IndexSearcher searcher = new IndexSearcher(DirectoryReader.open(dir));
+ for (int i = 0; i < 256; i++) {
+ Query seedQuery = new MatchNoDocsQuery();
+ Query filterQuery = new MatchAllDocsQuery();
+ float[] vector = new float[128];
+ vector[0] = i;
+ vector[1] = 1;
+ TopDocs docs =
+ searcher.search(
+ new SeededKnnFloatVectorQuery("field", vector, 10, filterQuery, seedQuery), 5);
+ assertEquals(5, docs.scoreDocs.length);
+ Document d = searcher.storedFields().document(docs.scoreDocs[0].doc);
+ String s = "";
+ for (int j = 0; j < docs.scoreDocs.length - 1; j++) {
+ s += docs.scoreDocs[j].doc + " " + docs.scoreDocs[j].score + "\n";
+ }
+ assertEquals(s, i + 256, d.getField("intValue").numericValue());
+ }
}
}
}
diff --git a/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java
index b45d6e8fb641..21219e0e1d99 100644
--- a/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java
+++ b/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java
@@ -61,7 +61,7 @@ Field getKnnVectorField(String name, float[] vector) {
return new KnnByteVectorField(name, floatToBytes(vector), VectorSimilarityFunction.EUCLIDEAN);
}
- private static byte[] floatToBytes(float[] query) {
+ static byte[] floatToBytes(float[] query) {
byte[] bytes = new byte[query.length];
for (int i = 0; i < query.length; i++) {
assert query[i] <= Byte.MAX_VALUE && query[i] >= Byte.MIN_VALUE && (query[i] % 1) == 0
@@ -109,7 +109,7 @@ public void testVectorEncodingMismatch() throws IOException {
}
}
- private static class ThrowingKnnVectorQuery extends KnnByteVectorQuery {
+ static class ThrowingKnnVectorQuery extends KnnByteVectorQuery {
public ThrowingKnnVectorQuery(String field, byte[] target, int k, Query filter) {
super(field, target, k, filter);
diff --git a/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java
index 5dcb6f97df93..ece2b385654e 100644
--- a/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java
+++ b/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java
@@ -259,7 +259,7 @@ public void testDocAndScoreQueryBasics() throws IOException {
}
}
- private static class ThrowingKnnVectorQuery extends KnnFloatVectorQuery {
+ static class ThrowingKnnVectorQuery extends KnnFloatVectorQuery {
public ThrowingKnnVectorQuery(String field, float[] target, int k, Query filter) {
super(field, target, k, filter);
diff --git a/lucene/core/src/test/org/apache/lucene/search/TestSeededKnnByteVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestSeededKnnByteVectorQuery.java
new file mode 100644
index 000000000000..d0fb8c95e035
--- /dev/null
+++ b/lucene/core/src/test/org/apache/lucene/search/TestSeededKnnByteVectorQuery.java
@@ -0,0 +1,205 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.search;
+
+import static org.apache.lucene.search.TestKnnByteVectorQuery.floatToBytes;
+
+import java.io.IOException;
+import org.apache.lucene.document.Document;
+import org.apache.lucene.document.Field;
+import org.apache.lucene.document.IntPoint;
+import org.apache.lucene.document.KnnByteVectorField;
+import org.apache.lucene.document.NumericDocValuesField;
+import org.apache.lucene.index.DirectoryReader;
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.IndexWriterConfig;
+import org.apache.lucene.index.LeafReaderContext;
+import org.apache.lucene.index.QueryTimeout;
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.store.Directory;
+import org.apache.lucene.tests.index.RandomIndexWriter;
+import org.apache.lucene.tests.util.TestUtil;
+import org.apache.lucene.util.TestVectorUtil;
+
+public class TestSeededKnnByteVectorQuery extends BaseKnnVectorQueryTestCase {
+
+ private static final Query MATCH_NONE = new MatchNoDocsQuery();
+
+ @Override
+ AbstractKnnVectorQuery getKnnVectorQuery(String field, float[] query, int k, Query queryFilter) {
+ return new SeededKnnByteVectorQuery(field, floatToBytes(query), k, queryFilter, MATCH_NONE);
+ }
+
+ @Override
+ AbstractKnnVectorQuery getThrowingKnnVectorQuery(String field, float[] vec, int k, Query query) {
+ return new ThrowingKnnVectorQuery(field, floatToBytes(vec), k, query, MATCH_NONE);
+ }
+
+ @Override
+ float[] randomVector(int dim) {
+ byte[] b = TestVectorUtil.randomVectorBytes(dim);
+ float[] v = new float[b.length];
+ int vi = 0;
+ for (int i = 0; i < v.length; i++) {
+ v[vi++] = b[i];
+ }
+ return v;
+ }
+
+ @Override
+ Field getKnnVectorField(
+ String name, float[] vector, VectorSimilarityFunction similarityFunction) {
+ return new KnnByteVectorField(name, floatToBytes(vector), similarityFunction);
+ }
+
+ @Override
+ Field getKnnVectorField(String name, float[] vector) {
+ return new KnnByteVectorField(name, floatToBytes(vector), VectorSimilarityFunction.EUCLIDEAN);
+ }
+
+ /** Tests with random vectors and a random seed. Uses RandomIndexWriter. */
+ public void testRandomWithSeed() throws IOException {
+ int numDocs = 1000;
+ int dimension = atLeast(5);
+ int numIters = atLeast(10);
+ int numDocsWithVector = 0;
+ try (Directory d = newDirectoryForTest()) {
+ // Always use the default kNN format to have predictable behavior around when it hits
+ // visitedLimit. This is fine since the test targets AbstractKnnVectorQuery logic, not the kNN
+ // format
+ // implementation.
+ IndexWriterConfig iwc = new IndexWriterConfig().setCodec(TestUtil.getDefaultCodec());
+ RandomIndexWriter w = new RandomIndexWriter(random(), d, iwc);
+ for (int i = 0; i < numDocs; i++) {
+ Document doc = new Document();
+ if (random().nextBoolean()) {
+ // Randomly skip some vectors to test the mapping from docid to ordinals
+ doc.add(getKnnVectorField("field", randomVector(dimension)));
+ numDocsWithVector += 1;
+ }
+ doc.add(new NumericDocValuesField("tag", i));
+ doc.add(new IntPoint("tag", i));
+ w.addDocument(doc);
+ }
+ w.forceMerge(1);
+ w.close();
+
+ try (IndexReader reader = DirectoryReader.open(d)) {
+ IndexSearcher searcher = newSearcher(reader);
+ for (int i = 0; i < numIters; i++) {
+ int k = random().nextInt(80) + 1;
+ int n = random().nextInt(100) + 1;
+ // we may get fewer results than requested if there are deletions, but this test doesn't
+ // check that
+ assert reader.hasDeletions() == false;
+
+ // All documents as seeds
+ Query seed1 = new MatchAllDocsQuery();
+ Query filter = random().nextBoolean() ? null : new MatchAllDocsQuery();
+ SeededKnnByteVectorQuery query =
+ new SeededKnnByteVectorQuery(
+ "field", floatToBytes(randomVector(dimension)), k, filter, seed1);
+ TopDocs results = searcher.search(query, n);
+ int expected = Math.min(Math.min(n, k), numDocsWithVector);
+
+ assertEquals(expected, results.scoreDocs.length);
+ assertTrue(results.totalHits.value() >= results.scoreDocs.length);
+ // verify the results are in descending score order
+ float last = Float.MAX_VALUE;
+ for (ScoreDoc scoreDoc : results.scoreDocs) {
+ assertTrue(scoreDoc.score <= last);
+ last = scoreDoc.score;
+ }
+
+ // Restrictive seed query -- 6 documents
+ Query seed2 = IntPoint.newRangeQuery("tag", 1, 6);
+ query =
+ new SeededKnnByteVectorQuery(
+ "field", floatToBytes(randomVector(dimension)), k, null, seed2);
+ results = searcher.search(query, n);
+ expected = Math.min(Math.min(n, k), reader.numDocs());
+ assertEquals(expected, results.scoreDocs.length);
+ assertTrue(results.totalHits.value() >= results.scoreDocs.length);
+ // verify the results are in descending score order
+ last = Float.MAX_VALUE;
+ for (ScoreDoc scoreDoc : results.scoreDocs) {
+ assertTrue(scoreDoc.score <= last);
+ last = scoreDoc.score;
+ }
+
+ // No seed documents -- falls back on full approx search
+ Query seed3 = new MatchNoDocsQuery();
+ query =
+ new SeededKnnByteVectorQuery(
+ "field", floatToBytes(randomVector(dimension)), k, null, seed3);
+ results = searcher.search(query, n);
+ expected = Math.min(Math.min(n, k), reader.numDocs());
+ assertEquals(expected, results.scoreDocs.length);
+ assertTrue(results.totalHits.value() >= results.scoreDocs.length);
+ // verify the results are in descending score order
+ last = Float.MAX_VALUE;
+ for (ScoreDoc scoreDoc : results.scoreDocs) {
+ assertTrue(scoreDoc.score <= last);
+ last = scoreDoc.score;
+ }
+ }
+ }
+ }
+ }
+
+ private static class ThrowingKnnVectorQuery extends SeededKnnByteVectorQuery {
+
+ public ThrowingKnnVectorQuery(String field, byte[] target, int k, Query filter, Query seed) {
+ super(field, target, k, filter, seed);
+ }
+
+ private ThrowingKnnVectorQuery(
+ String field, byte[] target, int k, Query filter, Weight seedWeight) {
+ super(field, target, k, filter, seedWeight);
+ }
+
+ @Override
+ // This is test only and we need to overwrite the inner rewrite to throw
+ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
+ if (seedWeight != null) {
+ return super.rewrite(indexSearcher);
+ }
+ BooleanQuery.Builder booleanSeedQueryBuilder =
+ new BooleanQuery.Builder()
+ .add(seed, BooleanClause.Occur.MUST)
+ .add(new FieldExistsQuery(field), BooleanClause.Occur.FILTER);
+ if (filter != null) {
+ booleanSeedQueryBuilder.add(filter, BooleanClause.Occur.FILTER);
+ }
+ Query seedRewritten = indexSearcher.rewrite(booleanSeedQueryBuilder.build());
+ Weight seedWeight = indexSearcher.createWeight(seedRewritten, ScoreMode.TOP_SCORES, 1f);
+ return new ThrowingKnnVectorQuery(field, target, k, filter, seedWeight)
+ .rewrite(indexSearcher);
+ }
+
+ @Override
+ protected TopDocs exactSearch(
+ LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout) {
+ throw new UnsupportedOperationException("exact search is not supported");
+ }
+
+ @Override
+ public String toString(String field) {
+ return null;
+ }
+ }
+}
diff --git a/lucene/core/src/test/org/apache/lucene/search/TestSeededKnnFloatVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestSeededKnnFloatVectorQuery.java
new file mode 100644
index 000000000000..d5630037ef74
--- /dev/null
+++ b/lucene/core/src/test/org/apache/lucene/search/TestSeededKnnFloatVectorQuery.java
@@ -0,0 +1,191 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.search;
+
+import java.io.IOException;
+import org.apache.lucene.document.Document;
+import org.apache.lucene.document.Field;
+import org.apache.lucene.document.IntPoint;
+import org.apache.lucene.document.KnnFloatVectorField;
+import org.apache.lucene.document.NumericDocValuesField;
+import org.apache.lucene.index.DirectoryReader;
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.IndexWriterConfig;
+import org.apache.lucene.index.LeafReaderContext;
+import org.apache.lucene.index.QueryTimeout;
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.store.Directory;
+import org.apache.lucene.tests.index.RandomIndexWriter;
+import org.apache.lucene.tests.util.TestUtil;
+import org.apache.lucene.util.TestVectorUtil;
+
+public class TestSeededKnnFloatVectorQuery extends BaseKnnVectorQueryTestCase {
+ private static final Query MATCH_NONE = new MatchNoDocsQuery();
+
+ @Override
+ KnnFloatVectorQuery getKnnVectorQuery(String field, float[] query, int k, Query queryFilter) {
+ return new SeededKnnFloatVectorQuery(field, query, k, queryFilter, MATCH_NONE);
+ }
+
+ @Override
+ AbstractKnnVectorQuery getThrowingKnnVectorQuery(String field, float[] vec, int k, Query query) {
+ return new ThrowingKnnVectorQuery(field, vec, k, query, MATCH_NONE);
+ }
+
+ @Override
+ float[] randomVector(int dim) {
+ return TestVectorUtil.randomVector(dim);
+ }
+
+ @Override
+ Field getKnnVectorField(
+ String name, float[] vector, VectorSimilarityFunction similarityFunction) {
+ return new KnnFloatVectorField(name, vector, similarityFunction);
+ }
+
+ @Override
+ Field getKnnVectorField(String name, float[] vector) {
+ return new KnnFloatVectorField(name, vector);
+ }
+
+ /** Tests with random vectors and a random seed. Uses RandomIndexWriter. */
+ public void testRandomWithSeed() throws IOException {
+ int numDocs = 1000;
+ int dimension = atLeast(5);
+ int numIters = atLeast(10);
+ int numDocsWithVector = 0;
+ try (Directory d = newDirectoryForTest()) {
+ // Always use the default kNN format to have predictable behavior around when it hits
+ // visitedLimit. This is fine since the test targets AbstractKnnVectorQuery logic, not the kNN
+ // format
+ // implementation.
+ IndexWriterConfig iwc = new IndexWriterConfig().setCodec(TestUtil.getDefaultCodec());
+ RandomIndexWriter w = new RandomIndexWriter(random(), d, iwc);
+ for (int i = 0; i < numDocs; i++) {
+ Document doc = new Document();
+ if (random().nextBoolean()) {
+ // Randomly skip some vectors to test the mapping from docid to ordinals
+ doc.add(getKnnVectorField("field", randomVector(dimension)));
+ numDocsWithVector += 1;
+ }
+ doc.add(new NumericDocValuesField("tag", i));
+ doc.add(new IntPoint("tag", i));
+ w.addDocument(doc);
+ }
+ w.forceMerge(1);
+ w.close();
+
+ try (IndexReader reader = DirectoryReader.open(d)) {
+ IndexSearcher searcher = newSearcher(reader);
+ for (int i = 0; i < numIters; i++) {
+ int k = random().nextInt(80) + 1;
+ int n = random().nextInt(100) + 1;
+ // we may get fewer results than requested if there are deletions, but this test doesn't
+ // check that
+ assert reader.hasDeletions() == false;
+
+ // All documents as seeds
+ Query seed1 = new MatchAllDocsQuery();
+ Query filter = random().nextBoolean() ? null : new MatchAllDocsQuery();
+ AbstractKnnVectorQuery query =
+ new SeededKnnFloatVectorQuery("field", randomVector(dimension), k, filter, seed1);
+ TopDocs results = searcher.search(query, n);
+ int expected = Math.min(Math.min(n, k), numDocsWithVector);
+
+ assertEquals(expected, results.scoreDocs.length);
+ assertTrue(results.totalHits.value() >= results.scoreDocs.length);
+ // verify the results are in descending score order
+ float last = Float.MAX_VALUE;
+ for (ScoreDoc scoreDoc : results.scoreDocs) {
+ assertTrue(scoreDoc.score <= last);
+ last = scoreDoc.score;
+ }
+
+ // Restrictive seed query -- 6 documents
+ Query seed2 = IntPoint.newRangeQuery("tag", 1, 6);
+ query = new SeededKnnFloatVectorQuery("field", randomVector(dimension), k, null, seed2);
+ results = searcher.search(query, n);
+ expected = Math.min(Math.min(n, k), reader.numDocs());
+ assertEquals(expected, results.scoreDocs.length);
+ assertTrue(results.totalHits.value() >= results.scoreDocs.length);
+ // verify the results are in descending score order
+ last = Float.MAX_VALUE;
+ for (ScoreDoc scoreDoc : results.scoreDocs) {
+ assertTrue(scoreDoc.score <= last);
+ last = scoreDoc.score;
+ }
+
+ // No seed documents -- falls back on full approx search
+ Query seed3 = new MatchNoDocsQuery();
+ query = new SeededKnnFloatVectorQuery("field", randomVector(dimension), k, null, seed3);
+ results = searcher.search(query, n);
+ expected = Math.min(Math.min(n, k), reader.numDocs());
+ assertEquals(expected, results.scoreDocs.length);
+ assertTrue(results.totalHits.value() >= results.scoreDocs.length);
+ // verify the results are in descending score order
+ last = Float.MAX_VALUE;
+ for (ScoreDoc scoreDoc : results.scoreDocs) {
+ assertTrue(scoreDoc.score <= last);
+ last = scoreDoc.score;
+ }
+ }
+ }
+ }
+ }
+
+ private static class ThrowingKnnVectorQuery extends SeededKnnFloatVectorQuery {
+
+ private ThrowingKnnVectorQuery(String field, float[] target, int k, Query filter, Query seed) {
+ super(field, target, k, filter, seed);
+ }
+
+ private ThrowingKnnVectorQuery(
+ String field, float[] target, int k, Query filter, Weight seedWeight) {
+ super(field, target, k, filter, seedWeight);
+ }
+
+ @Override
+ // This is test only and we need to overwrite the inner rewrite to throw
+ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
+ if (seedWeight != null) {
+ return super.rewrite(indexSearcher);
+ }
+ BooleanQuery.Builder booleanSeedQueryBuilder =
+ new BooleanQuery.Builder()
+ .add(seed, BooleanClause.Occur.MUST)
+ .add(new FieldExistsQuery(field), BooleanClause.Occur.FILTER);
+ if (filter != null) {
+ booleanSeedQueryBuilder.add(filter, BooleanClause.Occur.FILTER);
+ }
+ Query seedRewritten = indexSearcher.rewrite(booleanSeedQueryBuilder.build());
+ Weight seedWeight = indexSearcher.createWeight(seedRewritten, ScoreMode.TOP_SCORES, 1f);
+ return new ThrowingKnnVectorQuery(field, target, k, filter, seedWeight)
+ .rewrite(indexSearcher);
+ }
+
+ @Override
+ protected TopDocs exactSearch(
+ LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout) {
+ throw new UnsupportedOperationException("exact search is not supported");
+ }
+
+ @Override
+ public String toString(String field) {
+ return null;
+ }
+ }
+}