diff --git a/merlin/dag/executors.py b/merlin/dag/executors.py index 2ff6cec92..a0196844d 100644 --- a/merlin/dag/executors.py +++ b/merlin/dag/executors.py @@ -44,6 +44,7 @@ def transform( output_dtypes=None, additional_columns=None, capture_dtypes=False, + validate_dtypes=True, ): """ Transforms a single dataframe (possibly a partition of a Dask Dataframe) @@ -67,11 +68,13 @@ def transform( output_data = None for node in nodes: - input_data = self._build_input_data(node, transformable, capture_dtypes=capture_dtypes) + input_data = self._build_input_data( + node, transformable, capture_dtypes=capture_dtypes, validate_dtypes=validate_dtypes + ) if node.op: transformed_data = self._transform_data( - node, input_data, capture_dtypes=capture_dtypes + node, input_data, capture_dtypes=capture_dtypes, validate_dtypes=validate_dtypes ) else: transformed_data = input_data @@ -85,7 +88,7 @@ def transform( return output_data - def _build_input_data(self, node, transformable, capture_dtypes=False): + def _build_input_data(self, node, transformable, capture_dtypes=False, validate_dtypes=True): """ Recurse through the graph executing parent and dependency operators to form the input dataframe for each output node @@ -114,7 +117,12 @@ def _build_input_data(self, node, transformable, capture_dtypes=False): for parent in node.parents_with_dependencies: parent_output_cols = _get_unique(parent.output_schema.column_names) - parent_data = self.transform(transformable, [parent], capture_dtypes=capture_dtypes) + parent_data = self.transform( + transformable, + [parent], + capture_dtypes=capture_dtypes, + validate_dtypes=validate_dtypes, + ) if input_data is None or not len(input_data): input_data = parent_data[parent_output_cols] seen_columns = set(parent_output_cols) @@ -141,7 +149,7 @@ def _build_input_data(self, node, transformable, capture_dtypes=False): return input_data - def _transform_data(self, node, input_data, capture_dtypes=False): + def _transform_data(self, node, input_data, capture_dtypes=False, validate_dtypes=True): """ Run the transform represented by the final node in the graph and check output dtypes against the output schema @@ -171,41 +179,45 @@ def _transform_data(self, node, input_data, capture_dtypes=False): output_data = node.op.transform(selection, input_data) # Update or validate output_data dtypes - for col_name, output_col_schema in node.output_schema.column_schemas.items(): - col_series = output_data[col_name] - col_dtype = col_series.dtype - is_list = is_list_dtype(col_series) - - if is_list: - col_dtype = list_val_dtype(col_series) - - # TODO: Add a utility that condenses the known methods of fetching dtypes - # from series/arrays into a single function, so that Tensorflow specific - # code doesn't leak into the executors - if not hasattr(col_dtype, "as_numpy_dtype") and hasattr(col_series, "numpy"): - col_dtype = col_series[0].cpu().numpy().dtype - - output_data_schema = output_col_schema.with_dtype(col_dtype, is_list=is_list) - - if capture_dtypes: - node.output_schema.column_schemas[col_name] = output_data_schema - elif len(output_data): - # Validate that the dtypes match but only if they both exist - # (since schemas may not have all dtypes specified, especially - # in the tests) - if ( - output_col_schema.dtype - and output_data_schema.dtype - and output_col_schema.dtype.without_shape != md.string - and output_col_schema.dtype.without_shape - != output_data_schema.dtype.without_shape + if capture_dtypes or validate_dtypes: + for col_name, output_col_schema in node.output_schema.column_schemas.items(): + col_series = output_data[col_name] + output_data_dtype = col_series.dtype + is_list = is_list_dtype(col_series) + + if is_list: + output_data_dtype = list_val_dtype(col_series) + + # TODO: Add a utility that condenses the known methods of fetching dtypes + # from series/arrays into a single function, so that Tensorflow specific + # code doesn't leak into the executors + if not hasattr(output_data_dtype, "as_numpy_dtype") and hasattr( + col_series, "numpy" ): - raise TypeError( - f"Dtype discrepancy detected for column {col_name}: " - f"operator {node.op.label} reported dtype " - f"`{output_col_schema.dtype}` but returned dtype " - f"`{output_data_schema.dtype}`." + output_data_dtype = col_series[0].cpu().numpy().dtype + + if capture_dtypes: + node.output_schema.column_schemas[col_name] = output_col_schema.with_dtype( + output_data_dtype, is_list=is_list ) + elif validate_dtypes and len(output_data): + # Validate that the dtypes match but only if they both exist + # (since schemas may not have all dtypes specified, especially + # in the tests) + output_schema_dtype = output_col_schema.dtype.without_shape + output_data_dtype = md.dtype(output_data_dtype).without_shape + if ( + output_schema_dtype + and output_data_dtype + and output_schema_dtype != md.string + and output_schema_dtype != output_data_dtype + ): + raise TypeError( + f"Dtype discrepancy detected for column {col_name}: " + f"operator {node.op.label} reported dtype " + f"`{output_schema_dtype}` but returned dtype " + f"`{output_data_dtype}`." + ) except Exception: LOG.exception("Failed to transform operator %s", node.op) raise diff --git a/merlin/dtypes/base.py b/merlin/dtypes/base.py index 05ea5ed26..d486089bf 100644 --- a/merlin/dtypes/base.py +++ b/merlin/dtypes/base.py @@ -172,4 +172,7 @@ def without_shape(self): DType A copy of this object with the shape removed """ - return self.with_shape(Shape()) + if self.shape.dims is None: + return self + + return replace(self, shape=Shape()) diff --git a/merlin/dtypes/mapping.py b/merlin/dtypes/mapping.py index 7c7aef29e..68e68a6a0 100644 --- a/merlin/dtypes/mapping.py +++ b/merlin/dtypes/mapping.py @@ -167,10 +167,7 @@ def _matches(self, dtype, mapping, base_class=None): # Some external dtype objects are not hashable, so they # can't be used as dictionary keys. In that case, match # against the dtype class instead. - hashable_dtype = dtype try: - hash(dtype) + return dtype in mapping.keys() except TypeError: - hashable_dtype = type(dtype) - - return hashable_dtype in mapping.keys() + return type(dtype) in mapping.keys()