diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 26b8662718a6..3c58e2c35a4e 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -26,7 +26,7 @@ from pyspark import since, SparkContext from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix -from pyspark.serializers import PickleSerializer, AutoBatchedSerializer +from pyspark.serializers import PickleSerializer, BatchedSerializer from pyspark.sql.types import StringType from pyspark.sql.column import Column, _to_java_column, _to_seq @@ -1414,7 +1414,7 @@ def __init__(self, func, returnType, name=None): def _create_judf(self, name): f, returnType = self.func, self.returnType # put them in closure `func` func = lambda _, it: map(lambda x: returnType.toInternal(f(*x)), it) - ser = AutoBatchedSerializer(PickleSerializer()) + ser = BatchedSerializer(PickleSerializer(), 100) command = (func, None, ser, ser) sc = SparkContext._active_spark_context pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala index 5a58d846ad80..02d34b1a3642 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala @@ -338,7 +338,11 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: def children: Seq[SparkPlan] = child :: Nil protected override def doExecute(): RDD[InternalRow] = { - val childResults = child.execute().map(_.copy()) + val buffer = new java.util.concurrent.ConcurrentLinkedQueue[InternalRow]() + val childResults = child.execute().map(_.copy()).map { row => + buffer.add(row) + row + } val parent = childResults.mapPartitions { iter => EvaluatePython.registerPicklers() // register pickler for Row @@ -354,7 +358,7 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: } } - val pyRDD = new PythonRDD( + new PythonRDD( parent, udf.command, udf.envVars, @@ -372,17 +376,10 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: } }.mapPartitions { iter => val row = new GenericMutableRow(1) + val joined = new JoinedRow iter.map { result => row(0) = EvaluatePython.fromJava(result, udf.dataType) - row: InternalRow - } - } - - childResults.zip(pyRDD).mapPartitions { iter => - val joinedRow = new JoinedRow() - iter.map { - case (row, udfResult) => - joinedRow(row, udfResult) + joined(buffer.poll(), row) } } }