Skip to content

Commit

Permalink
tox
Browse files Browse the repository at this point in the history
  • Loading branch information
bengioe committed Jul 17, 2023
1 parent 468f770 commit 08b39be
Showing 1 changed file with 194 additions and 0 deletions.
194 changes: 194 additions & 0 deletions src/gflownet/tasks/make_rings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
import copy
import os
import pathlib
import socket
from typing import Callable, Dict, List, Tuple, Union

import numpy as np
import torch
import torch.nn as nn
from omegaconf import OmegaConf
from rdkit import Chem, RDLogger
from rdkit.Chem.rdchem import Mol as RDMol
from torch import Tensor
from torch.utils.data import Dataset

from gflownet.algo.trajectory_balance import TrajectoryBalance
from gflownet.config import Config
from gflownet.data.replay_buffer import ReplayBuffer
from gflownet.envs.graph_building_env import GraphBuildingEnv
from gflownet.envs.mol_building_env import MolBuildingEnvContext
from gflownet.models.graph_transformer import GraphTransformerGFN
from gflownet.train import FlatRewards, GFNTask, GFNTrainer, RewardScalar


class MakeRingsTask(GFNTask):
"""A toy task where the reward is the number of rings in the molecule."""

def __init__(
self,
dataset: Dataset,
cfg: Config,
rng: np.random.Generator = None,
wrap_model: Callable[[nn.Module], nn.Module] = None,
):
self._wrap_model = wrap_model
self.rng = rng
self.dataset = dataset

def flat_reward_transform(self, y: Union[float, Tensor]) -> FlatRewards:
return FlatRewards(y)

def inverse_flat_reward_transform(self, rp):
return rp

def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Tensor]:
return {"beta": torch.ones(n), "encoding": torch.ones(n, 1)}

def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar:
scalar_logreward = torch.as_tensor(flat_reward).squeeze().clamp(min=1e-30).log()
return RewardScalar(scalar_logreward.flatten())

def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]:
rs = torch.tensor([m.GetRingInfo().NumRings() for m in mols]).float()
return FlatRewards(rs.reshape((-1, 1))), torch.ones(len(mols)).bool()


class MakeRingsTrainer(GFNTrainer):
def set_default_hps(self, cfg: Config):
cfg.hostname = socket.gethostname()
cfg.num_workers = 8
cfg.opt.learning_rate = 1e-4
cfg.opt.weight_decay = 1e-8
cfg.opt.momentum = 0.9
cfg.opt.adam_eps = 1e-8
cfg.opt.lr_decay = 20_000
cfg.opt.clip_grad_type = "norm"
cfg.opt.clip_grad_param = 10
cfg.algo.global_batch_size = 64
cfg.algo.offline_ratio = 0
cfg.model.num_emb = 128
cfg.model.num_layers = 4

cfg.algo.method = "TB"
cfg.algo.max_nodes = 6
cfg.algo.sampling_tau = 0.9
cfg.algo.illegal_action_logreward = -75
cfg.algo.train_random_action_prob = 0.0
cfg.algo.valid_random_action_prob = 0.0
cfg.algo.tb.do_parameterize_p_b = True

cfg.replay.use = False

def setup_algo(self):
assert self.cfg.algo.method == "TB"
self.algo = TrajectoryBalance(self.env, self.ctx, self.rng, self.cfg)

def setup_task(self):
self.task = MakeRingsTask(
dataset=self.training_data,
cfg=self.cfg,
rng=self.rng,
wrap_model=self._wrap_for_mp,
)

def setup_model(self):
model = GraphTransformerGFN(
self.ctx,
self.cfg,
do_bck=self.cfg.algo.tb.do_parameterize_p_b,
)
self.model = model

def setup_env_context(self):
self.ctx = MolBuildingEnvContext(
["C"],
charges=[0], # disable charge
chiral_types=[Chem.rdchem.ChiralType.CHI_UNSPECIFIED], # disable chirality
num_rw_feat=0,
max_nodes=self.cfg.algo.max_nodes,
num_cond_dim=1,
)

def setup(self):
RDLogger.DisableLog("rdApp.*")
self.rng = np.random.default_rng(142857)
self.env = GraphBuildingEnv()
self.training_data = []
self.test_data = []
self.offline_ratio = 0
self.valid_offline_ratio = 0
self.replay_buffer = ReplayBuffer(self.cfg, self.rng) if self.cfg.replay.use else None
self.setup_env_context()
self.setup_algo()
self.setup_task()
self.setup_model()

# Separate Z parameters from non-Z to allow for LR decay on the former
Z_params = list(self.model.logZ.parameters())
non_Z_params = [i for i in self.model.parameters() if all(id(i) != id(j) for j in Z_params)]
self.opt = torch.optim.Adam(
non_Z_params,
self.cfg.opt.learning_rate,
(self.cfg.opt.momentum, 0.999),
weight_decay=self.cfg.opt.weight_decay,
eps=self.cfg.opt.adam_eps,
)
self.opt_Z = torch.optim.Adam(Z_params, self.cfg.algo.tb.Z_learning_rate, (0.9, 0.999))
self.lr_sched = torch.optim.lr_scheduler.LambdaLR(self.opt, lambda steps: 2 ** (-steps / self.cfg.opt.lr_decay))
self.lr_sched_Z = torch.optim.lr_scheduler.LambdaLR(
self.opt_Z, lambda steps: 2 ** (-steps / self.cfg.algo.tb.Z_lr_decay)
)

self.sampling_tau = self.cfg.algo.sampling_tau
if self.sampling_tau > 0:
self.sampling_model = copy.deepcopy(self.model)
else:
self.sampling_model = self.model

self.mb_size = self.cfg.algo.global_batch_size
self.clip_grad_callback = {
"value": lambda params: torch.nn.utils.clip_grad_value_(params, self.cfg.opt.clip_grad_param),
"norm": lambda params: torch.nn.utils.clip_grad_norm_(params, self.cfg.opt.clip_grad_param),
"none": lambda x: None,
}[self.cfg.opt.clip_grad_type]

# saving hyperparameters
print("\n\nHyperparameters:\n")
yaml = OmegaConf.to_yaml(self.cfg)
print(yaml)
with open(pathlib.Path(self.cfg.log_dir) / "hps.yaml", "w") as f:
f.write(yaml)

def step(self, loss: Tensor):
loss.backward()
for i in self.model.parameters():
self.clip_grad_callback(i)
self.opt.step()
self.opt.zero_grad()
self.opt_Z.step()
self.opt_Z.zero_grad()
self.lr_sched.step()
self.lr_sched_Z.step()
if self.sampling_tau > 0:
for a, b in zip(self.model.parameters(), self.sampling_model.parameters()):
b.data.mul_(self.sampling_tau).add_(a.data * (1 - self.sampling_tau))


def main():
"""Example of how this model can be run outside of Determined"""
hps = {
"log_dir": "./logs/debug_run_mr3",
"num_training_steps": 10_000,
"num_workers": 8,
"algo": {"tb": {"do_parameterize_p_b": False}},
}
os.makedirs(hps["log_dir"], exist_ok=True)

trial = MakeRingsTrainer(hps, torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
trial.print_every = 1
trial.run()


if __name__ == "__main__":
main()

0 comments on commit 08b39be

Please sign in to comment.