Skip to content

Commit

Permalink
Clarify that Worker.story() can request arbitrary log tags
Browse files Browse the repository at this point in the history
Moved out of dask#6342
  • Loading branch information
crusaderky committed Jun 1, 2022
1 parent 715d7be commit 038fb72
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 29 deletions.
33 changes: 22 additions & 11 deletions distributed/_stories.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,55 @@
from __future__ import annotations

from typing import Iterable


def scheduler_story(keys: set, transition_log: Iterable) -> list:
def scheduler_story(
keys_or_stimuli: set[str], transition_log: Iterable[tuple]
) -> list[tuple]:
"""Creates a story from the scheduler transition log given a set of keys
describing tasks or stimuli.
Parameters
----------
keys : set
A set of task `keys` or `stimulus_id`'s
keys_or_stimuli : set[str]
Task keys or stimulus_id's
log : iterable
The scheduler transition log
Returns
-------
story : list
story : list[tuple]
"""
return [t for t in transition_log if t[0] in keys or keys.intersection(t[3])]
return [
t
for t in transition_log
if t[0] in keys_or_stimuli or keys_or_stimuli.intersection(t[3])
]


def worker_story(keys: set, log: Iterable) -> list:
def worker_story(keys_or_tags: set[str], log: Iterable[tuple]) -> list:
"""Creates a story from the worker log given a set of keys
describing tasks or stimuli.
Parameters
----------
keys : set
A set of task `keys` or `stimulus_id`'s
keys_or_tags : set[str]
Task keys or arbitrary tags from the transition log, e.g. stimulus_id's
log : iterable
The worker log
Returns
-------
story : list
story : list[str]
"""
return [
msg
for msg in log
if any(key in msg for key in keys)
if any(key in msg for key in keys_or_tags)
or any(
key in c for key in keys for c in msg if isinstance(c, (tuple, list, set))
key in c
for key in keys_or_tags
for c in msg
if isinstance(c, (tuple, list, set))
)
]
14 changes: 8 additions & 6 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4286,11 +4286,11 @@ def collections_to_dsk(collections, *args, **kwargs):
"""Convert many collections into a single dask graph, after optimization"""
return collections_to_dsk(collections, *args, **kwargs)

async def _story(self, keys=(), on_error="raise"):
async def _story(self, *keys_or_tags: str, on_error="raise"):
assert on_error in ("raise", "ignore")

try:
flat_stories = await self.scheduler.get_story(keys=keys)
flat_stories = await self.scheduler.get_story(keys_or_stimuli=keys_or_tags)
flat_stories = [("scheduler", *msg) for msg in flat_stories]
except Exception:
if on_error == "raise":
Expand All @@ -4301,15 +4301,17 @@ async def _story(self, keys=(), on_error="raise"):
raise ValueError(f"on_error not in {'raise', 'ignore'}")

responses = await self.scheduler.broadcast(
msg={"op": "get_story", "keys": keys}, on_error=on_error
msg={"op": "get_story", "keys_or_tags": keys_or_tags}, on_error=on_error
)
for worker, stories in responses.items():
flat_stories.extend((worker, *msg) for msg in stories)
return flat_stories

def story(self, *keys_or_stimulus_ids, on_error="raise"):
"""Returns a cluster-wide story for the given keys or simtulus_id's"""
return self.sync(self._story, keys=keys_or_stimulus_ids, on_error=on_error)
def story(self, *keys_or_tags, on_error="raise"):
"""Returns a cluster-wide story for the given keys or transition log tags, such
as stimulus_id's
"""
return self.sync(self._story, *keys_or_tags, on_error=on_error)

def get_task_stream(
self,
Expand Down
15 changes: 9 additions & 6 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6540,13 +6540,16 @@ def transitions(self, recommendations: dict, stimulus_id: str):
self._transitions(recommendations, client_msgs, worker_msgs, stimulus_id)
self.send_all(client_msgs, worker_msgs)

def story(self, *keys):
"""Get all transitions that touch one of the input keys"""
keys = {key.key if isinstance(key, TaskState) else key for key in keys}
return scheduler_story(keys, self.transition_log)
def story(self, *keys_or_tasks_or_stimuli: str | TaskState) -> list[tuple]:
"""Get all transitions that touch one of the input keys or stimulus_id's"""
keys_or_stimuli = {
key.key if isinstance(key, TaskState) else key
for key in keys_or_tasks_or_stimuli
}
return scheduler_story(keys_or_stimuli, self.transition_log)

async def get_story(self, keys=()):
return self.story(*keys)
async def get_story(self, keys_or_stimuli: Iterable[str]) -> list[tuple]:
return self.story(*keys_or_stimuli)

transition_story = story

Expand Down
16 changes: 10 additions & 6 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2949,13 +2949,17 @@ def stateof(self, key: str) -> dict[str, Any]:
"data": key in self.data,
}

def story(self, *keys_or_tasks: str | TaskState) -> list[tuple]:
"""Return all transitions involving one or more tasks"""
keys = {e.key if isinstance(e, TaskState) else e for e in keys_or_tasks}
return worker_story(keys, self.log)
def story(self, *keys_or_tasks_or_tags: str | TaskState) -> list[tuple]:
"""Return all records from the transitions log involving one or more tasks;
it can also be used for arbitrary non-transition tags.
"""
keys_or_tags = {
e.key if isinstance(e, TaskState) else e for e in keys_or_tasks_or_tags
}
return worker_story(keys_or_tags, self.log)

async def get_story(self, keys=None):
return self.story(*keys)
async def get_story(self, keys_or_tags: Iterable[str]) -> list[tuple]:
return self.story(*keys_or_tags)

def stimulus_story(
self, *keys_or_tasks: str | TaskState
Expand Down

0 comments on commit 038fb72

Please sign in to comment.