Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Partial revert of compute-task message format #6626

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7321,16 +7321,21 @@ def _task_to_msg(
dts.key: [ws.address for ws in dts.who_has] for dts in ts.dependencies
},
"nbytes": {dts.key: dts.nbytes for dts in ts.dependencies},
"run_spec": ts.run_spec,
"run_spec": None,
"function": None,
"args": None,
"kwargs": None,
Comment on lines +7324 to +7327
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't 100% see the point of initializing these all with None but I also don't see the harm in it.

"resource_restrictions": ts.resource_restrictions,
"actor": ts.actor,
"annotations": ts.annotations,
}
if state.validate:
assert all(msg["who_has"].values())
if isinstance(msg["run_spec"], dict):
assert set(msg["run_spec"]).issubset({"function", "args", "kwargs"})
assert msg["run_spec"].get("function")

if isinstance(ts.run_spec, dict):
msg.update(ts.run_spec)
else:
msg["run_spec"] = ts.run_spec

return msg

Expand Down
15 changes: 15 additions & 0 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3641,3 +3641,18 @@ async def test_worker_state_unique_regardless_of_address(s, w):
async def test_scheduler_close_fast_deprecated(s, w):
with pytest.warns(FutureWarning):
await s.close(fast=True)


def test_runspec_regression_sync():
# https://github.com/dask/distributed/issues/6624

da = pytest.importorskip("dask.array")
np = pytest.importorskip("numpy")
with Client():
v = da.random.random((20, 20), chunks=(5, 5))

overlapped = da.map_overlap(np.sum, v, depth=2, boundary="reflect")
# This computation is somehow broken but we want to avoid catching any
# serialization errors that result in KilledWorker
with pytest.raises(IndexError):
overlapped.compute()
Comment on lines +3657 to +3658
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is obviously not ideal but I simply took the reproducer from #6624
I do not hit the serialization error anymore but rather an IndexError ¯_(ツ)_/¯

9 changes: 7 additions & 2 deletions distributed/tests/test_worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,12 +232,14 @@ def test_computetask_to_dict():
nbytes={"y": 123},
priority=(0,),
duration=123.45,
# Automatically converted to SerializedTask on init
run_spec={"function": b"blob", "args": b"blob"},
run_spec=None,
resource_restrictions={},
actor=False,
annotations={},
stimulus_id="test",
function=b"blob",
args=b"blob",
kwargs=None,
)
assert ev.run_spec == SerializedTask(function=b"blob", args=b"blob")
ev2 = ev.to_loggable(handled=11.22)
Expand All @@ -258,6 +260,9 @@ def test_computetask_to_dict():
"annotations": {},
"stimulus_id": "test",
"handled": 11.22,
"function": None,
"args": None,
"kwargs": None,
}
ev3 = StateMachineEvent.from_dict(d)
assert isinstance(ev3, ComputeTaskEvent)
Expand Down
30 changes: 23 additions & 7 deletions distributed/worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,10 @@ class ComputeTaskEvent(StateMachineEvent):
nbytes: dict[str, int]
priority: tuple[int, ...]
duration: float
run_spec: SerializedTask
run_spec: SerializedTask | None
function: bytes | None
args: bytes | tuple | list | None | None
kwargs: bytes | dict[str, Any] | None
resource_restrictions: dict[str, float]
actor: bool
annotations: dict
Expand All @@ -663,19 +666,32 @@ def __post_init__(self) -> None:
if isinstance(self.priority, list): # type: ignore[unreachable]
self.priority = tuple(self.priority) # type: ignore[unreachable]

if isinstance(self.run_spec, dict):
self.run_spec = SerializedTask(**self.run_spec) # type: ignore[unreachable]
if self.function is not None:
assert self.run_spec is None
self.run_spec = SerializedTask(
function=self.function, args=self.args, kwargs=self.kwargs
)
elif not isinstance(self.run_spec, SerializedTask):
self.run_spec = SerializedTask(task=self.run_spec) # type: ignore[unreachable]
self.run_spec = SerializedTask(task=self.run_spec)

def to_loggable(self, *, handled: float) -> StateMachineEvent:
def _to_dict(self, *, exclude: Container[str] = ()) -> dict:
return StateMachineEvent._to_dict(self._clean(), exclude=exclude)

def _clean(self) -> StateMachineEvent:
out = copy(self)
out.function = None
out.kwargs = None
out.args = None
out.run_spec = SerializedTask(task=None, function=None, args=None, kwargs=None)
return out

def to_loggable(self, *, handled: float) -> StateMachineEvent:
out = self._clean()
out.handled = handled
out.run_spec = SerializedTask(task=None)
return out

def _after_from_dict(self) -> None:
self.run_spec = SerializedTask(task=None)
self.run_spec = SerializedTask(task=None, function=None, args=None, kwargs=None)


@dataclass
Expand Down