Skip to content

Commit

Permalink
API, Spark: Fix aggregation pushdown on struct fields (apache#9176)
Browse files Browse the repository at this point in the history
  • Loading branch information
amogh-jahagirdar authored Jan 31, 2024
1 parent 26d62c0 commit 9de693f
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,11 @@ public int size() {
@Override
@SuppressWarnings("unchecked")
public <T> T get(int pos, Class<T> javaClass) {
return (T) value;
if (javaClass.isAssignableFrom(StructLike.class)) {
return (T) this;
} else {
return (T) value;
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.math.BigDecimal;
import java.sql.Date;
import java.sql.Timestamp;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import org.apache.iceberg.CatalogUtil;
Expand All @@ -36,6 +37,7 @@
import org.apache.iceberg.spark.CatalogTestBase;
import org.apache.iceberg.spark.TestBase;
import org.apache.spark.sql.SparkSession;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.TestTemplate;
Expand Down Expand Up @@ -478,6 +480,126 @@ public void testAggregateWithComplexType() {
.isFalse();
}

@TestTemplate
public void testAggregationPushdownStructInteger() {
sql("CREATE TABLE %s (id BIGINT, struct_with_int STRUCT<c1:BIGINT>) USING iceberg", tableName);
sql("INSERT INTO TABLE %s VALUES (1, named_struct(\"c1\", NULL))", tableName);
sql("INSERT INTO TABLE %s VALUES (2, named_struct(\"c1\", 2))", tableName);
sql("INSERT INTO TABLE %s VALUES (3, named_struct(\"c1\", 3))", tableName);

String query = "SELECT COUNT(%s), MAX(%s), MIN(%s) FROM %s";
String aggField = "struct_with_int.c1";
assertAggregates(sql(query, aggField, aggField, aggField, tableName), 2L, 3L, 2L);
assertExplainContains(
sql("EXPLAIN " + query, aggField, aggField, aggField, tableName),
"count(struct_with_int.c1)",
"max(struct_with_int.c1)",
"min(struct_with_int.c1)");
}

@TestTemplate
public void testAggregationPushdownNestedStruct() {
sql(
"CREATE TABLE %s (id BIGINT, struct_with_int STRUCT<c1:STRUCT<c2:STRUCT<c3:STRUCT<c4:BIGINT>>>>) USING iceberg",
tableName);
sql(
"INSERT INTO TABLE %s VALUES (1, named_struct(\"c1\", named_struct(\"c2\", named_struct(\"c3\", named_struct(\"c4\", NULL)))))",
tableName);
sql(
"INSERT INTO TABLE %s VALUES (2, named_struct(\"c1\", named_struct(\"c2\", named_struct(\"c3\", named_struct(\"c4\", 2)))))",
tableName);
sql(
"INSERT INTO TABLE %s VALUES (3, named_struct(\"c1\", named_struct(\"c2\", named_struct(\"c3\", named_struct(\"c4\", 3)))))",
tableName);

String query = "SELECT COUNT(%s), MAX(%s), MIN(%s) FROM %s";
String aggField = "struct_with_int.c1.c2.c3.c4";

assertAggregates(sql(query, aggField, aggField, aggField, tableName), 2L, 3L, 2L);

assertExplainContains(
sql("EXPLAIN " + query, aggField, aggField, aggField, tableName),
"count(struct_with_int.c1.c2.c3.c4)",
"max(struct_with_int.c1.c2.c3.c4)",
"min(struct_with_int.c1.c2.c3.c4)");
}

@TestTemplate
public void testAggregationPushdownStructTimestamp() {
sql(
"CREATE TABLE %s (id BIGINT, struct_with_ts STRUCT<c1:TIMESTAMP>) USING iceberg",
tableName);
sql("INSERT INTO TABLE %s VALUES (1, named_struct(\"c1\", NULL))", tableName);
sql(
"INSERT INTO TABLE %s VALUES (2, named_struct(\"c1\", timestamp('2023-01-30T22:22:22Z')))",
tableName);
sql(
"INSERT INTO TABLE %s VALUES (3, named_struct(\"c1\", timestamp('2023-01-30T22:23:23Z')))",
tableName);

String query = "SELECT COUNT(%s), MAX(%s), MIN(%s) FROM %s";
String aggField = "struct_with_ts.c1";

assertAggregates(
sql(query, aggField, aggField, aggField, tableName),
2L,
new Timestamp(1675117403000L),
new Timestamp(1675117342000L));

assertExplainContains(
sql("EXPLAIN " + query, aggField, aggField, aggField, tableName),
"count(struct_with_ts.c1)",
"max(struct_with_ts.c1)",
"min(struct_with_ts.c1)");
}

@TestTemplate
public void testAggregationPushdownOnBucketedColumn() {
sql(
"CREATE TABLE %s (id BIGINT, struct_with_int STRUCT<c1:INT>) USING iceberg PARTITIONED BY (bucket(8, id))",
tableName);

sql("INSERT INTO TABLE %s VALUES (1, named_struct(\"c1\", NULL))", tableName);
sql("INSERT INTO TABLE %s VALUES (null, named_struct(\"c1\", 2))", tableName);
sql("INSERT INTO TABLE %s VALUES (2, named_struct(\"c1\", 3))", tableName);

String query = "SELECT COUNT(%s), MAX(%s), MIN(%s) FROM %s";
String aggField = "id";
assertAggregates(sql(query, aggField, aggField, aggField, tableName), 2L, 2L, 1L);
assertExplainContains(
sql("EXPLAIN " + query, aggField, aggField, aggField, tableName),
"count(id)",
"max(id)",
"min(id)");
}

private void assertAggregates(
List<Object[]> actual, Object expectedCount, Object expectedMax, Object expectedMin) {
Object actualCount = actual.get(0)[0];
Object actualMax = actual.get(0)[1];
Object actualMin = actual.get(0)[2];

Assertions.assertThat(actualCount)
.as("Expected and actual count should equal")
.isEqualTo(expectedCount);
Assertions.assertThat(actualMax)
.as("Expected and actual max should equal")
.isEqualTo(expectedMax);
Assertions.assertThat(actualMin)
.as("Expected and actual min should equal")
.isEqualTo(expectedMin);
}

private void assertExplainContains(List<Object[]> explain, String... expectedFragments) {
String explainString = explain.get(0)[0].toString().toLowerCase(Locale.ROOT);
Arrays.stream(expectedFragments)
.forEach(
fragment ->
Assertions.assertThat(explainString.contains(fragment))
.isTrue()
.as("Expected to find plan fragment in explain plan"));
}

@TestTemplate
public void testAggregatePushDownInDeleteCopyOnWrite() {
sql("CREATE TABLE %s (id LONG, data INT) USING iceberg", tableName);
Expand Down

0 comments on commit 9de693f

Please sign in to comment.