Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix delete_node to allow breaking cycles in graphs. #75

Merged
merged 6 commits into from
Apr 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions golem/core/dag/graph.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
from enum import Enum
from os import PathLike
from typing import Dict, List, Optional, Sequence, Union, Tuple, TypeVar

Expand All @@ -8,6 +9,13 @@
NodeType = TypeVar('NodeType', bound=GraphNode, covariant=False, contravariant=False)


class ReconnectType(Enum):
"""Defines allowed kinds of removals in Graph. Used by mutations."""
none = 'none' # do not reconnect predecessors
single = 'single' # reconnect a predecessor only if it's single
all = 'all' # reconnect all predecessors to all successors


class Graph(ABC):
"""Defines abstract graph interface that's required by graph optimisation process.
"""
Expand Down Expand Up @@ -41,12 +49,13 @@ def update_subtree(self, old_subtree: GraphNode, new_subtree: GraphNode):
raise NotImplementedError()

@abstractmethod
def delete_node(self, node: GraphNode):
def delete_node(self, node: GraphNode, reconnect: ReconnectType = ReconnectType.single):
"""Removes ``node`` from the graph.
If ``node`` has only one child, then connects all of the ``node`` parents to it.

Args:
node: node of the graph to be deleted
reconnect: defines how to treat left edges between parents and children
"""
raise NotImplementedError()

Expand Down Expand Up @@ -84,7 +93,7 @@ def connect_nodes(self, node_parent: GraphNode, node_child: GraphNode):

@abstractmethod
def disconnect_nodes(self, node_parent: GraphNode, node_child: GraphNode,
clean_up_leftovers: bool = True):
clean_up_leftovers: bool = False):
"""Removes an edge between two nodes

Args:
Expand Down
8 changes: 4 additions & 4 deletions golem/core/dag/graph_delegate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Union, Sequence, List, Optional, Tuple, Type

from golem.core.dag.graph import Graph
from golem.core.dag.graph import Graph, ReconnectType
from golem.core.dag.graph_node import GraphNode
from golem.core.dag.linked_graph import LinkedGraph

Expand All @@ -26,8 +26,8 @@ def update_node(self, old_node: GraphNode, new_node: GraphNode):
def update_subtree(self, old_subtree: GraphNode, new_subtree: GraphNode):
self.operator.update_subtree(old_subtree, new_subtree)

def delete_node(self, node: GraphNode):
self.operator.delete_node(node)
def delete_node(self, node: GraphNode, reconnect: ReconnectType = ReconnectType.single):
self.operator.delete_node(node, reconnect)

def delete_subtree(self, subtree: GraphNode):
self.operator.delete_subtree(subtree)
Expand All @@ -39,7 +39,7 @@ def connect_nodes(self, node_parent: GraphNode, node_child: GraphNode):
self.operator.connect_nodes(node_parent, node_child)

def disconnect_nodes(self, node_parent: GraphNode, node_child: GraphNode,
clean_up_leftovers: bool = True):
clean_up_leftovers: bool = False):
self.operator.disconnect_nodes(node_parent, node_child, clean_up_leftovers)

def get_edges(self) -> Sequence[Tuple[GraphNode, GraphNode]]:
Expand Down
24 changes: 15 additions & 9 deletions golem/core/dag/linked_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from networkx import graph_edit_distance, set_node_attributes

from golem.core.dag.graph import Graph
from golem.core.dag.graph import Graph, ReconnectType
from golem.core.dag.graph_node import GraphNode
from golem.core.dag.graph_utils import ordered_subnodes_hierarchy, node_depth
from golem.core.dag.convert import graph_structure_as_nx_graph
Expand Down Expand Up @@ -34,19 +34,25 @@ def _empty_postprocess(*args):
pass

@copy_doc(Graph.delete_node)
def delete_node(self, node: GraphNode):
def delete_node(self, node: GraphNode, reconnect: ReconnectType = ReconnectType.single) -> object:
node_children_cached = self.node_children(node)

self._nodes.remove(node)
for node_child in node_children_cached:
node_child.nodes_from.remove(node)

# if removed node had a single child
# then reconnect it to preceding parent nodes.
if node.nodes_from and len(node_children_cached) == 1:
child = node_children_cached[0]
for node_from in node.nodes_from:
child.nodes_from.append(node_from)
if reconnect == ReconnectType.single:
# if removed node had a single child
# then reconnect it to preceding parent nodes.
if node.nodes_from and len(node_children_cached) == 1:
child = node_children_cached[0]
child.nodes_from.extend(node.nodes_from)
elif reconnect == ReconnectType.all:
if node.nodes_from:
for child in node_children_cached:
child.nodes_from.extend(node.nodes_from)
elif reconnect == ReconnectType.none:
pass

self._postprocess_nodes(self, self._nodes)

Expand Down Expand Up @@ -126,7 +132,7 @@ def _clean_up_leftovers(self, node: GraphNode):

@copy_doc(Graph.disconnect_nodes)
def disconnect_nodes(self, node_parent: GraphNode, node_child: GraphNode,
clean_up_leftovers: bool = True):
clean_up_leftovers: bool = False):
if node_parent not in node_child.nodes_from:
return
if node_parent not in self._nodes or node_child not in self._nodes:
Expand Down
5 changes: 3 additions & 2 deletions golem/core/optimisers/advisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@

class RemoveType(Enum):
"""Defines allowed kinds of removals in Graph. Used by mutations."""
forbidden = 'forbidden'
node_only = 'node_only'
node_rewire = 'node_rewire'
with_direct_children = 'with_direct_children'
with_parents = 'with_parents'
forbidden = 'forbidden'


class DefaultChangeAdvisor:
Expand All @@ -24,7 +25,7 @@ def propose_change(self, node: OptNode, possible_operations: List[Any]) -> List[
return possible_operations

def can_be_removed(self, node: OptNode) -> RemoveType:
return RemoveType.node_only
return RemoveType.node_rewire

def propose_parent(self, node: OptNode, possible_operations: List[Any]) -> List[Any]:
return possible_operations
23 changes: 12 additions & 11 deletions golem/core/optimisers/genetic/operators/base_mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import TYPE_CHECKING

from golem.core.adapter import register_native
from golem.core.dag.graph import ReconnectType
from golem.core.dag.graph_node import GraphNode
from golem.core.dag.graph_utils import distance_to_root_level, ordered_subnodes_hierarchy, distance_to_primary_level
from golem.core.optimisers.advisor import RemoveType
Expand Down Expand Up @@ -171,7 +172,8 @@ def add_as_child(graph: OptGraph,
graph.connect_nodes(node_parent=node_to_mutate, node_child=new_node)
if new_node_child:
graph.connect_nodes(node_parent=new_node, node_child=new_node_child)
graph.disconnect_nodes(node_parent=node_to_mutate, node_child=new_node_child)
graph.disconnect_nodes(node_parent=node_to_mutate, node_child=new_node_child,
clean_up_leftovers=True)

return graph

Expand Down Expand Up @@ -246,18 +248,17 @@ def single_drop_mutation(graph: OptGraph,
if n.descriptive_id.count('data_source') == 1
and node_name in n.descriptive_id]
for child_node in nodes_to_delete:
graph.delete_node(child_node)
graph.delete_node(child_node, reconnect=ReconnectType.all)
elif removal_type == RemoveType.with_parents:
graph.delete_subtree(node_to_del)
elif removal_type != RemoveType.forbidden:
graph.delete_node(node_to_del)
if node_to_del.nodes_from:
children = graph.node_children(node_to_del)
for child in children:
if child.nodes_from:
child.nodes_from.extend(node_to_del.nodes_from)
else:
child.nodes_from = node_to_del.nodes_from
elif removal_type == RemoveType.node_rewire:
graph.delete_node(node_to_del, reconnect=ReconnectType.all)
elif removal_type == RemoveType.node_only:
graph.delete_node(node_to_del, reconnect=ReconnectType.none)
elif removal_type == RemoveType.forbidden:
pass
else:
raise ValueError("Unknown advice (RemoveType) returned by Advisor ")
return graph


Expand Down
36 changes: 35 additions & 1 deletion test/unit/dag/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
import pytest

from golem.core.dag.graph import Graph
from golem.core.dag.graph import Graph, ReconnectType
from golem.core.dag.graph_delegate import GraphDelegate
from golem.core.dag.linked_graph import LinkedGraph
from golem.core.dag.linked_graph_node import LinkedGraphNode
Expand Down Expand Up @@ -118,6 +118,40 @@ def test_delete_intermediate_node():
assert graph.depth == 2


def test_delete_leave_cycle():
first = GraphNode(content='n1')
second = GraphNode(content='n2', nodes_from=[first])
third = GraphNode(content='n3', nodes_from=[second])
final = GraphNode(content='n4', nodes_from=[third])
graph = GraphImpl(final)
graph.connect_nodes(final, first)
maypink marked this conversation as resolved.
Show resolved Hide resolved

assert len(graph.get_edges()) == 4

graph.delete_node(third, reconnect=ReconnectType.single)

assert third not in graph.nodes
assert len(graph.get_edges()) == 3
assert (second, final) in graph.get_edges()


def test_delete_break_cycle():
first = GraphNode(content='n1')
second = GraphNode(content='n2', nodes_from=[first])
third = GraphNode(content='n3', nodes_from=[second])
final = GraphNode(content='n4', nodes_from=[third])
graph = GraphImpl(final)
graph.connect_nodes(final, first)

assert len(graph.get_edges()) == 4

graph.delete_node(third, reconnect=ReconnectType.none)

assert third not in graph.nodes
assert len(graph.get_edges()) == 2
assert not final.nodes_from


def test_delete_node_with_duplicated_edges():
ok_primary_node = GraphNode('n1')
bad_primary_node = GraphNode('n2')
Expand Down
10 changes: 5 additions & 5 deletions test/unit/dag/test_graph_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def test_disconnect_nodes_method_first():
node_e = graph.nodes[4]
node_e_root = graph.nodes[0]

graph.disconnect_nodes(node_e, node_e_root)
graph.disconnect_nodes(node_e, node_e_root, clean_up_leftovers=True)

assert res_graph == graph

Expand All @@ -197,7 +197,7 @@ def test_disconnect_nodes_method_second():
node_b = graph.nodes[5]
node_e = graph.nodes[4]

graph.disconnect_nodes(node_b, node_e)
graph.disconnect_nodes(node_b, node_e, clean_up_leftovers=True)

assert res_graph == graph

Expand All @@ -210,7 +210,7 @@ def test_disconnect_nodes_method_third():
node_d = graph.nodes[1]
root_node_e = graph.nodes[0]

graph.disconnect_nodes(node_d, root_node_e)
graph.disconnect_nodes(node_d, root_node_e, clean_up_leftovers=True)

assert res_graph == graph

Expand All @@ -224,7 +224,7 @@ def test_disconnect_nodes_method_fourth():
node_c = res_graph.nodes[2]
root_node_e = res_graph.nodes[0]

res_graph.disconnect_nodes(node_c, root_node_e)
res_graph.disconnect_nodes(node_c, root_node_e, clean_up_leftovers=True)
assert res_graph == graph


Expand All @@ -237,7 +237,7 @@ def test_disconnect_nodes_method_fifth():
node_k = LinkedGraphNode('k')
node_m = LinkedGraphNode('m', nodes_from=[node_k])

res_graph.disconnect_nodes(node_k, node_m)
res_graph.disconnect_nodes(node_k, node_m, clean_up_leftovers=True)
assert res_graph == graph


Expand Down