From 6811a8239398c4762de40ce2036dfc4efb9af8bd Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Tue, 20 Feb 2024 17:00:15 -0500 Subject: [PATCH 01/20] added ising example --- tutorials/examples/train_ising.py | 134 ++++++++++++++++++++++++++++++ 1 file changed, 134 insertions(+) create mode 100644 tutorials/examples/train_ising.py diff --git a/tutorials/examples/train_ising.py b/tutorials/examples/train_ising.py new file mode 100644 index 00000000..6401dca2 --- /dev/null +++ b/tutorials/examples/train_ising.py @@ -0,0 +1,134 @@ +from argparse import ArgumentParser + +import torch +import wandb +from tqdm import tqdm + +from gfn.gflownet import FMGFlowNet +from gfn.gym import DiscreteEBM +from gfn.gym.discrete_ebm import IsingModel +from gfn.modules import DiscretePolicyEstimator +from gfn.utils.modules import NeuralNet +from gfn.utils.training import validate + + +def main(args): + # Configs + + use_wandb = len(args.wandb_project) > 0 + if use_wandb: + wandb.init(project=args.wandb_project) + wandb.config.update(args) + + device = "cpu" + torch.set_num_threads(args.n_threads) + hidden_dim = 512 + + n_hidden = 2 + acc_fn = "relu" + lr = 0.001 + lr_Z = 0.01 + validation_samples = 1000 + + def make_J(L, coupling_constant): + """Ising model parameters.""" + + def ising_n_to_ij(L, n): + i = n // L + j = n - i * L + return (i, j) + + N = L**2 + J = torch.zeros((N, N), device=torch.device(device)) + for k in range(N): + for m in range(k): + x1, y1 = ising_n_to_ij(L, k) + x2, y2 = ising_n_to_ij(L, m) + if x1 == x2 and abs(y2 - y1) == 1: + J[k][m] = 1 + J[m][k] = 1 + elif y1 == y2 and abs(x2 - x1) == 1: + J[k][m] = 1 + J[m][k] = 1 + + for k in range(L): + J[k * L][(k + 1) * L - 1] = 1 + J[(k + 1) * L - 1][k * L] = 1 + J[k][k + N - L] = 1 + J[k + N - L][k] = 1 + + return coupling_constant * J + + # Ising model env + N = args.L**2 + J = make_J(args.L, args.J) + ising_energy = IsingModel(J) + env = DiscreteEBM(N, alpha=1, energy=ising_energy, device_str=device) + + # Parametrization and losses + pf_module = NeuralNet( + input_dim=env.preprocessor.output_dim, + output_dim=env.n_actions, + hidden_dim=hidden_dim, + n_hidden_layers=n_hidden, + activation_fn=acc_fn, + ) + + pf_estimator = DiscretePolicyEstimator( + pf_module, env.n_actions, env.preprocessor, is_backward=False + ) + gflownet = FMGFlowNet(pf_estimator) + optimizer = torch.optim.Adam(gflownet.parameters(), lr=1e-3) + + # Learning + visited_terminating_states = env.States.from_batch_shape((0,)) + states_visited = 0 + for i in (pbar := tqdm(range(10000))): + trajectories = gflownet.sample_trajectories(env, n_samples=8, off_policy=False) + training_samples = gflownet.to_training_samples(trajectories) + optimizer.zero_grad() + loss = gflownet.loss(env, training_samples) + loss.backward() + optimizer.step() + + states_visited += len(trajectories) + to_log = {"loss": loss.item(), "states_visited": states_visited} + + if i % 25 == 0: + tqdm.write(f"{i}: {to_log}") + + +if __name__ == "__main__": + # Comand-line arguments + parser = ArgumentParser() + + parser.add_argument( + "--n_threads", + type=int, + default=4, + help="Number of threads used by PyTorch", + ) + + parser.add_argument( + "-L", + type=int, + default=16, + help="Lentgh of the grid", + ) + + parser.add_argument( + "-J", + type=float, + default=0.44, + help="J (Magnetic coupling constant)", + ) + + parser.add_argument( + "--wandb_project", + type=str, + default="", + help="Name of the wandb project. If empty, don't use wandb", + ) + + args = parser.parse_args() + main(args) From 89027ddb8d3edfd5c91de977991e6281c1d5a0a6 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Thu, 22 Feb 2024 12:55:50 -0500 Subject: [PATCH 02/20] function to stack a list of states --- src/gfn/states.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/src/gfn/states.py b/src/gfn/states.py index 53492861..79f1100b 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from copy import deepcopy from math import prod -from typing import Callable, ClassVar, Optional, Sequence, cast +from typing import Callable, ClassVar, Optional, Sequence, List, cast import torch from torchtyping import TensorType as TT @@ -446,3 +446,20 @@ def init_forward_masks(self, set_ones: bool = True): self.forward_masks = torch.ones(shape).bool() else: self.forward_masks = torch.zeros(shape).bool() + + +def stack_states(states: List[States]): + """Given a list of states, stacks them along a new dimension (0).""" + state_example = states[0] # We assume all elems of `states` are the same. + + stacked_states = state_example.from_batch_shape((0, 0)) # Empty. + stacked_states.tensor = torch.stack([s.tensor for s in states], dim=0) + if state_example._log_rewards: + stacked_states._log_rewards = torch.stack([s._log_rewards for s in states], dim=0) + stacked_states.forward_masks = torch.stack([s.forward_masks for s in states], dim=0) + stacked_states.backward_masks = torch.stack([s.backward_masks for s in states], dim=0) + + # Adds the trajectory dimension. + stacked_states.batch_shape = (stacked_states.tensor.shape[0],) + state_example.batch_shape + + return stacked_states From 2afb00ee37ad9daa4e3d1eed3799909fb4bf711c Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Thu, 22 Feb 2024 12:56:26 -0500 Subject: [PATCH 03/20] added notes for bug --- src/gfn/gflownet/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gfn/gflownet/base.py b/src/gfn/gflownet/base.py index e38bb10a..b5aa2929 100644 --- a/src/gfn/gflownet/base.py +++ b/src/gfn/gflownet/base.py @@ -153,8 +153,8 @@ def get_pfs_and_pbs( if self.off_policy: # We re-use the values calculated in .sample_trajectories(). if trajectories.estimator_outputs is not None: - estimator_outputs = trajectories.estimator_outputs[ - ~trajectories.actions.is_dummy + estimator_outputs = trajectories.estimator_outputs[ # TODO: This contains `inf` when we use the new `stack_states` method in `samplers.py`! + ~trajectories.actions.is_dummy # And this causes later failures (p_f is not finite). ] else: raise Exception( From be2fee1efb4cc8aeaf10a35630068b3f7b9eb414 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Thu, 22 Feb 2024 12:58:58 -0500 Subject: [PATCH 04/20] NOT WORKING: this commit contains trajectories_states_b which is the proposed new method for stacking a list of states into a trajectory, but as the assert statements show, the tensor is correct, but the forward_masks are not --- src/gfn/samplers.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index 68b052a6..f22fa06e 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -7,7 +7,7 @@ from gfn.containers import Trajectories from gfn.env import Env from gfn.modules import GFNModule -from gfn.states import States +from gfn.states import States, stack_states class Sampler: @@ -140,6 +140,8 @@ def sample_trajectories( else states.is_sink_state ) + trajectories_states_b: List[States] = [states] + trajectories_states: List[TT["n_trajectories", "state_shape", torch.float]] = [ states.tensor ] @@ -220,9 +222,18 @@ def sample_trajectories( dones = dones | new_dones trajectories_states += [states.tensor] + trajectories_states_b += [states] + + # New Method + trajectories_states_b = stack_states(trajectories_states_b) + + # Old Method + trajectories_states = env.states_from_tensor( + torch.stack(trajectories_states, dim=0)) + + assert (trajectories_states_b.tensor == trajectories_states.tensor).sum() == trajectories_states.tensor.numel() + assert (trajectories_states_b.forward_masks == trajectories_states.forward_masks).sum() == trajectories_states.forward_masks.numel() - trajectories_states = torch.stack(trajectories_states, dim=0) - trajectories_states = env.states_from_tensor(trajectories_states) trajectories_actions = env.Actions.stack(trajectories_actions) trajectories_logprobs = torch.stack(trajectories_logprobs, dim=0) From 7b536a24276272cdde9ea3b8ff374088f640c818 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Sat, 24 Feb 2024 15:33:20 -0500 Subject: [PATCH 05/20] using stack_states to prevent recomputation of masks --- src/gfn/samplers.py | 21 ++++----------------- 1 file changed, 4 insertions(+), 17 deletions(-) diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index f22fa06e..475d4849 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -1,4 +1,5 @@ from typing import List, Optional, Tuple +from copy import deepcopy import torch from torchtyping import TensorType as TT @@ -140,11 +141,7 @@ def sample_trajectories( else states.is_sink_state ) - trajectories_states_b: List[States] = [states] - - trajectories_states: List[TT["n_trajectories", "state_shape", torch.float]] = [ - states.tensor - ] + trajectories_states: List[States] = [deepcopy(states)] trajectories_actions: List[TT["n_trajectories", torch.long]] = [] trajectories_logprobs: List[TT["n_trajectories", torch.float]] = [] trajectories_dones = torch.zeros( @@ -221,19 +218,9 @@ def sample_trajectories( states = new_states dones = dones | new_dones - trajectories_states += [states.tensor] - trajectories_states_b += [states] - - # New Method - trajectories_states_b = stack_states(trajectories_states_b) - - # Old Method - trajectories_states = env.states_from_tensor( - torch.stack(trajectories_states, dim=0)) - - assert (trajectories_states_b.tensor == trajectories_states.tensor).sum() == trajectories_states.tensor.numel() - assert (trajectories_states_b.forward_masks == trajectories_states.forward_masks).sum() == trajectories_states.forward_masks.numel() + trajectories_states += [deepcopy(states)] + trajectories_states = stack_states(trajectories_states) trajectories_actions = env.Actions.stack(trajectories_actions) trajectories_logprobs = torch.stack(trajectories_logprobs, dim=0) From 77e7e1b524a0a640f51b30a82a73a5ea8fee9e90 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Sat, 24 Feb 2024 15:34:13 -0500 Subject: [PATCH 06/20] stack_states now ignores masks for non-discrete states, and fixed bug in mask updating behaviour to prevent accumulation of errors. --- src/gfn/states.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/gfn/states.py b/src/gfn/states.py index 79f1100b..86eeabae 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -409,12 +409,15 @@ def set_nonexit_action_masks(self, cond, allow_exit: bool): allow_exit: sets whether exiting can happen at any point in the trajectory - if so, it should be set to True. """ + # Resets masks in place to prevent side-effects across steps. + self.forward_masks[:] = True if allow_exit: exit_idx = torch.zeros(self.batch_shape + (1,)).to(cond.device) else: exit_idx = torch.ones(self.batch_shape + (1,)).to(cond.device) self.forward_masks[torch.cat([cond, exit_idx], dim=-1).bool()] = False + def set_exit_masks(self, batch_idx): """Sets forward masks such that the only allowable next action is to exit. @@ -456,8 +459,11 @@ def stack_states(states: List[States]): stacked_states.tensor = torch.stack([s.tensor for s in states], dim=0) if state_example._log_rewards: stacked_states._log_rewards = torch.stack([s._log_rewards for s in states], dim=0) - stacked_states.forward_masks = torch.stack([s.forward_masks for s in states], dim=0) - stacked_states.backward_masks = torch.stack([s.backward_masks for s in states], dim=0) + + # We are dealing with a list of DiscretrStates instances. + if hasattr(state_example, "forward_masks"): + stacked_states.forward_masks = torch.stack([s.forward_masks for s in states], dim=0) + stacked_states.backward_masks = torch.stack([s.backward_masks for s in states], dim=0) # Adds the trajectory dimension. stacked_states.batch_shape = (stacked_states.tensor.shape[0],) + state_example.batch_shape From 4e364d389e372810d9a91a9e2b5df237f9a64de9 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Sat, 24 Feb 2024 15:34:41 -0500 Subject: [PATCH 07/20] black --- src/gfn/states.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/gfn/states.py b/src/gfn/states.py index 86eeabae..0e774b3b 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -417,7 +417,6 @@ def set_nonexit_action_masks(self, cond, allow_exit: bool): exit_idx = torch.ones(self.batch_shape + (1,)).to(cond.device) self.forward_masks[torch.cat([cond, exit_idx], dim=-1).bool()] = False - def set_exit_masks(self, batch_idx): """Sets forward masks such that the only allowable next action is to exit. @@ -458,14 +457,22 @@ def stack_states(states: List[States]): stacked_states = state_example.from_batch_shape((0, 0)) # Empty. stacked_states.tensor = torch.stack([s.tensor for s in states], dim=0) if state_example._log_rewards: - stacked_states._log_rewards = torch.stack([s._log_rewards for s in states], dim=0) + stacked_states._log_rewards = torch.stack( + [s._log_rewards for s in states], dim=0 + ) # We are dealing with a list of DiscretrStates instances. if hasattr(state_example, "forward_masks"): - stacked_states.forward_masks = torch.stack([s.forward_masks for s in states], dim=0) - stacked_states.backward_masks = torch.stack([s.backward_masks for s in states], dim=0) + stacked_states.forward_masks = torch.stack( + [s.forward_masks for s in states], dim=0 + ) + stacked_states.backward_masks = torch.stack( + [s.backward_masks for s in states], dim=0 + ) # Adds the trajectory dimension. - stacked_states.batch_shape = (stacked_states.tensor.shape[0],) + state_example.batch_shape + stacked_states.batch_shape = ( + stacked_states.tensor.shape[0], + ) + state_example.batch_shape return stacked_states From 26dda4b9ce919d223e5857b2218dd312498a2d66 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Sat, 24 Feb 2024 15:35:03 -0500 Subject: [PATCH 08/20] isort --- src/gfn/samplers.py | 2 +- src/gfn/states.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index 475d4849..a2f810b6 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -1,5 +1,5 @@ -from typing import List, Optional, Tuple from copy import deepcopy +from typing import List, Optional, Tuple import torch from torchtyping import TensorType as TT diff --git a/src/gfn/states.py b/src/gfn/states.py index 0e774b3b..cb48b130 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from copy import deepcopy from math import prod -from typing import Callable, ClassVar, Optional, Sequence, List, cast +from typing import Callable, ClassVar, List, Optional, Sequence, cast import torch from torchtyping import TensorType as TT From 1a6e768fd4981cb27af854e90558272a0b6cde71 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Sat, 24 Feb 2024 15:36:49 -0500 Subject: [PATCH 09/20] removed comment --- src/gfn/gflownet/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gfn/gflownet/base.py b/src/gfn/gflownet/base.py index b5aa2929..e38bb10a 100644 --- a/src/gfn/gflownet/base.py +++ b/src/gfn/gflownet/base.py @@ -153,8 +153,8 @@ def get_pfs_and_pbs( if self.off_policy: # We re-use the values calculated in .sample_trajectories(). if trajectories.estimator_outputs is not None: - estimator_outputs = trajectories.estimator_outputs[ # TODO: This contains `inf` when we use the new `stack_states` method in `samplers.py`! - ~trajectories.actions.is_dummy # And this causes later failures (p_f is not finite). + estimator_outputs = trajectories.estimator_outputs[ + ~trajectories.actions.is_dummy ] else: raise Exception( From 45d9893962901b876304758d2df9f39f24b4eeae Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Sat, 24 Feb 2024 15:49:44 -0500 Subject: [PATCH 10/20] black --- src/gfn/gflownet/base.py | 1 + testing/test_environments.py | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/gfn/gflownet/base.py b/src/gfn/gflownet/base.py index e38bb10a..9bd216cf 100644 --- a/src/gfn/gflownet/base.py +++ b/src/gfn/gflownet/base.py @@ -24,6 +24,7 @@ class GFlowNet(ABC, nn.Module, Generic[TrainingSampleType]): A formal definition of GFlowNets is given in Sec. 3 of [GFlowNet Foundations](https://arxiv.org/pdf/2111.09266). """ + log_reward_clip_min = float("-inf") # Default off. @abstractmethod diff --git a/testing/test_environments.py b/testing/test_environments.py index b110baac..5dbd4cc6 100644 --- a/testing/test_environments.py +++ b/testing/test_environments.py @@ -209,7 +209,9 @@ def test_box_fwd_step(delta: float): ] for failing_actions_list in failing_actions_lists_at_s0: - actions = env.actions_from_tensor(format_tensor(failing_actions_list, discrete=False)) + actions = env.actions_from_tensor( + format_tensor(failing_actions_list, discrete=False) + ) with pytest.raises(NonValidActionsError): states = env._step(states, actions) From 1e72273edaad50e6aa90aa5cfd14788b8010617d Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Sat, 24 Feb 2024 16:10:41 -0500 Subject: [PATCH 11/20] default value reduced for grid size --- tutorials/examples/train_ising.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tutorials/examples/train_ising.py b/tutorials/examples/train_ising.py index 6401dca2..26ca2864 100644 --- a/tutorials/examples/train_ising.py +++ b/tutorials/examples/train_ising.py @@ -112,7 +112,7 @@ def ising_n_to_ij(L, n): parser.add_argument( "-L", type=int, - default=16, + default=6, help="Lentgh of the grid", ) From c8cf89c64656b7d4c9a7258a7cf5cc59f31b4706 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Sat, 24 Feb 2024 16:10:53 -0500 Subject: [PATCH 12/20] typo --- tutorials/examples/train_ising.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tutorials/examples/train_ising.py b/tutorials/examples/train_ising.py index 26ca2864..1ca2c656 100644 --- a/tutorials/examples/train_ising.py +++ b/tutorials/examples/train_ising.py @@ -113,7 +113,7 @@ def ising_n_to_ij(L, n): "-L", type=int, default=6, - help="Lentgh of the grid", + help="Length of the grid", ) parser.add_argument( From 687136ca9025aee7eff44c3a9dc5578c2c9fb237 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Sat, 24 Feb 2024 16:34:52 -0500 Subject: [PATCH 13/20] black --- src/gfn/gflownet/base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/gfn/gflownet/base.py b/src/gfn/gflownet/base.py index e38bb10a..9bd216cf 100644 --- a/src/gfn/gflownet/base.py +++ b/src/gfn/gflownet/base.py @@ -24,6 +24,7 @@ class GFlowNet(ABC, nn.Module, Generic[TrainingSampleType]): A formal definition of GFlowNets is given in Sec. 3 of [GFlowNet Foundations](https://arxiv.org/pdf/2111.09266). """ + log_reward_clip_min = float("-inf") # Default off. @abstractmethod From 1846da1c9400c7526e2e26707039a97227400d6e Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Sat, 24 Feb 2024 16:38:32 -0500 Subject: [PATCH 14/20] black upgrade --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 957a60ce..539e3cb6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ torch = ">=1.9.0" torchtyping = ">=0.1.4" # dev dependencies. -black = { version = "*", optional = true } +black = { version = "24.2", optional = true } flake8 = { version = "*", optional = true } gitmopy = { version = "*", optional = true } myst-parser = { version = "*", optional = true } From 552e010bfe2c0088c15c143cb68614bd758b2c14 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Sat, 24 Feb 2024 16:42:09 -0500 Subject: [PATCH 15/20] upgrade black --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 539e3cb6..3d05ed8c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -86,7 +86,7 @@ all = [ "Bug Tracker" = "https://github.com/saleml/gfn/issues" [tool.black] -py36 = true +target_version = ["py310"] include = '\.pyi?$' exclude = '''/(\.git|\.hg|\.mypy_cache|\.tox|\.venv|build)/g''' From 21b845d48aafa1ec202a1c21feb2fd1bf8409824 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Sat, 24 Feb 2024 16:42:42 -0500 Subject: [PATCH 16/20] black --- src/gfn/gflownet/base.py | 4 +--- src/gfn/gflownet/detailed_balance.py | 4 +--- src/gfn/gym/helpers/box_utils.py | 1 + src/gfn/gym/hypergrid.py | 1 + 4 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/gfn/gflownet/base.py b/src/gfn/gflownet/base.py index 9bd216cf..ece89bc3 100644 --- a/src/gfn/gflownet/base.py +++ b/src/gfn/gflownet/base.py @@ -201,9 +201,7 @@ def get_pfs_and_pbs( return log_pf_trajectories, log_pb_trajectories - def get_trajectories_scores( - self, trajectories: Trajectories - ) -> Tuple[ + def get_trajectories_scores(self, trajectories: Trajectories) -> Tuple[ TT["n_trajectories", torch.float], TT["n_trajectories", torch.float], TT["n_trajectories", torch.float], diff --git a/src/gfn/gflownet/detailed_balance.py b/src/gfn/gflownet/detailed_balance.py index 4cb4e6e2..2c9cc723 100644 --- a/src/gfn/gflownet/detailed_balance.py +++ b/src/gfn/gflownet/detailed_balance.py @@ -42,9 +42,7 @@ def __init__( self.forward_looking = forward_looking self.log_reward_clip_min = log_reward_clip_min - def get_scores( - self, env: Env, transitions: Transitions - ) -> Tuple[ + def get_scores(self, env: Env, transitions: Transitions) -> Tuple[ TT["n_transitions", float], TT["n_transitions", float], TT["n_transitions", float], diff --git a/src/gfn/gym/helpers/box_utils.py b/src/gfn/gym/helpers/box_utils.py index c6342c75..bc5b18f2 100644 --- a/src/gfn/gym/helpers/box_utils.py +++ b/src/gfn/gym/helpers/box_utils.py @@ -1,4 +1,5 @@ """This file contains utilitary functions for the Box environment.""" + from typing import Tuple import numpy as np diff --git a/src/gfn/gym/hypergrid.py b/src/gfn/gym/hypergrid.py index b8bf27d1..9d6d7d0f 100644 --- a/src/gfn/gym/hypergrid.py +++ b/src/gfn/gym/hypergrid.py @@ -1,6 +1,7 @@ """ Copied and Adapted from https://github.com/Tikquuss/GflowNets_Tutorial """ + from typing import Literal, Tuple import torch From 1a5461529608da61c81cb9d1d37cccd37cb9cd53 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Sat, 24 Feb 2024 16:43:46 -0500 Subject: [PATCH 17/20] black upgrade --- tutorials/examples/train_box.py | 9 ++++++--- tutorials/examples/train_discreteebm.py | 1 + tutorials/examples/train_hypergrid.py | 1 + 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/tutorials/examples/train_box.py b/tutorials/examples/train_box.py index 5a3cf8dd..632d5b78 100644 --- a/tutorials/examples/train_box.py +++ b/tutorials/examples/train_box.py @@ -6,6 +6,7 @@ python train_box.py --delta {0.1, 0.25} --tied {--uniform_pb} --loss {TB, DB} """ + from argparse import ArgumentParser import numpy as np @@ -189,9 +190,11 @@ def main(args): # noqa: C901 if not args.uniform_pb: optimizer.add_param_group( { - "params": pb_module.last_layer.parameters() - if args.tied - else pb_module.parameters(), + "params": ( + pb_module.last_layer.parameters() + if args.tied + else pb_module.parameters() + ), "lr": args.lr, } ) diff --git a/tutorials/examples/train_discreteebm.py b/tutorials/examples/train_discreteebm.py index 3574fa2d..562bb2b4 100644 --- a/tutorials/examples/train_discreteebm.py +++ b/tutorials/examples/train_discreteebm.py @@ -10,6 +10,7 @@ [Learning GFlowNets from partial episodes for improved convergence and stability](https://arxiv.org/abs/2209.12782) python train_hypergrid.py --ndim {2, 4} --height 12 --R0 {1e-3, 1e-4} --tied --loss {TB, DB, SubTB} """ + from argparse import ArgumentParser import torch diff --git a/tutorials/examples/train_hypergrid.py b/tutorials/examples/train_hypergrid.py index 4d4e3a25..f52932e9 100644 --- a/tutorials/examples/train_hypergrid.py +++ b/tutorials/examples/train_hypergrid.py @@ -10,6 +10,7 @@ [Learning GFlowNets from partial episodes for improved convergence and stability](https://arxiv.org/abs/2209.12782) python train_hypergrid.py --ndim {2, 4} --height 12 --R0 {1e-3, 1e-4} --tied --loss {TB, DB, SubTB} """ + from argparse import ArgumentParser import torch From 6aa1659d6238d1a366222af884dfbec2ee4b40cd Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Sat, 24 Feb 2024 16:46:17 -0500 Subject: [PATCH 18/20] black formatting update --- tutorials/examples/train_box.py | 9 ++++++--- tutorials/examples/train_discreteebm.py | 1 + tutorials/examples/train_hypergrid.py | 1 + 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/tutorials/examples/train_box.py b/tutorials/examples/train_box.py index 5a3cf8dd..632d5b78 100644 --- a/tutorials/examples/train_box.py +++ b/tutorials/examples/train_box.py @@ -6,6 +6,7 @@ python train_box.py --delta {0.1, 0.25} --tied {--uniform_pb} --loss {TB, DB} """ + from argparse import ArgumentParser import numpy as np @@ -189,9 +190,11 @@ def main(args): # noqa: C901 if not args.uniform_pb: optimizer.add_param_group( { - "params": pb_module.last_layer.parameters() - if args.tied - else pb_module.parameters(), + "params": ( + pb_module.last_layer.parameters() + if args.tied + else pb_module.parameters() + ), "lr": args.lr, } ) diff --git a/tutorials/examples/train_discreteebm.py b/tutorials/examples/train_discreteebm.py index 3574fa2d..562bb2b4 100644 --- a/tutorials/examples/train_discreteebm.py +++ b/tutorials/examples/train_discreteebm.py @@ -10,6 +10,7 @@ [Learning GFlowNets from partial episodes for improved convergence and stability](https://arxiv.org/abs/2209.12782) python train_hypergrid.py --ndim {2, 4} --height 12 --R0 {1e-3, 1e-4} --tied --loss {TB, DB, SubTB} """ + from argparse import ArgumentParser import torch diff --git a/tutorials/examples/train_hypergrid.py b/tutorials/examples/train_hypergrid.py index 4d4e3a25..f52932e9 100644 --- a/tutorials/examples/train_hypergrid.py +++ b/tutorials/examples/train_hypergrid.py @@ -10,6 +10,7 @@ [Learning GFlowNets from partial episodes for improved convergence and stability](https://arxiv.org/abs/2209.12782) python train_hypergrid.py --ndim {2, 4} --height 12 --R0 {1e-3, 1e-4} --tied --loss {TB, DB, SubTB} """ + from argparse import ArgumentParser import torch From f1a5c7f016b36d1c2ba2809da7eeb7add63c6b1e Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Sat, 24 Feb 2024 16:53:02 -0500 Subject: [PATCH 19/20] extended excludes --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 3d05ed8c..36947af0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,7 +88,7 @@ all = [ [tool.black] target_version = ["py310"] include = '\.pyi?$' -exclude = '''/(\.git|\.hg|\.mypy_cache|\.tox|\.venv|build)/g''' +extend-exclude = '''/(\.git|\.hg|\.mypy_cache|\.ipynb|\.tox|\.venv|build)/g''' [tool.tox] legacy_tox_ini = ''' From 1a5ad2c98aa42eeb6ffcc76bf4bcccfe19936cc7 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Sat, 24 Feb 2024 17:01:00 -0500 Subject: [PATCH 20/20] checks whether user-defined function returns the expected type --- src/gfn/env.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/gfn/env.py b/src/gfn/env.py index 9b045ca3..510d3820 100644 --- a/src/gfn/env.py +++ b/src/gfn/env.py @@ -219,6 +219,10 @@ def _step( not_done_actions = actions[~new_sink_states_idx] new_not_done_states_tensor = self.step(not_done_states, not_done_actions) + if not isinstance(new_not_done_states_tensor, torch.Tensor): + raise Exception( + "User implemented env.step function *must* return a torch.Tensor!" + ) new_states.tensor[~new_sink_states_idx] = new_not_done_states_tensor