Skip to content

Commit

Permalink
Warn unreachable for scheduler.py (#6611)
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter authored Jun 23, 2022
1 parent 1c48633 commit f0680c9
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 32 deletions.
54 changes: 23 additions & 31 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1385,18 +1385,17 @@ def new_task(
self, key: str, spec: object, state: str, computation: Computation = None
) -> TaskState:
"""Create a new task, and associated states"""
ts: TaskState = TaskState(key, spec)
ts = TaskState(key, spec)
ts._state = state

tp: TaskPrefix
prefix_key = key_split(key)
tp = self.task_prefixes.get(prefix_key) # type: ignore
tp = self.task_prefixes.get(prefix_key)
if tp is None:
self.task_prefixes[prefix_key] = tp = TaskPrefix(prefix_key)
ts.prefix = tp

group_key = ts.group_key
tg: TaskGroup = self.task_groups.get(group_key) # type: ignore
tg = self.task_groups.get(group_key)
if tg is None:
self.task_groups[group_key] = tg = TaskGroup(group_key)
if computation:
Expand Down Expand Up @@ -1432,7 +1431,7 @@ def _transition(
Scheduler.transitions : transitive version of this function
"""
try:
ts: TaskState = self.tasks.get(key) # type: ignore
ts = self.tasks.get(key)
if ts is None:
return {}, {}, {}
start = ts._state
Expand Down Expand Up @@ -1574,12 +1573,7 @@ def _transitions(
"""
keys: set = set()
recommendations = recommendations.copy()
msgs: list
new_msgs: list
new: tuple
new_recs: dict
new_cmsgs: dict
new_wmsgs: dict

while recommendations:
key, finish = recommendations.popitem()
keys.add(key)
Expand All @@ -1589,13 +1583,13 @@ def _transitions(

recommendations.update(new_recs)
for c, new_msgs in new_cmsgs.items():
msgs = client_msgs.get(c) # type: ignore
msgs = client_msgs.get(c)
if msgs is not None:
msgs.extend(new_msgs)
else:
client_msgs[c] = new_msgs
for w, new_msgs in new_wmsgs.items():
msgs = worker_msgs.get(w) # type: ignore
msgs = worker_msgs.get(w)
if msgs is not None:
msgs.extend(new_msgs)
else:
Expand Down Expand Up @@ -1795,6 +1789,7 @@ def decide_worker(self, ts: TaskState) -> WorkerState | None:
(self.idle or self.workers).values(),
key=partial(self.worker_objective, ts),
)
assert ws
tg.last_worker_tasks_left = math.floor(
(len(tg) / self.total_nthreads) * ws.nthreads
)
Expand All @@ -1820,6 +1815,7 @@ def decide_worker(self, ts: TaskState) -> WorkerState | None:
n_workers: int = len(wp_vals)
if n_workers < 20: # smart but linear in small case
ws = min(wp_vals, key=operator.attrgetter("occupancy"))
assert ws
if ws.occupancy == 0:
# special case to use round-robin; linear search
# for next worker with zero occupancy (or just
Expand Down Expand Up @@ -1942,13 +1938,11 @@ def transition_processing_memory(
startstops=None,
**kwargs,
):
ws: WorkerState
wws: WorkerState
recommendations: dict = {}
client_msgs: dict = {}
worker_msgs: dict = {}
try:
ts: TaskState = self.tasks[key]
ts = self.tasks[key]

assert worker
assert isinstance(worker, str)
Expand Down Expand Up @@ -2615,13 +2609,13 @@ def get_task_duration(self, ts: TaskState) -> float:
if duration >= 0:
return duration

s: set = self.unknown_durations.get(ts.prefix.name) # type: ignore
s = self.unknown_durations.get(ts.prefix.name)
if s is None:
self.unknown_durations[ts.prefix.name] = s = set()
s.add(ts)
return self.UNKNOWN_TASK_DURATION

def valid_workers(self, ts: TaskState) -> set: # set[WorkerState] | None
def valid_workers(self, ts: TaskState) -> set[WorkerState] | None:
"""Return set of currently valid workers for key
If all workers are valid then this returns ``None``.
Expand All @@ -2631,7 +2625,7 @@ def valid_workers(self, ts: TaskState) -> set: # set[WorkerState] | None
* host_restrictions
* resource_restrictions
"""
s: set = None # type: ignore
s: set | None = None

if ts.worker_restrictions:
s = {addr for addr in ts.worker_restrictions if addr in self.workers}
Expand All @@ -2643,7 +2637,7 @@ def valid_workers(self, ts: TaskState) -> set: # set[WorkerState] | None
# XXX need HostState?
sl: list = []
for h in hr:
dh: dict = self.host_info.get(h) # type: ignore
dh = self.host_info.get(h)
if dh is not None:
sl.append(dh["addresses"])

Expand All @@ -2654,9 +2648,9 @@ def valid_workers(self, ts: TaskState) -> set: # set[WorkerState] | None
s |= ss

if ts.resource_restrictions:
dw: dict = {}
dw = {}
for resource, required in ts.resource_restrictions.items():
dr: dict = self.resources.get(resource) # type: ignore
dr = self.resources.get(resource)
if dr is None:
self.resources[resource] = dr = {}

Expand Down Expand Up @@ -2775,10 +2769,9 @@ def bulk_schedule_after_adding_worker(self, ws: WorkerState):
immediately, without waiting for the batch to end, we can't rely on worker-side
ordering, so the recommendations are sorted by priority order here.
"""
ts: TaskState
tasks = []
for ts in self.unrunnable:
valid: set = self.valid_workers(ts)
valid = self.valid_workers(ts)
if valid is None or ws in valid:
tasks.append(ts)
# These recommendations will generate {"op": "compute-task"} messages
Expand Down Expand Up @@ -3648,11 +3641,11 @@ async def add_worker(
if ws.status == Status.running:
self.running.add(ws)

dh: dict = self.host_info.get(host) # type: ignore
dh = self.host_info.get(host)
if dh is None:
self.host_info[host] = dh = {}

dh_addresses: set = dh.get("addresses") # type: ignore
dh_addresses = dh.get("addresses")
if dh_addresses is None:
dh["addresses"] = dh_addresses = set()
dh["nthreads"] = 0
Expand Down Expand Up @@ -5794,7 +5787,7 @@ def workers_to_close(
comm=None,
memory_ratio: int | float | None = None,
n: int | None = None,
key: Callable[[WorkerState], Hashable] | None = None,
key: Callable[[WorkerState], Hashable] | bytes | None = None,
minimum: int | None = None,
target: int | None = None,
attribute: str = "address",
Expand Down Expand Up @@ -6148,7 +6141,7 @@ def update_data(
logger.debug("Update data %s", who_has)

for key, workers in who_has.items():
ts: TaskState = self.tasks.get(key) # type: ignore
ts = self.tasks.get(key)
if ts is None:
ts = self.new_task(key, None, "memory")
ts.state = "memory"
Expand All @@ -6157,7 +6150,7 @@ def update_data(
ts.set_nbytes(ts_nbytes)

for w in workers:
ws: WorkerState = self.workers[w]
ws = self.workers[w]
if ws not in ts.who_has:
self.add_replica(ts, ws)
self.report({"op": "key-in-memory", "key": key, "workers": list(workers)})
Expand All @@ -6172,7 +6165,6 @@ def report_on_key(self, key: str = None, ts: TaskState = None, client: str = Non
key = ts.key
else:
assert False, (key, ts)
return

if ts is not None:
report_msg = _task_to_report_msg(ts)
Expand Down Expand Up @@ -6602,7 +6594,7 @@ def add_resources(self, worker: str, resources=None):
ws.used_resources = {}
for resource, quantity in ws.resources.items():
ws.used_resources[resource] = 0
dr: dict = self.resources.get(resource, None)
dr = self.resources.get(resource, None)
if dr is None:
self.resources[resource] = dr = {}
dr[worker] = quantity
Expand Down
1 change: 0 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@ warn_unreachable = true
allow_incomplete_defs = true
[mypy-distributed.scheduler]
allow_incomplete_defs = true
warn_unreachable = false
[mypy-distributed.worker]
allow_incomplete_defs = true

Expand Down

0 comments on commit f0680c9

Please sign in to comment.