Skip to content

Commit

Permalink
Add finalizer to reap clusters when Python exits (#487)
Browse files Browse the repository at this point in the history
* Release 2022.5.0

* Add finalizer to shutdown experimental KubeCluster

* Add finalizer to reap clusters when Python exits

* Handle async clusters
  • Loading branch information
jacobtomlinson authored May 16, 2022
1 parent 8190529 commit bf1b6a4
Showing 1 changed file with 38 additions and 5 deletions.
43 changes: 38 additions & 5 deletions dask_kubernetes/experimental/kubecluster.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
from __future__ import annotations

import asyncio
import atexit
from contextlib import suppress
from enum import Enum
import time
from typing import ClassVar
import weakref

import kubernetes_asyncio as kubernetes

from distributed.core import rpc
from distributed.core import Status, rpc
from distributed.deploy import Cluster

from distributed.utils import Log, Logs, LoopRunner
from distributed.utils import Log, Logs, LoopRunner, TimeoutError

from dask_kubernetes.common.auth import ClusterAuth
from dask_kubernetes.operator import (
Expand Down Expand Up @@ -103,6 +111,8 @@ class KubeCluster(Cluster):
KubeCluster.from_name
"""

_instances: ClassVar[weakref.WeakSet[KubeCluster]] = weakref.WeakSet()

def __init__(
self,
name,
Expand Down Expand Up @@ -133,6 +143,8 @@ def __init__(
self._loop_runner = LoopRunner(loop=loop, asynchronous=asynchronous)
self.loop = self._loop_runner.loop

self._instances.add(self)

super().__init__(asynchronous=asynchronous, **kwargs)
if not self.asynchronous:
self._loop_runner.start()
Expand Down Expand Up @@ -363,11 +375,11 @@ async def _delete_worker_group(self, name):
name=f"{self.name}-cluster-{name}",
)

def close(self):
def close(self, timeout=3600):
"""Delete the dask cluster"""
return self.sync(self._close)
return self.sync(self._close, timeout=timeout)

async def _close(self):
async def _close(self, timeout=None):
await super()._close()
if self.shutdown_on_close:
async with kubernetes.client.api_client.ApiClient() as api_client:
Expand All @@ -379,7 +391,12 @@ async def _close(self):
namespace=self.namespace,
name=self.cluster_name,
)
start = time.time()
while (await self._get_cluster()) is not None:
if time.time() > start + timeout:
raise TimeoutError(
f"Timed out deleting cluster resource {self.cluster_name}"
)
await asyncio.sleep(1)

def scale(self, n, worker_group="default"):
Expand Down Expand Up @@ -537,3 +554,19 @@ def from_name(cls, name, **kwargs):
>>> cluster = KubeCluster.from_name(name="simple-cluster")
"""
return cls(name=name, create_mode=CreateMode.CONNECT_ONLY, **kwargs)


@atexit.register
def reap_clusters():
async def _reap_clusters():
for cluster in list(KubeCluster._instances):
if cluster.shutdown_on_close and cluster.status != Status.closed:
await ClusterAuth.load_first(cluster.auth)
with suppress(TimeoutError):
if cluster.asynchronous:
await cluster.close(timeout=10)
else:
cluster.close(timeout=10)

loop = asyncio.get_event_loop()
loop.run_until_complete(_reap_clusters())

0 comments on commit bf1b6a4

Please sign in to comment.