Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "Adjust the DaskExecutor API methods to take Datasets inst… #306

Merged
merged 1 commit into from
May 2, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 45 additions & 44 deletions merlin/dag/executors.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
)
from merlin.dag import ColumnSelector, Graph, Node
from merlin.dtypes.shape import DefaultShapes
from merlin.io import Dataset
from merlin.io.worker import clean_worker_cache

LOG = logging.getLogger("merlin")
Expand All @@ -52,7 +51,20 @@ def transform(
Transforms a single dataframe (possibly a partition of a Dask Dataframe)
by applying the operators from a collection of Nodes
"""
nodes = self._output_nodes(graph)
nodes = []
if isinstance(graph, Graph):
nodes.append(graph.output_node)
elif isinstance(graph, Node):
nodes.append(graph)
elif isinstance(graph, list):
nodes = graph
else:
raise TypeError(
f"LocalExecutor detected unsupported type of input for graph: {type(graph)}."
" `graph` argument must be either a `Graph` object (preferred)"
" or a list of `Node` objects (deprecated, but supported for backward "
" compatibility.)"
)

# There's usually only one node, but it's possibly to pass multiple nodes for `fit`
# If we have multiple, we concatenate their outputs into a single transformable
Expand All @@ -71,24 +83,6 @@ def transform(

return output_data

def _output_nodes(self, graph):
nodes = []
if isinstance(graph, Graph):
nodes.append(graph.output_node)
elif isinstance(graph, Node):
nodes.append(graph)
elif isinstance(graph, list):
nodes = graph
else:
raise TypeError(
f"LocalExecutor detected unsupported type of input for graph: {type(graph)}."
" `graph` argument must be either a `Graph` object (preferred)"
" or a list of `Node` objects (deprecated, but supported for backward "
" compatibility.)"
)

return nodes

def _execute_node(self, node, transformable, capture_dtypes=False, strict=False):
upstream_outputs = self._run_upstream_transforms(
node, transformable, capture_dtypes, strict
Expand Down Expand Up @@ -269,7 +263,7 @@ def __getstate__(self):

def transform(
self,
dataset,
ddf,
graph,
output_dtypes=None,
additional_columns=None,
Expand All @@ -280,12 +274,23 @@ def transform(
Transforms all partitions of a Dask Dataframe by applying the operators
from a collection of Nodes
"""
nodes = self._executor._output_nodes(graph)
nodes = []
if isinstance(graph, Graph):
nodes.append(graph.output_node)
elif isinstance(graph, Node):
nodes.append(graph)
elif isinstance(graph, list):
nodes = graph
else:
raise TypeError(
f"DaskExecutor detected unsupported type of input for graph: {type(graph)}."
" `graph` argument must be either a `Graph` object (preferred)"
" or a list of `Node` objects (deprecated, but supported for backward"
" compatibility.)"
)

self._clear_worker_cache()

ddf = dataset.to_ddf()

# Check if we are only selecting columns (no transforms).
# If so, we should perform column selection at the ddf level.
# Otherwise, Dask will not push the column selection into the
Expand Down Expand Up @@ -318,31 +323,27 @@ def transform(
# don't require dtype information on the DDF this doesn't matter all that much
output_dtypes = type(ddf._meta)({k: [] for k in columns})

return Dataset(
ensure_optimize_dataframe_graph(
ddf=ddf.map_partitions(
self._executor.transform,
nodes,
additional_columns=additional_columns,
capture_dtypes=capture_dtypes,
strict=strict,
meta=output_dtypes,
enforce_metadata=False,
)
return ensure_optimize_dataframe_graph(
ddf=ddf.map_partitions(
self._executor.transform,
nodes,
additional_columns=additional_columns,
capture_dtypes=capture_dtypes,
strict=strict,
meta=output_dtypes,
enforce_metadata=False,
)
)

def fit(self, dataset: Dataset, graph, strict=False):
"""Calculates statistics for a set of nodes on the input dataset
def fit(self, ddf, nodes, strict=False):
"""Calculates statistics for a set of nodes on the input dataframe

Parameters
-----------
dataset: merlin.io.Dataset
The input dataset to calculate statistics for. If there is a
ddf: dask.Dataframe
The input dataframe to calculate statistics for. If there is a
train/test split this should be the training dataset only.
"""
nodes = self._executor._output_nodes(graph)

stats = []
for node in nodes:
# Check for additional input columns that aren't generated by parents
Expand All @@ -356,16 +357,16 @@ def fit(self, dataset: Dataset, graph, strict=False):

# apply transforms necessary for the inputs to the current column group, ignoring
# the transforms from the statop itself
transformed = self.transform(
dataset,
transformed_ddf = self.transform(
ddf,
node.parents_with_dependencies,
additional_columns=addl_input_cols,
capture_dtypes=True,
strict=strict,
)

try:
stats.append(node.op.fit(node.input_columns, transformed))
stats.append(node.op.fit(node.input_columns, transformed_ddf))
except Exception:
LOG.exception("Failed to fit operator %s", node.op)
raise
Expand Down