Skip to content

Commit

Permalink
Make output of transform same type as input (#350)
Browse files Browse the repository at this point in the history
* add changes to support subgraph in multistage example

* fix linting issue around parameter rename

* add final convert format to executors to convert output to input type

* remove unneeded daskexecutor change
  • Loading branch information
jperez999 authored Jun 22, 2023
1 parent 9260b69 commit 61d65d3
Showing 1 changed file with 13 additions and 2 deletions.
15 changes: 13 additions & 2 deletions merlin/dag/executors.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,15 @@ class LocalExecutor:
def __init__(self, device=Device.GPU):
self.device = device if HAS_GPU else Device.CPU

@property
def target_format(self):
return (
DataFormats.PANDAS_DATAFRAME
| DataFormats.CUDF_DATAFRAME
| DataFormats.NUMPY_TENSOR_TABLE
| DataFormats.CUPY_TENSOR_TABLE
)

def transform(
self,
transformable,
Expand All @@ -60,6 +69,7 @@ def transform(
additional_columns=None,
capture_dtypes=False,
strict=False,
target_format=None,
):
"""
Transforms a single dataframe (possibly a partition of a Dask Dataframe)
Expand All @@ -79,7 +89,7 @@ def transform(
" or a list of `Node` objects (deprecated, but supported for backward "
" compatibility.)"
)

target_format = target_format or self.target_format
# There's usually only one node, but it's possibly to pass multiple nodes for `fit`
# If we have multiple, we concatenate their outputs into a single transformable
output_data = None
Expand All @@ -95,7 +105,7 @@ def transform(
[output_data, transformable[_get_unique(additional_columns)]]
)

return output_data
return _convert_format(output_data, _data_format(transformable))

def _execute_node(self, node, transformable, capture_dtypes=False, strict=False):
upstream_outputs = self._run_upstream_transforms(
Expand All @@ -105,6 +115,7 @@ def _execute_node(self, node, transformable, capture_dtypes=False, strict=False)
formatted_columns = self._standardize_formats(node, upstream_columns)
transform_input = self._merge_upstream_columns(formatted_columns)
transform_output = self._run_node_transform(node, transform_input, capture_dtypes, strict)
transform_output = _convert_format(transform_output, self.target_format)
return transform_output

def _run_upstream_transforms(self, node, transformable, capture_dtypes=False, strict=False):
Expand Down

0 comments on commit 61d65d3

Please sign in to comment.