diff --git a/api/src/main/java/org/apache/iceberg/expressions/BoundUnaryPredicate.java b/api/src/main/java/org/apache/iceberg/expressions/BoundUnaryPredicate.java index 528dd1f11300..3989c49a3c3f 100644 --- a/api/src/main/java/org/apache/iceberg/expressions/BoundUnaryPredicate.java +++ b/api/src/main/java/org/apache/iceberg/expressions/BoundUnaryPredicate.java @@ -19,6 +19,8 @@ package org.apache.iceberg.expressions; +import org.apache.iceberg.util.NaNUtil; + public class BoundUnaryPredicate extends BoundPredicate { BoundUnaryPredicate(Operation op, BoundTerm term) { super(op, term); @@ -46,6 +48,10 @@ public boolean test(T value) { return value == null; case NOT_NULL: return value != null; + case IS_NAN: + return NaNUtil.isNaN(value); + case NOT_NAN: + return !NaNUtil.isNaN(value); default: throw new IllegalStateException("Invalid operation for BoundUnaryPredicate: " + op()); } @@ -58,6 +64,10 @@ public String toString() { return "is_null(" + term() + ")"; case NOT_NULL: return "not_null(" + term() + ")"; + case IS_NAN: + return "is_nan(" + term() + ")"; + case NOT_NAN: + return "not_nan(" + term() + ")"; default: return "Invalid unary predicate: operation = " + op(); } diff --git a/api/src/main/java/org/apache/iceberg/expressions/Evaluator.java b/api/src/main/java/org/apache/iceberg/expressions/Evaluator.java index b80a13b27f8f..a4d7b67b8bde 100644 --- a/api/src/main/java/org/apache/iceberg/expressions/Evaluator.java +++ b/api/src/main/java/org/apache/iceberg/expressions/Evaluator.java @@ -25,6 +25,7 @@ import org.apache.iceberg.StructLike; import org.apache.iceberg.expressions.ExpressionVisitors.BoundVisitor; import org.apache.iceberg.types.Types.StructType; +import org.apache.iceberg.util.NaNUtil; /** * Evaluates an {@link Expression} for data described by a {@link StructType}. @@ -91,6 +92,16 @@ public Boolean notNull(Bound valueExpr) { return valueExpr.eval(struct) != null; } + @Override + public Boolean isNaN(Bound valueExpr) { + return NaNUtil.isNaN(valueExpr.eval(struct)); + } + + @Override + public Boolean notNaN(Bound valueExpr) { + return !NaNUtil.isNaN(valueExpr.eval(struct)); + } + @Override public Boolean lt(Bound valueExpr, Literal lit) { Comparator cmp = lit.comparator(); diff --git a/api/src/main/java/org/apache/iceberg/expressions/Expression.java b/api/src/main/java/org/apache/iceberg/expressions/Expression.java index 18606a6ddb42..b49f0070ec8a 100644 --- a/api/src/main/java/org/apache/iceberg/expressions/Expression.java +++ b/api/src/main/java/org/apache/iceberg/expressions/Expression.java @@ -30,6 +30,8 @@ enum Operation { FALSE, IS_NULL, NOT_NULL, + IS_NAN, + NOT_NAN, LT, LT_EQ, GT, @@ -52,6 +54,10 @@ public Operation negate() { return Operation.NOT_NULL; case NOT_NULL: return Operation.IS_NULL; + case IS_NAN: + return Operation.NOT_NAN; + case NOT_NAN: + return Operation.IS_NAN; case LT: return Operation.GT_EQ; case LT_EQ: diff --git a/api/src/main/java/org/apache/iceberg/expressions/ExpressionVisitors.java b/api/src/main/java/org/apache/iceberg/expressions/ExpressionVisitors.java index eab693c48158..cbec7485521d 100644 --- a/api/src/main/java/org/apache/iceberg/expressions/ExpressionVisitors.java +++ b/api/src/main/java/org/apache/iceberg/expressions/ExpressionVisitors.java @@ -75,6 +75,14 @@ public R notNull(BoundReference ref) { return null; } + public R isNaN(BoundReference ref) { + throw new UnsupportedOperationException(this.getClass().getName() + " does not implement isNaN"); + } + + public R notNaN(BoundReference ref) { + throw new UnsupportedOperationException(this.getClass().getName() + " does not implement notNaN"); + } + public R lt(BoundReference ref, Literal lit) { return null; } @@ -143,6 +151,10 @@ public R predicate(BoundPredicate pred) { return isNull((BoundReference) pred.term()); case NOT_NULL: return notNull((BoundReference) pred.term()); + case IS_NAN: + return isNaN((BoundReference) pred.term()); + case NOT_NAN: + return notNaN((BoundReference) pred.term()); default: throw new IllegalStateException("Invalid operation for BoundUnaryPredicate: " + pred.op()); } @@ -176,6 +188,14 @@ public R notNull(Bound expr) { return null; } + public R isNaN(Bound expr) { + throw new UnsupportedOperationException(this.getClass().getName() + " does not implement isNaN"); + } + + public R notNaN(Bound expr) { + throw new UnsupportedOperationException(this.getClass().getName() + " does not implement notNaN"); + } + public R lt(Bound expr, Literal lit) { return null; } @@ -241,6 +261,10 @@ public R predicate(BoundPredicate pred) { return isNull(pred.term()); case NOT_NULL: return notNull(pred.term()); + case IS_NAN: + return isNaN(pred.term()); + case NOT_NAN: + return notNaN(pred.term()); default: throw new IllegalStateException("Invalid operation for BoundUnaryPredicate: " + pred.op()); } diff --git a/api/src/main/java/org/apache/iceberg/expressions/Expressions.java b/api/src/main/java/org/apache/iceberg/expressions/Expressions.java index 22abf70d454e..1dba525f21fb 100644 --- a/api/src/main/java/org/apache/iceberg/expressions/Expressions.java +++ b/api/src/main/java/org/apache/iceberg/expressions/Expressions.java @@ -26,6 +26,7 @@ import org.apache.iceberg.transforms.Transform; import org.apache.iceberg.transforms.Transforms; import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.NaNUtil; /** * Factory methods for creating {@link Expression expressions}. @@ -123,51 +124,79 @@ public static UnboundPredicate notNull(UnboundTerm expr) { return new UnboundPredicate<>(Expression.Operation.NOT_NULL, expr); } + public static UnboundPredicate isNaN(String name) { + return new UnboundPredicate<>(Expression.Operation.IS_NAN, ref(name)); + } + + public static UnboundPredicate isNaN(UnboundTerm expr) { + return new UnboundPredicate<>(Expression.Operation.IS_NAN, expr); + } + + public static UnboundPredicate notNaN(String name) { + return new UnboundPredicate<>(Expression.Operation.NOT_NAN, ref(name)); + } + + public static UnboundPredicate notNaN(UnboundTerm expr) { + return new UnboundPredicate<>(Expression.Operation.NOT_NAN, expr); + } + public static UnboundPredicate lessThan(String name, T value) { + validateInput("lessThan", value); return new UnboundPredicate<>(Expression.Operation.LT, ref(name), value); } public static UnboundPredicate lessThan(UnboundTerm expr, T value) { + validateInput("lessThan", value); return new UnboundPredicate<>(Expression.Operation.LT, expr, value); } public static UnboundPredicate lessThanOrEqual(String name, T value) { + validateInput("lessThanOrEqual", value); return new UnboundPredicate<>(Expression.Operation.LT_EQ, ref(name), value); } public static UnboundPredicate lessThanOrEqual(UnboundTerm expr, T value) { + validateInput("lessThanOrEqual", value); return new UnboundPredicate<>(Expression.Operation.LT_EQ, expr, value); } public static UnboundPredicate greaterThan(String name, T value) { + validateInput("greaterThan", value); return new UnboundPredicate<>(Expression.Operation.GT, ref(name), value); } public static UnboundPredicate greaterThan(UnboundTerm expr, T value) { + validateInput("greaterThan", value); return new UnboundPredicate<>(Expression.Operation.GT, expr, value); } public static UnboundPredicate greaterThanOrEqual(String name, T value) { + validateInput("greaterThanOrEqual", value); return new UnboundPredicate<>(Expression.Operation.GT_EQ, ref(name), value); } public static UnboundPredicate greaterThanOrEqual(UnboundTerm expr, T value) { + validateInput("greaterThanOrEqual", value); return new UnboundPredicate<>(Expression.Operation.GT_EQ, expr, value); } public static UnboundPredicate equal(String name, T value) { + validateInput("equal", value); return new UnboundPredicate<>(Expression.Operation.EQ, ref(name), value); } public static UnboundPredicate equal(UnboundTerm expr, T value) { + validateInput("equal", value); return new UnboundPredicate<>(Expression.Operation.EQ, expr, value); } public static UnboundPredicate notEqual(String name, T value) { + validateInput("notEqual", value); return new UnboundPredicate<>(Expression.Operation.NOT_EQ, ref(name), value); } public static UnboundPredicate notEqual(UnboundTerm expr, T value) { + validateInput("notEqual", value); return new UnboundPredicate<>(Expression.Operation.NOT_EQ, expr, value); } @@ -216,29 +245,43 @@ public static UnboundPredicate notIn(UnboundTerm expr, Iterable val } public static UnboundPredicate predicate(Operation op, String name, T value) { + validateInput(op.toString(), value); return predicate(op, name, Literals.from(value)); } public static UnboundPredicate predicate(Operation op, String name, Literal lit) { - Preconditions.checkArgument(op != Operation.IS_NULL && op != Operation.NOT_NULL, + Preconditions.checkArgument( + op != Operation.IS_NULL && op != Operation.NOT_NULL && op != Operation.IS_NAN && op != Operation.NOT_NAN, "Cannot create %s predicate inclusive a value", op); return new UnboundPredicate(op, ref(name), lit); } public static UnboundPredicate predicate(Operation op, String name, Iterable values) { + validateInput(op.toString(), values); return predicate(op, ref(name), values); } public static UnboundPredicate predicate(Operation op, String name) { - Preconditions.checkArgument(op == Operation.IS_NULL || op == Operation.NOT_NULL, + Preconditions.checkArgument( + op == Operation.IS_NULL || op == Operation.NOT_NULL || op == Operation.IS_NAN || op == Operation.NOT_NAN, "Cannot create %s predicate without a value", op); return new UnboundPredicate<>(op, ref(name)); } private static UnboundPredicate predicate(Operation op, UnboundTerm expr, Iterable values) { + validateInput(op.toString(), values); return new UnboundPredicate<>(op, expr, values); } + private static void validateInput(String op, T value) { + Preconditions.checkArgument(!NaNUtil.isNaN(value), String.format("Cannot create %s predicate with NaN", op)); + } + + private static void validateInput(String op, Iterable values) { + Preconditions.checkArgument(Lists.newArrayList(values).stream().noneMatch(NaNUtil::isNaN), + String.format("Cannot create %s predicate with NaN", op)); + } + public static True alwaysTrue() { return True.INSTANCE; } diff --git a/api/src/main/java/org/apache/iceberg/expressions/InclusiveMetricsEvaluator.java b/api/src/main/java/org/apache/iceberg/expressions/InclusiveMetricsEvaluator.java index 344a453fb521..e452d27287a8 100644 --- a/api/src/main/java/org/apache/iceberg/expressions/InclusiveMetricsEvaluator.java +++ b/api/src/main/java/org/apache/iceberg/expressions/InclusiveMetricsEvaluator.java @@ -76,6 +76,7 @@ public boolean eval(ContentFile file) { private class MetricsEvalVisitor extends BoundExpressionVisitor { private Map valueCounts = null; private Map nullCounts = null; + private Map nanCounts = null; private Map lowerBounds = null; private Map upperBounds = null; @@ -93,6 +94,7 @@ private boolean eval(ContentFile file) { this.valueCounts = file.valueCounts(); this.nullCounts = file.nullValueCounts(); + this.nanCounts = file.nanValueCounts(); this.lowerBounds = file.lowerBounds(); this.upperBounds = file.upperBounds(); @@ -150,6 +152,34 @@ public Boolean notNull(BoundReference ref) { return ROWS_MIGHT_MATCH; } + @Override + public Boolean isNaN(BoundReference ref) { + Integer id = ref.fieldId(); + + if (nanCounts != null && nanCounts.containsKey(id) && nanCounts.get(id) == 0) { + return ROWS_CANNOT_MATCH; + } + + // when there's no nanCounts information, but we already know the column only contains null, + // it's guaranteed that there's no NaN value + if (containsNullsOnly(id)) { + return ROWS_CANNOT_MATCH; + } + + return ROWS_MIGHT_MATCH; + } + + @Override + public Boolean notNaN(BoundReference ref) { + Integer id = ref.fieldId(); + + if (containsNaNsOnly(id)) { + return ROWS_CANNOT_MATCH; + } + + return ROWS_MIGHT_MATCH; + } + @Override public Boolean lt(BoundReference ref, Literal lit) { Integer id = ref.fieldId(); @@ -347,5 +377,10 @@ private boolean containsNullsOnly(Integer id) { nullCounts != null && nullCounts.containsKey(id) && valueCounts.get(id) - nullCounts.get(id) == 0; } + + private boolean containsNaNsOnly(Integer id) { + return nanCounts != null && nanCounts.containsKey(id) && + valueCounts != null && nanCounts.get(id).equals(valueCounts.get(id)); + } } } diff --git a/api/src/main/java/org/apache/iceberg/expressions/ManifestEvaluator.java b/api/src/main/java/org/apache/iceberg/expressions/ManifestEvaluator.java index e0f4728b5e84..d072959419d7 100644 --- a/api/src/main/java/org/apache/iceberg/expressions/ManifestEvaluator.java +++ b/api/src/main/java/org/apache/iceberg/expressions/ManifestEvaluator.java @@ -134,14 +134,31 @@ public Boolean notNull(BoundReference ref) { int pos = Accessors.toPosition(ref.accessor()); // containsNull encodes whether at least one partition value is null, lowerBound is null if // all partition values are null. - ByteBuffer lowerBound = stats.get(pos).lowerBound(); - if (lowerBound == null) { + if (stats.get(pos).containsNull() && stats.get(pos).lowerBound() == null) { + return ROWS_CANNOT_MATCH; // all values are null + } + + return ROWS_MIGHT_MATCH; + } + + @Override + public Boolean isNaN(BoundReference ref) { + int pos = Accessors.toPosition(ref.accessor()); + // containsNull encodes whether at least one partition value is null, lowerBound is null if + // all partition values are null. + if (stats.get(pos).containsNull() && stats.get(pos).lowerBound() == null) { return ROWS_CANNOT_MATCH; // all values are null } return ROWS_MIGHT_MATCH; } + @Override + public Boolean notNaN(BoundReference ref) { + // we don't have enough information to tell if there is no NaN value + return ROWS_MIGHT_MATCH; + } + @Override public Boolean lt(BoundReference ref, Literal lit) { int pos = Accessors.toPosition(ref.accessor()); diff --git a/api/src/main/java/org/apache/iceberg/expressions/ResidualEvaluator.java b/api/src/main/java/org/apache/iceberg/expressions/ResidualEvaluator.java index 3ea3d7bb7e21..791d484ecb75 100644 --- a/api/src/main/java/org/apache/iceberg/expressions/ResidualEvaluator.java +++ b/api/src/main/java/org/apache/iceberg/expressions/ResidualEvaluator.java @@ -28,6 +28,7 @@ import org.apache.iceberg.StructLike; import org.apache.iceberg.expressions.ExpressionVisitors.BoundExpressionVisitor; import org.apache.iceberg.transforms.Transform; +import org.apache.iceberg.util.NaNUtil; /** * Finds the residuals for an {@link Expression} the partitions in the given {@link PartitionSpec}. @@ -152,6 +153,16 @@ public Expression notNull(BoundReference ref) { return (ref.eval(struct) != null) ? alwaysTrue() : alwaysFalse(); } + @Override + public Expression isNaN(BoundReference ref) { + return NaNUtil.isNaN(ref.eval(struct)) ? alwaysTrue() : alwaysFalse(); + } + + @Override + public Expression notNaN(BoundReference ref) { + return NaNUtil.isNaN(ref.eval(struct)) ? alwaysFalse() : alwaysTrue(); + } + @Override public Expression lt(BoundReference ref, Literal lit) { Comparator cmp = lit.comparator(); diff --git a/api/src/main/java/org/apache/iceberg/expressions/StrictMetricsEvaluator.java b/api/src/main/java/org/apache/iceberg/expressions/StrictMetricsEvaluator.java index 8fd0b602459d..d46a8216b1ec 100644 --- a/api/src/main/java/org/apache/iceberg/expressions/StrictMetricsEvaluator.java +++ b/api/src/main/java/org/apache/iceberg/expressions/StrictMetricsEvaluator.java @@ -57,10 +57,10 @@ public StrictMetricsEvaluator(Schema schema, Expression unbound) { } /** - * Test whether the file may contain records that match the expression. + * Test whether all records within the file match the expression. * * @param file a data file - * @return false if the file cannot contain rows that match the expression, true otherwise. + * @return false if the file may contain any row that doesn't match the expression, true otherwise. */ public boolean eval(ContentFile file) { // TODO: detect the case where a column is missing from the file using file's max field id. @@ -73,6 +73,7 @@ public boolean eval(ContentFile file) { private class MetricsEvalVisitor extends BoundExpressionVisitor { private Map valueCounts = null; private Map nullCounts = null; + private Map nanCounts = null; private Map lowerBounds = null; private Map upperBounds = null; @@ -83,6 +84,7 @@ private boolean eval(ContentFile file) { this.valueCounts = file.valueCounts(); this.nullCounts = file.nullValueCounts(); + this.nanCounts = file.nanValueCounts(); this.lowerBounds = file.lowerBounds(); this.upperBounds = file.upperBounds(); @@ -118,7 +120,7 @@ public Boolean or(Boolean leftResult, Boolean rightResult) { public Boolean isNull(BoundReference ref) { // no need to check whether the field is required because binding evaluates that case // if the column has any non-null values, the expression does not match - Integer id = ref.fieldId(); + int id = ref.fieldId(); Preconditions.checkNotNull(struct.field(id), "Cannot filter by nested column: %s", schema.findField(id)); @@ -133,7 +135,7 @@ public Boolean isNull(BoundReference ref) { public Boolean notNull(BoundReference ref) { // no need to check whether the field is required because binding evaluates that case // if the column has any null values, the expression does not match - Integer id = ref.fieldId(); + int id = ref.fieldId(); Preconditions.checkNotNull(struct.field(id), "Cannot filter by nested column: %s", schema.findField(id)); @@ -144,6 +146,32 @@ public Boolean notNull(BoundReference ref) { return ROWS_MIGHT_NOT_MATCH; } + @Override + public Boolean isNaN(BoundReference ref) { + int id = ref.fieldId(); + + if (containsNaNsOnly(id)) { + return ROWS_MUST_MATCH; + } + + return ROWS_MIGHT_NOT_MATCH; + } + + @Override + public Boolean notNaN(BoundReference ref) { + int id = ref.fieldId(); + + if (nanCounts != null && nanCounts.containsKey(id) && nanCounts.get(id) == 0) { + return ROWS_MUST_MATCH; + } + + if (containsNullsOnly(id)) { + return ROWS_MUST_MATCH; + } + + return ROWS_MIGHT_NOT_MATCH; + } + @Override public Boolean lt(BoundReference ref, Literal lit) { // Rows must match when: <----------Min----Max---X-------> @@ -383,5 +411,10 @@ private boolean containsNullsOnly(Integer id) { nullCounts != null && nullCounts.containsKey(id) && valueCounts.get(id) - nullCounts.get(id) == 0; } + + private boolean containsNaNsOnly(Integer id) { + return nanCounts != null && nanCounts.containsKey(id) && + valueCounts != null && nanCounts.get(id).equals(valueCounts.get(id)); + } } } diff --git a/api/src/main/java/org/apache/iceberg/expressions/UnboundPredicate.java b/api/src/main/java/org/apache/iceberg/expressions/UnboundPredicate.java index e11806e2f65e..1a53395995c5 100644 --- a/api/src/main/java/org/apache/iceberg/expressions/UnboundPredicate.java +++ b/api/src/main/java/org/apache/iceberg/expressions/UnboundPredicate.java @@ -27,6 +27,7 @@ import org.apache.iceberg.relocated.com.google.common.collect.Iterables; import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.types.Type; import org.apache.iceberg.types.Types.StructType; import org.apache.iceberg.util.CharSequenceSet; @@ -127,11 +128,27 @@ private Expression bindUnaryOperation(BoundTerm boundTerm) { return Expressions.alwaysTrue(); } return new BoundUnaryPredicate<>(Operation.NOT_NULL, boundTerm); + case IS_NAN: + if (floatingType(boundTerm.type().typeId())) { + return new BoundUnaryPredicate<>(Operation.IS_NAN, boundTerm); + } else { + throw new ValidationException("IsNaN cannot be used with a non-floating-point column"); + } + case NOT_NAN: + if (floatingType(boundTerm.type().typeId())) { + return new BoundUnaryPredicate<>(Operation.NOT_NAN, boundTerm); + } else { + throw new ValidationException("NotNaN cannot be used with a non-floating-point column"); + } default: - throw new ValidationException("Operation must be IS_NULL or NOT_NULL"); + throw new ValidationException("Operation must be IS_NULL, NOT_NULL, IS_NAN, or NOT_NAN"); } } + private boolean floatingType(Type.TypeID typeID) { + return Type.TypeID.DOUBLE.equals(typeID) || Type.TypeID.FLOAT.equals(typeID); + } + private Expression bindLiteralOperation(BoundTerm boundTerm) { Literal lit = literal().to(boundTerm.type()); @@ -210,6 +227,10 @@ public String toString() { return "is_null(" + term() + ")"; case NOT_NULL: return "not_null(" + term() + ")"; + case IS_NAN: + return "is_nan(" + term() + ")"; + case NOT_NAN: + return "not_nan(" + term() + ")"; case LT: return term() + " < " + literal(); case LT_EQ: diff --git a/api/src/main/java/org/apache/iceberg/util/NaNUtil.java b/api/src/main/java/org/apache/iceberg/util/NaNUtil.java new file mode 100644 index 000000000000..4a0176629bc2 --- /dev/null +++ b/api/src/main/java/org/apache/iceberg/util/NaNUtil.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.util; + +public class NaNUtil { + + private NaNUtil() { + } + + public static boolean isNaN(Object value) { + if (value == null) { + return false; + } + + if (value instanceof Double) { + return Double.isNaN((Double) value); + } else if (value instanceof Float) { + return Float.isNaN((Float) value); + } else { + return false; + } + } +} diff --git a/api/src/test/java/org/apache/iceberg/TestHelpers.java b/api/src/test/java/org/apache/iceberg/TestHelpers.java index 92d87ac66045..091b032dadbb 100644 --- a/api/src/test/java/org/apache/iceberg/TestHelpers.java +++ b/api/src/test/java/org/apache/iceberg/TestHelpers.java @@ -311,16 +311,18 @@ public static class TestDataFile implements DataFile { private final long recordCount; private final Map valueCounts; private final Map nullValueCounts; + private final Map nanValueCounts; private final Map lowerBounds; private final Map upperBounds; public TestDataFile(String path, StructLike partition, long recordCount) { - this(path, partition, recordCount, null, null, null, null); + this(path, partition, recordCount, null, null, null, null, null); } public TestDataFile(String path, StructLike partition, long recordCount, Map valueCounts, Map nullValueCounts, + Map nanValueCounts, Map lowerBounds, Map upperBounds) { this.path = path; @@ -328,6 +330,7 @@ public TestDataFile(String path, StructLike partition, long recordCount, this.recordCount = recordCount; this.valueCounts = valueCounts; this.nullValueCounts = nullValueCounts; + this.nanValueCounts = nanValueCounts; this.lowerBounds = lowerBounds; this.upperBounds = upperBounds; } @@ -384,7 +387,7 @@ public Map nullValueCounts() { @Override public Map nanValueCounts() { - return null; // will be updated in a separate pr soon + return nanValueCounts; } @Override diff --git a/api/src/test/java/org/apache/iceberg/expressions/TestEvaluator.java b/api/src/test/java/org/apache/iceberg/expressions/TestEvaluator.java index a437f7483558..911e3ff300c0 100644 --- a/api/src/test/java/org/apache/iceberg/expressions/TestEvaluator.java +++ b/api/src/test/java/org/apache/iceberg/expressions/TestEvaluator.java @@ -38,12 +38,14 @@ import static org.apache.iceberg.expressions.Expressions.greaterThan; import static org.apache.iceberg.expressions.Expressions.greaterThanOrEqual; import static org.apache.iceberg.expressions.Expressions.in; +import static org.apache.iceberg.expressions.Expressions.isNaN; import static org.apache.iceberg.expressions.Expressions.isNull; import static org.apache.iceberg.expressions.Expressions.lessThan; import static org.apache.iceberg.expressions.Expressions.lessThanOrEqual; import static org.apache.iceberg.expressions.Expressions.not; import static org.apache.iceberg.expressions.Expressions.notEqual; import static org.apache.iceberg.expressions.Expressions.notIn; +import static org.apache.iceberg.expressions.Expressions.notNaN; import static org.apache.iceberg.expressions.Expressions.notNull; import static org.apache.iceberg.expressions.Expressions.or; import static org.apache.iceberg.expressions.Expressions.predicate; @@ -59,7 +61,10 @@ public class TestEvaluator { Types.NestedField.required(17, "s2", Types.StructType.of( Types.NestedField.required(18, "s3", Types.StructType.of( Types.NestedField.required(19, "s4", Types.StructType.of( - Types.NestedField.required(20, "i", Types.IntegerType.get())))))))))); + Types.NestedField.required(20, "i", Types.IntegerType.get()))))))))), + optional(21, "s5", Types.StructType.of( + Types.NestedField.required(22, "s6", Types.StructType.of( + Types.NestedField.required(23, "f", Types.FloatType.get())))))); @Test public void testLessThan() { @@ -256,6 +261,32 @@ public void testNotNull() { TestHelpers.Row.of(3))))))); } + @Test + public void testIsNan() { + Evaluator evaluator = new Evaluator(STRUCT, isNaN("y")); + Assert.assertTrue("NaN is NaN", evaluator.eval(TestHelpers.Row.of(1, Double.NaN, 3))); + Assert.assertFalse("2 is not NaN", evaluator.eval(TestHelpers.Row.of(1, 2.0, 3))); + + Evaluator structEvaluator = new Evaluator(STRUCT, isNaN("s5.s6.f")); + Assert.assertTrue("NaN is NaN", structEvaluator.eval(TestHelpers.Row.of(1, 2, 3, null, + TestHelpers.Row.of(TestHelpers.Row.of(Float.NaN))))); + Assert.assertFalse("4F is not NaN", structEvaluator.eval(TestHelpers.Row.of(1, 2, 3, null, + TestHelpers.Row.of(TestHelpers.Row.of(4F))))); + } + + @Test + public void testNotNaN() { + Evaluator evaluator = new Evaluator(STRUCT, notNaN("y")); + Assert.assertFalse("NaN is NaN", evaluator.eval(TestHelpers.Row.of(1, Double.NaN, 3))); + Assert.assertTrue("2 is not NaN", evaluator.eval(TestHelpers.Row.of(1, 2.0, 3))); + + Evaluator structEvaluator = new Evaluator(STRUCT, notNaN("s5.s6.f")); + Assert.assertFalse("NaN is NaN", structEvaluator.eval(TestHelpers.Row.of(1, 2, 3, null, + TestHelpers.Row.of(TestHelpers.Row.of(Float.NaN))))); + Assert.assertTrue("4F is not NaN", structEvaluator.eval(TestHelpers.Row.of(1, 2, 3, null, + TestHelpers.Row.of(TestHelpers.Row.of(4F))))); + } + @Test public void testAnd() { Evaluator evaluator = new Evaluator(STRUCT, and(equal("x", 7), notNull("z"))); diff --git a/api/src/test/java/org/apache/iceberg/expressions/TestExpressionHelpers.java b/api/src/test/java/org/apache/iceberg/expressions/TestExpressionHelpers.java index 302e4a66d05c..2ac1e5973b57 100644 --- a/api/src/test/java/org/apache/iceberg/expressions/TestExpressionHelpers.java +++ b/api/src/test/java/org/apache/iceberg/expressions/TestExpressionHelpers.java @@ -19,7 +19,9 @@ package org.apache.iceberg.expressions; +import java.util.concurrent.Callable; import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.transforms.Transforms; import org.apache.iceberg.types.Types; import org.apache.iceberg.types.Types.NestedField; import org.apache.iceberg.types.Types.StructType; @@ -45,6 +47,8 @@ import static org.apache.iceberg.expressions.Expressions.notIn; import static org.apache.iceberg.expressions.Expressions.notNull; import static org.apache.iceberg.expressions.Expressions.or; +import static org.apache.iceberg.expressions.Expressions.predicate; +import static org.apache.iceberg.expressions.Expressions.ref; import static org.apache.iceberg.expressions.Expressions.rewriteNot; import static org.apache.iceberg.expressions.Expressions.truncate; import static org.apache.iceberg.expressions.Expressions.year; @@ -187,4 +191,44 @@ public void testMultiAnd() { Assert.assertEquals(expected.toString(), actual.toString()); } + + @Test + public void testInvalidateNaNInput() { + assertInvalidateNaNThrows("lessThan", () -> lessThan("a", Double.NaN)); + assertInvalidateNaNThrows("lessThan", () -> lessThan(self("a"), Double.NaN)); + + assertInvalidateNaNThrows("lessThanOrEqual", () -> lessThanOrEqual("a", Double.NaN)); + assertInvalidateNaNThrows("lessThanOrEqual", () -> lessThanOrEqual(self("a"), Double.NaN)); + + assertInvalidateNaNThrows("greaterThan", () -> greaterThan("a", Double.NaN)); + assertInvalidateNaNThrows("greaterThan", () -> greaterThan(self("a"), Double.NaN)); + + assertInvalidateNaNThrows("greaterThanOrEqual", () -> greaterThanOrEqual("a", Double.NaN)); + assertInvalidateNaNThrows("greaterThanOrEqual", () -> greaterThanOrEqual(self("a"), Double.NaN)); + + assertInvalidateNaNThrows("equal", () -> equal("a", Double.NaN)); + assertInvalidateNaNThrows("equal", () -> equal(self("a"), Double.NaN)); + + assertInvalidateNaNThrows("notEqual", () -> notEqual("a", Double.NaN)); + assertInvalidateNaNThrows("notEqual", () -> notEqual(self("a"), Double.NaN)); + + assertInvalidateNaNThrows("IN", () -> in("a", 1.0D, 2.0D, Double.NaN)); + assertInvalidateNaNThrows("IN", () -> in(self("a"), 1.0D, 2.0D, Double.NaN)); + + assertInvalidateNaNThrows("NOT_IN", () -> notIn("a", 1.0D, 2.0D, Double.NaN)); + assertInvalidateNaNThrows("NOT_IN", () -> notIn(self("a"), 1.0D, 2.0D, Double.NaN)); + + assertInvalidateNaNThrows("EQ", () -> predicate(Expression.Operation.EQ, "a", Double.NaN)); + } + + private void assertInvalidateNaNThrows(String operation, Callable> callable) { + AssertHelpers.assertThrows("Should invalidate NaN input", + IllegalArgumentException.class, String.format("Cannot create %s predicate with NaN", operation), + callable); + } + + private UnboundTerm self(String name) { + return new UnboundTransform<>(ref(name), Transforms.identity(Types.DoubleType.get())); + } + } diff --git a/api/src/test/java/org/apache/iceberg/expressions/TestExpressionSerialization.java b/api/src/test/java/org/apache/iceberg/expressions/TestExpressionSerialization.java index 0f10881f7524..d57b7ea62aff 100644 --- a/api/src/test/java/org/apache/iceberg/expressions/TestExpressionSerialization.java +++ b/api/src/test/java/org/apache/iceberg/expressions/TestExpressionSerialization.java @@ -48,6 +48,8 @@ public void testExpressions() throws Exception { Expressions.notEqual("col", "abc"), Expressions.notNull("maybeNull"), Expressions.isNull("maybeNull2"), + Expressions.isNaN("maybeNaN"), + Expressions.notNaN("maybeNaN2"), Expressions.not(Expressions.greaterThan("a", 10)), Expressions.and(Expressions.greaterThanOrEqual("a", 0), Expressions.lessThan("a", 3)), Expressions.or(Expressions.lessThan("a", 0), Expressions.greaterThan("a", 10)), @@ -132,7 +134,8 @@ private static boolean equals(Predicate left, Predicate right) { return false; } - if (left.op() == Operation.IS_NULL || left.op() == Operation.NOT_NULL) { + if (left.op() == Operation.IS_NULL || left.op() == Operation.NOT_NULL || + left.op() == Operation.IS_NAN || left.op() == Operation.NOT_NAN) { return true; } diff --git a/api/src/test/java/org/apache/iceberg/expressions/TestInclusiveManifestEvaluator.java b/api/src/test/java/org/apache/iceberg/expressions/TestInclusiveManifestEvaluator.java index a35d97988f77..0fb2b1ac7753 100644 --- a/api/src/test/java/org/apache/iceberg/expressions/TestInclusiveManifestEvaluator.java +++ b/api/src/test/java/org/apache/iceberg/expressions/TestInclusiveManifestEvaluator.java @@ -36,12 +36,14 @@ import static org.apache.iceberg.expressions.Expressions.greaterThan; import static org.apache.iceberg.expressions.Expressions.greaterThanOrEqual; import static org.apache.iceberg.expressions.Expressions.in; +import static org.apache.iceberg.expressions.Expressions.isNaN; import static org.apache.iceberg.expressions.Expressions.isNull; import static org.apache.iceberg.expressions.Expressions.lessThan; import static org.apache.iceberg.expressions.Expressions.lessThanOrEqual; import static org.apache.iceberg.expressions.Expressions.not; import static org.apache.iceberg.expressions.Expressions.notEqual; import static org.apache.iceberg.expressions.Expressions.notIn; +import static org.apache.iceberg.expressions.Expressions.notNaN; import static org.apache.iceberg.expressions.Expressions.notNull; import static org.apache.iceberg.expressions.Expressions.or; import static org.apache.iceberg.expressions.Expressions.startsWith; @@ -54,7 +56,9 @@ public class TestInclusiveManifestEvaluator { required(1, "id", Types.IntegerType.get()), optional(4, "all_nulls", Types.StringType.get()), optional(5, "some_nulls", Types.StringType.get()), - optional(6, "no_nulls", Types.StringType.get()) + optional(6, "no_nulls", Types.StringType.get()), + optional(7, "float", Types.FloatType.get()), + optional(8, "all_nulls_double", Types.DoubleType.get()) ); private static final PartitionSpec SPEC = PartitionSpec.builderFor(SCHEMA) @@ -63,6 +67,8 @@ public class TestInclusiveManifestEvaluator { .identity("all_nulls") .identity("some_nulls") .identity("no_nulls") + .identity("float") + .identity("all_nulls_double") .build(); private static final int INT_MIN_VALUE = 30; @@ -82,7 +88,12 @@ public class TestInclusiveManifestEvaluator { new TestHelpers.TestFieldSummary(false, INT_MIN, INT_MAX), new TestHelpers.TestFieldSummary(true, null, null), new TestHelpers.TestFieldSummary(true, STRING_MIN, STRING_MAX), - new TestHelpers.TestFieldSummary(false, STRING_MIN, STRING_MAX))); + new TestHelpers.TestFieldSummary(false, STRING_MIN, STRING_MAX), + new TestHelpers.TestFieldSummary(false, + toByteBuffer(Types.FloatType.get(), 0F), + toByteBuffer(Types.FloatType.get(), 20F)), + new TestHelpers.TestFieldSummary(true, null, null) + )); @Test public void testAllNulls() { @@ -111,6 +122,24 @@ public void testNoNulls() { Assert.assertFalse("Should skip: non-null column contains no null values", shouldRead); } + @Test + public void testIsNaN() { + boolean shouldRead = ManifestEvaluator.forRowFilter(isNaN("float"), SPEC, true).eval(FILE); + Assert.assertTrue("Should read: no information on if there are nan value in float column", shouldRead); + + shouldRead = ManifestEvaluator.forRowFilter(isNaN("all_nulls_double"), SPEC, true).eval(FILE); + Assert.assertFalse("Should skip: all null column doesn't contain nan value", shouldRead); + } + + @Test + public void testNotNaN() { + boolean shouldRead = ManifestEvaluator.forRowFilter(notNaN("float"), SPEC, true).eval(FILE); + Assert.assertTrue("Should read: no information on if there are nan value in float column", shouldRead); + + shouldRead = ManifestEvaluator.forRowFilter(notNaN("all_nulls_double"), SPEC, true).eval(FILE); + Assert.assertTrue("Should read: all null column contains non nan value", shouldRead); + } + @Test public void testMissingColumn() { AssertHelpers.assertThrows("Should complain about missing column in expression", @@ -123,7 +152,8 @@ public void testMissingStats() { Expression[] exprs = new Expression[] { lessThan("id", 5), lessThanOrEqual("id", 30), equal("id", 70), greaterThan("id", 78), greaterThanOrEqual("id", 90), notEqual("id", 101), - isNull("id"), notNull("id"), startsWith("all_nulls", "a") + isNull("id"), notNull("id"), startsWith("all_nulls", "a"), + isNaN("float"), notNaN("float") }; for (Expression expr : exprs) { diff --git a/api/src/test/java/org/apache/iceberg/expressions/TestInclusiveMetricsEvaluator.java b/api/src/test/java/org/apache/iceberg/expressions/TestInclusiveMetricsEvaluator.java index 3863261e788c..56135c8a331b 100644 --- a/api/src/test/java/org/apache/iceberg/expressions/TestInclusiveMetricsEvaluator.java +++ b/api/src/test/java/org/apache/iceberg/expressions/TestInclusiveMetricsEvaluator.java @@ -40,12 +40,14 @@ import static org.apache.iceberg.expressions.Expressions.greaterThan; import static org.apache.iceberg.expressions.Expressions.greaterThanOrEqual; import static org.apache.iceberg.expressions.Expressions.in; +import static org.apache.iceberg.expressions.Expressions.isNaN; import static org.apache.iceberg.expressions.Expressions.isNull; import static org.apache.iceberg.expressions.Expressions.lessThan; import static org.apache.iceberg.expressions.Expressions.lessThanOrEqual; import static org.apache.iceberg.expressions.Expressions.not; import static org.apache.iceberg.expressions.Expressions.notEqual; import static org.apache.iceberg.expressions.Expressions.notIn; +import static org.apache.iceberg.expressions.Expressions.notNaN; import static org.apache.iceberg.expressions.Expressions.notNull; import static org.apache.iceberg.expressions.Expressions.or; import static org.apache.iceberg.expressions.Expressions.startsWith; @@ -60,7 +62,14 @@ public class TestInclusiveMetricsEvaluator { required(3, "required", Types.StringType.get()), optional(4, "all_nulls", Types.StringType.get()), optional(5, "some_nulls", Types.StringType.get()), - optional(6, "no_nulls", Types.StringType.get()) + optional(6, "no_nulls", Types.StringType.get()), + optional(7, "all_nans", Types.DoubleType.get()), + optional(8, "some_nans", Types.FloatType.get()), + optional(9, "no_nans", Types.FloatType.get()), + optional(10, "all_nulls_double", Types.DoubleType.get()), + optional(11, "all_nans_v1_stats", Types.FloatType.get()), + optional(12, "nan_and_null_only", Types.DoubleType.get()), + optional(13, "no_nan_stats", Types.DoubleType.get()) ); private static final int INT_MIN_VALUE = 30; @@ -68,27 +77,50 @@ public class TestInclusiveMetricsEvaluator { private static final DataFile FILE = new TestDataFile("file.avro", Row.of(), 50, // any value counts, including nulls - ImmutableMap.of( - 4, 50L, - 5, 50L, - 6, 50L), + ImmutableMap.builder() + .put(4, 50L) + .put(5, 50L) + .put(6, 50L) + .put(7, 50L) + .put(8, 50L) + .put(9, 50L) + .put(10, 50L) + .put(11, 50L) + .put(12, 50L) + .put(13, 50L) + .build(), // null value counts + ImmutableMap.builder() + .put(4, 50L) + .put(5, 10L) + .put(6, 0L) + .put(10, 50L) + .put(11, 0L) + .put(12, 1L) + .build(), + // nan value counts ImmutableMap.of( - 4, 50L, - 5, 10L, - 6, 0L), + 7, 50L, + 8, 10L, + 9, 0L), // lower bounds ImmutableMap.of( - 1, toByteBuffer(IntegerType.get(), INT_MIN_VALUE)), + 1, toByteBuffer(IntegerType.get(), INT_MIN_VALUE), + 11, toByteBuffer(Types.FloatType.get(), Float.NaN), + 12, toByteBuffer(Types.DoubleType.get(), Double.NaN)), // upper bounds ImmutableMap.of( - 1, toByteBuffer(IntegerType.get(), INT_MAX_VALUE))); + 1, toByteBuffer(IntegerType.get(), INT_MAX_VALUE), + 11, toByteBuffer(Types.FloatType.get(), Float.NaN), + 12, toByteBuffer(Types.DoubleType.get(), Double.NaN))); private static final DataFile FILE_2 = new TestDataFile("file_2.avro", Row.of(), 50, // any value counts, including nulls ImmutableMap.of(3, 20L), // null value counts ImmutableMap.of(3, 2L), + // nan value counts + null, // lower bounds ImmutableMap.of(3, toByteBuffer(StringType.get(), "aa")), // upper bounds @@ -99,6 +131,8 @@ public class TestInclusiveMetricsEvaluator { ImmutableMap.of(3, 20L), // null value counts ImmutableMap.of(3, 2L), + // nan value counts + null, // lower bounds ImmutableMap.of(3, toByteBuffer(StringType.get(), "1str1")), // upper bounds @@ -109,6 +143,8 @@ public class TestInclusiveMetricsEvaluator { ImmutableMap.of(3, 20L), // null value counts ImmutableMap.of(3, 2L), + // nan value counts + null, // lower bounds ImmutableMap.of(3, toByteBuffer(StringType.get(), "abc")), // upper bounds @@ -156,6 +192,54 @@ public void testNoNulls() { Assert.assertFalse("Should skip: non-null column contains no null values", shouldRead); } + @Test + public void testIsNaN() { + boolean shouldRead = new InclusiveMetricsEvaluator(SCHEMA, isNaN("all_nans")).eval(FILE); + Assert.assertTrue("Should read: at least one nan value in all nan column", shouldRead); + + shouldRead = new InclusiveMetricsEvaluator(SCHEMA, isNaN("some_nans")).eval(FILE); + Assert.assertTrue("Should read: at least one nan value in some nan column", shouldRead); + + shouldRead = new InclusiveMetricsEvaluator(SCHEMA, isNaN("no_nans")).eval(FILE); + Assert.assertFalse("Should skip: no-nans column contains no nan values", shouldRead); + + shouldRead = new InclusiveMetricsEvaluator(SCHEMA, isNaN("all_nulls_double")).eval(FILE); + Assert.assertFalse("Should skip: all-null column doesn't contain nan value", shouldRead); + + shouldRead = new InclusiveMetricsEvaluator(SCHEMA, isNaN("no_nan_stats")).eval(FILE); + Assert.assertTrue("Should read: no guarantee on if contains nan value without nan stats", shouldRead); + + shouldRead = new InclusiveMetricsEvaluator(SCHEMA, isNaN("all_nans_v1_stats")).eval(FILE); + Assert.assertTrue("Should read: at least one nan value in all nan column", shouldRead); + + shouldRead = new InclusiveMetricsEvaluator(SCHEMA, isNaN("nan_and_null_only")).eval(FILE); + Assert.assertTrue("Should read: at least one nan value in nan and nulls only column", shouldRead); + } + + @Test + public void testNotNaN() { + boolean shouldRead = new InclusiveMetricsEvaluator(SCHEMA, notNaN("all_nans")).eval(FILE); + Assert.assertFalse("Should skip: column with all nans will not contain non-nan", shouldRead); + + shouldRead = new InclusiveMetricsEvaluator(SCHEMA, notNaN("some_nans")).eval(FILE); + Assert.assertTrue("Should read: at least one non-nan value in some nan column", shouldRead); + + shouldRead = new InclusiveMetricsEvaluator(SCHEMA, notNaN("no_nans")).eval(FILE); + Assert.assertTrue("Should read: at least one non-nan value in no nan column", shouldRead); + + shouldRead = new InclusiveMetricsEvaluator(SCHEMA, notNaN("all_nulls_double")).eval(FILE); + Assert.assertTrue("Should read: at least one non-nan value in all null column", shouldRead); + + shouldRead = new InclusiveMetricsEvaluator(SCHEMA, notNaN("no_nan_stats")).eval(FILE); + Assert.assertTrue("Should read: no guarantee on if contains nan value without nan stats", shouldRead); + + shouldRead = new InclusiveMetricsEvaluator(SCHEMA, notNaN("all_nans_v1_stats")).eval(FILE); + Assert.assertTrue("Should read: no guarantee on if contains nan value without nan stats", shouldRead); + + shouldRead = new InclusiveMetricsEvaluator(SCHEMA, notNaN("nan_and_null_only")).eval(FILE); + Assert.assertTrue("Should read: at least one null value in nan and nulls only column", shouldRead); + } + @Test public void testRequiredColumn() { boolean shouldRead = new InclusiveMetricsEvaluator(SCHEMA, notNull("required")).eval(FILE); @@ -179,7 +263,7 @@ public void testMissingStats() { Expression[] exprs = new Expression[] { lessThan("no_stats", 5), lessThanOrEqual("no_stats", 30), equal("no_stats", 70), greaterThan("no_stats", 78), greaterThanOrEqual("no_stats", 90), notEqual("no_stats", 101), - isNull("no_stats"), notNull("no_stats") + isNull("no_stats"), notNull("no_stats"), isNaN("some_nans"), notNaN("some_nans") }; for (Expression expr : exprs) { @@ -195,7 +279,7 @@ public void testZeroRecordFile() { Expression[] exprs = new Expression[] { lessThan("id", 5), lessThanOrEqual("id", 30), equal("id", 70), greaterThan("id", 78), greaterThanOrEqual("id", 90), notEqual("id", 101), isNull("some_nulls"), - notNull("some_nulls") + notNull("some_nulls"), isNaN("some_nans"), notNaN("some_nans"), }; for (Expression expr : exprs) { diff --git a/api/src/test/java/org/apache/iceberg/expressions/TestPredicateBinding.java b/api/src/test/java/org/apache/iceberg/expressions/TestPredicateBinding.java index 37b90c15c3c7..db2029d3044b 100644 --- a/api/src/test/java/org/apache/iceberg/expressions/TestPredicateBinding.java +++ b/api/src/test/java/org/apache/iceberg/expressions/TestPredicateBinding.java @@ -23,6 +23,7 @@ import java.util.Arrays; import java.util.List; import java.util.stream.Collectors; +import org.apache.iceberg.AssertHelpers; import org.apache.iceberg.exceptions.ValidationException; import org.apache.iceberg.types.Types; import org.apache.iceberg.types.Types.StructType; @@ -35,11 +36,13 @@ import static org.apache.iceberg.expressions.Expression.Operation.GT; import static org.apache.iceberg.expressions.Expression.Operation.GT_EQ; import static org.apache.iceberg.expressions.Expression.Operation.IN; +import static org.apache.iceberg.expressions.Expression.Operation.IS_NAN; import static org.apache.iceberg.expressions.Expression.Operation.IS_NULL; import static org.apache.iceberg.expressions.Expression.Operation.LT; import static org.apache.iceberg.expressions.Expression.Operation.LT_EQ; import static org.apache.iceberg.expressions.Expression.Operation.NOT_EQ; import static org.apache.iceberg.expressions.Expression.Operation.NOT_IN; +import static org.apache.iceberg.expressions.Expression.Operation.NOT_NAN; import static org.apache.iceberg.expressions.Expression.Operation.NOT_NULL; import static org.apache.iceberg.expressions.Expressions.ref; import static org.apache.iceberg.types.Types.NestedField.optional; @@ -318,6 +321,65 @@ public void testNotNull() { Expressions.alwaysTrue(), unbound.bind(required)); } + @Test + public void testIsNaN() { + // double + StructType struct = StructType.of(optional(21, "d", Types.DoubleType.get())); + + UnboundPredicate unbound = new UnboundPredicate<>(IS_NAN, ref("d")); + Expression expr = unbound.bind(struct); + BoundPredicate bound = assertAndUnwrap(expr); + Assert.assertEquals("Should use the same operation", IS_NAN, bound.op()); + Assert.assertEquals("Should use the correct field", 21, bound.ref().fieldId()); + Assert.assertTrue("Should be a unary predicate", bound.isUnaryPredicate()); + + // float + struct = StructType.of(optional(21, "f", Types.FloatType.get())); + + unbound = new UnboundPredicate<>(IS_NAN, ref("f")); + expr = unbound.bind(struct); + bound = assertAndUnwrap(expr); + Assert.assertEquals("Should use the same operation", IS_NAN, bound.op()); + Assert.assertEquals("Should use the correct field", 21, bound.ref().fieldId()); + Assert.assertTrue("Should be a unary predicate", bound.isUnaryPredicate()); + + // string (non-compatible) + StructType strStruct = StructType.of(optional(21, "s", Types.StringType.get())); + AssertHelpers.assertThrows("Should complain about incompatible type binding", + ValidationException.class, "IsNaN cannot be used with a non-floating-point column", + () -> new UnboundPredicate<>(IS_NAN, ref("s")).bind(strStruct)); + } + + @Test + public void testNotNaN() { + // double + StructType struct = StructType.of(optional(21, "d", Types.DoubleType.get())); + + UnboundPredicate unbound = new UnboundPredicate<>(NOT_NAN, ref("d")); + Expression expr = unbound.bind(struct); + BoundPredicate bound = assertAndUnwrap(expr); + Assert.assertEquals("Should use the same operation", NOT_NAN, bound.op()); + Assert.assertEquals("Should use the correct field", 21, bound.ref().fieldId()); + Assert.assertTrue("Should be a unary predicate", bound.isUnaryPredicate()); + + // float + struct = StructType.of(optional(21, "f", Types.FloatType.get())); + + unbound = new UnboundPredicate<>(NOT_NAN, ref("f")); + expr = unbound.bind(struct); + bound = assertAndUnwrap(expr); + Assert.assertEquals("Should use the same operation", NOT_NAN, bound.op()); + Assert.assertEquals("Should use the correct field", 21, bound.ref().fieldId()); + Assert.assertTrue("Should be a unary predicate", bound.isUnaryPredicate()); + + // string (non-compatible) + StructType strStruct = StructType.of(optional(21, "s", Types.StringType.get())); + AssertHelpers.assertThrows("Should complain about incompatible type binding", + ValidationException.class, "NotNaN cannot be used with a non-floating-point column", + () -> new UnboundPredicate<>(NOT_NAN, ref("s")).bind(strStruct)); + + } + @Test public void testInPredicateBinding() { StructType struct = StructType.of( @@ -392,6 +454,7 @@ public void testInPredicateBindingConversionToEq() { Assert.assertEquals("Should change the IN operation to EQ", EQ, bound.op()); } + @Test public void testInPredicateBindingConversionDedupToEq() { StructType struct = StructType.of(required(15, "d", Types.DecimalType.of(9, 2))); diff --git a/api/src/test/java/org/apache/iceberg/expressions/TestStrictMetricsEvaluator.java b/api/src/test/java/org/apache/iceberg/expressions/TestStrictMetricsEvaluator.java index 6116fca2aa05..8df5475546e9 100644 --- a/api/src/test/java/org/apache/iceberg/expressions/TestStrictMetricsEvaluator.java +++ b/api/src/test/java/org/apache/iceberg/expressions/TestStrictMetricsEvaluator.java @@ -26,6 +26,7 @@ import org.apache.iceberg.TestHelpers.TestDataFile; import org.apache.iceberg.exceptions.ValidationException; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.types.Types; import org.apache.iceberg.types.Types.IntegerType; import org.apache.iceberg.types.Types.StringType; import org.junit.Assert; @@ -36,12 +37,14 @@ import static org.apache.iceberg.expressions.Expressions.greaterThan; import static org.apache.iceberg.expressions.Expressions.greaterThanOrEqual; import static org.apache.iceberg.expressions.Expressions.in; +import static org.apache.iceberg.expressions.Expressions.isNaN; import static org.apache.iceberg.expressions.Expressions.isNull; import static org.apache.iceberg.expressions.Expressions.lessThan; import static org.apache.iceberg.expressions.Expressions.lessThanOrEqual; import static org.apache.iceberg.expressions.Expressions.not; import static org.apache.iceberg.expressions.Expressions.notEqual; import static org.apache.iceberg.expressions.Expressions.notIn; +import static org.apache.iceberg.expressions.Expressions.notNaN; import static org.apache.iceberg.expressions.Expressions.notNull; import static org.apache.iceberg.expressions.Expressions.or; import static org.apache.iceberg.types.Conversions.toByteBuffer; @@ -56,7 +59,14 @@ public class TestStrictMetricsEvaluator { optional(4, "all_nulls", StringType.get()), optional(5, "some_nulls", StringType.get()), optional(6, "no_nulls", StringType.get()), - required(7, "always_5", IntegerType.get()) + required(7, "always_5", IntegerType.get()), + optional(8, "all_nans", Types.DoubleType.get()), + optional(9, "some_nans", Types.FloatType.get()), + optional(10, "no_nans", Types.FloatType.get()), + optional(11, "all_nulls_double", Types.DoubleType.get()), + optional(12, "all_nans_v1_stats", Types.FloatType.get()), + optional(13, "nan_and_null_only", Types.DoubleType.get()), + optional(14, "no_nan_stats", Types.DoubleType.get()) ); private static final int INT_MIN_VALUE = 30; @@ -64,35 +74,59 @@ public class TestStrictMetricsEvaluator { private static final DataFile FILE = new TestDataFile("file.avro", Row.of(), 50, // any value counts, including nulls - ImmutableMap.of( - 4, 50L, - 5, 50L, - 6, 50L), + ImmutableMap.builder() + .put(4, 50L) + .put(5, 50L) + .put(6, 50L) + .put(8, 50L) + .put(9, 50L) + .put(10, 50L) + .put(11, 50L) + .put(12, 50L) + .put(13, 50L) + .put(14, 50L) + .build(), // null value counts + ImmutableMap.builder() + .put(4, 50L) + .put(5, 10L) + .put(6, 0L) + .put(11, 50L) + .put(12, 0L) + .put(13, 1L) + .build(), + // nan value counts ImmutableMap.of( - 4, 50L, - 5, 10L, - 6, 0L), + 8, 50L, + 9, 10L, + 10, 0L), // lower bounds ImmutableMap.of( 1, toByteBuffer(IntegerType.get(), INT_MIN_VALUE), - 7, toByteBuffer(IntegerType.get(), 5)), + 7, toByteBuffer(IntegerType.get(), 5), + 12, toByteBuffer(Types.FloatType.get(), Float.NaN), + 13, toByteBuffer(Types.DoubleType.get(), Double.NaN)), // upper bounds ImmutableMap.of( 1, toByteBuffer(IntegerType.get(), INT_MAX_VALUE), - 7, toByteBuffer(IntegerType.get(), 5))); + 7, toByteBuffer(IntegerType.get(), 5), + 12, toByteBuffer(Types.FloatType.get(), Float.NaN), + 13, toByteBuffer(Types.DoubleType.get(), Double.NaN))); private static final DataFile FILE_2 = new TestDataFile("file_2.avro", Row.of(), 50, // any value counts, including nulls ImmutableMap.of( 4, 50L, 5, 50L, - 6, 50L), + 6, 50L, + 8, 50L), // null value counts ImmutableMap.of( 4, 50L, 5, 10L, 6, 0L), + // nan value counts + null, // lower bounds ImmutableMap.of(5, toByteBuffer(StringType.get(), "bbb")), // upper bounds @@ -109,6 +143,8 @@ public class TestStrictMetricsEvaluator { 4, 50L, 5, 10L, 6, 0L), + // nan value counts + null, // lower bounds ImmutableMap.of(5, toByteBuffer(StringType.get(), "bbb")), // upper bounds @@ -159,6 +195,54 @@ public void testSomeNulls() { Assert.assertFalse("Should not match: equal on some nulls column", shouldRead); } + @Test + public void testIsNaN() { + boolean shouldRead = new StrictMetricsEvaluator(SCHEMA, isNaN("all_nans")).eval(FILE); + Assert.assertTrue("Should match: all values are nan", shouldRead); + + shouldRead = new StrictMetricsEvaluator(SCHEMA, isNaN("some_nans")).eval(FILE); + Assert.assertFalse("Should not match: at least one non-nan value in some nan column", shouldRead); + + shouldRead = new StrictMetricsEvaluator(SCHEMA, isNaN("no_nans")).eval(FILE); + Assert.assertFalse("Should not match: at least one non-nan value in no nan column", shouldRead); + + shouldRead = new StrictMetricsEvaluator(SCHEMA, isNaN("all_nulls_double")).eval(FILE); + Assert.assertFalse("Should not match: at least one non-nan value in all null column", shouldRead); + + shouldRead = new StrictMetricsEvaluator(SCHEMA, isNaN("no_nan_stats")).eval(FILE); + Assert.assertFalse("Should not match: cannot determine without nan stats", shouldRead); + + shouldRead = new StrictMetricsEvaluator(SCHEMA, isNaN("all_nans_v1_stats")).eval(FILE); + Assert.assertFalse("Should not match: cannot determine without nan stats", shouldRead); + + shouldRead = new StrictMetricsEvaluator(SCHEMA, isNaN("nan_and_null_only")).eval(FILE); + Assert.assertFalse("Should not match: null values are not nan", shouldRead); + } + + @Test + public void testNotNaN() { + boolean shouldRead = new StrictMetricsEvaluator(SCHEMA, notNaN("all_nans")).eval(FILE); + Assert.assertFalse("Should not match: all values are nan", shouldRead); + + shouldRead = new StrictMetricsEvaluator(SCHEMA, notNaN("some_nans")).eval(FILE); + Assert.assertFalse("Should not match: at least one nan value in some nan column", shouldRead); + + shouldRead = new StrictMetricsEvaluator(SCHEMA, notNaN("no_nans")).eval(FILE); + Assert.assertTrue("Should match: no value is nan", shouldRead); + + shouldRead = new StrictMetricsEvaluator(SCHEMA, notNaN("all_nulls_double")).eval(FILE); + Assert.assertTrue("Should match: no nan value in all null column", shouldRead); + + shouldRead = new StrictMetricsEvaluator(SCHEMA, notNaN("no_nan_stats")).eval(FILE); + Assert.assertFalse("Should not match: cannot determine without nan stats", shouldRead); + + shouldRead = new StrictMetricsEvaluator(SCHEMA, notNaN("all_nans_v1_stats")).eval(FILE); + Assert.assertFalse("Should not match: all values are nan", shouldRead); + + shouldRead = new StrictMetricsEvaluator(SCHEMA, notNaN("nan_and_null_only")).eval(FILE); + Assert.assertFalse("Should not match: null values are not nan", shouldRead); + } + @Test public void testRequiredColumn() { boolean shouldRead = new StrictMetricsEvaluator(SCHEMA, notNull("required")).eval(FILE); @@ -182,7 +266,7 @@ public void testMissingStats() { Expression[] exprs = new Expression[] { lessThan("no_stats", 5), lessThanOrEqual("no_stats", 30), equal("no_stats", 70), greaterThan("no_stats", 78), greaterThanOrEqual("no_stats", 90), notEqual("no_stats", 101), - isNull("no_stats"), notNull("no_stats") + isNull("no_stats"), notNull("no_stats"), isNaN("all_nans"), notNaN("all_nans") }; for (Expression expr : exprs) { @@ -198,7 +282,7 @@ public void testZeroRecordFile() { Expression[] exprs = new Expression[] { lessThan("id", 5), lessThanOrEqual("id", 30), equal("id", 70), greaterThan("id", 78), greaterThanOrEqual("id", 90), notEqual("id", 101), isNull("some_nulls"), - notNull("some_nulls") + notNull("some_nulls"), isNaN("all_nans"), notNaN("all_nans") }; for (Expression expr : exprs) { diff --git a/api/src/test/java/org/apache/iceberg/transforms/TestResiduals.java b/api/src/test/java/org/apache/iceberg/transforms/TestResiduals.java index 2264800b9117..92c0a1efe902 100644 --- a/api/src/test/java/org/apache/iceberg/transforms/TestResiduals.java +++ b/api/src/test/java/org/apache/iceberg/transforms/TestResiduals.java @@ -42,8 +42,10 @@ import static org.apache.iceberg.expressions.Expressions.equal; import static org.apache.iceberg.expressions.Expressions.greaterThan; import static org.apache.iceberg.expressions.Expressions.in; +import static org.apache.iceberg.expressions.Expressions.isNaN; import static org.apache.iceberg.expressions.Expressions.lessThan; import static org.apache.iceberg.expressions.Expressions.notIn; +import static org.apache.iceberg.expressions.Expressions.notNaN; import static org.apache.iceberg.expressions.Expressions.or; public class TestResiduals { @@ -153,7 +155,9 @@ public void testUnpartitionedResiduals() { Expressions.notNull("c"), Expressions.isNull("d"), Expressions.in("e", 1, 2, 3), - Expressions.notIn("f", 1, 2, 3) + Expressions.notIn("f", 1, 2, 3), + Expressions.notNaN("g"), + Expressions.isNaN("h"), }; for (Expression expr : expressions) { @@ -235,6 +239,78 @@ public void testNotIn() { Assert.assertEquals("Residual should be alwaysFalse", alwaysFalse(), residual); } + @Test + public void testIsNaN() { + Schema schema = new Schema( + Types.NestedField.optional(50, "double", Types.DoubleType.get()), + Types.NestedField.optional(51, "float", Types.FloatType.get()) + ); + + // test double field + PartitionSpec spec = PartitionSpec.builderFor(schema) + .identity("double") + .build(); + + ResidualEvaluator resEval = ResidualEvaluator.of(spec, + isNaN("double"), true); + + Expression residual = resEval.residualFor(Row.of(Double.NaN)); + Assert.assertEquals("Residual should be alwaysTrue", alwaysTrue(), residual); + + residual = resEval.residualFor(Row.of(2D)); + Assert.assertEquals("Residual should be alwaysFalse", alwaysFalse(), residual); + + // test float field + spec = PartitionSpec.builderFor(schema) + .identity("float") + .build(); + + resEval = ResidualEvaluator.of(spec, + isNaN("float"), true); + + residual = resEval.residualFor(Row.of(Float.NaN)); + Assert.assertEquals("Residual should be alwaysTrue", alwaysTrue(), residual); + + residual = resEval.residualFor(Row.of(3F)); + Assert.assertEquals("Residual should be alwaysFalse", alwaysFalse(), residual); + } + + @Test + public void testNotNaN() { + Schema schema = new Schema( + Types.NestedField.optional(50, "double", Types.DoubleType.get()), + Types.NestedField.optional(51, "float", Types.FloatType.get()) + ); + + // test double field + PartitionSpec spec = PartitionSpec.builderFor(schema) + .identity("double") + .build(); + + ResidualEvaluator resEval = ResidualEvaluator.of(spec, + notNaN("double"), true); + + Expression residual = resEval.residualFor(Row.of(Double.NaN)); + Assert.assertEquals("Residual should be alwaysFalse", alwaysFalse(), residual); + + residual = resEval.residualFor(Row.of(2D)); + Assert.assertEquals("Residual should be alwaysTrue", alwaysTrue(), residual); + + // test float field + spec = PartitionSpec.builderFor(schema) + .identity("float") + .build(); + + resEval = ResidualEvaluator.of(spec, + notNaN("float"), true); + + residual = resEval.residualFor(Row.of(Float.NaN)); + Assert.assertEquals("Residual should be alwaysFalse", alwaysFalse(), residual); + + residual = resEval.residualFor(Row.of(3F)); + Assert.assertEquals("Residual should be alwaysTrue", alwaysTrue(), residual); + } + @Test public void testNotInTimestamp() { Schema schema = new Schema( diff --git a/data/src/test/java/org/apache/iceberg/data/TestMetricsRowGroupFilter.java b/data/src/test/java/org/apache/iceberg/data/TestMetricsRowGroupFilter.java index 1d40be14446c..b82d6eb663d5 100644 --- a/data/src/test/java/org/apache/iceberg/data/TestMetricsRowGroupFilter.java +++ b/data/src/test/java/org/apache/iceberg/data/TestMetricsRowGroupFilter.java @@ -47,9 +47,9 @@ import org.apache.iceberg.parquet.ParquetMetricsRowGroupFilter; import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.iceberg.types.Types; +import org.apache.iceberg.types.Types.DoubleType; import org.apache.iceberg.types.Types.FloatType; import org.apache.iceberg.types.Types.IntegerType; -import org.apache.iceberg.types.Types.LongType; import org.apache.iceberg.types.Types.StringType; import org.apache.orc.OrcFile; import org.apache.orc.Reader; @@ -73,12 +73,14 @@ import static org.apache.iceberg.expressions.Expressions.greaterThan; import static org.apache.iceberg.expressions.Expressions.greaterThanOrEqual; import static org.apache.iceberg.expressions.Expressions.in; +import static org.apache.iceberg.expressions.Expressions.isNaN; import static org.apache.iceberg.expressions.Expressions.isNull; import static org.apache.iceberg.expressions.Expressions.lessThan; import static org.apache.iceberg.expressions.Expressions.lessThanOrEqual; import static org.apache.iceberg.expressions.Expressions.not; import static org.apache.iceberg.expressions.Expressions.notEqual; import static org.apache.iceberg.expressions.Expressions.notIn; +import static org.apache.iceberg.expressions.Expressions.notNaN; import static org.apache.iceberg.expressions.Expressions.notNull; import static org.apache.iceberg.expressions.Expressions.or; import static org.apache.iceberg.expressions.Expressions.startsWith; @@ -106,14 +108,17 @@ public TestMetricsRowGroupFilter(String format) { required(1, "id", IntegerType.get()), optional(2, "no_stats_parquet", StringType.get()), required(3, "required", StringType.get()), - optional(4, "all_nulls", LongType.get()), + optional(4, "all_nulls", DoubleType.get()), optional(5, "some_nulls", StringType.get()), optional(6, "no_nulls", StringType.get()), optional(7, "struct_not_null", structFieldType), optional(9, "not_in_file", FloatType.get()), optional(10, "str", StringType.get()), optional(11, "map_not_null", - Types.MapType.ofRequired(12, 13, StringType.get(), IntegerType.get())) + Types.MapType.ofRequired(12, 13, StringType.get(), IntegerType.get())), + optional(14, "all_nans", DoubleType.get()), + optional(15, "some_nans", FloatType.get()), + optional(16, "no_nans", DoubleType.get()) ); private static final Types.StructType _structFieldType = @@ -123,11 +128,14 @@ public TestMetricsRowGroupFilter(String format) { required(1, "_id", IntegerType.get()), optional(2, "_no_stats_parquet", StringType.get()), required(3, "_required", StringType.get()), - optional(4, "_all_nulls", LongType.get()), + optional(4, "_all_nulls", DoubleType.get()), optional(5, "_some_nulls", StringType.get()), optional(6, "_no_nulls", StringType.get()), optional(7, "_struct_not_null", _structFieldType), - optional(10, "_str", StringType.get()) + optional(10, "_str", StringType.get()), + optional(14, "_all_nans", Types.DoubleType.get()), + optional(15, "_some_nans", FloatType.get()), + optional(16, "_no_nans", Types.DoubleType.get()) ); private static final String TOO_LONG_FOR_STATS_PARQUET; @@ -184,6 +192,9 @@ public void createOrcInputFile() throws IOException { record.setField("_some_nulls", (i % 10 == 0) ? null : "some"); // includes some null values record.setField("_no_nulls", ""); // optional, but always non-null record.setField("_str", i + "str" + i); + record.setField("_all_nans", Double.NaN); // never non-nan + record.setField("_some_nans", (i % 10 == 0) ? Float.NaN : 2F); // includes some nan values + record.setField("_no_nans", 3D); // optional, but always non-nan GenericRecord structNotNull = GenericRecord.create(_structFieldType); structNotNull.setField("_int_field", INT_MIN_VALUE + i); @@ -223,6 +234,9 @@ private void createParquetInputFile() throws IOException { builder.set("_all_nulls", null); // never non-null builder.set("_some_nulls", (i % 10 == 0) ? null : "some"); // includes some null values builder.set("_no_nulls", ""); // optional, but always non-null + builder.set("_all_nans", Double.NaN); // never non-nan + builder.set("_some_nans", (i % 10 == 0) ? Float.NaN : 2F); // includes some nan values + builder.set("_no_nans", 3D); // optional, but always non-nan builder.set("_str", i + "str" + i); Record structNotNull = new Record(structSchema); @@ -281,6 +295,37 @@ public void testNoNulls() { Assert.assertTrue("Should read: struct type is not skipped", shouldRead); } + @Test + public void testIsNaN() { + boolean shouldRead = shouldRead(isNaN("all_nans")); + Assert.assertTrue("Should read: NaN counts are not tracked in Parquet metrics", shouldRead); + + shouldRead = shouldRead(isNaN("some_nans")); + Assert.assertTrue("Should read: NaN counts are not tracked in Parquet metrics", shouldRead); + + shouldRead = shouldRead(isNaN("no_nans")); + Assert.assertTrue("Should read: NaN counts are not tracked in Parquet metrics", shouldRead); + + shouldRead = shouldRead(isNaN("all_nulls")); + Assert.assertFalse("Should skip: all null column will not contain nan value", shouldRead); + } + + @Test + public void testNotNaN() { + boolean shouldRead = shouldRead(notNaN("all_nans")); + Assert.assertTrue("Should read: NaN counts are not tracked in Parquet metrics", shouldRead); + + shouldRead = shouldRead(notNaN("some_nans")); + Assert.assertTrue("Should read: NaN counts are not tracked in Parquet metrics", shouldRead); + + shouldRead = shouldRead(notNaN("no_nans")); + Assert.assertTrue("Should read: NaN counts are not tracked in Parquet metrics", shouldRead); + + shouldRead = shouldRead(notNaN("all_nulls")); + Assert.assertTrue("Should read: NaN counts are not tracked in Parquet metrics", shouldRead); + + } + @Test public void testRequiredColumn() { boolean shouldRead = shouldRead(notNull("required")); diff --git a/orc/src/main/java/org/apache/iceberg/orc/ExpressionToSearchArgument.java b/orc/src/main/java/org/apache/iceberg/orc/ExpressionToSearchArgument.java index 00bc6ff2c505..4211cf09ccd5 100644 --- a/orc/src/main/java/org/apache/iceberg/orc/ExpressionToSearchArgument.java +++ b/orc/src/main/java/org/apache/iceberg/orc/ExpressionToSearchArgument.java @@ -122,6 +122,37 @@ public Action notNull(Bound expr) { .end(); } + @Override + public Action isNaN(Bound expr) { + return () -> this.builder.equals( + idToColumnName.get(expr.ref().fieldId()), + type(expr.ref().type()), + literal(expr.ref().type(), getNaNForType(expr.ref().type()))); + } + + private Object getNaNForType(Type type) { + switch (type.typeId()) { + case FLOAT: + return Float.NaN; + case DOUBLE: + return Double.NaN; + default: + throw new IllegalArgumentException("Cannot get NaN value for type " + type.typeId()); + } + } + + @Override + public Action notNaN(Bound expr) { + return () -> { + this.builder.startOr(); + isNull(expr).invoke(); + this.builder.startNot(); + isNaN(expr).invoke(); + this.builder.end(); // end NOT + this.builder.end(); // end OR + }; + } + @Override public Action lt(Bound expr, Literal lit) { return () -> this.builder.lessThan(idToColumnName.get(expr.ref().fieldId()), diff --git a/orc/src/test/java/org/apache/iceberg/orc/TestExpressionToSearchArgument.java b/orc/src/test/java/org/apache/iceberg/orc/TestExpressionToSearchArgument.java index 8012050a8112..a7c77b111be4 100644 --- a/orc/src/test/java/org/apache/iceberg/orc/TestExpressionToSearchArgument.java +++ b/orc/src/test/java/org/apache/iceberg/orc/TestExpressionToSearchArgument.java @@ -51,11 +51,13 @@ import static org.apache.iceberg.expressions.Expressions.greaterThan; import static org.apache.iceberg.expressions.Expressions.greaterThanOrEqual; import static org.apache.iceberg.expressions.Expressions.in; +import static org.apache.iceberg.expressions.Expressions.isNaN; import static org.apache.iceberg.expressions.Expressions.isNull; import static org.apache.iceberg.expressions.Expressions.lessThan; import static org.apache.iceberg.expressions.Expressions.lessThanOrEqual; import static org.apache.iceberg.expressions.Expressions.notEqual; import static org.apache.iceberg.expressions.Expressions.notIn; +import static org.apache.iceberg.expressions.Expressions.notNaN; import static org.apache.iceberg.expressions.Expressions.notNull; import static org.apache.iceberg.types.Types.NestedField.optional; import static org.apache.iceberg.types.Types.NestedField.required; @@ -75,7 +77,9 @@ public void testPrimitiveTypes() { required(8, "time", Types.TimeType.get()), required(9, "tsTz", Types.TimestampType.withZone()), required(10, "ts", Types.TimestampType.withoutZone()), - required(11, "decimal", Types.DecimalType.of(38, 2)) + required(11, "decimal", Types.DecimalType.of(38, 2)), + required(12, "float2", Types.FloatType.get()), + required(13, "double2", Types.DoubleType.get()) ); Expression expr = and( @@ -86,7 +90,8 @@ public void testPrimitiveTypes() { and( and(equal("boolean", true), notEqual("string", "test")), and(in("decimal", BigDecimal.valueOf(-12345, 2), BigDecimal.valueOf(12345, 2)), notIn("time", 100L, 200L)) - ) + ), + and(isNaN("float2"), notNaN("double2")) ); Expression boundFilter = Binder.bind(schema.asStruct(), expr, true); SearchArgument expected = SearchArgumentFactory.newBuilder() @@ -99,6 +104,8 @@ public void testPrimitiveTypes() { .startOr().isNull("`string`", Type.STRING).startNot().equals("`string`", Type.STRING, "test").end().end() .in("`decimal`", Type.DECIMAL, new HiveDecimalWritable("-123.45"), new HiveDecimalWritable("123.45")) .startOr().isNull("`time`", Type.LONG).startNot().in("`time`", Type.LONG, 100L, 200L).end().end() + .equals("`float2`", Type.FLOAT, Double.NaN) + .startOr().isNull("`double2`", Type.FLOAT).startNot().equals("`double2`", Type.FLOAT, Double.NaN).end().end() .end() .build(); @@ -178,17 +185,19 @@ public void testUnsupportedTypes() { public void testNestedPrimitives() { Schema schema = new Schema( optional(1, "struct", Types.StructType.of( - required(2, "long", Types.LongType.get()) + required(2, "long", Types.LongType.get()), + required(11, "float", Types.FloatType.get()) )), optional(3, "list", Types.ListType.ofRequired(4, Types.LongType.get())), - optional(5, "map", Types.MapType.ofRequired(6, 7, Types.LongType.get(), Types.LongType.get())), + optional(5, "map", Types.MapType.ofRequired(6, 7, Types.LongType.get(), Types.DoubleType.get())), optional(8, "listOfStruct", Types.ListType.ofRequired(9, Types.StructType.of( required(10, "long", Types.LongType.get())))) ); Expression expr = and( and(equal("struct.long", 1), equal("list.element", 2)), - and(equal("map.key", 3), equal("listOfStruct.long", 4)) + and(equal("map.key", 3), equal("listOfStruct.long", 4)), + and(isNaN("map.value"), notNaN("struct.float")) ); Expression boundFilter = Binder.bind(schema.asStruct(), expr, true); SearchArgument expected = SearchArgumentFactory.newBuilder() @@ -197,7 +206,13 @@ public void testNestedPrimitives() { .equals("`list`.`_elem`", Type.LONG, 2L) .equals("`map`.`_key`", Type.LONG, 3L) .equals("`listOfStruct`.`_elem`.`long`", Type.LONG, 4L) - .end() + .equals("`map`.`_value`", Type.FLOAT, Double.NaN) + .startOr() + .isNull("`struct`.`float`", Type.FLOAT) + .startNot().equals("`struct`.`float`", Type.FLOAT, Double.NaN) + .end() // not + .end() // or + .end() // and .build(); SearchArgument actual = ExpressionToSearchArgument.convert(boundFilter, ORCSchemaUtil.convert(schema)); diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetDictionaryRowGroupFilter.java b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetDictionaryRowGroupFilter.java index 807bd3e6244e..d72cf49ee503 100644 --- a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetDictionaryRowGroupFilter.java +++ b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetDictionaryRowGroupFilter.java @@ -36,7 +36,9 @@ import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.relocated.com.google.common.collect.Maps; import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.types.Comparators; import org.apache.iceberg.types.Types.StructType; +import org.apache.iceberg.util.NaNUtil; import org.apache.parquet.column.ColumnDescriptor; import org.apache.parquet.column.Dictionary; import org.apache.parquet.column.page.DictionaryPage; @@ -148,6 +150,39 @@ public Boolean notNull(BoundReference ref) { return ROWS_MIGHT_MATCH; } + @Override + public Boolean isNaN(BoundReference ref) { + int id = ref.fieldId(); + + Boolean hasNonDictPage = isFallback.get(id); + if (hasNonDictPage == null || hasNonDictPage) { + return ROWS_MIGHT_MATCH; + } + + Set dictionary = dict(id, comparatorForNaNPredicate(ref)); + return dictionary.stream().anyMatch(NaNUtil::isNaN) ? ROWS_MIGHT_MATCH : ROWS_CANNOT_MATCH; + } + + @Override + public Boolean notNaN(BoundReference ref) { + int id = ref.fieldId(); + + Boolean hasNonDictPage = isFallback.get(id); + if (hasNonDictPage == null || hasNonDictPage) { + return ROWS_MIGHT_MATCH; + } + + Set dictionary = dict(id, comparatorForNaNPredicate(ref)); + return dictionary.stream().allMatch(NaNUtil::isNaN) ? ROWS_CANNOT_MATCH : ROWS_MIGHT_MATCH; + } + + private Comparator comparatorForNaNPredicate(BoundReference ref) { + // Construct the same comparator as in ComparableLiteral.comparator, ignoring null value order as dictionary + // cannot contain null values. + // No need to check type: incompatible types will be handled during expression binding. + return Comparators.forType(ref.type().asPrimitiveType()); + } + @Override public Boolean lt(BoundReference ref, Literal lit) { int id = ref.fieldId(); diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetFilters.java b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetFilters.java index d3e3df3a86d5..fa9387535f59 100644 --- a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetFilters.java +++ b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetFilters.java @@ -181,6 +181,22 @@ FilterPredicate pred(Operation op, COL col, C value) { return FilterApi.eq(col, null); case NOT_NULL: return FilterApi.notEq(col, null); + case IS_NAN: + if (col.getColumnType().equals(Double.class)) { + return FilterApi.eq(col, (C) (Double) Double.NaN); + } else if (col.getColumnType().equals(Float.class)) { + return FilterApi.eq(col, (C) (Float) Float.NaN); + } else { + return AlwaysFalse.INSTANCE; + } + case NOT_NAN: + if (col.getColumnType().equals(Double.class)) { + return FilterApi.notEq(col, (C) (Double) Double.NaN); + } else if (col.getColumnType().equals(Float.class)) { + return FilterApi.notEq(col, (C) (Float) Float.NaN); + } else { + return AlwaysTrue.INSTANCE; + } case EQ: return FilterApi.eq(col, value); case NOT_EQ: diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetMetricsRowGroupFilter.java b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetMetricsRowGroupFilter.java index e42f44026fce..ee026cc38100 100644 --- a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetMetricsRowGroupFilter.java +++ b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetMetricsRowGroupFilter.java @@ -176,6 +176,30 @@ public Boolean notNull(BoundReference ref) { return ROWS_MIGHT_MATCH; } + @Override + public Boolean isNaN(BoundReference ref) { + int id = ref.fieldId(); + + Long valueCount = valueCounts.get(id); + if (valueCount == null) { + // the column is not present and is all nulls + return ROWS_CANNOT_MATCH; + } + + Statistics colStats = stats.get(id); + if (colStats != null && valueCount - colStats.getNumNulls() == 0) { + // (num nulls == value count) => all values are null => no nan values + return ROWS_CANNOT_MATCH; + } + + return ROWS_MIGHT_MATCH; + } + + @Override + public Boolean notNaN(BoundReference ref) { + return ROWS_MIGHT_MATCH; + } + @Override public Boolean lt(BoundReference ref, Literal lit) { Integer id = ref.fieldId(); diff --git a/parquet/src/test/java/org/apache/iceberg/parquet/TestDictionaryRowGroupFilter.java b/parquet/src/test/java/org/apache/iceberg/parquet/TestDictionaryRowGroupFilter.java index 894a8e46949f..a5e7a353093d 100644 --- a/parquet/src/test/java/org/apache/iceberg/parquet/TestDictionaryRowGroupFilter.java +++ b/parquet/src/test/java/org/apache/iceberg/parquet/TestDictionaryRowGroupFilter.java @@ -36,6 +36,7 @@ import org.apache.iceberg.io.InputFile; import org.apache.iceberg.io.OutputFile; import org.apache.iceberg.types.Types; +import org.apache.iceberg.types.Types.DoubleType; import org.apache.iceberg.types.Types.FloatType; import org.apache.iceberg.types.Types.IntegerType; import org.apache.iceberg.types.Types.LongType; @@ -56,12 +57,14 @@ import static org.apache.iceberg.expressions.Expressions.greaterThan; import static org.apache.iceberg.expressions.Expressions.greaterThanOrEqual; import static org.apache.iceberg.expressions.Expressions.in; +import static org.apache.iceberg.expressions.Expressions.isNaN; import static org.apache.iceberg.expressions.Expressions.isNull; import static org.apache.iceberg.expressions.Expressions.lessThan; import static org.apache.iceberg.expressions.Expressions.lessThanOrEqual; import static org.apache.iceberg.expressions.Expressions.not; import static org.apache.iceberg.expressions.Expressions.notEqual; import static org.apache.iceberg.expressions.Expressions.notIn; +import static org.apache.iceberg.expressions.Expressions.notNaN; import static org.apache.iceberg.expressions.Expressions.notNull; import static org.apache.iceberg.expressions.Expressions.or; import static org.apache.iceberg.expressions.Expressions.startsWith; @@ -82,7 +85,10 @@ public class TestDictionaryRowGroupFilter { optional(6, "no_nulls", StringType.get()), optional(7, "non_dict", StringType.get()), optional(8, "struct_not_null", structFieldType), - optional(10, "not_in_file", FloatType.get()) + optional(10, "not_in_file", FloatType.get()), + optional(11, "all_nans", DoubleType.get()), + optional(12, "some_nans", FloatType.get()), + optional(13, "no_nans", DoubleType.get()) ); private static final Types.StructType _structFieldType = @@ -96,7 +102,11 @@ public class TestDictionaryRowGroupFilter { optional(5, "_some_nulls", StringType.get()), optional(6, "_no_nulls", StringType.get()), optional(7, "_non_dict", StringType.get()), - optional(8, "_struct_not_null", _structFieldType) + optional(8, "_struct_not_null", _structFieldType), + optional(11, "_all_nans", DoubleType.get()), + optional(12, "_some_nans", FloatType.get()), + optional(13, "_no_nans", DoubleType.get()) + ); private static final String TOO_LONG_FOR_STATS; @@ -143,6 +153,9 @@ public void createInputFile() throws IOException { builder.set("_some_nulls", (i % 10 == 0) ? null : "some"); // includes some null values builder.set("_no_nulls", ""); // optional, but always non-null builder.set("_non_dict", UUID.randomUUID().toString()); // not dictionary-encoded + builder.set("_all_nans", Double.NaN); // never non-nan + builder.set("_some_nans", (i % 10 == 0) ? Float.NaN : 2F); // includes some nan values + builder.set("_no_nans", 3D); // optional, but always non-nan Record structNotNull = new Record(structSchema); structNotNull.put("_int_field", INT_MIN_VALUE + i); @@ -245,6 +258,36 @@ public void testRequiredColumn() { Assert.assertFalse("Should skip: required columns are always non-null", shouldRead); } + @Test + public void testIsNaNs() { + boolean shouldRead = new ParquetDictionaryRowGroupFilter(SCHEMA, isNaN("all_nans")) + .shouldRead(parquetSchema, rowGroupMetadata, dictionaryStore); + Assert.assertTrue("Should read: all_nans column will contain NaN", shouldRead); + + shouldRead = new ParquetDictionaryRowGroupFilter(SCHEMA, isNaN("some_nans")) + .shouldRead(parquetSchema, rowGroupMetadata, dictionaryStore); + Assert.assertTrue("Should read: some_nans column will contain NaN", shouldRead); + + shouldRead = new ParquetDictionaryRowGroupFilter(SCHEMA, isNaN("no_nans")) + .shouldRead(parquetSchema, rowGroupMetadata, dictionaryStore); + Assert.assertFalse("Should skip: no_nans column will not contain NaN", shouldRead); + } + + @Test + public void testNotNaNs() { + boolean shouldRead = new ParquetDictionaryRowGroupFilter(SCHEMA, notNaN("all_nans")) + .shouldRead(parquetSchema, rowGroupMetadata, dictionaryStore); + Assert.assertFalse("Should skip: all_nans column will not contain non-NaN", shouldRead); + + shouldRead = new ParquetDictionaryRowGroupFilter(SCHEMA, notNaN("some_nans")) + .shouldRead(parquetSchema, rowGroupMetadata, dictionaryStore); + Assert.assertTrue("Should read: some_nans column will contain non-NaN", shouldRead); + + shouldRead = new ParquetDictionaryRowGroupFilter(SCHEMA, notNaN("no_nans")) + .shouldRead(parquetSchema, rowGroupMetadata, dictionaryStore); + Assert.assertTrue("Should read: no_nans column will contain non-NaN", shouldRead); + } + @Test public void testStartsWith() { boolean shouldRead = new ParquetDictionaryRowGroupFilter(SCHEMA, startsWith("non_dict", "re")) diff --git a/spark3/src/main/java/org/apache/iceberg/spark/Spark3Util.java b/spark3/src/main/java/org/apache/iceberg/spark/Spark3Util.java index 10523d43991c..093deb5a9cfc 100644 --- a/spark3/src/main/java/org/apache/iceberg/spark/Spark3Util.java +++ b/spark3/src/main/java/org/apache/iceberg/spark/Spark3Util.java @@ -527,6 +527,10 @@ public String predicate(UnboundPredicate pred) { return pred.ref().name() + " IS NULL"; case NOT_NULL: return pred.ref().name() + " IS NOT NULL"; + case IS_NAN: + return pred.ref().name() + " = NaN"; + case NOT_NAN: + return pred.ref().name() + " != NaN"; case LT: return pred.ref().name() + " < " + sqlString(pred.literal()); case LT_EQ: