Skip to content

Commit

Permalink
add backward mask support + small ring task
Browse files Browse the repository at this point in the history
  • Loading branch information
bengioe committed Aug 2, 2023
1 parent 152b18f commit ef0ae83
Show file tree
Hide file tree
Showing 4 changed files with 226 additions and 20 deletions.
4 changes: 2 additions & 2 deletions src/gflownet/envs/graph_building_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,9 +271,9 @@ def add_parent(a, new_g):
GraphAction(GraphActionType.AddNode, source=anchor, value=g.nodes[i]["v"]),
new_g,
)
if len(g.nodes) == 1:
if len(g.nodes) == 1 and len(g.nodes[i]) == 1:
# The final node is degree 0, need this special case to remove it
# and end up with S0, the empty graph root
# and end up with S0, the empty graph root (but only if it has no attrs except 'v')
add_parent(
GraphAction(GraphActionType.AddNode, source=0, value=g.nodes[i]["v"]),
graph_without_node(g, i),
Expand Down
101 changes: 91 additions & 10 deletions src/gflownet/envs/mol_building_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@
from rdkit.Chem import Mol
from rdkit.Chem.rdchem import BondType, ChiralType

from gflownet.envs.graph_building_env import Graph, GraphAction, GraphActionType, GraphBuildingEnvContext
from gflownet.envs.graph_building_env import (
Graph,
GraphAction,
GraphActionType,
GraphBuildingEnvContext,
graph_without_edge,
)
from gflownet.utils.graphs import random_walk_probs

DEFAULT_CHIRAL_TYPES = [ChiralType.CHI_UNSPECIFIED, ChiralType.CHI_TETRAHEDRAL_CW, ChiralType.CHI_TETRAHEDRAL_CCW]
Expand Down Expand Up @@ -77,19 +83,22 @@ def __init__(
# The size of the input vector for each atom
self.atom_attr_size = sum(len(i) for i in self.atom_attr_values.values())
self.atom_attrs = sorted(self.atom_attr_values.keys())
# 'v' is set separately when creating the node, so there's no point in having a SetNodeAttr logit for it
self.settable_atom_attrs = [i for i in self.atom_attrs if i != "v"]
# The beginning position within the input vector of each attribute
self.atom_attr_slice = [0] + list(np.cumsum([len(self.atom_attr_values[i]) for i in self.atom_attrs]))
# The beginning position within the logit vector of each attribute
num_atom_logits = [len(self.atom_attr_values[i]) - 1 for i in self.atom_attrs]
num_atom_logits = [len(self.atom_attr_values[i]) - 1 for i in self.settable_atom_attrs]
self.atom_attr_logit_slice = {
k: (s, e)
for k, s, e in zip(self.atom_attrs, [0] + list(np.cumsum(num_atom_logits)), np.cumsum(num_atom_logits))
for k, s, e in zip(
self.settable_atom_attrs, [0] + list(np.cumsum(num_atom_logits)), np.cumsum(num_atom_logits)
)
}
# The attribute and value each logit dimension maps back to
self.atom_attr_logit_map = [
(k, v)
for k in self.atom_attrs
if k != "v"
for k in self.settable_atom_attrs
# index 0 is skipped because it is the default value
for v in self.atom_attr_values[k][1:]
]
Expand Down Expand Up @@ -147,12 +156,21 @@ def __init__(
GraphActionType.AddEdge,
GraphActionType.SetEdgeAttr,
]
self.bck_action_type_order = [
GraphActionType.RemoveNode,
GraphActionType.RemoveNodeAttr,
GraphActionType.RemoveEdge,
GraphActionType.RemoveEdgeAttr,
]
self.device = torch.device("cpu")

def aidx_to_GraphAction(self, g: gd.Data, action_idx: Tuple[int, int, int], fwd: bool = True):
"""Translate an action index (e.g. from a GraphActionCategorical) to a GraphAction"""
act_type, act_row, act_col = [int(i) for i in action_idx]
t = self.action_type_order[act_type]
if fwd:
t = self.action_type_order[act_type]
else:
t = self.bck_action_type_order[act_type]
if t is GraphActionType.Stop:
return GraphAction(t)
elif t is GraphActionType.AddNode:
Expand All @@ -167,9 +185,28 @@ def aidx_to_GraphAction(self, g: gd.Data, action_idx: Tuple[int, int, int], fwd:
a, b = g.edge_index[:, act_row * 2] # Edges are duplicated to get undirected GNN, deduplicated for logits
attr, val = self.bond_attr_logit_map[act_col]
return GraphAction(t, source=a.item(), target=b.item(), attr=attr, value=val)
elif t is GraphActionType.RemoveNode:
return GraphAction(t, source=act_row)
elif t is GraphActionType.RemoveNodeAttr:
attr = self.settable_atom_attrs[act_col]
return GraphAction(t, source=act_row, attr=attr)
elif t is GraphActionType.RemoveEdge:
a, b = g.edge_index[:, act_row * 2]
return GraphAction(t, source=a.item(), target=b.item())
elif t is GraphActionType.RemoveEdgeAttr:
a, b = g.edge_index[:, act_row * 2]
attr = self.bond_attrs[act_col]
return GraphAction(t, source=a.item(), target=b.item(), attr=attr)

def GraphAction_to_aidx(self, g: gd.Data, action: GraphAction) -> Tuple[int, int, int]:
"""Translate a GraphAction to an index tuple"""
for u in [self.action_type_order, self.bck_action_type_order]:
if action.action in u:
type_idx = u.index(action.action)
break
else:
raise ValueError(f"Unknown action type {action.action}")

if action.action is GraphActionType.Stop:
row = col = 0
elif action.action is GraphActionType.AddNode:
Expand Down Expand Up @@ -201,7 +238,22 @@ def GraphAction_to_aidx(self, g: gd.Data, action: GraphAction) -> Tuple[int, int
col = (
self.bond_attr_values[action.attr].index(action.value) - 1 + self.bond_attr_logit_slice[action.attr][0]
)
type_idx = self.action_type_order.index(action.action)
elif action.action is GraphActionType.RemoveNode:
row = action.source
col = 0
elif action.action is GraphActionType.RemoveNodeAttr:
row = action.source
col = self.settable_atom_attrs.index(action.attr)
elif action.action is GraphActionType.RemoveEdge:
row = ((g.edge_index.T == torch.tensor([(action.source, action.target)])).prod(1)).argmax()
row = int(row) // 2 # edges are duplicated, but edge logits are not
col = 0
elif action.action is GraphActionType.RemoveEdgeAttr:
row = (g.edge_index.T == torch.tensor([(action.source, action.target)])).prod(1).argmax()
row = row.div(2, rounding_mode="floor") # type: ignore
col = self.bond_attrs.index(action.attr)
else:
raise ValueError(f"Unknown action type {action.action}")
return (type_idx, int(row), int(col))

def graph_to_Data(self, g: Graph) -> gd.Data:
Expand All @@ -211,25 +263,43 @@ def graph_to_Data(self, g: Graph) -> gd.Data:
add_node_mask = torch.ones((x.shape[0], self.num_new_node_values))
if self.max_nodes is not None and len(g.nodes) >= self.max_nodes:
add_node_mask *= 0
remove_node_mask = torch.zeros((x.shape[0], 1)) + (1 if len(g) == 0 else 0)
remove_node_attr_mask = torch.zeros((x.shape[0], len(self.settable_atom_attrs)))

explicit_valence = {}
max_valence = {}
set_node_attr_mask = torch.ones((x.shape[0], self.num_node_attr_logits))
if not len(g.nodes):
set_node_attr_mask *= 0
for i, n in enumerate(g.nodes):
ad = g.nodes[n]
if g.degree(n) <= 1 and len(ad) == 1 and all([len(g[n][neigh]) == 0 for neigh in g.neighbors(n)]):
# If there's only the 'v' key left and the node is a leaf, and the edge that connect to the node have
# no attributes set, we can remove it
remove_node_mask[i] = 1
for k, sl in zip(self.atom_attrs, self.atom_attr_slice):
# idx > 0 means that the attribute is not the default value
idx = self.atom_attr_values[k].index(ad[k]) if k in ad else 0
x[i, sl + idx] = 1
# If the attribute is already there, mask out logits
# (or if the attribute is a negative attribute and has been filled)
if k == "v":
continue
# If the attribute
# - is already there (idx > 0),
# - or the attribute is a negative attribute and has been filled
# - or the attribute is a negative attribute and is not fillable (i.e. not a key of ad)
# then mask forward logits.
# For backward logits, positively mask if the attribute is there (idx > 0).
if k in self.negative_attrs:
if k in ad and idx > 0 or k not in ad:
s, e = self.atom_attr_logit_slice[k]
set_node_attr_mask[i, s:e] = 0
# We don't want to make the attribute removable if it's not fillable (i.e. not a key of ad)
if k in ad:
remove_node_attr_mask[i, self.settable_atom_attrs.index(k)] = 1
elif k in ad:
s, e = self.atom_attr_logit_slice[k]
set_node_attr_mask[i, s:e] = 0
remove_node_attr_mask[i, self.settable_atom_attrs.index(k)] = 1
# Account for charge and explicit Hs in atom as limiting the total valence
max_atom_valence = self._max_atom_valence[ad.get("fill_wildcard", None) or ad["v"]]
# Special rule for Nitrogen
Expand All @@ -256,8 +326,14 @@ def graph_to_Data(self, g: Graph) -> gd.Data:
s, e = self.atom_attr_logit_slice["expl_H"]
set_node_attr_mask[i, s:e] = 0

remove_edge_mask = torch.zeros((len(g.edges), 1))
for i, (u, v) in enumerate(g.edges):
if g.degree(u) > 1 and g.degree(v) > 1:
if nx.algorithms.is_connected(graph_without_edge(g, (u, v))):
remove_edge_mask[i] = 1
edge_attr = torch.zeros((len(g.edges) * 2, self.num_edge_dim))
set_edge_attr_mask = torch.zeros((len(g.edges), self.num_edge_attr_logits))
remove_edge_attr_mask = torch.zeros((len(g.edges), len(self.bond_attrs)))
for i, e in enumerate(g.edges):
ad = g.edges[e]
for k, sl in zip(self.bond_attrs, self.bond_attr_slice):
Expand All @@ -267,6 +343,7 @@ def graph_to_Data(self, g: Graph) -> gd.Data:
if k in ad: # If the attribute is already there, mask out logits
s, e = self.bond_attr_logit_slice[k]
set_edge_attr_mask[i, s:e] = 0
remove_edge_attr_mask[i, self.bond_attrs.index(k)] = 1
# Check which bonds don't bust the valence of their atoms
if "type" not in ad: # Only if type isn't already set
sl, _ = self.bond_attr_logit_slice["type"]
Expand All @@ -293,11 +370,15 @@ def is_ok_non_edge(e):
edge_index,
edge_attr,
non_edge_index=non_edge_index,
stop_mask=torch.ones(1, 1) if len(g) > 0 else torch.zeros(1, 1),
stop_mask=torch.ones((1, 1)) * (len(g.nodes) > 0), # Can only stop if there's at least a node
add_node_mask=add_node_mask,
set_node_attr_mask=set_node_attr_mask,
add_edge_mask=torch.ones((non_edge_index.shape[1], 1)), # Already filtered by is_ok_non_edge
set_edge_attr_mask=set_edge_attr_mask,
remove_node_mask=remove_node_mask,
remove_node_attr_mask=remove_node_attr_mask,
remove_edge_mask=remove_edge_mask,
remove_edge_attr_mask=remove_edge_attr_mask,
)
if self.num_rw_feat > 0:
data.x = torch.cat([data.x, random_walk_probs(data, self.num_rw_feat, skip_odd=True)], 1)
Expand Down
90 changes: 90 additions & 0 deletions src/gflownet/tasks/make_rings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import os
import socket
from typing import Dict, List, Tuple, Union

import numpy as np
import torch
from rdkit import Chem
from rdkit.Chem.rdchem import Mol as RDMol
from torch import Tensor

from gflownet.config import Config
from gflownet.envs.mol_building_env import MolBuildingEnvContext
from gflownet.online_trainer import StandardOnlineTrainer
from gflownet.trainer import FlatRewards, GFNTask, RewardScalar


class MakeRingsTask(GFNTask):
"""A toy task where the reward is the number of rings in the molecule."""

def __init__(
self,
rng: np.random.Generator,
):
self.rng = rng

def flat_reward_transform(self, y: Union[float, Tensor]) -> FlatRewards:
return FlatRewards(y)

def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Tensor]:
return {"beta": torch.ones(n), "encoding": torch.ones(n, 1)}

def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar:
scalar_logreward = torch.as_tensor(flat_reward).squeeze().clamp(min=1e-30).log()
return RewardScalar(scalar_logreward.flatten())

def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]:
rs = torch.tensor([m.GetRingInfo().NumRings() for m in mols]).float()
return FlatRewards(rs.reshape((-1, 1))), torch.ones(len(mols)).bool()


class MakeRingsTrainer(StandardOnlineTrainer):
def set_default_hps(self, cfg: Config):
cfg.hostname = socket.gethostname()
cfg.num_workers = 8
cfg.algo.global_batch_size = 64
cfg.algo.offline_ratio = 0
cfg.model.num_emb = 128
cfg.model.num_layers = 4

cfg.algo.method = "TB"
cfg.algo.max_nodes = 6
cfg.algo.sampling_tau = 0.9
cfg.algo.illegal_action_logreward = -75
cfg.algo.train_random_action_prob = 0.0
cfg.algo.valid_random_action_prob = 0.0
cfg.algo.tb.do_parameterize_p_b = True

cfg.replay.use = False

def setup_task(self):
self.task = MakeRingsTask(rng=self.rng)

def setup_env_context(self):
self.ctx = MolBuildingEnvContext(
["C"],
charges=[0], # disable charge
chiral_types=[Chem.rdchem.ChiralType.CHI_UNSPECIFIED], # disable chirality
num_rw_feat=0,
max_nodes=self.cfg.algo.max_nodes,
num_cond_dim=1,
)


def main():
hps = {
"log_dir": "./logs/debug_run_mr4",
"device": "cuda",
"num_training_steps": 10_000,
"num_workers": 8,
"algo": {"tb": {"do_parameterize_p_b": True}},
}
os.makedirs(hps["log_dir"], exist_ok=True)

trial = MakeRingsTrainer(hps)
trial.print_every = 1
trial.run()


if __name__ == "__main__":
main()
Loading

0 comments on commit ef0ae83

Please sign in to comment.