Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix the column merging during graph traversal and concat selector #353

Merged
merged 4 commits into from
Jun 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions merlin/dag/executors.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,12 +187,19 @@ def _merge_upstream_columns(self, upstream_outputs, merge_fn=concat_columns):
combined_outputs = upstream_output
seen_columns = upstream_columns
else:
old_columns = seen_columns - upstream_columns
overlap_columns = seen_columns.intersection(upstream_columns)
new_columns = upstream_columns - seen_columns
merge_columns = []
if old_columns:
merge_columns.append(combined_outputs[list(old_columns)])
if overlap_columns:
merge_columns.append(upstream_output[list(overlap_columns)])
if new_columns:
combined_outputs = merge_fn(
[combined_outputs, upstream_output[list(new_columns)]]
)
merge_columns.append(upstream_output[list(new_columns)])
seen_columns.update(new_columns)
if merge_columns:
combined_outputs = merge_fn(merge_columns)
return combined_outputs

def _run_node_transform(self, node, input_data, capture_dtypes=False, strict=False):
Expand Down
78 changes: 33 additions & 45 deletions merlin/dag/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,20 @@
from typing import List, Union

from merlin.dag.base_operator import BaseOperator
from merlin.dag.ops import ConcatColumns, SelectionOp, SubsetColumns, SubtractionOp
from merlin.dag.ops import ConcatColumns, GroupingOp, SelectionOp, SubsetColumns, SubtractionOp
from merlin.dag.ops.udf import UDF
from merlin.dag.selector import ColumnSelector
from merlin.schema import Schema

Nodable = Union[
"Node",
BaseOperator,
str,
List[str],
ColumnSelector,
List[Union["Node", BaseOperator, str, List[str], ColumnSelector]],
]


class Node:
"""A Node is a group of columns that you want to apply the same transformations to.
Expand Down Expand Up @@ -69,13 +78,7 @@ def selector(self, sel):
# These methods must maintain grouping
def add_dependency(
self,
dep: Union[
str,
List[str],
ColumnSelector,
"Node",
List[Union[str, List[str], "Node", ColumnSelector]],
],
dep: Nodable,
):
"""
Adding a dependency node to this node
Expand All @@ -99,13 +102,7 @@ def add_dependency(

def add_parent(
self,
parent: Union[
str,
List[str],
ColumnSelector,
"Node",
List[Union[str, List[str], "Node", ColumnSelector]],
],
parent: Nodable,
):
"""
Adding a parent node to this node
Expand All @@ -127,13 +124,7 @@ def add_parent(

def add_child(
self,
child: Union[
str,
List[str],
ColumnSelector,
"Node",
List[Union[str, List[str], "Node", ColumnSelector]],
],
child: Nodable,
):
"""
Adding a child node to this node
Expand All @@ -155,13 +146,7 @@ def add_child(

def remove_child(
self,
child: Union[
str,
List[str],
ColumnSelector,
"Node",
List[Union[str, List[str], "Node", ColumnSelector]],
],
child: Nodable,
):
"""
Removing a child node from this node
Expand Down Expand Up @@ -209,7 +194,6 @@ def compute_schemas(self, root_schema: Schema, preserve_dtypes: bool = False):
self.input_schema = self.op.compute_input_schema(
root_schema, parents_schema, deps_schema, self.selector
)

self.selector = self.op.compute_selector(
self.input_schema, self.selector, parents_selector, dependencies_selector
)
Expand Down Expand Up @@ -335,18 +319,26 @@ def __add__(self, other):
other_nodes = [other_nodes]

for other_node in other_nodes:
# If the other node is a `[]`, we want to maintain grouping
# so create a selection node that we can use to do that
if isinstance(other_node, list):
grouped_node = Node.construct_from(GroupingOp())
for node in other_node:
grouped_node.add_parent(node)
child.add_dependency(grouped_node)
# If the other node is a `+` node, we want to collapse it into this `+` node to
# avoid creating a cascade of repeated `+`s that we'd need to optimize out by
# re-combining them later in order to clean up the graph
if not isinstance(other_node, list) and isinstance(other_node.op, ConcatColumns):
elif isinstance(other_node.op, ConcatColumns):
child.dependencies += other_node.grouped_parents_with_dependencies
else:
child.add_dependency(other_node)

return child

# handle the "column_name" + Node case
__radd__ = __add__
def __radd__(self, other):
other_node = Node.construct_from(other)
return other_node.__add__(self)

def __sub__(self, other):
"""Removes columns from this Node with another to return a new Node
Expand Down Expand Up @@ -534,14 +526,6 @@ def _cols_repr(self):
def graph(self):
return _to_graphviz(self)

Nodable = Union[
"Node",
str,
List[str],
ColumnSelector,
List[Union["Node", str, List[str], ColumnSelector]],
]

@classmethod
def construct_from(
cls,
Expand Down Expand Up @@ -589,7 +573,11 @@ def construct_from(
selection_nodes = (
[Node(_combine_selectors(selection_nodes))] if selection_nodes else []
)
return non_selection_nodes + selection_nodes
group_node = Node.construct_from(GroupingOp())
all_node = non_selection_nodes + selection_nodes
for node in all_node:
group_node.add_parent(node)
return group_node

else:
raise TypeError(
Expand Down Expand Up @@ -690,7 +678,9 @@ def _combine_selectors(elements):
combined = ColumnSelector()
for elem in elements:
if isinstance(elem, Node):
if elem.selector:
if isinstance(elem.op, GroupingOp):
selector = elem.selector
elif elem.selector:
selector = elem.op.output_column_names(elem.selector)
elif elem.output_schema:
selector = ColumnSelector(elem.output_schema.column_names)
Expand All @@ -705,8 +695,6 @@ def _combine_selectors(elements):
combined += elem
elif isinstance(elem, str):
combined += ColumnSelector(elem)
elif isinstance(elem, list):
combined += ColumnSelector(subgroups=_combine_selectors(elem))
return combined


Expand Down
1 change: 1 addition & 0 deletions merlin/dag/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
TagAsUserID,
)
from merlin.dag.ops.concat_columns import ConcatColumns
from merlin.dag.ops.grouping import GroupingOp
from merlin.dag.ops.rename import Rename
from merlin.dag.ops.selection import SelectionOp
from merlin.dag.ops.subset_columns import SubsetColumns
Expand Down
12 changes: 8 additions & 4 deletions merlin/dag/ops/concat_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,14 @@ def compute_selector(
ColumnSelector
Combined column selectors of parent and dependency nodes
"""
selector = super().compute_selector(
input_schema,
parents_selector + dependencies_selector,
)
upstream_selector = parents_selector + dependencies_selector
if upstream_selector.subgroups:
selector = super().compute_selector(
input_schema,
upstream_selector,
)
else:
selector = ColumnSelector(input_schema.column_names)
return selector

def compute_input_schema(
Expand Down
22 changes: 22 additions & 0 deletions merlin/dag/ops/grouping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import Optional

from merlin.dag.ops.selection import SelectionOp
from merlin.dag.selector import ColumnSelector
from merlin.schema import Schema


class GroupingOp(SelectionOp):
def compute_selector(
self,
input_schema: Schema,
selector: ColumnSelector,
parents_selector: Optional[ColumnSelector] = None,
dependencies_selector: Optional[ColumnSelector] = None,
) -> ColumnSelector:
upstream_selector = parents_selector + dependencies_selector
new_selector = ColumnSelector(subgroups=upstream_selector)
selector = super().compute_selector(
input_schema,
new_selector,
)
return selector
42 changes: 42 additions & 0 deletions tests/unit/dag/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@
#
import pytest

from merlin.core.dispatch import make_df
from merlin.dag import Graph, Node, iter_nodes, postorder_iter_nodes, preorder_iter_nodes
from merlin.dag.base_operator import BaseOperator
from merlin.dag.executors import LocalExecutor
from merlin.dag.ops.subgraph import Subgraph
from merlin.dag.ops.udf import UDF
from merlin.dag.selector import ColumnSelector
from merlin.schema.schema import ColumnSchema, Schema

Expand Down Expand Up @@ -95,3 +98,42 @@ def test_subgraph_with_summed_subgraphs():
assert post_len == pre_len
assert iter_len == post_len
assert iter_len == pre_len


def test_concat_prefers_rhs_with_seen_root_output():
df = make_df({"a": [1, 1, 1, 1, 1, 1], "b": [1, 1, 1, 1, 1, 1]})

graph = Graph((["a", "b"] >> UDF(lambda x: x + 1)) + ["a"])

schema = Schema(["a", "b"])

graph.construct_schema(schema)
result2 = LocalExecutor().transform(df, graph)
assert result2["b"].to_numpy().tolist() == [2, 2, 2, 2, 2, 2]
assert result2["a"].to_numpy().tolist() == [1, 1, 1, 1, 1, 1]


def test_concat_prefers_rhs_with_unseen_root_output():
df = make_df({"a": [1, 1, 1, 1, 1, 1], "b": [1, 1, 1, 1, 1, 1]})

graph = Graph((["a"] >> UDF(lambda x: x + 1)) + ["b"])

schema = Schema(["a", "b"])

graph.construct_schema(schema)
result2 = LocalExecutor().transform(df, graph)
assert result2["b"].to_numpy().tolist() == [1, 1, 1, 1, 1, 1]
assert result2["a"].to_numpy().tolist() == [2, 2, 2, 2, 2, 2]


def test_concat_prefers_rhs_with_seen_and_unseen_root_output():
df = make_df({"a": [1, 1, 1, 1, 1, 1], "b": [1, 1, 1, 1, 1, 1]})

graph = Graph((["a"] >> UDF(lambda x: x + 1)) + ["a", "b"])

schema = Schema(["a", "b"])

graph.construct_schema(schema)
result2 = LocalExecutor().transform(df, graph)
assert result2["b"].to_numpy().tolist() == [1, 1, 1, 1, 1, 1]
assert result2["a"].to_numpy().tolist() == [1, 1, 1, 1, 1, 1]