Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Close workers gracefully #2892

Merged
merged 13 commits into from
Jul 31, 2019
Merged
26 changes: 25 additions & 1 deletion distributed/cli/dask_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,30 @@
@click.option(
"--dashboard-prefix", type=str, default="", help="Prefix for the dashboard"
)
@click.option(
"--lifetime",
type=str,
default="",
help="If provided, shut down the worker after this duration. "
"Note that if combined with --nanny (default) "
"the worker will restart after this time",
mrocklin marked this conversation as resolved.
Show resolved Hide resolved
)
@click.option(
"--lifetime-stagger",
type=str,
default="0 seconds",
show_default=True,
help="Random amount by which to stagger lifetime values",
)
@click.option(
"--lifetime-restart/--no-lifetime-restart",
"lifetime_restart",
default=False,
show_default=True,
required=False,
help="Whether or not to restart the worker after the lifetime lapses. "
"This assumes that you are using the --lifetime and --nanny keywords",
)
@click.option(
"--preload",
type=str,
Expand Down Expand Up @@ -347,7 +371,7 @@ def del_pid_file():
dashboard_address=dashboard_address if dashboard else None,
service_kwargs={"dashboard": {"prefix": dashboard_prefix}},
name=name if nprocs == 1 or not name else name + "-" + str(i),
**kwargs,
**kwargs
)
for i in range(nprocs)
]
Expand Down
1 change: 1 addition & 0 deletions distributed/diagnostics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from ..utils import ignoring
from .graph_layout import GraphLayout
from .plugin import SchedulerPlugin

with ignoring(ImportError):
from .progressbar import progress
Expand Down
5 changes: 4 additions & 1 deletion distributed/distributed.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ distributed:
key: null
cert: null


worker:
blocked-handlers: []
multiprocessing-method: forkserver
Expand All @@ -44,6 +43,10 @@ distributed:
preload: []
preload-argv: []
daemon: True
lifetime:
duration: null # Time after which to gracefully shutdown the worker
stagger: 0 seconds # Random amount by which to stagger lifetimes
restart: False # Do we ressurrect the worker after the lifetime deadline?

profile:
interval: 10ms # Time between statistical profiling queries
Expand Down
25 changes: 25 additions & 0 deletions distributed/tests/test_nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
from toolz import valmap, first
from tornado import gen
from tornado.ioloop import IOLoop
from tornado.locks import Event

import dask
from distributed import Nanny, rpc, Scheduler, Worker
from distributed.diagnostics import SchedulerPlugin
from distributed.core import CommClosedError
from distributed.metrics import time
from distributed.protocol.pickle import dumps
Expand Down Expand Up @@ -398,3 +400,26 @@ async def test_nanny_closes_cleanly(cleanup):
assert not n.process
assert not proc.is_alive()
assert proc.exitcode == 0


@pytest.mark.slow
@pytest.mark.asyncio
async def test_lifetime(cleanup):
counter = 0
event = Event()

class Plugin(SchedulerPlugin):
def add_worker(self, **kwargs):
pass

def remove_worker(self, **kwargs):
nonlocal counter
counter += 1
if counter == 2: # wait twice, then trigger closing event
event.set()

async with Scheduler() as s:
s.add_plugin(Plugin())
async with Nanny(s.address) as a:
async with Nanny(s.address, lifetime="500 ms", lifetime_restart=True) as b:
await event.wait()
39 changes: 39 additions & 0 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1498,3 +1498,42 @@ async def test_worker_listens_on_same_interface_by_default(Worker):
assert s.ip in {"127.0.0.1", "localhost"}
async with Worker(s.address) as w:
assert s.ip == w.ip


@gen_cluster(client=True)
async def test_close_gracefully(c, s, a, b):
futures = c.map(slowinc, range(200), delay=0.1)
while not b.data:
await gen.sleep(0.1)

mem = set(b.data)
proc = set(b.executing)

await b.close_gracefully()

assert b.status == "closed"
assert b.address not in s.workers
assert mem.issubset(set(a.data))
for key in proc:
assert s.tasks[key].state in ("processing", "memory")


@pytest.mark.slow
@pytest.mark.asyncio
async def test_lifetime(cleanup):
async with Scheduler() as s:
async with Worker(s.address) as a, Worker(s.address, lifetime="1 seconds") as b:
async with Client(s.address, asynchronous=True) as c:
futures = c.map(slowinc, range(200), delay=0.1)
await gen.sleep(1.5)
assert b.status != "running"
await b.finished()

assert set(b.data).issubset(a.data) # successfully moved data over


@gen_cluster(client=True, worker_kwargs={"lifetime": "10s", "lifetime_stagger": "2s"})
async def test_lifetime_stagger(c, s, a, b):
assert a.lifetime != b.lifetime
assert 8 <= a.lifetime <= 12
assert 8 <= b.lifetime <= 12
62 changes: 57 additions & 5 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,16 @@ class Worker(ServerNode):
Resources that this worker has like ``{'GPU': 2}``
nanny: str
Address on which to contact nanny, if it exists
lifetime: str
Amount of time like "1 hour" after which we gracefully shut down the worker.
This defaults to None, meaning no explicit shutdown time.
lifetime_stagger: str
Amount of time like "5 minutes" to stagger the lifetime value
The actual lifetime will selected uniformly at random between
mrocklin marked this conversation as resolved.
Show resolved Hide resolved
lifetime +/- lifetime_stagger
lifetime_restart: bool
Whether or not to restart a worker after it has reached its lifetime
Default False

Examples
--------
Expand Down Expand Up @@ -311,6 +321,9 @@ def __init__(
low_level_profiler=dask.config.get("distributed.worker.profile.low-level"),
validate=False,
profile_cycle_interval=None,
lifetime=None,
lifetime_stagger="0s",
lifetime_restart=None,
**kwargs
):
self.tasks = dict()
Expand Down Expand Up @@ -656,6 +669,23 @@ def __init__(
self.plugins = {}
self._pending_plugins = plugins

self.lifetime = lifetime or dask.config.get(
"distributed.worker.lifetime.duration"
)
lifetime_stagger = lifetime_stagger or dask.config.get(
"distributed.worker.lifetime.stagger"
)
self.lifetime_restart = lifetime_restart or dask.config.get(
"distributed.worker.lifetime.restart"
)
if isinstance(self.lifetime, str):
self.lifetime = parse_timedelta(self.lifetime)
if isinstance(lifetime_stagger, str):
lifetime_stagger = parse_timedelta(lifetime_stagger)
if self.lifetime:
self.lifetime += (random.random() * 2 - 1) * lifetime_stagger
self.io_loop.call_later(self.lifetime, self.close_gracefully)

Worker._instances.add(self)

##################
Expand Down Expand Up @@ -960,19 +990,22 @@ def _close(self, *args, **kwargs):
warnings.warn("Worker._close has moved to Worker.close", stacklevel=2)
return self.close(*args, **kwargs)

async def close(self, report=True, timeout=10, nanny=True, executor_wait=True):
async def close(
self, report=True, timeout=10, nanny=True, executor_wait=True, safe=False
):
with log_errors():
if self.status in ("closed", "closing"):
await self.finished()
return

self.reconnect = False
disable_gc_diagnosis()

try:
logger.info("Stopping worker at %s", self.address)
except ValueError: # address not available if already closed
logger.info("Stopping worker")
if self.status != "running":
if self.status not in ("running", "closing-gracefully"):
logger.info("Closed worker has not yet started: %s", self.status)
self.status = "closing"

Expand All @@ -996,7 +1029,9 @@ async def close(self, report=True, timeout=10, nanny=True, executor_wait=True):
if report:
await gen.with_timeout(
timedelta(seconds=timeout),
self.scheduler.unregister(address=self.contact_address),
self.scheduler.unregister(
address=self.contact_address, safe=safe
),
)
self.scheduler.close_rpc()
self.actor_executor._work_queue.queue.clear()
Expand Down Expand Up @@ -1030,6 +1065,23 @@ async def close(self, report=True, timeout=10, nanny=True, executor_wait=True):

setproctitle("dask-worker [closed]")

async def close_gracefully(self):
""" Gracefully shut down a worker

This first informs the scheduler that we're shutting down, and asks it
to move our data elsewhere. Afterwards, we close as normal
"""
if self.status.startswith("closing"):
await self.finished()

if self.status == "closed":
return

logger.info("Closing worker gracefully: %s", self.address)
self.status = "closing-gracefully"
await self.scheduler.retire_workers(workers=[self.address], remove=False)
await self.close(safe=True, nanny=not self.lifetime_restart)

async def terminate(self, comm, report=True):
await self.close(report=report)
return "OK"
Expand Down Expand Up @@ -1546,7 +1598,7 @@ def transition_executing_done(self, key, value=no_value, report=True):
if key in self.dep_state:
self.transition_dep(key, "memory")

if report and self.batched_stream:
if report and self.batched_stream and self.status == "running":
self.send_task_state_to_scheduler(key)
else:
raise CommClosedError
Expand Down Expand Up @@ -2283,7 +2335,7 @@ def ensure_computing(self):

async def execute(self, key, report=False):
executor_error = None
if self.status in ("closing", "closed"):
if self.status in ("closing", "closed", "closing-gracefully"):
return
try:
if key not in self.executing or key not in self.task_state:
Expand Down