Skip to content

Commit

Permalink
Add synced dict between cluster and scheduler to store cluster info (#…
Browse files Browse the repository at this point in the history
…5033)

Adds a `cluster_info` attribute to all `Cluster` objects which is a dictionary that is synced to the scheduler periodically. Any info already on the scheduler during `_start` is merged into the dict in `Cluster` and then that dict is synced back to the scheduler every second.
  • Loading branch information
jacobtomlinson authored Sep 9, 2021
1 parent 3fba8f2 commit 5bf60e3
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 10 deletions.
48 changes: 38 additions & 10 deletions distributed/deploy/cluster.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import copy
import datetime
import logging
import threading
Expand All @@ -9,7 +10,7 @@
from tornado.ioloop import PeriodicCallback

import dask.config
from dask.utils import _deprecated, format_bytes, parse_timedelta
from dask.utils import _deprecated, format_bytes, parse_timedelta, typename
from dask.widgets import get_template

from ..core import Status
Expand Down Expand Up @@ -44,9 +45,8 @@ class Cluster:
"""

_supports_scaling = True
name = None

def __init__(self, asynchronous, quiet=False, name=None):
def __init__(self, asynchronous, quiet=False, name=None, scheduler_sync_interval=1):
self.scheduler_info = {"workers": {}}
self.periodic_callbacks = {}
self._asynchronous = asynchronous
Expand All @@ -56,13 +56,24 @@ def __init__(self, asynchronous, quiet=False, name=None):
self.quiet = quiet
self.scheduler_comm = None
self._adaptive = None
self._sync_interval = parse_timedelta(
scheduler_sync_interval, default="seconds"
)

if name is None:
name = str(uuid.uuid4())[:8]

if name is not None:
self.name = name
elif self.name is None:
self.name = str(uuid.uuid4())[:8]
self._cluster_info = {"name": name, "type": typename(type(self))}
self.status = Status.created

@property
def name(self):
return self._cluster_info["name"]

@name.setter
def name(self, name):
self._cluster_info["name"] = name

async def _start(self):
comm = await self.scheduler_comm.live_comm()
await comm.write({"op": "subscribe_worker_status"})
Expand All @@ -71,8 +82,25 @@ async def _start(self):
self._watch_worker_status_task = asyncio.ensure_future(
self._watch_worker_status(comm)
)

info = await self.scheduler_comm.get_metadata(
keys=["cluster-manager-info"], default={}
)
self._cluster_info.update(info)

self.periodic_callbacks["sync-cluster-info"] = PeriodicCallback(
self._sync_cluster_info, self._sync_interval * 1000
)
for pc in self.periodic_callbacks.values():
pc.start()
self.status = Status.running

async def _sync_cluster_info(self):
await self.scheduler_comm.set_metadata(
keys=["cluster-manager-info"],
value=copy.copy(self._cluster_info),
)

async def _close(self):
if self.status == Status.closed:
return
Expand All @@ -85,12 +113,12 @@ async def _close(self):
if self._watch_worker_status_task:
await self._watch_worker_status_task

for pc in self.periodic_callbacks.values():
pc.stop()

if self.scheduler_comm:
await self.scheduler_comm.close_rpc()

for pc in self.periodic_callbacks.values():
pc.stop()

self.status = Status.closed

def close(self, timeout=None):
Expand Down
2 changes: 2 additions & 0 deletions distributed/deploy/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def __init__(
interface=None,
worker_class=None,
scheduler_kwargs=None,
scheduler_sync_interval=1,
**worker_kwargs,
):
if ip is not None:
Expand Down Expand Up @@ -241,6 +242,7 @@ def __init__(
asynchronous=asynchronous,
silence_logs=silence_logs,
security=security,
scheduler_sync_interval=scheduler_sync_interval,
)

def start_worker(self, *args, **kwargs):
Expand Down
2 changes: 2 additions & 0 deletions distributed/deploy/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ def __init__(
silence_logs=False,
name=None,
shutdown_on_close=True,
scheduler_sync_interval=1,
):
self._created = weakref.WeakSet()

Expand Down Expand Up @@ -274,6 +275,7 @@ def __init__(
super().__init__(
asynchronous=asynchronous,
name=name,
scheduler_sync_interval=scheduler_sync_interval,
)

if not self.asynchronous:
Expand Down
27 changes: 27 additions & 0 deletions distributed/deploy/tests/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -1073,3 +1073,30 @@ async def test_local_cluster_redundant_kwarg(nanny):
async with Client(cluster) as c:
f = c.submit(sleep, 0)
await f


@pytest.mark.asyncio
async def test_cluster_info_sync():
async with LocalCluster(
processes=False, asynchronous=True, scheduler_sync_interval="1ms"
) as cluster:
assert cluster._cluster_info["name"] == cluster.name

while "name" not in cluster.scheduler.get_metadata(
keys=["cluster-manager-info"]
):
await asyncio.sleep(0.01)

info = await cluster.scheduler_comm.get_metadata(keys=["cluster-manager-info"])
assert info["name"] == cluster.name
info = cluster.scheduler.get_metadata(keys=["cluster-manager-info"])
assert info["name"] == cluster.name

cluster._cluster_info["foo"] = "bar"
while "foo" not in cluster.scheduler.get_metadata(
keys=["cluster-manager-info"]
):
await asyncio.sleep(0.01)

info = cluster.scheduler.get_metadata(keys=["cluster-manager-info"])
assert info["foo"] == "bar"

0 comments on commit 5bf60e3

Please sign in to comment.