Skip to content

Commit

Permalink
Unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed May 15, 2022
1 parent 78736df commit 8c3617b
Show file tree
Hide file tree
Showing 2 changed files with 153 additions and 35 deletions.
183 changes: 150 additions & 33 deletions distributed/tests/test_worker_state_machine.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import asyncio
from contextlib import contextmanager
from itertools import chain

import pytest

from distributed import Worker
from distributed.core import Status
from distributed.protocol.serialize import Serialize
from distributed.utils import recursive_to_dict
from distributed.utils_test import assert_story, gen_cluster, inc
Expand All @@ -16,12 +19,13 @@
SendMessageToScheduler,
StateMachineEvent,
TaskState,
TaskStateState,
UniqueTaskHeap,
merge_recs_instructions,
)


async def wait_for_state(key, state, dask_worker):
async def wait_for_state(key: str, state: TaskStateState, dask_worker: Worker) -> None:
while key not in dask_worker.tasks or dask_worker.tasks[key].state != state:
await asyncio.sleep(0.005)

Expand Down Expand Up @@ -245,60 +249,173 @@ def test_executefailure_to_dict():
assert ev3.traceback_text == "tb text"


@gen_cluster(client=True)
async def test_fetch_to_compute(c, s, a, b):
# Block ensure_communicating to ensure we indeed know that the task is in
# fetch and doesn't leave it accidentally
old_out_connections, b.total_out_connections = b.total_out_connections, 0
old_comm_threshold, b.comm_threshold_bytes = b.comm_threshold_bytes, 0
@contextmanager
def freeze_inbound_comms(w: Worker):
"""Prevent any task from transitioning from fetch to flight on the worker while
inside the context.
f1 = c.submit(inc, 1, workers=[a.address], key="f1", allow_other_workers=True)
f2 = c.submit(inc, f1, workers=[b.address], key="f2")
This is not the same as setting the worker to Status=paused, which would also
inform the Scheduler and prevent further tasks to be enqueued on the worker.
"""
old_out_connections = w.total_out_connections
old_comm_threshold = w.comm_threshold_bytes
w.total_out_connections = 0
w.comm_threshold_bytes = 0
yield
w.total_out_connections = old_out_connections
w.comm_threshold_bytes = old_comm_threshold
# Jump-start ensure_communicating
w.status = Status.paused
w.status = Status.running

await wait_for_state(f1.key, "fetch", b)
await a.close()

b.total_out_connections = old_out_connections
b.comm_threshold_bytes = old_comm_threshold
@gen_cluster(client=True)
async def test_fetch_to_compute(c, s, a, b):
with freeze_inbound_comms(b):
f1 = c.submit(inc, 1, workers=[a.address], key="f1", allow_other_workers=True)
f2 = c.submit(inc, f1, workers=[b.address], key="f2")
await wait_for_state("f1", "fetch", b)
await a.close()

await f2

assert_story(
b.log,
# FIXME: This log should be replaced with an
# StateMachineEvent/Instruction log
# FIXME: This log should be replaced with a StateMachineEvent log
[
(f2.key, "compute-task", "released"),
("f2", "compute-task", "released"),
# This is a "please fetch" request. We don't have anything like
# this, yet. We don't see the request-dep signal in here because we
# do not wait for the key to be actually scheduled
(f1.key, "ensure-task-exists", "released"),
("f1", "ensure-task-exists", "released"),
# After the worker failed, we're instructed to forget f2 before
# something new comes in
("free-keys", (f2.key,)),
(f1.key, "compute-task", "released"),
(f1.key, "put-in-memory"),
(f2.key, "compute-task", "released"),
("free-keys", ("f2",)),
("f1", "compute-task", "released"),
("f1", "put-in-memory"),
("f2", "compute-task", "released"),
],
)


@gen_cluster(client=True)
async def test_fetch_via_amm_to_compute(c, s, a, b):
# Block ensure_communicating to ensure we indeed know that the task is in
# fetch and doesn't leave it accidentally
old_out_connections, b.total_out_connections = b.total_out_connections, 0
old_comm_threshold, b.comm_threshold_bytes = b.comm_threshold_bytes, 0

f1 = c.submit(inc, 1, workers=[a.address], key="f1", allow_other_workers=True)
with freeze_inbound_comms(b):
f1 = c.submit(inc, 1, workers=[a.address], key="f1", allow_other_workers=True)
await f1
s.request_acquire_replicas(b.address, [f1.key], stimulus_id="test")
await wait_for_state(f1.key, "fetch", b)
await a.close()

await f1
s.request_acquire_replicas(b.address, [f1.key], stimulus_id="test")

await wait_for_state(f1.key, "fetch", b)
await a.close()
assert_story(
b.log,
# FIXME: This log should be replaced with a StateMachineEvent log
[
("f1", "ensure-task-exists", "released"),
("f1", "released", "fetch", "fetch", {}),
("f1", "compute-task", "fetch"),
("f1", "put-in-memory"),
],
)


b.total_out_connections = old_out_connections
b.comm_threshold_bytes = old_comm_threshold
@pytest.mark.parametrize("as_deps", [False, True])
@gen_cluster(client=True, nthreads=[("", 1)] * 3)
async def test_lose_replica_during_fetch(c, s, w1, w2, w3, as_deps):
"""
as_deps=True
0. task x is a dependency of y1 and y2
1. scheduler calls handle_compute("y1", who_has={"x": [w2, w3]}) on w1
2. x transitions released -> fetch
3. the network stack is busy, so x does not transition to flight yet.
4. scheduler calls handle_compute("y2", who_has={"x": [w3]}) on w1
5. when x finally reaches the top of the data_needed heap, the w1 will not try
contacting w2
as_deps=False
1. scheduler calls handle_acquire_replicas(who_has={"x": [w2, w3]}) on w1
2. x transitions released -> fetch
3. the network stack is busy, so x does not transition to flight yet.
4. scheduler calls handle_acquire_replicas(who_has={"x": [w3]}) on w1
5. when x finally reaches the top of the data_needed heap, the w1 will not try
contacting w2
"""
x = (await c.scatter({"x": 1}, workers=[w2.address, w3.address], broadcast=True))[
"x"
]
with freeze_inbound_comms(w1):
if as_deps:
y1 = c.submit(inc, x, key="y1", workers=[w1.address])
else:
s.request_acquire_replicas(w1.address, ["x"], stimulus_id="test")

await f1
await wait_for_state("x", "fetch", w1)
assert w1.tasks["x"].who_has == {w2.address, w3.address}

assert len(s.tasks["x"].who_has) == 2
s.handle_missing_data(
key="x", worker="na", errant_worker=w2.address, stimulus_id="test"
)
assert len(s.tasks["x"].who_has) == 1

if as_deps:
y2 = c.submit(inc, x, key="y2", workers=[w1.address])
else:
s.request_acquire_replicas(w1.address, ["x"], stimulus_id="test")

while w1.tasks["x"].who_has != {w3.address}:
await asyncio.sleep(0.01)

await wait_for_state("x", "memory", w1)

assert_story(
w1.story("request-dep"),
[("request-dep", w3.address, {"x"})],
# This tests that there has been no attempt to contact w2.
# If the assumption being tested breaks, this will fail 50% of the times.
strict=True,
)


@gen_cluster(client=True, nthreads=[("", 1)] * 2)
async def test_fetch_to_missing(c, s, a, b):
"""
1. task x is a dependency of y
2. scheduler calls handle_compute("y", who_has={"x": [b]}) on a
3. x transitions released -> fetch -> flight; a connects to b
4. b responds it's busy. x transitions flight -> fetch
5. The busy state triggers an RPC call to Scheduler.who_has
6. the scheduler responds {"x": []}, because w1 in the meantime has lost the key.
7. x is transitioned fetch -> missing
"""
x = (await c.scatter({"x": 1}, workers=[b.address]))["x"]
b.total_in_connections = 0
with freeze_inbound_comms(a):
y = c.submit(inc, x, key="y", workers=[a.address])
await wait_for_state("x", "fetch", a)
# Do not use handle_missing_data, since it would cause the scheduler to call
# handle_free_keys(["y"]) on a
s.remove_replica(ts=s.tasks["x"], ws=s.workers[b.address])
# We used a scheduler internal call, thus corrupting its state.
# Don't crash at the end of the test.
s.validate = False

await wait_for_state("x", "missing", a)
assert_story(
a.story("x"),
[
("x", "ensure-task-exists", "released"),
("x", "update-who-has", [], [b.address]),
("x", "released", "fetch", "fetch", {}),
("gather-dependencies", b.address, {"x"}),
("x", "fetch", "flight", "flight", {}),
("request-dep", b.address, {"x"}),
("busy-gather", b.address, {"x"}),
("x", "flight", "fetch", "fetch", {}),
("x", "update-who-has", [b.address], []), # Called Scheduler.who_has
("x", "fetch", "missing", "missing", {}),
],
strict=True,
)
5 changes: 3 additions & 2 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3433,10 +3433,11 @@ def done_event():
who_has = await retry_operation(
self.scheduler.who_has, keys=refresh_who_has
)
refresh_stimulus_id = f"refresh-who-has-{time()}"
recommendations, instructions = self._update_who_has(
who_has, stimulus_id=stimulus_id
who_has, stimulus_id=refresh_stimulus_id
)
self.transitions(recommendations, stimulus_id=stimulus_id)
self.transitions(recommendations, stimulus_id=refresh_stimulus_id)
self._handle_instructions(instructions)

@log_errors
Expand Down

0 comments on commit 8c3617b

Please sign in to comment.