-
Notifications
You must be signed in to change notification settings - Fork 41
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
194 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,194 @@ | ||
import copy | ||
import os | ||
import pathlib | ||
import socket | ||
from typing import Callable, Dict, List, Tuple, Union | ||
|
||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
from omegaconf import OmegaConf | ||
from rdkit import Chem, RDLogger | ||
from rdkit.Chem.rdchem import Mol as RDMol | ||
from torch import Tensor | ||
from torch.utils.data import Dataset | ||
|
||
from gflownet.algo.trajectory_balance import TrajectoryBalance | ||
from gflownet.config import Config | ||
from gflownet.data.replay_buffer import ReplayBuffer | ||
from gflownet.envs.graph_building_env import GraphBuildingEnv | ||
from gflownet.envs.mol_building_env import MolBuildingEnvContext | ||
from gflownet.models.graph_transformer import GraphTransformerGFN | ||
from gflownet.train import FlatRewards, GFNTask, GFNTrainer, RewardScalar | ||
|
||
|
||
class MakeRingsTask(GFNTask): | ||
"""A toy task where the reward is the number of rings in the molecule.""" | ||
|
||
def __init__( | ||
self, | ||
dataset: Dataset, | ||
cfg: Config, | ||
rng: np.random.Generator = None, | ||
wrap_model: Callable[[nn.Module], nn.Module] = None, | ||
): | ||
self._wrap_model = wrap_model | ||
self.rng = rng | ||
self.dataset = dataset | ||
|
||
def flat_reward_transform(self, y: Union[float, Tensor]) -> FlatRewards: | ||
return FlatRewards(y) | ||
|
||
def inverse_flat_reward_transform(self, rp): | ||
return rp | ||
|
||
def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Tensor]: | ||
return {"beta": torch.ones(n), "encoding": torch.ones(n, 1)} | ||
|
||
def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar: | ||
scalar_logreward = torch.as_tensor(flat_reward).squeeze().clamp(min=1e-30).log() | ||
return RewardScalar(scalar_logreward.flatten()) | ||
|
||
def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: | ||
rs = torch.tensor([m.GetRingInfo().NumRings() for m in mols]).float() | ||
return FlatRewards(rs.reshape((-1, 1))), torch.ones(len(mols)).bool() | ||
|
||
|
||
class MakeRingsTrainer(GFNTrainer): | ||
def set_default_hps(self, cfg: Config): | ||
cfg.hostname = socket.gethostname() | ||
cfg.num_workers = 8 | ||
cfg.opt.learning_rate = 1e-4 | ||
cfg.opt.weight_decay = 1e-8 | ||
cfg.opt.momentum = 0.9 | ||
cfg.opt.adam_eps = 1e-8 | ||
cfg.opt.lr_decay = 20_000 | ||
cfg.opt.clip_grad_type = "norm" | ||
cfg.opt.clip_grad_param = 10 | ||
cfg.algo.global_batch_size = 64 | ||
cfg.algo.offline_ratio = 0 | ||
cfg.model.num_emb = 128 | ||
cfg.model.num_layers = 4 | ||
|
||
cfg.algo.method = "TB" | ||
cfg.algo.max_nodes = 6 | ||
cfg.algo.sampling_tau = 0.9 | ||
cfg.algo.illegal_action_logreward = -75 | ||
cfg.algo.train_random_action_prob = 0.0 | ||
cfg.algo.valid_random_action_prob = 0.0 | ||
cfg.algo.tb.do_parameterize_p_b = True | ||
|
||
cfg.replay.use = False | ||
|
||
def setup_algo(self): | ||
assert self.cfg.algo.method == "TB" | ||
self.algo = TrajectoryBalance(self.env, self.ctx, self.rng, self.cfg) | ||
|
||
def setup_task(self): | ||
self.task = MakeRingsTask( | ||
dataset=self.training_data, | ||
cfg=self.cfg, | ||
rng=self.rng, | ||
wrap_model=self._wrap_for_mp, | ||
) | ||
|
||
def setup_model(self): | ||
model = GraphTransformerGFN( | ||
self.ctx, | ||
self.cfg, | ||
do_bck=self.cfg.algo.tb.do_parameterize_p_b, | ||
) | ||
self.model = model | ||
|
||
def setup_env_context(self): | ||
self.ctx = MolBuildingEnvContext( | ||
["C"], | ||
charges=[0], # disable charge | ||
chiral_types=[Chem.rdchem.ChiralType.CHI_UNSPECIFIED], # disable chirality | ||
num_rw_feat=0, | ||
max_nodes=self.cfg.algo.max_nodes, | ||
num_cond_dim=1, | ||
) | ||
|
||
def setup(self): | ||
RDLogger.DisableLog("rdApp.*") | ||
self.rng = np.random.default_rng(142857) | ||
self.env = GraphBuildingEnv() | ||
self.training_data = [] | ||
self.test_data = [] | ||
self.offline_ratio = 0 | ||
self.valid_offline_ratio = 0 | ||
self.replay_buffer = ReplayBuffer(self.cfg, self.rng) if self.cfg.replay.use else None | ||
self.setup_env_context() | ||
self.setup_algo() | ||
self.setup_task() | ||
self.setup_model() | ||
|
||
# Separate Z parameters from non-Z to allow for LR decay on the former | ||
Z_params = list(self.model.logZ.parameters()) | ||
non_Z_params = [i for i in self.model.parameters() if all(id(i) != id(j) for j in Z_params)] | ||
self.opt = torch.optim.Adam( | ||
non_Z_params, | ||
self.cfg.opt.learning_rate, | ||
(self.cfg.opt.momentum, 0.999), | ||
weight_decay=self.cfg.opt.weight_decay, | ||
eps=self.cfg.opt.adam_eps, | ||
) | ||
self.opt_Z = torch.optim.Adam(Z_params, self.cfg.algo.tb.Z_learning_rate, (0.9, 0.999)) | ||
self.lr_sched = torch.optim.lr_scheduler.LambdaLR(self.opt, lambda steps: 2 ** (-steps / self.cfg.opt.lr_decay)) | ||
self.lr_sched_Z = torch.optim.lr_scheduler.LambdaLR( | ||
self.opt_Z, lambda steps: 2 ** (-steps / self.cfg.algo.tb.Z_lr_decay) | ||
) | ||
|
||
self.sampling_tau = self.cfg.algo.sampling_tau | ||
if self.sampling_tau > 0: | ||
self.sampling_model = copy.deepcopy(self.model) | ||
else: | ||
self.sampling_model = self.model | ||
|
||
self.mb_size = self.cfg.algo.global_batch_size | ||
self.clip_grad_callback = { | ||
"value": lambda params: torch.nn.utils.clip_grad_value_(params, self.cfg.opt.clip_grad_param), | ||
"norm": lambda params: torch.nn.utils.clip_grad_norm_(params, self.cfg.opt.clip_grad_param), | ||
"none": lambda x: None, | ||
}[self.cfg.opt.clip_grad_type] | ||
|
||
# saving hyperparameters | ||
print("\n\nHyperparameters:\n") | ||
yaml = OmegaConf.to_yaml(self.cfg) | ||
print(yaml) | ||
with open(pathlib.Path(self.cfg.log_dir) / "hps.yaml", "w") as f: | ||
f.write(yaml) | ||
|
||
def step(self, loss: Tensor): | ||
loss.backward() | ||
for i in self.model.parameters(): | ||
self.clip_grad_callback(i) | ||
self.opt.step() | ||
self.opt.zero_grad() | ||
self.opt_Z.step() | ||
self.opt_Z.zero_grad() | ||
self.lr_sched.step() | ||
self.lr_sched_Z.step() | ||
if self.sampling_tau > 0: | ||
for a, b in zip(self.model.parameters(), self.sampling_model.parameters()): | ||
b.data.mul_(self.sampling_tau).add_(a.data * (1 - self.sampling_tau)) | ||
|
||
|
||
def main(): | ||
"""Example of how this model can be run outside of Determined""" | ||
hps = { | ||
"log_dir": "./logs/debug_run_mr3", | ||
"num_training_steps": 10_000, | ||
"num_workers": 8, | ||
"algo": {"tb": {"do_parameterize_p_b": False}}, | ||
} | ||
os.makedirs(hps["log_dir"], exist_ok=True) | ||
|
||
trial = MakeRingsTrainer(hps, torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")) | ||
trial.print_every = 1 | ||
trial.run() | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |