Skip to content

Commit

Permalink
Don't connect to cluster subprocesses at shutdown (#6829)
Browse files Browse the repository at this point in the history
  • Loading branch information
gjoseph92 authored Aug 5, 2022
1 parent caf5189 commit e1f3779
Showing 1 changed file with 29 additions and 68 deletions.
97 changes: 29 additions & 68 deletions distributed/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,9 +568,13 @@ def security():
return tls_only_security()


def _terminate_join(proc):
proc.terminate()
proc.join()
def _kill_join(proc, timeout):
proc.kill()
proc.join(timeout)
if proc.is_alive():
raise multiprocessing.TimeoutError(
f"Process {proc} did not shut down within {timeout}s"
)
proc.close()


Expand All @@ -586,7 +590,7 @@ def cluster(
nanny=False,
worker_kwargs=None,
active_rpc_timeout=10,
disconnect_timeout=20,
shutdown_timeout=20,
scheduler_kwargs=None,
config=None,
):
Expand Down Expand Up @@ -618,7 +622,7 @@ def cluster(
)
ws.add(scheduler)
scheduler.start()
stack.callback(_terminate_join, scheduler)
stack.callback(_kill_join, scheduler, shutdown_timeout)

# Launch workers
workers_by_pid = {}
Expand All @@ -640,7 +644,7 @@ def cluster(
)
ws.add(proc)
proc.start()
stack.callback(_terminate_join, proc)
stack.callback(_kill_join, proc, shutdown_timeout)
workers_by_pid[proc.pid] = {"proc": proc}

saddr_or_exception = scheduler_q.get()
Expand All @@ -656,50 +660,27 @@ def cluster(

start = time()
try:
try:
security = scheduler_kwargs["security"]
rpc_kwargs = {
"connection_args": security.get_connection_args("client")
}
except KeyError:
rpc_kwargs = {}

async def wait_for_workers():
async with rpc(saddr, **rpc_kwargs) as s:
while True:
nthreads = await s.ncores_running()
if len(nthreads) == nworkers:
break
if time() - start > 5:
raise Exception("Timeout on cluster creation")
security = scheduler_kwargs["security"]
rpc_kwargs = {"connection_args": security.get_connection_args("client")}
except KeyError:
rpc_kwargs = {}

async def wait_for_workers():
async with rpc(saddr, **rpc_kwargs) as s:
while True:
nthreads = await s.ncores_running()
if len(nthreads) == nworkers:
break
if time() - start > 5:
raise Exception("Timeout on cluster creation")

_run_and_close_tornado(wait_for_workers)
_run_and_close_tornado(wait_for_workers)

# avoid sending processes down to function
yield {"address": saddr}, [
{"address": w["address"], "proc": weakref.ref(w["proc"])}
for w in workers_by_pid.values()
]
finally:

async def close():
logger.debug("Closing out test cluster")
alive_workers = [
w["address"]
for w in workers_by_pid.values()
if w["proc"].is_alive()
]
await disconnect_all(
alive_workers,
timeout=disconnect_timeout,
rpc_kwargs=rpc_kwargs,
)
if scheduler.is_alive():
await disconnect(
saddr, timeout=disconnect_timeout, rpc_kwargs=rpc_kwargs
)

_run_and_close_tornado(close)
# avoid sending processes down to function
yield {"address": saddr}, [
{"address": w["address"], "proc": weakref.ref(w["proc"])}
for w in workers_by_pid.values()
]
try:
client = default_client()
except ValueError:
Expand All @@ -708,26 +689,6 @@ async def close():
client.close()


async def disconnect(addr, timeout=3, rpc_kwargs=None):
rpc_kwargs = rpc_kwargs or {}

async def do_disconnect():
async with rpc(addr, **rpc_kwargs) as w:
# If the worker was killed hard (e.g. sigterm) during test runtime,
# we do not know at this point and may not be able to connect
with suppress(EnvironmentError, CommClosedError):
# Do not request a reply since comms will be closed by the
# worker before a reply can be made and we will always trigger
# the timeout
await w.terminate(reply=False)

await asyncio.wait_for(do_disconnect(), timeout=timeout)


async def disconnect_all(addresses, timeout=3, rpc_kwargs=None):
await asyncio.gather(*(disconnect(addr, timeout, rpc_kwargs) for addr in addresses))


def gen_test(
timeout: float = _TEST_TIMEOUT,
clean_kwargs: dict[str, Any] | None = None,
Expand Down

0 comments on commit e1f3779

Please sign in to comment.