Skip to content

Commit

Permalink
Revert "Adjust the DaskExecutor API methods to take Datasets inst…
Browse files Browse the repository at this point in the history
…ead of ddfs (#299)" (#306)

This reverts commit 7c66938.
  • Loading branch information
karlhigley authored May 2, 2023
1 parent 02b3e32 commit d742664
Showing 1 changed file with 45 additions and 44 deletions.
89 changes: 45 additions & 44 deletions merlin/dag/executors.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
)
from merlin.dag import ColumnSelector, Graph, Node
from merlin.dtypes.shape import DefaultShapes
from merlin.io import Dataset
from merlin.io.worker import clean_worker_cache

LOG = logging.getLogger("merlin")
Expand All @@ -52,7 +51,20 @@ def transform(
Transforms a single dataframe (possibly a partition of a Dask Dataframe)
by applying the operators from a collection of Nodes
"""
nodes = self._output_nodes(graph)
nodes = []
if isinstance(graph, Graph):
nodes.append(graph.output_node)
elif isinstance(graph, Node):
nodes.append(graph)
elif isinstance(graph, list):
nodes = graph
else:
raise TypeError(
f"LocalExecutor detected unsupported type of input for graph: {type(graph)}."
" `graph` argument must be either a `Graph` object (preferred)"
" or a list of `Node` objects (deprecated, but supported for backward "
" compatibility.)"
)

# There's usually only one node, but it's possibly to pass multiple nodes for `fit`
# If we have multiple, we concatenate their outputs into a single transformable
Expand All @@ -71,24 +83,6 @@ def transform(

return output_data

def _output_nodes(self, graph):
nodes = []
if isinstance(graph, Graph):
nodes.append(graph.output_node)
elif isinstance(graph, Node):
nodes.append(graph)
elif isinstance(graph, list):
nodes = graph
else:
raise TypeError(
f"LocalExecutor detected unsupported type of input for graph: {type(graph)}."
" `graph` argument must be either a `Graph` object (preferred)"
" or a list of `Node` objects (deprecated, but supported for backward "
" compatibility.)"
)

return nodes

def _execute_node(self, node, transformable, capture_dtypes=False, strict=False):
upstream_outputs = self._run_upstream_transforms(
node, transformable, capture_dtypes, strict
Expand Down Expand Up @@ -269,7 +263,7 @@ def __getstate__(self):

def transform(
self,
dataset,
ddf,
graph,
output_dtypes=None,
additional_columns=None,
Expand All @@ -280,12 +274,23 @@ def transform(
Transforms all partitions of a Dask Dataframe by applying the operators
from a collection of Nodes
"""
nodes = self._executor._output_nodes(graph)
nodes = []
if isinstance(graph, Graph):
nodes.append(graph.output_node)
elif isinstance(graph, Node):
nodes.append(graph)
elif isinstance(graph, list):
nodes = graph
else:
raise TypeError(
f"DaskExecutor detected unsupported type of input for graph: {type(graph)}."
" `graph` argument must be either a `Graph` object (preferred)"
" or a list of `Node` objects (deprecated, but supported for backward"
" compatibility.)"
)

self._clear_worker_cache()

ddf = dataset.to_ddf()

# Check if we are only selecting columns (no transforms).
# If so, we should perform column selection at the ddf level.
# Otherwise, Dask will not push the column selection into the
Expand Down Expand Up @@ -318,31 +323,27 @@ def transform(
# don't require dtype information on the DDF this doesn't matter all that much
output_dtypes = type(ddf._meta)({k: [] for k in columns})

return Dataset(
ensure_optimize_dataframe_graph(
ddf=ddf.map_partitions(
self._executor.transform,
nodes,
additional_columns=additional_columns,
capture_dtypes=capture_dtypes,
strict=strict,
meta=output_dtypes,
enforce_metadata=False,
)
return ensure_optimize_dataframe_graph(
ddf=ddf.map_partitions(
self._executor.transform,
nodes,
additional_columns=additional_columns,
capture_dtypes=capture_dtypes,
strict=strict,
meta=output_dtypes,
enforce_metadata=False,
)
)

def fit(self, dataset: Dataset, graph, strict=False):
"""Calculates statistics for a set of nodes on the input dataset
def fit(self, ddf, nodes, strict=False):
"""Calculates statistics for a set of nodes on the input dataframe
Parameters
-----------
dataset: merlin.io.Dataset
The input dataset to calculate statistics for. If there is a
ddf: dask.Dataframe
The input dataframe to calculate statistics for. If there is a
train/test split this should be the training dataset only.
"""
nodes = self._executor._output_nodes(graph)

stats = []
for node in nodes:
# Check for additional input columns that aren't generated by parents
Expand All @@ -356,16 +357,16 @@ def fit(self, dataset: Dataset, graph, strict=False):

# apply transforms necessary for the inputs to the current column group, ignoring
# the transforms from the statop itself
transformed = self.transform(
dataset,
transformed_ddf = self.transform(
ddf,
node.parents_with_dependencies,
additional_columns=addl_input_cols,
capture_dtypes=True,
strict=strict,
)

try:
stats.append(node.op.fit(node.input_columns, transformed))
stats.append(node.op.fit(node.input_columns, transformed_ddf))
except Exception:
LOG.exception("Failed to fit operator %s", node.op)
raise
Expand Down

0 comments on commit d742664

Please sign in to comment.