diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 18f1cc50b7a..1165d420a87 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -343,28 +343,29 @@ def func(x): pass -def assert_balanced(inp, out, c, s, *workers): +def assert_balanced(inp, expected, c, s, *workers): steal = s.extensions['stealing'] steal._pc.callback_time = 1000000000 counter = itertools.count() B = BANDWIDTH tasks = list(concat(inp)) - data = yield c._scatter(range(len(tasks))) - - for t, f in zip(tasks, data): - s.nbytes[f.key] = BANDWIDTH * t - s.task_duration[str(int(t))] = 1 + data_seq = itertools.count() futures = [] - data_seq = iter(data) for w, ts in zip(workers, inp): for t in ts: - dat = next(data_seq) if t else 123 + if t: + [dat] = yield c._scatter([next(data_seq)], workers=w.address) + s.nbytes[dat.key] = BANDWIDTH * t + else: + dat = 123 + s.task_duration[str(int(t))] = 1 f = c.submit(func, dat, key='%d-%d' % (int(t), next(counter)), - workers=w.address, allow_other_workers=True) + workers=w.address, allow_other_workers=True, + pure=False) futures.append(f) - while not any(s.processing.values()): + while len(s.rprocessing) < len(futures): yield gen.sleep(0.001) s.extensions['stealing'].balance() @@ -374,15 +375,15 @@ def assert_balanced(inp, out, c, s, *workers): for w in workers] result2 = sorted(result, reverse=True) - out2 = sorted(out, reverse=True) + expected2 = sorted(expected, reverse=True) - if result2 != out2: + if result2 != expected2: import pdb; pdb.set_trace() - assert result2 == out2 + assert result2 == expected2 -@pytest.mark.parametrize('inp,out', [ +@pytest.mark.parametrize('inp,expected', [ ([[1], []], # don't move unnecessarily [[1], []]), @@ -420,12 +421,12 @@ def assert_balanced(inp, out, c, s, *workers): [[0, 0], [0, 0], [0], [0]]), ([[4, 2, 2, 2, 2, 1, 1], - [4, 2, 1, 1, 1], + [4, 2, 1, 1, 1, 1], [], [], []], [[4, 2, 2, 2, 2], - [4, 2, 1], + [4, 2, 1, 1], [1, 1], [1], [1]]), @@ -437,8 +438,8 @@ def assert_balanced(inp, out, c, s, *workers): [1, 1], [1, 1], [1, 1], [1, 1, 1]]) ]) -def test_balance(inp, out): - test = lambda *args, **kwargs: assert_balanced(inp, out, *args, **kwargs) +def test_balance(inp, expected): + test = lambda *args, **kwargs: assert_balanced(inp, expected, *args, **kwargs) test = gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * len(inp))(test) test()