Skip to content

Commit 5743c64

Browse files
Davies Liudavies
authored andcommitted
[SPARK-12981] [SQL] extract Pyhton UDF in physical plan
## What changes were proposed in this pull request? Currently we extract Python UDFs into a special logical plan EvaluatePython in analyzer, But EvaluatePython is not part of catalyst, many rules have no knowledge of it , which will break many things (for example, filter push down or column pruning). We should treat Python UDFs as normal expressions, until we want to evaluate in physical plan, we could extract them in end of optimizer, or physical plan. This PR extract Python UDFs in physical plan. Closes #10935 ## How was this patch tested? Added regression tests. Author: Davies Liu <davies@databricks.com> Closes #12127 from davies/py_udf.
1 parent 855ed44 commit 5743c64

File tree

8 files changed

+64
-70
lines changed

8 files changed

+64
-70
lines changed

python/pyspark/sql/tests.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,15 @@ def test_broadcast_in_udf(self):
343343
[res] = self.sqlCtx.sql("SELECT MYUDF('')").collect()
344344
self.assertEqual("", res[0])
345345

346+
def test_udf_with_aggregate_function(self):
347+
df = self.sqlCtx.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
348+
from pyspark.sql.functions import udf, col
349+
from pyspark.sql.types import BooleanType
350+
351+
my_filter = udf(lambda a: a == 1, BooleanType())
352+
sel = df.select(col("key")).distinct().filter(my_filter(col("key")))
353+
self.assertEqual(sel.collect(), [Row(key=1)])
354+
346355
def test_basic_functions(self):
347356
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
348357
df = self.sqlCtx.read.json(rdd)

sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) {
7474

7575
/** A sequence of rules that will be applied in order to the physical plan before execution. */
7676
protected def preparations: Seq[Rule[SparkPlan]] = Seq(
77+
python.ExtractPythonUDFs,
7778
PlanSubqueries(sqlContext),
7879
EnsureRequirements(sqlContext.conf),
7980
CollapseCodegenStages(sqlContext.conf),

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -392,8 +392,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
392392
case logical.RepartitionByExpression(expressions, child, nPartitions) =>
393393
exchange.ShuffleExchange(HashPartitioning(
394394
expressions, nPartitions.getOrElse(numPartitions)), planLater(child)) :: Nil
395-
case e @ python.EvaluatePython(udfs, child, _) =>
396-
python.BatchPythonEvaluation(udfs, e.output, planLater(child)) :: Nil
397395
case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd, "ExistingRDD") :: Nil
398396
case BroadcastHint(child) => planLater(child) :: Nil
399397
case _ => Nil

sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -35,30 +35,7 @@ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, Generic
3535
import org.apache.spark.sql.types._
3636
import org.apache.spark.unsafe.types.UTF8String
3737

38-
/**
39-
* Evaluates a list of [[PythonUDF]], appending the result to the end of the input tuple.
40-
*/
41-
case class EvaluatePython(
42-
udfs: Seq[PythonUDF],
43-
child: LogicalPlan,
44-
resultAttribute: Seq[AttributeReference])
45-
extends logical.UnaryNode {
46-
47-
def output: Seq[Attribute] = child.output ++ resultAttribute
48-
49-
// References should not include the produced attribute.
50-
override def references: AttributeSet = AttributeSet(udfs.flatMap(_.references))
51-
}
52-
53-
5438
object EvaluatePython {
55-
def apply(udfs: Seq[PythonUDF], child: LogicalPlan): EvaluatePython = {
56-
val resultAttrs = udfs.zipWithIndex.map { case (u, i) =>
57-
AttributeReference(s"pythonUDF$i", u.dataType)()
58-
}
59-
new EvaluatePython(udfs, child, resultAttrs)
60-
}
61-
6239
def takeAndServe(df: DataFrame, n: Int): Int = {
6340
registerPicklers()
6441
df.withNewExecutionId {

sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala

Lines changed: 53 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@ package org.apache.spark.sql.execution.python
1919

2020
import scala.collection.mutable
2121

22-
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
23-
import org.apache.spark.sql.catalyst.plans.logical
24-
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
22+
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression}
2523
import org.apache.spark.sql.catalyst.rules.Rule
24+
import org.apache.spark.sql.execution
25+
import org.apache.spark.sql.execution.SparkPlan
2626

2727
/**
2828
* Extracts PythonUDFs from operators, rewriting the query plan so that the UDF can be evaluated
@@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
3434
* This has the limitation that the input to the Python UDF is not allowed include attributes from
3535
* multiple child operators.
3636
*/
37-
private[spark] object ExtractPythonUDFs extends Rule[LogicalPlan] {
37+
private[spark] object ExtractPythonUDFs extends Rule[SparkPlan] {
3838

3939
private def hasPythonUDF(e: Expression): Boolean = {
4040
e.find(_.isInstanceOf[PythonUDF]).isDefined
@@ -54,49 +54,61 @@ private[spark] object ExtractPythonUDFs extends Rule[LogicalPlan] {
5454
case e => e.children.flatMap(collectEvaluatableUDF)
5555
}
5656

57-
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
58-
// Skip EvaluatePython nodes.
59-
case plan: EvaluatePython => plan
57+
def apply(plan: SparkPlan): SparkPlan = plan transformUp {
58+
case plan: SparkPlan => extract(plan)
59+
}
6060

61-
case plan: LogicalPlan if plan.resolved =>
62-
// Extract any PythonUDFs from the current operator.
63-
val udfs = plan.expressions.flatMap(collectEvaluatableUDF).filter(_.resolved)
64-
if (udfs.isEmpty) {
65-
// If there aren't any, we are done.
66-
plan
67-
} else {
68-
val attributeMap = mutable.HashMap[PythonUDF, Expression]()
69-
// Rewrite the child that has the input required for the UDF
70-
val newChildren = plan.children.map { child =>
71-
// Pick the UDF we are going to evaluate
72-
val validUdfs = udfs.filter { case udf =>
73-
// Check to make sure that the UDF can be evaluated with only the input of this child.
74-
udf.references.subsetOf(child.outputSet)
75-
}
76-
if (validUdfs.nonEmpty) {
77-
val evaluation = EvaluatePython(validUdfs, child)
78-
attributeMap ++= validUdfs.zip(evaluation.resultAttribute)
79-
evaluation
80-
} else {
81-
child
82-
}
61+
/**
62+
* Extract all the PythonUDFs from the current operator.
63+
*/
64+
def extract(plan: SparkPlan): SparkPlan = {
65+
val udfs = plan.expressions.flatMap(collectEvaluatableUDF)
66+
if (udfs.isEmpty) {
67+
// If there aren't any, we are done.
68+
plan
69+
} else {
70+
val attributeMap = mutable.HashMap[PythonUDF, Expression]()
71+
// Rewrite the child that has the input required for the UDF
72+
val newChildren = plan.children.map { child =>
73+
// Pick the UDF we are going to evaluate
74+
val validUdfs = udfs.filter { case udf =>
75+
// Check to make sure that the UDF can be evaluated with only the input of this child.
76+
udf.references.subsetOf(child.outputSet)
8377
}
84-
// Other cases are disallowed as they are ambiguous or would require a cartesian
85-
// product.
86-
udfs.filterNot(attributeMap.contains).foreach { udf =>
87-
if (udf.references.subsetOf(plan.inputSet)) {
88-
sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.")
89-
} else {
90-
sys.error(s"Unable to evaluate PythonUDF $udf. Missing input attributes.")
78+
if (validUdfs.nonEmpty) {
79+
val resultAttrs = udfs.zipWithIndex.map { case (u, i) =>
80+
AttributeReference(s"pythonUDF$i", u.dataType)()
9181
}
82+
val evaluation = BatchPythonEvaluation(validUdfs, child.output ++ resultAttrs, child)
83+
attributeMap ++= validUdfs.zip(resultAttrs)
84+
evaluation
85+
} else {
86+
child
87+
}
88+
}
89+
// Other cases are disallowed as they are ambiguous or would require a cartesian
90+
// product.
91+
udfs.filterNot(attributeMap.contains).foreach { udf =>
92+
if (udf.references.subsetOf(plan.inputSet)) {
93+
sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.")
94+
} else {
95+
sys.error(s"Unable to evaluate PythonUDF $udf. Missing input attributes.")
9296
}
97+
}
98+
99+
val rewritten = plan.transformExpressions {
100+
case p: PythonUDF if attributeMap.contains(p) =>
101+
attributeMap(p)
102+
}.withNewChildren(newChildren)
93103

104+
// extract remaining python UDFs recursively
105+
val newPlan = extract(rewritten)
106+
if (newPlan.output != plan.output) {
94107
// Trim away the new UDF value if it was only used for filtering or something.
95-
logical.Project(
96-
plan.output,
97-
plan.transformExpressions {
98-
case p: PythonUDF if attributeMap.contains(p) => attributeMap(p)
99-
}.withNewChildren(newChildren))
108+
execution.Project(plan.output, newPlan)
109+
} else {
110+
newPlan
100111
}
112+
}
101113
}
102114
}

sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
package org.apache.spark.sql.execution.python
1919

2020
import org.apache.spark.api.python.PythonFunction
21-
import org.apache.spark.internal.Logging
2221
import org.apache.spark.sql.catalyst.expressions.{Expression, NonSQLExpression, Unevaluable}
2322
import org.apache.spark.sql.types.DataType
2423

@@ -30,7 +29,7 @@ case class PythonUDF(
3029
func: PythonFunction,
3130
dataType: DataType,
3231
children: Seq[Expression])
33-
extends Expression with Unevaluable with NonSQLExpression with Logging {
32+
extends Expression with Unevaluable with NonSQLExpression {
3433

3534
override def toString: String = s"$name(${children.mkString(", ")})"
3635

sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@ private[sql] class SessionState(ctx: SQLContext) {
6464
lazy val analyzer: Analyzer = {
6565
new Analyzer(catalog, functionRegistry, conf) {
6666
override val extendedResolutionRules =
67-
python.ExtractPythonUDFs ::
6867
PreInsertCastAndRename ::
6968
DataSourceAnalysis ::
7069
(if (conf.runSQLOnFile) new ResolveDataSource(ctx) :: Nil else Nil)

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ private[hive] class HiveSessionState(ctx: HiveContext) extends SessionState(ctx)
6060
catalog.OrcConversions ::
6161
catalog.CreateTables ::
6262
catalog.PreInsertionCasts ::
63-
python.ExtractPythonUDFs ::
6463
PreInsertCastAndRename ::
6564
DataSourceAnalysis ::
6665
(if (conf.runSQLOnFile) new ResolveDataSource(ctx) :: Nil else Nil)

0 commit comments

Comments
 (0)