Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix the column merging during graph traversal and concat selector #353

Merged
merged 4 commits into from
Jun 30, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions merlin/dag/executors.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,12 +187,19 @@ def _merge_upstream_columns(self, upstream_outputs, merge_fn=concat_columns):
combined_outputs = upstream_output
seen_columns = upstream_columns
else:
old_columns = seen_columns - upstream_columns
overlap_columns = seen_columns.intersection(upstream_columns)
new_columns = upstream_columns - seen_columns
merge_columns = []
if old_columns:
merge_columns.append(combined_outputs[list(old_columns)])
if overlap_columns:
merge_columns.append(upstream_output[list(overlap_columns)])
if new_columns:
combined_outputs = merge_fn(
[combined_outputs, upstream_output[list(new_columns)]]
)
merge_columns.append(upstream_output[list(new_columns)])
seen_columns.update(new_columns)
if merge_columns:
combined_outputs = merge_fn(merge_columns)
return combined_outputs

def _run_node_transform(self, node, input_data, capture_dtypes=False, strict=False):
Expand Down
6 changes: 3 additions & 3 deletions merlin/dag/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,6 @@ def compute_schemas(self, root_schema: Schema, preserve_dtypes: bool = False):
self.input_schema = self.op.compute_input_schema(
root_schema, parents_schema, deps_schema, self.selector
)

self.selector = self.op.compute_selector(
self.input_schema, self.selector, parents_selector, dependencies_selector
)
Expand Down Expand Up @@ -345,8 +344,9 @@ def __add__(self, other):

return child

# handle the "column_name" + Node case
__radd__ = __add__
def __radd__(self, other):
other_node = Node.construct_from(other)
return other_node.__add__(self)

def __sub__(self, other):
"""Removes columns from this Node with another to return a new Node
Expand Down
13 changes: 9 additions & 4 deletions merlin/dag/ops/concat_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,15 @@ def compute_selector(
ColumnSelector
Combined column selectors of parent and dependency nodes
"""
selector = super().compute_selector(
input_schema,
parents_selector + dependencies_selector,
)
upstream_selector = parents_selector and dependencies_selector
if upstream_selector.subgroups:
upstream_selector = parents_selector + dependencies_selector
selector = super().compute_selector(
input_schema,
upstream_selector,
)
else:
selector = ColumnSelector(input_schema.column_names)
return selector

def compute_input_schema(
Expand Down
42 changes: 42 additions & 0 deletions tests/unit/dag/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@
#
import pytest

from merlin.core.dispatch import make_df
from merlin.dag import Graph, Node, iter_nodes, postorder_iter_nodes, preorder_iter_nodes
from merlin.dag.base_operator import BaseOperator
from merlin.dag.executors import LocalExecutor
from merlin.dag.ops.subgraph import Subgraph
from merlin.dag.ops.udf import UDF
from merlin.dag.selector import ColumnSelector
from merlin.schema.schema import ColumnSchema, Schema

Expand Down Expand Up @@ -95,3 +98,42 @@ def test_subgraph_with_summed_subgraphs():
assert post_len == pre_len
assert iter_len == post_len
assert iter_len == pre_len


def test_selector_concat_parents():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this tests that an overlapping column name prefers the version from the right-hand side. Not quite sure how to distill that into a good test name though...test_concat_prefers_rhs?

df = make_df({"a": [1, 1, 1, 1, 1, 1], "b": [1, 1, 1, 1, 1, 1]})

graph = Graph((["a", "b"] >> UDF(lambda x: x + 1)) + ["a"])

schema = Schema(["a", "b"])

graph.construct_schema(schema)
result2 = LocalExecutor().transform(df, graph)
assert (result2["b"] == [2, 2, 2, 2, 2, 2]).all()
assert (result2["a"] == [1, 1, 1, 1, 1, 1]).all()


def test_selector_concat_parents_inverted():
df = make_df({"a": [1, 1, 1, 1, 1, 1], "b": [1, 1, 1, 1, 1, 1]})

graph = Graph((["a"] >> UDF(lambda x: x + 1)) + ["a", "b"])

schema = Schema(["a", "b"])

graph.construct_schema(schema)
result2 = LocalExecutor().transform(df, graph)
assert (result2["b"] == [1, 1, 1, 1, 1, 1]).all()
assert (result2["a"] == [1, 1, 1, 1, 1, 1]).all()


def test_selector_concat_parents_open():
df = make_df({"a": [1, 1, 1, 1, 1, 1], "b": [1, 1, 1, 1, 1, 1]})

graph = Graph((["a"] >> UDF(lambda x: x + 1)) + ["b"])

schema = Schema(["a", "b"])

graph.construct_schema(schema)
result2 = LocalExecutor().transform(df, graph)
assert (result2["b"] == [1, 1, 1, 1, 1, 1]).all()
assert (result2["a"] == [2, 2, 2, 2, 2, 2]).all()