Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,16 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String

// Attempt to connect, restart and retry once if it fails
try {
new Socket(daemonHost, daemonPort)
val socket = new Socket(daemonHost, daemonPort)
val launchStatus = new DataInputStream(socket.getInputStream).readInt()
if (launchStatus != 0) {
throw new IllegalStateException("Python daemon failed to launch worker")
}
socket
} catch {
case exc: SocketException =>
logWarning("Python daemon unexpectedly quit, attempting to restart")
logWarning("Failed to open socket to Python daemon:", exc)
logWarning("Assuming that daemon unexpectedly quit, attempting to restart")
stopDaemon()
startDaemon()
new Socket(daemonHost, daemonPort)
Expand Down
179 changes: 71 additions & 108 deletions python/pyspark/daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,64 +15,39 @@
# limitations under the License.
#

import numbers
import os
import signal
import select
import socket
import sys
import traceback
import multiprocessing
from ctypes import c_bool
from errno import EINTR, ECHILD
from socket import AF_INET, SOCK_STREAM, SOMAXCONN
from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN
from pyspark.worker import main as worker_main
from pyspark.serializers import write_int

try:
POOLSIZE = multiprocessing.cpu_count()
except NotImplementedError:
POOLSIZE = 4

exit_flag = multiprocessing.Value(c_bool, False)


def should_exit():
global exit_flag
return exit_flag.value


def compute_real_exit_code(exit_code):
# SystemExit's code can be integer or string, but os._exit only accepts integers
import numbers
if isinstance(exit_code, numbers.Integral):
return exit_code
else:
return 1


def worker(listen_sock):
def worker(sock):
"""
Called by a worker process after the fork().
"""
# Redirect stdout to stderr
os.dup2(2, 1)
sys.stdout = sys.stderr # The sys.stdout object is different from file descriptor 1

# Manager sends SIGHUP to request termination of workers in the pool
def handle_sighup(*args):
assert should_exit()
signal.signal(SIGHUP, handle_sighup)

# Cleanup zombie children
def handle_sigchld(*args):
pid = status = None
try:
while (pid, status) != (0, 0):
pid, status = os.waitpid(0, os.WNOHANG)
except EnvironmentError as err:
if err.errno == EINTR:
# retry
handle_sigchld()
elif err.errno != ECHILD:
raise
signal.signal(SIGCHLD, handle_sigchld)
signal.signal(SIGHUP, SIG_DFL)
signal.signal(SIGCHLD, SIG_DFL)
signal.signal(SIGTERM, SIG_DFL)

# Blocks until the socket is closed by draining the input stream
# until it raises an exception or returns EOF.
Expand All @@ -85,55 +60,23 @@ def waitSocketClose(sock):
except:
pass

# Handle clients
while not should_exit():
# Wait until a client arrives or we have to exit
sock = None
while not should_exit() and sock is None:
try:
sock, addr = listen_sock.accept()
except EnvironmentError as err:
if err.errno != EINTR:
raise

if sock is not None:
# Fork a child to handle the client.
# The client is handled in the child so that the manager
# never receives SIGCHLD unless a worker crashes.
if os.fork() == 0:
# Leave the worker pool
signal.signal(SIGHUP, SIG_DFL)
signal.signal(SIGCHLD, SIG_DFL)
listen_sock.close()
# Read the socket using fdopen instead of socket.makefile() because the latter
# seems to be very slow; note that we need to dup() the file descriptor because
# otherwise writes also cause a seek that makes us miss data on the read side.
infile = os.fdopen(os.dup(sock.fileno()), "a+", 65536)
outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536)
exit_code = 0
try:
worker_main(infile, outfile)
except SystemExit as exc:
exit_code = exc.code
finally:
outfile.flush()
# The Scala side will close the socket upon task completion.
waitSocketClose(sock)
os._exit(compute_real_exit_code(exit_code))
else:
sock.close()


def launch_worker(listen_sock):
if os.fork() == 0:
try:
worker(listen_sock)
except Exception as err:
traceback.print_exc()
os._exit(1)
else:
assert should_exit()
os._exit(0)
# Read the socket using fdopen instead of socket.makefile() because the latter
# seems to be very slow; note that we need to dup() the file descriptor because
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file had two different versions of this "cleanup zombie children" handle_sigchld function that are slightly different, so it might be a good idea to understand whether there's a reason why they differed.

# otherwise writes also cause a seek that makes us miss data on the read side.
infile = os.fdopen(os.dup(sock.fileno()), "a+", 65536)
outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536)
exit_code = 0
try:
write_int(0, outfile) # Acknowledge that the fork was successful
outfile.flush()
worker_main(infile, outfile)
except SystemExit as exc:
exit_code = exc.code
finally:
outfile.flush()
# The Scala side will close the socket upon task completion.
waitSocketClose(sock)
os._exit(compute_real_exit_code(exit_code))


def manager():
Expand All @@ -143,29 +86,28 @@ def manager():
# Create a listening socket on the AF_INET loopback interface
listen_sock = socket.socket(AF_INET, SOCK_STREAM)
listen_sock.bind(('127.0.0.1', 0))
listen_sock.listen(max(1024, 2 * POOLSIZE, SOMAXCONN))
listen_sock.listen(max(1024, SOMAXCONN))
listen_host, listen_port = listen_sock.getsockname()
write_int(listen_port, sys.stdout)

# Launch initial worker pool
for idx in range(POOLSIZE):
launch_worker(listen_sock)
listen_sock.close()

def shutdown():
global exit_flag
exit_flag.value = True
def shutdown(code):
signal.signal(SIGTERM, SIG_DFL)
# Send SIGHUP to notify workers of shutdown
os.kill(0, SIGHUP)
exit(code)

# Gracefully exit on SIGTERM, don't die on SIGHUP
signal.signal(SIGTERM, lambda signum, frame: shutdown())
signal.signal(SIGHUP, SIG_IGN)
def handle_sigterm(*args):
shutdown(1)
signal.signal(SIGTERM, handle_sigterm) # Gracefully exit on SIGTERM
signal.signal(SIGHUP, SIG_IGN) # Don't die on SIGHUP

# Cleanup zombie children
def handle_sigchld(*args):
try:
pid, status = os.waitpid(0, os.WNOHANG)
if status != 0 and not should_exit():
raise RuntimeError("worker crashed: %s, %s" % (pid, status))
if status != 0:
msg = "worker %s crashed abruptly with exit status %s" % (pid, status)
print >> sys.stderr, msg
except EnvironmentError as err:
if err.errno not in (ECHILD, EINTR):
raise
Expand All @@ -174,20 +116,41 @@ def handle_sigchld(*args):
# Initialization complete
sys.stdout.close()
try:
while not should_exit():
while True:
try:
# Spark tells us to exit by closing stdin
if os.read(0, 512) == '':
shutdown()
except EnvironmentError as err:
if err.errno != EINTR:
shutdown()
ready_fds = select.select([0, listen_sock], [], [])[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

select() should have an timeout here, so after SIGTERM, it will not exit until events happen.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this updated code still need a timeout? The current code exits directly in the SIGTERM handler.

except select.error as ex:
if ex[0] == EINTR:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@davies raised a good question about whether select.select() can return other errors here and whether we should try to more gracefully handle those errors. According to man select:

     An error return from select() indicates:

     [EAGAIN]           The kernel was (perhaps temporarily) unable to allocate the requested number of file descriptors.

     [EBADF]            One of the descriptor sets specified an invalid descriptor.

     [EINTR]            A signal was delivered before the time limit expired and before any of the selected events occurred.

     [EINVAL]           The specified time limit is invalid.  One of its components is negative or too large.

     [EINVAL]           ndfs is greater than FD_SETSIZE and _DARWIN_UNLIMITED_SELECT is not defined.

I think only EINTR is recoverable here. I've updated this code to use the EINTR constant instead of the magic number 4.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In most cases, EAGAIN should be recovable, should we also catch that?

continue
else:
raise
if 0 in ready_fds:
# Spark told us to exit by closing stdin
shutdown(0)
if listen_sock in ready_fds:
sock, addr = listen_sock.accept()
# Launch a worker process
try:
fork_return_code = os.fork()
if fork_return_code == 0:
listen_sock.close()
try:
worker(sock)
except:
traceback.print_exc()
os._exit(1)
else:
os._exit(0)
else:
sock.close()
except OSError as e:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 2 * POOLSIZE didn't really make sense here; it would only have an effect if multiprocessing.cpu_count() reported more than 512 cores.

print >> sys.stderr, "Daemon failed to fork PySpark worker: %s" % e
outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536)
write_int(-1, outfile) # Signal that the fork failed
outfile.flush()
sock.close()
finally:
signal.signal(SIGTERM, SIG_DFL)
exit_flag.value = True
# Send SIGHUP to notify workers of shutdown
os.kill(0, SIGHUP)
shutdown(1)


if __name__ == '__main__':
Expand Down