diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 63a62cd0cbfe..208688efa21e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -516,7 +516,8 @@ object PushProjectionThroughUnion extends Rule[LogicalPlan] with PredicateHelper object ColumnPruning extends Rule[LogicalPlan] { private def sameOutput(output1: Seq[Attribute], output2: Seq[Attribute]): Boolean = output1.size == output2.size && - output1.zip(output2).forall(pair => pair._1.semanticEquals(pair._2)) + output1.zip(output2).forall(pair => + pair._1.semanticEquals(pair._2) && pair._1.name == pair._2.name) def apply(plan: LogicalPlan): LogicalPlan = removeProjectBeforeFilter(plan transform { // Prunes the unused columns from project list of Project/Aggregate/Expand @@ -680,7 +681,12 @@ object CollapseProject extends Rule[LogicalPlan] { // e.g., 'SELECT c + 1 FROM (SELECT a + b AS C ...' produces 'SELECT a + b + 1 ...' // Use transformUp to prevent infinite recursion. val rewrittenUpper = upper.map(_.transformUp { - case a: Attribute => aliases.getOrElse(a, a) + case a: Attribute => if (aliases.contains(a)) { + val alias = aliases.get(a).get + alias.copy(name = a.name)(alias.exprId, alias.qualifier, alias.explicitMetadata) + } else { + a + } }) // collapse upper and lower Projects may introduce unnecessary Aliases, trim them here. rewrittenUpper.map { p => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 01dc28d70184..4f734c79e0f4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -24,6 +24,7 @@ import java.util.concurrent.atomic.AtomicBoolean import org.apache.spark.{AccumulatorSuite, SparkException} import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec} @@ -2853,6 +2854,81 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } + test("Insert overwrite table command should output correct schema: basic") { + withTable("tbl", "tbl2") { + withView("view1") { + val df = spark.range(10).toDF("id") + df.write.format("parquet").saveAsTable("tbl") + spark.sql("CREATE VIEW view1 AS SELECT id FROM tbl") + spark.sql("CREATE TABLE tbl2(ID long) USING parquet") + spark.sql("INSERT OVERWRITE TABLE tbl2 SELECT ID FROM view1") + val identifier = TableIdentifier("tbl2", Some("default")) + val location = spark.sessionState.catalog.getTableMetadata(identifier).location.toString + val expectedSchema = StructType(Seq(StructField("ID", LongType, true))) + assert(spark.read.parquet(location).schema == expectedSchema) + checkAnswer(spark.table("tbl2"), df) + } + } + } + + test("Insert overwrite table command should output correct schema: complex") { + withTable("tbl", "tbl2") { + withView("view1") { + val df = spark.range(10).map(x => (x, x.toInt, x.toInt)).toDF("col1", "col2", "col3") + df.write.format("parquet").saveAsTable("tbl") + spark.sql("CREATE VIEW view1 AS SELECT * FROM tbl") + spark.sql("CREATE TABLE tbl2(COL1 long, COL2 int, COL3 int) USING parquet PARTITIONED " + + "BY (COL2) CLUSTERED BY (COL3) INTO 3 BUCKETS") + spark.sql("INSERT OVERWRITE TABLE tbl2 SELECT COL1, COL2, COL3 " + + "FROM view1 CLUSTER BY COL3") + val identifier = TableIdentifier("tbl2", Some("default")) + val location = spark.sessionState.catalog.getTableMetadata(identifier).location.toString + val expectedSchema = StructType(Seq( + StructField("COL1", LongType, true), + StructField("COL3", IntegerType, true), + StructField("COL2", IntegerType, true))) + assert(spark.read.parquet(location).schema == expectedSchema) + checkAnswer(spark.table("tbl2"), df) + } + } + } + + test("Create table as select command should output correct schema: basic") { + withTable("tbl", "tbl2") { + withView("view1") { + val df = spark.range(10).toDF("id") + df.write.format("parquet").saveAsTable("tbl") + spark.sql("CREATE VIEW view1 AS SELECT id FROM tbl") + spark.sql("CREATE TABLE tbl2 USING parquet AS SELECT ID FROM view1") + val identifier = TableIdentifier("tbl2", Some("default")) + val location = spark.sessionState.catalog.getTableMetadata(identifier).location.toString + val expectedSchema = StructType(Seq(StructField("ID", LongType, true))) + assert(spark.read.parquet(location).schema == expectedSchema) + checkAnswer(spark.table("tbl2"), df) + } + } + } + + test("Create table as select command should output correct schema: complex") { + withTable("tbl", "tbl2") { + withView("view1") { + val df = spark.range(10).map(x => (x, x.toInt, x.toInt)).toDF("col1", "col2", "col3") + df.write.format("parquet").saveAsTable("tbl") + spark.sql("CREATE VIEW view1 AS SELECT * FROM tbl") + spark.sql("CREATE TABLE tbl2 USING parquet PARTITIONED BY (COL2) " + + "CLUSTERED BY (COL3) INTO 3 BUCKETS AS SELECT COL1, COL2, COL3 FROM view1") + val identifier = TableIdentifier("tbl2", Some("default")) + val location = spark.sessionState.catalog.getTableMetadata(identifier).location.toString + val expectedSchema = StructType(Seq( + StructField("COL1", LongType, true), + StructField("COL3", IntegerType, true), + StructField("COL2", IntegerType, true))) + assert(spark.read.parquet(location).schema == expectedSchema) + checkAnswer(spark.table("tbl2"), df) + } + } + } + test("SPARK-25144 'distinct' causes memory leak") { val ds = List(Foo(Some("bar"))).toDS val result = ds.flatMap(_.bar).distinct