Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates from conformer #338

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ sh/
*.txt
.vscode/
external/
playground/
!requirements.txt
!docs/requirements-docs.txt
.DS_Store
Expand Down
1 change: 1 addition & 0 deletions config/evaluator/base.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
_target_: gflownet.evaluator.base.BaseEvaluator

reward_sampling_method: rejection
# config formerly from logger.test
first_it: True
period: 100
Expand Down
11 changes: 10 additions & 1 deletion gflownet/envs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Note that the mask of invalid actions indeed flags _invalid_ actions, as opposed

## Buffer, train data and test data

A train and a test set can be created at the beginning of training. The train set may be used to sample offline (backward) trajectories. The test set may be used to compute metrics during and after training. These sets may be created in different ways, specificied by the configuration variables `env.buffer.train.type` and `env.buffer.test.type`. Options for the data set `type` are
A train and a test set can be created at the beginning of training. The train set may be used to sample offline (backward) trajectories. The test set may be used to compute metrics during and after training (e.g. JSD, correlation). These sets may be created in different ways, specificied by the configuration variables `env.buffer.train.type` and `env.buffer.test.type`. Options for the data set `type` are

- `all`: all terminating states in the output space $\mathcal{X}$ will be added - Convenient but only feasible for small, synthetic environments like the hyper-grid.
- `grid`: a grid of points in the output space $\mathcal{X}$ - Only available in certain environments where obtaining a grid of points is meaningful. This mode also requires specifying the number of points via `env.buffer.<train/test>.n`.
Expand Down Expand Up @@ -52,3 +52,12 @@ To use the replay buffer (once enabled) for backward sampling, one can specify `
:::{tip}
You can use [MyST](https://myst-parser.readthedocs.io/en/latest/syntax/admonitions.html) in the documentation. This is expected to fail on Github.
:::

## Evaluator
Evaluator's parameters define which method is used for sampling from reward (`nested` or `rejection` sampling), and how many points will be sampled `evaluator.n`. evaluator.n_grid is used only for plotting and defines, if applicable, the number of grid points for visualizing KDEs.
```yaml
evaluator:
reward_sampling_method: nested
n_grid: 1000 # number of grid points to visualize KDEs
n: 1000 # number of samples from reward and from gfn
```
12 changes: 5 additions & 7 deletions gflownet/envs/alaninedipeptide.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from torchtyping import TensorType

from gflownet.envs.ctorus import ContinuousTorus
from gflownet.utils.molecule import constants
from gflownet.utils.molecule.atom_positions_dataset import AtomPositionsDataset
from gflownet.utils.molecule.conformer_base import ConformerBase
from gflownet.utils.molecule.constants import AD_FREE_TAS, AD_SMILES
from gflownet.utils.molecule.datasets import AtomPositionsDataset
from gflownet.utils.molecule.rdkit_conformer import RDKitConformer


class AlanineDipeptide(ContinuousTorus):
Expand All @@ -26,9 +26,7 @@ def __init__(
path_to_dataset, url_to_dataset
)
atom_positions = self.atom_positions_dataset.sample()
self.conformer = ConformerBase(
atom_positions, constants.ad_smiles, constants.ad_free_tas
)
self.conformer = RDKitConformer(atom_positions, AD_SMILES, AD_FREE_TAS)
n_dim = len(self.conformer.freely_rotatable_tas)
super().__init__(**kwargs)
self.sync_conformer_with_state()
Expand Down Expand Up @@ -62,7 +60,7 @@ def states2proxy(
-------
A numpy array containing all the states in the batch.
"""
return super().states2proxy(states).numpy()
return super().states2proxy(states).detach().cpu().numpy()


if __name__ == "__main__":
Expand Down
34 changes: 31 additions & 3 deletions gflownet/envs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,14 @@
from torch.distributions import Categorical
from torchtyping import TensorType

from gflownet.utils.common import copy, set_device, set_float_precision, tbool, tfloat
from gflownet.utils.common import (
copy,
set_device,
set_float_precision,
tbool,
tfloat,
torch2np,
)

CMAP = mpl.colormaps["cividis"]
"""
Expand Down Expand Up @@ -48,7 +55,7 @@ def __init__(
# Call reset() to set initial state, done, n_actions
self.reset()
# Device
self.device = set_device(device)
self.set_device(set_device(device))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this change? Is this correct? :/

# Float precision
self.float = set_float_precision(float_precision)
# Flag to skip checking if action is valid (computing mask) before step
Expand All @@ -72,6 +79,17 @@ def __init__(
self.policy_output_dim = len(self.fixed_policy_output)
self.policy_input_dim = len(self.state2policy())

def set_device(self, device: torch.device):
"""
Set the device of the environment.

Parameters
----------
device : torch.device
The device to set the environment to.
"""
self.device = device

@abstractmethod
def get_action_space(self):
"""
Expand Down Expand Up @@ -757,6 +775,15 @@ def traj2readable(self, traj=None):
"""
return str(traj).replace("(", "[").replace(")", "]").replace(",", "")

def states2kde(
self, states: Union[List, TensorType["batch", "state_dim"]]
) -> Union[List, npt.NDArray, TensorType["batch", "kde_dim"]]:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the return type is always npt.NDArray isn't it?

"""
Converts a batch of states into a batch of states suitable for the KDE computations.
"""
states_kde = self.states2proxy(states)
return torch2np(states_kde)

def reset(self, env_id: Union[int, str] = None):
"""
Resets the environment.
Expand Down Expand Up @@ -1249,6 +1276,7 @@ def top_k_metrics_and_plots(

return metrics, figs, fig_names

@torch.no_grad()
def plot_reward_distribution(
self, states=None, scores=None, ax=None, title=None, proxy=None, **kwargs
):
Expand All @@ -1269,7 +1297,7 @@ def plot_reward_distribution(
states_proxy = self.states2proxy(states)
scores = self.proxy(states_proxy)
if isinstance(scores, TensorType):
scores = scores.cpu().detach().numpy()
scores = scores.detach().cpu().numpy()
ax.hist(scores)
ax.set_title(title)
ax.set_ylabel("Number of Samples")
Expand Down
4 changes: 0 additions & 4 deletions gflownet/envs/cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -1449,7 +1449,6 @@ def fit_kde(
bandwidth : float
The bandwidth of the kernel.
"""
samples = torch2np(samples)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change because Sklearns supports tensortypes when fitting the KernelDensity?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is because env.states2kde(states) is always called before calling fit_kde and torch2np happens there

return KernelDensity(kernel=kernel, bandwidth=bandwidth).fit(samples)

def plot_reward_samples(
Expand Down Expand Up @@ -1489,8 +1488,6 @@ def plot_reward_samples(
"""
if self.n_dim != 2:
return None
samples = torch2np(samples)
samples_reward = torch2np(samples_reward)
rewards = torch2np(rewards)
# Create mesh grid from samples_reward
n_per_dim = int(np.sqrt(samples_reward.shape[0]))
Expand Down Expand Up @@ -1543,7 +1540,6 @@ def plot_kde(
"""
if self.n_dim != 2:
return None
samples = torch2np(samples)
# Create mesh grid from samples
n_per_dim = int(np.sqrt(samples.shape[0]))
assert n_per_dim**2 == samples.shape[0]
Expand Down
7 changes: 4 additions & 3 deletions gflownet/envs/htorus.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__(
"vonmises_mean": 0.0,
"vonmises_concentration": 0.001,
},
reward_sampling_method="rejection",
**kwargs,
):
assert n_dim > 0
Expand All @@ -74,6 +75,9 @@ def __init__(
self.source = self.source_angles + [0]
# End-of-sequence action: (n_dim, 0)
self.eos = (self.n_dim, 0)

self.reward_sampling_method = reward_sampling_method

# Base class init
super().__init__(
fixed_distr_params=fixed_distr_params,
Expand Down Expand Up @@ -556,7 +560,6 @@ def fit_kde(
bandwidth : float
The bandwidth of the kernel.
"""
samples = torch2np(samples)
samples_aug = self.augment_samples(samples)
kde = KernelDensity(kernel=kernel, bandwidth=bandwidth).fit(samples_aug)
return kde
Expand Down Expand Up @@ -606,7 +609,6 @@ def plot_reward_samples(
"""
if self.n_dim != 2:
return None
samples = torch2np(samples)
rewards = torch2np(rewards)
n_per_dim = int(np.sqrt(rewards.shape[0]))
assert n_per_dim**2 == rewards.shape[0]
Expand Down Expand Up @@ -677,7 +679,6 @@ def plot_kde(
"""
if self.n_dim != 2:
return None
samples = torch2np(samples)
# Create mesh grid from samples
n_per_dim = int(np.sqrt(samples.shape[0]))
assert n_per_dim**2 == samples.shape[0]
Expand Down
4 changes: 2 additions & 2 deletions gflownet/envs/tetris.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def state2readable(self, state: Optional[TensorType["height", "width"]] = None):
if isinstance(state, tuple):
readable = str(np.stack(state))
else:
readable = str(state.cpu().numpy())
readable = str(state.detach().cpu().numpy())
readable = readable.replace("[[", "[").replace("]]", "]").replace("\n ", "\n")
return readable

Expand Down Expand Up @@ -581,7 +581,7 @@ def _plot_board(board, ax: Axes, cellsize: int = 20, linewidth: int = 2):
linewidth : int
The width of the separation between cells, in pixels.
"""
board = board.clone().numpy()
board = board.detach().clone().numpy()
height = board.shape[0] * cellsize
width = board.shape[1] * cellsize
board_img = 128 * np.ones(
Expand Down
2 changes: 1 addition & 1 deletion gflownet/envs/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,7 +882,7 @@ def state2readable(self, state=None):
"""
if state is None:
state = self.state.clone().detach()
state = state.cpu().numpy()
state = state.detach().cpu().numpy()
readable = ""
for idx in range(self.n_nodes):
attributes = self._attributes_to_readable(state[idx])
Expand Down
32 changes: 21 additions & 11 deletions gflownet/evaluator/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ class methods to instantiate an evaluator.


class BaseEvaluator(AbstractEvaluator):

def __init__(self, gfn_agent=None, **config):
def __init__(self, gfn_agent=None, reward_sampling_method="rejection", **config):
"""
Base evaluator class for GFlowNetAgent.

Expand All @@ -56,6 +55,7 @@ def __init__(self, gfn_agent=None, **config):
details about other methods and attributes, including the
:meth:`~gflownet.evaluator.abstract.AbstractEvaluator.__init__`.
"""
self.reward_sampling_method = reward_sampling_method
super().__init__(gfn_agent, **config)

def define_new_metrics(self):
Expand Down Expand Up @@ -223,6 +223,7 @@ def eval_top_k(self, it, gfn_states=None, random_states=None):
"summary": summary,
}

@torch.no_grad()
def compute_log_prob_metrics(self, x_tt, metrics=None):
"""
Compute log-probability metrics for the given test data.
Expand Down Expand Up @@ -275,10 +276,12 @@ def compute_log_prob_metrics(self, x_tt, metrics=None):

if "reward_batch" in reqs:
rewards_x_tt = self.gfn.proxy.rewards(self.gfn.env.states2proxy(x_tt))
if torch.is_tensor(rewards_x_tt):
rewards_x_tt = rewards_x_tt.detach().cpu().numpy()

if "corr_prob_traj_rewards" in metrics:
lp_metrics["corr_prob_traj_rewards"] = np.corrcoef(
np.exp(logprobs_x_tt.cpu().numpy()), rewards_x_tt
np.exp(logprobs_x_tt.detach().cpu().numpy()), rewards_x_tt
)[0, 1]

if "var_logrewards_logp" in metrics:
Expand All @@ -304,6 +307,7 @@ def compute_log_prob_metrics(self, x_tt, metrics=None):
"metrics": lp_metrics,
}

@torch.no_grad()
def compute_density_metrics(self, x_tt, dict_tt, metrics=None):
"""
Compute density metrics for the given test data.
Expand Down Expand Up @@ -371,9 +375,9 @@ def compute_density_metrics(self, x_tt, dict_tt, metrics=None):
elif self.gfn.continuous and hasattr(self.gfn.env, "fit_kde"):
batch, _ = self.gfn.sample_batch(n_forward=self.config.n, train=False)
assert batch.is_valid()
x_sampled = batch.get_terminating_states(proxy=True)
x_sampled = self.gfn.env.states2kde(batch.get_terminating_states())
# TODO make it work with conditional env
x_tt = torch2np(self.gfn.env.states2proxy(x_tt))
x_tt = self.gfn.env.states2kde(x_tt)
kde_pred = self.gfn.env.fit_kde(
x_sampled,
kernel=self.config.kde.kernel,
Expand All @@ -384,8 +388,10 @@ def compute_density_metrics(self, x_tt, dict_tt, metrics=None):
kde_true = dict_tt["kde_true"]
else:
# Sample from reward via rejection sampling
x_from_reward = self.gfn.env.states2proxy(
self.gfn.sample_from_reward(n_samples=self.config.n)
x_from_reward = self.gfn.env.states2kde(
self.gfn.sample_from_reward(
n_samples=self.config.n, method=self.reward_sampling_method
)
)
# Fit KDE with samples from reward
kde_true = self.gfn.env.fit_kde(
Expand Down Expand Up @@ -444,6 +450,7 @@ def compute_density_metrics(self, x_tt, dict_tt, metrics=None):
"data": density_data,
}

@torch.no_grad()
def eval(self, metrics=None, **plot_kwargs):
"""
Evaluate the GFlowNetAgent and compute metrics and plots.
Expand Down Expand Up @@ -560,9 +567,11 @@ def plot(self, x_sampled, kde_pred, kde_true, plot_kwargs, **kwargs):
fig_kde_pred = fig_kde_true = fig_reward_samples = fig_samples_topk = None

if hasattr(self.gfn.env, "plot_reward_samples") and x_sampled is not None:
(sample_space_batch, rewards_sample_space) = (
self.gfn.get_sample_space_and_reward()
)
(
sample_space_batch,
rewards_sample_space,
) = self.gfn.get_sample_space_and_reward(return_states_proxy=False)
sample_space_batch = self.gfn.env.states2kde(sample_space_batch)
fig_reward_samples = self.gfn.env.plot_reward_samples(
x_sampled,
sample_space_batch,
Expand All @@ -571,7 +580,8 @@ def plot(self, x_sampled, kde_pred, kde_true, plot_kwargs, **kwargs):
)

if hasattr(self.gfn.env, "plot_kde"):
sample_space_batch, _ = self.gfn.get_sample_space_and_reward()
sample_space_batch = self.gfn.get_sample_space()
sample_space_batch = self.gfn.env.states2kde(sample_space_batch)
if kde_pred is not None:
fig_kde_pred = self.gfn.env.plot_kde(
sample_space_batch, kde_pred, **plot_kwargs
Expand Down
Loading
Loading