Skip to content

Commit

Permalink
Overhaul update_who_has (dask#6342)
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed May 20, 2022
1 parent 33fc50c commit 4e9385b
Show file tree
Hide file tree
Showing 6 changed files with 312 additions and 72 deletions.
17 changes: 8 additions & 9 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4184,7 +4184,9 @@ def stimulus_retry(self, keys, client=None):
return tuple(seen)

@log_errors
async def remove_worker(self, address, stimulus_id, safe=False, close=True):
async def remove_worker(
self, address: str, *, stimulus_id: str, safe: bool = False, close: bool = True
) -> Literal["OK", "already-removed"]:
"""
Remove worker from cluster
Expand All @@ -4193,7 +4195,7 @@ async def remove_worker(self, address, stimulus_id, safe=False, close=True):
state.
"""
if self.status == Status.closed:
return
return "already-removed"

address = self.coerce_address(address)

Expand Down Expand Up @@ -7120,23 +7122,20 @@ def adaptive_target(self, target_duration=None):
to_close = self.workers_to_close()
return len(self.workers) - len(to_close)

def request_acquire_replicas(self, addr: str, keys: list, *, stimulus_id: str):
def request_acquire_replicas(
self, addr: str, keys: Iterable[str], *, stimulus_id: str
):
"""Asynchronously ask a worker to acquire a replica of the listed keys from
other workers. This is a fire-and-forget operation which offers no feedback for
success or failure, and is intended for housekeeping and not for computation.
"""
who_has = {}
for key in keys:
ts = self.tasks[key]
who_has[key] = {ws.address for ws in ts.who_has}

who_has = {key: {ws.address for ws in self.tasks[key].who_has} for key in keys}
if self.validate:
assert all(who_has.values())

self.stream_comms[addr].send(
{
"op": "acquire-replicas",
"keys": keys,
"who_has": who_has,
"stimulus_id": stimulus_id,
},
Expand Down
2 changes: 1 addition & 1 deletion distributed/tests/test_cancelled_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ async def get_data(self, comm, *args, **kwargs):
s.set_restrictions({fut1.key: [a.address, b.address]})
# It is removed, i.e. get_data is guaranteed to fail and f1 is scheduled
# to be recomputed on B
await s.remove_worker(a.address, "foo", close=False, safe=True)
await s.remove_worker(a.address, stimulus_id="foo", close=False, safe=True)

while not b.tasks[fut1.key].state == "resumed":
await asyncio.sleep(0.01)
Expand Down
1 change: 1 addition & 0 deletions distributed/tests/test_stories.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ async def test_worker_story_with_deps(c, s, a, b):
assert stimulus_ids == {"compute-task"}
expected = [
("dep", "ensure-task-exists", "released"),
("dep", "update-who-has", [], [a.address]),
("dep", "released", "fetch", "fetch", {}),
("gather-dependencies", a.address, {"dep"}),
("dep", "fetch", "flight", "flight", {}),
Expand Down
1 change: 1 addition & 0 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2814,6 +2814,7 @@ def __getstate__(self):
story[b],
[
("x", "ensure-task-exists", "released"),
("x", "update-who-has", [], [a]),
("x", "released", "fetch", "fetch", {}),
("gather-dependencies", a, {"x"}),
("x", "fetch", "flight", "flight", {}),
Expand Down
225 changes: 198 additions & 27 deletions distributed/tests/test_worker_state_machine.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import asyncio
from contextlib import contextmanager
from itertools import chain

import pytest

from distributed import Worker
from distributed.core import Status
from distributed.protocol.serialize import Serialize
from distributed.utils import recursive_to_dict
from distributed.utils_test import assert_story, gen_cluster, inc
Expand All @@ -16,12 +19,13 @@
SendMessageToScheduler,
StateMachineEvent,
TaskState,
TaskStateState,
UniqueTaskHeap,
merge_recs_instructions,
)


async def wait_for_state(key, state, dask_worker):
async def wait_for_state(key: str, state: TaskStateState, dask_worker: Worker) -> None:
while key not in dask_worker.tasks or dask_worker.tasks[key].state != state:
await asyncio.sleep(0.005)

Expand Down Expand Up @@ -245,28 +249,39 @@ def test_executefailure_to_dict():
assert ev3.traceback_text == "tb text"


@gen_cluster(client=True)
async def test_fetch_to_compute(c, s, a, b):
# Block ensure_communicating to ensure we indeed know that the task is in
# fetch and doesn't leave it accidentally
old_out_connections, b.total_out_connections = b.total_out_connections, 0
old_comm_threshold, b.comm_threshold_bytes = b.comm_threshold_bytes, 0
@contextmanager
def freeze_data_fetching(w: Worker):
"""Prevent any task from transitioning from fetch to flight on the worker while
inside the context.
f1 = c.submit(inc, 1, workers=[a.address], key="f1", allow_other_workers=True)
f2 = c.submit(inc, f1, workers=[b.address], key="f2")
This is not the same as setting the worker to Status=paused, which would also
inform the Scheduler and prevent further tasks to be enqueued on the worker.
"""
old_out_connections = w.total_out_connections
old_comm_threshold = w.comm_threshold_bytes
w.total_out_connections = 0
w.comm_threshold_bytes = 0
yield
w.total_out_connections = old_out_connections
w.comm_threshold_bytes = old_comm_threshold
# Jump-start ensure_communicating
w.status = Status.paused
w.status = Status.running

await wait_for_state(f1.key, "fetch", b)
await a.close()

b.total_out_connections = old_out_connections
b.comm_threshold_bytes = old_comm_threshold
@gen_cluster(client=True)
async def test_fetch_to_compute(c, s, a, b):
with freeze_data_fetching(b):
f1 = c.submit(inc, 1, workers=[a.address], key="f1", allow_other_workers=True)
f2 = c.submit(inc, f1, workers=[b.address], key="f2")
await wait_for_state(f1.key, "fetch", b)
await a.close()

await f2

assert_story(
b.log,
# FIXME: This log should be replaced with an
# StateMachineEvent/Instruction log
# FIXME: This log should be replaced with a StateMachineEvent log
[
(f2.key, "compute-task", "released"),
# This is a "please fetch" request. We don't have anything like
Expand All @@ -285,20 +300,176 @@ async def test_fetch_to_compute(c, s, a, b):

@gen_cluster(client=True)
async def test_fetch_via_amm_to_compute(c, s, a, b):
# Block ensure_communicating to ensure we indeed know that the task is in
# fetch and doesn't leave it accidentally
old_out_connections, b.total_out_connections = b.total_out_connections, 0
old_comm_threshold, b.comm_threshold_bytes = b.comm_threshold_bytes, 0

f1 = c.submit(inc, 1, workers=[a.address], key="f1", allow_other_workers=True)
with freeze_data_fetching(b):
f1 = c.submit(inc, 1, workers=[a.address], key="f1", allow_other_workers=True)
await f1
s.request_acquire_replicas(b.address, [f1.key], stimulus_id="test")
await wait_for_state(f1.key, "fetch", b)
await a.close()

await f1
s.request_acquire_replicas(b.address, [f1.key], stimulus_id="test")

await wait_for_state(f1.key, "fetch", b)
await a.close()
assert_story(
b.log,
# FIXME: This log should be replaced with a StateMachineEvent log
[
(f1.key, "ensure-task-exists", "released"),
(f1.key, "released", "fetch", "fetch", {}),
(f1.key, "compute-task", "fetch"),
(f1.key, "put-in-memory"),
],
)

b.total_out_connections = old_out_connections
b.comm_threshold_bytes = old_comm_threshold

await f1
@pytest.mark.parametrize("as_deps", [False, True])
@gen_cluster(client=True, nthreads=[("", 1)] * 3)
async def test_lose_replica_during_fetch(c, s, w1, w2, w3, as_deps):
"""
as_deps=True
0. task x is a dependency of y1 and y2
1. scheduler calls handle_compute("y1", who_has={"x": [w2, w3]}) on w1
2. x transitions released -> fetch
3. the network stack is busy, so x does not transition to flight yet.
4. scheduler calls handle_compute("y2", who_has={"x": [w3]}) on w1
5. when x finally reaches the top of the data_needed heap, the w1 will not try
contacting w2
as_deps=False
1. scheduler calls handle_acquire_replicas(who_has={"x": [w2, w3]}) on w1
2. x transitions released -> fetch
3. the network stack is busy, so x does not transition to flight yet.
4. scheduler calls handle_acquire_replicas(who_has={"x": [w3]}) on w1
5. when x finally reaches the top of the data_needed heap, the w1 will not try
contacting w2
"""
x = (await c.scatter({"x": 1}, workers=[w2.address, w3.address], broadcast=True))[
"x"
]
with freeze_data_fetching(w1):
if as_deps:
y1 = c.submit(inc, x, key="y1", workers=[w1.address])
else:
s.request_acquire_replicas(w1.address, ["x"], stimulus_id="test")

await wait_for_state("x", "fetch", w1)
assert w1.tasks["x"].who_has == {w2.address, w3.address}

assert len(s.tasks["x"].who_has) == 2
await w2.close()
while len(s.tasks["x"].who_has) > 1:
await asyncio.sleep(0.01)

if as_deps:
y2 = c.submit(inc, x, key="y2", workers=[w1.address])
else:
s.request_acquire_replicas(w1.address, ["x"], stimulus_id="test")

while w1.tasks["x"].who_has != {w3.address}:
await asyncio.sleep(0.01)

await wait_for_state("x", "memory", w1)

assert_story(
w1.story("request-dep"),
[("request-dep", w3.address, {"x"})],
# This tests that there has been no attempt to contact w2.
# If the assumption being tested breaks, this will fail 50% of the times.
strict=True,
)


@gen_cluster(client=True, nthreads=[("", 1)] * 2)
async def test_fetch_to_missing(c, s, a, b):
"""
1. task x is a dependency of y
2. scheduler calls handle_compute("y", who_has={"x": [b]}) on a
3. x transitions released -> fetch -> flight; a connects to b
4. b responds it's busy. x transitions flight -> fetch
5. The busy state triggers an RPC call to Scheduler.who_has
6. the scheduler responds {"x": []}, because w1 in the meantime has lost the key.
7. x is transitioned fetch -> missing
"""
x = await c.scatter({"x": 1}, workers=[b.address])
b.total_in_connections = 0
# Crucially, unlike with `c.submit(inc, x, workers=[a.address])`, the scheduler
# doesn't keep track of acquire-replicas requests, so it won't proactively inform a
# when we call remove_worker later on
s.request_acquire_replicas(a.address, ["x"], stimulus_id="test")

# state will flip-flop between fetch and flight every 150ms, which is the retry
# period for busy workers.
await wait_for_state("x", "fetch", a)
assert b.address in a.busy_workers

# Sever connection between b and s, but not between b and a.
# If a tries fetching from b after this, b will keep responding {status: busy}.
b.periodic_callbacks["heartbeat"].stop()
await s.remove_worker(b.address, close=False, stimulus_id="test")

await wait_for_state("x", "missing", a)

assert_story(
a.story("x"),
[
("x", "ensure-task-exists", "released"),
("x", "update-who-has", [], [b.address]),
("x", "released", "fetch", "fetch", {}),
("gather-dependencies", b.address, {"x"}),
("x", "fetch", "flight", "flight", {}),
("request-dep", b.address, {"x"}),
("busy-gather", b.address, {"x"}),
("x", "flight", "fetch", "fetch", {}),
("x", "update-who-has", [b.address], []), # Called Scheduler.who_has
("x", "fetch", "missing", "missing", {}),
],
# There may be a round of find_missing() after this.
# Due to timings, there also may be multiple attempts to connect from a to b.
strict=False,
)


@gen_cluster(client=True, nthreads=[("", 1)])
async def test_self_denounce_missing_data(c, s, a):
x = c.submit(inc, 1, key="x")
await x

# Wipe x from the worker. This simulates the following condition:
# 1. The scheduler thinks a and b hold a replica of x.
# 2. b is unresponsive, but the scheduler doesn't know yet as it didn't time out.
# 3. The AMM decides to ask a to drop its replica.
a.handle_remove_replicas(keys=["x"], stimulus_id="test")

# Lose the message that would inform the scheduler.
# In theory, this should not happen.
# In practice, in case of TPC connection fault, BatchedSend can drop messages.
assert a.batched_stream.buffer == [
{"key": "x", "stimulus_id": "test", "op": "release-worker-data"}
]
a.batched_stream.buffer.clear()

y = c.submit(inc, x, key="y")
# The scheduler tries computing y, but a responds that x is not available.
# The scheduler kicks off the computation of x and then y from scratch.
assert await y == 3

assert_story(
a.story("compute-task"),
[
("x", "compute-task", "released"),
# The scheduler tries computing y a first time and fails.
# This line would not be here if we didn't lose the
# {"op": "release-worker-data"} message earlier.
("y", "compute-task", "released"),
# The scheduler receives the {"op": "missing-data"} message from the
# worker. This makes the computation of y to fail. The scheduler reschedules
# x and then y.
("x", "compute-task", "released"),
("y", "compute-task", "released"),
],
strict=True,
)

del x
while "x" in a.data:
await asyncio.sleep(0.01)
assert a.tasks["x"].state == "released"
Loading

0 comments on commit 4e9385b

Please sign in to comment.