Skip to content

Commit

Permalink
update stealing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mrocklin committed Dec 31, 2016
1 parent f49a2d7 commit e53ac96
Showing 1 changed file with 19 additions and 18 deletions.
37 changes: 19 additions & 18 deletions distributed/tests/test_steal.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,28 +343,29 @@ def func(x):
pass


def assert_balanced(inp, out, c, s, *workers):
def assert_balanced(inp, expected, c, s, *workers):
steal = s.extensions['stealing']
steal._pc.callback_time = 1000000000
counter = itertools.count()
B = BANDWIDTH
tasks = list(concat(inp))
data = yield c._scatter(range(len(tasks)))

for t, f in zip(tasks, data):
s.nbytes[f.key] = BANDWIDTH * t
s.task_duration[str(int(t))] = 1
data_seq = itertools.count()

futures = []
data_seq = iter(data)
for w, ts in zip(workers, inp):
for t in ts:
dat = next(data_seq) if t else 123
if t:
[dat] = yield c._scatter([next(data_seq)], workers=w.address)
s.nbytes[dat.key] = BANDWIDTH * t
else:
dat = 123
s.task_duration[str(int(t))] = 1
f = c.submit(func, dat, key='%d-%d' % (int(t), next(counter)),
workers=w.address, allow_other_workers=True)
workers=w.address, allow_other_workers=True,
pure=False)
futures.append(f)

while not any(s.processing.values()):
while len(s.rprocessing) < len(futures):
yield gen.sleep(0.001)

s.extensions['stealing'].balance()
Expand All @@ -374,15 +375,15 @@ def assert_balanced(inp, out, c, s, *workers):
for w in workers]

result2 = sorted(result, reverse=True)
out2 = sorted(out, reverse=True)
expected2 = sorted(expected, reverse=True)

if result2 != out2:
if result2 != expected2:
import pdb; pdb.set_trace()

assert result2 == out2
assert result2 == expected2


@pytest.mark.parametrize('inp,out', [
@pytest.mark.parametrize('inp,expected', [
([[1], []], # don't move unnecessarily
[[1], []]),
Expand Down Expand Up @@ -420,12 +421,12 @@ def assert_balanced(inp, out, c, s, *workers):
[[0, 0], [0, 0], [0], [0]]),
([[4, 2, 2, 2, 2, 1, 1],
[4, 2, 1, 1, 1],
[4, 2, 1, 1, 1, 1],
[],
[],
[]],
[[4, 2, 2, 2, 2],
[4, 2, 1],
[4, 2, 1, 1],
[1, 1],
[1],
[1]]),
Expand All @@ -437,8 +438,8 @@ def assert_balanced(inp, out, c, s, *workers):
[1, 1], [1, 1], [1, 1],
[1, 1, 1]])
])
def test_balance(inp, out):
test = lambda *args, **kwargs: assert_balanced(inp, out, *args, **kwargs)
def test_balance(inp, expected):
test = lambda *args, **kwargs: assert_balanced(inp, expected, *args, **kwargs)
test = gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * len(inp))(test)
test()

Expand Down

0 comments on commit e53ac96

Please sign in to comment.