Skip to content

Commit

Permalink
Add graph creation stats to the KNNStats API (#1141)
Browse files Browse the repository at this point in the history
* Add graph stats to KNN Stats API

Signed-off-by: Ryan Bogan <rbogan@amazon.com>
  • Loading branch information
ryanbogan authored Oct 3, 2023
1 parent 78aba55 commit 9e3e046
Show file tree
Hide file tree
Showing 9 changed files with 277 additions and 16 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
### Enhancements
* Added support for ignore_unmapped in KNN queries. [#1071](https://github.com/opensearch-project/k-NN/pull/1071)
* Add graph creation stats to the KNNStats API. [#1141](https://github.com/opensearch-project/k-NN/pull/1141)
### Bug Fixes
### Infrastructure
### Documentation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import lombok.NonNull;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.store.ChecksumIndexInput;
import org.opensearch.common.StopWatch;
import org.opensearch.common.xcontent.XContentHelper;
import org.opensearch.core.common.bytes.BytesArray;
import org.opensearch.core.xcontent.MediaTypeRegistry;
Expand All @@ -35,6 +36,7 @@
import org.apache.lucene.store.FilterDirectory;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.plugin.stats.KNNGraphValue;

import java.io.Closeable;
import java.io.IOException;
Expand All @@ -53,6 +55,7 @@
import static org.opensearch.knn.common.KNNConstants.MODEL_ID;
import static org.opensearch.knn.common.KNNConstants.PARAMETERS;
import static org.opensearch.knn.index.codec.util.KNNCodecUtil.buildEngineFileName;
import static org.opensearch.knn.index.codec.util.KNNCodecUtil.calculateArraySize;

/**
* This class writes the KNN docvalues to the segments
Expand All @@ -76,7 +79,13 @@ class KNN80DocValuesConsumer extends DocValuesConsumer implements Closeable {
public void addBinaryField(FieldInfo field, DocValuesProducer valuesProducer) throws IOException {
delegatee.addBinaryField(field, valuesProducer);
if (isKNNBinaryFieldRequired(field)) {
addKNNBinaryField(field, valuesProducer);
StopWatch stopWatch = new StopWatch();
stopWatch.start();
addKNNBinaryField(field, valuesProducer, false, true);
stopWatch.stop();
long time_in_millis = stopWatch.totalTime().millis();
KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.set(KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.getValue() + time_in_millis);
logger.warn("Refresh operation complete in " + time_in_millis + " ms");
}
}

Expand All @@ -97,14 +106,21 @@ private KNNEngine getKNNEngine(@NonNull FieldInfo field) {
return KNNEngine.getEngine(engineName);
}

public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer) throws IOException {
public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, boolean isMerge, boolean isRefresh)
throws IOException {
// Get values to be indexed
BinaryDocValues values = valuesProducer.getBinary(field);
KNNCodecUtil.Pair pair = KNNCodecUtil.getFloats(values);
if (pair.vectors.length == 0 || pair.docs.length == 0) {
logger.info("Skipping engine index creation as there are no vectors or docs in the documents");
return;
}
long arraySize = calculateArraySize(pair.vectors, pair.serializationMode);
if (isMerge) {
KNNGraphValue.MERGE_CURRENT_OPERATIONS.increment();
KNNGraphValue.MERGE_CURRENT_DOCS.incrementBy(pair.docs.length);
KNNGraphValue.MERGE_CURRENT_SIZE_IN_BYTES.incrementBy(arraySize);
}
// Increment counter for number of graph index requests
KNNCounter.GRAPH_INDEX_REQUESTS.increment();
// Create library index either from model or from scratch
Expand Down Expand Up @@ -135,6 +151,14 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer)
indexCreator = () -> createKNNIndexFromScratch(field, pair, knnEngine, indexPath);
}

if (isMerge) {
recordMergeStats(pair.docs.length, arraySize);
}

if (isRefresh) {
recordRefreshStats();
}

// This is a bit of a hack. We have to create an output here and then immediately close it to ensure that
// engineFileName is added to the tracked files by Lucene's TrackingDirectoryWrapper. Otherwise, the file will
// not be marked as added to the directory.
Expand All @@ -143,6 +167,19 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer)
writeFooter(indexPath, engineFileName);
}

private void recordMergeStats(int length, long arraySize) {
KNNGraphValue.MERGE_CURRENT_OPERATIONS.decrement();
KNNGraphValue.MERGE_CURRENT_DOCS.decrementBy(length);
KNNGraphValue.MERGE_CURRENT_SIZE_IN_BYTES.decrementBy(arraySize);
KNNGraphValue.MERGE_TOTAL_OPERATIONS.increment();
KNNGraphValue.MERGE_TOTAL_DOCS.incrementBy(length);
KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.incrementBy(arraySize);
}

private void recordRefreshStats() {
KNNGraphValue.REFRESH_TOTAL_OPERATIONS.increment();
}

private void createKNNIndexFromTemplate(byte[] model, KNNCodecUtil.Pair pair, KNNEngine knnEngine, String indexPath) {
Map<String, Object> parameters = ImmutableMap.of(
KNNConstants.INDEX_THREAD_QTY,
Expand Down Expand Up @@ -210,7 +247,13 @@ public void merge(MergeState mergeState) {
for (FieldInfo fieldInfo : mergeState.mergeFieldInfos) {
DocValuesType type = fieldInfo.getDocValuesType();
if (type == DocValuesType.BINARY && fieldInfo.attributes().containsKey(KNNVectorFieldMapper.KNN_FIELD)) {
addKNNBinaryField(fieldInfo, new KNN80DocValuesReader(mergeState));
StopWatch stopWatch = new StopWatch();
stopWatch.start();
addKNNBinaryField(fieldInfo, new KNN80DocValuesReader(mergeState), true, false);
stopWatch.stop();
long time_in_millis = stopWatch.totalTime().millis();
KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS.set(KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS.getValue() + time_in_millis);
logger.warn("Merge operation complete in " + time_in_millis + " ms");
}
}
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,30 +17,72 @@ public class KNNCodecUtil {

public static final String HNSW_EXTENSION = ".hnsw";
public static final String HNSW_COMPOUND_EXTENSION = ".hnswc";
// Floats are 4 bytes in size
public static final int FLOAT_BYTE_SIZE = 4;
// References to objects are 4 bytes in size
public static final int JAVA_REFERENCE_SIZE = 4;
// Each array in Java has a header that is 12 bytes
public static final int JAVA_ARRAY_HEADER_SIZE = 12;
// Java rounds each array size up to multiples of 8 bytes
public static final int JAVA_ROUNDING_NUMBER = 8;

public static final class Pair {
public Pair(int[] docs, float[][] vectors) {
public Pair(int[] docs, float[][] vectors, SerializationMode serializationMode) {
this.docs = docs;
this.vectors = vectors;
this.serializationMode = serializationMode;
}

public int[] docs;
public float[][] vectors;
public SerializationMode serializationMode;
}

public static KNNCodecUtil.Pair getFloats(BinaryDocValues values) throws IOException {
ArrayList<float[]> vectorList = new ArrayList<>();
ArrayList<Integer> docIdList = new ArrayList<>();
SerializationMode serializationMode = SerializationMode.COLLECTION_OF_FLOATS;
for (int doc = values.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = values.nextDoc()) {
BytesRef bytesref = values.binaryValue();
try (ByteArrayInputStream byteStream = new ByteArrayInputStream(bytesref.bytes, bytesref.offset, bytesref.length)) {
serializationMode = KNNVectorSerializerFactory.serializerModeFromStream(byteStream);
final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream);
final float[] vector = vectorSerializer.byteToFloatArray(byteStream);
vectorList.add(vector);
}
docIdList.add(doc);
}
return new KNNCodecUtil.Pair(docIdList.stream().mapToInt(Integer::intValue).toArray(), vectorList.toArray(new float[][] {}));
return new KNNCodecUtil.Pair(
docIdList.stream().mapToInt(Integer::intValue).toArray(),
vectorList.toArray(new float[][] {}),
serializationMode
);
}

public static long calculateArraySize(float[][] vectors, SerializationMode serializationMode) {
int vectorLength = vectors[0].length;
int numVectors = vectors.length;
if (serializationMode == SerializationMode.ARRAY) {
int vectorSize = vectorLength * FLOAT_BYTE_SIZE + JAVA_ARRAY_HEADER_SIZE;
if (vectorSize % JAVA_ROUNDING_NUMBER != 0) {
vectorSize += vectorSize % JAVA_ROUNDING_NUMBER;
}
int vectorsSize = numVectors * (vectorSize + JAVA_REFERENCE_SIZE) + JAVA_ARRAY_HEADER_SIZE;
if (vectorsSize % JAVA_ROUNDING_NUMBER != 0) {
vectorsSize += vectorsSize % JAVA_ROUNDING_NUMBER;
}
return vectorsSize;
} else {
int vectorSize = vectorLength * FLOAT_BYTE_SIZE;
if (vectorSize % JAVA_ROUNDING_NUMBER != 0) {
vectorSize += vectorSize % JAVA_ROUNDING_NUMBER;
}
int vectorsSize = numVectors * (vectorSize + JAVA_REFERENCE_SIZE);
if (vectorsSize % JAVA_ROUNDING_NUMBER != 0) {
vectorsSize += vectorsSize % JAVA_ROUNDING_NUMBER;
}
return vectorsSize;
}
}

public static String buildEngineFileName(String segmentName, String latestBuildVersion, String fieldName, String extension) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ public static KNNVectorSerializer getSerializerByStreamContent(final ByteArrayIn
return getSerializerBySerializationMode(serializationMode);
}

private static SerializationMode serializerModeFromStream(ByteArrayInputStream byteStream) {
static SerializationMode serializerModeFromStream(ByteArrayInputStream byteStream) {
int numberOfAvailableBytesInStream = byteStream.available();
if (numberOfAvailableBytesInStream < ARRAY_HEADER_OFFSET) {
return getSerializerOrThrowError(numberOfAvailableBytesInStream, COLLECTION_OF_FLOATS);
Expand Down
95 changes: 95 additions & 0 deletions src/main/java/org/opensearch/knn/plugin/stats/KNNGraphValue.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.plugin.stats;

import java.util.concurrent.atomic.AtomicLong;

/**
* Contains a map to keep track of different graph values
*/
public enum KNNGraphValue {

REFRESH_TOTAL_OPERATIONS("total"),
REFRESH_TOTAL_TIME_IN_MILLIS("total_time_in_millis"),
MERGE_CURRENT_OPERATIONS("current"),
MERGE_CURRENT_DOCS("current_docs"),
MERGE_CURRENT_SIZE_IN_BYTES("current_size_in_bytes"),
MERGE_TOTAL_OPERATIONS("total"),
MERGE_TOTAL_TIME_IN_MILLIS("total_time_in_millis"),
MERGE_TOTAL_DOCS("total_docs"),
MERGE_TOTAL_SIZE_IN_BYTES("total_size_in_bytes");

private String name;
private AtomicLong value;

/**
* Constructor
*
* @param name name of the graph value
*/
KNNGraphValue(String name) {
this.name = name;
this.value = new AtomicLong(0);
}

/**
* Get name of value
*
* @return name
*/
public String getName() {
return name;
}

/**
* Get the graph value
*
* @return value
*/
public Long getValue() {
return value.get();
}

/**
* Increment the graph value
*/
public void increment() {
value.getAndIncrement();
}

/**
* Decrement the graph value
*/
public void decrement() {
value.getAndDecrement();
}

/**
* Increment the graph value by a specified amount
*
* @param delta The amount to increment
*/
public void incrementBy(long delta) {
value.getAndAdd(delta);
}

/**
* Decrement the graph value by a specified amount
*
* @param delta The amount to decrement
*/
public void decrementBy(long delta) {
value.set(value.get() - delta);
}

/**
* @param value graph value
* Set the graph value
*/
public void set(long value) {
this.value.set(value);
}
}
29 changes: 29 additions & 0 deletions src/main/java/org/opensearch/knn/plugin/stats/KNNStats.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.time.temporal.ChronoUnit;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Supplier;

/**
* Class represents all stats the plugin keeps track of
Expand Down Expand Up @@ -84,6 +85,7 @@ private Map<String, KNNStat<?>> buildStatsMap() {
addEngineStats(builder);
addScriptStats(builder);
addModelStats(builder);
addGraphStats(builder);
return builder.build();
}

Expand Down Expand Up @@ -169,4 +171,31 @@ private void addModelStats(ImmutableMap.Builder<String, KNNStat<?>> builder) {
new KNNStat<>(false, new NativeMemoryCacheManagerSupplier<>(NativeMemoryCacheManager::getTrainingSizeAsPercentage))
);
}

private void addGraphStats(ImmutableMap.Builder<String, KNNStat<?>> builder) {
builder.put(StatNames.GRAPH_STATS.getName(), new KNNStat<>(false, new Supplier<Map<String, Map<String, Object>>>() {
@Override
public Map<String, Map<String, Object>> get() {
return createGraphStatsMap();
}
}));
}

private Map<String, Map<String, Object>> createGraphStatsMap() {
Map<String, Object> mergeMap = new HashMap<>();
mergeMap.put(KNNGraphValue.MERGE_CURRENT_OPERATIONS.getName(), KNNGraphValue.MERGE_CURRENT_OPERATIONS.getValue());
mergeMap.put(KNNGraphValue.MERGE_CURRENT_DOCS.getName(), KNNGraphValue.MERGE_CURRENT_DOCS.getValue());
mergeMap.put(KNNGraphValue.MERGE_CURRENT_SIZE_IN_BYTES.getName(), KNNGraphValue.MERGE_CURRENT_SIZE_IN_BYTES.getValue());
mergeMap.put(KNNGraphValue.MERGE_TOTAL_OPERATIONS.getName(), KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue());
mergeMap.put(KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS.getName(), KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS.getValue());
mergeMap.put(KNNGraphValue.MERGE_TOTAL_DOCS.getName(), KNNGraphValue.MERGE_TOTAL_DOCS.getValue());
mergeMap.put(KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getName(), KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue());
Map<String, Object> refreshMap = new HashMap<>();
refreshMap.put(KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getName(), KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue());
refreshMap.put(KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.getName(), KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.getValue());
Map<String, Map<String, Object>> graphStatsMap = new HashMap<>();
graphStatsMap.put(StatNames.MERGE.getName(), mergeMap);
graphStatsMap.put(StatNames.REFRESH.getName(), refreshMap);
return graphStatsMap;
}
}
5 changes: 4 additions & 1 deletion src/main/java/org/opensearch/knn/plugin/stats/StatNames.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@ public enum StatNames {
TRAINING_MEMORY_USAGE("training_memory_usage"),
TRAINING_MEMORY_USAGE_PERCENTAGE("training_memory_usage_percentage"),
SCRIPT_QUERY_ERRORS(KNNCounter.SCRIPT_QUERY_ERRORS.getName()),
KNN_QUERY_WITH_FILTER_REQUESTS(KNNCounter.KNN_QUERY_WITH_FILTER_REQUESTS.getName());
KNN_QUERY_WITH_FILTER_REQUESTS(KNNCounter.KNN_QUERY_WITH_FILTER_REQUESTS.getName()),
GRAPH_STATS("graph_stats"),
REFRESH("refresh"),
MERGE("merge");

private String name;

Expand Down
Loading

0 comments on commit 9e3e046

Please sign in to comment.