Skip to content

Commit

Permalink
more corrective factor cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
john-wagster committed Sep 5, 2024
1 parent a17f0fd commit 2fd2c3a
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 102 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,7 @@ public static short decodeClusterIdFromByte(byte bClusterId) {
return (short) Byte.toUnsignedInt(bClusterId);
}

public abstract float getDistanceToCentroid() throws IOException;

public abstract float getMagnitude() throws IOException;

public abstract float getOOQ() throws IOException;

public abstract float getNormOC() throws IOException;

public abstract float getODotC() throws IOException;
public abstract float[] getCorrectiveTerms();

public abstract byte[] vectorValue() throws IOException;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,14 +243,8 @@ private void writeBinarizedVectors(
float[] corrections =
scalarQuantizer.quantizeForIndex(v, vector, clusterCenters[clusterId]);
binarizedVectorData.writeBytes(vector, vector.length);
// FIXME: handle of sim types like MIP such as COSINE?
if (scalarQuantizer.getSimilarity() == VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT) {
correctionsBuffer.putFloat(corrections[0]);
correctionsBuffer.putFloat(corrections[1]);
correctionsBuffer.putFloat(corrections[2]);
} else {
correctionsBuffer.putFloat(corrections[0]);
correctionsBuffer.putFloat(corrections[1]);
for (int i = 0; i < corrections.length; i++) {
correctionsBuffer.putFloat(corrections[i]);
}
binarizedVectorData.writeBytes(correctionsBuffer.array(), correctionsBuffer.array().length);
correctionsBuffer.rewind();
Expand All @@ -260,14 +254,8 @@ private void writeBinarizedVectors(
for (float[] v : fieldData.getVectors()) {
float[] corrections = scalarQuantizer.quantizeForIndex(v, vector, clusterCenter);
binarizedVectorData.writeBytes(vector, vector.length);
// FIXME: handle of sim types like MIP such as COSINE?
if (scalarQuantizer.getSimilarity() == VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT) {
correctionsBuffer.putFloat(corrections[0]);
correctionsBuffer.putFloat(corrections[1]);
correctionsBuffer.putFloat(corrections[2]);
} else {
correctionsBuffer.putFloat(corrections[0]);
correctionsBuffer.putFloat(corrections[1]);
for (int i = 0; i < corrections.length; i++) {
correctionsBuffer.putFloat(corrections[i]);
}
binarizedVectorData.writeBytes(correctionsBuffer.array(), correctionsBuffer.array().length);
correctionsBuffer.rewind();
Expand Down Expand Up @@ -335,14 +323,8 @@ private void writeSortedBinarizedVectors(
float[] corrections =
scalarQuantizer.quantizeForIndex(v, vector, clusterCenters[clusterId]);
binarizedVectorData.writeBytes(vector, vector.length);
// FIXME: handle of sim types like MIP such as COSINE?
if (scalarQuantizer.getSimilarity() == VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT) {
correctionsBuffer.putFloat(corrections[0]);
correctionsBuffer.putFloat(corrections[1]);
correctionsBuffer.putFloat(corrections[2]);
} else {
correctionsBuffer.putFloat(corrections[0]);
correctionsBuffer.putFloat(corrections[1]);
for (int i = 0; i < corrections.length; i++) {
correctionsBuffer.putFloat(corrections[i]);
}
binarizedVectorData.writeBytes(correctionsBuffer.array(), correctionsBuffer.array().length);
correctionsBuffer.rewind();
Expand All @@ -353,14 +335,8 @@ private void writeSortedBinarizedVectors(
float[] v = fieldData.getVectors().get(ordinal);
float[] corrections = scalarQuantizer.quantizeForIndex(v, vector, clusterCenter);
binarizedVectorData.writeBytes(vector, vector.length);
// FIXME: handle of sim types like MIP such as COSINE?
if (scalarQuantizer.getSimilarity() == VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT) {
correctionsBuffer.putFloat(corrections[0]);
correctionsBuffer.putFloat(corrections[1]);
correctionsBuffer.putFloat(corrections[2]);
} else {
correctionsBuffer.putFloat(corrections[0]);
correctionsBuffer.putFloat(corrections[1]);
for (int i = 0; i < corrections.length; i++) {
correctionsBuffer.putFloat(corrections[i]);
}
binarizedVectorData.writeBytes(correctionsBuffer.array(), correctionsBuffer.array().length);
correctionsBuffer.rewind();
Expand Down Expand Up @@ -575,7 +551,7 @@ static void writeQueryBinarizedVectorData(
correctionsBuffer.putFloat(factors.lower());
correctionsBuffer.putFloat(factors.width());

// FIXME: handle other similarity types here like COSINE
// FIXME: handle other similarity types?
if (quantizer.getSimilarity() == VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT) {
correctionsBuffer.putFloat(factors.normVmC());
correctionsBuffer.putFloat(factors.vDotC());
Expand Down Expand Up @@ -607,15 +583,10 @@ static DocsWithFieldSet writeBinarizedVectorData(
if (vectorOrdToClusterOrdWriter != null) {
vectorOrdToClusterOrdWriter.add(binarizedByteVectorValues.clusterId());
}
// FIXME: handle other similarity functions the same as MIP such as COSINE
// TODO handle quantization output correctly
if (similarityFunction == VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT) {
output.writeInt(Float.floatToIntBits(binarizedByteVectorValues.getOOQ()));
output.writeInt(Float.floatToIntBits(binarizedByteVectorValues.getNormOC()));
output.writeInt(Float.floatToIntBits(binarizedByteVectorValues.getODotC()));
} else {
output.writeInt(Float.floatToIntBits(binarizedByteVectorValues.getDistanceToCentroid()));
output.writeInt(Float.floatToIntBits(binarizedByteVectorValues.getMagnitude()));
float[] corrections = binarizedByteVectorValues.getCorrectiveTerms();
for (int i = 0; i < corrections.length; i++) {
output.writeInt(Float.floatToIntBits(corrections[i]));
}
docsWithField.add(docV);
}
Expand Down Expand Up @@ -1175,29 +1146,8 @@ public short clusterId() {
return clusterId;
}

@Override
public float getDistanceToCentroid() {
return corrections[0];
}

@Override
public float getMagnitude() {
return corrections[1];
}

@Override
public float getOOQ() {
return corrections[0];
}

@Override
public float getNormOC() {
return corrections[1];
}

@Override
public float getODotC() {
return corrections[2];
public float[] getCorrectiveTerms() {
return corrections;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,28 +113,8 @@ public byte[] vectorValue(int targetOrd) throws IOException {
}

@Override
public float getDistanceToCentroid() {
return correctiveValues[0];
}

@Override
public float getMagnitude() {
return correctiveValues[1];
}

@Override
public float getOOQ() {
return correctiveValues[0];
}

@Override
public float getNormOC() {
return correctiveValues[1];
}

@Override
public float getODotC() {
return correctiveValues[2];
public float[] getCorrectiveTerms() {
return correctiveValues;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,13 +178,9 @@ public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
float[] corrections =
quantizer.quantizeForIndex(vectorValues.vectorValue(), expectedVector, centroid);
assertArrayEquals(expectedVector, qvectorValues.vectorValue());
assertEquals(corrections[0], qvectorValues.getOOQ(), 0.00001f);
assertEquals(corrections[1], qvectorValues.getNormOC(), 0.00001f);
if (corrections.length == 3) {
assertEquals(corrections[2], qvectorValues.getODotC(), 0.00001f);
assertEquals(VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT, similarityFunction);
} else {
assertEquals(2, corrections.length);
assertEquals(corrections.length, qvectorValues.getCorrectiveTerms().length);
for (int i = 0; i < corrections.length; i++) {
assertEquals(corrections[i], qvectorValues.getCorrectiveTerms()[i], 0.00001f);
}
}
}
Expand Down

0 comments on commit 2fd2c3a

Please sign in to comment.