Skip to content

Commit

Permalink
Single separate thread for server, replier, publisher
Browse files Browse the repository at this point in the history
  • Loading branch information
dwsutherland committed Mar 17, 2022
1 parent f02f217 commit 349bd08
Show file tree
Hide file tree
Showing 8 changed files with 100 additions and 124 deletions.
43 changes: 10 additions & 33 deletions cylc/flow/network/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
import asyncio
import getpass
import json
from threading import Thread
from time import sleep

import zmq
import zmq.asyncio
Expand Down Expand Up @@ -92,7 +90,7 @@ def get_location(workflow: str):


class ZMQSocketBase:
"""Initiate the ZMQ socket bind for specified pattern on new thread.
"""Initiate the ZMQ socket bind for specified pattern.
NOTE: Security to be provided via zmq.auth (see PR #3359).
Expand All @@ -102,13 +100,6 @@ class ZMQSocketBase:
context (object, optional): instantiated ZeroMQ context, defaults
to zmq.asyncio.Context().
barrier (object, optional): threading.Barrier object for syncing with
other threads.
threaded (bool, optional): Start socket on separate thread.
daemon (bool, optional): daemonise socket thread.
This class is designed to be inherited by REP Server (REQ/REP)
and by PUB Publisher (PUB/SUB), as the start-up logic is similar.
Expand All @@ -117,22 +108,23 @@ class ZMQSocketBase:
"""

def __init__(self, pattern, workflow=None, bind=False, context=None,
barrier=None, threaded=False, daemon=False):
def __init__(
self,
pattern,
workflow=None,
bind=False,
context=None,
):
self.bind = bind
if context is None:
self.context = zmq.asyncio.Context()
else:
self.context = context
self.barrier = barrier
self.pattern = pattern
self.daemon = daemon
self.workflow = workflow
self.host = None
self.port = None
self.socket = None
self.threaded = threaded
self.thread = None
self.loop = None
self.stopping = False

Expand All @@ -141,16 +133,7 @@ def start(self, *args, **kwargs):
Pass arguments to _start_
"""
if self.threaded:
self.thread = Thread(
target=self._start_sequence,
args=args,
kwargs=kwargs,
daemon=self.daemon
)
self.thread.start()
else:
self._start_sequence(*args, **kwargs)
self._start_sequence(*args, **kwargs)

def _start_sequence(self, *args, **kwargs):
"""Create the thread async loop, and bind socket."""
Expand Down Expand Up @@ -228,9 +211,6 @@ def _socket_bind(self, min_port, max_port, srv_prv_key_loc=None):
except (zmq.error.ZMQError, zmq.error.ZMQBindError) as exc:
raise CylcError(f'could not start Cylc ZMQ server: {exc}')

if self.barrier is not None:
self.barrier.wait()

# Keeping srv_public_key_loc as optional arg so as to not break interface
def _socket_connect(self, host, port, srv_public_key_loc=None):
"""Connect socket to stub."""
Expand Down Expand Up @@ -295,9 +275,8 @@ def _socket_options(self):
def _bespoke_start(self):
"""Initiate bespoke items on thread at start."""
self.stopping = False
sleep(0) # yield control to other threads

def stop(self, stop_loop=True):
def stop(self, stop_loop=False):
"""Stop the server.
Args:
Expand All @@ -307,8 +286,6 @@ def stop(self, stop_loop=True):
self._bespoke_stop()
if stop_loop and self.loop and self.loop.is_running():
self.loop.stop()
if self.thread and self.thread.is_alive():
self.thread.join() # Wait for processes to return
if self.socket and not self.socket.closed:
self.socket.close()
LOG.debug('...stopped')
Expand Down
12 changes: 4 additions & 8 deletions cylc/flow/network/publisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,10 @@ class WorkflowPublisher(ZMQSocketBase):
"""

def __init__(self, workflow, context=None, barrier=None,
threaded=False, daemon=False):
super().__init__(zmq.PUB, bind=True, context=context,
barrier=barrier, threaded=threaded, daemon=daemon)
self.workflow = workflow
def __init__(self, server, context=None):
super().__init__(zmq.PUB, bind=True, context=context)
self.server = server
self.workflow = server.schd.workflow
self.topics = set()

def _socket_options(self):
Expand All @@ -70,9 +69,6 @@ def _socket_options(self):
def _bespoke_stop(self):
"""Bespoke stop items."""
LOG.debug('stopping zmq publisher...')
# Child of server object, parent to stop loop.
self.loop = None
self.stopping = True

async def send_multi(self, topic, data, serializer=None):
"""Send multi part message.
Expand Down
55 changes: 8 additions & 47 deletions cylc/flow/network/replier.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import getpass # noqa: F401
from queue import Queue
from time import sleep

import zmq

Expand All @@ -35,57 +34,23 @@ class WorkflowReplier(ZMQSocketBase):
"""

RECV_TIMEOUT = 1
"""Max time the Workflow Replier will wait for an incoming
message in seconds.
We use a timeout here so as to give the _listener a chance to respond to
requests (i.e. stop) from its spawner (the scheduler).
The alternative would be to spin up a client and send a message to the
server, this way seems safer.
"""

def __init__(self, server, context=None, barrier=None,
threaded=True, daemon=False):
super().__init__(zmq.REP, bind=True, context=context,
barrier=barrier, threaded=threaded, daemon=daemon)
def __init__(self, server, context=None):
super().__init__(zmq.REP, bind=True, context=context)
self.server = server
self.workflow = server.schd.workflow
self.queue = None

def _socket_options(self):
"""Set socket options.
Overwrites Base method.
"""
# create socket
self.socket.RCVTIMEO = int(self.RECV_TIMEOUT) * 1000

def _bespoke_start(self):
"""Setup start items, and run listener.
Overwrites Base method.
"""
# start accepting requests
self.queue = Queue()
self._listener()

def _bespoke_stop(self):
"""Stop the listener and Authenticator.
Overwrites Base method.
"""
LOG.debug('stopping zmq server...')
self.stopping = True
LOG.debug('stopping zmq replier...')
if self.queue is not None:
self.queue.put('STOP')

def _listener(self):
def listener(self):
"""The server main loop, listen for and serve requests."""
while True:
# process any commands passed to the listener by its parent process
Expand All @@ -96,16 +61,14 @@ def _listener(self):
raise ValueError('Unknown command "%s"' % command)

try:
# wait RECV_TIMEOUT for a message
msg = self.socket.recv_string()
# Check for messages
msg = self.socket.recv_string(zmq.NOBLOCK)
except zmq.error.Again:
# timeout, continue with the loop, this allows the listener
# thread to stop
continue
# No messages, break to parent loop.
break
except zmq.error.ZMQError as exc:
LOG.exception('unexpected error: %s', exc)
continue

# attempt to decode the message, authenticating the user in the
# process
try:
Expand All @@ -128,5 +91,3 @@ def _listener(self):
# self.curve_auth, self.socket.curve_...key etc.). We have set up
# public-key cryptography on the ZMQ messaging and sockets, so
# there is no need to encrypt messages ourselves before sending.

sleep(0) # yield control to other threads
68 changes: 57 additions & 11 deletions cylc/flow/network/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""Server for workflow runtime API."""

import asyncio
import getpass # noqa: F401
from queue import Queue
from textwrap import dedent
from threading import Barrier
from time import sleep
from typing import Any, Dict, List, Optional, Union

from graphql.execution import ExecutionResult
Expand Down Expand Up @@ -128,7 +130,8 @@ def __init__(self, schd):
self.pub_port = None
self.replier = None
self.publisher = None
self.barrier = None
self.loop = None
self.thread = None
self.curve_auth = None
self.client_pub_key_dir = None

Expand All @@ -143,10 +146,13 @@ def __init__(self, schd):
IgnoreFieldMiddleware,
]

self.queue = Queue()
self.publish_queue = Queue()
self.stopping = False
self.stopped = True

def configure(self):
self.register_endpoints()
# create thread sync barrier for setup
self.barrier = Barrier(2, timeout=10)

# TODO: this in zmq asyncio context?
# Requires the scheduler main loop in asyncio first
Expand Down Expand Up @@ -174,24 +180,39 @@ def configure(self):
location=(self.client_pub_key_dir)
)

self.replier = WorkflowReplier(
self, context=self.zmq_context, barrier=self.barrier)
self.publisher = WorkflowPublisher(
self.schd.workflow, context=self.zmq_context)
self.replier = WorkflowReplier(self, context=self.zmq_context)
self.publisher = WorkflowPublisher(self, context=self.zmq_context)

async def start(self):
def start(self, barrier):
"""Start the TCP servers."""
# set asyncio loop on thread
try:
self.loop = asyncio.get_running_loop()
except RuntimeError:
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)

min_, max_ = glbl_cfg().get(['scheduler', 'run hosts', 'ports'])
self.replier.start(min_, max_)
self.publisher.start(min_, max_)
# wait for threads to setup socket ports before continuing
self.barrier.wait()
self.port = self.replier.port
self.pub_port = self.publisher.port
self.schd.data_store_mgr.delta_workflow_ports()

# wait for threads to setup socket ports before continuing
barrier.wait()

self.stopped = False

self.operate()

async def stop(self, reason):
"""Stop the TCP servers, and clean up authentication."""
self.queue.put('STOP')
if self.thread and self.thread.is_alive():
while not self.stopping:
sleep(0.2)

if self.replier:
self.replier.stop()
if self.publisher:
Expand All @@ -201,6 +222,31 @@ async def stop(self, reason):
self.publisher.stop()
if self.curve_auth:
self.curve_auth.stop() # stop the authentication thread
if self.loop and self.loop.is_running():
self.loop.stop()
if self.thread and self.thread.is_alive():
self.thread.join() # Wait for processes to return

self.stopped = True

def operate(self):
while True:
# process messages from the scheduler.
if self.queue.qsize():
message = self.queue.get()
if message == 'STOP':
self.stopping = True
break
raise ValueError('Unknown message "%s"' % message)

self.replier.listener()

while self.publish_queue.qsize():
articles = self.publish_queue.get()
self.loop.run_until_complete(self.publisher.publish(articles))

# Yield control to other threads
sleep(0)

def responder(self, message):
"""Process message, coordinate publishing, return response."""
Expand Down
Loading

0 comments on commit 349bd08

Please sign in to comment.