Skip to content

Commit

Permalink
Refactor tasks (#98)
Browse files Browse the repository at this point in the history
This PR refactors tasks into reusable components that will make it easier to define new tasks and methods.

The PR is pretty big, but a lot of it is moving code around (or deleting repeated code):

- Creates a notion of `Conditional`, implements `TemperatureConditional`, `MultiObjectiveWeightedPreferences`, `FocusRegionConditional`
- Separates the commonalities between `seh_frag` and `seh_frag_moo` into a more generic `StandardOnlineTrainer` that is meant to be easily subclassed for new tasks.
- Adds some implementation notes and comments
- Adds a `validate_batch` routine that's useful for debugging, e.g. new environments and datasets

Also fixes some bugs:
- Makes `valid_offline_ratio` a flag and sets it explicitly in tasks where it wasn't properly set
- `first_graph_idx` was incorrectly calculated in SubTB (affected logging of logZ values)
- QM9Dataset was returning the wrong shape for its rewards
- Adds a `allow_5_valence_nitrogen` flag to `MolBuildingEnvContext`, this is needed in some cases, see `tasks/qm9.py`.
- Adds an explicit `stop_mask` to `MolBuildingEnvContext.graph_to_Data`
- Fixes incorrect default objective name in `seh_frag_moo`
- Fixes the default configurations in the tasks' `main` that hadn't been updated
- Fixes a number of routines where `focus_cond` was assumed to exist (but we can now turn it off).

commits:

* little test

* new config structure - in progress

* trying config by names

* further refactor progress

* better pyi sort + fix moo example

* tox

* import fixes

* add SQL + fix n_valid

* tox test

* fix mypy hook and convert qm9 to new cfg

* better config generation

* use generated config.py

* use generated config.py

* fix rng call types

* fix test + tox

* better config doc

* fix deps

* tox

* re-fix deps

* minor fixes for seh_frag_moo

* tox

* beginning of refactor + impl notes

* multiobject weighted prefs

* focus conditional in progress

* switch to OmegaConf

* switch to omegaconf

* fix pre-commit-config

* add omegaconf dep

* fix list defaults to fields

* remove comment

* finish focus conditional

* various fixes + switch to new config

* tox

* update README

* make string configs into Literals

* switch task construction order

* OmegaConf does not support Literal :(

* tox

* remove dead code + guard against no focus used

* fix for no replay

* refactor qm9 + remove unused configs

* many fixes to QM9, some debugging code and other various fixes

* explicit valid_offline_ratio flag

* do_validate_batch off by default

* addressing PR comments

* made device configurable
  • Loading branch information
bengioe committed Aug 1, 2023
1 parent ffabcfd commit 152b18f
Show file tree
Hide file tree
Showing 21 changed files with 782 additions and 543 deletions.
11 changes: 8 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ The GNN model can be trained on a mix of existing data (offline) and self-genera

## Repo overview

- [algo](src/gflownet/algo), contains GFlowNet algorithms implementations (only [Trajectory Balance](https://arxiv.org/abs/2201.13259) for now), as well as some baselines. These implement how to sample trajectories from a model and compute the loss from trajectories.
- [algo](src/gflownet/algo), contains GFlowNet algorithms implementations ([Trajectory Balance](https://arxiv.org/abs/2201.13259), [SubTB](https://arxiv.org/abs/2209.12782), [Flow Matching](https://arxiv.org/abs/2106.04399)), as well as some baselines. These implement how to sample trajectories from a model and compute the loss from trajectories.
- [data](src/gflownet/data), contains dataset definitions, data loading and data sampling utilities.
- [envs](src/gflownet/envs), contains environment classes; a graph-building environment base, and a molecular graph context class. The base environment is agnostic to what kind of graph is being made, and the context class specifies mappings from graphs to objects (e.g. molecules) and torch geometric Data.
- [examples](docs/examples), contains simple example implementations of GFlowNet.
Expand All @@ -30,8 +30,11 @@ The GNN model can be trained on a mix of existing data (offline) and self-genera
- [qm9](src/gflownet/tasks/qm9/qm9.py), temperature-conditional molecule sampler based on QM9's HOMO-LUMO gap data as a reward.
- [seh_frag](src/gflownet/tasks/seh_frag.py), reproducing Bengio et al. 2021, fragment-based molecule design targeting the sEH protein
- [seh_frag_moo](src/gflownet/tasks/seh_frag_moo.py), same as the above, but with multi-objective optimization (incl. QED, SA, and molecule weight objectives).
- [utils](src/gflownet/utils), contains utilities (multiprocessing).
- [`train.py`](src/gflownet/train.py), defines a general harness for training GFlowNet models.
- [utils](src/gflownet/utils), contains utilities (multiprocessing, metrics, conditioning).
- [`trainer.py`](src/gflownet/trainer.py), defines a general harness for training GFlowNet models.
- [`online_trainer.py`](src/gflownet/online_trainer.py), defines a typical online-GFN training loop.

See [implementation notes](docs/implementation_notes.md) for more.

## Getting started

Expand All @@ -57,6 +60,8 @@ To install or [depend on](https://matiascodesal.com/blog/how-use-git-repository-
pip install git+https://github.com/recursionpharma/gflownet.git@v0.0.10 --find-links ...
```

If package dependencies seem not to work, you may need to install the exact frozen versions listed `requirements/`, i.e. `pip install -r requirements/main_3.9.txt`.

## Developing & Contributing

TODO: Write Contributing.md.
18 changes: 18 additions & 0 deletions docs/implementation_notes.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Implementation notes

This repo is centered around training GFlowNets that produce graphs. While we intend to specialize towards building molecules, we've tried to keep the implementation moderately agnostic to that fact, which makes it able to support other graph-generation environments.

## Environment, Context, Task, Trainers

We separate experiment concerns in four categories:
- The Environment is the graph abstraction that is common to all; think of it as the base definition of the MDP.
- The Context provides an interface between the agent and the environment, it
- maps graphs to torch_geometric `Data`
instances
- maps GraphActions to action indices
- produces action masks
- communicates to the model what inputs it should expect
- The Task class is responsible for computing the reward of a state, and for sampling conditioning information
- The Trainer class is responsible for instanciating everything, and running the training & testing loop

Typically one would setup a new experiment by creating a class that inherits from `GFNTask` and a class that inherits from `GFNTrainer`. To implement a new MDP, one would create a class that inherits from `GraphBuildingEnvContext`.
3 changes: 3 additions & 0 deletions src/gflownet/algo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ class AlgoConfig:
offline_ratio: float
The ratio of samples drawn from `self.training_data` during training. The rest is drawn from
`self.sampling_model`
valid_offline_ratio: float
Idem but for validation, and `self.test_data`.
train_random_action_prob : float
The probability of taking a random action during training
valid_random_action_prob : float
Expand All @@ -108,6 +110,7 @@ class AlgoConfig:
max_edges: int = 128
illegal_action_logreward: float = -100
offline_ratio: float = 0.5
valid_offline_ratio: float = 1
train_random_action_prob: float = 0.0
valid_random_action_prob: float = 0.0
valid_sample_cond_info: bool = True
Expand Down
2 changes: 1 addition & 1 deletion src/gflownet/algo/envelope_q_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
generate_forward_trajectory,
)
from gflownet.models.graph_transformer import GraphTransformer, mlp
from gflownet.train import GFNTask
from gflownet.trainer import GFNTask

from .graph_sampling import GraphSampler

Expand Down
4 changes: 2 additions & 2 deletions src/gflownet/algo/trajectory_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
GraphBuildingEnvContext,
generate_forward_trajectory,
)
from gflownet.train import GFNAlgorithm
from gflownet.trainer import GFNAlgorithm


class TrajectoryBalanceModel(nn.Module):
Expand Down Expand Up @@ -363,7 +363,7 @@ def compute_batch_losses(
traj_losses = self.subtb_loss_fast(log_p_F, log_p_B, per_graph_out[:, 0], clip_log_R, batch.traj_lens)
# The position of the first graph of each trajectory
first_graph_idx = torch.zeros_like(batch.traj_lens)
first_graph_idx = torch.cumsum(batch.traj_lens[:-1], 0, out=first_graph_idx[1:])
torch.cumsum(batch.traj_lens[:-1], 0, out=first_graph_idx[1:])
log_Z = per_graph_out[first_graph_idx, 0]
else:
# Compute log numerator and denominator of the TB objective
Expand Down
8 changes: 8 additions & 0 deletions src/gflownet/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from gflownet.data.config import ReplayConfig
from gflownet.models.config import ModelConfig
from gflownet.tasks.config import TasksConfig
from gflownet.utils.config import ConditionalsConfig


@dataclass
Expand Down Expand Up @@ -51,12 +52,16 @@ class Config:
----------
log_dir : str
The directory where to store logs, checkpoints, and samples.
device : str
The device to use for training (either "cpu" or "cuda[:<device_id>]")
seed : int
The random seed
validate_every : int
The number of training steps after which to validate the model
checkpoint_every : Optional[int]
The number of training steps after which to checkpoint the model
print_every : int
The number of training steps after which to print the training loss
start_at_step : int
The training step to start at (default: 0)
num_final_gen_steps : Optional[int]
Expand All @@ -76,9 +81,11 @@ class Config:
"""

log_dir: str = MISSING
device: str = "cuda"
seed: int = 0
validate_every: int = 1000
checkpoint_every: Optional[int] = None
print_every: int = 100
start_at_step: int = 0
num_final_gen_steps: Optional[int] = None
num_training_steps: int = 10_000
Expand All @@ -92,3 +99,4 @@ class Config:
opt: OptimizerConfig = OptimizerConfig()
replay: ReplayConfig = ReplayConfig()
task: TasksConfig = TasksConfig()
cond: ConditionalsConfig = ConditionalsConfig()
6 changes: 5 additions & 1 deletion src/gflownet/data/qm9.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import pandas as pd
import rdkit.Chem as Chem
import torch
from torch.utils.data import Dataset


Expand Down Expand Up @@ -39,7 +40,10 @@ def __len__(self):
return len(self.idcs)

def __getitem__(self, idx):
return (Chem.MolFromSmiles(self.df["SMILES"][self.idcs[idx]]), self.df[self.target][self.idcs[idx]])
return (
Chem.MolFromSmiles(self.df["SMILES"][self.idcs[idx]]),
torch.tensor([self.df[self.target][self.idcs[idx]]]).float(),
)


def convert_h5():
Expand Down
35 changes: 34 additions & 1 deletion src/gflownet/data/sampling_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from torch.utils.data import Dataset, IterableDataset

from gflownet.data.replay_buffer import ReplayBuffer
from gflownet.envs.graph_building_env import GraphActionCategorical


class SamplingIterator(IterableDataset):
Expand Down Expand Up @@ -98,6 +99,7 @@ def __init__(
self.random_action_prob = random_action_prob
self.hindsight_ratio = hindsight_ratio
self.train_it = init_train_iter
self.do_validate_batch = False # Turn this on for debugging
self.log_molecule_smis = not hasattr(self.ctx, "not_a_molecule_env") # TODO: make this a proper flag

# Slightly weird semantics, but if we're sampling x given some fixed cond info (data)
Expand Down Expand Up @@ -130,6 +132,9 @@ def _idx_iterator(self):
if n == 0:
yield np.arange(0, 0)
return
assert (
self.offline_batch_size > 0
), "offline_batch_size must be > 0 if not streaming and len(data) > 0 (have you set ratio=0?)"
if worker_info is None: # no multi-processing
start, end, wid = 0, n, -1
else: # split the data into chunks (per-worker)
Expand Down Expand Up @@ -237,6 +242,7 @@ def __iter__(self):
log_rewards[torch.logical_not(is_valid)] = self.illegal_action_logreward

# Computes some metrics
extra_info = {}
if not self.sample_cond_info:
# If we're using a dataset of preferences, the user may want to know the id of the preference
for i, j in zip(trajs, idcs):
Expand All @@ -253,7 +259,6 @@ def __iter__(self):
{k: v[num_offline:] for k, v in deepcopy(cond_info).items()},
)
if num_online > 0:
extra_info = {}
for hook in self.log_hooks:
extra_info.update(
hook(
Expand Down Expand Up @@ -318,9 +323,37 @@ def __iter__(self):
# TODO: we could very well just pass the cond_info dict to construct_batch above,
# and the algo can decide what it wants to put in the batch object

# Only activate for debugging your environment or dataset (e.g. the dataset could be
# generating trajectories with illegal actions)
if self.do_validate_batch:
self.validate_batch(batch, trajs)

self.train_it += worker_info.num_workers if worker_info is not None else 1
yield batch

def validate_batch(self, batch, trajs):
for actions, atypes in [(batch.actions, self.ctx.action_type_order)] + (
[(batch.bck_actions, self.ctx.bck_action_type_order)]
if hasattr(batch, "bck_actions") and hasattr(self.ctx, "bck_action_type_order")
else []
):
mask_cat = GraphActionCategorical(
batch,
[self.model._action_type_to_mask(t, batch) for t in atypes],
[self.model._action_type_to_key[t] for t in atypes],
[None for _ in atypes],
)
masked_action_is_used = 1 - mask_cat.log_prob(actions, logprobs=mask_cat.logits)
num_trajs = len(trajs)
batch_idx = torch.arange(num_trajs, device=batch.x.device).repeat_interleave(batch.traj_lens)
first_graph_idx = torch.zeros_like(batch.traj_lens)
torch.cumsum(batch.traj_lens[:-1], 0, out=first_graph_idx[1:])
if masked_action_is_used.sum() != 0:
invalid_idx = masked_action_is_used.argmax().item()
traj_idx = batch_idx[invalid_idx].item()
timestep = invalid_idx - first_graph_idx[traj_idx].item()
raise ValueError("Found an action that was masked out", trajs[traj_idx]["traj"][timestep])

def log_generated(self, trajs, rewards, flat_rewards, cond_info):
if self.log_molecule_smis:
mols = [
Expand Down
13 changes: 9 additions & 4 deletions src/gflownet/envs/mol_building_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(
charges=[0, 1, -1],
expl_H_range=[0, 1],
allow_explicitly_aromatic=False,
allow_5_valence_nitrogen=False,
num_rw_feat=8,
max_nodes=None,
max_edges=None,
Expand Down Expand Up @@ -118,9 +119,11 @@ def __init__(
BondType.AROMATIC: 1.5,
}
pt = Chem.GetPeriodicTable()
self.allow_5_valence_nitrogen = allow_5_valence_nitrogen
self._max_atom_valence = {
**{a: max(pt.GetValenceList(a)) for a in atoms},
"N": 3, # We'll handle nitrogen valence later explicitly in graph_to_Data
# We'll handle nitrogen valence later explicitly in graph_to_Data
"N": 3 if not allow_5_valence_nitrogen else 5,
"*": 0, # wildcard atoms have 0 valence until filled in
}

Expand Down Expand Up @@ -231,9 +234,10 @@ def graph_to_Data(self, g: Graph) -> gd.Data:
max_atom_valence = self._max_atom_valence[ad.get("fill_wildcard", None) or ad["v"]]
# Special rule for Nitrogen
if ad["v"] == "N" and ad.get("charge", 0) == 1:
# This is definitely a heuristic, but to keep things simple we'll limit Nitrogen's valence to 3 (as
# This is definitely a heuristic, but to keep things simple we'll limit* Nitrogen's valence to 3 (as
# per self._max_atom_valence) unless it is charged, then we make it 5.
# This keeps RDKit happy (and is probably a good idea anyway).
# (* unless allow_5_valence_nitrogen is True, then it's just always 5)
max_atom_valence = 5
max_valence[n] = max_atom_valence - abs(ad.get("charge", 0)) - ad.get("expl_H", 0)
# Compute explicitly defined valence:
Expand Down Expand Up @@ -281,14 +285,15 @@ def is_ok_non_edge(e):
non_edge_index = torch.zeros((2, 0), dtype=torch.long)
else:
gc = nx.complement(g)
non_edge_index = torch.tensor([i for i in gc.edges if is_ok_non_edge(i)], dtype=torch.long).T.reshape(
(2, -1)
non_edge_index = (
torch.tensor([i for i in gc.edges if is_ok_non_edge(i)], dtype=torch.long).reshape((-1, 2)).T
)
data = gd.Data(
x,
edge_index,
edge_attr,
non_edge_index=non_edge_index,
stop_mask=torch.ones(1, 1) if len(g) > 0 else torch.zeros(1, 1),
add_node_mask=add_node_mask,
set_node_attr_mask=set_node_attr_mask,
add_edge_mask=torch.ones((non_edge_index.shape[1], 1)), # Already filtered by is_ok_non_edge
Expand Down
45 changes: 25 additions & 20 deletions src/gflownet/models/graph_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,27 @@ class GraphTransformerGFN(nn.Module):
Outputs logits corresponding to the action types used by the env_ctx argument.
"""

# The GraphTransformer outputs per-node, per-edge, and per-graph embeddings, this routes the
# embeddings to the right MLP
_action_type_to_graph_part = {
GraphActionType.Stop: "graph",
GraphActionType.AddNode: "node",
GraphActionType.SetNodeAttr: "node",
GraphActionType.AddEdge: "non_edge",
GraphActionType.SetEdgeAttr: "edge",
GraphActionType.RemoveNode: "node",
GraphActionType.RemoveNodeAttr: "node",
GraphActionType.RemoveEdge: "edge",
GraphActionType.RemoveEdgeAttr: "edge",
}
# The torch_geometric batch key each graph part corresponds to
_graph_part_to_key = {
"graph": None,
"node": "x",
"non_edge": "non_edge_index",
"edge": "edge_index",
}

def __init__(
self,
env_ctx,
Expand Down Expand Up @@ -184,25 +205,8 @@ def __init__(
GraphActionType.RemoveEdge: (num_edge_feat, 1),
GraphActionType.RemoveEdgeAttr: (num_edge_feat, env_ctx.num_edge_attrs),
}
# The GraphTransformer outputs per-node, per-edge, and per-graph embeddings, this routes the
# embeddings to the right MLP
self._action_type_to_graph_part = {
GraphActionType.Stop: "graph",
GraphActionType.AddNode: "node",
GraphActionType.SetNodeAttr: "node",
GraphActionType.AddEdge: "non_edge",
GraphActionType.SetEdgeAttr: "edge",
GraphActionType.RemoveNode: "node",
GraphActionType.RemoveNodeAttr: "node",
GraphActionType.RemoveEdge: "edge",
GraphActionType.RemoveEdgeAttr: "edge",
}
# The torch_geometric batch key each graph part corresponds to
self._graph_part_to_key = {
"graph": None,
"node": "x",
"non_edge": "non_edge_index",
"edge": "edge_index",
self._action_type_to_key = {
at: self._graph_part_to_key[self._action_type_to_graph_part[at]] for at in self._action_type_to_graph_part
}

# Here we create only the embedding -> logit mapping MLPs that are required by the environment
Expand All @@ -229,13 +233,14 @@ def _action_type_to_logit(self, t, emb, g):

def _mask(self, x, m):
# mask logit vector x with binary mask m, -1000 is a tiny log-value
# Note to self: we can't use torch.inf here, because inf * 0 is nan (but also see issue #99)
return x * m + -1000 * (1 - m)

def _make_cat(self, g, emb, action_types):
return GraphActionCategorical(
g,
logits=[self._action_type_to_logit(t, emb, g) for t in action_types],
keys=[self._graph_part_to_key[self._action_type_to_graph_part[t]] for t in action_types],
keys=[self._action_type_to_key[t] for t in action_types],
masks=[self._action_type_to_mask(t, g) for t in action_types],
types=action_types,
)
Expand Down
Loading

0 comments on commit 152b18f

Please sign in to comment.