Skip to content

Commit

Permalink
Molecule search example (#83)
Browse files Browse the repository at this point in the history
* Add molecular graph, adapter and some metrics

* Fix imports

* Add CL score

* Add draft add_atom mutation

* Add draft delete_atom, replace_bond mutation

* Add draft delete_bond, replace_atom mutation

* Add experiment draft

* Fix adapter

* Fix experiment, add metrics description

* Small mutations fix

* Minor changes

* Add cut_atom mutation

* Refactor MolAdvisor

* Add insert_carbon

* Add remove_group

* Add move_group

* Move utils to separate file

* PEP 8 and docstrings

* Add zinc normalized logp

* Minor

* Fix extending initial population

* Add guacomol benchmark

* Refactror guacamol experiment

* Add statistics and visualization for guacamol

* pep

* Add mutations list

* Add CLScorer

* Add download from github

* Review fixes

* Add integration test
  • Loading branch information
YamLyubov authored Jun 9, 2023
1 parent e231edf commit 3cc87ea
Show file tree
Hide file tree
Showing 24 changed files with 1,197 additions and 36 deletions.
Empty file.
10 changes: 10 additions & 0 deletions examples/molecule_search/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
SULFUR_DEFAULT_VALENCE = 6
MIN_LONG_CYCLE_SIZE = 6

# normalization constants, statistics from 250k_rndm_zinc_drugs_clean.smi dataset
ZINC_LOGP_MEAN = 2.4570953396190123
ZINC_LOGP_STD = 1.434324401111988
ZINC_SA_MEAN = -3.0525811293166134
ZINC_SA_STD = 0.8335207024513095
ZINC_CYCLE_MEAN = -0.0485696876403053
ZINC_CYCLE_STD = 0.2860212110245455
Binary file not shown.
Binary file not shown.
185 changes: 185 additions & 0 deletions examples/molecule_search/experiment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
import os.path
from datetime import timedelta
from io import StringIO
from pathlib import Path
from typing import Type, Optional, Sequence, List, Iterable, Callable, Dict

import numpy as np
from rdkit.Chem import Draw
from rdkit.Chem.rdchem import BondType

from examples.molecule_search.mol_adapter import MolAdapter
from examples.molecule_search.mol_advisor import MolChangeAdvisor
from examples.molecule_search.mol_graph import MolGraph
from examples.molecule_search.mol_graph_parameters import MolGraphRequirements
from examples.molecule_search.mol_mutations import CHEMICAL_MUTATIONS
from examples.molecule_search.mol_metrics import normalized_sa_score, penalised_logp, qed_score, \
normalized_logp, CLScorer
from golem.core.dag.verification_rules import has_no_self_cycled_nodes, has_no_isolated_components, \
has_no_isolated_nodes
from golem.core.optimisers.adaptive.operator_agent import MutationAgentTypeEnum
from golem.core.optimisers.genetic.gp_optimizer import EvoGraphOptimizer
from golem.core.optimisers.genetic.gp_params import GPAlgorithmParameters
from golem.core.optimisers.genetic.operators.crossover import CrossoverTypesEnum
from golem.core.optimisers.genetic.operators.elitism import ElitismTypesEnum
from golem.core.optimisers.genetic.operators.inheritance import GeneticSchemeTypesEnum
from golem.core.optimisers.objective import Objective
from golem.core.optimisers.opt_history_objects.opt_history import OptHistory
from golem.core.optimisers.optimizer import GraphGenerationParams, GraphOptimizer
from golem.visualisation.opt_viz import PlotTypesEnum, OptHistoryVisualizer
from golem.visualisation.opt_viz_extra import visualise_pareto


def get_methane() -> MolGraph:
methane = 'C'
return MolGraph.from_smiles(methane)


def get_all_mol_metrics() -> Dict[str, Callable]:
metrics = {'qed_score': qed_score,
'cl_score': CLScorer(),
'norm_sa_score': normalized_sa_score,
'penalised_logp': penalised_logp,
'norm_log_p': normalized_logp}
return metrics


def molecule_search_setup(optimizer_cls: Type[GraphOptimizer] = EvoGraphOptimizer,
max_heavy_atoms: int = 50,
atom_types: Optional[List[str]] = None,
bond_types: Sequence[BondType] = (BondType.SINGLE, BondType.DOUBLE, BondType.TRIPLE),
timeout: Optional[timedelta] = None,
num_iterations: Optional[int] = None,
pop_size: int = 20,
metrics: Optional[List[str]] = None,
initial_molecules: Optional[Sequence[MolGraph]] = None):
requirements = MolGraphRequirements(
max_heavy_atoms=max_heavy_atoms,
available_atom_types=atom_types or ['C', 'N', 'O', 'F', 'P', 'S', 'Cl', 'Br'],
bond_types=bond_types,
early_stopping_timeout=np.inf,
early_stopping_iterations=np.inf,
keep_n_best=4,
timeout=timeout,
num_of_generations=num_iterations,
keep_history=True,
n_jobs=1,
history_dir=os.path.join(os.path.curdir, 'history')
)
gp_params = GPAlgorithmParameters(
pop_size=pop_size,
max_pop_size=pop_size,
multi_objective=True,
genetic_scheme_type=GeneticSchemeTypesEnum.steady_state,
elitism_type=ElitismTypesEnum.replace_worst,
mutation_types=CHEMICAL_MUTATIONS,
crossover_types=[CrossoverTypesEnum.none],
adaptive_mutation_type=MutationAgentTypeEnum.bandit
)
graph_gen_params = GraphGenerationParams(
adapter=MolAdapter(),
rules_for_constraint=[has_no_self_cycled_nodes, has_no_isolated_components, has_no_isolated_nodes],
advisor=MolChangeAdvisor(),
)

metrics = metrics or ['qed_score']
all_metrics = get_all_mol_metrics()
objective = Objective(
quality_metrics={metric_name: all_metrics[metric_name] for metric_name in metrics},
is_multi_objective=len(metrics) > 1
)

initial_graphs = initial_molecules or [get_methane()]
initial_graphs = graph_gen_params.adapter.adapt(initial_graphs)

# Build the optimizer
optimiser = optimizer_cls(objective, initial_graphs, requirements, graph_gen_params, gp_params)
return optimiser, objective


def visualize_results(molecules: Iterable[MolGraph],
objective: Objective,
history: OptHistory,
save_path: Optional[str] = None):
save_path = save_path or os.path.join(os.path.curdir, 'visualisations')
Path(save_path).mkdir(exist_ok=True)

if objective.is_multi_objective:
visualise_pareto(history.archive_history[-1], objectives_names=objective.metric_names[:2], folder=save_path)

visualizer = OptHistoryVisualizer(history)
visualization = PlotTypesEnum.fitness_line.value(visualizer.history, visualizer.visuals_params)
visualization.visualize(dpi=100, save_path=os.path.join(save_path, 'fitness_line.png'))

rw_molecules = [mol.get_rw_molecule() for mol in set(molecules)]
objectives = [objective.format_fitness(objective(mol)) for mol in set(molecules)]
image = Draw.MolsToGridImage(rw_molecules,
legends=objectives,
molsPerRow=min(4, len(rw_molecules)),
subImgSize=(1000, 1000),
legendFontSize=50)
image.show()
image.save(os.path.join(save_path, 'best_molecules.png'))


def run_experiment(optimizer_setup: Callable,
optimizer_cls: Type[GraphOptimizer] = EvoGraphOptimizer,
max_heavy_atoms: int = 50,
atom_types: Optional[List[str]] = None,
bond_types: Sequence[BondType] = (BondType.SINGLE, BondType.DOUBLE, BondType.TRIPLE),
initial_molecules: Optional[Sequence[MolGraph]] = None,
pop_size: int = 20,
metrics: Optional[List[str]] = None,
num_trials: int = 1,
trial_timeout: Optional[int] = None,
trial_iterations: Optional[int] = None,
visualize: bool = False
):
log = StringIO()
atom_types = atom_types or ['C', 'N', 'O', 'F', 'P', 'S', 'Cl', 'Br']
metrics = metrics or ['qed_score']
trial_results = []
experiment_id = f'Experiment [metrics={", ".join(metrics)} pop_size={pop_size}]\n'
for trial in range(num_trials):
optimizer, objective = optimizer_setup(optimizer_cls,
max_heavy_atoms,
atom_types,
bond_types,
trial_timeout,
trial_iterations,
pop_size,
metrics,
initial_molecules)
found_graphs = optimizer.optimise(objective)
history = optimizer.history
if visualize:
molecules = [MolAdapter().restore(graph) for graph in found_graphs]
save_path = os.path.join(os.path.curdir,
'visualisations',
f'trial_{trial}_pop_size_{pop_size}_{"_".join(metrics)}')
visualize_results(set(molecules), objective, history, save_path)
Path("results").mkdir(exist_ok=True)
history.save(f'./results/trial_{trial}_pop_size_{pop_size}_{"_".join(metrics)}.json')
trial_results.extend(history.final_choices)

# Compute mean & std for metrics of trials
ff = objective.format_fitness
trial_metrics = np.array([ind.fitness.values for ind in trial_results])
trial_metrics_mean = trial_metrics.mean(axis=0)
trial_metrics_std = trial_metrics.std(axis=0)
print(f'{experiment_id} finished with metrics:\n'
f'mean={ff(trial_metrics_mean)}\n'
f' std={ff(trial_metrics_std)}',
file=log)
print(log.getvalue())
return log.getvalue()


if __name__ == '__main__':
run_experiment(molecule_search_setup,
max_heavy_atoms=38,
trial_iterations=100,
pop_size=100,
metrics=['cl_score'],
visualize=True,
num_trials=10)
168 changes: 168 additions & 0 deletions examples/molecule_search/guacamol_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
import json
from typing import Optional, List

import joblib
import numpy as np
import pandas as pd
from guacamol.assess_goal_directed_generation import assess_goal_directed_generation
from guacamol.goal_directed_generator import GoalDirectedGenerator
from guacamol.scoring_function import ScoringFunction
from joblib import delayed
from rdkit.Chem import Draw, MolFromSmiles
from rdkit.Chem.rdchem import BondType

from examples.molecule_search.mol_adapter import MolAdapter
from examples.molecule_search.mol_advisor import MolChangeAdvisor
from examples.molecule_search.mol_graph import MolGraph
from examples.molecule_search.mol_graph_parameters import MolGraphRequirements
from examples.molecule_search.mol_mutations import CHEMICAL_MUTATIONS
from golem.core.optimisers.adaptive.operator_agent import MutationAgentTypeEnum
from golem.core.optimisers.genetic.gp_optimizer import EvoGraphOptimizer
from golem.core.optimisers.genetic.gp_params import GPAlgorithmParameters
from golem.core.optimisers.genetic.operators.crossover import CrossoverTypesEnum
from golem.core.optimisers.genetic.operators.elitism import ElitismTypesEnum
from golem.core.optimisers.genetic.operators.inheritance import GeneticSchemeTypesEnum
from golem.core.optimisers.objective import Objective
from golem.core.optimisers.optimizer import GraphGenerationParams


def load_init_population(scoring_function: ScoringFunction,
n_jobs: int = -1,
path=".\\data\\guacamol_v1_all.smiles"):
""" Original code:
https://github.com/BenevolentAI/guacamol_baselines/blob/master/graph_ga/goal_directed_generation.py"""
with open(path, "r") as f:
smiles_list = f.readlines()
joblist = [delayed(scoring_function.score)(smile) for smile in smiles_list]
scores = joblib.Parallel(n_jobs=n_jobs)(joblist)
scored_smiles = list(zip(scores, smiles_list))
scored_smiles = sorted(scored_smiles, key=lambda x: x[0], reverse=True)
best_smiles = [smile for score, smile in scored_smiles][:100]
init_pop = [MolGraph.from_smiles(smile) for smile in best_smiles]
return init_pop


class GolemMoleculeGenerator(GoalDirectedGenerator):
""" You need to download Guacamol all smiles dataset from https://figshare.com/projects/GuacaMol/56639"""
def __init__(self,
requirements: Optional[MolGraphRequirements] = None,
graph_gen_params: Optional[GraphGenerationParams] = None,
gp_params: Optional[GPAlgorithmParameters] = None):
self.requirements = requirements or MolGraphRequirements(
max_heavy_atoms=50,
available_atom_types=['C', 'N', 'O', 'F', 'P', 'S', 'Cl', 'Br'],
bond_types=(BondType.SINGLE, BondType.DOUBLE, BondType.TRIPLE),
early_stopping_timeout=np.inf,
early_stopping_iterations=50,
keep_n_best=4,
timeout=None,
num_of_generations=500,
keep_history=True,
n_jobs=-1,
history_dir=None)

self.graph_gen_params = graph_gen_params or GraphGenerationParams(
adapter=MolAdapter(),
advisor=MolChangeAdvisor())

self.gp_params = gp_params or GPAlgorithmParameters(
pop_size=2,
max_pop_size=2,
multi_objective=False,
genetic_scheme_type=GeneticSchemeTypesEnum.steady_state,
elitism_type=ElitismTypesEnum.replace_worst,
mutation_types=CHEMICAL_MUTATIONS,
crossover_types=[CrossoverTypesEnum.none],
adaptive_mutation_type=MutationAgentTypeEnum.bandit)

def generate_optimized_molecules(self, scoring_function: ScoringFunction, number_molecules: int,
starting_population: Optional[List[str]] = None) -> List[str]:
objective = Objective(
quality_metrics=lambda mol: -scoring_function.score(mol.get_smiles(aromatic=True)),
is_multi_objective=False
)
self.gp_params.pop_size = max(number_molecules // 2, 100)
self.gp_params.max_pop_size = min(self.gp_params.pop_size * 10, 1000)

initial_graphs = load_init_population(scoring_function)
initial_graphs = self.graph_gen_params.adapter.adapt(initial_graphs)

# Build the optimizer
optimiser = EvoGraphOptimizer(objective,
initial_graphs,
self.requirements,
self.graph_gen_params,
self.gp_params)
optimiser.optimise(objective)
history = optimiser.history

# Take only the first graph's appearance in history
individuals \
= list({hash(self.graph_gen_params.adapter.restore(ind.graph)): ind
for gen in history.individuals
for ind in reversed(list(gen))}.values())

top_individuals = sorted(individuals,
key=lambda pos_ind: pos_ind.fitness, reverse=True)[:number_molecules]
top_smiles = [MolAdapter().restore(ind.graph).get_smiles(aromatic=True) for ind in top_individuals]
return top_smiles


def visualize(path: str):
with open(path) as json_file:
results = json.load(json_file)
print(f"Guacamol version: {results['guacamol_version']} \n"
f"Benchmark version: {results['benchmark_suite_version']} \n")
results = results['results']
for result in results:
generated_molecules, scores = [[MolFromSmiles(smile) for smile, score in result['optimized_molecules'][:12]],
[round(score, 3) for smile, score in result['optimized_molecules'][:12]]]
benchmark_name = result['benchmark_name']
scores = [f"{benchmark_name} : {score}" for score in scores]
image = Draw.MolsToGridImage(generated_molecules,
legends=scores,
molsPerRow=min(4, len(generated_molecules)),
subImgSize=(1000, 1000),
legendFontSize=50)
image.show()
image.save(f'{benchmark_name}_results.png')


def get_launch_statistics(paths: List[str]):
results = []
for path in paths:
with open(path) as json_file:
results.append(json.load(json_file)['results'])

column_names = ['benchmark', 'mean', 'std', 'min', 'max', 'mean_time']

df = pd.DataFrame(columns=column_names)

for bench_num in range(20):
benchmark = results[0][bench_num]['benchmark_name']
scores = []
time_spent = []
for result in results:
scores.append(result[bench_num]['score'])
time_spent.append(result[bench_num]['execution_time'])
bench_result = pd.DataFrame(data=[[benchmark,
np.mean(scores),
np.std(scores),
np.min(scores),
np.max(scores),
np.mean(time_spent)]],
columns=column_names)
df = pd.concat([df, bench_result], ignore_index=True, axis=0)
pd.set_option('display.max_columns', None)
print(df)


if __name__ == '__main__':
# one launch takes more than 24h
for launch in range(10):
print(f'\nLaunch_num {launch}\n')
assess_goal_directed_generation(GolemMoleculeGenerator(),
benchmark_version='v2',
json_output_file=f'output_goal_directed_{launch}.json')
visualize('output_goal_directed_1.json')
get_launch_statistics([f'output_goal_directed_{launch}.json' for launch in range(4)])
Loading

0 comments on commit 3cc87ea

Please sign in to comment.