Skip to content

Commit

Permalink
update_who_has can remove workers (dask#6342)
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky authored Jun 1, 2022
1 parent c2b28cf commit 715d7be
Show file tree
Hide file tree
Showing 4 changed files with 333 additions and 90 deletions.
28 changes: 3 additions & 25 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
from distributed.protocol import pickle
from distributed.scheduler import Scheduler
from distributed.utils_test import (
BlockedGatherDep,
BlockedGetData,
TaskStateMetadataPlugin,
_LockedCommPool,
assert_story,
Expand Down Expand Up @@ -2618,7 +2620,7 @@ def sink(a, b, *args):
if peer_addr == a.address and msg["op"] == "get_data":
break

# Provoke an "impossible transision exception"
# Provoke an "impossible transition exception"
# By choosing a state which doesn't exist we're not running into validation
# errors and the state machine should raise if we want to transition from
# fetch to memory
Expand Down Expand Up @@ -3137,30 +3139,6 @@ async def test_task_flight_compute_oserror(c, s, a, b):
assert_story(sum_story, expected_sum_story, strict=True)


class BlockedGatherDep(Worker):
def __init__(self, *args, **kwargs):
self.in_gather_dep = asyncio.Event()
self.block_gather_dep = asyncio.Event()
super().__init__(*args, **kwargs)

async def gather_dep(self, *args, **kwargs):
self.in_gather_dep.set()
await self.block_gather_dep.wait()
return await super().gather_dep(*args, **kwargs)


class BlockedGetData(Worker):
def __init__(self, *args, **kwargs):
self.in_get_data = asyncio.Event()
self.block_get_data = asyncio.Event()
super().__init__(*args, **kwargs)

async def get_data(self, comm, *args, **kwargs):
self.in_get_data.set()
await self.block_get_data.wait()
return await super().get_data(comm, *args, **kwargs)


@gen_cluster(client=True, nthreads=[])
async def test_gather_dep_cancelled_rescheduled(c, s):
"""At time of writing, the gather_dep implementation filtered tasks again
Expand Down
215 changes: 186 additions & 29 deletions distributed/tests/test_worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,17 @@

import pytest

from distributed import Worker, wait
from distributed.protocol.serialize import Serialize
from distributed.utils import recursive_to_dict
from distributed.utils_test import _LockedCommPool, assert_story, gen_cluster, inc
from distributed.utils_test import (
BlockedGetData,
_LockedCommPool,
assert_story,
freeze_data_fetching,
gen_cluster,
inc,
)
from distributed.worker_state_machine import (
ExecuteFailureEvent,
ExecuteSuccessEvent,
Expand All @@ -16,12 +24,13 @@
SendMessageToScheduler,
StateMachineEvent,
TaskState,
TaskStateState,
UniqueTaskHeap,
merge_recs_instructions,
)


async def wait_for_state(key, state, dask_worker):
async def wait_for_state(key: str, state: TaskStateState, dask_worker: Worker) -> None:
while key not in dask_worker.tasks or dask_worker.tasks[key].state != state:
await asyncio.sleep(0.005)

Expand Down Expand Up @@ -247,26 +256,17 @@ def test_executefailure_to_dict():

@gen_cluster(client=True)
async def test_fetch_to_compute(c, s, a, b):
# Block ensure_communicating to ensure we indeed know that the task is in
# fetch and doesn't leave it accidentally
old_out_connections, b.total_out_connections = b.total_out_connections, 0
old_comm_threshold, b.comm_threshold_bytes = b.comm_threshold_bytes, 0

f1 = c.submit(inc, 1, workers=[a.address], key="f1", allow_other_workers=True)
f2 = c.submit(inc, f1, workers=[b.address], key="f2")

await wait_for_state(f1.key, "fetch", b)
await a.close()

b.total_out_connections = old_out_connections
b.comm_threshold_bytes = old_comm_threshold
with freeze_data_fetching(b):
f1 = c.submit(inc, 1, workers=[a.address], key="f1", allow_other_workers=True)
f2 = c.submit(inc, f1, workers=[b.address], key="f2")
await wait_for_state(f1.key, "fetch", b)
await a.close()

await f2

assert_story(
b.log,
# FIXME: This log should be replaced with an
# StateMachineEvent/Instruction log
# FIXME: This log should be replaced with a StateMachineEvent log
[
(f2.key, "compute-task", "released"),
# This is a "please fetch" request. We don't have anything like
Expand All @@ -285,23 +285,180 @@ async def test_fetch_to_compute(c, s, a, b):

@gen_cluster(client=True)
async def test_fetch_via_amm_to_compute(c, s, a, b):
# Block ensure_communicating to ensure we indeed know that the task is in
# fetch and doesn't leave it accidentally
old_out_connections, b.total_out_connections = b.total_out_connections, 0
old_comm_threshold, b.comm_threshold_bytes = b.comm_threshold_bytes, 0

f1 = c.submit(inc, 1, workers=[a.address], key="f1", allow_other_workers=True)
with freeze_data_fetching(b):
f1 = c.submit(inc, 1, workers=[a.address], key="f1", allow_other_workers=True)
await f1
s.request_acquire_replicas(b.address, [f1.key], stimulus_id="test")
await wait_for_state(f1.key, "fetch", b)
await a.close()

await f1
s.request_acquire_replicas(b.address, [f1.key], stimulus_id="test")

await wait_for_state(f1.key, "fetch", b)
await a.close()
assert_story(
b.log,
# FIXME: This log should be replaced with a StateMachineEvent log
[
(f1.key, "ensure-task-exists", "released"),
(f1.key, "released", "fetch", "fetch", {}),
(f1.key, "compute-task", "fetch"),
(f1.key, "put-in-memory"),
],
)

b.total_out_connections = old_out_connections
b.comm_threshold_bytes = old_comm_threshold

await f1
@pytest.mark.parametrize("as_deps", [False, True])
@gen_cluster(client=True, nthreads=[("", 1)] * 3)
async def test_lose_replica_during_fetch(c, s, w1, w2, w3, as_deps):
"""
as_deps=True
0. task x is a dependency of y1 and y2
1. scheduler calls handle_compute("y1", who_has={"x": [w2, w3]}) on w1
2. x transitions released -> fetch
3. the network stack is busy, so x does not transition to flight yet.
4. scheduler calls handle_compute("y2", who_has={"x": [w3]}) on w1
5. when x finally reaches the top of the data_needed heap, w1 will not try
contacting w2
as_deps=False
1. scheduler calls handle_acquire_replicas(who_has={"x": [w2, w3]}) on w1
2. x transitions released -> fetch
3. the network stack is busy, so x does not transition to flight yet.
4. scheduler calls handle_acquire_replicas(who_has={"x": [w3]}) on w1
5. when x finally reaches the top of the data_needed heap, w1 will not try
contacting w2
"""
x = (await c.scatter({"x": 1}, workers=[w2.address, w3.address], broadcast=True))[
"x"
]

# Make sure find_missing is not involved
w1.periodic_callbacks["find-missing"].stop()

with freeze_data_fetching(w1, jump_start=True):
if as_deps:
y1 = c.submit(inc, x, key="y1", workers=[w1.address])
else:
s.request_acquire_replicas(w1.address, ["x"], stimulus_id="test")

await wait_for_state("x", "fetch", w1)
assert w1.tasks["x"].who_has == {w2.address, w3.address}

assert len(s.tasks["x"].who_has) == 2
await w2.close()
while len(s.tasks["x"].who_has) > 1:
await asyncio.sleep(0.01)

if as_deps:
y2 = c.submit(inc, x, key="y2", workers=[w1.address])
else:
s.request_acquire_replicas(w1.address, ["x"], stimulus_id="test")

while w1.tasks["x"].who_has != {w3.address}:
await asyncio.sleep(0.01)

await wait_for_state("x", "memory", w1)
assert_story(
w1.story("request-dep"),
[("request-dep", w3.address, {"x"})],
# This tests that there has been no attempt to contact w2.
# If the assumption being tested breaks, this will fail 50% of the times.
strict=True,
)


@gen_cluster(client=True, nthreads=[("", 1)] * 2)
async def test_fetch_to_missing(c, s, a, b):
"""
1. task x is a dependency of y
2. scheduler calls handle_compute("y", who_has={"x": [b]}) on a
3. x transitions released -> fetch -> flight; a connects to b
4. b responds it's busy. x transitions flight -> fetch
5. The busy state triggers an RPC call to Scheduler.who_has
6. the scheduler responds {"x": []}, because w1 in the meantime has lost the key.
7. x is transitioned fetch -> missing
"""
x = await c.scatter({"x": 1}, workers=[b.address])
b.total_in_connections = 0
# Crucially, unlike with `c.submit(inc, x, workers=[a.address])`, the scheduler
# doesn't keep track of acquire-replicas requests, so it won't proactively inform a
# when we call remove_worker later on
s.request_acquire_replicas(a.address, ["x"], stimulus_id="test")

# state will flip-flop between fetch and flight every 150ms, which is the retry
# period for busy workers.
await wait_for_state("x", "fetch", a)
assert b.address in a.busy_workers

# Sever connection between b and s, but not between b and a.
# If a tries fetching from b after this, b will keep responding {status: busy}.
b.periodic_callbacks["heartbeat"].stop()
await s.remove_worker(b.address, close=False, stimulus_id="test")

await wait_for_state("x", "missing", a)

assert_story(
a.story("x"),
[
("x", "ensure-task-exists", "released"),
("x", "released", "fetch", "fetch", {}),
("gather-dependencies", b.address, {"x"}),
("x", "fetch", "flight", "flight", {}),
("request-dep", b.address, {"x"}),
("busy-gather", b.address, {"x"}),
("x", "flight", "fetch", "fetch", {}),
("x", "fetch", "missing", "missing", {}),
],
# There may be a round of find_missing() after this.
# Due to timings, there also may be multiple attempts to connect from a to b.
strict=False,
)


@pytest.mark.skip(reason="https://github.com/dask/distributed/issues/6446")
@gen_cluster(client=True)
async def test_new_replica_while_all_workers_in_flight(c, s, w1, w2):
"""A task is stuck in 'fetch' state because all workers that hold a replica are in
flight. While in this state, a new replica appears on a different worker and the
scheduler informs the waiting worker through a new acquire-replicas or
compute-task op.
In real life, this will typically happen when the Active Memory Manager replicates a
key to multiple workers and some workers are much faster than others to acquire it,
due to unrelated tasks being in flight, so 2 seconds later the AMM reiterates the
request, passing a larger who_has.
Test that, when this happens, the task is immediately acquired from the new worker,
without waiting for the original replica holders to get out of flight.
"""
# Make sure find_missing is not involved
w1.periodic_callbacks["find-missing"].stop()

async with BlockedGetData(s.address) as w3:
x = c.submit(inc, 1, key="x", workers=[w3.address])
y = c.submit(inc, 2, key="y", workers=[w3.address])
await wait([x, y])
s.request_acquire_replicas(w1.address, ["x"], stimulus_id="test")
await w3.in_get_data.wait()
assert w1.tasks["x"].state == "flight"
s.request_acquire_replicas(w1.address, ["y"], stimulus_id="test")
# This cannot progress beyond fetch because w3 is already in flight
await wait_for_state("y", "fetch", w1)

# Simulate that the AMM also requires that w2 acquires a replica of x.
# The replica lands on w2 soon afterwards, while w3->w1 comms remain blocked by
# unrelated transfers (x in our case).
w2.update_data({"y": 3}, report=True)
ws2 = s.workers[w2.address]
while ws2 not in s.tasks["y"].who_has:
await asyncio.sleep(0.01)

# 2 seconds later, the AMM reiterates that w1 should acquire a replica of y
s.request_acquire_replicas(w1.address, ["y"], stimulus_id="test")
await wait_for_state("y", "memory", w1)

# Finally let the other worker to get out of flight
w3.block_get_data.set()
await wait_for_state("x", "memory", w1)


@gen_cluster(client=True)
Expand Down
86 changes: 86 additions & 0 deletions distributed/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2261,3 +2261,89 @@ def wait_for_log_line(
if match in line:
return line
i += 1


class BlockedGatherDep(Worker):
"""A Worker that sets event `in_gather_dep` the first time it enters the gather_dep
method and then does not initiate any comms, thus leaving the task(s) in flight
indefinitely, until the test sets `block_gather_dep`
Example
-------
.. code-block:: python
@gen_test()
async def test1(s, a, b):
async with BlockedGatherDep(s.address) as x:
# [do something to cause x to fetch data from a or b]
await x.in_gather_dep.wait()
# [do something that must happen while the tasks are in flight]
x.block_gather_dep.set()
# [from this moment on, x is a regular worker]
See also
--------
BlockedGetData
"""

def __init__(self, *args, **kwargs):
self.in_gather_dep = asyncio.Event()
self.block_gather_dep = asyncio.Event()
super().__init__(*args, **kwargs)

async def gather_dep(self, *args, **kwargs):
self.in_gather_dep.set()
await self.block_gather_dep.wait()
return await super().gather_dep(*args, **kwargs)


class BlockedGetData(Worker):
"""A Worker that sets event `in_get_data` the first time it enters the get_data
method and then does not answer the comms, thus leaving the task(s) in flight
indefinitely, until the test sets `block_get_data`
See also
--------
BlockedGatherDep
"""

def __init__(self, *args, **kwargs):
self.in_get_data = asyncio.Event()
self.block_get_data = asyncio.Event()
super().__init__(*args, **kwargs)

async def get_data(self, comm, *args, **kwargs):
self.in_get_data.set()
await self.block_get_data.wait()
return await super().get_data(comm, *args, **kwargs)


@contextmanager
def freeze_data_fetching(w: Worker, *, jump_start: bool = False):
"""Prevent any task from transitioning from fetch to flight on the worker while
inside the context, simulating a situation where the worker's network comms are
saturated.
This is not the same as setting the worker to Status=paused, which would also
inform the Scheduler and prevent further tasks to be enqueued on the worker.
Parameters
----------
w: Worker
The Worker on which tasks will not transition from fetch to flight
jump_start: bool
If False, tasks will remain in fetch state after exiting the context, until
something else triggers ensure_communicating.
If True, trigger ensure_communicating on exit; this simulates e.g. an unrelated
worker moving out of in_flight_workers.
"""
old_out_connections = w.total_out_connections
old_comm_threshold = w.comm_threshold_bytes
w.total_out_connections = 0
w.comm_threshold_bytes = 0
yield
w.total_out_connections = old_out_connections
w.comm_threshold_bytes = old_comm_threshold
if jump_start:
w.status = Status.paused
w.status = Status.running
Loading

0 comments on commit 715d7be

Please sign in to comment.