Skip to content

Commit

Permalink
Remove alias resolving to fix queuing (#8933)
Browse files Browse the repository at this point in the history
  • Loading branch information
phofl authored Nov 13, 2024
1 parent e2f3f96 commit 53a5679
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 18 deletions.
22 changes: 4 additions & 18 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,9 @@

import dask
import dask.utils
from dask._task_spec import (
DependenciesMapping,
GraphNode,
convert_legacy_graph,
resolve_aliases,
)
from dask._task_spec import DependenciesMapping, GraphNode, convert_legacy_graph
from dask.base import TokenizationError, normalize_token, tokenize
from dask.core import istask, reverse_dict, validate_key
from dask.core import istask, validate_key
from dask.typing import Key, no_default
from dask.utils import (
_deprecated,
Expand Down Expand Up @@ -9411,19 +9406,10 @@ def _materialize_graph(
)

dsk2 = convert_legacy_graph(dsk)
dependents = reverse_dict(DependenciesMapping(dsk2))
# This is removing weird references like "x-foo": "foo" which often make up
# a substantial part of the graph
# This also performs culling!
dsk3 = resolve_aliases(dsk2, keys, dependents)

logger.debug(
"Removing aliases. Started with %i and got %i left", len(dsk2), len(dsk3)
)
# FIXME: There should be no need to fully materialize and copy this but some
# sections in the scheduler are mutating it.
dependencies = {k: set(v) for k, v in DependenciesMapping(dsk3).items()}
return dsk3, dependencies, annotations_by_type
dependencies = {k: set(v) for k, v in DependenciesMapping(dsk2).items()}
return dsk2, dependencies, annotations_by_type


def _cull(dsk: dict[Key, GraphNode], keys: set[Key]) -> dict[Key, GraphNode]:
Expand Down
14 changes: 14 additions & 0 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5304,3 +5304,17 @@ async def before_close(self):
async def test_rootish_taskgroup_configuration(c, s, *workers):
assert s.rootish_tg_threshold == 10
assert s.rootish_tg_dependencies_threshold == 15


@gen_cluster(client=True, nthreads=[("", 1)])
async def test_alias_resolving_break_queuing(c, s, a):
pytest.importorskip("numpy")
import dask.array as da

arr = da.random.random((90, 100), chunks=(10, 50))
result = arr.rechunk(((10, 7, 7, 6) * 3, (50, 50)))
result = result.sum(split_every=1000)
x = result.persist()
while not s.tasks:
await asyncio.sleep(0.01)
assert sum([s.is_rootish(v) for v in s.tasks.values()]) == 18

0 comments on commit 53a5679

Please sign in to comment.