Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add updateable random scorer interface for vector index building #14181

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;

/** A bit vector scorer for scoring byte vectors. */
public class FlatBitVectorsScorer implements FlatVectorsScorer {
@Override
public RandomVectorScorerSupplier getRandomVectorScorerSupplier(
VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues)
throws IOException {
assert vectorValues instanceof ByteVectorValues;
if (vectorValues instanceof ByteVectorValues byteVectorValues) {
return new BitRandomVectorScorerSupplier(byteVectorValues);
}
Expand All @@ -51,14 +51,13 @@ public RandomVectorScorer getRandomVectorScorer(
public RandomVectorScorer getRandomVectorScorer(
VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, byte[] target)
throws IOException {
assert vectorValues instanceof ByteVectorValues;
if (vectorValues instanceof ByteVectorValues byteVectorValues) {
return new BitRandomVectorScorer(byteVectorValues, target);
}
throw new IllegalArgumentException("vectorValues must be an instance of ByteVectorValues");
}

static class BitRandomVectorScorer implements RandomVectorScorer {
static class BitRandomVectorScorer implements UpdateableRandomVectorScorer {
private final ByteVectorValues vectorValues;
private final int bitDimensions;
private final byte[] query;
benwtrent marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -80,6 +79,11 @@ public int maxOrd() {
return vectorValues.size();
}

@Override
public void setScoringOrdinal(int node) throws IOException {
System.arraycopy(vectorValues.vectorValue(node), 0, query, 0, query.length);
}

@Override
public int ordToDoc(int ord) {
return vectorValues.ordToDoc(ord);
Expand All @@ -93,24 +97,26 @@ public Bits getAcceptOrds(Bits acceptDocs) {

static class BitRandomVectorScorerSupplier implements RandomVectorScorerSupplier {
protected final ByteVectorValues vectorValues;
protected final ByteVectorValues vectorValues1;
protected final ByteVectorValues vectorValues2;
protected final ByteVectorValues targetVectors;

public BitRandomVectorScorerSupplier(ByteVectorValues vectorValues) throws IOException {
this.vectorValues = vectorValues;
this.vectorValues1 = vectorValues.copy();
this.vectorValues2 = vectorValues.copy();
this.targetVectors = vectorValues.copy();
}

@Override
public RandomVectorScorer scorer(int ord) throws IOException {
byte[] query = vectorValues1.vectorValue(ord);
return new BitRandomVectorScorer(vectorValues2, query);
public UpdateableRandomVectorScorer scorer(Integer ord) throws IOException {
byte[] query = new byte[vectorValues.dimension()];
if (ord == null) {
return new BitRandomVectorScorer(vectorValues, query);
}
System.arraycopy(targetVectors.vectorValue(ord), 0, query, 0, query.length);
return new BitRandomVectorScorer(targetVectors, query);
}

@Override
public RandomVectorScorerSupplier copy() throws IOException {
return new BitRandomVectorScorerSupplier(vectorValues.copy());
return new BitRandomVectorScorerSupplier(vectorValues);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;

/**
* Default implementation of {@link FlatVectorsScorer}.
Expand Down Expand Up @@ -89,24 +90,32 @@ public String toString() {
/** RandomVectorScorerSupplier for bytes vector */
private static final class ByteScoringSupplier implements RandomVectorScorerSupplier {
private final ByteVectorValues vectors;
private final ByteVectorValues vectors1;
private final ByteVectorValues vectors2;
private final ByteVectorValues targetVectors;
private final VectorSimilarityFunction similarityFunction;

private ByteScoringSupplier(
ByteVectorValues vectors, VectorSimilarityFunction similarityFunction) throws IOException {
this.vectors = vectors;
vectors1 = vectors.copy();
vectors2 = vectors.copy();
targetVectors = vectors.copy();
this.similarityFunction = similarityFunction;
}

@Override
public RandomVectorScorer scorer(int ord) {
return new RandomVectorScorer.AbstractRandomVectorScorer(vectors) {
public UpdateableRandomVectorScorer scorer(Integer ord) throws IOException {
byte[] vector = new byte[vectors.dimension()];
if (ord != null) {
System.arraycopy(targetVectors.vectorValue(ord), 0, vector, 0, vector.length);
}
return new UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer(vectors) {

@Override
public void setScoringOrdinal(int node) throws IOException {
System.arraycopy(targetVectors.vectorValue(node), 0, vector, 0, vector.length);
}

@Override
public float score(int node) throws IOException {
return similarityFunction.compare(vectors1.vectorValue(ord), vectors2.vectorValue(node));
return similarityFunction.compare(vector, targetVectors.vectorValue(node));
}
};
}
Expand All @@ -125,24 +134,31 @@ public String toString() {
/** RandomVectorScorerSupplier for Float vector */
private static final class FloatScoringSupplier implements RandomVectorScorerSupplier {
private final FloatVectorValues vectors;
private final FloatVectorValues vectors1;
private final FloatVectorValues vectors2;
private final FloatVectorValues targetVectors;
private final VectorSimilarityFunction similarityFunction;

private FloatScoringSupplier(
FloatVectorValues vectors, VectorSimilarityFunction similarityFunction) throws IOException {
this.vectors = vectors;
vectors1 = vectors.copy();
vectors2 = vectors.copy();
targetVectors = vectors.copy();
this.similarityFunction = similarityFunction;
}

@Override
public RandomVectorScorer scorer(int ord) {
return new RandomVectorScorer.AbstractRandomVectorScorer(vectors) {
public UpdateableRandomVectorScorer scorer(Integer ord) throws IOException {
float[] vector = new float[vectors.dimension()];
if (ord != null) {
System.arraycopy(targetVectors.vectorValue(ord), 0, vector, 0, vector.length);
}
return new UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer(vectors) {
@Override
public float score(int node) throws IOException {
return similarityFunction.compare(vectors1.vectorValue(ord), vectors2.vectorValue(node));
return similarityFunction.compare(vector, targetVectors.vectorValue(node));
}

@Override
public void setScoringOrdinal(int node) throws IOException {
System.arraycopy(targetVectors.vectorValue(node), 0, vector, 0, vector.length);
}
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
import org.apache.lucene.util.quantization.ScalarQuantizedVectorSimilarity;
import org.apache.lucene.util.quantization.ScalarQuantizer;
Expand Down Expand Up @@ -147,11 +148,21 @@ private ScalarQuantizedRandomVectorScorerSupplier(
}

@Override
public RandomVectorScorer scorer(int ord) throws IOException {
public UpdateableRandomVectorScorer scorer(Integer ord) throws IOException {
final QuantizedByteVectorValues vectorsCopy = values.copy();
final byte[] queryVector = values.vectorValue(ord);
final float queryOffset = values.getScoreCorrectionConstant(ord);
return new RandomVectorScorer.AbstractRandomVectorScorer(vectorsCopy) {
byte[] queryVector = new byte[values.dimension()];
if (ord != null) {
System.arraycopy(values.vectorValue(ord), 0, queryVector, 0, queryVector.length);
}
return new UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer(vectorsCopy) {
float queryOffset = ord != null ? values.getScoreCorrectionConstant(ord) : 0;

@Override
public void setScoringOrdinal(int node) throws IOException {
System.arraycopy(vectorsCopy.vectorValue(node), 0, queryVector, 0, queryVector.length);
queryOffset = vectorsCopy.getScoreCorrectionConstant(node);
}

@Override
public float score(int node) throws IOException {
byte[] nodeVector = vectorsCopy.vectorValue(node);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;

/**
* Writes vector values to index segments.
Expand Down Expand Up @@ -507,7 +507,7 @@ static final class FlatCloseableRandomVectorScorerSupplier
}

@Override
public RandomVectorScorer scorer(int ord) throws IOException {
public UpdateableRandomVectorScorer scorer(Integer ord) throws IOException {
return supplier.scorer(ord);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
import org.apache.lucene.util.hnsw.NeighborArray;
import org.apache.lucene.util.hnsw.OnHeapHnswGraph;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
import org.apache.lucene.util.packed.DirectMonotonicWriter;

/**
Expand Down Expand Up @@ -561,6 +562,8 @@ private static class FieldWriter<T> extends KnnFieldVectorsWriter<T> {
private int lastDocID = -1;
private int node = 0;
private final FlatFieldVectorsWriter<T> flatFieldVectorsWriter;
private final RandomVectorScorerSupplier scorerSupplier;
private UpdateableRandomVectorScorer scorer;

@SuppressWarnings("unchecked")
static FieldWriter<?> create(
Expand Down Expand Up @@ -601,7 +604,7 @@ static FieldWriter<?> create(
InfoStream infoStream)
throws IOException {
this.fieldInfo = fieldInfo;
RandomVectorScorerSupplier scorerSupplier =
scorerSupplier =
switch (fieldInfo.getVectorEncoding()) {
case BYTE ->
scorer.getRandomVectorScorerSupplier(
Expand Down Expand Up @@ -631,7 +634,12 @@ public void addValue(int docID, T vectorValue) throws IOException {
+ "\" appears more than once in this document (only one value is allowed per field)");
}
flatFieldVectorsWriter.addValue(docID, vectorValue);
hnswGraphBuilder.addGraphNode(node);
if (scorer == null) {
scorer = scorerSupplier.scorer(node);
} else {
scorer.setScoringOrdinal(node);
}
hnswGraphBuilder.addGraphNode(node, scorer);
node++;
lastDocID = docID;
}
Expand Down
Loading