diff --git a/distributed/worker.py b/distributed/worker.py index 577e75775e..c38e45cdae 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -406,7 +406,6 @@ class Worker(BaseWorker, ServerNode): _dashboard_address: str | None _dashboard: bool _http_prefix: str - total_resources: dict[str, float] death_timeout: float | None lifetime: float | None lifetime_stagger: float | None @@ -628,7 +627,6 @@ def __init__( if resources is None: resources = dask.config.get("distributed.worker.resources") assert isinstance(resources, dict) - self.total_resources = resources.copy() self.death_timeout = parse_timedelta(death_timeout) @@ -754,7 +752,7 @@ def __init__( data=self.memory_manager.data, threads=self.threads, plugins=self.plugins, - resources=self.total_resources, + resources=resources, total_out_connections=total_out_connections, validate=validate, transition_counter_max=transition_counter_max, @@ -877,6 +875,7 @@ def data(self) -> MutableMapping[str, Any]: tasks = DeprecatedWorkerStateAttribute() target_message_size = DeprecatedWorkerStateAttribute() total_out_connections = DeprecatedWorkerStateAttribute() + total_resources = DeprecatedWorkerStateAttribute() transition_counter = DeprecatedWorkerStateAttribute() transition_counter_max = DeprecatedWorkerStateAttribute() validate = DeprecatedWorkerStateAttribute() @@ -1100,7 +1099,7 @@ async def _register_with_scheduler(self) -> None: }, types={k: typename(v) for k, v in self.data.items()}, now=time(), - resources=self.total_resources, + resources=self.state.total_resources, memory_limit=self.memory_manager.memory_limit, local_directory=self.local_directory, services=self.service_ports, @@ -1753,16 +1752,11 @@ def update_data( return {"nbytes": {k: sizeof(v) for k, v in data.items()}, "status": "OK"} async def set_resources(self, **resources) -> None: - for r, quantity in resources.items(): - if r in self.total_resources: - self.state.available_resources[r] += quantity - self.total_resources[r] - else: - self.state.available_resources[r] = quantity - self.total_resources[r] = quantity + self.state.set_resources(resources) await retry_operation( self.scheduler.set_resources, - resources=self.total_resources, + resources=self.state.total_resources, worker=self.contact_address, ) diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index 6759e50ff0..d04d79b8e7 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -81,7 +81,7 @@ "resumed", } READY: set[TaskStateState] = {"ready", "constrained"} - +RUNNING: set[TaskStateState] = {"executing", "long-running", "cancelled", "resumed"} NO_VALUE = "--no-value-sentinel--" @@ -1027,8 +1027,12 @@ class WorkerState: #: determining a last-in-first-out order between them. generation: int + #: ``{resource name: amount}``. Total resources available for task execution. + #: See :doc: `resources`. + total_resources: dict[str, float] + #: ``{resource name: amount}``. Current resources that aren't being currently - #: consumed by task execution. Always less or equal to ``Worker.total_resources``. + #: consumed by task execution. Always less or equal to :attr:`total_resources`. #: See :doc:`resources`. available_resources: dict[str, float] @@ -1102,7 +1106,8 @@ def __init__( self.data = data if data is not None else {} self.threads = threads if threads is not None else {} self.plugins = plugins if plugins is not None else {} - self.available_resources = dict(resources) if resources is not None else {} + self.total_resources = dict(resources) if resources is not None else {} + self.available_resources = self.total_resources.copy() self.validate = validate self.tasks = {} @@ -3109,6 +3114,24 @@ def validate_state(self) -> None: for tss in self.data_needed.values(): assert len({ts.key for ts in tss}) == len(tss) + # Test that resources are consumed and released correctly + for resource, total in self.total_resources.items(): + available = self.available_resources[resource] + assert available >= 0 + allocated = 0.0 + for ts in self.tasks.values(): + if ts.resource_restrictions and ts.state in RUNNING: + allocated += ts.resource_restrictions.get(resource, 0) + assert available + allocated == total + + def set_resources(self, resources: dict[str, float]) -> None: + for r, quantity in resources.items(): + if r in self.total_resources: + self.available_resources[r] += quantity - self.total_resources[r] + else: + self.available_resources[r] = quantity + self.total_resources[r] = quantity + class BaseWorker(abc.ABC): """Wrapper around the :class:`WorkerState` that implements instructions handling.