diff --git a/distributed/scheduler.py b/distributed/scheduler.py index b0f735a1dc4..d5aa1f98be2 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1876,10 +1876,14 @@ def transition_waiting_processing(self, key): self.unknown_durations[ks].add(key) duration = 0.5 - self.processing[worker][key] = duration + comm = (sum(self.nbytes[dep] + for dep in self.dependencies[key] - self.has_what[worker]) + / BANDWIDTH) + + self.processing[worker][key] = duration + comm self.rprocessing[key] = worker - self.occupancy[worker] += duration - self.total_occupancy += duration + self.occupancy[worker] += duration + comm + self.total_occupancy += duration + comm self.task_state[key] = 'processing' self.consume_resources(key, worker) self.check_idle_saturated(worker) @@ -1947,9 +1951,12 @@ def transition_processing_memory(self, key, nbytes=None, type=None, if k in self.rprocessing: w = self.rprocessing[k] old = self.processing[w][k] - self.processing[w][k] = avg_duration - self.occupancy[w] += avg_duration - old - self.total_occupancy += avg_duration - old + comm = (sum(self.nbytes[d] for d in + self.dependencies[k] - self.has_what[w]) + / BANDWIDTH) + self.processing[w][k] = avg_duration + comm + self.occupancy[w] += avg_duration + comm - old + self.total_occupancy += avg_duration + comm - old info['last-task'] = compute_stop diff --git a/distributed/stealing.py b/distributed/stealing.py index ecb411c04a5..d421c82e9de 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -122,7 +122,8 @@ def steal_time_ratio(self, key, split=None): if split in fast_tasks: return None, None try: - compute_time = self.scheduler.task_duration[split] + worker = self.scheduler.rprocessing[key] + compute_time = self.scheduler.processing[worker][key] except KeyError: self.stealable_unknown_durations[split].add(key) return None, None @@ -153,6 +154,9 @@ def move_task(self, key, victim, thief): self.scheduler.total_occupancy -= duration duration = self.scheduler.task_duration.get(key_split(key), 0.5) + duration += sum(self.scheduler.nbytes[key] for key in + self.scheduler.dependencies[key] - + self.scheduler.has_what[thief]) / BANDWIDTH self.scheduler.processing[thief][key] = duration self.scheduler.rprocessing[key] = thief self.scheduler.occupancy[thief] += duration @@ -210,7 +214,7 @@ def balance(self): for key in list(stealable): i += 1 idl = idle[i % len(idle)] - duration = s.task_duration.get(key_split(key), 0.5) + duration = s.processing[sat][key] if (occupancy[idl] + cost_multiplier * duration <= occupancy[sat] - duration / 2): @@ -232,7 +236,7 @@ def balance(self): continue i += 1 idl = idle[i % len(idle)] - duration = s.task_duration.get(key_split(key), 0.5) + duration = s.processing[sat][key] if (occupancy[idl] + cost_multiplier * duration <= occupancy[sat] - duration / 2): diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index ca96d0919c3..ab851f6d6d3 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -4,7 +4,7 @@ from collections import defaultdict, deque from copy import deepcopy from datetime import timedelta -from operator import add +from operator import add, mul import sys from time import sleep @@ -21,13 +21,13 @@ from distributed import Nanny, Worker from distributed.core import connect, read, write, close, rpc -from distributed.scheduler import validate_state, Scheduler +from distributed.scheduler import validate_state, Scheduler, BANDWIDTH from distributed.client import _wait, _first_completed from distributed.metrics import time from distributed.protocol.pickle import dumps from distributed.worker import dumps_function, dumps_task from distributed.utils_test import (inc, ignoring, dec, gen_cluster, gen_test, - loop, readone, slowinc) + loop, readone, slowinc, slowadd) from distributed.utils import All from distributed.utils_test import slow from dask.compatibility import apply @@ -855,3 +855,16 @@ def test_learn_occupancy_multiple_workers(c, s, a, b): assert not any(v == 0.5 for vv in s.processing.values() for v in vv) s.validate_state() + + +@gen_cluster(client=True) +def test_include_communication_in_occupancy(c, s, a, b): + s.task_duration['slowadd'] = 0.001 + x = c.submit(mul, b'0', int(BANDWIDTH), workers=a.address) + y = c.submit(mul, b'1', int(BANDWIDTH * 1.5), workers=b.address) + + z = c.submit(slowadd, x, y, delay=1) + while z.key not in s.rprocessing: + yield gen.sleep(0.01) + + assert s.processing[b.address][z.key] > 1 diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 74e956fad17..18f1cc50b7a 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -458,3 +458,21 @@ def test_restart(c, s, a, b): assert not any(x for x in steal.stealable_all) assert not any(x for L in steal.stealable.values() for x in L) + + +@gen_cluster(client=True) +def test_steal_communication_heavy_tasks(c, s, a, b): + s.task_duration['slowadd'] = 0.001 + x = c.submit(mul, b'0', int(BANDWIDTH), workers=a.address) + y = c.submit(mul, b'1', int(BANDWIDTH), workers=b.address) + + futures = [c.submit(slowadd, x, y, delay=1, pure=False, workers=a.address, + allow_other_workers=True) + for i in range(10)] + + while not any(f.key in s.rprocessing for f in futures): + yield gen.sleep(0.01) + + s.extensions['stealing'].balance() + + assert s.processing[b.address]