Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,8 @@ object PushProjectionThroughUnion extends Rule[LogicalPlan] with PredicateHelper
object ColumnPruning extends Rule[LogicalPlan] {
private def sameOutput(output1: Seq[Attribute], output2: Seq[Attribute]): Boolean =
output1.size == output2.size &&
output1.zip(output2).forall(pair => pair._1.semanticEquals(pair._2))
output1.zip(output2).forall(pair =>
pair._1.semanticEquals(pair._2) && pair._1.name == pair._2.name)

def apply(plan: LogicalPlan): LogicalPlan = removeProjectBeforeFilter(plan transform {
// Prunes the unused columns from project list of Project/Aggregate/Expand
Expand Down Expand Up @@ -680,7 +681,12 @@ object CollapseProject extends Rule[LogicalPlan] {
// e.g., 'SELECT c + 1 FROM (SELECT a + b AS C ...' produces 'SELECT a + b + 1 ...'
// Use transformUp to prevent infinite recursion.
val rewrittenUpper = upper.map(_.transformUp {
case a: Attribute => aliases.getOrElse(a, a)
case a: Attribute => if (aliases.contains(a)) {
val alias = aliases.get(a).get
alias.copy(name = a.name)(alias.exprId, alias.qualifier, alias.explicitMetadata)
} else {
a
}
})
// collapse upper and lower Projects may introduce unnecessary Aliases, trim them here.
rewrittenUpper.map { p =>
Expand Down
76 changes: 76 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import java.util.concurrent.atomic.AtomicBoolean

import org.apache.spark.{AccumulatorSuite, SparkException}
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.util.StringUtils
import org.apache.spark.sql.execution.aggregate
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec}
Expand Down Expand Up @@ -2853,6 +2854,81 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
}

test("Insert overwrite table command should output correct schema: basic") {
withTable("tbl", "tbl2") {
withView("view1") {
val df = spark.range(10).toDF("id")
df.write.format("parquet").saveAsTable("tbl")
spark.sql("CREATE VIEW view1 AS SELECT id FROM tbl")
spark.sql("CREATE TABLE tbl2(ID long) USING parquet")
spark.sql("INSERT OVERWRITE TABLE tbl2 SELECT ID FROM view1")
val identifier = TableIdentifier("tbl2", Some("default"))
val location = spark.sessionState.catalog.getTableMetadata(identifier).location.toString
val expectedSchema = StructType(Seq(StructField("ID", LongType, true)))
assert(spark.read.parquet(location).schema == expectedSchema)
checkAnswer(spark.table("tbl2"), df)
}
}
}

test("Insert overwrite table command should output correct schema: complex") {
withTable("tbl", "tbl2") {
withView("view1") {
val df = spark.range(10).map(x => (x, x.toInt, x.toInt)).toDF("col1", "col2", "col3")
df.write.format("parquet").saveAsTable("tbl")
spark.sql("CREATE VIEW view1 AS SELECT * FROM tbl")
spark.sql("CREATE TABLE tbl2(COL1 long, COL2 int, COL3 int) USING parquet PARTITIONED " +
"BY (COL2) CLUSTERED BY (COL3) INTO 3 BUCKETS")
spark.sql("INSERT OVERWRITE TABLE tbl2 SELECT COL1, COL2, COL3 " +
"FROM view1 CLUSTER BY COL3")
val identifier = TableIdentifier("tbl2", Some("default"))
val location = spark.sessionState.catalog.getTableMetadata(identifier).location.toString
val expectedSchema = StructType(Seq(
StructField("COL1", LongType, true),
StructField("COL3", IntegerType, true),
StructField("COL2", IntegerType, true)))
assert(spark.read.parquet(location).schema == expectedSchema)
checkAnswer(spark.table("tbl2"), df)
}
}
}

test("Create table as select command should output correct schema: basic") {
withTable("tbl", "tbl2") {
withView("view1") {
val df = spark.range(10).toDF("id")
df.write.format("parquet").saveAsTable("tbl")
spark.sql("CREATE VIEW view1 AS SELECT id FROM tbl")
spark.sql("CREATE TABLE tbl2 USING parquet AS SELECT ID FROM view1")
val identifier = TableIdentifier("tbl2", Some("default"))
val location = spark.sessionState.catalog.getTableMetadata(identifier).location.toString
val expectedSchema = StructType(Seq(StructField("ID", LongType, true)))
assert(spark.read.parquet(location).schema == expectedSchema)
checkAnswer(spark.table("tbl2"), df)
}
}
}

test("Create table as select command should output correct schema: complex") {
withTable("tbl", "tbl2") {
withView("view1") {
val df = spark.range(10).map(x => (x, x.toInt, x.toInt)).toDF("col1", "col2", "col3")
df.write.format("parquet").saveAsTable("tbl")
spark.sql("CREATE VIEW view1 AS SELECT * FROM tbl")
spark.sql("CREATE TABLE tbl2 USING parquet PARTITIONED BY (COL2) " +
"CLUSTERED BY (COL3) INTO 3 BUCKETS AS SELECT COL1, COL2, COL3 FROM view1")
val identifier = TableIdentifier("tbl2", Some("default"))
val location = spark.sessionState.catalog.getTableMetadata(identifier).location.toString
val expectedSchema = StructType(Seq(
StructField("COL1", LongType, true),
StructField("COL3", IntegerType, true),
StructField("COL2", IntegerType, true)))
assert(spark.read.parquet(location).schema == expectedSchema)
checkAnswer(spark.table("tbl2"), df)
}
}
}

test("SPARK-25144 'distinct' causes memory leak") {
val ds = List(Foo(Some("bar"))).toDS
val result = ds.flatMap(_.bar).distinct
Expand Down