Skip to content

Commit

Permalink
Support stealing of tasks that require heavy communication
Browse files Browse the repository at this point in the history
Previously we would assume that the worker that was assigned the task has most
of the data already.  This assumption is faulty for shuffle-like tasks where no
one has most of the data.  Now we properly penalize workers for data that they
don't own.
  • Loading branch information
mrocklin committed Dec 31, 2016
1 parent 17d5352 commit f49a2d7
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 12 deletions.
19 changes: 13 additions & 6 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1876,10 +1876,14 @@ def transition_waiting_processing(self, key):
self.unknown_durations[ks].add(key)
duration = 0.5

self.processing[worker][key] = duration
comm = (sum(self.nbytes[dep]
for dep in self.dependencies[key] - self.has_what[worker])
/ BANDWIDTH)

self.processing[worker][key] = duration + comm
self.rprocessing[key] = worker
self.occupancy[worker] += duration
self.total_occupancy += duration
self.occupancy[worker] += duration + comm
self.total_occupancy += duration + comm
self.task_state[key] = 'processing'
self.consume_resources(key, worker)
self.check_idle_saturated(worker)
Expand Down Expand Up @@ -1947,9 +1951,12 @@ def transition_processing_memory(self, key, nbytes=None, type=None,
if k in self.rprocessing:
w = self.rprocessing[k]
old = self.processing[w][k]
self.processing[w][k] = avg_duration
self.occupancy[w] += avg_duration - old
self.total_occupancy += avg_duration - old
comm = (sum(self.nbytes[d] for d in
self.dependencies[k] - self.has_what[w])
/ BANDWIDTH)
self.processing[w][k] = avg_duration + comm
self.occupancy[w] += avg_duration + comm - old
self.total_occupancy += avg_duration + comm - old

info['last-task'] = compute_stop

Expand Down
10 changes: 7 additions & 3 deletions distributed/stealing.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ def steal_time_ratio(self, key, split=None):
if split in fast_tasks:
return None, None
try:
compute_time = self.scheduler.task_duration[split]
worker = self.scheduler.rprocessing[key]
compute_time = self.scheduler.processing[worker][key]
except KeyError:
self.stealable_unknown_durations[split].add(key)
return None, None
Expand Down Expand Up @@ -153,6 +154,9 @@ def move_task(self, key, victim, thief):
self.scheduler.total_occupancy -= duration

duration = self.scheduler.task_duration.get(key_split(key), 0.5)
duration += sum(self.scheduler.nbytes[key] for key in
self.scheduler.dependencies[key] -
self.scheduler.has_what[thief]) / BANDWIDTH
self.scheduler.processing[thief][key] = duration
self.scheduler.rprocessing[key] = thief
self.scheduler.occupancy[thief] += duration
Expand Down Expand Up @@ -210,7 +214,7 @@ def balance(self):
for key in list(stealable):
i += 1
idl = idle[i % len(idle)]
duration = s.task_duration.get(key_split(key), 0.5)
duration = s.processing[sat][key]

if (occupancy[idl] + cost_multiplier * duration
<= occupancy[sat] - duration / 2):
Expand All @@ -232,7 +236,7 @@ def balance(self):
continue
i += 1
idl = idle[i % len(idle)]
duration = s.task_duration.get(key_split(key), 0.5)
duration = s.processing[sat][key]

if (occupancy[idl] + cost_multiplier * duration
<= occupancy[sat] - duration / 2):
Expand Down
19 changes: 16 additions & 3 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections import defaultdict, deque
from copy import deepcopy
from datetime import timedelta
from operator import add
from operator import add, mul
import sys
from time import sleep

Expand All @@ -21,13 +21,13 @@

from distributed import Nanny, Worker
from distributed.core import connect, read, write, close, rpc
from distributed.scheduler import validate_state, Scheduler
from distributed.scheduler import validate_state, Scheduler, BANDWIDTH
from distributed.client import _wait, _first_completed
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.worker import dumps_function, dumps_task
from distributed.utils_test import (inc, ignoring, dec, gen_cluster, gen_test,
loop, readone, slowinc)
loop, readone, slowinc, slowadd)
from distributed.utils import All
from distributed.utils_test import slow
from dask.compatibility import apply
Expand Down Expand Up @@ -855,3 +855,16 @@ def test_learn_occupancy_multiple_workers(c, s, a, b):

assert not any(v == 0.5 for vv in s.processing.values() for v in vv)
s.validate_state()


@gen_cluster(client=True)
def test_include_communication_in_occupancy(c, s, a, b):
s.task_duration['slowadd'] = 0.001
x = c.submit(mul, b'0', int(BANDWIDTH), workers=a.address)
y = c.submit(mul, b'1', int(BANDWIDTH * 1.5), workers=b.address)

z = c.submit(slowadd, x, y, delay=1)
while z.key not in s.rprocessing:
yield gen.sleep(0.01)

assert s.processing[b.address][z.key] > 1
18 changes: 18 additions & 0 deletions distributed/tests/test_steal.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,3 +458,21 @@ def test_restart(c, s, a, b):

assert not any(x for x in steal.stealable_all)
assert not any(x for L in steal.stealable.values() for x in L)


@gen_cluster(client=True)
def test_steal_communication_heavy_tasks(c, s, a, b):
s.task_duration['slowadd'] = 0.001
x = c.submit(mul, b'0', int(BANDWIDTH), workers=a.address)
y = c.submit(mul, b'1', int(BANDWIDTH), workers=b.address)

futures = [c.submit(slowadd, x, y, delay=1, pure=False, workers=a.address,
allow_other_workers=True)
for i in range(10)]

while not any(f.key in s.rprocessing for f in futures):
yield gen.sleep(0.01)

s.extensions['stealing'].balance()

assert s.processing[b.address]

0 comments on commit f49a2d7

Please sign in to comment.