From f49a2d71f3706a84594185b9eecec84780cf57dc Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sat, 31 Dec 2016 12:54:37 -0500 Subject: [PATCH] Support stealing of tasks that require heavy communication Previously we would assume that the worker that was assigned the task has most of the data already. This assumption is faulty for shuffle-like tasks where no one has most of the data. Now we properly penalize workers for data that they don't own. --- distributed/scheduler.py | 19 +++++++++++++------ distributed/stealing.py | 10 +++++++--- distributed/tests/test_scheduler.py | 19 ++++++++++++++++--- distributed/tests/test_steal.py | 18 ++++++++++++++++++ 4 files changed, 54 insertions(+), 12 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index b0f735a1dc4..d5aa1f98be2 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1876,10 +1876,14 @@ def transition_waiting_processing(self, key): self.unknown_durations[ks].add(key) duration = 0.5 - self.processing[worker][key] = duration + comm = (sum(self.nbytes[dep] + for dep in self.dependencies[key] - self.has_what[worker]) + / BANDWIDTH) + + self.processing[worker][key] = duration + comm self.rprocessing[key] = worker - self.occupancy[worker] += duration - self.total_occupancy += duration + self.occupancy[worker] += duration + comm + self.total_occupancy += duration + comm self.task_state[key] = 'processing' self.consume_resources(key, worker) self.check_idle_saturated(worker) @@ -1947,9 +1951,12 @@ def transition_processing_memory(self, key, nbytes=None, type=None, if k in self.rprocessing: w = self.rprocessing[k] old = self.processing[w][k] - self.processing[w][k] = avg_duration - self.occupancy[w] += avg_duration - old - self.total_occupancy += avg_duration - old + comm = (sum(self.nbytes[d] for d in + self.dependencies[k] - self.has_what[w]) + / BANDWIDTH) + self.processing[w][k] = avg_duration + comm + self.occupancy[w] += avg_duration + comm - old + self.total_occupancy += avg_duration + comm - old info['last-task'] = compute_stop diff --git a/distributed/stealing.py b/distributed/stealing.py index ecb411c04a5..d421c82e9de 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -122,7 +122,8 @@ def steal_time_ratio(self, key, split=None): if split in fast_tasks: return None, None try: - compute_time = self.scheduler.task_duration[split] + worker = self.scheduler.rprocessing[key] + compute_time = self.scheduler.processing[worker][key] except KeyError: self.stealable_unknown_durations[split].add(key) return None, None @@ -153,6 +154,9 @@ def move_task(self, key, victim, thief): self.scheduler.total_occupancy -= duration duration = self.scheduler.task_duration.get(key_split(key), 0.5) + duration += sum(self.scheduler.nbytes[key] for key in + self.scheduler.dependencies[key] - + self.scheduler.has_what[thief]) / BANDWIDTH self.scheduler.processing[thief][key] = duration self.scheduler.rprocessing[key] = thief self.scheduler.occupancy[thief] += duration @@ -210,7 +214,7 @@ def balance(self): for key in list(stealable): i += 1 idl = idle[i % len(idle)] - duration = s.task_duration.get(key_split(key), 0.5) + duration = s.processing[sat][key] if (occupancy[idl] + cost_multiplier * duration <= occupancy[sat] - duration / 2): @@ -232,7 +236,7 @@ def balance(self): continue i += 1 idl = idle[i % len(idle)] - duration = s.task_duration.get(key_split(key), 0.5) + duration = s.processing[sat][key] if (occupancy[idl] + cost_multiplier * duration <= occupancy[sat] - duration / 2): diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index ca96d0919c3..ab851f6d6d3 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -4,7 +4,7 @@ from collections import defaultdict, deque from copy import deepcopy from datetime import timedelta -from operator import add +from operator import add, mul import sys from time import sleep @@ -21,13 +21,13 @@ from distributed import Nanny, Worker from distributed.core import connect, read, write, close, rpc -from distributed.scheduler import validate_state, Scheduler +from distributed.scheduler import validate_state, Scheduler, BANDWIDTH from distributed.client import _wait, _first_completed from distributed.metrics import time from distributed.protocol.pickle import dumps from distributed.worker import dumps_function, dumps_task from distributed.utils_test import (inc, ignoring, dec, gen_cluster, gen_test, - loop, readone, slowinc) + loop, readone, slowinc, slowadd) from distributed.utils import All from distributed.utils_test import slow from dask.compatibility import apply @@ -855,3 +855,16 @@ def test_learn_occupancy_multiple_workers(c, s, a, b): assert not any(v == 0.5 for vv in s.processing.values() for v in vv) s.validate_state() + + +@gen_cluster(client=True) +def test_include_communication_in_occupancy(c, s, a, b): + s.task_duration['slowadd'] = 0.001 + x = c.submit(mul, b'0', int(BANDWIDTH), workers=a.address) + y = c.submit(mul, b'1', int(BANDWIDTH * 1.5), workers=b.address) + + z = c.submit(slowadd, x, y, delay=1) + while z.key not in s.rprocessing: + yield gen.sleep(0.01) + + assert s.processing[b.address][z.key] > 1 diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 74e956fad17..18f1cc50b7a 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -458,3 +458,21 @@ def test_restart(c, s, a, b): assert not any(x for x in steal.stealable_all) assert not any(x for L in steal.stealable.values() for x in L) + + +@gen_cluster(client=True) +def test_steal_communication_heavy_tasks(c, s, a, b): + s.task_duration['slowadd'] = 0.001 + x = c.submit(mul, b'0', int(BANDWIDTH), workers=a.address) + y = c.submit(mul, b'1', int(BANDWIDTH), workers=b.address) + + futures = [c.submit(slowadd, x, y, delay=1, pure=False, workers=a.address, + allow_other_workers=True) + for i in range(10)] + + while not any(f.key in s.rprocessing for f in futures): + yield gen.sleep(0.01) + + s.extensions['stealing'].balance() + + assert s.processing[b.address]