From 59dab18a9a0654097edfc37be427a88856d395da Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Wed, 29 Jan 2025 16:28:40 -0500 Subject: [PATCH 1/4] Refactor the vector scorer interface to allow reuse during HNSW graph building --- .../bitvectors/FlatBitVectorsScorer.java | 10 ++- .../codecs/hnsw/DefaultFlatVectorScorer.java | 28 ++++-- .../hnsw/ScalarQuantizedVectorScorer.java | 17 +++- .../lucene99/Lucene99FlatVectorsWriter.java | 4 +- .../lucene99/Lucene99HnswVectorsWriter.java | 12 ++- .../Lucene99ScalarQuantizedVectorScorer.java | 68 ++++++++++---- .../Lucene99ScalarQuantizedVectorsWriter.java | 4 +- .../CloseableRandomVectorScorerSupplier.java | 4 +- .../util/hnsw/HnswConcurrentMergeBuilder.java | 16 +++- .../lucene/util/hnsw/HnswGraphBuilder.java | 88 ++++++++++++------- .../hnsw/InitializedHnswGraphBuilder.java | 8 ++ .../lucene/util/hnsw/NeighborArray.java | 18 ++-- .../lucene/util/hnsw/RandomVectorScorer.java | 5 +- .../util/hnsw/RandomVectorScorerSupplier.java | 8 +- .../hnsw/UpdateableRandomVectorScorer.java | 67 ++++++++++++++ ...MemorySegmentByteVectorScorerSupplier.java | 60 ++++++++++--- 16 files changed, 313 insertions(+), 104 deletions(-) create mode 100644 lucene/core/src/java/org/apache/lucene/util/hnsw/UpdateableRandomVectorScorer.java diff --git a/lucene/codecs/src/java/org/apache/lucene/codecs/bitvectors/FlatBitVectorsScorer.java b/lucene/codecs/src/java/org/apache/lucene/codecs/bitvectors/FlatBitVectorsScorer.java index 8ffcc1c8d50e..812858c4968f 100644 --- a/lucene/codecs/src/java/org/apache/lucene/codecs/bitvectors/FlatBitVectorsScorer.java +++ b/lucene/codecs/src/java/org/apache/lucene/codecs/bitvectors/FlatBitVectorsScorer.java @@ -26,6 +26,7 @@ import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; /** A bit vector scorer for scoring byte vectors. */ public class FlatBitVectorsScorer implements FlatVectorsScorer { @@ -58,7 +59,7 @@ public RandomVectorScorer getRandomVectorScorer( throw new IllegalArgumentException("vectorValues must be an instance of ByteVectorValues"); } - static class BitRandomVectorScorer implements RandomVectorScorer { + static class BitRandomVectorScorer implements UpdateableRandomVectorScorer { private final ByteVectorValues vectorValues; private final int bitDimensions; private final byte[] query; @@ -80,6 +81,11 @@ public int maxOrd() { return vectorValues.size(); } + @Override + public void setScoringOrdinal(int node) throws IOException { + System.arraycopy(vectorValues.vectorValue(node), 0, query, 0, query.length); + } + @Override public int ordToDoc(int ord) { return vectorValues.ordToDoc(ord); @@ -103,7 +109,7 @@ public BitRandomVectorScorerSupplier(ByteVectorValues vectorValues) throws IOExc } @Override - public RandomVectorScorer scorer(int ord) throws IOException { + public UpdateableRandomVectorScorer scorer(int ord) throws IOException { byte[] query = vectorValues1.vectorValue(ord); return new BitRandomVectorScorer(vectorValues2, query); } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/DefaultFlatVectorScorer.java b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/DefaultFlatVectorScorer.java index 3e506037969a..bafd2209820e 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/DefaultFlatVectorScorer.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/DefaultFlatVectorScorer.java @@ -24,6 +24,7 @@ import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; /** * Default implementation of {@link FlatVectorsScorer}. @@ -90,23 +91,29 @@ public String toString() { private static final class ByteScoringSupplier implements RandomVectorScorerSupplier { private final ByteVectorValues vectors; private final ByteVectorValues vectors1; - private final ByteVectorValues vectors2; private final VectorSimilarityFunction similarityFunction; private ByteScoringSupplier( ByteVectorValues vectors, VectorSimilarityFunction similarityFunction) throws IOException { this.vectors = vectors; vectors1 = vectors.copy(); - vectors2 = vectors.copy(); this.similarityFunction = similarityFunction; } @Override - public RandomVectorScorer scorer(int ord) { - return new RandomVectorScorer.AbstractRandomVectorScorer(vectors) { + public UpdateableRandomVectorScorer scorer(int ord) throws IOException { + byte[] vector = new byte[vectors.dimension()]; + System.arraycopy(vectors1.vectorValue(ord), 0, vector, 0, vector.length); + return new UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer(vectors) { + + @Override + public void setScoringOrdinal(int node) throws IOException { + System.arraycopy(vectors1.vectorValue(node), 0, vector, 0, vector.length); + } + @Override public float score(int node) throws IOException { - return similarityFunction.compare(vectors1.vectorValue(ord), vectors2.vectorValue(node)); + return similarityFunction.compare(vector, vectors1.vectorValue(ord)); } }; } @@ -138,12 +145,19 @@ private FloatScoringSupplier( } @Override - public RandomVectorScorer scorer(int ord) { - return new RandomVectorScorer.AbstractRandomVectorScorer(vectors) { + public UpdateableRandomVectorScorer scorer(int ord) throws IOException { + float[] vector = new float[vectors.dimension()]; + System.arraycopy(vectors1.vectorValue(ord), 0, vector, 0, vector.length); + return new UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer(vectors) { @Override public float score(int node) throws IOException { return similarityFunction.compare(vectors1.vectorValue(ord), vectors2.vectorValue(node)); } + + @Override + public void setScoringOrdinal(int node) throws IOException { + System.arraycopy(vectors1.vectorValue(node), 0, vector, 0, vector.length); + } }; } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/ScalarQuantizedVectorScorer.java b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/ScalarQuantizedVectorScorer.java index ceb826aa3a11..11f6569dd88d 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/ScalarQuantizedVectorScorer.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/ScalarQuantizedVectorScorer.java @@ -24,6 +24,7 @@ import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; import org.apache.lucene.util.quantization.QuantizedByteVectorValues; import org.apache.lucene.util.quantization.ScalarQuantizedVectorSimilarity; import org.apache.lucene.util.quantization.ScalarQuantizer; @@ -147,11 +148,19 @@ private ScalarQuantizedRandomVectorScorerSupplier( } @Override - public RandomVectorScorer scorer(int ord) throws IOException { + public UpdateableRandomVectorScorer scorer(int ord) throws IOException { final QuantizedByteVectorValues vectorsCopy = values.copy(); - final byte[] queryVector = values.vectorValue(ord); - final float queryOffset = values.getScoreCorrectionConstant(ord); - return new RandomVectorScorer.AbstractRandomVectorScorer(vectorsCopy) { + byte[] queryVector = new byte[values.dimension()]; + System.arraycopy(values.vectorValue(ord), 0, queryVector, 0, queryVector.length); + return new UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer(vectorsCopy) { + float queryOffset = values.getScoreCorrectionConstant(ord); + + @Override + public void setScoringOrdinal(int node) throws IOException { + System.arraycopy(vectorsCopy.vectorValue(node), 0, queryVector, 0, queryVector.length); + queryOffset = vectorsCopy.getScoreCorrectionConstant(node); + } + @Override public float score(int node) throws IOException { byte[] nodeVector = vectorsCopy.vectorValue(node); diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsWriter.java index b731e758b7a8..33426e41aa77 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsWriter.java @@ -52,8 +52,8 @@ import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.RamUsageEstimator; import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier; -import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; /** * Writes vector values to index segments. @@ -507,7 +507,7 @@ static final class FlatCloseableRandomVectorScorerSupplier } @Override - public RandomVectorScorer scorer(int ord) throws IOException { + public UpdateableRandomVectorScorer scorer(int ord) throws IOException { return supplier.scorer(ord); } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsWriter.java index e219157ab986..af6c30a2c272 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsWriter.java @@ -57,6 +57,7 @@ import org.apache.lucene.util.hnsw.NeighborArray; import org.apache.lucene.util.hnsw.OnHeapHnswGraph; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; import org.apache.lucene.util.packed.DirectMonotonicWriter; /** @@ -561,6 +562,8 @@ private static class FieldWriter extends KnnFieldVectorsWriter { private int lastDocID = -1; private int node = 0; private final FlatFieldVectorsWriter flatFieldVectorsWriter; + private final RandomVectorScorerSupplier scorerSupplier; + private UpdateableRandomVectorScorer scorer; @SuppressWarnings("unchecked") static FieldWriter create( @@ -601,7 +604,7 @@ static FieldWriter create( InfoStream infoStream) throws IOException { this.fieldInfo = fieldInfo; - RandomVectorScorerSupplier scorerSupplier = + scorerSupplier = switch (fieldInfo.getVectorEncoding()) { case BYTE -> scorer.getRandomVectorScorerSupplier( @@ -631,7 +634,12 @@ public void addValue(int docID, T vectorValue) throws IOException { + "\" appears more than once in this document (only one value is allowed per field)"); } flatFieldVectorsWriter.addValue(docID, vectorValue); - hnswGraphBuilder.addGraphNode(node); + if (scorer == null) { + scorer = scorerSupplier.scorer(node); + } else { + scorer.setScoringOrdinal(node); + } + hnswGraphBuilder.addGraphNode(node, scorer); node++; lastDocID = docID; } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorScorer.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorScorer.java index 8fc417e22f07..0c3c07703a1a 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorScorer.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorScorer.java @@ -26,6 +26,7 @@ import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; import org.apache.lucene.util.quantization.QuantizedByteVectorValues; import org.apache.lucene.util.quantization.ScalarQuantizer; @@ -87,7 +88,7 @@ public String toString() { return "ScalarQuantizedVectorScorer(" + "nonQuantizedDelegate=" + nonQuantizedDelegate + ')'; } - static RandomVectorScorer fromVectorSimilarity( + static UpdateableRandomVectorScorer fromVectorSimilarity( byte[] targetBytes, float offsetCorrection, VectorSimilarityFunction sim, @@ -120,12 +121,13 @@ static void checkDimensions(int queryLen, int fieldLen) { } } - private static RandomVectorScorer.AbstractRandomVectorScorer dotProductFactory( - byte[] targetBytes, - float offsetCorrection, - float constMultiplier, - QuantizedByteVectorValues values, - FloatToFloatFunction scoreAdjustmentFunction) { + private static UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer + dotProductFactory( + byte[] targetBytes, + float offsetCorrection, + float constMultiplier, + QuantizedByteVectorValues values, + FloatToFloatFunction scoreAdjustmentFunction) { if (values.getScalarQuantizer().getBits() <= 4) { if (values.getVectorByteLength() != values.dimension() && values.getSlice() != null) { return new CompressedInt4DotProduct( @@ -138,7 +140,8 @@ private static RandomVectorScorer.AbstractRandomVectorScorer dotProductFactory( values, constMultiplier, targetBytes, offsetCorrection, scoreAdjustmentFunction); } - private static class Euclidean extends RandomVectorScorer.AbstractRandomVectorScorer { + private static class Euclidean + extends UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer { private final float constMultiplier; private final byte[] targetBytes; private final QuantizedByteVectorValues values; @@ -157,14 +160,20 @@ public float score(int node) throws IOException { float adjustedDistance = squareDistance * constMultiplier; return 1 / (1f + adjustedDistance); } + + @Override + public void setScoringOrdinal(int node) throws IOException { + System.arraycopy(values.vectorValue(node), 0, targetBytes, 0, targetBytes.length); + } } /** Calculates dot product on quantized vectors, applying the appropriate corrections */ - private static class DotProduct extends RandomVectorScorer.AbstractRandomVectorScorer { + private static class DotProduct + extends UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer { private final float constMultiplier; private final QuantizedByteVectorValues values; private final byte[] targetBytes; - private final float offsetCorrection; + private float offsetCorrection; private final FloatToFloatFunction scoreAdjustmentFunction; public DotProduct( @@ -191,15 +200,24 @@ public float score(int vectorOrdinal) throws IOException { float adjustedDistance = dotProduct * constMultiplier + offsetCorrection + vectorOffset; return scoreAdjustmentFunction.apply(adjustedDistance); } + + @Override + public void setScoringOrdinal(int node) throws IOException { + System.arraycopy(values.vectorValue(node), 0, targetBytes, 0, targetBytes.length); + offsetCorrection = values.getScoreCorrectionConstant(node); + } } + // TODO consider splitting this into two classes. right now the "query" vector is always + // decompressed + // it could stay compressed if we had a compressed version of the target vector private static class CompressedInt4DotProduct - extends RandomVectorScorer.AbstractRandomVectorScorer { + extends UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer { private final float constMultiplier; private final QuantizedByteVectorValues values; private final byte[] compressedVector; private final byte[] targetBytes; - private final float offsetCorrection; + private float offsetCorrection; private final FloatToFloatFunction scoreAdjustmentFunction; private CompressedInt4DotProduct( @@ -230,13 +248,20 @@ public float score(int vectorOrdinal) throws IOException { float adjustedDistance = dotProduct * constMultiplier + offsetCorrection + vectorOffset; return scoreAdjustmentFunction.apply(adjustedDistance); } + + @Override + public void setScoringOrdinal(int node) throws IOException { + System.arraycopy(values.vectorValue(node), 0, targetBytes, 0, targetBytes.length); + offsetCorrection = values.getScoreCorrectionConstant(node); + } } - private static class Int4DotProduct extends RandomVectorScorer.AbstractRandomVectorScorer { + private static class Int4DotProduct + extends UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer { private final float constMultiplier; private final QuantizedByteVectorValues values; private final byte[] targetBytes; - private final float offsetCorrection; + private float offsetCorrection; private final FloatToFloatFunction scoreAdjustmentFunction; public Int4DotProduct( @@ -263,6 +288,12 @@ public float score(int vectorOrdinal) throws IOException { float adjustedDistance = dotProduct * constMultiplier + offsetCorrection + vectorOffset; return scoreAdjustmentFunction.apply(adjustedDistance); } + + @Override + public void setScoringOrdinal(int node) throws IOException { + System.arraycopy(values.vectorValue(node), 0, targetBytes, 0, targetBytes.length); + offsetCorrection = values.getScoreCorrectionConstant(node); + } } @FunctionalInterface @@ -276,27 +307,26 @@ private static final class ScalarQuantizedRandomVectorScorerSupplier private final VectorSimilarityFunction vectorSimilarityFunction; private final QuantizedByteVectorValues values; private final QuantizedByteVectorValues values1; - private final QuantizedByteVectorValues values2; public ScalarQuantizedRandomVectorScorerSupplier( QuantizedByteVectorValues values, VectorSimilarityFunction vectorSimilarityFunction) throws IOException { this.values = values; this.values1 = values.copy(); - this.values2 = values.copy(); this.vectorSimilarityFunction = vectorSimilarityFunction; } @Override - public RandomVectorScorer scorer(int ord) throws IOException { - byte[] vectorValue = values1.vectorValue(ord); + public UpdateableRandomVectorScorer scorer(int ord) throws IOException { + byte[] vectorValue = new byte[values.dimension()]; + System.arraycopy(values1.vectorValue(ord), 0, vectorValue, 0, vectorValue.length); float offsetCorrection = values1.getScoreCorrectionConstant(ord); return fromVectorSimilarity( vectorValue, offsetCorrection, vectorSimilarityFunction, values.getScalarQuantizer().getConstantMultiplier(), - values2); + values1); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java index 39f3a81983b4..6bd18dd9ca6f 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java @@ -58,8 +58,8 @@ import org.apache.lucene.util.InfoStream; import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier; -import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; import org.apache.lucene.util.quantization.QuantizedByteVectorValues; import org.apache.lucene.util.quantization.QuantizedVectorsReader; import org.apache.lucene.util.quantization.ScalarQuantizer; @@ -1128,7 +1128,7 @@ static final class ScalarQuantizedCloseableRandomVectorScorerSupplier } @Override - public RandomVectorScorer scorer(int ord) throws IOException { + public UpdateableRandomVectorScorer scorer(int ord) throws IOException { return supplier.scorer(ord); } diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/CloseableRandomVectorScorerSupplier.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/CloseableRandomVectorScorerSupplier.java index 148963e7dc98..8c66147f1a8b 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/CloseableRandomVectorScorerSupplier.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/CloseableRandomVectorScorerSupplier.java @@ -20,8 +20,8 @@ import java.io.Closeable; /** - * A supplier that creates {@link RandomVectorScorer} from an ordinal. Caller should be sure to - * close after use + * A supplier that creates {@link UpdateableRandomVectorScorer} from an ordinal. Caller should be + * sure to close after use * *

NOTE: the {@link #copy()} returned {@link RandomVectorScorerSupplier} is not necessarily * closeable diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswConcurrentMergeBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswConcurrentMergeBuilder.java index d9d58c829d3d..c7c659fdb525 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswConcurrentMergeBuilder.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswConcurrentMergeBuilder.java @@ -142,6 +142,7 @@ private static final class ConcurrentMergeWorker extends HnswGraphBuilder { private final BitSet initializedNodes; private int batchSize = DEFAULT_BATCH_SIZE; + private UpdateableRandomVectorScorer scorer; private ConcurrentMergeWorker( RandomVectorScorerSupplier scorerSupplier, @@ -190,12 +191,25 @@ private int getStartPos(int maxOrd) { } } + @Override + public void addGraphNode(int node, UpdateableRandomVectorScorer scorer) throws IOException { + if (initializedNodes != null && initializedNodes.get(node)) { + return; + } + super.addGraphNode(node, scorer); + } + @Override public void addGraphNode(int node) throws IOException { if (initializedNodes != null && initializedNodes.get(node)) { return; } - super.addGraphNode(node); + if (scorer == null) { + scorer = scorerSupplier.scorer(node); + } else { + scorer.setScoringOrdinal(node); + } + addGraphNode(node, scorer); } } diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java index d3cef0bc6d10..ba8e523adc83 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java @@ -61,7 +61,7 @@ public class HnswGraphBuilder implements HnswBuilder { private final double ml; private final SplittableRandom random; - private final RandomVectorScorerSupplier scorerSupplier; + protected final RandomVectorScorerSupplier scorerSupplier; private final HnswGraphSearcher graphSearcher; private final GraphBuilderKnnCollector entryCandidates; // for upper levels of graph search private final GraphBuilderKnnCollector @@ -191,8 +191,14 @@ protected void addVectors(int minOrd, int maxOrd) throws IOException { if (infoStream.isEnabled(HNSW_COMPONENT)) { infoStream.message(HNSW_COMPONENT, "addVectors [" + minOrd + " " + maxOrd + ")"); } + UpdateableRandomVectorScorer scorer = null; for (int node = minOrd; node < maxOrd; node++) { - addGraphNode(node); + if (scorer == null) { + scorer = scorerSupplier.scorer(node); + } else { + scorer.setScoringOrdinal(node); + } + addGraphNode(node, scorer); if ((node % 10000 == 0) && infoStream.isEnabled(HNSW_COMPONENT)) { t = printGraphBuildStatus(node, start, t); } @@ -203,30 +209,10 @@ private void addVectors(int maxOrd) throws IOException { addVectors(0, maxOrd); } - @Override - public void addGraphNode(int node) throws IOException { - /* - Note: this implementation is thread safe when graph size is fixed (e.g. when merging) - The process of adding a node is roughly: - 1. Add the node to all level from top to the bottom, but do not connect it to any other node, - nor try to promote itself to an entry node before the connection is done. (Unless the graph is empty - and this is the first node, in that case we set the entry node and return) - 2. Do the search from top to bottom, remember all the possible neighbours on each level the node - is on. - 3. Add the neighbor to the node from bottom to top level, when adding the neighbour, - we always add all the outgoing links first before adding incoming link such that - when a search visits this node, it can always find a way out - 4. If the node has level that is less or equal to graph level, then we're done here. - If the node has level larger than graph level, then we need to promote the node - as the entry node. If, while we add the node to the graph, the entry node has changed - (which means the graph level has changed as well), we need to reinsert the node - to the newly introduced levels (repeating step 2,3 for new levels) and again try to - promote the node to entry node. - */ + public void addGraphNode(int node, UpdateableRandomVectorScorer scorer) throws IOException { if (frozen) { throw new IllegalStateException("Graph builder is already frozen"); } - RandomVectorScorer scorer = scorerSupplier.scorer(node); final int nodeLevel = getRandomGraphLevel(ml, random); // first add nodes to all levels for (int level = nodeLevel; level >= 0; level--) { @@ -271,7 +257,7 @@ to the newly introduced levels (repeating step 2,3 for new levels) and again try // then do connections from bottom up for (int i = 0; i < scratchPerLevel.length; i++) { - addDiverseNeighbors(i + lowestUnsetLevel, node, scratchPerLevel[i]); + addDiverseNeighbors(i + lowestUnsetLevel, node, scratchPerLevel[i], scorer); } lowestUnsetLevel += scratchPerLevel.length; assert lowestUnsetLevel == Math.min(nodeLevel, curMaxLevel) + 1; @@ -296,6 +282,30 @@ to the newly introduced levels (repeating step 2,3 for new levels) and again try } while (true); } + @Override + public void addGraphNode(int node) throws IOException { + /* + Note: this implementation is thread safe when graph size is fixed (e.g. when merging) + The process of adding a node is roughly: + 1. Add the node to all level from top to the bottom, but do not connect it to any other node, + nor try to promote itself to an entry node before the connection is done. (Unless the graph is empty + and this is the first node, in that case we set the entry node and return) + 2. Do the search from top to bottom, remember all the possible neighbours on each level the node + is on. + 3. Add the neighbor to the node from bottom to top level, when adding the neighbour, + we always add all the outgoing links first before adding incoming link such that + when a search visits this node, it can always find a way out + 4. If the node has level that is less or equal to graph level, then we're done here. + If the node has level larger than graph level, then we need to promote the node + as the entry node. If, while we add the node to the graph, the entry node has changed + (which means the graph level has changed as well), we need to reinsert the node + to the newly introduced levels (repeating step 2,3 for new levels) and again try to + promote the node to entry node. + */ + UpdateableRandomVectorScorer scorer = scorerSupplier.scorer(node); + addGraphNode(node, scorer); + } + private long printGraphBuildStatus(int node, long start, long t) { long now = System.nanoTime(); infoStream.message( @@ -309,7 +319,8 @@ private long printGraphBuildStatus(int node, long start, long t) { return now; } - private void addDiverseNeighbors(int level, int node, NeighborArray candidates) + private void addDiverseNeighbors( + int level, int node, NeighborArray candidates, UpdateableRandomVectorScorer scorer) throws IOException { /* For each of the beamWidth nearest candidates (going from best to worst), select it only if it * is closer to target than it is to any of the already-selected neighbors (ie selected in this method, @@ -318,7 +329,7 @@ private void addDiverseNeighbors(int level, int node, NeighborArray candidates) NeighborArray neighbors = hnsw.getNeighbors(level, node); assert neighbors.size() == 0; // new node int maxConnOnLevel = level == 0 ? M * 2 : M; - boolean[] mask = selectAndLinkDiverse(neighbors, candidates, maxConnOnLevel); + boolean[] mask = selectAndLinkDiverse(neighbors, candidates, maxConnOnLevel, scorer); // Link the selected nodes to the new node, and the new node to the selected nodes (again // applying diversity heuristic) @@ -334,13 +345,13 @@ private void addDiverseNeighbors(int level, int node, NeighborArray candidates) Lock lock = hnswLock.write(level, nbr); try { NeighborArray nbrsOfNbr = getGraph().getNeighbors(level, nbr); - nbrsOfNbr.addAndEnsureDiversity(node, candidates.scores()[i], nbr, scorerSupplier); + nbrsOfNbr.addAndEnsureDiversity(node, candidates.scores()[i], nbr, scorer); } finally { lock.unlock(); } } else { NeighborArray nbrsOfNbr = hnsw.getNeighbors(level, nbr); - nbrsOfNbr.addAndEnsureDiversity(node, candidates.scores()[i], nbr, scorerSupplier); + nbrsOfNbr.addAndEnsureDiversity(node, candidates.scores()[i], nbr, scorer); } } } @@ -350,7 +361,11 @@ private void addDiverseNeighbors(int level, int node, NeighborArray candidates) * are selected */ private boolean[] selectAndLinkDiverse( - NeighborArray neighbors, NeighborArray candidates, int maxConnOnLevel) throws IOException { + NeighborArray neighbors, + NeighborArray candidates, + int maxConnOnLevel, + UpdateableRandomVectorScorer scorer) + throws IOException { boolean[] mask = new boolean[candidates.size()]; // Select the best maxConnOnLevel neighbors of the new node, applying the diversity heuristic for (int i = candidates.size() - 1; neighbors.size() < maxConnOnLevel && i >= 0; i--) { @@ -359,7 +374,8 @@ private boolean[] selectAndLinkDiverse( int cNode = candidates.nodes()[i]; float cScore = candidates.scores()[i]; assert cNode <= hnsw.maxNodeId(); - if (diversityCheck(cNode, cScore, neighbors)) { + scorer.setScoringOrdinal(cNode); + if (diversityCheck(cScore, neighbors, scorer)) { mask[i] = true; // here we don't need to lock, because there's no incoming link so no others is able to // discover this node such that no others will modify this neighbor array as well @@ -381,15 +397,14 @@ private static void popToScratch(GraphBuilderKnnCollector candidates, NeighborAr } /** - * @param candidate the vector of a new candidate neighbor of a node n * @param score the score of the new candidate and node n, to be compared with scores of the * candidate and n's neighbors * @param neighbors the neighbors selected so far * @return whether the candidate is diverse given the existing neighbors */ - private boolean diversityCheck(int candidate, float score, NeighborArray neighbors) + private boolean diversityCheck( + float score, NeighborArray neighbors, RandomVectorScorer scorer) throws IOException { - RandomVectorScorer scorer = scorerSupplier.scorer(candidate); for (int i = 0; i < neighbors.size(); i++) { float neighborSimilarity = scorer.score(neighbors.nodes()[i]); if (neighborSimilarity >= score) { @@ -452,6 +467,7 @@ private boolean connectComponents(int level) throws IOException { // while linking GraphBuilderKnnCollector beam = new GraphBuilderKnnCollector(2); int[] eps = new int[1]; + UpdateableRandomVectorScorer scorer = null; for (Component c : components) { if (c != c0) { if (c.start() == NO_MORE_DOCS) { @@ -463,7 +479,11 @@ private boolean connectComponents(int level) throws IOException { beam.clear(); eps[0] = c0.start(); - RandomVectorScorer scorer = scorerSupplier.scorer(c.start()); + if (scorer == null) { + scorer = scorerSupplier.scorer(c.start()); + } else { + scorer.setScoringOrdinal(c.start()); + } // find the closest node in the largest component to the lowest-numbered node in this // component that has room to make a connection graphSearcher.searchLevel(beam, scorer, level, eps, hnsw, notFullyConnected); diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/InitializedHnswGraphBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/InitializedHnswGraphBuilder.java index 7dff036ddde4..c4c790c70258 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/InitializedHnswGraphBuilder.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/InitializedHnswGraphBuilder.java @@ -98,6 +98,14 @@ public InitializedHnswGraphBuilder( this.initializedNodes = initializedNodes; } + @Override + public void addGraphNode(int node, UpdateableRandomVectorScorer scorer) throws IOException { + if (initializedNodes.get(node)) { + return; + } + super.addGraphNode(node, scorer); + } + @Override public void addGraphNode(int node) throws IOException { if (initializedNodes.get(node)) { diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java index 716364a39dc2..c7745120a5fa 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java @@ -87,14 +87,15 @@ public void addOutOfOrder(int newNode, float newScore) { * @param nodeId node Id of the owner of this NeighbourArray */ public void addAndEnsureDiversity( - int newNode, float newScore, int nodeId, RandomVectorScorerSupplier scorerSupplier) + int newNode, float newScore, int nodeId, UpdateableRandomVectorScorer scorer) throws IOException { addOutOfOrder(newNode, newScore); if (size < nodes.length) { return; } // we're oversize, need to do diversity check and pop out the least diverse neighbour - removeIndex(findWorstNonDiverse(nodeId, scorerSupplier)); + scorer.setScoringOrdinal(nodeId); + removeIndex(findWorstNonDiverse(scorer)); assert size == nodes.length - 1; } @@ -235,9 +236,7 @@ private int descSortFindRightMostInsertionPoint(float newScore, int bound) { * Find first non-diverse neighbour among the list of neighbors starting from the most distant * neighbours */ - private int findWorstNonDiverse(int nodeOrd, RandomVectorScorerSupplier scorerSupplier) - throws IOException { - RandomVectorScorer scorer = scorerSupplier.scorer(nodeOrd); + private int findWorstNonDiverse(UpdateableRandomVectorScorer scorer) throws IOException { int[] uncheckedIndexes = sort(scorer); assert uncheckedIndexes != null : "We will always have something unchecked"; int uncheckedCursor = uncheckedIndexes.length - 1; @@ -246,7 +245,8 @@ private int findWorstNonDiverse(int nodeOrd, RandomVectorScorerSupplier scorerSu // no unchecked node left break; } - if (isWorstNonDiverse(i, uncheckedIndexes, uncheckedCursor, scorerSupplier)) { + scorer.setScoringOrdinal(nodes[i]); + if (isWorstNonDiverse(i, uncheckedIndexes, uncheckedCursor, scorer)) { return i; } if (i == uncheckedIndexes[uncheckedCursor]) { @@ -257,13 +257,9 @@ private int findWorstNonDiverse(int nodeOrd, RandomVectorScorerSupplier scorerSu } private boolean isWorstNonDiverse( - int candidateIndex, - int[] uncheckedIndexes, - int uncheckedCursor, - RandomVectorScorerSupplier scorerSupplier) + int candidateIndex, int[] uncheckedIndexes, int uncheckedCursor, RandomVectorScorer scorer) throws IOException { float minAcceptedSimilarity = scores[candidateIndex]; - RandomVectorScorer scorer = scorerSupplier.scorer(nodes[candidateIndex]); if (candidateIndex == uncheckedIndexes[uncheckedCursor]) { // the candidate itself is unchecked for (int i = candidateIndex - 1; i >= 0; i--) { diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomVectorScorer.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomVectorScorer.java index a135df436991..d2ea8e28a246 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomVectorScorer.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomVectorScorer.java @@ -21,7 +21,10 @@ import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.util.Bits; -/** A {@link RandomVectorScorer} for scoring random nodes in batches against an abstract query. */ +/** + * A {@link RandomVectorScorer} for scoring random nodes in batches against an abstract query. This + * class isn't thread-safe and should be used by a single thread. + */ public interface RandomVectorScorer { /** * Returns the score between the query and the provided node. diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomVectorScorerSupplier.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomVectorScorerSupplier.java index f8436f061d6a..5ee84d5594e9 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomVectorScorerSupplier.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomVectorScorerSupplier.java @@ -22,13 +22,13 @@ /** A supplier that creates {@link RandomVectorScorer} from an ordinal. */ public interface RandomVectorScorerSupplier { /** - * This creates a {@link RandomVectorScorer} for scoring random nodes in batches against the given - * ordinal. + * This creates a {@link UpdateableRandomVectorScorer} for scoring random nodes in batches against + * the given ordinal. Optionally allowing the ordinal to be updated. * * @param ord the ordinal of the node to compare - * @return a new {@link RandomVectorScorer} + * @return a new {@link UpdateableRandomVectorScorer} */ - RandomVectorScorer scorer(int ord) throws IOException; + UpdateableRandomVectorScorer scorer(int ord) throws IOException; /** * Make a copy of the supplier, which will copy the underlying vectorValues so the copy is safe to diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/UpdateableRandomVectorScorer.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/UpdateableRandomVectorScorer.java new file mode 100644 index 000000000000..ac07cf0fd6e3 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/UpdateableRandomVectorScorer.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.util.hnsw; + +import java.io.IOException; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.util.Bits; + +/** + * Just like a {@link RandomVectorScorer} but allows the scoring ordinal to be changed. Useful + * during indexing operations + * + * @lucene.internal + */ +public interface UpdateableRandomVectorScorer extends RandomVectorScorer { + /** + * Changes the scoring ordinal to the given node. If the same scorer object is being used + * continually, this can be used to avoid creating a new scorer for each node. + * + * @param node the node to score against + * @throws IOException if an exception occurs initializing the scorer for the given node + */ + void setScoringOrdinal(int node) throws IOException; + + /** Creates a default scorer for random access vectors. */ + abstract class AbstractUpdateableRandomVectorScorer implements UpdateableRandomVectorScorer { + private final KnnVectorValues values; + + /** + * Creates a new scorer for the given vector values. + * + * @param values the vector values + */ + public AbstractUpdateableRandomVectorScorer(KnnVectorValues values) { + this.values = values; + } + + @Override + public int maxOrd() { + return values.size(); + } + + @Override + public int ordToDoc(int ord) { + return values.ordToDoc(ord); + } + + @Override + public Bits getAcceptOrds(Bits acceptDocs) { + return values.getAcceptOrds(acceptDocs); + } + } +} diff --git a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentByteVectorScorerSupplier.java b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentByteVectorScorerSupplier.java index 02c71561122d..e93a5d735077 100644 --- a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentByteVectorScorerSupplier.java +++ b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentByteVectorScorerSupplier.java @@ -25,8 +25,8 @@ import org.apache.lucene.store.FilterIndexInput; import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.MemorySegmentAccessInput; -import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; /** A score supplier of vectors whose element size is byte. */ public abstract sealed class Lucene99MemorySegmentByteVectorScorerSupplier @@ -110,15 +110,24 @@ static final class CosineSupplier extends Lucene99MemorySegmentByteVectorScorerS } @Override - public RandomVectorScorer scorer(int ord) { + public UpdateableRandomVectorScorer scorer(int ord) { checkOrdinal(ord); - return new RandomVectorScorer.AbstractRandomVectorScorer(values) { + return new UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer(values) { + private int queryOrd = ord; + @Override public float score(int node) throws IOException { checkOrdinal(node); - float raw = PanamaVectorUtilSupport.cosine(getFirstSegment(ord), getSecondSegment(node)); + float raw = + PanamaVectorUtilSupport.cosine(getFirstSegment(queryOrd), getSecondSegment(node)); return (1 + raw) / 2; } + + @Override + public void setScoringOrdinal(int node) { + checkOrdinal(node); + queryOrd = node; + } }; } @@ -135,17 +144,25 @@ static final class DotProductSupplier extends Lucene99MemorySegmentByteVectorSco } @Override - public RandomVectorScorer scorer(int ord) { + public UpdateableRandomVectorScorer scorer(int ord) { checkOrdinal(ord); - return new RandomVectorScorer.AbstractRandomVectorScorer(values) { + return new UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer(values) { + private int queryOrd = ord; + @Override public float score(int node) throws IOException { checkOrdinal(node); // divide by 2 * 2^14 (maximum absolute value of product of 2 signed bytes) * len float raw = - PanamaVectorUtilSupport.dotProduct(getFirstSegment(ord), getSecondSegment(node)); + PanamaVectorUtilSupport.dotProduct(getFirstSegment(queryOrd), getSecondSegment(node)); return 0.5f + raw / (float) (values.dimension() * (1 << 15)); } + + @Override + public void setScoringOrdinal(int node) { + checkOrdinal(node); + queryOrd = node; + } }; } @@ -162,16 +179,25 @@ static final class EuclideanSupplier extends Lucene99MemorySegmentByteVectorScor } @Override - public RandomVectorScorer scorer(int ord) { + public UpdateableRandomVectorScorer scorer(int ord) { checkOrdinal(ord); - return new RandomVectorScorer.AbstractRandomVectorScorer(values) { + return new UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer(values) { + private int queryOrd = ord; + @Override public float score(int node) throws IOException { checkOrdinal(node); float raw = - PanamaVectorUtilSupport.squareDistance(getFirstSegment(ord), getSecondSegment(node)); + PanamaVectorUtilSupport.squareDistance( + getFirstSegment(queryOrd), getSecondSegment(node)); return 1 / (1f + raw); } + + @Override + public void setScoringOrdinal(int node) { + checkOrdinal(node); + queryOrd = node; + } }; } @@ -188,19 +214,27 @@ static final class MaxInnerProductSupplier extends Lucene99MemorySegmentByteVect } @Override - public RandomVectorScorer scorer(int ord) { + public UpdateableRandomVectorScorer scorer(int ord) { checkOrdinal(ord); - return new RandomVectorScorer.AbstractRandomVectorScorer(values) { + return new UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer(values) { + private int queryOrd = ord; + @Override public float score(int node) throws IOException { checkOrdinal(node); float raw = - PanamaVectorUtilSupport.dotProduct(getFirstSegment(ord), getSecondSegment(node)); + PanamaVectorUtilSupport.dotProduct(getFirstSegment(queryOrd), getSecondSegment(node)); if (raw < 0) { return 1 / (1 + -1 * raw); } return raw + 1; } + + @Override + public void setScoringOrdinal(int node) { + checkOrdinal(node); + queryOrd = node; + } }; } From 5e2ae5f2d92b91ffbfeacbaa03de79245eda7929 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Wed, 29 Jan 2025 16:42:40 -0500 Subject: [PATCH 2/4] fixing scorers --- .../apache/lucene/codecs/hnsw/DefaultFlatVectorScorer.java | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/DefaultFlatVectorScorer.java b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/DefaultFlatVectorScorer.java index bafd2209820e..7f90197ccb3e 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/DefaultFlatVectorScorer.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/DefaultFlatVectorScorer.java @@ -113,7 +113,7 @@ public void setScoringOrdinal(int node) throws IOException { @Override public float score(int node) throws IOException { - return similarityFunction.compare(vector, vectors1.vectorValue(ord)); + return similarityFunction.compare(vector, vectors1.vectorValue(node)); } }; } @@ -133,14 +133,12 @@ public String toString() { private static final class FloatScoringSupplier implements RandomVectorScorerSupplier { private final FloatVectorValues vectors; private final FloatVectorValues vectors1; - private final FloatVectorValues vectors2; private final VectorSimilarityFunction similarityFunction; private FloatScoringSupplier( FloatVectorValues vectors, VectorSimilarityFunction similarityFunction) throws IOException { this.vectors = vectors; vectors1 = vectors.copy(); - vectors2 = vectors.copy(); this.similarityFunction = similarityFunction; } @@ -151,7 +149,7 @@ public UpdateableRandomVectorScorer scorer(int ord) throws IOException { return new UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer(vectors) { @Override public float score(int node) throws IOException { - return similarityFunction.compare(vectors1.vectorValue(ord), vectors2.vectorValue(node)); + return similarityFunction.compare(vector, vectors1.vectorValue(node)); } @Override From 9b7fe53458513e06333b5ed164f5c8ef2a0c005a Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Thu, 30 Jan 2025 09:14:47 -0500 Subject: [PATCH 3/4] iter --- .../codecs/bitvectors/FlatBitVectorsScorer.java | 11 ++++------- .../org/apache/lucene/util/hnsw/HnswGraphBuilder.java | 3 +-- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/lucene/codecs/src/java/org/apache/lucene/codecs/bitvectors/FlatBitVectorsScorer.java b/lucene/codecs/src/java/org/apache/lucene/codecs/bitvectors/FlatBitVectorsScorer.java index 812858c4968f..e182eaa2c4e3 100644 --- a/lucene/codecs/src/java/org/apache/lucene/codecs/bitvectors/FlatBitVectorsScorer.java +++ b/lucene/codecs/src/java/org/apache/lucene/codecs/bitvectors/FlatBitVectorsScorer.java @@ -34,7 +34,6 @@ public class FlatBitVectorsScorer implements FlatVectorsScorer { public RandomVectorScorerSupplier getRandomVectorScorerSupplier( VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues) throws IOException { - assert vectorValues instanceof ByteVectorValues; if (vectorValues instanceof ByteVectorValues byteVectorValues) { return new BitRandomVectorScorerSupplier(byteVectorValues); } @@ -52,7 +51,6 @@ public RandomVectorScorer getRandomVectorScorer( public RandomVectorScorer getRandomVectorScorer( VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, byte[] target) throws IOException { - assert vectorValues instanceof ByteVectorValues; if (vectorValues instanceof ByteVectorValues byteVectorValues) { return new BitRandomVectorScorer(byteVectorValues, target); } @@ -65,7 +63,8 @@ static class BitRandomVectorScorer implements UpdateableRandomVectorScorer { private final byte[] query; BitRandomVectorScorer(ByteVectorValues vectorValues, byte[] query) { - this.query = query; + this.query = new byte[query.length]; + System.arraycopy(query, 0, this.query, 0, query.length); this.bitDimensions = vectorValues.dimension() * Byte.SIZE; this.vectorValues = vectorValues; } @@ -100,23 +99,21 @@ public Bits getAcceptOrds(Bits acceptDocs) { static class BitRandomVectorScorerSupplier implements RandomVectorScorerSupplier { protected final ByteVectorValues vectorValues; protected final ByteVectorValues vectorValues1; - protected final ByteVectorValues vectorValues2; public BitRandomVectorScorerSupplier(ByteVectorValues vectorValues) throws IOException { this.vectorValues = vectorValues; this.vectorValues1 = vectorValues.copy(); - this.vectorValues2 = vectorValues.copy(); } @Override public UpdateableRandomVectorScorer scorer(int ord) throws IOException { byte[] query = vectorValues1.vectorValue(ord); - return new BitRandomVectorScorer(vectorValues2, query); + return new BitRandomVectorScorer(vectorValues1, query); } @Override public RandomVectorScorerSupplier copy() throws IOException { - return new BitRandomVectorScorerSupplier(vectorValues.copy()); + return new BitRandomVectorScorerSupplier(vectorValues); } } diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java index ba8e523adc83..4266660fb146 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java @@ -402,8 +402,7 @@ private static void popToScratch(GraphBuilderKnnCollector candidates, NeighborAr * @param neighbors the neighbors selected so far * @return whether the candidate is diverse given the existing neighbors */ - private boolean diversityCheck( - float score, NeighborArray neighbors, RandomVectorScorer scorer) + private boolean diversityCheck(float score, NeighborArray neighbors, RandomVectorScorer scorer) throws IOException { for (int i = 0; i < neighbors.size(); i++) { float neighborSimilarity = scorer.score(neighbors.nodes()[i]); From 0d1de688fdde3bbe357107053e04e7adf87f0591 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Thu, 30 Jan 2025 14:58:44 -0500 Subject: [PATCH 4/4] addressing pr comments --- .../bitvectors/FlatBitVectorsScorer.java | 17 ++++++---- .../codecs/hnsw/DefaultFlatVectorScorer.java | 28 +++++++++------- .../hnsw/ScalarQuantizedVectorScorer.java | 8 +++-- .../lucene99/Lucene99FlatVectorsWriter.java | 2 +- .../Lucene99ScalarQuantizedVectorScorer.java | 15 +++++---- .../Lucene99ScalarQuantizedVectorsWriter.java | 2 +- .../lucene/util/hnsw/HnswGraphBuilder.java | 16 +++------- .../util/hnsw/RandomVectorScorerSupplier.java | 5 +-- ...MemorySegmentByteVectorScorerSupplier.java | 32 ++++++++++++------- 9 files changed, 69 insertions(+), 56 deletions(-) diff --git a/lucene/codecs/src/java/org/apache/lucene/codecs/bitvectors/FlatBitVectorsScorer.java b/lucene/codecs/src/java/org/apache/lucene/codecs/bitvectors/FlatBitVectorsScorer.java index e182eaa2c4e3..5f6f129504b2 100644 --- a/lucene/codecs/src/java/org/apache/lucene/codecs/bitvectors/FlatBitVectorsScorer.java +++ b/lucene/codecs/src/java/org/apache/lucene/codecs/bitvectors/FlatBitVectorsScorer.java @@ -63,8 +63,7 @@ static class BitRandomVectorScorer implements UpdateableRandomVectorScorer { private final byte[] query; BitRandomVectorScorer(ByteVectorValues vectorValues, byte[] query) { - this.query = new byte[query.length]; - System.arraycopy(query, 0, this.query, 0, query.length); + this.query = query; this.bitDimensions = vectorValues.dimension() * Byte.SIZE; this.vectorValues = vectorValues; } @@ -98,17 +97,21 @@ public Bits getAcceptOrds(Bits acceptDocs) { static class BitRandomVectorScorerSupplier implements RandomVectorScorerSupplier { protected final ByteVectorValues vectorValues; - protected final ByteVectorValues vectorValues1; + protected final ByteVectorValues targetVectors; public BitRandomVectorScorerSupplier(ByteVectorValues vectorValues) throws IOException { this.vectorValues = vectorValues; - this.vectorValues1 = vectorValues.copy(); + this.targetVectors = vectorValues.copy(); } @Override - public UpdateableRandomVectorScorer scorer(int ord) throws IOException { - byte[] query = vectorValues1.vectorValue(ord); - return new BitRandomVectorScorer(vectorValues1, query); + public UpdateableRandomVectorScorer scorer(Integer ord) throws IOException { + byte[] query = new byte[vectorValues.dimension()]; + if (ord == null) { + return new BitRandomVectorScorer(vectorValues, query); + } + System.arraycopy(targetVectors.vectorValue(ord), 0, query, 0, query.length); + return new BitRandomVectorScorer(targetVectors, query); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/DefaultFlatVectorScorer.java b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/DefaultFlatVectorScorer.java index 7f90197ccb3e..5976ca8f89f7 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/DefaultFlatVectorScorer.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/DefaultFlatVectorScorer.java @@ -90,30 +90,32 @@ public String toString() { /** RandomVectorScorerSupplier for bytes vector */ private static final class ByteScoringSupplier implements RandomVectorScorerSupplier { private final ByteVectorValues vectors; - private final ByteVectorValues vectors1; + private final ByteVectorValues targetVectors; private final VectorSimilarityFunction similarityFunction; private ByteScoringSupplier( ByteVectorValues vectors, VectorSimilarityFunction similarityFunction) throws IOException { this.vectors = vectors; - vectors1 = vectors.copy(); + targetVectors = vectors.copy(); this.similarityFunction = similarityFunction; } @Override - public UpdateableRandomVectorScorer scorer(int ord) throws IOException { + public UpdateableRandomVectorScorer scorer(Integer ord) throws IOException { byte[] vector = new byte[vectors.dimension()]; - System.arraycopy(vectors1.vectorValue(ord), 0, vector, 0, vector.length); + if (ord != null) { + System.arraycopy(targetVectors.vectorValue(ord), 0, vector, 0, vector.length); + } return new UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer(vectors) { @Override public void setScoringOrdinal(int node) throws IOException { - System.arraycopy(vectors1.vectorValue(node), 0, vector, 0, vector.length); + System.arraycopy(targetVectors.vectorValue(node), 0, vector, 0, vector.length); } @Override public float score(int node) throws IOException { - return similarityFunction.compare(vector, vectors1.vectorValue(node)); + return similarityFunction.compare(vector, targetVectors.vectorValue(node)); } }; } @@ -132,29 +134,31 @@ public String toString() { /** RandomVectorScorerSupplier for Float vector */ private static final class FloatScoringSupplier implements RandomVectorScorerSupplier { private final FloatVectorValues vectors; - private final FloatVectorValues vectors1; + private final FloatVectorValues targetVectors; private final VectorSimilarityFunction similarityFunction; private FloatScoringSupplier( FloatVectorValues vectors, VectorSimilarityFunction similarityFunction) throws IOException { this.vectors = vectors; - vectors1 = vectors.copy(); + targetVectors = vectors.copy(); this.similarityFunction = similarityFunction; } @Override - public UpdateableRandomVectorScorer scorer(int ord) throws IOException { + public UpdateableRandomVectorScorer scorer(Integer ord) throws IOException { float[] vector = new float[vectors.dimension()]; - System.arraycopy(vectors1.vectorValue(ord), 0, vector, 0, vector.length); + if (ord != null) { + System.arraycopy(targetVectors.vectorValue(ord), 0, vector, 0, vector.length); + } return new UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer(vectors) { @Override public float score(int node) throws IOException { - return similarityFunction.compare(vector, vectors1.vectorValue(node)); + return similarityFunction.compare(vector, targetVectors.vectorValue(node)); } @Override public void setScoringOrdinal(int node) throws IOException { - System.arraycopy(vectors1.vectorValue(node), 0, vector, 0, vector.length); + System.arraycopy(targetVectors.vectorValue(node), 0, vector, 0, vector.length); } }; } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/ScalarQuantizedVectorScorer.java b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/ScalarQuantizedVectorScorer.java index 11f6569dd88d..533a7a54b4d7 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/ScalarQuantizedVectorScorer.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/ScalarQuantizedVectorScorer.java @@ -148,12 +148,14 @@ private ScalarQuantizedRandomVectorScorerSupplier( } @Override - public UpdateableRandomVectorScorer scorer(int ord) throws IOException { + public UpdateableRandomVectorScorer scorer(Integer ord) throws IOException { final QuantizedByteVectorValues vectorsCopy = values.copy(); byte[] queryVector = new byte[values.dimension()]; - System.arraycopy(values.vectorValue(ord), 0, queryVector, 0, queryVector.length); + if (ord != null) { + System.arraycopy(values.vectorValue(ord), 0, queryVector, 0, queryVector.length); + } return new UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer(vectorsCopy) { - float queryOffset = values.getScoreCorrectionConstant(ord); + float queryOffset = ord != null ? values.getScoreCorrectionConstant(ord) : 0; @Override public void setScoringOrdinal(int node) throws IOException { diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsWriter.java index 33426e41aa77..7c45b056c7f5 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsWriter.java @@ -507,7 +507,7 @@ static final class FlatCloseableRandomVectorScorerSupplier } @Override - public UpdateableRandomVectorScorer scorer(int ord) throws IOException { + public UpdateableRandomVectorScorer scorer(Integer ord) throws IOException { return supplier.scorer(ord); } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorScorer.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorScorer.java index 0c3c07703a1a..b4755008a264 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorScorer.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorScorer.java @@ -306,27 +306,30 @@ private static final class ScalarQuantizedRandomVectorScorerSupplier private final VectorSimilarityFunction vectorSimilarityFunction; private final QuantizedByteVectorValues values; - private final QuantizedByteVectorValues values1; + private final QuantizedByteVectorValues targetVectors; public ScalarQuantizedRandomVectorScorerSupplier( QuantizedByteVectorValues values, VectorSimilarityFunction vectorSimilarityFunction) throws IOException { this.values = values; - this.values1 = values.copy(); + this.targetVectors = values.copy(); this.vectorSimilarityFunction = vectorSimilarityFunction; } @Override - public UpdateableRandomVectorScorer scorer(int ord) throws IOException { + public UpdateableRandomVectorScorer scorer(Integer ord) throws IOException { byte[] vectorValue = new byte[values.dimension()]; - System.arraycopy(values1.vectorValue(ord), 0, vectorValue, 0, vectorValue.length); - float offsetCorrection = values1.getScoreCorrectionConstant(ord); + float offsetCorrection = 0; + if (ord != null) { + System.arraycopy(targetVectors.vectorValue(ord), 0, vectorValue, 0, vectorValue.length); + offsetCorrection = targetVectors.getScoreCorrectionConstant(ord); + } return fromVectorSimilarity( vectorValue, offsetCorrection, vectorSimilarityFunction, values.getScalarQuantizer().getConstantMultiplier(), - values1); + targetVectors); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java index 6bd18dd9ca6f..46c0d9fd450b 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java @@ -1128,7 +1128,7 @@ static final class ScalarQuantizedCloseableRandomVectorScorerSupplier } @Override - public UpdateableRandomVectorScorer scorer(int ord) throws IOException { + public UpdateableRandomVectorScorer scorer(Integer ord) throws IOException { return supplier.scorer(ord); } diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java index 4266660fb146..9415f598617a 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java @@ -191,13 +191,9 @@ protected void addVectors(int minOrd, int maxOrd) throws IOException { if (infoStream.isEnabled(HNSW_COMPONENT)) { infoStream.message(HNSW_COMPONENT, "addVectors [" + minOrd + " " + maxOrd + ")"); } - UpdateableRandomVectorScorer scorer = null; + UpdateableRandomVectorScorer scorer = scorerSupplier.scorer(null); for (int node = minOrd; node < maxOrd; node++) { - if (scorer == null) { - scorer = scorerSupplier.scorer(node); - } else { - scorer.setScoringOrdinal(node); - } + scorer.setScoringOrdinal(node); addGraphNode(node, scorer); if ((node % 10000 == 0) && infoStream.isEnabled(HNSW_COMPONENT)) { t = printGraphBuildStatus(node, start, t); @@ -466,7 +462,7 @@ private boolean connectComponents(int level) throws IOException { // while linking GraphBuilderKnnCollector beam = new GraphBuilderKnnCollector(2); int[] eps = new int[1]; - UpdateableRandomVectorScorer scorer = null; + UpdateableRandomVectorScorer scorer = scorerSupplier.scorer(null); for (Component c : components) { if (c != c0) { if (c.start() == NO_MORE_DOCS) { @@ -478,11 +474,7 @@ private boolean connectComponents(int level) throws IOException { beam.clear(); eps[0] = c0.start(); - if (scorer == null) { - scorer = scorerSupplier.scorer(c.start()); - } else { - scorer.setScoringOrdinal(c.start()); - } + scorer.setScoringOrdinal(c.start()); // find the closest node in the largest component to the lowest-numbered node in this // component that has room to make a connection graphSearcher.searchLevel(beam, scorer, level, eps, hnsw, notFullyConnected); diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomVectorScorerSupplier.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomVectorScorerSupplier.java index 5ee84d5594e9..236739cbce5f 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomVectorScorerSupplier.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomVectorScorerSupplier.java @@ -25,10 +25,11 @@ public interface RandomVectorScorerSupplier { * This creates a {@link UpdateableRandomVectorScorer} for scoring random nodes in batches against * the given ordinal. Optionally allowing the ordinal to be updated. * - * @param ord the ordinal of the node to compare + * @param ord the ordinal of the node to compare. If null, the {@link + * UpdateableRandomVectorScorer} needs to be initialized with a valid ordinal before scoring. * @return a new {@link UpdateableRandomVectorScorer} */ - UpdateableRandomVectorScorer scorer(int ord) throws IOException; + UpdateableRandomVectorScorer scorer(Integer ord) throws IOException; /** * Make a copy of the supplier, which will copy the underlying vectorValues so the copy is safe to diff --git a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentByteVectorScorerSupplier.java b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentByteVectorScorerSupplier.java index e93a5d735077..42b291d66cc6 100644 --- a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentByteVectorScorerSupplier.java +++ b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentByteVectorScorerSupplier.java @@ -110,10 +110,12 @@ static final class CosineSupplier extends Lucene99MemorySegmentByteVectorScorerS } @Override - public UpdateableRandomVectorScorer scorer(int ord) { - checkOrdinal(ord); + public UpdateableRandomVectorScorer scorer(Integer ord) { + if (ord != null) { + checkOrdinal(ord); + } return new UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer(values) { - private int queryOrd = ord; + private int queryOrd = ord == null ? 0 : ord; @Override public float score(int node) throws IOException { @@ -144,10 +146,12 @@ static final class DotProductSupplier extends Lucene99MemorySegmentByteVectorSco } @Override - public UpdateableRandomVectorScorer scorer(int ord) { - checkOrdinal(ord); + public UpdateableRandomVectorScorer scorer(Integer ord) { + if (ord != null) { + checkOrdinal(ord); + } return new UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer(values) { - private int queryOrd = ord; + private int queryOrd = ord == null ? 0 : ord; @Override public float score(int node) throws IOException { @@ -179,10 +183,12 @@ static final class EuclideanSupplier extends Lucene99MemorySegmentByteVectorScor } @Override - public UpdateableRandomVectorScorer scorer(int ord) { - checkOrdinal(ord); + public UpdateableRandomVectorScorer scorer(Integer ord) { + if (ord != null) { + checkOrdinal(ord); + } return new UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer(values) { - private int queryOrd = ord; + private int queryOrd = ord == null ? 0 : ord; @Override public float score(int node) throws IOException { @@ -214,10 +220,12 @@ static final class MaxInnerProductSupplier extends Lucene99MemorySegmentByteVect } @Override - public UpdateableRandomVectorScorer scorer(int ord) { - checkOrdinal(ord); + public UpdateableRandomVectorScorer scorer(Integer ord) { + if (ord != null) { + checkOrdinal(ord); + } return new UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer(values) { - private int queryOrd = ord; + private int queryOrd = ord == null ? 0 : ord; @Override public float score(int node) throws IOException {