diff --git a/.gitignore b/.gitignore index 64a683331..d3f5f24f8 100644 --- a/.gitignore +++ b/.gitignore @@ -10,7 +10,6 @@ sh/ *.txt .vscode/ external/ -playground/ !requirements.txt !docs/requirements-docs.txt .DS_Store diff --git a/config/evaluator/base.yaml b/config/evaluator/base.yaml index c429fe808..e0eb852af 100644 --- a/config/evaluator/base.yaml +++ b/config/evaluator/base.yaml @@ -1,5 +1,6 @@ _target_: gflownet.evaluator.base.BaseEvaluator +reward_sampling_method: rejection # config formerly from logger.test first_it: True period: 100 diff --git a/gflownet/envs/README.md b/gflownet/envs/README.md index b4a17bf4f..6bf20a223 100644 --- a/gflownet/envs/README.md +++ b/gflownet/envs/README.md @@ -14,7 +14,7 @@ Note that the mask of invalid actions indeed flags _invalid_ actions, as opposed ## Buffer, train data and test data -A train and a test set can be created at the beginning of training. The train set may be used to sample offline (backward) trajectories. The test set may be used to compute metrics during and after training. These sets may be created in different ways, specificied by the configuration variables `env.buffer.train.type` and `env.buffer.test.type`. Options for the data set `type` are +A train and a test set can be created at the beginning of training. The train set may be used to sample offline (backward) trajectories. The test set may be used to compute metrics during and after training (e.g. JSD, correlation). These sets may be created in different ways, specificied by the configuration variables `env.buffer.train.type` and `env.buffer.test.type`. Options for the data set `type` are - `all`: all terminating states in the output space $\mathcal{X}$ will be added - Convenient but only feasible for small, synthetic environments like the hyper-grid. - `grid`: a grid of points in the output space $\mathcal{X}$ - Only available in certain environments where obtaining a grid of points is meaningful. This mode also requires specifying the number of points via `env.buffer..n`. @@ -52,3 +52,12 @@ To use the replay buffer (once enabled) for backward sampling, one can specify ` :::{tip} You can use [MyST](https://myst-parser.readthedocs.io/en/latest/syntax/admonitions.html) in the documentation. This is expected to fail on Github. ::: + +## Evaluator +Evaluator's parameters define which method is used for sampling from reward (`nested` or `rejection` sampling), and how many points will be sampled `evaluator.n`. evaluator.n_grid is used only for plotting and defines, if applicable, the number of grid points for visualizing KDEs. +```yaml +evaluator: + reward_sampling_method: nested + n_grid: 1000 # number of grid points to visualize KDEs + n: 1000 # number of samples from reward and from gfn + ``` \ No newline at end of file diff --git a/gflownet/envs/alaninedipeptide.py b/gflownet/envs/alaninedipeptide.py index 697c15b1a..d32ddb957 100644 --- a/gflownet/envs/alaninedipeptide.py +++ b/gflownet/envs/alaninedipeptide.py @@ -7,9 +7,9 @@ from torchtyping import TensorType from gflownet.envs.ctorus import ContinuousTorus -from gflownet.utils.molecule import constants -from gflownet.utils.molecule.atom_positions_dataset import AtomPositionsDataset -from gflownet.utils.molecule.conformer_base import ConformerBase +from gflownet.utils.molecule.constants import AD_FREE_TAS, AD_SMILES +from gflownet.utils.molecule.datasets import AtomPositionsDataset +from gflownet.utils.molecule.rdkit_conformer import RDKitConformer class AlanineDipeptide(ContinuousTorus): @@ -26,9 +26,7 @@ def __init__( path_to_dataset, url_to_dataset ) atom_positions = self.atom_positions_dataset.sample() - self.conformer = ConformerBase( - atom_positions, constants.ad_smiles, constants.ad_free_tas - ) + self.conformer = RDKitConformer(atom_positions, AD_SMILES, AD_FREE_TAS) n_dim = len(self.conformer.freely_rotatable_tas) super().__init__(**kwargs) self.sync_conformer_with_state() @@ -62,7 +60,7 @@ def states2proxy( ------- A numpy array containing all the states in the batch. """ - return super().states2proxy(states).numpy() + return super().states2proxy(states).detach().cpu().numpy() if __name__ == "__main__": diff --git a/gflownet/envs/base.py b/gflownet/envs/base.py index 3a1d0a3f9..b3e2b3f1d 100644 --- a/gflownet/envs/base.py +++ b/gflownet/envs/base.py @@ -16,7 +16,14 @@ from torch.distributions import Categorical from torchtyping import TensorType -from gflownet.utils.common import copy, set_device, set_float_precision, tbool, tfloat +from gflownet.utils.common import ( + copy, + set_device, + set_float_precision, + tbool, + tfloat, + torch2np, +) CMAP = mpl.colormaps["cividis"] """ @@ -48,7 +55,7 @@ def __init__( # Call reset() to set initial state, done, n_actions self.reset() # Device - self.device = set_device(device) + self.set_device(set_device(device)) # Float precision self.float = set_float_precision(float_precision) # Flag to skip checking if action is valid (computing mask) before step @@ -72,6 +79,17 @@ def __init__( self.policy_output_dim = len(self.fixed_policy_output) self.policy_input_dim = len(self.state2policy()) + def set_device(self, device: torch.device): + """ + Set the device of the environment. + + Parameters + ---------- + device : torch.device + The device to set the environment to. + """ + self.device = device + @abstractmethod def get_action_space(self): """ @@ -757,6 +775,15 @@ def traj2readable(self, traj=None): """ return str(traj).replace("(", "[").replace(")", "]").replace(",", "") + def states2kde( + self, states: Union[List, TensorType["batch", "state_dim"]] + ) -> Union[List, npt.NDArray, TensorType["batch", "kde_dim"]]: + """ + Converts a batch of states into a batch of states suitable for the KDE computations. + """ + states_kde = self.states2proxy(states) + return torch2np(states_kde) + def reset(self, env_id: Union[int, str] = None): """ Resets the environment. @@ -1249,6 +1276,7 @@ def top_k_metrics_and_plots( return metrics, figs, fig_names + @torch.no_grad() def plot_reward_distribution( self, states=None, scores=None, ax=None, title=None, proxy=None, **kwargs ): @@ -1269,7 +1297,7 @@ def plot_reward_distribution( states_proxy = self.states2proxy(states) scores = self.proxy(states_proxy) if isinstance(scores, TensorType): - scores = scores.cpu().detach().numpy() + scores = scores.detach().cpu().numpy() ax.hist(scores) ax.set_title(title) ax.set_ylabel("Number of Samples") diff --git a/gflownet/envs/cube.py b/gflownet/envs/cube.py index 27f7cc266..b81fe881f 100644 --- a/gflownet/envs/cube.py +++ b/gflownet/envs/cube.py @@ -1449,7 +1449,6 @@ def fit_kde( bandwidth : float The bandwidth of the kernel. """ - samples = torch2np(samples) return KernelDensity(kernel=kernel, bandwidth=bandwidth).fit(samples) def plot_reward_samples( @@ -1489,8 +1488,6 @@ def plot_reward_samples( """ if self.n_dim != 2: return None - samples = torch2np(samples) - samples_reward = torch2np(samples_reward) rewards = torch2np(rewards) # Create mesh grid from samples_reward n_per_dim = int(np.sqrt(samples_reward.shape[0])) @@ -1543,7 +1540,6 @@ def plot_kde( """ if self.n_dim != 2: return None - samples = torch2np(samples) # Create mesh grid from samples n_per_dim = int(np.sqrt(samples.shape[0])) assert n_per_dim**2 == samples.shape[0] diff --git a/gflownet/envs/htorus.py b/gflownet/envs/htorus.py index f5fd32fbf..396a7708d 100644 --- a/gflownet/envs/htorus.py +++ b/gflownet/envs/htorus.py @@ -54,6 +54,7 @@ def __init__( "vonmises_mean": 0.0, "vonmises_concentration": 0.001, }, + reward_sampling_method="rejection", **kwargs, ): assert n_dim > 0 @@ -74,6 +75,9 @@ def __init__( self.source = self.source_angles + [0] # End-of-sequence action: (n_dim, 0) self.eos = (self.n_dim, 0) + + self.reward_sampling_method = reward_sampling_method + # Base class init super().__init__( fixed_distr_params=fixed_distr_params, @@ -556,7 +560,6 @@ def fit_kde( bandwidth : float The bandwidth of the kernel. """ - samples = torch2np(samples) samples_aug = self.augment_samples(samples) kde = KernelDensity(kernel=kernel, bandwidth=bandwidth).fit(samples_aug) return kde @@ -606,7 +609,6 @@ def plot_reward_samples( """ if self.n_dim != 2: return None - samples = torch2np(samples) rewards = torch2np(rewards) n_per_dim = int(np.sqrt(rewards.shape[0])) assert n_per_dim**2 == rewards.shape[0] @@ -677,7 +679,6 @@ def plot_kde( """ if self.n_dim != 2: return None - samples = torch2np(samples) # Create mesh grid from samples n_per_dim = int(np.sqrt(samples.shape[0])) assert n_per_dim**2 == samples.shape[0] diff --git a/gflownet/envs/tetris.py b/gflownet/envs/tetris.py index bf927de2b..5a646836e 100644 --- a/gflownet/envs/tetris.py +++ b/gflownet/envs/tetris.py @@ -317,7 +317,7 @@ def state2readable(self, state: Optional[TensorType["height", "width"]] = None): if isinstance(state, tuple): readable = str(np.stack(state)) else: - readable = str(state.cpu().numpy()) + readable = str(state.detach().cpu().numpy()) readable = readable.replace("[[", "[").replace("]]", "]").replace("\n ", "\n") return readable @@ -581,7 +581,7 @@ def _plot_board(board, ax: Axes, cellsize: int = 20, linewidth: int = 2): linewidth : int The width of the separation between cells, in pixels. """ - board = board.clone().numpy() + board = board.detach().clone().numpy() height = board.shape[0] * cellsize width = board.shape[1] * cellsize board_img = 128 * np.ones( diff --git a/gflownet/envs/tree.py b/gflownet/envs/tree.py index 3895f336c..fbc081b51 100644 --- a/gflownet/envs/tree.py +++ b/gflownet/envs/tree.py @@ -882,7 +882,7 @@ def state2readable(self, state=None): """ if state is None: state = self.state.clone().detach() - state = state.cpu().numpy() + state = state.detach().cpu().numpy() readable = "" for idx in range(self.n_nodes): attributes = self._attributes_to_readable(state[idx]) diff --git a/gflownet/evaluator/base.py b/gflownet/evaluator/base.py index adc633c73..05221329c 100644 --- a/gflownet/evaluator/base.py +++ b/gflownet/evaluator/base.py @@ -34,8 +34,7 @@ class methods to instantiate an evaluator. class BaseEvaluator(AbstractEvaluator): - - def __init__(self, gfn_agent=None, **config): + def __init__(self, gfn_agent=None, reward_sampling_method="rejection", **config): """ Base evaluator class for GFlowNetAgent. @@ -56,6 +55,7 @@ def __init__(self, gfn_agent=None, **config): details about other methods and attributes, including the :meth:`~gflownet.evaluator.abstract.AbstractEvaluator.__init__`. """ + self.reward_sampling_method = reward_sampling_method super().__init__(gfn_agent, **config) def define_new_metrics(self): @@ -223,6 +223,7 @@ def eval_top_k(self, it, gfn_states=None, random_states=None): "summary": summary, } + @torch.no_grad() def compute_log_prob_metrics(self, x_tt, metrics=None): """ Compute log-probability metrics for the given test data. @@ -275,10 +276,12 @@ def compute_log_prob_metrics(self, x_tt, metrics=None): if "reward_batch" in reqs: rewards_x_tt = self.gfn.proxy.rewards(self.gfn.env.states2proxy(x_tt)) + if torch.is_tensor(rewards_x_tt): + rewards_x_tt = rewards_x_tt.detach().cpu().numpy() if "corr_prob_traj_rewards" in metrics: lp_metrics["corr_prob_traj_rewards"] = np.corrcoef( - np.exp(logprobs_x_tt.cpu().numpy()), rewards_x_tt + np.exp(logprobs_x_tt.detach().cpu().numpy()), rewards_x_tt )[0, 1] if "var_logrewards_logp" in metrics: @@ -304,6 +307,7 @@ def compute_log_prob_metrics(self, x_tt, metrics=None): "metrics": lp_metrics, } + @torch.no_grad() def compute_density_metrics(self, x_tt, dict_tt, metrics=None): """ Compute density metrics for the given test data. @@ -371,9 +375,9 @@ def compute_density_metrics(self, x_tt, dict_tt, metrics=None): elif self.gfn.continuous and hasattr(self.gfn.env, "fit_kde"): batch, _ = self.gfn.sample_batch(n_forward=self.config.n, train=False) assert batch.is_valid() - x_sampled = batch.get_terminating_states(proxy=True) + x_sampled = self.gfn.env.states2kde(batch.get_terminating_states()) # TODO make it work with conditional env - x_tt = torch2np(self.gfn.env.states2proxy(x_tt)) + x_tt = self.gfn.env.states2kde(x_tt) kde_pred = self.gfn.env.fit_kde( x_sampled, kernel=self.config.kde.kernel, @@ -384,8 +388,10 @@ def compute_density_metrics(self, x_tt, dict_tt, metrics=None): kde_true = dict_tt["kde_true"] else: # Sample from reward via rejection sampling - x_from_reward = self.gfn.env.states2proxy( - self.gfn.sample_from_reward(n_samples=self.config.n) + x_from_reward = self.gfn.env.states2kde( + self.gfn.sample_from_reward( + n_samples=self.config.n, method=self.reward_sampling_method + ) ) # Fit KDE with samples from reward kde_true = self.gfn.env.fit_kde( @@ -444,6 +450,7 @@ def compute_density_metrics(self, x_tt, dict_tt, metrics=None): "data": density_data, } + @torch.no_grad() def eval(self, metrics=None, **plot_kwargs): """ Evaluate the GFlowNetAgent and compute metrics and plots. @@ -560,9 +567,11 @@ def plot(self, x_sampled, kde_pred, kde_true, plot_kwargs, **kwargs): fig_kde_pred = fig_kde_true = fig_reward_samples = fig_samples_topk = None if hasattr(self.gfn.env, "plot_reward_samples") and x_sampled is not None: - (sample_space_batch, rewards_sample_space) = ( - self.gfn.get_sample_space_and_reward() - ) + ( + sample_space_batch, + rewards_sample_space, + ) = self.gfn.get_sample_space_and_reward(return_states_proxy=False) + sample_space_batch = self.gfn.env.states2kde(sample_space_batch) fig_reward_samples = self.gfn.env.plot_reward_samples( x_sampled, sample_space_batch, @@ -571,7 +580,8 @@ def plot(self, x_sampled, kde_pred, kde_true, plot_kwargs, **kwargs): ) if hasattr(self.gfn.env, "plot_kde"): - sample_space_batch, _ = self.gfn.get_sample_space_and_reward() + sample_space_batch = self.gfn.get_sample_space() + sample_space_batch = self.gfn.env.states2kde(sample_space_batch) if kde_pred is not None: fig_kde_pred = self.gfn.env.plot_kde( sample_space_batch, kde_pred, **plot_kwargs diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index 715596bb7..dfa9ad625 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -1167,96 +1167,101 @@ def train(self): self.opt.zero_grad() all_losses.append([i.item() for i in losses]) # Buffer - t0_buffer = time.time() - states_term = batch.get_terminating_states(sort_by="trajectory") - proxy_vals = batch.get_terminating_proxy_values(sort_by="trajectory") - proxy_vals = proxy_vals.tolist() - # The batch will typically have the log-rewards available, since they are - # used to compute the losses. In order to avoid recalculating the proxy - # values, the natural rewards are computed by taking the exponential of the - # log-rewards. In case the rewards are available in the batch but not the - # log-rewards, the latter are computed by taking the log of the rewards. - # Numerical issues are not critical in this case, since the derived values - # are only used for reporting purposes. - if batch.rewards_available(log=False): - rewards = batch.get_terminating_rewards(sort_by="trajectory") - if batch.rewards_available(log=True): - logrewards = batch.get_terminating_rewards( - sort_by="trajectory", log=True + with torch.no_grad(): + t0_buffer = time.time() + states_term = batch.get_terminating_states(sort_by="trajectory") + proxy_vals = batch.get_terminating_proxy_values(sort_by="trajectory") + proxy_vals = proxy_vals.tolist() + # The batch will typically have the log-rewards available, since they are + # used to compute the losses. In order to avoid recalculating the proxy + # values, the natural rewards are computed by taking the exponential of the + # log-rewards. In case the rewards are available in the batch but not the + # log-rewards, the latter are computed by taking the log of the rewards. + # Numerical issues are not critical in this case, since the derived values + # are only used for reporting purposes. + if batch.rewards_available(log=False): + rewards = batch.get_terminating_rewards(sort_by="trajectory") + if batch.rewards_available(log=True): + logrewards = batch.get_terminating_rewards( + sort_by="trajectory", log=True + ) + if not batch.rewards_available(log=False): + assert batch.rewards_available(log=True) + rewards = torch.exp(logrewards) + if not batch.rewards_available(log=True): + assert batch.rewards_available(log=False) + logrewards = torch.log(rewards) + rewards = rewards.tolist() + logrewards = logrewards.tolist() + actions_trajectories = batch.get_actions_trajectories() + self.buffer.add(states_term, actions_trajectories, logrewards, it) + self.buffer.add( + states_term, actions_trajectories, logrewards, it, buffer="replay" ) - if not batch.rewards_available(log=False): - assert batch.rewards_available(log=True) - rewards = torch.exp(logrewards) - if not batch.rewards_available(log=True): - assert batch.rewards_available(log=False) - logrewards = torch.log(rewards) - rewards = rewards.tolist() - logrewards = logrewards.tolist() - actions_trajectories = batch.get_actions_trajectories() - self.buffer.add(states_term, actions_trajectories, logrewards, it) - self.buffer.add( - states_term, actions_trajectories, logrewards, it, buffer="replay" - ) - t1_buffer = time.time() - times.update({"buffer": t1_buffer - t0_buffer}) - # Log - if self.logger.lightweight: - all_losses = all_losses[-100:] - else: - all_visited.extend(states_term) - # Progress bar - self.logger.progressbar_update( - pbar, all_losses, rewards, self.jsd, it, self.use_context - ) - # Train logs - t0_log = time.time() - if self.evaluator.should_log_train(it): - self.logger.log_train( - losses=losses, - rewards=rewards, - logrewards=logrewards, - proxy_vals=proxy_vals, - states_term=states_term, - batch_size=len(batch), - logz=self.logZ, - learning_rates=self.lr_scheduler.get_last_lr(), - step=it, - use_context=self.use_context, - ) - t1_log = time.time() - times.update({"log": t1_log - t0_log}) - # Save intermediate models - t0_model = time.time() - if self.evaluator.should_checkpoint(it): - self.logger.save_models( - self.forward_policy, self.backward_policy, self.state_flow, step=it - ) - t1_model = time.time() - times.update({"save_interim_model": t1_model - t0_model}) - - # Moving average of the loss for early stopping - if loss_term_ema and loss_flow_ema: - loss_term_ema = ( - self.ema_alpha * losses[1].item() - + (1.0 - self.ema_alpha) * loss_term_ema - ) - loss_flow_ema = ( - self.ema_alpha * losses[2].item() - + (1.0 - self.ema_alpha) * loss_flow_ema + t1_buffer = time.time() + times.update({"buffer": t1_buffer - t0_buffer}) + # Log + if self.logger.lightweight: + all_losses = all_losses[-100:] + else: + all_visited.extend(states_term) + # Progress bar + self.logger.progressbar_update( + pbar, all_losses, rewards, self.jsd, it, self.use_context ) - if ( - loss_term_ema < self.early_stopping - and loss_flow_ema < self.early_stopping - ): - break - else: - loss_term_ema = losses[1].item() - loss_flow_ema = losses[2].item() + # Train logs + t0_log = time.time() + if self.evaluator.should_log_train(it): + self.logger.log_train( + # convert losses to numbers before logging + losses=[i.item() for i in losses], + rewards=rewards, + logrewards=logrewards, + proxy_vals=proxy_vals, + states_term=states_term, + batch_size=len(batch), + logz=self.logZ, + learning_rates=self.lr_scheduler.get_last_lr(), + step=it, + use_context=self.use_context, + ) + t1_log = time.time() + times.update({"log": t1_log - t0_log}) + # Save intermediate models + t0_model = time.time() + if self.evaluator.should_checkpoint(it): + self.logger.save_models( + self.forward_policy, + self.backward_policy, + self.state_flow, + step=it, + ) + t1_model = time.time() + times.update({"save_interim_model": t1_model - t0_model}) + + # Moving average of the loss for early stopping + if loss_term_ema and loss_flow_ema: + loss_term_ema = ( + self.ema_alpha * losses[1].item() + + (1.0 - self.ema_alpha) * loss_term_ema + ) + loss_flow_ema = ( + self.ema_alpha * losses[2].item() + + (1.0 - self.ema_alpha) * loss_flow_ema + ) + if ( + loss_term_ema < self.early_stopping + and loss_flow_ema < self.early_stopping + ): + break + else: + loss_term_ema = losses[1].item() + loss_flow_ema = losses[2].item() - # Log times - t1_iter = time.time() - times.update({"iter": t1_iter - t0_iter}) - self.logger.log_time(times, use_context=self.use_context) + # Log times + t1_iter = time.time() + times.update({"iter": t1_iter - t0_iter}) + self.logger.log_time(times, use_context=self.use_context) # Save final model self.logger.save_models( @@ -1266,16 +1271,17 @@ def train(self): if self.use_context is False: self.logger.end() - def get_sample_space_and_reward(self): + def get_sample_space(self): """ - Returns samples representative of the env state space with their rewards + Obtains and returns samples representative of the env state space, in + environment format. + + This method sets self.sample_space_batch. Returns ------- - sample_space_batch : tensor - Repressentative terminating states for the environment - rewards_sample_space : tensor - Rewards associated with the tates in sample_space_batch + sample_space_batch : list, tensor, array + Representative terminating states (in environment format) for the environment. """ if not hasattr(self, "sample_space_batch"): if hasattr(self.env, "get_all_terminating_states"): @@ -1290,25 +1296,86 @@ def get_sample_space_and_reward(self): "environment must implement either get_all_terminating_states() " "or get_grid_terminating_states()" ) - self.sample_space_batch = self.env.states2proxy(self.sample_space_batch) + return self.sample_space_batch + + def get_sample_space_and_reward(self, return_states_proxy: bool = False): + """ + Returns samples representative of the env state space with their rewards. + + Parameters + ---------- + return_states_proxy : bool + If True, returns the states in proxy format. + + Returns + ------- + sample_space_batch : list, tensor, array + Representative terminating states for the environment. If + return_states_proxy, the format returned will be the proxy format. + Otherwise, states will be returned in environment fomat. + rewards_sample_space : tensor + Rewards associated with the tates in sample_space_batch + """ + if return_states_proxy or not hasattr(self, "rewards_sample_space"): + sample_space_proxy = self.env.states2proxy(self.get_sample_space()) if not hasattr(self, "rewards_sample_space"): - self.rewards_sample_space = self.proxy.rewards(self.sample_space_batch) + self.rewards_sample_space = self.proxy.rewards(sample_space_proxy) + if return_states_proxy: + return sample_space_proxy, self.rewards_sample_space + else: + return self.sample_space_batch, self.rewards_sample_space + + def sample_from_reward( + self, + n_samples: int, + proposal_distribution: str = "uniform", + epsilon: float = 1e-4, + method: str = "rejection", + ) -> Union[List, Dict, TensorType["n_samples", "state_dim"]]: + """ + Sampling from reward using rejection sampling or nested sampling. + + Returns a tensor in GFlowNet (state) format. + + Parameters + ---------- + n_samples : int + The number of samples to draw from the reward distribution. + proposal_distribution : str + Identifier of the proposal distribution for rejection sampling. Currently only `uniform` is + implemented. + epsilon : float + Small epsilon parameter for rejection sampling. + method : str + Identifier of the sampling method. Currently only `rejection` and `nested` are + implemented. - return self.sample_space_batch, self.rewards_sample_space + Returns + ------- + samples_final : list + The list of samples drawn from the reward distribution in environment + format. + """ + if method == "rejection": + return self.sample_from_reward_rejection( + n_samples, proposal_distribution, epsilon + ) + elif method == "nested": + return self.sample_from_reward_nested(n_samples) # TODO: implement other proposal distributions # TODO: rethink whether it is needed to convert to reward - def sample_from_reward( + def sample_from_reward_rejection( self, n_samples: int, proposal_distribution: str = "uniform", - epsilon=1e-4, + epsilon: float = 1e-4, ) -> Union[List, Dict, TensorType["n_samples", "state_dim"]]: """ Rejection sampling with proposal the uniform distribution defined over the sample space. - Returns a tensor in GFloNet (state) format. + Returns a tensor in GFlowNet (state) format. Parameters ---------- @@ -1350,6 +1417,77 @@ def sample_from_reward( samples_final.extend(samples_accepted[-(n_samples - len(samples_final)) :]) return samples_final + def sample_from_reward_nested( + self, n_samples: int + ) -> Union[List, Dict, TensorType["n_samples", "state_dim"]]: + """ + Nested sampling from reward, using ultranest. + + Returns a tensor in GFlowNet (state) format. + + Parameters + ---------- + n_samples : int + The number of samples to draw from the reward distribution. + + Returns + ------- + samples_final : list + The list of samples drawn from the reward distribution in environment + format. + """ + import ultranest + from wurlitzer import pipes + + def reward_func(angles): + # angles here is np array + states = np.concatenate( + [angles, np.ones((angles.shape[0], 1))], axis=1 + ).tolist() + rewards = ( + self.proxy.proxy2reward(self.proxy(self.env.states2proxy(states))) + .detach() + .cpu() + .numpy() + ) + return np.log(rewards) + + def prior_transform(cube): + params = cube.copy() + # transform location parameter: uniform prior + low = 0 + high = 2 * np.pi + for idx, elem in enumerate(cube): + params[idx] = elem * (high - low) + low + return params + + samples = [] + n_sampled = 0 + iteration = 0 + print(f"Running nested sampling (until {n_samples} samples are obtained)...") + while n_sampled < n_samples: + param_names = [f"theta_{i}" for i in range(self.env.n_dim)] + + with pipes(): + sampler = ultranest.ReactiveNestedSampler( + param_names, + reward_func, + prior_transform, + vectorized=True, + ndraw_min=1000, + ) + result = sampler.run() + + samples.append(result["samples"]) + n_sampled += result["samples"].shape[0] + print(f"Total samples (iteration #{iteration}): {n_sampled}.") + iteration += 1 + samples = np.concatenate(samples, axis=0) + # add dummy step dimension + samples = np.concatenate([samples, np.ones((samples.shape[0], 1))], axis=1) + np.random.shuffle(samples) + return torch.Tensor(samples[:n_samples]) + def make_opt(params, logZ, config): """ diff --git a/gflownet/policy/multihead_tree.py b/gflownet/policy/multihead_tree.py index 8c3306509..c8b566eda 100644 --- a/gflownet/policy/multihead_tree.py +++ b/gflownet/policy/multihead_tree.py @@ -357,9 +357,9 @@ def forward(self, x): logits[indices, self.leaf_index : self.feature_index] = y_leaf logits[indices, self.eos_index] = y_eos elif stage == Stage.LEAF: - logits[indices, self.feature_index : self.threshold_index] = ( - self.feature_head(batch) - ) + logits[ + indices, self.feature_index : self.threshold_index + ] = self.feature_head(batch) else: ks = [Tree.find_active(state) for state in states] feature_index = torch.Tensor( @@ -374,9 +374,9 @@ def forward(self, x): if self.continuous: logits[indices, (self.eos_index + 1) :] = head_output else: - logits[indices, self.threshold_index : self.operator_index] = ( - head_output - ) + logits[ + indices, self.threshold_index : self.operator_index + ] = head_output elif stage == Stage.THRESHOLD: threshold = torch.Tensor( [ @@ -464,14 +464,14 @@ def forward(self, x): ) if stage == Stage.COMPLETE: - logits[indices, self.operator_index : self.eos_index] = ( - self.complete_stage_head(batch) - ) + logits[ + indices, self.operator_index : self.eos_index + ] = self.complete_stage_head(batch) logits[indices, self.eos_index] = 1.0 elif stage == Stage.LEAF: - logits[indices, self.leaf_index : self.feature_index] = ( - self.leaf_stage_head(batch) - ) + logits[ + indices, self.leaf_index : self.feature_index + ] = self.leaf_stage_head(batch) elif stage == Stage.FEATURE: logits[indices, self.feature_index : self.threshold_index] = 1.0 elif stage == Stage.THRESHOLD: diff --git a/gflownet/utils/batch.py b/gflownet/utils/batch.py index 83fc83aa4..0082b2f56 100644 --- a/gflownet/utils/batch.py +++ b/gflownet/utils/batch.py @@ -679,20 +679,36 @@ def _compute_parents_policy(self): self._parents_policy_available is set to True. """ self.states_policy = self.get_states(policy=True) - self.parents_policy = torch.zeros_like(self.states_policy) - # Iterate over the trajectories to obtain the parents from the states - for traj_idx, batch_indices in self.trajectories.items(): - # parent is source - self.parents_policy[batch_indices[0]] = tfloat( - self.envs[traj_idx].state2policy(self.envs[traj_idx].source), - device=self.device, - float_type=self.float, - ) - # parent is not source - self.parents_policy[batch_indices[1:]] = self.states_policy[ - batch_indices[:-1] - ] - self._parents_policy_available = True + # hacky way to check whether it is MoleculeGraph without importing MoleculeGraph here + if hasattr(self.states_policy, "num_nodes"): + parents_policy = [None] * len(self.states_policy) + # Iterate over the trajectories to obtain the parents from the states + for traj_idx, batch_indices in self.trajectories.items(): + # parent is source + parents_policy[batch_indices[0]] = tfloat( + self.envs[traj_idx].state2policy(self.envs[traj_idx].source), + device=self.device, + float_type=self.float, + ) + # parent is not source + for idx, previous_idx in zip(batch_indices[1:], batch_indices[:-1]): + parents_policy[idx] = self.states_policy[previous_idx] + self.parents_policy = self.env.collate_policy_states(parents_policy) + else: + self.parents_policy = torch.zeros_like(self.states_policy) + # Iterate over the trajectories to obtain the parents from the states + for traj_idx, batch_indices in self.trajectories.items(): + # parent is source + self.parents_policy[batch_indices[0]] = tfloat( + self.envs[traj_idx].state2policy(self.envs[traj_idx].source), + device=self.device, + float_type=self.float, + ) + # parent is not source + self.parents_policy[batch_indices[1:]] = self.states_policy[ + batch_indices[:-1] + ] + self.parents_policy_available = True def get_parents_all( self, policy: bool = False, force_recompute: bool = False @@ -857,9 +873,9 @@ def get_masks_forward( masks_invalid_actions_forward_parents[parents_indices == -1] = self.source[ "mask_forward" ] - masks_invalid_actions_forward_parents[parents_indices != -1] = ( - masks_invalid_actions_forward[parents_indices[parents_indices != -1]] - ) + masks_invalid_actions_forward_parents[ + parents_indices != -1 + ] = masks_invalid_actions_forward[parents_indices[parents_indices != -1]] return masks_invalid_actions_forward_parents return masks_invalid_actions_forward diff --git a/gflownet/utils/common.py b/gflownet/utils/common.py index 5c29bc238..9fb1b236e 100644 --- a/gflownet/utils/common.py +++ b/gflownet/utils/common.py @@ -42,7 +42,7 @@ def set_device(device: Union[str, torch.device]): """ if isinstance(device, torch.device): return device - if device.lower() == "cuda" and torch.cuda.is_available(): + if device.lower() in ["cuda", "gpu"] and torch.cuda.is_available(): return torch.device("cuda") else: return torch.device("cpu") @@ -140,8 +140,8 @@ def torch2np(x): np.ndarray Converted data. """ - if hasattr(x, "is_cuda") and x.is_cuda: - x = x.detach().cpu() + if torch.is_tensor(x): + return np.array(x.detach().cpu()) return np.array(x) @@ -462,6 +462,9 @@ def tfloat(x, device, float_type): return torch.stack(x).to(device=device, dtype=float_type) if torch.is_tensor(x): return x.to(device=device, dtype=float_type) + elif hasattr(x, "tfloat"): + x = x.tfloat(float_type) + return x.to(device=device) else: return torch.tensor(x, dtype=float_type, device=device) @@ -579,7 +582,7 @@ def concat_items(list_of_items, indices=None): result = np.concatenate(list_of_items) if indices is not None: if torch.is_tensor(indices[0]): - indices = indices.cpu().numpy() + indices = indices.detach().cpu().numpy() result = result[indices] elif torch.is_tensor(list_of_items[0]): result = torch.cat(list_of_items) diff --git a/gflownet/utils/logger.py b/gflownet/utils/logger.py index ffe9c21ec..3d02bfecc 100644 --- a/gflownet/utils/logger.py +++ b/gflownet/utils/logger.py @@ -250,7 +250,7 @@ def log_train( loss_metrics = dict( zip( ["Loss", "Loss (terminating)", "Loss (non-term.)"], - [loss.item() for loss in losses], + losses, ) ) self.log_metrics( diff --git a/gflownet/utils/molecule/atom_positions_dataset.py b/gflownet/utils/molecule/atom_positions_dataset.py index 0b66f4363..738c15371 100644 --- a/gflownet/utils/molecule/atom_positions_dataset.py +++ b/gflownet/utils/molecule/atom_positions_dataset.py @@ -4,9 +4,12 @@ class AtomPositionsDataset: - def __init__(self, path_to_data, url_to_data): + def __init__(self, smiles: str, path_to_data: str, url_to_data: str): path_to_data = download_file_if_not_exists(path_to_data, url_to_data) - self.positions = np.load(path_to_data) + conformers = np.load(path_to_data, allow_pickle=True).item() + + self.positions = conformers[smiles]["conformers"] + self.torsion_angles = conformers[smiles]["torsion_angles"] def __getitem__(self, i): return self.positions[i] @@ -17,3 +20,6 @@ def __len__(self): def sample(self, size=None): idx = np.random.randint(0, len(self), size=size) return self.positions[idx] + + def first(self): + return self[0] diff --git a/gflownet/utils/molecule/conformer_base.py b/gflownet/utils/molecule/conformer_base.py index 7282010e4..9966ad078 100644 --- a/gflownet/utils/molecule/conformer_base.py +++ b/gflownet/utils/molecule/conformer_base.py @@ -3,7 +3,7 @@ from rdkit.Chem import AllChem, TorsionFingerprints, rdMolTransforms from rdkit.Geometry.rdGeometry import Point3D -from gflownet.utils.molecule import constants +from gflownet.utils.molecule.constants import AD_ATOM_TYPES, AD_FREE_TAS, AD_SMILES def get_torsion_angles_atoms_list(mol): @@ -24,20 +24,20 @@ def get_all_torsion_angles(mol, conf): def get_dummy_ad_atom_positions(): - rmol = Chem.MolFromSmiles(constants.ad_smiles) + rmol = Chem.MolFromSmiles(AD_SMILES) rmol = Chem.AddHs(rmol) AllChem.EmbedMolecule(rmol) rconf = rmol.GetConformer() return rconf.GetPositions() -def get_dummy_ad_conf_base(): +def get_dummy_ad_rdkconf(): pos = get_dummy_ad_atom_positions() - conf = ConformerBase(pos, constants.ad_smiles, constants.ad_free_tas) + conf = RDKitConformer(pos, AD_SMILES, AD_FREE_TAS) return conf -class ConformerBase: +class RDKitConformer: def __init__(self, atom_positions, smiles, freely_rotatable_tas=None): """ :param atom_positions: numpy.ndarray of shape [num_atoms, 3] of dtype float64 @@ -134,16 +134,14 @@ def increment_torsion_angle(self, torsion_angle, increment): if __name__ == "__main__": from tabulate import tabulate - rmol = Chem.MolFromSmiles(constants.ad_smiles) + rmol = Chem.MolFromSmiles(AD_SMILES) rmol = Chem.AddHs(rmol) AllChem.EmbedMolecule(rmol) rconf = rmol.GetConformer() test_pos = rconf.GetPositions() initial_tas = get_all_torsion_angles(rmol, rconf) - conf = ConformerBase( - test_pos, constants.ad_smiles, constants.ad_atom_types, constants.ad_free_tas - ) + conf = RDKitConformer(test_pos, AD_SMILES, AD_ATOM_TYPES, AD_FREE_TAS) # check torsion angles randomisation conf.randomize_freely_rotatable_tas() conf_tas = conf.get_all_torsion_angles() diff --git a/gflownet/utils/molecule/constants.py b/gflownet/utils/molecule/constants.py index d2786b6cb..e2900e357 100644 --- a/gflownet/utils/molecule/constants.py +++ b/gflownet/utils/molecule/constants.py @@ -1,20 +1,145 @@ from rdkit import Chem -# Edge and node feature names in DGL graph -atom_position_name = "pos" -atom_feature_name = "atom_features" -edge_feature_name = "edge_features" +# Edge and node feature names in a graph step_feature_name = "step" atomic_numbers_name = "atomic_numbers" +rotatable_edges_mask_name = "rotatable_edges" +rotation_affected_nodes_mask_name = "rotation_affected_nodes" +rotation_signs_name = "rotation_signs" # Options for atoms featurization -ad_atom_types = ("H", "C", "N", "O") -atom_degrees = tuple(range(1, 7)) -atom_hybridizations = tuple(list(Chem.rdchem.HybridizationType.names.values())) -bond_types = tuple([Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE]) +AD_ATOM_TYPES = ("H", "C", "N", "O", "F", "S", "Cl") +ATOM_DEGREES = tuple(range(1, 7)) +ATOM_HYBRIDIZATIONS = tuple(list(Chem.rdchem.HybridizationType.names.values())) +BOND_TYPES = tuple( + [ + "FAKE", + Chem.rdchem.BondType.SINGLE, + Chem.rdchem.BondType.DOUBLE, + Chem.rdchem.BondType.TRIPLE, + Chem.rdchem.BondType.AROMATIC, + ] +) # SMILES strings -ad_smiles = "CC(C(=O)NC)NC(=O)C" +AD_SMILES = "CC(C(=O)NC)NC(=O)C" +KETOROLAC_SMILES = "OC(=O)C1CCn2c1ccc2C(=O)c1ccccc1" +IBUPROFEN_SMILES = "CC(C)CC1=CC=C(C=C1)C(C)C(=O)O" + # Freely rotatable torsion angles -ad_free_tas = ((0, 1, 2, 3), (0, 1, 6, 7)) +AD_FREE_TAS = ((0, 1, 2, 3), (0, 1, 6, 7)) + +# some selected SMILES strings used in the paper https://pubs.rsc.org/en/content/articlepdf/2024/DD/D4DD00023D +SELECTED_SMILES = [ + "O=C(c1ccccc1)c1ccc2c(c1)OCCOCCOCCOCCO2", + "O=S(=O)(NN=C1CCCCCC1)c1ccc(Cl)cc1", + "O=C(NC1CCCCC1)N1CCN(C2CCCCC2)CC1", + "O=C(COc1ccc(Cl)cc1[N+](=O)[O-])N1CCCCCC1", + "O=C(Nc1ccc(N2CCN(C(=O)c3ccccc3)CC2)cc1)c1cccs1", + "O=[N+]([O-])/C(C(=C(Cl)Cl)N1CCN(Cc2ccccc2)CC1)=C1\\NCCN1Cc1ccc(Cl)nc1", + "O=C(CSc1nnc(C2CC2)n1-c1ccccc1)Nc1ccc(N2CCOCC2)cc1", + "O=C(Nc1ccccn1)C(F)(F)C(F)(F)C(F)(F)C(F)(F)C(=O)Nc1ccccn1", + "O=C(NCCc1nnc2ccc(NCCCN3CCOCC3)nn12)c1ccccc1F", + "O=C(CSc1cn(CCNC(=O)c2cccs2)c2ccccc12)NCc1ccccc1", + "O=C(CCC(=O)N(CC(=O)NC1CCCCC1)Cc1cccs1)Nc1ccccn1", + "S=C(c1ccc2c(c1)OCO2)N1CCOCC1", + "O=Cc1cnc(N2CCN(c3ccccc3)CC2)s1", + "O=[N+]([O-])c1ccc(NCc2ccc(Cl)cc2)nc1", + "O=C(Nc1cc(C(F)(F)F)ccc1N1CCCCC1)c1ccncc1", + "O=C(Nc1nnc(-c2ccccc2Cl)s1)C1CCN(S(=O)(=O)c2ccc(Cl)cc2)CC1", + "O=C(CCNC(=O)c1ccccc1Cl)Nc1nc2ccccc2s1", + "O=C(CCC(=O)Nc1ccccc1Cl)N/N=C/c1ccccc1", + "O=C(COc1ccccc1)Nc1ccccc1C(=O)NCc1ccco1", + "O=C(CNC(=O)c1cccs1)NCC(=O)OCc1ccc(Cl)cc1Cl", + "O=C(NCc1ccccc1)c1onc(CSc2ccc(Cl)cc2)c1C(=O)NCC1CC1", + "O=C(CN(C(=O)CCC(=O)Nc1ccccn1)c1ccc2c(c1)OCO2)NCc1ccco1", + "O=[N+]([O-])c1ccc(N2CCNCC2)c(Cl)c1", + "O=[N+]([O-])c1ccccc1S(=O)(=O)N1CCCCC1", + "N#C/C(=C\\N1CCN(Cc2ccc3c(c2)OCO3)CC1)c1nc2ccccc2s1", + "O=C(NNc1ccc([N+](=O)[O-])cc1)c1ccccc1Cl", + "O=C(OCc1ccccc1Cl)c1ccccc1C(=O)c1ccccc1", + "O=C(CN(c1cccc(C(F)(F)F)c1)S(=O)(=O)c1ccccc1)N1CCOCC1", + "O=c1[nH]c2cc3c(cc2cc1CN(Cc1cccnc1)Cc1nnnn1Cc1ccco1)OCCO3", + "O=C(CCNC(=O)C1CCN(S(=O)(=O)c2ccccc2)CC1)NC1CC1", + "O=C(CN(c1ccc(F)cc1)S(=O)(=O)c1ccccc1)NCCSC1CCCCC1", + "C=CCn1c(CSCc2ccccc2)nnc1SCC(=O)N1CCN(c2ccccc2)CC1", + "C=COCCNC(=S)N1CCOCCOCCN(C(=S)NCCOC=C)CCOCC1", + "O=S(=O)(c1cccc(Cl)c1Cl)N1CCCCC1", + "O=C(Cn1ccccc1=O)c1cccs1", + "FC(F)Sc1ccc(Nc2ncnc3nc[nH]c23)cc1", + "O=c1[nH]c(SCCOc2ccccc2F)nc2ccccc12", + "O=C(/C=C/c1ccccc1Cl)NCCN1CCOCC1", + "N#CCCN1CCN(S(=O)(=O)c2ccc(S(=O)(=O)NC3CC3)cc2)CC1", + "O=C(CSc1nnc(Cc2cccs2)n1-c1ccccc1)Nc1ccc2c(c1)OCCO2", + "O=C(CNC(=O)c1ccc(N2C(=O)c3ccccc3C2=O)cc1)OCC(=O)c1ccccc1", + "O=C(CSc1nnc(CNC(=O)c2c(F)cccc2Cl)o1)NCc1ccc2c(c1)OCO2", + "O=C(CNC(=S)N(Cc1ccc(F)cc1)C1CCCCC1)NCCN1CCOCC1", + "C=CCn1c(CCNC(=O)c2cccs2)nnc1SCC(=O)Nc1cccc(F)c1", + "Clc1cccc(CN2CCCCCC2)c1Cl", + "O=C(Nc1cc(=O)c2ccccc2o1)N1CCCCC1", + "O=S(=O)(c1ccccc1)c1nc(-c2ccco2)oc1N1CCOCC1", + "O=C(CC1CCCCC1)NCc1ccco1", + "O=C(c1cccc([N+](=O)[O-])c1)N1CCCN(C(=O)c2cccc([N+](=O)[O-])c2)CC1", + "Fc1ccccc1OCCCCCN1CCCC1", + "O=C(c1ccc(S(=O)(=O)NCc2ccco2)cc1)N1CCN(Cc2ccc3c(c2)OCO3)CC1", + "N#Cc1ccc(NC(=O)COC(=O)CNC(=O)C2CCCCC2)cc1", + "O=C(CCC(=O)OCC(=O)c1ccc(-c2ccccc2)cc1)Nc1cccc(Cl)c1", + "O=C(COC(=O)CCC(=O)c1cccs1)Nc1ccc(S(=O)(=O)N2CCOCC2)cc1", + "O=C(COC(=O)CCCNC(=O)c1ccc(Cl)cc1)NCc1ccccc1Cl", + "C1CCC([NH2+]C2=NCCC2)CC1", + "N#Cc1ccccc1S(=O)(=O)Nc1ccc2c(c1)OCCO2", + "O=[N+]([O-])c1ccccc1S(=O)(=O)N1CCN(c2ccccc2)CC1", + "O=S(=O)(NCc1ccccc1Cl)c1ccc(-n2cccn2)cc1", + "O=C(CNS(=O)(=O)c1cccc2nsnc12)NC1CCCCC1", + "O=C(c1cccc([N+](=O)[O-])c1)n1nc(-c2ccccc2)nc1NCc1ccccc1", + "O=C(CN1CCN(c2ccccc2)CC1)NC(=O)NCc1ccco1", + "O=C(CCCn1c(=O)c2ccccc2n(Cc2ccccc2)c1=O)NCc1ccco1", + "O=C(COc1ccc(Cl)cc1)NCc1nnc(SCC(=O)N2CCCCCC2)o1", + "O=C(NCc1ccccc1)c1onc(CSc2ccccn2)c1C(=O)NCc1ccccc1", + "C=CCN(c1cccc(C(F)(F)F)c1)S(=O)(=O)c1cccc(C(=O)OCC(=O)Nc2ccccc2)c1", + "O=S(=O)(N1CCCCCC1)N1CC[NH2+]CC1", + "O=C1c2ccccc2C(=O)N1Cc1nn2c(-c3ccc(Cl)cc3)nnc2s1", + "O=C(CN1C(=O)NC2(CCCC2)C1=O)Nc1ccc(F)c(F)c1F", + "O=C(Cc1n[nH]c(=O)[nH]c1=O)N/N=C/c1ccccc1", + "O=C(NCCSc1ccc(Cl)cc1)c1ccco1", + "O=C(CN1CCN(Cc2ccccc2Cl)CC1)N/N=C/c1ccco1", + "O=C(Nc1ccc(-c2csc(Nc3cccc(C(F)(F)F)c3)n2)cc1)c1cccc(C(F)(F)F)c1", + "O=C(CCCCCN1C(=O)c2cccc3cccc(c23)C1=O)NCc1ccco1", + "O=C(NCCCn1ccnc1)/C(=C\\c1cccs1)NC(=O)c1cccs1", + "O=C(COC(=O)COc1ccccc1[N+](=O)[O-])Nc1ccc(S(=O)(=O)N2CCCCC2)cc1", + "O=C(NCCCN1CCCC1=O)c1cc(NS(=O)(=O)c2ccc(F)cc2)cc(NS(=O)(=O)c2ccc(F)cc2)c1", + "Clc1ccc(N2CCN(c3ncnc4c3oc3ccccc34)CC2)cc1", + "O=C(Nc1c(Cl)ccc2nsnc12)c1cccnc1", + "c1nc(COc2nsnc2N2CCOCC2)cs1", + "O=C(C1CC1)N1CCN=C1SCc1ccccc1", + "O=C(Cc1ccccc1)OCC[NH+]1CCOCC1", + "O=C(CCSc1ccccc1)NCc1cccnc1", + "O=C(CNC(=O)c1ccc(F)cc1)N/N=C/c1cn[nH]c1-c1ccccc1", + "O=C1CCN(CCc2ccccc2)CCN1[C@H](CSc1ccccc1)Cc1ccccc1", + "O=C(CCCCCn1c(=S)[nH]c2ccc(N3CCOCC3)cc2c1=O)NCc1ccc(Cl)cc1", + "O=C(CCC(=O)OCCCC(F)(F)C(F)(F)F)NC1CCCCC1", + "O=C(CN(Cc1ccco1)C(=O)CNS(=O)(=O)c1ccc(F)cc1)NCc1ccco1", + "O=c1[nH]c(N2CCN(c3ccccc3)CC2)nc2c1CCC2", + "O=C(Nc1ccc2c(c1)OCO2)c1cccs1", + "O=C(Nc1cccc2ccccc12)N1CCN(c2ccccc2)CC1", + "O=C(NC(=S)Nc1ccccn1)c1ccccc1", + "O=C(Nc1ccc(Cl)cc1)c1cccc(S(=O)(=O)Nc2ccccn2)c1", + "O=C(COc1cnc2ccccc2n1)NCCC1=CCCCC1", + "O=C(NCCN1CCOCC1)c1ccc(/C=C2\\Sc3ccccc3N(Cc3ccc(F)cc3)C2=O)cc1", + "O=C(NCCc1cccc(Cl)c1)c1ccc(OC2CCN(Cc3ccccn3)CC2)cc1", + "C=CC[NH2+]CCOCCOc1ccccc1-c1ccccc1", + "O=C(COC(=O)c1ccccc1NC(=O)c1ccco1)NCCC1=CCCCC1", + "O=C(CNC(=S)N(Cc1ccccc1Cl)C1CCCC1)NCCCN1CCOCC1", + "O=c1c2ccccc2nnn1Cc1ccccc1Cl", + "S=C(Nc1ccccc1)N1CCCCCCC1", + "O=C(Cn1ccc([N+](=O)[O-])n1)N1CCCc2ccccc21", + "O=C(NS(=O)(=O)N1CCOCC1)C1=C(N2CCCC2)COC1=O", + "O=C(CCCn1c(=O)[nH]c2ccsc2c1=O)NC1CCCCC1", + "O=C(Cc1ccc(Cl)cc1)Nc1ccc(S(=O)(=O)Nc2ncccn2)cc1", + "O=C1COc2ccc(C(=O)COC(=O)CCSc3ccccc3)cc2N1", + "O=C(Nc1cc(F)cc(F)c1)c1ccc(NCCC[NH+]2CCCCCC2)c([N+](=O)[O-])c1", + "O=C(CCCn1c(=O)[nH]c2cc(Cl)ccc2c1=O)NCCCN1CCN(c2ccc(F)cc2)CC1", + "O=C(NCCN1CCN(C(=O)C(c2ccccc2)c2ccccc2)CC1)C(=O)Nc1ccccc1", + "O=C(NCCCN1CCN(CCCNC(=O)c2ccc3c(c2)OCO3)CC1)c1ccc2c(c1)OCO2", +] diff --git a/tests/gflownet/test_sample_from_reward.py b/tests/gflownet/test_sample_from_reward.py new file mode 100644 index 000000000..cf30492af --- /dev/null +++ b/tests/gflownet/test_sample_from_reward.py @@ -0,0 +1,31 @@ +import os +from pathlib import Path + +import numpy as np +import yaml +from hydra import compose, initialize +from omegaconf import OmegaConf + +from gflownet.utils.common import gflownet_from_config + + +def test_nested_sampling_simple_check(): + ROOT = Path(__file__).resolve().parent.parent.parent + command = "+experiments=icml23/ctorus logger.do.online=False" + overrides = command.split() + with initialize( + version_base="1.1", + config_path=os.path.relpath( + str(ROOT / "config"), start=str(Path(__file__).parent) + ), + job_name="xxx", + ): + config = compose(config_name="main", overrides=overrides) + + gfn = gflownet_from_config(config) + samples = gfn.sample_from_reward(100, method="nested") + assert samples.shape[0] == 100 + assert (samples[:, 0] < np.pi * 2).all() + assert (samples[:, 1] < np.pi * 2).all() + assert (samples[:, 0] > 0).all() + assert (samples[:, 1] > 0).all()