From fd976763ce6347a5c306150ea025f6c0bab8d5eb Mon Sep 17 00:00:00 2001 From: tigranfah Date: Sat, 27 Jul 2024 11:18:58 +0400 Subject: [PATCH] refactor mol_opt --- .../mol_opt/chemlactica_125m_hparams.yaml | 22 +- chemlactica/mol_opt/no_slurm_hparam_search.py | 119 --------- chemlactica/mol_opt/optimization.py | 238 ------------------ ..._search.py => optimization_run_example.py} | 32 +-- chemlactica/mol_opt/oracle_estimators.py | 96 ------- chemlactica/mol_opt/utils.py | 2 +- 6 files changed, 18 insertions(+), 491 deletions(-) delete mode 100644 chemlactica/mol_opt/no_slurm_hparam_search.py delete mode 100644 chemlactica/mol_opt/optimization.py rename chemlactica/mol_opt/{hparam_search.py => optimization_run_example.py} (57%) delete mode 100644 chemlactica/mol_opt/oracle_estimators.py diff --git a/chemlactica/mol_opt/chemlactica_125m_hparams.yaml b/chemlactica/mol_opt/chemlactica_125m_hparams.yaml index f9b8ad3..04227b4 100644 --- a/chemlactica/mol_opt/chemlactica_125m_hparams.yaml +++ b/chemlactica/mol_opt/chemlactica_125m_hparams.yaml @@ -1,15 +1,12 @@ -# checkpoint_path: /nfs/dgx/raid/chem/checkpoints/facebook/galactica-125m/9954e52e400b43d18d3a40f6/checkpoint-20480 -# checkpoint_path: /nfs/dgx/raid/chem/checkpoints/facebook/galactica-125m/9954e52e400b43d18d3a40f6/checkpoint-12288 -checkpoint_path: /home/admin/checkpoints/facebook/galactica-125m/1f289ff103034364bd27e1c3/checkpoint-18000 -tokenizer_path: /home/admin/tigran/ChemLactica/chemlactica/tokenizer/ChemLacticaTokenizer66 -pool_size: 50 +checkpoint_path: /nfs/dgx/raid/chem/checkpoints/facebook/galactica-125m/1f289ff103034364bd27e1c3/checkpoint-18000 +tokenizer_path: ChemLactica/chemlactica/tokenizer/ChemLacticaTokenizer66 +pool_size: 10 validation_perc: 0.2 num_mols: 0 -num_similars: 1 +num_similars: 5 num_gens_per_iter: 200 device: cuda:0 -sim_range: [0.8, 0.9] -# qed_range: [0.5, 0.9] +sim_range: [0.4, 0.9] num_processes: 8 generation_batch_size: 200 eos_token: "" @@ -21,18 +18,19 @@ generation_config: do_sample: true eos_token_id: 20 -strategy: [default] +strategy: [rej-sample-v2] rej_sample_config: train_tol_level: 3 - checkpoints_dir: ./ - max_learning_rate: 0.00001 + checkpoints_dir: checkpoints + max_learning_rate: 0.0001 + lr_end: 0 train_batch_size: 2 gradient_accumulation_steps: 8 weight_decay: 0.1 adam_beta1: 0.9 adam_beta2: 0.999 - warmup_steps: 0 + warmup_steps: 10 global_gradient_norm: 1.0 dataloader_num_workers: 1 max_seq_length: 2048 diff --git a/chemlactica/mol_opt/no_slurm_hparam_search.py b/chemlactica/mol_opt/no_slurm_hparam_search.py deleted file mode 100644 index 219c0c9..0000000 --- a/chemlactica/mol_opt/no_slurm_hparam_search.py +++ /dev/null @@ -1,119 +0,0 @@ -import submitit -import subprocess -import itertools as it -import datetime -import yaml -import os -import copy -import time -import torch - - -def is_gpu_being_used(gpu_id): - try: - # Run the nvidia-smi command - cmd = ['nvidia-smi','-i',f"{gpu_id}"] - output = subprocess.check_output(cmd) - output = output.decode('utf-8') - if "No running processes found" in output: - return False - else: - return True - - except subprocess.CalledProcessError as e: - print(f"Error executing nvidia-smi command: {e}") - - -def create_hparam_configs(config_file_path): - config_tune = yaml.safe_load(open("hparams_tune.yaml")) - config_merged = {} - for key, value in config_tune["parameters"].items(): - if type(value) == list: - config_merged[key] = value - else: - for k, v in value.items(): - config_merged[key+'+'+k] = v - - config_default = yaml.safe_load(open(config_file_path)) - hparam_names = list(config_merged.keys()) - all_configs = [] - for params in it.product(*config_merged.values()): - # pprint(params) - # pprint(hparam_names) - config = copy.deepcopy(config_default) - for i, p in enumerate(params): - if '+' in hparam_names[i]: - a, b = hparam_names[i].split("+") - config[a][b] = p - else: - config[hparam_names[i]] = p - # pprint(params) - # pprint(config) - all_configs.append(config) - # print(config) - return all_configs - - -if __name__ == "__main__": - n_runs = 3 - - config_file_path = "chemlactica_125m_hparams.yaml" - # config_file_path = "main/chemlactica/chemma_2b_hparams.yaml" - hparam_configs = create_hparam_configs(config_file_path) - # infer_config = [yaml.safe_load(open(config_file_path))] - model_name = "-".join(config_file_path.split("/")[-1].split("_")[:2]) - gpu_indices = [0, 1, 2, 3, 4, 5, 6, 7] - - index = 0 - while index < len(hparam_configs): - free_gpu_index = None - for gpu_index in gpu_indices: - gpu_is_free = True - print(f"Checking gpu: {gpu_index}") - for _ in range(10): - if is_gpu_being_used(gpu_index): - gpu_is_free = False - break - time.sleep(1) - if gpu_is_free: - free_gpu_index = gpu_index - print(f"gpu: {gpu_index} is free") - break - else: - print(f"gpu: {gpu_index} is being used") - if free_gpu_index is not None: - print(f"found a free gpu {free_gpu_index}, putting a job") - executor = submitit.LocalExecutor(folder="/home/admin/tigran/slurm_jobs/PMO/job_%j") - executor.update_parameters( - name="chemlactica-pmo", timeout_min=n_runs * 12 * 60, - visible_gpus=[free_gpu_index], - gpus_per_node=1, nodes=1, mem_gb=80, cpus_per_task=8, - slurm_array_parallelism=10 - ) - jobs = [] - with executor.batch(): - current_hparams = [hparam_configs[index]] - for config in current_hparams: - formatted_date_time = datetime.datetime.now().strftime("%Y-%m-%d") - base = f"results/{formatted_date_time}" - os.makedirs(base, exist_ok=True) - v = 0 - name = model_name + "-" + "+".join(config["strategy"]) - while os.path.exists(os.path.join(base, f"{name}-{v}-hparam-search")): - v += 1 - output_dir = os.path.join(base, f"{name}-{v}-hparam-search") - os.makedirs(output_dir, exist_ok=True) - yaml.safe_dump(config, open(os.path.join(output_dir, "hparams.yaml"), "w")) - function = submitit.helpers.CommandFunction([ - 'python3', 'hparam_search.py', - '--config_default', os.path.join(output_dir, "hparams.yaml"), - '--output_dir', output_dir, - '--n_runs', str(n_runs), - ]) - print(' '.join(function.command)) - job = executor.submit(function) - jobs.append(job) - for job in jobs: - print(job.job_id) - index += 1 - free_gpu_index = None \ No newline at end of file diff --git a/chemlactica/mol_opt/optimization.py b/chemlactica/mol_opt/optimization.py deleted file mode 100644 index 9cb2897..0000000 --- a/chemlactica/mol_opt/optimization.py +++ /dev/null @@ -1,238 +0,0 @@ -from typing import List -import torch -from datasets import Dataset -import gc -import shutil -from trl import SFTTrainer -from transformers import OPTForCausalLM -from chemlactica.mol_opt.utils import OptimEntry, MoleculeEntry, Pool -from chemlactica.mol_opt.tunning import get_training_arguments, get_optimizer_and_lr_scheduler, CustomEarlyStopCallback, CustomModelSelectionCallback - - -def create_similar_mol_entries(pool, mol_entry, num_similars): - similar_entries = [e.last_entry for e in pool.random_subset(num_similars)] - count = 0 - valid_similar_entries = [] - for similar_entry in similar_entries: - if count >= num_similars: - break - if similar_entry == mol_entry: - continue - valid_similar_entries.append(similar_entry) - count += 1 - return valid_similar_entries - - -def create_optimization_entries(num_entries, pool, config): - optim_entries = [] - for i in range(num_entries): - mol_entries = [e.last_entry for e in pool.random_subset(config["num_mols"])] - entries = [] - for mol_entry in mol_entries: - similar_mol_entries = create_similar_mol_entries(pool, mol_entry, num_similars=config["num_similars"]) - mol_entry.similar_mol_entries = similar_mol_entries - entries.append(mol_entry) - optim_entries.append(OptimEntry(None, entries)) - return optim_entries - - -def create_molecule_entry(output_text): - start_smiles_tag, end_smiles_tag = "[START_SMILES]", "[END_SMILES]" - start_ind = output_text.rfind(start_smiles_tag) - end_ind = output_text.rfind(end_smiles_tag) - if start_ind == -1 or end_ind == -1: - return None - generated_smiles = output_text[start_ind+len(start_smiles_tag):end_ind] - if len(generated_smiles) == 0: - return None - - try: - molecule = MoleculeEntry( - smiles=generated_smiles, - ) - return molecule - except: - return None - - -def optimize( - model, tokenizer, - oracle, config, - additional_properties={} - ): - file = open(config["log_dir"], "w") - print("config", config) - # print("molecule generation arguments", config["generation_config"]) - pool = Pool(config["pool_size"], validation_perc=config["validation_perc"]) - - config["generation_config"]["temperature"] = config["generation_temperature"][0] - - if "rej-sample-v2" in config["strategy"]: - training_args = get_training_arguments(config["rej_sample_config"]) - effective_batch_size = config["rej_sample_config"]["gradient_accumulation_steps"] * config["rej_sample_config"]["train_batch_size"] - num_single_train_steps = config["rej_sample_config"]["num_train_epochs"] * ((1 - config["validation_perc"]) * config["pool_size"] / effective_batch_size) - max_num_trains = oracle.max_oracle_calls / (config["rej_sample_config"]["train_tol_level"] * config["num_gens_per_iter"]) - max_num_train_steps = int(max_num_trains * num_single_train_steps) - optimizer, lr_scheduler = get_optimizer_and_lr_scheduler(model, config["rej_sample_config"], max_num_train_steps) - max_score = 0 - tol_level = 0 - num_iter = 0 - prev_train_iter = 0 - while True: - model.eval() - new_best_molecule_generated = False - iter_unique_optim_entries: List[OptimEntry] = {} - while len(iter_unique_optim_entries) < config["num_gens_per_iter"]: - optim_entries = create_optimization_entries( - config["generation_batch_size"], pool, - config=config - ) - for i in range(len(optim_entries)): - last_entry = MoleculeEntry(smiles="") - last_entry.similar_mol_entries = create_similar_mol_entries( - pool, last_entry, config["num_similars"] - ) - for prop_name, prop_spec in additional_properties.items(): - last_entry.add_props[prop_name] = prop_spec - optim_entries[i].last_entry = last_entry - - prompts = [ - optim_entry.to_prompt( - is_generation=True, include_oracle_score=prev_train_iter != 0, - config=config, max_score=max_score - ) - for optim_entry in optim_entries - ] - output_texts = [] - data = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) - if type(model) == OPTForCausalLM: - del data["token_type_ids"] - for key, value in data.items(): - data[key] = value[:, -2048 + config["generation_config"]["max_new_tokens"]:] - output = model.generate( - **data, - **config["generation_config"] - ) - gc.collect() - torch.cuda.empty_cache() - output_texts.extend(tokenizer.batch_decode(output)) - - current_unique_optim_entries = {} - # with multiprocessing.Pool(processes=config["num_processes"]) as pol: - for i, molecule in enumerate(map(create_molecule_entry, output_texts)): - if molecule and not optim_entries[i].contains_entry(molecule): - if molecule.smiles not in oracle.mol_buffer and molecule.smiles not in current_unique_optim_entries: - molecule.similar_mol_entries = optim_entries[i].last_entry.similar_mol_entries - for prop_name, prop_spec in additional_properties.items(): - molecule.add_props[prop_name] = prop_spec - molecule.add_props[prop_name]["value"] = molecule.add_props[prop_name]["calculate_value"](molecule) - optim_entries[i].last_entry = molecule - current_unique_optim_entries[molecule.smiles] = optim_entries[i] - - num_of_molecules_to_score = min(len(current_unique_optim_entries), config["num_gens_per_iter"] - len(iter_unique_optim_entries)) - current_unique_smiles_list = list(current_unique_optim_entries.keys())[:num_of_molecules_to_score] - current_unique_optim_entries = {smiles: current_unique_optim_entries[smiles] for smiles in current_unique_smiles_list} - - if getattr(oracle, "takes_entry", False): - oracle_scores = oracle([current_unique_optim_entries[smiles].last_entry for smiles in current_unique_smiles_list]) - else: - oracle_scores = oracle(current_unique_smiles_list) - - for smiles, oracle_score in zip(current_unique_smiles_list, oracle_scores): - current_unique_optim_entries[smiles].last_entry.score = oracle_score - iter_unique_optim_entries[smiles] = current_unique_optim_entries[smiles] - file.write(f"generated smiles: {smiles}, score: {current_unique_optim_entries[smiles].last_entry.score:.4f}\n") - if max_score >= config["max_possible_oracle_score"] - 1e-2 or current_unique_optim_entries[smiles].last_entry.score > max_score: - max_score = max(max_score, current_unique_optim_entries[smiles].last_entry.score) - new_best_molecule_generated = True - - print(f"Iter unique optim entries: {len(iter_unique_optim_entries)}, budget: {len(oracle)}") - - if oracle.finish: - break - - if oracle.finish: - break - initial_num_iter = num_iter - num_iter = len(oracle.mol_buffer) // config["num_gens_per_iter"] - if num_iter > initial_num_iter: - tol_level += 1 - - if new_best_molecule_generated: - tol_level = 0 - - print(f"num_iter: {num_iter}, tol_level: {tol_level}, prev_train_iter: {prev_train_iter}") - if num_iter > initial_num_iter: - config["generation_config"]["temperature"] += config["num_gens_per_iter"] / (oracle.budget - config["num_gens_per_iter"]) * (config["generation_temperature"][1] - config["generation_temperature"][0]) - print(f"Generation temperature: {config['generation_config']['temperature']}") - - # diversity_score = 1 / (1 + math.log(1 + repeated_max_score) / math.log(10)) - pool.add(list(iter_unique_optim_entries.values())) - file.write("Pool\n") - for i, optim_entry in enumerate(pool.optim_entries): - file.write(f"\t{i} smiles: {optim_entry.last_entry.smiles}, score: {optim_entry.last_entry.score:.4f}\n") - - if "rej-sample-v2" in config["strategy"]: - # round_entries.extend(current_entries) - # round_entries = list(np.unique(round_entries))[::-1] - # top_k = int(len(all_entries) * config["rej_sample_config"]["rej_perc"]) - # if top_k >= config["rej_sample_config"]["num_samples_per_round"]: - if tol_level >= config["rej_sample_config"]["train_tol_level"]: - train_entries, validation_entries = pool.get_train_valid_entries() - print(f"Num of training examples: {len(train_entries)}, num of validation examples: {len(validation_entries)}.") - file.write("Training entries\n") - for i, optim_entry in enumerate(train_entries): - file.write(f"\t{i} smiles: {optim_entry.last_entry.smiles}, score: {optim_entry.last_entry.score:.4f}\n") - file.write("Validation entries\n") - for i, optim_entry in enumerate(validation_entries): - file.write(f"\t{i} smiles: {optim_entry.last_entry.smiles}, score: {optim_entry.last_entry.score:.4f}\n") - - train_dataset = Dataset.from_dict({ - "sample": [ - optim_entry.to_prompt( - is_generation=False, include_oracle_score=True, - config=config, max_score=config["max_possible_oracle_score"] - ) - for optim_entry in train_entries - ] - }) - validation_dataset = Dataset.from_dict({ - "sample": [ - optim_entry.to_prompt( - is_generation=False, include_oracle_score=True, - config=config, max_score=config["max_possible_oracle_score"] - ) - for optim_entry in validation_entries - ] - }) - train_dataset.shuffle(seed=42) - validation_dataset.shuffle(seed=42) - - # early_stopping_callback = CustomEarlyStopCallback( - # early_stopping_patience=1, - # early_stopping_threshold=0.0001 - # ) - model_selection_callback = CustomModelSelectionCallback() - - model.train() - trainer = SFTTrainer( - model=model, - train_dataset=train_dataset, - eval_dataset=validation_dataset, - formatting_func=lambda x: x["sample"], - args=training_args, - packing=config["rej_sample_config"]["packing"], - tokenizer=tokenizer, - max_seq_length=config["rej_sample_config"]["max_seq_length"], - # data_collator=collator, - callbacks=[model_selection_callback], - optimizers=[optimizer, lr_scheduler], - ) - trainer.train() - print(f"Loading the best model state dict with validation loss {model_selection_callback.best_validation_loss}") - model.load_state_dict(model_selection_callback.best_model_state_dict) - del model_selection_callback.best_model_state_dict - gc.collect() - torch.cuda.empty_cache() - tol_level = 0 - prev_train_iter = num_iter \ No newline at end of file diff --git a/chemlactica/mol_opt/hparam_search.py b/chemlactica/mol_opt/optimization_run_example.py similarity index 57% rename from chemlactica/mol_opt/hparam_search.py rename to chemlactica/mol_opt/optimization_run_example.py index 4730f78..d295736 100644 --- a/chemlactica/mol_opt/hparam_search.py +++ b/chemlactica/mol_opt/optimization_run_example.py @@ -4,28 +4,13 @@ import datetime import argparse import os -from utils import ConstraedTPSAOracle +from utils import ConstrainedTPSAOracle from typing import List from chemlactica.mol_opt.optimization import optimize os.environ["TOKENIZERS_PARALLELISM"] = "true" -def default_train_condition(num_iter, tol_level, prev_train_iter): - return num_iter - prev_train_iter >= 3 - - -def tolerance_train_condition(cur_tol_level, train_tol_level): - return cur_tol_level >= train_tol_level - - -def choose_train_condition(name): - return { - "default" : default_train_condition, - "tolerance": tolerance_train_condition - }[name] - - def parse_arguments(): parser = argparse.ArgumentParser() parser.add_argument("--run_name", type=str, required=False) @@ -39,17 +24,14 @@ def parse_arguments(): if __name__ == "__main__": args = parse_arguments() config = yaml.safe_load(open(args.config_default)) - print(config) model = AutoModelForCausalLM.from_pretrained(config["checkpoint_path"], torch_dtype=torch.bfloat16).to(config["device"]) tokenizer = AutoTokenizer.from_pretrained(config["tokenizer_path"], padding_side="left") seeds = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31] - oracle = ConstraedTPSAOracle(max_oracle_calls=15000) - for seed in seeds[:args.n_runs]: - config["log_dir"] = os.path.join(args.output_dir, "results_tpsa+weight+num_rungs.log") - config["rej_sample_config"]["should_train"] = choose_train_condition("tolerance") - optimize( - model, tokenizer, - oracle, config - ) \ No newline at end of file + oracle = ConstrainedTPSAOracle(max_oracle_calls=5000) + config["log_dir"] = os.path.join(args.output_dir, "results_tpsa+weight+num_rungs.log") + optimize( + model, tokenizer, + oracle, config + ) \ No newline at end of file diff --git a/chemlactica/mol_opt/oracle_estimators.py b/chemlactica/mol_opt/oracle_estimators.py deleted file mode 100644 index 3d64477..0000000 --- a/chemlactica/mol_opt/oracle_estimators.py +++ /dev/null @@ -1,96 +0,0 @@ -from typing import List -import time -from trl import SFTTrainer, DataCollatorForCompletionOnlyLM -from transformers import AutoModelForCausalLM, PreTrainedModel, AutoModel, AutoConfig, AutoTokenizer -import torch -import torch.nn as nn -import numpy as np -from chemlactica.mol_opt.utils import MoleculeEntry -from sklearn.linear_model import Ridge - - -def find_second_eos_token_indices(sequences, eos_token_id): - return torch.where(sequences[:, 1:] == eos_token_id) - - -def init_linear_layer(layer, emb_length): - torch.nn.init.normal_( - layer.weight, - mean=0.0, std=1 / np.sqrt(emb_length + 1) - ) - torch.nn.init.constant_(layer.bias, val=0.0) - return layer - - -class ScalarHeadLM(PreTrainedModel): - - def __init__(self, config): - super().__init__(config) - self.config = config - self.lm_backbone = AutoModel.from_pretrained( - config._name_or_path, - config=config - ) - self.scalar_head = nn.Linear(config.hidden_size, 1) - init_linear_layer(self.scalar_head) - - def forward(self, **kwargs): - output = self.lm_backbone(**kwargs) - return self.scalar_head(output.last_hidden_state) - - -class LinearFingerprintModel: - - def __init__(self): - self.emb_length = 2048 - self.linear = Ridge() - self.all_entries = [] - self.is_fit = False - - def __call__(self, mol_entries: List[MoleculeEntry]): - mol_embs = np.array([entry.fingerprint for entry in mol_entries]) - return self.linear.predict(mol_embs) - - def fit(self, mol_entries: List[MoleculeEntry]): - self.is_fit = True - start_time = time.time() - self.all_entries.extend(mol_entries) - mol_embs = np.array([entry.fingerprint for entry in self.all_entries]) - scores = np.array([entry.score for entry in self.all_entries]) - self.linear.fit(mol_embs, scores) - print(f"Fit time {time.time() - start_time:.4f}s") - - -class ScalarOracleApproximator: - - def __init__(self, config, tokenizer): - self.scalar_head_lm = ScalarHeadLM(config) - self.tokenizer = tokenizer - - def __call__(self, mol_entries): - prompts = [f"[START_SMILES]{e.smiles}[END_SMILES]" for e in mol_entries] - data = self.tokenizer(prompts, return_tensors="pt", padding=True).to(self.scalar_head_lm.device) - del data["token_type_ids"] - outputs = self.scalar_head_lm( - **data - ) - print(outputs) - - -class SFTOracleApproximator: - - def __init__(self, config, tokenizer, device): - self.ml = AutoModelForCausalLM.from_pretrained( - config._name_or_path, - config=config - ).to(device) - self.tokenizer = tokenizer - - -if __name__ == "__main__": - config = AutoConfig.from_pretrained("/nfs/dgx/raid/chem/checkpoints/facebook/galactica-125m/26d322857a184fcbafda5d4a/checkpoint-118784") - tokenizer = AutoTokenizer.from_pretrained("chemlactica/tokenizer/ChemLacticaTokenizer66", padding_side="left") - scalar_oracle_approx = ScalarOracleApproximator(config, tokenizer) - - mol_entries = [MoleculeEntry("CCC" + i * "C") for i in range(10)] - scalar_oracle_approx(mol_entries) \ No newline at end of file diff --git a/chemlactica/mol_opt/utils.py b/chemlactica/mol_opt/utils.py index ec6d2af..4bf89ea 100644 --- a/chemlactica/mol_opt/utils.py +++ b/chemlactica/mol_opt/utils.py @@ -86,7 +86,7 @@ def __hash__(self): return hash(self.smiles) -class ConstraedTPSAOracle: +class ConstrainedTPSAOracle: def __init__(self, max_oracle_calls: int): self.max_oracle_calls = max_oracle_calls self.freq_log = 100