Skip to content

Commit

Permalink
Merge pull request #412 from MSeal/concurrencyRaceFixes
Browse files Browse the repository at this point in the history
Fixed socket binding race conditions
  • Loading branch information
Carreau authored Jul 26, 2019
2 parents 638bc6f + d1b2d01 commit 4e3769d
Show file tree
Hide file tree
Showing 9 changed files with 241 additions and 39 deletions.
67 changes: 52 additions & 15 deletions ipykernel/heartbeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,27 +35,64 @@ def __init__(self, context, addr=None):
Thread.__init__(self)
self.context = context
self.transport, self.ip, self.port = addr
if self.port == 0:
if addr[0] == 'tcp':
s = socket.socket()
# '*' means all interfaces to 0MQ, which is '' to socket.socket
s.bind(('' if self.ip == '*' else self.ip, 0))
self.port = s.getsockname()[1]
s.close()
elif addr[0] == 'ipc':
self.port = 1
while os.path.exists("%s-%s" % (self.ip, self.port)):
self.port = self.port + 1
else:
raise ValueError("Unrecognized zmq transport: %s" % addr[0])
self.original_port = self.port
if self.original_port == 0:
self.pick_port()
self.addr = (self.ip, self.port)
self.daemon = True

def pick_port(self):
if self.transport == 'tcp':
s = socket.socket()
# '*' means all interfaces to 0MQ, which is '' to socket.socket
s.bind(('' if self.ip == '*' else self.ip, 0))
self.port = s.getsockname()[1]
s.close()
elif self.transport == 'ipc':
self.port = 1
while os.path.exists("%s-%s" % (self.ip, self.port)):
self.port = self.port + 1
else:
raise ValueError("Unrecognized zmq transport: %s" % self.transport)
return self.port

def _try_bind_socket(self):
c = ':' if self.transport == 'tcp' else '-'
return self.socket.bind('%s://%s' % (self.transport, self.ip) + c + str(self.port))

def _bind_socket(self):
try:
win_in_use = errno.WSAEADDRINUSE
except AttributeError:
win_in_use = None

# Try up to 100 times to bind a port when in conflict to avoid
# infinite attempts in bad setups
max_attempts = 1 if self.original_port else 100
for attempt in range(max_attempts):
try:
self._try_bind_socket()
except zmq.ZMQError as ze:
if attempt == max_attempts - 1:
raise
# Raise if we have any error not related to socket binding
if ze.errno != errno.EADDRINUSE and ze.errno != win_in_use:
raise
# Raise if we have any error not related to socket binding
if self.original_port == 0:
self.pick_port()
else:
raise

def run(self):
self.socket = self.context.socket(zmq.ROUTER)
self.socket.linger = 1000
c = ':' if self.transport == 'tcp' else '-'
self.socket.bind('%s://%s' % (self.transport, self.ip) + c + str(self.port))
try:
self._bind_socket()
except Exception:
self.socket.close()
raise

while True:
try:
zmq.device(zmq.QUEUE, self.socket, self.socket)
Expand Down
6 changes: 3 additions & 3 deletions ipykernel/inprocess/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@
#-----------------------------------------------------------------------------

class SocketABC(with_metaclass(abc.ABCMeta, object)):

@abc.abstractmethod
def recv_multipart(self, flags=0, copy=True, track=False):
raise NotImplementedError

@abc.abstractmethod
def send_multipart(self, msg_parts, flags=0, copy=True, track=False):
raise NotImplementedError

@classmethod
def register(cls, other_cls):
if other_cls is not DummySocket:
Expand All @@ -47,7 +47,7 @@ class DummySocket(HasTraits):
message_sent = Int(0) # Should be an Event
context = Instance(zmq.Context)
def _context_default(self):
return zmq.Context.instance()
return zmq.Context()

#-------------------------------------------------------------------------
# Socket interface
Expand Down
24 changes: 13 additions & 11 deletions ipykernel/iostream.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,14 +167,14 @@ def _check_mp_mode(self):
return MASTER
else:
return CHILD

def start(self):
"""Start the IOPub thread"""
self.thread.start()
# make sure we don't prevent process exit
# I'm not sure why setting daemon=True above isn't enough, but it doesn't appear to be.
atexit.register(self.stop)

def stop(self):
"""Stop the IOPub thread"""
if not self.thread.is_alive():
Expand All @@ -183,8 +183,10 @@ def stop(self):
self.thread.join()
if hasattr(self._local, 'event_pipe'):
self._local.event_pipe.close()

def close(self):
if self.closed:
return
self.socket.close()
self.socket = None

Expand All @@ -206,11 +208,11 @@ def schedule(self, f):

def send_multipart(self, *args, **kwargs):
"""send_multipart schedules actual zmq send in my thread.
If my thread isn't running (e.g. forked process), send immediately.
"""
self.schedule(lambda : self._really_send(*args, **kwargs))

def _really_send(self, msg, *args, **kwargs):
"""The callback that actually sends messages"""
mp_mode = self._check_mp_mode()
Expand All @@ -231,10 +233,10 @@ def _really_send(self, msg, *args, **kwargs):
class BackgroundSocket(object):
"""Wrapper around IOPub thread that provides zmq send[_multipart]"""
io_thread = None

def __init__(self, io_thread):
self.io_thread = io_thread

def __getattr__(self, attr):
"""Wrap socket attr access for backward-compatibility"""
if attr.startswith('__') and attr.endswith('__'):
Expand All @@ -245,15 +247,15 @@ def __getattr__(self, attr):
DeprecationWarning, stacklevel=2)
return getattr(self.io_thread.socket, attr)
super(BackgroundSocket, self).__getattr__(attr)

def __setattr__(self, attr, value):
if attr == 'io_thread' or (attr.startswith('__' and attr.endswith('__'))):
super(BackgroundSocket, self).__setattr__(attr, value)
else:
warnings.warn("Setting zmq Socket attribute %s on BackgroundSocket" % attr,
DeprecationWarning, stacklevel=2)
setattr(self.io_thread.socket, attr, value)

def send(self, msg, *args, **kwargs):
return self.send_multipart([msg], *args, **kwargs)

Expand All @@ -264,7 +266,7 @@ def send_multipart(self, *args, **kwargs):

class OutStream(TextIOBase):
"""A file like object that publishes the stream to a 0MQ PUB socket.
Output is handed off to an IO Thread
"""

Expand Down Expand Up @@ -419,7 +421,7 @@ def writable(self):

def _flush_buffer(self):
"""clear the current buffer and return the current buffer data.
This should only be called in the IO thread.
"""
data = u''
Expand Down
62 changes: 56 additions & 6 deletions ipykernel/kernelapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import atexit
import os
import sys
import errno
import signal
import traceback
import logging
Expand Down Expand Up @@ -112,6 +113,14 @@ class IPKernelApp(BaseIPythonApplication, InteractiveShellApp,
kernel = Any()
poller = Any() # don't restrict this even though current pollers are all Threads
heartbeat = Instance(Heartbeat, allow_none=True)

context = Any()
shell_socket = Any()
control_socket = Any()
stdin_socket = Any()
iopub_socket = Any()
iopub_thread = Any()

ports = Dict()

subcommands = {
Expand Down Expand Up @@ -171,7 +180,7 @@ def init_poller(self):
# Parent polling doesn't work if ppid == 1 to start with.
self.poller = ParentPollerUnix()

def _bind_socket(self, s, port):
def _try_bind_socket(self, s, port):
iface = '%s://%s' % (self.transport, self.ip)
if self.transport == 'tcp':
if port <= 0:
Expand All @@ -190,6 +199,25 @@ def _bind_socket(self, s, port):
s.bind("ipc://%s" % path)
return port

def _bind_socket(self, s, port):
try:
win_in_use = errno.WSAEADDRINUSE
except AttributeError:
win_in_use = None

# Try up to 100 times to bind a port when in conflict to avoid
# infinite attempts in bad setups
max_attempts = 1 if port else 100
for attempt in range(max_attempts):
try:
return self._try_bind_socket(s, port)
except zmq.ZMQError as ze:
# Raise if we have any error not related to socket binding
if ze.errno != errno.EADDRINUSE and ze.errno != win_in_use:
raise
if attempt == max_attempts - 1:
raise

def write_connection_file(self):
"""write connection info to JSON file"""
cf = self.abs_connection_file
Expand Down Expand Up @@ -229,9 +257,9 @@ def init_connection_file(self):
def init_sockets(self):
# Create a context, a session, and the kernel sockets.
self.log.info("Starting the kernel at pid: %i", os.getpid())
context = zmq.Context.instance()
# Uncomment this to try closing the context.
# atexit.register(context.term)
assert self.context is None, "init_sockets cannot be called twice!"
self.context = context = zmq.Context()
atexit.register(self.close)

self.shell_socket = context.socket(zmq.ROUTER)
self.shell_socket.linger = 1000
Expand Down Expand Up @@ -279,6 +307,26 @@ def init_heartbeat(self):
self.log.debug("Heartbeat REP Channel on port: %i" % self.hb_port)
self.heartbeat.start()

def close(self):
"""Close zmq sockets in an orderly fashion"""
self.log.info("Cleaning up sockets")
if self.heartbeat:
self.log.debug("Closing heartbeat channel")
self.heartbeat.socket.close()
self.heartbeat.context.term()
if self.iopub_thread:
self.log.debug("Closing iopub channel")
self.iopub_thread.stop()
self.iopub_thread.close()
for channel in ('shell', 'control', 'stdin'):
self.log.debug("Closing %s channel", channel)
socket = getattr(self, channel + "_socket", None)
if socket and not socket.closed:
socket.close()
self.log.debug("Terminating zmq context")
self.context.term()
self.log.debug("Terminated zmq context")

def log_connection_info(self):
"""display connection info, and store ports"""
basename = os.path.basename(self.connection_file)
Expand Down Expand Up @@ -477,8 +525,8 @@ def initialize(self, argv=None):
try:
self.init_signal()
except:
# Catch exception when initializing signal fails, eg when running the
# kernel on a separate thread
# Catch exception when initializing signal fails, eg when running the
# kernel on a separate thread
if self.log_level < logging.CRITICAL:
self.log.error("Unable to initialize signal:", exc_info=True)
self.init_kernel()
Expand Down Expand Up @@ -506,8 +554,10 @@ def start(self):
except KeyboardInterrupt:
pass


launch_new_instance = IPKernelApp.launch_instance


def main():
"""Run an IPKernel as an application"""
app = IPKernelApp.instance()
Expand Down
4 changes: 2 additions & 2 deletions ipykernel/tests/test_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@
KC = KM = None


def setup():
def setup_function():
"""start the global kernel (if it isn't running) and return its client"""
global KM, KC
KM, KC = start_new_kernel()
flush_channels(KC)


def teardown():
def teardown_function():
KC.stop_channels()
KM.shutdown_kernel(now=True)

Expand Down
Loading

0 comments on commit 4e3769d

Please sign in to comment.