Skip to content

Commit

Permalink
Initially implementation of dynamic annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
madsbk committed Aug 12, 2021
1 parent 7392de6 commit c56bc9b
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 19 deletions.
28 changes: 27 additions & 1 deletion distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2662,6 +2662,8 @@ def transition_processing_memory(
typename: str = None,
worker=None,
startstops=None,
annotations: dict = None,
resource_restrictions: dict = None,
**kwargs,
):
ws: WorkerState
Expand Down Expand Up @@ -2765,7 +2767,15 @@ def transition_processing_memory(
_remove_from_processing(self, ts)

_add_to_memory(
self, ts, ws, recommendations, client_msgs, type=type, typename=typename
self,
ts,
ws,
recommendations,
client_msgs,
type=type,
typename=typename,
annotations=annotations,
resource_restrictions=resource_restrictions,
)

if self._validate:
Expand Down Expand Up @@ -7595,6 +7605,8 @@ def _add_to_memory(
client_msgs: dict,
type=None,
typename: str = None,
annotations: dict = None,
resource_restrictions: dict = None,
):
"""
Add *ts* to the set of in-memory tasks.
Expand All @@ -7619,6 +7631,20 @@ def _add_to_memory(
if not s: # new task ready to run
recommendations[dts._key] = "processing"

if annotations:
for dts in deps:
if dts._annotations:
dts._annotations.update(annotations)
else:
dts._annotations = annotations

if resource_restrictions:
for dts in deps:
if dts._resource_restrictions:
dts._resource_restrictions.update(resource_restrictions)
else:
dts._resource_restrictions = resource_restrictions

for dts in ts._dependencies:
s = dts._waiters
s.discard(ts)
Expand Down
80 changes: 79 additions & 1 deletion distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,13 @@
slowinc,
slowsum,
)
from distributed.worker import Worker, error_message, logger, parse_memory_limit
from distributed.worker import (
TaskState,
Worker,
error_message,
logger,
parse_memory_limit,
)

pytestmark = pytest.mark.ci1

Expand Down Expand Up @@ -2006,6 +2012,78 @@ async def test_process_executor(c, s, a, b):
assert (await future) != os.getpid()


@gen_cluster(client=True, nthreads=[])
async def test_executor_by_dynamic_annotation(c, s):
"""Test that it is possible for a task to choose executor"""

def get_thread_name(x):
return threading.current_thread().name

def set_executor_by_task_output(ts: TaskState, value: object) -> None:
ts.annotations["executor"] = value

def f(x):
return x

async with Worker(
s.address,
nthreads=1,
executor={
"exec1": ThreadPoolExecutor(1, thread_name_prefix="Executor1"),
"exec2": ThreadPoolExecutor(1, thread_name_prefix="Executor2"),
},
dynamic_annotations=[set_executor_by_task_output],
):
res = c.submit(f, "exec1")
res = await c.submit(get_thread_name, res)
assert "Executor1" in res

res = c.submit(f, "exec2")
res = await c.submit(get_thread_name, res)
assert "Executor2" in res


@gen_cluster(client=True, nthreads=[])
async def test_gpu_executor_by_dynamic_annotation(c, s):
"""
Demonstrate the possible of automatically choose a
GPU-Executor based on the type of the task output
"""

numpy = pytest.importorskip("numpy")
cupy = pytest.importorskip("cupy")
dask_cuda = pytest.importorskip("dask_cuda.is_device_object")

def get_thread_name(x):
return threading.current_thread().name

def set_gpu_executor(ts: TaskState, value: object) -> None:
if dask_cuda.is_device_object(value):
ts.annotations["executor"] = "gpu"

def f(ary_type):
if ary_type == "numpy":
return numpy.arange(10)
elif ary_type == "cupy":
return cupy.arange(10)

async with Worker(
s.address,
nthreads=1,
executor={
"gpu": ThreadPoolExecutor(1, thread_name_prefix="GPU-Executor"),
},
dynamic_annotations=[set_gpu_executor],
):
res = c.submit(f, "numpy")
res = await c.submit(get_thread_name, res)
assert "Dask-Default-Threads" in res

res = c.submit(f, "cupy")
res = await c.submit(get_thread_name, res)
assert "GPU-Executor" in res


def kill_process():
import os
import signal
Expand Down
37 changes: 20 additions & 17 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from datetime import timedelta
from inspect import isawaitable
from pickle import PicklingError
from typing import Dict, Hashable, Iterable, Optional
from typing import Callable, Dict, Hashable, Iterable, Optional

from tlz import first, keymap, merge, pluck # noqa: F401
from tornado.ioloop import IOLoop, PeriodicCallback
Expand Down Expand Up @@ -168,7 +168,7 @@ def __init__(self, key, runspec=None):
self.who_has = set()
self.coming_from = None
self.waiting_for_data = set()
self.resource_restrictions = None
self.resource_restrictions = {}
self.exception = None
self.traceback = None
self.type = None
Expand All @@ -178,7 +178,7 @@ def __init__(self, key, runspec=None):
self.stop_time = None
self.metadata = {}
self.nbytes = None
self.annotations = None
self.annotations = {}
self.scheduler_holds_ref = False

def __repr__(self):
Expand Down Expand Up @@ -407,9 +407,10 @@ def __init__(
lifetime=None,
lifetime_stagger=None,
lifetime_restart=None,
dynamic_annotations=None,
**kwargs,
):
self.tasks = dict()
self.tasks: Dict[str, TaskState] = dict()
self.waiting_for_data_count = 0
self.has_what = defaultdict(set)
self.pending_data_per_worker = defaultdict(deque)
Expand Down Expand Up @@ -672,6 +673,10 @@ def __init__(

self.low_level_profiler = low_level_profiler

self.dynamic_annotations: Iterable[Callable[[TaskState, object], None]] = (
dynamic_annotations or ()
)

handlers = {
"gather": self.gather,
"run": self.run,
Expand Down Expand Up @@ -1629,7 +1634,7 @@ def add_task(
ts.duration = duration
if resource_restrictions:
ts.resource_restrictions = resource_restrictions
ts.annotations = annotations
ts.annotations = annotations or {}

who_has = who_has or {}

Expand Down Expand Up @@ -1902,7 +1907,7 @@ def transition_waiting_ready(self, ts):

self.has_what[self.address].discard(ts.key)

if ts.resource_restrictions is not None:
if ts.resource_restrictions:
self.constrained.append(ts.key)
return "constrained"
else:
Expand Down Expand Up @@ -1983,9 +1988,11 @@ def transition_executing_done(self, ts, value=no_value, report=True):
assert ts.key not in self.ready

out = None
if ts.resource_restrictions is not None:
for resource, quantity in ts.resource_restrictions.items():
self.available_resources[resource] += quantity
for resource, quantity in ts.resource_restrictions.items():
self.available_resources[resource] += quantity

for func in self.dynamic_annotations:
func(ts, value)

if ts.state == "executing":
self.executing_count -= 1
Expand Down Expand Up @@ -2191,7 +2198,7 @@ def ensure_communicating(self):
pdb.set_trace()
raise

def send_task_state_to_scheduler(self, ts):
def send_task_state_to_scheduler(self, ts: TaskState):
if ts.key in self.data or self.actors.get(ts.key):
typ = ts.type
if ts.nbytes is None or typ is None:
Expand All @@ -2217,6 +2224,8 @@ def send_task_state_to_scheduler(self, ts):
"type": typ_serialized,
"typename": typename(typ),
"metadata": ts.metadata,
"annotations": ts.annotations,
"resource_restrictions": ts.resource_restrictions,
}
elif ts.exception is not None:
d = {
Expand Down Expand Up @@ -2792,12 +2801,9 @@ def actor_attribute(self, comm=None, actor=None, attribute=None):

def meets_resource_constraints(self, key):
ts = self.tasks[key]
if not ts.resource_restrictions:
return True
for resource, needed in ts.resource_restrictions.items():
if self.available_resources[resource] < needed:
return False

return True

async def _maybe_deserialize_task(self, ts):
Expand Down Expand Up @@ -2884,10 +2890,7 @@ async def execute(self, key):

args2, kwargs2 = self._prepare_args_for_execution(ts, args, kwargs)

if ts.annotations is not None and "executor" in ts.annotations:
executor = ts.annotations["executor"]
else:
executor = "default"
executor = ts.annotations.get("executor", "default")
assert executor in self.executors
assert key == ts.key
self.active_keys.add(ts.key)
Expand Down

0 comments on commit c56bc9b

Please sign in to comment.