-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add molecular graph, adapter and some metrics * Fix imports * Add CL score * Add draft add_atom mutation * Add draft delete_atom, replace_bond mutation * Add draft delete_bond, replace_atom mutation * Add experiment draft * Fix adapter * Fix experiment, add metrics description * Small mutations fix * Minor changes * Add cut_atom mutation * Refactor MolAdvisor * Add insert_carbon * Add remove_group * Add move_group * Move utils to separate file * PEP 8 and docstrings * Add zinc normalized logp * Minor * Fix extending initial population * Add guacomol benchmark * Refactror guacamol experiment * Add statistics and visualization for guacamol * pep * Add mutations list * Add CLScorer * Add download from github * Review fixes * Add integration test
- Loading branch information
Showing
24 changed files
with
1,197 additions
and
36 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
SULFUR_DEFAULT_VALENCE = 6 | ||
MIN_LONG_CYCLE_SIZE = 6 | ||
|
||
# normalization constants, statistics from 250k_rndm_zinc_drugs_clean.smi dataset | ||
ZINC_LOGP_MEAN = 2.4570953396190123 | ||
ZINC_LOGP_STD = 1.434324401111988 | ||
ZINC_SA_MEAN = -3.0525811293166134 | ||
ZINC_SA_STD = 0.8335207024513095 | ||
ZINC_CYCLE_MEAN = -0.0485696876403053 | ||
ZINC_CYCLE_STD = 0.2860212110245455 |
Binary file added
BIN
+23.5 MB
examples/molecule_search/data/shingles/chembl_24_1_shingle_scores_log10_nrooted_nchir.pkl
Binary file not shown.
Binary file added
BIN
+5.15 MB
...ecule_search/data/shingles/chembl_24_1_shingle_scores_log10_rooted_nchir_min_freq_100.pkl
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,185 @@ | ||
import os.path | ||
from datetime import timedelta | ||
from io import StringIO | ||
from pathlib import Path | ||
from typing import Type, Optional, Sequence, List, Iterable, Callable, Dict | ||
|
||
import numpy as np | ||
from rdkit.Chem import Draw | ||
from rdkit.Chem.rdchem import BondType | ||
|
||
from examples.molecule_search.mol_adapter import MolAdapter | ||
from examples.molecule_search.mol_advisor import MolChangeAdvisor | ||
from examples.molecule_search.mol_graph import MolGraph | ||
from examples.molecule_search.mol_graph_parameters import MolGraphRequirements | ||
from examples.molecule_search.mol_mutations import CHEMICAL_MUTATIONS | ||
from examples.molecule_search.mol_metrics import normalized_sa_score, penalised_logp, qed_score, \ | ||
normalized_logp, CLScorer | ||
from golem.core.dag.verification_rules import has_no_self_cycled_nodes, has_no_isolated_components, \ | ||
has_no_isolated_nodes | ||
from golem.core.optimisers.adaptive.operator_agent import MutationAgentTypeEnum | ||
from golem.core.optimisers.genetic.gp_optimizer import EvoGraphOptimizer | ||
from golem.core.optimisers.genetic.gp_params import GPAlgorithmParameters | ||
from golem.core.optimisers.genetic.operators.crossover import CrossoverTypesEnum | ||
from golem.core.optimisers.genetic.operators.elitism import ElitismTypesEnum | ||
from golem.core.optimisers.genetic.operators.inheritance import GeneticSchemeTypesEnum | ||
from golem.core.optimisers.objective import Objective | ||
from golem.core.optimisers.opt_history_objects.opt_history import OptHistory | ||
from golem.core.optimisers.optimizer import GraphGenerationParams, GraphOptimizer | ||
from golem.visualisation.opt_viz import PlotTypesEnum, OptHistoryVisualizer | ||
from golem.visualisation.opt_viz_extra import visualise_pareto | ||
|
||
|
||
def get_methane() -> MolGraph: | ||
methane = 'C' | ||
return MolGraph.from_smiles(methane) | ||
|
||
|
||
def get_all_mol_metrics() -> Dict[str, Callable]: | ||
metrics = {'qed_score': qed_score, | ||
'cl_score': CLScorer(), | ||
'norm_sa_score': normalized_sa_score, | ||
'penalised_logp': penalised_logp, | ||
'norm_log_p': normalized_logp} | ||
return metrics | ||
|
||
|
||
def molecule_search_setup(optimizer_cls: Type[GraphOptimizer] = EvoGraphOptimizer, | ||
max_heavy_atoms: int = 50, | ||
atom_types: Optional[List[str]] = None, | ||
bond_types: Sequence[BondType] = (BondType.SINGLE, BondType.DOUBLE, BondType.TRIPLE), | ||
timeout: Optional[timedelta] = None, | ||
num_iterations: Optional[int] = None, | ||
pop_size: int = 20, | ||
metrics: Optional[List[str]] = None, | ||
initial_molecules: Optional[Sequence[MolGraph]] = None): | ||
requirements = MolGraphRequirements( | ||
max_heavy_atoms=max_heavy_atoms, | ||
available_atom_types=atom_types or ['C', 'N', 'O', 'F', 'P', 'S', 'Cl', 'Br'], | ||
bond_types=bond_types, | ||
early_stopping_timeout=np.inf, | ||
early_stopping_iterations=np.inf, | ||
keep_n_best=4, | ||
timeout=timeout, | ||
num_of_generations=num_iterations, | ||
keep_history=True, | ||
n_jobs=1, | ||
history_dir=os.path.join(os.path.curdir, 'history') | ||
) | ||
gp_params = GPAlgorithmParameters( | ||
pop_size=pop_size, | ||
max_pop_size=pop_size, | ||
multi_objective=True, | ||
genetic_scheme_type=GeneticSchemeTypesEnum.steady_state, | ||
elitism_type=ElitismTypesEnum.replace_worst, | ||
mutation_types=CHEMICAL_MUTATIONS, | ||
crossover_types=[CrossoverTypesEnum.none], | ||
adaptive_mutation_type=MutationAgentTypeEnum.bandit | ||
) | ||
graph_gen_params = GraphGenerationParams( | ||
adapter=MolAdapter(), | ||
rules_for_constraint=[has_no_self_cycled_nodes, has_no_isolated_components, has_no_isolated_nodes], | ||
advisor=MolChangeAdvisor(), | ||
) | ||
|
||
metrics = metrics or ['qed_score'] | ||
all_metrics = get_all_mol_metrics() | ||
objective = Objective( | ||
quality_metrics={metric_name: all_metrics[metric_name] for metric_name in metrics}, | ||
is_multi_objective=len(metrics) > 1 | ||
) | ||
|
||
initial_graphs = initial_molecules or [get_methane()] | ||
initial_graphs = graph_gen_params.adapter.adapt(initial_graphs) | ||
|
||
# Build the optimizer | ||
optimiser = optimizer_cls(objective, initial_graphs, requirements, graph_gen_params, gp_params) | ||
return optimiser, objective | ||
|
||
|
||
def visualize_results(molecules: Iterable[MolGraph], | ||
objective: Objective, | ||
history: OptHistory, | ||
save_path: Optional[str] = None): | ||
save_path = save_path or os.path.join(os.path.curdir, 'visualisations') | ||
Path(save_path).mkdir(exist_ok=True) | ||
|
||
if objective.is_multi_objective: | ||
visualise_pareto(history.archive_history[-1], objectives_names=objective.metric_names[:2], folder=save_path) | ||
|
||
visualizer = OptHistoryVisualizer(history) | ||
visualization = PlotTypesEnum.fitness_line.value(visualizer.history, visualizer.visuals_params) | ||
visualization.visualize(dpi=100, save_path=os.path.join(save_path, 'fitness_line.png')) | ||
|
||
rw_molecules = [mol.get_rw_molecule() for mol in set(molecules)] | ||
objectives = [objective.format_fitness(objective(mol)) for mol in set(molecules)] | ||
image = Draw.MolsToGridImage(rw_molecules, | ||
legends=objectives, | ||
molsPerRow=min(4, len(rw_molecules)), | ||
subImgSize=(1000, 1000), | ||
legendFontSize=50) | ||
image.show() | ||
image.save(os.path.join(save_path, 'best_molecules.png')) | ||
|
||
|
||
def run_experiment(optimizer_setup: Callable, | ||
optimizer_cls: Type[GraphOptimizer] = EvoGraphOptimizer, | ||
max_heavy_atoms: int = 50, | ||
atom_types: Optional[List[str]] = None, | ||
bond_types: Sequence[BondType] = (BondType.SINGLE, BondType.DOUBLE, BondType.TRIPLE), | ||
initial_molecules: Optional[Sequence[MolGraph]] = None, | ||
pop_size: int = 20, | ||
metrics: Optional[List[str]] = None, | ||
num_trials: int = 1, | ||
trial_timeout: Optional[int] = None, | ||
trial_iterations: Optional[int] = None, | ||
visualize: bool = False | ||
): | ||
log = StringIO() | ||
atom_types = atom_types or ['C', 'N', 'O', 'F', 'P', 'S', 'Cl', 'Br'] | ||
metrics = metrics or ['qed_score'] | ||
trial_results = [] | ||
experiment_id = f'Experiment [metrics={", ".join(metrics)} pop_size={pop_size}]\n' | ||
for trial in range(num_trials): | ||
optimizer, objective = optimizer_setup(optimizer_cls, | ||
max_heavy_atoms, | ||
atom_types, | ||
bond_types, | ||
trial_timeout, | ||
trial_iterations, | ||
pop_size, | ||
metrics, | ||
initial_molecules) | ||
found_graphs = optimizer.optimise(objective) | ||
history = optimizer.history | ||
if visualize: | ||
molecules = [MolAdapter().restore(graph) for graph in found_graphs] | ||
save_path = os.path.join(os.path.curdir, | ||
'visualisations', | ||
f'trial_{trial}_pop_size_{pop_size}_{"_".join(metrics)}') | ||
visualize_results(set(molecules), objective, history, save_path) | ||
Path("results").mkdir(exist_ok=True) | ||
history.save(f'./results/trial_{trial}_pop_size_{pop_size}_{"_".join(metrics)}.json') | ||
trial_results.extend(history.final_choices) | ||
|
||
# Compute mean & std for metrics of trials | ||
ff = objective.format_fitness | ||
trial_metrics = np.array([ind.fitness.values for ind in trial_results]) | ||
trial_metrics_mean = trial_metrics.mean(axis=0) | ||
trial_metrics_std = trial_metrics.std(axis=0) | ||
print(f'{experiment_id} finished with metrics:\n' | ||
f'mean={ff(trial_metrics_mean)}\n' | ||
f' std={ff(trial_metrics_std)}', | ||
file=log) | ||
print(log.getvalue()) | ||
return log.getvalue() | ||
|
||
|
||
if __name__ == '__main__': | ||
run_experiment(molecule_search_setup, | ||
max_heavy_atoms=38, | ||
trial_iterations=100, | ||
pop_size=100, | ||
metrics=['cl_score'], | ||
visualize=True, | ||
num_trials=10) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
import json | ||
from typing import Optional, List | ||
|
||
import joblib | ||
import numpy as np | ||
import pandas as pd | ||
from guacamol.assess_goal_directed_generation import assess_goal_directed_generation | ||
from guacamol.goal_directed_generator import GoalDirectedGenerator | ||
from guacamol.scoring_function import ScoringFunction | ||
from joblib import delayed | ||
from rdkit.Chem import Draw, MolFromSmiles | ||
from rdkit.Chem.rdchem import BondType | ||
|
||
from examples.molecule_search.mol_adapter import MolAdapter | ||
from examples.molecule_search.mol_advisor import MolChangeAdvisor | ||
from examples.molecule_search.mol_graph import MolGraph | ||
from examples.molecule_search.mol_graph_parameters import MolGraphRequirements | ||
from examples.molecule_search.mol_mutations import CHEMICAL_MUTATIONS | ||
from golem.core.optimisers.adaptive.operator_agent import MutationAgentTypeEnum | ||
from golem.core.optimisers.genetic.gp_optimizer import EvoGraphOptimizer | ||
from golem.core.optimisers.genetic.gp_params import GPAlgorithmParameters | ||
from golem.core.optimisers.genetic.operators.crossover import CrossoverTypesEnum | ||
from golem.core.optimisers.genetic.operators.elitism import ElitismTypesEnum | ||
from golem.core.optimisers.genetic.operators.inheritance import GeneticSchemeTypesEnum | ||
from golem.core.optimisers.objective import Objective | ||
from golem.core.optimisers.optimizer import GraphGenerationParams | ||
|
||
|
||
def load_init_population(scoring_function: ScoringFunction, | ||
n_jobs: int = -1, | ||
path=".\\data\\guacamol_v1_all.smiles"): | ||
""" Original code: | ||
https://github.com/BenevolentAI/guacamol_baselines/blob/master/graph_ga/goal_directed_generation.py""" | ||
with open(path, "r") as f: | ||
smiles_list = f.readlines() | ||
joblist = [delayed(scoring_function.score)(smile) for smile in smiles_list] | ||
scores = joblib.Parallel(n_jobs=n_jobs)(joblist) | ||
scored_smiles = list(zip(scores, smiles_list)) | ||
scored_smiles = sorted(scored_smiles, key=lambda x: x[0], reverse=True) | ||
best_smiles = [smile for score, smile in scored_smiles][:100] | ||
init_pop = [MolGraph.from_smiles(smile) for smile in best_smiles] | ||
return init_pop | ||
|
||
|
||
class GolemMoleculeGenerator(GoalDirectedGenerator): | ||
""" You need to download Guacamol all smiles dataset from https://figshare.com/projects/GuacaMol/56639""" | ||
def __init__(self, | ||
requirements: Optional[MolGraphRequirements] = None, | ||
graph_gen_params: Optional[GraphGenerationParams] = None, | ||
gp_params: Optional[GPAlgorithmParameters] = None): | ||
self.requirements = requirements or MolGraphRequirements( | ||
max_heavy_atoms=50, | ||
available_atom_types=['C', 'N', 'O', 'F', 'P', 'S', 'Cl', 'Br'], | ||
bond_types=(BondType.SINGLE, BondType.DOUBLE, BondType.TRIPLE), | ||
early_stopping_timeout=np.inf, | ||
early_stopping_iterations=50, | ||
keep_n_best=4, | ||
timeout=None, | ||
num_of_generations=500, | ||
keep_history=True, | ||
n_jobs=-1, | ||
history_dir=None) | ||
|
||
self.graph_gen_params = graph_gen_params or GraphGenerationParams( | ||
adapter=MolAdapter(), | ||
advisor=MolChangeAdvisor()) | ||
|
||
self.gp_params = gp_params or GPAlgorithmParameters( | ||
pop_size=2, | ||
max_pop_size=2, | ||
multi_objective=False, | ||
genetic_scheme_type=GeneticSchemeTypesEnum.steady_state, | ||
elitism_type=ElitismTypesEnum.replace_worst, | ||
mutation_types=CHEMICAL_MUTATIONS, | ||
crossover_types=[CrossoverTypesEnum.none], | ||
adaptive_mutation_type=MutationAgentTypeEnum.bandit) | ||
|
||
def generate_optimized_molecules(self, scoring_function: ScoringFunction, number_molecules: int, | ||
starting_population: Optional[List[str]] = None) -> List[str]: | ||
objective = Objective( | ||
quality_metrics=lambda mol: -scoring_function.score(mol.get_smiles(aromatic=True)), | ||
is_multi_objective=False | ||
) | ||
self.gp_params.pop_size = max(number_molecules // 2, 100) | ||
self.gp_params.max_pop_size = min(self.gp_params.pop_size * 10, 1000) | ||
|
||
initial_graphs = load_init_population(scoring_function) | ||
initial_graphs = self.graph_gen_params.adapter.adapt(initial_graphs) | ||
|
||
# Build the optimizer | ||
optimiser = EvoGraphOptimizer(objective, | ||
initial_graphs, | ||
self.requirements, | ||
self.graph_gen_params, | ||
self.gp_params) | ||
optimiser.optimise(objective) | ||
history = optimiser.history | ||
|
||
# Take only the first graph's appearance in history | ||
individuals \ | ||
= list({hash(self.graph_gen_params.adapter.restore(ind.graph)): ind | ||
for gen in history.individuals | ||
for ind in reversed(list(gen))}.values()) | ||
|
||
top_individuals = sorted(individuals, | ||
key=lambda pos_ind: pos_ind.fitness, reverse=True)[:number_molecules] | ||
top_smiles = [MolAdapter().restore(ind.graph).get_smiles(aromatic=True) for ind in top_individuals] | ||
return top_smiles | ||
|
||
|
||
def visualize(path: str): | ||
with open(path) as json_file: | ||
results = json.load(json_file) | ||
print(f"Guacamol version: {results['guacamol_version']} \n" | ||
f"Benchmark version: {results['benchmark_suite_version']} \n") | ||
results = results['results'] | ||
for result in results: | ||
generated_molecules, scores = [[MolFromSmiles(smile) for smile, score in result['optimized_molecules'][:12]], | ||
[round(score, 3) for smile, score in result['optimized_molecules'][:12]]] | ||
benchmark_name = result['benchmark_name'] | ||
scores = [f"{benchmark_name} : {score}" for score in scores] | ||
image = Draw.MolsToGridImage(generated_molecules, | ||
legends=scores, | ||
molsPerRow=min(4, len(generated_molecules)), | ||
subImgSize=(1000, 1000), | ||
legendFontSize=50) | ||
image.show() | ||
image.save(f'{benchmark_name}_results.png') | ||
|
||
|
||
def get_launch_statistics(paths: List[str]): | ||
results = [] | ||
for path in paths: | ||
with open(path) as json_file: | ||
results.append(json.load(json_file)['results']) | ||
|
||
column_names = ['benchmark', 'mean', 'std', 'min', 'max', 'mean_time'] | ||
|
||
df = pd.DataFrame(columns=column_names) | ||
|
||
for bench_num in range(20): | ||
benchmark = results[0][bench_num]['benchmark_name'] | ||
scores = [] | ||
time_spent = [] | ||
for result in results: | ||
scores.append(result[bench_num]['score']) | ||
time_spent.append(result[bench_num]['execution_time']) | ||
bench_result = pd.DataFrame(data=[[benchmark, | ||
np.mean(scores), | ||
np.std(scores), | ||
np.min(scores), | ||
np.max(scores), | ||
np.mean(time_spent)]], | ||
columns=column_names) | ||
df = pd.concat([df, bench_result], ignore_index=True, axis=0) | ||
pd.set_option('display.max_columns', None) | ||
print(df) | ||
|
||
|
||
if __name__ == '__main__': | ||
# one launch takes more than 24h | ||
for launch in range(10): | ||
print(f'\nLaunch_num {launch}\n') | ||
assess_goal_directed_generation(GolemMoleculeGenerator(), | ||
benchmark_version='v2', | ||
json_output_file=f'output_goal_directed_{launch}.json') | ||
visualize('output_goal_directed_1.json') | ||
get_launch_statistics([f'output_goal_directed_{launch}.json' for launch in range(4)]) |
Oops, something went wrong.