Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce the overhead of using LocalExecutor (esp. dtype validation) #219

Merged
merged 6 commits into from
Feb 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 50 additions & 38 deletions merlin/dag/executors.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def transform(
output_dtypes=None,
additional_columns=None,
capture_dtypes=False,
validate_dtypes=True,
):
"""
Transforms a single dataframe (possibly a partition of a Dask Dataframe)
Expand All @@ -67,11 +68,13 @@ def transform(
output_data = None

for node in nodes:
input_data = self._build_input_data(node, transformable, capture_dtypes=capture_dtypes)
input_data = self._build_input_data(
node, transformable, capture_dtypes=capture_dtypes, validate_dtypes=validate_dtypes
)

if node.op:
transformed_data = self._transform_data(
node, input_data, capture_dtypes=capture_dtypes
node, input_data, capture_dtypes=capture_dtypes, validate_dtypes=validate_dtypes
)
else:
transformed_data = input_data
Expand All @@ -85,7 +88,7 @@ def transform(

return output_data

def _build_input_data(self, node, transformable, capture_dtypes=False):
def _build_input_data(self, node, transformable, capture_dtypes=False, validate_dtypes=True):
"""
Recurse through the graph executing parent and dependency operators
to form the input dataframe for each output node
Expand Down Expand Up @@ -114,7 +117,12 @@ def _build_input_data(self, node, transformable, capture_dtypes=False):

for parent in node.parents_with_dependencies:
parent_output_cols = _get_unique(parent.output_schema.column_names)
parent_data = self.transform(transformable, [parent], capture_dtypes=capture_dtypes)
parent_data = self.transform(
transformable,
[parent],
capture_dtypes=capture_dtypes,
validate_dtypes=validate_dtypes,
)
if input_data is None or not len(input_data):
input_data = parent_data[parent_output_cols]
seen_columns = set(parent_output_cols)
Expand All @@ -141,7 +149,7 @@ def _build_input_data(self, node, transformable, capture_dtypes=False):

return input_data

def _transform_data(self, node, input_data, capture_dtypes=False):
def _transform_data(self, node, input_data, capture_dtypes=False, validate_dtypes=True):
"""
Run the transform represented by the final node in the graph
and check output dtypes against the output schema
Expand Down Expand Up @@ -171,41 +179,45 @@ def _transform_data(self, node, input_data, capture_dtypes=False):
output_data = node.op.transform(selection, input_data)

# Update or validate output_data dtypes
for col_name, output_col_schema in node.output_schema.column_schemas.items():
col_series = output_data[col_name]
col_dtype = col_series.dtype
is_list = is_list_dtype(col_series)

if is_list:
col_dtype = list_val_dtype(col_series)

# TODO: Add a utility that condenses the known methods of fetching dtypes
# from series/arrays into a single function, so that Tensorflow specific
# code doesn't leak into the executors
if not hasattr(col_dtype, "as_numpy_dtype") and hasattr(col_series, "numpy"):
col_dtype = col_series[0].cpu().numpy().dtype

output_data_schema = output_col_schema.with_dtype(col_dtype, is_list=is_list)

if capture_dtypes:
node.output_schema.column_schemas[col_name] = output_data_schema
elif len(output_data):
# Validate that the dtypes match but only if they both exist
# (since schemas may not have all dtypes specified, especially
# in the tests)
if (
output_col_schema.dtype
and output_data_schema.dtype
and output_col_schema.dtype.without_shape != md.string
and output_col_schema.dtype.without_shape
!= output_data_schema.dtype.without_shape
if capture_dtypes or validate_dtypes:
for col_name, output_col_schema in node.output_schema.column_schemas.items():
col_series = output_data[col_name]
output_data_dtype = col_series.dtype
is_list = is_list_dtype(col_series)

if is_list:
output_data_dtype = list_val_dtype(col_series)

# TODO: Add a utility that condenses the known methods of fetching dtypes
# from series/arrays into a single function, so that Tensorflow specific
# code doesn't leak into the executors
if not hasattr(output_data_dtype, "as_numpy_dtype") and hasattr(
col_series, "numpy"
):
raise TypeError(
f"Dtype discrepancy detected for column {col_name}: "
f"operator {node.op.label} reported dtype "
f"`{output_col_schema.dtype}` but returned dtype "
f"`{output_data_schema.dtype}`."
output_data_dtype = col_series[0].cpu().numpy().dtype

if capture_dtypes:
node.output_schema.column_schemas[col_name] = output_col_schema.with_dtype(
output_data_dtype, is_list=is_list
)
elif validate_dtypes and len(output_data):
# Validate that the dtypes match but only if they both exist
# (since schemas may not have all dtypes specified, especially
# in the tests)
output_schema_dtype = output_col_schema.dtype.without_shape
output_data_dtype = md.dtype(output_data_dtype).without_shape
if (
output_schema_dtype
and output_data_dtype
and output_schema_dtype != md.string
and output_schema_dtype != output_data_dtype
):
raise TypeError(
f"Dtype discrepancy detected for column {col_name}: "
f"operator {node.op.label} reported dtype "
f"`{output_schema_dtype}` but returned dtype "
f"`{output_data_dtype}`."
)
except Exception:
LOG.exception("Failed to transform operator %s", node.op)
raise
Expand Down
5 changes: 4 additions & 1 deletion merlin/dtypes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,4 +172,7 @@ def without_shape(self):
DType
A copy of this object with the shape removed
"""
return self.with_shape(Shape())
if self.shape.dims is None:
return self

return replace(self, shape=Shape())
7 changes: 2 additions & 5 deletions merlin/dtypes/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,7 @@ def _matches(self, dtype, mapping, base_class=None):
# Some external dtype objects are not hashable, so they
# can't be used as dictionary keys. In that case, match
# against the dtype class instead.
hashable_dtype = dtype
try:
hash(dtype)
return dtype in mapping.keys()
except TypeError:
hashable_dtype = type(dtype)

return hashable_dtype in mapping.keys()
return type(dtype) in mapping.keys()