Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ARROW-401: Floating point vectors should do an approximate comparison… #223

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 50 additions & 1 deletion java/tools/src/main/java/org/apache/arrow/tools/Integration.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
import org.apache.arrow.vector.file.json.JsonFileReader;
import org.apache.arrow.vector.file.json.JsonFileWriter;
import org.apache.arrow.vector.schema.ArrowRecordBatch;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.ArrowType.FloatingPoint;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.commons.cli.CommandLine;
Expand Down Expand Up @@ -247,14 +249,61 @@ private static void compare(VectorSchemaRoot arrowRoot, VectorSchemaRoot jsonRoo
for (int j = 0; j < valueCount; j++) {
Object arrow = arrowVector.getAccessor().getObject(j);
Object json = jsonVector.getAccessor().getObject(j);
if (!Objects.equal(arrow, json)) {
if (!equals(field.getType(), arrow, json)) {
throw new IllegalArgumentException(
"Different values in column:\n" + field + " at index " + j + ": " + arrow + " != " + json);
}
}
}
}

private static boolean equals(ArrowType type, final Object arrow, final Object json) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my proposed functions in the jira. The code here is insufficient for a number of situations.

if (type instanceof ArrowType.FloatingPoint) {
FloatingPoint fpType = (FloatingPoint) type;
switch (fpType.getPrecision()) {
case DOUBLE:
return equalEnough((Double)arrow, (Double)json);
case SINGLE:
return equalEnough((Float)arrow, (Float)json);
case HALF:
default:
throw new UnsupportedOperationException("unsupported precision: " + fpType);
}
}
return Objects.equal(arrow, json);
}

static boolean equalEnough(Float f1, Float f2) {
if (f1 == null || f2 == null) {
return f1 == null && f2 == null;
}
if (f1.isNaN()) {
return f2.isNaN();
}
if (f1.isInfinite()) {
return f2.isInfinite() && Math.signum(f1) == Math.signum(f2);
}
float average = Math.abs((f1 + f2) / 2);
float differenceScaled = Math.abs(f1 - f2) / (average == 0.0f ? 1f : average);
return differenceScaled < 1.0E-6f;
}

static boolean equalEnough(Double f1, Double f2) {
if (f1 == null || f2 == null) {
return f1 == null && f2 == null;
}
if (f1.isNaN()) {
return f2.isNaN();
}
if (f1.isInfinite()) {
return f2.isInfinite() && Math.signum(f1) == Math.signum(f2);
}
double average = Math.abs((f1 + f2) / 2);
double differenceScaled = Math.abs(f1 - f2) / (average == 0.0d ? 1d : average);
return differenceScaled < 1.0E-12d;
}


private static void compareSchemas(Schema jsonSchema, Schema arrowSchema) {
if (!arrowSchema.equals(jsonSchema)) {
throw new IllegalArgumentException("Different schemas:\n" + arrowSchema + "\n" + jsonSchema);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
import static org.apache.arrow.tools.ArrowFileTestFixtures.write;
import static org.apache.arrow.tools.ArrowFileTestFixtures.writeData;
import static org.apache.arrow.tools.ArrowFileTestFixtures.writeInput;
import static org.apache.arrow.tools.Integration.equalEnough;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;

import java.io.BufferedReader;
Expand All @@ -39,9 +43,9 @@
import org.apache.arrow.vector.complex.writer.BaseWriter.ComplexWriter;
import org.apache.arrow.vector.complex.writer.BaseWriter.MapWriter;
import org.apache.arrow.vector.complex.writer.BigIntWriter;
import org.apache.arrow.vector.complex.writer.Float8Writer;
import org.apache.arrow.vector.complex.writer.IntWriter;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
Expand Down Expand Up @@ -121,7 +125,7 @@ public void testJSONRoundTripWithVariableWidth() throws Exception {
String i, o;
int j = 0;
while ((i = orig.readLine()) != null && (o = rt.readLine()) != null) {
Assert.assertEquals("line: " + j, i, o);
assertEquals("line: " + j, i, o);
++j;
}
}
Expand All @@ -142,6 +146,33 @@ private BufferedReader readNormalized(File f) throws IOException {
}


/**
* the test should not be sensitive to small variations in float representation
*/
@Test
public void testFloat() throws Exception {
File testValidInFile = testFolder.newFile("testValidFloatIn.arrow");
File testInvalidInFile = testFolder.newFile("testAlsoValidFloatIn.arrow");
File testJSONFile = testFolder.newFile("testValidOut.json");
testJSONFile.delete();

// generate an arrow file
writeInputFloat(testValidInFile, allocator, 912.4140000000002, 912.414);
// generate a different arrow file
writeInputFloat(testInvalidInFile, allocator, 912.414, 912.4140000000002);

Integration integration = new Integration();

// convert the "valid" file to json
String[] args1 = { "-arrow", testValidInFile.getAbsolutePath(), "-json", testJSONFile.getAbsolutePath(), "-command", Command.ARROW_TO_JSON.name()};
integration.run(args1);

// compare the "invalid" file to the "valid" json
String[] args3 = { "-arrow", testInvalidInFile.getAbsolutePath(), "-json", testJSONFile.getAbsolutePath(), "-command", Command.VALIDATE.name()};
// this should fail
integration.run(args3);
}

@Test
public void testInvalid() throws Exception {
File testValidInFile = testFolder.newFile("testValidIn.arrow");
Expand All @@ -167,12 +198,28 @@ public void testInvalid() throws Exception {
integration.run(args3);
fail("should have failed");
} catch (IllegalArgumentException e) {
Assert.assertTrue(e.getMessage(), e.getMessage().contains("Different values in column"));
Assert.assertTrue(e.getMessage(), e.getMessage().contains("999"));
assertTrue(e.getMessage(), e.getMessage().contains("Different values in column"));
assertTrue(e.getMessage(), e.getMessage().contains("999"));
}

}

static void writeInputFloat(File testInFile, BufferAllocator allocator, double... f) throws FileNotFoundException, IOException {
try (
BufferAllocator vectorAllocator = allocator.newChildAllocator("original vectors", 0, Integer.MAX_VALUE);
MapVector parent = new MapVector("parent", vectorAllocator, null)) {
ComplexWriter writer = new ComplexWriterImpl("root", parent);
MapWriter rootWriter = writer.rootAsMap();
Float8Writer floatWriter = rootWriter.float8("float");
for (int i = 0; i < f.length; i++) {
floatWriter.setPosition(i);
floatWriter.writeFloat8(f[i]);
}
writer.setValueCount(f.length);
write(parent.getChild("root"), testInFile);
}
}

static void writeInput2(File testInFile, BufferAllocator allocator) throws FileNotFoundException, IOException {
int count = ArrowFileTestFixtures.COUNT;
try (
Expand All @@ -192,4 +239,33 @@ static void writeInput2(File testInFile, BufferAllocator allocator) throws FileN
}
}

@Test
public void testFloatComp() {
assertTrue(equalEnough(912.4140000000002F, 912.414F));
assertTrue(equalEnough(912.4140000000002D, 912.414D));
assertTrue(equalEnough(912.414F, 912.4140000000002F));
assertTrue(equalEnough(912.414D, 912.4140000000002D));
assertFalse(equalEnough(912.414D, 912.4140001D));
assertFalse(equalEnough(null, 912.414D));
assertTrue(equalEnough((Float)null, null));
assertTrue(equalEnough((Double)null, null));
assertFalse(equalEnough(912.414D, null));
assertFalse(equalEnough(Double.MAX_VALUE, Double.MIN_VALUE));
assertFalse(equalEnough(Double.MIN_VALUE, Double.MAX_VALUE));
assertTrue(equalEnough(Double.MAX_VALUE, Double.MAX_VALUE));
assertTrue(equalEnough(Double.MIN_VALUE, Double.MIN_VALUE));
assertTrue(equalEnough(Double.NEGATIVE_INFINITY, Double.NEGATIVE_INFINITY));
assertFalse(equalEnough(Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY));
assertTrue(equalEnough(Double.NaN, Double.NaN));
assertFalse(equalEnough(1.0, Double.NaN));
assertFalse(equalEnough(Float.MAX_VALUE, Float.MIN_VALUE));
assertFalse(equalEnough(Float.MIN_VALUE, Float.MAX_VALUE));
assertTrue(equalEnough(Float.MAX_VALUE, Float.MAX_VALUE));
assertTrue(equalEnough(Float.MIN_VALUE, Float.MIN_VALUE));
assertTrue(equalEnough(Float.NEGATIVE_INFINITY, Float.NEGATIVE_INFINITY));
assertFalse(equalEnough(Float.NEGATIVE_INFINITY, Float.POSITIVE_INFINITY));
assertTrue(equalEnough(Float.NaN, Float.NaN));
assertFalse(equalEnough(1.0F, Float.NaN));
}

}