diff --git a/api/src/main/java/org/apache/iceberg/expressions/Literals.java b/api/src/main/java/org/apache/iceberg/expressions/Literals.java index 8db51128102e..44c8b8be86d0 100644 --- a/api/src/main/java/org/apache/iceberg/expressions/Literals.java +++ b/api/src/main/java/org/apache/iceberg/expressions/Literals.java @@ -58,7 +58,7 @@ private Literals() { @SuppressWarnings("unchecked") static Literal from(T value) { Preconditions.checkNotNull(value, "Cannot create expression literal from null"); - Preconditions.checkArgument(!NaNUtil.isNaN(value), "Cannot expression literal from NaN"); + Preconditions.checkArgument(!NaNUtil.isNaN(value), "Cannot create expression literal from NaN"); if (value instanceof Boolean) { return (Literal) new Literals.BooleanLiteral((Boolean) value); diff --git a/api/src/test/java/org/apache/iceberg/expressions/TestExpressionHelpers.java b/api/src/test/java/org/apache/iceberg/expressions/TestExpressionHelpers.java index 5625e4a056ab..3bb34072ffc6 100644 --- a/api/src/test/java/org/apache/iceberg/expressions/TestExpressionHelpers.java +++ b/api/src/test/java/org/apache/iceberg/expressions/TestExpressionHelpers.java @@ -223,7 +223,7 @@ public void testInvalidateNaNInput() { private void assertInvalidateNaNThrows(Callable> callable) { AssertHelpers.assertThrows("Should invalidate NaN input", - IllegalArgumentException.class, "Cannot expression literal from NaN", + IllegalArgumentException.class, "Cannot create expression literal from NaN", callable); } diff --git a/mr/src/main/java/org/apache/iceberg/mr/hive/HiveIcebergFilterFactory.java b/mr/src/main/java/org/apache/iceberg/mr/hive/HiveIcebergFilterFactory.java index d4239ef3e332..33791c719bbe 100644 --- a/mr/src/main/java/org/apache/iceberg/mr/hive/HiveIcebergFilterFactory.java +++ b/mr/src/main/java/org/apache/iceberg/mr/hive/HiveIcebergFilterFactory.java @@ -34,11 +34,13 @@ import org.apache.iceberg.expressions.Expression; import org.apache.iceberg.expressions.Expressions; import org.apache.iceberg.util.DateTimeUtil; +import org.apache.iceberg.util.NaNUtil; import static org.apache.iceberg.expressions.Expressions.and; import static org.apache.iceberg.expressions.Expressions.equal; import static org.apache.iceberg.expressions.Expressions.greaterThanOrEqual; import static org.apache.iceberg.expressions.Expressions.in; +import static org.apache.iceberg.expressions.Expressions.isNaN; import static org.apache.iceberg.expressions.Expressions.isNull; import static org.apache.iceberg.expressions.Expressions.lessThan; import static org.apache.iceberg.expressions.Expressions.lessThanOrEqual; @@ -96,7 +98,8 @@ private static Expression translateLeaf(PredicateLeaf leaf) { String column = leaf.getColumnName(); switch (leaf.getOperator()) { case EQUALS: - return equal(column, leafToLiteral(leaf)); + Object literal = leafToLiteral(leaf); + return NaNUtil.isNaN(literal) ? isNaN(column) : equal(column, literal); case LESS_THAN: return lessThan(column, leafToLiteral(leaf)); case LESS_THAN_EQUALS: diff --git a/mr/src/test/java/org/apache/iceberg/mr/hive/TestHiveIcebergFilterFactory.java b/mr/src/test/java/org/apache/iceberg/mr/hive/TestHiveIcebergFilterFactory.java index 2d6b632756d0..3d436dcbe323 100644 --- a/mr/src/test/java/org/apache/iceberg/mr/hive/TestHiveIcebergFilterFactory.java +++ b/mr/src/test/java/org/apache/iceberg/mr/hive/TestHiveIcebergFilterFactory.java @@ -53,6 +53,17 @@ public void testEqualsOperand() { assertPredicatesMatch(expected, actual); } + @Test + public void testEqualsOperandRewrite() { + SearchArgument.Builder builder = SearchArgumentFactory.newBuilder(); + SearchArgument arg = builder.startAnd().equals("float", PredicateLeaf.Type.FLOAT, Double.NaN).end().build(); + + UnboundPredicate expected = Expressions.isNaN("float"); + UnboundPredicate actual = (UnboundPredicate) HiveIcebergFilterFactory.generateFilterExpression(arg); + + assertPredicatesMatch(expected, actual); + } + @Test public void testNotEqualsOperand() { SearchArgument.Builder builder = SearchArgumentFactory.newBuilder(); diff --git a/pig/src/main/java/org/apache/iceberg/pig/IcebergStorage.java b/pig/src/main/java/org/apache/iceberg/pig/IcebergStorage.java index 6a41f8fe0248..1daa33afedf6 100644 --- a/pig/src/main/java/org/apache/iceberg/pig/IcebergStorage.java +++ b/pig/src/main/java/org/apache/iceberg/pig/IcebergStorage.java @@ -42,6 +42,7 @@ import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.iceberg.relocated.com.google.common.collect.Maps; import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.NaNUtil; import org.apache.pig.Expression; import org.apache.pig.Expression.BetweenExpression; import org.apache.pig.Expression.BinaryExpression; @@ -234,8 +235,8 @@ private org.apache.iceberg.expressions.Expression convert(OpType op, Column col, case OP_GT: return Expressions.greaterThan(name, value); case OP_LE: return Expressions.lessThanOrEqual(name, value); case OP_LT: return Expressions.lessThan(name, value); - case OP_EQ: return Expressions.equal(name, value); - case OP_NE: return Expressions.notEqual(name, value); + case OP_EQ: return NaNUtil.isNaN(value) ? Expressions.isNaN(name) : Expressions.equal(name, value); + case OP_NE: return NaNUtil.isNaN(value) ? Expressions.notNaN(name) : Expressions.notEqual(name, value); } throw new RuntimeException( diff --git a/spark2/src/main/java/org/apache/iceberg/spark/SparkFilters.java b/spark2/src/main/java/org/apache/iceberg/spark/SparkFilters.java index 0e8bb67d7b7b..0703688b9773 100644 --- a/spark2/src/main/java/org/apache/iceberg/spark/SparkFilters.java +++ b/spark2/src/main/java/org/apache/iceberg/spark/SparkFilters.java @@ -28,6 +28,7 @@ import org.apache.iceberg.expressions.Expression.Operation; import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.util.NaNUtil; import org.apache.spark.sql.catalyst.util.DateTimeUtils; import org.apache.spark.sql.sources.And; import org.apache.spark.sql.sources.EqualNullSafe; @@ -49,6 +50,7 @@ import static org.apache.iceberg.expressions.Expressions.greaterThan; import static org.apache.iceberg.expressions.Expressions.greaterThanOrEqual; import static org.apache.iceberg.expressions.Expressions.in; +import static org.apache.iceberg.expressions.Expressions.isNaN; import static org.apache.iceberg.expressions.Expressions.isNull; import static org.apache.iceberg.expressions.Expressions.lessThan; import static org.apache.iceberg.expressions.Expressions.lessThanOrEqual; @@ -113,13 +115,13 @@ public static Expression convert(Filter filter) { // comparison with null in normal equality is always null. this is probably a mistake. Preconditions.checkNotNull(eq.value(), "Expression is always false (eq is not null-safe): %s", filter); - return equal(eq.attribute(), convertLiteral(eq.value())); + return handleEqual(eq.attribute(), eq.value()); } else { EqualNullSafe eq = (EqualNullSafe) filter; if (eq.value() == null) { return isNull(eq.attribute()); } else { - return equal(eq.attribute(), convertLiteral(eq.value())); + return handleEqual(eq.attribute(), eq.value()); } } @@ -177,4 +179,12 @@ private static Object convertLiteral(Object value) { } return value; } + + private static Expression handleEqual(String attribute, Object value) { + if (NaNUtil.isNaN(value)) { + return isNaN(attribute); + } else { + return equal(attribute, convertLiteral(value)); + } + } } diff --git a/spark2/src/test/java/org/apache/iceberg/spark/source/TestSelect.java b/spark2/src/test/java/org/apache/iceberg/spark/source/TestSelect.java new file mode 100644 index 000000000000..5df767ecac64 --- /dev/null +++ b/spark2/src/test/java/org/apache/iceberg/spark/source/TestSelect.java @@ -0,0 +1,244 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.io.File; +import java.io.Serializable; +import java.util.List; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.events.Listeners; +import org.apache.iceberg.events.ScanEvent; +import org.apache.iceberg.expressions.And; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.relocated.com.google.common.base.Objects; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import static org.apache.iceberg.types.Types.NestedField.optional; + +public class TestSelect { + private static final HadoopTables TABLES = new HadoopTables(new Configuration()); + private static final Schema SCHEMA = new Schema( + optional(1, "id", Types.IntegerType.get()), + optional(2, "data", Types.StringType.get()), + optional(3, "doubleVal", Types.DoubleType.get()) + ); + + private static SparkSession spark; + + private static int scanEventCount = 0; + private static ScanEvent lastScanEvent = null; + + private Table table; + + static { + Listeners.register(event -> { + scanEventCount += 1; + lastScanEvent = event; + }, ScanEvent.class); + } + + @BeforeClass + public static void startSpark() { + spark = SparkSession.builder() + .master("local[2]") + .getOrCreate(); + } + + @AfterClass + public static void stopSpark() { + SparkSession currentSpark = spark; + spark = null; + currentSpark.stop(); + } + + @Rule + public TemporaryFolder temp = new TemporaryFolder(); + + private String tableLocation = null; + + @Before + public void init() throws Exception { + File tableDir = temp.newFolder(); + this.tableLocation = tableDir.toURI().toString(); + + table = TABLES.create(SCHEMA, tableLocation); + + List rows = Lists.newArrayList( + new Record(1, "a", 1.0), + new Record(2, "b", 2.0), + new Record(3, "c", Double.NaN) + ); + + Dataset df = spark.createDataFrame(rows, Record.class); + + df.select("id", "data", "doubleVal").write() + .format("iceberg") + .mode("append") + .save(tableLocation); + + table.refresh(); + + Dataset results = spark.read() + .format("iceberg") + .load(tableLocation); + results.createOrReplaceTempView("table"); + + scanEventCount = 0; + lastScanEvent = null; + } + + @Test + public void testSelect() { + List expected = ImmutableList.of( + new Record(1, "a", 1.0), new Record(2, "b", 2.0), new Record(3, "c", Double.NaN)); + + Assert.assertEquals("Should return all expected rows", expected, + sql("select * from table", Encoders.bean(Record.class))); + } + + @Test + public void testSelectRewrite() { + List expected = ImmutableList.of(new Record(3, "c", Double.NaN)); + + Assert.assertEquals("Should return all expected rows", expected, + sql("SELECT * FROM table where doubleVal = double('NaN')", Encoders.bean(Record.class))); + Assert.assertEquals("Should create only one scan", 1, scanEventCount); + + Expression filter = lastScanEvent.filter(); + Assert.assertEquals("Should create AND expression", Expression.Operation.AND, filter.op()); + Expression left = ((And) filter).left(); + Expression right = ((And) filter).right(); + + Assert.assertEquals("Left expression should be NOT_NULL", + Expression.Operation.NOT_NULL, left.op()); + Assert.assertTrue("Left expression should contain column name 'doubleVal'", + left.toString().contains("doubleVal")); + Assert.assertEquals("Right expression should be IS_NAN", + Expression.Operation.IS_NAN, right.op()); + Assert.assertTrue("Right expression should contain column name 'doubleVal'", + right.toString().contains("doubleVal")); + } + + @Test + public void testProjection() { + List expected = ImmutableList.of(1, 2, 3); + + Assert.assertEquals("Should return all expected rows", expected, sql("SELECT id FROM table", Encoders.INT())); + + Assert.assertEquals("Should create only one scan", 1, scanEventCount); + Assert.assertEquals("Should not push down a filter", Expressions.alwaysTrue(), lastScanEvent.filter()); + Assert.assertEquals("Should project only the id column", + table.schema().select("id").asStruct(), + lastScanEvent.projection().asStruct()); + } + + @Test + public void testExpressionPushdown() { + List expected = ImmutableList.of("b"); + + Assert.assertEquals("Should return all expected rows", expected, + sql("SELECT data FROM table WHERE id = 2", Encoders.STRING())); + + Assert.assertEquals("Should create only one scan", 1, scanEventCount); + Assert.assertEquals("Should project only id and data columns", + table.schema().select("id", "data").asStruct(), + lastScanEvent.projection().asStruct()); + } + + private List sql(String str, Encoder encoder) { + return spark.sql(str).as(encoder).collectAsList(); + } + + public static class Record implements Serializable { + private Integer id; + private String data; + private Double doubleVal; + + public Record() { + } + + Record(Integer id, String data, Double doubleVal) { + this.id = id; + this.data = data; + this.doubleVal = doubleVal; + } + + public void setId(Integer id) { + this.id = id; + } + + public void setData(String data) { + this.data = data; + } + + public void setDoubleVal(Double doubleVal) { + this.doubleVal = doubleVal; + } + + public Integer getId() { + return id; + } + + public String getData() { + return data; + } + + public Double getDoubleVal() { + return doubleVal; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + Record record = (Record) o; + return Objects.equal(id, record.id) && Objects.equal(data, record.data) && + Objects.equal(doubleVal, record.doubleVal); + } + + @Override + public int hashCode() { + return Objects.hashCode(id, data, doubleVal); + } + } +} diff --git a/spark3/src/main/java/org/apache/iceberg/spark/Spark3Util.java b/spark3/src/main/java/org/apache/iceberg/spark/Spark3Util.java index 82f0de33b5a6..62045be29bb3 100644 --- a/spark3/src/main/java/org/apache/iceberg/spark/Spark3Util.java +++ b/spark3/src/main/java/org/apache/iceberg/spark/Spark3Util.java @@ -533,9 +533,9 @@ public String predicate(UnboundPredicate pred) { case NOT_NULL: return pred.ref().name() + " IS NOT NULL"; case IS_NAN: - return pred.ref().name() + " = NaN"; + return "is_nan(" + pred.ref().name() + ")"; case NOT_NAN: - return pred.ref().name() + " != NaN"; + return "not_nan(" + pred.ref().name() + ")"; case LT: return pred.ref().name() + " < " + sqlString(pred.literal()); case LT_EQ: diff --git a/spark3/src/main/java/org/apache/iceberg/spark/SparkFilters.java b/spark3/src/main/java/org/apache/iceberg/spark/SparkFilters.java index 95e5a101f890..c4fb57906774 100644 --- a/spark3/src/main/java/org/apache/iceberg/spark/SparkFilters.java +++ b/spark3/src/main/java/org/apache/iceberg/spark/SparkFilters.java @@ -30,6 +30,7 @@ import org.apache.iceberg.expressions.Expressions; import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.util.NaNUtil; import org.apache.spark.sql.catalyst.util.DateTimeUtils; import org.apache.spark.sql.sources.AlwaysFalse$; import org.apache.spark.sql.sources.AlwaysTrue$; @@ -53,6 +54,7 @@ import static org.apache.iceberg.expressions.Expressions.greaterThan; import static org.apache.iceberg.expressions.Expressions.greaterThanOrEqual; import static org.apache.iceberg.expressions.Expressions.in; +import static org.apache.iceberg.expressions.Expressions.isNaN; import static org.apache.iceberg.expressions.Expressions.isNull; import static org.apache.iceberg.expressions.Expressions.lessThan; import static org.apache.iceberg.expressions.Expressions.lessThanOrEqual; @@ -135,13 +137,13 @@ public static Expression convert(Filter filter) { // comparison with null in normal equality is always null. this is probably a mistake. Preconditions.checkNotNull(eq.value(), "Expression is always false (eq is not null-safe): %s", filter); - return equal(eq.attribute(), convertLiteral(eq.value())); + return handleEqual(eq.attribute(), eq.value()); } else { EqualNullSafe eq = (EqualNullSafe) filter; if (eq.value() == null) { return isNull(eq.attribute()); } else { - return equal(eq.attribute(), convertLiteral(eq.value())); + return handleEqual(eq.attribute(), eq.value()); } } @@ -199,4 +201,12 @@ private static Object convertLiteral(Object value) { } return value; } + + private static Expression handleEqual(String attribute, Object value) { + if (NaNUtil.isNaN(value)) { + return isNaN(attribute); + } else { + return equal(attribute, convertLiteral(value)); + } + } } diff --git a/spark3/src/test/java/org/apache/iceberg/spark/sql/TestSelect.java b/spark3/src/test/java/org/apache/iceberg/spark/sql/TestSelect.java index 51da0735f394..846e234cba07 100644 --- a/spark3/src/test/java/org/apache/iceberg/spark/sql/TestSelect.java +++ b/spark3/src/test/java/org/apache/iceberg/spark/sql/TestSelect.java @@ -49,8 +49,8 @@ public TestSelect(String catalogName, String implementation, Map @Before public void createTables() { - sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); - sql("INSERT INTO %s VALUES (1, 'a'), (2, 'b'), (3, 'c')", tableName); + sql("CREATE TABLE %s (id bigint, data string, float float) USING iceberg", tableName); + sql("INSERT INTO %s VALUES (1, 'a', 1.0), (2, 'b', 2.0), (3, 'c', float('NaN'))", tableName); this.scanEventCount = 0; this.lastScanEvent = null; @@ -63,11 +63,25 @@ public void removeTables() { @Test public void testSelect() { - List expected = ImmutableList.of(row(1L, "a"), row(2L, "b"), row(3L, "c")); + List expected = ImmutableList.of( + row(1L, "a", 1.0F), row(2L, "b", 2.0F), row(3L, "c", Float.NaN)); assertEquals("Should return all expected rows", expected, sql("SELECT * FROM %s", tableName)); } + @Test + public void testSelectRewrite() { + List expected = ImmutableList.of(row(3L, "c", Float.NaN)); + + assertEquals("Should return all expected rows", expected, + sql("SELECT * FROM %s where float = float('NaN')", tableName)); + + Assert.assertEquals("Should create only one scan", 1, scanEventCount); + Assert.assertEquals("Should push down expected filter", + "(float IS NOT NULL AND is_nan(float))", + Spark3Util.describe(lastScanEvent.filter())); + } + @Test public void testProjection() { List expected = ImmutableList.of(row(1L), row(2L), row(3L)); @@ -88,11 +102,11 @@ public void testExpressionPushdown() { assertEquals("Should return all expected rows", expected, sql("SELECT data FROM %s WHERE id = 2", tableName)); Assert.assertEquals("Should create only one scan", 1, scanEventCount); - Assert.assertEquals("Should not push down a filter", + Assert.assertEquals("Should push down expected filter", "(id IS NOT NULL AND id = 2)", Spark3Util.describe(lastScanEvent.filter())); - Assert.assertEquals("Should project only the id column", - validationCatalog.loadTable(tableIdent).schema().asStruct(), + Assert.assertEquals("Should project only id and data columns", + validationCatalog.loadTable(tableIdent).schema().select("id", "data").asStruct(), lastScanEvent.projection().asStruct()); }