Skip to content

Commit

Permalink
remove upstream dependencies that have no outputs (#107)
Browse files Browse the repository at this point in the history
* remove upstream dependencies that have no outputs

* add regression test for Graph dependencies
  • Loading branch information
nv-alaiacano authored Jul 25, 2022
1 parent 99ce749 commit 1354dcf
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 2 deletions.
10 changes: 8 additions & 2 deletions merlin/dag/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,16 +133,22 @@ def remove_inputs(self, to_remove):

while nodes_to_process:
node, columns_to_remove = nodes_to_process.popleft()

if node.input_schema and len(node.input_schema):
output_columns_to_remove = node.remove_inputs(columns_to_remove)

for child in node.children:
nodes_to_process.append((child, to_remove + output_columns_to_remove))
nodes_to_process.append(
(child, list(set(to_remove + output_columns_to_remove)))
)

if not len(node.input_schema):
node.remove_child(child)

# remove any dependencies that do not have an output schema
node.dependencies = [
dep for dep in node.dependencies if dep.output_schema and len(dep.output_schema)
]

if not node.input_schema or not len(node.input_schema):
for parent in node.parents:
parent.remove_child(node)
Expand Down
33 changes: 33 additions & 0 deletions tests/unit/dag/test_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from merlin.dag import Graph, Node
from merlin.dag.selector import ColumnSelector
from merlin.schema.schema import ColumnSchema, Schema


def test_remove_dependencies():
# Construct a simple graph with a structure like:
# ["y"] ----> ["x", "y"] ---\
# --- > ["o"]
# ["z"] --------------------/

# When removing "y", we should see all of the dependencies
# and parents removed from the list of leaf nodes.

dep_node = Node(selector=ColumnSelector(["y"]))
dep_node.input_schema = Schema([ColumnSchema("y")])
dep_node.output_schema = Schema([ColumnSchema("y")])

node_xy = Node(selector=ColumnSelector(["x", "y"]))
node_xy.input_schema = Schema([ColumnSchema("x"), ColumnSchema("y")])
node_xy.output_schema = Schema([ColumnSchema("z")])

plus_node = Node(selector=ColumnSelector(["z", "y"]))
plus_node.input_schema = Schema([ColumnSchema("y"), ColumnSchema("z")])
plus_node.output_schema = Schema([ColumnSchema("o")])
plus_node.add_parent(dep_node)
plus_node.add_parent(node_xy)
plus_node.add_dependency(dep_node)

graph_with_dependency = Graph(plus_node)
assert len(graph_with_dependency.leaf_nodes) == 2
graph_with_dependency.remove_inputs(["y"])
assert len(graph_with_dependency.leaf_nodes) == 1

0 comments on commit 1354dcf

Please sign in to comment.