|
17 | 17 |
|
18 | 18 | package org.apache.spark.sql.execution.python |
19 | 19 |
|
20 | | -import java.io.File |
21 | | - |
22 | 20 | import scala.collection.JavaConverters._ |
23 | | -import scala.collection.mutable.ArrayBuffer |
24 | 21 |
|
25 | 22 | import net.razorvine.pickle.{Pickler, Unpickler} |
26 | 23 |
|
27 | | -import org.apache.spark.{SparkEnv, TaskContext} |
| 24 | +import org.apache.spark.TaskContext |
28 | 25 | import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType, PythonRunner} |
29 | | -import org.apache.spark.rdd.RDD |
30 | 26 | import org.apache.spark.sql.catalyst.InternalRow |
31 | 27 | import org.apache.spark.sql.catalyst.expressions._ |
32 | 28 | import org.apache.spark.sql.execution.SparkPlan |
33 | | -import org.apache.spark.sql.types.{DataType, StructField, StructType} |
34 | | -import org.apache.spark.util.Utils |
35 | | - |
| 29 | +import org.apache.spark.sql.types.{StructField, StructType} |
36 | 30 |
|
37 | 31 | /** |
38 | | - * A physical plan that evaluates a [[PythonUDF]], one partition of tuples at a time. |
39 | | - * |
40 | | - * Python evaluation works by sending the necessary (projected) input data via a socket to an |
41 | | - * external Python process, and combine the result from the Python process with the original row. |
42 | | - * |
43 | | - * For each row we send to Python, we also put it in a queue first. For each output row from Python, |
44 | | - * we drain the queue to find the original input row. Note that if the Python process is way too |
45 | | - * slow, this could lead to the queue growing unbounded and spill into disk when run out of memory. |
46 | | - * |
47 | | - * Here is a diagram to show how this works: |
48 | | - * |
49 | | - * Downstream (for parent) |
50 | | - * / \ |
51 | | - * / socket (output of UDF) |
52 | | - * / \ |
53 | | - * RowQueue Python |
54 | | - * \ / |
55 | | - * \ socket (input of UDF) |
56 | | - * \ / |
57 | | - * upstream (from child) |
58 | | - * |
59 | | - * The rows sent to and received from Python are packed into batches (100 rows) and serialized, |
60 | | - * there should be always some rows buffered in the socket or Python process, so the pulling from |
61 | | - * RowQueue ALWAYS happened after pushing into it. |
| 32 | + * A physical plan that evaluates a [[PythonUDF]] |
62 | 33 | */ |
63 | 34 | case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan) |
64 | | - extends SparkPlan { |
65 | | - |
66 | | - def children: Seq[SparkPlan] = child :: Nil |
67 | | - |
68 | | - override def producedAttributes: AttributeSet = AttributeSet(output.drop(child.output.length)) |
69 | | - |
70 | | - private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = { |
71 | | - udf.children match { |
72 | | - case Seq(u: PythonUDF) => |
73 | | - val (chained, children) = collectFunctions(u) |
74 | | - (ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children) |
75 | | - case children => |
76 | | - // There should not be any other UDFs, or the children can't be evaluated directly. |
77 | | - assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty)) |
78 | | - (ChainedPythonFunctions(Seq(udf.func)), udf.children) |
79 | | - } |
80 | | - } |
81 | | - |
82 | | - protected override def doExecute(): RDD[InternalRow] = { |
83 | | - val inputRDD = child.execute().map(_.copy()) |
84 | | - val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) |
85 | | - val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) |
86 | | - |
87 | | - inputRDD.mapPartitions { iter => |
88 | | - EvaluatePython.registerPicklers() // register pickler for Row |
89 | | - |
90 | | - // The queue used to buffer input rows so we can drain it to |
91 | | - // combine input with output from Python. |
92 | | - val queue = HybridRowQueue(TaskContext.get().taskMemoryManager(), |
93 | | - new File(Utils.getLocalDir(SparkEnv.get.conf)), child.output.length) |
94 | | - TaskContext.get().addTaskCompletionListener({ ctx => |
95 | | - queue.close() |
96 | | - }) |
97 | | - |
98 | | - val (pyFuncs, inputs) = udfs.map(collectFunctions).unzip |
99 | | - |
100 | | - // flatten all the arguments |
101 | | - val allInputs = new ArrayBuffer[Expression] |
102 | | - val dataTypes = new ArrayBuffer[DataType] |
103 | | - val argOffsets = inputs.map { input => |
104 | | - input.map { e => |
105 | | - if (allInputs.exists(_.semanticEquals(e))) { |
106 | | - allInputs.indexWhere(_.semanticEquals(e)) |
107 | | - } else { |
108 | | - allInputs += e |
109 | | - dataTypes += e.dataType |
110 | | - allInputs.length - 1 |
111 | | - } |
112 | | - }.toArray |
113 | | - }.toArray |
114 | | - val projection = newMutableProjection(allInputs, child.output) |
115 | | - val schema = StructType(dataTypes.map(dt => StructField("", dt))) |
116 | | - val needConversion = dataTypes.exists(EvaluatePython.needConversionInPython) |
117 | | - |
118 | | - // enable memo iff we serialize the row with schema (schema and class should be memorized) |
119 | | - val pickle = new Pickler(needConversion) |
120 | | - // Input iterator to Python: input rows are grouped so we send them in batches to Python. |
121 | | - // For each row, add it to the queue. |
122 | | - val inputIterator = iter.map { inputRow => |
123 | | - queue.add(inputRow.asInstanceOf[UnsafeRow]) |
124 | | - val row = projection(inputRow) |
125 | | - if (needConversion) { |
126 | | - EvaluatePython.toJava(row, schema) |
127 | | - } else { |
128 | | - // fast path for these types that does not need conversion in Python |
129 | | - val fields = new Array[Any](row.numFields) |
130 | | - var i = 0 |
131 | | - while (i < row.numFields) { |
132 | | - val dt = dataTypes(i) |
133 | | - fields(i) = EvaluatePython.toJava(row.get(i, dt), dt) |
134 | | - i += 1 |
135 | | - } |
136 | | - fields |
137 | | - } |
138 | | - }.grouped(100).map(x => pickle.dumps(x.toArray)) |
139 | | - |
140 | | - val context = TaskContext.get() |
141 | | - |
142 | | - // Output iterator for results from Python. |
143 | | - val outputIterator = new PythonRunner( |
144 | | - pyFuncs, bufferSize, reuseWorker, PythonEvalType.SQL_BATCHED_UDF, argOffsets) |
145 | | - .compute(inputIterator, context.partitionId(), context) |
146 | | - |
147 | | - val unpickle = new Unpickler |
148 | | - val mutableRow = new GenericInternalRow(1) |
149 | | - val joined = new JoinedRow |
150 | | - val resultType = if (udfs.length == 1) { |
151 | | - udfs.head.dataType |
| 35 | + extends EvalPythonExec(udfs, output, child) { |
| 36 | + |
| 37 | + protected override def evaluate( |
| 38 | + funcs: Seq[ChainedPythonFunctions], |
| 39 | + bufferSize: Int, |
| 40 | + reuseWorker: Boolean, |
| 41 | + argOffsets: Array[Array[Int]], |
| 42 | + iter: Iterator[InternalRow], |
| 43 | + schema: StructType, |
| 44 | + context: TaskContext): Iterator[InternalRow] = { |
| 45 | + EvaluatePython.registerPicklers() // register pickler for Row |
| 46 | + |
| 47 | + val dataTypes = schema.map(_.dataType) |
| 48 | + val needConversion = dataTypes.exists(EvaluatePython.needConversionInPython) |
| 49 | + |
| 50 | + // enable memo iff we serialize the row with schema (schema and class should be memorized) |
| 51 | + val pickle = new Pickler(needConversion) |
| 52 | + // Input iterator to Python: input rows are grouped so we send them in batches to Python. |
| 53 | + // For each row, add it to the queue. |
| 54 | + val inputIterator = iter.map { row => |
| 55 | + if (needConversion) { |
| 56 | + EvaluatePython.toJava(row, schema) |
152 | 57 | } else { |
153 | | - StructType(udfs.map(u => StructField("", u.dataType, u.nullable))) |
154 | | - } |
155 | | - val resultProj = UnsafeProjection.create(output, output) |
156 | | - outputIterator.flatMap { pickedResult => |
157 | | - val unpickledBatch = unpickle.loads(pickedResult) |
158 | | - unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala |
159 | | - }.map { result => |
160 | | - val row = if (udfs.length == 1) { |
161 | | - // fast path for single UDF |
162 | | - mutableRow(0) = EvaluatePython.fromJava(result, resultType) |
163 | | - mutableRow |
164 | | - } else { |
165 | | - EvaluatePython.fromJava(result, resultType).asInstanceOf[InternalRow] |
| 58 | + // fast path for these types that does not need conversion in Python |
| 59 | + val fields = new Array[Any](row.numFields) |
| 60 | + var i = 0 |
| 61 | + while (i < row.numFields) { |
| 62 | + val dt = dataTypes(i) |
| 63 | + fields(i) = EvaluatePython.toJava(row.get(i, dt), dt) |
| 64 | + i += 1 |
166 | 65 | } |
167 | | - resultProj(joined(queue.remove(), row)) |
| 66 | + fields |
| 67 | + } |
| 68 | + }.grouped(100).map(x => pickle.dumps(x.toArray)) |
| 69 | + |
| 70 | + // Output iterator for results from Python. |
| 71 | + val outputIterator = new PythonRunner( |
| 72 | + funcs, bufferSize, reuseWorker, PythonEvalType.SQL_BATCHED_UDF, argOffsets) |
| 73 | + .compute(inputIterator, context.partitionId(), context) |
| 74 | + |
| 75 | + val unpickle = new Unpickler |
| 76 | + val mutableRow = new GenericInternalRow(1) |
| 77 | + val resultType = if (udfs.length == 1) { |
| 78 | + udfs.head.dataType |
| 79 | + } else { |
| 80 | + StructType(udfs.map(u => StructField("", u.dataType, u.nullable))) |
| 81 | + } |
| 82 | + outputIterator.flatMap { pickedResult => |
| 83 | + val unpickledBatch = unpickle.loads(pickedResult) |
| 84 | + unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala |
| 85 | + }.map { result => |
| 86 | + if (udfs.length == 1) { |
| 87 | + // fast path for single UDF |
| 88 | + mutableRow(0) = EvaluatePython.fromJava(result, resultType) |
| 89 | + mutableRow |
| 90 | + } else { |
| 91 | + EvaluatePython.fromJava(result, resultType).asInstanceOf[InternalRow] |
168 | 92 | } |
169 | 93 | } |
170 | 94 | } |
|
0 commit comments