Skip to content

Commit

Permalink
Network refactor, unit tests, docstrings & comments
Browse files Browse the repository at this point in the history
  • Loading branch information
dwsutherland committed Oct 25, 2019
1 parent 1dd65bf commit a13d538
Show file tree
Hide file tree
Showing 14 changed files with 847 additions and 492 deletions.
60 changes: 37 additions & 23 deletions bin/cylc-subscribe
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

"""cylc subscribe [OPTIONS] ARGS
"""cylc subscribe [OPTIONS] REG
(This command is for internal use.)
Invoke suite subscriber to receive published workflow output.
Expand All @@ -33,6 +33,7 @@ from cylc.flow.option_parsers import CylcOptionParser as COP
from cylc.flow.network.scan import get_scan_items_from_fs, re_compile_filters
from cylc.flow.network.subscriber import WorkflowSubscriber, process_delta_msg
from cylc.flow.terminal import cli_function
from cylc.flow.ws_data_mgr import DELTAS_MAP

if '--use-ssh' in sys.argv[1:]:
sys.argv.remove('--use-ssh')
Expand All @@ -41,47 +42,60 @@ if '--use-ssh' in sys.argv[1:]:
sys.exit(0)


def print_message(_, data):
def print_message(topic, data):
"""Print protobuf message."""
print(f'Received: {topic}')
sys.stdout.write(
json.dumps(MessageToDict(data), indent=4) + '\n')


def get_option_parser():
"""Augment options parser to current context."""
parser = COP(__doc__, comms=True, argdoc=[
('REG', 'Suite name'),
('[TOPICS]', 'Subscription topics to receive')])
parser = COP(__doc__, comms=True)

delta_keys = list(DELTAS_MAP)
pb_topics = ("Directly published data-store topics include: '" +
("', '").join(delta_keys[:-1]) +
"' and '" + delta_keys[-1] + "'.")

parser.add_option(
"-T", "--topics",
help="Specify a comma delimited list of subscription topics. "
+ pb_topics,
action="store", dest="topics", default='workflow')


return parser


@cli_function(get_option_parser)
def main(_, options, suite, topics=None):
def main(_, options, suite):
cre_owner, cre_name = re_compile_filters(None, ['.*'])
host = None
port = None
cre_owner, cre_name = re_compile_filters(None, ['.*'])
while True:
for s_reg, s_host, _, s_pub_port in get_scan_items_from_fs(
cre_owner, cre_name):
if s_reg == suite:
host = s_host
port = int(s_pub_port)
if options.host is None or options.port is None:
while True:
for s_reg, s_host, _, s_pub_port in get_scan_items_from_fs(
cre_owner, cre_name):
if s_reg == suite:
host = s_host
port = int(s_pub_port)
break
if host and port:
break
if host and port:
break
time.sleep(5)
time.sleep(5)
else:
host = options.host
port = options.port

print(f'Connecting to tcp://{host}:{port}')
topic_set = set()
if topics is None:
topic_set.add(b'workflow')
else:
for topic in topics.split(','):
topic_set.add(topic.encode('utf-8'))
for topic in options.topics.split(','):
topic_set.add(topic.encode('utf-8'))

subscriber = WorkflowSubscriber(host, port, topics=topic_set)

asyncio.ensure_future(
subscriber.loop.create_task(
subscriber.subscribe(
process_delta_msg,
func=print_message
Expand All @@ -90,7 +104,7 @@ def main(_, options, suite, topics=None):

# run Python run
try:
asyncio.get_event_loop().run_forever()
subscriber.loop.run_forever()
except KeyboardInterrupt:
print('\nDisconnecting')
subscriber.stop()
Expand Down
145 changes: 145 additions & 0 deletions cylc/flow/network/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,148 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""Package for network interfaces to cylc suite server objects."""

import asyncio
from threading import Thread
from time import sleep

import zmq
import zmq.asyncio

from cylc.flow import LOG
from cylc.flow.exceptions import CylcError


class ZMQSocketBase:
"""Initiate the ZMQ socket bind for specified pattern on new thread.
NOTE: Security to be provided via zmq.auth (see PR #3359).
Args:
pattern (enum): ZeroMQ message pattern (zmq.PATTERN)
context (object): instantiated ZeroMQ context (i.e. zmq.Context())
barrier (object): threading.Barrier object for syncing with
other threads.
This class is designed to be inherited by REP Server (REQ/REP)
and by PUB Publisher (PUB/SUB), as the start-up logic is the same.
"""

def __init__(self, pattern, bind=False, context=None,
barrier=None, threaded=False):
self.bind = bind
if context is None:
self.context = zmq.asyncio.Context()
else:
self.context = context
self.barrier = barrier
self.pattern = pattern
self.port = None
self.socket = None
self.threaded = threaded
self.thread = None
self.loop = None
self.stopping = False

def start(self, *args, **kwargs):
"""Start the server.
Port range passed to socket creation in server thread.
Args:
min_port (int): minimum socket port number
max_port (int): maximum socket port number
"""
if self.threaded:
self.thread = Thread(
target=self._start_sequence,
args=args,
kwargs=kwargs
)
self.thread.start()
else:
self._start_sequence(*args, **kwargs)

def _start_sequence(self, *args, **kwargs):
"""Create the thread async loop, and bind socket."""
# set asyncio loop on thread
try:
self.loop = asyncio.get_running_loop()
except RuntimeError:
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)

if self.bind:
self._socket_bind(*args, **kwargs)
else:
self._socket_connect(*args, **kwargs)

# initiate bespoke items
self._bespoke_start()

def _socket_bind(self, min_port, max_port):
"""Bind socket.
Will use a port range provided to select random ports.
"""
# create socket
self.socket = self.context.socket(self.pattern)
self._socket_options()

try:
if min_port == max_port:
self.port = min_port
self.socket.bind('tcp://*:%d' % min_port)
else:
self.port = self.socket.bind_to_random_port(
'tcp://*', min_port, max_port)
except (zmq.error.ZMQError, zmq.error.ZMQBindError) as exc:
self.socket.close()
raise CylcError('could not start Cylc ZMQ server: %s' % str(exc))

if self.barrier is not None:
self.barrier.wait()

def _socket_connect(self, host, port):
"""Connect socket to stub."""
self.socket = self.context.socket(self.pattern)
self._socket_options()
self.socket.connect(f'tcp://{host}:{port}')

def _socket_options(self):
"""Set socket options.
i.e. self.socket.sndhwm = 1000
Overwrite this method on inheritance.
"""
self.socket.sndhwm = 10000

def _bespoke_start(self):
"""Initiate bespoke items on thread at start.
Overwrite this method on inheritance.
"""
sleep(0) # yield control to other threads

def stop(self):
"""Stop the server."""
LOG.debug('stopping zmq server...')
self._bespoke_stop()
self.socket.close()
if self.threaded:
self.thread.join() # Wait for processes to return
LOG.debug('...stopped')

def _bespoke_stop(self):
"""Bespoke stop items.
Overwrite this method on inheritance.
"""
self.stopping = True
Loading

0 comments on commit a13d538

Please sign in to comment.