Skip to content

Commit

Permalink
Overhaul update_who_has
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed May 13, 2022
1 parent 50d2911 commit c39288a
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 31 deletions.
1 change: 1 addition & 0 deletions distributed/tests/test_stories.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ async def test_worker_story_with_deps(c, s, a, b):
assert stimulus_ids == {"compute-task"}
expected = [
("dep", "ensure-task-exists", "released"),
("dep", "update-who-has", [], [a.address]),
("dep", "released", "fetch", "fetch", {}),
("gather-dependencies", a.address, {"dep"}),
("dep", "fetch", "flight", "flight", {}),
Expand Down
1 change: 1 addition & 0 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2814,6 +2814,7 @@ def __getstate__(self):
story[b],
[
("x", "ensure-task-exists", "released"),
("x", "update-who-has", [], [a]),
("x", "released", "fetch", "fetch", {}),
("gather-dependencies", a, {"x"}),
("x", "fetch", "flight", "flight", {}),
Expand Down
108 changes: 77 additions & 31 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1873,8 +1873,12 @@ def handle_acquire_replicas(
if ts.state != "memory":
recommendations[ts] = "fetch"

self.update_who_has(who_has)
recommendations, instructions = merge_recs_instructions(
(recommendations, []),
self._update_who_has(who_has, stimulus_id=stimulus_id),
)
self.transitions(recommendations, stimulus_id=stimulus_id)
self._handle_instructions(instructions)

if self.validate:
for key in keys:
Expand Down Expand Up @@ -1977,7 +1981,10 @@ def handle_compute_task(
for dep_key, value in nbytes.items():
self.tasks[dep_key].nbytes = value

self.update_who_has(who_has)
recommendations, instructions = merge_recs_instructions(
(recommendations, instructions),
self._update_who_has(who_has, stimulus_id=stimulus_id),
)
else: # pragma: nocover
raise RuntimeError(f"Unexpected task state encountered {ts} {stimulus_id}")

Expand Down Expand Up @@ -3136,8 +3143,14 @@ def _select_keys_for_gather(

while tasks:
ts = tasks.peek()
if ts.state != "fetch" or ts.key in all_keys_to_gather:
if (
ts.state != "fetch"
# Do not acquire the same key twice if multiple workers holds replicas
or ts.key in all_keys_to_gather
# A replica is still available (otherwise status would not be 'fetch'
# anymore), but not on this worker. See _update_who_has().
or worker not in ts.who_has
):
tasks.pop()
continue
if total_bytes + ts.get_nbytes() > self.target_message_size:
Expand Down Expand Up @@ -3422,7 +3435,11 @@ def done_event():
who_has = await retry_operation(
self.scheduler.who_has, keys=refresh_who_has
)
self.update_who_has(who_has)
recommendations, instructions = self._update_who_has(
who_has, stimulus_id=stimulus_id
)
self.transitions(recommendations, stimulus_id=stimulus_id)
self._handle_instructions(instructions)

@log_errors
def _readd_busy_worker(self, worker: str) -> None:
Expand All @@ -3445,47 +3462,76 @@ async def find_missing(self) -> None:
self.scheduler.who_has,
keys=[ts.key for ts in self._missing_dep_flight],
)
self.update_who_has(who_has)
recommendations: Recs = {}
recommendations, instructions = self._update_who_has(
who_has, stimulus_id=stimulus_id
)
for ts in self._missing_dep_flight:
if ts.who_has:
assert ts not in recommendations
recommendations[ts] = "fetch"
self.transitions(recommendations, stimulus_id=stimulus_id)
self._handle_instructions(instructions)

finally:
# This is quite arbitrary but the heartbeat has scaling implemented
self.periodic_callbacks[
"find-missing"
].callback_time = self.periodic_callbacks["heartbeat"].callback_time

def update_who_has(self, who_has: dict[str, Collection[str]]) -> None:
try:
for dep, workers in who_has.items():
if not workers:
continue
def _update_who_has(
self, who_has: dict[str, Collection[str]], *, stimulus_id: str
) -> RecsInstrs:
recs: Recs = {}
instructions: Instructions = []

if dep in self.tasks:
dep_ts = self.tasks[dep]
if self.address in workers and self.tasks[dep].state != "memory":
logger.debug(
"Scheduler claims worker %s holds data for task %s which is not true.",
self.name,
dep,
)
# Do not mutate the input dict. That's rude
workers = set(workers) - {self.address}
dep_ts.who_has.update(workers)
for key, workers in who_has.items():
ts = self.tasks.get(key)
if not ts:
continue
workers = set(workers)

for worker in workers:
self.has_what[worker].add(dep)
self.data_needed_per_worker[worker].push(dep_ts)
except Exception as e: # pragma: no cover
logger.exception(e)
if LOG_PDB:
import pdb
if self.address in workers and ts.state != "memory":
logger.debug(
"Scheduler claims worker %s holds data for task %s which is not true.",
self.name,
ts,
)
workers.remove(self.address)
instructions.append(
MissingDataMsg(
key=key, errant_worker=self.address, stimulus_id=stimulus_id
)
)

pdb.set_trace()
raise
if ts.who_has == workers:
continue

self.log.append(
(
key,
"update-who-has",
list(ts.who_has),
list(workers),
stimulus_id,
time(),
)
)

for worker in ts.who_has - workers:
self.has_what[worker].discard(key)
# Can't remove from self.data_needed_per_worker; there is logic
# in _select_keys_for_gather to deal with this

for worker in workers - ts.who_has:
self.has_what[worker].add(key)
if ts.state == "fetch":
self.data_needed_per_worker[worker].push(ts)

ts.who_has = workers
if not workers and ts.state == "fetch":
recs[ts] = "missing"

return recs, instructions

def handle_steal_request(self, key: str, stimulus_id: str) -> None:
# There may be a race condition between stealing and releasing a task.
Expand Down

0 comments on commit c39288a

Please sign in to comment.