Skip to content

Commit

Permalink
Makes sure KNNVectorValues aren't recreated unnecessarily when quanti…
Browse files Browse the repository at this point in the history
…zation isn't needed (#2133)



Signed-off-by: Tejas Shah <shatejas@amazon.com>
  • Loading branch information
shatejas authored Sep 23, 2024
1 parent 5423cc1 commit e33afa5
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 25 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Documentation
### Maintenance
### Refactoring
* Does not create additional KNNVectorValues in NativeEngines990KNNVectorWriter when quantization is not needed [#2133](https://github.com/opensearch-project/k-NN/pull/2133)

## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.17...2.x)
### Features
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Supplier;

import static org.opensearch.knn.common.FieldInfoExtractor.extractVectorDataType;
import static org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory.getVectorValues;
Expand Down Expand Up @@ -82,19 +83,19 @@ public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException {
for (final NativeEngineFieldVectorsWriter<?> field : fields) {
final FieldInfo fieldInfo = field.getFieldInfo();
final VectorDataType vectorDataType = extractVectorDataType(fieldInfo);
int totalLiveDocs = getLiveDocs(getVectorValues(vectorDataType, field.getDocsWithField(), field.getVectors()));
int totalLiveDocs = field.getVectors().size();
if (totalLiveDocs > 0) {
KNNVectorValues<?> knnVectorValues = getVectorValues(vectorDataType, field.getDocsWithField(), field.getVectors());

final QuantizationState quantizationState = train(field.getFieldInfo(), knnVectorValues, totalLiveDocs);
final Supplier<KNNVectorValues<?>> knnVectorValuesSupplier = () -> getVectorValues(
vectorDataType,
field.getDocsWithField(),
field.getVectors()
);
final QuantizationState quantizationState = train(field.getFieldInfo(), knnVectorValuesSupplier, totalLiveDocs);
final NativeIndexWriter writer = NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState);

knnVectorValues = getVectorValues(vectorDataType, field.getDocsWithField(), field.getVectors());
final KNNVectorValues<?> knnVectorValues = knnVectorValuesSupplier.get();

StopWatch stopWatch = new StopWatch().start();

writer.flushIndex(knnVectorValues, totalLiveDocs);

long time_in_millis = stopWatch.stop().totalTime().millis();
KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.incrementBy(time_in_millis);
log.debug("Flush took {} ms for vector field [{}]", time_in_millis, fieldInfo.getName());
Expand All @@ -110,17 +111,20 @@ public void mergeOneField(final FieldInfo fieldInfo, final MergeState mergeState
flatVectorsWriter.mergeOneField(fieldInfo, mergeState);

final VectorDataType vectorDataType = extractVectorDataType(fieldInfo);
int totalLiveDocs = getLiveDocs(getKNNVectorValuesForMerge(vectorDataType, fieldInfo, mergeState));
final Supplier<KNNVectorValues<?>> knnVectorValuesSupplier = () -> getKNNVectorValuesForMerge(
vectorDataType,
fieldInfo,
mergeState
);
int totalLiveDocs = getLiveDocs(knnVectorValuesSupplier.get());
if (totalLiveDocs == 0) {
log.debug("[Merge] No live docs for field {}", fieldInfo.getName());
return;
}

KNNVectorValues<?> knnVectorValues = getKNNVectorValuesForMerge(vectorDataType, fieldInfo, mergeState);
final QuantizationState quantizationState = train(fieldInfo, knnVectorValues, totalLiveDocs);
final QuantizationState quantizationState = train(fieldInfo, knnVectorValuesSupplier, totalLiveDocs);
final NativeIndexWriter writer = NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState);

knnVectorValues = getKNNVectorValuesForMerge(vectorDataType, fieldInfo, mergeState);
final KNNVectorValues<?> knnVectorValues = knnVectorValuesSupplier.get();

StopWatch stopWatch = new StopWatch().start();

Expand Down Expand Up @@ -191,27 +195,36 @@ private <T> KNNVectorValues<T> getKNNVectorValuesForMerge(
final VectorDataType vectorDataType,
final FieldInfo fieldInfo,
final MergeState mergeState
) throws IOException {
switch (fieldInfo.getVectorEncoding()) {
case FLOAT32:
FloatVectorValues mergedFloats = MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
return getVectorValues(vectorDataType, mergedFloats);
case BYTE:
ByteVectorValues mergedBytes = MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState);
return getVectorValues(vectorDataType, mergedBytes);
default:
throw new IllegalStateException("Unsupported vector encoding [" + fieldInfo.getVectorEncoding() + "]");
) {
try {
switch (fieldInfo.getVectorEncoding()) {
case FLOAT32:
FloatVectorValues mergedFloats = MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
return getVectorValues(vectorDataType, mergedFloats);
case BYTE:
ByteVectorValues mergedBytes = MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState);
return getVectorValues(vectorDataType, mergedBytes);
default:
throw new IllegalStateException("Unsupported vector encoding [" + fieldInfo.getVectorEncoding() + "]");
}
} catch (final IOException e) {
log.error("Unable to merge vectors for field [{}]", fieldInfo.getName(), e);
throw new IllegalStateException("Unable to merge vectors for field [" + fieldInfo.getName() + "]", e);
}
}

private QuantizationState train(final FieldInfo fieldInfo, final KNNVectorValues<?> knnVectorValues, final int totalLiveDocs)
throws IOException {
private QuantizationState train(
final FieldInfo fieldInfo,
final Supplier<KNNVectorValues<?>> knnVectorValuesSupplier,
final int totalLiveDocs
) throws IOException {

final QuantizationService quantizationService = QuantizationService.getInstance();
final QuantizationParams quantizationParams = quantizationService.getQuantizationParams(fieldInfo);
QuantizationState quantizationState = null;
if (quantizationParams != null && totalLiveDocs > 0) {
initQuantizationStateWriterIfNecessary();
KNNVectorValues<?> knnVectorValues = knnVectorValuesSupplier.get();
quantizationState = quantizationService.train(quantizationParams, knnVectorValues, totalLiveDocs);
quantizationStateWriter.writeState(fieldInfo.getFieldNumber(), quantizationState);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockConstruction;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

Expand Down Expand Up @@ -176,6 +177,11 @@ public void testFlush() {
throw new RuntimeException(e);
}
});

knnVectorValuesFactoryMockedStatic.verify(
() -> KNNVectorValuesFactory.getVectorValues(any(VectorDataType.class), any(DocsWithFieldSet.class), any()),
times(expectedVectorValues.size())
);
}
}

Expand Down Expand Up @@ -264,6 +270,11 @@ public void testFlush_WithQuantization() {
throw new RuntimeException(e);
}
});

knnVectorValuesFactoryMockedStatic.verify(
() -> KNNVectorValuesFactory.getVectorValues(any(VectorDataType.class), any(DocsWithFieldSet.class), any()),
times(expectedVectorValues.size() * 2)
);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockConstruction;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;
Expand Down Expand Up @@ -144,6 +145,10 @@ public void testMerge() {
if (!mergedVectors.isEmpty()) {
verify(nativeIndexWriter).mergeIndex(knnVectorValues, mergedVectors.size());
assertTrue(KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS.getValue() > 0L);
knnVectorValuesFactoryMockedStatic.verify(
() -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, floatVectorValues),
times(2)
);
} else {
verifyNoInteractions(nativeIndexWriter);
}
Expand Down Expand Up @@ -211,6 +216,10 @@ public void testMerge_WithQuantization() {
verify(knn990QuantWriterMockedConstruction.constructed().get(0)).writeState(0, quantizationState);
verify(nativeIndexWriter).mergeIndex(knnVectorValues, mergedVectors.size());
assertTrue(KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS.getValue() > 0L);
knnVectorValuesFactoryMockedStatic.verify(
() -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, floatVectorValues),
times(3)
);
} else {
assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size());
verifyNoInteractions(nativeIndexWriter);
Expand Down

0 comments on commit e33afa5

Please sign in to comment.