Skip to content

Commit

Permalink
merge & squash to refresh branch
Browse files Browse the repository at this point in the history
  • Loading branch information
bengioe committed Aug 4, 2023
1 parent ec857a5 commit cc12abc
Show file tree
Hide file tree
Showing 12 changed files with 1,175 additions and 23 deletions.
2 changes: 1 addition & 1 deletion docs/implementation_notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,4 @@ The code contains a specific categorical distribution type for graph actions, `G

Consider for example the `AddNode` and `SetEdgeAttr` actions, one applies to nodes and one to edges. An efficient way to produce logits for these actions would be to take the node/edge embeddings and project them (e.g. via an MLP) to a `(n_nodes, n_node_actions)` and `(n_edges, n_edge_actions)` tensor respectively. We thus obtain a list of tensors representing the logits of different actions, but logits are mixed between graphs in the minibatch, so one cannot simply apply a `softmax` operator on the tensor.

The `GraphActionCategorical` class handles this and can be used to compute various other things, such as entropy, log probabilities, and so on; it can also be used to sample from the distribution.
The `GraphActionCategorical` class handles this and can be used to compute various other things, such as entropy, log probabilities, and so on; it can also be used to sample from the distribution.
4 changes: 3 additions & 1 deletion src/gflownet/algo/graph_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,9 @@ def not_done(lst):
]
if self.sample_temp != 1:
sample_cat = copy.copy(fwd_cat)
sample_cat.logits = [i / self.sample_temp for i in fwd_cat.logits]
sample_cat.logits = [
i * m / self.sample_temp - 1000 * (1 - m) for i, m in zip(fwd_cat.logits, fwd_cat.masks)
]
actions = sample_cat.sample()
else:
actions = fwd_cat.sample()
Expand Down
15 changes: 4 additions & 11 deletions src/gflownet/algo/trajectory_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch.nn as nn
import torch_geometric.data as gd
from torch import Tensor
from torch_scatter import scatter, scatter_sum
from torch_scatter import scatter, scatter_sum, scatter_logsumexp

from gflownet.algo.graph_sampling import GraphSampler
from gflownet.config import Config
Expand Down Expand Up @@ -309,22 +309,15 @@ def compute_batch_losses(
# Indicate that the `batch` corresponding to each action is the above
ip_log_prob = fwd_cat.log_prob(batch.ip_actions, batch=ip_batch_idces)
# take the logsumexp (because we want to sum probabilities, not log probabilities)
# TODO: numerically stable version:
p = scatter(ip_log_prob.exp(), ip_batch_idces, dim=0, dim_size=batch_idx.shape[0], reduce="sum")
# As a (reasonable) band-aid, ignore p < 1e-30, this will prevent underflows due to
# scatter(small number) = 0 on CUDA
log_p_F = p.clamp(1e-30).log()
log_p_F = scatter_logsumexp(ip_log_prob, ip_batch_idces, dim=0, dim_size=batch_idx.shape[0])

if self.cfg.do_parameterize_p_b:
# Now we repeat this but for the backward policy
bck_ip_batch_idces = torch.arange(batch.bck_ip_lens.shape[0], device=dev).repeat_interleave(
batch.bck_ip_lens
)
bck_ip_log_prob = bck_cat.log_prob(batch.bck_ip_actions, batch=bck_ip_batch_idces)
bck_p = scatter(
bck_ip_log_prob.exp(), bck_ip_batch_idces, dim=0, dim_size=batch_idx.shape[0], reduce="sum"
)
log_p_B = bck_p.clamp(1e-30).log()
log_p_B = scatter_logsumexp(bck_ip_log_prob, bck_ip_batch_idces, dim=0, dim_size=batch_idx.shape[0])
else:
# Else just naively take the logprob of the actions we took
log_p_F = fwd_cat.log_prob(batch.actions)
Expand Down Expand Up @@ -496,7 +489,7 @@ def subtb_loss_fast(self, P_F, P_B, F, R, traj_lengths):
cumul_lens = torch.cumsum(torch.cat([torch.zeros(1, device=dev), traj_lengths]), 0).long()
total_loss = torch.zeros(num_trajs, device=dev)
ar = torch.arange(max_len, device=dev)
car = torch.cumsum(ar, 0)
car = torch.cumsum(ar, 0) if self.length_normalize_losses else torch.ones_like(ar)
F_and_R = torch.cat([F, R])
R_start = F.shape[0]
for ep in range(traj_lengths.shape[0]):
Expand Down
1 change: 1 addition & 0 deletions src/gflownet/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class Config:
"""

log_dir: str = MISSING
log_sampled_data: bool = True
device: str = "cuda"
seed: int = 0
validate_every: int = 1000
Expand Down
25 changes: 20 additions & 5 deletions src/gflownet/data/sampling_iterator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import sqlite3
import traceback
from collections.abc import Iterable
from copy import deepcopy
from typing import Callable, List
Expand All @@ -11,6 +12,7 @@
from rdkit import Chem, RDLogger
from torch.utils.data import Dataset, IterableDataset

from gflownet.config import Config
from gflownet.data.replay_buffer import ReplayBuffer
from gflownet.envs.graph_building_env import GraphActionCategorical

Expand Down Expand Up @@ -112,9 +114,14 @@ def __init__(
# don't want to initialize per-worker things just yet, such as where the log the worker writes
# to. This must be done in __iter__, which is called by the DataLoader once this instance
# has been copied into a new python process.
self.log_dir = log_dir
import warnings

warnings.warn("Fix dependency on cfg.log_sampled_data")
self.log_dir = log_dir # if cfg.log_sampled_data else None
self.log = SQLiteLog()
self.log_hooks: List[Callable] = []
# TODO: make this a proper flag / make a separate class for logging sampled molecules to a SQLite db
self.log_molecule_smis = not hasattr(self.ctx, "not_a_molecule_env") and self.log_dir is not None

def add_log_hook(self, hook: Callable):
self.log_hooks.append(hook)
Expand Down Expand Up @@ -158,6 +165,14 @@ def __len__(self):
return len(self.data)

def __iter__(self):
try:
for x in self.iterator():
yield x
except Exception as e:
traceback.print_exc()
raise e

def iterator(self):
worker_info = torch.utils.data.get_worker_info()
self._wid = worker_info.id if worker_info is not None else 0
# Now that we know we are in a worker instance, we can initialize per-worker things
Expand Down Expand Up @@ -189,9 +204,7 @@ def __iter__(self):
else: # If we're not sampling the conditionals, then the idcs refer to listed preferences
num_online = num_offline
num_offline = 0
cond_info = self.task.encode_conditional_information(
steer_info=torch.stack([self.data[i] for i in idcs])
)
cond_info = self.task.encode_conditional_information(torch.stack([self.data[i] for i in idcs]))
trajs, flat_rewards = [], []

# Sample some on-policy data
Expand Down Expand Up @@ -250,14 +263,16 @@ def __iter__(self):
# note: we convert back into natural rewards for logging purposes
# (allows to take averages and plot in objective space)
# TODO: implement that per-task (in case they don't apply the same beta and log transformations)
rewards = torch.exp(log_rewards / cond_info["beta"])
rewards = torch.exp(log_rewards / (cond_info["beta"] if "beta" in cond_info else 1.0))
if num_online > 0 and self.log_dir is not None:
self.log_generated(
deepcopy(trajs[num_offline:]),
deepcopy(rewards[num_offline:]),
deepcopy(flat_rewards[num_offline:]),
{k: v[num_offline:] for k, v in deepcopy(cond_info).items()},
)

extra_info = {}
if num_online > 0:
for hook in self.log_hooks:
extra_info.update(
Expand Down
198 changes: 198 additions & 0 deletions src/gflownet/envs/basic_graph_ctx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
from typing import Dict, List, Tuple

import networkx as nx
import torch
import torch_geometric.data as gd
from networkx.algorithms.isomorphism import is_isomorphic as nx_is_isomorphic

from gflownet.envs.graph_building_env import (
Graph,
GraphAction,
GraphActionType,
GraphBuildingEnvContext,
graph_without_edge,
)
from gflownet.utils.graphs import random_walk_probs


def hashg(g):
return nx.algorithms.graph_hashing.weisfeiler_lehman_graph_hash(g, node_attr="v")


def is_isomorphic(u, v):
return nx_is_isomorphic(u, v, lambda a, b: a == b, lambda a, b: a == b)


class BasicGraphContext(GraphBuildingEnvContext):
"""
A basic graph generation context.
This simple environment context is designed to be used to test implementations. It only allows for AddNode and
AddEdge actions, and is meant to be used within the BasicGraphTask to generate graphs of up to 7 nodes with
only two possible node attributes, making the state space a total of ~70k states (which is nicely enumerable
and allows us to compute p_theta(x) exactly for all x in the state space).
"""

def __init__(self, max_nodes=7, num_cond_dim=0, graph_data=None, output_gid=False):
self.max_nodes = max_nodes
self.output_gid = output_gid

self.node_attr_values = {
"v": [0, 1], # Imagine this is as colors
}
self._num_rw_feat = 8

self.num_new_node_values = len(self.node_attr_values["v"])
self.num_node_attr_logits = None
self.num_node_dim = self.num_new_node_values + 1 + self._num_rw_feat
self.num_node_attrs = 1
self.num_edge_attr_logits = None
self.num_edge_attrs = 0
self.num_cond_dim = num_cond_dim
self.num_edge_dim = 1
self.edges_are_duplicated = True
self.edges_are_unordered = True

# Order in which models have to output logits
self.action_type_order = [
GraphActionType.Stop,
GraphActionType.AddNode,
GraphActionType.AddEdge,
]
self.bck_action_type_order = [
GraphActionType.RemoveNode,
GraphActionType.RemoveEdge,
]
self.device = torch.device("cpu")
self.graph_data = graph_data
self.hash_to_graphs: Dict[str, int] = {}
if graph_data is not None:
states_hash = [hashg(i) for i in graph_data]
for i, h, g in zip(range(len(graph_data)), states_hash, graph_data):
self.hash_to_graphs[h] = self.hash_to_graphs.get(h, list()) + [(g, i)]

def get_graph_idx(self, g, default=None):
h = hashg(g)
if h not in self.hash_to_graphs and default is not None:
return default
bucket = self.hash_to_graphs[h]
if len(bucket) == 1:
return bucket[0][1]
for i in bucket:
if is_isomorphic(i[0], g):
return i[1]
if default is not None:
return default
raise ValueError(g)

def aidx_to_GraphAction(self, g: gd.Data, action_idx: Tuple[int, int, int], fwd: bool = True):
"""Translate an action index (e.g. from a GraphActionCategorical) to a GraphAction"""
act_type, act_row, act_col = [int(i) for i in action_idx]
if fwd:
t = self.action_type_order[act_type]
else:
t = self.bck_action_type_order[act_type]

if t is GraphActionType.Stop:
return GraphAction(t)
elif t is GraphActionType.AddNode:
return GraphAction(t, source=act_row, value=self.node_attr_values["v"][act_col])
elif t is GraphActionType.AddEdge:
a, b = g.non_edge_index[:, act_row]
return GraphAction(t, source=a.item(), target=b.item())
elif t is GraphActionType.RemoveNode:
return GraphAction(t, source=act_row)
elif t is GraphActionType.RemoveEdge:
a, b = g.edge_index[:, act_row * 2]
return GraphAction(t, source=a.item(), target=b.item())

def GraphAction_to_aidx(self, g: gd.Data, action: GraphAction) -> Tuple[int, int, int]:
"""Translate a GraphAction to an index tuple"""
if action.action is GraphActionType.Stop:
row = col = 0
type_idx = self.action_type_order.index(action.action)
elif action.action is GraphActionType.AddNode:
row = action.source
col = self.node_attr_values["v"].index(action.value)
type_idx = self.action_type_order.index(action.action)
elif action.action is GraphActionType.AddEdge:
# Here we have to retrieve the index in non_edge_index of an edge (s,t)
# that's also possibly in the reverse order (t,s).
# That's definitely not too efficient, can we do better?
row = (
(g.non_edge_index.T == torch.tensor([(action.source, action.target)])).prod(1)
+ (g.non_edge_index.T == torch.tensor([(action.target, action.source)])).prod(1)
).argmax()
col = 0
type_idx = self.action_type_order.index(action.action)
elif action.action is GraphActionType.RemoveNode:
row = action.source
col = 0
type_idx = self.bck_action_type_order.index(action.action)
elif action.action is GraphActionType.RemoveEdge:
row = ((g.edge_index.T == torch.tensor([(action.source, action.target)])).prod(1)).argmax()
row = int(row) // 2 # edges are duplicated, but edge logits are not
col = 0
type_idx = self.bck_action_type_order.index(action.action)
return (type_idx, int(row), int(col))

def graph_to_Data(self, g: Graph) -> gd.Data:
"""Convert a networkx Graph to a torch geometric Data instance"""
x = torch.zeros((max(1, len(g.nodes)), self.num_node_dim - self._num_rw_feat))
x[0, -1] = len(g.nodes) == 0
remove_node_mask = torch.zeros((x.shape[0], 1)) + (1 if len(g) == 0 else 0)
for i, n in enumerate(g.nodes):
ad = g.nodes[n]
x[i, self.node_attr_values["v"].index(ad["v"])] = 1
if g.degree(n) <= 1:
remove_node_mask[i] = 1

remove_edge_mask = torch.zeros((len(g.edges), 1))
for i, (u, v) in enumerate(g.edges):
if g.degree(u) > 1 and g.degree(v) > 1:
if nx.algorithms.is_connected(graph_without_edge(g, (u, v))):
remove_edge_mask[i] = 1
edge_attr = torch.zeros((len(g.edges) * 2, self.num_edge_dim))
edge_index = (
torch.tensor([e for i, j in g.edges for e in [(i, j), (j, i)]], dtype=torch.long).reshape((-1, 2)).T
)
gc = nx.complement(g)
non_edge_index = torch.tensor([i for i in gc.edges], dtype=torch.long).reshape((-1, 2)).T
gid = self.get_graph_idx(g) if self.output_gid else 0

return self._preprocess(
gd.Data(
x,
edge_index,
edge_attr,
non_edge_index=non_edge_index,
stop_mask=torch.ones((1, 1)),
add_node_mask=torch.ones((x.shape[0], self.num_new_node_values)) * (len(g) < self.max_nodes),
add_edge_mask=torch.ones((non_edge_index.shape[1], 1)),
remove_node_mask=remove_node_mask,
remove_edge_mask=remove_edge_mask,
gid=gid,
)
)

def _preprocess(self, g: gd.Data) -> gd.Data:
if self._num_rw_feat > 0:
g.x = torch.cat([g.x, random_walk_probs(g, self._num_rw_feat, skip_odd=True)], 1)
return g

def collate(self, graphs: List[gd.Data]):
"""Batch Data instances"""
return gd.Batch.from_data_list(graphs, follow_batch=["edge_index", "non_edge_index"])

def mol_to_graph(self, obj: Graph) -> Graph:
return obj # This is already a graph

def graph_to_mol(self, g: Graph) -> Graph:
# idem
return g

def is_sane(self, g: Graph) -> bool:
return True

def get_object_description(self, g: Graph, is_valid: bool) -> str:
return str(self.get_graph_idx(g, -1))
3 changes: 2 additions & 1 deletion src/gflownet/envs/graph_building_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def generate_forward_trajectory(g: Graph, max_nodes: int = None) -> List[Tuple[G
# TODO: should this be a method of GraphBuildingEnv? handle set_node_attr flags and so on?
gn = Graph()
# Choose an arbitrary starting point, add to the stack
stack: List[Tuple[int, ...]] = [(np.random.randint(0, len(g.nodes)),)]
stack: List[Tuple[int, ...]] = [(np.random.randint(0, len(g.nodes)),)] if len(g.nodes) > 0 else []
traj = []
# This map keeps track of node labels in gn, since we have to start from 0
relabeling_map: Dict[int, int] = {}
Expand Down Expand Up @@ -777,6 +777,7 @@ class GraphBuildingEnvContext:
"""A context class defines what the graphs are, how they map to and from data"""

device: torch.device
num_cond_dim: int = 0

def aidx_to_GraphAction(self, g: gd.Data, action_idx: Tuple[int, int, int], fwd: bool = True) -> GraphAction:
"""Translate an action index (e.g. from a GraphActionCategorical) to a GraphAction
Expand Down
Loading

0 comments on commit cc12abc

Please sign in to comment.