Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor KNNCodec to use new extension point #319

Merged
merged 12 commits into from
Mar 18, 2022
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

package org.opensearch.knn.index.codec.KNN80Codec;

import org.opensearch.knn.index.codec.BinaryDocValuesSub;
import org.opensearch.knn.index.codec.util.BinaryDocValuesSub;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.DocIDMerger;
import org.apache.lucene.util.BytesRef;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@

package org.opensearch.knn.index.codec.KNN80Codec;

import org.apache.lucene.codecs.lucene50.Lucene50CompoundFormat;
import org.opensearch.knn.common.KNNConstants;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.CompoundDirectory;
import org.apache.lucene.codecs.CompoundFormat;
import org.apache.lucene.index.SegmentInfo;
Expand All @@ -26,22 +24,33 @@
*/
public class KNN80CompoundFormat extends CompoundFormat {

private final Logger logger = LogManager.getLogger(KNN80CompoundFormat.class);
private final CompoundFormat delegate;


public KNN80CompoundFormat() {
this.delegate = new Lucene50CompoundFormat();
}

/**
* Constructor that takes a delegate to handle non-overridden methods
*
* @param delegate CompoundFormat that will handle non-overridden methods
*/
public KNN80CompoundFormat(CompoundFormat delegate) {
this.delegate = delegate;
}

@Override
public CompoundDirectory getCompoundReader(Directory dir, SegmentInfo si, IOContext context) throws IOException {
return Codec.getDefault().compoundFormat().getCompoundReader(dir, si, context);
return delegate.getCompoundReader(dir, si, context);
}

@Override
public void write(Directory dir, SegmentInfo si, IOContext context) throws IOException {
for (KNNEngine knnEngine : KNNEngine.values()) {
writeEngineFiles(dir, si, context, knnEngine.getExtension());
}
Codec.getDefault().compoundFormat().write(dir, si, context);
delegate.write(dir, si, context);
}

private void writeEngineFiles(Directory dir, SegmentInfo si, IOContext context, String engineExtension)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,82 +67,84 @@ class KNN80DocValuesConsumer extends DocValuesConsumer implements Closeable {
@Override
public void addBinaryField(FieldInfo field, DocValuesProducer valuesProducer) throws IOException {
delegatee.addBinaryField(field, valuesProducer);
addKNNBinaryField(field, valuesProducer);
if (field.attributes().containsKey(KNNVectorFieldMapper.KNN_FIELD)) {
addKNNBinaryField(field, valuesProducer);
}
}

public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer) throws IOException {
KNNCounter.GRAPH_INDEX_REQUESTS.increment();
if (field.attributes().containsKey(KNNVectorFieldMapper.KNN_FIELD)) {

// Get values to be indexed
BinaryDocValues values = valuesProducer.getBinary(field);
KNNCodecUtil.Pair pair = KNNCodecUtil.getFloats(values);
if (pair.vectors.length == 0 || pair.docs.length == 0) {
logger.info("Skipping engine index creation as there are no vectors or docs in the documents");
return;
}
// Get values to be indexed
BinaryDocValues values = valuesProducer.getBinary(field);
KNNCodecUtil.Pair pair = KNNCodecUtil.getFloats(values);
if (pair.vectors.length == 0 || pair.docs.length == 0) {
logger.info("Skipping engine index creation as there are no vectors or docs in the documents");
return;
}

// Increment counter for number of graph index requests
KNNCounter.GRAPH_INDEX_REQUESTS.increment();

// Create library index either from model or from scratch
String engineFileName;
String indexPath;
String tmpEngineFileName;
// Create library index either from model or from scratch
String engineFileName;
String indexPath;
String tmpEngineFileName;

if (field.attributes().containsKey(MODEL_ID)) {
if (field.attributes().containsKey(MODEL_ID)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need null check for attributes before calling contains?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


String modelId = field.attributes().get(MODEL_ID);
Model model = ModelCache.getInstance().get(modelId);
String modelId = field.attributes().get(MODEL_ID);
Model model = ModelCache.getInstance().get(modelId);

KNNEngine knnEngine = model.getModelMetadata().getKnnEngine();
KNNEngine knnEngine = model.getModelMetadata().getKnnEngine();

engineFileName = buildEngineFileName(state.segmentInfo.name, knnEngine.getLatestBuildVersion(),
field.name, knnEngine.getExtension());
indexPath = Paths.get(((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(),
engineFileName).toString();
tmpEngineFileName = engineFileName + TEMP_SUFFIX;
String tempIndexPath = indexPath + TEMP_SUFFIX;
engineFileName = buildEngineFileName(state.segmentInfo.name, knnEngine.getLatestBuildVersion(),
field.name, knnEngine.getExtension());
indexPath = Paths.get(((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(),
engineFileName).toString();
tmpEngineFileName = engineFileName + TEMP_SUFFIX;
String tempIndexPath = indexPath + TEMP_SUFFIX;

if (model.getModelBlob() == null) {
throw new RuntimeException("There is no trained model with id \"" + modelId + "\"");
}
if (model.getModelBlob() == null) {
throw new RuntimeException("There is no trained model with id \"" + modelId + "\"");
}

createKNNIndexFromTemplate(model.getModelBlob(), pair, knnEngine, tempIndexPath);
} else {
createKNNIndexFromTemplate(model.getModelBlob(), pair, knnEngine, tempIndexPath);
} else {

// Get engine to be used for indexing
String engineName = field.attributes().getOrDefault(KNNConstants.KNN_ENGINE, KNNEngine.DEFAULT.getName());
KNNEngine knnEngine = KNNEngine.getEngine(engineName);
// Get engine to be used for indexing
String engineName = field.attributes().getOrDefault(KNNConstants.KNN_ENGINE, KNNEngine.DEFAULT.getName());
KNNEngine knnEngine = KNNEngine.getEngine(engineName);

engineFileName = buildEngineFileName(state.segmentInfo.name, knnEngine.getLatestBuildVersion(),
field.name, knnEngine.getExtension());
indexPath = Paths.get(((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(),
engineFileName).toString();
tmpEngineFileName = engineFileName + TEMP_SUFFIX;
String tempIndexPath = indexPath + TEMP_SUFFIX;
engineFileName = buildEngineFileName(state.segmentInfo.name, knnEngine.getLatestBuildVersion(),
field.name, knnEngine.getExtension());
indexPath = Paths.get(((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(),
engineFileName).toString();
tmpEngineFileName = engineFileName + TEMP_SUFFIX;
String tempIndexPath = indexPath + TEMP_SUFFIX;

createKNNIndexFromScratch(field, pair, knnEngine, tempIndexPath);
}
createKNNIndexFromScratch(field, pair, knnEngine, tempIndexPath);
}

/*
* Adds Footer to the serialized graph
* 1. Copies the serialized graph to new file.
* 2. Adds Footer to the new file.
*
* We had to create new file here because adding footer directly to the
* existing file will miss calculating checksum for the serialized graph
* bytes and result in index corruption issues.
*/
//TODO: I think this can be refactored to avoid this copy and then write
// https://github.com/opendistro-for-elasticsearch/k-NN/issues/330
try (IndexInput is = state.directory.openInput(tmpEngineFileName, state.context);
IndexOutput os = state.directory.createOutput(engineFileName, state.context)) {
os.copyBytes(is, is.length());
CodecUtil.writeFooter(os);
} catch (Exception ex) {
KNNCounter.GRAPH_INDEX_ERRORS.increment();
throw new RuntimeException("[KNN] Adding footer to serialized graph failed: " + ex);
} finally {
IOUtils.deleteFilesIgnoringExceptions(state.directory, tmpEngineFileName);
}
/*
* Adds Footer to the serialized graph
* 1. Copies the serialized graph to new file.
* 2. Adds Footer to the new file.
*
* We had to create new file here because adding footer directly to the
* existing file will miss calculating checksum for the serialized graph
* bytes and result in index corruption issues.
*/
//TODO: I think this can be refactored to avoid this copy and then write
// https://github.com/opendistro-for-elasticsearch/k-NN/issues/330
try (IndexInput is = state.directory.openInput(tmpEngineFileName, state.context);
IndexOutput os = state.directory.createOutput(engineFileName, state.context)) {
os.copyBytes(is, is.length());
CodecUtil.writeFooter(os);
} catch (Exception ex) {
KNNCounter.GRAPH_INDEX_ERRORS.increment();
throw new RuntimeException("[KNN] Adding footer to serialized graph failed: " + ex);
} finally {
IOUtils.deleteFilesIgnoringExceptions(state.directory, tmpEngineFileName);
}
}

Expand Down Expand Up @@ -214,7 +216,7 @@ public void merge(MergeState mergeState) {
assert mergeState.mergeFieldInfos != null;
for (FieldInfo fieldInfo : mergeState.mergeFieldInfos) {
DocValuesType type = fieldInfo.getDocValuesType();
if (type == DocValuesType.BINARY) {
if (type == DocValuesType.BINARY && fieldInfo.attributes().containsKey(KNNVectorFieldMapper.KNN_FIELD)) {
addKNNBinaryField(fieldInfo, new KNN80DocValuesReader(mergeState));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@

package org.opensearch.knn.index.codec.KNN80Codec;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.codecs.DocValuesConsumer;
import org.apache.lucene.codecs.DocValuesFormat;
import org.apache.lucene.codecs.DocValuesProducer;
Expand All @@ -19,11 +17,21 @@
* Encodes/Decodes per document values
*/
public class KNN80DocValuesFormat extends DocValuesFormat {
private final Logger logger = LogManager.getLogger(KNN80DocValuesFormat.class);
private final DocValuesFormat delegate = DocValuesFormat.forName(KNN80Codec.LUCENE_80);
private final DocValuesFormat delegate;

public KNN80DocValuesFormat() {
super(KNN80Codec.LUCENE_80);
this.delegate = DocValuesFormat.forName(KNN80Codec.LUCENE_80);
}

/**
* Constructor that takes delegate in order to handle non-overridden methods
*
* @param delegate DocValuesFormat to handle non-overridden methods
*/
public KNN80DocValuesFormat(DocValuesFormat delegate) {
super(delegate.getName());
this.delegate = delegate;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

package org.opensearch.knn.index.codec.KNN80Codec;

import org.opensearch.knn.index.codec.BinaryDocValuesSub;
import org.opensearch.knn.index.codec.util.BinaryDocValuesSub;
import org.apache.lucene.codecs.DocValuesProducer;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.DocIDMerger;
Expand All @@ -22,7 +22,7 @@
*/
class KNN80DocValuesReader extends EmptyDocValuesProducer {

private MergeState mergeState;
private final MergeState mergeState;

KNN80DocValuesReader(MergeState mergeState) {
this.mergeState = mergeState;
Expand Down
Loading