Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exception when executing task graph with HighLevelGraph after a shuffle #1129

Open
TomAugspurger opened this issue Aug 28, 2024 · 3 comments

Comments

@TomAugspurger
Copy link
Member

TomAugspurger commented Aug 28, 2024

Describe the issue:

In geopandas/dask-geopandas#303, dask-geopandas has a report that its custom sjoin method fails with a TypeError under some conditions. Internally, that method constructs a HighLevelGraph.

The observed failure is a TypeError raised by geopandas because dask fails to substitute the concrete (geo)DataFrame for the key (name, partition_number) when executing the task graph.

I've managed to produce a dask / dask-expr only version:

Minimal Complete Verifiable Example:

import dask.dataframe as dd
import pandas as pd
import dask
import dask_expr

dask.config.set(scheduler="single-threaded")


l1 = dd.from_pandas(pd.DataFrame({"a": [1, 2], "b": [0, 0]}), npartitions=1).shuffle("a")
r1 = dd.from_pandas(pd.DataFrame({"a": [1, 3], "c": [1, 1]}), npartitions=1).shuffle("a")


def func(left, right):
    assert isinstance(left, pd.DataFrame), f"Wrong type {left}"
    return pd.merge(left, right, how="inner")


dsk = {}
name = "myjoin-" + dask.base.tokenize(l1, r1, ensure_deterministic=True)
for i in range(1):
    dsk[(name, i)] = (
        func,
        (l1._name, i),
        (r1._name, i),
    )

divisions = [None] * (len(dsk) + 1)
g1 = dask.highlevelgraph.HighLevelGraph.from_collections(name, dsk, dependencies=[l1, r1])
x1 = dask_expr.from_graph(g1, func(l1._meta, r1._meta), divisions, dsk.keys(), "myjoin")
x1.compute()

That fails with

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[3], line 30
     28 g1 = dask.highlevelgraph.HighLevelGraph.from_collections(name, dsk, dependencies=[l1, r1])
     29 x1 = dask_expr.from_graph(g1, func(l1._meta, r1._meta), divisions, dsk.keys(), "myjoin")
---> 30 x1.compute()

File [~/gh/geopandas/dask-geopandas/.direnv/python-3.10/lib/python3.10/site-packages/dask_expr/_collection.py:477](http://127.0.0.1:8889/~/gh/geopandas/dask-geopandas/.direnv/python-3.10/lib/python3.10/site-packages/dask_expr/_collection.py#line=476), in FrameBase.compute(self, fuse, **kwargs)
    475     out = out.repartition(npartitions=1)
    476 out = out.optimize(fuse=fuse)
--> 477 return DaskMethodsMixin.compute(out, **kwargs)

File [~/gh/geopandas/dask-geopandas/.direnv/python-3.10/lib/python3.10/site-packages/dask/base.py:376](http://127.0.0.1:8889/~/gh/geopandas/dask-geopandas/.direnv/python-3.10/lib/python3.10/site-packages/dask/base.py#line=375), in DaskMethodsMixin.compute(self, **kwargs)
    352 def compute(self, **kwargs):
    353     """Compute this dask collection
    354 
    355     This turns a lazy Dask collection into its in-memory equivalent.
   (...)
    374     dask.compute
    375     """
--> 376     (result,) = compute(self, traverse=False, **kwargs)
    377     return result

File [~/gh/geopandas/dask-geopandas/.direnv/python-3.10/lib/python3.10/site-packages/dask/base.py:662](http://127.0.0.1:8889/~/gh/geopandas/dask-geopandas/.direnv/python-3.10/lib/python3.10/site-packages/dask/base.py#line=661), in compute(traverse, optimize_graph, scheduler, get, *args, **kwargs)
    659     postcomputes.append(x.__dask_postcompute__())
    661 with shorten_traceback():
--> 662     results = schedule(dsk, keys, **kwargs)
    664 return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])

Cell In[3], line 14, in func(left, right)
     13 def func(left, right):
---> 14     assert isinstance(left, pd.DataFrame), f"Wrong type {left}"
     15     return pd.merge(left, right, how="inner")

AssertionError: Wrong type ('rearrangebycolumn-8e5135a4ed46029cf9bb3c4bc7eff32e', 0)

As I mentioned, the shuffle there is important. Without that shuffle, things work fine.

Anything else we need to know?:

I'll take a look at this today.

Environment:

  • Dask version: '2024.8.1'
  • dask-expr version: '1.1.11'
@phofl
Copy link
Collaborator

phofl commented Aug 28, 2024

Thanks for the report. This is not unexpected, the optimiser will most likely change l1._name and r1._name because we are changing the backing expression.

You can simulate this if you need the output keys by

r1 = r1.optimize()
l1 = l1.optimize()

This is roughly what happens internally in dask-expr (we are not calling optimise but .lower() to create an expression that we can execute (you can see this in the __dask_graph__ method), optimise is preferable though

@TomAugspurger
Copy link
Member Author

Gotcha, make sense.

Do you have any initial guesses on a good path forward? I was thinking to try to rewrite that dask-geopandas method as a new dask_expr.Expr, with the hope that it would integrate into the whole optimization.

A potential workaround in dask-geopandas is to call optimize() on the inputs before constructing the HLG. It fixes the reported error, but I'm not sure how robust this will be.

@phofl
Copy link
Collaborator

phofl commented Aug 28, 2024

A potential workaround in dask-geopandas is to call optimize() on the inputs before constructing the HLG. It fixes the reported error, but I'm not sure how robust this will be.

This is definitely what I would recommend as long as you are using HLGs for this. Otherwise optimisation won't be triggered at all, which might leave stuff on the table.

Do you have any initial guesses on a good path forward? I was thinking to try to rewrite that dask-geopandas method as a new dask_expr.Expr, with the hope that it would integrate into the whole optimization.

When you create an expression for this, then you basically pass in left and right expressions and the expression tree will keep the references to the correct changed expression. The graph is only constructed after everything else was done (meaning optimisations), so if you implement the layer method on the expression, keys won't change anymore. Does this help? Happy to help you through the process of creating the expression

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants