-
Notifications
You must be signed in to change notification settings - Fork 916
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Introduce basic "cudf" backend for Dask Expressions (#14805)
Mostly addresses #15027 dask/dask-expr#728 exposed the necessary mechanisms for us to define a custom dask-expr backend for `cudf`. The new dispatching mechanisms are effectively the same as those in `dask.dataframe`. The only difference is that we are now registering/implementing "expression-based" collections. This PR does the following: - Defines a basic `DataFrameBackendEntrypoint` class for collection creation, and registers new collections using `get_collection_type`. - Refactors the `dask_cudf` import structure to properly support the `"dataframe.query-planning"` configuration. - Modifies CI to test dask-expr support for some of the `dask_cudf` tests. This coverage can be expanded in follow-up work. ~**Experimental Change**: This PR patches `dask_expr._expr.Expr.__new__` to enable type-based dispatching. This effectively allows us to surgically replace problematic `Expr` subclasses that do not work for cudf-backed data. For example, this PR replaces the upstream `TakeLast` expression to avoid using `squeeze` (since this method is not supported by cudf). This particular fix can be moved upstream relatively easily. However, having this kind of "patching" mechanism may be valuable for more complicated pandas/cudf discrepancies.~ ## Usage example ```python from dask import config config.set({"dataframe.query-planning": True}) import dask_cudf df = dask_cudf.DataFrame.from_dict( {"x": range(100), "y": [1, 2, 3, 4] * 25, "z": ["1", "2"] * 50}, npartitions=10, ) df["y2"] = df["x"] + df["y"] agg = df.groupby("y").agg({"y2": "mean"})["y2"] agg.simplify().pprint() ``` Dask cuDF should now be using dask-expr for "query planning": ``` Projection: columns='y2' GroupbyAggregation: arg={'y2': 'mean'} observed=True split_out=1'y' Assign: y2= Projection: columns=['y'] FromPandas: frame='<dataframe>' npartitions=10 columns=['x', 'y'] Add: Projection: columns='x' FromPandas: frame='<dataframe>' npartitions=10 columns=['x', 'y'] Projection: columns='y' FromPandas: frame='<dataframe>' npartitions=10 columns=['x', 'y'] ``` ## TODO - [x] Add basic tests - [x] Confirm that general design makes sense **Follow Up Work**: - Expand dask-expr test coverage - Fix local and upstream bugs - Add documentation once "critical mass" is reached Authors: - Richard (Rick) Zamora (https://github.com/rjzamora) - Lawrence Mitchell (https://github.com/wence-) - Vyas Ramasubramani (https://github.com/vyasr) - Bradley Dice (https://github.com/bdice) Approvers: - Lawrence Mitchell (https://github.com/wence-) - Ray Douglass (https://github.com/raydouglass) URL: #14805
- Loading branch information
Showing
24 changed files
with
545 additions
and
123 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,29 +1,75 @@ | ||
# Copyright (c) 2018-2023, NVIDIA CORPORATION. | ||
# Copyright (c) 2018-2024, NVIDIA CORPORATION. | ||
|
||
from dask import config | ||
|
||
# For dask>2024.2.0, we can silence the loud deprecation | ||
# warning before importing `dask.dataframe` (this won't | ||
# do anything for dask==2024.2.0) | ||
config.set({"dataframe.query-planning-warning": False}) | ||
|
||
import dask.dataframe as dd | ||
from dask.dataframe import from_delayed | ||
|
||
import cudf | ||
|
||
from . import backends | ||
from ._version import __git_commit__, __version__ | ||
from .core import DataFrame, Series, concat, from_cudf, from_dask_dataframe | ||
from .groupby import groupby_agg | ||
from .io import read_csv, read_json, read_orc, read_text, to_orc | ||
from .core import concat, from_cudf, from_dask_dataframe | ||
from .expr import QUERY_PLANNING_ON | ||
|
||
|
||
def read_csv(*args, **kwargs): | ||
with config.set({"dataframe.backend": "cudf"}): | ||
return dd.read_csv(*args, **kwargs) | ||
|
||
|
||
def read_json(*args, **kwargs): | ||
with config.set({"dataframe.backend": "cudf"}): | ||
return dd.read_json(*args, **kwargs) | ||
|
||
|
||
def read_orc(*args, **kwargs): | ||
with config.set({"dataframe.backend": "cudf"}): | ||
return dd.read_orc(*args, **kwargs) | ||
|
||
|
||
def read_parquet(*args, **kwargs): | ||
with config.set({"dataframe.backend": "cudf"}): | ||
return dd.read_parquet(*args, **kwargs) | ||
|
||
|
||
def raise_not_implemented_error(attr_name): | ||
def inner_func(*args, **kwargs): | ||
raise NotImplementedError( | ||
f"Top-level {attr_name} API is not available for dask-expr." | ||
) | ||
|
||
return inner_func | ||
|
||
|
||
if QUERY_PLANNING_ON: | ||
from .expr._collection import DataFrame, Index, Series | ||
|
||
groupby_agg = raise_not_implemented_error("groupby_agg") | ||
read_text = raise_not_implemented_error("read_text") | ||
to_orc = raise_not_implemented_error("to_orc") | ||
else: | ||
from .core import DataFrame, Index, Series | ||
from .groupby import groupby_agg | ||
from .io import read_text, to_orc | ||
|
||
try: | ||
from .io import read_parquet | ||
except ImportError: | ||
pass | ||
|
||
__all__ = [ | ||
"DataFrame", | ||
"Series", | ||
"Index", | ||
"from_cudf", | ||
"from_dask_dataframe", | ||
"concat", | ||
"from_delayed", | ||
] | ||
|
||
|
||
if not hasattr(cudf.DataFrame, "mean"): | ||
cudf.DataFrame.mean = None | ||
del cudf |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. | ||
|
||
from dask import config | ||
|
||
# Check if dask-dataframe is using dask-expr. | ||
# For dask>=2024.3.0, a null value will default to True | ||
QUERY_PLANNING_ON = config.get("dataframe.query-planning", None) is not False | ||
|
||
# Register custom expressions and collections | ||
try: | ||
import dask_cudf.expr._collection | ||
import dask_cudf.expr._expr | ||
|
||
except ImportError as err: | ||
if QUERY_PLANNING_ON: | ||
# Dask *should* raise an error before this. | ||
# However, we can still raise here to be certain. | ||
raise RuntimeError( | ||
"Failed to register the 'cudf' backend for dask-expr." | ||
" Please make sure you have dask-expr installed.\n" | ||
f"Error Message: {err}" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. | ||
|
||
from functools import cached_property | ||
|
||
from dask_expr import ( | ||
DataFrame as DXDataFrame, | ||
FrameBase, | ||
Index as DXIndex, | ||
Series as DXSeries, | ||
get_collection_type, | ||
) | ||
from dask_expr._collection import new_collection | ||
from dask_expr._util import _raise_if_object_series | ||
|
||
from dask import config | ||
from dask.dataframe.core import is_dataframe_like | ||
|
||
import cudf | ||
|
||
## | ||
## Custom collection classes | ||
## | ||
|
||
|
||
# VarMixin can be removed if cudf#15179 is addressed. | ||
# See: https://github.com/rapidsai/cudf/issues/15179 | ||
class VarMixin: | ||
def var( | ||
self, | ||
axis=0, | ||
skipna=True, | ||
ddof=1, | ||
numeric_only=False, | ||
split_every=False, | ||
**kwargs, | ||
): | ||
_raise_if_object_series(self, "var") | ||
axis = self._validate_axis(axis) | ||
self._meta.var(axis=axis, skipna=skipna, numeric_only=numeric_only) | ||
frame = self | ||
if is_dataframe_like(self._meta) and numeric_only: | ||
# Convert to pandas - cudf does something weird here | ||
index = self._meta.to_pandas().var(numeric_only=True).index | ||
frame = frame[list(index)] | ||
return new_collection( | ||
frame.expr.var( | ||
axis, skipna, ddof, numeric_only, split_every=split_every | ||
) | ||
) | ||
|
||
|
||
class DataFrame(VarMixin, DXDataFrame): | ||
@classmethod | ||
def from_dict(cls, *args, **kwargs): | ||
with config.set({"dataframe.backend": "cudf"}): | ||
return DXDataFrame.from_dict(*args, **kwargs) | ||
|
||
def groupby( | ||
self, | ||
by, | ||
group_keys=True, | ||
sort=None, | ||
observed=None, | ||
dropna=None, | ||
**kwargs, | ||
): | ||
from dask_cudf.expr._groupby import GroupBy | ||
|
||
if isinstance(by, FrameBase) and not isinstance(by, DXSeries): | ||
raise ValueError( | ||
f"`by` must be a column name or list of columns, got {by}." | ||
) | ||
|
||
return GroupBy( | ||
self, | ||
by, | ||
group_keys=group_keys, | ||
sort=sort, | ||
observed=observed, | ||
dropna=dropna, | ||
**kwargs, | ||
) | ||
|
||
|
||
class Series(VarMixin, DXSeries): | ||
def groupby(self, by, **kwargs): | ||
from dask_cudf.expr._groupby import SeriesGroupBy | ||
|
||
return SeriesGroupBy(self, by, **kwargs) | ||
|
||
@cached_property | ||
def list(self): | ||
from dask_cudf.accessors import ListMethods | ||
|
||
return ListMethods(self) | ||
|
||
@cached_property | ||
def struct(self): | ||
from dask_cudf.accessors import StructMethods | ||
|
||
return StructMethods(self) | ||
|
||
|
||
class Index(DXIndex): | ||
pass # Same as pandas (for now) | ||
|
||
|
||
get_collection_type.register(cudf.DataFrame, lambda _: DataFrame) | ||
get_collection_type.register(cudf.Series, lambda _: Series) | ||
get_collection_type.register(cudf.BaseIndex, lambda _: Index) |
Oops, something went wrong.