Skip to content

Commit

Permalink
Partial revert of compute-task message format (#6626)
Browse files Browse the repository at this point in the history
Hotfix for #6624 by reverting the compute-task message format almost to the original state before #6410
  • Loading branch information
fjetter authored Jun 24, 2022
1 parent f129485 commit a156c35
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 13 deletions.
13 changes: 9 additions & 4 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7321,16 +7321,21 @@ def _task_to_msg(
dts.key: [ws.address for ws in dts.who_has] for dts in ts.dependencies
},
"nbytes": {dts.key: dts.nbytes for dts in ts.dependencies},
"run_spec": ts.run_spec,
"run_spec": None,
"function": None,
"args": None,
"kwargs": None,
"resource_restrictions": ts.resource_restrictions,
"actor": ts.actor,
"annotations": ts.annotations,
}
if state.validate:
assert all(msg["who_has"].values())
if isinstance(msg["run_spec"], dict):
assert set(msg["run_spec"]).issubset({"function", "args", "kwargs"})
assert msg["run_spec"].get("function")

if isinstance(ts.run_spec, dict):
msg.update(ts.run_spec)
else:
msg["run_spec"] = ts.run_spec

return msg

Expand Down
15 changes: 15 additions & 0 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3641,3 +3641,18 @@ async def test_worker_state_unique_regardless_of_address(s, w):
async def test_scheduler_close_fast_deprecated(s, w):
with pytest.warns(FutureWarning):
await s.close(fast=True)


def test_runspec_regression_sync():
# https://github.com/dask/distributed/issues/6624

da = pytest.importorskip("dask.array")
np = pytest.importorskip("numpy")
with Client():
v = da.random.random((20, 20), chunks=(5, 5))

overlapped = da.map_overlap(np.sum, v, depth=2, boundary="reflect")
# This computation is somehow broken but we want to avoid catching any
# serialization errors that result in KilledWorker
with pytest.raises(IndexError):
overlapped.compute()
9 changes: 7 additions & 2 deletions distributed/tests/test_worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,12 +232,14 @@ def test_computetask_to_dict():
nbytes={"y": 123},
priority=(0,),
duration=123.45,
# Automatically converted to SerializedTask on init
run_spec={"function": b"blob", "args": b"blob"},
run_spec=None,
resource_restrictions={},
actor=False,
annotations={},
stimulus_id="test",
function=b"blob",
args=b"blob",
kwargs=None,
)
assert ev.run_spec == SerializedTask(function=b"blob", args=b"blob")
ev2 = ev.to_loggable(handled=11.22)
Expand All @@ -258,6 +260,9 @@ def test_computetask_to_dict():
"annotations": {},
"stimulus_id": "test",
"handled": 11.22,
"function": None,
"args": None,
"kwargs": None,
}
ev3 = StateMachineEvent.from_dict(d)
assert isinstance(ev3, ComputeTaskEvent)
Expand Down
30 changes: 23 additions & 7 deletions distributed/worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,10 @@ class ComputeTaskEvent(StateMachineEvent):
nbytes: dict[str, int]
priority: tuple[int, ...]
duration: float
run_spec: SerializedTask
run_spec: SerializedTask | None
function: bytes | None
args: bytes | tuple | list | None | None
kwargs: bytes | dict[str, Any] | None
resource_restrictions: dict[str, float]
actor: bool
annotations: dict
Expand All @@ -663,19 +666,32 @@ def __post_init__(self) -> None:
if isinstance(self.priority, list): # type: ignore[unreachable]
self.priority = tuple(self.priority) # type: ignore[unreachable]

if isinstance(self.run_spec, dict):
self.run_spec = SerializedTask(**self.run_spec) # type: ignore[unreachable]
if self.function is not None:
assert self.run_spec is None
self.run_spec = SerializedTask(
function=self.function, args=self.args, kwargs=self.kwargs
)
elif not isinstance(self.run_spec, SerializedTask):
self.run_spec = SerializedTask(task=self.run_spec) # type: ignore[unreachable]
self.run_spec = SerializedTask(task=self.run_spec)

def to_loggable(self, *, handled: float) -> StateMachineEvent:
def _to_dict(self, *, exclude: Container[str] = ()) -> dict:
return StateMachineEvent._to_dict(self._clean(), exclude=exclude)

def _clean(self) -> StateMachineEvent:
out = copy(self)
out.function = None
out.kwargs = None
out.args = None
out.run_spec = SerializedTask(task=None, function=None, args=None, kwargs=None)
return out

def to_loggable(self, *, handled: float) -> StateMachineEvent:
out = self._clean()
out.handled = handled
out.run_spec = SerializedTask(task=None)
return out

def _after_from_dict(self) -> None:
self.run_spec = SerializedTask(task=None)
self.run_spec = SerializedTask(task=None, function=None, args=None, kwargs=None)


@dataclass
Expand Down

0 comments on commit a156c35

Please sign in to comment.