Skip to content

Commit

Permalink
Changing serialization for knn vector from single array object to col…
Browse files Browse the repository at this point in the history
…lection of floats (opensearch-project#253)

* Changing serialization for knn vector from single array object to collection of floats

rev2:
* Addressing PR comments:
- added getDefaultSerializer to Factory
- moved SerializationMode enum to a separate file
- added javadocs and comments
- adjust format, added missing endline characters
rev3:
* Addressing multiple review comments:
- replace Vector by KNNVector in class names and variables
- fixed method names in Serializer interface
- replace number of bytes in float from number to constant
rev4:
* Moving new classes under index.codec.util
rev5:
* Addressing multiple review comments:
- rework factory method getSerializerByStreamContent
- added test case for stream of unsupported content
- removed exceptions from Serializer interface method's signatures, changed it to unchecked runtime exception
- simplify license header in new classes

Signed-off-by: Martin Gaievski <gaievski@amazon.com>
  • Loading branch information
martin-gaievski authored Jan 13, 2022
1 parent d0d22ae commit 2a78a82
Show file tree
Hide file tree
Showing 13 changed files with 355 additions and 31 deletions.
2 changes: 1 addition & 1 deletion src/main/java/org/opensearch/knn/index/KNNIndexShard.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import java.util.stream.Collectors;

import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE;
import static org.opensearch.knn.index.codec.KNNCodecUtil.buildEngineFileName;
import static org.opensearch.knn.index.codec.util.KNNCodecUtil.buildEngineFileName;

/**
* KNNIndexShard wraps IndexShard and adds methods to perform k-NN related operations against the shard
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
import org.apache.lucene.util.BytesRef;
import org.opensearch.ExceptionsHelper;
import org.opensearch.index.fielddata.ScriptDocValues;
import org.opensearch.knn.index.codec.util.KNNVectorSerializer;
import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.ObjectInputStream;

public final class KNNVectorScriptDocValues extends ScriptDocValues<float[]> {

Expand Down Expand Up @@ -45,12 +46,11 @@ public float[] getValue() {
try {
BytesRef value = binaryDocValues.binaryValue();
ByteArrayInputStream byteStream = new ByteArrayInputStream(value.bytes, value.offset, value.length);
ObjectInputStream objectStream = new ObjectInputStream(byteStream);
return (float[]) objectStream.readObject();
final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream);
final float[] vector = vectorSerializer.byteToFloatArray(byteStream);
return vector;
} catch (IOException e) {
throw ExceptionsHelper.convertToOpenSearchException(e);
} catch (ClassNotFoundException e) {
throw new RuntimeException((e));
}
}

Expand Down
19 changes: 5 additions & 14 deletions src/main/java/org/opensearch/knn/index/VectorField.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,19 @@
import org.apache.lucene.document.Field;
import org.apache.lucene.index.IndexableFieldType;
import org.apache.lucene.util.BytesRef;

import java.io.ByteArrayOutputStream;
import java.io.ObjectOutputStream;
import org.opensearch.knn.index.codec.util.KNNVectorSerializer;
import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory;

public class VectorField extends Field {

public VectorField(String name, float[] value, IndexableFieldType type) {
super(name, new BytesRef(), type);
try {
this.setBytesValue(floatToByte(value));
final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getDefaultSerializer();
final byte[] floatToByte = vectorSerializer.floatToByteArray(value);
this.setBytesValue(floatToByte);
} catch (Exception e) {
throw new RuntimeException(e);
}
}

public static byte[] floatToByte(float[] floats) throws Exception {
byte[] bytes;
try (ByteArrayOutputStream byteStream = new ByteArrayOutputStream();
ObjectOutputStream objectStream = new ObjectOutputStream(byteStream);) {
objectStream.writeObject(floats);
bytes = byteStream.toByteArray();
}
return bytes;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
package org.opensearch.knn.index.codec.KNN80Codec;

import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.codec.KNNCodecUtil;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.codecs.Codec;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.jni.JNIService;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.codec.KNNCodecUtil;
import org.opensearch.knn.index.codec.util.KNNCodecUtil;
import org.opensearch.knn.index.util.KNNEngine;
import org.opensearch.knn.indices.Model;
import org.opensearch.knn.indices.ModelCache;
Expand Down Expand Up @@ -46,7 +46,7 @@

import static org.opensearch.knn.common.KNNConstants.MODEL_ID;
import static org.opensearch.knn.common.KNNConstants.PARAMETERS;
import static org.opensearch.knn.index.codec.KNNCodecUtil.buildEngineFileName;
import static org.opensearch.knn.index.codec.util.KNNCodecUtil.buildEngineFileName;

/**
* This class writes the KNN docvalues to the segments
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,14 @@
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec;
package org.opensearch.knn.index.codec.util;

import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.BytesRef;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.util.ArrayList;

public class KNNCodecUtil {
Expand All @@ -34,12 +33,10 @@ public static KNNCodecUtil.Pair getFloats(BinaryDocValues values) throws IOExcep
ArrayList<Integer> docIdList = new ArrayList<>();
for (int doc = values.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = values.nextDoc()) {
BytesRef bytesref = values.binaryValue();
try (ByteArrayInputStream byteStream = new ByteArrayInputStream(bytesref.bytes, bytesref.offset, bytesref.length);
ObjectInputStream objectStream = new ObjectInputStream(byteStream)) {
float[] vector = (float[]) objectStream.readObject();
try (ByteArrayInputStream byteStream = new ByteArrayInputStream(bytesref.bytes, bytesref.offset, bytesref.length)) {
final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream);
final float[] vector = vectorSerializer.byteToFloatArray(byteStream);
vectorList.add(vector);
} catch (ClassNotFoundException e) {
throw new RuntimeException(e);
}
docIdList.add(doc);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.util;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;

/**
* Class implements KNNVectorSerializer based on standard Java serialization/deserialization as a single array object
*/
public class KNNVectorAsArraySerializer implements KNNVectorSerializer {
@Override
public byte[] floatToByteArray(float[] input) {
byte[] bytes;
try (ByteArrayOutputStream byteStream = new ByteArrayOutputStream();
ObjectOutputStream objectStream = new ObjectOutputStream(byteStream);) {
objectStream.writeObject(input);
bytes = byteStream.toByteArray();
} catch (IOException e) {
throw new RuntimeException(e);
}
return bytes;
}

@Override
public float[] byteToFloatArray(ByteArrayInputStream byteStream) {
try {
final ObjectInputStream objectStream = new ObjectInputStream(byteStream);
final float[] vector = (float[]) objectStream.readObject();
return vector;
} catch (IOException e) {
throw new RuntimeException(e);
} catch (ClassNotFoundException e) {
throw new RuntimeException(e);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.util;

import java.io.ByteArrayInputStream;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.stream.IntStream;

/**
* Class implements KNNVectorSerializer based on serialization/deserialization of array as collection of individual numbers
*/
public class KNNVectorAsCollectionOfFloatsSerializer implements KNNVectorSerializer {
private static final int BYTES_IN_FLOAT = 4;

@Override
public byte[] floatToByteArray(float[] input) {
final ByteBuffer bb = ByteBuffer.allocate(input.length * BYTES_IN_FLOAT).order(ByteOrder.BIG_ENDIAN);
IntStream.range(0, input.length).forEach((index) -> bb.putFloat(input[index]));
byte[] bytes = new byte[bb.flip().limit()];
bb.get(bytes);
return bytes;
}

@Override
public float[] byteToFloatArray(ByteArrayInputStream byteStream) {
if (byteStream == null || byteStream.available() % BYTES_IN_FLOAT != 0) {
throw new IllegalArgumentException("Byte stream cannot be deserialized to array of floats");
}
final byte[] vectorAsByteArray = new byte[byteStream.available()];
byteStream.read(vectorAsByteArray, 0, byteStream.available());
final int sizeOfFloatArray = vectorAsByteArray.length / BYTES_IN_FLOAT;
final float[] vector = new float[sizeOfFloatArray];
ByteBuffer.wrap(vectorAsByteArray).asFloatBuffer().get(vector);
return vector;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.util;

import java.io.ByteArrayInputStream;
import java.io.IOException;

/**
* Interface abstracts the vector serializer object that is responsible for serialization and de-serialization of k-NN vector
*/
public interface KNNVectorSerializer {
/**
* Serializes array of floats to array of bytes
* @param input array that will be converted
* @return array of bytes that contains serialized input array
*/
byte[] floatToByteArray(float[] input);

/**
* Deserializes all bytes from the stream to array of floats
* @param byteStream stream of bytes that will be used for deserialization to array of floats
* @return array of floats deserialized from the stream
*/
float[] byteToFloatArray(ByteArrayInputStream byteStream);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.util;

import com.google.common.collect.ImmutableMap;

import java.io.ByteArrayInputStream;
import java.io.ObjectStreamConstants;
import java.util.Arrays;
import java.util.Map;

import static org.opensearch.knn.index.codec.util.SerializationMode.ARRAY;
import static org.opensearch.knn.index.codec.util.SerializationMode.COLLECTION_OF_FLOATS;

/**
* Class abstracts Factory for KNNVectorSerializer implementations. Exact implementation constructed and returned based on
* either content of the byte array or directly based on serialization type.
*/
public class KNNVectorSerializerFactory {
private static Map<SerializationMode, KNNVectorSerializer> VECTOR_SERIALIZER_BY_TYPE = ImmutableMap.of(
ARRAY, new KNNVectorAsArraySerializer(),
COLLECTION_OF_FLOATS, new KNNVectorAsCollectionOfFloatsSerializer()
);

private static final int ARRAY_HEADER_OFFSET = 27;
private static final int BYTES_IN_FLOAT = 4;
private static final int BITS_IN_ONE_BYTE = 8;

/**
* Array represents first 6 bytes of the byte stream header as per Java serialization protocol described in details
* <a href="https://docs.oracle.com/javase/8/docs/platform/serialization/spec/protocol.html">here</a>.
*/
private static final byte[] SERIALIZATION_PROTOCOL_HEADER_PREFIX = new byte[] {
highByte(ObjectStreamConstants.STREAM_MAGIC),
lowByte(ObjectStreamConstants.STREAM_MAGIC),
highByte(ObjectStreamConstants.STREAM_VERSION),
lowByte(ObjectStreamConstants.STREAM_VERSION),
ObjectStreamConstants.TC_ARRAY,
ObjectStreamConstants.TC_CLASSDESC
};

public static KNNVectorSerializer getSerializerBySerializationMode(final SerializationMode serializationMode) {
return VECTOR_SERIALIZER_BY_TYPE.getOrDefault(serializationMode, new KNNVectorAsCollectionOfFloatsSerializer());
}

public static KNNVectorSerializer getDefaultSerializer() {
return getSerializerBySerializationMode(COLLECTION_OF_FLOATS);
}

public static KNNVectorSerializer getSerializerByStreamContent(final ByteArrayInputStream byteStream) {
final SerializationMode serializationMode = serializerModeFromStream(byteStream);
return getSerializerBySerializationMode(serializationMode);
}

private static SerializationMode serializerModeFromStream(ByteArrayInputStream byteStream) {
int numberOfAvailableBytesInStream = byteStream.available();
if (numberOfAvailableBytesInStream < ARRAY_HEADER_OFFSET) {
return getSerializerOrThrowError(numberOfAvailableBytesInStream, COLLECTION_OF_FLOATS);
}
final byte[] byteArray = new byte[SERIALIZATION_PROTOCOL_HEADER_PREFIX.length];
byteStream.read(byteArray, 0, SERIALIZATION_PROTOCOL_HEADER_PREFIX.length);
byteStream.reset();
//checking if stream protocol grammar in header is valid for serialized array
if (Arrays.equals(SERIALIZATION_PROTOCOL_HEADER_PREFIX, byteArray)) {
int numberOfAvailableBytesAfterHeader = numberOfAvailableBytesInStream - ARRAY_HEADER_OFFSET;
return getSerializerOrThrowError(numberOfAvailableBytesAfterHeader, ARRAY);
}
return getSerializerOrThrowError(numberOfAvailableBytesInStream, COLLECTION_OF_FLOATS);
}

private static SerializationMode getSerializerOrThrowError(int numberOfRemainingBytes, final SerializationMode serializationMode) {
if (numberOfRemainingBytes % BYTES_IN_FLOAT == 0) {
return serializationMode;
}
throw new IllegalArgumentException(String.format("Byte stream cannot be deserialized to array of floats due to invalid length %d", numberOfRemainingBytes));
}

private static byte highByte(short shortValue) {
return (byte) (shortValue>> BITS_IN_ONE_BYTE);
}

private static byte lowByte(short shortValue) {
return (byte) shortValue;
}


}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.util;

public enum SerializationMode {
ARRAY, COLLECTION_OF_FLOATS
}
2 changes: 1 addition & 1 deletion src/test/java/org/opensearch/knn/TestUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import org.opensearch.common.xcontent.NamedXContentRegistry;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.knn.index.codec.KNNCodecUtil;
import org.opensearch.knn.index.codec.util.KNNCodecUtil;

import java.io.BufferedReader;
import java.io.FileReader;
Expand Down
Loading

0 comments on commit 2a78a82

Please sign in to comment.