Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Creating tables with a dask-cudf DataFrame #219

Closed
2 changes: 2 additions & 0 deletions dask_sql/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def create_table(
format: str = None,
persist: bool = True,
schema_name: str = None,
gpu: bool = False,
**kwargs,
):
"""
Expand Down Expand Up @@ -199,6 +200,7 @@ def create_table(
table_name=table_name,
format=format,
persist=persist,
gpu=gpu,
**kwargs,
)
self.schema[schema_name].tables[table_name.lower()] = dc
Expand Down
12 changes: 9 additions & 3 deletions dask_sql/input_utils/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def to_dc(
table_name: str,
format: str = None,
persist: bool = True,
gpu: bool = False,
**kwargs,
) -> DataContainer:
"""
Expand All @@ -45,7 +46,7 @@ def to_dc(
maybe persist them to cluster memory before.
"""
filled_get_dask_dataframe = lambda *args: cls._get_dask_dataframe(
*args, table_name=table_name, format=format, **kwargs,
*args, table_name=table_name, format=format, gpu=gpu, **kwargs,
)

if isinstance(input_item, list):
Expand All @@ -60,7 +61,12 @@ def to_dc(

@classmethod
def _get_dask_dataframe(
cls, input_item: InputType, table_name: str, format: str = None, **kwargs,
cls,
input_item: InputType,
table_name: str,
format: str = None,
gpu: bool = False,
**kwargs,
):
plugin_list = cls.get_plugins()

Expand All @@ -69,7 +75,7 @@ def _get_dask_dataframe(
input_item, table_name=table_name, format=format, **kwargs
):
return plugin.to_dc(
input_item, table_name=table_name, format=format, **kwargs
input_item, table_name=table_name, format=format, gpu=gpu, **kwargs
)

raise ValueError(f"Do not understand the input type {type(input_item)}")
16 changes: 14 additions & 2 deletions dask_sql/input_utils/location.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,14 @@ def is_correct_input(
):
return isinstance(input_item, str)

def to_dc(self, input_item: Any, table_name: str, format: str = None, **kwargs):
def to_dc(
self,
input_item: Any,
table_name: str,
format: str = None,
gpu: bool = False,
**kwargs,
):

if format == "memory":
client = default_client()
Expand All @@ -27,7 +34,12 @@ def to_dc(self, input_item: Any, table_name: str, format: str = None, **kwargs):
format = extension.lstrip(".")

try:
read_function = getattr(dd, f"read_{format}")
if gpu:
import dask_cudf

read_function = getattr(dask_cudf, f"read_{format}")
else:
read_function = getattr(dd, f"read_{format}")
except AttributeError:
raise AttributeError(f"Can not read files of format {format}")

Expand Down
6 changes: 6 additions & 0 deletions dask_sql/physical/rel/custom/create_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,17 @@ def convert(
except KeyError:
raise AttributeError("Parameters must include a 'location' parameter.")

try:
gpu = kwargs.pop("gpu")
except KeyError:
gpu = False

sarahyurick marked this conversation as resolved.
Show resolved Hide resolved
context.create_table(
table_name,
location,
format=format,
persist=persist,
schema_name=schema_name,
gpu=gpu,
**kwargs,
)