From 51755d3157bc945c41d0b128901f6c8ce0850504 Mon Sep 17 00:00:00 2001 From: fjetter Date: Mon, 26 Jul 2021 15:58:10 +0200 Subject: [PATCH] fixes --- distributed/tests/test_worker.py | 39 ++++++++++++++++++++++++++++++++ distributed/worker.py | 26 +++++++++++++-------- 2 files changed, 56 insertions(+), 9 deletions(-) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 2b40599f1fc..41c4eb14c5c 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -2665,3 +2665,42 @@ def reduce(*args, **kwargs): while any(w.tasks for w in workers): await asyncio.sleep(0.001) + + +@gen_cluster(client=True, nthreads=[("", 1)] * 3) +async def test_who_has_consistent_remove_replica(c, s, *workers): + + a = workers[0] + other_workers = {w for w in workers if w != a} + f1 = c.submit(inc, 1, key="f1", workers=[w.address for w in other_workers]) + await wait(f1) + for w in other_workers: + _acquire_replica(s, w, f1) + + while not len(s.tasks[f1.key].who_has) == len(other_workers): + await asyncio.sleep(0) + + f2 = c.submit(inc, f1, workers=[a.address]) + + # Wait just until the moment the worker received the task and scheduled the + # task to be fetched, then remove the replica from the worker this one is + # trying to get the data from. Ensure this is handled gracefully and no + # suspicious counters are raised since this is expected behaviour when + # removing replicas + + while f1.key not in a.tasks or a.tasks[f1.key].state != "flight": + await asyncio.sleep(0) + + coming_from = None + for w in other_workers: + coming_from = w + if w.address == a.tasks[f1.key].coming_from: + break + + coming_from.handle_remove_replicas([f1.key], "test") + + await f2 + + assert ("missing-dep", f1.key) in a.story(f1.key) + assert a.tasks[f1.key].suspicious_count == 0 + assert s.tasks[f1.key].suspicious == 0 diff --git a/distributed/worker.py b/distributed/worker.py index 1b89121bbee..5a8ca5e50c5 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1747,6 +1747,8 @@ def transition_released_waiting(self, ts, *, stimulus_id): recommendations[ts] = "ready" else: recommendations[ts] = "constrained" + else: + self.waiting_for_data_count += 1 ts.state = "waiting" return recommendations, [] @@ -2078,8 +2080,14 @@ def transitions(self, recommendations: dict, stimulus_id): """ s_msgs = [] self._transitions(recommendations, s_msgs, stimulus_id) - for msg in s_msgs: - self.batched_stream.send(msg) + if not self.batched_stream.closed(): + for msg in s_msgs: + self.batched_stream.send(msg) + else: + logger.debug( + "BatchedSend closed while transitioning tasks. %s tasks not sent.", + len(s_msgs), + ) def maybe_transition_long_running(self, ts, stimulus_id, compute_duration=None): if ts.state == "executing": @@ -2255,12 +2263,9 @@ def _put_key_in_memory(self, ts, value, stimulus_id): ts.type = type(value) for dep in ts.dependents: - try: - dep.waiting_for_data.remove(ts) - self.waiting_for_data_count -= 1 - except KeyError: - pass + dep.waiting_for_data.discard(ts) if not dep.waiting_for_data: + self.waiting_for_data_count -= 1 recommendations[dep] = "ready" self.log.append((ts.key, "put-in-memory", stimulus_id, time())) @@ -2460,7 +2465,7 @@ async def gather_dep( ) recommendations[ts] = "fetch" else: - logger.debug( + logger.warning( "Unexpected task state encountered for %s after gather_dep" ) del data, response @@ -3366,6 +3371,7 @@ def validate_state(self): return try: assert self.executing_count >= 0 + waiting_for_data_count = 0 for ts in self.tasks.values(): assert ts.state is not None # check that worker has task @@ -3380,6 +3386,8 @@ def validate_state(self): # Might need better bookkeeping assert dep.state is not None assert ts in dep.dependents, ts + if ts.waiting_for_data: + waiting_for_data_count += 1 for ts_wait in ts.waiting_for_data: assert ts_wait.key in self.tasks assert ( @@ -3391,7 +3399,7 @@ def validate_state(self): assert isinstance(ts.nbytes, int) assert not ts.waiting_for_data assert ts.key in self.data or ts.key in self.actors - + assert self.waiting_for_data_count == waiting_for_data_count for worker, keys in self.has_what.items(): for k in keys: assert worker in self.tasks[k].who_has