From f94fdf7fd74a75c777b5b38ce970e0742d00091c Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Sat, 1 Sep 2018 23:17:04 +0800 Subject: [PATCH 1/2] Fix ColumnPruning and CollapseProject on eliminating Project --- .../sql/catalyst/optimizer/Optimizer.scala | 16 ++-- .../org/apache/spark/sql/SQLQuerySuite.scala | 76 +++++++++++++++++++ 2 files changed, 85 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 63a62cd0cbfe..5c51f9614469 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -515,8 +515,7 @@ object PushProjectionThroughUnion extends Rule[LogicalPlan] with PredicateHelper */ object ColumnPruning extends Rule[LogicalPlan] { private def sameOutput(output1: Seq[Attribute], output2: Seq[Attribute]): Boolean = - output1.size == output2.size && - output1.zip(output2).forall(pair => pair._1.semanticEquals(pair._2)) + output1.size == output2.size && output1.zip(output2).forall(pair => pair._1 == pair._2) def apply(plan: LogicalPlan): LogicalPlan = removeProjectBeforeFilter(plan transform { // Prunes the unused columns from project list of Project/Aggregate/Expand @@ -649,9 +648,12 @@ object CollapseProject extends Rule[LogicalPlan] { } } - private def collectAliases(projectList: Seq[NamedExpression]): AttributeMap[Alias] = { - AttributeMap(projectList.collect { - case a: Alias => a.toAttribute -> a + private def collectAliases( + upper: Seq[NamedExpression], lower: Seq[NamedExpression]): AttributeMap[Alias] = { + AttributeMap(lower.zipWithIndex.collect { + case (a: Alias, index: Int) => + a.toAttribute -> + a.copy(name = upper(index).name)(a.exprId, a.qualifier, a.explicitMetadata) }) } @@ -659,7 +661,7 @@ object CollapseProject extends Rule[LogicalPlan] { upper: Seq[NamedExpression], lower: Seq[NamedExpression]): Boolean = { // Create a map of Aliases to their values from the lower projection. // e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a + b, c)). - val aliases = collectAliases(lower) + val aliases = collectAliases(upper, lower) // Collapse upper and lower Projects if and only if their overlapped expressions are all // deterministic. @@ -673,7 +675,7 @@ object CollapseProject extends Rule[LogicalPlan] { lower: Seq[NamedExpression]): Seq[NamedExpression] = { // Create a map of Aliases to their values from the lower projection. // e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a + b, c)). - val aliases = collectAliases(lower) + val aliases = collectAliases(upper, lower) // Substitute any attributes that are produced by the lower projection, so that we safely // eliminate it. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 01dc28d70184..4f734c79e0f4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -24,6 +24,7 @@ import java.util.concurrent.atomic.AtomicBoolean import org.apache.spark.{AccumulatorSuite, SparkException} import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec} @@ -2853,6 +2854,81 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } + test("Insert overwrite table command should output correct schema: basic") { + withTable("tbl", "tbl2") { + withView("view1") { + val df = spark.range(10).toDF("id") + df.write.format("parquet").saveAsTable("tbl") + spark.sql("CREATE VIEW view1 AS SELECT id FROM tbl") + spark.sql("CREATE TABLE tbl2(ID long) USING parquet") + spark.sql("INSERT OVERWRITE TABLE tbl2 SELECT ID FROM view1") + val identifier = TableIdentifier("tbl2", Some("default")) + val location = spark.sessionState.catalog.getTableMetadata(identifier).location.toString + val expectedSchema = StructType(Seq(StructField("ID", LongType, true))) + assert(spark.read.parquet(location).schema == expectedSchema) + checkAnswer(spark.table("tbl2"), df) + } + } + } + + test("Insert overwrite table command should output correct schema: complex") { + withTable("tbl", "tbl2") { + withView("view1") { + val df = spark.range(10).map(x => (x, x.toInt, x.toInt)).toDF("col1", "col2", "col3") + df.write.format("parquet").saveAsTable("tbl") + spark.sql("CREATE VIEW view1 AS SELECT * FROM tbl") + spark.sql("CREATE TABLE tbl2(COL1 long, COL2 int, COL3 int) USING parquet PARTITIONED " + + "BY (COL2) CLUSTERED BY (COL3) INTO 3 BUCKETS") + spark.sql("INSERT OVERWRITE TABLE tbl2 SELECT COL1, COL2, COL3 " + + "FROM view1 CLUSTER BY COL3") + val identifier = TableIdentifier("tbl2", Some("default")) + val location = spark.sessionState.catalog.getTableMetadata(identifier).location.toString + val expectedSchema = StructType(Seq( + StructField("COL1", LongType, true), + StructField("COL3", IntegerType, true), + StructField("COL2", IntegerType, true))) + assert(spark.read.parquet(location).schema == expectedSchema) + checkAnswer(spark.table("tbl2"), df) + } + } + } + + test("Create table as select command should output correct schema: basic") { + withTable("tbl", "tbl2") { + withView("view1") { + val df = spark.range(10).toDF("id") + df.write.format("parquet").saveAsTable("tbl") + spark.sql("CREATE VIEW view1 AS SELECT id FROM tbl") + spark.sql("CREATE TABLE tbl2 USING parquet AS SELECT ID FROM view1") + val identifier = TableIdentifier("tbl2", Some("default")) + val location = spark.sessionState.catalog.getTableMetadata(identifier).location.toString + val expectedSchema = StructType(Seq(StructField("ID", LongType, true))) + assert(spark.read.parquet(location).schema == expectedSchema) + checkAnswer(spark.table("tbl2"), df) + } + } + } + + test("Create table as select command should output correct schema: complex") { + withTable("tbl", "tbl2") { + withView("view1") { + val df = spark.range(10).map(x => (x, x.toInt, x.toInt)).toDF("col1", "col2", "col3") + df.write.format("parquet").saveAsTable("tbl") + spark.sql("CREATE VIEW view1 AS SELECT * FROM tbl") + spark.sql("CREATE TABLE tbl2 USING parquet PARTITIONED BY (COL2) " + + "CLUSTERED BY (COL3) INTO 3 BUCKETS AS SELECT COL1, COL2, COL3 FROM view1") + val identifier = TableIdentifier("tbl2", Some("default")) + val location = spark.sessionState.catalog.getTableMetadata(identifier).location.toString + val expectedSchema = StructType(Seq( + StructField("COL1", LongType, true), + StructField("COL3", IntegerType, true), + StructField("COL2", IntegerType, true))) + assert(spark.read.parquet(location).schema == expectedSchema) + checkAnswer(spark.table("tbl2"), df) + } + } + } + test("SPARK-25144 'distinct' causes memory leak") { val ds = List(Foo(Some("bar"))).toDS val result = ds.flatMap(_.bar).distinct From 96b4cca5addab62f9e97c1983484ea1e4a3d5059 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Mon, 3 Sep 2018 14:25:33 +0800 Subject: [PATCH 2/2] fix test failure --- .../sql/catalyst/optimizer/Optimizer.scala | 24 +++++++++++-------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 5c51f9614469..208688efa21e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -515,7 +515,9 @@ object PushProjectionThroughUnion extends Rule[LogicalPlan] with PredicateHelper */ object ColumnPruning extends Rule[LogicalPlan] { private def sameOutput(output1: Seq[Attribute], output2: Seq[Attribute]): Boolean = - output1.size == output2.size && output1.zip(output2).forall(pair => pair._1 == pair._2) + output1.size == output2.size && + output1.zip(output2).forall(pair => + pair._1.semanticEquals(pair._2) && pair._1.name == pair._2.name) def apply(plan: LogicalPlan): LogicalPlan = removeProjectBeforeFilter(plan transform { // Prunes the unused columns from project list of Project/Aggregate/Expand @@ -648,12 +650,9 @@ object CollapseProject extends Rule[LogicalPlan] { } } - private def collectAliases( - upper: Seq[NamedExpression], lower: Seq[NamedExpression]): AttributeMap[Alias] = { - AttributeMap(lower.zipWithIndex.collect { - case (a: Alias, index: Int) => - a.toAttribute -> - a.copy(name = upper(index).name)(a.exprId, a.qualifier, a.explicitMetadata) + private def collectAliases(projectList: Seq[NamedExpression]): AttributeMap[Alias] = { + AttributeMap(projectList.collect { + case a: Alias => a.toAttribute -> a }) } @@ -661,7 +660,7 @@ object CollapseProject extends Rule[LogicalPlan] { upper: Seq[NamedExpression], lower: Seq[NamedExpression]): Boolean = { // Create a map of Aliases to their values from the lower projection. // e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a + b, c)). - val aliases = collectAliases(upper, lower) + val aliases = collectAliases(lower) // Collapse upper and lower Projects if and only if their overlapped expressions are all // deterministic. @@ -675,14 +674,19 @@ object CollapseProject extends Rule[LogicalPlan] { lower: Seq[NamedExpression]): Seq[NamedExpression] = { // Create a map of Aliases to their values from the lower projection. // e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a + b, c)). - val aliases = collectAliases(upper, lower) + val aliases = collectAliases(lower) // Substitute any attributes that are produced by the lower projection, so that we safely // eliminate it. // e.g., 'SELECT c + 1 FROM (SELECT a + b AS C ...' produces 'SELECT a + b + 1 ...' // Use transformUp to prevent infinite recursion. val rewrittenUpper = upper.map(_.transformUp { - case a: Attribute => aliases.getOrElse(a, a) + case a: Attribute => if (aliases.contains(a)) { + val alias = aliases.get(a).get + alias.copy(name = a.name)(alias.exprId, alias.qualifier, alias.explicitMetadata) + } else { + a + } }) // collapse upper and lower Projects may introduce unnecessary Aliases, trim them here. rewrittenUpper.map { p =>