Skip to content

Commit

Permalink
AMM: tentatively stabilize flaky tests around worker pause (#5735)
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky authored Feb 3, 2022
1 parent 834421b commit 30ffa9c
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 29 deletions.
76 changes: 52 additions & 24 deletions distributed/tests/test_active_memory_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,12 @@ async def test_drop_with_bad_candidates(c, s, a, b):
assert s.tasks["x"].who_has == {ws0, ws1}


@gen_cluster(client=True, nthreads=[("", 1)] * 10, config=demo_config("drop", n=1))
@gen_cluster(
client=True,
nthreads=[("", 1)] * 10,
config=demo_config("drop", n=1),
worker_kwargs={"memory_monitor_interval": "20ms"},
)
async def test_drop_prefers_paused_workers(c, s, *workers):
x = await c.scatter({"x": 1}, broadcast=True)
ts = s.tasks["x"]
Expand All @@ -420,7 +425,11 @@ async def test_drop_prefers_paused_workers(c, s, *workers):


@pytest.mark.slow
@gen_cluster(client=True, config=demo_config("drop"))
@gen_cluster(
client=True,
config=demo_config("drop"),
worker_kwargs={"memory_monitor_interval": "20ms"},
)
async def test_drop_with_paused_workers_with_running_tasks_1(c, s, a, b):
"""If there is exactly 1 worker that holds a replica of a task that isn't paused or
retiring, and there are 1+ paused/retiring workers with the same task, don't drop
Expand All @@ -431,21 +440,25 @@ async def test_drop_with_paused_workers_with_running_tasks_1(c, s, a, b):
b is running and has no dependent tasks
"""
x = (await c.scatter({"x": 1}, broadcast=True))["x"]
y = c.submit(slowinc, x, delay=2, key="y", workers=[a.address])
y = c.submit(slowinc, x, delay=2.5, key="y", workers=[a.address])

while "y" not in a.tasks or a.tasks["y"].state != "executing":
await asyncio.sleep(0.01)
a.memory_pause_fraction = 1e-15
while s.workers[a.address].status != Status.paused:
await asyncio.sleep(0.01)
assert s.tasks["y"].state == "processing"
assert a.tasks["y"].state == "executing"

s.extensions["amm"].run_once()
await y
assert len(s.tasks["x"].who_has) == 2


@gen_cluster(client=True, config=demo_config("drop"))
@gen_cluster(
client=True,
config=demo_config("drop"),
worker_kwargs={"memory_monitor_interval": "20ms"},
)
async def test_drop_with_paused_workers_with_running_tasks_2(c, s, a, b):
"""If there is exactly 1 worker that holds a replica of a task that isn't paused or
retiring, and there are 1+ paused/retiring workers with the same task, don't drop
Expand All @@ -470,7 +483,7 @@ async def test_drop_with_paused_workers_with_running_tasks_2(c, s, a, b):
@gen_cluster(
client=True,
config=demo_config("drop"),
worker_kwargs={"memory_monitor_interval": "50ms"},
worker_kwargs={"memory_monitor_interval": "20ms"},
)
async def test_drop_with_paused_workers_with_running_tasks_3_4(c, s, a, b, pause):
"""If there is exactly 1 worker that holds a replica of a task that isn't paused or
Expand Down Expand Up @@ -505,7 +518,12 @@ async def test_drop_with_paused_workers_with_running_tasks_3_4(c, s, a, b, pause


@pytest.mark.slow
@gen_cluster(client=True, nthreads=[("", 1)] * 3, config=demo_config("drop"))
@gen_cluster(
client=True,
nthreads=[("", 1)] * 3,
config=demo_config("drop"),
worker_kwargs={"memory_monitor_interval": "20ms"},
)
async def test_drop_with_paused_workers_with_running_tasks_5(c, s, w1, w2, w3):
"""If there is exactly 1 worker that holds a replica of a task that isn't paused or
retiring, and there are 1+ paused/retiring workers with the same task, don't drop
Expand All @@ -517,27 +535,28 @@ async def test_drop_with_paused_workers_with_running_tasks_5(c, s, w1, w2, w3):
w3 is running and with dependent tasks executing on it
"""
x = (await c.scatter({"x": 1}, broadcast=True))["x"]
y1 = c.submit(slowinc, x, delay=2, key="y1", workers=[w1.address])
y2 = c.submit(slowinc, x, delay=2, key="y2", workers=[w3.address])
while (
"y1" not in w1.tasks
or w1.tasks["y1"].state != "executing"
or "y2" not in w3.tasks
or w3.tasks["y2"].state != "executing"
):
y1 = c.submit(slowinc, x, delay=2.5, key="y1", workers=[w1.address])
y2 = c.submit(slowinc, x, delay=2.5, key="y2", workers=[w3.address])

def executing() -> bool:
return (
"y1" in w1.tasks
and w1.tasks["y1"].state == "executing"
and "y2" in w3.tasks
and w3.tasks["y2"].state == "executing"
)

while not executing():
await asyncio.sleep(0.01)
w1.memory_pause_fraction = 1e-15
while s.workers[w1.address].status != Status.paused:
await asyncio.sleep(0.01)
assert s.tasks["y1"].state == "processing"
assert s.tasks["y2"].state == "processing"
assert w1.tasks["y1"].state == "executing"
assert w3.tasks["y2"].state == "executing"
assert executing()

s.extensions["amm"].run_once()
await y1
await y2
assert {ws.address for ws in s.tasks["x"].who_has} == {w1.address, w3.address}
while {ws.address for ws in s.tasks["x"].who_has} != {w1.address, w3.address}:
await asyncio.sleep(0.01)
assert executing()


@gen_cluster(nthreads=[("", 1)] * 4, client=True, config=demo_config("replicate", n=2))
Expand Down Expand Up @@ -648,7 +667,12 @@ async def test_replicate_to_candidates_with_key(c, s, a, b):
assert s.tasks["x"].who_has == {ws0}


@gen_cluster(client=True, nthreads=[("", 1)] * 3, config=demo_config("replicate"))
@gen_cluster(
client=True,
nthreads=[("", 1)] * 3,
config=demo_config("replicate"),
worker_kwargs={"memory_monitor_interval": "20ms"},
)
async def test_replicate_avoids_paused_workers_1(c, s, w0, w1, w2):
w1.memory_pause_fraction = 1e-15
while s.workers[w1.address].status != Status.paused:
Expand All @@ -662,7 +686,11 @@ async def test_replicate_avoids_paused_workers_1(c, s, w0, w1, w2):
assert "x" not in w1.data


@gen_cluster(client=True, config=demo_config("replicate"))
@gen_cluster(
client=True,
config=demo_config("replicate"),
worker_kwargs={"memory_monitor_interval": "20ms"},
)
async def test_replicate_avoids_paused_workers_2(c, s, a, b):
b.memory_pause_fraction = 1e-15
while s.workers[b.address].status != Status.paused:
Expand Down
6 changes: 5 additions & 1 deletion distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5865,7 +5865,11 @@ def bad_fn(x):
@pytest.mark.parametrize("workers_arg", [False, True])
@pytest.mark.parametrize("direct", [False, True])
@pytest.mark.parametrize("broadcast", [False, True, 10])
@gen_cluster(client=True, nthreads=[("", 1)] * 10)
@gen_cluster(
client=True,
nthreads=[("", 1)] * 10,
worker_kwargs={"memory_monitor_interval": "20ms"},
)
async def test_scatter_and_replicate_avoid_paused_workers(
c, s, *workers, workers_arg, direct, broadcast
):
Expand Down
12 changes: 10 additions & 2 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3258,7 +3258,11 @@ async def test_set_restrictions(c, s, a, b):
await f


@gen_cluster(client=True, nthreads=[("", 1)] * 3)
@gen_cluster(
client=True,
nthreads=[("", 1)] * 3,
worker_kwargs={"memory_monitor_interval": "20ms"},
)
async def test_avoid_paused_workers(c, s, w1, w2, w3):
w2.memory_pause_fraction = 1e-15
while s.workers[w2.address].status != Status.paused:
Expand All @@ -3271,7 +3275,11 @@ async def test_avoid_paused_workers(c, s, w1, w2, w3):
assert len(w1.data) + len(w3.data) == 8


@gen_cluster(client=True, nthreads=[("", 1)])
@gen_cluster(
client=True,
nthreads=[("", 1)],
worker_kwargs={"memory_monitor_interval": "20ms"},
)
async def test_unpause_schedules_unrannable_tasks(c, s, a):
a.memory_pause_fraction = 1e-15
while s.workers[a.address].status != Status.paused:
Expand Down
6 changes: 5 additions & 1 deletion distributed/tests/test_steal.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,7 +838,11 @@ async def test_steal_twice(c, s, a, b):
await asyncio.gather(*(w.close() for w in workers))


@gen_cluster(client=True, nthreads=[("", 1)] * 3)
@gen_cluster(
client=True,
nthreads=[("", 1)] * 3,
worker_kwargs={"memory_monitor_interval": "20ms"},
)
async def test_paused_workers_must_not_steal(c, s, w1, w2, w3):
w2.memory_pause_fraction = 1e-15
while s.workers[w2.address].status != Status.paused:
Expand Down
2 changes: 1 addition & 1 deletion distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1214,7 +1214,7 @@ def f(n):
nthreads=[("127.0.0.1", 2)],
client=True,
worker_kwargs={
"memory_monitor_interval": 10,
"memory_monitor_interval": "20ms",
"memory_spill_fraction": False, # don't spill
"memory_target_fraction": False,
"memory_pause_fraction": 0.5,
Expand Down

0 comments on commit 30ffa9c

Please sign in to comment.