diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 2cc27d82f7d2..586c7fbed639 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -450,13 +450,16 @@ object ColumnPruning extends Rule[LogicalPlan] { case d @ DeserializeToObject(_, _, child) if (child.outputSet -- d.references).nonEmpty => d.copy(child = prunedChild(child, d.references)) - // Prunes the unused columns from child of Aggregate/Expand/Generate + // Prunes the unused columns from child of Aggregate/Expand/Generate/ScriptTransformation case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty => a.copy(child = prunedChild(child, a.references)) case f @ FlatMapGroupsInPandas(_, _, _, child) if (child.outputSet -- f.references).nonEmpty => f.copy(child = prunedChild(child, f.references)) case e @ Expand(_, _, child) if (child.outputSet -- e.references).nonEmpty => e.copy(child = prunedChild(child, e.references)) + case s @ ScriptTransformation(_, _, _, child, _) + if (child.outputSet -- s.references).nonEmpty => + s.copy(child = prunedChild(child, s.references)) // prune unrequired references case p @ Project(_, g: Generate) if p.references != g.outputSet => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index 8b05ba32e6ee..f6db3c90ad96 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -140,6 +140,30 @@ class ColumnPruningSuite extends PlanTest { comparePlans(optimized, expected) } + test("Column pruning for ScriptTransformation") { + val input = LocalRelation('a.int, 'b.string, 'c.double) + val query = + ScriptTransformation( + Seq('a, 'b), + "func", + Seq.empty, + input, + null).analyze + val optimized = Optimize.execute(query) + + val expected = + ScriptTransformation( + Seq('a, 'b), + "func", + Seq.empty, + Project( + Seq('a, 'b), + input), + null).analyze + + comparePlans(optimized, expected) + } + test("Column pruning on Filter") { val input = LocalRelation('a.int, 'b.string, 'c.double) val plan1 = Filter('a > 1, input).analyze diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala index 5318b4650b01..5f73b7170c61 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala @@ -136,6 +136,25 @@ class ScriptTransformationSuite extends SparkPlanTest with TestHiveSingleton { } assert(e.getMessage.contains("Subprocess exited with status")) } + + test("SPARK-24339 verify the result after pruning the unused columns") { + val rowsDf = Seq( + ("Bob", 16, 176), + ("Alice", 32, 164), + ("David", 60, 192), + ("Amy", 24, 180)).toDF("name", "age", "height") + + checkAnswer( + rowsDf, + (child: SparkPlan) => new ScriptTransformationExec( + input = Seq(rowsDf.col("name").expr), + script = "cat", + output = Seq(AttributeReference("name", StringType)()), + child = child, + ioschema = serdeIOSchema + ), + rowsDf.select("name").collect()) + } } private case class ExceptionInjectingOperator(child: SparkPlan) extends UnaryExecNode {