From af5254b0fd4a11696f248d148c650f157496af6e Mon Sep 17 00:00:00 2001 From: Justin Uang Date: Tue, 8 Sep 2015 00:23:14 -0400 Subject: [PATCH 1/3] [SPARK-8632] [SQL] [PYSPARK] Poor Python UDF performance because of RDD caching - I wanted to reuse most of the logic from PythonRDD, so I pulled out two methods, writeHeaderToStream and readPythonProcessSocket - The worker.py now has a switch where it reads an int that either tells it to go into normal pyspark RDD mode, which is meant for a streaming two thread workflow, and pyspark UDF mode, which is meant to be called synchronously --- .../apache/spark/api/python/PythonRDD.scala | 257 +++++++++++------- python/pyspark/worker.py | 43 ++- .../spark/sql/execution/pythonUDFs.scala | 172 +++++++++--- 3 files changed, 317 insertions(+), 155 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index b4d152b33660..aa0ca8df99d3 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -19,6 +19,7 @@ package org.apache.spark.api.python import java.io._ import java.net._ +import java.util.concurrent.atomic.AtomicBoolean import java.util.{Collections, ArrayList => JArrayList, List => JList, Map => JMap} import scala.collection.JavaConverters._ @@ -72,14 +73,14 @@ private[spark] class PythonRDD( } val worker: Socket = env.createPythonWorker(pythonExec, envVars.asScala.toMap) // Whether is the worker released into idle pool - @volatile var released = false + var released = new AtomicBoolean(false) // Start a thread to feed the process input from our parent's iterator val writerThread = new WriterThread(env, worker, split, context) context.addTaskCompletionListener { context => writerThread.shutdownOnTaskCompletion() - if (!reuse_worker || !released) { + if (!reuse_worker || !released.get()) { try { worker.close() } catch { @@ -107,73 +108,17 @@ private[spark] class PythonRDD( if (writerThread.exception.isDefined) { throw writerThread.exception.get } - try { - stream.readInt() match { - case length if length > 0 => - val obj = new Array[Byte](length) - stream.readFully(obj) - obj - case 0 => Array.empty[Byte] - case SpecialLengths.TIMING_DATA => - // Timing data from worker - val bootTime = stream.readLong() - val initTime = stream.readLong() - val finishTime = stream.readLong() - val boot = bootTime - startTime - val init = initTime - bootTime - val finish = finishTime - initTime - val total = finishTime - startTime - logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot, - init, finish)) - val memoryBytesSpilled = stream.readLong() - val diskBytesSpilled = stream.readLong() - context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled) - context.taskMetrics.incDiskBytesSpilled(diskBytesSpilled) - read() - case SpecialLengths.PYTHON_EXCEPTION_THROWN => - // Signals that an exception has been thrown in python - val exLength = stream.readInt() - val obj = new Array[Byte](exLength) - stream.readFully(obj) - throw new PythonException(new String(obj, UTF_8), - writerThread.exception.getOrElse(null)) - case SpecialLengths.END_OF_DATA_SECTION => - // We've finished the data section of the output, but we can still - // read some accumulator updates: - val numAccumulatorUpdates = stream.readInt() - (1 to numAccumulatorUpdates).foreach { _ => - val updateLen = stream.readInt() - val update = new Array[Byte](updateLen) - stream.readFully(update) - accumulator += Collections.singletonList(update) - } - // Check whether the worker is ready to be re-used. - if (stream.readInt() == SpecialLengths.END_OF_STREAM) { - if (reuse_worker) { - env.releasePythonWorker(pythonExec, envVars.asScala.toMap, worker) - released = true - } - } - null - } - } catch { - - case e: Exception if context.isInterrupted => - logDebug("Exception thrown after task interruption", e) - throw new TaskKilledException - - case e: Exception if env.isStopped => - logDebug("Exception thrown after context is stopped", e) - null // exit silently - case e: Exception if writerThread.exception.isDefined => - logError("Python worker exited unexpectedly (crashed)", e) - logError("This may have been caused by a prior exception:", writerThread.exception.get) - throw writerThread.exception.get - - case eof: EOFException => - throw new SparkException("Python worker exited unexpectedly (crashed)", eof) - } + PythonRDD.readPythonProcessSocket( + stream, + startTime, + reuse_worker, + accumulator, + envVars, + pythonExec, + worker, + released, + writerThread.exception) } var _nextObj = read() @@ -192,6 +137,7 @@ private[spark] class PythonRDD( class WriterThread(env: SparkEnv, worker: Socket, split: Partition, context: TaskContext) extends Thread(s"stdout writer for $pythonExec") { + @volatile private var _exception: Exception = null setDaemon(true) @@ -208,45 +154,19 @@ private[spark] class PythonRDD( override def run(): Unit = Utils.logUncaughtExceptions { try { TaskContext.setTaskContext(context) + val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) val dataOut = new DataOutputStream(stream) - // Partition index - dataOut.writeInt(split.index) - // Python version of driver - PythonRDD.writeUTF(pythonVer, dataOut) - // sparkFilesDir - PythonRDD.writeUTF(SparkFiles.getRootDirectory, dataOut) - // Python includes (*.zip and *.egg files) - dataOut.writeInt(pythonIncludes.size()) - for (include <- pythonIncludes.asScala) { - PythonRDD.writeUTF(include, dataOut) - } - // Broadcast variables - val oldBids = PythonRDD.getWorkerBroadcasts(worker) - val newBids = broadcastVars.asScala.map(_.id).toSet - // number of different broadcasts - val toRemove = oldBids.diff(newBids) - val cnt = toRemove.size + newBids.diff(oldBids).size - dataOut.writeInt(cnt) - for (bid <- toRemove) { - // remove the broadcast from worker - dataOut.writeLong(- bid - 1) // bid >= 0 - oldBids.remove(bid) - } - for (broadcast <- broadcastVars.asScala) { - if (!oldBids.contains(broadcast.id)) { - // send new broadcast - dataOut.writeLong(broadcast.id) - PythonRDD.writeUTF(broadcast.value.path, dataOut) - oldBids.add(broadcast.id) - } - } - dataOut.flush() - // Serialized command: - dataOut.writeInt(command.length) - dataOut.write(command) + + PythonRDD.writeHeaderToStream(dataOut, pythonIncludes, broadcastVars, worker, pythonVer, split.index, command) + + // RDD mode + dataOut.writeInt(PySparkMode.RDD) + // Data values PythonRDD.writeIteratorToStream(firstParent.iterator(split, context), dataOut) + + // Finish dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) dataOut.writeInt(SpecialLengths.END_OF_STREAM) dataOut.flush() @@ -315,7 +235,12 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends RDD[(Long, Array[Byte] val asJavaPairRDD : JavaPairRDD[Long, Array[Byte]] = JavaPairRDD.fromRDD(this) } -private object SpecialLengths { +object PySparkMode { + val RDD = 0 + val UDF = 1 +} + +object SpecialLengths { val END_OF_DATA_SECTION = -1 val PYTHON_EXCEPTION_THROWN = -2 val TIMING_DATA = -3 @@ -420,6 +345,130 @@ private[spark] object PythonRDD extends Logging { iter.foreach(write) } + def writeHeaderToStream( + dataOut: DataOutputStream, + pythonIncludes: JList[String], + broadcastVars: JList[Broadcast[PythonBroadcast]], + worker: Socket, + pythonVer: String, + partitionIndex: Int, + command: Array[Byte]) { + // Partition index + dataOut.writeInt(partitionIndex) + // Python version of driver + PythonRDD.writeUTF(pythonVer, dataOut) + // sparkFilesDir + PythonRDD.writeUTF(SparkFiles.getRootDirectory, dataOut) + // Python includes (*.zip and *.egg files) + dataOut.writeInt(pythonIncludes.size()) + for (include <- pythonIncludes.asScala) { + PythonRDD.writeUTF(include, dataOut) + } + // Broadcast variables + val oldBids = PythonRDD.getWorkerBroadcasts(worker) + val newBids = broadcastVars.asScala.map(_.id).toSet + // number of different broadcasts + val toRemove = oldBids.diff(newBids) + val cnt = toRemove.size + newBids.diff(oldBids).size + dataOut.writeInt(cnt) + for (bid <- toRemove) { + // remove the broadcast from worker + dataOut.writeLong(- bid - 1) // bid >= 0 + oldBids.remove(bid) + } + for (broadcast <- broadcastVars.asScala) { + if (!oldBids.contains(broadcast.id)) { + // send new broadcast + dataOut.writeLong(broadcast.id) + PythonRDD.writeUTF(broadcast.value.path, dataOut) + oldBids.add(broadcast.id) + } + } + + // Serialized command: + dataOut.writeInt(command.length) + dataOut.write(command) + + dataOut.flush() + } + + def readPythonProcessSocket( + stream: DataInputStream, + startTime: Long, + reuseWorker: Boolean, + accumulator: Accumulator[JList[Array[Byte]]], + envVars: JMap[String, String], + pythonExec: String, + worker: Socket, + released: AtomicBoolean, + writerThreadException: Option[Exception] = None) : Array[Byte] = { + try { + stream.readInt() match { + case length if length > 0 => + val obj = new Array[Byte](length) + stream.readFully(obj) + obj + case 0 => Array.empty[Byte] + case SpecialLengths.TIMING_DATA => + // Timing data from worker + val bootTime = stream.readLong() + val initTime = stream.readLong() + val finishTime = stream.readLong() + val boot = bootTime - startTime + val init = initTime - bootTime + val finish = finishTime - initTime + val total = finishTime - startTime + logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot, + init, finish)) + val memoryBytesSpilled = stream.readLong() + val diskBytesSpilled = stream.readLong() + TaskContext.get.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled) + TaskContext.get.taskMetrics.incDiskBytesSpilled(diskBytesSpilled) + readPythonProcessSocket(stream, startTime, reuseWorker, accumulator, envVars, pythonExec, worker, released, writerThreadException) + case SpecialLengths.PYTHON_EXCEPTION_THROWN => + // Signals that an exception has been thrown in python + val exLength = stream.readInt() + val obj = new Array[Byte](exLength) + stream.readFully(obj) + throw new PythonException(new String(obj, UTF_8), writerThreadException.getOrElse(null)) + case SpecialLengths.END_OF_DATA_SECTION => + // We've finished the data section of the output, but we can still + // read some accumulator updates: + val numAccumulatorUpdates = stream.readInt() + (1 to numAccumulatorUpdates).foreach { _ => + val updateLen = stream.readInt() + val update = new Array[Byte](updateLen) + stream.readFully(update) + accumulator += Collections.singletonList(update) + } + // Check whether the worker is ready to be re-used. + if (stream.readInt() == SpecialLengths.END_OF_STREAM) { + if (reuseWorker) { + SparkEnv.get.releasePythonWorker(pythonExec, envVars.asScala.toMap, worker) + released.set(true) + } + } + null + } + } catch { + case e: Exception if TaskContext.get.isInterrupted => + logDebug("Exception thrown after task interruption", e) + throw new TaskKilledException + + case e: Exception if SparkEnv.get.isStopped => + logDebug("Exception thrown after context is stopped", e) + null // exit silently + + case e: Exception if writerThreadException.isDefined => + logError("Python worker exited unexpectedly (crashed)", e) + logError("This may have been caused by a prior exception:", writerThreadException.get) + throw writerThreadException.get + + case eof: EOFException => + throw new SparkException("Python worker exited unexpectedly (crashed)", eof) + } + } + /** * Create an RDD from a path using [[org.apache.hadoop.mapred.SequenceFileInputFormat]], * key and value class. diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 42c2f8b75933..bfe57dbb7316 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -32,6 +32,10 @@ write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer from pyspark import shuffle +class PySparkMode(object): + RDD = 0 + UDF = 1 + pickleSer = PickleSerializer() utf8_deserializer = UTF8Deserializer() @@ -101,14 +105,41 @@ def main(infile, outfile): func, profiler, deserializer, serializer = command init_time = time.time() - def process(): - iterator = deserializer.load_stream(infile) - serializer.dump_stream(func(split_index, iterator), outfile) + pyspark_mode = read_int(infile) + + if pyspark_mode == PySparkMode.RDD: + def process(): + iterator = deserializer.load_stream(infile) + serializer.dump_stream(func(split_index, iterator), outfile) + + if profiler: + profiler.profile(process) + else: + process() + elif pyspark_mode == PySparkMode.UDF: + pickle_serializer = PickleSerializer() + + batch_length = read_int(infile) - if profiler: - profiler.profile(process) + while batch_length != SpecialLengths.END_OF_DATA_SECTION: + batch_bytes = infile.read(batch_length) + batch_objs = pickle_serializer.loads(batch_bytes) + + udf_result_objs = list(func(split_index, iter(batch_objs))) + pickled_bytes = pickle_serializer.dumps(udf_result_objs) + write_int(len(pickled_bytes), outfile) + + if sys.version_info[0:2] <= (2, 6): + outfile.write(str(pickled_bytes)) + else: + outfile.write(pickled_bytes) + + outfile.flush() + + batch_length = read_int(infile) else: - process() + raise Exception("Unknown pyspark mode: {}".format(pyspark_mode)) + except Exception: try: write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala index 5a58d846ad80..d6dce6dcd4b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala @@ -17,15 +17,19 @@ package org.apache.spark.sql.execution -import java.io.OutputStream -import java.util.{List => JList, Map => JMap} +import java.io.{BufferedInputStream, DataInputStream, DataOutputStream, BufferedOutputStream, OutputStream} +import java.net.Socket +import java.util.concurrent.atomic.AtomicBoolean +import java.util.{List => JList, Map => JMap, Collections} + +import org.apache.spark.api.python import scala.collection.JavaConverters._ import net.razorvine.pickle._ import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.api.python.{PythonBroadcast, PythonRDD, SerDeUtil} +import org.apache.spark.api.python._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -35,7 +39,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -import org.apache.spark.{Accumulator, Logging => SparkLogging} +import org.apache.spark.{Logging => SparkLogging, TaskContext, SparkEnv, Accumulator} /** * A serialized version of a Python lambda function. Suitable for use in a [[PythonRDD]]. @@ -328,8 +332,8 @@ case class EvaluatePython( /** * :: DeveloperApi :: - * Uses PythonRDD to evaluate a [[PythonUDF]], one partition of tuples at a time. - * The input data is zipped with the result of the udf evaluation. + * Use a synchronous batched based system into order to calculate a [[PythonUDF]] without having to + * risk deadlock. */ @DeveloperApi case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: SparkPlan) @@ -340,50 +344,128 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: protected override def doExecute(): RDD[InternalRow] = { val childResults = child.execute().map(_.copy()) - val parent = childResults.mapPartitions { iter => - EvaluatePython.registerPicklers() // register pickler for Row + val bufferSize = childResults.context.conf.getInt("spark.buffer.size", 65536) + val reuseWorker = childResults.context.conf.getBoolean("spark.python.worker.reuse", true) + + childResults.mapPartitionsWithIndex { case (partitionIndex, iter) => + val startTime = System.currentTimeMillis + val env = SparkEnv.get + + val envVars = udf.envVars + val localdir = env.blockManager.diskBlockManager.localDirs.map( + f => f.getPath()).mkString(",") + envVars.put("SPARK_LOCAL_DIRS", localdir) // it's also used in monitor thread + if (reuseWorker) { + envVars.put("SPARK_REUSE_WORKER", "1") + } + + val worker: Socket = env.createPythonWorker(udf.pythonExec, envVars.asScala.toMap) + + var released = new AtomicBoolean(false) + + TaskContext.get.addTaskCompletionListener { context => + if (!reuseWorker || !released.get()) { + try { + worker.close() + } catch { + case e: Exception => + logWarning("Failed to close worker socket", e) + } + } + } + + val dataOut = new DataOutputStream(new BufferedOutputStream(worker.getOutputStream, bufferSize)) + val dataIn = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize)) + + PythonRDD.writeHeaderToStream( + dataOut, + udf.pythonIncludes, + udf.broadcastVars, + worker, + udf.pythonVer, + partitionIndex, // partition number isn't used + udf.command) + + // UDF mode + dataOut.writeInt(PySparkMode.UDF) + + EvaluatePython.registerPicklers() + val pickle = new Pickler + val unpickle = new Unpickler + val currentRow = newMutableProjection(udf.children, child.output)() val fields = udf.children.map(_.dataType) val schema = new StructType(fields.map(t => new StructField("", t, true)).toArray) - iter.grouped(100).map { inputRows => - val toBePickled = inputRows.map { row => - EvaluatePython.toJava(currentRow(row), schema) - }.toArray - pickle.dumps(toBePickled) - } - } - - val pyRDD = new PythonRDD( - parent, - udf.command, - udf.envVars, - udf.pythonIncludes, - false, - udf.pythonExec, - udf.pythonVer, - udf.broadcastVars, - udf.accumulator - ).mapPartitions { iter => - val pickle = new Unpickler - iter.flatMap { pickedResult => - val unpickledBatch = pickle.loads(pickedResult) - unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala - } - }.mapPartitions { iter => - val row = new GenericMutableRow(1) - iter.map { result => - row(0) = EvaluatePython.fromJava(result, udf.dataType) - row: InternalRow - } - } - childResults.zip(pyRDD).mapPartitions { iter => - val joinedRow = new JoinedRow() - iter.map { - case (row, udfResult) => - joinedRow(row, udfResult) - } + val groupedIterator = iter.grouped(100) + + // Add a sentinel at the end so that we can know when to finish the pyspark protocol + val groupedIteratorWithEnd = groupedIterator ++ Seq(BatchPythonEvaluation.SentinelEnd) + + groupedIteratorWithEnd.map { inputRows => + if (inputRows == BatchPythonEvaluation.SentinelEnd) { + // Finish + dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) + dataOut.writeInt(SpecialLengths.END_OF_STREAM) + dataOut.flush() + + // This should return null and complete the cleanup procedure + val shouldBeNull = PythonRDD.readPythonProcessSocket( + dataIn, + startTime, + reuseWorker, + udf.accumulator, + udf.envVars, + udf.pythonExec, + worker, + released) + + if (shouldBeNull != null) { + throw new RuntimeException("Cleanup procedure failed") + } + + // Return empty sequence, which will be flattened into nothing + Seq() + } else { + val toBePickled = inputRows.map { row => + EvaluatePython.toJava(currentRow(row), schema) + }.toArray + val batchBytes = pickle.dumps(toBePickled) + + dataOut.writeInt(batchBytes.length) + dataOut.write(batchBytes) + dataOut.flush() + + val myObj = PythonRDD.readPythonProcessSocket( + dataIn, + startTime, + reuseWorker, + udf.accumulator, + udf.envVars, + udf.pythonExec, + worker, + released) + + val unpickledBatch = unpickle.loads(myObj) + val unpickledBatchBuffer = unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala + + // TODO optimize performance + + val udfResults = unpickledBatchBuffer.map { result => + InternalRow.apply(EvaluatePython.fromJava(result, udf.dataType)) + } + + inputRows.zip(udfResults).map { + case (row, udfResult) => + new JoinedRow(row, udfResult) + } + } + }.flatten } } } + +private object BatchPythonEvaluation { + val SentinelEnd = Seq(InternalRow("END")) +} \ No newline at end of file From 2e812508d5458e19414a245179f795d74d9e63fd Mon Sep 17 00:00:00 2001 From: Justin Uang Date: Thu, 10 Sep 2015 14:06:17 -0400 Subject: [PATCH 2/3] Style changes and make SpecialLengths and PySparkMode private to spark --- .../apache/spark/api/python/PythonRDD.scala | 24 +++++++++++++++---- python/pyspark/worker.py | 1 + .../spark/sql/execution/pythonUDFs.scala | 7 +++--- 3 files changed, 25 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index aa0ca8df99d3..5fa2c724f0a1 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -158,7 +158,14 @@ private[spark] class PythonRDD( val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) val dataOut = new DataOutputStream(stream) - PythonRDD.writeHeaderToStream(dataOut, pythonIncludes, broadcastVars, worker, pythonVer, split.index, command) + PythonRDD.writeHeaderToStream( + dataOut, + pythonIncludes, + broadcastVars, + worker, + pythonVer, + split.index, + command) // RDD mode dataOut.writeInt(PySparkMode.RDD) @@ -235,12 +242,12 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends RDD[(Long, Array[Byte] val asJavaPairRDD : JavaPairRDD[Long, Array[Byte]] = JavaPairRDD.fromRDD(this) } -object PySparkMode { +private[spark] object PySparkMode { val RDD = 0 val UDF = 1 } -object SpecialLengths { +private[spark] object SpecialLengths { val END_OF_DATA_SECTION = -1 val PYTHON_EXCEPTION_THROWN = -2 val TIMING_DATA = -3 @@ -424,7 +431,16 @@ private[spark] object PythonRDD extends Logging { val diskBytesSpilled = stream.readLong() TaskContext.get.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled) TaskContext.get.taskMetrics.incDiskBytesSpilled(diskBytesSpilled) - readPythonProcessSocket(stream, startTime, reuseWorker, accumulator, envVars, pythonExec, worker, released, writerThreadException) + readPythonProcessSocket( + stream, + startTime, + reuseWorker, + accumulator, + envVars, + pythonExec, + worker, + released, + writerThreadException) case SpecialLengths.PYTHON_EXCEPTION_THROWN => // Signals that an exception has been thrown in python val exLength = stream.readInt() diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index bfe57dbb7316..37b1d176aa13 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -32,6 +32,7 @@ write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer from pyspark import shuffle + class PySparkMode(object): RDD = 0 UDF = 1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala index d6dce6dcd4b3..59f80fd03646 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala @@ -374,7 +374,8 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: } } - val dataOut = new DataOutputStream(new BufferedOutputStream(worker.getOutputStream, bufferSize)) + val dataOut = new DataOutputStream( + new BufferedOutputStream(worker.getOutputStream, bufferSize)) val dataIn = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize)) PythonRDD.writeHeaderToStream( @@ -403,7 +404,7 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: // Add a sentinel at the end so that we can know when to finish the pyspark protocol val groupedIteratorWithEnd = groupedIterator ++ Seq(BatchPythonEvaluation.SentinelEnd) - groupedIteratorWithEnd.map { inputRows => + groupedIteratorWithEnd.map { inputRows => if (inputRows == BatchPythonEvaluation.SentinelEnd) { // Finish dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) @@ -468,4 +469,4 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: private object BatchPythonEvaluation { val SentinelEnd = Seq(InternalRow("END")) -} \ No newline at end of file +} From 7fe4a0e4992ae8dffdb0af50644c0e4a573cb974 Mon Sep 17 00:00:00 2001 From: Justin Uang Date: Mon, 14 Sep 2015 18:48:30 -0400 Subject: [PATCH 3/3] Make PySparkMode private also in Python --- python/pyspark/worker.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 37b1d176aa13..b668badee401 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -33,7 +33,7 @@ from pyspark import shuffle -class PySparkMode(object): +class _PySparkMode(object): RDD = 0 UDF = 1 @@ -108,7 +108,7 @@ def main(infile, outfile): pyspark_mode = read_int(infile) - if pyspark_mode == PySparkMode.RDD: + if pyspark_mode == _PySparkMode.RDD: def process(): iterator = deserializer.load_stream(infile) serializer.dump_stream(func(split_index, iterator), outfile) @@ -117,7 +117,7 @@ def process(): profiler.profile(process) else: process() - elif pyspark_mode == PySparkMode.UDF: + elif pyspark_mode == _PySparkMode.UDF: pickle_serializer = PickleSerializer() batch_length = read_int(infile)