Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor resource restriction handling in WorkerState #6672

Merged
merged 11 commits into from
Jul 6, 2022
10 changes: 5 additions & 5 deletions distributed/tests/test_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ async def test_submit_many_non_overlapping_2(c, s, a, b):
assert b.state.executing_count <= 1

await wait(futures)
assert a.total_resources == a.state.available_resources
assert b.total_resources == b.state.available_resources
assert a.state.total_resources == a.state.available_resources
assert b.state.total_resources == b.state.available_resources


@gen_cluster(
Expand Down Expand Up @@ -232,7 +232,7 @@ async def test_minimum_resource(c, s, a):
assert a.state.executing_count <= 1

await wait(futures)
assert a.total_resources == a.state.available_resources
assert a.state.total_resources == a.state.available_resources


@gen_cluster(client=True, nthreads=[("127.0.0.1", 2, {"resources": {"A": 1}})])
Expand Down Expand Up @@ -271,7 +271,7 @@ async def test_balance_resources(c, s, a, b):
@gen_cluster(client=True, nthreads=[("127.0.0.1", 2)])
async def test_set_resources(c, s, a):
await a.set_resources(A=2)
assert a.total_resources["A"] == 2
assert a.state.total_resources["A"] == 2
assert a.state.available_resources["A"] == 2
assert s.workers[a.address].resources == {"A": 2}
lock = Lock()
Expand All @@ -281,7 +281,7 @@ async def test_set_resources(c, s, a):
await asyncio.sleep(0.01)

await a.set_resources(A=3)
assert a.total_resources["A"] == 3
assert a.state.total_resources["A"] == 3
assert a.state.available_resources["A"] == 2
assert s.workers[a.address].resources == {"A": 3}

Expand Down
2 changes: 1 addition & 1 deletion distributed/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2449,7 +2449,7 @@ def ws_with_running_task(ws, request):

The task may or may not raise secede(); the tests using this fixture runs twice.
"""
ws.available_resources = {"R": 1}
ws.set_resources(R=1)
instructions = ws.handle_stimulus(
ComputeTaskEvent.dummy(
key="x", resource_restrictions={"R": 1}, stimulus_id="compute"
Expand Down
18 changes: 6 additions & 12 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,6 @@ class Worker(BaseWorker, ServerNode):
_dashboard_address: str | None
_dashboard: bool
_http_prefix: str
total_resources: dict[str, float]
death_timeout: float | None
lifetime: float | None
lifetime_stagger: float | None
Expand Down Expand Up @@ -628,7 +627,6 @@ def __init__(
if resources is None:
resources = dask.config.get("distributed.worker.resources")
assert isinstance(resources, dict)
self.total_resources = resources.copy()

self.death_timeout = parse_timedelta(death_timeout)

Expand Down Expand Up @@ -754,7 +752,7 @@ def __init__(
data=self.memory_manager.data,
threads=self.threads,
plugins=self.plugins,
resources=self.total_resources,
resources=resources,
total_out_connections=total_out_connections,
validate=validate,
transition_counter_max=transition_counter_max,
Expand Down Expand Up @@ -877,6 +875,7 @@ def data(self) -> MutableMapping[str, Any]:
tasks = DeprecatedWorkerStateAttribute()
target_message_size = DeprecatedWorkerStateAttribute()
total_out_connections = DeprecatedWorkerStateAttribute()
total_resources = DeprecatedWorkerStateAttribute()
transition_counter = DeprecatedWorkerStateAttribute()
transition_counter_max = DeprecatedWorkerStateAttribute()
validate = DeprecatedWorkerStateAttribute()
Expand Down Expand Up @@ -1100,7 +1099,7 @@ async def _register_with_scheduler(self) -> None:
},
types={k: typename(v) for k, v in self.data.items()},
now=time(),
resources=self.total_resources,
resources=self.state.total_resources,
memory_limit=self.memory_manager.memory_limit,
local_directory=self.local_directory,
services=self.service_ports,
Expand Down Expand Up @@ -1752,17 +1751,12 @@ def update_data(
)
return {"nbytes": {k: sizeof(v) for k, v in data.items()}, "status": "OK"}

async def set_resources(self, **resources) -> None:
for r, quantity in resources.items():
if r in self.total_resources:
self.state.available_resources[r] += quantity - self.total_resources[r]
else:
self.state.available_resources[r] = quantity
self.total_resources[r] = quantity
async def set_resources(self, **resources: float) -> None:
self.state.set_resources(**resources)

await retry_operation(
self.scheduler.set_resources,
resources=self.total_resources,
resources=self.state.total_resources,
worker=self.contact_address,
)

Expand Down
55 changes: 35 additions & 20 deletions distributed/worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@
}
READY: set[TaskStateState] = {"ready", "constrained"}


NO_VALUE = "--no-value-sentinel--"


Expand Down Expand Up @@ -1027,8 +1026,12 @@ class WorkerState:
#: determining a last-in-first-out order between them.
generation: int

#: ``{resource name: amount}``. Total resources available for task execution.
#: See :doc: `resources`.
total_resources: dict[str, float]

#: ``{resource name: amount}``. Current resources that aren't being currently
#: consumed by task execution. Always less or equal to ``Worker.total_resources``.
#: consumed by task execution. Always less or equal to :attr:`total_resources`.
#: See :doc:`resources`.
available_resources: dict[str, float]

Expand Down Expand Up @@ -1102,7 +1105,8 @@ def __init__(
self.data = data if data is not None else {}
self.threads = threads if threads is not None else {}
self.plugins = plugins if plugins is not None else {}
self.available_resources = dict(resources) if resources is not None else {}
self.total_resources = dict(resources) if resources is not None else {}
self.available_resources = self.total_resources.copy()

self.validate = validate
self.tasks = {}
Expand Down Expand Up @@ -1445,15 +1449,11 @@ def _ensure_computing(self) -> RecsInstrs:
if ts in recs:
continue

if any(
self.available_resources[resource] < needed
for resource, needed in ts.resource_restrictions.items()
):
if not self._resource_restrictions_satisfied(ts):
break

self.constrained.popleft()
for resource, needed in ts.resource_restrictions.items():
self.available_resources[resource] -= needed
self._acquire_resources(ts)
hendrikmakait marked this conversation as resolved.
Show resolved Hide resolved

recs[ts] = "executing"
self.executing.add(ts)
Expand Down Expand Up @@ -1734,8 +1734,7 @@ def _transition_executing_rescheduled(
# Reschedule(), which is "cancelled"
assert ts.state in ("executing", "long-running"), ts

for resource, quantity in ts.resource_restrictions.items():
self.available_resources[resource] += quantity
self._release_resources(ts)
self.executing.discard(ts)

return merge_recs_instructions(
Expand Down Expand Up @@ -1831,8 +1830,7 @@ def _transition_executing_error(
*,
stimulus_id: str,
) -> RecsInstrs:
for resource, quantity in ts.resource_restrictions.items():
self.available_resources[resource] += quantity
self._release_resources(ts)
self.executing.discard(ts)

return merge_recs_instructions(
Expand Down Expand Up @@ -1977,9 +1975,7 @@ def _transition_cancelled_released(
self.executing.discard(ts)
self.in_flight_tasks.discard(ts)

for resource, quantity in ts.resource_restrictions.items():
self.available_resources[resource] += quantity

self._release_resources(ts)
return self._transition_generic_released(ts, stimulus_id=stimulus_id)

def _transition_executing_released(
Expand All @@ -2006,10 +2002,7 @@ def _transition_generic_memory(
f"Tried to transition task {ts} to `memory` without data available"
)

if ts.resource_restrictions is not None:
for resource, quantity in ts.resource_restrictions.items():
self.available_resources[resource] += quantity

self._release_resources(ts)
self.executing.discard(ts)
self.in_flight_tasks.discard(ts)
ts.coming_from = None
Expand Down Expand Up @@ -2351,6 +2344,20 @@ def _transition(
)
return recs, instructions

def _resource_restrictions_satisfied(self, ts: TaskState) -> bool:
return all(
self.available_resources[resource] >= needed
for resource, needed in ts.resource_restrictions.items()
)

def _acquire_resources(self, ts: TaskState) -> None:
crusaderky marked this conversation as resolved.
Show resolved Hide resolved
for resource, needed in ts.resource_restrictions.items():
self.available_resources[resource] -= needed

def _release_resources(self, ts: TaskState) -> None:
for resource, quantity in ts.resource_restrictions.items():
self.available_resources[resource] += quantity
hendrikmakait marked this conversation as resolved.
Show resolved Hide resolved

def _transitions(self, recommendations: Recs, *, stimulus_id: str) -> Instructions:
"""Process transitions until none are left

Expand Down Expand Up @@ -3109,6 +3116,14 @@ def validate_state(self) -> None:
for tss in self.data_needed.values():
assert len({ts.key for ts in tss}) == len(tss)

def set_resources(self, **resources: float) -> None:
for r, quantity in resources.items():
if r in self.total_resources:
self.available_resources[r] += quantity - self.total_resources[r]
else:
self.available_resources[r] = quantity
self.total_resources[r] = quantity

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are many problems with the code of Worker.set_resources. It would deserve a much more thorough overhaul.
For the purpose of this PR, I'd rather revert this and just tamper with the state directly from Worker.set_resources.

Copy link
Member Author

@hendrikmakait hendrikmakait Jul 6, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One issue I see with this is that this method would be quite useful for tests that are currently setting attributes directly which results in a technically invalid worker state (https://github.com/dask/distributed/blob/main/distributed/utils_test.py#L2452). Given that you have already set up a ticket for the more thorough overhaul (#6677), I'd prefer leaving this as is, in particular since a follow-up PR will add more validations of worker state that check for consistency between total_resources, available_resources and any currently acquired resources.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would need to be a SetResourcesEvent anyway. I'd rather not have a throwaway refactoring.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, updated ws_with_running_task as well.


class BaseWorker(abc.ABC):
"""Wrapper around the :class:`WorkerState` that implements instructions handling.
Expand Down