Skip to content

Commit

Permalink
more comments and implementation notes
Browse files Browse the repository at this point in the history
  • Loading branch information
bengioe committed Aug 4, 2023
1 parent ef0ae83 commit 35466a7
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 9 deletions.
18 changes: 18 additions & 0 deletions docs/implementation_notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,21 @@ We separate experiment concerns in four categories:
- The Trainer class is responsible for instanciating everything, and running the training & testing loop

Typically one would setup a new experiment by creating a class that inherits from `GFNTask` and a class that inherits from `GFNTrainer`. To implement a new MDP, one would create a class that inherits from `GraphBuildingEnvContext`.


## Graphs

This library is built around the idea of generating graphs. We use the `networkx` library to represent graphs, and we use the `torch_geometric` library to represent graphs as tensors for the models. There is a fair amount of code that is dedicated to converting between the two representations.

Some notes:
- graphs are (for now) assumed to be _undirected_. This is encoded for `torch_geometric` by duplicating the edges (contiguously) in both directions. Models still only produce one logit(-row) per edge, so the policy is still assumed to operate on undirected graphs.
- When converting from `GraphAction`s (nx) to so-called `aidx`s, the `aidx`s are encoding-bound, i.e. they point to specific rows and columns in the torch encoding.


### Graph policies & graph action categoricals

The code contains a specific categorical distribution type for graph actions, `GraphActionCategorical`. This class contains logic to sample from concatenated sets of logits accross a minibatch.

Consider for example the `AddNode` and `SetEdgeAttr` actions, one applies to nodes and one to edges. An efficient way to produce logits for these actions would be to take the node/edge embeddings and project them (e.g. via an MLP) to a `(n_nodes, n_node_actions)` and `(n_edges, n_edge_actions)` tensor respectively. We thus obtain a list of tensors representing the logits of different actions, but logits are mixed between graphs in the minibatch, so one cannot simply apply a `softmax` operator on the tensor.

The `GraphActionCategorical` class handles this and can be used to compute various other things, such as entropy, log probabilities, and so on; it can also be used to sample from the distribution.
22 changes: 13 additions & 9 deletions src/gflownet/envs/mol_building_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,10 @@ def aidx_to_GraphAction(self, g: gd.Data, action_idx: Tuple[int, int, int], fwd:
a, b = g.non_edge_index[:, act_row]
return GraphAction(t, source=a.item(), target=b.item())
elif t is GraphActionType.SetEdgeAttr:
a, b = g.edge_index[:, act_row * 2] # Edges are duplicated to get undirected GNN, deduplicated for logits
# In order to form an undirected graph for torch_geometric, edges are duplicated, in order (i.e.
# g.edge_index = [[a,b], [b,a], [c,d], [d,c], ...].T), but edge logits are not. So to go from one
# to another we can safely divide or multiply by two.
a, b = g.edge_index[:, act_row * 2]
attr, val = self.bond_attr_logit_map[act_col]
return GraphAction(t, source=a.item(), target=b.item(), attr=attr, value=val)
elif t is GraphActionType.RemoveNode:
Expand All @@ -191,10 +194,10 @@ def aidx_to_GraphAction(self, g: gd.Data, action_idx: Tuple[int, int, int], fwd:
attr = self.settable_atom_attrs[act_col]
return GraphAction(t, source=act_row, attr=attr)
elif t is GraphActionType.RemoveEdge:
a, b = g.edge_index[:, act_row * 2]
a, b = g.edge_index[:, act_row * 2] # see note above about edge_index
return GraphAction(t, source=a.item(), target=b.item())
elif t is GraphActionType.RemoveEdgeAttr:
a, b = g.edge_index[:, act_row * 2]
a, b = g.edge_index[:, act_row * 2] # see note above about edge_index
attr = self.bond_attrs[act_col]
return GraphAction(t, source=a.item(), target=b.item(), attr=attr)

Expand Down Expand Up @@ -228,12 +231,10 @@ def GraphAction_to_aidx(self, g: gd.Data, action: GraphAction) -> Tuple[int, int
).argmax()
col = 0
elif action.action is GraphActionType.SetEdgeAttr:
# Here the edges are duplicated, both (i,j) and (j,i) are in edge_index
# so no need for a double check.
# row = ((g.edge_index.T == torch.tensor([(action.source, action.target)])).prod(1) +
# (g.edge_index.T == torch.tensor([(action.target, action.source)])).prod(1)).argmax()
# In order to form an undirected graph for torch_geometric, edges are duplicated, in order (i.e.
# g.edge_index = [[a,b], [b,a], [c,d], [d,c], ...].T), but edge logits are not. So to go from one
# to another we can safely divide or multiply by two.
row = (g.edge_index.T == torch.tensor([(action.source, action.target)])).prod(1).argmax()
# Because edges are duplicated but logits aren't, divide by two
row = row.div(2, rounding_mode="floor") # type: ignore
col = (
self.bond_attr_values[action.attr].index(action.value) - 1 + self.bond_attr_logit_slice[action.attr][0]
Expand All @@ -246,7 +247,10 @@ def GraphAction_to_aidx(self, g: gd.Data, action: GraphAction) -> Tuple[int, int
col = self.settable_atom_attrs.index(action.attr)
elif action.action is GraphActionType.RemoveEdge:
row = ((g.edge_index.T == torch.tensor([(action.source, action.target)])).prod(1)).argmax()
row = int(row) // 2 # edges are duplicated, but edge logits are not
# In order to form an undirected graph for torch_geometric, edges are duplicated, in order (i.e.
# g.edge_index = [[a,b], [b,a], [c,d], [d,c], ...].T), but edge logits are not. So to go from one
# to another we can safely divide or multiply by two.
row = int(row) // 2
col = 0
elif action.action is GraphActionType.RemoveEdgeAttr:
row = (g.edge_index.T == torch.tensor([(action.source, action.target)])).prod(1).argmax()
Expand Down

0 comments on commit 35466a7

Please sign in to comment.