Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@ package org.apache.spark.api.python
import java.io._
import java.net._
import java.nio.charset.StandardCharsets
import java.nio.charset.StandardCharsets.UTF_8
import java.util.concurrent.atomic.AtomicBoolean

import scala.collection.JavaConverters._

import org.apache.spark._
import org.apache.spark.internal.Logging
import org.apache.spark.security.SocketAuthHelper
import org.apache.spark.util._


Expand Down Expand Up @@ -76,6 +78,12 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
// TODO: support accumulator in multiple UDF
protected val accumulator = funcs.head.funcs.head.accumulator

// Expose a ServerSocket to support method calls via socket from Python side.
private[spark] var serverSocket: Option[ServerSocket] = None

// Authentication helper used when serving method calls via socket from Python side.
private lazy val authHelper = new SocketAuthHelper(SparkEnv.get.conf)

def compute(
inputIterator: Iterator[IN],
partitionIndex: Int,
Expand Down Expand Up @@ -180,7 +188,73 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
dataOut.writeInt(partitionIndex)
// Python version of driver
PythonRDD.writeUTF(pythonVer, dataOut)
// Init a ServerSocket to accept method calls from Python side.
val isBarrier = context.isInstanceOf[BarrierTaskContext]
if (isBarrier) {
serverSocket = Some(new ServerSocket(/* port */ 0,
/* backlog */ 1,
InetAddress.getByName("localhost")))
// A call to accept() for ServerSocket shall block infinitely.
serverSocket.map(_.setSoTimeout(0))
new Thread("accept-connections") {
setDaemon(true)

override def run(): Unit = {
while (!serverSocket.get.isClosed()) {
var sock: Socket = null
try {
sock = serverSocket.get.accept()
// Wait for function call from python side.
sock.setSoTimeout(10000)
val input = new DataInputStream(sock.getInputStream())
Copy link
Contributor

@squito squito Aug 27, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why isn't authentication the first thing which happens on this connection? I don't think anything bad can happen in this case, but it just makes it more likely we leave a security hole here later on.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I'd also like to do some refactoring of the socket setup code in python, and that can go further if we do authenticaion first here)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching this, yea I agree it would be better to move the authentication before recognising functions.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok I'm doing this -- SPARK-25253, will open a pr shortly

input.readInt() match {
case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION =>
// The barrier() function may wait infinitely, socket shall not timeout
// before the function finishes.
sock.setSoTimeout(0)
barrierAndServe(sock)

case _ =>
val out = new DataOutputStream(new BufferedOutputStream(
sock.getOutputStream))
writeUTF(BarrierTaskContextMessageProtocol.ERROR_UNRECOGNIZED_FUNCTION, out)
}
} catch {
case e: SocketException if e.getMessage.contains("Socket closed") =>
// It is possible that the ServerSocket is not closed, but the native socket
// has already been closed, we shall catch and silently ignore this case.
} finally {
if (sock != null) {
sock.close()
}
}
}
}
}.start()
}
val secret = if (isBarrier) {
authHelper.secret
} else {
""
}
// Close ServerSocket on task completion.
serverSocket.foreach { server =>
context.addTaskCompletionListener(_ => server.close())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is failing the Scala 2.12 build

[error] /Users/d_tsai/dev/apache-spark/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala:242: ambiguous reference to overloaded definition,
[error] both method addTaskCompletionListener in class TaskContext of type [U](f: org.apache.spark.TaskContext => U)org.apache.spark.TaskContext
[error] and  method addTaskCompletionListener in class TaskContext of type (listener: org.apache.spark.util.TaskCompletionListener)org.apache.spark.TaskContext
[error] match argument types (org.apache.spark.TaskContext => Unit)
[error]           context.addTaskCompletionListener(_ => server.close())
[error]                   ^
[error] one error found
[error] Compile failed at Aug 24, 2018 1:56:06 PM [31.582s]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in #22229

}
val boundPort: Int = serverSocket.map(_.getLocalPort).getOrElse(0)
if (boundPort == -1) {
val message = "ServerSocket failed to bind to Java side."
logError(message)
throw new SparkException(message)
} else if (isBarrier) {
logDebug(s"Started ServerSocket on port $boundPort.")
}
// Write out the TaskContextInfo
dataOut.writeBoolean(isBarrier)
dataOut.writeInt(boundPort)
val secretBytes = secret.getBytes(UTF_8)
dataOut.writeInt(secretBytes.length)
dataOut.write(secretBytes, 0, secretBytes.length)
dataOut.writeInt(context.stageId())
dataOut.writeInt(context.partitionId())
dataOut.writeInt(context.attemptNumber())
Expand Down Expand Up @@ -243,6 +317,32 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
}
}
}

/**
* Gateway to call BarrierTaskContext.barrier().
*/
def barrierAndServe(sock: Socket): Unit = {
require(serverSocket.isDefined, "No available ServerSocket to redirect the barrier() call.")

authHelper.authClient(sock)

val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream))
try {
context.asInstanceOf[BarrierTaskContext].barrier()
writeUTF(BarrierTaskContextMessageProtocol.BARRIER_RESULT_SUCCESS, out)
} catch {
case e: SparkException =>
writeUTF(e.getMessage, out)
} finally {
out.close()
}
}

def writeUTF(str: String, dataOut: DataOutputStream) {
val bytes = str.getBytes(UTF_8)
dataOut.writeInt(bytes.length)
dataOut.write(bytes)
}
}

abstract class ReaderIterator(
Expand Down Expand Up @@ -465,3 +565,9 @@ private[spark] object SpecialLengths {
val NULL = -5
val START_ARROW_STREAM = -6
}

private[spark] object BarrierTaskContextMessageProtocol {
val BARRIER_FUNCTION = 1
val BARRIER_RESULT_SUCCESS = "success"
val ERROR_UNRECOGNIZED_FUNCTION = "Not recognized function call from python side."
}
7 changes: 7 additions & 0 deletions python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,13 @@ def write_int(value, stream):
stream.write(struct.pack("!i", value))


def read_bool(stream):
length = stream.read(1)
if not length:
raise EOFError
return struct.unpack("!?", length)[0]


def write_with_length(obj, stream):
write_int(len(obj), stream)
stream.write(obj)
Expand Down
144 changes: 144 additions & 0 deletions python/pyspark/taskcontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
#

from __future__ import print_function
import socket

from pyspark.java_gateway import do_server_auth
from pyspark.serializers import write_int, UTF8Deserializer


class TaskContext(object):
Expand Down Expand Up @@ -95,3 +99,143 @@ def getLocalProperty(self, key):
Get a local property set upstream in the driver, or None if it is missing.
"""
return self._localProperties.get(key, None)


BARRIER_FUNCTION = 1


def _load_from_socket(port, auth_secret):
"""
Load data from a given socket, this is a blocking method thus only return when the socket
connection has been closed.

This is copied from context.py, while modified the message protocol.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nicer if we can deduplciate it later.

"""
sock = None
# Support for both IPv4 and IPv6.
# On most of IPv6-ready systems, IPv6 will take precedence.
for res in socket.getaddrinfo("localhost", port, socket.AF_UNSPEC, socket.SOCK_STREAM):
af, socktype, proto, canonname, sa = res
sock = socket.socket(af, socktype, proto)
try:
# Do not allow timeout for socket reading operation.
sock.settimeout(None)
sock.connect(sa)
except socket.error:
sock.close()
sock = None
continue
break
if not sock:
raise Exception("could not open socket")

# We don't really need a socket file here, it's just for convenience that we can reuse the
# do_server_auth() function and data serialization methods.
sockfile = sock.makefile("rwb", 65536)

# Make a barrier() function call.
write_int(BARRIER_FUNCTION, sockfile)
sockfile.flush()

# Do server auth.
do_server_auth(sockfile, auth_secret)

# Collect result.
res = UTF8Deserializer().loads(sockfile)

# Release resources.
sockfile.close()
sock.close()

return res


class BarrierTaskContext(TaskContext):

"""
.. note:: Experimental

A TaskContext with extra info and tooling for a barrier stage. To access the BarrierTaskContext
for a running task, use:
L{BarrierTaskContext.get()}.

.. versionadded:: 2.4.0
"""

_port = None
_secret = None

def __init__(self):
"""Construct a BarrierTaskContext, use get instead"""
pass

@classmethod
def _getOrCreate(cls):
"""Internal function to get or create global BarrierTaskContext."""
if cls._taskContext is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: Does it handle python worker reuse?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC reuse python worker just means we start a python worker from a daemon thread, it shall not affect the input/output files related to worker.py.

cls._taskContext = BarrierTaskContext()
return cls._taskContext

@classmethod
def get(cls):
"""
Return the currently active BarrierTaskContext. This can be called inside of user functions
to access contextual information about running tasks.

.. note:: Must be called on the worker, not the driver. Returns None if not initialized.
"""
return cls._taskContext

@classmethod
def _initialize(cls, port, secret):
"""
Initialize BarrierTaskContext, other methods within BarrierTaskContext can only be called
after BarrierTaskContext is initialized.
"""
cls._port = port
cls._secret = secret

def barrier(self):
"""
.. note:: Experimental

Sets a global barrier and waits until all tasks in this stage hit this barrier.
Note this method is only allowed for a BarrierTaskContext.

.. versionadded:: 2.4.0
"""
if self._port is None or self._secret is None:
raise Exception("Not supported to call barrier() before initialize " +
"BarrierTaskContext.")
else:
_load_from_socket(self._port, self._secret)

def getTaskInfos(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not available temporarily.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

"""
.. note:: Experimental

Returns the all task infos in this barrier stage, the task infos are ordered by
partitionId.
Note this method is only allowed for a BarrierTaskContext.

.. versionadded:: 2.4.0
"""
if self._port is None or self._secret is None:
raise Exception("Not supported to call getTaskInfos() before initialize " +
"BarrierTaskContext.")
else:
addresses = self._localProperties.get("addresses", "")
return [BarrierTaskInfo(h.strip()) for h in addresses.split(",")]


class BarrierTaskInfo(object):
"""
.. note:: Experimental

Carries all task infos of a barrier task.

.. versionadded:: 2.4.0
"""

def __init__(self, address):
self.address = address
36 changes: 35 additions & 1 deletion python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
from pyspark.shuffle import Aggregator, ExternalMerger, ExternalSorter
from pyspark import shuffle
from pyspark.profiler import BasicProfiler
from pyspark.taskcontext import TaskContext
from pyspark.taskcontext import BarrierTaskContext, TaskContext

_have_scipy = False
_have_numpy = False
Expand Down Expand Up @@ -588,6 +588,40 @@ def test_get_local_property(self):
finally:
self.sc.setLocalProperty(key, None)

def test_barrier(self):
"""
Verify that BarrierTaskContext.barrier() performs global sync among all barrier tasks
within a stage.
"""
rdd = self.sc.parallelize(range(10), 4)

def f(iterator):
yield sum(iterator)

def context_barrier(x):
tc = BarrierTaskContext.get()
time.sleep(random.randint(1, 10))
tc.barrier()
return time.time()

times = rdd.barrier().mapPartitions(f).map(context_barrier).collect()
self.assertTrue(max(times) - min(times) < 1)

def test_barrier_infos(self):
"""
Verify that BarrierTaskContext.getTaskInfos() returns a list of all task infos in the
barrier stage.
"""
rdd = self.sc.parallelize(range(10), 4)

def f(iterator):
yield sum(iterator)

taskInfos = rdd.barrier().mapPartitions(f).map(lambda x: BarrierTaskContext.get()
.getTaskInfos()).collect()
self.assertTrue(len(taskInfos) == 4)
self.assertTrue(len(taskInfos[0]) == 4)


class RDDTests(ReusedPySparkTestCase):

Expand Down
16 changes: 13 additions & 3 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@
from pyspark.accumulators import _accumulatorRegistry
from pyspark.broadcast import Broadcast, _broadcastRegistry
from pyspark.java_gateway import do_server_auth
from pyspark.taskcontext import TaskContext
from pyspark.taskcontext import BarrierTaskContext, TaskContext
from pyspark.files import SparkFiles
from pyspark.rdd import PythonEvalType
from pyspark.serializers import write_with_length, write_int, read_long, \
from pyspark.serializers import write_with_length, write_int, read_long, read_bool, \
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \
BatchedSerializer, ArrowStreamPandasSerializer
from pyspark.sql.types import to_arrow_type
Expand Down Expand Up @@ -259,8 +259,18 @@ def main(infile, outfile):
"PYSPARK_DRIVER_PYTHON are correctly set.") %
("%d.%d" % sys.version_info[:2], version))

# read inputs only for a barrier task
isBarrier = read_bool(infile)
boundPort = read_int(infile)
secret = UTF8Deserializer().loads(infile)
# initialize global state
taskContext = TaskContext._getOrCreate()
taskContext = None
if isBarrier:
taskContext = BarrierTaskContext._getOrCreate()
BarrierTaskContext._initialize(boundPort, secret)
else:
taskContext = TaskContext._getOrCreate()
# read inputs for TaskContext info
taskContext._stageId = read_int(infile)
taskContext._partitionId = read_int(infile)
taskContext._attemptNumber = read_int(infile)
Expand Down