From 0ad693cc4bc585d2d6a453c5550a126cb800da1d Mon Sep 17 00:00:00 2001 From: David Sutherland Date: Fri, 18 Jun 2021 23:56:44 +1200 Subject: [PATCH] Centralise REP and PUB into server --- cylc/flow/data_store_mgr.py | 8 +- cylc/flow/network/replier.py | 132 +++++++++++++++++ cylc/flow/network/server.py | 172 ++++++++++------------- cylc/flow/scheduler.py | 88 ++---------- cylc/flow/task_events_mgr.py | 4 +- tests/integration/test_examples.py | 4 +- tests/integration/test_publisher.py | 2 +- tests/integration/test_replier.py | 55 ++++++++ tests/integration/test_scan.py | 2 +- tests/integration/test_server.py | 15 -- tests/integration/test_workflow_files.py | 5 +- 11 files changed, 291 insertions(+), 196 deletions(-) create mode 100644 cylc/flow/network/replier.py create mode 100644 tests/integration/test_replier.py diff --git a/cylc/flow/data_store_mgr.py b/cylc/flow/data_store_mgr.py index 241ac7257b5..1b1726c1872 100644 --- a/cylc/flow/data_store_mgr.py +++ b/cylc/flow/data_store_mgr.py @@ -550,8 +550,8 @@ def generate_definition_elements(self): workflow.name = self.schd.workflow workflow.owner = self.schd.owner workflow.host = self.schd.host - workflow.port = self.schd.port or -1 - workflow.pub_port = self.schd.pub_port or -1 + workflow.port = self.schd.server.port or -1 + workflow.pub_port = self.schd.server.pub_port or -1 user_defined_meta = {} for key, val in config.cfg['meta'].items(): if key in ['title', 'description', 'URL']: @@ -1606,8 +1606,8 @@ def delta_workflow_ports(self): w_delta.last_updated = time() w_delta.stamp = f'{w_delta.id}@{w_delta.last_updated}' - w_delta.port = self.schd.port - w_delta.pub_port = self.schd.pub_port + w_delta.port = self.schd.server.port + w_delta.pub_port = self.schd.server.pub_port self.updates_pending = True def delta_broadcast(self): diff --git a/cylc/flow/network/replier.py b/cylc/flow/network/replier.py new file mode 100644 index 00000000000..bd0dc6fac02 --- /dev/null +++ b/cylc/flow/network/replier.py @@ -0,0 +1,132 @@ +# THIS FILE IS PART OF THE CYLC WORKFLOW ENGINE. +# Copyright (C) NIWA & British Crown (Met Office) & Contributors. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +"""Server for workflow runtime API.""" + +import getpass # noqa: F401 +from queue import Queue +from time import sleep + +import zmq + +from cylc.flow import LOG +from cylc.flow.network import encode_, decode_, ZMQSocketBase + + +class WorkflowReplier(ZMQSocketBase): + """Initiate the REP part of a ZMQ REQ-REP pattern. + + This class contains the logic for the ZMQ message replier. + + Usage: + * Define ... + + """ + + RECV_TIMEOUT = 1 + """Max time the Workflow Replier will wait for an incoming + message in seconds. + + We use a timeout here so as to give the _listener a chance to respond to + requests (i.e. stop) from its spawner (the scheduler). + + The alternative would be to spin up a client and send a message to the + server, this way seems safer. + + """ + + def __init__(self, server, context=None, barrier=None, + threaded=True, daemon=False): + super().__init__(zmq.REP, bind=True, context=context, + barrier=barrier, threaded=threaded, daemon=daemon) + self.server = server + self.workflow = server.schd.workflow + self.queue = None + + def _socket_options(self): + """Set socket options. + + Overwrites Base method. + + """ + # create socket + self.socket.RCVTIMEO = int(self.RECV_TIMEOUT) * 1000 + + def _bespoke_start(self): + """Setup start items, and run listener. + + Overwrites Base method. + + """ + # start accepting requests + self.queue = Queue() + self._listener() + + def _bespoke_stop(self): + """Stop the listener and Authenticator. + + Overwrites Base method. + + """ + LOG.debug('stopping zmq server...') + self.stopping = True + if self.queue is not None: + self.queue.put('STOP') + + def _listener(self): + """The server main loop, listen for and serve requests.""" + while True: + # process any commands passed to the listener by its parent process + if self.queue.qsize(): + command = self.queue.get() + if command == 'STOP': + break + raise ValueError('Unknown command "%s"' % command) + + try: + # wait RECV_TIMEOUT for a message + msg = self.socket.recv_string() + except zmq.error.Again: + # timeout, continue with the loop, this allows the listener + # thread to stop + continue + except zmq.error.ZMQError as exc: + LOG.exception('unexpected error: %s', exc) + continue + + # attempt to decode the message, authenticating the user in the + # process + try: + message = decode_(msg) + except Exception as exc: # purposefully catch generic exception + # failed to decode message, possibly resulting from failed + # authentication + LOG.exception('failed to decode message: "%s"', exc) + else: + # success case - serve the request + res = self.server.responder(message) + # send back the string to bytes response + if isinstance(res.get('data'), bytes): + response = res['data'] + else: + response = encode_(res).encode() + self.socket.send(response) + + # Note: we are using CurveZMQ to secure the messages (see + # self.curve_auth, self.socket.curve_...key etc.). We have set up + # public-key cryptography on the ZMQ messaging and sockets, so + # there is no need to encrypt messages ourselves before sending. + + sleep(0) # yield control to other threads diff --git a/cylc/flow/network/server.py b/cylc/flow/network/server.py index b79c808b7b5..f71f90a8195 100644 --- a/cylc/flow/network/server.py +++ b/cylc/flow/network/server.py @@ -16,19 +16,21 @@ """Server for workflow runtime API.""" import getpass # noqa: F401 -from queue import Queue from textwrap import dedent -from time import sleep +from threading import Barrier from graphql.execution.executors.asyncio import AsyncioExecutor import zmq +from zmq.auth.thread import ThreadAuthenticator -from cylc.flow import LOG -from cylc.flow.network import encode_, decode_, ZMQSocketBase +from cylc.flow import LOG, workflow_files +from cylc.flow.cfgspec.glbl_cfg import glbl_cfg from cylc.flow.network.authorisation import authorise from cylc.flow.network.graphql import ( CylcGraphQLBackend, IgnoreFieldMiddleware, instantiate_middleware ) +from cylc.flow.network.publisher import WorkflowPublisher +from cylc.flow.network.replier import WorkflowReplier from cylc.flow.network.resolvers import Resolvers from cylc.flow.network.schema import schema from cylc.flow.data_store_mgr import DELTAS_MAP @@ -66,7 +68,7 @@ def filter_none(dictionary): } -class WorkflowRuntimeServer(ZMQSocketBase): +class WorkflowRuntimeServer: """Workflow runtime service API facade exposed via zmq. This class contains the Cylc endpoints. @@ -116,27 +118,20 @@ class WorkflowRuntimeServer(ZMQSocketBase): """ - RECV_TIMEOUT = 1 - """Max time the WorkflowRuntimeServer will wait for an incoming - message in seconds. + def __init__(self, schd): - We use a timeout here so as to give the _listener a chance to respond to - requests (i.e. stop) from its spawner (the scheduler). + self.zmq_context = None + self.port = None + self.pub_port = None + self.replier = None + self.publisher = None + self.barrier = None + self.curve_auth = None + self.client_pub_key_dir = None - The alternative would be to spin up a client and send a message to the - server, this way seems safer. - - """ - - def __init__(self, schd, context=None, barrier=None, - threaded=True, daemon=False): - super().__init__(zmq.REP, bind=True, context=context, - barrier=barrier, threaded=threaded, daemon=daemon) self.schd = schd - self.workflow = schd.workflow self.public_priv = None # update in get_public_priv() self.endpoints = None - self.queue = None self.resolvers = Resolvers( self.schd.data_store_mgr, schd=self.schd @@ -145,82 +140,69 @@ def __init__(self, schd, context=None, barrier=None, IgnoreFieldMiddleware, ] - def _socket_options(self): - """Set socket options. - - Overwrites Base method. - - """ - # create socket - self.socket.RCVTIMEO = int(self.RECV_TIMEOUT) * 1000 - - def _bespoke_start(self): - """Setup start items, and run listener. - - Overwrites Base method. - - """ - # start accepting requests - self.queue = Queue() + def configure(self): self.register_endpoints() - self._listener() - - def _bespoke_stop(self): - """Stop the listener and Authenticator. - - Overwrites Base method. + # create thread sync barrier for setup + self.barrier = Barrier(3, timeout=10) + + # TODO: this in zmq asyncio context? + # Requires the scheduler main loop in asyncio first + # And use of concurrent.futures.ThreadPoolExecutor? + self.zmq_context = zmq.Context() + # create an authenticator for the ZMQ context + self.curve_auth = ThreadAuthenticator(self.zmq_context, log=LOG) + self.curve_auth.start() # start the authentication thread + + # Setting the location means that the CurveZMQ auth will only + # accept public client certificates from the given directory, as + # generated by a user when they initiate a ZMQ socket ready to + # connect to a server. + workflow_srv_dir = workflow_files.get_workflow_srv_dir( + self.schd.workflow) + client_pub_keyinfo = workflow_files.KeyInfo( + workflow_files.KeyType.PUBLIC, + workflow_files.KeyOwner.CLIENT, + workflow_srv_dir=workflow_srv_dir) + self.client_pub_key_dir = client_pub_keyinfo.key_path + + # Initial load for the localhost key. + self.curve_auth.configure_curve( + domain='*', + location=(self.client_pub_key_dir) + ) - """ - LOG.debug('stopping zmq server...') - self.stopping = True - if self.queue is not None: - self.queue.put('STOP') - - def _listener(self): - """The server main loop, listen for and serve requests.""" - while True: - # process any commands passed to the listener by its parent process - if self.queue.qsize(): - command = self.queue.get() - if command == 'STOP': - break - raise ValueError('Unknown command "%s"' % command) - - try: - # wait RECV_TIMEOUT for a message - msg = self.socket.recv_string() - except zmq.error.Again: - # timeout, continue with the loop, this allows the listener - # thread to stop - continue - except zmq.error.ZMQError as exc: - LOG.exception('unexpected error: %s', exc) - continue - - # attempt to decode the message, authenticating the user in the - # process - try: - message = decode_(msg) - except Exception as exc: # purposefully catch generic exception - # failed to decode message, possibly resulting from failed - # authentication - LOG.exception('failed to decode message: "%s"', exc) - else: - # success case - serve the request - res = self._receiver(message) - if message['command'] in PB_METHOD_MAP: - response = res['data'] - else: - response = encode_(res).encode() - # send back the string to bytes response - self.socket.send(response) - - # Note: we are using CurveZMQ to secure the messages (see - # self.curve_auth, self.socket.curve_...key etc.). We have set up - # public-key cryptography on the ZMQ messaging and sockets, so - # there is no need to encrypt messages ourselves before sending. - - sleep(0) # yield control to other threads + self.replier = WorkflowReplier( + self, context=self.zmq_context, barrier=self.barrier) + self.publisher = WorkflowPublisher( + self.schd.workflow, context=self.zmq_context, barrier=self.barrier) + + async def start(self): + """Start the TCP servers.""" + min_, max_ = glbl_cfg().get(['scheduler', 'run hosts', 'ports']) + self.replier.start(min_, max_) + self.publisher.start(min_, max_) + # wait for threads to setup socket ports before continuing + self.barrier.wait() + self.port = self.replier.port + self.pub_port = self.publisher.port + self.schd.data_store_mgr.delta_workflow_ports() + + async def stop(self, reason): + """Stop the TCP servers, and clean up authentication.""" + if self.replier: + self.replier.stop() + if self.publisher: + await self.publisher.publish( + [(b'shutdown', str(reason).encode('utf-8'))] + ) + self.publisher.stop() + if self.curve_auth: + self.curve_auth.stop() # stop the authentication thread + + def responder(self, message): + """Process message, coordinate publishing, return response.""" + # TODO: coordinate publishing. + return self._receiver(message) def _receiver(self, message): """Wrap incoming messages and dispatch them to exposed methods. diff --git a/cylc/flow/scheduler.py b/cylc/flow/scheduler.py index 08433096ece..8e8fadcd42a 100644 --- a/cylc/flow/scheduler.py +++ b/cylc/flow/scheduler.py @@ -27,15 +27,12 @@ from shlex import quote from subprocess import Popen, PIPE, DEVNULL import sys -from threading import Barrier from time import sleep, time import traceback from typing import Dict, Iterable, List, NoReturn, Optional, Set, Union from uuid import uuid4 import psutil -import zmq -from zmq.auth.thread import ThreadAuthenticator from metomi.isodatetime.parsers import TimePointParser @@ -66,7 +63,6 @@ from cylc.flow.timer import Timer from cylc.flow.network import API from cylc.flow.network.authentication import key_housekeeping -from cylc.flow.network.publisher import WorkflowPublisher from cylc.flow.network.schema import WorkflowStopMode from cylc.flow.network.server import WorkflowRuntimeServer from cylc.flow.option_parsers import verbosity_to_env @@ -153,7 +149,7 @@ class Scheduler: START_MESSAGE_PREFIX = 'Scheduler: ' START_MESSAGE_TMPL = ( START_MESSAGE_PREFIX + - 'url=%(comms_method)s://%(host)s:%(port)s/ pid=%(pid)s') + 'url=%(comms_method)s://%(host)s:%(port)s pid=%(pid)s') START_PUB_MESSAGE_PREFIX = 'Workflow publisher: ' START_PUB_MESSAGE_TMPL = ( START_PUB_MESSAGE_PREFIX + @@ -228,14 +224,7 @@ class Scheduler: auto_restart_time: Optional[float] = None # tcp / zmq - zmq_context: Optional[zmq.Context] = None - port: Optional[int] = None - pub_port: Optional[int] = None server: Optional[WorkflowRuntimeServer] = None - publisher: Optional[WorkflowPublisher] = None - barrier: Optional[Barrier] = None - curve_auth: Optional[ThreadAuthenticator] = None - client_pub_key_dir: Optional[str] = None # queue-released tasks awaiting job preparation pre_prep_tasks: Optional[List[TaskProxy]] = None @@ -273,9 +262,6 @@ def __init__(self, reg: str, options: Values) -> None: self.restored_stop_task_id = None - # create thread sync barrier for setup - self.barrier = Barrier(3, timeout=10) - self.timers: Dict[str, Timer] = {} async def install(self): @@ -336,36 +322,9 @@ async def initialise(self): self.workflow_db_mgr, self.data_store_mgr) self.flow_mgr = FlowMgr(self.workflow_db_mgr) - # *** Network Related *** - # TODO: this in zmq asyncio context? - # Requires the Cylc main loop in asyncio first - # And use of concurrent.futures.ThreadPoolExecutor? - self.zmq_context = zmq.Context() - # create an authenticator for the ZMQ context - self.curve_auth = ThreadAuthenticator(self.zmq_context, log=LOG) - self.curve_auth.start() # start the authentication thread - - # Setting the location means that the CurveZMQ auth will only - # accept public client certificates from the given directory, as - # generated by a user when they initiate a ZMQ socket ready to - # connect to a server. - workflow_srv_dir = workflow_files.get_workflow_srv_dir(self.workflow) - client_pub_keyinfo = workflow_files.KeyInfo( - workflow_files.KeyType.PUBLIC, - workflow_files.KeyOwner.CLIENT, - workflow_srv_dir=workflow_srv_dir) - self.client_pub_key_dir = client_pub_keyinfo.key_path - - # Initial load for the localhost key. - self.curve_auth.configure_curve( - domain='*', - location=(self.client_pub_key_dir) - ) + self.server = WorkflowRuntimeServer(self) + self.server.configure() - self.server = WorkflowRuntimeServer( - self, context=self.zmq_context, barrier=self.barrier) - self.publisher = WorkflowPublisher( - self.workflow, context=self.zmq_context, barrier=self.barrier) self.proc_pool = SubProcPool() self.command_queue = Queue() self.message_queue = Queue() @@ -539,17 +498,6 @@ async def configure(self): self.profiler.log_memory("scheduler.py: end configure") - async def start_servers(self): - """Start the TCP servers.""" - min_, max_ = glbl_cfg().get(['scheduler', 'run hosts', 'ports']) - self.server.start(min_, max_) - self.publisher.start(min_, max_) - # wait for threads to setup socket ports before continuing - self.barrier.wait() - self.port = self.server.port - self.pub_port = self.publisher.port - self.data_store_mgr.delta_workflow_ports() - async def log_start(self): if self.is_restart: n_restart = self.workflow_db_mgr.n_restart @@ -570,7 +518,7 @@ async def log_start(self): self.START_MESSAGE_TMPL % { 'comms_method': 'tcp', 'host': self.host, - 'port': self.port, + 'port': self.server.port, 'pid': os.getpid()}, extra=log_extra, ) @@ -578,7 +526,7 @@ async def log_start(self): self.START_PUB_MESSAGE_TMPL % { 'comms_method': 'tcp', 'host': self.host, - 'port': self.pub_port}, + 'port': self.server.pub_port}, extra=log_extra, ) LOG.info( @@ -611,7 +559,8 @@ async def run_scheduler(self): self ) ) - await self.publisher.publish(self.data_store_mgr.publish_deltas) + await self.server.publisher.publish( + self.data_store_mgr.publish_deltas) self.profiler.start() await self.main_loop() @@ -649,7 +598,7 @@ async def start(self): try: await self.initialise() await self.configure() - await self.start_servers() + await self.server.start() await self.log_start() self._configure_contact() except (KeyboardInterrupt, asyncio.CancelledError, Exception) as exc: @@ -754,8 +703,8 @@ def restart_remote_init(self): incomplete_init = False for platform in distinct_install_target_platforms: self.task_job_mgr.task_remote_mgr.remote_init( - platform, self.curve_auth, - self.client_pub_key_dir) + platform, self.server.curve_auth, + self.server.client_pub_key_dir) status = self.task_job_mgr.task_remote_mgr.remote_init_map[ platform['install target']] if status in (REMOTE_INIT_IN_PROGRESS, @@ -1065,7 +1014,7 @@ def get_contact_data(self) -> Dict[str, str]: fields.COMMAND: cli_format(proc.cmdline()), fields.PUBLISH_PORT: - str(self.publisher.port), # type: ignore + str(self.server.pub_port), # type: ignore fields.WORKFLOW_RUN_DIR_ON_WORKFLOW_HOST: # type: ignore self.workflow_run_dir, fields.UUID: @@ -1287,8 +1236,8 @@ def release_queued_tasks(self): for itask in self.task_job_mgr.submit_task_jobs( self.workflow, self.pre_prep_tasks, - self.curve_auth, - self.client_pub_key_dir, + self.server.curve_auth, + self.server.client_pub_key_dir, self.config.run_mode('simulation') ): # (Not using f"{itask}"_here to avoid breaking func tests) @@ -1625,7 +1574,7 @@ async def update_data_structure(self): # Publish updates: if self.data_store_mgr.publish_pending: self.data_store_mgr.publish_pending = False - await self.publisher.publish( + await self.server.publisher.publish( self.data_store_mgr.publish_deltas) if has_updated: # Database update @@ -1732,14 +1681,7 @@ async def _shutdown(self, reason: Exception) -> None: LOG.exception(exc) if self.server: - self.server.stop() - if self.publisher: - await self.publisher.publish( - [(b'shutdown', str(reason).encode('utf-8'))] - ) - self.publisher.stop() - if self.curve_auth: - self.curve_auth.stop() # stop the authentication thread + await self.server.stop(reason) # Flush errors and info before removing workflow contact file sys.stdout.flush() diff --git a/cylc/flow/task_events_mgr.py b/cylc/flow/task_events_mgr.py index 6ba65736abb..a7264a8ee7d 100644 --- a/cylc/flow/task_events_mgr.py +++ b/cylc/flow/task_events_mgr.py @@ -714,14 +714,14 @@ def _process_event_email(self, schd_ctx, ctx, id_keys): for label, value in [ ('workflow', schd_ctx.workflow), ("host", schd_ctx.host), - ("port", schd_ctx.port), + ("port", schd_ctx.server.port), ("owner", schd_ctx.owner)]: if value: stdin_str += "%s: %s\n" % (label, value) if self.mail_footer: stdin_str += (self.mail_footer + "\n") % { "host": schd_ctx.host, - "port": schd_ctx.port, + "port": schd_ctx.server.port, "owner": schd_ctx.owner, "workflow": schd_ctx.workflow} # SMTP server diff --git a/tests/integration/test_examples.py b/tests/integration/test_examples.py index fa2faa21eb7..1ff04766093 100644 --- a/tests/integration/test_examples.py +++ b/tests/integration/test_examples.py @@ -99,7 +99,7 @@ async def test_shutdown(flow, scheduler, run, one_conf): schd = scheduler(reg) async with run(schd): pass - assert schd.server.socket.closed + assert schd.server.replier.socket.closed async def test_install(flow, scheduler, one_conf, run_dir): @@ -180,7 +180,7 @@ def killer(): # make sure the server socket has closed - a good indication of a # successful clean shutdown - assert schd.server.socket.closed + assert schd.server.replier.socket.closed @pytest.fixture(scope='module') diff --git a/tests/integration/test_publisher.py b/tests/integration/test_publisher.py index ac1e6c612a6..dadc3b7bea8 100644 --- a/tests/integration/test_publisher.py +++ b/tests/integration/test_publisher.py @@ -31,7 +31,7 @@ async def test_publisher(flow, scheduler, run, one_conf, port_range): subscriber = WorkflowSubscriber( schd.workflow, host=schd.host, - port=schd.publisher.port, + port=schd.server.pub_port, topics=[b'workflow'] ) diff --git a/tests/integration/test_replier.py b/tests/integration/test_replier.py new file mode 100644 index 00000000000..2bc77ad3d54 --- /dev/null +++ b/tests/integration/test_replier.py @@ -0,0 +1,55 @@ +# THIS FILE IS PART OF THE CYLC WORKFLOW ENGINE. +# Copyright (C) NIWA & British Crown (Met Office) & Contributors. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from async_timeout import timeout +import asyncio +from getpass import getuser + +import pytest + + +@pytest.mark.asyncio +@pytest.fixture(scope='module') +async def myflow(mod_flow, mod_scheduler, mod_run, mod_one_conf): + reg = mod_flow(mod_one_conf) + schd = mod_scheduler(reg) + async with mod_run(schd): + yield schd + + +@pytest.mark.asyncio +@pytest.fixture +async def accident(flow, scheduler, run, one_conf): + reg = flow(one_conf) + schd = scheduler(reg) + async with run(schd): + yield schd + + +@pytest.mark.asyncio +async def test_listener(accident): + """Test listener.""" + accident.server.replier.queue.put('STOP') + async with timeout(2): + # wait for the server to consume the STOP item from the queue + while True: + if accident.server.replier.queue.empty(): + break + await asyncio.sleep(0.01) + # ensure the server is "closed" + with pytest.raises(ValueError): + accident.server.replier.queue.put('foobar') + accident.server.replier._listener() diff --git a/tests/integration/test_scan.py b/tests/integration/test_scan.py index c0b1e267e48..c13c8eea723 100644 --- a/tests/integration/test_scan.py +++ b/tests/integration/test_scan.py @@ -369,7 +369,7 @@ async def test_scan_sigstop(flow, scheduler, run, one_conf, test_dir, caplog): schd = scheduler(reg) async with run(schd): # stop the server to make the flow un-responsive - schd.server.stop() + schd.server.stop('make-unresponsive') # try scanning the workflow pipe = scan(test_dir) | graphql_query(['status']) caplog.clear() diff --git a/tests/integration/test_server.py b/tests/integration/test_server.py index 9efbed4938b..cc36e47eed1 100644 --- a/tests/integration/test_server.py +++ b/tests/integration/test_server.py @@ -86,21 +86,6 @@ async def accident(flow, scheduler, run, one_conf): yield schd -async def test_listener(accident): - """Test listener.""" - accident.server.queue.put('STOP') - async with timeout(2): - # wait for the server to consume the STOP item from the queue - while True: - if accident.server.queue.empty(): - break - await asyncio.sleep(0.01) - # ensure the server is "closed" - with pytest.raises(ValueError): - accident.server.queue.put('foobar') - accident.server._listener() - - def test_receiver(accident): """Test receiver.""" msg_in = {'not_command': 'foobar', 'args': {}} diff --git a/tests/integration/test_workflow_files.py b/tests/integration/test_workflow_files.py index d9a09eeba77..6c4067d1e60 100644 --- a/tests/integration/test_workflow_files.py +++ b/tests/integration/test_workflow_files.py @@ -61,9 +61,8 @@ async def workflow(flow, scheduler, one_conf, run_dir): await schd.install() from collections import namedtuple - Server = namedtuple('Server', ['port']) - schd.server = Server(1234) - schd.publisher = Server(2345) + Server = namedtuple('Server', ['port', 'pub_port']) + schd.server = Server(1234, pub_port=2345) contact_data = schd.get_contact_data() contact_file = Path(