diff --git a/pyproject.toml b/pyproject.toml index 957a60ce..36947af0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ torch = ">=1.9.0" torchtyping = ">=0.1.4" # dev dependencies. -black = { version = "*", optional = true } +black = { version = "24.2", optional = true } flake8 = { version = "*", optional = true } gitmopy = { version = "*", optional = true } myst-parser = { version = "*", optional = true } @@ -86,9 +86,9 @@ all = [ "Bug Tracker" = "https://github.com/saleml/gfn/issues" [tool.black] -py36 = true +target_version = ["py310"] include = '\.pyi?$' -exclude = '''/(\.git|\.hg|\.mypy_cache|\.tox|\.venv|build)/g''' +extend-exclude = '''/(\.git|\.hg|\.mypy_cache|\.ipynb|\.tox|\.venv|build)/g''' [tool.tox] legacy_tox_ini = ''' diff --git a/src/gfn/env.py b/src/gfn/env.py index 9b045ca3..510d3820 100644 --- a/src/gfn/env.py +++ b/src/gfn/env.py @@ -219,6 +219,10 @@ def _step( not_done_actions = actions[~new_sink_states_idx] new_not_done_states_tensor = self.step(not_done_states, not_done_actions) + if not isinstance(new_not_done_states_tensor, torch.Tensor): + raise Exception( + "User implemented env.step function *must* return a torch.Tensor!" + ) new_states.tensor[~new_sink_states_idx] = new_not_done_states_tensor diff --git a/src/gfn/gflownet/base.py b/src/gfn/gflownet/base.py index e38bb10a..ece89bc3 100644 --- a/src/gfn/gflownet/base.py +++ b/src/gfn/gflownet/base.py @@ -24,6 +24,7 @@ class GFlowNet(ABC, nn.Module, Generic[TrainingSampleType]): A formal definition of GFlowNets is given in Sec. 3 of [GFlowNet Foundations](https://arxiv.org/pdf/2111.09266). """ + log_reward_clip_min = float("-inf") # Default off. @abstractmethod @@ -200,9 +201,7 @@ def get_pfs_and_pbs( return log_pf_trajectories, log_pb_trajectories - def get_trajectories_scores( - self, trajectories: Trajectories - ) -> Tuple[ + def get_trajectories_scores(self, trajectories: Trajectories) -> Tuple[ TT["n_trajectories", torch.float], TT["n_trajectories", torch.float], TT["n_trajectories", torch.float], diff --git a/src/gfn/gflownet/detailed_balance.py b/src/gfn/gflownet/detailed_balance.py index 4cb4e6e2..2c9cc723 100644 --- a/src/gfn/gflownet/detailed_balance.py +++ b/src/gfn/gflownet/detailed_balance.py @@ -42,9 +42,7 @@ def __init__( self.forward_looking = forward_looking self.log_reward_clip_min = log_reward_clip_min - def get_scores( - self, env: Env, transitions: Transitions - ) -> Tuple[ + def get_scores(self, env: Env, transitions: Transitions) -> Tuple[ TT["n_transitions", float], TT["n_transitions", float], TT["n_transitions", float], diff --git a/src/gfn/gym/helpers/box_utils.py b/src/gfn/gym/helpers/box_utils.py index c6342c75..bc5b18f2 100644 --- a/src/gfn/gym/helpers/box_utils.py +++ b/src/gfn/gym/helpers/box_utils.py @@ -1,4 +1,5 @@ """This file contains utilitary functions for the Box environment.""" + from typing import Tuple import numpy as np diff --git a/src/gfn/gym/hypergrid.py b/src/gfn/gym/hypergrid.py index b8bf27d1..9d6d7d0f 100644 --- a/src/gfn/gym/hypergrid.py +++ b/src/gfn/gym/hypergrid.py @@ -1,6 +1,7 @@ """ Copied and Adapted from https://github.com/Tikquuss/GflowNets_Tutorial """ + from typing import Literal, Tuple import torch diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index 68b052a6..a2f810b6 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -1,3 +1,4 @@ +from copy import deepcopy from typing import List, Optional, Tuple import torch @@ -7,7 +8,7 @@ from gfn.containers import Trajectories from gfn.env import Env from gfn.modules import GFNModule -from gfn.states import States +from gfn.states import States, stack_states class Sampler: @@ -140,9 +141,7 @@ def sample_trajectories( else states.is_sink_state ) - trajectories_states: List[TT["n_trajectories", "state_shape", torch.float]] = [ - states.tensor - ] + trajectories_states: List[States] = [deepcopy(states)] trajectories_actions: List[TT["n_trajectories", torch.long]] = [] trajectories_logprobs: List[TT["n_trajectories", torch.float]] = [] trajectories_dones = torch.zeros( @@ -219,10 +218,9 @@ def sample_trajectories( states = new_states dones = dones | new_dones - trajectories_states += [states.tensor] + trajectories_states += [deepcopy(states)] - trajectories_states = torch.stack(trajectories_states, dim=0) - trajectories_states = env.states_from_tensor(trajectories_states) + trajectories_states = stack_states(trajectories_states) trajectories_actions = env.Actions.stack(trajectories_actions) trajectories_logprobs = torch.stack(trajectories_logprobs, dim=0) diff --git a/src/gfn/states.py b/src/gfn/states.py index 53492861..cb48b130 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from copy import deepcopy from math import prod -from typing import Callable, ClassVar, Optional, Sequence, cast +from typing import Callable, ClassVar, List, Optional, Sequence, cast import torch from torchtyping import TensorType as TT @@ -409,6 +409,8 @@ def set_nonexit_action_masks(self, cond, allow_exit: bool): allow_exit: sets whether exiting can happen at any point in the trajectory - if so, it should be set to True. """ + # Resets masks in place to prevent side-effects across steps. + self.forward_masks[:] = True if allow_exit: exit_idx = torch.zeros(self.batch_shape + (1,)).to(cond.device) else: @@ -446,3 +448,31 @@ def init_forward_masks(self, set_ones: bool = True): self.forward_masks = torch.ones(shape).bool() else: self.forward_masks = torch.zeros(shape).bool() + + +def stack_states(states: List[States]): + """Given a list of states, stacks them along a new dimension (0).""" + state_example = states[0] # We assume all elems of `states` are the same. + + stacked_states = state_example.from_batch_shape((0, 0)) # Empty. + stacked_states.tensor = torch.stack([s.tensor for s in states], dim=0) + if state_example._log_rewards: + stacked_states._log_rewards = torch.stack( + [s._log_rewards for s in states], dim=0 + ) + + # We are dealing with a list of DiscretrStates instances. + if hasattr(state_example, "forward_masks"): + stacked_states.forward_masks = torch.stack( + [s.forward_masks for s in states], dim=0 + ) + stacked_states.backward_masks = torch.stack( + [s.backward_masks for s in states], dim=0 + ) + + # Adds the trajectory dimension. + stacked_states.batch_shape = ( + stacked_states.tensor.shape[0], + ) + state_example.batch_shape + + return stacked_states diff --git a/testing/test_environments.py b/testing/test_environments.py index b110baac..5dbd4cc6 100644 --- a/testing/test_environments.py +++ b/testing/test_environments.py @@ -209,7 +209,9 @@ def test_box_fwd_step(delta: float): ] for failing_actions_list in failing_actions_lists_at_s0: - actions = env.actions_from_tensor(format_tensor(failing_actions_list, discrete=False)) + actions = env.actions_from_tensor( + format_tensor(failing_actions_list, discrete=False) + ) with pytest.raises(NonValidActionsError): states = env._step(states, actions) diff --git a/tutorials/examples/train_box.py b/tutorials/examples/train_box.py index 5a3cf8dd..632d5b78 100644 --- a/tutorials/examples/train_box.py +++ b/tutorials/examples/train_box.py @@ -6,6 +6,7 @@ python train_box.py --delta {0.1, 0.25} --tied {--uniform_pb} --loss {TB, DB} """ + from argparse import ArgumentParser import numpy as np @@ -189,9 +190,11 @@ def main(args): # noqa: C901 if not args.uniform_pb: optimizer.add_param_group( { - "params": pb_module.last_layer.parameters() - if args.tied - else pb_module.parameters(), + "params": ( + pb_module.last_layer.parameters() + if args.tied + else pb_module.parameters() + ), "lr": args.lr, } ) diff --git a/tutorials/examples/train_discreteebm.py b/tutorials/examples/train_discreteebm.py index 3574fa2d..562bb2b4 100644 --- a/tutorials/examples/train_discreteebm.py +++ b/tutorials/examples/train_discreteebm.py @@ -10,6 +10,7 @@ [Learning GFlowNets from partial episodes for improved convergence and stability](https://arxiv.org/abs/2209.12782) python train_hypergrid.py --ndim {2, 4} --height 12 --R0 {1e-3, 1e-4} --tied --loss {TB, DB, SubTB} """ + from argparse import ArgumentParser import torch diff --git a/tutorials/examples/train_hypergrid.py b/tutorials/examples/train_hypergrid.py index 4d4e3a25..f52932e9 100644 --- a/tutorials/examples/train_hypergrid.py +++ b/tutorials/examples/train_hypergrid.py @@ -10,6 +10,7 @@ [Learning GFlowNets from partial episodes for improved convergence and stability](https://arxiv.org/abs/2209.12782) python train_hypergrid.py --ndim {2, 4} --height 12 --R0 {1e-3, 1e-4} --tied --loss {TB, DB, SubTB} """ + from argparse import ArgumentParser import torch diff --git a/tutorials/examples/train_ising.py b/tutorials/examples/train_ising.py new file mode 100644 index 00000000..1ca2c656 --- /dev/null +++ b/tutorials/examples/train_ising.py @@ -0,0 +1,134 @@ +from argparse import ArgumentParser + +import torch +import wandb +from tqdm import tqdm + +from gfn.gflownet import FMGFlowNet +from gfn.gym import DiscreteEBM +from gfn.gym.discrete_ebm import IsingModel +from gfn.modules import DiscretePolicyEstimator +from gfn.utils.modules import NeuralNet +from gfn.utils.training import validate + + +def main(args): + # Configs + + use_wandb = len(args.wandb_project) > 0 + if use_wandb: + wandb.init(project=args.wandb_project) + wandb.config.update(args) + + device = "cpu" + torch.set_num_threads(args.n_threads) + hidden_dim = 512 + + n_hidden = 2 + acc_fn = "relu" + lr = 0.001 + lr_Z = 0.01 + validation_samples = 1000 + + def make_J(L, coupling_constant): + """Ising model parameters.""" + + def ising_n_to_ij(L, n): + i = n // L + j = n - i * L + return (i, j) + + N = L**2 + J = torch.zeros((N, N), device=torch.device(device)) + for k in range(N): + for m in range(k): + x1, y1 = ising_n_to_ij(L, k) + x2, y2 = ising_n_to_ij(L, m) + if x1 == x2 and abs(y2 - y1) == 1: + J[k][m] = 1 + J[m][k] = 1 + elif y1 == y2 and abs(x2 - x1) == 1: + J[k][m] = 1 + J[m][k] = 1 + + for k in range(L): + J[k * L][(k + 1) * L - 1] = 1 + J[(k + 1) * L - 1][k * L] = 1 + J[k][k + N - L] = 1 + J[k + N - L][k] = 1 + + return coupling_constant * J + + # Ising model env + N = args.L**2 + J = make_J(args.L, args.J) + ising_energy = IsingModel(J) + env = DiscreteEBM(N, alpha=1, energy=ising_energy, device_str=device) + + # Parametrization and losses + pf_module = NeuralNet( + input_dim=env.preprocessor.output_dim, + output_dim=env.n_actions, + hidden_dim=hidden_dim, + n_hidden_layers=n_hidden, + activation_fn=acc_fn, + ) + + pf_estimator = DiscretePolicyEstimator( + pf_module, env.n_actions, env.preprocessor, is_backward=False + ) + gflownet = FMGFlowNet(pf_estimator) + optimizer = torch.optim.Adam(gflownet.parameters(), lr=1e-3) + + # Learning + visited_terminating_states = env.States.from_batch_shape((0,)) + states_visited = 0 + for i in (pbar := tqdm(range(10000))): + trajectories = gflownet.sample_trajectories(env, n_samples=8, off_policy=False) + training_samples = gflownet.to_training_samples(trajectories) + optimizer.zero_grad() + loss = gflownet.loss(env, training_samples) + loss.backward() + optimizer.step() + + states_visited += len(trajectories) + to_log = {"loss": loss.item(), "states_visited": states_visited} + + if i % 25 == 0: + tqdm.write(f"{i}: {to_log}") + + +if __name__ == "__main__": + # Comand-line arguments + parser = ArgumentParser() + + parser.add_argument( + "--n_threads", + type=int, + default=4, + help="Number of threads used by PyTorch", + ) + + parser.add_argument( + "-L", + type=int, + default=6, + help="Length of the grid", + ) + + parser.add_argument( + "-J", + type=float, + default=0.44, + help="J (Magnetic coupling constant)", + ) + + parser.add_argument( + "--wandb_project", + type=str, + default="", + help="Name of the wandb project. If empty, don't use wandb", + ) + + args = parser.parse_args() + main(args)