Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into server_close_refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter committed May 20, 2022
2 parents 7ad6996 + 6b91ec6 commit 859e196
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 14 deletions.
3 changes: 3 additions & 0 deletions distributed/diagnostics/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ def add_client(self, scheduler: Scheduler, client: str) -> None:
def remove_client(self, scheduler: Scheduler, client: str) -> None:
"""Run when a client disconnects"""

def log_event(self, name, msg) -> None:
"""Run when an event is logged"""


class WorkerPlugin:
"""Interface to extend the Worker
Expand Down
22 changes: 21 additions & 1 deletion distributed/diagnostics/tests/test_scheduler_plugin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from distributed import Scheduler, SchedulerPlugin, Worker
from distributed import Scheduler, SchedulerPlugin, Worker, get_worker
from distributed.utils_test import gen_cluster, gen_test, inc


Expand Down Expand Up @@ -178,3 +178,23 @@ def start(self, scheduler):
assert "distributed.scheduler.pickle" in msg

assert n_plugins == len(s.plugins)


@gen_cluster(client=True)
async def test_log_event_plugin(c, s, a, b):
class EventPlugin(SchedulerPlugin):
async def start(self, scheduler: Scheduler) -> None:
self.scheduler = scheduler
self.scheduler._recorded_events = list() # type: ignore

def log_event(self, name, msg):
self.scheduler._recorded_events.append((name, msg))

await c.register_scheduler_plugin(EventPlugin())

def f():
get_worker().log_event("foo", 123)

await c.submit(f)

assert ("foo", 123) in s._recorded_events
6 changes: 6 additions & 0 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6941,6 +6941,12 @@ def log_event(self, name, msg):
self.event_counts[name] += 1
self._report_event(name, event)

for plugin in list(self.plugins.values()):
try:
plugin.log_event(name, msg)
except Exception:
logger.info("Plugin failed with exception", exc_info=True)

def _report_event(self, name, event):
for client in self.event_subscriber[name]:
self.report(
Expand Down
66 changes: 61 additions & 5 deletions distributed/tests/test_cancelled_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import distributed
from distributed import Event, Lock, Worker
from distributed.client import wait
from distributed.utils_test import (
_LockedCommPool,
assert_story,
Expand Down Expand Up @@ -264,11 +265,12 @@ async def test_in_flight_lost_after_resumed(c, s, b):
block_get_data = asyncio.Lock()
in_get_data = asyncio.Event()

await block_get_data.acquire()
lock_executing = Lock()

def block_execution(lock):
with lock:
return
return 1

class BlockedGetData(Worker):
async def get_data(self, comm, *args, **kwargs):
Expand All @@ -281,15 +283,12 @@ async def get_data(self, comm, *args, **kwargs):
block_execution,
lock_executing,
workers=[a.address],
allow_other_workers=True,
key="fut1",
)
# Ensure fut1 is in memory but block any further execution afterwards to
# ensure we control when the recomputation happens
await fut1
await wait(fut1)
await lock_executing.acquire()
in_get_data.clear()
await block_get_data.acquire()
fut2 = c.submit(inc, fut1, workers=[b.address], key="fut2")

# This ensures that B already fetches the task, i.e. after this the task
Expand All @@ -298,6 +297,7 @@ async def get_data(self, comm, *args, **kwargs):
assert fut1.key in b.tasks
assert b.tasks[fut1.key].state == "flight"

s.set_restrictions({fut1.key: [a.address, b.address]})
# It is removed, i.e. get_data is guaranteed to fail and f1 is scheduled
# to be recomputed on B
await s.remove_worker(a.address, "foo", close=False, safe=True)
Expand Down Expand Up @@ -396,3 +396,59 @@ def block_execution(event, lock):
await lock_executing.release()

assert await fut2 == 2


@gen_cluster(client=True, nthreads=[("", 1)] * 2)
async def test_cancelled_resumed_after_flight_with_dependencies(c, s, w2, w3):
# See https://github.com/dask/distributed/pull/6327#discussion_r872231090
block_get_data_1 = asyncio.Lock()
enter_get_data_1 = asyncio.Event()
await block_get_data_1.acquire()

class BlockGetDataWorker(Worker):
def __init__(self, *args, get_data_event, get_data_lock, **kwargs):
self._get_data_event = get_data_event
self._get_data_lock = get_data_lock
super().__init__(*args, **kwargs)

async def get_data(self, comm, *args, **kwargs):
self._get_data_event.set()
async with self._get_data_lock:
return await super().get_data(comm, *args, **kwargs)

async with await BlockGetDataWorker(
s.address,
get_data_event=enter_get_data_1,
get_data_lock=block_get_data_1,
name="w1",
) as w1:

f1 = c.submit(inc, 1, key="f1", workers=[w1.address])
f2 = c.submit(inc, 2, key="f2", workers=[w1.address])
f3 = c.submit(sum, [f1, f2], key="f3", workers=[w1.address])

await wait(f3)
f4 = c.submit(inc, f3, key="f4", workers=[w2.address])

await enter_get_data_1.wait()
s.set_restrictions(
{
f1.key: {w3.address},
f2.key: {w3.address},
f3.key: {w2.address},
}
)
await s.remove_worker(w1.address, "stim-id")

await wait_for_state(f3.key, "resumed", w2)
assert_story(
w2.log,
[
(f3.key, "flight", "released", "cancelled", {}),
# ...
(f3.key, "cancelled", "waiting", "resumed", {}),
],
)
# w1 closed

assert await f4 == 6
12 changes: 4 additions & 8 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,6 +828,7 @@ def __init__(

# FIXME annotations: https://github.com/tornadoweb/tornado/issues/3117
pc = PeriodicCallback(self.find_missing, 1000) # type: ignore
self._find_missing_running = False
self.periodic_callbacks["find-missing"] = pc

self._address = contact_address
Expand Down Expand Up @@ -1982,13 +1983,6 @@ def handle_compute_task(
self.transitions(recommendations, stimulus_id=stimulus_id)
self._handle_instructions(instructions)

if self.validate:
# All previously unknown tasks that were created above by
# ensure_tasks_exists() have been transitioned to fetch or flight
assert all(
ts2.state != "released" for ts2 in (ts, *ts.dependencies)
), self.story(ts, *ts.dependencies)

########################
# Worker State Machine #
########################
Expand Down Expand Up @@ -3431,9 +3425,10 @@ def _readd_busy_worker(self, worker: str) -> None:

@log_errors
async def find_missing(self) -> None:
if not self._missing_dep_flight:
if self._find_missing_running or not self._missing_dep_flight:
return
try:
self._find_missing_running = True
if self.validate:
for ts in self._missing_dep_flight:
assert not ts.who_has
Expand All @@ -3451,6 +3446,7 @@ async def find_missing(self) -> None:
self.transitions(recommendations, stimulus_id=stimulus_id)

finally:
self._find_missing_running = False
# This is quite arbitrary but the heartbeat has scaling implemented
self.periodic_callbacks[
"find-missing"
Expand Down

0 comments on commit 859e196

Please sign in to comment.