From c6369dd06ae109ffeab6fcc7e24f4adca691e354 Mon Sep 17 00:00:00 2001 From: Navneet Verma Date: Mon, 7 Oct 2024 14:43:31 -0700 Subject: [PATCH] Fix lucene codec after lucene version bumped to 9.12 Signed-off-by: Navneet Verma --- CHANGELOG.md | 1 + .../codec/KNN9120Codec/KNN9120Codec.java | 61 +++++++++++++++++++ .../NativeEngineFieldVectorsWriter.java | 32 ++++++++-- .../NativeEngines990KnnVectorsWriter.java | 8 ++- .../knn/index/codec/KNNCodecVersion.java | 21 ++++++- ...KNNScalarQuantizedVectorsFormatParams.java | 14 ++++- .../services/org.apache.lucene.codecs.Codec | 1 + .../NativeEngineFieldVectorsWriterTests.java | 51 ++++++++++++---- ...eEngines990KnnVectorsWriterFlushTests.java | 17 ++++-- ...eEngines990KnnVectorsWriterMergeTests.java | 16 +++-- ...alarQuantizedVectorsFormatParamsTests.java | 53 +++++++++++++++- 11 files changed, 241 insertions(+), 34 deletions(-) create mode 100644 src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120Codec.java diff --git a/CHANGELOG.md b/CHANGELOG.md index f615a78fb..7bc3019df 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Documentation * Fix sed command in DEVELOPER_GUIDE.md to append a new line character '\n'. [#2181](https://github.com/opensearch-project/k-NN/pull/2181) ### Maintenance +* Fix lucene codec after lucene version bumped to 9.12. [#2195](https://github.com/opensearch-project/k-NN/pull/2195) ### Refactoring * Does not create additional KNNVectorValues in NativeEngines990KNNVectorWriter when quantization is not needed [#2133](https://github.com/opensearch-project/k-NN/pull/2133) diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120Codec.java b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120Codec.java new file mode 100644 index 000000000..a370197ec --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120Codec.java @@ -0,0 +1,61 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN9120Codec; + +import lombok.Builder; +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.CompoundFormat; +import org.apache.lucene.codecs.DocValuesFormat; +import org.apache.lucene.codecs.FilterCodec; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.opensearch.knn.index.codec.KNNCodecVersion; +import org.opensearch.knn.index.codec.KNNFormatFacade; + +/** + * KNN Codec that wraps the Lucene Codec which is part of Lucene 9.12 + */ +public class KNN9120Codec extends FilterCodec { + private static final KNNCodecVersion VERSION = KNNCodecVersion.V_9_12_0; + private final KNNFormatFacade knnFormatFacade; + private final PerFieldKnnVectorsFormat perFieldKnnVectorsFormat; + + /** + * No arg constructor that uses Lucene99 as the delegate + */ + public KNN9120Codec() { + this(VERSION.getDefaultCodecDelegate(), VERSION.getPerFieldKnnVectorsFormat()); + } + + /** + * Sole constructor. When subclassing this codec, create a no-arg ctor and pass the delegate codec + * and a unique name to this ctor. + * + * @param delegate codec that will perform all operations this codec does not override + * @param knnVectorsFormat per field format for KnnVector + */ + @Builder + protected KNN9120Codec(Codec delegate, PerFieldKnnVectorsFormat knnVectorsFormat) { + super(VERSION.getCodecName(), delegate); + knnFormatFacade = VERSION.getKnnFormatFacadeSupplier().apply(delegate); + perFieldKnnVectorsFormat = knnVectorsFormat; + } + + @Override + public DocValuesFormat docValuesFormat() { + return knnFormatFacade.docValuesFormat(); + } + + @Override + public CompoundFormat compoundFormat() { + return knnFormatFacade.compoundFormat(); + } + + @Override + public KnnVectorsFormat knnVectorsFormat() { + return perFieldKnnVectorsFormat; + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngineFieldVectorsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngineFieldVectorsWriter.java index 1abb84944..389c76e49 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngineFieldVectorsWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngineFieldVectorsWriter.java @@ -13,11 +13,13 @@ import lombok.Getter; import org.apache.lucene.codecs.KnnFieldVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter; import org.apache.lucene.index.DocsWithFieldSet; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.util.InfoStream; import org.apache.lucene.util.RamUsageEstimator; +import java.io.IOException; import java.util.HashMap; import java.util.Map; @@ -44,22 +46,37 @@ class NativeEngineFieldVectorsWriter extends KnnFieldVectorsWriter { @Getter private final DocsWithFieldSet docsWithField; private final InfoStream infoStream; + private final FlatFieldVectorsWriter flatFieldVectorsWriter; - static NativeEngineFieldVectorsWriter create(final FieldInfo fieldInfo, final InfoStream infoStream) { + @SuppressWarnings("unchecked") + static NativeEngineFieldVectorsWriter create( + final FieldInfo fieldInfo, + final FlatFieldVectorsWriter flatFieldVectorsWriter, + final InfoStream infoStream + ) { switch (fieldInfo.getVectorEncoding()) { case FLOAT32: - return new NativeEngineFieldVectorsWriter(fieldInfo, infoStream); + return new NativeEngineFieldVectorsWriter<>( + fieldInfo, + (FlatFieldVectorsWriter) flatFieldVectorsWriter, + infoStream + ); case BYTE: - return new NativeEngineFieldVectorsWriter(fieldInfo, infoStream); + return new NativeEngineFieldVectorsWriter<>(fieldInfo, (FlatFieldVectorsWriter) flatFieldVectorsWriter, infoStream); } throw new IllegalStateException("Unsupported Vector encoding : " + fieldInfo.getVectorEncoding()); } - private NativeEngineFieldVectorsWriter(final FieldInfo fieldInfo, final InfoStream infoStream) { + private NativeEngineFieldVectorsWriter( + final FieldInfo fieldInfo, + final FlatFieldVectorsWriter flatFieldVectorsWriter, + final InfoStream infoStream + ) { this.fieldInfo = fieldInfo; this.infoStream = infoStream; vectors = new HashMap<>(); this.docsWithField = new DocsWithFieldSet(); + this.flatFieldVectorsWriter = flatFieldVectorsWriter; } /** @@ -70,7 +87,7 @@ private NativeEngineFieldVectorsWriter(final FieldInfo fieldInfo, final InfoStre * @param vectorValue T */ @Override - public void addValue(int docID, T vectorValue) { + public void addValue(int docID, T vectorValue) throws IOException { if (docID == lastDocID) { throw new IllegalArgumentException( "[NativeEngineKNNVectorWriter]VectorValuesField \"" @@ -81,6 +98,8 @@ public void addValue(int docID, T vectorValue) { // TODO: we can build the graph here too iteratively. but right now I am skipping that as we need iterative // graph build support on the JNI layer. assert docID > lastDocID; + // ensuring that vector is provided to flatFieldWriter. + flatFieldVectorsWriter.addValue(docID, vectorValue); vectors.put(docID, vectorValue); docsWithField.add(docID); lastDocID = docID; @@ -105,6 +124,7 @@ public long ramBytesUsed() { return SHALLOW_SIZE + docsWithField.ramBytesUsed() + (long) this.vectors.size() * (long) (RamUsageEstimator.NUM_BYTES_OBJECT_REF + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER) + (long) this.vectors.size() * RamUsageEstimator.shallowSizeOfInstance( Integer.class - ) + (long) vectors.size() * fieldInfo.getVectorDimension() * fieldInfo.getVectorEncoding().byteSize; + ) + (long) vectors.size() * fieldInfo.getVectorDimension() * fieldInfo.getVectorEncoding().byteSize + flatFieldVectorsWriter + .ramBytesUsed(); } } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java index 2f22565c9..eccad41c8 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java @@ -65,9 +65,13 @@ public NativeEngines990KnnVectorsWriter(SegmentWriteState segmentWriteState, Fla */ @Override public KnnFieldVectorsWriter addField(final FieldInfo fieldInfo) throws IOException { - final NativeEngineFieldVectorsWriter newField = NativeEngineFieldVectorsWriter.create(fieldInfo, segmentWriteState.infoStream); + final NativeEngineFieldVectorsWriter newField = NativeEngineFieldVectorsWriter.create( + fieldInfo, + flatVectorsWriter.addField(fieldInfo), + segmentWriteState.infoStream + ); fields.add(newField); - return flatVectorsWriter.addField(fieldInfo, newField); + return newField; } /** diff --git a/src/main/java/org/opensearch/knn/index/codec/KNNCodecVersion.java b/src/main/java/org/opensearch/knn/index/codec/KNNCodecVersion.java index 419505aa2..46171ce9f 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNNCodecVersion.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNNCodecVersion.java @@ -12,12 +12,14 @@ import org.apache.lucene.backward_codecs.lucene94.Lucene94Codec; import org.apache.lucene.codecs.Codec; import org.apache.lucene.backward_codecs.lucene95.Lucene95Codec; -import org.apache.lucene.codecs.lucene99.Lucene99Codec; +import org.apache.lucene.backward_codecs.lucene99.Lucene99Codec; +import org.apache.lucene.codecs.lucene912.Lucene912Codec; import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.opensearch.index.mapper.MapperService; import org.opensearch.knn.index.codec.KNN80Codec.KNN80CompoundFormat; import org.opensearch.knn.index.codec.KNN80Codec.KNN80DocValuesFormat; import org.opensearch.knn.index.codec.KNN910Codec.KNN910Codec; +import org.opensearch.knn.index.codec.KNN9120Codec.KNN9120Codec; import org.opensearch.knn.index.codec.KNN920Codec.KNN920Codec; import org.opensearch.knn.index.codec.KNN920Codec.KNN920PerFieldKnnVectorsFormat; import org.opensearch.knn.index.codec.KNN940Codec.KNN940Codec; @@ -110,9 +112,24 @@ public enum KNNCodecVersion { .knnVectorsFormat(new KNN990PerFieldKnnVectorsFormat(Optional.ofNullable(mapperService))) .build(), KNN990Codec::new + ), + + V_9_12_0( + "KNN9120Codec", + new Lucene912Codec(), + new KNN990PerFieldKnnVectorsFormat(Optional.empty()), + (delegate) -> new KNNFormatFacade( + new KNN80DocValuesFormat(delegate.docValuesFormat()), + new KNN80CompoundFormat(delegate.compoundFormat()) + ), + (userCodec, mapperService) -> KNN9120Codec.builder() + .delegate(userCodec) + .knnVectorsFormat(new KNN990PerFieldKnnVectorsFormat(Optional.ofNullable(mapperService))) + .build(), + KNN9120Codec::new ); - private static final KNNCodecVersion CURRENT = V_9_9_0; + private static final KNNCodecVersion CURRENT = V_9_12_0; private final String codecName; private final Codec defaultCodecDelegate; diff --git a/src/main/java/org/opensearch/knn/index/codec/params/KNNScalarQuantizedVectorsFormatParams.java b/src/main/java/org/opensearch/knn/index/codec/params/KNNScalarQuantizedVectorsFormatParams.java index e2d31183b..3498119c1 100644 --- a/src/main/java/org/opensearch/knn/index/codec/params/KNNScalarQuantizedVectorsFormatParams.java +++ b/src/main/java/org/opensearch/knn/index/codec/params/KNNScalarQuantizedVectorsFormatParams.java @@ -31,7 +31,8 @@ public KNNScalarQuantizedVectorsFormatParams(Map params, int def Map sqEncoderParams = encoderMethodComponentContext.getParameters(); this.initConfidenceInterval(sqEncoderParams); this.initBits(sqEncoderParams); - this.initCompressFlag(); + // compression flag should be set after bits has been initialised as compressionFlag depends on bits. + this.setCompressionFlag(); } @Override @@ -76,7 +77,14 @@ private void initBits(final Map params) { this.bits = LUCENE_SQ_DEFAULT_BITS; } - private void initCompressFlag() { - this.compressFlag = true; + private void setCompressionFlag() { + if (this.bits <= 0) { + throw new IllegalArgumentException( + "Either bits are set to less than 0 or they have not been initialized." + " Bit value: " + this.bits + ); + } + // This check is coming from Lucene. Code ref: + // https://github.com/apache/lucene/blob/branch_9_12/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java#L113-L116 + this.compressFlag = this.bits <= 4; } } diff --git a/src/main/resources/META-INF/services/org.apache.lucene.codecs.Codec b/src/main/resources/META-INF/services/org.apache.lucene.codecs.Codec index 401b094b2..7a8916981 100644 --- a/src/main/resources/META-INF/services/org.apache.lucene.codecs.Codec +++ b/src/main/resources/META-INF/services/org.apache.lucene.codecs.Codec @@ -7,4 +7,5 @@ org.opensearch.knn.index.codec.KNN920Codec.KNN920Codec org.opensearch.knn.index.codec.KNN940Codec.KNN940Codec org.opensearch.knn.index.codec.KNN950Codec.KNN950Codec org.opensearch.knn.index.codec.KNN990Codec.KNN990Codec +org.opensearch.knn.index.codec.KNN9120Codec.KNN9120Codec org.opensearch.knn.index.codec.KNN990Codec.UnitTestCodec diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngineFieldVectorsWriterTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngineFieldVectorsWriterTests.java index 29e3531cf..6e6a51b88 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngineFieldVectorsWriterTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngineFieldVectorsWriterTests.java @@ -11,6 +11,8 @@ package org.opensearch.knn.index.codec.KNN990Codec; +import lombok.SneakyThrows; +import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.util.InfoStream; @@ -21,82 +23,109 @@ public class NativeEngineFieldVectorsWriterTests extends KNNCodecTestCase { @SuppressWarnings("unchecked") + @SneakyThrows public void testCreate_ForDifferentInputs_thenSuccess() { final FieldInfo fieldInfo = Mockito.mock(FieldInfo.class); Mockito.when(fieldInfo.getVectorEncoding()).thenReturn(VectorEncoding.FLOAT32); + final FlatFieldVectorsWriter mockedFlatFieldVectorsWriter = Mockito.mock(FlatFieldVectorsWriter.class); NativeEngineFieldVectorsWriter floatWriter = (NativeEngineFieldVectorsWriter) NativeEngineFieldVectorsWriter - .create(fieldInfo, InfoStream.getDefault()); - floatWriter.addValue(1, new float[] { 1.0f, 2.0f }); + .create(fieldInfo, mockedFlatFieldVectorsWriter, InfoStream.getDefault()); + final float[] floatVector = new float[] { 1.0f, 2.0f }; + floatWriter.addValue(1, floatVector); + Mockito.doNothing().when(mockedFlatFieldVectorsWriter).addValue(1, floatVector); Mockito.verify(fieldInfo).getVectorEncoding(); + Mockito.verify(mockedFlatFieldVectorsWriter).addValue(1, floatVector); + final byte[] byteVector = new byte[] { 1, 2 }; + final FlatFieldVectorsWriter mockedFlatFieldByteVectorsWriter = Mockito.mock(FlatFieldVectorsWriter.class); + Mockito.doNothing().when(mockedFlatFieldByteVectorsWriter).addValue(1, byteVector); Mockito.when(fieldInfo.getVectorEncoding()).thenReturn(VectorEncoding.BYTE); NativeEngineFieldVectorsWriter byteWriter = (NativeEngineFieldVectorsWriter) NativeEngineFieldVectorsWriter.create( fieldInfo, + mockedFlatFieldByteVectorsWriter, InfoStream.getDefault() ); Assert.assertNotNull(byteWriter); Mockito.verify(fieldInfo, Mockito.times(2)).getVectorEncoding(); - byteWriter.addValue(1, new byte[] { 1, 2 }); + byteWriter.addValue(1, byteVector); + Mockito.verify(mockedFlatFieldByteVectorsWriter).addValue(1, byteVector); } @SuppressWarnings("unchecked") + @SneakyThrows public void testAddValue_ForDifferentInputs_thenSuccess() { final FieldInfo fieldInfo = Mockito.mock(FieldInfo.class); Mockito.when(fieldInfo.getVectorEncoding()).thenReturn(VectorEncoding.FLOAT32); - final NativeEngineFieldVectorsWriter floatWriter = (NativeEngineFieldVectorsWriter) NativeEngineFieldVectorsWriter - .create(fieldInfo, InfoStream.getDefault()); + final FlatFieldVectorsWriter mockedFlatFieldVectorsWriter = Mockito.mock(FlatFieldVectorsWriter.class); final float[] vec1 = new float[] { 1.0f, 2.0f }; final float[] vec2 = new float[] { 2.0f, 2.0f }; + Mockito.doNothing().when(mockedFlatFieldVectorsWriter).addValue(1, vec1); + Mockito.doNothing().when(mockedFlatFieldVectorsWriter).addValue(2, vec2); + final NativeEngineFieldVectorsWriter floatWriter = (NativeEngineFieldVectorsWriter) NativeEngineFieldVectorsWriter + .create(fieldInfo, mockedFlatFieldVectorsWriter, InfoStream.getDefault()); floatWriter.addValue(1, vec1); floatWriter.addValue(2, vec2); + Mockito.verify(mockedFlatFieldVectorsWriter).addValue(1, vec1); + Mockito.verify(mockedFlatFieldVectorsWriter).addValue(2, vec2); Assert.assertEquals(vec1, floatWriter.getVectors().get(1)); Assert.assertEquals(vec2, floatWriter.getVectors().get(2)); Mockito.verify(fieldInfo).getVectorEncoding(); Mockito.when(fieldInfo.getVectorEncoding()).thenReturn(VectorEncoding.BYTE); - final NativeEngineFieldVectorsWriter byteWriter = (NativeEngineFieldVectorsWriter) NativeEngineFieldVectorsWriter - .create(fieldInfo, InfoStream.getDefault()); + final FlatFieldVectorsWriter mockedFlatFieldByteVectorsWriter = Mockito.mock(FlatFieldVectorsWriter.class); final byte[] bvec1 = new byte[] { 1, 2 }; final byte[] bvec2 = new byte[] { 2, 2 }; + Mockito.doNothing().when(mockedFlatFieldByteVectorsWriter).addValue(1, bvec1); + Mockito.doNothing().when(mockedFlatFieldByteVectorsWriter).addValue(2, bvec2); + final NativeEngineFieldVectorsWriter byteWriter = (NativeEngineFieldVectorsWriter) NativeEngineFieldVectorsWriter + .create(fieldInfo, mockedFlatFieldByteVectorsWriter, InfoStream.getDefault()); byteWriter.addValue(1, bvec1); byteWriter.addValue(2, bvec2); Assert.assertEquals(bvec1, byteWriter.getVectors().get(1)); Assert.assertEquals(bvec2, byteWriter.getVectors().get(2)); Mockito.verify(fieldInfo, Mockito.times(2)).getVectorEncoding(); + Mockito.verify(mockedFlatFieldByteVectorsWriter).addValue(1, bvec1); + Mockito.verify(mockedFlatFieldByteVectorsWriter).addValue(2, bvec2); } @SuppressWarnings("unchecked") + @SneakyThrows public void testCopyValue_whenValidInput_thenException() { final FieldInfo fieldInfo = Mockito.mock(FieldInfo.class); + FlatFieldVectorsWriter mockedFlatFieldVectorsWriter = Mockito.mock(FlatFieldVectorsWriter.class); Mockito.when(fieldInfo.getVectorEncoding()).thenReturn(VectorEncoding.FLOAT32); final NativeEngineFieldVectorsWriter floatWriter = (NativeEngineFieldVectorsWriter) NativeEngineFieldVectorsWriter - .create(fieldInfo, InfoStream.getDefault()); + .create(fieldInfo, mockedFlatFieldVectorsWriter, InfoStream.getDefault()); expectThrows(UnsupportedOperationException.class, () -> floatWriter.copyValue(new float[3])); Mockito.when(fieldInfo.getVectorEncoding()).thenReturn(VectorEncoding.BYTE); final NativeEngineFieldVectorsWriter byteWriter = (NativeEngineFieldVectorsWriter) NativeEngineFieldVectorsWriter - .create(fieldInfo, InfoStream.getDefault()); + .create(fieldInfo, mockedFlatFieldVectorsWriter, InfoStream.getDefault()); expectThrows(UnsupportedOperationException.class, () -> byteWriter.copyValue(new byte[3])); } @SuppressWarnings("unchecked") + @SneakyThrows public void testRamByteUsed_whenValidInput_thenSuccess() { final FieldInfo fieldInfo = Mockito.mock(FieldInfo.class); Mockito.when(fieldInfo.getVectorEncoding()).thenReturn(VectorEncoding.FLOAT32); Mockito.when(fieldInfo.getVectorDimension()).thenReturn(2); + FlatFieldVectorsWriter mockedFlatFieldVectorsWriter = Mockito.mock(FlatFieldVectorsWriter.class); + Mockito.when(mockedFlatFieldVectorsWriter.ramBytesUsed()).thenReturn(1L); final NativeEngineFieldVectorsWriter floatWriter = (NativeEngineFieldVectorsWriter) NativeEngineFieldVectorsWriter - .create(fieldInfo, InfoStream.getDefault()); + .create(fieldInfo, mockedFlatFieldVectorsWriter, InfoStream.getDefault()); // testing for value > 0 as we don't have a concrete way to find out expected bytes. This can OS dependent too. Assert.assertTrue(floatWriter.ramBytesUsed() > 0); Mockito.when(fieldInfo.getVectorEncoding()).thenReturn(VectorEncoding.BYTE); final NativeEngineFieldVectorsWriter byteWriter = (NativeEngineFieldVectorsWriter) NativeEngineFieldVectorsWriter - .create(fieldInfo, InfoStream.getDefault()); + .create(fieldInfo, mockedFlatFieldVectorsWriter, InfoStream.getDefault()); // testing for value > 0 as we don't have a concrete way to find out expected bytes. This can OS dependent too. Assert.assertTrue(byteWriter.ramBytesUsed() > 0); + Mockito.verify(mockedFlatFieldVectorsWriter, Mockito.times(2)).ramBytesUsed(); } } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java index 9f74b2c10..5bb6d1926 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java @@ -8,6 +8,7 @@ import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; import lombok.RequiredArgsConstructor; import lombok.SneakyThrows; +import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter; import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; import org.apache.lucene.index.DocsWithFieldSet; import org.apache.lucene.index.FieldInfo; @@ -16,6 +17,7 @@ import org.mockito.Mock; import org.mockito.MockedConstruction; import org.mockito.MockedStatic; +import org.mockito.Mockito; import org.mockito.MockitoAnnotations; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.VectorDataType; @@ -68,6 +70,8 @@ public class NativeEngines990KnnVectorsWriterFlushTests extends OpenSearchTestCa @Mock private NativeIndexWriter nativeIndexWriter; + private FlatFieldVectorsWriter mockedFlatFieldVectorsWriter; + private NativeEngines990KnnVectorsWriter objectUnderTest; private final String description; @@ -78,6 +82,9 @@ public void setUp() throws Exception { super.setUp(); MockitoAnnotations.openMocks(this); objectUnderTest = new NativeEngines990KnnVectorsWriter(segmentWriteState, flatVectorsWriter); + mockedFlatFieldVectorsWriter = Mockito.mock(FlatFieldVectorsWriter.class); + Mockito.doNothing().when(mockedFlatFieldVectorsWriter).addValue(Mockito.anyInt(), Mockito.any()); + Mockito.when(flatVectorsWriter.addField(Mockito.any())).thenReturn(mockedFlatFieldVectorsWriter); } @ParametersFactory @@ -139,8 +146,9 @@ public void testFlush() { ); NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, vectorsPerField.get(i)); - fieldWriterMockedStatic.when(() -> NativeEngineFieldVectorsWriter.create(fieldInfo, segmentWriteState.infoStream)) - .thenReturn(field); + fieldWriterMockedStatic.when( + () -> NativeEngineFieldVectorsWriter.create(fieldInfo, mockedFlatFieldVectorsWriter, segmentWriteState.infoStream) + ).thenReturn(field); try { objectUnderTest.addField(fieldInfo); @@ -227,8 +235,9 @@ public void testFlush_WithQuantization() { ); NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, vectorsPerField.get(i)); - fieldWriterMockedStatic.when(() -> NativeEngineFieldVectorsWriter.create(fieldInfo, segmentWriteState.infoStream)) - .thenReturn(field); + fieldWriterMockedStatic.when( + () -> NativeEngineFieldVectorsWriter.create(fieldInfo, mockedFlatFieldVectorsWriter, segmentWriteState.infoStream) + ).thenReturn(field); try { objectUnderTest.addField(fieldInfo); diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java index 41940c4d4..af18cd281 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java @@ -9,6 +9,7 @@ import lombok.RequiredArgsConstructor; import lombok.SneakyThrows; import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter; import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; import org.apache.lucene.index.DocsWithFieldSet; import org.apache.lucene.index.FieldInfo; @@ -19,6 +20,7 @@ import org.mockito.Mock; import org.mockito.MockedConstruction; import org.mockito.MockedStatic; +import org.mockito.Mockito; import org.mockito.MockitoAnnotations; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.VectorDataType; @@ -74,12 +76,16 @@ public class NativeEngines990KnnVectorsWriterMergeTests extends OpenSearchTestCa private final String description; private final Map mergedVectors; + private FlatFieldVectorsWriter mockedFlatFieldVectorsWriter; @Override public void setUp() throws Exception { super.setUp(); MockitoAnnotations.openMocks(this); objectUnderTest = new NativeEngines990KnnVectorsWriter(segmentWriteState, flatVectorsWriter); + mockedFlatFieldVectorsWriter = Mockito.mock(FlatFieldVectorsWriter.class); + Mockito.doNothing().when(mockedFlatFieldVectorsWriter).addValue(Mockito.anyInt(), Mockito.any()); + Mockito.when(flatVectorsWriter.addField(Mockito.any())).thenReturn(mockedFlatFieldVectorsWriter); } @ParametersFactory @@ -120,8 +126,9 @@ public void testMerge() { ); NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, mergedVectors); - fieldWriterMockedStatic.when(() -> NativeEngineFieldVectorsWriter.create(fieldInfo, segmentWriteState.infoStream)) - .thenReturn(field); + fieldWriterMockedStatic.when( + () -> NativeEngineFieldVectorsWriter.create(fieldInfo, mockedFlatFieldVectorsWriter, segmentWriteState.infoStream) + ).thenReturn(field); mergedVectorValuesMockedStatic.when(() -> KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState)) .thenReturn(floatVectorValues); @@ -184,8 +191,9 @@ public void testMerge_WithQuantization() { ); NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, mergedVectors); - fieldWriterMockedStatic.when(() -> NativeEngineFieldVectorsWriter.create(fieldInfo, segmentWriteState.infoStream)) - .thenReturn(field); + fieldWriterMockedStatic.when( + () -> NativeEngineFieldVectorsWriter.create(fieldInfo, mockedFlatFieldVectorsWriter, segmentWriteState.infoStream) + ).thenReturn(field); mergedVectorValuesMockedStatic.when(() -> KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState)) .thenReturn(floatVectorValues); diff --git a/src/test/java/org/opensearch/knn/index/codec/params/KNNScalarQuantizedVectorsFormatParamsTests.java b/src/test/java/org/opensearch/knn/index/codec/params/KNNScalarQuantizedVectorsFormatParamsTests.java index 573e826e0..b7394b06a 100644 --- a/src/test/java/org/opensearch/knn/index/codec/params/KNNScalarQuantizedVectorsFormatParamsTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/params/KNNScalarQuantizedVectorsFormatParamsTests.java @@ -12,12 +12,14 @@ package org.opensearch.knn.index.codec.params; import junit.framework.TestCase; +import org.junit.Assert; import org.opensearch.knn.index.engine.MethodComponentContext; import java.util.HashMap; import java.util.Map; import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ; +import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_BITS; import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_CONFIDENCE_INTERVAL; import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_DEFAULT_BITS; import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; @@ -39,7 +41,7 @@ public void testInitParams_whenCalled_thenReturnDefaultParams() { assertEquals(DEFAULT_MAX_CONNECTIONS, knnScalarQuantizedVectorsFormatParams.getMaxConnections()); assertEquals(DEFAULT_BEAM_WIDTH, knnScalarQuantizedVectorsFormatParams.getBeamWidth()); assertNull(knnScalarQuantizedVectorsFormatParams.getConfidenceInterval()); - assertTrue(knnScalarQuantizedVectorsFormatParams.isCompressFlag()); + assertFalse(knnScalarQuantizedVectorsFormatParams.isCompressFlag()); assertEquals(LUCENE_SQ_DEFAULT_BITS, knnScalarQuantizedVectorsFormatParams.getBits()); } @@ -65,10 +67,57 @@ public void testInitParams_whenCalled_thenReturnParams() { assertEquals(m, knnScalarQuantizedVectorsFormatParams.getMaxConnections()); assertEquals(efConstruction, knnScalarQuantizedVectorsFormatParams.getBeamWidth()); assertEquals((float) MINIMUM_CONFIDENCE_INTERVAL, knnScalarQuantizedVectorsFormatParams.getConfidenceInterval()); - assertTrue(knnScalarQuantizedVectorsFormatParams.isCompressFlag()); + assertFalse(knnScalarQuantizedVectorsFormatParams.isCompressFlag()); assertEquals(LUCENE_SQ_DEFAULT_BITS, knnScalarQuantizedVectorsFormatParams.getBits()); } + public void testInitParams_whenBitsIs4_thenReturnParams() { + int m = 64; + int efConstruction = 128; + + Map encoderParams = new HashMap<>(); + encoderParams.put(LUCENE_SQ_CONFIDENCE_INTERVAL, MINIMUM_CONFIDENCE_INTERVAL); + encoderParams.put(LUCENE_SQ_BITS, 4); + MethodComponentContext encoderComponentContext = new MethodComponentContext(ENCODER_SQ, encoderParams); + + Map params = new HashMap<>(); + params.put(METHOD_ENCODER_PARAMETER, encoderComponentContext); + params.put(METHOD_PARAMETER_M, m); + params.put(METHOD_PARAMETER_EF_CONSTRUCTION, efConstruction); + + KNNScalarQuantizedVectorsFormatParams knnScalarQuantizedVectorsFormatParams = new KNNScalarQuantizedVectorsFormatParams( + params, + DEFAULT_MAX_CONNECTIONS, + DEFAULT_BEAM_WIDTH + ); + + assertEquals(m, knnScalarQuantizedVectorsFormatParams.getMaxConnections()); + assertEquals(efConstruction, knnScalarQuantizedVectorsFormatParams.getBeamWidth()); + assertEquals((float) MINIMUM_CONFIDENCE_INTERVAL, knnScalarQuantizedVectorsFormatParams.getConfidenceInterval()); + assertTrue(knnScalarQuantizedVectorsFormatParams.isCompressFlag()); + assertEquals(4, knnScalarQuantizedVectorsFormatParams.getBits()); + } + + public void testInitParams_whenBitsIs0_thenThrowException() { + int m = 64; + int efConstruction = 128; + + Map encoderParams = new HashMap<>(); + encoderParams.put(LUCENE_SQ_CONFIDENCE_INTERVAL, MINIMUM_CONFIDENCE_INTERVAL); + encoderParams.put(LUCENE_SQ_BITS, 0); + MethodComponentContext encoderComponentContext = new MethodComponentContext(ENCODER_SQ, encoderParams); + + Map params = new HashMap<>(); + params.put(METHOD_ENCODER_PARAMETER, encoderComponentContext); + params.put(METHOD_PARAMETER_M, m); + params.put(METHOD_PARAMETER_EF_CONSTRUCTION, efConstruction); + + Assert.assertThrows( + IllegalArgumentException.class, + () -> new KNNScalarQuantizedVectorsFormatParams(params, DEFAULT_MAX_CONNECTIONS, DEFAULT_BEAM_WIDTH) + ); + } + public void testValidate_whenCalled_thenReturnTrue() { Map params = getDefaultParamsForConstructor(); KNNScalarQuantizedVectorsFormatParams knnScalarQuantizedVectorsFormatParams = new KNNScalarQuantizedVectorsFormatParams(