Skip to content

Commit

Permalink
slow_terminate
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed May 6, 2022
1 parent 70e5c90 commit a907175
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 18 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ jobs:
set -o pipefail
mkdir reports
pytest distributed \
-m "not avoid_ci and ${{ matrix.partition }}" --runslow \
pytest distributed/tests/test_worker_memory.py \
-m "not avoid_ci" --runslow \
--leaks=fds,processes,threads \
--junitxml reports/pytest.xml -o junit_suite_name=$TEST_ID \
--cov=distributed --cov-report=xml \
Expand Down
110 changes: 96 additions & 14 deletions distributed/tests/test_worker_memory.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from __future__ import annotations

import asyncio
import glob
import logging
import os
import signal
import sys
import threading
from collections import Counter, UserDict
from time import sleep
Expand All @@ -11,7 +15,7 @@
import dask.config

import distributed.system
from distributed import Client, Event, Nanny, Worker, wait
from distributed import Client, Event, KilledWorker, Nanny, Scheduler, Worker, wait
from distributed.compatibility import MACOS
from distributed.core import Status
from distributed.metrics import monotonic
Expand Down Expand Up @@ -680,6 +684,44 @@ async def test_manual_evict_proto(c, s, a):
await asyncio.sleep(0.01)


async def leak_until_restart(c: Client, s: Scheduler, a: Nanny) -> None:
s.allowed_failures = 0

def leak():
L = []
while True:
L.append(b"0" * 5_000_000)
sleep(0.01)

assert a.process
assert a.process.process
pid = a.process.pid
addr = a.worker_address
with captured_logger(logging.getLogger("distributed.worker_memory")) as logger:
future = c.submit(leak, key="leak")
while (
not a.process
or not a.process.process
or a.process.pid == pid
or a.worker_address == addr
):
await asyncio.sleep(0.01)

# Test that the restarting message happened only once;
# see test_slow_terminate below.
assert logger.getvalue() == (
f"Worker {addr} (pid={pid}) exceeded 95% memory budget. Restarting...\n"
)

with pytest.raises(KilledWorker):
await future
assert s.tasks["leak"].suspicious == 1
assert await c.run(lambda dask_worker: "leak" in dask_worker.tasks) == {a.worker_address: False}
future.release()
while "leak" in s.tasks:
await asyncio.sleep(0.01)


@pytest.mark.slow
@gen_cluster(
nthreads=[("", 1)],
Expand All @@ -689,21 +731,61 @@ async def test_manual_evict_proto(c, s, a):
config={"distributed.worker.memory.monitor-interval": "10ms"},
)
async def test_nanny_terminate(c, s, a):
def leak():
L = []
while True:
L.append(b"0" * 5_000_000)
sleep(0.01)
await leak_until_restart(c, s, a)

before = a.process.pid
with captured_logger(logging.getLogger("distributed.worker_memory")) as logger:
future = c.submit(leak)
while a.process.pid == before:
await asyncio.sleep(0.01)

out = logger.getvalue()
assert "restart" in out.lower()
assert "memory" in out.lower()
@pytest.mark.slow
@gen_cluster(
nthreads=[("", 1)],
client=True,
Worker=Nanny,
worker_kwargs={"memory_limit": "400 MiB"},
config={"distributed.worker.memory.monitor-interval": "10ms"},
)
async def test_disk_cleanup_on_terminate(c, s, a):
"""Test that the spilled data on disk is cleaned up when the nanny kills the worker"""
fut = c.submit(inc, 1, key="myspill")
await wait(fut)
await c.run(lambda dask_worker: dask_worker.data.evict())

glob_out = await c.run(
lambda dask_worker: glob.glob(dask_worker.local_directory + "/**/myspill")
)
spill_file = glob_out[a.worker_address][0]
assert os.path.exists(spill_file)
await leak_until_restart(c, s, a)
assert not os.path.exists(spill_file)


@pytest.mark.slow
@gen_cluster(
client=True,
Worker=Nanny,
nthreads=[("", 1)],
worker_kwargs={"memory_limit": "400 MiB"},
config={"distributed.worker.memory.monitor-interval": "10ms"},
)
async def test_slow_terminate(c, s, a):
"""A worker is slow to accept SIGTERM, e.g. because the
distributed.diskutils.WorkDir teardown is deleting tens of GB worth of spilled data.
"""

def install_slow_sigterm_handler():
def cb(signo, frame):
# If something sends SIGTERM while the previous SIGTERM handler is running,
# you will eventually get RecursionError.
print(f"Received signal {signo}")
sleep(0.2) # Longer than monitor-interval
print("Leaving handler")
sys.exit(0)

signal.signal(signal.SIGTERM, cb)

await c.run(install_slow_sigterm_handler)
# Test that SIGTERM is only sent once
await leak_until_restart(c, s, a)
# Test that SIGTERM can be sent again after the worker restarts
await leak_until_restart(c, s, a)


@gen_cluster(
Expand Down
16 changes: 14 additions & 2 deletions distributed/worker_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ class NannyMemoryManager:
memory_limit: int | None
memory_terminate_fraction: float | Literal[False]
memory_monitor_interval: float | None
_last_terminated_pid: int

def __init__(
self,
Expand All @@ -321,6 +322,8 @@ def __init__(
default=False,
)
assert isinstance(self.memory_monitor_interval, (int, float))
self._last_terminated_pid = -1

if self.memory_limit and self.memory_terminate_fraction is not False:
pc = PeriodicCallback(
partial(self.memory_monitor, nanny),
Expand All @@ -341,11 +344,20 @@ def memory_monitor(self, nanny: Nanny) -> None:
except (ProcessLookupError, psutil.NoSuchProcess, psutil.AccessDenied):
return # pragma: nocover

if proc.pid == self._last_terminated_pid:
# We already sent SIGTERM to the worker, but its handler is still running
# since the previous iteration of the memory_monitor - for example, it
# may be taking a long time deleting all the spilled data from disk.
return
self._last_terminated_pid = -1

if memory / self.memory_limit > self.memory_terminate_fraction:
logger.warning(
"Worker exceeded %d%% memory budget. Restarting",
100 * self.memory_terminate_fraction,
f"Worker {nanny.worker_address} (pid={process.pid}) exceeded "
f"{self.memory_terminate_fraction * 100:.0f}% memory budget. "
"Restarting...",
)
self._last_terminated_pid = proc.pid
process.terminate()


Expand Down

0 comments on commit a907175

Please sign in to comment.