Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
273 changes: 169 additions & 104 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.api.python

import java.io._
import java.net._
import java.util.concurrent.atomic.AtomicBoolean
import java.util.{Collections, ArrayList => JArrayList, List => JList, Map => JMap}

import scala.collection.JavaConverters._
Expand Down Expand Up @@ -72,14 +73,14 @@ private[spark] class PythonRDD(
}
val worker: Socket = env.createPythonWorker(pythonExec, envVars.asScala.toMap)
// Whether is the worker released into idle pool
@volatile var released = false
var released = new AtomicBoolean(false)

// Start a thread to feed the process input from our parent's iterator
val writerThread = new WriterThread(env, worker, split, context)

context.addTaskCompletionListener { context =>
writerThread.shutdownOnTaskCompletion()
if (!reuse_worker || !released) {
if (!reuse_worker || !released.get()) {
try {
worker.close()
} catch {
Expand Down Expand Up @@ -107,73 +108,17 @@ private[spark] class PythonRDD(
if (writerThread.exception.isDefined) {
throw writerThread.exception.get
}
try {
stream.readInt() match {
case length if length > 0 =>
val obj = new Array[Byte](length)
stream.readFully(obj)
obj
case 0 => Array.empty[Byte]
case SpecialLengths.TIMING_DATA =>
// Timing data from worker
val bootTime = stream.readLong()
val initTime = stream.readLong()
val finishTime = stream.readLong()
val boot = bootTime - startTime
val init = initTime - bootTime
val finish = finishTime - initTime
val total = finishTime - startTime
logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot,
init, finish))
val memoryBytesSpilled = stream.readLong()
val diskBytesSpilled = stream.readLong()
context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled)
context.taskMetrics.incDiskBytesSpilled(diskBytesSpilled)
read()
case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
// Signals that an exception has been thrown in python
val exLength = stream.readInt()
val obj = new Array[Byte](exLength)
stream.readFully(obj)
throw new PythonException(new String(obj, UTF_8),
writerThread.exception.getOrElse(null))
case SpecialLengths.END_OF_DATA_SECTION =>
// We've finished the data section of the output, but we can still
// read some accumulator updates:
val numAccumulatorUpdates = stream.readInt()
(1 to numAccumulatorUpdates).foreach { _ =>
val updateLen = stream.readInt()
val update = new Array[Byte](updateLen)
stream.readFully(update)
accumulator += Collections.singletonList(update)
}
// Check whether the worker is ready to be re-used.
if (stream.readInt() == SpecialLengths.END_OF_STREAM) {
if (reuse_worker) {
env.releasePythonWorker(pythonExec, envVars.asScala.toMap, worker)
released = true
}
}
null
}
} catch {

case e: Exception if context.isInterrupted =>
logDebug("Exception thrown after task interruption", e)
throw new TaskKilledException

case e: Exception if env.isStopped =>
logDebug("Exception thrown after context is stopped", e)
null // exit silently

case e: Exception if writerThread.exception.isDefined =>
logError("Python worker exited unexpectedly (crashed)", e)
logError("This may have been caused by a prior exception:", writerThread.exception.get)
throw writerThread.exception.get

case eof: EOFException =>
throw new SparkException("Python worker exited unexpectedly (crashed)", eof)
}
PythonRDD.readPythonProcessSocket(
stream,
startTime,
reuse_worker,
accumulator,
envVars,
pythonExec,
worker,
released,
writerThread.exception)
}

var _nextObj = read()
Expand All @@ -192,6 +137,7 @@ private[spark] class PythonRDD(
class WriterThread(env: SparkEnv, worker: Socket, split: Partition, context: TaskContext)
extends Thread(s"stdout writer for $pythonExec") {


@volatile private var _exception: Exception = null

setDaemon(true)
Expand All @@ -208,45 +154,26 @@ private[spark] class PythonRDD(
override def run(): Unit = Utils.logUncaughtExceptions {
try {
TaskContext.setTaskContext(context)

val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
val dataOut = new DataOutputStream(stream)
// Partition index
dataOut.writeInt(split.index)
// Python version of driver
PythonRDD.writeUTF(pythonVer, dataOut)
// sparkFilesDir
PythonRDD.writeUTF(SparkFiles.getRootDirectory, dataOut)
// Python includes (*.zip and *.egg files)
dataOut.writeInt(pythonIncludes.size())
for (include <- pythonIncludes.asScala) {
PythonRDD.writeUTF(include, dataOut)
}
// Broadcast variables
val oldBids = PythonRDD.getWorkerBroadcasts(worker)
val newBids = broadcastVars.asScala.map(_.id).toSet
// number of different broadcasts
val toRemove = oldBids.diff(newBids)
val cnt = toRemove.size + newBids.diff(oldBids).size
dataOut.writeInt(cnt)
for (bid <- toRemove) {
// remove the broadcast from worker
dataOut.writeLong(- bid - 1) // bid >= 0
oldBids.remove(bid)
}
for (broadcast <- broadcastVars.asScala) {
if (!oldBids.contains(broadcast.id)) {
// send new broadcast
dataOut.writeLong(broadcast.id)
PythonRDD.writeUTF(broadcast.value.path, dataOut)
oldBids.add(broadcast.id)
}
}
dataOut.flush()
// Serialized command:
dataOut.writeInt(command.length)
dataOut.write(command)

PythonRDD.writeHeaderToStream(
dataOut,
pythonIncludes,
broadcastVars,
worker,
pythonVer,
split.index,
command)

// RDD mode
dataOut.writeInt(PySparkMode.RDD)

// Data values
PythonRDD.writeIteratorToStream(firstParent.iterator(split, context), dataOut)

// Finish
dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
dataOut.writeInt(SpecialLengths.END_OF_STREAM)
dataOut.flush()
Expand Down Expand Up @@ -315,7 +242,12 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends RDD[(Long, Array[Byte]
val asJavaPairRDD : JavaPairRDD[Long, Array[Byte]] = JavaPairRDD.fromRDD(this)
}

private object SpecialLengths {
private[spark] object PySparkMode {
val RDD = 0
val UDF = 1
}

private[spark] object SpecialLengths {
val END_OF_DATA_SECTION = -1
val PYTHON_EXCEPTION_THROWN = -2
val TIMING_DATA = -3
Expand Down Expand Up @@ -420,6 +352,139 @@ private[spark] object PythonRDD extends Logging {
iter.foreach(write)
}

def writeHeaderToStream(
dataOut: DataOutputStream,
pythonIncludes: JList[String],
broadcastVars: JList[Broadcast[PythonBroadcast]],
worker: Socket,
pythonVer: String,
partitionIndex: Int,
command: Array[Byte]) {
// Partition index
dataOut.writeInt(partitionIndex)
// Python version of driver
PythonRDD.writeUTF(pythonVer, dataOut)
// sparkFilesDir
PythonRDD.writeUTF(SparkFiles.getRootDirectory, dataOut)
// Python includes (*.zip and *.egg files)
dataOut.writeInt(pythonIncludes.size())
for (include <- pythonIncludes.asScala) {
PythonRDD.writeUTF(include, dataOut)
}
// Broadcast variables
val oldBids = PythonRDD.getWorkerBroadcasts(worker)
val newBids = broadcastVars.asScala.map(_.id).toSet
// number of different broadcasts
val toRemove = oldBids.diff(newBids)
val cnt = toRemove.size + newBids.diff(oldBids).size
dataOut.writeInt(cnt)
for (bid <- toRemove) {
// remove the broadcast from worker
dataOut.writeLong(- bid - 1) // bid >= 0
oldBids.remove(bid)
}
for (broadcast <- broadcastVars.asScala) {
if (!oldBids.contains(broadcast.id)) {
// send new broadcast
dataOut.writeLong(broadcast.id)
PythonRDD.writeUTF(broadcast.value.path, dataOut)
oldBids.add(broadcast.id)
}
}

// Serialized command:
dataOut.writeInt(command.length)
dataOut.write(command)

dataOut.flush()
}

def readPythonProcessSocket(
stream: DataInputStream,
startTime: Long,
reuseWorker: Boolean,
accumulator: Accumulator[JList[Array[Byte]]],
envVars: JMap[String, String],
pythonExec: String,
worker: Socket,
released: AtomicBoolean,
writerThreadException: Option[Exception] = None) : Array[Byte] = {
try {
stream.readInt() match {
case length if length > 0 =>
val obj = new Array[Byte](length)
stream.readFully(obj)
obj
case 0 => Array.empty[Byte]
case SpecialLengths.TIMING_DATA =>
// Timing data from worker
val bootTime = stream.readLong()
val initTime = stream.readLong()
val finishTime = stream.readLong()
val boot = bootTime - startTime
val init = initTime - bootTime
val finish = finishTime - initTime
val total = finishTime - startTime
logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot,
init, finish))
val memoryBytesSpilled = stream.readLong()
val diskBytesSpilled = stream.readLong()
TaskContext.get.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled)
TaskContext.get.taskMetrics.incDiskBytesSpilled(diskBytesSpilled)
readPythonProcessSocket(
stream,
startTime,
reuseWorker,
accumulator,
envVars,
pythonExec,
worker,
released,
writerThreadException)
case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
// Signals that an exception has been thrown in python
val exLength = stream.readInt()
val obj = new Array[Byte](exLength)
stream.readFully(obj)
throw new PythonException(new String(obj, UTF_8), writerThreadException.getOrElse(null))
case SpecialLengths.END_OF_DATA_SECTION =>
// We've finished the data section of the output, but we can still
// read some accumulator updates:
val numAccumulatorUpdates = stream.readInt()
(1 to numAccumulatorUpdates).foreach { _ =>
val updateLen = stream.readInt()
val update = new Array[Byte](updateLen)
stream.readFully(update)
accumulator += Collections.singletonList(update)
}
// Check whether the worker is ready to be re-used.
if (stream.readInt() == SpecialLengths.END_OF_STREAM) {
if (reuseWorker) {
SparkEnv.get.releasePythonWorker(pythonExec, envVars.asScala.toMap, worker)
released.set(true)
}
}
null
}
} catch {
case e: Exception if TaskContext.get.isInterrupted =>
logDebug("Exception thrown after task interruption", e)
throw new TaskKilledException

case e: Exception if SparkEnv.get.isStopped =>
logDebug("Exception thrown after context is stopped", e)
null // exit silently

case e: Exception if writerThreadException.isDefined =>
logError("Python worker exited unexpectedly (crashed)", e)
logError("This may have been caused by a prior exception:", writerThreadException.get)
throw writerThreadException.get

case eof: EOFException =>
throw new SparkException("Python worker exited unexpectedly (crashed)", eof)
}
}

/**
* Create an RDD from a path using [[org.apache.hadoop.mapred.SequenceFileInputFormat]],
* key and value class.
Expand Down
44 changes: 38 additions & 6 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer
from pyspark import shuffle


class _PySparkMode(object):
RDD = 0
UDF = 1

pickleSer = PickleSerializer()
utf8_deserializer = UTF8Deserializer()

Expand Down Expand Up @@ -101,14 +106,41 @@ def main(infile, outfile):
func, profiler, deserializer, serializer = command
init_time = time.time()

def process():
iterator = deserializer.load_stream(infile)
serializer.dump_stream(func(split_index, iterator), outfile)
pyspark_mode = read_int(infile)

if pyspark_mode == _PySparkMode.RDD:
def process():
iterator = deserializer.load_stream(infile)
serializer.dump_stream(func(split_index, iterator), outfile)

if profiler:
profiler.profile(process)
else:
process()
elif pyspark_mode == _PySparkMode.UDF:
pickle_serializer = PickleSerializer()

batch_length = read_int(infile)

if profiler:
profiler.profile(process)
while batch_length != SpecialLengths.END_OF_DATA_SECTION:
batch_bytes = infile.read(batch_length)
batch_objs = pickle_serializer.loads(batch_bytes)

udf_result_objs = list(func(split_index, iter(batch_objs)))
pickled_bytes = pickle_serializer.dumps(udf_result_objs)
write_int(len(pickled_bytes), outfile)

if sys.version_info[0:2] <= (2, 6):
outfile.write(str(pickled_bytes))
else:
outfile.write(pickled_bytes)

outfile.flush()

batch_length = read_int(infile)
else:
process()
raise Exception("Unknown pyspark mode: {}".format(pyspark_mode))

except Exception:
try:
write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
Expand Down
Loading