Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Creating tables with a dask-cudf DataFrame #219

Closed
2 changes: 2 additions & 0 deletions dask_sql/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def create_table(
format: str = None,
persist: bool = True,
schema_name: str = None,
gpu: bool = False,
**kwargs,
):
"""
Expand Down Expand Up @@ -199,6 +200,7 @@ def create_table(
table_name=table_name,
format=format,
persist=persist,
gpu=gpu,
**kwargs,
)
self.schema[schema_name].tables[table_name.lower()] = dc
Expand Down
12 changes: 9 additions & 3 deletions dask_sql/input_utils/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def to_dc(
table_name: str,
format: str = None,
persist: bool = True,
gpu: bool = False,
**kwargs,
) -> DataContainer:
"""
Expand All @@ -45,7 +46,7 @@ def to_dc(
maybe persist them to cluster memory before.
"""
filled_get_dask_dataframe = lambda *args: cls._get_dask_dataframe(
*args, table_name=table_name, format=format, **kwargs,
*args, table_name=table_name, format=format, gpu=gpu, **kwargs,
)

if isinstance(input_item, list):
Expand All @@ -60,7 +61,12 @@ def to_dc(

@classmethod
def _get_dask_dataframe(
cls, input_item: InputType, table_name: str, format: str = None, **kwargs,
cls,
input_item: InputType,
table_name: str,
format: str = None,
gpu: bool = False,
**kwargs,
):
plugin_list = cls.get_plugins()

Expand All @@ -69,7 +75,7 @@ def _get_dask_dataframe(
input_item, table_name=table_name, format=format, **kwargs
):
return plugin.to_dc(
input_item, table_name=table_name, format=format, **kwargs
input_item, table_name=table_name, format=format, gpu=gpu, **kwargs
)

raise ValueError(f"Do not understand the input type {type(input_item)}")
14 changes: 12 additions & 2 deletions dask_sql/input_utils/intake.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,21 @@ def is_correct_input(
isinstance(input_item, intake.catalog.Catalog) or format == "intake"
)

def to_dc(self, input_item: Any, table_name: str, format: str = None, **kwargs):
def to_dc(
self,
input_item: Any,
table_name: str,
format: str = None,
gpu: bool = False,
**kwargs
):
table_name = kwargs.pop("intake_table_name", table_name)
catalog_kwargs = kwargs.pop("catalog_kwargs", {})

if isinstance(input_item, str):
input_item = intake.open_catalog(input_item, **catalog_kwargs)

return input_item[table_name].to_dask(**kwargs)
if gpu:
raise Exception("Intake does not support gpu")
else:
return input_item[table_name].to_dask(**kwargs)
16 changes: 14 additions & 2 deletions dask_sql/input_utils/location.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,14 @@ def is_correct_input(
):
return isinstance(input_item, str)

def to_dc(self, input_item: Any, table_name: str, format: str = None, **kwargs):
def to_dc(
self,
input_item: Any,
table_name: str,
format: str = None,
gpu: bool = False,
**kwargs,
):

if format == "memory":
client = default_client()
Expand All @@ -27,7 +34,12 @@ def to_dc(self, input_item: Any, table_name: str, format: str = None, **kwargs):
format = extension.lstrip(".")

try:
read_function = getattr(dd, f"read_{format}")
if gpu:
import dask_cudf

read_function = getattr(dask_cudf, f"read_{format}")
else:
read_function = getattr(dd, f"read_{format}")
except AttributeError:
raise AttributeError(f"Can not read files of format {format}")

Expand Down
19 changes: 17 additions & 2 deletions dask_sql/input_utils/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,21 @@ def is_correct_input(
):
return isinstance(input_item, pd.DataFrame) or format == "dask"

def to_dc(self, input_item, table_name: str, format: str = None, **kwargs):
def to_dc(
self,
input_item,
table_name: str,
format: str = None,
gpu: bool = False,
**kwargs,
):
npartitions = kwargs.pop("npartitions", 1)
return dd.from_pandas(input_item, npartitions=npartitions, **kwargs)
if gpu:
import cudf
import dask_cudf

return dask_cudf.from_cudf(
cudf.from_pandas(input_item), npartitions=npartitions, **kwargs,
)
else:
return dd.from_pandas(input_item, npartitions=npartitions, **kwargs)
2 changes: 2 additions & 0 deletions dask_sql/physical/rel/custom/create_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,13 @@ def convert(
except KeyError:
raise AttributeError("Parameters must include a 'location' parameter.")

gpu = kwargs.pop("gpu", False)
context.create_table(
table_name,
location,
format=format,
persist=persist,
schema_name=schema_name,
gpu=gpu,
**kwargs,
)