forked from opensearch-project/k-NN
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Changing serialization for knn vector from single array object to col…
…lection of floats (opensearch-project#253) * Changing serialization for knn vector from single array object to collection of floats rev2: * Addressing PR comments: - added getDefaultSerializer to Factory - moved SerializationMode enum to a separate file - added javadocs and comments - adjust format, added missing endline characters rev3: * Addressing multiple review comments: - replace Vector by KNNVector in class names and variables - fixed method names in Serializer interface - replace number of bytes in float from number to constant rev4: * Moving new classes under index.codec.util rev5: * Addressing multiple review comments: - rework factory method getSerializerByStreamContent - added test case for stream of unsupported content - removed exceptions from Serializer interface method's signatures, changed it to unchecked runtime exception - simplify license header in new classes Signed-off-by: Martin Gaievski <gaievski@amazon.com>
- Loading branch information
1 parent
d0d22ae
commit 2a78a82
Showing
13 changed files
with
355 additions
and
31 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
43 changes: 43 additions & 0 deletions
43
src/main/java/org/opensearch/knn/index/codec/util/KNNVectorAsArraySerializer.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.knn.index.codec.util; | ||
|
||
import java.io.ByteArrayInputStream; | ||
import java.io.ByteArrayOutputStream; | ||
import java.io.IOException; | ||
import java.io.ObjectInputStream; | ||
import java.io.ObjectOutputStream; | ||
|
||
/** | ||
* Class implements KNNVectorSerializer based on standard Java serialization/deserialization as a single array object | ||
*/ | ||
public class KNNVectorAsArraySerializer implements KNNVectorSerializer { | ||
@Override | ||
public byte[] floatToByteArray(float[] input) { | ||
byte[] bytes; | ||
try (ByteArrayOutputStream byteStream = new ByteArrayOutputStream(); | ||
ObjectOutputStream objectStream = new ObjectOutputStream(byteStream);) { | ||
objectStream.writeObject(input); | ||
bytes = byteStream.toByteArray(); | ||
} catch (IOException e) { | ||
throw new RuntimeException(e); | ||
} | ||
return bytes; | ||
} | ||
|
||
@Override | ||
public float[] byteToFloatArray(ByteArrayInputStream byteStream) { | ||
try { | ||
final ObjectInputStream objectStream = new ObjectInputStream(byteStream); | ||
final float[] vector = (float[]) objectStream.readObject(); | ||
return vector; | ||
} catch (IOException e) { | ||
throw new RuntimeException(e); | ||
} catch (ClassNotFoundException e) { | ||
throw new RuntimeException(e); | ||
} | ||
} | ||
} |
40 changes: 40 additions & 0 deletions
40
...ain/java/org/opensearch/knn/index/codec/util/KNNVectorAsCollectionOfFloatsSerializer.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.knn.index.codec.util; | ||
|
||
import java.io.ByteArrayInputStream; | ||
import java.nio.ByteBuffer; | ||
import java.nio.ByteOrder; | ||
import java.util.stream.IntStream; | ||
|
||
/** | ||
* Class implements KNNVectorSerializer based on serialization/deserialization of array as collection of individual numbers | ||
*/ | ||
public class KNNVectorAsCollectionOfFloatsSerializer implements KNNVectorSerializer { | ||
private static final int BYTES_IN_FLOAT = 4; | ||
|
||
@Override | ||
public byte[] floatToByteArray(float[] input) { | ||
final ByteBuffer bb = ByteBuffer.allocate(input.length * BYTES_IN_FLOAT).order(ByteOrder.BIG_ENDIAN); | ||
IntStream.range(0, input.length).forEach((index) -> bb.putFloat(input[index])); | ||
byte[] bytes = new byte[bb.flip().limit()]; | ||
bb.get(bytes); | ||
return bytes; | ||
} | ||
|
||
@Override | ||
public float[] byteToFloatArray(ByteArrayInputStream byteStream) { | ||
if (byteStream == null || byteStream.available() % BYTES_IN_FLOAT != 0) { | ||
throw new IllegalArgumentException("Byte stream cannot be deserialized to array of floats"); | ||
} | ||
final byte[] vectorAsByteArray = new byte[byteStream.available()]; | ||
byteStream.read(vectorAsByteArray, 0, byteStream.available()); | ||
final int sizeOfFloatArray = vectorAsByteArray.length / BYTES_IN_FLOAT; | ||
final float[] vector = new float[sizeOfFloatArray]; | ||
ByteBuffer.wrap(vectorAsByteArray).asFloatBuffer().get(vector); | ||
return vector; | ||
} | ||
} |
28 changes: 28 additions & 0 deletions
28
src/main/java/org/opensearch/knn/index/codec/util/KNNVectorSerializer.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.knn.index.codec.util; | ||
|
||
import java.io.ByteArrayInputStream; | ||
import java.io.IOException; | ||
|
||
/** | ||
* Interface abstracts the vector serializer object that is responsible for serialization and de-serialization of k-NN vector | ||
*/ | ||
public interface KNNVectorSerializer { | ||
/** | ||
* Serializes array of floats to array of bytes | ||
* @param input array that will be converted | ||
* @return array of bytes that contains serialized input array | ||
*/ | ||
byte[] floatToByteArray(float[] input); | ||
|
||
/** | ||
* Deserializes all bytes from the stream to array of floats | ||
* @param byteStream stream of bytes that will be used for deserialization to array of floats | ||
* @return array of floats deserialized from the stream | ||
*/ | ||
float[] byteToFloatArray(ByteArrayInputStream byteStream); | ||
} |
90 changes: 90 additions & 0 deletions
90
src/main/java/org/opensearch/knn/index/codec/util/KNNVectorSerializerFactory.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.knn.index.codec.util; | ||
|
||
import com.google.common.collect.ImmutableMap; | ||
|
||
import java.io.ByteArrayInputStream; | ||
import java.io.ObjectStreamConstants; | ||
import java.util.Arrays; | ||
import java.util.Map; | ||
|
||
import static org.opensearch.knn.index.codec.util.SerializationMode.ARRAY; | ||
import static org.opensearch.knn.index.codec.util.SerializationMode.COLLECTION_OF_FLOATS; | ||
|
||
/** | ||
* Class abstracts Factory for KNNVectorSerializer implementations. Exact implementation constructed and returned based on | ||
* either content of the byte array or directly based on serialization type. | ||
*/ | ||
public class KNNVectorSerializerFactory { | ||
private static Map<SerializationMode, KNNVectorSerializer> VECTOR_SERIALIZER_BY_TYPE = ImmutableMap.of( | ||
ARRAY, new KNNVectorAsArraySerializer(), | ||
COLLECTION_OF_FLOATS, new KNNVectorAsCollectionOfFloatsSerializer() | ||
); | ||
|
||
private static final int ARRAY_HEADER_OFFSET = 27; | ||
private static final int BYTES_IN_FLOAT = 4; | ||
private static final int BITS_IN_ONE_BYTE = 8; | ||
|
||
/** | ||
* Array represents first 6 bytes of the byte stream header as per Java serialization protocol described in details | ||
* <a href="https://docs.oracle.com/javase/8/docs/platform/serialization/spec/protocol.html">here</a>. | ||
*/ | ||
private static final byte[] SERIALIZATION_PROTOCOL_HEADER_PREFIX = new byte[] { | ||
highByte(ObjectStreamConstants.STREAM_MAGIC), | ||
lowByte(ObjectStreamConstants.STREAM_MAGIC), | ||
highByte(ObjectStreamConstants.STREAM_VERSION), | ||
lowByte(ObjectStreamConstants.STREAM_VERSION), | ||
ObjectStreamConstants.TC_ARRAY, | ||
ObjectStreamConstants.TC_CLASSDESC | ||
}; | ||
|
||
public static KNNVectorSerializer getSerializerBySerializationMode(final SerializationMode serializationMode) { | ||
return VECTOR_SERIALIZER_BY_TYPE.getOrDefault(serializationMode, new KNNVectorAsCollectionOfFloatsSerializer()); | ||
} | ||
|
||
public static KNNVectorSerializer getDefaultSerializer() { | ||
return getSerializerBySerializationMode(COLLECTION_OF_FLOATS); | ||
} | ||
|
||
public static KNNVectorSerializer getSerializerByStreamContent(final ByteArrayInputStream byteStream) { | ||
final SerializationMode serializationMode = serializerModeFromStream(byteStream); | ||
return getSerializerBySerializationMode(serializationMode); | ||
} | ||
|
||
private static SerializationMode serializerModeFromStream(ByteArrayInputStream byteStream) { | ||
int numberOfAvailableBytesInStream = byteStream.available(); | ||
if (numberOfAvailableBytesInStream < ARRAY_HEADER_OFFSET) { | ||
return getSerializerOrThrowError(numberOfAvailableBytesInStream, COLLECTION_OF_FLOATS); | ||
} | ||
final byte[] byteArray = new byte[SERIALIZATION_PROTOCOL_HEADER_PREFIX.length]; | ||
byteStream.read(byteArray, 0, SERIALIZATION_PROTOCOL_HEADER_PREFIX.length); | ||
byteStream.reset(); | ||
//checking if stream protocol grammar in header is valid for serialized array | ||
if (Arrays.equals(SERIALIZATION_PROTOCOL_HEADER_PREFIX, byteArray)) { | ||
int numberOfAvailableBytesAfterHeader = numberOfAvailableBytesInStream - ARRAY_HEADER_OFFSET; | ||
return getSerializerOrThrowError(numberOfAvailableBytesAfterHeader, ARRAY); | ||
} | ||
return getSerializerOrThrowError(numberOfAvailableBytesInStream, COLLECTION_OF_FLOATS); | ||
} | ||
|
||
private static SerializationMode getSerializerOrThrowError(int numberOfRemainingBytes, final SerializationMode serializationMode) { | ||
if (numberOfRemainingBytes % BYTES_IN_FLOAT == 0) { | ||
return serializationMode; | ||
} | ||
throw new IllegalArgumentException(String.format("Byte stream cannot be deserialized to array of floats due to invalid length %d", numberOfRemainingBytes)); | ||
} | ||
|
||
private static byte highByte(short shortValue) { | ||
return (byte) (shortValue>> BITS_IN_ONE_BYTE); | ||
} | ||
|
||
private static byte lowByte(short shortValue) { | ||
return (byte) shortValue; | ||
} | ||
|
||
|
||
} |
10 changes: 10 additions & 0 deletions
10
src/main/java/org/opensearch/knn/index/codec/util/SerializationMode.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.knn.index.codec.util; | ||
|
||
public enum SerializationMode { | ||
ARRAY, COLLECTION_OF_FLOATS | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.