Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a validate_schemas hook to clean up downstream validation code #76

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 57 additions & 1 deletion merlin/dag/base_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,26 @@ def compute_selector(
parents_selector: ColumnSelector,
dependencies_selector: ColumnSelector,
) -> ColumnSelector:
"""
Provides a hook method for sub-classes to override to implement
custom column selection logic.

Parameters
----------
input_schema : Schema
Schemas of the columns to apply this operator to
selector : ColumnSelector
Column selector to apply to the input schema
parents_selector : ColumnSelector
Combined selectors of the upstream parents feeding into this operator
dependencies_selector : ColumnSelector
Combined selectors of the upstream dependencies feeding into this operator

Returns
-------
ColumnSelector
Revised column selector to apply to the input schema
"""
self._validate_matching_cols(input_schema, selector, self.compute_selector.__name__)

return selector
Expand All @@ -61,6 +81,7 @@ def compute_input_schema(
) -> Schema:
"""Given the schemas coming from upstream sources and a column selector for the
input columns, returns a set of schemas for the input columns this operator will use

Parameters
-----------
root_schema: Schema
Expand All @@ -71,6 +92,7 @@ def compute_input_schema(
The combined schemas of the upstream dependencies feeding into this operator
col_selector: ColumnSelector
The column selector to apply to the input schema

Returns
-------
Schema
Expand All @@ -88,14 +110,17 @@ def compute_output_schema(
col_selector: ColumnSelector,
prev_output_schema: Schema = None,
) -> Schema:
"""Given a set of schemas and a column selector for the input columns,
"""
Given a set of schemas and a column selector for the input columns,
returns a set of schemas for the transformed columns this operator will produce

Parameters
-----------
input_schema: Schema
The schemas of the columns to apply this operator to
col_selector: ColumnSelector
The column selector to apply to the input schema

Returns
-------
Schema
Expand Down Expand Up @@ -131,6 +156,34 @@ def compute_output_schema(

return output_schema

def validate_schemas(
self,
parents_schema: Schema,
deps_schema: Schema,
input_schema: Schema,
output_schema: Schema,
strict_dtypes: bool = False,
):
"""
Provides a hook method that sub-classes can override to implement schema validation logic.

Sub-class implementations should raise an exception if the schemas are not valid for the
operations they implement.

Parameters
----------
parents_schema : Schema
The combined schemas of the upstream parents feeding into this operator
deps_schema : Schema
The combined schemas of the upstream dependencies feeding into this operator
input_schema : Schema
The schemas of the columns to apply this operator to
output_schema : Schema
The schemas of the columns produced by this operator
strict_dtypes : Boolean, optional
Enables strict checking for column dtype matching if True, by default False
"""

def column_mapping(self, col_selector):
column_mapping = {}
for col_name in col_selector.names:
Expand Down Expand Up @@ -207,10 +260,12 @@ def _validate_matching_cols(self, schema, selector, method_name):
def output_column_names(self, col_selector: ColumnSelector) -> ColumnSelector:
"""Given a set of columns names returns the names of the transformed columns this
operator will produce

Parameters
-----------
columns: list of str, or list of list of str
The columns to apply this operator to

Returns
-------
list of str, or list of list of str
Expand All @@ -222,6 +277,7 @@ def output_column_names(self, col_selector: ColumnSelector) -> ColumnSelector:
def dependencies(self) -> List[Union[str, Any]]:
"""Defines an optional list of column dependencies for this operator. This lets you consume columns
that aren't part of the main transformation workflow.

Returns
-------
str, list of str or ColumnSelector, optional
Expand Down
54 changes: 50 additions & 4 deletions merlin/dag/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,10 @@ def validate_schemas(self, root_schema: Schema, strict_dtypes: bool = False):
f"expected dtype '{col_schema.dtype}'."
)

self.op.validate_schemas(
parents_schema, deps_schema, self.input_schema, self.output_schema, strict_dtypes
)

def __rshift__(self, operator):
"""Transforms this Node by applying an BaseOperator

Expand Down Expand Up @@ -376,7 +380,20 @@ def __repr__(self):
output = " output" if not self.children else ""
return f"<Node {self.label}{output}>"

def remove_inputs(self, input_cols):
def remove_inputs(self, input_cols: List[str]) -> List[str]:
"""
Remove input columns and all output columns that depend on them.

Parameters
----------
input_cols : List[str]
The input columns to remove

Returns
-------
List[str]
The output columns that were removed
"""
removed_outputs = _derived_output_cols(input_cols, self.column_mapping)

self.input_schema = self.input_schema.without(input_cols)
Expand Down Expand Up @@ -473,8 +490,33 @@ def _cols_repr(self):
def graph(self):
return _to_graphviz(self)

Nodable = Union[
"Node", str, List[str], ColumnSelector, List[Union["Node", str, List[str], ColumnSelector]]
]

@classmethod
def construct_from(cls, nodable):
def construct_from(
cls,
nodable: Nodable,
):
"""
Convert Node-like objects to a Node or list of Nodes.

Parameters
----------
nodable : Nodable
Node-like objects to convert to a Node or list of Nodes.

Returns
-------
Union["Node", List["Node"]]
New Node(s) corresponding to the Node-like input objects

Raises
------
TypeError
If supplied input cannot be converted to a Node or list of Nodes
"""
if isinstance(nodable, str):
return Node(ColumnSelector([nodable]))
if isinstance(nodable, ColumnSelector):
Expand All @@ -486,8 +528,12 @@ def construct_from(cls, nodable):
return Node(nodable)
else:
nodes = [Node.construct_from(node) for node in nodable]
non_selection_nodes = [node for node in nodes if not node.selector]
selection_nodes = [node.selector for node in nodes if node.selector]
non_selection_nodes = [
node for node in nodes if not (hasattr(node, "selector") and node.selector)
]
selection_nodes = [
node.selector for node in nodes if (hasattr(node, "selector") and node.selector)
]
selection_nodes = (
[Node(_combine_selectors(selection_nodes))] if selection_nodes else []
)
Expand Down