Skip to content

Commit

Permalink
Centralise REP and PUB into server
Browse files Browse the repository at this point in the history
  • Loading branch information
dwsutherland committed Feb 11, 2022
1 parent 9a10c85 commit 0ad693c
Show file tree
Hide file tree
Showing 11 changed files with 291 additions and 196 deletions.
8 changes: 4 additions & 4 deletions cylc/flow/data_store_mgr.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,8 +550,8 @@ def generate_definition_elements(self):
workflow.name = self.schd.workflow
workflow.owner = self.schd.owner
workflow.host = self.schd.host
workflow.port = self.schd.port or -1
workflow.pub_port = self.schd.pub_port or -1
workflow.port = self.schd.server.port or -1
workflow.pub_port = self.schd.server.pub_port or -1
user_defined_meta = {}
for key, val in config.cfg['meta'].items():
if key in ['title', 'description', 'URL']:
Expand Down Expand Up @@ -1606,8 +1606,8 @@ def delta_workflow_ports(self):
w_delta.last_updated = time()
w_delta.stamp = f'{w_delta.id}@{w_delta.last_updated}'

w_delta.port = self.schd.port
w_delta.pub_port = self.schd.pub_port
w_delta.port = self.schd.server.port
w_delta.pub_port = self.schd.server.pub_port
self.updates_pending = True

def delta_broadcast(self):
Expand Down
132 changes: 132 additions & 0 deletions cylc/flow/network/replier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# THIS FILE IS PART OF THE CYLC WORKFLOW ENGINE.
# Copyright (C) NIWA & British Crown (Met Office) & Contributors.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""Server for workflow runtime API."""

import getpass # noqa: F401
from queue import Queue
from time import sleep

import zmq

from cylc.flow import LOG
from cylc.flow.network import encode_, decode_, ZMQSocketBase


class WorkflowReplier(ZMQSocketBase):
"""Initiate the REP part of a ZMQ REQ-REP pattern.
This class contains the logic for the ZMQ message replier.
Usage:
* Define ...
"""

RECV_TIMEOUT = 1
"""Max time the Workflow Replier will wait for an incoming
message in seconds.
We use a timeout here so as to give the _listener a chance to respond to
requests (i.e. stop) from its spawner (the scheduler).
The alternative would be to spin up a client and send a message to the
server, this way seems safer.
"""

def __init__(self, server, context=None, barrier=None,
threaded=True, daemon=False):
super().__init__(zmq.REP, bind=True, context=context,
barrier=barrier, threaded=threaded, daemon=daemon)
self.server = server
self.workflow = server.schd.workflow
self.queue = None

def _socket_options(self):
"""Set socket options.
Overwrites Base method.
"""
# create socket
self.socket.RCVTIMEO = int(self.RECV_TIMEOUT) * 1000

def _bespoke_start(self):
"""Setup start items, and run listener.
Overwrites Base method.
"""
# start accepting requests
self.queue = Queue()
self._listener()

def _bespoke_stop(self):
"""Stop the listener and Authenticator.
Overwrites Base method.
"""
LOG.debug('stopping zmq server...')
self.stopping = True
if self.queue is not None:
self.queue.put('STOP')

def _listener(self):
"""The server main loop, listen for and serve requests."""
while True:
# process any commands passed to the listener by its parent process
if self.queue.qsize():
command = self.queue.get()
if command == 'STOP':
break
raise ValueError('Unknown command "%s"' % command)

try:
# wait RECV_TIMEOUT for a message
msg = self.socket.recv_string()
except zmq.error.Again:
# timeout, continue with the loop, this allows the listener
# thread to stop
continue
except zmq.error.ZMQError as exc:
LOG.exception('unexpected error: %s', exc)
continue

# attempt to decode the message, authenticating the user in the
# process
try:
message = decode_(msg)
except Exception as exc: # purposefully catch generic exception
# failed to decode message, possibly resulting from failed
# authentication
LOG.exception('failed to decode message: "%s"', exc)
else:
# success case - serve the request
res = self.server.responder(message)
# send back the string to bytes response
if isinstance(res.get('data'), bytes):
response = res['data']
else:
response = encode_(res).encode()
self.socket.send(response)

# Note: we are using CurveZMQ to secure the messages (see
# self.curve_auth, self.socket.curve_...key etc.). We have set up
# public-key cryptography on the ZMQ messaging and sockets, so
# there is no need to encrypt messages ourselves before sending.

sleep(0) # yield control to other threads
172 changes: 77 additions & 95 deletions cylc/flow/network/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,21 @@
"""Server for workflow runtime API."""

import getpass # noqa: F401
from queue import Queue
from textwrap import dedent
from time import sleep
from threading import Barrier

from graphql.execution.executors.asyncio import AsyncioExecutor
import zmq
from zmq.auth.thread import ThreadAuthenticator

from cylc.flow import LOG
from cylc.flow.network import encode_, decode_, ZMQSocketBase
from cylc.flow import LOG, workflow_files
from cylc.flow.cfgspec.glbl_cfg import glbl_cfg
from cylc.flow.network.authorisation import authorise
from cylc.flow.network.graphql import (
CylcGraphQLBackend, IgnoreFieldMiddleware, instantiate_middleware
)
from cylc.flow.network.publisher import WorkflowPublisher
from cylc.flow.network.replier import WorkflowReplier
from cylc.flow.network.resolvers import Resolvers
from cylc.flow.network.schema import schema
from cylc.flow.data_store_mgr import DELTAS_MAP
Expand Down Expand Up @@ -66,7 +68,7 @@ def filter_none(dictionary):
}


class WorkflowRuntimeServer(ZMQSocketBase):
class WorkflowRuntimeServer:
"""Workflow runtime service API facade exposed via zmq.
This class contains the Cylc endpoints.
Expand Down Expand Up @@ -116,27 +118,20 @@ class WorkflowRuntimeServer(ZMQSocketBase):
"""

RECV_TIMEOUT = 1
"""Max time the WorkflowRuntimeServer will wait for an incoming
message in seconds.
def __init__(self, schd):

We use a timeout here so as to give the _listener a chance to respond to
requests (i.e. stop) from its spawner (the scheduler).
self.zmq_context = None
self.port = None
self.pub_port = None
self.replier = None
self.publisher = None
self.barrier = None
self.curve_auth = None
self.client_pub_key_dir = None

The alternative would be to spin up a client and send a message to the
server, this way seems safer.
"""

def __init__(self, schd, context=None, barrier=None,
threaded=True, daemon=False):
super().__init__(zmq.REP, bind=True, context=context,
barrier=barrier, threaded=threaded, daemon=daemon)
self.schd = schd
self.workflow = schd.workflow
self.public_priv = None # update in get_public_priv()
self.endpoints = None
self.queue = None
self.resolvers = Resolvers(
self.schd.data_store_mgr,
schd=self.schd
Expand All @@ -145,82 +140,69 @@ def __init__(self, schd, context=None, barrier=None,
IgnoreFieldMiddleware,
]

def _socket_options(self):
"""Set socket options.
Overwrites Base method.
"""
# create socket
self.socket.RCVTIMEO = int(self.RECV_TIMEOUT) * 1000

def _bespoke_start(self):
"""Setup start items, and run listener.
Overwrites Base method.
"""
# start accepting requests
self.queue = Queue()
def configure(self):
self.register_endpoints()
self._listener()

def _bespoke_stop(self):
"""Stop the listener and Authenticator.
Overwrites Base method.
# create thread sync barrier for setup
self.barrier = Barrier(3, timeout=10)

# TODO: this in zmq asyncio context?
# Requires the scheduler main loop in asyncio first
# And use of concurrent.futures.ThreadPoolExecutor?
self.zmq_context = zmq.Context()
# create an authenticator for the ZMQ context
self.curve_auth = ThreadAuthenticator(self.zmq_context, log=LOG)
self.curve_auth.start() # start the authentication thread

# Setting the location means that the CurveZMQ auth will only
# accept public client certificates from the given directory, as
# generated by a user when they initiate a ZMQ socket ready to
# connect to a server.
workflow_srv_dir = workflow_files.get_workflow_srv_dir(
self.schd.workflow)
client_pub_keyinfo = workflow_files.KeyInfo(
workflow_files.KeyType.PUBLIC,
workflow_files.KeyOwner.CLIENT,
workflow_srv_dir=workflow_srv_dir)
self.client_pub_key_dir = client_pub_keyinfo.key_path

# Initial load for the localhost key.
self.curve_auth.configure_curve(
domain='*',
location=(self.client_pub_key_dir)
)

"""
LOG.debug('stopping zmq server...')
self.stopping = True
if self.queue is not None:
self.queue.put('STOP')

def _listener(self):
"""The server main loop, listen for and serve requests."""
while True:
# process any commands passed to the listener by its parent process
if self.queue.qsize():
command = self.queue.get()
if command == 'STOP':
break
raise ValueError('Unknown command "%s"' % command)

try:
# wait RECV_TIMEOUT for a message
msg = self.socket.recv_string()
except zmq.error.Again:
# timeout, continue with the loop, this allows the listener
# thread to stop
continue
except zmq.error.ZMQError as exc:
LOG.exception('unexpected error: %s', exc)
continue

# attempt to decode the message, authenticating the user in the
# process
try:
message = decode_(msg)
except Exception as exc: # purposefully catch generic exception
# failed to decode message, possibly resulting from failed
# authentication
LOG.exception('failed to decode message: "%s"', exc)
else:
# success case - serve the request
res = self._receiver(message)
if message['command'] in PB_METHOD_MAP:
response = res['data']
else:
response = encode_(res).encode()
# send back the string to bytes response
self.socket.send(response)

# Note: we are using CurveZMQ to secure the messages (see
# self.curve_auth, self.socket.curve_...key etc.). We have set up
# public-key cryptography on the ZMQ messaging and sockets, so
# there is no need to encrypt messages ourselves before sending.

sleep(0) # yield control to other threads
self.replier = WorkflowReplier(
self, context=self.zmq_context, barrier=self.barrier)
self.publisher = WorkflowPublisher(
self.schd.workflow, context=self.zmq_context, barrier=self.barrier)

async def start(self):
"""Start the TCP servers."""
min_, max_ = glbl_cfg().get(['scheduler', 'run hosts', 'ports'])
self.replier.start(min_, max_)
self.publisher.start(min_, max_)
# wait for threads to setup socket ports before continuing
self.barrier.wait()
self.port = self.replier.port
self.pub_port = self.publisher.port
self.schd.data_store_mgr.delta_workflow_ports()

async def stop(self, reason):
"""Stop the TCP servers, and clean up authentication."""
if self.replier:
self.replier.stop()
if self.publisher:
await self.publisher.publish(
[(b'shutdown', str(reason).encode('utf-8'))]
)
self.publisher.stop()
if self.curve_auth:
self.curve_auth.stop() # stop the authentication thread

def responder(self, message):
"""Process message, coordinate publishing, return response."""
# TODO: coordinate publishing.
return self._receiver(message)

def _receiver(self, message):
"""Wrap incoming messages and dispatch them to exposed methods.
Expand Down
Loading

0 comments on commit 0ad693c

Please sign in to comment.