diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104FlatVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104FlatVectorsFormat.java new file mode 100644 index 000000000000..b54cc438381a --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104FlatVectorsFormat.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.codecs.lucene104; + +import java.io.IOException; +import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.codecs.lucene90.IndexedDISI; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.store.IndexOutput; + +/** + * Lucene 10.4 flat vector format, details TBD. + * + * @lucene.experimental + */ +public final class Lucene104FlatVectorsFormat extends FlatVectorsFormat { + + static final String NAME = "Lucene104FlatVectorsFormat"; + static final String META_CODEC_NAME = "Lucene104FlatVectorsFormatMeta"; + static final String VECTOR_DATA_CODEC_NAME = "Lucene104FlatVectorsFormatData"; + static final String META_EXTENSION = "vemf"; + static final String VECTOR_DATA_EXTENSION = "vec"; + + public static final int VERSION_START = 0; + public static final int VERSION_CURRENT = VERSION_START; + + static final int DIRECT_MONOTONIC_BLOCK_SHIFT = 16; + private final FlatVectorsScorer vectorsScorer; + + /** Constructs a format */ + public Lucene104FlatVectorsFormat(FlatVectorsScorer vectorsScorer) { + super(NAME); + this.vectorsScorer = vectorsScorer; + } + + @Override + public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new Lucene104FlatVectorsWriter(state, vectorsScorer); + } + + @Override + public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException { + return new Lucene104FlatVectorsReader(state, vectorsScorer); + } + + @Override + public String toString() { + return "Lucene104FlatVectorsFormat(" + "vectorsScorer=" + vectorsScorer + ')'; + } +} diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104FlatVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104FlatVectorsReader.java new file mode 100644 index 000000000000..4e46f987ef51 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104FlatVectorsReader.java @@ -0,0 +1,343 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.codecs.lucene104; + +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readSimilarityFunction; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readVectorEncoding; + +import java.io.IOException; +import java.util.Map; +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues; +import org.apache.lucene.codecs.lucene95.OffHeapFloatVectorValues; +import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.CorruptIndexException; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FieldInfos; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.internal.hppc.IntObjectHashMap; +import org.apache.lucene.store.ChecksumIndexInput; +import org.apache.lucene.store.DataAccessHint; +import org.apache.lucene.store.FileDataHint; +import org.apache.lucene.store.FileTypeHint; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.IOUtils; +import org.apache.lucene.util.RamUsageEstimator; +import org.apache.lucene.util.hnsw.RandomVectorScorer; + +/** + * Reads vectors from the index segments. + * + * @lucene.experimental + */ +public final class Lucene104FlatVectorsReader extends FlatVectorsReader { + + private static final long SHALLOW_SIZE = + RamUsageEstimator.shallowSizeOfInstance(Lucene104FlatVectorsFormat.class); + + private final IntObjectHashMap fields = new IntObjectHashMap<>(); + private final IndexInput vectorData; + private final FieldInfos fieldInfos; + private final IOContext dataContext; + + public Lucene104FlatVectorsReader(SegmentReadState state, FlatVectorsScorer scorer) + throws IOException { + super(scorer); + int versionMeta = readMetadata(state); + this.fieldInfos = state.fieldInfos; + // Flat formats are used to randomly access vectors from their node ID that is stored + // in the HNSW graph. + dataContext = + state.context.withHints(FileTypeHint.DATA, FileDataHint.KNN_VECTORS, DataAccessHint.RANDOM); + try { + vectorData = + openDataInput( + state, + versionMeta, + Lucene104FlatVectorsFormat.VECTOR_DATA_EXTENSION, + Lucene104FlatVectorsFormat.VECTOR_DATA_CODEC_NAME, + dataContext); + } catch (Throwable t) { + IOUtils.closeWhileSuppressingExceptions(t, this); + throw t; + } + } + + private int readMetadata(SegmentReadState state) throws IOException { + String metaFileName = + IndexFileNames.segmentFileName( + state.segmentInfo.name, state.segmentSuffix, Lucene104FlatVectorsFormat.META_EXTENSION); + int versionMeta = -1; + try (ChecksumIndexInput meta = state.directory.openChecksumInput(metaFileName)) { + Throwable priorE = null; + try { + versionMeta = + CodecUtil.checkIndexHeader( + meta, + Lucene104FlatVectorsFormat.META_CODEC_NAME, + Lucene104FlatVectorsFormat.VERSION_START, + Lucene104FlatVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix); + readFields(meta, state.fieldInfos); + } catch (Throwable exception) { + priorE = exception; + } finally { + CodecUtil.checkFooter(meta, priorE); + } + } + return versionMeta; + } + + private static IndexInput openDataInput( + SegmentReadState state, + int versionMeta, + String fileExtension, + String codecName, + IOContext context) + throws IOException { + String fileName = + IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, fileExtension); + IndexInput in = state.directory.openInput(fileName, context); + try { + int versionVectorData = + CodecUtil.checkIndexHeader( + in, + codecName, + Lucene104FlatVectorsFormat.VERSION_START, + Lucene104FlatVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix); + if (versionMeta != versionVectorData) { + throw new CorruptIndexException( + "Format versions mismatch: meta=" + + versionMeta + + ", " + + codecName + + "=" + + versionVectorData, + in); + } + CodecUtil.retrieveChecksum(in); + return in; + } catch (Throwable t) { + IOUtils.closeWhileSuppressingExceptions(t, in); + throw t; + } + } + + private void readFields(ChecksumIndexInput meta, FieldInfos infos) throws IOException { + for (int fieldNumber = meta.readInt(); fieldNumber != -1; fieldNumber = meta.readInt()) { + FieldInfo info = infos.fieldInfo(fieldNumber); + if (info == null) { + throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta); + } + FieldEntry fieldEntry = FieldEntry.create(meta, info); + fields.put(info.number, fieldEntry); + } + } + + @Override + public long ramBytesUsed() { + return Lucene104FlatVectorsReader.SHALLOW_SIZE + fields.ramBytesUsed(); + } + + @Override + public Map getOffHeapByteSize(FieldInfo fieldInfo) { + return Map.of(Lucene104FlatVectorsFormat.VECTOR_DATA_EXTENSION, 0L); + } + + @Override + public void checkIntegrity() throws IOException { + CodecUtil.checksumEntireFile(vectorData); + } + + @Override + public FlatVectorsReader getMergeInstance() throws IOException { + // Update the read advice since vectors are guaranteed to be accessed sequentially for merge + vectorData.updateIOContext(dataContext.withHints(DataAccessHint.SEQUENTIAL)); + return this; + } + + private FieldEntry getFieldEntryOrThrow(String field) { + final FieldInfo info = fieldInfos.fieldInfo(field); + final FieldEntry entry; + if (info == null || (entry = fields.get(info.number)) == null) { + throw new IllegalArgumentException("field=\"" + field + "\" not found"); + } + return entry; + } + + private FieldEntry getFieldEntry(String field, VectorEncoding expectedEncoding) { + final FieldEntry fieldEntry = getFieldEntryOrThrow(field); + if (fieldEntry.vectorEncoding != expectedEncoding) { + throw new IllegalArgumentException( + "field=\"" + + field + + "\" is encoded as: " + + fieldEntry.vectorEncoding + + " expected: " + + expectedEncoding); + } + return fieldEntry; + } + + @Override + public FloatVectorValues getFloatVectorValues(String field) throws IOException { + final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.FLOAT32); + return OffHeapFloatVectorValues.load( + fieldEntry.similarityFunction, + vectorScorer, + fieldEntry.ordToDoc, + fieldEntry.vectorEncoding, + fieldEntry.dimension, + fieldEntry.vectorDataOffset, + fieldEntry.vectorDataLength, + vectorData, + fieldEntry.offsets); + } + + @Override + public ByteVectorValues getByteVectorValues(String field) throws IOException { + final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.BYTE); + return OffHeapByteVectorValues.load( + fieldEntry.similarityFunction, + vectorScorer, + fieldEntry.ordToDoc, + fieldEntry.vectorEncoding, + fieldEntry.dimension, + fieldEntry.vectorDataOffset, + fieldEntry.vectorDataLength, + vectorData, + fieldEntry.offsets); + } + + @Override + public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException { + final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.FLOAT32); + return vectorScorer.getRandomVectorScorer( + fieldEntry.similarityFunction, + OffHeapFloatVectorValues.load( + fieldEntry.similarityFunction, + vectorScorer, + fieldEntry.ordToDoc, + fieldEntry.vectorEncoding, + fieldEntry.dimension, + fieldEntry.vectorDataOffset, + fieldEntry.vectorDataLength, + vectorData, + fieldEntry.offsets), + target); + } + + @Override + public RandomVectorScorer getRandomVectorScorer(String field, byte[] target) throws IOException { + final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.BYTE); + return vectorScorer.getRandomVectorScorer( + fieldEntry.similarityFunction, + OffHeapByteVectorValues.load( + fieldEntry.similarityFunction, + vectorScorer, + fieldEntry.ordToDoc, + fieldEntry.vectorEncoding, + fieldEntry.dimension, + fieldEntry.vectorDataOffset, + fieldEntry.vectorDataLength, + vectorData, + fieldEntry.offsets), + target); + } + + @Override + public void finishMerge() throws IOException { + // This makes sure that the access pattern hint is reverted back since HNSW implementation + // needs it + vectorData.updateIOContext(dataContext); + } + + @Override + public void close() throws IOException { + IOUtils.close(vectorData); + } + + private record FieldEntry( + VectorSimilarityFunction similarityFunction, + VectorEncoding vectorEncoding, + long vectorDataOffset, + long vectorDataLength, + int dimension, + int size, + OrdToDocDISIReaderConfiguration ordToDoc, + FieldInfo info, + long[] offsets) { + + FieldEntry { + if (similarityFunction != info.getVectorSimilarityFunction()) { + throw new IllegalStateException( + "Inconsistent vector similarity function for field=\"" + + info.name + + "\"; " + + similarityFunction + + " != " + + info.getVectorSimilarityFunction()); + } + int infoVectorDimension = info.getVectorDimension(); + if (infoVectorDimension != dimension) { + throw new IllegalStateException( + "Inconsistent vector dimension for field=\"" + + info.name + + "\"; " + + infoVectorDimension + + " != " + + dimension); + } + } + + static FieldEntry create(IndexInput input, FieldInfo info) throws IOException { + final VectorEncoding vectorEncoding = readVectorEncoding(input); + final VectorSimilarityFunction similarityFunction = readSimilarityFunction(input); + final var dimension = input.readVInt(); + final var vectorDataOffset = input.readVLong(); + final var vectorDataLength = input.readVLong(); + + final var offsets = new long[input.readInt()]; + input.readLongs(offsets, 0, offsets.length); + + final var size = input.readInt(); + final var ordToDoc = OrdToDocDISIReaderConfiguration.fromStoredMeta(input, size); + return new FieldEntry( + similarityFunction, + vectorEncoding, + vectorDataOffset, + vectorDataLength, + dimension, + size, + ordToDoc, + info, + offsets); + } + } +} diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104FlatVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104FlatVectorsWriter.java new file mode 100644 index 000000000000..9e00f65dda7b --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104FlatVectorsWriter.java @@ -0,0 +1,479 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.codecs.lucene104; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; +import org.apache.lucene.index.DocsWithFieldSet; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.MergeState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.index.Sorter; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.util.ArrayUtil; +import org.apache.lucene.util.IOUtils; +import org.apache.lucene.util.RamUsageEstimator; +import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier; + +/** + * Writes vector values to index segments. + * + * @lucene.experimental + */ +public final class Lucene104FlatVectorsWriter extends FlatVectorsWriter { + + private static final long SHALLOW_RAM_BYTES_USED = + RamUsageEstimator.shallowSizeOfInstance(Lucene104FlatVectorsWriter.class); + + private final IndexOutput meta, vectorData; + + private final List> fields = new ArrayList<>(); + private boolean finished; + + public Lucene104FlatVectorsWriter(SegmentWriteState state, FlatVectorsScorer scorer) + throws IOException { + super(scorer); + String metaFileName = + IndexFileNames.segmentFileName( + state.segmentInfo.name, state.segmentSuffix, Lucene104FlatVectorsFormat.META_EXTENSION); + + String vectorDataFileName = + IndexFileNames.segmentFileName( + state.segmentInfo.name, + state.segmentSuffix, + Lucene104FlatVectorsFormat.VECTOR_DATA_EXTENSION); + + try { + meta = state.directory.createOutput(metaFileName, state.context); + vectorData = state.directory.createOutput(vectorDataFileName, state.context); + + CodecUtil.writeIndexHeader( + meta, + Lucene104FlatVectorsFormat.META_CODEC_NAME, + Lucene104FlatVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix); + CodecUtil.writeIndexHeader( + vectorData, + Lucene104FlatVectorsFormat.VECTOR_DATA_CODEC_NAME, + Lucene104FlatVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix); + } catch (Throwable t) { + IOUtils.closeWhileSuppressingExceptions(t, this); + throw t; + } + } + + @Override + public FlatFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException { + FieldWriter newField = FieldWriter.create(fieldInfo); + fields.add(newField); + return newField; + } + + /** + * This function somewhat transposes the .vec and .vem files before flushing to disk. + * + *

Today, the .vec file is partitioned per-field, and looks like: + * + *

+   *   # field 1 begin:
+   *   (vector for field 1, document d1) # position x0
+   *   (vector for field 1, document d2)
+   *   (vector for field 1, document d3)
+   *   # field 1 end, field 2 begin:
+   *   (vector for field 2, document d1) # position x1
+   *   (vector for field 2, document d3)
+   *   # field 2 end, field 3 begin:
+   *   (vector for field 3, document d1) # position x2
+   *   (vector for field 3, document d2)
+   *   # field 3 end, and so on...
+   * 
+ * + *

The .vem file contains per-field tuples to denote (position, length) of the corresponding + * vector "block": + * + *

+   *   # (field number, offset of vector "block", length of vector "block")
+   *   # "..." represents other metadata, including dimension, ord -> doc mapping, etc.
+   *   (1, x0, x1 - x0, ...)
+   *   (2, x1, x2 - x1, ...)
+   *   # and so on...
+   * 
+ * + *

This function changes the .vec to be partitioned per-document instead, something like: + * + *

+   *   # document d1 begin:
+   *   (vector for field 1, document d1) # position x0
+   *   (vector for field 2, document d1) # position x1
+   *   (vector for field 3, document d1) # position x2
+   *   # document d1 end, document 2 begin:
+   *   (vector for field 1, document d2) # position x3
+   *   (vector for field 3, document d2) # position x4
+   *   # document d2 end, document 3 begin:
+   *   (vector for field 1, document d3) # position x5
+   *   (vector for field 2, document d3) # position x6
+   *   # document d3 end, and so on...
+   * 
+ * + *

Correspondingly, the .vem file will contain per-field mappings from ord -> position in the + * raw file: + * + *

+   *   # (field number, ord -> position mapping as array, ...)
+   *   # "..." represents other metadata, including dimension, ord -> doc mapping, etc. which is unchanged
+   *   (1, [x0, x3, x5], ...) # {ord 0 -> position x0, ord 1 -> position x3, ord 2 -> position x5}
+   *   (2, [x1, x4], ...) # {ord 0 -> position x1, ord 1 -> position x4}
+   *   (3, [x2, x6], ...) # {ord 0 -> position x2, ord 1 -> position x6}
+   *   # and so on...
+   * 
+ * + *

This is done so that in case of duplicate vectors within a document, we can simply + * "point" to the pre-existing vector! + */ + @Override + public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { + // TODO *very* crude! + + // index sort not supported for now + if (sortMap != null) { + throw new UnsupportedOperationException(); + } + + int numFields = fields.size(); + + // offsets[i][j] denotes offset of vector with ord j of field i + long[][] offsets = new long[numFields][]; + + // iterator over ord, docid, vector + KnnVectorValues.DocIndexIterator[] iterators = new KnnVectorValues.DocIndexIterator[numFields]; + for (int i = 0; i < numFields; i++) { + FieldWriter writer = fields.get(i); + + offsets[i] = new long[writer.docsWithField.cardinality()]; + + DocIdSetIterator iterator = + Objects.requireNonNullElse(writer.docsWithField.iterator(), DocIdSetIterator.empty()); + iterators[i] = KnnVectorValuesPublicAccess.fromDISILocal(iterator); + + // initialize iteration + iterators[i].nextDoc(); + } + + // get first offset + long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES); + + long bytesWritten = 0; + ByteBuffer buffer = ByteBuffer.allocate(8192).order(ByteOrder.LITTLE_ENDIAN); + + // Go over all documents one by one + for (int i = 0; i < maxDoc; i++) { + // Record positions of ALL vectors in document i + Map byteOffsets = new HashMap<>(); + Map floatOffsets = new HashMap<>(); + + for (int j = 0; j < numFields; j++) { + // If field j contains a vector for document i + if (iterators[j].docID() == i) { + FieldWriter fieldWriter = fields.get(j); + int ord = iterators[j].index(); + + switch (fieldWriter.fieldInfo.getVectorEncoding()) { + case BYTE -> { + byte[] bytes = (byte[]) fieldWriter.vectors.get(ord); + ByteVector vector = new ByteVector(bytes); + + // Check if we saw the vector earlier + Long lookup = byteOffsets.get(vector); + if (lookup == null) { // If the vector is new + // Record offset + offsets[j][ord] = bytesWritten; + byteOffsets.put(vector, bytesWritten); + + // Write the vector + int vectorByteLength = bytes.length; + vectorData.writeBytes(bytes, vectorByteLength); + bytesWritten += vectorByteLength; + } else { // If the vector has been encountered before + // Simply "point" to the older offset + offsets[j][ord] = lookup; + } + } + case FLOAT32 -> { + float[] floats = (float[]) fieldWriter.vectors.get(ord); + FloatVector vector = new FloatVector(floats); + + // Check if we saw the vector earlier + Long lookup = floatOffsets.get(vector); + if (lookup == null) { // If the vector is new + // Record offset + offsets[j][ord] = bytesWritten; + floatOffsets.put(vector, bytesWritten); + + // Write the vector + int vectorByteLength = floats.length * Float.BYTES; + buffer.asFloatBuffer().put(floats); + vectorData.writeBytes(buffer.array(), vectorByteLength); + bytesWritten += vectorByteLength; + } else { // If the vector has been encountered before + // Simply "point" to the older offset + offsets[j][ord] = lookup; + } + } + default -> throw new IllegalArgumentException(); + } + + // Increment per-field iterator + iterators[j].nextDoc(); + } + } + } + + // Finally write per-field metadata + for (int i = 0; i < numFields; i++) { + FieldWriter writer = fields.get(i); + writeMeta( + writer.fieldInfo, + maxDoc, + vectorDataOffset, + bytesWritten, + offsets[i], + writer.docsWithField); + } + } + + record ByteVector(byte[] vector) { + @Override + public int hashCode() { + return Arrays.hashCode(vector); + } + + @Override + public boolean equals(Object obj) { + if (obj instanceof ByteVector(byte[] bytes)) { + return Arrays.equals(vector, bytes); + } + return false; + } + } + + record FloatVector(float[] vector) { + @Override + public int hashCode() { + return Arrays.hashCode(vector); + } + + @Override + public boolean equals(Object obj) { + if (obj instanceof FloatVector(float[] floats)) { + return Arrays.equals(vector, floats); + } + return false; + } + } + + private void writeMeta( + FieldInfo field, + int maxDoc, + long vectorDataOffset, + long vectorDataLength, + long[] offsets, + DocsWithFieldSet docsWithField) + throws IOException { + meta.writeInt(field.number); + meta.writeInt(field.getVectorEncoding().ordinal()); + meta.writeInt(field.getVectorSimilarityFunction().ordinal()); + meta.writeVInt(field.getVectorDimension()); + meta.writeVLong(vectorDataOffset); + meta.writeVLong(vectorDataLength); + + // write offsets (ord -> position mapping) + meta.writeInt(offsets.length); + byte[] buffer = new byte[offsets.length * Long.BYTES]; + ByteBuffer.wrap(buffer).order(ByteOrder.LITTLE_ENDIAN).asLongBuffer().put(offsets); + meta.writeBytes(buffer, buffer.length); + + // write docIDs + int count = docsWithField.cardinality(); + meta.writeInt(count); + OrdToDocDISIReaderConfiguration.writeStoredMeta( + Lucene104FlatVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT, + meta, + vectorData, + count, + maxDoc, + docsWithField); + } + + @Override + public void finish() throws IOException { + if (finished) { + throw new IllegalStateException("already finished"); + } + finished = true; + if (meta != null) { + // write end of fields marker + meta.writeInt(-1); + CodecUtil.writeFooter(meta); + } + if (vectorData != null) { + CodecUtil.writeFooter(vectorData); + } + } + + @Override + public long ramBytesUsed() { + long total = SHALLOW_RAM_BYTES_USED; + for (FieldWriter field : fields) { + total += field.ramBytesUsed(); + } + return total; + } + + @Override + public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException { + throw new UnsupportedOperationException(); // TODO + } + + @Override + public CloseableRandomVectorScorerSupplier mergeOneFieldToIndex( + FieldInfo fieldInfo, MergeState mergeState) { + throw new UnsupportedOperationException(); // TODO + } + + @Override + public void close() throws IOException { + IOUtils.close(meta, vectorData); + } + + private abstract static class FieldWriter extends FlatFieldVectorsWriter { + private static final long SHALLOW_RAM_BYTES_USED = + RamUsageEstimator.shallowSizeOfInstance(FieldWriter.class); + private final FieldInfo fieldInfo; + private final DocsWithFieldSet docsWithField; + private final List vectors; + private boolean finished; + + private int lastDocID = -1; + + static FieldWriter create(FieldInfo fieldInfo) { + int dim = fieldInfo.getVectorDimension(); + return switch (fieldInfo.getVectorEncoding()) { + case BYTE -> + new FieldWriter(fieldInfo) { + @Override + public byte[] copyValue(byte[] value) { + return ArrayUtil.copyOfSubArray(value, 0, dim); + } + }; + case FLOAT32 -> + new FieldWriter(fieldInfo) { + @Override + public float[] copyValue(float[] value) { + return ArrayUtil.copyOfSubArray(value, 0, dim); + } + }; + }; + } + + FieldWriter(FieldInfo fieldInfo) { + super(); + this.fieldInfo = fieldInfo; + this.docsWithField = new DocsWithFieldSet(); + vectors = new ArrayList<>(); + } + + @Override + public void addValue(int docID, T vectorValue) { + if (finished) { + throw new IllegalStateException("already finished, cannot add more values"); + } + if (docID == lastDocID) { + throw new IllegalArgumentException( + "VectorValuesField \"" + + fieldInfo.name + + "\" appears more than once in this document (only one value is allowed per field)"); + } + assert docID > lastDocID; + T copy = copyValue(vectorValue); + docsWithField.add(docID); + vectors.add(copy); + lastDocID = docID; + } + + @Override + public long ramBytesUsed() { + long size = SHALLOW_RAM_BYTES_USED; + if (vectors.size() == 0) return size; + return size + + docsWithField.ramBytesUsed() + + (long) vectors.size() + * (RamUsageEstimator.NUM_BYTES_OBJECT_REF + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER) + + (long) vectors.size() + * fieldInfo.getVectorDimension() + * fieldInfo.getVectorEncoding().byteSize; + } + + @Override + public List getVectors() { + return vectors; + } + + @Override + public DocsWithFieldSet getDocsWithFieldSet() { + return docsWithField; + } + + @Override + public void finish() { + if (finished) { + return; + } + this.finished = true; + } + + @Override + public boolean isFinished() { + return finished; + } + } + + /** Utility class to expose {@link KnnVectorValues#fromDISI(DocIdSetIterator)} */ + private abstract static class KnnVectorValuesPublicAccess extends KnnVectorValues { + private static DocIndexIterator fromDISILocal(DocIdSetIterator iterator) { + return KnnVectorValues.fromDISI(iterator); + } + } +} diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java index 1e78c8ea7aa2..fb69149b6b07 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java @@ -44,6 +44,7 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues implement protected final int byteSize; protected final VectorSimilarityFunction similarityFunction; protected final FlatVectorsScorer flatVectorsScorer; + protected final long[] offsets; OffHeapByteVectorValues( int dimension, @@ -51,7 +52,8 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues implement IndexInput slice, int byteSize, FlatVectorsScorer flatVectorsScorer, - VectorSimilarityFunction similarityFunction) { + VectorSimilarityFunction similarityFunction, + long[] offsets) { this.dimension = dimension; this.size = size; this.slice = slice; @@ -60,6 +62,7 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues implement binaryValue = byteBuffer.array(); this.similarityFunction = similarityFunction; this.flatVectorsScorer = flatVectorsScorer; + this.offsets = offsets; } @Override @@ -86,8 +89,17 @@ public IndexInput getSlice() { return slice; } + @Override + public long address(int ord) { + if (offsets == null) { + return (long) ord * byteSize; + } else { + return offsets[ord]; + } + } + private void readValue(int targetOrd) throws IOException { - slice.seek((long) targetOrd * byteSize); + slice.seek(address(targetOrd)); slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize); } @@ -101,6 +113,29 @@ public static OffHeapByteVectorValues load( long vectorDataLength, IndexInput vectorData) throws IOException { + return load( + vectorSimilarityFunction, + flatVectorsScorer, + configuration, + vectorEncoding, + dimension, + vectorDataOffset, + vectorDataLength, + vectorData, + null); + } + + public static OffHeapByteVectorValues load( + VectorSimilarityFunction vectorSimilarityFunction, + FlatVectorsScorer flatVectorsScorer, + OrdToDocDISIReaderConfiguration configuration, + VectorEncoding vectorEncoding, + int dimension, + long vectorDataOffset, + long vectorDataLength, + IndexInput vectorData, + long[] offsets) + throws IOException { if (configuration.isEmpty() || vectorEncoding != VectorEncoding.BYTE) { return new EmptyOffHeapVectorValues(dimension, flatVectorsScorer, vectorSimilarityFunction); } @@ -112,7 +147,8 @@ public static OffHeapByteVectorValues load( bytesSlice, dimension, flatVectorsScorer, - vectorSimilarityFunction); + vectorSimilarityFunction, + offsets); } else { return new SparseOffHeapVectorValues( configuration, @@ -121,7 +157,8 @@ public static OffHeapByteVectorValues load( dimension, dimension, flatVectorsScorer, - vectorSimilarityFunction); + vectorSimilarityFunction, + offsets); } } @@ -137,13 +174,24 @@ public DenseOffHeapVectorValues( int byteSize, FlatVectorsScorer flatVectorsScorer, VectorSimilarityFunction vectorSimilarityFunction) { - super(dimension, size, slice, byteSize, flatVectorsScorer, vectorSimilarityFunction); + super(dimension, size, slice, byteSize, flatVectorsScorer, vectorSimilarityFunction, null); + } + + public DenseOffHeapVectorValues( + int dimension, + int size, + IndexInput slice, + int byteSize, + FlatVectorsScorer flatVectorsScorer, + VectorSimilarityFunction vectorSimilarityFunction, + long[] offsets) { + super(dimension, size, slice, byteSize, flatVectorsScorer, vectorSimilarityFunction, offsets); } @Override public DenseOffHeapVectorValues copy() throws IOException { return new DenseOffHeapVectorValues( - dimension, size, slice.clone(), byteSize, flatVectorsScorer, similarityFunction); + dimension, size, slice.clone(), byteSize, flatVectorsScorer, similarityFunction, offsets); } @Override @@ -190,7 +238,8 @@ public SparseOffHeapVectorValues( int dimension, int byteSize, FlatVectorsScorer flatVectorsScorer, - VectorSimilarityFunction vectorSimilarityFunction) + VectorSimilarityFunction vectorSimilarityFunction, + long[] offsets) throws IOException { super( @@ -199,7 +248,8 @@ public SparseOffHeapVectorValues( slice, byteSize, flatVectorsScorer, - vectorSimilarityFunction); + vectorSimilarityFunction, + offsets); this.configuration = configuration; final RandomAccessInput addressesData = dataIn.randomAccessSlice(configuration.addressesOffset, configuration.addressesLength); @@ -224,7 +274,8 @@ public SparseOffHeapVectorValues copy() throws IOException { dimension, byteSize, flatVectorsScorer, - similarityFunction); + similarityFunction, + offsets); } @Override @@ -280,7 +331,7 @@ public EmptyOffHeapVectorValues( int dimension, FlatVectorsScorer flatVectorsScorer, VectorSimilarityFunction vectorSimilarityFunction) { - super(dimension, 0, null, 0, flatVectorsScorer, vectorSimilarityFunction); + super(dimension, 0, null, 0, flatVectorsScorer, vectorSimilarityFunction, null); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java index b05aeb20347a..3ea1e52221d4 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java @@ -46,6 +46,7 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues impleme protected final float[] value; protected final VectorSimilarityFunction similarityFunction; protected final FlatVectorsScorer flatVectorsScorer; + protected final long[] offsets; OffHeapFloatVectorValues( int dimension, @@ -53,7 +54,8 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues impleme IndexInput slice, int byteSize, FlatVectorsScorer flatVectorsScorer, - VectorSimilarityFunction similarityFunction) { + VectorSimilarityFunction similarityFunction, + long[] offsets) { this.dimension = dimension; this.size = size; this.slice = slice; @@ -61,6 +63,7 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues impleme this.similarityFunction = similarityFunction; this.flatVectorsScorer = flatVectorsScorer; value = new float[dimension]; + this.offsets = offsets; } @Override @@ -78,12 +81,21 @@ public IndexInput getSlice() { return slice; } + @Override + public long address(int ord) { + if (offsets == null) { + return (long) ord * byteSize; + } else { + return offsets[ord]; + } + } + @Override public float[] vectorValue(int targetOrd) throws IOException { if (lastOrd == targetOrd) { return value; } - slice.seek((long) targetOrd * byteSize); + slice.seek(address(targetOrd)); slice.readFloats(value, 0, value.length); lastOrd = targetOrd; return value; @@ -99,6 +111,29 @@ public static OffHeapFloatVectorValues load( long vectorDataLength, IndexInput vectorData) throws IOException { + return load( + vectorSimilarityFunction, + flatVectorsScorer, + configuration, + vectorEncoding, + dimension, + vectorDataOffset, + vectorDataLength, + vectorData, + null); + } + + public static OffHeapFloatVectorValues load( + VectorSimilarityFunction vectorSimilarityFunction, + FlatVectorsScorer flatVectorsScorer, + OrdToDocDISIReaderConfiguration configuration, + VectorEncoding vectorEncoding, + int dimension, + long vectorDataOffset, + long vectorDataLength, + IndexInput vectorData, + long[] offsets) + throws IOException { if (configuration.docsWithFieldOffset == -2 || vectorEncoding != VectorEncoding.FLOAT32) { return new EmptyOffHeapVectorValues(dimension, flatVectorsScorer, vectorSimilarityFunction); } @@ -111,7 +146,8 @@ public static OffHeapFloatVectorValues load( bytesSlice, byteSize, flatVectorsScorer, - vectorSimilarityFunction); + vectorSimilarityFunction, + offsets); } else { return new SparseOffHeapVectorValues( configuration, @@ -120,7 +156,8 @@ public static OffHeapFloatVectorValues load( dimension, byteSize, flatVectorsScorer, - vectorSimilarityFunction); + vectorSimilarityFunction, + offsets); } } @@ -137,13 +174,24 @@ public DenseOffHeapVectorValues( int byteSize, FlatVectorsScorer flatVectorsScorer, VectorSimilarityFunction similarityFunction) { - super(dimension, size, slice, byteSize, flatVectorsScorer, similarityFunction); + super(dimension, size, slice, byteSize, flatVectorsScorer, similarityFunction, null); + } + + public DenseOffHeapVectorValues( + int dimension, + int size, + IndexInput slice, + int byteSize, + FlatVectorsScorer flatVectorsScorer, + VectorSimilarityFunction similarityFunction, + long[] offsets) { + super(dimension, size, slice, byteSize, flatVectorsScorer, similarityFunction, offsets); } @Override public DenseOffHeapVectorValues copy() throws IOException { return new DenseOffHeapVectorValues( - dimension, size, slice.clone(), byteSize, flatVectorsScorer, similarityFunction); + dimension, size, slice.clone(), byteSize, flatVectorsScorer, similarityFunction, offsets); } @Override @@ -219,10 +267,18 @@ public SparseOffHeapVectorValues( int dimension, int byteSize, FlatVectorsScorer flatVectorsScorer, - VectorSimilarityFunction similarityFunction) + VectorSimilarityFunction similarityFunction, + long[] offsets) throws IOException { - super(dimension, configuration.size, slice, byteSize, flatVectorsScorer, similarityFunction); + super( + dimension, + configuration.size, + slice, + byteSize, + flatVectorsScorer, + similarityFunction, + offsets); this.configuration = configuration; final RandomAccessInput addressesData = dataIn.randomAccessSlice(configuration.addressesOffset, configuration.addressesLength); @@ -247,7 +303,8 @@ public SparseOffHeapVectorValues copy() throws IOException { dimension, byteSize, flatVectorsScorer, - similarityFunction); + similarityFunction, + offsets); } @Override @@ -341,7 +398,7 @@ public EmptyOffHeapVectorValues( int dimension, FlatVectorsScorer flatVectorsScorer, VectorSimilarityFunction similarityFunction) { - super(dimension, 0, null, 0, flatVectorsScorer, similarityFunction); + super(dimension, 0, null, 0, flatVectorsScorer, similarityFunction, null); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsFormat.java index d265d9d29329..b015f7bd9535 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsFormat.java @@ -24,6 +24,7 @@ import org.apache.lucene.codecs.KnnVectorsWriter; import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil; import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.codecs.lucene104.Lucene104FlatVectorsFormat; import org.apache.lucene.index.MergePolicy; import org.apache.lucene.index.MergeScheduler; import org.apache.lucene.index.SegmentReadState; @@ -150,7 +151,7 @@ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat { /** The format for storing, reading, and merging vectors on disk. */ private static final FlatVectorsFormat flatVectorsFormat = - new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer()); + new Lucene104FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer()); private final int numMergeWorkers; private final TaskExecutor mergeExec; diff --git a/lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java b/lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java index 8e58f387a334..53ab9f5aeb9c 100644 --- a/lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java @@ -55,6 +55,14 @@ public int ordToDoc(int ord) { */ public abstract KnnVectorValues copy() throws IOException; + /** + * Temporary function to expose the address of a specific vector ordinal. This is needed because + * the assumption of address being ord * vectorByteSize no longer holds, and can now be arbitrary. + */ + public long address(int ord) { + throw new UnsupportedOperationException(); + } + /** Returns the vector byte length, defaults to dimension multiplied by float byte size */ public int getVectorByteLength() { return dimension() * getEncoding().byteSize; diff --git a/lucene/core/src/java25/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentByteVectorScorer.java b/lucene/core/src/java25/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentByteVectorScorer.java index a8799c25a30a..421a678c08ef 100644 --- a/lucene/core/src/java25/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentByteVectorScorer.java +++ b/lucene/core/src/java25/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentByteVectorScorer.java @@ -32,6 +32,7 @@ abstract sealed class Lucene99MemorySegmentByteVectorScorer final int vectorByteSize; final MemorySegmentAccessInput input; + final KnnVectorValues values; final byte[] query; byte[] scratch; @@ -61,12 +62,13 @@ public static Optional create( super(values); this.input = input; this.vectorByteSize = values.getVectorByteLength(); + this.values = values; this.query = queryVector; } final MemorySegment getSegment(int ord) throws IOException { checkOrdinal(ord); - long byteOffset = (long) ord * vectorByteSize; + long byteOffset = values.address(ord); MemorySegment seg = input.segmentSliceOrNull(byteOffset, vectorByteSize); if (seg == null) { if (scratch == null) { diff --git a/lucene/core/src/java25/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentFloatVectorScorer.java b/lucene/core/src/java25/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentFloatVectorScorer.java index d3f08f1c47b2..2fd5c26abe33 100644 --- a/lucene/core/src/java25/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentFloatVectorScorer.java +++ b/lucene/core/src/java25/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentFloatVectorScorer.java @@ -31,7 +31,6 @@ abstract sealed class Lucene99MemorySegmentFloatVectorScorer extends RandomVectorScorer.AbstractRandomVectorScorer { final FloatVectorValues values; - final int vectorByteSize; final MemorySegment seg; final float[] query; final float[] scratchScores = new float[4]; @@ -63,7 +62,6 @@ public static Optional create( super(values); this.values = values; this.seg = seg; - this.vectorByteSize = values.getVectorByteLength(); this.query = query; } @@ -85,10 +83,10 @@ public float bulkScore(int[] nodes, float[] scores, int numNodes) throws IOExcep final int limit = numNodes & ~3; float maxScore = Float.NEGATIVE_INFINITY; for (; i < limit; i += 4) { - long offset1 = (long) nodes[i] * vectorByteSize; - long offset2 = (long) nodes[i + 1] * vectorByteSize; - long offset3 = (long) nodes[i + 2] * vectorByteSize; - long offset4 = (long) nodes[i + 3] * vectorByteSize; + long offset1 = values.address(nodes[i]); + long offset2 = values.address(nodes[i + 1]); + long offset3 = values.address(nodes[i + 2]); + long offset4 = values.address(nodes[i + 3]); vectorOp(seg, scratchScores, offset1, offset2, offset3, offset4, query.length); scores[i + 0] = normalizeRawScore(scratchScores[0]); maxScore = Math.max(maxScore, scores[i + 0]); @@ -102,9 +100,9 @@ public float bulkScore(int[] nodes, float[] scores, int numNodes) throws IOExcep // Handle remaining 1–3 nodes in bulk (if any) int remaining = numNodes - i; if (remaining > 0) { - long addr1 = (long) nodes[i] * vectorByteSize; - long addr2 = (remaining > 1) ? (long) nodes[i + 1] * vectorByteSize : addr1; - long addr3 = (remaining > 2) ? (long) nodes[i + 2] * vectorByteSize : addr1; + long addr1 = values.address(nodes[i]); + long addr2 = (remaining > 1) ? values.address(nodes[i + 1]) : addr1; + long addr3 = (remaining > 2) ? values.address(nodes[i + 2]) : addr1; vectorOp(seg, scratchScores, addr1, addr2, addr3, addr3, query.length); scores[i] = normalizeRawScore(scratchScores[0]); maxScore = Math.max(maxScore, scores[i]); diff --git a/lucene/core/src/java25/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentFloatVectorScorerSupplier.java b/lucene/core/src/java25/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentFloatVectorScorerSupplier.java index 01ba6a05c806..5d00f47ce600 100644 --- a/lucene/core/src/java25/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentFloatVectorScorerSupplier.java +++ b/lucene/core/src/java25/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentFloatVectorScorerSupplier.java @@ -31,7 +31,6 @@ /** A score supplier of vectors whose element size is byte. */ public abstract sealed class Lucene99MemorySegmentFloatVectorScorerSupplier implements RandomVectorScorerSupplier { - final int vectorByteSize; final int maxOrd; final int dims; final MemorySegment seg; @@ -62,7 +61,6 @@ static Optional create( Lucene99MemorySegmentFloatVectorScorerSupplier(MemorySegment seg, FloatVectorValues values) { this.seg = seg; this.values = values; - this.vectorByteSize = values.getVectorByteLength(); this.maxOrd = values.size(); this.dims = values.dimension(); } @@ -282,8 +280,8 @@ abstract void vectorOp( @Override public float score(int node) { checkOrdinal(node); - long queryAddr = (long) queryOrd * vectorByteSize; - long addr = (long) node * vectorByteSize; + long queryAddr = values.address(queryOrd); + long addr = values.address(node); var raw = vectorOp(seg, queryAddr, addr, dims); return normalizeRawScore(raw); } @@ -291,14 +289,14 @@ public float score(int node) { @Override public float bulkScore(int[] nodes, float[] scores, int numNodes) { int i = 0; - long queryAddr = (long) queryOrd * vectorByteSize; + long queryAddr = values.address(queryOrd); float maxScore = Float.NEGATIVE_INFINITY; final int limit = numNodes & ~3; for (; i < limit; i += 4) { - long offset1 = (long) nodes[i] * vectorByteSize; - long offset2 = (long) nodes[i + 1] * vectorByteSize; - long offset3 = (long) nodes[i + 2] * vectorByteSize; - long offset4 = (long) nodes[i + 3] * vectorByteSize; + long offset1 = values.address(nodes[i]); + long offset2 = values.address(nodes[i + 1]); + long offset3 = values.address(nodes[i + 2]); + long offset4 = values.address(nodes[i + 3]); vectorOp(seg, scratchScores, queryAddr, offset1, offset2, offset3, offset4, dims); scores[i + 0] = normalizeRawScore(scratchScores[0]); maxScore = Math.max(maxScore, scores[i + 0]); @@ -312,9 +310,9 @@ public float bulkScore(int[] nodes, float[] scores, int numNodes) { // Handle remaining 1–3 nodes in bulk (if any) int remaining = numNodes - i; if (remaining > 0) { - long addr1 = (long) nodes[i] * vectorByteSize; - long addr2 = (remaining > 1) ? (long) nodes[i + 1] * vectorByteSize : addr1; - long addr3 = (remaining > 2) ? (long) nodes[i + 2] * vectorByteSize : addr1; + long addr1 = values.address(nodes[i]); + long addr2 = (remaining > 1) ? values.address(nodes[i + 1]) : addr1; + long addr3 = (remaining > 2) ? values.address(nodes[i + 2]) : addr1; vectorOp(seg, scratchScores, queryAddr, addr1, addr2, addr3, addr1, dims); scores[i] = normalizeRawScore(scratchScores[0]); maxScore = Math.max(maxScore, scores[i]);