diff --git a/src/gflownet/envs/graph_building_env.py b/src/gflownet/envs/graph_building_env.py index a3df87bf..fa7b284b 100644 --- a/src/gflownet/envs/graph_building_env.py +++ b/src/gflownet/envs/graph_building_env.py @@ -271,9 +271,9 @@ def add_parent(a, new_g): GraphAction(GraphActionType.AddNode, source=anchor, value=g.nodes[i]["v"]), new_g, ) - if len(g.nodes) == 1: + if len(g.nodes) == 1 and len(g.nodes[i]) == 1: # The final node is degree 0, need this special case to remove it - # and end up with S0, the empty graph root + # and end up with S0, the empty graph root (but only if it has no attrs except 'v') add_parent( GraphAction(GraphActionType.AddNode, source=0, value=g.nodes[i]["v"]), graph_without_node(g, i), diff --git a/src/gflownet/envs/mol_building_env.py b/src/gflownet/envs/mol_building_env.py index 925e3456..21030200 100644 --- a/src/gflownet/envs/mol_building_env.py +++ b/src/gflownet/envs/mol_building_env.py @@ -8,7 +8,13 @@ from rdkit.Chem import Mol from rdkit.Chem.rdchem import BondType, ChiralType -from gflownet.envs.graph_building_env import Graph, GraphAction, GraphActionType, GraphBuildingEnvContext +from gflownet.envs.graph_building_env import ( + Graph, + GraphAction, + GraphActionType, + GraphBuildingEnvContext, + graph_without_edge, +) from gflownet.utils.graphs import random_walk_probs DEFAULT_CHIRAL_TYPES = [ChiralType.CHI_UNSPECIFIED, ChiralType.CHI_TETRAHEDRAL_CW, ChiralType.CHI_TETRAHEDRAL_CCW] @@ -77,19 +83,22 @@ def __init__( # The size of the input vector for each atom self.atom_attr_size = sum(len(i) for i in self.atom_attr_values.values()) self.atom_attrs = sorted(self.atom_attr_values.keys()) + # 'v' is set separately when creating the node, so there's no point in having a SetNodeAttr logit for it + self.settable_atom_attrs = [i for i in self.atom_attrs if i != "v"] # The beginning position within the input vector of each attribute self.atom_attr_slice = [0] + list(np.cumsum([len(self.atom_attr_values[i]) for i in self.atom_attrs])) # The beginning position within the logit vector of each attribute - num_atom_logits = [len(self.atom_attr_values[i]) - 1 for i in self.atom_attrs] + num_atom_logits = [len(self.atom_attr_values[i]) - 1 for i in self.settable_atom_attrs] self.atom_attr_logit_slice = { k: (s, e) - for k, s, e in zip(self.atom_attrs, [0] + list(np.cumsum(num_atom_logits)), np.cumsum(num_atom_logits)) + for k, s, e in zip( + self.settable_atom_attrs, [0] + list(np.cumsum(num_atom_logits)), np.cumsum(num_atom_logits) + ) } # The attribute and value each logit dimension maps back to self.atom_attr_logit_map = [ (k, v) - for k in self.atom_attrs - if k != "v" + for k in self.settable_atom_attrs # index 0 is skipped because it is the default value for v in self.atom_attr_values[k][1:] ] @@ -147,12 +156,21 @@ def __init__( GraphActionType.AddEdge, GraphActionType.SetEdgeAttr, ] + self.bck_action_type_order = [ + GraphActionType.RemoveNode, + GraphActionType.RemoveNodeAttr, + GraphActionType.RemoveEdge, + GraphActionType.RemoveEdgeAttr, + ] self.device = torch.device("cpu") def aidx_to_GraphAction(self, g: gd.Data, action_idx: Tuple[int, int, int], fwd: bool = True): """Translate an action index (e.g. from a GraphActionCategorical) to a GraphAction""" act_type, act_row, act_col = [int(i) for i in action_idx] - t = self.action_type_order[act_type] + if fwd: + t = self.action_type_order[act_type] + else: + t = self.bck_action_type_order[act_type] if t is GraphActionType.Stop: return GraphAction(t) elif t is GraphActionType.AddNode: @@ -167,9 +185,28 @@ def aidx_to_GraphAction(self, g: gd.Data, action_idx: Tuple[int, int, int], fwd: a, b = g.edge_index[:, act_row * 2] # Edges are duplicated to get undirected GNN, deduplicated for logits attr, val = self.bond_attr_logit_map[act_col] return GraphAction(t, source=a.item(), target=b.item(), attr=attr, value=val) + elif t is GraphActionType.RemoveNode: + return GraphAction(t, source=act_row) + elif t is GraphActionType.RemoveNodeAttr: + attr = self.settable_atom_attrs[act_col] + return GraphAction(t, source=act_row, attr=attr) + elif t is GraphActionType.RemoveEdge: + a, b = g.edge_index[:, act_row * 2] + return GraphAction(t, source=a.item(), target=b.item()) + elif t is GraphActionType.RemoveEdgeAttr: + a, b = g.edge_index[:, act_row * 2] + attr = self.bond_attrs[act_col] + return GraphAction(t, source=a.item(), target=b.item(), attr=attr) def GraphAction_to_aidx(self, g: gd.Data, action: GraphAction) -> Tuple[int, int, int]: """Translate a GraphAction to an index tuple""" + for u in [self.action_type_order, self.bck_action_type_order]: + if action.action in u: + type_idx = u.index(action.action) + break + else: + raise ValueError(f"Unknown action type {action.action}") + if action.action is GraphActionType.Stop: row = col = 0 elif action.action is GraphActionType.AddNode: @@ -201,7 +238,22 @@ def GraphAction_to_aidx(self, g: gd.Data, action: GraphAction) -> Tuple[int, int col = ( self.bond_attr_values[action.attr].index(action.value) - 1 + self.bond_attr_logit_slice[action.attr][0] ) - type_idx = self.action_type_order.index(action.action) + elif action.action is GraphActionType.RemoveNode: + row = action.source + col = 0 + elif action.action is GraphActionType.RemoveNodeAttr: + row = action.source + col = self.settable_atom_attrs.index(action.attr) + elif action.action is GraphActionType.RemoveEdge: + row = ((g.edge_index.T == torch.tensor([(action.source, action.target)])).prod(1)).argmax() + row = int(row) // 2 # edges are duplicated, but edge logits are not + col = 0 + elif action.action is GraphActionType.RemoveEdgeAttr: + row = (g.edge_index.T == torch.tensor([(action.source, action.target)])).prod(1).argmax() + row = row.div(2, rounding_mode="floor") # type: ignore + col = self.bond_attrs.index(action.attr) + else: + raise ValueError(f"Unknown action type {action.action}") return (type_idx, int(row), int(col)) def graph_to_Data(self, g: Graph) -> gd.Data: @@ -211,6 +263,9 @@ def graph_to_Data(self, g: Graph) -> gd.Data: add_node_mask = torch.ones((x.shape[0], self.num_new_node_values)) if self.max_nodes is not None and len(g.nodes) >= self.max_nodes: add_node_mask *= 0 + remove_node_mask = torch.zeros((x.shape[0], 1)) + (1 if len(g) == 0 else 0) + remove_node_attr_mask = torch.zeros((x.shape[0], len(self.settable_atom_attrs))) + explicit_valence = {} max_valence = {} set_node_attr_mask = torch.ones((x.shape[0], self.num_node_attr_logits)) @@ -218,18 +273,33 @@ def graph_to_Data(self, g: Graph) -> gd.Data: set_node_attr_mask *= 0 for i, n in enumerate(g.nodes): ad = g.nodes[n] + if g.degree(n) <= 1 and len(ad) == 1 and all([len(g[n][neigh]) == 0 for neigh in g.neighbors(n)]): + # If there's only the 'v' key left and the node is a leaf, and the edge that connect to the node have + # no attributes set, we can remove it + remove_node_mask[i] = 1 for k, sl in zip(self.atom_attrs, self.atom_attr_slice): + # idx > 0 means that the attribute is not the default value idx = self.atom_attr_values[k].index(ad[k]) if k in ad else 0 x[i, sl + idx] = 1 - # If the attribute is already there, mask out logits - # (or if the attribute is a negative attribute and has been filled) + if k == "v": + continue + # If the attribute + # - is already there (idx > 0), + # - or the attribute is a negative attribute and has been filled + # - or the attribute is a negative attribute and is not fillable (i.e. not a key of ad) + # then mask forward logits. + # For backward logits, positively mask if the attribute is there (idx > 0). if k in self.negative_attrs: if k in ad and idx > 0 or k not in ad: s, e = self.atom_attr_logit_slice[k] set_node_attr_mask[i, s:e] = 0 + # We don't want to make the attribute removable if it's not fillable (i.e. not a key of ad) + if k in ad: + remove_node_attr_mask[i, self.settable_atom_attrs.index(k)] = 1 elif k in ad: s, e = self.atom_attr_logit_slice[k] set_node_attr_mask[i, s:e] = 0 + remove_node_attr_mask[i, self.settable_atom_attrs.index(k)] = 1 # Account for charge and explicit Hs in atom as limiting the total valence max_atom_valence = self._max_atom_valence[ad.get("fill_wildcard", None) or ad["v"]] # Special rule for Nitrogen @@ -256,8 +326,14 @@ def graph_to_Data(self, g: Graph) -> gd.Data: s, e = self.atom_attr_logit_slice["expl_H"] set_node_attr_mask[i, s:e] = 0 + remove_edge_mask = torch.zeros((len(g.edges), 1)) + for i, (u, v) in enumerate(g.edges): + if g.degree(u) > 1 and g.degree(v) > 1: + if nx.algorithms.is_connected(graph_without_edge(g, (u, v))): + remove_edge_mask[i] = 1 edge_attr = torch.zeros((len(g.edges) * 2, self.num_edge_dim)) set_edge_attr_mask = torch.zeros((len(g.edges), self.num_edge_attr_logits)) + remove_edge_attr_mask = torch.zeros((len(g.edges), len(self.bond_attrs))) for i, e in enumerate(g.edges): ad = g.edges[e] for k, sl in zip(self.bond_attrs, self.bond_attr_slice): @@ -267,6 +343,7 @@ def graph_to_Data(self, g: Graph) -> gd.Data: if k in ad: # If the attribute is already there, mask out logits s, e = self.bond_attr_logit_slice[k] set_edge_attr_mask[i, s:e] = 0 + remove_edge_attr_mask[i, self.bond_attrs.index(k)] = 1 # Check which bonds don't bust the valence of their atoms if "type" not in ad: # Only if type isn't already set sl, _ = self.bond_attr_logit_slice["type"] @@ -293,11 +370,15 @@ def is_ok_non_edge(e): edge_index, edge_attr, non_edge_index=non_edge_index, - stop_mask=torch.ones(1, 1) if len(g) > 0 else torch.zeros(1, 1), + stop_mask=torch.ones((1, 1)) * (len(g.nodes) > 0), # Can only stop if there's at least a node add_node_mask=add_node_mask, set_node_attr_mask=set_node_attr_mask, add_edge_mask=torch.ones((non_edge_index.shape[1], 1)), # Already filtered by is_ok_non_edge set_edge_attr_mask=set_edge_attr_mask, + remove_node_mask=remove_node_mask, + remove_node_attr_mask=remove_node_attr_mask, + remove_edge_mask=remove_edge_mask, + remove_edge_attr_mask=remove_edge_attr_mask, ) if self.num_rw_feat > 0: data.x = torch.cat([data.x, random_walk_probs(data, self.num_rw_feat, skip_odd=True)], 1) diff --git a/src/gflownet/tasks/make_rings.py b/src/gflownet/tasks/make_rings.py new file mode 100644 index 00000000..c3e8d0f9 --- /dev/null +++ b/src/gflownet/tasks/make_rings.py @@ -0,0 +1,90 @@ +import os +import socket +from typing import Dict, List, Tuple, Union + +import numpy as np +import torch +from rdkit import Chem +from rdkit.Chem.rdchem import Mol as RDMol +from torch import Tensor + +from gflownet.config import Config +from gflownet.envs.mol_building_env import MolBuildingEnvContext +from gflownet.online_trainer import StandardOnlineTrainer +from gflownet.trainer import FlatRewards, GFNTask, RewardScalar + + +class MakeRingsTask(GFNTask): + """A toy task where the reward is the number of rings in the molecule.""" + + def __init__( + self, + rng: np.random.Generator, + ): + self.rng = rng + + def flat_reward_transform(self, y: Union[float, Tensor]) -> FlatRewards: + return FlatRewards(y) + + def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Tensor]: + return {"beta": torch.ones(n), "encoding": torch.ones(n, 1)} + + def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar: + scalar_logreward = torch.as_tensor(flat_reward).squeeze().clamp(min=1e-30).log() + return RewardScalar(scalar_logreward.flatten()) + + def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: + rs = torch.tensor([m.GetRingInfo().NumRings() for m in mols]).float() + return FlatRewards(rs.reshape((-1, 1))), torch.ones(len(mols)).bool() + + +class MakeRingsTrainer(StandardOnlineTrainer): + def set_default_hps(self, cfg: Config): + cfg.hostname = socket.gethostname() + cfg.num_workers = 8 + cfg.algo.global_batch_size = 64 + cfg.algo.offline_ratio = 0 + cfg.model.num_emb = 128 + cfg.model.num_layers = 4 + + cfg.algo.method = "TB" + cfg.algo.max_nodes = 6 + cfg.algo.sampling_tau = 0.9 + cfg.algo.illegal_action_logreward = -75 + cfg.algo.train_random_action_prob = 0.0 + cfg.algo.valid_random_action_prob = 0.0 + cfg.algo.tb.do_parameterize_p_b = True + + cfg.replay.use = False + + def setup_task(self): + self.task = MakeRingsTask(rng=self.rng) + + def setup_env_context(self): + self.ctx = MolBuildingEnvContext( + ["C"], + charges=[0], # disable charge + chiral_types=[Chem.rdchem.ChiralType.CHI_UNSPECIFIED], # disable chirality + num_rw_feat=0, + max_nodes=self.cfg.algo.max_nodes, + num_cond_dim=1, + ) + + +def main(): + hps = { + "log_dir": "./logs/debug_run_mr4", + "device": "cuda", + "num_training_steps": 10_000, + "num_workers": 8, + "algo": {"tb": {"do_parameterize_p_b": True}}, + } + os.makedirs(hps["log_dir"], exist_ok=True) + + trial = MakeRingsTrainer(hps) + trial.print_every = 1 + trial.run() + + +if __name__ == "__main__": + main() diff --git a/tests/test_frag_env.py b/tests/test_envs.py similarity index 76% rename from tests/test_frag_env.py rename to tests/test_envs.py index deda34e2..204a17cb 100644 --- a/tests/test_frag_env.py +++ b/tests/test_envs.py @@ -9,9 +9,11 @@ from gflownet.config import Config from gflownet.envs.frag_mol_env import FragMolBuildingEnvContext from gflownet.envs.graph_building_env import GraphBuildingEnv +from gflownet.envs.mol_building_env import MolBuildingEnvContext +from gflownet.models import bengio2021flow -def build_two_node_states(): +def build_two_node_states(ctx): # TODO: This is actually fairly generic code that will probably be reused by other tests in the future. # Having a proper class to handle graph-indexed hash maps would probably be good. graph_cache = {} @@ -21,7 +23,6 @@ def build_two_node_states(): # We're enumerating all states of length two, but we could've just as well randomly sampled # some states. env = GraphBuildingEnv() - ctx = FragMolBuildingEnvContext(max_frags=2) def g2h(g): gc = g.to_directed() @@ -73,11 +74,19 @@ def expand(s, idx): return [graph_by_idx[i] for i in list(nx.topological_sort(mdp_graph))] +def get_frag_env_ctx(): + return FragMolBuildingEnvContext(max_frags=2, fragments=bengio2021flow.FRAGMENTS[:20]) + + +def get_atom_env_ctx(): + return MolBuildingEnvContext(atoms=["C", "N"], expl_H_range=[0], charges=[0], max_nodes=2) + + @pytest.fixture -def two_node_states(request): +def two_node_states_frags(request): data = request.config.cache.get("frag_env/two_node_states", None) if data is None: - data = build_two_node_states() + data = build_two_node_states(get_frag_env_ctx()) # pytest caches through JSON so we have to make a clean enough string request.config.cache.set("frag_env/two_node_states", base64.b64encode(pickle.dumps(data)).decode()) else: @@ -85,13 +94,24 @@ def two_node_states(request): return data -def test_backwards_mask_equivalence(two_node_states): +@pytest.fixture +def two_node_states_atoms(request): + data = request.config.cache.get("atom_env/two_node_states", None) + if data is None: + data = build_two_node_states(get_atom_env_ctx()) + # pytest caches through JSON so we have to make a clean enough string + request.config.cache.set("atom_env/two_node_states", base64.b64encode(pickle.dumps(data)).decode()) + else: + data = pickle.loads(base64.b64decode(data)) + return data + + +def _test_backwards_mask_equivalence(two_node_states, ctx): """This tests that FragMolBuildingEnvContext implements backwards masks correctly. It treats GraphBuildingEnv.count_backward_transitions as the ground truth and raises an error if there is a different number of actions leading to the parents of any state. """ env = GraphBuildingEnv() - ctx = FragMolBuildingEnvContext(max_frags=2) for i in range(1, len(two_node_states)): g = two_node_states[i] n = env.count_backward_transitions(g, check_idempotent=False) @@ -104,7 +124,7 @@ def test_backwards_mask_equivalence(two_node_states): raise ValueError() -def test_backwards_mask_equivalence_ipa(two_node_states): +def _test_backwards_mask_equivalence_ipa(two_node_states, ctx): """This tests that FragMolBuildingEnvContext implements backwards masks correctly. It treats GraphBuildingEnv.count_backward_transitions as the ground truth and raises an error if there is a different number of actions leading to the parents of any state. @@ -112,7 +132,6 @@ def test_backwards_mask_equivalence_ipa(two_node_states): This test also accounts for idempotent actions. """ env = GraphBuildingEnv() - ctx = FragMolBuildingEnvContext(max_frags=2) cfg = OmegaConf.structured(Config) cfg.algo.max_nodes = 2 algo = TrajectoryBalance(env, ctx, None, cfg) @@ -141,3 +160,19 @@ def test_backwards_mask_equivalence_ipa(two_node_states): equivalence_classes.append(ipa) if n != len(equivalence_classes): raise ValueError() + + +def test_backwards_mask_equivalence_frag(two_node_states_frags): + _test_backwards_mask_equivalence(two_node_states_frags, get_frag_env_ctx()) + + +def test_backwards_mask_equivalence_ipa_frag(two_node_states_frags): + _test_backwards_mask_equivalence_ipa(two_node_states_frags, get_frag_env_ctx()) + + +def test_backwards_mask_equivalence_atom(two_node_states_atoms): + _test_backwards_mask_equivalence(two_node_states_atoms, get_atom_env_ctx()) + + +def test_backwards_mask_equivalence_ipa_atom(two_node_states_atoms): + _test_backwards_mask_equivalence_ipa(two_node_states_atoms, get_atom_env_ctx())