diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala index 6530b176968f..2d823552dba2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala @@ -45,7 +45,8 @@ object PythonUDF { } /** - * A serialized version of a Python lambda function. + * A serialized version of a Python lambda function. This is a special expression, which needs a + * dedicated physical operator to execute it, and thus can't be pushed down to data sources. */ case class PythonUDF( name: String, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index f32f2c7986dc..5b59ac7d2a9b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1143,6 +1143,8 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { case _: Repartition => true case _: ScriptTransformation => true case _: Sort => true + case _: BatchEvalPython => true + case _: ArrowEvalPython => true case _ => false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala index 254687ec0088..2df30a1a53ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, PythonUDF} /** * FlatMap groups using an udf: pandas.Dataframe -> pandas.DataFrame. @@ -38,3 +38,30 @@ case class FlatMapGroupsInPandas( */ override val producedAttributes = AttributeSet(output) } + +trait BaseEvalPython extends UnaryNode { + + def udfs: Seq[PythonUDF] + + def resultAttrs: Seq[Attribute] + + override def output: Seq[Attribute] = child.output ++ resultAttrs + + override def references: AttributeSet = AttributeSet(udfs.flatMap(_.references)) +} + +/** + * A logical plan that evaluates a [[PythonUDF]] + */ +case class BatchEvalPython( + udfs: Seq[PythonUDF], + resultAttrs: Seq[Attribute], + child: LogicalPlan) extends BaseEvalPython + +/** + * A logical plan that evaluates a [[PythonUDF]] with Apache Arrow. + */ +case class ArrowEvalPython( + udfs: Seq[PythonUDF], + resultAttrs: Seq[Attribute], + child: LogicalPlan) extends BaseEvalPython diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 31540e81d125..c35e5de88636 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.ExperimentalMethods import org.apache.spark.sql.catalyst.catalog.SessionCatalog -import org.apache.spark.sql.catalyst.optimizer.Optimizer +import org.apache.spark.sql.catalyst.optimizer.{ColumnPruning, Optimizer, PushDownPredicate, RemoveNoopOperators} import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions import org.apache.spark.sql.execution.datasources.SchemaPruning import org.apache.spark.sql.execution.python.{ExtractPythonUDFFromAggregate, ExtractPythonUDFs} @@ -32,14 +32,21 @@ class SparkOptimizer( override def defaultBatches: Seq[Batch] = (preOptimizationBatches ++ super.defaultBatches :+ Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+ Batch("Extract Python UDFs", Once, - Seq(ExtractPythonUDFFromAggregate, ExtractPythonUDFs): _*) :+ + ExtractPythonUDFFromAggregate, + ExtractPythonUDFs, + // The eval-python node may be between Project/Filter and the scan node, which breaks + // column pruning and filter push-down. Here we rerun the related optimizer rules. + ColumnPruning, + PushDownPredicate, + RemoveNoopOperators) :+ Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions) :+ Batch("Schema Pruning", Once, SchemaPruning)) ++ postHocOptimizationBatches :+ Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*) - override def nonExcludableRules: Seq[String] = - super.nonExcludableRules :+ ExtractPythonUDFFromAggregate.ruleName + override def nonExcludableRules: Seq[String] = super.nonExcludableRules :+ + ExtractPythonUDFFromAggregate.ruleName :+ + ExtractPythonUDFs.ruleName /** * Optimization batches that are executed before the regular optimization batches (also before diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index 61b167f50fd6..c114b1c87e64 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -23,7 +23,6 @@ import org.apache.spark.TaskContext import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.arrow.ArrowUtils import org.apache.spark.sql.types.StructType @@ -57,21 +56,11 @@ private[spark] class BatchIterator[T](iter: Iterator[T], batchSize: Int) } } -/** - * A logical plan that evaluates a [[PythonUDF]]. - */ -case class ArrowEvalPython( - udfs: Seq[PythonUDF], - output: Seq[Attribute], - child: LogicalPlan) extends UnaryNode { - override def producedAttributes: AttributeSet = AttributeSet(output.drop(child.output.length)) -} - /** * A physical plan that evaluates a [[PythonUDF]]. */ -case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan) - extends EvalPythonExec(udfs, output, child) { +case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute], child: SparkPlan) + extends EvalPythonExec(udfs, resultAttrs, child) { private val batchSize = conf.arrowMaxRecordsPerBatch private val sessionLocalTimeZone = conf.sessionLocalTimeZone diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala index eff709ef7f72..4f352782067c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala @@ -25,25 +25,14 @@ import org.apache.spark.TaskContext import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.types.{StructField, StructType} -/** - * A logical plan that evaluates a [[PythonUDF]] - */ -case class BatchEvalPython( - udfs: Seq[PythonUDF], - output: Seq[Attribute], - child: LogicalPlan) extends UnaryNode { - override def producedAttributes: AttributeSet = AttributeSet(output.drop(child.output.length)) -} - /** * A physical plan that evaluates a [[PythonUDF]] */ -case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan) - extends EvalPythonExec(udfs, output, child) { +case class BatchEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute], child: SparkPlan) + extends EvalPythonExec(udfs, resultAttrs, child) { protected override def evaluate( funcs: Seq[ChainedPythonFunctions], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala index 67dcdd3732b4..10e1d5e00648 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala @@ -57,10 +57,12 @@ import org.apache.spark.util.Utils * there should be always some rows buffered in the socket or Python process, so the pulling from * RowQueue ALWAYS happened after pushing into it. */ -abstract class EvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan) +abstract class EvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute], child: SparkPlan) extends UnaryExecNode { - override def producedAttributes: AttributeSet = AttributeSet(output.drop(child.output.length)) + override def output: Seq[Attribute] = child.output ++ resultAttrs + + override def references: AttributeSet = AttributeSet(udfs.flatMap(_.references)) private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = { udf.children match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index 380c31baa621..7f59d74b0cd4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -158,21 +158,9 @@ object ExtractPythonUDFs extends Rule[LogicalPlan] with PredicateHelper { // If there aren't any, we are done. plan } else { - val inputsForPlan = plan.references ++ plan.outputSet - val prunedChildren = plan.children.map { child => - val allNeededOutput = inputsForPlan.intersect(child.outputSet).toSeq - if (allNeededOutput.length != child.output.length) { - Project(allNeededOutput, child) - } else { - child - } - } - val planWithNewChildren = plan.withNewChildren(prunedChildren) - val attributeMap = mutable.HashMap[PythonUDF, Expression]() - val splitFilter = trySplitFilter(planWithNewChildren) // Rewrite the child that has the input required for the UDF - val newChildren = splitFilter.children.map { child => + val newChildren = plan.children.map { child => // Pick the UDF we are going to evaluate val validUdfs = udfs.filter { udf => // Check to make sure that the UDF can be evaluated with only the input of this child. @@ -191,9 +179,9 @@ object ExtractPythonUDFs extends Rule[LogicalPlan] with PredicateHelper { _.evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF ) match { case (vectorizedUdfs, plainUdfs) if plainUdfs.isEmpty => - ArrowEvalPython(vectorizedUdfs, child.output ++ resultAttrs, child) + ArrowEvalPython(vectorizedUdfs, resultAttrs, child) case (vectorizedUdfs, plainUdfs) if vectorizedUdfs.isEmpty => - BatchEvalPython(plainUdfs, child.output ++ resultAttrs, child) + BatchEvalPython(plainUdfs, resultAttrs, child) case _ => throw new AnalysisException( "Expected either Scalar Pandas UDFs or Batched UDFs but got both") @@ -211,7 +199,7 @@ object ExtractPythonUDFs extends Rule[LogicalPlan] with PredicateHelper { sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.") } - val rewritten = splitFilter.withNewChildren(newChildren).transformExpressions { + val rewritten = plan.withNewChildren(newChildren).transformExpressions { case p: PythonUDF if attributeMap.contains(p) => attributeMap(p) } @@ -226,22 +214,4 @@ object ExtractPythonUDFs extends Rule[LogicalPlan] with PredicateHelper { } } } - - // Split the original FilterExec to two FilterExecs. Only push down the first few predicates - // that are all deterministic. - private def trySplitFilter(plan: LogicalPlan): LogicalPlan = { - plan match { - case filter: Filter => - val (candidates, nonDeterministic) = - splitConjunctivePredicates(filter.condition).partition(_.deterministic) - val (pushDown, rest) = candidates.partition(!hasScalarPythonUDF(_)) - if (pushDown.nonEmpty) { - val newChild = Filter(pushDown.reduceLeft(And), filter.child) - Filter((rest ++ nonDeterministic).reduceLeft(And), newChild) - } else { - filter - } - case o => o - } - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala index 76b609d111ac..60cb40f16c73 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala @@ -17,13 +17,12 @@ package org.apache.spark.sql.execution.python -import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} +import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan, SparkPlanTest} import org.apache.spark.sql.functions.col import org.apache.spark.sql.test.SharedSQLContext class ExtractPythonUDFsSuite extends SparkPlanTest with SharedSQLContext { - import testImplicits.newProductEncoder - import testImplicits.localSeqToDatasetHolder + import testImplicits._ val batchedPythonUDF = new MyDummyPythonUDF val scalarPandasUDF = new MyDummyScalarPandasUDF @@ -88,5 +87,40 @@ class ExtractPythonUDFsSuite extends SparkPlanTest with SharedSQLContext { assert(pythonEvalNodes.size == 2) assert(arrowEvalNodes.size == 2) } + + test("Python UDF should not break column pruning/filter pushdown") { + withTempPath { f => + spark.range(10).select($"id".as("a"), $"id".as("b")) + .write.parquet(f.getCanonicalPath) + val df = spark.read.parquet(f.getCanonicalPath) + + withClue("column pruning") { + val query = df.filter(batchedPythonUDF($"a")).select($"a") + + val pythonEvalNodes = collectBatchExec(query.queryExecution.executedPlan) + assert(pythonEvalNodes.length == 1) + + val scanNodes = query.queryExecution.executedPlan.collect { + case scan: FileSourceScanExec => scan + } + assert(scanNodes.length == 1) + assert(scanNodes.head.output.map(_.name) == Seq("a")) + } + + withClue("filter pushdown") { + val query = df.filter($"a" > 1 && batchedPythonUDF($"a")) + val pythonEvalNodes = collectBatchExec(query.queryExecution.executedPlan) + assert(pythonEvalNodes.length == 1) + + val scanNodes = query.queryExecution.executedPlan.collect { + case scan: FileSourceScanExec => scan + } + assert(scanNodes.length == 1) + // 'a is not null and 'a > 1 + assert(scanNodes.head.dataFilters.length == 2) + assert(scanNodes.head.dataFilters.flatMap(_.references.map(_.name)).distinct == Seq("a")) + } + } + } }