From 8c3617b55dacecc26212da2e34a275102070b7db Mon Sep 17 00:00:00 2001 From: crusaderky Date: Sun, 15 May 2022 21:30:11 +0100 Subject: [PATCH] Unit tests --- .../tests/test_worker_state_machine.py | 183 ++++++++++++++---- distributed/worker.py | 5 +- 2 files changed, 153 insertions(+), 35 deletions(-) diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py index 9295b9987d9..4253a05d3e3 100644 --- a/distributed/tests/test_worker_state_machine.py +++ b/distributed/tests/test_worker_state_machine.py @@ -1,8 +1,11 @@ import asyncio +from contextlib import contextmanager from itertools import chain import pytest +from distributed import Worker +from distributed.core import Status from distributed.protocol.serialize import Serialize from distributed.utils import recursive_to_dict from distributed.utils_test import assert_story, gen_cluster, inc @@ -16,12 +19,13 @@ SendMessageToScheduler, StateMachineEvent, TaskState, + TaskStateState, UniqueTaskHeap, merge_recs_instructions, ) -async def wait_for_state(key, state, dask_worker): +async def wait_for_state(key: str, state: TaskStateState, dask_worker: Worker) -> None: while key not in dask_worker.tasks or dask_worker.tasks[key].state != state: await asyncio.sleep(0.005) @@ -245,60 +249,173 @@ def test_executefailure_to_dict(): assert ev3.traceback_text == "tb text" -@gen_cluster(client=True) -async def test_fetch_to_compute(c, s, a, b): - # Block ensure_communicating to ensure we indeed know that the task is in - # fetch and doesn't leave it accidentally - old_out_connections, b.total_out_connections = b.total_out_connections, 0 - old_comm_threshold, b.comm_threshold_bytes = b.comm_threshold_bytes, 0 +@contextmanager +def freeze_inbound_comms(w: Worker): + """Prevent any task from transitioning from fetch to flight on the worker while + inside the context. - f1 = c.submit(inc, 1, workers=[a.address], key="f1", allow_other_workers=True) - f2 = c.submit(inc, f1, workers=[b.address], key="f2") + This is not the same as setting the worker to Status=paused, which would also + inform the Scheduler and prevent further tasks to be enqueued on the worker. + """ + old_out_connections = w.total_out_connections + old_comm_threshold = w.comm_threshold_bytes + w.total_out_connections = 0 + w.comm_threshold_bytes = 0 + yield + w.total_out_connections = old_out_connections + w.comm_threshold_bytes = old_comm_threshold + # Jump-start ensure_communicating + w.status = Status.paused + w.status = Status.running - await wait_for_state(f1.key, "fetch", b) - await a.close() - b.total_out_connections = old_out_connections - b.comm_threshold_bytes = old_comm_threshold +@gen_cluster(client=True) +async def test_fetch_to_compute(c, s, a, b): + with freeze_inbound_comms(b): + f1 = c.submit(inc, 1, workers=[a.address], key="f1", allow_other_workers=True) + f2 = c.submit(inc, f1, workers=[b.address], key="f2") + await wait_for_state("f1", "fetch", b) + await a.close() await f2 assert_story( b.log, - # FIXME: This log should be replaced with an - # StateMachineEvent/Instruction log + # FIXME: This log should be replaced with a StateMachineEvent log [ - (f2.key, "compute-task", "released"), + ("f2", "compute-task", "released"), # This is a "please fetch" request. We don't have anything like # this, yet. We don't see the request-dep signal in here because we # do not wait for the key to be actually scheduled - (f1.key, "ensure-task-exists", "released"), + ("f1", "ensure-task-exists", "released"), # After the worker failed, we're instructed to forget f2 before # something new comes in - ("free-keys", (f2.key,)), - (f1.key, "compute-task", "released"), - (f1.key, "put-in-memory"), - (f2.key, "compute-task", "released"), + ("free-keys", ("f2",)), + ("f1", "compute-task", "released"), + ("f1", "put-in-memory"), + ("f2", "compute-task", "released"), ], ) @gen_cluster(client=True) async def test_fetch_via_amm_to_compute(c, s, a, b): - # Block ensure_communicating to ensure we indeed know that the task is in - # fetch and doesn't leave it accidentally - old_out_connections, b.total_out_connections = b.total_out_connections, 0 - old_comm_threshold, b.comm_threshold_bytes = b.comm_threshold_bytes, 0 - - f1 = c.submit(inc, 1, workers=[a.address], key="f1", allow_other_workers=True) + with freeze_inbound_comms(b): + f1 = c.submit(inc, 1, workers=[a.address], key="f1", allow_other_workers=True) + await f1 + s.request_acquire_replicas(b.address, [f1.key], stimulus_id="test") + await wait_for_state(f1.key, "fetch", b) + await a.close() await f1 - s.request_acquire_replicas(b.address, [f1.key], stimulus_id="test") - await wait_for_state(f1.key, "fetch", b) - await a.close() + assert_story( + b.log, + # FIXME: This log should be replaced with a StateMachineEvent log + [ + ("f1", "ensure-task-exists", "released"), + ("f1", "released", "fetch", "fetch", {}), + ("f1", "compute-task", "fetch"), + ("f1", "put-in-memory"), + ], + ) + - b.total_out_connections = old_out_connections - b.comm_threshold_bytes = old_comm_threshold +@pytest.mark.parametrize("as_deps", [False, True]) +@gen_cluster(client=True, nthreads=[("", 1)] * 3) +async def test_lose_replica_during_fetch(c, s, w1, w2, w3, as_deps): + """ + as_deps=True + 0. task x is a dependency of y1 and y2 + 1. scheduler calls handle_compute("y1", who_has={"x": [w2, w3]}) on w1 + 2. x transitions released -> fetch + 3. the network stack is busy, so x does not transition to flight yet. + 4. scheduler calls handle_compute("y2", who_has={"x": [w3]}) on w1 + 5. when x finally reaches the top of the data_needed heap, the w1 will not try + contacting w2 + + as_deps=False + 1. scheduler calls handle_acquire_replicas(who_has={"x": [w2, w3]}) on w1 + 2. x transitions released -> fetch + 3. the network stack is busy, so x does not transition to flight yet. + 4. scheduler calls handle_acquire_replicas(who_has={"x": [w3]}) on w1 + 5. when x finally reaches the top of the data_needed heap, the w1 will not try + contacting w2 + """ + x = (await c.scatter({"x": 1}, workers=[w2.address, w3.address], broadcast=True))[ + "x" + ] + with freeze_inbound_comms(w1): + if as_deps: + y1 = c.submit(inc, x, key="y1", workers=[w1.address]) + else: + s.request_acquire_replicas(w1.address, ["x"], stimulus_id="test") - await f1 + await wait_for_state("x", "fetch", w1) + assert w1.tasks["x"].who_has == {w2.address, w3.address} + + assert len(s.tasks["x"].who_has) == 2 + s.handle_missing_data( + key="x", worker="na", errant_worker=w2.address, stimulus_id="test" + ) + assert len(s.tasks["x"].who_has) == 1 + + if as_deps: + y2 = c.submit(inc, x, key="y2", workers=[w1.address]) + else: + s.request_acquire_replicas(w1.address, ["x"], stimulus_id="test") + + while w1.tasks["x"].who_has != {w3.address}: + await asyncio.sleep(0.01) + + await wait_for_state("x", "memory", w1) + + assert_story( + w1.story("request-dep"), + [("request-dep", w3.address, {"x"})], + # This tests that there has been no attempt to contact w2. + # If the assumption being tested breaks, this will fail 50% of the times. + strict=True, + ) + + +@gen_cluster(client=True, nthreads=[("", 1)] * 2) +async def test_fetch_to_missing(c, s, a, b): + """ + 1. task x is a dependency of y + 2. scheduler calls handle_compute("y", who_has={"x": [b]}) on a + 3. x transitions released -> fetch -> flight; a connects to b + 4. b responds it's busy. x transitions flight -> fetch + 5. The busy state triggers an RPC call to Scheduler.who_has + 6. the scheduler responds {"x": []}, because w1 in the meantime has lost the key. + 7. x is transitioned fetch -> missing + """ + x = (await c.scatter({"x": 1}, workers=[b.address]))["x"] + b.total_in_connections = 0 + with freeze_inbound_comms(a): + y = c.submit(inc, x, key="y", workers=[a.address]) + await wait_for_state("x", "fetch", a) + # Do not use handle_missing_data, since it would cause the scheduler to call + # handle_free_keys(["y"]) on a + s.remove_replica(ts=s.tasks["x"], ws=s.workers[b.address]) + # We used a scheduler internal call, thus corrupting its state. + # Don't crash at the end of the test. + s.validate = False + + await wait_for_state("x", "missing", a) + assert_story( + a.story("x"), + [ + ("x", "ensure-task-exists", "released"), + ("x", "update-who-has", [], [b.address]), + ("x", "released", "fetch", "fetch", {}), + ("gather-dependencies", b.address, {"x"}), + ("x", "fetch", "flight", "flight", {}), + ("request-dep", b.address, {"x"}), + ("busy-gather", b.address, {"x"}), + ("x", "flight", "fetch", "fetch", {}), + ("x", "update-who-has", [b.address], []), # Called Scheduler.who_has + ("x", "fetch", "missing", "missing", {}), + ], + strict=True, + ) diff --git a/distributed/worker.py b/distributed/worker.py index acfc043760c..5e23a0e69dc 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -3433,10 +3433,11 @@ def done_event(): who_has = await retry_operation( self.scheduler.who_has, keys=refresh_who_has ) + refresh_stimulus_id = f"refresh-who-has-{time()}" recommendations, instructions = self._update_who_has( - who_has, stimulus_id=stimulus_id + who_has, stimulus_id=refresh_stimulus_id ) - self.transitions(recommendations, stimulus_id=stimulus_id) + self.transitions(recommendations, stimulus_id=refresh_stimulus_id) self._handle_instructions(instructions) @log_errors