From 21739997c48965cae504d372dcbc71ff68004bfa Mon Sep 17 00:00:00 2001 From: alex-rakowski Date: Fri, 2 Aug 2024 10:18:26 +0100 Subject: [PATCH 1/7] typo fix (#8812) --- distributed/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 7dc30b8893..4b5791db80 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -2100,7 +2100,7 @@ def _transition( "key": key, "start": start, "finish": finish, - "transistion_log": list(self.transition_log), + "transition_log": list(self.transition_log), }, ) if LOG_PDB: From 798183deaa7de3ae663314af584905a634c59e55 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Fri, 2 Aug 2024 14:59:38 +0200 Subject: [PATCH 2/7] Fix exception handling for ``WorkerPlugin.setup`` and ``WorkerPlugin.teardown`` (#8810) --- distributed/client.py | 3 +- .../diagnostics/tests/test_worker_plugin.py | 58 ++- distributed/shuffle/tests/test_shuffle.py | 343 +++++++++--------- distributed/worker.py | 13 +- 4 files changed, 228 insertions(+), 189 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index ad283a352a..0601b0db5f 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -5409,8 +5409,7 @@ async def _unregister_worker_plugin(self, name, nanny=None): for response in responses.values(): if response["status"] == "error": - exc = response["exception"] - tb = response["traceback"] + _, exc, tb = clean_exception(**response) raise exc.with_traceback(tb) return responses diff --git a/distributed/diagnostics/tests/test_worker_plugin.py b/distributed/diagnostics/tests/test_worker_plugin.py index 0f206512b8..001576afe3 100644 --- a/distributed/diagnostics/tests/test_worker_plugin.py +++ b/distributed/diagnostics/tests/test_worker_plugin.py @@ -1,13 +1,14 @@ from __future__ import annotations import asyncio +import logging import warnings import pytest from distributed import Worker, WorkerPlugin from distributed.protocol.pickle import dumps -from distributed.utils_test import async_poll_for, gen_cluster, inc +from distributed.utils_test import async_poll_for, captured_logger, gen_cluster, inc class MyPlugin(WorkerPlugin): @@ -423,3 +424,58 @@ def setup(self, worker): await c.register_plugin(second, idempotent=True) assert "idempotentplugin" in a.plugins assert a.plugins["idempotentplugin"].instance == "first" + + +class BrokenSetupPlugin(WorkerPlugin): + def setup(self, worker): + raise RuntimeError("test error") + + +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_register_plugin_with_broken_setup_to_existing_workers_raises(c, s, a): + with pytest.raises(RuntimeError, match="test error"): + with captured_logger("distributed.worker", level=logging.ERROR) as caplog: + await c.register_plugin(BrokenSetupPlugin(), name="TestPlugin1") + logs = caplog.getvalue() + assert "TestPlugin1 failed to setup" in logs + assert "test error" in logs + + +@gen_cluster(client=True, nthreads=[]) +async def test_plugin_with_broken_setup_on_new_worker_logs(c, s): + await c.register_plugin(BrokenSetupPlugin(), name="TestPlugin1") + + with captured_logger("distributed.worker", level=logging.ERROR) as caplog: + async with Worker(s.address): + pass + logs = caplog.getvalue() + assert "TestPlugin1 failed to setup" in logs + assert "test error" in logs + + +class BrokenTeardownPlugin(WorkerPlugin): + def teardown(self, worker): + raise RuntimeError("test error") + + +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_unregister_worker_plugin_with_broken_teardown_raises(c, s, a): + await c.register_plugin(BrokenTeardownPlugin(), name="TestPlugin1") + with pytest.raises(RuntimeError, match="test error"): + with captured_logger("distributed.worker", level=logging.ERROR) as caplog: + await c.unregister_worker_plugin("TestPlugin1") + logs = caplog.getvalue() + assert "TestPlugin1 failed to teardown" in logs + assert "test error" in logs + + +@gen_cluster(client=True, nthreads=[]) +async def test_plugin_with_broken_teardown_logs_on_close(c, s): + await c.register_plugin(BrokenTeardownPlugin(), name="TestPlugin1") + + with captured_logger("distributed.worker", level=logging.ERROR) as caplog: + async with Worker(s.address): + pass + logs = caplog.getvalue() + assert "TestPlugin1 failed to teardown" in logs + assert "test error" in logs diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index 39f53528fd..443595b0bb 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -87,29 +87,32 @@ def lose_annotations(request): return request.param -async def check_worker_cleanup( +async def assert_worker_cleanup( worker: Worker, - closed: bool = False, + close: bool = False, interval: float = 0.01, timeout: int | None = 5, ) -> None: """Assert that the worker has no shuffle state""" - deadline = Deadline.after(timeout) plugin = worker.plugins["shuffle"] assert isinstance(plugin, ShuffleWorkerPlugin) - while plugin.shuffle_runs._runs and not deadline.expired: - await asyncio.sleep(interval) - assert not plugin.shuffle_runs._runs - if closed: + deadline = Deadline.after(timeout) + if close: + await worker.close() + assert "shuffle" not in worker.plugins assert plugin.closed + else: + while plugin.shuffle_runs._runs and not deadline.expired: + await asyncio.sleep(interval) + assert not plugin.shuffle_runs._runs for dirpath, dirnames, filenames in os.walk(worker.local_directory): assert "shuffle" not in dirpath for fn in dirnames + filenames: assert "shuffle" not in fn -async def check_scheduler_cleanup( +async def assert_scheduler_cleanup( scheduler: Scheduler, interval: float = 0.01, timeout: int | None = 5 ) -> None: """Assert that the scheduler has no shuffle state""" @@ -175,9 +178,9 @@ async def test_basic_cudf_support(c, s, a, b): result, expected = await c.compute([shuffled, df], sync=True) dd.assert_eq(result, expected) - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) def get_active_shuffle_run(shuffle_id: ShuffleId, worker: Worker) -> ShuffleRun: @@ -213,9 +216,9 @@ async def test_basic_integration(c, s, a, b, npartitions, disk): result, expected = await c.compute([shuffled, df], sync=True) dd.assert_eq(result, expected) - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) @pytest.mark.parametrize("processes", [True, False]) @@ -260,9 +263,9 @@ async def test_shuffle_with_array_conversion(c, s, a, b, npartitions): else: await c.compute(out) - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) def test_shuffle_before_categorize(loop_in_thread): @@ -295,9 +298,9 @@ async def test_concurrent(c, s, a, b): dd.assert_eq(x, df, check_index=False) dd.assert_eq(y, df, check_index=False) - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) @gen_cluster(client=True) @@ -323,9 +326,9 @@ async def test_bad_disk(c, s, a, b): out = await c.compute(out) await c.close() - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) async def wait_until_worker_has_tasks( @@ -401,15 +404,14 @@ async def test_closed_worker_during_transfer(c, s, a, b): shuffled = df.shuffle("x") fut = c.compute([shuffled, df], sync=True) await wait_for_tasks_in_state("shuffle-transfer", "memory", 1, b) - await b.close() + await assert_worker_cleanup(b, close=True) result, expected = await fut dd.assert_eq(result, expected) await c.close() - await check_worker_cleanup(a) - await check_worker_cleanup(b, closed=True) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_scheduler_cleanup(s) @gen_cluster( @@ -428,16 +430,15 @@ async def test_restarting_during_transfer_raises_killed_worker(c, s, a, b): out = df.shuffle("x") out = c.compute(out.x.size) await wait_for_tasks_in_state("shuffle-transfer", "memory", 1, b) - await b.close() + await assert_worker_cleanup(b, close=True) with pytest.raises(KilledWorker): await out assert sum(event["action"] == "p2p-failed" for _, event in s.get_events("p2p")) == 1 await c.close() - await check_worker_cleanup(a) - await check_worker_cleanup(b, closed=True) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_scheduler_cleanup(s) @gen_cluster( @@ -491,14 +492,13 @@ async def test_restarting_does_not_log_p2p_failed(c, s, a, b): out = df.shuffle("x") out = c.compute(out.x.size) await wait_for_tasks_in_state("shuffle-transfer", "memory", 1, b) - await b.close() + await assert_worker_cleanup(b, close=True) await out assert not s.get_events("p2p") await c.close() - await check_worker_cleanup(a) - await check_worker_cleanup(b, closed=True) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_scheduler_cleanup(s) class BlockedGetOrCreateShuffleRunManager(_ShuffleRunManager): @@ -538,7 +538,7 @@ async def test_get_or_create_from_dangling_transfer(c, s, a, b): shuffle_extB.shuffle_runs.block_get_or_create.set() await shuffle_extA.shuffle_runs.in_get_or_create.wait() - await b.close() + await assert_worker_cleanup(b, close=True) await async_poll_for( lambda: not any(ws.processing for ws in s.workers.values()), timeout=5 ) @@ -552,10 +552,9 @@ async def test_get_or_create_from_dangling_transfer(c, s, a, b): await async_poll_for(lambda: not a.state.tasks, timeout=10) assert not s.plugins["shuffle"].active_shuffles - await check_worker_cleanup(a) - await check_worker_cleanup(b, closed=True) + await assert_worker_cleanup(a) await c.close() - await check_scheduler_cleanup(s) + await assert_scheduler_cleanup(s) @pytest.mark.slow @@ -581,8 +580,8 @@ async def test_crashed_worker_during_transfer(c, s, a): dd.assert_eq(result, expected) await c.close() - await check_worker_cleanup(a) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_scheduler_cleanup(s) @gen_cluster( @@ -648,15 +647,14 @@ def mock_get_worker_for_range_sharding( shuffled = df.shuffle("x") fut = c.compute([shuffled, df], sync=True) await wait_for_tasks_in_state("shuffle-transfer", "memory", 1, b, 0.001) - await b.close() + await assert_worker_cleanup(b, close=True) result, expected = await fut dd.assert_eq(result, expected) await c.close() - await check_worker_cleanup(a) - await check_worker_cleanup(b, closed=True) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_scheduler_cleanup(s) @pytest.mark.slow @@ -691,8 +689,8 @@ def mock_mock_get_worker_for_range_sharding( dd.assert_eq(result, expected) await c.close() - await check_worker_cleanup(a) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_scheduler_cleanup(s) # @pytest.mark.slow @@ -714,15 +712,14 @@ async def test_closed_bystanding_worker_during_shuffle(c, s, w1, w2, w3): ) await wait_for_tasks_in_state("shuffle-transfer", "memory", 1, w1) await wait_for_tasks_in_state("shuffle-transfer", "memory", 1, w2) - await w3.close() + await assert_worker_cleanup(w3, close=True) result, expected = await fut dd.assert_eq(result, expected) - await check_worker_cleanup(w1) - await check_worker_cleanup(w2) - await check_worker_cleanup(w3, closed=True) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(w1) + await assert_worker_cleanup(w2) + await assert_scheduler_cleanup(s) class RaiseOnCloseShuffleRun(DataFrameShuffleRun): @@ -749,9 +746,8 @@ async def test_exception_on_close_cleans_up(c, s, caplog): with dask.config.set({"dataframe.shuffle.method": "p2p"}): shuffled = df.shuffle("x") await c.compute([shuffled, df], sync=True) - + await assert_worker_cleanup(w, close=True) assert any("test-exception-on-close" in record.message for record in caplog.records) - await check_worker_cleanup(w, closed=True) class BlockedInputsDoneShuffle(DataFrameShuffleRun): @@ -798,7 +794,7 @@ async def test_closed_worker_during_barrier(c, s, a, b): else: close_worker, alive_worker = b, a alive_shuffle = shuffleA - await close_worker.close() + await assert_worker_cleanup(close_worker, close=True) alive_shuffle.block_inputs_done.set() alive_shuffles = get_active_shuffle_runs(alive_worker) @@ -820,9 +816,8 @@ def shuffle_restarted(): dd.assert_eq(result, expected) await c.close() - await check_worker_cleanup(close_worker, closed=True) - await check_worker_cleanup(alive_worker) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(alive_worker) + await assert_scheduler_cleanup(s) @mock.patch( @@ -861,7 +856,7 @@ async def test_restarting_during_barrier_raises_killed_worker(c, s, a, b): else: close_worker, alive_worker = b, a alive_shuffle = shuffleA - await close_worker.close() + await assert_worker_cleanup(close_worker, close=True) with pytest.raises(KilledWorker): await out @@ -870,9 +865,8 @@ async def test_restarting_during_barrier_raises_killed_worker(c, s, a, b): alive_shuffle.block_inputs_done.set() await c.close() - await check_worker_cleanup(close_worker, closed=True) - await check_worker_cleanup(alive_worker) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(alive_worker) + await assert_scheduler_cleanup(s) @mock.patch( @@ -909,7 +903,7 @@ async def test_closed_other_worker_during_barrier(c, s, a, b): else: close_worker, alive_worker = a, b alive_shuffle = shuffleB - await close_worker.close() + await assert_worker_cleanup(close_worker, close=True) alive_shuffle.block_inputs_done.set() alive_shuffles = get_active_shuffle_runs(alive_worker) @@ -931,9 +925,8 @@ def shuffle_restarted(): dd.assert_eq(result, expected) await c.close() - await check_worker_cleanup(close_worker, closed=True) - await check_worker_cleanup(alive_worker) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(alive_worker) + await assert_scheduler_cleanup(s) @pytest.mark.slow @@ -981,8 +974,8 @@ def shuffle_restarted(): dd.assert_eq(result, expected) await c.close() - await check_worker_cleanup(a) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_scheduler_cleanup(s) @gen_cluster(client=True, nthreads=[("", 1)] * 2) @@ -997,15 +990,14 @@ async def test_closed_worker_during_unpack(c, s, a, b): shuffled = df.shuffle("x") fut = c.compute([shuffled, df], sync=True) await wait_for_tasks_in_state(UNPACK_PREFIX, "memory", 1, b) - await b.close() + await assert_worker_cleanup(b, close=True) result, expected = await fut dd.assert_eq(result, expected) await c.close() - await check_worker_cleanup(a) - await check_worker_cleanup(b, closed=True) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_scheduler_cleanup(s) @gen_cluster( @@ -1024,16 +1016,15 @@ async def test_restarting_during_unpack_raises_killed_worker(c, s, a, b): out = df.shuffle("x") out = c.compute(out.x.size) await wait_for_tasks_in_state(UNPACK_PREFIX, "memory", 1, b) - await b.close() + await assert_worker_cleanup(b, close=True) with pytest.raises(KilledWorker): await out assert sum(event["action"] == "p2p-failed" for _, event in s.get_events("p2p")) == 1 await c.close() - await check_worker_cleanup(a) - await check_worker_cleanup(b, closed=True) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_scheduler_cleanup(s) @pytest.mark.slow @@ -1059,14 +1050,14 @@ async def test_crashed_worker_during_unpack(c, s, a): dd.assert_eq(result, expected) await c.close() - await check_worker_cleanup(a) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_scheduler_cleanup(s) @gen_cluster(client=True) async def test_heartbeat(c, s, a, b): await a.heartbeat() - await check_scheduler_cleanup(s) + await assert_scheduler_cleanup(s) df = dask.datasets.timeseries( start="2000-01-01", end="2000-01-10", @@ -1084,10 +1075,10 @@ async def test_heartbeat(c, s, a, b): assert s.plugins["shuffle"].heartbeats.values() await out - await check_worker_cleanup(a) - await check_worker_cleanup(b) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) del out - await check_scheduler_cleanup(s) + await assert_scheduler_cleanup(s) @pytest.mark.skipif("not pa", reason="Requires PyArrow") @@ -1292,10 +1283,10 @@ async def test_head(c, s, a, b): assert list(os.walk(a.local_directory)) == a_files # cleaned up files? assert list(os.walk(b.local_directory)) == b_files - await check_worker_cleanup(a) - await check_worker_cleanup(b) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) del out - await check_scheduler_cleanup(s) + await assert_scheduler_cleanup(s) def test_split_by_worker(): @@ -1399,9 +1390,9 @@ async def test_clean_after_forgotten_early(c, s, a, b): await wait_for_tasks_in_state("shuffle-transfer", "memory", 1, a) await wait_for_tasks_in_state("shuffle-transfer", "memory", 1, b) del out - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) @gen_cluster(client=True) @@ -1424,9 +1415,9 @@ async def test_tail(c, s, a, b): assert len(s.tasks) < ntasks_full del partial - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) @pytest.mark.parametrize("wait_until_forgotten", [True, False]) @@ -1454,9 +1445,9 @@ async def test_repeat_shuffle_instance(c, s, a, b, wait_until_forgotten): await c.compute(out) - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) @pytest.mark.parametrize("wait_until_forgotten", [True, False]) @@ -1485,9 +1476,9 @@ async def test_repeat_shuffle_operation(c, s, a, b, wait_until_forgotten): with dask.config.set({"dataframe.shuffle.method": "p2p"}): await c.compute(df.shuffle("x")) - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) @gen_cluster(client=True, nthreads=[("", 1)]) @@ -1532,8 +1523,8 @@ def block(df, in_event, block_event): assert result == expected await c.close() - await check_worker_cleanup(a) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_scheduler_cleanup(s) @gen_cluster(client=True, nthreads=[("", 1)]) @@ -1561,8 +1552,8 @@ async def test_crashed_worker_after_shuffle_persisted(c, s, a): assert result == expected await c.close() - await check_worker_cleanup(a) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_scheduler_cleanup(s) @gen_cluster(client=True, nthreads=[("", 1)] * 3) @@ -1578,25 +1569,22 @@ async def test_closed_worker_between_repeats(c, s, w1, w2, w3): out = df.shuffle("x") await c.compute(out.head(compute=False)) - await check_worker_cleanup(w1) - await check_worker_cleanup(w2) - await check_worker_cleanup(w3) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(w1) + await assert_worker_cleanup(w2) + await assert_worker_cleanup(w3) + await assert_scheduler_cleanup(s) - await w3.close() + await assert_worker_cleanup(w3, close=True) await c.compute(out.tail(compute=False)) - await check_worker_cleanup(w1) - await check_worker_cleanup(w2) - await check_worker_cleanup(w3, closed=True) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(w1) + await assert_worker_cleanup(w2) + await assert_scheduler_cleanup(s) - await w2.close() + await assert_worker_cleanup(w2, close=True) await c.compute(out.head(compute=False)) - await check_worker_cleanup(w1) - await check_worker_cleanup(w2, closed=True) - await check_worker_cleanup(w3, closed=True) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(w1) + await assert_scheduler_cleanup(s) @gen_cluster(client=True) @@ -1616,11 +1604,11 @@ async def test_new_worker(c, s, a, b): async with Worker(s.address) as w: await c.compute(persisted) - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_worker_cleanup(w) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_worker_cleanup(w) del persisted - await check_scheduler_cleanup(s) + await assert_scheduler_cleanup(s) @gen_cluster(client=True) @@ -1644,9 +1632,9 @@ async def test_multi(c, s, a, b): out = await c.compute(out.size) assert out - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) @pytest.mark.skipif( @@ -1694,10 +1682,10 @@ async def test_delete_some_results(c, s, a, b): x = x.partitions[: x.npartitions // 2] x = await c.compute(x.size) - await check_worker_cleanup(a) - await check_worker_cleanup(b) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) del x - await check_scheduler_cleanup(s) + await assert_scheduler_cleanup(s) @gen_cluster(client=True) @@ -1719,11 +1707,11 @@ async def test_add_some_results(c, s, a, b): await c.compute(x.size) - await check_worker_cleanup(a) - await check_worker_cleanup(b) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) del x del y - await check_scheduler_cleanup(s) + await assert_scheduler_cleanup(s) @pytest.mark.slow @@ -1743,12 +1731,11 @@ async def test_clean_after_close(c, s, a, b): await wait_for_tasks_in_state("shuffle-transfer", "executing", 1, a) await wait_for_tasks_in_state("shuffle-transfer", "memory", 1, b) - await a.close() - await check_worker_cleanup(a, closed=True) + await assert_worker_cleanup(a, close=True) del out - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) class DataFrameShuffleTestPool(AbstractShuffleTestPool): @@ -2115,9 +2102,9 @@ async def test_deduplicate_stale_transfer(c, s, a, b, wait_until_forgotten): expected = await c.compute(df) dd.assert_eq(result, expected) - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) class BlockedBarrierShuffleWorkerPlugin(ShuffleWorkerPlugin): @@ -2172,9 +2159,9 @@ async def test_handle_stale_barrier(c, s, a, b, wait_until_forgotten): result, expected = await fut dd.assert_eq(result, expected) - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) @gen_cluster(client=True, nthreads=[("", 1)]) @@ -2270,8 +2257,8 @@ async def test_shuffle_run_consistency(c, s, a): await out del out - await check_worker_cleanup(a) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_scheduler_cleanup(s) @gen_cluster(client=True, nthreads=[("", 1)]) @@ -2317,8 +2304,8 @@ async def test_fail_fetch_race(c, s, a): worker_plugin.block_barrier.set() del out - await check_worker_cleanup(a) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_scheduler_cleanup(s) class BlockedShuffleAccessAndFailShuffleRunManager(_ShuffleRunManager): @@ -2393,7 +2380,7 @@ async def test_replace_stale_shuffle(c, s, a, b): await asyncio.sleep(0) # A is cleaned - await check_worker_cleanup(a) + await assert_worker_cleanup(a) # B is not cleaned assert shuffle_id in get_active_shuffle_runs(b) @@ -2424,9 +2411,9 @@ async def test_replace_stale_shuffle(c, s, a, b): await out del out - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) @gen_cluster(client=True) @@ -2444,9 +2431,9 @@ async def test_handle_null_partitions(c, s, a, b): result = await c.compute(ddf) dd.assert_eq(result, df) - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) @gen_cluster(client=True) @@ -2467,9 +2454,9 @@ def make_partition(i): expected = await expected dd.assert_eq(result, expected) - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) @gen_cluster(client=True) @@ -2496,9 +2483,9 @@ async def test_handle_object_columns(c, s, a, b): result = await c.compute(shuffled) dd.assert_eq(result, df) - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) @gen_cluster(client=True) @@ -2529,9 +2516,9 @@ def make_partition(i): await c.close() del out - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) @gen_cluster(client=True) @@ -2555,9 +2542,9 @@ def make_partition(i): await c.compute(out) await c.close() - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) @gen_cluster(client=True) @@ -2582,9 +2569,9 @@ async def test_handle_categorical_data(c, s, a, b): result, expected = await c.compute([shuffled, df], sync=True) dd.assert_eq(result, expected, check_categorical=False) - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) @gen_cluster(client=True) @@ -2620,8 +2607,8 @@ async def test_set_index(c, s, *workers): dd.assert_eq(result, df.set_index("a")) await c.close() - await asyncio.gather(*[check_worker_cleanup(w) for w in workers]) - await check_scheduler_cleanup(s) + await asyncio.gather(*[assert_worker_cleanup(w) for w in workers]) + await assert_scheduler_cleanup(s) def test_shuffle_with_existing_index(client): @@ -2741,9 +2728,9 @@ async def test_unpack_is_non_rootish(c, s, a, b): scheduler_plugin.block_barrier.set() result = await result - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) class FlakyConnectionPool(ConnectionPool): @@ -2791,10 +2778,10 @@ async def test_flaky_connect_fails_without_retry(c, s, a, b): ): await c.compute(x) - await check_worker_cleanup(a) - await check_worker_cleanup(b) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) await c.close() - await check_scheduler_cleanup(s) + await assert_scheduler_cleanup(s) @gen_cluster( @@ -2823,9 +2810,9 @@ async def test_flaky_connect_recover_with_retry(c, s, a, b): assert len(line) < 250 assert not line or line.startswith("Retrying") - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) class BlockedAfterGatherDep(Worker): @@ -2900,9 +2887,9 @@ def make_partition(partition_id, size): for _, group in result.groupby("b"): assert group["a"].is_monotonic_increasing - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) @pytest.mark.parametrize("disk", [True, False]) diff --git a/distributed/worker.py b/distributed/worker.py index 5778d60dac..18ef0aca86 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1229,7 +1229,7 @@ async def _register_with_scheduler(self) -> None: *( self.plugin_add(name=name, plugin=plugin) for name, plugin in response["worker-plugins"].items() - ) + ), ) logger.info(" Registered to: %26s", self.scheduler.address) @@ -1560,12 +1560,7 @@ async def close( # type: ignore # Cancel async instructions await BaseWorker.close(self, timeout=timeout) - teardowns = [ - plugin.teardown(self) - for plugin in self.plugins.values() - if hasattr(plugin, "teardown") - ] - await asyncio.gather(*(td for td in teardowns if isawaitable(td))) + await asyncio.gather(*(self.plugin_remove(name) for name in self.plugins)) for extension in self.extensions.values(): if hasattr(extension, "close"): @@ -1870,13 +1865,14 @@ async def plugin_add( self.plugins[name] = plugin - logger.info("Starting Worker plugin %s" % name) + logger.info("Starting Worker plugin %s", name) if hasattr(plugin, "setup"): try: result = plugin.setup(worker=self) if isawaitable(result): result = await result except Exception as e: + logger.exception("Worker plugin %s failed to setup", name) if not catch_errors: raise return error_message(e) @@ -1893,6 +1889,7 @@ async def plugin_remove(self, name: str) -> ErrorMessage | OKMessage: if isawaitable(result): result = await result except Exception as e: + logger.exception("Worker plugin %s failed to teardown", name) return error_message(e) return {"status": "OK"} From fea5515030e3b79475e3555fec84e309177132f8 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Fri, 2 Aug 2024 15:02:45 +0200 Subject: [PATCH 3/7] Fix exception handling for ``NannyPlugin.setup`` and ``NannyPlugin.teardown`` (#8811) --- .../diagnostics/tests/test_nanny_plugin.py | 59 ++++++++++++++++++- distributed/nanny.py | 12 ++-- 2 files changed, 62 insertions(+), 9 deletions(-) diff --git a/distributed/diagnostics/tests/test_nanny_plugin.py b/distributed/diagnostics/tests/test_nanny_plugin.py index db17fe70b5..3c481dce26 100644 --- a/distributed/diagnostics/tests/test_nanny_plugin.py +++ b/distributed/diagnostics/tests/test_nanny_plugin.py @@ -1,10 +1,12 @@ from __future__ import annotations +import logging + import pytest from distributed import Nanny, NannyPlugin from distributed.protocol.pickle import dumps -from distributed.utils_test import gen_cluster +from distributed.utils_test import captured_logger, gen_cluster @gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny) @@ -160,3 +162,58 @@ def setup(self, nanny): await c.register_plugin(second, idempotent=True) assert "idempotentplugin" in a.plugins assert a.plugins["idempotentplugin"].instance == "first" + + +class BrokenSetupPlugin(NannyPlugin): + def setup(self, nanny): + raise RuntimeError("test error") + + +@gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny) +async def test_register_plugin_with_broken_setup_to_existing_nannies_raises(c, s, a): + with pytest.raises(RuntimeError, match="test error"): + with captured_logger("distributed.nanny", level=logging.ERROR) as caplog: + await c.register_plugin(BrokenSetupPlugin(), name="TestPlugin1") + logs = caplog.getvalue() + assert "TestPlugin1 failed to setup" in logs + assert "test error" in logs + + +@gen_cluster(client=True, nthreads=[]) +async def test_plugin_with_broken_setup_on_new_nanny_logs(c, s): + await c.register_plugin(BrokenSetupPlugin(), name="TestPlugin1") + + with captured_logger("distributed.nanny", level=logging.ERROR) as caplog: + async with Nanny(s.address): + pass + logs = caplog.getvalue() + assert "TestPlugin1 failed to setup" in logs + assert "test error" in logs + + +class BrokenTeardownPlugin(NannyPlugin): + def teardown(self, nanny): + raise RuntimeError("test error") + + +@gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny) +async def test_unregister_nanny_plugin_with_broken_teardown_raises(c, s, a): + await c.register_plugin(BrokenTeardownPlugin(), name="TestPlugin1") + with pytest.raises(RuntimeError, match="test error"): + with captured_logger("distributed.nanny", level=logging.ERROR) as caplog: + await c.unregister_worker_plugin("TestPlugin1", nanny=True) + logs = caplog.getvalue() + assert "TestPlugin1 failed to teardown" in logs + assert "test error" in logs + + +@gen_cluster(client=True, nthreads=[]) +async def test_nanny_plugin_with_broken_teardown_logs_on_close(c, s): + await c.register_plugin(BrokenTeardownPlugin(), name="TestPlugin1") + + with captured_logger("distributed.nanny", level=logging.ERROR) as caplog: + async with Nanny(s.address): + pass + logs = caplog.getvalue() + assert "TestPlugin1 failed to teardown" in logs + assert "test error" in logs diff --git a/distributed/nanny.py b/distributed/nanny.py index af0d9a62ad..7a14ee6576 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -477,13 +477,14 @@ async def plugin_add( self.plugins[name] = plugin - logger.info("Starting Nanny plugin %s" % name) + logger.info("Starting Nanny plugin %s", name) if hasattr(plugin, "setup"): try: result = plugin.setup(nanny=self) if isawaitable(result): result = await result except Exception as e: + logger.exception("Nanny plugin %s failed to setup", name) return error_message(e) if getattr(plugin, "restart", False): await self.restart(reason=f"nanny-plugin-{name}-restart") @@ -500,6 +501,7 @@ async def plugin_remove(self, name: str) -> ErrorMessage | OKMessage: if isawaitable(result): result = await result except Exception as e: + logger.exception("Nanny plugin %s failed to teardown", name) msg = error_message(e) return msg @@ -610,13 +612,7 @@ async def close( # type:ignore[override] await self.preloads.teardown() - teardowns = [ - plugin.teardown(self) - for plugin in self.plugins.values() - if hasattr(plugin, "teardown") - ] - - await asyncio.gather(*(td for td in teardowns if isawaitable(td))) + await asyncio.gather(*(self.plugin_remove(name) for name in self.plugins)) self.stop() if self.process is not None: From 879e5924ccb4f1bb45e0090f6ca9e6f7eaf343ce Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Fri, 2 Aug 2024 15:55:56 +0200 Subject: [PATCH 4/7] Fail tasks exceeding `no-workers-timeout` (#8806) --- distributed/distributed-schema.yaml | 11 +- distributed/distributed.yaml | 2 +- distributed/scheduler.py | 365 +++++++++++++++++++++--- distributed/tests/test_scheduler.py | 71 +++-- distributed/tests/test_worker_memory.py | 4 +- 5 files changed, 383 insertions(+), 70 deletions(-) diff --git a/distributed/distributed-schema.yaml b/distributed/distributed-schema.yaml index 45534e9be8..f7e452383c 100644 --- a/distributed/distributed-schema.yaml +++ b/distributed/distributed-schema.yaml @@ -81,16 +81,15 @@ properties: - string - "null" description: | - Shut down the scheduler after this duration if there are pending tasks, - but no workers that can process them. This can either mean that there are - no workers running at all, or that there are idle workers but they've been - excluded through worker or resource restrictions. + Timeout for tasks in an unrunnable state. + + If task remains unrunnable for longer than this, it fails. A task is considered unrunnable IFF + it has no pending dependencies, and the task has restrictions that are not satisfied by + any available worker or no workers are running at all. In adaptive clusters, this timeout must be set to be safely higher than the time it takes for workers to spin up. - Works in conjunction with idle-timeout. - work-stealing: type: boolean description: | diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index 26b88322d0..250af10f7d 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -19,7 +19,7 @@ distributed: # after they have been removed from the scheduler events-cleanup-delay: 1h idle-timeout: null # Shut down after this duration, like "1h" or "30 minutes" - no-workers-timeout: null # Shut down if there are tasks but no workers to process them + no-workers-timeout: null # If a task remains unrunnable for longer than this, it fails. work-stealing: True # workers should steal tasks from each other work-stealing-interval: 100ms # Callback time for work stealing worker-saturation: 1.1 # Send this fraction of nthreads root tasks to workers diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 4b5791db80..adf99113b9 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -103,7 +103,7 @@ from distributed.event import EventExtension from distributed.gc import disable_gc_diagnosis, enable_gc_diagnosis from distributed.http import get_handlers -from distributed.metrics import time +from distributed.metrics import monotonic, time from distributed.multi_lock import MultiLockExtension from distributed.node import ServerNode from distributed.proctitle import setproctitle @@ -1681,8 +1681,8 @@ class SchedulerState: #: Tasks in the "queued" state, ordered by priority queued: HeapSet[TaskState] - #: Tasks in the "no-worker" state - unrunnable: set[TaskState] + #: Tasks in the "no-worker" state with the (monotonic) time when they became unrunnable + unrunnable: dict[TaskState, float] #: Subset of tasks that exist in memory on more than one worker replicated_tasks: set[TaskState] @@ -1755,7 +1755,7 @@ def __init__( host_info: dict[str, dict[str, Any]], resources: dict[str, dict[str, float]], tasks: dict[Key, TaskState], - unrunnable: set[TaskState], + unrunnable: dict[TaskState, float], queued: HeapSet[TaskState], validate: bool, plugins: Iterable[SchedulerPlugin] = (), @@ -2193,12 +2193,74 @@ def _transition_no_worker_processing(self, key: Key, stimulus_id: str) -> RecsMs assert ts in self.unrunnable if ws := self.decide_worker_non_rootish(ts): - self.unrunnable.discard(ts) + self.unrunnable.pop(ts, None) return self._add_to_processing(ts, ws, stimulus_id=stimulus_id) # If no worker, task just stays in `no-worker` return {}, {}, {} + def _transition_no_worker_erred( + self, + key: Key, + stimulus_id: str, + *, + # TODO: Which ones can actually be None? + cause: Key | None = None, + exception: Serialized | None = None, + traceback: Serialized | None = None, + exception_text: str | None = None, + traceback_text: str | None = None, + ) -> RecsMsgs: + ts = self.tasks[key] + + if self.validate: + assert not ts.actor, f"Actors can't be in `no-worker`: {ts}" + assert cause + assert ts in self.unrunnable + assert not ts.processing_on + + self.unrunnable.pop(ts) + + return self._propagate_erred( + ts, + cause=cause, + exception=exception, + traceback=traceback, + exception_text=exception_text, + traceback_text=traceback_text, + ) + + def _transition_queued_erred( + self, + key: Key, + stimulus_id: str, + *, + # TODO: Which ones can actually be None? + cause: Key | None = None, + exception: Serialized | None = None, + traceback: Serialized | None = None, + exception_text: str | None = None, + traceback_text: str | None = None, + ) -> RecsMsgs: + ts = self.tasks[key] + + if self.validate: + assert not ts.actor, f"Actors can't be in `no-worker`: {ts}" + assert cause + assert ts in self.queued + assert not ts.processing_on + + self.queued.remove(ts) + + return self._propagate_erred( + ts, + cause=cause, + exception=exception, + traceback=traceback, + exception_text=exception_text, + traceback_text=traceback_text, + ) + def decide_worker_rootish_queuing_disabled( self, ts: TaskState ) -> WorkerState | None: @@ -2730,8 +2792,6 @@ def _transition_processing_erred( Recommendations, client messages and worker messages to process """ ts = self.tasks[key] - recommendations: Recs = {} - client_msgs: Msgs = {} if self.validate: assert cause or ts.exception_blame @@ -2746,9 +2806,41 @@ def _transition_processing_erred( self._exit_processing_common(ts) + if self.validate: + assert not ts.processing_on + + return self._propagate_erred( + ts, + worker=worker, + cause=cause, + exception=exception, + traceback=traceback, + exception_text=exception_text, + traceback_text=traceback_text, + ) + + def _propagate_erred( + self, + ts: TaskState, + *, + worker: str | None = None, + cause: Key | None = None, + exception: Serialized | None = None, + traceback: Serialized | None = None, + exception_text: str | None = None, + traceback_text: str | None = None, + ) -> RecsMsgs: + recommendations: Recs = {} + client_msgs: Msgs = {} + + ts.state = "erred" + key = ts.key + if not ts.erred_on: ts.erred_on = set() - ts.erred_on.add(worker) + if worker is not None: + ts.erred_on.add(worker) + if exception is not None: ts.exception = exception ts.exception_text = exception_text @@ -2783,8 +2875,6 @@ def _transition_processing_erred( ts.waiters = None - ts.state = "erred" - report_msg = { "op": "task-erred", "key": key, @@ -2802,9 +2892,6 @@ def _transition_processing_erred( recommendations=recommendations, ) - if self.validate: - assert not ts.processing_on - return recommendations, client_msgs, {} def _transition_no_worker_released(self, key: Key, stimulus_id: str) -> RecsMsgs: @@ -2815,7 +2902,7 @@ def _transition_no_worker_released(self, key: Key, stimulus_id: str) -> RecsMsgs assert not ts.who_has assert not ts.waiting_on - self.unrunnable.remove(ts) + self.unrunnable.pop(ts) recommendations: Recs = {} self._propagate_released(ts, recommendations) @@ -2838,9 +2925,13 @@ def _transition_waiting_no_worker(self, key: Key, stimulus_id: str) -> RecsMsgs: if self.validate: self._validate_ready(ts) + assert ts not in self.unrunnable ts.state = "no-worker" - self.unrunnable.add(ts) + self.unrunnable[ts] = monotonic() + + if self.validate: + validate_unrunnable(self.unrunnable) return {}, {}, {} @@ -2873,7 +2964,7 @@ def _transition_queued_processing(self, key: Key, stimulus_id: str) -> RecsMsgs: def _remove_key(self, key: Key) -> None: ts = self.tasks.pop(key) assert ts.state == "forgotten" - self.unrunnable.discard(ts) + self.unrunnable.pop(ts, None) for cs in ts.who_wants or (): cs.wants_what.remove(ts) ts.who_wants = None @@ -2963,11 +3054,13 @@ def _transition_released_forgotten(self, key: Key, stimulus_id: str) -> RecsMsgs ("waiting", "memory"): _transition_waiting_memory, ("queued", "released"): _transition_queued_released, ("queued", "processing"): _transition_queued_processing, + ("queued", "erred"): _transition_queued_erred, ("processing", "released"): _transition_processing_released, ("processing", "memory"): _transition_processing_memory, ("processing", "erred"): _transition_processing_erred, ("no-worker", "released"): _transition_no_worker_released, ("no-worker", "processing"): _transition_no_worker_processing, + ("no-worker", "erred"): _transition_no_worker_erred, ("released", "forgotten"): _transition_released_forgotten, ("memory", "forgotten"): _transition_memory_forgotten, ("erred", "released"): _transition_erred_released, @@ -3597,7 +3690,6 @@ class Scheduler(SchedulerState, ServerNode): idle_timeout: float | None _no_workers_since: float | None # Note: not None iff there are pending tasks no_workers_timeout: float | None - _client_connections_added_total: int _client_connections_removed_total: int _workers_added_total: int @@ -3794,7 +3886,7 @@ async def post(self): self.generation = 0 self._last_client = None self._last_time = 0 - unrunnable = set() + unrunnable = {} queued = HeapSet(key=operator.attrgetter("priority")) self.datasets = {} @@ -4434,6 +4526,7 @@ async def add_worker( self._workers_added_total += 1 if ws.status == Status.running: self.running.add(ws) + self._refresh_no_workers_since() dh = self.host_info.get(host) if dh is None: @@ -5320,6 +5413,7 @@ async def remove_worker( recommendations: Recs = {} + timestamp = monotonic() processing_keys = {ts.key for ts in ws.processing} for ts in list(ws.processing): k = ts.key @@ -5390,6 +5484,9 @@ async def remove_worker( self.log_event("all", event_msg) self.transitions(recommendations, stimulus_id=stimulus_id) + # Make sure that the timestamp has been collected before tasks were transitioned to no-worker + # to ensure a meaningful error message. + self._refresh_no_workers_since(timestamp=timestamp) awaitables = [] for plugin in list(self.plugins.values()): @@ -5632,6 +5729,7 @@ def validate_key(self, key: Key, ts: TaskState | None = None) -> None: def validate_state(self, allow_overlap: bool = False) -> None: validate_state(self.tasks, self.workers, self.clients) + validate_unrunnable(self.unrunnable) if not (set(self.workers) == set(self.stream_comms)): raise ValueError("Workers not the same in all collections") @@ -5971,6 +6069,7 @@ def handle_worker_status_change( self.idle.pop(ws.address, None) self.idle_task_count.discard(ws) self.saturated.discard(ws) + self._refresh_no_workers_since() def handle_request_refresh_who_has( self, keys: Iterable[Key], worker: str, stimulus_id: str @@ -8552,39 +8651,134 @@ def check_idle(self) -> float | None: return self.idle_since def _check_no_workers(self) -> None: - """Shut down the scheduler if there have been tasks ready to run which have - nowhere to run for `distributed.scheduler.no-workers-timeout`, and there - aren't other tasks running. - """ - if self.status in (Status.closing, Status.closed): - return # pragma: nocover - if ( - (not self.queued and not self.unrunnable) - or (self.queued and self.workers) - or any(ws.processing for ws in self.workers.values()) + self.status in (Status.closing, Status.closed) + or self.no_workers_timeout is None ): - self._no_workers_since = None return - # 1. There are queued or unrunnable tasks and no workers at all - # 2. There are unrunnable tasks and no workers satisfy their restrictions - # (Only rootish tasks can be queued, and rootish tasks can't have restrictions) + now = monotonic() + stimulus_id = f"check-no-workers-timeout-{time()}" - if not self._no_workers_since: - self._no_workers_since = time() - return + recommendations: Recs = {} - if ( - self.no_workers_timeout - and time() > self._no_workers_since + self.no_workers_timeout - ): - logger.info( - "Tasks have been without any workers to run them for %s; " - "shutting scheduler down", - format_time(self.no_workers_timeout), + self._refresh_no_workers_since(now) + + affected = self._check_unrunnable_task_timeouts( + now, recommendations=recommendations, stimulus_id=stimulus_id + ) + + affected.update( + self._check_queued_task_timeouts( + now, recommendations=recommendations, stimulus_id=stimulus_id ) - self._ongoing_background_tasks.call_soon(self.close) + ) + self.transitions(recommendations, stimulus_id=stimulus_id) + if affected: + self.log_event( + "scheduler", + {"action": "no-workers-timeout-exceeded", "keys": affected}, + ) + + def _check_unrunnable_task_timeouts( + self, timestamp: float, recommendations: Recs, stimulus_id: str + ) -> set[Key]: + assert self.no_workers_timeout + unsatisfied = [] + no_workers = [] + for ts, unrunnable_since in self.unrunnable.items(): + if timestamp <= unrunnable_since + self.no_workers_timeout: + # unrunnable is insertion-ordered, which means that unrunnable_since will + # be monotonically increasing in this loop. + break + if ( + self._no_workers_since is None + or self._no_workers_since >= unrunnable_since + ): + unsatisfied.append(ts) + else: + no_workers.append(ts) + if not unsatisfied and not no_workers: + return set() + + for ts in unsatisfied: + e = pickle.dumps( + NoValidWorkerError( + task=ts.key, + host_restrictions=(ts.host_restrictions or set()).copy(), + worker_restrictions=(ts.worker_restrictions or set()).copy(), + resource_restrictions=(ts.resource_restrictions or {}).copy(), + timeout=self.no_workers_timeout, + ), + ) + r = self.transition( + ts.key, + "erred", + exception=e, + cause=ts.key, + stimulus_id=stimulus_id, + ) + recommendations.update(r) + logger.error( + "Task %s marked as failed because it timed out waiting " + "for its restrictions to become satisfied.", + ts.key, + ) + self._fail_tasks_after_no_workers_timeout( + no_workers, recommendations, stimulus_id + ) + return {ts.key for ts in concat([unsatisfied, no_workers])} + + def _check_queued_task_timeouts( + self, timestamp: float, recommendations: Recs, stimulus_id: str + ) -> set[Key]: + assert self.no_workers_timeout + + if self._no_workers_since is None: + return set() + + if timestamp <= self._no_workers_since + self.no_workers_timeout: + return set() + affected = list(self.queued) + self._fail_tasks_after_no_workers_timeout( + affected, recommendations, stimulus_id + ) + return {ts.key for ts in affected} + + def _fail_tasks_after_no_workers_timeout( + self, timed_out: Iterable[TaskState], recommendations: Recs, stimulus_id: str + ) -> None: + assert self.no_workers_timeout + + for ts in timed_out: + e = pickle.dumps( + NoWorkerError( + task=ts.key, + timeout=self.no_workers_timeout, + ), + ) + r = self.transition( + ts.key, + "erred", + exception=e, + cause=ts.key, + stimulus_id=stimulus_id, + ) + recommendations.update(r) + logger.error( + "Task %s marked as failed because it timed out waiting " + "without any running workers.", + ts.key, + ) + + def _refresh_no_workers_since(self, timestamp: float | None = None) -> None: + if self.running or not (self.queued or self.unrunnable): + self._no_workers_since = None + return + + if not self._no_workers_since: + self._no_workers_since = timestamp or monotonic() + return def adaptive_target(self, target_duration=None): """Desired number of workers based on the current workload @@ -8607,7 +8801,7 @@ def adaptive_target(self, target_duration=None): target_duration = parse_timedelta(target_duration) # CPU - queued = take(100, concat([self.queued, self.unrunnable])) + queued = take(100, concat([self.queued, self.unrunnable.keys()])) queued_occupancy = 0 for ts in queued: if ts.prefix.duration_average == -1: @@ -8894,6 +9088,25 @@ def validate_task_state(ts: TaskState) -> None: assert ts.state != "queued" +def validate_unrunnable(unrunnable: dict[TaskState, float]) -> None: + prev_unrunnable_since: float | None = None + prev_ts: TaskState | None = None + for ts, unrunnable_since in unrunnable.items(): + assert ts.state == "no-worker" + if prev_ts is not None: + assert prev_unrunnable_since is not None + # Ensure that unrunnable_since is monotonically increasing when iterating over unrunnable. + # _check_no_workers relies on this. + assert prev_unrunnable_since <= unrunnable_since, ( + prev_ts, + ts, + prev_unrunnable_since, + unrunnable_since, + ) + prev_ts = ts + prev_unrunnable_since = unrunnable_since + + def validate_worker_state(ws: WorkerState) -> None: for ts in ws.has_what or (): assert ts.who_has @@ -8988,6 +9201,68 @@ def __str__(self) -> str: ) +class NoValidWorkerError(Exception): + def __init__( + self, + task: Key, + host_restrictions: set[str], + worker_restrictions: set[str], + resource_restrictions: dict[str, float], + timeout: float, + ): + super().__init__( + task, host_restrictions, worker_restrictions, resource_restrictions, timeout + ) + + @property + def task(self) -> Key: + return self.args[0] + + @property + def host_restrictions(self) -> Any: + return self.args[1] + + @property + def worker_restrictions(self) -> Any: + return self.args[2] + + @property + def resource_restrictions(self) -> Any: + return self.args[3] + + @property + def timeout(self) -> float: + return self.args[4] + + def __str__(self) -> str: + return ( + f"Attempted to run task {self.task!r} but timed out after {format_time(self.timeout)} " + "waiting for a valid worker matching all restrictions.\n\nRestrictions:\n" + "host_restrictions={self.host_restrictions!s}\n" + "worker_restrictions={self.worker_restrictions!s}\n" + "resource_restrictions={self.resource_restrictions!s}\n" + ) + + +class NoWorkerError(Exception): + def __init__(self, task: Key, timeout: float): + super().__init__(task, timeout) + + @property + def task(self) -> Key: + return self.args[0] + + @property + def timeout(self) -> float: + return self.args[1] + + def __str__(self) -> str: + return ( + f"Attempted to run task {self.task!r} but timed out after {format_time(self.timeout)} " + "waiting without any running workers." + ) + + class WorkerStatusPlugin(SchedulerPlugin): """A plugin to share worker status with a remote observer diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 0a63557106..a3c3316999 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -47,7 +47,14 @@ from distributed.protocol import serialize from distributed.protocol.pickle import dumps, loads from distributed.protocol.serialize import ToPickle -from distributed.scheduler import KilledWorker, MemoryState, Scheduler, WorkerState +from distributed.scheduler import ( + KilledWorker, + MemoryState, + NoValidWorkerError, + NoWorkerError, + Scheduler, + WorkerState, +) from distributed.utils import TimeoutError, wait_for from distributed.utils_test import ( NO_AMM, @@ -2108,7 +2115,7 @@ def g(_, ev1, ev2): await ev2.set() -# @pytest.mark.slow +@pytest.mark.slow @gen_cluster( client=True, Worker=Nanny, clean_kwargs={"processes": False, "threads": False} ) @@ -2445,7 +2452,7 @@ async def test_idle_timeout_no_workers(c, s): nthreads=[], config={"distributed.scheduler.no-workers-timeout": None}, ) -async def test_no_workers_timeout_disabled(c, s, a, b): +async def test_no_workers_timeout_disabled(c, s): """no-workers-timeout has been disabled""" future = c.submit(inc, 1, key="x") await wait_for_state("x", ("queued", "no-worker"), s) @@ -2455,7 +2462,13 @@ async def test_no_workers_timeout_disabled(c, s, a, b): s._check_no_workers() await asyncio.sleep(0.2) - assert s.status == Status.running + async with Worker(s.address): + await future + + assert all( + event["action"] != "no-workers-timeout-exceeded" + for _, event in s.get_events("scheduler") + ) @pytest.mark.slow @@ -2466,17 +2479,23 @@ async def test_no_workers_timeout_disabled(c, s, a, b): ) async def test_no_workers_timeout_without_workers(c, s): """Trip no-workers-timeout when there are no workers available""" - # Don't trip scheduler shutdown when there are no tasks + future = c.submit(inc, 1, key="x") + await wait_for_state("x", ("queued", "no-worker"), s) s._check_no_workers() await asyncio.sleep(0.2) s._check_no_workers() await asyncio.sleep(0.2) - assert s.status == Status.running + with pytest.raises(NoWorkerError if QUEUING_ON_BY_DEFAULT else NoValidWorkerError): + await future - future = c.submit(inc, 1) - while s.status != Status.closed: - await asyncio.sleep(0.01) + events = [ + event + for _, event in s.get_events("scheduler") + if event["action"] == "no-workers-timeout-exceeded" + ] + assert len(events) == 1 + assert events[0]["keys"] == {"x"} @pytest.mark.slow @@ -2489,8 +2508,16 @@ async def test_no_workers_timeout_bad_restrictions(c, s, a, b): task restrictions """ future = c.submit(inc, 1, key="x", workers=["127.0.0.2:1234"]) - while s.status != Status.closed: - await asyncio.sleep(0.01) + with pytest.raises(NoValidWorkerError): + await future + + events = [ + event + for _, event in s.get_events("scheduler") + if event["action"] == "no-workers-timeout-exceeded" + ] + assert len(events) == 1 + assert events[0]["keys"] == {"x"} @gen_cluster( @@ -2510,8 +2537,13 @@ async def test_no_workers_timeout_queued(c, s, a): s._check_no_workers() await asyncio.sleep(0.2) - assert s.status == Status.running await ev.set() + await c.gather(futures) + + assert all( + event["action"] != "no-workers-timeout-exceeded" + for _, event in s.get_events("scheduler") + ) @pytest.mark.slow @@ -2532,14 +2564,21 @@ async def test_no_workers_timeout_processing(c, s, a, b): await asyncio.sleep(0.2) s._check_no_workers() await asyncio.sleep(0.2) - assert s.status == Status.running + + with pytest.raises(NoValidWorkerError): + await y + + events = [ + event + for _, event in s.get_events("scheduler") + if event["action"] == "no-workers-timeout-exceeded" + ] + assert len(events) == 1 + assert events[0]["keys"] == {"y"} await ev.set() await x - while s.status != Status.closed: - await asyncio.sleep(0.01) - @gen_cluster(client=True, config={"distributed.scheduler.bandwidth": "100 GB"}) async def test_bandwidth(c, s, a, b): diff --git a/distributed/tests/test_worker_memory.py b/distributed/tests/test_worker_memory.py index fc473ba8d1..0994816d2e 100644 --- a/distributed/tests/test_worker_memory.py +++ b/distributed/tests/test_worker_memory.py @@ -538,7 +538,7 @@ def f(ev): z = c.submit(inc, 2, key="z") while "z" not in s.tasks or s.tasks["z"].state != "no-worker": await asyncio.sleep(0.01) - assert s.unrunnable == {s.tasks["z"]} + assert s.unrunnable.keys() == {s.tasks["z"]} # Test that a task that already started when the worker paused can complete # and its output can be retrieved. Also test that the now free slot won't be @@ -605,7 +605,7 @@ def f(ev): z = c.submit(inc, 2, key="z") while "z" not in s.tasks or s.tasks["z"].state != "no-worker": await asyncio.sleep(0.01) - assert s.unrunnable == {s.tasks["z"]} + assert s.unrunnable.keys() == {s.tasks["z"]} # Test that a task that already started when the worker paused can complete # and its output can be retrieved. Also test that the now free slot won't be From e0ce0fe455231bea4adfd881b07bbf28d85ed1a1 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Mon, 5 Aug 2024 12:43:01 +0200 Subject: [PATCH 5/7] Update large graph size warning to remove scatter recommendation (#8815) --- distributed/client.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/distributed/client.py b/distributed/client.py index 0601b0db5f..2ece27802a 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -3362,7 +3362,9 @@ def _graph_to_futures( warnings.warn( f"Sending large graph of size {format_bytes(pickled_size)}.\n" "This may cause some slowdown.\n" - "Consider scattering data ahead of time and using futures." + "Consider loading the data with Dask directly\n or using futures or " + "delayed objects to embed the data into the graph without repetition.\n" + "See also https://docs.dask.org/en/stable/best-practices.html#load-data-with-dask for more information." ) computations = self._get_computation_code( From 833831922d4f070003dff9eac1975e999bc22a6e Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Tue, 6 Aug 2024 12:32:59 +0200 Subject: [PATCH 6/7] Run graph normalisation after dask order (#8818) --- distributed/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index adf99113b9..a7cd91a765 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -4884,6 +4884,7 @@ async def update_graph( internal_priority = await offload( dask.order.order, dsk=dsk, dependencies=stripped_deps ) + dsk = valmap(_normalize_task, dsk) self._create_taskstate_from_graph( dsk=dsk, @@ -9383,5 +9384,4 @@ def _materialize_graph( deps.discard(k) dependencies[k] = deps - dsk = valmap(_normalize_task, dsk) return dsk, dependencies, annotations_by_type From 92fc0e24846d3edcbabc5861a3b9a812f3fc76d8 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Tue, 6 Aug 2024 15:18:25 -0500 Subject: [PATCH 7/7] bump version to 2024.8.0 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 44a79310f5..e7e729642f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ requires-python = ">=3.9" dependencies = [ "click >= 8.0", "cloudpickle >= 1.5.0", - "dask == 2024.7.1", + "dask == 2024.8.0", "jinja2 >= 2.10.3", "locket >= 1.0.0", "msgpack >= 1.0.0",