From 7476d911c2dd5da4617e7b4069da88051d589555 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 30 Jun 2022 15:17:26 +0100 Subject: [PATCH] trantitions caused worker death use old 'worker-connect' stimulus_id --- distributed/scheduler.py | 29 +++++++++++++---------------- distributed/tests/test_scheduler.py | 17 ----------------- 2 files changed, 13 insertions(+), 33 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index adb1bbc45f7..9eadaf51288 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -3553,7 +3553,7 @@ def heartbeat_worker( @log_errors async def add_worker( self, - comm=None, + comm: Comm, *, address: str, status: str, @@ -3575,8 +3575,8 @@ async def add_worker( versions: dict[str, Any] | None = None, nanny=None, extra=None, - stimulus_id=None, - ): + stimulus_id: str, + ) -> None: """Add a new worker to the cluster""" address = self.coerce_address(address, resolve_address) address = normalize_address(address) @@ -3591,8 +3591,7 @@ async def add_worker( f"Keys: {list(nbytes)}" ) logger.error(err) - if comm: - await comm.write({"status": "error", "message": err, "time": time()}) + await comm.write({"status": "error", "message": err, "time": time()}) return if name in self.aliases: @@ -3602,8 +3601,7 @@ async def add_worker( "message": "name taken, %s" % name, "time": time(), } - if comm: - await comm.write(msg) + await comm.write(msg) return self.log_event(address, {"action": "add-worker"}) @@ -3682,7 +3680,6 @@ async def add_worker( "worker-plugins": self.worker_plugins, } - cs: ClientState version_warning = version_module.error_message( version_module.get_versions(), merge( @@ -3694,12 +3691,11 @@ async def add_worker( ) msg.update(version_warning) - if comm: - await comm.write(msg) + await comm.write(msg) + # This will keep running until the worker is removed + await self.handle_worker(comm, address) - await self.handle_worker(comm=comm, worker=address, stimulus_id=stimulus_id) - - async def add_nanny(self, comm): + async def add_nanny(self) -> dict[str, Any]: msg = { "status": "OK", "nanny-plugins": self.nanny_plugins, @@ -4807,7 +4803,7 @@ async def handle_request_refresh_who_has( } ) - async def handle_worker(self, comm=None, worker=None, stimulus_id=None): + async def handle_worker(self, comm: Comm, worker: str) -> None: """ Listen to responses from a single worker @@ -4817,7 +4813,6 @@ async def handle_worker(self, comm=None, worker=None, stimulus_id=None): -------- Scheduler.handle_client: Equivalent coroutine for clients """ - assert stimulus_id comm.name = "Scheduler connection to worker" worker_comm = self.stream_comms[worker] worker_comm.start(comm) @@ -4827,7 +4822,9 @@ async def handle_worker(self, comm=None, worker=None, stimulus_id=None): finally: if worker in self.stream_comms: worker_comm.abort() - await self.remove_worker(address=worker, stimulus_id=stimulus_id) + await self.remove_worker( + worker, stimulus_id=f"handle-worker-cleanup-{time()}" + ) def add_plugin( self, diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index b90a7689e5c..175375dc19f 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -1339,23 +1339,6 @@ async def test_scheduler_file(): await s.close() -@pytest.mark.xfail() -@gen_cluster(client=True, nthreads=[]) -async def test_non_existent_worker(c, s): - with dask.config.set({"distributed.comm.timeouts.connect": "100ms"}): - await s.add_worker( - address="127.0.0.1:5738", - status="running", - nthreads=2, - nbytes={}, - host_info={}, - ) - futures = c.map(inc, range(10)) - await asyncio.sleep(0.300) - assert not s.workers - assert all(ts.state == "no-worker" for ts in s.tasks.values()) - - @pytest.mark.parametrize( "host", ["tcp://0.0.0.0", "tcp://127.0.0.1", "tcp://127.0.0.1:38275"] )