Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ private Literals() {
@SuppressWarnings("unchecked")
static <T> Literal<T> from(T value) {
Preconditions.checkNotNull(value, "Cannot create expression literal from null");
Preconditions.checkArgument(!NaNUtil.isNaN(value), "Cannot expression literal from NaN");
Preconditions.checkArgument(!NaNUtil.isNaN(value), "Cannot create expression literal from NaN");

if (value instanceof Boolean) {
return (Literal<T>) new Literals.BooleanLiteral((Boolean) value);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ public void testInvalidateNaNInput() {

private void assertInvalidateNaNThrows(Callable<UnboundPredicate<Double>> callable) {
AssertHelpers.assertThrows("Should invalidate NaN input",
IllegalArgumentException.class, "Cannot expression literal from NaN",
IllegalArgumentException.class, "Cannot create expression literal from NaN",
callable);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,13 @@
import org.apache.iceberg.expressions.Expression;
import org.apache.iceberg.expressions.Expressions;
import org.apache.iceberg.util.DateTimeUtil;
import org.apache.iceberg.util.NaNUtil;

import static org.apache.iceberg.expressions.Expressions.and;
import static org.apache.iceberg.expressions.Expressions.equal;
import static org.apache.iceberg.expressions.Expressions.greaterThanOrEqual;
import static org.apache.iceberg.expressions.Expressions.in;
import static org.apache.iceberg.expressions.Expressions.isNaN;
import static org.apache.iceberg.expressions.Expressions.isNull;
import static org.apache.iceberg.expressions.Expressions.lessThan;
import static org.apache.iceberg.expressions.Expressions.lessThanOrEqual;
Expand Down Expand Up @@ -96,7 +98,8 @@ private static Expression translateLeaf(PredicateLeaf leaf) {
String column = leaf.getColumnName();
switch (leaf.getOperator()) {
case EQUALS:
return equal(column, leafToLiteral(leaf));
Object literal = leafToLiteral(leaf);
return NaNUtil.isNaN(literal) ? isNaN(column) : equal(column, literal);
case LESS_THAN:
return lessThan(column, leafToLiteral(leaf));
case LESS_THAN_EQUALS:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,17 @@ public void testEqualsOperand() {
assertPredicatesMatch(expected, actual);
}

@Test
public void testEqualsOperandRewrite() {
SearchArgument.Builder builder = SearchArgumentFactory.newBuilder();
SearchArgument arg = builder.startAnd().equals("float", PredicateLeaf.Type.FLOAT, Double.NaN).end().build();

UnboundPredicate expected = Expressions.isNaN("float");
UnboundPredicate actual = (UnboundPredicate) HiveIcebergFilterFactory.generateFilterExpression(arg);

assertPredicatesMatch(expected, actual);
}

@Test
public void testNotEqualsOperand() {
SearchArgument.Builder builder = SearchArgumentFactory.newBuilder();
Expand Down
5 changes: 3 additions & 2 deletions pig/src/main/java/org/apache/iceberg/pig/IcebergStorage.java
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.iceberg.relocated.com.google.common.collect.Maps;
import org.apache.iceberg.types.Types;
import org.apache.iceberg.util.NaNUtil;
import org.apache.pig.Expression;
import org.apache.pig.Expression.BetweenExpression;
import org.apache.pig.Expression.BinaryExpression;
Expand Down Expand Up @@ -234,8 +235,8 @@ private org.apache.iceberg.expressions.Expression convert(OpType op, Column col,
case OP_GT: return Expressions.greaterThan(name, value);
case OP_LE: return Expressions.lessThanOrEqual(name, value);
case OP_LT: return Expressions.lessThan(name, value);
case OP_EQ: return Expressions.equal(name, value);
case OP_NE: return Expressions.notEqual(name, value);
case OP_EQ: return NaNUtil.isNaN(value) ? Expressions.isNaN(name) : Expressions.equal(name, value);
case OP_NE: return NaNUtil.isNaN(value) ? Expressions.notNaN(name) : Expressions.notEqual(name, value);
}

throw new RuntimeException(
Expand Down
14 changes: 12 additions & 2 deletions spark2/src/main/java/org/apache/iceberg/spark/SparkFilters.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.apache.iceberg.expressions.Expression.Operation;
import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
import org.apache.iceberg.util.NaNUtil;
import org.apache.spark.sql.catalyst.util.DateTimeUtils;
import org.apache.spark.sql.sources.And;
import org.apache.spark.sql.sources.EqualNullSafe;
Expand All @@ -49,6 +50,7 @@
import static org.apache.iceberg.expressions.Expressions.greaterThan;
import static org.apache.iceberg.expressions.Expressions.greaterThanOrEqual;
import static org.apache.iceberg.expressions.Expressions.in;
import static org.apache.iceberg.expressions.Expressions.isNaN;
import static org.apache.iceberg.expressions.Expressions.isNull;
import static org.apache.iceberg.expressions.Expressions.lessThan;
import static org.apache.iceberg.expressions.Expressions.lessThanOrEqual;
Expand Down Expand Up @@ -113,13 +115,13 @@ public static Expression convert(Filter filter) {
// comparison with null in normal equality is always null. this is probably a mistake.
Preconditions.checkNotNull(eq.value(),
"Expression is always false (eq is not null-safe): %s", filter);
return equal(eq.attribute(), convertLiteral(eq.value()));
return handleEqual(eq.attribute(), eq.value());
} else {
EqualNullSafe eq = (EqualNullSafe) filter;
if (eq.value() == null) {
return isNull(eq.attribute());
} else {
return equal(eq.attribute(), convertLiteral(eq.value()));
return handleEqual(eq.attribute(), eq.value());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this not directly inside Expressions.equal, so we can avoid duplication between spark 2 and 3?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought to reject NaN in any predicate and let SparkFilters to do rewrites was the conclusion we reached in this thread?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree. Rewriting filters should be done in translation to Iceberg so that we have simpler behavior and strong assumptions.

}
}

Expand Down Expand Up @@ -177,4 +179,12 @@ private static Object convertLiteral(Object value) {
}
return value;
}

private static Expression handleEqual(String attribute, Object value) {
if (NaNUtil.isNaN(value)) {
return isNaN(attribute);
} else {
return equal(attribute, convertLiteral(value));
}
}
}
244 changes: 244 additions & 0 deletions spark2/src/test/java/org/apache/iceberg/spark/source/TestSelect.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.iceberg.spark.source;

import java.io.File;
import java.io.Serializable;
import java.util.List;
import org.apache.hadoop.conf.Configuration;
import org.apache.iceberg.Schema;
import org.apache.iceberg.Table;
import org.apache.iceberg.events.Listeners;
import org.apache.iceberg.events.ScanEvent;
import org.apache.iceberg.expressions.And;
import org.apache.iceberg.expressions.Expression;
import org.apache.iceberg.expressions.Expressions;
import org.apache.iceberg.hadoop.HadoopTables;
import org.apache.iceberg.relocated.com.google.common.base.Objects;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.iceberg.types.Types;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;

import static org.apache.iceberg.types.Types.NestedField.optional;

public class TestSelect {
private static final HadoopTables TABLES = new HadoopTables(new Configuration());
private static final Schema SCHEMA = new Schema(
optional(1, "id", Types.IntegerType.get()),
optional(2, "data", Types.StringType.get()),
optional(3, "doubleVal", Types.DoubleType.get())
);

private static SparkSession spark;

private static int scanEventCount = 0;
private static ScanEvent lastScanEvent = null;

private Table table;

static {
Listeners.register(event -> {
scanEventCount += 1;
lastScanEvent = event;
}, ScanEvent.class);
}

@BeforeClass
public static void startSpark() {
spark = SparkSession.builder()
.master("local[2]")
.getOrCreate();
}

@AfterClass
public static void stopSpark() {
SparkSession currentSpark = spark;
spark = null;
currentSpark.stop();
}

@Rule
public TemporaryFolder temp = new TemporaryFolder();

private String tableLocation = null;

@Before
public void init() throws Exception {
File tableDir = temp.newFolder();
this.tableLocation = tableDir.toURI().toString();

table = TABLES.create(SCHEMA, tableLocation);

List<Record> rows = Lists.newArrayList(
new Record(1, "a", 1.0),
new Record(2, "b", 2.0),
new Record(3, "c", Double.NaN)
);

Dataset<Row> df = spark.createDataFrame(rows, Record.class);

df.select("id", "data", "doubleVal").write()
.format("iceberg")
.mode("append")
.save(tableLocation);

table.refresh();

Dataset<Row> results = spark.read()
.format("iceberg")
.load(tableLocation);
results.createOrReplaceTempView("table");

scanEventCount = 0;
lastScanEvent = null;
}

@Test
public void testSelect() {
List<Record> expected = ImmutableList.of(
new Record(1, "a", 1.0), new Record(2, "b", 2.0), new Record(3, "c", Double.NaN));

Assert.assertEquals("Should return all expected rows", expected,
sql("select * from table", Encoders.bean(Record.class)));
}

@Test
public void testSelectRewrite() {
List<Record> expected = ImmutableList.of(new Record(3, "c", Double.NaN));

Assert.assertEquals("Should return all expected rows", expected,
sql("SELECT * FROM table where doubleVal = double('NaN')", Encoders.bean(Record.class)));
Assert.assertEquals("Should create only one scan", 1, scanEventCount);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this validate more than just the number of scans?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, sorry I forgot to revisit this after cleaning up other changes. Since in spark2 we don't have Spark3Util.describe() I wasn't sure to which level we want to assert the expression, so that we can still have test coverage without being too coupled with internal implementation. Let me know how you think the updated test is!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!


Expression filter = lastScanEvent.filter();
Assert.assertEquals("Should create AND expression", Expression.Operation.AND, filter.op());
Expression left = ((And) filter).left();
Expression right = ((And) filter).right();

Assert.assertEquals("Left expression should be NOT_NULL",
Expression.Operation.NOT_NULL, left.op());
Assert.assertTrue("Left expression should contain column name 'doubleVal'",
left.toString().contains("doubleVal"));
Assert.assertEquals("Right expression should be IS_NAN",
Expression.Operation.IS_NAN, right.op());
Assert.assertTrue("Right expression should contain column name 'doubleVal'",
right.toString().contains("doubleVal"));
}

@Test
public void testProjection() {
List<Integer> expected = ImmutableList.of(1, 2, 3);

Assert.assertEquals("Should return all expected rows", expected, sql("SELECT id FROM table", Encoders.INT()));

Assert.assertEquals("Should create only one scan", 1, scanEventCount);
Assert.assertEquals("Should not push down a filter", Expressions.alwaysTrue(), lastScanEvent.filter());
Assert.assertEquals("Should project only the id column",
table.schema().select("id").asStruct(),
lastScanEvent.projection().asStruct());
}

@Test
public void testExpressionPushdown() {
List<String> expected = ImmutableList.of("b");

Assert.assertEquals("Should return all expected rows", expected,
sql("SELECT data FROM table WHERE id = 2", Encoders.STRING()));

Assert.assertEquals("Should create only one scan", 1, scanEventCount);
Assert.assertEquals("Should project only id and data columns",
table.schema().select("id", "data").asStruct(),
lastScanEvent.projection().asStruct());
}

private <T> List<T> sql(String str, Encoder<T> encoder) {
return spark.sql(str).as(encoder).collectAsList();
}

public static class Record implements Serializable {
private Integer id;
private String data;
private Double doubleVal;

public Record() {
}

Record(Integer id, String data, Double doubleVal) {
this.id = id;
this.data = data;
this.doubleVal = doubleVal;
}

public void setId(Integer id) {
this.id = id;
}

public void setData(String data) {
this.data = data;
}

public void setDoubleVal(Double doubleVal) {
this.doubleVal = doubleVal;
}

public Integer getId() {
return id;
}

public String getData() {
return data;
}

public Double getDoubleVal() {
return doubleVal;
}

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}

Record record = (Record) o;
return Objects.equal(id, record.id) && Objects.equal(data, record.data) &&
Objects.equal(doubleVal, record.doubleVal);
}

@Override
public int hashCode() {
return Objects.hashCode(id, data, doubleVal);
}
}
}
4 changes: 2 additions & 2 deletions spark3/src/main/java/org/apache/iceberg/spark/Spark3Util.java
Original file line number Diff line number Diff line change
Expand Up @@ -533,9 +533,9 @@ public <T> String predicate(UnboundPredicate<T> pred) {
case NOT_NULL:
return pred.ref().name() + " IS NOT NULL";
case IS_NAN:
return pred.ref().name() + " = NaN";
return "is_nan(" + pred.ref().name() + ")";
case NOT_NAN:
return pred.ref().name() + " != NaN";
return "not_nan(" + pred.ref().name() + ")";
case LT:
return pred.ref().name() + " < " + sqlString(pred.literal());
case LT_EQ:
Expand Down
Loading