Skip to content

Commit

Permalink
Adjust the DaskExecutor API methods to take Datasets instead of d…
Browse files Browse the repository at this point in the history
…dfs (NVIDIA-Merlin#299)

* Adjust the `DaskExecutor` API methods to take `Dataset`s instead of ddfs

* Update docstrings to reflect changing Dask dataframes to Merlin datasets
  • Loading branch information
karlhigley authored May 1, 2023
1 parent fc67eec commit 7c66938
Showing 1 changed file with 44 additions and 45 deletions.
89 changes: 44 additions & 45 deletions merlin/dag/executors.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
)
from merlin.dag import ColumnSelector, Graph, Node
from merlin.dtypes.shape import DefaultShapes
from merlin.io import Dataset
from merlin.io.worker import clean_worker_cache

LOG = logging.getLogger("merlin")
Expand All @@ -51,20 +52,7 @@ def transform(
Transforms a single dataframe (possibly a partition of a Dask Dataframe)
by applying the operators from a collection of Nodes
"""
nodes = []
if isinstance(graph, Graph):
nodes.append(graph.output_node)
elif isinstance(graph, Node):
nodes.append(graph)
elif isinstance(graph, list):
nodes = graph
else:
raise TypeError(
f"LocalExecutor detected unsupported type of input for graph: {type(graph)}."
" `graph` argument must be either a `Graph` object (preferred)"
" or a list of `Node` objects (deprecated, but supported for backward "
" compatibility.)"
)
nodes = self._output_nodes(graph)

# There's usually only one node, but it's possibly to pass multiple nodes for `fit`
# If we have multiple, we concatenate their outputs into a single transformable
Expand All @@ -83,6 +71,24 @@ def transform(

return output_data

def _output_nodes(self, graph):
nodes = []
if isinstance(graph, Graph):
nodes.append(graph.output_node)
elif isinstance(graph, Node):
nodes.append(graph)
elif isinstance(graph, list):
nodes = graph
else:
raise TypeError(
f"LocalExecutor detected unsupported type of input for graph: {type(graph)}."
" `graph` argument must be either a `Graph` object (preferred)"
" or a list of `Node` objects (deprecated, but supported for backward "
" compatibility.)"
)

return nodes

def _execute_node(self, node, transformable, capture_dtypes=False, strict=False):
upstream_outputs = self._run_upstream_transforms(
node, transformable, capture_dtypes, strict
Expand Down Expand Up @@ -263,7 +269,7 @@ def __getstate__(self):

def transform(
self,
ddf,
dataset,
graph,
output_dtypes=None,
additional_columns=None,
Expand All @@ -274,23 +280,12 @@ def transform(
Transforms all partitions of a Dask Dataframe by applying the operators
from a collection of Nodes
"""
nodes = []
if isinstance(graph, Graph):
nodes.append(graph.output_node)
elif isinstance(graph, Node):
nodes.append(graph)
elif isinstance(graph, list):
nodes = graph
else:
raise TypeError(
f"DaskExecutor detected unsupported type of input for graph: {type(graph)}."
" `graph` argument must be either a `Graph` object (preferred)"
" or a list of `Node` objects (deprecated, but supported for backward"
" compatibility.)"
)
nodes = self._executor._output_nodes(graph)

self._clear_worker_cache()

ddf = dataset.to_ddf()

# Check if we are only selecting columns (no transforms).
# If so, we should perform column selection at the ddf level.
# Otherwise, Dask will not push the column selection into the
Expand Down Expand Up @@ -323,27 +318,31 @@ def transform(
# don't require dtype information on the DDF this doesn't matter all that much
output_dtypes = type(ddf._meta)({k: [] for k in columns})

return ensure_optimize_dataframe_graph(
ddf=ddf.map_partitions(
self._executor.transform,
nodes,
additional_columns=additional_columns,
capture_dtypes=capture_dtypes,
strict=strict,
meta=output_dtypes,
enforce_metadata=False,
return Dataset(
ensure_optimize_dataframe_graph(
ddf=ddf.map_partitions(
self._executor.transform,
nodes,
additional_columns=additional_columns,
capture_dtypes=capture_dtypes,
strict=strict,
meta=output_dtypes,
enforce_metadata=False,
)
)
)

def fit(self, ddf, nodes, strict=False):
"""Calculates statistics for a set of nodes on the input dataframe
def fit(self, dataset: Dataset, graph, strict=False):
"""Calculates statistics for a set of nodes on the input dataset
Parameters
-----------
ddf: dask.Dataframe
The input dataframe to calculate statistics for. If there is a
dataset: merlin.io.Dataset
The input dataset to calculate statistics for. If there is a
train/test split this should be the training dataset only.
"""
nodes = self._executor._output_nodes(graph)

stats = []
for node in nodes:
# Check for additional input columns that aren't generated by parents
Expand All @@ -357,16 +356,16 @@ def fit(self, ddf, nodes, strict=False):

# apply transforms necessary for the inputs to the current column group, ignoring
# the transforms from the statop itself
transformed_ddf = self.transform(
ddf,
transformed = self.transform(
dataset,
node.parents_with_dependencies,
additional_columns=addl_input_cols,
capture_dtypes=True,
strict=strict,
)

try:
stats.append(node.op.fit(node.input_columns, transformed_ddf))
stats.append(node.op.fit(node.input_columns, transformed))
except Exception:
LOG.exception("Failed to fit operator %s", node.op)
raise
Expand Down

0 comments on commit 7c66938

Please sign in to comment.