diff --git a/README.md b/README.md index 5f3dc65c..3b7a143d 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ The GNN model can be trained on a mix of existing data (offline) and self-genera ## Repo overview -- [algo](src/gflownet/algo), contains GFlowNet algorithms implementations (only [Trajectory Balance](https://arxiv.org/abs/2201.13259) for now), as well as some baselines. These implement how to sample trajectories from a model and compute the loss from trajectories. +- [algo](src/gflownet/algo), contains GFlowNet algorithms implementations ([Trajectory Balance](https://arxiv.org/abs/2201.13259), [SubTB](https://arxiv.org/abs/2209.12782), [Flow Matching](https://arxiv.org/abs/2106.04399)), as well as some baselines. These implement how to sample trajectories from a model and compute the loss from trajectories. - [data](src/gflownet/data), contains dataset definitions, data loading and data sampling utilities. - [envs](src/gflownet/envs), contains environment classes; a graph-building environment base, and a molecular graph context class. The base environment is agnostic to what kind of graph is being made, and the context class specifies mappings from graphs to objects (e.g. molecules) and torch geometric Data. - [examples](docs/examples), contains simple example implementations of GFlowNet. @@ -30,8 +30,11 @@ The GNN model can be trained on a mix of existing data (offline) and self-genera - [qm9](src/gflownet/tasks/qm9/qm9.py), temperature-conditional molecule sampler based on QM9's HOMO-LUMO gap data as a reward. - [seh_frag](src/gflownet/tasks/seh_frag.py), reproducing Bengio et al. 2021, fragment-based molecule design targeting the sEH protein - [seh_frag_moo](src/gflownet/tasks/seh_frag_moo.py), same as the above, but with multi-objective optimization (incl. QED, SA, and molecule weight objectives). -- [utils](src/gflownet/utils), contains utilities (multiprocessing). -- [`train.py`](src/gflownet/train.py), defines a general harness for training GFlowNet models. +- [utils](src/gflownet/utils), contains utilities (multiprocessing, metrics, conditioning). +- [`trainer.py`](src/gflownet/trainer.py), defines a general harness for training GFlowNet models. +- [`online_trainer.py`](src/gflownet/online_trainer.py), defines a typical online-GFN training loop. + +See [implementation notes](docs/implementation_notes.md) for more. ## Getting started @@ -57,6 +60,8 @@ To install or [depend on](https://matiascodesal.com/blog/how-use-git-repository- pip install git+https://github.com/recursionpharma/gflownet.git@v0.0.10 --find-links ... ``` +If package dependencies seem not to work, you may need to install the exact frozen versions listed `requirements/`, i.e. `pip install -r requirements/main_3.9.txt`. + ## Developing & Contributing TODO: Write Contributing.md. diff --git a/docs/implementation_notes.md b/docs/implementation_notes.md new file mode 100644 index 00000000..598e570c --- /dev/null +++ b/docs/implementation_notes.md @@ -0,0 +1,18 @@ +# Implementation notes + +This repo is centered around training GFlowNets that produce graphs. While we intend to specialize towards building molecules, we've tried to keep the implementation moderately agnostic to that fact, which makes it able to support other graph-generation environments. + +## Environment, Context, Task, Trainers + +We separate experiment concerns in four categories: +- The Environment is the graph abstraction that is common to all; think of it as the base definition of the MDP. +- The Context provides an interface between the agent and the environment, it + - maps graphs to torch_geometric `Data` + instances + - maps GraphActions to action indices + - produces action masks + - communicates to the model what inputs it should expect +- The Task class is responsible for computing the reward of a state, and for sampling conditioning information +- The Trainer class is responsible for instanciating everything, and running the training & testing loop + +Typically one would setup a new experiment by creating a class that inherits from `GFNTask` and a class that inherits from `GFNTrainer`. To implement a new MDP, one would create a class that inherits from `GraphBuildingEnvContext`. diff --git a/src/gflownet/algo/config.py b/src/gflownet/algo/config.py index 121ab5f4..bd0ce3de 100644 --- a/src/gflownet/algo/config.py +++ b/src/gflownet/algo/config.py @@ -91,6 +91,8 @@ class AlgoConfig: offline_ratio: float The ratio of samples drawn from `self.training_data` during training. The rest is drawn from `self.sampling_model` + valid_offline_ratio: float + Idem but for validation, and `self.test_data`. train_random_action_prob : float The probability of taking a random action during training valid_random_action_prob : float @@ -108,6 +110,7 @@ class AlgoConfig: max_edges: int = 128 illegal_action_logreward: float = -100 offline_ratio: float = 0.5 + valid_offline_ratio: float = 1 train_random_action_prob: float = 0.0 valid_random_action_prob: float = 0.0 valid_sample_cond_info: bool = True diff --git a/src/gflownet/algo/envelope_q_learning.py b/src/gflownet/algo/envelope_q_learning.py index f9fba9d2..4d694ae2 100644 --- a/src/gflownet/algo/envelope_q_learning.py +++ b/src/gflownet/algo/envelope_q_learning.py @@ -14,7 +14,7 @@ generate_forward_trajectory, ) from gflownet.models.graph_transformer import GraphTransformer, mlp -from gflownet.train import GFNTask +from gflownet.trainer import GFNTask from .graph_sampling import GraphSampler diff --git a/src/gflownet/algo/trajectory_balance.py b/src/gflownet/algo/trajectory_balance.py index 57af6056..22fe655e 100644 --- a/src/gflownet/algo/trajectory_balance.py +++ b/src/gflownet/algo/trajectory_balance.py @@ -19,7 +19,7 @@ GraphBuildingEnvContext, generate_forward_trajectory, ) -from gflownet.train import GFNAlgorithm +from gflownet.trainer import GFNAlgorithm class TrajectoryBalanceModel(nn.Module): @@ -363,7 +363,7 @@ def compute_batch_losses( traj_losses = self.subtb_loss_fast(log_p_F, log_p_B, per_graph_out[:, 0], clip_log_R, batch.traj_lens) # The position of the first graph of each trajectory first_graph_idx = torch.zeros_like(batch.traj_lens) - first_graph_idx = torch.cumsum(batch.traj_lens[:-1], 0, out=first_graph_idx[1:]) + torch.cumsum(batch.traj_lens[:-1], 0, out=first_graph_idx[1:]) log_Z = per_graph_out[first_graph_idx, 0] else: # Compute log numerator and denominator of the TB objective diff --git a/src/gflownet/config.py b/src/gflownet/config.py index 8de0ebeb..be4fa879 100644 --- a/src/gflownet/config.py +++ b/src/gflownet/config.py @@ -7,6 +7,7 @@ from gflownet.data.config import ReplayConfig from gflownet.models.config import ModelConfig from gflownet.tasks.config import TasksConfig +from gflownet.utils.config import ConditionalsConfig @dataclass @@ -51,12 +52,16 @@ class Config: ---------- log_dir : str The directory where to store logs, checkpoints, and samples. + device : str + The device to use for training (either "cpu" or "cuda[:]") seed : int The random seed validate_every : int The number of training steps after which to validate the model checkpoint_every : Optional[int] The number of training steps after which to checkpoint the model + print_every : int + The number of training steps after which to print the training loss start_at_step : int The training step to start at (default: 0) num_final_gen_steps : Optional[int] @@ -76,9 +81,11 @@ class Config: """ log_dir: str = MISSING + device: str = "cuda" seed: int = 0 validate_every: int = 1000 checkpoint_every: Optional[int] = None + print_every: int = 100 start_at_step: int = 0 num_final_gen_steps: Optional[int] = None num_training_steps: int = 10_000 @@ -92,3 +99,4 @@ class Config: opt: OptimizerConfig = OptimizerConfig() replay: ReplayConfig = ReplayConfig() task: TasksConfig = TasksConfig() + cond: ConditionalsConfig = ConditionalsConfig() diff --git a/src/gflownet/data/qm9.py b/src/gflownet/data/qm9.py index c68ead98..b26c29d2 100644 --- a/src/gflownet/data/qm9.py +++ b/src/gflownet/data/qm9.py @@ -3,6 +3,7 @@ import numpy as np import pandas as pd import rdkit.Chem as Chem +import torch from torch.utils.data import Dataset @@ -39,7 +40,10 @@ def __len__(self): return len(self.idcs) def __getitem__(self, idx): - return (Chem.MolFromSmiles(self.df["SMILES"][self.idcs[idx]]), self.df[self.target][self.idcs[idx]]) + return ( + Chem.MolFromSmiles(self.df["SMILES"][self.idcs[idx]]), + torch.tensor([self.df[self.target][self.idcs[idx]]]).float(), + ) def convert_h5(): diff --git a/src/gflownet/data/sampling_iterator.py b/src/gflownet/data/sampling_iterator.py index b69591cc..90b8b4db 100644 --- a/src/gflownet/data/sampling_iterator.py +++ b/src/gflownet/data/sampling_iterator.py @@ -12,6 +12,7 @@ from torch.utils.data import Dataset, IterableDataset from gflownet.data.replay_buffer import ReplayBuffer +from gflownet.envs.graph_building_env import GraphActionCategorical class SamplingIterator(IterableDataset): @@ -98,6 +99,7 @@ def __init__( self.random_action_prob = random_action_prob self.hindsight_ratio = hindsight_ratio self.train_it = init_train_iter + self.do_validate_batch = False # Turn this on for debugging self.log_molecule_smis = not hasattr(self.ctx, "not_a_molecule_env") # TODO: make this a proper flag # Slightly weird semantics, but if we're sampling x given some fixed cond info (data) @@ -130,6 +132,9 @@ def _idx_iterator(self): if n == 0: yield np.arange(0, 0) return + assert ( + self.offline_batch_size > 0 + ), "offline_batch_size must be > 0 if not streaming and len(data) > 0 (have you set ratio=0?)" if worker_info is None: # no multi-processing start, end, wid = 0, n, -1 else: # split the data into chunks (per-worker) @@ -237,6 +242,7 @@ def __iter__(self): log_rewards[torch.logical_not(is_valid)] = self.illegal_action_logreward # Computes some metrics + extra_info = {} if not self.sample_cond_info: # If we're using a dataset of preferences, the user may want to know the id of the preference for i, j in zip(trajs, idcs): @@ -253,7 +259,6 @@ def __iter__(self): {k: v[num_offline:] for k, v in deepcopy(cond_info).items()}, ) if num_online > 0: - extra_info = {} for hook in self.log_hooks: extra_info.update( hook( @@ -318,9 +323,37 @@ def __iter__(self): # TODO: we could very well just pass the cond_info dict to construct_batch above, # and the algo can decide what it wants to put in the batch object + # Only activate for debugging your environment or dataset (e.g. the dataset could be + # generating trajectories with illegal actions) + if self.do_validate_batch: + self.validate_batch(batch, trajs) + self.train_it += worker_info.num_workers if worker_info is not None else 1 yield batch + def validate_batch(self, batch, trajs): + for actions, atypes in [(batch.actions, self.ctx.action_type_order)] + ( + [(batch.bck_actions, self.ctx.bck_action_type_order)] + if hasattr(batch, "bck_actions") and hasattr(self.ctx, "bck_action_type_order") + else [] + ): + mask_cat = GraphActionCategorical( + batch, + [self.model._action_type_to_mask(t, batch) for t in atypes], + [self.model._action_type_to_key[t] for t in atypes], + [None for _ in atypes], + ) + masked_action_is_used = 1 - mask_cat.log_prob(actions, logprobs=mask_cat.logits) + num_trajs = len(trajs) + batch_idx = torch.arange(num_trajs, device=batch.x.device).repeat_interleave(batch.traj_lens) + first_graph_idx = torch.zeros_like(batch.traj_lens) + torch.cumsum(batch.traj_lens[:-1], 0, out=first_graph_idx[1:]) + if masked_action_is_used.sum() != 0: + invalid_idx = masked_action_is_used.argmax().item() + traj_idx = batch_idx[invalid_idx].item() + timestep = invalid_idx - first_graph_idx[traj_idx].item() + raise ValueError("Found an action that was masked out", trajs[traj_idx]["traj"][timestep]) + def log_generated(self, trajs, rewards, flat_rewards, cond_info): if self.log_molecule_smis: mols = [ diff --git a/src/gflownet/envs/mol_building_env.py b/src/gflownet/envs/mol_building_env.py index 44c99e4a..925e3456 100644 --- a/src/gflownet/envs/mol_building_env.py +++ b/src/gflownet/envs/mol_building_env.py @@ -28,6 +28,7 @@ def __init__( charges=[0, 1, -1], expl_H_range=[0, 1], allow_explicitly_aromatic=False, + allow_5_valence_nitrogen=False, num_rw_feat=8, max_nodes=None, max_edges=None, @@ -118,9 +119,11 @@ def __init__( BondType.AROMATIC: 1.5, } pt = Chem.GetPeriodicTable() + self.allow_5_valence_nitrogen = allow_5_valence_nitrogen self._max_atom_valence = { **{a: max(pt.GetValenceList(a)) for a in atoms}, - "N": 3, # We'll handle nitrogen valence later explicitly in graph_to_Data + # We'll handle nitrogen valence later explicitly in graph_to_Data + "N": 3 if not allow_5_valence_nitrogen else 5, "*": 0, # wildcard atoms have 0 valence until filled in } @@ -231,9 +234,10 @@ def graph_to_Data(self, g: Graph) -> gd.Data: max_atom_valence = self._max_atom_valence[ad.get("fill_wildcard", None) or ad["v"]] # Special rule for Nitrogen if ad["v"] == "N" and ad.get("charge", 0) == 1: - # This is definitely a heuristic, but to keep things simple we'll limit Nitrogen's valence to 3 (as + # This is definitely a heuristic, but to keep things simple we'll limit* Nitrogen's valence to 3 (as # per self._max_atom_valence) unless it is charged, then we make it 5. # This keeps RDKit happy (and is probably a good idea anyway). + # (* unless allow_5_valence_nitrogen is True, then it's just always 5) max_atom_valence = 5 max_valence[n] = max_atom_valence - abs(ad.get("charge", 0)) - ad.get("expl_H", 0) # Compute explicitly defined valence: @@ -281,14 +285,15 @@ def is_ok_non_edge(e): non_edge_index = torch.zeros((2, 0), dtype=torch.long) else: gc = nx.complement(g) - non_edge_index = torch.tensor([i for i in gc.edges if is_ok_non_edge(i)], dtype=torch.long).T.reshape( - (2, -1) + non_edge_index = ( + torch.tensor([i for i in gc.edges if is_ok_non_edge(i)], dtype=torch.long).reshape((-1, 2)).T ) data = gd.Data( x, edge_index, edge_attr, non_edge_index=non_edge_index, + stop_mask=torch.ones(1, 1) if len(g) > 0 else torch.zeros(1, 1), add_node_mask=add_node_mask, set_node_attr_mask=set_node_attr_mask, add_edge_mask=torch.ones((non_edge_index.shape[1], 1)), # Already filtered by is_ok_non_edge diff --git a/src/gflownet/models/graph_transformer.py b/src/gflownet/models/graph_transformer.py index 097e7e91..05f9b0e4 100644 --- a/src/gflownet/models/graph_transformer.py +++ b/src/gflownet/models/graph_transformer.py @@ -145,6 +145,27 @@ class GraphTransformerGFN(nn.Module): Outputs logits corresponding to the action types used by the env_ctx argument. """ + # The GraphTransformer outputs per-node, per-edge, and per-graph embeddings, this routes the + # embeddings to the right MLP + _action_type_to_graph_part = { + GraphActionType.Stop: "graph", + GraphActionType.AddNode: "node", + GraphActionType.SetNodeAttr: "node", + GraphActionType.AddEdge: "non_edge", + GraphActionType.SetEdgeAttr: "edge", + GraphActionType.RemoveNode: "node", + GraphActionType.RemoveNodeAttr: "node", + GraphActionType.RemoveEdge: "edge", + GraphActionType.RemoveEdgeAttr: "edge", + } + # The torch_geometric batch key each graph part corresponds to + _graph_part_to_key = { + "graph": None, + "node": "x", + "non_edge": "non_edge_index", + "edge": "edge_index", + } + def __init__( self, env_ctx, @@ -184,25 +205,8 @@ def __init__( GraphActionType.RemoveEdge: (num_edge_feat, 1), GraphActionType.RemoveEdgeAttr: (num_edge_feat, env_ctx.num_edge_attrs), } - # The GraphTransformer outputs per-node, per-edge, and per-graph embeddings, this routes the - # embeddings to the right MLP - self._action_type_to_graph_part = { - GraphActionType.Stop: "graph", - GraphActionType.AddNode: "node", - GraphActionType.SetNodeAttr: "node", - GraphActionType.AddEdge: "non_edge", - GraphActionType.SetEdgeAttr: "edge", - GraphActionType.RemoveNode: "node", - GraphActionType.RemoveNodeAttr: "node", - GraphActionType.RemoveEdge: "edge", - GraphActionType.RemoveEdgeAttr: "edge", - } - # The torch_geometric batch key each graph part corresponds to - self._graph_part_to_key = { - "graph": None, - "node": "x", - "non_edge": "non_edge_index", - "edge": "edge_index", + self._action_type_to_key = { + at: self._graph_part_to_key[self._action_type_to_graph_part[at]] for at in self._action_type_to_graph_part } # Here we create only the embedding -> logit mapping MLPs that are required by the environment @@ -229,13 +233,14 @@ def _action_type_to_logit(self, t, emb, g): def _mask(self, x, m): # mask logit vector x with binary mask m, -1000 is a tiny log-value + # Note to self: we can't use torch.inf here, because inf * 0 is nan (but also see issue #99) return x * m + -1000 * (1 - m) def _make_cat(self, g, emb, action_types): return GraphActionCategorical( g, logits=[self._action_type_to_logit(t, emb, g) for t in action_types], - keys=[self._graph_part_to_key[self._action_type_to_graph_part[t]] for t in action_types], + keys=[self._action_type_to_key[t] for t in action_types], masks=[self._action_type_to_mask(t, g) for t in action_types], types=action_types, ) diff --git a/src/gflownet/online_trainer.py b/src/gflownet/online_trainer.py new file mode 100644 index 00000000..98791be5 --- /dev/null +++ b/src/gflownet/online_trainer.py @@ -0,0 +1,107 @@ +import copy +import os +import pathlib + +import git +import torch +from omegaconf import OmegaConf +from torch import Tensor + +from gflownet.algo.advantage_actor_critic import A2C +from gflownet.algo.flow_matching import FlowMatching +from gflownet.algo.soft_q_learning import SoftQLearning +from gflownet.algo.trajectory_balance import TrajectoryBalance +from gflownet.data.replay_buffer import ReplayBuffer +from gflownet.models.graph_transformer import GraphTransformerGFN + +from .trainer import GFNTrainer + + +class StandardOnlineTrainer(GFNTrainer): + def setup_model(self): + self.model = GraphTransformerGFN( + self.ctx, + self.cfg, + do_bck=self.cfg.algo.tb.do_parameterize_p_b, + ) + + def setup_algo(self): + algo = self.cfg.algo.method + if algo == "TB": + algo = TrajectoryBalance + elif algo == "FM": + algo = FlowMatching + elif algo == "A2C": + algo = A2C + elif algo == "SQL": + algo = SoftQLearning + else: + raise ValueError(algo) + self.algo = algo(self.env, self.ctx, self.rng, self.cfg) + + def setup_data(self): + self.training_data = [] + self.test_data = [] + + def setup(self): + super().setup() + self.offline_ratio = 0 + self.replay_buffer = ReplayBuffer(self.cfg, self.rng) if self.cfg.replay.use else None + + # Separate Z parameters from non-Z to allow for LR decay on the former + if hasattr(self.model, "logZ"): + Z_params = list(self.model.logZ.parameters()) + non_Z_params = [i for i in self.model.parameters() if all(id(i) != id(j) for j in Z_params)] + else: + Z_params = [] + non_Z_params = list(self.model.parameters()) + self.opt = torch.optim.Adam( + non_Z_params, + self.cfg.opt.learning_rate, + (self.cfg.opt.momentum, 0.999), + weight_decay=self.cfg.opt.weight_decay, + eps=self.cfg.opt.adam_eps, + ) + self.opt_Z = torch.optim.Adam(Z_params, self.cfg.algo.tb.Z_learning_rate, (0.9, 0.999)) + self.lr_sched = torch.optim.lr_scheduler.LambdaLR(self.opt, lambda steps: 2 ** (-steps / self.cfg.opt.lr_decay)) + self.lr_sched_Z = torch.optim.lr_scheduler.LambdaLR( + self.opt_Z, lambda steps: 2 ** (-steps / self.cfg.algo.tb.Z_lr_decay) + ) + + self.sampling_tau = self.cfg.algo.sampling_tau + if self.sampling_tau > 0: + self.sampling_model = copy.deepcopy(self.model) + else: + self.sampling_model = self.model + + self.mb_size = self.cfg.algo.global_batch_size + self.clip_grad_callback = { + "value": lambda params: torch.nn.utils.clip_grad_value_(params, self.cfg.opt.clip_grad_param), + "norm": lambda params: torch.nn.utils.clip_grad_norm_(params, self.cfg.opt.clip_grad_param), + "none": lambda x: None, + }[self.cfg.opt.clip_grad_type] + + # saving hyperparameters + git_hash = git.Repo(__file__, search_parent_directories=True).head.object.hexsha[:7] + self.cfg.git_hash = git_hash + + os.makedirs(self.cfg.log_dir, exist_ok=True) + print("\n\nHyperparameters:\n") + yaml = OmegaConf.to_yaml(self.cfg) + print(yaml) + with open(pathlib.Path(self.cfg.log_dir) / "hps.yaml", "w") as f: + f.write(yaml) + + def step(self, loss: Tensor): + loss.backward() + for i in self.model.parameters(): + self.clip_grad_callback(i) + self.opt.step() + self.opt.zero_grad() + self.opt_Z.step() + self.opt_Z.zero_grad() + self.lr_sched.step() + self.lr_sched_Z.step() + if self.sampling_tau > 0: + for a, b in zip(self.model.parameters(), self.sampling_model.parameters()): + b.data.mul_(self.sampling_tau).add_(a.data * (1 - self.sampling_tau)) diff --git a/src/gflownet/tasks/config.py b/src/gflownet/tasks/config.py index bc0e31d0..a9f6ac3f 100644 --- a/src/gflownet/tasks/config.py +++ b/src/gflownet/tasks/config.py @@ -1,30 +1,10 @@ from dataclasses import dataclass, field -from typing import Any, List, Optional, Tuple +from typing import List, Optional, Tuple @dataclass class SEHTaskConfig: - """Config for the SEHTask - - Attributes - ---------- - - temperature_sample_dist : str - The distribution to sample the inverse temperature from. Can be one of: - - "uniform": uniform distribution - - "loguniform": log-uniform distribution - - "gamma": gamma distribution - - "constant": constant temperature - temperature_dist_params : List[Any] - The parameters of the temperature distribution. E.g. for the "uniform" distribution, this is the range. - num_thermometer_dim : int - The number of thermometer encoding dimensions to use. - """ - - # TODO: a proper class for temperature-conditional sampling - temperature_sample_dist: str = "uniform" - temperature_dist_params: List[Any] = field(default_factory=lambda: [0.5, 32]) - num_thermometer_dim: int = 32 + pass # SEH just uses a temperature conditional @dataclass @@ -57,10 +37,6 @@ class SEHMOOTaskConfig: The objectives to use for the multi-objective optimization. Should be a subset of ["seh", "qed", "sa", "wt"]. """ - # TODO: a proper class for temperature-conditional sampling - temperature_sample_dist: str = "uniform" - temperature_dist_params: List[Any] = field(default_factory=lambda: [0.5, 32]) - num_thermometer_dim: int = 32 use_steer_thermometer: bool = False preference_type: Optional[str] = "dirichlet" focus_type: Optional[str] = None @@ -72,16 +48,13 @@ class SEHMOOTaskConfig: max_train_it: Optional[int] = None n_valid: int = 15 n_valid_repeats: int = 128 - objectives: List[str] = field(default_factory=lambda: ["seh", "qed", "sa", "wt"]) + objectives: List[str] = field(default_factory=lambda: ["seh", "qed", "sa", "mw"]) @dataclass class QM9TaskConfig: - # TODO: a proper class for temperature-conditional sampling - temperature_sample_dist: str = "uniform" - temperature_dist_params: List[Any] = field(default_factory=lambda: [0.5, 32]) - num_thermometer_dim: int = 32 - h5_path = "./data/qm9/qm9.h5" # see src/gflownet/data/qm9.py + h5_path: str = "./data/qm9/qm9.h5" # see src/gflownet/data/qm9.py + model_path: str = "./data/qm9/qm9_model.pt" @dataclass diff --git a/src/gflownet/tasks/qm9/qm9.py b/src/gflownet/tasks/qm9/qm9.py index 5405d867..e5b1d29a 100644 --- a/src/gflownet/tasks/qm9/qm9.py +++ b/src/gflownet/tasks/qm9/qm9.py @@ -1,27 +1,22 @@ -import copy import os from typing import Callable, Dict, List, Tuple, Union import numpy as np -import scipy.stats as stats import torch import torch.nn as nn import torch_geometric.data as gd -from rdkit import RDLogger from rdkit.Chem.rdchem import Mol as RDMol from ruamel.yaml import YAML from torch import Tensor from torch.utils.data import Dataset import gflownet.models.mxmnet as mxmnet -from gflownet.algo.trajectory_balance import TrajectoryBalance from gflownet.config import Config from gflownet.data.qm9 import QM9Dataset -from gflownet.envs.graph_building_env import GraphBuildingEnv from gflownet.envs.mol_building_env import MolBuildingEnvContext -from gflownet.models.graph_transformer import GraphTransformerGFN -from gflownet.train import FlatRewards, GFNTask, GFNTrainer, RewardScalar -from gflownet.utils.transforms import thermometer +from gflownet.online_trainer import StandardOnlineTrainer +from gflownet.trainer import FlatRewards, GFNTask, RewardScalar +from gflownet.utils.conditioning import TemperatureConditional class QM9GapTask(GFNTask): @@ -30,19 +25,15 @@ class QM9GapTask(GFNTask): def __init__( self, dataset: Dataset, - temperature_distribution: str, - temperature_parameters: List[float], - num_thermometer_dim: int, + cfg: Config, rng: np.random.Generator = None, wrap_model: Callable[[nn.Module], nn.Module] = None, ): self._wrap_model = wrap_model self.rng = rng - self.models = self.load_task_models() + self.models = self.load_task_models(cfg.task.qm9.model_path) self.dataset = dataset - self.temperature_sample_dist = temperature_distribution - self.temperature_dist_params = temperature_parameters - self.num_thermometer_dim = num_thermometer_dim + self.temperature_conditional = TemperatureConditional(cfg, rng) # TODO: fix interface self._min, self._max, self._percentile_95 = self.dataset.get_stats(percentile=0.05) # type: ignore self._width = self._max - self._min @@ -69,51 +60,20 @@ def inverse_flat_reward_transform(self, rp): elif self._rtrans == "unit+95p": return (1 - rp + (1 - self._percentile_95)) * self._width + self._min - def load_task_models(self): + def load_task_models(self, path): gap_model = mxmnet.MXMNet(mxmnet.Config(128, 6, 5.0)) # TODO: this path should be part of the config? - state_dict = torch.load("/data/chem/qm9/mxmnet_gap_model.pt") + state_dict = torch.load(path) gap_model.load_state_dict(state_dict) gap_model.cuda() gap_model, self.device = self._wrap_model(gap_model, send_to_device=True) return {"mxmnet_gap": gap_model} - def sample_conditional_information(self, n: int) -> Dict[str, Tensor]: - beta = None - if self.temperature_sample_dist == "constant": - assert type(self.temperature_dist_params) in [float, int] - beta = np.array(self.temperature_dist_params).repeat(n).astype(np.float32) - beta_enc = torch.zeros((n, self.num_thermometer_dim)) - else: - if self.temperature_sample_dist == "gamma": - loc, scale = self.temperature_dist_params - beta = self.rng.gamma(loc, scale, n).astype(np.float32) - upper_bound = stats.gamma.ppf(0.95, loc, scale=scale) - elif self.temperature_sample_dist == "uniform": - a, b = float(self.temperature_dist_params[0]), float(self.temperature_dist_params[1]) - beta = self.rng.uniform(a, b, n).astype(np.float32) - upper_bound = self.temperature_dist_params[1] - elif self.temperature_sample_dist == "loguniform": - low, high = np.log(self.temperature_dist_params) - beta = np.exp(self.rng.uniform(low, high, n).astype(np.float32)) - upper_bound = self.temperature_dist_params[1] - elif self.temperature_sample_dist == "beta": - a, b = float(self.temperature_dist_params[0]), float(self.temperature_dist_params[1]) - beta = self.rng.beta(a, b, n).astype(np.float32) - upper_bound = 1 - beta_enc = thermometer(torch.tensor(beta), self.num_thermometer_dim, 0, upper_bound) - - assert len(beta.shape) == 1, f"beta should be a 1D array, got {beta.shape}" - return {"beta": torch.tensor(beta), "encoding": beta_enc} + def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Tensor]: + return self.temperature_conditional.sample(n) def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar: - if isinstance(flat_reward, list): - flat_reward = torch.tensor(flat_reward) - scalar_logreward = flat_reward.squeeze().clamp(min=1e-30).log() - assert len(scalar_logreward.shape) == len( - cond_info["beta"].shape - ), f"dangerous shape mismatch: {scalar_logreward.shape} vs {cond_info['beta'].shape}" - return RewardScalar(scalar_logreward * cond_info["beta"]) + return RewardScalar(self.temperature_conditional.transform(cond_info, flat_reward)) def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: graphs = [mxmnet.mol2graph(i) for i in mols] # type: ignore[attr-defined] @@ -128,7 +88,7 @@ def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: return FlatRewards(preds), is_valid -class QM9GapTrainer(GFNTrainer): +class QM9GapTrainer(StandardOnlineTrainer): def set_default_hps(self, cfg: Config): cfg.num_workers = 8 cfg.num_training_steps = 100000 @@ -139,86 +99,48 @@ def set_default_hps(self, cfg: Config): cfg.opt.lr_decay = 20000 cfg.opt.clip_grad_type = "norm" cfg.opt.clip_grad_param = 10 + cfg.algo.max_nodes = 9 cfg.algo.global_batch_size = 64 cfg.algo.train_random_action_prob = 0.001 cfg.algo.illegal_action_logreward = -75 cfg.algo.sampling_tau = 0.0 cfg.model.num_emb = 128 cfg.model.num_layers = 4 - cfg.task.qm9.temperature_sample_dist = "uniform" - cfg.task.qm9.temperature_dist_params = [0.5, 32.0] - cfg.task.qm9.num_thermometer_dim = 32 - - def setup(self): - RDLogger.DisableLog("rdApp.*") - self.rng = np.random.default_rng(142857) - self.env = GraphBuildingEnv() - self.ctx = MolBuildingEnvContext(["H", "C", "N", "F", "O"], num_cond_dim=32) - self.training_data = QM9Dataset(self.cfg.task.qm9.h5_path, train=True, target="gap") - self.test_data = QM9Dataset(self.cfg.task.qm9.h5_path, train=False, target="gap") + cfg.cond.temperature.sample_dist = "uniform" + cfg.cond.temperature.dist_params = [0.5, 32.0] + cfg.cond.temperature.num_thermometer_dim = 32 - model = GraphTransformerGFN(self.ctx, self.cfg) - self.model = model - # Separate Z parameters from non-Z to allow for LR decay on the former - Z_params = list(model.logZ.parameters()) - non_Z_params = [i for i in self.model.parameters() if all(id(i) != id(j) for j in Z_params)] - self.opt = torch.optim.Adam( - non_Z_params, - self.cfg.opt.learning_rate, - (self.cfg.opt.momentum, 0.999), - weight_decay=self.cfg.opt.weight_decay, - eps=self.cfg.opt.adam_eps, - ) - self.opt_Z = torch.optim.Adam(Z_params, self.cfg.opt.learning_rate, (0.9, 0.999)) - self.lr_sched = torch.optim.lr_scheduler.LambdaLR(self.opt, lambda steps: 2 ** (-steps / self.cfg.opt.lr_decay)) - self.lr_sched_Z = torch.optim.lr_scheduler.LambdaLR( - self.opt_Z, lambda steps: 2 ** (-steps / self.cfg.opt.lr_decay) + def setup_env_context(self): + self.ctx = MolBuildingEnvContext( + ["C", "N", "F", "O"], expl_H_range=[0, 1, 2, 3], num_cond_dim=32, allow_5_valence_nitrogen=True ) + # Note: we only need the allow_5_valence_nitrogen flag because of how we generate trajectories + # from the dataset. For example, consider tue Nitrogen atom in this: C[NH+](C)C, when s=CN(C)C, if the action + # for setting the explicit hydrogen is used before the positive charge is set, it will be considered + # an invalid action. However, generate_forward_trajectory does not consider this implementation detail, + # it assumes that attribute-setting will always be valid. For the molecular environment, as of writing + # (PR #98) this edge case is the only case where the ordering in which attributes are set can matter. + + def setup_data(self): + self.training_data = QM9Dataset(self.cfg.task.qm9.h5_path, train=True, target="gap") + self.test_data = QM9Dataset(self.cfg.task.qm9.h5_path, train=False, target="gap") - self.sampling_tau = self.cfg.algo.sampling_tau - if self.sampling_tau > 0: - self.sampling_model = copy.deepcopy(model) - else: - self.sampling_model = self.model - self.algo = TrajectoryBalance(self.env, self.ctx, self.rng, self.cfg) - + def setup_task(self): self.task = QM9GapTask( dataset=self.training_data, - temperature_distribution=self.cfg.task.qm9.temperature_sample_dist, - temperature_parameters=self.cfg.task.qm9.temperature_dist_params, - num_thermometer_dim=self.cfg.task.qm9.num_thermometer_dim, + cfg=self.cfg, + rng=self.rng, wrap_model=self._wrap_for_mp, ) - self.mb_size = self.cfg.algo.global_batch_size - self.clip_grad_param = self.cfg.opt.clip_grad_param - self.clip_grad_callback = { - "value": lambda params: torch.nn.utils.clip_grad_value_(params, self.clip_grad_param), - "norm": lambda params: torch.nn.utils.clip_grad_norm_(params, self.clip_grad_param), - "none": lambda x: None, - }[self.cfg.opt.clip_grad_type] - - def step(self, loss: Tensor): - loss.backward() - for i in self.model.parameters(): - self.clip_grad_callback(i) - self.opt.step() - self.opt.zero_grad() - self.opt_Z.step() - self.opt_Z.zero_grad() - self.lr_sched.step() - self.lr_sched_Z.step() - if self.sampling_tau > 0: - for a, b in zip(self.model.parameters(), self.sampling_model.parameters()): - b.data.mul_(self.sampling_tau).add_(a.data * (1 - self.sampling_tau)) def main(): - """Example of how this model can be run outside of Determined""" + """Example of how this model can be run.""" yaml = YAML(typ="safe", pure=True) config_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), "qm9.yaml") with open(config_file, "r") as f: hps = yaml.load(f) - trial = QM9GapTrainer(hps, torch.device("cpu")) + trial = QM9GapTrainer(hps) trial.run() diff --git a/src/gflownet/tasks/qm9/qm9.yaml b/src/gflownet/tasks/qm9/qm9.yaml index 01ea17f4..19701fac 100644 --- a/src/gflownet/tasks/qm9/qm9.yaml +++ b/src/gflownet/tasks/qm9/qm9.yaml @@ -1,5 +1,10 @@ -lr_decay: 10000 -qm9_h5_path: /data/chem/qm9/qm9.h5 -log_dir: /scratch/logs/qm9_gap_mxmnet +opt: + lr_decay: 10000 +task: + qm9: + h5_path: /rxrx/data/chem/qm9/qm9.h5 + model_path: /rxrx/data/chem/qm9/mxmnet_gap_model.pt num_training_steps: 100000 validate_every: 100 +log_dir: ./logs/debug_qm9 +num_workers: 0 diff --git a/src/gflownet/tasks/seh_frag.py b/src/gflownet/tasks/seh_frag.py index 10fa142c..4d0cc624 100644 --- a/src/gflownet/tasks/seh_frag.py +++ b/src/gflownet/tasks/seh_frag.py @@ -1,34 +1,22 @@ -import copy import os -import pathlib import shutil import socket from typing import Callable, Dict, List, Tuple, Union -import git import numpy as np -import scipy.stats as stats import torch import torch.nn as nn import torch_geometric.data as gd -from omegaconf import OmegaConf -from rdkit import RDLogger from rdkit.Chem.rdchem import Mol as RDMol from torch import Tensor from torch.utils.data import Dataset -from gflownet.algo.advantage_actor_critic import A2C -from gflownet.algo.flow_matching import FlowMatching -from gflownet.algo.soft_q_learning import SoftQLearning -from gflownet.algo.trajectory_balance import TrajectoryBalance from gflownet.config import Config -from gflownet.data.replay_buffer import ReplayBuffer from gflownet.envs.frag_mol_env import FragMolBuildingEnvContext -from gflownet.envs.graph_building_env import GraphBuildingEnv from gflownet.models import bengio2021flow -from gflownet.models.graph_transformer import GraphTransformerGFN -from gflownet.train import FlatRewards, GFNTask, GFNTrainer, RewardScalar -from gflownet.utils.transforms import thermometer +from gflownet.online_trainer import StandardOnlineTrainer +from gflownet.trainer import FlatRewards, GFNTask, RewardScalar +from gflownet.utils.conditioning import TemperatureConditional class SEHTask(GFNTask): @@ -52,9 +40,8 @@ def __init__( self.rng = rng self.models = self._load_task_models() self.dataset = dataset - self.temperature_sample_dist = cfg.task.seh.temperature_sample_dist - self.temperature_dist_params = cfg.task.seh.temperature_dist_params - self.num_thermometer_dim = cfg.task.seh.num_thermometer_dim + self.temperature_conditional = TemperatureConditional(cfg, rng) + self.num_cond_dim = self.temperature_conditional.encoding_size() def flat_reward_transform(self, y: Union[float, Tensor]) -> FlatRewards: return FlatRewards(torch.as_tensor(y) / 8) @@ -68,41 +55,10 @@ def _load_task_models(self): return {"seh": model} def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Tensor]: - beta = None - if self.temperature_sample_dist == "constant": - assert type(self.temperature_dist_params) is float - beta = np.array(self.temperature_dist_params).repeat(n).astype(np.float32) - beta_enc = torch.zeros((n, self.num_thermometer_dim)) - else: - if self.temperature_sample_dist == "gamma": - loc, scale = self.temperature_dist_params - beta = self.rng.gamma(loc, scale, n).astype(np.float32) - upper_bound = stats.gamma.ppf(0.95, loc, scale=scale) - elif self.temperature_sample_dist == "uniform": - a, b = float(self.temperature_dist_params[0]), float(self.temperature_dist_params[1]) - beta = self.rng.uniform(a, b, n).astype(np.float32) - upper_bound = self.temperature_dist_params[1] - elif self.temperature_sample_dist == "loguniform": - low, high = np.log(self.temperature_dist_params) - beta = np.exp(self.rng.uniform(low, high, n).astype(np.float32)) - upper_bound = self.temperature_dist_params[1] - elif self.temperature_sample_dist == "beta": - a, b = float(self.temperature_dist_params[0]), float(self.temperature_dist_params[1]) - beta = self.rng.beta(a, b, n).astype(np.float32) - upper_bound = 1 - beta_enc = thermometer(torch.tensor(beta), self.num_thermometer_dim, 0, upper_bound) - - assert len(beta.shape) == 1, f"beta should be a 1D array, got {beta.shape}" - return {"beta": torch.tensor(beta), "encoding": beta_enc} + return self.temperature_conditional.sample(n) def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar: - if isinstance(flat_reward, list): - flat_reward = torch.tensor(flat_reward) - scalar_logreward = flat_reward.squeeze().clamp(min=1e-30).log() - assert len(scalar_logreward.shape) == len( - cond_info["beta"].shape - ), f"dangerous shape mismatch: {scalar_logreward.shape} vs {cond_info['beta'].shape}" - return RewardScalar(scalar_logreward * cond_info["beta"]) + return RewardScalar(self.temperature_conditional.transform(cond_info, flat_reward)) def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: graphs = [bengio2021flow.mol2graph(i) for i in mols] @@ -117,7 +73,9 @@ def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: return FlatRewards(preds), is_valid -class SEHFragTrainer(GFNTrainer): +class SEHFragTrainer(StandardOnlineTrainer): + task: SEHTask + def set_default_hps(self, cfg: Config): cfg.hostname = socket.gethostname() cfg.pickle_mp_messages = False @@ -140,6 +98,7 @@ def set_default_hps(self, cfg: Config): cfg.algo.illegal_action_logreward = -75 cfg.algo.train_random_action_prob = 0.0 cfg.algo.valid_random_action_prob = 0.0 + cfg.algo.valid_offline_ratio = 0 cfg.algo.tb.epsilon = None cfg.algo.tb.bootstrap_own_reward = False cfg.algo.tb.Z_learning_rate = 1e-3 @@ -150,20 +109,6 @@ def set_default_hps(self, cfg: Config): cfg.replay.capacity = 10_000 cfg.replay.warmup = 1_000 - def setup_algo(self): - algo = self.cfg.algo.method - if algo == "TB": - algo = TrajectoryBalance - elif algo == "FM": - algo = FlowMatching - elif algo == "A2C": - algo = A2C - elif algo == "SQL": - algo = SoftQLearning - else: - raise ValueError(algo) - self.algo = algo(self.env, self.ctx, self.rng, self.cfg) - def setup_task(self): self.task = SEHTask( dataset=self.training_data, @@ -172,98 +117,30 @@ def setup_task(self): wrap_model=self._wrap_for_mp, ) - def setup_model(self): - model = GraphTransformerGFN( - self.ctx, - self.cfg, - do_bck=self.cfg.algo.tb.do_parameterize_p_b, - ) - self.model = model - def setup_env_context(self): - self.ctx = FragMolBuildingEnvContext( - max_frags=self.cfg.algo.max_nodes, num_cond_dim=self.cfg.task.seh.num_thermometer_dim - ) - - def setup(self): - RDLogger.DisableLog("rdApp.*") - self.rng = np.random.default_rng(142857) - self.env = GraphBuildingEnv() - self.training_data = [] - self.test_data = [] - self.offline_ratio = 0 - self.valid_offline_ratio = 0 - self.replay_buffer = ReplayBuffer(self.cfg, self.rng) if self.cfg.replay.use else None - self.setup_env_context() - self.setup_algo() - self.setup_task() - self.setup_model() - - # Separate Z parameters from non-Z to allow for LR decay on the former - Z_params = list(self.model.logZ.parameters()) - non_Z_params = [i for i in self.model.parameters() if all(id(i) != id(j) for j in Z_params)] - self.opt = torch.optim.Adam( - non_Z_params, - self.cfg.opt.learning_rate, - (self.cfg.opt.momentum, 0.999), - weight_decay=self.cfg.opt.weight_decay, - eps=self.cfg.opt.adam_eps, - ) - self.opt_Z = torch.optim.Adam(Z_params, self.cfg.algo.tb.Z_learning_rate, (0.9, 0.999)) - self.lr_sched = torch.optim.lr_scheduler.LambdaLR(self.opt, lambda steps: 2 ** (-steps / self.cfg.opt.lr_decay)) - self.lr_sched_Z = torch.optim.lr_scheduler.LambdaLR( - self.opt_Z, lambda steps: 2 ** (-steps / self.cfg.algo.tb.Z_lr_decay) - ) - - self.sampling_tau = self.cfg.algo.sampling_tau - if self.sampling_tau > 0: - self.sampling_model = copy.deepcopy(self.model) - else: - self.sampling_model = self.model - - self.mb_size = self.cfg.algo.global_batch_size - self.clip_grad_callback = { - "value": lambda params: torch.nn.utils.clip_grad_value_(params, self.cfg.opt.clip_grad_param), - "norm": lambda params: torch.nn.utils.clip_grad_norm_(params, self.cfg.opt.clip_grad_param), - "none": lambda x: None, - }[self.cfg.opt.clip_grad_type] - - # saving hyperparameters - git_hash = git.Repo(__file__, search_parent_directories=True).head.object.hexsha[:7] - self.cfg.git_hash = git_hash - - os.makedirs(self.cfg.log_dir, exist_ok=True) - print("\n\nHyperparameters:\n") - yaml = OmegaConf.to_yaml(self.cfg) - print(yaml) - with open(pathlib.Path(self.cfg.log_dir) / "hps.yaml", "w") as f: - f.write(yaml) - - def step(self, loss: Tensor): - loss.backward() - for i in self.model.parameters(): - self.clip_grad_callback(i) - self.opt.step() - self.opt.zero_grad() - self.opt_Z.step() - self.opt_Z.zero_grad() - self.lr_sched.step() - self.lr_sched_Z.step() - if self.sampling_tau > 0: - for a, b in zip(self.model.parameters(), self.sampling_model.parameters()): - b.data.mul_(self.sampling_tau).add_(a.data * (1 - self.sampling_tau)) + self.ctx = FragMolBuildingEnvContext(max_frags=self.cfg.algo.max_nodes, num_cond_dim=self.task.num_cond_dim) def main(): """Example of how this model can be run outside of Determined""" hps = { - "log_dir": "./logs/debug_run", + "log_dir": "./logs/debug_run_seh_frag", + "device": torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"), "overwrite_existing_exp": True, "num_training_steps": 10_000, "num_workers": 8, - "opt.lr_decay": 20000, - "algo.sampling_tau": 0.99, - "task.seh.temperature_dist_params": (0.0, 64.0), + "opt": { + "lr_decay": 20000, + }, + "algo": { + "sampling_tau": 0.99, + }, + "cond": { + "temperature": { + "sample_dist": "uniform", + "dist_params": [0, 64.0], + } + }, } if os.path.exists(hps["log_dir"]): if hps["overwrite_existing_exp"]: @@ -272,7 +149,7 @@ def main(): raise ValueError(f"Log dir {hps['log_dir']} already exists. Set overwrite_existing_exp=True to delete it.") os.makedirs(hps["log_dir"]) - trial = SEHFragTrainer(hps, torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")) + trial = SEHFragTrainer(hps) trial.print_every = 1 trial.run() diff --git a/src/gflownet/tasks/seh_frag_moo.py b/src/gflownet/tasks/seh_frag_moo.py index c515bc6e..8ad73320 100644 --- a/src/gflownet/tasks/seh_frag_moo.py +++ b/src/gflownet/tasks/seh_frag_moo.py @@ -1,8 +1,7 @@ import os import pathlib import shutil -from copy import deepcopy -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Tuple, Union import numpy as np import torch @@ -11,7 +10,6 @@ from rdkit.Chem import QED, Descriptors from rdkit.Chem.rdchem import Mol as RDMol from torch import Tensor -from torch.distributions.dirichlet import Dirichlet from torch.utils.data import Dataset from gflownet.algo.envelope_q_learning import EnvelopeQLearning, GraphTransformerFragEnvelopeQL @@ -20,11 +18,10 @@ from gflownet.envs.frag_mol_env import FragMolBuildingEnvContext from gflownet.models import bengio2021flow from gflownet.tasks.seh_frag import SEHFragTrainer, SEHTask -from gflownet.train import FlatRewards, RewardScalar +from gflownet.trainer import FlatRewards, RewardScalar from gflownet.utils import metrics, sascore -from gflownet.utils.focus_model import FocusModel, TabularFocusModel +from gflownet.utils.conditioning import FocusRegionConditional, MultiObjectiveWeightedPreferences from gflownet.utils.multiobjective_hooks import MultiObjectiveStatsHook, TopKHook -from gflownet.utils.transforms import thermometer class SEHMOOTask(SEHTask): @@ -42,66 +39,27 @@ def __init__( dataset: Dataset, cfg: Config, rng: np.random.Generator = None, - focus_model: Optional[FocusModel] = None, wrap_model: Callable[[nn.Module], nn.Module] = None, ): - self._wrap_model = wrap_model + super().__init__(dataset, cfg, rng, wrap_model) self.cfg = cfg mcfg = self.cfg.task.seh_moo - self.rng = rng - self.models = self._load_task_models() self.objectives = cfg.task.seh_moo.objectives self.dataset = dataset - self.temperature_sample_dist = mcfg.temperature_sample_dist - self.temperature_dist_params = mcfg.temperature_dist_params - self.num_thermometer_dim = mcfg.num_thermometer_dim - self.use_steer_thermometer = mcfg.use_steer_thermometer - self.preference_type = mcfg.preference_type - self.seeded_preference = None - self.experimental_dirichlet = False - self.focus_type = mcfg.focus_type - self.focus_cosim = mcfg.focus_cosim - self.focus_limit_coef = mcfg.focus_limit_coef - self.focus_model = focus_model - self.illegal_action_logreward = cfg.algo.illegal_action_logreward - self.focus_model_training_limits = mcfg.focus_model_training_limits - self.max_train_it = mcfg.max_train_it - self.setup_focus_regions() - assert set(self.objectives) <= {"seh", "qed", "sa", "mw"} and len(self.objectives) == len(set(self.objectives)) - - def setup_focus_regions(self): - mcfg = self.cfg.task.seh_moo - n_valid = mcfg.n_valid - n_obj = len(self.objectives) - # focus regions - if mcfg.focus_type is None: - valid_focus_dirs = np.zeros((n_valid, n_obj)) - self.fixed_focus_dirs = valid_focus_dirs - elif mcfg.focus_type == "centered": - valid_focus_dirs = np.ones((n_valid, n_obj)) - self.fixed_focus_dirs = valid_focus_dirs - elif mcfg.focus_type == "partitioned": - valid_focus_dirs = metrics.partition_hypersphere(d=n_obj, k=n_valid, normalisation="l2") - self.fixed_focus_dirs = valid_focus_dirs - elif mcfg.focus_type in ["dirichlet", "learned-gfn"]: - valid_focus_dirs = metrics.partition_hypersphere(d=n_obj, k=n_valid, normalisation="l1") - self.fixed_focus_dirs = None - elif mcfg.focus_type in ["hyperspherical", "learned-tabular"]: - valid_focus_dirs = metrics.partition_hypersphere(d=n_obj, k=n_valid, normalisation="l2") - self.fixed_focus_dirs = None - elif mcfg.focus_type == "listed": - if len(mcfg.focus_type) == 1: - valid_focus_dirs = np.array([mcfg.focus_dirs_listed[0]] * n_valid) - self.fixed_focus_dirs = valid_focus_dirs - else: - valid_focus_dirs = np.array(mcfg.focus_dirs_listed) - self.fixed_focus_dirs = valid_focus_dirs + if self.cfg.cond.focus_region.focus_type is not None: + self.focus_cond = FocusRegionConditional(self.cfg, mcfg.n_valid, rng) else: - raise NotImplementedError( - f"focus_type should be None, a list of fixed_focus_dirs, or a string describing one of the supported " - f"focus_type, but here: {mcfg.focus_type}" - ) - self.valid_focus_dirs = valid_focus_dirs + self.focus_cond = None + self.pref_cond = MultiObjectiveWeightedPreferences(self.cfg) + self.temperature_sample_dist = cfg.cond.temperature.sample_dist + self.temperature_dist_params = cfg.cond.temperature.dist_params + self.num_thermometer_dim = cfg.cond.temperature.num_thermometer_dim + self.num_cond_dim = ( + self.temperature_conditional.encoding_size() + + self.pref_cond.encoding_size() + + (self.focus_cond.encoding_size() if self.focus_cond is not None else 0) + ) + assert set(self.objectives) <= {"seh", "qed", "sa", "mw"} and len(self.objectives) == len(set(self.objectives)) def flat_reward_transform(self, y: Union[float, Tensor]) -> FlatRewards: return FlatRewards(torch.as_tensor(y)) @@ -109,62 +67,18 @@ def flat_reward_transform(self, y: Union[float, Tensor]) -> FlatRewards: def inverse_flat_reward_transform(self, rp): return rp - def _load_task_models(self): - model = bengio2021flow.load_original_model() - model, self.device = self._wrap_model(model, send_to_device=True) - return {"seh": model} - - def get_steer_encodings(self, preferences, focus_dirs): - n = len(preferences) - if self.use_steer_thermometer: - pref_enc = thermometer(preferences, self.num_thermometer_dim, 0, 1).reshape(n, -1) - focus_enc = thermometer(focus_dirs, self.num_thermometer_dim, 0, 1).reshape(n, -1) - else: - pref_enc = preferences - focus_enc = focus_dirs - return pref_enc, focus_enc - def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Tensor]: cond_info = super().sample_conditional_information(n, train_it) - - if self.preference_type is None: - preferences = torch.ones((n, len(self.objectives))) - else: - if self.seeded_preference is not None: - preferences = torch.tensor([self.seeded_preference] * n).float() - elif self.experimental_dirichlet: - a = np.random.dirichlet([1] * len(self.objectives), n) - b = np.random.exponential(1, n)[:, None] - preferences = Dirichlet(torch.tensor(a * b)).sample([1])[0].float() - else: - m = Dirichlet(torch.FloatTensor([1.0] * len(self.objectives))) - preferences = m.sample([n]) - - if self.fixed_focus_dirs is not None: - focus_dir = torch.tensor( - np.array(self.fixed_focus_dirs)[self.rng.choice(len(self.fixed_focus_dirs), n)].astype(np.float32) - ) - elif self.focus_type == "dirichlet": - m = Dirichlet(torch.FloatTensor([1.0] * len(self.objectives))) - focus_dir = m.sample([n]) - elif self.focus_type == "hyperspherical": - focus_dir = torch.tensor( - metrics.sample_positiveQuadrant_ndim_sphere(n, len(self.objectives), normalisation="l2") - ).float() - elif self.focus_type is not None and "learned" in self.focus_type: - if self.focus_model is not None and train_it >= self.focus_model_training_limits[0] * self.max_train_it: - focus_dir = self.focus_model.sample_focus_directions(n) - else: - focus_dir = torch.tensor( - metrics.sample_positiveQuadrant_ndim_sphere(n, len(self.objectives), normalisation="l2") - ).float() - else: - raise NotImplementedError(f"Unsupported focus_type={type(self.focus_type)}") - - preferences_enc, focus_enc = self.get_steer_encodings(preferences, focus_dir) - cond_info["encoding"] = torch.cat([cond_info["encoding"], preferences_enc, focus_enc], 1) - cond_info["preferences"] = preferences - cond_info["focus_dir"] = focus_dir + pref_ci = self.pref_cond.sample(n) + focus_ci = ( + self.focus_cond.sample(n, train_it) if self.focus_cond is not None else {"encoding": torch.zeros(n, 0)} + ) + cond_info = { + **cond_info, + **pref_ci, + **focus_ci, + "encoding": torch.cat([cond_info["encoding"], pref_ci["encoding"], focus_ci["encoding"]], dim=1), + } return cond_info def encode_conditional_information(self, steer_info: Tensor) -> Dict[str, Tensor]: @@ -179,7 +93,7 @@ def encode_conditional_information(self, steer_info: Tensor) -> Dict[str, Tensor """ n = len(steer_info) if self.temperature_sample_dist == "constant": - beta = torch.ones(n) * self.temperature_dist_params + beta = torch.ones(n) * self.temperature_dist_params[0] beta_enc = torch.zeros((n, self.num_thermometer_dim)) else: beta = torch.ones(n) * self.temperature_dist_params[-1] @@ -191,9 +105,12 @@ def encode_conditional_information(self, steer_info: Tensor) -> Dict[str, Tensor preferences = steer_info[:, : len(self.objectives)].float() focus_dir = steer_info[:, len(self.objectives) :].float() - preferences_enc, focus_enc = self.get_steer_encodings(preferences, focus_dir) - encoding = torch.cat([beta_enc, preferences_enc, focus_enc], 1).float() - + preferences_enc = self.pref_cond.encode(preferences) + if self.focus_cond is not None: + focus_enc = self.focus_cond.encode(focus_dir) + encoding = torch.cat([beta_enc, preferences_enc, focus_enc], 1).float() + else: + encoding = torch.cat([beta_enc, preferences_enc], 1).float() return { "beta": beta, "encoding": encoding, @@ -204,17 +121,23 @@ def encode_conditional_information(self, steer_info: Tensor) -> Dict[str, Tensor def relabel_condinfo_and_logrewards( self, cond_info: Dict[str, Tensor], log_rewards: Tensor, flat_rewards: FlatRewards, hindsight_idxs: Tensor ): - if self.focus_type is None: + # TODO: we seem to be relabeling tensors in place, could that cause a problem? + if self.focus_cond is None: + raise NotImplementedError("Hindsight relabeling only implemented for focus conditioning") + if self.focus_cond.cfg.focus_type is None: return cond_info, log_rewards # only keep hindsight_idxs that actually correspond to a violated constraint - _, in_focus_mask = metrics.compute_focus_coef(flat_rewards, cond_info["focus_dir"], self.focus_cosim) + _, in_focus_mask = metrics.compute_focus_coef( + flat_rewards, cond_info["focus_dir"], self.focus_cond.cfg.focus_cosim + ) out_focus_mask = torch.logical_not(in_focus_mask) hindsight_idxs = hindsight_idxs[out_focus_mask[hindsight_idxs]] # relabels the focus_dirs and log_rewards cond_info["focus_dir"][hindsight_idxs] = nn.functional.normalize(flat_rewards[hindsight_idxs], dim=1) - preferences_enc, focus_enc = self.get_steer_encodings(cond_info["preferences"], cond_info["focus_dir"]) + preferences_enc = self.pref_cond.encode(cond_info["preferences"]) + focus_enc = self.focus_cond.encode(cond_info["focus_dir"]) cond_info["encoding"] = torch.cat( [cond_info["encoding"][:, : self.num_thermometer_dim], preferences_enc, focus_enc], 1 ) @@ -228,19 +151,15 @@ def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: Flat flat_reward = torch.stack(flat_reward) else: flat_reward = torch.tensor(flat_reward) - scalar_logreward = (flat_reward * cond_info["preferences"]).sum(1).clamp(min=1e-30).log() - assert len(scalar_logreward.shape) == len( - cond_info["beta"].shape - ), f"dangerous shape mismatch: {scalar_logreward.shape} vs {cond_info['beta'].shape}" - - if self.focus_type is not None: - focus_coef, in_focus_mask = metrics.compute_focus_coef( - flat_reward, cond_info["focus_dir"], self.focus_cosim, self.focus_limit_coef - ) - scalar_logreward[in_focus_mask] += torch.log(focus_coef[in_focus_mask]) - scalar_logreward[~in_focus_mask] = self.illegal_action_logreward - return RewardScalar(scalar_logreward * cond_info["beta"]) + scalarized_reward = self.pref_cond.transform(cond_info, flat_reward) + focused_reward = ( + self.focus_cond.transform(cond_info, flat_reward, scalarized_reward) + if self.focus_cond is not None + else scalarized_reward + ) + tempered_reward = self.temperature_conditional.transform(cond_info, focused_reward) + return RewardScalar(tempered_reward) def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: graphs = [bengio2021flow.mol2graph(i) for i in mols] @@ -283,11 +202,15 @@ def safe(f, x, default): class SEHMOOFragTrainer(SEHFragTrainer): task: SEHMOOTask + ctx: FragMolBuildingEnvContext def set_default_hps(self, cfg: Config): super().set_default_hps(cfg) cfg.algo.sampling_tau = 0.95 + # We use a fixed set of preferences as our "validation set", so we must disable the preference (cond_info) + # sampling and set the offline ratio to 1 cfg.algo.valid_sample_cond_info = False + cfg.algo.valid_offline_ratio = 1 def setup_algo(self): algo = self.cfg.algo.method @@ -298,35 +221,16 @@ def setup_algo(self): else: super().setup_algo() - focus_type = self.cfg.task.seh_moo.focus_type - if focus_type is not None and "learned" in focus_type: - if focus_type == "learned-tabular": - self.focus_model = TabularFocusModel( - device=self.device, - n_objectives=len(self.cfg.task.seh_moo.objectives), - state_space_res=self.cfg.task.seh_moo.focus_model_state_space_res, - ) - else: - raise NotImplementedError("Unknown focus model type {self.focus_type}") - else: - self.focus_model = None - def setup_task(self): self.task = SEHMOOTask( dataset=self.training_data, cfg=self.cfg, - focus_model=self.focus_model, rng=self.rng, wrap_model=self._wrap_for_mp, ) def setup_env_context(self): - if self.cfg.task.seh_moo.use_steer_thermometer: - ncd = self.cfg.task.seh_moo.num_thermometer_dim * (1 + 2 * len(self.cfg.task.seh_moo.objectives)) - else: - # 1 for prefs and 1 for focus region - ncd = self.cfg.task.seh_moo.num_thermometer_dim + 2 * len(self.cfg.task.seh_moo.objectives) - self.ctx = FragMolBuildingEnvContext(max_frags=self.cfg.algo.max_nodes, num_cond_dim=ncd) + self.ctx = FragMolBuildingEnvContext(max_frags=self.cfg.algo.max_nodes, num_cond_dim=self.task.num_cond_dim) def setup_model(self): if self.cfg.algo.method == "MOQL": @@ -391,16 +295,22 @@ def setup(self): # hps["fixed_focus_dirs"] = ( # np.unique(self.task.fixed_focus_dirs, axis=0).tolist() if self.task.fixed_focus_dirs is not None else None # ) - assert self.task.valid_focus_dirs.shape == ( - n_valid, - n_obj, - ), f"Invalid shape for valid_preferences, {self.task.valid_focus_dirs.shape} != ({n_valid}, {n_obj})" + if self.task.focus_cond is not None: + assert self.task.focus_cond.valid_focus_dirs.shape == ( + n_valid, + n_obj, + ), ( + "Invalid shape for valid_preferences, " + f"{self.task.focus_cond.valid_focus_dirs.shape} != ({n_valid}, {n_obj})" + ) - # combine preferences and focus directions (fixed focus cosim) since they could be used together (not either/or) - # TODO: this relies on positional assumptions, should have something cleaner - valid_cond_vector = np.concatenate([valid_preferences, self.task.valid_focus_dirs], axis=1) + # combine preferences and focus directions (fixed focus cosim) since they could be used together + # (not either/or). TODO: this relies on positional assumptions, should have something cleaner + valid_cond_vector = np.concatenate([valid_preferences, self.task.focus_cond.valid_focus_dirs], axis=1) + else: + valid_cond_vector = valid_preferences - self._top_k_hook = TopKHook(10, tcfg.n_valid_repeats, len(valid_cond_vector)) + self._top_k_hook = TopKHook(10, tcfg.n_valid_repeats, n_valid) self.test_data = RepeatedCondInfoDataset(valid_cond_vector, repeat=tcfg.n_valid_repeats) self.valid_sampling_hooks.append(self._top_k_hook) @@ -421,19 +331,13 @@ def on_validation_end(self, metrics: Dict[str, Any]): return {"topk": TopKMetricCB()} def train_batch(self, batch: gd.Batch, epoch_idx: int, batch_idx: int, train_it: int) -> Dict[str, Any]: - focus_model_training_limits = self.cfg.task.seh_moo.focus_model_training_limits - max_train_it = self.cfg.num_training_steps - if ( - self.focus_model is not None - and train_it >= focus_model_training_limits[0] * max_train_it - and train_it <= focus_model_training_limits[1] * max_train_it - ): - self.focus_model.update_belief(deepcopy(batch.focus_dir), deepcopy(batch.flat_rewards)) + if self.task.focus_cond is not None: + self.task.focus_cond.step_focus_model(batch, train_it) return super().train_batch(batch, epoch_idx, batch_idx, train_it) def _save_state(self, it): - if self.focus_model is not None: - self.focus_model.save(pathlib.Path(self.cfg.log_dir)) + if self.task.focus_cond is not None and self.task.focus_cond.focus_model is not None: + self.task.focus_cond.focus_model.save(pathlib.Path(self.cfg.log_dir)) return super()._save_state(it) @@ -451,43 +355,65 @@ def __getitem__(self, idx): def main(): - """Example of how this model can be run outside of Determined""" + """Example of how this model can be run.""" hps = { - "log_dir": "./logs/debug_run", + "log_dir": "./logs/debug_run_sfm", + "device": torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"), "pickle_mp_messages": True, "overwrite_existing_exp": True, "seed": 0, - "num_training_steps": 20_000, - "num_final_gen_steps": 500, - "validate_every": 500, + "num_training_steps": 500, + "num_final_gen_steps": 50, + "validate_every": 100, "num_workers": 0, - "algo.global_batch_size": 64, - "algo.method": "TB", - "model.num_layers": 2, - "model.num_emb": 256, - "task.seh_moo.objectives": ["seh", "qed"], - "opt.learning_rate": 1e-4, - "algo.tb.Z_learning_rate": 1e-3, - "opt.lr_decay": 20000, - "algo.tb.Z_lr_decay": 50000, - "algo.sampling_tau": 0.95, - "algo.train_random_action_prob": 0.01, - "task.seh_moo.temperature_sample_dist": "constant", - "task.seh_moo.temperature_dist_params": 60.0, - "task.seh_moo.num_thermometer_dim": 32, - "task.seh_moo.use_steer_thermometer": False, - "task.seh_moo.preference_type": None, - "task.seh_moo.focus_type": "learned-tabular", - "task.seh_moo.focus_cosim": 0.98, - "task.seh_moo.focus_limit_coef": 1e-1, - "task.seh_moo.n_valid": 15, - "task.seh_moo.n_valid_repeats": 128, - "replay.use": True, - "replay.warmup": 1000, - "replay.hindsight_ratio": 0.3, - "task.seh_moo.focus_model_training_limits": [0.25, 0.75], - "task.seh_moo.focus_model_state_space_res": 30, - "task.seh_moo.max_train_it": 20_000, + "algo": { + "global_batch_size": 64, + "method": "TB", + "sampling_tau": 0.95, + "train_random_action_prob": 0.01, + "tb": { + "Z_learning_rate": 1e-3, + "Z_lr_decay": 50000, + }, + }, + "model": { + "num_layers": 2, + "num_emb": 256, + }, + "task": { + "seh_moo": { + "objectives": ["seh", "qed"], + "n_valid": 15, + "n_valid_repeats": 128, + }, + }, + "opt": { + "learning_rate": 1e-4, + "lr_decay": 20000, + }, + "cond": { + "temperature": { + "sample_dist": "constant", + "dist_params": [60.0], + "num_thermometer_dim": 32, + }, + "weighted_prefs": { + "preference_type": "dirichlet", + }, + "focus_region": { + "focus_type": None, # "learned-tabular", + "focus_cosim": 0.98, + "focus_limit_coef": 1e-1, + "focus_model_training_limits": (0.25, 0.75), + "focus_model_state_space_res": 30, + "max_train_it": 5_000, + }, + }, + "replay": { + "use": False, + "warmup": 1000, + "hindsight_ratio": 0.0, + }, } if os.path.exists(hps["log_dir"]): if hps["overwrite_existing_exp"]: @@ -496,7 +422,7 @@ def main(): raise ValueError(f"Log dir {hps['log_dir']} already exists. Set overwrite_existing_exp=True to delete it.") os.makedirs(hps["log_dir"]) - trial = SEHMOOFragTrainer(hps, torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")) + trial = SEHMOOFragTrainer(hps) trial.print_every = 1 trial.run() diff --git a/src/gflownet/train.py b/src/gflownet/trainer.py similarity index 93% rename from src/gflownet/train.py rename to src/gflownet/trainer.py index bb9c6ef3..93e0e0a5 100644 --- a/src/gflownet/train.py +++ b/src/gflownet/trainer.py @@ -2,11 +2,13 @@ import pathlib from typing import Any, Callable, Dict, List, NewType, Optional, Tuple +import numpy as np import torch import torch.nn as nn import torch.utils.tensorboard import torch_geometric.data as gd from omegaconf import OmegaConf +from rdkit import RDLogger from rdkit.Chem.rdchem import Mol as RDMol from torch import Tensor from torch.utils.data import DataLoader, Dataset @@ -88,7 +90,7 @@ def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: class GFNTrainer: - def __init__(self, hps: Dict[str, Any], device: torch.device): + def __init__(self, hps: Dict[str, Any]): """A GFlowNet trainer. Contains the main training loop in `run` and should be subclassed. Parameters @@ -119,13 +121,12 @@ def __init__(self, hps: Dict[str, Any], device: torch.device): # The final config is obtained by merging the three sources self.cfg: Config = OmegaConf.structured(Config()) self.set_default_hps(self.cfg) - OmegaConf.merge(self.cfg, hps) + # OmegaConf returns a fancy object but we can still pretend it's a Config instance + self.cfg = OmegaConf.merge(self.cfg, hps) # type: ignore - self.device = device - # idem, but from `self.test_data` during validation. - self.valid_offline_ratio = 1 + self.device = torch.device(self.cfg.device) # Print the loss every `self.print_every` iterations - self.print_every = 1000 + self.print_every = self.cfg.print_every # These hooks allow us to compute extra quantities when sampling data self.sampling_hooks: List[Callable] = [] self.valid_sampling_hooks: List[Callable] = [] @@ -137,12 +138,34 @@ def __init__(self, hps: Dict[str, Any], device: torch.device): def set_default_hps(self, base: Config): raise NotImplementedError() - def setup(self): + def setup_env_context(self): + raise NotImplementedError() + + def setup_task(self): + raise NotImplementedError() + + def setup_model(self): raise NotImplementedError() + def setup_algo(self): + raise NotImplementedError() + + def setup_data(self): + pass + def step(self, loss: Tensor): raise NotImplementedError() + def setup(self): + RDLogger.DisableLog("rdApp.*") + self.rng = np.random.default_rng(142857) + self.env = GraphBuildingEnv() + self.setup_data() + self.setup_task() + self.setup_env_context() + self.setup_algo() + self.setup_model() + def _wrap_for_mp(self, obj, send_to_device=False): """Wraps an object in a placeholder whose reference can be sent to a data worker process (only if the number of workers is non-zero).""" @@ -203,7 +226,7 @@ def build_validation_data_loader(self) -> DataLoader: dev, batch_size=self.cfg.algo.global_batch_size, illegal_action_logreward=self.cfg.algo.illegal_action_logreward, - ratio=self.valid_offline_ratio, + ratio=self.cfg.algo.valid_offline_ratio, log_dir=str(pathlib.Path(self.cfg.log_dir) / "valid"), sample_cond_info=self.cfg.algo.valid_sample_cond_info, stream=False, diff --git a/src/gflownet/utils/conditioning.py b/src/gflownet/utils/conditioning.py new file mode 100644 index 00000000..5893ef3b --- /dev/null +++ b/src/gflownet/utils/conditioning.py @@ -0,0 +1,246 @@ +import abc +from copy import deepcopy +from typing import Dict + +import numpy as np +import torch +from scipy import stats +from torch import Tensor +from torch.distributions.dirichlet import Dirichlet +from torch_geometric import data as gd + +from gflownet.config import Config +from gflownet.utils import metrics +from gflownet.utils.focus_model import TabularFocusModel +from gflownet.utils.transforms import thermometer + + +class Conditional(abc.ABC): + def sample(self, n): + raise NotImplementedError() + + @abc.abstractmethod + def transform(self, cond_info: Dict[str, Tensor], properties: Tensor) -> Tensor: + raise NotImplementedError() + + def encoding_size(self): + raise NotImplementedError() + + def encode(self, conditional: Tensor) -> Tensor: + raise NotImplementedError() + + +class TemperatureConditional(Conditional): + def __init__(self, cfg: Config, rng: np.random.Generator): + self.cfg = cfg + tmp_cfg = self.cfg.cond.temperature + self.rng = rng + self.upper_bound = 1024 + if tmp_cfg.sample_dist == "gamma": + loc, scale = tmp_cfg.dist_params + self.upper_bound = stats.gamma.ppf(0.95, loc, scale=scale) + elif tmp_cfg.sample_dist == "uniform": + self.upper_bound = tmp_cfg.dist_params[1] + elif tmp_cfg.sample_dist == "loguniform": + self.upper_bound = tmp_cfg.dist_params[1] + elif tmp_cfg.sample_dist == "beta": + self.upper_bound = 1 + + def encoding_size(self): + return self.cfg.cond.temperature.num_thermometer_dim + + def sample(self, n): + cfg = self.cfg.cond.temperature + beta = None + if cfg.sample_dist == "constant": + assert type(cfg.dist_params[0]) is float + beta = np.array(cfg.dist_params[0]).repeat(n).astype(np.float32) + beta_enc = torch.zeros((n, cfg.num_thermometer_dim)) + else: + if cfg.sample_dist == "gamma": + loc, scale = cfg.dist_params + beta = self.rng.gamma(loc, scale, n).astype(np.float32) + elif cfg.sample_dist == "uniform": + a, b = float(cfg.dist_params[0]), float(cfg.dist_params[1]) + beta = self.rng.uniform(a, b, n).astype(np.float32) + elif cfg.sample_dist == "loguniform": + low, high = np.log(cfg.dist_params) + beta = np.exp(self.rng.uniform(low, high, n).astype(np.float32)) + elif cfg.sample_dist == "beta": + a, b = float(cfg.dist_params[0]), float(cfg.dist_params[1]) + beta = self.rng.beta(a, b, n).astype(np.float32) + beta_enc = thermometer(torch.tensor(beta), cfg.num_thermometer_dim, 0, self.upper_bound) + + assert len(beta.shape) == 1, f"beta should be a 1D array, got {beta.shape}" + return {"beta": torch.tensor(beta), "encoding": beta_enc} + + def transform(self, cond_info: Dict[str, Tensor], linear_reward: Tensor) -> Tensor: + scalar_logreward = linear_reward.squeeze().clamp(min=1e-30).log() + assert len(scalar_logreward.shape) == len( + cond_info["beta"].shape + ), f"dangerous shape mismatch: {scalar_logreward.shape} vs {cond_info['beta'].shape}" + return scalar_logreward * cond_info["beta"] + + def encode(self, conditional: Tensor) -> Tensor: + cfg = self.cfg.cond.temperature + if cfg.sample_dist == "constant": + return torch.zeros((conditional.shape[0], cfg.num_thermometer_dim)) + return thermometer(torch.tensor(conditional), cfg.num_thermometer_dim, 0, self.upper_bound) + + +class MultiObjectiveWeightedPreferences(Conditional): + def __init__(self, cfg: Config): + self.cfg = cfg.cond.weighted_prefs + self.num_objectives = cfg.cond.moo.num_objectives + self.num_thermometer_dim = cfg.cond.moo.num_thermometer_dim + if self.cfg.preference_type == "seeded": + self.seeded_prefs = np.random.default_rng(142857 + int(cfg.seed)).dirichlet([1] * self.num_objectives) + + def sample(self, n): + if self.cfg.preference_type is None: + preferences = torch.ones((n, self.num_objectives)) + elif self.cfg.preference_type == "seeded": + preferences = torch.tensor(self.seeded_prefs).float().repeat(n, 1) + elif self.cfg.preference_type == "dirichlet_exponential": + a = np.random.dirichlet([1] * self.num_objectives, n) + b = np.random.exponential(1, n)[:, None] + preferences = Dirichlet(torch.tensor(a * b)).sample([1])[0].float() + elif self.cfg.preference_type == "dirichlet": + m = Dirichlet(torch.FloatTensor([1.0] * self.num_objectives)) + preferences = m.sample([n]) + else: + raise ValueError(f"Unknown preference type {self.cfg.preference_type}") + preferences = torch.as_tensor(preferences).float() + return {"preferences": preferences, "encoding": self.encode(preferences)} + + def transform(self, cond_info: Dict[str, Tensor], flat_reward: Tensor) -> Tensor: + scalar_logreward = (flat_reward * cond_info["preferences"]).sum(1).clamp(min=1e-30).log() + assert len(scalar_logreward.shape) == 1, f"scalar_logreward should be a 1D array, got {scalar_logreward.shape}" + return scalar_logreward + + def encoding_size(self): + return max(1, self.num_thermometer_dim * self.num_objectives) + + def encode(self, conditional: Tensor) -> Tensor: + if self.num_thermometer_dim > 0: + return thermometer(conditional, self.num_thermometer_dim, 0, 1).reshape(conditional.shape[0], -1) + else: + return conditional.unsqueeze(1) + + +class FocusRegionConditional(Conditional): + def __init__(self, cfg: Config, n_valid: int, rng: np.random.Generator): + self.cfg = cfg.cond.focus_region + self.n_valid = n_valid + self.n_objectives = cfg.cond.moo.num_objectives + self.ocfg = cfg + self.rng = rng + self.num_thermometer_dim = cfg.cond.moo.num_thermometer_dim if self.cfg.use_steer_thermomether else 0 + + focus_type = self.cfg.focus_type + if focus_type is not None and "learned" in focus_type: + if focus_type == "learned-tabular": + self.focus_model = TabularFocusModel( + # TODO: proper device propagation + device=torch.device("cpu"), + n_objectives=cfg.cond.moo.num_objectives, + state_space_res=self.cfg.focus_model_state_space_res, + ) + else: + raise NotImplementedError("Unknown focus model type {self.focus_type}") + else: + self.focus_model = None + self.setup_focus_regions() + + def encoding_size(self): + if self.num_thermometer_dim > 0: + return self.num_thermometer_dim * self.n_objectives + return self.n_objectives + + def setup_focus_regions(self): + # focus regions + if self.cfg.focus_type is None: + valid_focus_dirs = np.zeros((self.n_valid, self.n_objectives)) + self.fixed_focus_dirs = valid_focus_dirs + elif self.cfg.focus_type == "centered": + valid_focus_dirs = np.ones((self.n_valid, self.n_objectives)) + self.fixed_focus_dirs = valid_focus_dirs + elif self.cfg.focus_type == "partitioned": + valid_focus_dirs = metrics.partition_hypersphere(d=self.n_objectives, k=self.n_valid, normalisation="l2") + self.fixed_focus_dirs = valid_focus_dirs + elif self.cfg.focus_type in ["dirichlet", "learned-gfn"]: + valid_focus_dirs = metrics.partition_hypersphere(d=self.n_objectives, k=self.n_valid, normalisation="l1") + self.fixed_focus_dirs = None + elif self.cfg.focus_type in ["hyperspherical", "learned-tabular"]: + valid_focus_dirs = metrics.partition_hypersphere(d=self.n_objectives, k=self.n_valid, normalisation="l2") + self.fixed_focus_dirs = None + elif type(self.cfg.focus_type) is list: + if len(self.cfg.focus_type) == 1: + valid_focus_dirs = np.array([self.cfg.focus_type[0]] * self.n_valid) + self.fixed_focus_dirs = valid_focus_dirs + else: + valid_focus_dirs = np.array(self.cfg.focus_type) + self.fixed_focus_dirs = valid_focus_dirs + else: + raise NotImplementedError( + f"focus_type should be None, a list of fixed_focus_dirs, or a string describing one of the supported " + f"focus_type, but here: {self.cfg.focus_type}" + ) + self.valid_focus_dirs = valid_focus_dirs + + def sample(self, n: int, train_it: int = None): + train_it = train_it or 0 + if self.fixed_focus_dirs is not None: + focus_dir = torch.tensor( + np.array(self.fixed_focus_dirs)[self.rng.choice(len(self.fixed_focus_dirs), n)].astype(np.float32) + ) + elif self.cfg.focus_type == "dirichlet": + m = Dirichlet(torch.FloatTensor([1.0] * self.n_objectives)) + focus_dir = m.sample([n]) + elif self.cfg.focus_type == "hyperspherical": + focus_dir = torch.tensor( + metrics.sample_positiveQuadrant_ndim_sphere(n, self.n_objectives, normalisation="l2") + ).float() + elif self.cfg.focus_type is not None and "learned" in self.cfg.focus_type: + if ( + self.focus_model is not None + and train_it >= self.cfg.focus_model_training_limits[0] * self.cfg.max_train_it + ): + focus_dir = self.focus_model.sample_focus_directions(n) + else: + focus_dir = torch.tensor( + metrics.sample_positiveQuadrant_ndim_sphere(n, self.n_objectives, normalisation="l2") + ).float() + else: + raise NotImplementedError(f"Unsupported focus_type={type(self.cfg.focus_type)}") + + return {"focus_dir": focus_dir, "encoding": self.encode(focus_dir)} + + def encode(self, conditional: Tensor) -> Tensor: + return ( + thermometer(conditional, self.ocfg.cond.moo.num_thermometer_dim, 0, 1).reshape(conditional.shape[0], -1) + if self.cfg.use_steer_thermomether + else conditional + ) + + def transform(self, cond_info: Dict[str, Tensor], flat_rewards: Tensor, scalar_logreward: Tensor = None) -> Tensor: + focus_coef, in_focus_mask = metrics.compute_focus_coef( + flat_rewards, cond_info["focus_dir"], self.cfg.focus_cosim, self.cfg.focus_limit_coef + ) + if scalar_logreward is None: + scalar_logreward = torch.log(focus_coef) + else: + scalar_logreward[in_focus_mask] += torch.log(focus_coef[in_focus_mask]) + scalar_logreward[~in_focus_mask] = self.ocfg.algo.illegal_action_logreward + + return scalar_logreward + + def step_focus_model(self, batch: gd.Batch, train_it: int): + focus_model_training_limits = self.cfg.focus_model_training_limits + max_train_it = self.ocfg.num_training_steps + if ( + self.focus_model is not None + and train_it >= focus_model_training_limits[0] * max_train_it + and train_it <= focus_model_training_limits[1] * max_train_it + ): + self.focus_model.update_belief(deepcopy(batch.focus_dir), deepcopy(batch.flat_rewards)) diff --git a/src/gflownet/utils/config.py b/src/gflownet/utils/config.py new file mode 100644 index 00000000..db3d3905 --- /dev/null +++ b/src/gflownet/utils/config.py @@ -0,0 +1,77 @@ +from dataclasses import dataclass, field +from typing import Any, List, Optional + + +@dataclass +class TempCondConfig: + """Config for the temperature conditional. + + Attributes + ---------- + + sample_dist : str + The distribution to sample the inverse temperature from. Can be one of: + - "uniform": uniform distribution + - "loguniform": log-uniform distribution + - "gamma": gamma distribution + - "constant": constant temperature + - "beta": beta distribution + dist_params : List[Any] + The parameters of the temperature distribution. E.g. for the "uniform" distribution, this is the range. + num_thermometer_dim : int + The number of thermometer encoding dimensions to use. + """ + + sample_dist: str = "uniform" + dist_params: List[Any] = field(default_factory=lambda: [0.5, 32]) + num_thermometer_dim: int = 32 + + +@dataclass +class MultiObjectiveConfig: + num_objectives: int = 2 + num_thermometer_dim: int = 16 + + +@dataclass +class WeightedPreferencesConfig: + """Config for the weighted preferences conditional. + + Attributes + ---------- + preference_type : str + The preference sampling distribution, defaults to "dirichlet". Can be one of: + - "dirichlet": Dirichlet distribution + - "dirichlet_exponential": Dirichlet distribution with exponential temperature + - "seeded": Enumerated preferences + - None: All rewards equally weighted""" + + preference_type: Optional[str] = "dirichlet" + + +@dataclass +class FocusRegionConfig: + """Config for the focus region conditional. + + Attributes + ---------- + focus_type : str + The type of focus distribtuion used, see FocusRegionConditon.setup_focus_regions. Can be one of: + [None, "centered", "partitioned", "dirichlet", "hyperspherical", "learned-gfn", "learned-tabular"] + """ + + focus_type: Optional[str] = "learned-tabular" + use_steer_thermomether: bool = False + focus_cosim: float = 0.98 + focus_limit_coef: float = 0.1 + focus_model_training_limits: tuple[float, float] = (0.25, 0.75) + focus_model_state_space_res: int = 30 + max_train_it: int = 20_000 + + +@dataclass +class ConditionalsConfig: + temperature: TempCondConfig = TempCondConfig() + moo: MultiObjectiveConfig = MultiObjectiveConfig() + weighted_prefs: WeightedPreferencesConfig = WeightedPreferencesConfig() + focus_region: FocusRegionConfig = FocusRegionConfig() diff --git a/src/gflownet/utils/metrics.py b/src/gflownet/utils/metrics.py index 3b118b44..cc37c127 100644 --- a/src/gflownet/utils/metrics.py +++ b/src/gflownet/utils/metrics.py @@ -33,7 +33,7 @@ def compute_focus_coef( assert ( focus_limit_coef > 0.0 and focus_limit_coef <= 1.0 ), f"focus_limit_coef must be in (0, 1], now {focus_limit_coef}" - focus_gamma_param = np.log(focus_limit_coef) / np.log(focus_cosim) + focus_gamma_param = torch.tensor(np.log(focus_limit_coef) / np.log(focus_cosim)).float() cosim = nn.functional.cosine_similarity(flat_rewards, focus_dirs, dim=1) in_focus_mask = cosim >= focus_cosim focus_coef = torch.where(in_focus_mask, cosim**focus_gamma_param, 0.0) diff --git a/src/gflownet/utils/multiobjective_hooks.py b/src/gflownet/utils/multiobjective_hooks.py index d1ac0a2a..4862c6c7 100644 --- a/src/gflownet/utils/multiobjective_hooks.py +++ b/src/gflownet/utils/multiobjective_hooks.py @@ -120,7 +120,8 @@ def _run_pareto_accumulation(self): def __call__(self, trajs, rewards, flat_rewards, cond_info): # locally (in-process) accumulate flat rewards to build a better pareto estimate self.all_flat_rewards = self.all_flat_rewards + list(flat_rewards) - self.all_focus_dirs = self.all_focus_dirs + list(cond_info["focus_dir"]) + if self.compute_focus_accuracy: + self.all_focus_dirs = self.all_focus_dirs + list(cond_info["focus_dir"]) self.all_smi = self.all_smi + list([i.get("smi", None) for i in trajs]) if len(self.all_flat_rewards) > self.num_to_keep: self.all_flat_rewards = self.all_flat_rewards[-self.num_to_keep :] @@ -128,7 +129,8 @@ def __call__(self, trajs, rewards, flat_rewards, cond_info): self.all_smi = self.all_smi[-self.num_to_keep :] flat_rewards = torch.stack(self.all_flat_rewards).numpy() - focus_dirs = torch.stack(self.all_focus_dirs).numpy() + if self.compute_focus_accuracy: + focus_dirs = torch.stack(self.all_focus_dirs).numpy() # collects empirical pareto front from in-process samples pareto_idces = metrics.is_pareto_efficient(-flat_rewards, return_mask=False)