diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 3295172dd51..7eebbb6aef2 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -203,7 +203,7 @@ jobs: set -o pipefail mkdir reports - pytest distributed \ + pytest distributed/deploy/tests/test_spec_cluster.py \ # DNM -m "not avoid_ci and ${{ matrix.partition }}" --runslow \ --leaks=fds,processes,threads \ --junitxml reports/pytest.xml -o junit_suite_name=$TEST_ID \ diff --git a/distributed/deploy/spec.py b/distributed/deploy/spec.py index de48a231ad4..619d73c01fa 100644 --- a/distributed/deploy/spec.py +++ b/distributed/deploy/spec.py @@ -379,13 +379,17 @@ async def _correct_state_internal(self) -> None: self._created.add(worker) workers.append(worker) if workers: - await asyncio.wait( - [asyncio.create_task(_wrap_awaitable(w)) for w in workers] - ) + worker_futs = [asyncio.ensure_future(w) for w in workers] + await asyncio.wait(worker_futs) + self.workers.update(dict(zip(to_open, workers))) for w in workers: w._cluster = weakref.ref(self) + # Collect exceptions from failed workers. This must happen after all + # *other* workers have finished initialising, so that we can have a + # proper teardown. + await asyncio.gather(*worker_futs) + for w in workers: await w # for tornado gen.coroutine support - self.workers.update(dict(zip(to_open, workers))) def _update_worker_status(self, op, msg): if op == "remove": @@ -467,10 +471,14 @@ async def _close(self): await super()._close() async def __aenter__(self): - await self - await self._correct_state() - assert self.status == Status.running - return self + try: + await self + await self._correct_state() + assert self.status == Status.running + return self + except Exception: + await self.close() + raise def _threads_per_worker(self) -> int: """Return the number of threads per worker for new workers""" diff --git a/distributed/deploy/tests/test_spec_cluster.py b/distributed/deploy/tests/test_spec_cluster.py index f875db0c3ed..8f6550cc4d0 100644 --- a/distributed/deploy/tests/test_spec_cluster.py +++ b/distributed/deploy/tests/test_spec_cluster.py @@ -207,7 +207,7 @@ async def test_restart(): await asyncio.sleep(0.01) -@pytest.mark.skipif(WINDOWS, reason="HTTP Server doesn't close out") +@pytest.mark.repeat(50) # DNM @gen_test() async def test_broken_worker(): class BrokenWorkerException(Exception): @@ -216,7 +216,6 @@ class BrokenWorkerException(Exception): class BrokenWorker(Worker): def __await__(self): async def _(): - self.status = Status.closed raise BrokenWorkerException("Worker Broken") return _().__await__() @@ -226,13 +225,9 @@ async def _(): workers={"good": {"cls": Worker}, "bad": {"cls": BrokenWorker}}, scheduler=scheduler, ) - try: - with pytest.raises(BrokenWorkerException, match=r"Worker Broken"): - async with cluster: - pass - finally: - # FIXME: SpecCluster leaks if SpecCluster.__aenter__ raises - await cluster.close() + with pytest.raises(BrokenWorkerException, match=r"Worker Broken"): + async with cluster: + pass @pytest.mark.skipif(WINDOWS, reason="HTTP Server doesn't close out")