-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Fix equivalent subtree, add test * Fix crossover * Fix mol adapter * Minor * Review fixes
- Loading branch information
Showing
8 changed files
with
93 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
from functools import partial | ||
|
||
import networkx as nx | ||
import pytest | ||
|
||
from examples.synthetic_graph_evolution.generators import generate_labeled_graph | ||
from golem.core.adapter.nx_adapter import BaseNetworkxAdapter | ||
from golem.core.dag.verification_rules import DEFAULT_DAG_RULES | ||
from golem.core.log import Log | ||
from golem.core.optimisers.genetic.gp_optimizer import EvoGraphOptimizer | ||
from golem.core.optimisers.genetic.gp_params import GPAlgorithmParameters | ||
from golem.core.optimisers.genetic.operators.base_mutations import MutationTypesEnum | ||
from golem.core.optimisers.genetic.operators.crossover import CrossoverTypesEnum | ||
from golem.core.optimisers.objective import Objective | ||
from golem.core.optimisers.optimization_parameters import GraphRequirements | ||
from golem.core.optimisers.optimizer import GraphGenerationParams | ||
from golem.metrics.graph_metrics import spectral_dist | ||
|
||
|
||
@pytest.mark.parametrize('graph_type', ['tree', 'dag']) | ||
def test_evolution_with_crossover(graph_type): | ||
Log().reset_logging_level(10) | ||
target_graph = generate_labeled_graph(graph_type, 50) | ||
num_iterations = 100 | ||
objective = Objective(partial(spectral_dist, target_graph)) | ||
|
||
requirements = GraphRequirements( | ||
early_stopping_iterations=num_iterations, | ||
num_of_generations=num_iterations, | ||
n_jobs=-1, | ||
history_dir=None | ||
) | ||
gp_params = GPAlgorithmParameters( | ||
pop_size=30, | ||
mutation_types=[ | ||
MutationTypesEnum.single_edge, | ||
MutationTypesEnum.single_add, | ||
MutationTypesEnum.single_drop, | ||
MutationTypesEnum.simple, | ||
MutationTypesEnum.single_change | ||
], | ||
crossover_types=[CrossoverTypesEnum.subtree, CrossoverTypesEnum.one_point] | ||
) | ||
graph_gen_params = GraphGenerationParams( | ||
adapter=BaseNetworkxAdapter(), | ||
rules_for_constraint=DEFAULT_DAG_RULES, | ||
available_node_types=['x'], | ||
) | ||
|
||
# Generate simple initial population with cyclic graphs | ||
initial_graphs = [generate_labeled_graph(graph_type, i) for i in range(4, 20)] | ||
|
||
optimiser = EvoGraphOptimizer(objective, initial_graphs, requirements, graph_gen_params, gp_params) | ||
found_graphs = optimiser.optimise(objective) | ||
found_graph: nx.DiGraph = graph_gen_params.adapter.restore(found_graphs[0]) | ||
assert found_graph is not None | ||
assert len(found_graph.nodes) > 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters