Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from pyspark import since, SparkContext
from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix
from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
from pyspark.serializers import PickleSerializer, BatchedSerializer
from pyspark.sql.types import StringType
from pyspark.sql.column import Column, _to_java_column, _to_seq

Expand Down Expand Up @@ -1414,7 +1414,7 @@ def __init__(self, func, returnType, name=None):
def _create_judf(self, name):
f, returnType = self.func, self.returnType # put them in closure `func`
func = lambda _, it: map(lambda x: returnType.toInternal(f(*x)), it)
ser = AutoBatchedSerializer(PickleSerializer())
ser = BatchedSerializer(PickleSerializer(), 100)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we pull this out in a constant? And also the same value in the Python, and put a comment on each saying that they have to equal? It's very dangerous if this value goes out of sync.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two values don't need to be the same.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I was still thinking about my first attempt which involved a blocking queue.

command = (func, None, ser, ser)
sc = SparkContext._active_spark_context
pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,11 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:
def children: Seq[SparkPlan] = child :: Nil

protected override def doExecute(): RDD[InternalRow] = {
val childResults = child.execute().map(_.copy())
val buffer = new java.util.concurrent.ConcurrentLinkedQueue[InternalRow]()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If i understand this correctly, we are assuming the following in order for this to work:

  1. Each task gets their own copy of the deserialized closure, and thus their own copy of the queue.
  2. All closures are serialized together in one shot, rather than in multiple places (e.g. they are all done in the serializer, not in the ctor of the RDD)
  3. Java serializer does not serialize objects twice within the same stream, since it uses it to detect cycles. When they are deserialized, they still point to the same copy.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

val childResults = child.execute().map(_.copy()).map { row =>
buffer.add(row)
row
}

val parent = childResults.mapPartitions { iter =>
EvaluatePython.registerPicklers() // register pickler for Row
Expand All @@ -354,7 +358,7 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:
}
}

val pyRDD = new PythonRDD(
new PythonRDD(
parent,
udf.command,
udf.envVars,
Expand All @@ -372,17 +376,10 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:
}
}.mapPartitions { iter =>
val row = new GenericMutableRow(1)
val joined = new JoinedRow
iter.map { result =>
row(0) = EvaluatePython.fromJava(result, udf.dataType)
row: InternalRow
}
}

childResults.zip(pyRDD).mapPartitions { iter =>
val joinedRow = new JoinedRow()
iter.map {
case (row, udfResult) =>
joinedRow(row, udfResult)
joined(buffer.poll(), row)
}
}
}
Expand Down