diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index c15899cb230e..bb57582684f7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -338,6 +338,17 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging // Add where. val withFilter = relation.optionalMap(where)(filter) + // Add project. + val namedExpressions = expressions.map { + case e: NamedExpression => e + case e: Expression => UnresolvedAlias(e) + } + val withProject = if (namedExpressions.nonEmpty) { + Project(namedExpressions, withFilter) + } else { + withFilter + } + // Create the attributes. val (attributes, schemaLess) = if (colTypeList != null) { // Typed return columns. @@ -358,7 +369,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging expressions, string(script), attributes, - withFilter, + withProject, withScriptIOSchema( ctx, inRowFormat, recordWriter, outRowFormat, recordReader, schemaLess)) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala index 43ce093f8a7d..ea8bc72d618c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala @@ -265,11 +265,17 @@ class HiveDDLCommandSuite extends PlanTest with SQLTestUtils with TestHiveSingle "func", Seq.empty, plans.table("e"), null) comparePlans(plan1, - p.copy(child = p.child.where('f < 10), output = Seq('key.string, 'value.string))) + p.copy( + child = p.child.where('f < 10).select(UnresolvedAttribute("a"), UnresolvedAttribute("b")), + output = Seq('key.string, 'value.string))) comparePlans(plan2, - p.copy(output = Seq('c.string, 'd.string))) + p.copy( + child = p.child.select(UnresolvedAttribute("a"), UnresolvedAttribute("b")), + output = Seq('c.string, 'd.string))) comparePlans(plan3, - p.copy(output = Seq('c.int, 'd.decimal(10, 0)))) + p.copy( + child = p.child.select(UnresolvedAttribute("a"), UnresolvedAttribute("b")), + output = Seq('c.int, 'd.decimal(10, 0)))) } test("use backticks in output of Script Transform") {