-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-25095][PySpark] Python support for BarrierTaskContext #22085
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7b48829
289146d
05c9609
e234a0a
243a5a3
ba0ccad
2a8f3cb
1cacd40
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,12 +20,14 @@ package org.apache.spark.api.python | |
| import java.io._ | ||
| import java.net._ | ||
| import java.nio.charset.StandardCharsets | ||
| import java.nio.charset.StandardCharsets.UTF_8 | ||
| import java.util.concurrent.atomic.AtomicBoolean | ||
|
|
||
| import scala.collection.JavaConverters._ | ||
|
|
||
| import org.apache.spark._ | ||
| import org.apache.spark.internal.Logging | ||
| import org.apache.spark.security.SocketAuthHelper | ||
| import org.apache.spark.util._ | ||
|
|
||
|
|
||
|
|
@@ -76,6 +78,12 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( | |
| // TODO: support accumulator in multiple UDF | ||
| protected val accumulator = funcs.head.funcs.head.accumulator | ||
|
|
||
| // Expose a ServerSocket to support method calls via socket from Python side. | ||
| private[spark] var serverSocket: Option[ServerSocket] = None | ||
|
|
||
| // Authentication helper used when serving method calls via socket from Python side. | ||
| private lazy val authHelper = new SocketAuthHelper(SparkEnv.get.conf) | ||
|
|
||
| def compute( | ||
| inputIterator: Iterator[IN], | ||
| partitionIndex: Int, | ||
|
|
@@ -180,7 +188,73 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( | |
| dataOut.writeInt(partitionIndex) | ||
| // Python version of driver | ||
| PythonRDD.writeUTF(pythonVer, dataOut) | ||
| // Init a ServerSocket to accept method calls from Python side. | ||
| val isBarrier = context.isInstanceOf[BarrierTaskContext] | ||
| if (isBarrier) { | ||
| serverSocket = Some(new ServerSocket(/* port */ 0, | ||
| /* backlog */ 1, | ||
| InetAddress.getByName("localhost"))) | ||
| // A call to accept() for ServerSocket shall block infinitely. | ||
| serverSocket.map(_.setSoTimeout(0)) | ||
| new Thread("accept-connections") { | ||
| setDaemon(true) | ||
|
|
||
| override def run(): Unit = { | ||
| while (!serverSocket.get.isClosed()) { | ||
| var sock: Socket = null | ||
| try { | ||
| sock = serverSocket.get.accept() | ||
| // Wait for function call from python side. | ||
| sock.setSoTimeout(10000) | ||
| val input = new DataInputStream(sock.getInputStream()) | ||
| input.readInt() match { | ||
| case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION => | ||
| // The barrier() function may wait infinitely, socket shall not timeout | ||
| // before the function finishes. | ||
| sock.setSoTimeout(0) | ||
| barrierAndServe(sock) | ||
|
|
||
| case _ => | ||
| val out = new DataOutputStream(new BufferedOutputStream( | ||
| sock.getOutputStream)) | ||
| writeUTF(BarrierTaskContextMessageProtocol.ERROR_UNRECOGNIZED_FUNCTION, out) | ||
| } | ||
| } catch { | ||
| case e: SocketException if e.getMessage.contains("Socket closed") => | ||
| // It is possible that the ServerSocket is not closed, but the native socket | ||
| // has already been closed, we shall catch and silently ignore this case. | ||
| } finally { | ||
| if (sock != null) { | ||
| sock.close() | ||
| } | ||
| } | ||
| } | ||
| } | ||
| }.start() | ||
| } | ||
| val secret = if (isBarrier) { | ||
| authHelper.secret | ||
| } else { | ||
| "" | ||
| } | ||
| // Close ServerSocket on task completion. | ||
| serverSocket.foreach { server => | ||
| context.addTaskCompletionListener(_ => server.close()) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is failing the Scala 2.12 build
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Addressed in #22229 |
||
| } | ||
| val boundPort: Int = serverSocket.map(_.getLocalPort).getOrElse(0) | ||
| if (boundPort == -1) { | ||
| val message = "ServerSocket failed to bind to Java side." | ||
| logError(message) | ||
| throw new SparkException(message) | ||
| } else if (isBarrier) { | ||
| logDebug(s"Started ServerSocket on port $boundPort.") | ||
| } | ||
| // Write out the TaskContextInfo | ||
| dataOut.writeBoolean(isBarrier) | ||
| dataOut.writeInt(boundPort) | ||
| val secretBytes = secret.getBytes(UTF_8) | ||
| dataOut.writeInt(secretBytes.length) | ||
| dataOut.write(secretBytes, 0, secretBytes.length) | ||
| dataOut.writeInt(context.stageId()) | ||
| dataOut.writeInt(context.partitionId()) | ||
| dataOut.writeInt(context.attemptNumber()) | ||
|
|
@@ -243,6 +317,32 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( | |
| } | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Gateway to call BarrierTaskContext.barrier(). | ||
| */ | ||
| def barrierAndServe(sock: Socket): Unit = { | ||
| require(serverSocket.isDefined, "No available ServerSocket to redirect the barrier() call.") | ||
|
|
||
| authHelper.authClient(sock) | ||
|
|
||
| val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) | ||
| try { | ||
| context.asInstanceOf[BarrierTaskContext].barrier() | ||
| writeUTF(BarrierTaskContextMessageProtocol.BARRIER_RESULT_SUCCESS, out) | ||
| } catch { | ||
| case e: SparkException => | ||
| writeUTF(e.getMessage, out) | ||
| } finally { | ||
| out.close() | ||
| } | ||
| } | ||
|
|
||
| def writeUTF(str: String, dataOut: DataOutputStream) { | ||
| val bytes = str.getBytes(UTF_8) | ||
| dataOut.writeInt(bytes.length) | ||
| dataOut.write(bytes) | ||
| } | ||
| } | ||
|
|
||
| abstract class ReaderIterator( | ||
|
|
@@ -465,3 +565,9 @@ private[spark] object SpecialLengths { | |
| val NULL = -5 | ||
| val START_ARROW_STREAM = -6 | ||
| } | ||
|
|
||
| private[spark] object BarrierTaskContextMessageProtocol { | ||
| val BARRIER_FUNCTION = 1 | ||
| val BARRIER_RESULT_SUCCESS = "success" | ||
| val ERROR_UNRECOGNIZED_FUNCTION = "Not recognized function call from python side." | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,6 +16,10 @@ | |
| # | ||
|
|
||
| from __future__ import print_function | ||
| import socket | ||
|
|
||
| from pyspark.java_gateway import do_server_auth | ||
| from pyspark.serializers import write_int, UTF8Deserializer | ||
|
|
||
|
|
||
| class TaskContext(object): | ||
|
|
@@ -95,3 +99,143 @@ def getLocalProperty(self, key): | |
| Get a local property set upstream in the driver, or None if it is missing. | ||
| """ | ||
| return self._localProperties.get(key, None) | ||
|
|
||
|
|
||
| BARRIER_FUNCTION = 1 | ||
|
|
||
|
|
||
| def _load_from_socket(port, auth_secret): | ||
| """ | ||
| Load data from a given socket, this is a blocking method thus only return when the socket | ||
| connection has been closed. | ||
|
|
||
| This is copied from context.py, while modified the message protocol. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be nicer if we can deduplciate it later. |
||
| """ | ||
| sock = None | ||
| # Support for both IPv4 and IPv6. | ||
| # On most of IPv6-ready systems, IPv6 will take precedence. | ||
| for res in socket.getaddrinfo("localhost", port, socket.AF_UNSPEC, socket.SOCK_STREAM): | ||
| af, socktype, proto, canonname, sa = res | ||
| sock = socket.socket(af, socktype, proto) | ||
| try: | ||
| # Do not allow timeout for socket reading operation. | ||
| sock.settimeout(None) | ||
| sock.connect(sa) | ||
| except socket.error: | ||
| sock.close() | ||
| sock = None | ||
| continue | ||
| break | ||
| if not sock: | ||
| raise Exception("could not open socket") | ||
|
|
||
| # We don't really need a socket file here, it's just for convenience that we can reuse the | ||
| # do_server_auth() function and data serialization methods. | ||
| sockfile = sock.makefile("rwb", 65536) | ||
|
|
||
| # Make a barrier() function call. | ||
| write_int(BARRIER_FUNCTION, sockfile) | ||
| sockfile.flush() | ||
|
|
||
| # Do server auth. | ||
| do_server_auth(sockfile, auth_secret) | ||
|
|
||
| # Collect result. | ||
| res = UTF8Deserializer().loads(sockfile) | ||
|
|
||
| # Release resources. | ||
| sockfile.close() | ||
| sock.close() | ||
|
|
||
| return res | ||
|
|
||
|
|
||
| class BarrierTaskContext(TaskContext): | ||
|
|
||
| """ | ||
| .. note:: Experimental | ||
|
|
||
| A TaskContext with extra info and tooling for a barrier stage. To access the BarrierTaskContext | ||
| for a running task, use: | ||
| L{BarrierTaskContext.get()}. | ||
|
|
||
| .. versionadded:: 2.4.0 | ||
| """ | ||
|
|
||
| _port = None | ||
| _secret = None | ||
|
|
||
| def __init__(self): | ||
| """Construct a BarrierTaskContext, use get instead""" | ||
| pass | ||
|
|
||
| @classmethod | ||
| def _getOrCreate(cls): | ||
| """Internal function to get or create global BarrierTaskContext.""" | ||
| if cls._taskContext is None: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Q: Does it handle python worker reuse?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IIUC reuse python worker just means we start a python worker from a daemon thread, it shall not affect the input/output files related to worker.py. |
||
| cls._taskContext = BarrierTaskContext() | ||
| return cls._taskContext | ||
|
|
||
| @classmethod | ||
| def get(cls): | ||
| """ | ||
| Return the currently active BarrierTaskContext. This can be called inside of user functions | ||
| to access contextual information about running tasks. | ||
|
|
||
| .. note:: Must be called on the worker, not the driver. Returns None if not initialized. | ||
| """ | ||
| return cls._taskContext | ||
|
|
||
| @classmethod | ||
| def _initialize(cls, port, secret): | ||
| """ | ||
| Initialize BarrierTaskContext, other methods within BarrierTaskContext can only be called | ||
| after BarrierTaskContext is initialized. | ||
| """ | ||
| cls._port = port | ||
| cls._secret = secret | ||
|
|
||
| def barrier(self): | ||
| """ | ||
| .. note:: Experimental | ||
|
|
||
| Sets a global barrier and waits until all tasks in this stage hit this barrier. | ||
| Note this method is only allowed for a BarrierTaskContext. | ||
|
|
||
| .. versionadded:: 2.4.0 | ||
| """ | ||
| if self._port is None or self._secret is None: | ||
| raise Exception("Not supported to call barrier() before initialize " + | ||
| "BarrierTaskContext.") | ||
| else: | ||
| _load_from_socket(self._port, self._secret) | ||
|
|
||
| def getTaskInfos(self): | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not available temporarily.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed |
||
| """ | ||
| .. note:: Experimental | ||
|
|
||
| Returns the all task infos in this barrier stage, the task infos are ordered by | ||
| partitionId. | ||
| Note this method is only allowed for a BarrierTaskContext. | ||
|
|
||
| .. versionadded:: 2.4.0 | ||
| """ | ||
| if self._port is None or self._secret is None: | ||
| raise Exception("Not supported to call getTaskInfos() before initialize " + | ||
| "BarrierTaskContext.") | ||
| else: | ||
| addresses = self._localProperties.get("addresses", "") | ||
| return [BarrierTaskInfo(h.strip()) for h in addresses.split(",")] | ||
|
|
||
|
|
||
| class BarrierTaskInfo(object): | ||
| """ | ||
| .. note:: Experimental | ||
|
|
||
| Carries all task infos of a barrier task. | ||
|
|
||
| .. versionadded:: 2.4.0 | ||
| """ | ||
|
|
||
| def __init__(self, address): | ||
| self.address = address | ||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why isn't authentication the first thing which happens on this connection? I don't think anything bad can happen in this case, but it just makes it more likely we leave a security hole here later on.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(I'd also like to do some refactoring of the socket setup code in python, and that can go further if we do authenticaion first here)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for catching this, yea I agree it would be better to move the authentication before recognising functions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok I'm doing this -- SPARK-25253, will open a pr shortly