Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update_who_has can remove workers #6342

Merged
merged 12 commits into from
Jun 1, 2022
28 changes: 3 additions & 25 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
from distributed.protocol import pickle
from distributed.scheduler import Scheduler
from distributed.utils_test import (
BlockedGatherDep,
BlockedGetData,
TaskStateMetadataPlugin,
_LockedCommPool,
assert_story,
Expand Down Expand Up @@ -2618,7 +2620,7 @@ def sink(a, b, *args):
if peer_addr == a.address and msg["op"] == "get_data":
break

# Provoke an "impossible transision exception"
# Provoke an "impossible transition exception"
# By choosing a state which doesn't exist we're not running into validation
# errors and the state machine should raise if we want to transition from
# fetch to memory
Expand Down Expand Up @@ -3137,30 +3139,6 @@ async def test_task_flight_compute_oserror(c, s, a, b):
assert_story(sum_story, expected_sum_story, strict=True)


class BlockedGatherDep(Worker):
def __init__(self, *args, **kwargs):
self.in_gather_dep = asyncio.Event()
self.block_gather_dep = asyncio.Event()
super().__init__(*args, **kwargs)

async def gather_dep(self, *args, **kwargs):
self.in_gather_dep.set()
await self.block_gather_dep.wait()
return await super().gather_dep(*args, **kwargs)


class BlockedGetData(Worker):
def __init__(self, *args, **kwargs):
self.in_get_data = asyncio.Event()
self.block_get_data = asyncio.Event()
super().__init__(*args, **kwargs)

async def get_data(self, comm, *args, **kwargs):
self.in_get_data.set()
await self.block_get_data.wait()
return await super().get_data(comm, *args, **kwargs)


@gen_cluster(client=True, nthreads=[])
async def test_gather_dep_cancelled_rescheduled(c, s):
"""At time of writing, the gather_dep implementation filtered tasks again
Expand Down
215 changes: 186 additions & 29 deletions distributed/tests/test_worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,17 @@

import pytest

from distributed import Worker, wait
from distributed.protocol.serialize import Serialize
from distributed.utils import recursive_to_dict
from distributed.utils_test import _LockedCommPool, assert_story, gen_cluster, inc
from distributed.utils_test import (
BlockedGetData,
_LockedCommPool,
assert_story,
freeze_data_fetching,
gen_cluster,
inc,
)
from distributed.worker_state_machine import (
ExecuteFailureEvent,
ExecuteSuccessEvent,
Expand All @@ -16,12 +24,13 @@
SendMessageToScheduler,
StateMachineEvent,
TaskState,
TaskStateState,
UniqueTaskHeap,
merge_recs_instructions,
)


async def wait_for_state(key, state, dask_worker):
async def wait_for_state(key: str, state: TaskStateState, dask_worker: Worker) -> None:
while key not in dask_worker.tasks or dask_worker.tasks[key].state != state:
await asyncio.sleep(0.005)

Expand Down Expand Up @@ -247,26 +256,17 @@ def test_executefailure_to_dict():

@gen_cluster(client=True)
async def test_fetch_to_compute(c, s, a, b):
# Block ensure_communicating to ensure we indeed know that the task is in
# fetch and doesn't leave it accidentally
old_out_connections, b.total_out_connections = b.total_out_connections, 0
old_comm_threshold, b.comm_threshold_bytes = b.comm_threshold_bytes, 0

f1 = c.submit(inc, 1, workers=[a.address], key="f1", allow_other_workers=True)
f2 = c.submit(inc, f1, workers=[b.address], key="f2")

await wait_for_state(f1.key, "fetch", b)
await a.close()

b.total_out_connections = old_out_connections
b.comm_threshold_bytes = old_comm_threshold
with freeze_data_fetching(b):
f1 = c.submit(inc, 1, workers=[a.address], key="f1", allow_other_workers=True)
f2 = c.submit(inc, f1, workers=[b.address], key="f2")
await wait_for_state(f1.key, "fetch", b)
await a.close()

await f2

assert_story(
b.log,
# FIXME: This log should be replaced with an
# StateMachineEvent/Instruction log
# FIXME: This log should be replaced with a StateMachineEvent log
[
(f2.key, "compute-task", "released"),
# This is a "please fetch" request. We don't have anything like
Expand All @@ -285,23 +285,180 @@ async def test_fetch_to_compute(c, s, a, b):

@gen_cluster(client=True)
async def test_fetch_via_amm_to_compute(c, s, a, b):
# Block ensure_communicating to ensure we indeed know that the task is in
# fetch and doesn't leave it accidentally
old_out_connections, b.total_out_connections = b.total_out_connections, 0
old_comm_threshold, b.comm_threshold_bytes = b.comm_threshold_bytes, 0

f1 = c.submit(inc, 1, workers=[a.address], key="f1", allow_other_workers=True)
with freeze_data_fetching(b):
f1 = c.submit(inc, 1, workers=[a.address], key="f1", allow_other_workers=True)
await f1
s.request_acquire_replicas(b.address, [f1.key], stimulus_id="test")
await wait_for_state(f1.key, "fetch", b)
await a.close()

await f1
s.request_acquire_replicas(b.address, [f1.key], stimulus_id="test")

await wait_for_state(f1.key, "fetch", b)
await a.close()
assert_story(
b.log,
# FIXME: This log should be replaced with a StateMachineEvent log
[
(f1.key, "ensure-task-exists", "released"),
(f1.key, "released", "fetch", "fetch", {}),
(f1.key, "compute-task", "fetch"),
(f1.key, "put-in-memory"),
],
)
Comment on lines +297 to +306
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this log necessary now? It wasn't there before

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To show evidence that recomputing f1 was in fact what happened, as opposed to a.close() not working as intended in the test which would cause a successful fetch->flight->memory cycle.


b.total_out_connections = old_out_connections
b.comm_threshold_bytes = old_comm_threshold

await f1
@pytest.mark.parametrize("as_deps", [False, True])
@gen_cluster(client=True, nthreads=[("", 1)] * 3)
async def test_lose_replica_during_fetch(c, s, w1, w2, w3, as_deps):
"""
as_deps=True
0. task x is a dependency of y1 and y2
1. scheduler calls handle_compute("y1", who_has={"x": [w2, w3]}) on w1
2. x transitions released -> fetch
3. the network stack is busy, so x does not transition to flight yet.
4. scheduler calls handle_compute("y2", who_has={"x": [w3]}) on w1
5. when x finally reaches the top of the data_needed heap, w1 will not try
contacting w2

as_deps=False
1. scheduler calls handle_acquire_replicas(who_has={"x": [w2, w3]}) on w1
2. x transitions released -> fetch
3. the network stack is busy, so x does not transition to flight yet.
4. scheduler calls handle_acquire_replicas(who_has={"x": [w3]}) on w1
5. when x finally reaches the top of the data_needed heap, w1 will not try
contacting w2
"""
x = (await c.scatter({"x": 1}, workers=[w2.address, w3.address], broadcast=True))[
"x"
]

# Make sure find_missing is not involved
w1.periodic_callbacks["find-missing"].stop()

with freeze_data_fetching(w1, jump_start=True):
if as_deps:
y1 = c.submit(inc, x, key="y1", workers=[w1.address])
else:
s.request_acquire_replicas(w1.address, ["x"], stimulus_id="test")

await wait_for_state("x", "fetch", w1)
assert w1.tasks["x"].who_has == {w2.address, w3.address}

assert len(s.tasks["x"].who_has) == 2
await w2.close()
while len(s.tasks["x"].who_has) > 1:
await asyncio.sleep(0.01)

if as_deps:
y2 = c.submit(inc, x, key="y2", workers=[w1.address])
else:
s.request_acquire_replicas(w1.address, ["x"], stimulus_id="test")

while w1.tasks["x"].who_has != {w3.address}:
await asyncio.sleep(0.01)

await wait_for_state("x", "memory", w1)
assert_story(
w1.story("request-dep"),
[("request-dep", w3.address, {"x"})],
fjetter marked this conversation as resolved.
Show resolved Hide resolved
# This tests that there has been no attempt to contact w2.
# If the assumption being tested breaks, this will fail 50% of the times.
strict=True,
fjetter marked this conversation as resolved.
Show resolved Hide resolved
)


@gen_cluster(client=True, nthreads=[("", 1)] * 2)
async def test_fetch_to_missing(c, s, a, b):
"""
1. task x is a dependency of y
2. scheduler calls handle_compute("y", who_has={"x": [b]}) on a
3. x transitions released -> fetch -> flight; a connects to b
4. b responds it's busy. x transitions flight -> fetch
5. The busy state triggers an RPC call to Scheduler.who_has
6. the scheduler responds {"x": []}, because w1 in the meantime has lost the key.
7. x is transitioned fetch -> missing
Comment on lines +377 to +378
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would the scheduler respond with this message? Why would the worker need to transition to missing? I think the update_who_has should not do anything in this case but rather the scheduler should eventually tell the worker to release the entire task.

If the "update who_has" was not via a dedicated RPC this entire scenario would even be impossible, wouldn't it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the scheduler should eventually tell the worker to release the entire task.

This will happen for dependencies tracked by {op: compute-task}, but not for keys fetched through {op: acquire-replicas}.

Why would the scheduler respond with this message?

  1. Because a third worker recently sent to the scheduler {op: missing-data, errant_worker: b, key: x}
  2. Because the network between b and the scheduler is malfunctioning, so the scheduler hasn't received a heartbeat from b for 5 minutes, but for whatever reason the network between workers a and b is still functioning. This second use case is reproduced in the test (lines 404-407). I could have done the first one, but it would have been more tedious to implement due to time sensitivity.
    Because it hasn't received any heartbeat f

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because a third worker recently sent to the scheduler {op: missing-data, errant_worker: b, key: x}

For a dependency this should still be followed by an appropriate free-keys message such that the key will properly released by this worker. It feels like AMM should keep track of the workers it asked to fetch a key to allow for the same mechanism. I would feel more comfortable if both "types" of tasks are handled the same on worker side

Because the network between b and the scheduler is malfunctioning

If the network between a worker and the scheduler is down, we're closing the worker

"""
x = await c.scatter({"x": 1}, workers=[b.address])
b.total_in_connections = 0
# Crucially, unlike with `c.submit(inc, x, workers=[a.address])`, the scheduler
# doesn't keep track of acquire-replicas requests, so it won't proactively inform a
# when we call remove_worker later on
s.request_acquire_replicas(a.address, ["x"], stimulus_id="test")

# state will flip-flop between fetch and flight every 150ms, which is the retry
# period for busy workers.
await wait_for_state("x", "fetch", a)
assert b.address in a.busy_workers

# Sever connection between b and s, but not between b and a.
# If a tries fetching from b after this, b will keep responding {status: busy}.
b.periodic_callbacks["heartbeat"].stop()
await s.remove_worker(b.address, close=False, stimulus_id="test")
Comment on lines +392 to +395
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC you are preparing the scheduler state such that it believes the worker is dead. This is only partially true after this statement because the stream is still open. This is flagged in #6390 as a problem and it suggest we should close the stream as well, i.e. remove the close keyword.

Why this bothers me a bit is because I believe in a real world setup where this connection is severed, the cluster would be able to self heal because the connection between A and B would also close. This would be a different mechanism than what this test is specifically testing.


await wait_for_state("x", "missing", a)

assert_story(
a.story("x"),
[
("x", "ensure-task-exists", "released"),
("x", "released", "fetch", "fetch", {}),
("gather-dependencies", b.address, {"x"}),
("x", "fetch", "flight", "flight", {}),
("request-dep", b.address, {"x"}),
("busy-gather", b.address, {"x"}),
("x", "flight", "fetch", "fetch", {}),
("x", "fetch", "missing", "missing", {}),
],
# There may be a round of find_missing() after this.
# Due to timings, there also may be multiple attempts to connect from a to b.
strict=False,
)


@pytest.mark.skip(reason="https://github.com/dask/distributed/issues/6446")
@gen_cluster(client=True)
async def test_new_replica_while_all_workers_in_flight(c, s, w1, w2):
"""A task is stuck in 'fetch' state because all workers that hold a replica are in
flight. While in this state, a new replica appears on a different worker and the
scheduler informs the waiting worker through a new acquire-replicas or
compute-task op.

In real life, this will typically happen when the Active Memory Manager replicates a
key to multiple workers and some workers are much faster than others to acquire it,
due to unrelated tasks being in flight, so 2 seconds later the AMM reiterates the
request, passing a larger who_has.

Test that, when this happens, the task is immediately acquired from the new worker,
without waiting for the original replica holders to get out of flight.
"""
# Make sure find_missing is not involved
w1.periodic_callbacks["find-missing"].stop()

async with BlockedGetData(s.address) as w3:
x = c.submit(inc, 1, key="x", workers=[w3.address])
y = c.submit(inc, 2, key="y", workers=[w3.address])
await wait([x, y])
s.request_acquire_replicas(w1.address, ["x"], stimulus_id="test")
await w3.in_get_data.wait()
assert w1.tasks["x"].state == "flight"
s.request_acquire_replicas(w1.address, ["y"], stimulus_id="test")
# This cannot progress beyond fetch because w3 is already in flight
await wait_for_state("y", "fetch", w1)

# Simulate that the AMM also requires that w2 acquires a replica of x.
# The replica lands on w2 soon afterwards, while w3->w1 comms remain blocked by
# unrelated transfers (x in our case).
w2.update_data({"y": 3}, report=True)
ws2 = s.workers[w2.address]
while ws2 not in s.tasks["y"].who_has:
await asyncio.sleep(0.01)

# 2 seconds later, the AMM reiterates that w1 should acquire a replica of y
s.request_acquire_replicas(w1.address, ["y"], stimulus_id="test")
await wait_for_state("y", "memory", w1)

# Finally let the other worker to get out of flight
w3.block_get_data.set()
await wait_for_state("x", "memory", w1)


@gen_cluster(client=True)
Expand Down
86 changes: 86 additions & 0 deletions distributed/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2261,3 +2261,89 @@ def wait_for_log_line(
if match in line:
return line
i += 1


class BlockedGatherDep(Worker):
"""A Worker that sets event `in_gather_dep` the first time it enters the gather_dep
method and then does not initiate any comms, thus leaving the task(s) in flight
indefinitely, until the test sets `block_gather_dep`

Example
-------
.. code-block:: python

@gen_test()
async def test1(s, a, b):
async with BlockedGatherDep(s.address) as x:
# [do something to cause x to fetch data from a or b]
await x.in_gather_dep.wait()
# [do something that must happen while the tasks are in flight]
x.block_gather_dep.set()
# [from this moment on, x is a regular worker]

See also
--------
BlockedGetData
"""

def __init__(self, *args, **kwargs):
self.in_gather_dep = asyncio.Event()
self.block_gather_dep = asyncio.Event()
super().__init__(*args, **kwargs)

async def gather_dep(self, *args, **kwargs):
self.in_gather_dep.set()
await self.block_gather_dep.wait()
return await super().gather_dep(*args, **kwargs)


class BlockedGetData(Worker):
"""A Worker that sets event `in_get_data` the first time it enters the get_data
method and then does not answer the comms, thus leaving the task(s) in flight
indefinitely, until the test sets `block_get_data`

See also
--------
BlockedGatherDep
"""

def __init__(self, *args, **kwargs):
self.in_get_data = asyncio.Event()
self.block_get_data = asyncio.Event()
super().__init__(*args, **kwargs)

async def get_data(self, comm, *args, **kwargs):
self.in_get_data.set()
await self.block_get_data.wait()
return await super().get_data(comm, *args, **kwargs)


@contextmanager
def freeze_data_fetching(w: Worker, *, jump_start: bool = False):
"""Prevent any task from transitioning from fetch to flight on the worker while
inside the context, simulating a situation where the worker's network comms are
saturated.

This is not the same as setting the worker to Status=paused, which would also
inform the Scheduler and prevent further tasks to be enqueued on the worker.

Parameters
----------
w: Worker
The Worker on which tasks will not transition from fetch to flight
jump_start: bool
If False, tasks will remain in fetch state after exiting the context, until
something else triggers ensure_communicating.
If True, trigger ensure_communicating on exit; this simulates e.g. an unrelated
worker moving out of in_flight_workers.
"""
old_out_connections = w.total_out_connections
old_comm_threshold = w.comm_threshold_bytes
w.total_out_connections = 0
w.comm_threshold_bytes = 0
yield
w.total_out_connections = old_out_connections
w.comm_threshold_bytes = old_comm_threshold
if jump_start:
w.status = Status.paused
w.status = Status.running
Loading