diff --git a/merlin/dag/executors.py b/merlin/dag/executors.py index 3bb48a1eb..c66a93bd4 100644 --- a/merlin/dag/executors.py +++ b/merlin/dag/executors.py @@ -28,7 +28,6 @@ ) from merlin.dag import ColumnSelector, Graph, Node from merlin.dtypes.shape import DefaultShapes -from merlin.io import Dataset from merlin.io.worker import clean_worker_cache LOG = logging.getLogger("merlin") @@ -52,7 +51,20 @@ def transform( Transforms a single dataframe (possibly a partition of a Dask Dataframe) by applying the operators from a collection of Nodes """ - nodes = self._output_nodes(graph) + nodes = [] + if isinstance(graph, Graph): + nodes.append(graph.output_node) + elif isinstance(graph, Node): + nodes.append(graph) + elif isinstance(graph, list): + nodes = graph + else: + raise TypeError( + f"LocalExecutor detected unsupported type of input for graph: {type(graph)}." + " `graph` argument must be either a `Graph` object (preferred)" + " or a list of `Node` objects (deprecated, but supported for backward " + " compatibility.)" + ) # There's usually only one node, but it's possibly to pass multiple nodes for `fit` # If we have multiple, we concatenate their outputs into a single transformable @@ -71,24 +83,6 @@ def transform( return output_data - def _output_nodes(self, graph): - nodes = [] - if isinstance(graph, Graph): - nodes.append(graph.output_node) - elif isinstance(graph, Node): - nodes.append(graph) - elif isinstance(graph, list): - nodes = graph - else: - raise TypeError( - f"LocalExecutor detected unsupported type of input for graph: {type(graph)}." - " `graph` argument must be either a `Graph` object (preferred)" - " or a list of `Node` objects (deprecated, but supported for backward " - " compatibility.)" - ) - - return nodes - def _execute_node(self, node, transformable, capture_dtypes=False, strict=False): upstream_outputs = self._run_upstream_transforms( node, transformable, capture_dtypes, strict @@ -269,7 +263,7 @@ def __getstate__(self): def transform( self, - dataset, + ddf, graph, output_dtypes=None, additional_columns=None, @@ -280,12 +274,23 @@ def transform( Transforms all partitions of a Dask Dataframe by applying the operators from a collection of Nodes """ - nodes = self._executor._output_nodes(graph) + nodes = [] + if isinstance(graph, Graph): + nodes.append(graph.output_node) + elif isinstance(graph, Node): + nodes.append(graph) + elif isinstance(graph, list): + nodes = graph + else: + raise TypeError( + f"DaskExecutor detected unsupported type of input for graph: {type(graph)}." + " `graph` argument must be either a `Graph` object (preferred)" + " or a list of `Node` objects (deprecated, but supported for backward" + " compatibility.)" + ) self._clear_worker_cache() - ddf = dataset.to_ddf() - # Check if we are only selecting columns (no transforms). # If so, we should perform column selection at the ddf level. # Otherwise, Dask will not push the column selection into the @@ -318,31 +323,27 @@ def transform( # don't require dtype information on the DDF this doesn't matter all that much output_dtypes = type(ddf._meta)({k: [] for k in columns}) - return Dataset( - ensure_optimize_dataframe_graph( - ddf=ddf.map_partitions( - self._executor.transform, - nodes, - additional_columns=additional_columns, - capture_dtypes=capture_dtypes, - strict=strict, - meta=output_dtypes, - enforce_metadata=False, - ) + return ensure_optimize_dataframe_graph( + ddf=ddf.map_partitions( + self._executor.transform, + nodes, + additional_columns=additional_columns, + capture_dtypes=capture_dtypes, + strict=strict, + meta=output_dtypes, + enforce_metadata=False, ) ) - def fit(self, dataset: Dataset, graph, strict=False): - """Calculates statistics for a set of nodes on the input dataset + def fit(self, ddf, nodes, strict=False): + """Calculates statistics for a set of nodes on the input dataframe Parameters ----------- - dataset: merlin.io.Dataset - The input dataset to calculate statistics for. If there is a + ddf: dask.Dataframe + The input dataframe to calculate statistics for. If there is a train/test split this should be the training dataset only. """ - nodes = self._executor._output_nodes(graph) - stats = [] for node in nodes: # Check for additional input columns that aren't generated by parents @@ -356,8 +357,8 @@ def fit(self, dataset: Dataset, graph, strict=False): # apply transforms necessary for the inputs to the current column group, ignoring # the transforms from the statop itself - transformed = self.transform( - dataset, + transformed_ddf = self.transform( + ddf, node.parents_with_dependencies, additional_columns=addl_input_cols, capture_dtypes=True, @@ -365,7 +366,7 @@ def fit(self, dataset: Dataset, graph, strict=False): ) try: - stats.append(node.op.fit(node.input_columns, transformed)) + stats.append(node.op.fit(node.input_columns, transformed_ddf)) except Exception: LOG.exception("Failed to fit operator %s", node.op) raise