diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 190bc7c30c9..4affb556bdf 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -3604,29 +3604,27 @@ async def close_worker(self, comm=None, worker=None, safe=None): def heartbeat_worker( self, comm=None, - address=None, - resolve_address=True, - now=None, - resources=None, - host_info=None, - metrics=None, - executing=None, + *, + address, + resolve_address: bool = True, + now: float = None, + resources: dict = None, + host_info: dict = None, + metrics: dict, + executing: dict = None, ): parent: SchedulerState = cast(SchedulerState, self) address = self.coerce_address(address, resolve_address) address = normalize_address(address) - if address not in parent._workers: + ws: WorkerState = parent._workers.get(address) + if ws is None: return {"status": "missing"} host = get_address_host(address) local_now = time() - now = now or time() - assert metrics host_info = host_info or {} - dh: dict = parent._host_info.get(host) - if dh is None: - parent._host_info[host] = dh = dict() + dh: dict = parent._host_info.setdefault(host, {}) dh["last-seen"] = local_now frac = 1 / len(parent._workers) @@ -3650,26 +3648,20 @@ def heartbeat_worker( 1 - alpha ) - ws: WorkerState = parent._workers[address] - - ws._last_seen = time() - + ws._last_seen = local_now if executing is not None: ws._executing = { parent._tasks[key]: duration for key, duration in executing.items() } - if metrics: - ws._metrics = metrics + ws._metrics = metrics if host_info: - dh: dict = parent._host_info.get(host) - if dh is None: - parent._host_info[host] = dh = dict() + dh: dict = parent._host_info.setdefault(host, {}) dh.update(host_info) - delay = time() - now - ws._time_delay = delay + if now: + ws._time_delay = local_now - now if resources: self.add_resources(worker=address, resources=resources) @@ -3678,7 +3670,7 @@ def heartbeat_worker( return { "status": "OK", - "time": time(), + "time": local_now, "heartbeat-interval": heartbeat_interval(len(parent._workers)), } @@ -3756,7 +3748,7 @@ async def add_worker( parent._total_nthreads += nthreads parent._aliases[name] = address - response = self.heartbeat_worker( + self.heartbeat_worker( address=address, resolve_address=resolve_address, now=now, @@ -5331,7 +5323,7 @@ async def rebalance(self, comm=None, keys=None, workers=None): map(first, sorted(worker_bytes.items(), key=second, reverse=True)) ) - recipients = iter(reversed(sorted_workers)) + recipients = reversed(sorted_workers) recipient = next(recipients) msgs = [] # (sender, recipient, key) for sender in sorted_workers[: len(workers) // 2]: @@ -5343,11 +5335,8 @@ async def rebalance(self, comm=None, keys=None, workers=None): ) try: - while worker_bytes[sender] > avg: - while ( - worker_bytes[recipient] < avg - and worker_bytes[sender] > avg - ): + while avg < worker_bytes[sender]: + while worker_bytes[recipient] < avg < worker_bytes[sender]: ts, nb = next(sender_keys) if ts not in tasks_by_worker[recipient]: tasks_by_worker[recipient].add(ts) @@ -5355,7 +5344,7 @@ async def rebalance(self, comm=None, keys=None, workers=None): msgs.append((sender, recipient, ts)) worker_bytes[sender] -= nb worker_bytes[recipient] += nb - if worker_bytes[sender] > avg: + if avg < worker_bytes[sender]: recipient = next(recipients) except StopIteration: break @@ -5386,7 +5375,7 @@ async def rebalance(self, comm=None, keys=None, workers=None): }, ) - if not all(r["status"] == "OK" for r in result): + if any(r["status"] != "OK" for r in result): return { "status": "missing-data", "keys": tuple( @@ -5687,7 +5676,7 @@ async def retire_workers( workers: list (optional) List of worker addresses to retire. If not provided we call ``workers_to_close`` which finds a good set - workers_names: list (optional) + names: list (optional) List of worker names to retire. remove: bool (defaults to True) Whether or not to remove the worker metadata immediately or else @@ -5715,30 +5704,31 @@ async def retire_workers( with log_errors(): async with self._lock if lock else empty_context: if names is not None: + if workers is not None: + raise TypeError("names and workers are mutually exclusive") if names: logger.info("Retire worker names %s", names) names = set(map(str, names)) - workers = [ + workers = { ws._address for ws in parent._workers_dv.values() if str(ws._name) in names - ] - if workers is None: + } + elif workers is None: while True: try: workers = self.workers_to_close(**kwargs) - if workers: - workers = await self.retire_workers( - workers=workers, - remove=remove, - close_workers=close_workers, - lock=False, - ) - return workers - else: + if not workers: return {} + return await self.retire_workers( + workers=workers, + remove=remove, + close_workers=close_workers, + lock=False, + ) except KeyError: # keys left during replicate pass + workers = { parent._workers_dv[w] for w in workers if w in parent._workers_dv } @@ -5750,22 +5740,21 @@ async def retire_workers( keys = set.union(*[w.has_what for w in workers]) keys = {ts._key for ts in keys if ts._who_has.issubset(workers)} - other_workers = set(parent._workers_dv.values()) - workers if keys: - if other_workers: - logger.info("Moving %d keys to other workers", len(keys)) - await self.replicate( - keys=keys, - workers=[ws._address for ws in other_workers], - n=1, - delete=False, - lock=False, - ) - else: + other_workers = set(parent._workers_dv.values()) - workers + if not other_workers: return {} + logger.info("Moving %d keys to other workers", len(keys)) + await self.replicate( + keys=keys, + workers=[ws._address for ws in other_workers], + n=1, + delete=False, + lock=False, + ) worker_keys = {ws._address: ws.identity() for ws in workers} - if close_workers and worker_keys: + if close_workers: await asyncio.gather( *[self.close_worker(worker=w, safe=True) for w in worker_keys] ) diff --git a/distributed/worker.py b/distributed/worker.py index 276b41fd853..f448f9730e1 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -25,7 +25,7 @@ from dask.utils import format_bytes, funcname from dask.system import CPU_COUNT -from tlz import pluck, merge, first, keymap +from tlz import pluck, first, keymap from tornado import gen from tornado.ioloop import IOLoop, PeriodicCallback @@ -806,8 +806,7 @@ def local_dir(self): return self.local_directory async def get_metrics(self): - now = time() - core = dict( + out = dict( executing=self.executing_count, in_memory=len(self.data), ready=len(self.ready), @@ -818,17 +817,19 @@ async def get_metrics(self): "types": keymap(typename, self.bandwidth_types), }, ) - custom = {} + out.update(self.monitor.recent()) + for k, metric in self.metrics.items(): try: result = metric(self) if isawaitable(result): result = await result - custom[k] = result + # In case of collision, prefer core metrics + out.setdefault(k, result) except Exception: # TODO: log error once pass - return merge(custom, self.monitor.recent(), core) + return out async def get_startup_information(self): result = {} @@ -934,56 +935,57 @@ def _update_latency(self, latency): self.digests["latency"].add(latency) async def heartbeat(self): - if not self.heartbeat_active: - self.heartbeat_active = True - logger.debug("Heartbeat: %s" % self.address) - try: - start = time() - response = await retry_operation( - self.scheduler.heartbeat_worker, - address=self.contact_address, - now=time(), - metrics=await self.get_metrics(), - executing={ - key: start - self.tasks[key].start_time - for key in self.active_threads.values() - if key in self.tasks - }, - ) - end = time() - middle = (start + end) / 2 + if self.heartbeat_active: + logger.debug("Heartbeat skipped: channel busy") + return - self._update_latency(end - start) + self.heartbeat_active = True + logger.debug("Heartbeat: %s", self.address) + try: + start = time() + response = await retry_operation( + self.scheduler.heartbeat_worker, + address=self.contact_address, + now=start, + metrics=await self.get_metrics(), + executing={ + key: start - self.tasks[key].start_time + for key in self.active_threads.values() + if key in self.tasks + }, + ) + end = time() + middle = (start + end) / 2 + + self._update_latency(end - start) - if response["status"] == "missing": - for i in range(10): - if self.status != Status.running: - break - else: - await asyncio.sleep(0.05) + if response["status"] == "missing": + for i in range(10): + if self.status != Status.running: + break else: - await self._register_with_scheduler() - return - self.scheduler_delay = response["time"] - middle - self.periodic_callbacks["heartbeat"].callback_time = ( - response["heartbeat-interval"] * 1000 - ) - self.bandwidth_workers.clear() - self.bandwidth_types.clear() - except CommClosedError: - logger.warning("Heartbeat to scheduler failed") - if not self.reconnect: - await self.close(report=False) - except IOError as e: - # Scheduler is gone. Respect distributed.comm.timeouts.connect - if "Timed out trying to connect" in str(e): - await self.close(report=False) + await asyncio.sleep(0.05) else: - raise e - finally: - self.heartbeat_active = False - else: - logger.debug("Heartbeat skipped: channel busy") + await self._register_with_scheduler() + return + self.scheduler_delay = response["time"] - middle + self.periodic_callbacks["heartbeat"].callback_time = ( + response["heartbeat-interval"] * 1000 + ) + self.bandwidth_workers.clear() + self.bandwidth_types.clear() + except CommClosedError: + logger.warning("Heartbeat to scheduler failed") + if not self.reconnect: + await self.close(report=False) + except IOError as e: + # Scheduler is gone. Respect distributed.comm.timeouts.connect + if "Timed out trying to connect" in str(e): + await self.close(report=False) + else: + raise e + finally: + self.heartbeat_active = False async def handle_scheduler(self, comm): try: @@ -2759,26 +2761,25 @@ async def execute(self, key, report=False): self.transition(ts, "memory", value=value) if self.digests is not None: self.digests["task-duration"].add(result["stop"] - result["start"]) + elif isinstance(result.pop("actual-exception"), Reschedule): + self.batched_stream.send({"op": "reschedule", "key": ts.key}) + self.transition(ts, "rescheduled", report=False) + self.release_key(ts.key, report=False) else: - if isinstance(result.pop("actual-exception"), Reschedule): - self.batched_stream.send({"op": "reschedule", "key": ts.key}) - self.transition(ts, "rescheduled", report=False) - self.release_key(ts.key, report=False) - else: - ts.exception = result["exception"] - ts.traceback = result["traceback"] - logger.warning( - " Compute Failed\n" - "Function: %s\n" - "args: %s\n" - "kwargs: %s\n" - "Exception: %s\n", - str(funcname(function))[:1000], - convert_args_to_str(args2, max_len=1000), - convert_kwargs_to_str(kwargs2, max_len=1000), - repr(result["exception"].data), - ) - self.transition(ts, "error") + ts.exception = result["exception"] + ts.traceback = result["traceback"] + logger.warning( + "Compute Failed\n" + "Function: %s\n" + "args: %s\n" + "kwargs: %s\n" + "Exception: %r\n", + str(funcname(function))[:1000], + convert_args_to_str(args2, max_len=1000), + convert_kwargs_to_str(kwargs2, max_len=1000), + result["exception"].data, + ) + self.transition(ts, "error") logger.debug("Send compute response to scheduler: %s, %s", ts.key, result) @@ -3365,8 +3366,6 @@ class Reschedule(Exception): the task. """ - pass - def parse_memory_limit(memory_limit, nthreads, total_cores=CPU_COUNT): if memory_limit is None: