Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dynamic annotations #5207

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from datetime import timedelta
from functools import partial
from numbers import Number
from typing import Optional
from typing import List, Optional

import psutil
import sortedcontainers
Expand Down Expand Up @@ -3722,6 +3722,7 @@ def __init__(
"reschedule": self.reschedule,
"keep-alive": lambda *args, **kwargs: None,
"log-event": self.log_worker_event,
"annotate-task": self.annotate_task,
}

client_handlers = {
Expand Down Expand Up @@ -5529,6 +5530,30 @@ def send_all(self, client_msgs: dict, worker_msgs: dict):
except (CommClosedError, AttributeError):
self.loop.add_callback(self.remove_worker, address=worker)

def annotate_task(
self,
worker=None,
key: str = None,
annotations: Mapping = None,
dependents: bool = None,
):
"""Update the annotations and restrictions of a task"""

parent: SchedulerState = cast(SchedulerState, self)
if worker not in parent._workers_dv:
return
validate_key(key)
ts: TaskState = parent.tasks[key]

# Find tasks to annotate
tasks: List[TaskState] = [ts]
if dependents:
tasks.extend(ts.dependents)

if annotations:
for dts in tasks:
dts.annotations.update(annotations)

############################
# Less common interactions #
############################
Expand Down
106 changes: 105 additions & 1 deletion distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,13 @@
slowinc,
slowsum,
)
from distributed.worker import Worker, error_message, logger, parse_memory_limit
from distributed.worker import (
TaskState,
Worker,
error_message,
logger,
parse_memory_limit,
)

pytestmark = pytest.mark.ci1

Expand Down Expand Up @@ -2009,6 +2015,104 @@ async def test_process_executor(c, s, a, b):
assert (await future) != os.getpid()


@gen_cluster(client=True, nthreads=[])
async def test_annotator_choose_executor(c, s):
"""Test that it is possible for a task to choose executor"""

def get_thread_name(x):
return threading.current_thread().name

def set_executor_by_task_output(ts: TaskState, value: object) -> dict:
return {"annotations": {"executor": value}, "dependents": True}

def f(x):
return x

async with Worker(
s.address,
nthreads=1,
executor={
"exec1": ThreadPoolExecutor(1, thread_name_prefix="Executor1"),
"exec2": ThreadPoolExecutor(1, thread_name_prefix="Executor2"),
},
annotators=[set_executor_by_task_output],
):
dsk = {
"f1": (f, "exec1"),
"g1": (get_thread_name, "f1"),
"f2": (f, "exec2"),
"g2": (get_thread_name, "f2"),
}
res1, res2 = c.get(dsk, ["g1", "g2"], sync=False, asynchronous=True)
assert "Executor1" in await res1
assert "Executor2" in await res2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fact that you had to use get/compute here is intersting. There are interesting challenges with doing this with futures / dependencies that might change. I still think that this is a useful feature to have, but it seems like this approach might not be comprehensive if we want to solve things across update_graph calls. Agree or disagree?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree! I hadn't thought about this issue before today. As you say, it is still useful but I think we should put this PR on hold for a bit and see if we can come up with a better approach.



@gen_cluster(client=True, nthreads=[])
async def test_annotator_raise_exception(c, s):
"""Test that it is possible for a task to choose executor"""

def broken_annotator(ts: TaskState, value: object) -> dict:
raise RuntimeError()

def f(x):
return x

async with Worker(
s.address,
nthreads=1,
annotators=[broken_annotator],
) as w:
res = await c.submit(f, "test")
assert any("ERROR" in log for log in w.get_logs())

# Notice, the task should succeed even when the annotator fails
assert res == "test"


@gen_cluster(client=True, nthreads=[])
async def test_annotator_choose_gpu_executor(c, s):
"""
Demonstrate the possible of automatically choose a
GPU-Executor based on the type of the task output
"""

numpy = pytest.importorskip("numpy")
cupy = pytest.importorskip("cupy")
dask_cuda = pytest.importorskip("dask_cuda.is_device_object")

def get_thread_name(x):
return threading.current_thread().name

def set_gpu_executor(ts: TaskState, value: object) -> dict:
if dask_cuda.is_device_object(value):
return {"annotations": {"executor": "gpu"}, "dependents": True}

def f(ary_type):
if ary_type == "numpy":
return numpy.arange(10)
elif ary_type == "cupy":
return cupy.arange(10)

async with Worker(
s.address,
nthreads=1,
executor={
"gpu": ThreadPoolExecutor(1, thread_name_prefix="GPU-Executor"),
},
annotators=[set_gpu_executor],
):
dsk = {
"f1": (f, "numpy"),
"g1": (get_thread_name, "f1"),
"f2": (f, "cupy"),
"g2": (get_thread_name, "f2"),
}
res1, res2 = c.get(dsk, ["g1", "g2"], sync=False, asynchronous=True)
assert "Dask-Default-Threads" in await res1
assert "GPU-Executor" in await res2


def kill_process():
import os
import signal
Expand Down
93 changes: 72 additions & 21 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from datetime import timedelta
from inspect import isawaitable
from pickle import PicklingError
from typing import Dict, Hashable, Iterable, Optional
from typing import Callable, Dict, Hashable, Iterable, Mapping, Optional

from tlz import first, keymap, merge, pluck # noqa: F401
from tornado.ioloop import IOLoop, PeriodicCallback
Expand Down Expand Up @@ -174,7 +174,7 @@ def __init__(self, key, runspec=None):
self.who_has = set()
self.coming_from = None
self.waiting_for_data = set()
self.resource_restrictions = None
self.resource_restrictions = {}
self.exception = None
self.exception_text = ""
self.traceback = None
Expand All @@ -186,7 +186,7 @@ def __init__(self, key, runspec=None):
self.stop_time = None
self.metadata = {}
self.nbytes = None
self.annotations = None
self.annotations = {}
self.scheduler_holds_ref = False

def __repr__(self):
Expand Down Expand Up @@ -352,6 +352,10 @@ class Worker(ServerNode):
lifetime_restart: bool
Whether or not to restart a worker after it has reached its lifetime
Default False
annotators: Iterable[Callable[[TaskState, object], bool]], optional
List of annotators, which are functions that given a `TaskState` and task
output modifies the `TaskState` and return whether any changes was made to
the `TaskState` or not.

Examples
--------
Expand Down Expand Up @@ -418,9 +422,10 @@ def __init__(
lifetime=None,
lifetime_stagger=None,
lifetime_restart=None,
annotators=None,
**kwargs,
):
self.tasks = dict()
self.tasks: Dict[str, TaskState] = dict()
self.waiting_for_data_count = 0
self.has_what = defaultdict(set)
self.pending_data_per_worker = defaultdict(deque)
Expand Down Expand Up @@ -687,6 +692,10 @@ def __init__(

self.low_level_profiler = low_level_profiler

self.annotators: Iterable[Callable[[TaskState, object], Mapping]] = (
annotators or ()
)

handlers = {
"gather": self.gather,
"run": self.run,
Expand Down Expand Up @@ -1496,11 +1505,15 @@ def update_data(self, comm=None, data=None, report=True, serializers=None):
self.log.append((key, "receive-from-scatter"))

if report:

self.log.append(
("Notifying scheduler about in-memory in update-data", list(data))
)
self.batched_stream.send({"op": "add-keys", "keys": list(data)})

if self.annotators:
for key, value in data.items():
self.handle_annotators(self.tasks[key], value)

info = {"nbytes": {k: sizeof(v) for k, v in data.items()}, "status": "OK"}
return info

Expand Down Expand Up @@ -1646,7 +1659,7 @@ def add_task(
ts.duration = duration
if resource_restrictions:
ts.resource_restrictions = resource_restrictions
ts.annotations = annotations
ts.annotations = annotations or {}

who_has = who_has or {}

Expand Down Expand Up @@ -1896,6 +1909,7 @@ def transition_flight_memory(self, ts, value=None):

self.log.append(("Notifying scheduler about in-memory", ts.key))
self.batched_stream.send({"op": "add-keys", "keys": [ts.key]})
self.handle_annotators(ts, value)

except Exception as e:
logger.exception(e)
Expand All @@ -1919,7 +1933,7 @@ def transition_waiting_ready(self, ts):

self.has_what[self.address].discard(ts.key)

if ts.resource_restrictions is not None:
if ts.resource_restrictions:
self.constrained.append(ts.key)
return "constrained"
else:
Expand All @@ -1942,6 +1956,7 @@ def transition_waiting_done(self, ts, value=None):
ts.waiting_for_data.clear()
if value is not None:
self.put_key_in_memory(ts, value)
self.handle_annotators(ts, value)
self.send_task_state_to_scheduler(ts)
except Exception as e:
logger.exception(e)
Expand Down Expand Up @@ -2002,9 +2017,8 @@ def transition_executing_done(self, ts, value=no_value, report=True):
assert ts.key not in self.ready

out = None
if ts.resource_restrictions is not None:
for resource, quantity in ts.resource_restrictions.items():
self.available_resources[resource] += quantity
for resource, quantity in ts.resource_restrictions.items():
self.available_resources[resource] += quantity

if ts.state == "executing":
self.executing_count -= 1
Expand All @@ -2027,11 +2041,12 @@ def transition_executing_done(self, ts, value=no_value, report=True):
for d in ts.dependents:
d.waiting_for_data.add(ts.key)

if report and self.batched_stream and self.status == Status.running:
self.send_task_state_to_scheduler(ts)
if self.batched_stream and self.status == Status.running:
self.handle_annotators(ts, value)
if report:
self.send_task_state_to_scheduler(ts)
else:
raise CommClosedError

return out

except OSError:
Expand Down Expand Up @@ -2212,7 +2227,7 @@ def ensure_communicating(self):
pdb.set_trace()
raise

def send_task_state_to_scheduler(self, ts):
def send_task_state_to_scheduler(self, ts: TaskState):
if ts.key in self.data or self.actors.get(ts.key):
typ = ts.type
if ts.nbytes is None or typ is None:
Expand Down Expand Up @@ -2292,6 +2307,48 @@ def put_key_in_memory(self, ts, value, transition=True):

self.log.append((ts.key, "put-in-memory"))

def handle_annotators(
self,
ts: TaskState,
value: object = no_value,
):
"""Annotate task if applicable and report back to the scheduler"""

# Separate annotation updates into two cases:
a1 = {} # when the task's dependents should also be updated
a2 = {} # when only the task itself should be updated
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have thoughts on better names here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Heh, yeah I can do that :)

for func in self.annotators:
try:
res = func(ts, value)
except Exception as e:
logger.exception(e)
else:
if res:
if res["dependents"]:
a1.update(res["annotations"])
else:
a2.update(res["annotations"])
if a1:
ts.annotations.update(a1)
self.batched_stream.send(
{
"op": "annotate-task",
"key": ts.key,
"annotations": a1,
"dependents": True,
}
)
if a2:
ts.annotations.update(a2)
self.batched_stream.send(
{
"op": "annotate-task",
"key": ts.key,
"annotations": a2,
"dependents": False,
}
)

def select_keys_for_gather(self, worker, dep):
assert isinstance(dep, str)
deps = {dep}
Expand Down Expand Up @@ -2851,12 +2908,9 @@ def actor_attribute(self, comm=None, actor=None, attribute=None):

def meets_resource_constraints(self, key):
ts = self.tasks[key]
if not ts.resource_restrictions:
return True
for resource, needed in ts.resource_restrictions.items():
if self.available_resources[resource] < needed:
return False

return True

async def _maybe_deserialize_task(self, ts):
Expand Down Expand Up @@ -2943,10 +2997,7 @@ async def execute(self, key):

args2, kwargs2 = self._prepare_args_for_execution(ts, args, kwargs)

if ts.annotations is not None and "executor" in ts.annotations:
executor = ts.annotations["executor"]
else:
executor = "default"
executor = ts.annotations.get("executor", "default")
assert executor in self.executors
assert key == ts.key
self.active_keys.add(ts.key)
Expand Down