Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changing serialization for knn vector from single array object to collection of floats #253

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import java.util.stream.Collectors;

import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE;
import static org.opensearch.knn.index.codec.KNNCodecUtil.buildEngineFileName;
import static org.opensearch.knn.index.codec.util.KNNCodecUtil.buildEngineFileName;

/**
* KNNIndexShard wraps IndexShard and adds methods to perform k-NN related operations against the shard
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
import org.apache.lucene.util.BytesRef;
import org.opensearch.ExceptionsHelper;
import org.opensearch.index.fielddata.ScriptDocValues;
import org.opensearch.knn.index.codec.util.KNNVectorSerializer;
import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.ObjectInputStream;

public final class KNNVectorScriptDocValues extends ScriptDocValues<float[]> {

Expand Down Expand Up @@ -45,12 +46,11 @@ public float[] getValue() {
try {
BytesRef value = binaryDocValues.binaryValue();
ByteArrayInputStream byteStream = new ByteArrayInputStream(value.bytes, value.offset, value.length);
ObjectInputStream objectStream = new ObjectInputStream(byteStream);
return (float[]) objectStream.readObject();
final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream);
final float[] vector = vectorSerializer.byteToFloatArray(byteStream);
return vector;
} catch (IOException e) {
throw ExceptionsHelper.convertToOpenSearchException(e);
} catch (ClassNotFoundException e) {
throw new RuntimeException((e));
}
}

Expand Down
19 changes: 5 additions & 14 deletions src/main/java/org/opensearch/knn/index/VectorField.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,19 @@
import org.apache.lucene.document.Field;
import org.apache.lucene.index.IndexableFieldType;
import org.apache.lucene.util.BytesRef;

import java.io.ByteArrayOutputStream;
import java.io.ObjectOutputStream;
import org.opensearch.knn.index.codec.util.KNNVectorSerializer;
import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory;

public class VectorField extends Field {

public VectorField(String name, float[] value, IndexableFieldType type) {
super(name, new BytesRef(), type);
try {
this.setBytesValue(floatToByte(value));
final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getDefaultSerializer();
final byte[] floatToByte = vectorSerializer.floatToByteArray(value);
this.setBytesValue(floatToByte);
} catch (Exception e) {
throw new RuntimeException(e);
}
}

public static byte[] floatToByte(float[] floats) throws Exception {
byte[] bytes;
try (ByteArrayOutputStream byteStream = new ByteArrayOutputStream();
ObjectOutputStream objectStream = new ObjectOutputStream(byteStream);) {
objectStream.writeObject(floats);
bytes = byteStream.toByteArray();
}
return bytes;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
package org.opensearch.knn.index.codec.KNN80Codec;

import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.codec.KNNCodecUtil;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.codecs.Codec;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.jni.JNIService;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.codec.KNNCodecUtil;
import org.opensearch.knn.index.codec.util.KNNCodecUtil;
import org.opensearch.knn.index.util.KNNEngine;
import org.opensearch.knn.indices.Model;
import org.opensearch.knn.indices.ModelCache;
Expand Down Expand Up @@ -46,7 +46,7 @@

import static org.opensearch.knn.common.KNNConstants.MODEL_ID;
import static org.opensearch.knn.common.KNNConstants.PARAMETERS;
import static org.opensearch.knn.index.codec.KNNCodecUtil.buildEngineFileName;
import static org.opensearch.knn.index.codec.util.KNNCodecUtil.buildEngineFileName;

/**
* This class writes the KNN docvalues to the segments
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,14 @@
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec;
package org.opensearch.knn.index.codec.util;

import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.BytesRef;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.util.ArrayList;

public class KNNCodecUtil {
Expand All @@ -34,12 +33,10 @@ public static KNNCodecUtil.Pair getFloats(BinaryDocValues values) throws IOExcep
ArrayList<Integer> docIdList = new ArrayList<>();
for (int doc = values.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = values.nextDoc()) {
BytesRef bytesref = values.binaryValue();
try (ByteArrayInputStream byteStream = new ByteArrayInputStream(bytesref.bytes, bytesref.offset, bytesref.length);
ObjectInputStream objectStream = new ObjectInputStream(byteStream)) {
float[] vector = (float[]) objectStream.readObject();
try (ByteArrayInputStream byteStream = new ByteArrayInputStream(bytesref.bytes, bytesref.offset, bytesref.length)) {
final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream);
final float[] vector = vectorSerializer.byteToFloatArray(byteStream);
vectorList.add(vector);
} catch (ClassNotFoundException e) {
throw new RuntimeException(e);
}
docIdList.add(doc);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.util;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;

/**
* Class implements KNNVectorSerializer based on standard Java serialization/deserialization as a single array object
*/
public class KNNVectorAsArraySerializer implements KNNVectorSerializer {
@Override
public byte[] floatToByteArray(float[] input) {
byte[] bytes;
try (ByteArrayOutputStream byteStream = new ByteArrayOutputStream();
ObjectOutputStream objectStream = new ObjectOutputStream(byteStream);) {
objectStream.writeObject(input);
bytes = byteStream.toByteArray();
} catch (IOException e) {
throw new RuntimeException(e);
}
return bytes;
}

@Override
public float[] byteToFloatArray(ByteArrayInputStream byteStream) {
try {
final ObjectInputStream objectStream = new ObjectInputStream(byteStream);
final float[] vector = (float[]) objectStream.readObject();
return vector;
} catch (IOException e) {
throw new RuntimeException(e);
} catch (ClassNotFoundException e) {
throw new RuntimeException(e);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.util;

import java.io.ByteArrayInputStream;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.stream.IntStream;

/**
* Class implements KNNVectorSerializer based on serialization/deserialization of array as collection of individual numbers
*/
public class KNNVectorAsCollectionOfFloatsSerializer implements KNNVectorSerializer {
private static final int BYTES_IN_FLOAT = 4;

@Override
public byte[] floatToByteArray(float[] input) {
final ByteBuffer bb = ByteBuffer.allocate(input.length * BYTES_IN_FLOAT).order(ByteOrder.BIG_ENDIAN);
IntStream.range(0, input.length).forEach((index) -> bb.putFloat(input[index]));
byte[] bytes = new byte[bb.flip().limit()];
bb.get(bytes);
return bytes;
}

@Override
public float[] byteToFloatArray(ByteArrayInputStream byteStream) {
if (byteStream == null || byteStream.available() % BYTES_IN_FLOAT != 0) {
throw new IllegalArgumentException("Byte stream cannot be deserialized to array of floats");
}
final byte[] vectorAsByteArray = new byte[byteStream.available()];
byteStream.read(vectorAsByteArray, 0, byteStream.available());
final int sizeOfFloatArray = vectorAsByteArray.length / BYTES_IN_FLOAT;
final float[] vector = new float[sizeOfFloatArray];
ByteBuffer.wrap(vectorAsByteArray).asFloatBuffer().get(vector);
return vector;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.util;

import java.io.ByteArrayInputStream;
import java.io.IOException;

/**
* Interface abstracts the vector serializer object that is responsible for serialization and de-serialization of k-NN vector
*/
public interface KNNVectorSerializer {
/**
* Serializes array of floats to array of bytes
* @param input array that will be converted
* @return array of bytes that contains serialized input array
*/
byte[] floatToByteArray(float[] input);

/**
* Deserializes all bytes from the stream to array of floats
* @param byteStream stream of bytes that will be used for deserialization to array of floats
* @return array of floats deserialized from the stream
*/
float[] byteToFloatArray(ByteArrayInputStream byteStream);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.util;

import com.google.common.collect.ImmutableMap;

import java.io.ByteArrayInputStream;
import java.io.ObjectStreamConstants;
import java.util.Arrays;
import java.util.Map;

import static org.opensearch.knn.index.codec.util.SerializationMode.ARRAY;
import static org.opensearch.knn.index.codec.util.SerializationMode.COLLECTION_OF_FLOATS;

/**
* Class abstracts Factory for KNNVectorSerializer implementations. Exact implementation constructed and returned based on
* either content of the byte array or directly based on serialization type.
*/
public class KNNVectorSerializerFactory {
private static Map<SerializationMode, KNNVectorSerializer> VECTOR_SERIALIZER_BY_TYPE = ImmutableMap.of(
ARRAY, new KNNVectorAsArraySerializer(),
COLLECTION_OF_FLOATS, new KNNVectorAsCollectionOfFloatsSerializer()
);

private static final int ARRAY_HEADER_OFFSET = 27;
private static final int BYTES_IN_FLOAT = 4;
private static final int BITS_IN_ONE_BYTE = 8;

/**
* Array represents first 6 bytes of the byte stream header as per Java serialization protocol described in details
* <a href="https://docs.oracle.com/javase/8/docs/platform/serialization/spec/protocol.html">here</a>.
*/
private static final byte[] SERIALIZATION_PROTOCOL_HEADER_PREFIX = new byte[] {
highByte(ObjectStreamConstants.STREAM_MAGIC),
lowByte(ObjectStreamConstants.STREAM_MAGIC),
highByte(ObjectStreamConstants.STREAM_VERSION),
lowByte(ObjectStreamConstants.STREAM_VERSION),
ObjectStreamConstants.TC_ARRAY,
ObjectStreamConstants.TC_CLASSDESC
};

public static KNNVectorSerializer getSerializerBySerializationMode(final SerializationMode serializationMode) {
return VECTOR_SERIALIZER_BY_TYPE.getOrDefault(serializationMode, new KNNVectorAsCollectionOfFloatsSerializer());
}

public static KNNVectorSerializer getDefaultSerializer() {
return getSerializerBySerializationMode(COLLECTION_OF_FLOATS);
}

public static KNNVectorSerializer getSerializerByStreamContent(final ByteArrayInputStream byteStream) {
final SerializationMode serializationMode = serializerModeFromStream(byteStream);
return getSerializerBySerializationMode(serializationMode);
}

private static SerializationMode serializerModeFromStream(ByteArrayInputStream byteStream) {
int numberOfAvailableBytesInStream = byteStream.available();
if (numberOfAvailableBytesInStream < ARRAY_HEADER_OFFSET) {
return getSerializerOrThrowError(numberOfAvailableBytesInStream, COLLECTION_OF_FLOATS);
}
final byte[] byteArray = new byte[SERIALIZATION_PROTOCOL_HEADER_PREFIX.length];
byteStream.read(byteArray, 0, SERIALIZATION_PROTOCOL_HEADER_PREFIX.length);
byteStream.reset();
//checking if stream protocol grammar in header is valid for serialized array
if (Arrays.equals(SERIALIZATION_PROTOCOL_HEADER_PREFIX, byteArray)) {
int numberOfAvailableBytesAfterHeader = numberOfAvailableBytesInStream - ARRAY_HEADER_OFFSET;
return getSerializerOrThrowError(numberOfAvailableBytesAfterHeader, ARRAY);
}
return getSerializerOrThrowError(numberOfAvailableBytesInStream, COLLECTION_OF_FLOATS);
}

private static SerializationMode getSerializerOrThrowError(int numberOfRemainingBytes, final SerializationMode serializationMode) {
if (numberOfRemainingBytes % BYTES_IN_FLOAT == 0) {
return serializationMode;
}
throw new IllegalArgumentException(String.format("Byte stream cannot be deserialized to array of floats due to invalid length %d", numberOfRemainingBytes));
}

private static byte highByte(short shortValue) {
return (byte) (shortValue>> BITS_IN_ONE_BYTE);
}

private static byte lowByte(short shortValue) {
return (byte) shortValue;
}


}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.util;

public enum SerializationMode {
ARRAY, COLLECTION_OF_FLOATS
}
2 changes: 1 addition & 1 deletion src/test/java/org/opensearch/knn/TestUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import org.opensearch.common.xcontent.NamedXContentRegistry;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.knn.index.codec.KNNCodecUtil;
import org.opensearch.knn.index.codec.util.KNNCodecUtil;

import java.io.BufferedReader;
import java.io.FileReader;
Expand Down
Loading