Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AbstractKnnVectorQuery.seed for seeded HNSW #13635

Open
wants to merge 40 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
60e3fab
implement seeded knn queries
Aug 6, 2024
82e7053
cleanup
Aug 6, 2024
7955148
ensure seed docs have a vector
Aug 6, 2024
40d972d
apply filter to seed queries
Aug 6, 2024
f36a4cd
Update lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsReader…
seanmacavaney Sep 5, 2024
539b29a
Update lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsReader…
seanmacavaney Sep 5, 2024
3df6ad2
Update lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQue…
seanmacavaney Sep 5, 2024
0508d87
Update lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSear…
seanmacavaney Sep 5, 2024
732f69c
Update lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSear…
seanmacavaney Sep 5, 2024
c02b4cc
Update lucene/core/src/java/org/apache/lucene/search/AbstractKnnVecto…
seanmacavaney Sep 5, 2024
285ebfe
Update lucene/core/src/java/org/apache/lucene/search/AbstractKnnVecto…
seanmacavaney Sep 5, 2024
b64d458
Merge branch 'main' into seeds
seanmacavaney Sep 5, 2024
9f1be67
mapping docIds to ordinals
Sep 10, 2024
244f46b
fixed test warning
Sep 10, 2024
b73e7a3
fix test warning
Sep 10, 2024
3134132
tidy
Sep 10, 2024
69db4d4
Update lucene/core/src/java/org/apache/lucene/search/AbstractKnnVecto…
seanmacavaney Sep 26, 2024
fe4bef3
Update lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQue…
seanmacavaney Sep 26, 2024
8e044f8
Update lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQu…
seanmacavaney Sep 26, 2024
33231b3
Update lucene/core/src/java/org/apache/lucene/index/FloatVectorValues…
seanmacavaney Sep 26, 2024
2e86e4f
Update lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.…
seanmacavaney Sep 26, 2024
c0c18b2
Update lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSear…
seanmacavaney Sep 26, 2024
6190aca
Update lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSear…
seanmacavaney Sep 26, 2024
a49ba2f
address review comments
Sep 26, 2024
8698b59
Merge branch 'main' into seeds
seanmacavaney Sep 26, 2024
1f8a9f4
merge issues
Sep 26, 2024
58f34df
addresses review comments
Sep 27, 2024
fc2129f
refactor wip
Oct 2, 2024
e8417d3
consistent naming
Oct 2, 2024
440b0d0
javadoc
Oct 2, 2024
87e75ab
tidy
Oct 2, 2024
7b3350f
javadoc typo
Oct 2, 2024
5bb40c2
javadoc
Oct 2, 2024
216bfc4
test fixes
Oct 2, 2024
bccf15d
Merge branch 'main' into seeds
seanmacavaney Oct 2, 2024
b6725c7
merge resolution
Oct 2, 2024
0cfd99b
merging
Oct 2, 2024
dd63bb0
refactor as decorator
Oct 2, 2024
1635870
javadoc
Oct 2, 2024
c8a512a
apply decorator elsewhere
Oct 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,12 @@ public ByteVectorValues getByteVectorValues(String field) {
}

@Override
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
public void search(
String field,
float[] target,
KnnCollector knnCollector,
Bits acceptDocs,
DocIdSetIterator seedDocs)
throws IOException {
FieldEntry fieldEntry = fields.get(field);

Expand Down Expand Up @@ -265,7 +270,12 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits
}

@Override
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
public void search(
String field,
byte[] target,
KnnCollector knnCollector,
Bits acceptDocs,
DocIdSetIterator seedDocs)
throws IOException {
throw new UnsupportedOperationException();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,12 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException {
}

@Override
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
public void search(
String field,
float[] target,
KnnCollector knnCollector,
Bits acceptDocs,
DocIdSetIterator seedDocs)
throws IOException {
FieldEntry fieldEntry = fields.get(field);

Expand All @@ -243,11 +248,17 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits
scorer,
new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc),
getGraph(fieldEntry),
getAcceptOrds(acceptDocs, fieldEntry));
getAcceptOrds(acceptDocs, fieldEntry),
seedDocs);
}

@Override
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
public void search(
String field,
byte[] target,
KnnCollector knnCollector,
Bits acceptDocsBits,
DocIdSetIterator seedDocs)
throws IOException {
throw new UnsupportedOperationException();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataInput;
Expand Down Expand Up @@ -224,7 +225,12 @@ public ByteVectorValues getByteVectorValues(String field) {
}

@Override
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
public void search(
String field,
float[] target,
KnnCollector knnCollector,
Bits acceptDocs,
DocIdSetIterator seedDocs)
throws IOException {
FieldEntry fieldEntry = fields.get(field);

Expand All @@ -244,7 +250,12 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits
}

@Override
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
public void search(
String field,
byte[] target,
KnnCollector knnCollector,
Bits acceptDocs,
DocIdSetIterator seedDocs)
throws IOException {
throw new UnsupportedOperationException();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataInput;
Expand Down Expand Up @@ -261,7 +262,12 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException {
}

@Override
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
public void search(
String field,
float[] target,
KnnCollector knnCollector,
Bits acceptDocs,
DocIdSetIterator seedDocs)
throws IOException {
FieldEntry fieldEntry = fields.get(field);

Expand All @@ -277,11 +283,17 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits
scorer,
new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc),
getGraph(fieldEntry),
vectorValues.getAcceptOrds(acceptDocs));
vectorValues.getAcceptOrds(acceptDocs),
seedDocs);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This question is perhaps a side effect of the history of this pull request and/or a naive question on my part: when do and don't we change the implementation of classes in the backwards_codecs (implementation other than signature) -- I note that here for 94 and 95 there are changes but for 92 there is no change. and 99 has no change (as yet). My speculation (but it is only that) is that perhaps the 94 and 95 changes need to move to 99 and that anything in backwards_codecs would get the method signature change only but no implementation change?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good points -- I think this is a question for @benwtrent?

}

@Override
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
public void search(
String field,
byte[] target,
KnnCollector knnCollector,
Bits acceptDocs,
DocIdSetIterator seedDocs)
throws IOException {
FieldEntry fieldEntry = fields.get(field);

Expand All @@ -297,7 +309,8 @@ public void search(String field, byte[] target, KnnCollector knnCollector, Bits
scorer,
new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc),
getGraph(fieldEntry),
vectorValues.getAcceptOrds(acceptDocs));
vectorValues.getAcceptOrds(acceptDocs),
seedDocs);
}

private HnswGraph getGraph(FieldEntry entry) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataInput;
Expand Down Expand Up @@ -285,7 +286,12 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException {
}

@Override
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
public void search(
String field,
float[] target,
KnnCollector knnCollector,
Bits acceptDocs,
DocIdSetIterator seedDocs)
throws IOException {
FieldEntry fieldEntry = fields.get(field);

Expand All @@ -312,11 +318,17 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits
scorer,
new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc),
getGraph(fieldEntry),
vectorValues.getAcceptOrds(acceptDocs));
vectorValues.getAcceptOrds(acceptDocs),
seedDocs);
}

@Override
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
public void search(
String field,
byte[] target,
KnnCollector knnCollector,
Bits acceptDocs,
DocIdSetIterator seedDocs)
throws IOException {
FieldEntry fieldEntry = fields.get(field);

Expand All @@ -343,7 +355,8 @@ public void search(String field, byte[] target, KnnCollector knnCollector, Bits
scorer,
new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc),
getGraph(fieldEntry),
vectorValues.getAcceptOrds(acceptDocs));
vectorValues.getAcceptOrds(acceptDocs),
seedDocs);
}

/** Get knn graph values; used for testing */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,12 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException {
}

@Override
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
public void search(
String field,
float[] target,
KnnCollector knnCollector,
Bits acceptDocs,
DocIdSetIterator seedDocs)
throws IOException {
FloatVectorValues values = getFloatVectorValues(field);
if (target.length != values.dimension()) {
Expand Down Expand Up @@ -210,7 +215,12 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits
}

@Override
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
public void search(
String field,
byte[] target,
KnnCollector knnCollector,
Bits acceptDocs,
DocIdSetIterator seedDocs)
throws IOException {
ByteVectorValues values = getByteVectorValues(field);
if (target.length != values.dimension()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ public void testIndexAndSearchBitVectors() throws IOException {
try (IndexReader reader = DirectoryReader.open(w)) {
LeafReader r = getOnlyLeafReader(reader);
TopKnnCollector collector = new TopKnnCollector(3, Integer.MAX_VALUE);
r.searchNearestVectors("v1", vectors[0], collector, null);
r.searchNearestVectors("v1", vectors[0], collector, null, null);
TopDocs topDocs = collector.topDocs();
assertEquals(3, topDocs.scoreDocs.length);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.NamedSPILoader;
Expand Down Expand Up @@ -138,13 +139,21 @@ public ByteVectorValues getByteVectorValues(String field) {

@Override
public void search(
String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) {
String field,
float[] target,
KnnCollector knnCollector,
Bits acceptDocs,
DocIdSetIterator seedDocs) {
throw new UnsupportedOperationException();
}

@Override
public void search(
String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) {
String field,
byte[] target,
KnnCollector knnCollector,
Bits acceptDocs,
DocIdSetIterator seedDocs) {
throw new UnsupportedOperationException();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
Expand Down Expand Up @@ -82,9 +83,16 @@ protected KnnVectorsReader() {}
* @param knnCollector a KnnResults collector and relevant settings for gathering vector results
* @param acceptDocs {@link Bits} that represents the allowed documents to match, or {@code null}
* if they are all allowed to match.
* @param seedDocs {@link Bits} that represents the documents used to seed the search, or {@code
seanmacavaney marked this conversation as resolved.
Show resolved Hide resolved
* null} to perform a search without seeds.
*/
public abstract void search(
String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException;
String field,
float[] target,
KnnCollector knnCollector,
Bits acceptDocs,
DocIdSetIterator seedDocs)
throws IOException;

/**
* Return the k nearest neighbor documents as determined by comparison of their vector values for
Expand All @@ -110,9 +118,16 @@ public abstract void search(
* @param knnCollector a KnnResults collector and relevant settings for gathering vector results
* @param acceptDocs {@link Bits} that represents the allowed documents to match, or {@code null}
* if they are all allowed to match.
* @param seedDocs {@link Bits} that represents the documents used to seed the search, or {@code
seanmacavaney marked this conversation as resolved.
Show resolved Hide resolved
* null} to perform a search without seeds.
*/
public abstract void search(
String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException;
String field,
byte[] target,
KnnCollector knnCollector,
Bits acceptDocs,
DocIdSetIterator seedDocs)
throws IOException;

/**
* Returns an instance optimized for merging. This instance may only be consumed in the thread
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import java.io.IOException;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.Bits;
Expand Down Expand Up @@ -56,13 +57,23 @@ public FlatVectorsScorer getFlatVectorScorer() {
}

@Override
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
public void search(
String field,
float[] target,
KnnCollector knnCollector,
Bits acceptDocs,
DocIdSetIterator seedDocs)
throws IOException {
// don't scan stored field data. If we didn't index it, produce no search results
}

@Override
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
public void search(
String field,
byte[] target,
KnnCollector knnCollector,
Bits acceptDoc,
DocIdSetIterator seedDocs)
throws IOException {
// don't scan stored field data. If we didn't index it, produce no search results
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataInput;
Expand Down Expand Up @@ -247,7 +248,12 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException {
}

@Override
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
public void search(
String field,
float[] target,
KnnCollector knnCollector,
Bits acceptDocs,
DocIdSetIterator seeds)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

throws IOException {
search(
fields.get(field),
Expand All @@ -258,7 +264,12 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits
}

@Override
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
public void search(
String field,
byte[] target,
KnnCollector knnCollector,
Bits acceptDocs,
DocIdSetIterator seeds)
throws IOException {
search(
fields.get(field),
Expand Down
Loading