Skip to content

Commit 69112a5

Browse files
committed
Refactor EvalPythonExec.
1 parent d49a3db commit 69112a5

File tree

3 files changed

+230
-229
lines changed

3 files changed

+230
-229
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala

Lines changed: 30 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -17,110 +17,45 @@
1717

1818
package org.apache.spark.sql.execution.python
1919

20-
import java.io.File
21-
22-
import scala.collection.mutable.ArrayBuffer
23-
24-
import org.apache.spark.{SparkEnv, TaskContext}
20+
import org.apache.spark.TaskContext
2521
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType, PythonRunner}
26-
import org.apache.spark.rdd.RDD
2722
import org.apache.spark.sql.catalyst.InternalRow
2823
import org.apache.spark.sql.catalyst.expressions._
2924
import org.apache.spark.sql.execution.SparkPlan
3025
import org.apache.spark.sql.execution.arrow.{ArrowConverters, ArrowPayload}
31-
import org.apache.spark.sql.types.{DataType, StructField, StructType}
32-
import org.apache.spark.util.Utils
33-
26+
import org.apache.spark.sql.types.StructType
3427

3528
/**
3629
* A physical plan that evaluates a [[PythonUDF]],
3730
*/
3831
case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan)
39-
extends SparkPlan {
40-
41-
def children: Seq[SparkPlan] = child :: Nil
42-
43-
override def producedAttributes: AttributeSet = AttributeSet(output.drop(child.output.length))
44-
45-
private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = {
46-
udf.children match {
47-
case Seq(u: PythonUDF) =>
48-
val (chained, children) = collectFunctions(u)
49-
(ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children)
50-
case children =>
51-
// There should not be any other UDFs, or the children can't be evaluated directly.
52-
assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty))
53-
(ChainedPythonFunctions(Seq(udf.func)), udf.children)
54-
}
55-
}
56-
57-
protected override def doExecute(): RDD[InternalRow] = {
58-
val inputRDD = child.execute().map(_.copy())
59-
val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536)
60-
val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true)
61-
62-
inputRDD.mapPartitions { iter =>
63-
val context = TaskContext.get()
64-
65-
// The queue used to buffer input rows so we can drain it to
66-
// combine input with output from Python.
67-
val queue = HybridRowQueue(context.taskMemoryManager(),
68-
new File(Utils.getLocalDir(SparkEnv.get.conf)), child.output.length)
69-
context.addTaskCompletionListener { _ =>
70-
queue.close()
71-
}
72-
73-
val (pyFuncs, inputs) = udfs.map(collectFunctions).unzip
74-
75-
// flatten all the arguments
76-
val allInputs = new ArrayBuffer[Expression]
77-
val dataTypes = new ArrayBuffer[DataType]
78-
val argOffsets = inputs.map { input =>
79-
input.map { e =>
80-
if (allInputs.exists(_.semanticEquals(e))) {
81-
allInputs.indexWhere(_.semanticEquals(e))
82-
} else {
83-
allInputs += e
84-
dataTypes += e.dataType
85-
allInputs.length - 1
86-
}
87-
}.toArray
88-
}.toArray
89-
val projection = newMutableProjection(allInputs, child.output)
90-
val schemaIn = StructType(dataTypes.zipWithIndex.map { case (dt, i) =>
91-
StructField(s"_$i", dt)
92-
})
93-
94-
// Iterator to construct Arrow payloads. Add rows to queue to join later with the result.
95-
val projectedRowIter = iter.map { inputRow =>
96-
queue.add(inputRow.asInstanceOf[UnsafeRow])
97-
projection(inputRow)
98-
}
99-
100-
val inputIterator = ArrowConverters.toPayloadIterator(
101-
projectedRowIter, schemaIn, conf.arrowMaxRecordsPerBatch, context)
102-
.map(_.asPythonSerializable)
103-
104-
// Output iterator for results from Python.
105-
val outputIterator = new PythonRunner(
106-
pyFuncs, bufferSize, reuseWorker, PythonEvalType.SQL_PANDAS_UDF, argOffsets)
107-
.compute(inputIterator, context.partitionId(), context)
108-
109-
val outputRowIterator = ArrowConverters.fromPayloadIterator(
110-
outputIterator.map(new ArrowPayload(_)), context)
111-
112-
// Verify that the output schema is correct
113-
val schemaOut = StructType.fromAttributes(output.drop(child.output.length).zipWithIndex
114-
.map { case (attr, i) => attr.withName(s"_$i") })
115-
assert(schemaOut.equals(outputRowIterator.schema),
116-
s"Invalid schema from pandas_udf: expected $schemaOut, got ${outputRowIterator.schema}")
117-
118-
val joined = new JoinedRow
119-
val resultProj = UnsafeProjection.create(output, output)
120-
121-
outputRowIterator.map { outputRow =>
122-
resultProj(joined(queue.remove(), outputRow))
123-
}
124-
}
32+
extends EvalPythonExec(udfs, output, child) {
33+
34+
protected override def evaluate(
35+
funcs: Seq[ChainedPythonFunctions],
36+
bufferSize: Int,
37+
reuseWorker: Boolean,
38+
argOffsets: Array[Array[Int]],
39+
iter: Iterator[InternalRow],
40+
schema: StructType,
41+
context: TaskContext): Iterator[InternalRow] = {
42+
val inputIterator = ArrowConverters.toPayloadIterator(
43+
iter, schema, conf.arrowMaxRecordsPerBatch, context).map(_.asPythonSerializable)
44+
45+
// Output iterator for results from Python.
46+
val outputIterator = new PythonRunner(
47+
funcs, bufferSize, reuseWorker, PythonEvalType.SQL_PANDAS_UDF, argOffsets)
48+
.compute(inputIterator, context.partitionId(), context)
49+
50+
val outputRowIterator = ArrowConverters.fromPayloadIterator(
51+
outputIterator.map(new ArrowPayload(_)), context)
52+
53+
// Verify that the output schema is correct
54+
val schemaOut = StructType.fromAttributes(output.drop(child.output.length).zipWithIndex
55+
.map { case (attr, i) => attr.withName(s"_$i") })
56+
assert(schemaOut.equals(outputRowIterator.schema),
57+
s"Invalid schema from pandas_udf: expected $schemaOut, got ${outputRowIterator.schema}")
58+
59+
outputRowIterator
12560
}
12661
}

sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala

Lines changed: 58 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -17,154 +17,78 @@
1717

1818
package org.apache.spark.sql.execution.python
1919

20-
import java.io.File
21-
2220
import scala.collection.JavaConverters._
23-
import scala.collection.mutable.ArrayBuffer
2421

2522
import net.razorvine.pickle.{Pickler, Unpickler}
2623

27-
import org.apache.spark.{SparkEnv, TaskContext}
24+
import org.apache.spark.TaskContext
2825
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType, PythonRunner}
29-
import org.apache.spark.rdd.RDD
3026
import org.apache.spark.sql.catalyst.InternalRow
3127
import org.apache.spark.sql.catalyst.expressions._
3228
import org.apache.spark.sql.execution.SparkPlan
33-
import org.apache.spark.sql.types.{DataType, StructField, StructType}
34-
import org.apache.spark.util.Utils
35-
29+
import org.apache.spark.sql.types.{StructField, StructType}
3630

3731
/**
38-
* A physical plan that evaluates a [[PythonUDF]], one partition of tuples at a time.
39-
*
40-
* Python evaluation works by sending the necessary (projected) input data via a socket to an
41-
* external Python process, and combine the result from the Python process with the original row.
42-
*
43-
* For each row we send to Python, we also put it in a queue first. For each output row from Python,
44-
* we drain the queue to find the original input row. Note that if the Python process is way too
45-
* slow, this could lead to the queue growing unbounded and spill into disk when run out of memory.
46-
*
47-
* Here is a diagram to show how this works:
48-
*
49-
* Downstream (for parent)
50-
* / \
51-
* / socket (output of UDF)
52-
* / \
53-
* RowQueue Python
54-
* \ /
55-
* \ socket (input of UDF)
56-
* \ /
57-
* upstream (from child)
58-
*
59-
* The rows sent to and received from Python are packed into batches (100 rows) and serialized,
60-
* there should be always some rows buffered in the socket or Python process, so the pulling from
61-
* RowQueue ALWAYS happened after pushing into it.
32+
* A physical plan that evaluates a [[PythonUDF]]
6233
*/
6334
case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan)
64-
extends SparkPlan {
65-
66-
def children: Seq[SparkPlan] = child :: Nil
67-
68-
override def producedAttributes: AttributeSet = AttributeSet(output.drop(child.output.length))
69-
70-
private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = {
71-
udf.children match {
72-
case Seq(u: PythonUDF) =>
73-
val (chained, children) = collectFunctions(u)
74-
(ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children)
75-
case children =>
76-
// There should not be any other UDFs, or the children can't be evaluated directly.
77-
assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty))
78-
(ChainedPythonFunctions(Seq(udf.func)), udf.children)
79-
}
80-
}
81-
82-
protected override def doExecute(): RDD[InternalRow] = {
83-
val inputRDD = child.execute().map(_.copy())
84-
val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536)
85-
val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true)
86-
87-
inputRDD.mapPartitions { iter =>
88-
EvaluatePython.registerPicklers() // register pickler for Row
89-
90-
// The queue used to buffer input rows so we can drain it to
91-
// combine input with output from Python.
92-
val queue = HybridRowQueue(TaskContext.get().taskMemoryManager(),
93-
new File(Utils.getLocalDir(SparkEnv.get.conf)), child.output.length)
94-
TaskContext.get().addTaskCompletionListener({ ctx =>
95-
queue.close()
96-
})
97-
98-
val (pyFuncs, inputs) = udfs.map(collectFunctions).unzip
99-
100-
// flatten all the arguments
101-
val allInputs = new ArrayBuffer[Expression]
102-
val dataTypes = new ArrayBuffer[DataType]
103-
val argOffsets = inputs.map { input =>
104-
input.map { e =>
105-
if (allInputs.exists(_.semanticEquals(e))) {
106-
allInputs.indexWhere(_.semanticEquals(e))
107-
} else {
108-
allInputs += e
109-
dataTypes += e.dataType
110-
allInputs.length - 1
111-
}
112-
}.toArray
113-
}.toArray
114-
val projection = newMutableProjection(allInputs, child.output)
115-
val schema = StructType(dataTypes.map(dt => StructField("", dt)))
116-
val needConversion = dataTypes.exists(EvaluatePython.needConversionInPython)
117-
118-
// enable memo iff we serialize the row with schema (schema and class should be memorized)
119-
val pickle = new Pickler(needConversion)
120-
// Input iterator to Python: input rows are grouped so we send them in batches to Python.
121-
// For each row, add it to the queue.
122-
val inputIterator = iter.map { inputRow =>
123-
queue.add(inputRow.asInstanceOf[UnsafeRow])
124-
val row = projection(inputRow)
125-
if (needConversion) {
126-
EvaluatePython.toJava(row, schema)
127-
} else {
128-
// fast path for these types that does not need conversion in Python
129-
val fields = new Array[Any](row.numFields)
130-
var i = 0
131-
while (i < row.numFields) {
132-
val dt = dataTypes(i)
133-
fields(i) = EvaluatePython.toJava(row.get(i, dt), dt)
134-
i += 1
135-
}
136-
fields
137-
}
138-
}.grouped(100).map(x => pickle.dumps(x.toArray))
139-
140-
val context = TaskContext.get()
141-
142-
// Output iterator for results from Python.
143-
val outputIterator = new PythonRunner(
144-
pyFuncs, bufferSize, reuseWorker, PythonEvalType.SQL_BATCHED_UDF, argOffsets)
145-
.compute(inputIterator, context.partitionId(), context)
146-
147-
val unpickle = new Unpickler
148-
val mutableRow = new GenericInternalRow(1)
149-
val joined = new JoinedRow
150-
val resultType = if (udfs.length == 1) {
151-
udfs.head.dataType
35+
extends EvalPythonExec(udfs, output, child) {
36+
37+
protected override def evaluate(
38+
funcs: Seq[ChainedPythonFunctions],
39+
bufferSize: Int,
40+
reuseWorker: Boolean,
41+
argOffsets: Array[Array[Int]],
42+
iter: Iterator[InternalRow],
43+
schema: StructType,
44+
context: TaskContext): Iterator[InternalRow] = {
45+
EvaluatePython.registerPicklers() // register pickler for Row
46+
47+
val dataTypes = schema.map(_.dataType)
48+
val needConversion = dataTypes.exists(EvaluatePython.needConversionInPython)
49+
50+
// enable memo iff we serialize the row with schema (schema and class should be memorized)
51+
val pickle = new Pickler(needConversion)
52+
// Input iterator to Python: input rows are grouped so we send them in batches to Python.
53+
// For each row, add it to the queue.
54+
val inputIterator = iter.map { row =>
55+
if (needConversion) {
56+
EvaluatePython.toJava(row, schema)
15257
} else {
153-
StructType(udfs.map(u => StructField("", u.dataType, u.nullable)))
154-
}
155-
val resultProj = UnsafeProjection.create(output, output)
156-
outputIterator.flatMap { pickedResult =>
157-
val unpickledBatch = unpickle.loads(pickedResult)
158-
unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala
159-
}.map { result =>
160-
val row = if (udfs.length == 1) {
161-
// fast path for single UDF
162-
mutableRow(0) = EvaluatePython.fromJava(result, resultType)
163-
mutableRow
164-
} else {
165-
EvaluatePython.fromJava(result, resultType).asInstanceOf[InternalRow]
58+
// fast path for these types that does not need conversion in Python
59+
val fields = new Array[Any](row.numFields)
60+
var i = 0
61+
while (i < row.numFields) {
62+
val dt = dataTypes(i)
63+
fields(i) = EvaluatePython.toJava(row.get(i, dt), dt)
64+
i += 1
16665
}
167-
resultProj(joined(queue.remove(), row))
66+
fields
67+
}
68+
}.grouped(100).map(x => pickle.dumps(x.toArray))
69+
70+
// Output iterator for results from Python.
71+
val outputIterator = new PythonRunner(
72+
funcs, bufferSize, reuseWorker, PythonEvalType.SQL_BATCHED_UDF, argOffsets)
73+
.compute(inputIterator, context.partitionId(), context)
74+
75+
val unpickle = new Unpickler
76+
val mutableRow = new GenericInternalRow(1)
77+
val resultType = if (udfs.length == 1) {
78+
udfs.head.dataType
79+
} else {
80+
StructType(udfs.map(u => StructField("", u.dataType, u.nullable)))
81+
}
82+
outputIterator.flatMap { pickedResult =>
83+
val unpickledBatch = unpickle.loads(pickedResult)
84+
unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala
85+
}.map { result =>
86+
if (udfs.length == 1) {
87+
// fast path for single UDF
88+
mutableRow(0) = EvaluatePython.fromJava(result, resultType)
89+
mutableRow
90+
} else {
91+
EvaluatePython.fromJava(result, resultType).asInstanceOf[InternalRow]
16892
}
16993
}
17094
}

0 commit comments

Comments
 (0)