diff --git a/java/tools/src/main/java/org/apache/arrow/tools/Integration.java b/java/tools/src/main/java/org/apache/arrow/tools/Integration.java index 85af30da1e8ae..fd835a63a11ac 100644 --- a/java/tools/src/main/java/org/apache/arrow/tools/Integration.java +++ b/java/tools/src/main/java/org/apache/arrow/tools/Integration.java @@ -39,6 +39,8 @@ import org.apache.arrow.vector.file.json.JsonFileReader; import org.apache.arrow.vector.file.json.JsonFileWriter; import org.apache.arrow.vector.schema.ArrowRecordBatch; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.ArrowType.FloatingPoint; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.commons.cli.CommandLine; @@ -247,7 +249,7 @@ private static void compare(VectorSchemaRoot arrowRoot, VectorSchemaRoot jsonRoo for (int j = 0; j < valueCount; j++) { Object arrow = arrowVector.getAccessor().getObject(j); Object json = jsonVector.getAccessor().getObject(j); - if (!Objects.equal(arrow, json)) { + if (!equals(field.getType(), arrow, json)) { throw new IllegalArgumentException( "Different values in column:\n" + field + " at index " + j + ": " + arrow + " != " + json); } @@ -255,6 +257,53 @@ private static void compare(VectorSchemaRoot arrowRoot, VectorSchemaRoot jsonRoo } } + private static boolean equals(ArrowType type, final Object arrow, final Object json) { + if (type instanceof ArrowType.FloatingPoint) { + FloatingPoint fpType = (FloatingPoint) type; + switch (fpType.getPrecision()) { + case DOUBLE: + return equalEnough((Double)arrow, (Double)json); + case SINGLE: + return equalEnough((Float)arrow, (Float)json); + case HALF: + default: + throw new UnsupportedOperationException("unsupported precision: " + fpType); + } + } + return Objects.equal(arrow, json); + } + + static boolean equalEnough(Float f1, Float f2) { + if (f1 == null || f2 == null) { + return f1 == null && f2 == null; + } + if (f1.isNaN()) { + return f2.isNaN(); + } + if (f1.isInfinite()) { + return f2.isInfinite() && Math.signum(f1) == Math.signum(f2); + } + float average = Math.abs((f1 + f2) / 2); + float differenceScaled = Math.abs(f1 - f2) / (average == 0.0f ? 1f : average); + return differenceScaled < 1.0E-6f; + } + + static boolean equalEnough(Double f1, Double f2) { + if (f1 == null || f2 == null) { + return f1 == null && f2 == null; + } + if (f1.isNaN()) { + return f2.isNaN(); + } + if (f1.isInfinite()) { + return f2.isInfinite() && Math.signum(f1) == Math.signum(f2); + } + double average = Math.abs((f1 + f2) / 2); + double differenceScaled = Math.abs(f1 - f2) / (average == 0.0d ? 1d : average); + return differenceScaled < 1.0E-12d; + } + + private static void compareSchemas(Schema jsonSchema, Schema arrowSchema) { if (!arrowSchema.equals(jsonSchema)) { throw new IllegalArgumentException("Different schemas:\n" + arrowSchema + "\n" + jsonSchema); diff --git a/java/tools/src/test/java/org/apache/arrow/tools/TestIntegration.java b/java/tools/src/test/java/org/apache/arrow/tools/TestIntegration.java index 464144b95a1aa..ee6196b74e0dc 100644 --- a/java/tools/src/test/java/org/apache/arrow/tools/TestIntegration.java +++ b/java/tools/src/test/java/org/apache/arrow/tools/TestIntegration.java @@ -22,6 +22,10 @@ import static org.apache.arrow.tools.ArrowFileTestFixtures.write; import static org.apache.arrow.tools.ArrowFileTestFixtures.writeData; import static org.apache.arrow.tools.ArrowFileTestFixtures.writeInput; +import static org.apache.arrow.tools.Integration.equalEnough; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import java.io.BufferedReader; @@ -39,9 +43,9 @@ import org.apache.arrow.vector.complex.writer.BaseWriter.ComplexWriter; import org.apache.arrow.vector.complex.writer.BaseWriter.MapWriter; import org.apache.arrow.vector.complex.writer.BigIntWriter; +import org.apache.arrow.vector.complex.writer.Float8Writer; import org.apache.arrow.vector.complex.writer.IntWriter; import org.junit.After; -import org.junit.Assert; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -121,7 +125,7 @@ public void testJSONRoundTripWithVariableWidth() throws Exception { String i, o; int j = 0; while ((i = orig.readLine()) != null && (o = rt.readLine()) != null) { - Assert.assertEquals("line: " + j, i, o); + assertEquals("line: " + j, i, o); ++j; } } @@ -142,6 +146,33 @@ private BufferedReader readNormalized(File f) throws IOException { } + /** + * the test should not be sensitive to small variations in float representation + */ + @Test + public void testFloat() throws Exception { + File testValidInFile = testFolder.newFile("testValidFloatIn.arrow"); + File testInvalidInFile = testFolder.newFile("testAlsoValidFloatIn.arrow"); + File testJSONFile = testFolder.newFile("testValidOut.json"); + testJSONFile.delete(); + + // generate an arrow file + writeInputFloat(testValidInFile, allocator, 912.4140000000002, 912.414); + // generate a different arrow file + writeInputFloat(testInvalidInFile, allocator, 912.414, 912.4140000000002); + + Integration integration = new Integration(); + + // convert the "valid" file to json + String[] args1 = { "-arrow", testValidInFile.getAbsolutePath(), "-json", testJSONFile.getAbsolutePath(), "-command", Command.ARROW_TO_JSON.name()}; + integration.run(args1); + + // compare the "invalid" file to the "valid" json + String[] args3 = { "-arrow", testInvalidInFile.getAbsolutePath(), "-json", testJSONFile.getAbsolutePath(), "-command", Command.VALIDATE.name()}; + // this should fail + integration.run(args3); + } + @Test public void testInvalid() throws Exception { File testValidInFile = testFolder.newFile("testValidIn.arrow"); @@ -167,12 +198,28 @@ public void testInvalid() throws Exception { integration.run(args3); fail("should have failed"); } catch (IllegalArgumentException e) { - Assert.assertTrue(e.getMessage(), e.getMessage().contains("Different values in column")); - Assert.assertTrue(e.getMessage(), e.getMessage().contains("999")); + assertTrue(e.getMessage(), e.getMessage().contains("Different values in column")); + assertTrue(e.getMessage(), e.getMessage().contains("999")); } } + static void writeInputFloat(File testInFile, BufferAllocator allocator, double... f) throws FileNotFoundException, IOException { + try ( + BufferAllocator vectorAllocator = allocator.newChildAllocator("original vectors", 0, Integer.MAX_VALUE); + MapVector parent = new MapVector("parent", vectorAllocator, null)) { + ComplexWriter writer = new ComplexWriterImpl("root", parent); + MapWriter rootWriter = writer.rootAsMap(); + Float8Writer floatWriter = rootWriter.float8("float"); + for (int i = 0; i < f.length; i++) { + floatWriter.setPosition(i); + floatWriter.writeFloat8(f[i]); + } + writer.setValueCount(f.length); + write(parent.getChild("root"), testInFile); + } + } + static void writeInput2(File testInFile, BufferAllocator allocator) throws FileNotFoundException, IOException { int count = ArrowFileTestFixtures.COUNT; try ( @@ -192,4 +239,33 @@ static void writeInput2(File testInFile, BufferAllocator allocator) throws FileN } } + @Test + public void testFloatComp() { + assertTrue(equalEnough(912.4140000000002F, 912.414F)); + assertTrue(equalEnough(912.4140000000002D, 912.414D)); + assertTrue(equalEnough(912.414F, 912.4140000000002F)); + assertTrue(equalEnough(912.414D, 912.4140000000002D)); + assertFalse(equalEnough(912.414D, 912.4140001D)); + assertFalse(equalEnough(null, 912.414D)); + assertTrue(equalEnough((Float)null, null)); + assertTrue(equalEnough((Double)null, null)); + assertFalse(equalEnough(912.414D, null)); + assertFalse(equalEnough(Double.MAX_VALUE, Double.MIN_VALUE)); + assertFalse(equalEnough(Double.MIN_VALUE, Double.MAX_VALUE)); + assertTrue(equalEnough(Double.MAX_VALUE, Double.MAX_VALUE)); + assertTrue(equalEnough(Double.MIN_VALUE, Double.MIN_VALUE)); + assertTrue(equalEnough(Double.NEGATIVE_INFINITY, Double.NEGATIVE_INFINITY)); + assertFalse(equalEnough(Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY)); + assertTrue(equalEnough(Double.NaN, Double.NaN)); + assertFalse(equalEnough(1.0, Double.NaN)); + assertFalse(equalEnough(Float.MAX_VALUE, Float.MIN_VALUE)); + assertFalse(equalEnough(Float.MIN_VALUE, Float.MAX_VALUE)); + assertTrue(equalEnough(Float.MAX_VALUE, Float.MAX_VALUE)); + assertTrue(equalEnough(Float.MIN_VALUE, Float.MIN_VALUE)); + assertTrue(equalEnough(Float.NEGATIVE_INFINITY, Float.NEGATIVE_INFINITY)); + assertFalse(equalEnough(Float.NEGATIVE_INFINITY, Float.POSITIVE_INFINITY)); + assertTrue(equalEnough(Float.NaN, Float.NaN)); + assertFalse(equalEnough(1.0F, Float.NaN)); + } + }