diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 75e15a9f..94f5d6a9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -45,6 +45,6 @@ repos: name: pytest-check entry: pytest language: python - pass_filenames: false + files: testing/ types: [python] always_run: true diff --git a/README.md b/README.md index 627f933f..81ee0b60 100644 --- a/README.md +++ b/README.md @@ -48,9 +48,13 @@ pip install . ## About this repo -This repo serves the purpose of fast prototyping [GFlowNet](https://arxiv.org/abs/2111.09266) related algorithms. It decouples the environment definition, the sampling process, and the parametrization of the function approximators used to calculate the GFN loss. +This repo serves the purpose of fast prototyping [GFlowNet](https://arxiv.org/abs/2111.09266) (GFN) related algorithms. It decouples the environment definition, the sampling process, and the parametrization of the function approximators used to calculate the GFN loss. It aims to accompany researchers and engineers in learning about GFlowNets, and in developing new algorithms. -Example scripts and notebooks are provided [here](https://github.com/saleml/torchgfn/tree/master/tutorials/). +Currently, the library is shipped with three environments: two discrete environments (Discrete Energy Based Model and Hyper Grid) and a continuous box environment. The library is designed to allow users to define their own environments. See [here](https://github.com/saleml/torchgfn/tree/master/tutorials/ENV.md) for more details. + +### Scripts and notebooks + +Example scripts and notebooks for the three environments are provided [here](https://github.com/saleml/torchgfn/tree/master/tutorials/examples). For the hyper grid and the box environments, the provided scripts are supposed to reproduce published results. ### Standalone example @@ -61,32 +65,43 @@ This example, which shows how to use the library for a simple discrete environme import torch from tqdm import tqdm -from gfn.gflownet import TBGFlowNet -from gfn.gym import HyperGrid +from gfn.gflownet import TBGFlowNet # We use a GFlowNet with the Trajectory Balance (TB) loss +from gfn.gym import HyperGrid # We use the hyper grid environment from gfn.modules import DiscretePolicyEstimator from gfn.samplers import Sampler -from gfn.utils import NeuralNet +from gfn.utils import NeuralNet # NeuralNet is a simple multi-layer perceptron (MLP) if __name__ == "__main__": + # 1 - We define the environment + env = HyperGrid(ndim=4, height=8, R0=0.01) # Grid of size 8x8x8x8 + # 2 - We define the needed modules (neural networks) + + # The environment has a preprocessor attribute, which is used to preprocess the state before feeding it to the policy estimator module_PF = NeuralNet( input_dim=env.preprocessor.output_dim, output_dim=env.n_actions - ) + ) # Neural network for the forward policy, with as many outputs as there are actions module_PB = NeuralNet( input_dim=env.preprocessor.output_dim, output_dim=env.n_actions - 1, - torso=module_PF.torso + torso=module_PF.torso # We share all the parameters of P_F and P_B, except for the last layer ) - pf_estimator = DiscretePolicyEstimator(env, module_PF, forward=True) - pb_estimator = DiscretePolicyEstimator(env, module_PB, forward=False) + # 3 - We define the estimators + + pf_estimator = DiscretePolicyEstimator(module_PF, env.n_actions, is_backward=False, preprocessor=env.preprocessor) + pb_estimator = DiscretePolicyEstimator(module_PB, env.n_actions, is_backward=True, preprocessor=env.preprocessor) - gfn = TBGFlowNet(init_logZ=0., pf=pf_estimator, pb=pb_estimator) + # 4 - We define the GFlowNet - sampler = Sampler(estimator=pf_estimator)) + gfn = TBGFlowNet(init_logZ=0., pf=pf_estimator, pb=pb_estimator) # We initialize logZ to 0 + + # 5 - We define the sampler and the optimizer + + sampler = Sampler(estimator=pf_estimator) # We use an on-policy sampler, based on the forward policy # Policy parameters have their own LR. non_logz_params = [v for k, v in dict(gfn.named_parameters()).items() if k != "logZ"] @@ -94,12 +109,14 @@ if __name__ == "__main__": # Log Z gets dedicated learning rate (typically higher). logz_params = [dict(gfn.named_parameters())["logZ"]] - optimizer.add_param_group({"params": logz_params, "lr": 1e-2}) + optimizer.add_param_group({"params": logz_params, "lr": 1e-1}) + + # 6 - We train the GFlowNet for 1000 iterations, with 16 trajectories per iteration for i in (pbar := tqdm(range(1000))): - trajectories = sampler.sample_trajectories(n_trajectories=16) + trajectories = sampler.sample_trajectories(env=env, n_trajectories=16) optimizer.zero_grad() - loss = gfn.loss(trajectories) + loss = gfn.loss(env, trajectories) loss.backward() optimizer.step() if i % 25 == 0: @@ -116,7 +133,7 @@ pre-commit install pre-commit run --all-files ``` -Run `pre-commit` after staging, and before committing. Make sure all the tests pass (By running `pytest`). +Run `pre-commit` after staging, and before committing. Make sure all the tests pass (By running `pytest`). Note that the `pytest` hook of `pre-commit` only runs the tests in the `testing/` folder. To run all the tests, which take longer, run `pytest` manually. The codebase uses `black` formatter. To make the docs locally: @@ -145,7 +162,7 @@ The `batch_shape` attribute is required to keep track of the batch dimension. A Because multiple trajectories can have different lengths, batching requires appending a dummy tensor to trajectories that are shorter than the longest trajectory. The dummy state is the $s_f$ attribute of the environment (e.g. `[-1, ..., -1]`, or `[-inf, ..., -inf]`, etc...). Which is never processed, and is used to pad the batch of states only. -For discrete environments, the action set is represented with the set $\{0, \dots, n_{actions} - 1\}$, where the $(n_{actions})$-th action always corresponds to the exit or terminate action, i.e. that results in a transition of the type $s \rightarrow s_f$, but not all actions are possible at all states. Each `States` object is endowed with two extra attributes: `forward_masks` and `backward_masks`, representing which actions are allowed at each state and which actions could have led to each state, respectively. Such states are instances of the `DiscreteStates` abstract subclass of `States`. The `forward_masks` tensor is of shape `(*batch_shape, n_{actions})`, and `backward_masks` is of shape `(*batch_shape, n_{actions} - 1)`. Each subclass of `DiscreteStates` needs to implement the `update_masks` function, that uses the environment's logic to define the two tensors. +For discrete environments, the action set is represented with the set $\{0, \dots, n_{actions} - 1\}$, where the $(n_{actions})$-th action always corresponds to the exit or terminate action, i.e. that results in a transition of the type $s \rightarrow s_f$, but not all actions are possible at all states. For discrete environments, each `States` object is endowed with two extra attributes: `forward_masks` and `backward_masks`, representing which actions are allowed at each state and which actions could have led to each state, respectively. Such states are instances of the `DiscreteStates` abstract subclass of `States`. The `forward_masks` tensor is of shape `(*batch_shape, n_{actions})`, and `backward_masks` is of shape `(*batch_shape, n_{actions} - 1)`. Each subclass of `DiscreteStates` needs to implement the `update_masks` function, that uses the environment's logic to define the two tensors. ### Actions Actions should be though of as internal actions of an agent building a compositional object. They correspond to transitions $s \rightarrow s'$. An abstract `Actions` class is provided. It is automatically subclassed for discrete environments, but needs to be manually subclassed otherwise. @@ -163,32 +180,33 @@ Containers are collections of `States`, along with other information, such as re - [Transitions](https://github.com/saleml/torchgfn/tree/master/src/gfn/containers/transitions.py), representing a batch of transitions $s \rightarrow s'$. - [Trajectories](https://github.com/saleml/torchgfn/tree/master/src/gfn/containers/trajectories.py), representing a batch of complete trajectories $\tau = s_0 \rightarrow s_1 \rightarrow \dots \rightarrow s_n \rightarrow s_f$. -These containers can either be instantiated using a `States` object, or can be initialized as empty containers that can be populated on the fly, allowing the usage of the[ReplayBuffer](https://github.com/saleml/torchgfn/tree/master/src/gfn/containers/replay_buffer.py) class. +These containers can either be instantiated using a `States` object, or can be initialized as empty containers that can be populated on the fly, allowing the usage of the [ReplayBuffer](https://github.com/saleml/torchgfn/tree/master/src/gfn/containers/replay_buffer.py) class. They inherit from the base `Container` [class](https://github.com/saleml/torchgfn/tree/master/src/gfn/containers/base.py), indicating some helpful methods. -In most cases, one needs to sample complete trajectories. From a batch of trajectories, a batch of states and batch of transitions can be defined using `Trajectories.to_transitions()` and `Trajectories.to_states()`. These exclude meaningless transitions and dummy states that were added to the batch of trajectories to allow for efficient batching. +In most cases, one needs to sample complete trajectories. From a batch of trajectories, a batch of states and batch of transitions can be defined using `Trajectories.to_transitions()` and `Trajectories.to_states()`, in order to train GFlowNets with losses that are edge-decomposable or state-decomposable. These exclude meaningless transitions and dummy states that were added to the batch of trajectories to allow for efficient batching. ### Modules -Training GFlowNets requires one or multiple estimators, called `GFNModule`s, which is an abstract subclass of `torch.nn.Module`. In addition to the usual `forward` function, `GFNModule`s need to implement a `required_output_dim` attribute, to ensure that the outputs have the required dimension for the task at hand; and some (but not all) need to implement a `to_probability_distribution` function. They take the environment `env` as an input at initialization. -- `DiscretePolicyEstimator` is a `GFNModule` that defines the policies $P_F(. \mid s)$ and $P_B(. \mid s)$ for discrete environments. When `backward=False`, the required output dimension is `n = env.n_actions`, and when `backward=True`, it is `n = env.n_actions - 1`. These `n` numbers represent the logits of a Categorical distribution. Additionally, they include exploration parameters, in order to define a tempered version of $P_F$, or a mixture of $P_F$ with a uniform distribution. Naturally, before defining the Categorical distributions, forbidden actions (that are encoded in the `DiscreteStates`' masks attributes), are given 0 probability by setting the corresponding logit to $-\infty$. +Training GFlowNets requires one or multiple estimators, called `GFNModule`s, which is an abstract subclass of `torch.nn.Module`. In addition to the usual `forward` function, `GFNModule`s need to implement a `required_output_dim` attribute, to ensure that the outputs have the required dimension for the task at hand; and some (but not all) need to implement a `to_probability_distribution` function. + +- `DiscretePolicyEstimator` is a `GFNModule` that defines the policies $P_F(. \mid s)$ and $P_B(. \mid s)$ for discrete environments. When `is_backward=False`, the required output dimension is `n = env.n_actions`, and when `is_backward=True`, it is `n = env.n_actions - 1`. These `n` numbers represent the logits of a Categorical distribution. The corresponding `to_probability_distribution` function transforms the logits by masking illegal actions (according to the forward or backward masks), then return a Categorical distribution. The masking is done by setting the corresponding logit to $-\infty$. The function also includes exploration parameters, in order to define a tempered version of $P_F$, or a mixture of $P_F$ with a uniform distribution. `DiscretePolicyEstimator`` with `is_backward=False`` can be used to represent log-edge-flow estimators $\log F(s \rightarrow s')$. - `ScalarModule` is a simple module with required output dimension 1. It is useful to define log-state flows $\log F(s)$. For non-discrete environments, the user needs to specify their own policies $P_F$ and $P_B$. The module, taking as input a batch of states (as a `States`) object, should return the batched parameters of a `torch.Distribution`. The distribution depends on the environment. The `to_probability_distribution` function handles the conversion of the parameter outputs to an actual batched `Distribution` object, that implements at least the `sample` and `log_prob` functions. An example is provided [here](https://github.com/saleml/torchgfn/tree/master/src/gfn/gym/helpers/box_utils.py), for a square environment in which the forward policy has support either on a quarter disk, or on an arc-circle, such that the angle, and the radius (for the quarter disk part) are scaled samples from a mixture of Beta distributions. The provided example shows an intricate scenario, and it is not expected that user defined environment need this much level of details. -In all `GFNModule`s, note that the input of the `forward` function is a `States` object. Meaning that they first need to be transformed to tensors. However, `states.tensor` does not necessarily include the structure that a neural network can used to generalize. It is common in these scenarios to have a function that transforms these raw tensor states to ones where the structure is clearer, via a `Preprocessor` object, that is part of the environment. More on this [here](https://github.com/saleml/torchgfn/tree/master/tutorials/ENV.md). The default preprocessor of an environment is the identity preprocessor. The `forward` pass thus first calls the `preprocessor` attribute of the environment on `States`, before performing any transformation. +In all `GFNModule`s, note that the input of the `forward` function is a `States` object. Meaning that they first need to be transformed to tensors. However, `states.tensor` does not necessarily include the structure that a neural network can used to generalize. It is common in these scenarios to have a function that transforms these raw tensor states to ones where the structure is clearer, via a `Preprocessor` object, that is part of the environment. More on this [here](https://github.com/saleml/torchgfn/tree/master/tutorials/ENV.md). The default preprocessor of an environment is the identity preprocessor. The `forward` pass thus first calls the `preprocessor` attribute of the environment on `States`, before performing any transformation. The `preprocessor` is thus an attribute of the module. If it is not explicitly defined, it is set to the identity preprocessor. -For discrete environments, tabular modules are provided, where a lookup table is used instead of a neural network. Additionally, a `UniformPB` module is provided, implementing a uniform backward policy. +For discrete environments, a `Tabular` module is provided, where a lookup table is used instead of a neural network. Additionally, a `UniformPB` module is provided, implementing a uniform backward policy. These modules are provided [here](https://github.com/saleml/torchgfn/tree/master/src/gfn/utils/modules.py). ### Samplers -A [Sampler](https://github.com/saleml/torchgfn/tree/master/src/gfn/samplers.py) object defines how actions are sampled (`sample_actions()`) at each state, and trajectories (`sample_trajectories()`), which can sample a batch of trajectories starting from a given set of initial states or starting from $s_0$. It requires a `GFNModule` that implements the `to_probability_distribution` function. +A [Sampler](https://github.com/saleml/torchgfn/tree/master/src/gfn/samplers.py) object defines how actions are sampled (`sample_actions()`) at each state, and trajectories (`sample_trajectories()`), which can sample a batch of trajectories starting from a given set of initial states or starting from $s_0$. It requires a `GFNModule` that implements the `to_probability_distribution` function. For off-policy sampling, the parameters of `to_probability_distribution` can be directly passed when initializing the `Sampler`. ### Losses -GFlowNets can be trained with different losses, each of which requires a different parametrization, which we call in this library a `GFlowNet`. A `GFlowNet` is a `GFNModule` that includes one or multiple `GFNModules`, at least one of which implements a `to_probability_distribution` function. They also need to implement a `loss` function, that takes as input either states, transitions, or trajectories, depending on the loss. +GFlowNets can be trained with different losses, each of which requires a different parametrization, which we call in this library a `GFlowNet`. A `GFlowNet` is a `GFNModule` that includes one or multiple `GFNModule`s, at least one of which implements a `to_probability_distribution` function. They also need to implement a `loss` function, that takes as input either states, transitions, or trajectories, depending on the loss. Currently, the implemented losses are: @@ -197,6 +215,3 @@ Currently, the implemented losses are: - Trajectory Balance - Sub-Trajectory Balance. By default, each sub-trajectory is weighted geometrically (within the trajectory) depending on its length. This corresponds to the strategy defined [here](https://www.semanticscholar.org/reader/f2c32fe3f7f3e2e9d36d833e32ec55fc93f900f5). Other strategies exist and are implemented [here](https://github.com/saleml/torchgfn/tree/master/src/gfn/losses/sub_trajectory_balance.py). - Log Partition Variance loss. Introduced [here](https://arxiv.org/abs/2302.05446) - -# Scripts -Example scripts are provided [here](https://github.com/saleml/torchgfn/tree/master/tutorials/examples/). They can be used to reproduce published results in the HyperGrid environment, and the Box environment. \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 1d584a8e..269e7d1c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "torchgfn" packages = [{include = "gfn", from = "src"}] -version = "1.0.0" +version = "1.0.1" description = "A torch implementation of GFlowNets" authors = ["Salem Lahou ", "Joseph Viviano ", "Victor Schmidt "] license = "MIT" diff --git a/src/gfn/gflownet/base.py b/src/gfn/gflownet/base.py index e1291a8f..b5727486 100644 --- a/src/gfn/gflownet/base.py +++ b/src/gfn/gflownet/base.py @@ -6,6 +6,7 @@ from torchtyping import TensorType as TT from gfn.containers import Trajectories +from gfn.env import Env from gfn.modules import GFNModule from gfn.samplers import Sampler from gfn.states import States @@ -18,30 +19,36 @@ class GFlowNet(nn.Module): """ @abstractmethod - def sample_trajectories(self, n_samples: int) -> Trajectories: + def sample_trajectories(self, env: Env, n_samples: int) -> Trajectories: """Sample a specific number of complete trajectories. Args: + env: the environment to sample trajectories from. n_samples: number of trajectories to be sampled. Returns: Trajectories: sampled trajectories object. """ - def sample_terminating_states(self, n_samples: int) -> States: + def sample_terminating_states(self, env: Env, n_samples: int) -> States: """Rolls out the parametrization's policy and returns the terminating states. Args: + env: the environment to sample terminating states from. n_samples: number of terminating states to be sampled. Returns: States: sampled terminating states object. """ - trajectories = self.sample_trajectories(n_samples) + trajectories = self.sample_trajectories(env, n_samples) return trajectories.last_states @abstractmethod def to_training_samples(self, trajectories: Trajectories): """Converts trajectories to training samples. The type depends on the GFlowNet.""" + @abstractmethod + def loss(self, env: Env, training_objects): + """Computes the loss given the training objects.""" + class PFBasedGFlowNet(GFlowNet): r"""Base class for gflownets that explicitly uses $P_F$. @@ -57,9 +64,9 @@ def __init__(self, pf: GFNModule, pb: GFNModule, on_policy: bool = False): self.pb = pb self.on_policy = on_policy - def sample_trajectories(self, n_samples: int = 1000) -> Trajectories: + def sample_trajectories(self, env: Env, n_samples: int) -> Trajectories: sampler = Sampler(estimator=self.pf) - trajectories = sampler.sample_trajectories(n_trajectories=n_samples) + trajectories = sampler.sample_trajectories(env, n_trajectories=n_samples) return trajectories diff --git a/src/gfn/gflownet/detailed_balance.py b/src/gfn/gflownet/detailed_balance.py index 02125c77..2310abaa 100644 --- a/src/gfn/gflownet/detailed_balance.py +++ b/src/gfn/gflownet/detailed_balance.py @@ -4,8 +4,9 @@ from torchtyping import TensorType as TT from gfn.containers import Trajectories, Transitions +from gfn.env import Env from gfn.gflownet.base import PFBasedGFlowNet -from gfn.modules import ScalarEstimator +from gfn.modules import GFNModule, ScalarEstimator class DBGFlowNet(PFBasedGFlowNet): @@ -27,17 +28,18 @@ class DBGFlowNet(PFBasedGFlowNet): def __init__( self, + pf: GFNModule, + pb: GFNModule, logF: ScalarEstimator, + on_policy: bool = False, forward_looking: bool = False, - **kwargs, ): - super().__init__(**kwargs) + super().__init__(pf, pb, on_policy=on_policy) self.logF = logF self.forward_looking = forward_looking - self.env = self.logF.env # TODO We don't want to store env in here... def get_scores( - self, transitions: Transitions + self, env: Env, transitions: Transitions ) -> Tuple[ TT["n_transitions", float], TT["n_transitions", float], @@ -72,7 +74,7 @@ def get_scores( valid_log_F_s = self.logF(states).squeeze(-1) if self.forward_looking: - log_rewards = self.env.log_reward(states) # RM unsqueeze(-1) + log_rewards = env.log_reward(states) # RM unsqueeze(-1) valid_log_F_s = valid_log_F_s + log_rewards preds = valid_log_pf_actions + valid_log_F_s @@ -110,12 +112,12 @@ def get_scores( return (valid_log_pf_actions, log_pb_actions, scores) - def loss(self, transitions: Transitions) -> TT[0, float]: + def loss(self, env: Env, transitions: Transitions) -> TT[0, float]: """Detailed balance loss. The detailed balance loss is described in section 3.2 of [GFlowNet Foundations](https://arxiv.org/abs/2111.09266).""" - _, _, scores = self.get_scores(transitions) + _, _, scores = self.get_scores(env, transitions) loss = torch.mean(scores**2) if torch.isnan(loss): @@ -182,7 +184,7 @@ def get_scores(self, transitions: Transitions) -> TT["n_trajectories", torch.flo return scores - def loss(self, transitions: Transitions) -> TT[0, float]: + def loss(self, env: Env, transitions: Transitions) -> TT[0, float]: """Calculates the modified detailed balance loss.""" scores = self.get_scores(transitions) return torch.mean(scores**2) diff --git a/src/gfn/gflownet/flow_matching.py b/src/gfn/gflownet/flow_matching.py index 998859d9..f9872732 100644 --- a/src/gfn/gflownet/flow_matching.py +++ b/src/gfn/gflownet/flow_matching.py @@ -4,6 +4,7 @@ from torchtyping import TensorType as TT from gfn.containers import Trajectories +from gfn.env import Env from gfn.gflownet.base import GFlowNet from gfn.modules import DiscretePolicyEstimator from gfn.samplers import Sampler @@ -28,23 +29,21 @@ class FMGFlowNet(GFlowNet): def __init__(self, logF: DiscretePolicyEstimator, alpha: float = 1.0): super().__init__() - assert not logF.greedy_eps self.logF = logF self.alpha = alpha - self.env = self.logF.env - if not self.env.is_discrete: + + def sample_trajectories(self, env: Env, n_samples: int = 1000) -> Trajectories: + if not env.is_discrete: raise NotImplementedError( "Flow Matching GFlowNet only supports discrete environments for now." ) - - def sample_trajectories(self, n_samples: int = 1000) -> Trajectories: sampler = Sampler(estimator=self.logF) - trajectories = sampler.sample_trajectories(n_trajectories=n_samples) + trajectories = sampler.sample_trajectories(env, n_trajectories=n_samples) return trajectories def flow_matching_loss( - self, states: DiscreteStates + self, env: Env, states: DiscreteStates ) -> TT["n_trajectories", torch.float]: """Computes the FM for the provided states. @@ -67,7 +66,7 @@ def flow_matching_loss( states.forward_masks, -float("inf"), dtype=torch.float ) - for action_idx in range(self.env.n_actions - 1): + for action_idx in range(env.n_actions - 1): valid_backward_mask = states.backward_masks[:, action_idx] valid_forward_mask = states.forward_masks[:, action_idx] valid_backward_states = states[valid_backward_mask] @@ -76,9 +75,9 @@ def flow_matching_loss( backward_actions = torch.full_like( valid_backward_states.backward_masks[:, 0], action_idx, dtype=torch.long ).unsqueeze(-1) - backward_actions = self.env.Actions(backward_actions) + backward_actions = env.Actions(backward_actions) - valid_backward_states_parents = self.env.backward_step( + valid_backward_states_parents = env.backward_step( valid_backward_states, backward_actions ) @@ -101,8 +100,11 @@ def flow_matching_loss( return (log_incoming_flows - log_outgoing_flows).pow(2).mean() - def reward_matching_loss(self, terminating_states: DiscreteStates) -> TT[0, float]: + def reward_matching_loss( + self, env: Env, terminating_states: DiscreteStates + ) -> TT[0, float]: """Calculates the reward matching loss from the terminating states.""" + del env # Unused assert terminating_states.log_rewards is not None log_edge_flows = self.logF(terminating_states) @@ -111,7 +113,9 @@ def reward_matching_loss(self, terminating_states: DiscreteStates) -> TT[0, floa log_rewards = terminating_states.log_rewards return (terminating_log_edge_flows - log_rewards).pow(2).mean() - def loss(self, states_tuple: Tuple[DiscreteStates, DiscreteStates]) -> TT[0, float]: + def loss( + self, env: Env, states_tuple: Tuple[DiscreteStates, DiscreteStates] + ) -> TT[0, float]: """Given a batch of non-terminal and terminal states, compute a loss. Unlike the GFlowNets Foundations paper, we allow more flexibility by passing a @@ -119,8 +123,8 @@ def loss(self, states_tuple: Tuple[DiscreteStates, DiscreteStates]) -> TT[0, flo (i.e. non-terminal states), and the second one being the terminal states of the trajectories.""" intermediary_states, terminating_states = states_tuple - fm_loss = self.flow_matching_loss(intermediary_states) - rm_loss = self.reward_matching_loss(terminating_states) + fm_loss = self.flow_matching_loss(env, intermediary_states) + rm_loss = self.reward_matching_loss(env, terminating_states) return fm_loss + self.alpha * rm_loss def to_training_samples( diff --git a/src/gfn/gflownet/sub_trajectory_balance.py b/src/gfn/gflownet/sub_trajectory_balance.py index bd83e707..175be074 100644 --- a/src/gfn/gflownet/sub_trajectory_balance.py +++ b/src/gfn/gflownet/sub_trajectory_balance.py @@ -4,8 +4,9 @@ from torchtyping import TensorType as TT from gfn.containers import Trajectories -from gfn.gflownet.base import PFBasedGFlowNet, TrajectoryBasedGFlowNet -from gfn.modules import ScalarEstimator +from gfn.env import Env +from gfn.gflownet.base import TrajectoryBasedGFlowNet +from gfn.modules import GFNModule, ScalarEstimator class SubTBGFlowNet(TrajectoryBasedGFlowNet): @@ -43,7 +44,10 @@ class SubTBGFlowNet(TrajectoryBasedGFlowNet): def __init__( self, + pf: GFNModule, + pb: GFNModule, logF: ScalarEstimator, + on_policy: bool = False, weighting: Literal[ "DB", "ModifiedDB", @@ -56,9 +60,8 @@ def __init__( lamda: float = 0.9, log_reward_clip_min: float = -12, # roughly log(1e-5) forward_looking: bool = False, - **kwargs, ): - super().__init__(**kwargs) + super().__init__(pf, pb, on_policy=on_policy) self.logF = logF self.weighting = weighting self.lamda = lamda @@ -89,7 +92,7 @@ def cumulative_logprobs( ) def get_scores( - self, trajectories: Trajectories + self, env: Env, trajectories: Trajectories ) -> Tuple[List[TT[0, float]], List[TT[0, float]]]: """Scores all submitted trajectories. @@ -123,7 +126,7 @@ def get_scores( log_F = self.logF(valid_states).squeeze(-1) if self.forward_looking: - log_rewards = self.logF.env.log_reward(states).unsqueeze(-1) + log_rewards = env.log_reward(states).unsqueeze(-1) log_F = log_F + log_rewards log_state_flows[mask[:-1]] = log_F @@ -188,9 +191,9 @@ def get_scores( flattening_masks, ) - def loss(self, trajectories: Trajectories) -> TT[0, float]: + def loss(self, env: Env, trajectories: Trajectories) -> TT[0, float]: # Get all scores and masks from the trajectories. - scores, flattening_masks = self.get_scores(trajectories) + scores, flattening_masks = self.get_scores(env, trajectories) flattening_mask = torch.cat(flattening_masks) all_scores = torch.cat(scores, 0) diff --git a/src/gfn/gflownet/trajectory_balance.py b/src/gfn/gflownet/trajectory_balance.py index b638789d..bceac033 100644 --- a/src/gfn/gflownet/trajectory_balance.py +++ b/src/gfn/gflownet/trajectory_balance.py @@ -8,7 +8,9 @@ from torchtyping import TensorType as TT from gfn.containers import Trajectories +from gfn.env import Env from gfn.gflownet.base import TrajectoryBasedGFlowNet +from gfn.modules import GFNModule class TBGFlowNet(TrajectoryBasedGFlowNet): @@ -28,16 +30,18 @@ class TBGFlowNet(TrajectoryBasedGFlowNet): def __init__( self, + pf: GFNModule, + pb: GFNModule, + on_policy: bool = False, init_logZ: float = 0.0, log_reward_clip_min: float = -12, # roughly log(1e-5) - **kwargs, ): - super().__init__(**kwargs) + super().__init__(pf, pb, on_policy=on_policy) self.logZ = nn.Parameter(torch.tensor(init_logZ)) self.log_reward_clip_min = log_reward_clip_min - def loss(self, trajectories: Trajectories) -> TT[0, float]: + def loss(self, env: Env, trajectories: Trajectories) -> TT[0, float]: """Trajectory balance loss. The trajectory balance loss is described in 2.3 of @@ -46,6 +50,7 @@ def loss(self, trajectories: Trajectories) -> TT[0, float]: Raises: ValueError: if the loss is NaN. """ + del env # unused _, _, scores = self.get_trajectories_scores(trajectories) loss = (scores + self.logZ).pow(2).mean() if torch.isnan(loss): @@ -64,17 +69,24 @@ class LogPartitionVarianceGFlowNet(TrajectoryBasedGFlowNet): ValueError: if the loss is NaN. """ - def __init__(self, log_reward_clip_min: float = -12, **kwargs): - super().__init__(**kwargs) + def __init__( + self, + pf: GFNModule, + pb: GFNModule, + on_policy: bool = False, + log_reward_clip_min: float = -12, + ): + super().__init__(pf, pb, on_policy=on_policy) self.log_reward_clip_min = log_reward_clip_min # -12 is roughly log(1e-5) - def loss(self, trajectories: Trajectories) -> TT[0, float]: + def loss(self, env: Env, trajectories: Trajectories) -> TT[0, float]: """Log Partition Variance loss. This method is described in section 3.2 of [ROBUST SCHEDULING WITH GFLOWNETS](https://arxiv.org/abs/2302.05446)) """ + del env # unused _, _, scores = self.get_trajectories_scores(trajectories) loss = (scores - scores.mean()).pow(2).mean() if torch.isnan(loss): diff --git a/src/gfn/gym/helpers/box_utils.py b/src/gfn/gym/helpers/box_utils.py index 0783b701..2dfd83e8 100644 --- a/src/gfn/gym/helpers/box_utils.py +++ b/src/gfn/gym/helpers/box_utils.py @@ -359,8 +359,8 @@ class DistributionWrapper(Distribution): def __init__( self, states: States, - env: Box, delta: float, + epsilon: float, mixture_logits, alpha_r, beta_r, @@ -370,7 +370,6 @@ def __init__( n_components, n_components_s0, ): - self.env = env self.idx_is_initial = torch.where(torch.all(states.tensor == 0, 1))[0] self.idx_not_initial = torch.where(torch.any(states.tensor != 0, 1))[0] self._output_shape = states.tensor.shape @@ -387,13 +386,13 @@ def __init__( self.quarter_circ = None if len(self.idx_not_initial) > 0: self.quarter_circ = QuarterCircleWithExit( - delta=self.env.delta, + delta=delta, centers=states[self.idx_not_initial], # Remove initial states. exit_probability=exit_probability[self.idx_not_initial], mixture_logits=mixture_logits[self.idx_not_initial, :n_components], alpha=alpha_theta[self.idx_not_initial, :n_components], beta=beta_theta[self.idx_not_initial, :n_components], - epsilon=self.env.epsilon, + epsilon=epsilon, ) # no sample_shape req as it is stored in centers. def sample(self, sample_shape=()): @@ -472,6 +471,7 @@ def __init__( self.n_components = n_components input_dim = 2 + self.input_dim = input_dim output_dim = 1 + 3 * self.n_components @@ -571,6 +571,7 @@ def __init__( **kwargs: passed to the NeuralNet class. """ input_dim = 2 + self.input_dim = input_dim output_dim = 3 * n_components super().__init__( @@ -619,6 +620,8 @@ class BoxPBUniform(torch.nn.Module): uniform distribution over parents in the south-western part of circle. """ + input_dim = 2 + def forward( self, preprocessed_states: TT["batch_shape", 2, float] ) -> TT["batch_shape", 3]: @@ -680,14 +683,15 @@ def __init__( min_concentration: float = 0.1, max_concentration: float = 2.0, ): - super().__init__(env, module) + super().__init__(module) self._n_comp_max = max(n_components_s0, n_components) self.n_components_s0 = n_components_s0 self.n_components = n_components self.min_concentration = min_concentration self.max_concentration = max_concentration - self.env = env + self.delta = env.delta + self.epsilon = env.epsilon def expected_output_dim(self) -> int: return 1 + 5 * self._n_comp_max @@ -736,8 +740,8 @@ def _normalize(x): return DistributionWrapper( states, - self.env, - self.env.delta, + self.delta, + self.epsilon, mixture_logits, alpha_r, beta_r, @@ -760,13 +764,15 @@ def __init__( min_concentration: float = 0.1, max_concentration: float = 2.0, ): - super().__init__(env, module) + super().__init__(module, is_backward=True) self.module = module self.n_components = n_components self.min_concentration = min_concentration self.max_concentration = max_concentration + self.delta = env.delta + def expected_output_dim(self) -> int: return 3 * self.n_components @@ -789,7 +795,7 @@ def _normalize(x): alpha = _normalize(alpha) beta = _normalize(beta) return QuarterCircle( - delta=self.env.delta, + delta=self.delta, northeastern=False, centers=states, mixture_logits=mixture_logits, diff --git a/src/gfn/gym/helpers/test_box_utils.py b/src/gfn/gym/helpers/test_box_utils.py index bc011d92..fb004140 100644 --- a/src/gfn/gym/helpers/test_box_utils.py +++ b/src/gfn/gym/helpers/test_box_utils.py @@ -1,3 +1,4 @@ +import pytest import torch from gfn.gym import Box @@ -13,14 +14,14 @@ ) -def test_mixed_distributions(): +@pytest.mark.parametrize("n_components", [5, 6]) +@pytest.mark.parametrize("n_components_s0", [5, 6]) +def test_mixed_distributions(n_components: int, n_components_s0: int): """Ensure DistributionWrapper functions correctly.""" delta = 0.1 hidden_dim = 10 n_hidden_layers = 2 - n_components = 5 - n_components_s0 = 6 environment = Box( delta=delta, diff --git a/src/gfn/modules.py b/src/gfn/modules.py index 55f6361f..1dde4fa4 100644 --- a/src/gfn/modules.py +++ b/src/gfn/modules.py @@ -5,7 +5,7 @@ from torch.distributions import Categorical, Distribution from torchtyping import TensorType as TT -from gfn.env import DiscreteEnv, Env +from gfn.preprocessors import IdentityPreprocessor, Preprocessor from gfn.states import DiscreteStates, States from gfn.utils.distributions import UnsqueezedCategorical @@ -29,12 +29,14 @@ class GFNModule(ABC, nn.Module): Otherwise, one can overwrite and use the to_probability_distribution() method to directly output a probability distribution. - The preprocessor is also encapsulated in the estimator via the environment. + The preprocessor is also encapsulated in the estimator. These function estimators implement the `__call__` method, which takes `States` objects as inputs and calls the module on the preprocessed states. Attributes: - env: the environment. + preprocessor: Preprocessor object that transforms raw States objects to tensors + that can be used as input to the module. Optional, defaults to + `IdentityPreprocessor`. module: The module to use. If the module is a Tabular module (from `gfn.utils.modules`), then the environment preprocessor needs to be an `EnumPreprocessor`. @@ -44,19 +46,31 @@ class GFNModule(ABC, nn.Module): been verified. """ - def __init__(self, env: Env, module: nn.Module) -> None: + def __init__( + self, + module: nn.Module, + preprocessor: Preprocessor | None = None, + is_backward: bool = False, + ) -> None: """Initalize the FunctionEstimator with an environment and a module. Args: - env: the environment. module: The module to use. If the module is a Tabular module (from `gfn.utils.modules`), then the environment preprocessor needs to be an `EnumPreprocessor`. + preprocessor: Preprocessor object. + is_backward: Flags estimators of probability distributions over parents. """ nn.Module.__init__(self) - self.env = env self.module = module - self.preprocessor = env.preprocessor # TODO: passed explicitly? + if preprocessor is None: + assert hasattr(module, "input_dim"), ( + "Module needs to have an attribute `input_dim` specifying the input " + + "dimension, in order to use the default IdentityPreprocessor." + ) + preprocessor = IdentityPreprocessor(module.input_dim) + self.preprocessor = preprocessor self._output_dim_is_checked = False + self.is_backward = is_backward def forward(self, states: States) -> TT["batch_shape", "output_dim", float]: out = self.module(self.preprocessor(states)) @@ -88,9 +102,12 @@ def to_probability_distribution( self, states: States, module_output: TT["batch_shape", "output_dim", float], + *args, ) -> Distribution: """Transform the output of the module into a probability distribution. + The kwargs modify a base distribution, for example to encourage exploration. + Not all modules must implement this method, but it is required to define a policy from a module's outputs. See `DiscretePolicyEstimator` for an example using a categorical distribution, but note this can be done for all continuous @@ -105,7 +122,7 @@ def expected_output_dim(self) -> int: class DiscretePolicyEstimator(GFNModule): - r"""Container for forward and backward policy estimators. + r"""Container for forward and backward policy estimators for discrete environments. $s \mapsto (P_F(s' \mid s))_{s' \in Children(s)}$. @@ -113,11 +130,6 @@ class DiscretePolicyEstimator(GFNModule): $s \mapsto (P_B(s' \mid s))_{s' \in Parents(s)}$. - Note that while this class resembles LogEdgeFlowProbabilityEstimator, they have - different semantic meaning. With LogEdgeFlowEstimator, the module output is the log - of the flow from the parent to the child, while with DiscretePFEstimator, the - module output is arbitrary. - Attributes: temperature: scalar to divide the logits by before softmax. sf_bias: scalar to subtract from the exit action logit before dividing by @@ -127,60 +139,54 @@ class DiscretePolicyEstimator(GFNModule): def __init__( self, - env: Env, module: nn.Module, - forward: bool, - greedy_eps: float = 0.0, - temperature: float = 1.0, - sf_bias: float = 0.0, - epsilon: float = 0.0, + n_actions: int, + preprocessor: Preprocessor | None, + is_backward: bool = False, ): """Initializes a estimator for P_F for discrete environments. Args: - forward: if True, then this is a forward policy, else backward policy. - greedy_eps: if > 0 , then we go off policy using greedy epsilon exploration. - temperature: scalar to divide the logits by before softmax. Does nothing - if greedy_eps is 0. - sf_bias: scalar to subtract from the exit action logit before dividing by - temperature. Does nothing if greedy_eps is 0. - epsilon: with probability epsilon, a random action is chosen. Does nothing - if greedy_eps is 0. + n_actions: Total number of actions in the Discrete Environment. + is_backward: if False, then this is a forward policy, else backward policy. """ - super().__init__(env, module) - assert greedy_eps >= 0 - self._forward = forward - self._greedy_eps = greedy_eps - self.temperature = temperature - self.sf_bias = sf_bias - self.epsilon = epsilon - - @property - def greedy_eps(self): - return self._greedy_eps + super().__init__(module, preprocessor, is_backward=is_backward) + self.n_actions = n_actions def expected_output_dim(self) -> int: - if self._forward: - return self.env.n_actions + if self.is_backward: + return self.n_actions - 1 else: - return self.env.n_actions - 1 + return self.n_actions def to_probability_distribution( self, states: DiscreteStates, module_output: TT["batch_shape", "output_dim", float], + temperature: float = 1.0, + sf_bias: float = 0.0, + epsilon: float = 0.0, ) -> Categorical: - """Returns a probability distribution given a batch of states and module output.""" - masks = states.forward_masks if self._forward else states.backward_masks + """Returns a probability distribution given a batch of states and module output. + + Args: + temperature: scalar to divide the logits by before softmax. Does nothing + if set to 1.0 (default), in which case it's on policy. + sf_bias: scalar to subtract from the exit action logit before dividing by + temperature. Does nothing if set to 0.0 (default), in which case it's + on policy. + epsilon: with probability epsilon, a random action is chosen. Does nothing + if set to 0.0 (default), in which case it's on policy.""" + masks = states.backward_masks if self.is_backward else states.forward_masks logits = module_output logits[~masks] = -float("inf") # Forward policy supports exploration in many implementations. - if self._greedy_eps: - logits[:, -1] -= self.sf_bias - probs = torch.softmax(logits / self.temperature, dim=-1) + if temperature != 1.0 or sf_bias != 0.0 or epsilon != 0.0: + logits[:, -1] -= sf_bias + probs = torch.softmax(logits / temperature, dim=-1) uniform_dist_probs = masks.float() / masks.sum(dim=-1, keepdim=True) - probs = (1 - self.epsilon) * probs + self.epsilon * uniform_dist_probs + probs = (1 - epsilon) * probs + epsilon * uniform_dist_probs return UnsqueezedCategorical(probs=probs) diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index 8868528c..83d98221 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -5,11 +5,11 @@ from gfn.actions import Actions from gfn.containers import Trajectories +from gfn.env import Env from gfn.modules import GFNModule from gfn.states import States -# TODO: Environment should not live inside the estimator and here... needs refactor. class Sampler: """`Sampler is a container for a PolicyEstimator. @@ -18,23 +18,26 @@ class Sampler: Attributes: estimator: the submitted PolicyEstimator. - env: the Environment instance inside the PolicyEstimator. - is_backward: if True, samples trajectories of actions backward (a distribution - over parents). If True, the estimator must be a ProbabilityDistribution - over parents. + probability_distribution_kwargs: keyword arguments to be passed to the `to_probability_distribution` + method of the estimator. For example, for DiscretePolicyEstimators, the kwargs can contain + the `temperature` parameter, `epsilon`, and `sf_bias`. """ - def __init__(self, estimator: GFNModule, is_backward: bool = False) -> None: + def __init__( + self, + estimator: GFNModule, + **probability_distribution_kwargs: Optional[dict], + ) -> None: self.estimator = estimator - self.env = estimator.env - self.is_backward = is_backward # TODO: take directly from estimator. + self.probability_distribution_kwargs = probability_distribution_kwargs def sample_actions( - self, states: States + self, env: Env, states: States ) -> Tuple[Actions, TT["batch_shape", torch.float]]: """Samples actions from the given states. Args: + env: The environment to sample actions from. states (States): A batch of states. Returns: @@ -45,7 +48,9 @@ def sample_actions( states. """ module_output = self.estimator(states) - dist = self.estimator.to_probability_distribution(states, module_output) + dist = self.estimator.to_probability_distribution( + states, module_output, **self.probability_distribution_kwargs + ) with torch.no_grad(): actions = dist.sample() @@ -53,16 +58,18 @@ def sample_actions( if torch.any(torch.isinf(log_probs)): raise RuntimeError("Log probabilities are inf. This should not happen.") - return self.env.Actions(actions), log_probs + return env.Actions(actions), log_probs def sample_trajectories( self, + env: Env, states: Optional[States] = None, n_trajectories: Optional[int] = None, ) -> Trajectories: """Sample trajectories sequentially. Args: + env: The environment to sample trajectories from. states: If given, trajectories would start from such states. Otherwise, trajectories are sampled from $s_o$ and n_trajectories must be provided. n_trajectories: If given, a batch of n_trajectories will be sampled all @@ -78,7 +85,7 @@ def sample_trajectories( assert ( n_trajectories is not None ), "Either states or n_trajectories should be specified" - states = self.env.reset(batch_shape=(n_trajectories,)) + states = env.reset(batch_shape=(n_trajectories,)) else: assert ( len(states.batch_shape) == 1 @@ -87,7 +94,11 @@ def sample_trajectories( device = states.tensor.device - dones = states.is_initial_state if self.is_backward else states.is_sink_state + dones = ( + states.is_initial_state + if self.estimator.is_backward + else states.is_sink_state + ) trajectories_states: List[TT["n_trajectories", "state_shape", torch.float]] = [ states.tensor @@ -104,37 +115,37 @@ def sample_trajectories( step = 0 while not all(dones): - actions = self.env.Actions.make_dummy_actions(batch_shape=(n_trajectories,)) + actions = env.Actions.make_dummy_actions(batch_shape=(n_trajectories,)) log_probs = torch.full( (n_trajectories,), fill_value=0, dtype=torch.float, device=device ) - valid_actions, actions_log_probs = self.sample_actions(states[~dones]) + valid_actions, actions_log_probs = self.sample_actions(env, states[~dones]) actions[~dones] = valid_actions log_probs[~dones] = actions_log_probs trajectories_actions += [actions] trajectories_logprobs += [log_probs] - if self.is_backward: - new_states = self.env.backward_step(states, actions) + if self.estimator.is_backward: + new_states = env.backward_step(states, actions) else: - new_states = self.env.step(states, actions) + new_states = env.step(states, actions) sink_states_mask = new_states.is_sink_state step += 1 new_dones = ( - new_states.is_initial_state if self.is_backward else sink_states_mask + new_states.is_initial_state + if self.estimator.is_backward + else sink_states_mask ) & ~dones trajectories_dones[new_dones & ~dones] = step try: - trajectories_log_rewards[new_dones & ~dones] = self.env.log_reward( + trajectories_log_rewards[new_dones & ~dones] = env.log_reward( states[new_dones & ~dones] ) except NotImplementedError: - # print(states[new_dones & ~dones]) - # print(torch.log(self.env.reward(states[new_dones & ~dones]))) trajectories_log_rewards[new_dones & ~dones] = torch.log( - self.env.reward(states[new_dones & ~dones]) + env.reward(states[new_dones & ~dones]) ) states = new_states dones = dones | new_dones @@ -142,16 +153,16 @@ def sample_trajectories( trajectories_states += [states.tensor] trajectories_states = torch.stack(trajectories_states, dim=0) - trajectories_states = self.env.States(tensor=trajectories_states) - trajectories_actions = self.env.Actions.stack(trajectories_actions) + trajectories_states = env.States(tensor=trajectories_states) + trajectories_actions = env.Actions.stack(trajectories_actions) trajectories_logprobs = torch.stack(trajectories_logprobs, dim=0) trajectories = Trajectories( - env=self.env, + env=env, states=trajectories_states, actions=trajectories_actions, when_is_done=trajectories_dones, - is_backward=self.is_backward, + is_backward=self.estimator.is_backward, log_rewards=trajectories_log_rewards, log_probs=trajectories_logprobs, ) diff --git a/src/gfn/utils/modules.py b/src/gfn/utils/modules.py index 92524a42..7379d276 100644 --- a/src/gfn/utils/modules.py +++ b/src/gfn/utils/modules.py @@ -1,6 +1,6 @@ """This file contains some examples of modules that can be used with GFN.""" -from typing import Iterator, Literal, Optional, Tuple +from typing import Literal, Optional import torch import torch.nn as nn @@ -139,6 +139,3 @@ def forward( preprocessed_states.device ) return out - - def named_parameters(self) -> Iterator[Tuple[str, Parameter]]: - return iter([]) diff --git a/testing/test_parametrizations_and_losses.py b/testing/test_parametrizations_and_losses.py index 773a7583..ac7ffb5d 100644 --- a/testing/test_parametrizations_and_losses.py +++ b/testing/test_parametrizations_and_losses.py @@ -18,7 +18,7 @@ BoxPFEstimator, BoxPFNeuralNet, ) -from gfn.modules import DiscretePolicyEstimator, GFNModule, ScalarEstimator +from gfn.modules import DiscretePolicyEstimator, ScalarEstimator from gfn.utils.modules import DiscreteUniform, NeuralNet, Tabular @@ -51,15 +51,15 @@ def test_FM(env_name: int, ndim: int, module_name: str): raise ValueError("Unknown module name") log_F_edge = DiscretePolicyEstimator( - env=env, module=module, - forward=True, + n_actions=env.n_actions, + preprocessor=env.preprocessor, ) gflownet = FMGFlowNet(log_F_edge) # forward looking by default. - trajectories = gflownet.sample_trajectories(n_samples=10) + trajectories = gflownet.sample_trajectories(env, n_samples=10) states_tuple = trajectories.to_non_initial_intermediary_and_terminating_states() - loss = gflownet.loss(states_tuple) + loss = gflownet.loss(env, states_tuple) assert loss >= 0 @@ -174,10 +174,14 @@ def PFBasedGFlowNet_with_return( n_components=ndim + 1 if module_name != "Uniform" else 1, ) else: - pf = DiscretePolicyEstimator(env, pf_module, forward=True) - pb = DiscretePolicyEstimator(env, pb_module, forward=False) + pf = DiscretePolicyEstimator( + pf_module, env.n_actions, preprocessor=env.preprocessor + ) + pb = DiscretePolicyEstimator( + pb_module, env.n_actions, preprocessor=env.preprocessor, is_backward=True + ) - logF = ScalarEstimator(env, module=logF_module) + logF = ScalarEstimator(module=logF_module, preprocessor=env.preprocessor) if gflownet_name == "DB": gflownet = DBGFlowNet( @@ -202,10 +206,10 @@ def PFBasedGFlowNet_with_return( else: raise ValueError(f"Unknown gflownet {gflownet_name}") - trajectories = gflownet.sample_trajectories(10) + trajectories = gflownet.sample_trajectories(env, 10) training_objects = gflownet.to_training_samples(trajectories) - _ = gflownet.loss(training_objects) + _ = gflownet.loss(env, training_objects) if gflownet_name == "TB": assert torch.all( @@ -299,9 +303,11 @@ def test_subTB_vs_TB( zero_logF=True, ) - trajectories = gflownet.sample_trajectories(10) - subtb_loss = gflownet.loss(trajectories) + trajectories = gflownet.sample_trajectories(env, 10) + subtb_loss = gflownet.loss(env, trajectories) if weighting == "TB": - tb_loss = TBGFlowNet(pf=pf, pb=pb).loss(trajectories) # LogZ is default 0.0. + tb_loss = TBGFlowNet(pf=pf, pb=pb).loss( + env, trajectories + ) # LogZ is default 0.0. assert (tb_loss - subtb_loss).abs() < 1e-4 diff --git a/testing/test_samplers_and_trajectories.py b/testing/test_samplers_and_trajectories.py index cfa65197..a871d940 100644 --- a/testing/test_samplers_and_trajectories.py +++ b/testing/test_samplers_and_trajectories.py @@ -65,16 +65,26 @@ def trajectory_sampling_with_return( pb_module = NeuralNet( input_dim=env.preprocessor.output_dim, output_dim=env.n_actions - 1 ) - pf_estimator = DiscretePolicyEstimator(env=env, module=pf_module, forward=True) - pb_estimator = DiscretePolicyEstimator(env=env, module=pb_module, forward=False) + pf_estimator = DiscretePolicyEstimator( + module=pf_module, + n_actions=env.n_actions, + is_backward=False, + preprocessor=env.preprocessor, + ) + pb_estimator = DiscretePolicyEstimator( + module=pb_module, + n_actions=env.n_actions, + is_backward=True, + preprocessor=env.preprocessor, + ) sampler = Sampler(estimator=pf_estimator) - trajectories = sampler.sample_trajectories(n_trajectories=5) - trajectories = sampler.sample_trajectories(n_trajectories=10) + trajectories = sampler.sample_trajectories(env, n_trajectories=5) + trajectories = sampler.sample_trajectories(env, n_trajectories=10) states = env.reset(batch_shape=5, random=True) - bw_sampler = Sampler(estimator=pb_estimator, is_backward=True) - bw_trajectories = bw_sampler.sample_trajectories(states) + bw_sampler = Sampler(estimator=pb_estimator) + bw_trajectories = bw_sampler.sample_trajectories(env, states) return trajectories, bw_trajectories, pf_estimator, pb_estimator diff --git a/tutorials/examples/test_scripts.py b/tutorials/examples/test_scripts.py new file mode 100644 index 00000000..0f63316a --- /dev/null +++ b/tutorials/examples/test_scripts.py @@ -0,0 +1,122 @@ +# This file includes tests for the three examples in the tutorials folder. +# The tests ensure that after a certain number of iterations, the final L1 distance +# or JSD between the learned distribution and the target distribution is below a +# certain threshold. + +from dataclasses import dataclass + +import pytest + +from .train_box import main as train_box_main +from .train_discreteebm import main as train_discreteebm_main +from .train_hypergrid import main as train_hypergrid_main + + +@dataclass +class CommonArgs: + no_cuda: bool = True + seed: int = 1 # We fix the seed for reproducibility + batch_size: int = 16 + replay_buffer_size: int = 0 + loss: str = "TB" + subTB_weighting: str = "geometric_within" + subTB_lambda: float = 0.9 + tabular: bool = False + uniform_pb: bool = False + tied: bool = False + hidden_dim: int = 256 + n_hidden: int = 2 + lr: float = 1e-3 + lr_Z: float = 1e-1 + n_trajectories: int = 32000 + validation_interval: int = 100 + validation_samples: int = 200000 + wandb_project: str = "" + + +@dataclass +class DiscreteEBMArgs(CommonArgs): + ndim: int = 4 + alpha: float = 1.0 + + +@dataclass +class HypergridArgs(CommonArgs): + ndim: int = 2 + height: int = 8 + R0: float = 0.1 + R1: float = 0.5 + R2: float = 2.0 + + +@dataclass +class BoxArgs(CommonArgs): + delta: float = 0.25 + min_concentration: float = 0.1 + max_concentration: float = 5.1 + n_components: int = 2 + n_components_s0: int = 4 + gamma_scheduler: float = 0.5 + scheduler_milestone: int = 2500 + lr_F: float = 1e-2 + + +@pytest.mark.parametrize("ndim", [2, 4]) +@pytest.mark.parametrize("height", [8, 16]) +def test_hypergrid(ndim: int, height: int): + n_trajectories = 32000 if ndim == 2 else 16000 + args = HypergridArgs(ndim=ndim, height=height, n_trajectories=n_trajectories) + final_l1_dist = train_hypergrid_main(args) + if ndim == 2 and height == 8: + assert final_l1_dist < 7.3e-4 + elif ndim == 2 and height == 16: + assert final_l1_dist < 4.8e-4 + elif ndim == 4 and height == 8: + assert final_l1_dist < 1.6e-4 + elif ndim == 4 and height == 16: + assert final_l1_dist < 2.45e-5 + + +@pytest.mark.parametrize("ndim", [2, 4]) +@pytest.mark.parametrize("alpha", [0.1, 1.0]) +def test_discreteebm(ndim: int, alpha: float): + n_trajectories = 16000 + args = DiscreteEBMArgs(ndim=ndim, alpha=alpha, n_trajectories=n_trajectories) + final_l1_dist = train_discreteebm_main(args) + if ndim == 2 and alpha == 0.1: + assert final_l1_dist < 0.0026 + elif ndim == 2 and alpha == 1.0: + assert final_l1_dist < 0.017 + elif ndim == 4 and alpha == 0.1: + assert final_l1_dist < 0.009 + elif ndim == 4 and alpha == 1.0: + assert final_l1_dist < 0.062 + + +@pytest.mark.parametrize("delta", [0.1, 0.25]) +@pytest.mark.parametrize("loss", ["TB", "DB"]) +def test_box(delta: float, loss: str): + n_trajectories = 128128 + validation_interval = 500 + validation_samples = 10000 + args = BoxArgs( + delta=delta, + loss=loss, + n_trajectories=n_trajectories, + hidden_dim=128, + n_hidden=4, + batch_size=128, + lr_Z=1e-3, + validation_interval=validation_interval, + validation_samples=validation_samples, + ) + print(args) + final_jsd = train_box_main(args) + if loss == "TB" and delta == 0.1: + assert final_jsd < 0.046 + elif loss == "DB" and delta == 0.1: + assert final_jsd < 0.18 + if loss == "TB" and delta == 0.25: + assert final_jsd < 0.015 + elif loss == "DB" and delta == 0.25: + assert final_jsd < 0.027 diff --git a/tutorials/examples/train_box.py b/tutorials/examples/train_box.py index 30d6d854..996f4c1f 100644 --- a/tutorials/examples/train_box.py +++ b/tutorials/examples/train_box.py @@ -82,156 +82,7 @@ def estimate_jsd(kde1, kde2): return jsd / 2.0 -if __name__ == "__main__": # noqa: C901 - parser = ArgumentParser() - - parser.add_argument("--no_cuda", action="store_true", help="Prevent CUDA usage") - - parser.add_argument( - "--delta", - type=float, - default=0.25, - help="maximum distance between two successive states", - ) - - parser.add_argument( - "--seed", - type=int, - default=0, - help="Random seed, if 0 then a random seed is used", - ) - parser.add_argument( - "--batch_size", - type=int, - default=128, - help="Batch size, i.e. number of trajectories to sample per training iteration", - ) - - parser.add_argument( - "--loss", - type=str, - choices=["TB", "DB", "SubTB", "ZVar"], - default="TB", - help="Loss function to use", - ) - parser.add_argument( - "--subTB_weighting", - type=str, - default="geometric_within", - help="weighting scheme for SubTB", - ) - parser.add_argument( - "--subTB_lambda", type=float, default=0.9, help="Lambda parameter for SubTB" - ) - - parser.add_argument( - "--min_concentration", - type=float, - default=0.1, - help="minimal value for the Beta concentration parameters", - ) - - parser.add_argument( - "--max_concentration", - type=float, - default=5.1, - help="maximal value for the Beta concentration parameters", - ) - - parser.add_argument( - "--n_components", - type=int, - default=2, - help="Number of Beta distributions for P_F(s'|s)", - ) - parser.add_argument( - "--n_components_s0", - type=int, - default=4, - help="Number of Beta distributions for P_F(s'|s_0)", - ) - - parser.add_argument("--uniform_pb", action="store_true", help="Use a uniform PB") - parser.add_argument( - "--tied", - action="store_true", - help="Tie the parameters of PF, PB. F is never tied.", - ) - parser.add_argument( - "--hidden_dim", - type=int, - default=128, - help="Hidden dimension of the estimators' neural network modules.", - ) - parser.add_argument( - "--n_hidden", - type=int, - default=4, - help="Number of hidden layers (of size `hidden_dim`) in the estimators'" - + " neural network modules", - ) - - parser.add_argument( - "--lr", - type=float, - default=1e-3, - help="Learning rate for the estimators' modules", - ) - parser.add_argument( - "--lr_Z", - type=float, - default=1e-3, - help="Specific learning rate for logZ", - ) - parser.add_argument( - "--lr_F", - type=float, - default=1e-2, - help="Specific learning rate for the state flow function (only used for DB and SubTB losses)", - ) - parser.add_argument( - "--gamma_scheduler", - type=float, - default=0.5, - help="Every scheduler_milestone steps, multiply the learning rate by gamma_scheduler", - ) - parser.add_argument( - "--scheduler_milestone", - type=int, - default=2500, - help="Every scheduler_milestone steps, multiply the learning rate by gamma_scheduler", - ) - - parser.add_argument( - "--n_trajectories", - type=int, - default=int(3e6), - help="Total budget of trajectories to train on. " - + "Training iterations = n_trajectories // batch_size", - ) - - parser.add_argument( - "--validation_interval", - type=int, - default=500, - help="How often (in training steps) to validate the gflownet", - ) - parser.add_argument( - "--validation_samples", - type=int, - default=10000, - help="Number of validation samples to use to evaluate the probability mass function.", - ) - - parser.add_argument( - "--wandb_project", - type=str, - default="", - help="Name of the wandb project. If empty, don't use wandb", - ) - - args = parser.parse_args() - +def main(args): # noqa: C901 seed = args.seed if args.seed != 0 else torch.randint(int(10e10), (1,))[0].item() torch.manual_seed(seed) @@ -297,7 +148,7 @@ def estimate_jsd(kde1, kde2): torso=None, # We do not tie the parameters of the flow function to PF logZ_value=logZ, ) - logF_estimator = ScalarEstimator(env=env, module=module) + logF_estimator = ScalarEstimator(module=module, preprocessor=env.preprocessor) if args.loss == "DB": gflownet = DBGFlowNet( @@ -378,16 +229,17 @@ def estimate_jsd(kde1, kde2): if iteration % 1000 == 0: print(f"current optimizer LR: {optimizer.param_groups[0]['lr']}") - trajectories = gflownet.sample_trajectories(n_samples=args.batch_size) + trajectories = gflownet.sample_trajectories(env, n_samples=args.batch_size) training_samples = gflownet.to_training_samples(trajectories) optimizer.zero_grad() - loss = gflownet.loss(training_samples) + loss = gflownet.loss(env, training_samples) loss.backward() for p in gflownet.parameters(): - p.grad.data.clamp_(-10, 10).nan_to_num_(0.0) + if p.ndim > 0 and p.grad is not None: # We do not clip logZ grad + p.grad.data.clamp_(-10, 10).nan_to_num_(0.0) optimizer.step() scheduler.step() @@ -408,7 +260,7 @@ def estimate_jsd(kde1, kde2): if iteration % args.validation_interval == 0: validation_samples = gflownet.sample_terminating_states( - args.validation_samples + env, args.validation_samples ) kde = KernelDensity(kernel="exponential", bandwidth=0.1).fit( validation_samples.tensor.detach().cpu().numpy() @@ -419,3 +271,158 @@ def estimate_jsd(kde1, kde2): wandb.log({"JSD": jsd}, step=iteration) to_log.update({"JSD": jsd}) + + return jsd + + +if __name__ == "__main__": + parser = ArgumentParser() + + parser.add_argument("--no_cuda", action="store_true", help="Prevent CUDA usage") + + parser.add_argument( + "--delta", + type=float, + default=0.25, + help="maximum distance between two successive states", + ) + + parser.add_argument( + "--seed", + type=int, + default=0, + help="Random seed, if 0 then a random seed is used", + ) + parser.add_argument( + "--batch_size", + type=int, + default=128, + help="Batch size, i.e. number of trajectories to sample per training iteration", + ) + + parser.add_argument( + "--loss", + type=str, + choices=["TB", "DB", "SubTB", "ZVar"], + default="TB", + help="Loss function to use", + ) + parser.add_argument( + "--subTB_weighting", + type=str, + default="geometric_within", + help="weighting scheme for SubTB", + ) + parser.add_argument( + "--subTB_lambda", type=float, default=0.9, help="Lambda parameter for SubTB" + ) + + parser.add_argument( + "--min_concentration", + type=float, + default=0.1, + help="minimal value for the Beta concentration parameters", + ) + + parser.add_argument( + "--max_concentration", + type=float, + default=5.1, + help="maximal value for the Beta concentration parameters", + ) + + parser.add_argument( + "--n_components", + type=int, + default=2, + help="Number of Beta distributions for P_F(s'|s)", + ) + parser.add_argument( + "--n_components_s0", + type=int, + default=4, + help="Number of Beta distributions for P_F(s'|s_0)", + ) + + parser.add_argument("--uniform_pb", action="store_true", help="Use a uniform PB") + parser.add_argument( + "--tied", + action="store_true", + help="Tie the parameters of PF, PB. F is never tied.", + ) + parser.add_argument( + "--hidden_dim", + type=int, + default=128, + help="Hidden dimension of the estimators' neural network modules.", + ) + parser.add_argument( + "--n_hidden", + type=int, + default=4, + help="Number of hidden layers (of size `hidden_dim`) in the estimators'" + + " neural network modules", + ) + + parser.add_argument( + "--lr", + type=float, + default=1e-3, + help="Learning rate for the estimators' modules", + ) + parser.add_argument( + "--lr_Z", + type=float, + default=1e-3, + help="Specific learning rate for logZ", + ) + parser.add_argument( + "--lr_F", + type=float, + default=1e-2, + help="Specific learning rate for the state flow function (only used for DB and SubTB losses)", + ) + parser.add_argument( + "--gamma_scheduler", + type=float, + default=0.5, + help="Every scheduler_milestone steps, multiply the learning rate by gamma_scheduler", + ) + parser.add_argument( + "--scheduler_milestone", + type=int, + default=2500, + help="Every scheduler_milestone steps, multiply the learning rate by gamma_scheduler", + ) + + parser.add_argument( + "--n_trajectories", + type=int, + default=int(3e6), + help="Total budget of trajectories to train on. " + + "Training iterations = n_trajectories // batch_size", + ) + + parser.add_argument( + "--validation_interval", + type=int, + default=500, + help="How often (in training steps) to validate the gflownet", + ) + parser.add_argument( + "--validation_samples", + type=int, + default=10000, + help="Number of validation samples to use to evaluate the probability mass function.", + ) + + parser.add_argument( + "--wandb_project", + type=str, + default="", + help="Name of the wandb project. If empty, don't use wandb", + ) + + args = parser.parse_args() + + print(main(args)) diff --git a/tutorials/examples/train_discreteebm.py b/tutorials/examples/train_discreteebm.py index 84191f2d..a7aab784 100644 --- a/tutorials/examples/train_discreteebm.py +++ b/tutorials/examples/train_discreteebm.py @@ -23,6 +23,80 @@ from gfn.utils.common import validate from gfn.utils.modules import NeuralNet, Tabular + +def main(args): # noqa: C901 + seed = args.seed if args.seed != 0 else torch.randint(int(10e10), (1,))[0].item() + torch.manual_seed(seed) + + device_str = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu" + + use_wandb = len(args.wandb_project) > 0 + if use_wandb: + wandb.init(project=args.wandb_project) + wandb.config.update(args) + + # 1. Create the environment + env = DiscreteEBM(ndim=args.ndim, alpha=args.alpha, device_str=device_str) + + # 2. Create the gflownet. + # We need a LogEdgeFlowEstimator + if args.tabular: + module = Tabular(n_states=env.n_states, output_dim=env.n_actions) + else: + module = NeuralNet( + input_dim=env.preprocessor.output_dim, + output_dim=env.n_actions, + hidden_dim=args.hidden_dim, + n_hidden_layers=args.n_hidden, + ) + estimator = DiscretePolicyEstimator( + module=module, + n_actions=env.n_actions, + preprocessor=env.preprocessor, + ) + gflownet = FMGFlowNet(estimator) + + # 3. Create the optimizer + optimizer = torch.optim.Adam(module.parameters(), lr=args.lr) + + # 4. Train the gflownet + + visited_terminating_states = env.States.from_batch_shape((0,)) + + states_visited = 0 + n_iterations = args.n_trajectories // args.batch_size + validation_info = {"l1_dist": float("inf")} + for iteration in trange(n_iterations): + trajectories = gflownet.sample_trajectories(env, n_samples=args.batch_size) + training_samples = gflownet.to_training_samples(trajectories) + + optimizer.zero_grad() + loss = gflownet.loss(env, training_samples) + loss.backward() + optimizer.step() + + visited_terminating_states.extend(trajectories.last_states) + + states_visited += len(trajectories) + + to_log = {"loss": loss.item(), "states_visited": states_visited} + if use_wandb: + wandb.log(to_log, step=iteration) + if iteration % args.validation_interval == 0: + validation_info = validate( + env, + gflownet, + args.validation_samples, + visited_terminating_states, + ) + if use_wandb: + wandb.log(validation_info, step=iteration) + to_log.update(validation_info) + tqdm.write(f"{iteration}: {to_log}") + + return validation_info["l1_dist"] + + if __name__ == "__main__": parser = ArgumentParser() @@ -104,66 +178,4 @@ args = parser.parse_args() - seed = args.seed if args.seed != 0 else torch.randint(int(10e10), (1,))[0].item() - torch.manual_seed(seed) - - device_str = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu" - - use_wandb = len(args.wandb_project) > 0 - if use_wandb: - wandb.init(project=args.wandb_project) - wandb.config.update(args) - - # 1. Create the environment - env = DiscreteEBM(ndim=args.ndim, alpha=args.alpha) - - # 2. Create the gflownet. - # We need a LogEdgeFlowEstimator - if args.tabular: - module = Tabular(n_states=env.n_states, output_dim=env.n_actions) - else: - module = NeuralNet( - input_dim=env.preprocessor.output_dim, - output_dim=env.n_actions, - hidden_dim=args.hidden_dim, - n_hidden_layers=args.n_hidden, - ) - estimator = DiscretePolicyEstimator(env=env, module=module, forward=True) - gflownet = FMGFlowNet(estimator) - - # 3. Create the optimizer - optimizer = torch.optim.Adam(module.parameters(), lr=args.lr) - - # 4. Train the gflownet - - visited_terminating_states = env.States.from_batch_shape((0,)) - - states_visited = 0 - n_iterations = args.n_trajectories // args.batch_size - for iteration in trange(n_iterations): - trajectories = gflownet.sample_trajectories(n_samples=args.batch_size) - training_samples = gflownet.to_training_samples(trajectories) - - optimizer.zero_grad() - loss = gflownet.loss(training_samples) - loss.backward() - optimizer.step() - - visited_terminating_states.extend(trajectories.last_states) - - states_visited += len(trajectories) - - to_log = {"loss": loss.item(), "states_visited": states_visited} - if use_wandb: - wandb.log(to_log, step=iteration) - if iteration % args.validation_interval == 0: - validation_info = validate( - env, - gflownet, - args.validation_samples, - visited_terminating_states, - ) - if use_wandb: - wandb.log(validation_info, step=iteration) - to_log.update(validation_info) - tqdm.write(f"{iteration}: {to_log}") + print(main(args)) diff --git a/tutorials/examples/train_hypergrid.py b/tutorials/examples/train_hypergrid.py index 5a613faf..e9fd465c 100644 --- a/tutorials/examples/train_hypergrid.py +++ b/tutorials/examples/train_hypergrid.py @@ -31,123 +31,8 @@ from gfn.utils.common import validate from gfn.utils.modules import DiscreteUniform, NeuralNet, Tabular -if __name__ == "__main__": # noqa: C901 - parser = ArgumentParser() - - parser.add_argument("--no_cuda", action="store_true", help="Prevent CUDA usage") - - parser.add_argument( - "--ndim", type=int, default=2, help="Number of dimensions in the environment" - ) - parser.add_argument( - "--height", type=int, default=8, help="Height of the environment" - ) - parser.add_argument("--R0", type=float, default=0.1, help="Environment's R0") - parser.add_argument("--R1", type=float, default=0.5, help="Environment's R1") - parser.add_argument("--R2", type=float, default=2.0, help="Environment's R2") - - parser.add_argument( - "--seed", - type=int, - default=0, - help="Random seed, if 0 then a random seed is used", - ) - parser.add_argument( - "--batch_size", - type=int, - default=16, - help="Batch size, i.e. number of trajectories to sample per training iteration", - ) - parser.add_argument( - "--replay_buffer_size", - type=int, - default=0, - help="If zero, no replay buffer is used. Otherwise, the replay buffer is used.", - ) - - parser.add_argument( - "--loss", - type=str, - choices=["FM", "TB", "DB", "SubTB", "ZVar", "ModifiedDB"], - default="TB", - help="Loss function to use", - ) - parser.add_argument( - "--subTB_weighting", - type=str, - default="geometric_within", - help="weighting scheme for SubTB", - ) - parser.add_argument( - "--subTB_lambda", type=float, default=0.9, help="Lambda parameter for SubTB" - ) - - parser.add_argument( - "--tabular", - action="store_true", - help="Use a lookup table for F, PF, PB instead of an estimator", - ) - parser.add_argument("--uniform_pb", action="store_true", help="Use a uniform PB") - parser.add_argument( - "--tied", action="store_true", help="Tie the parameters of PF, PB, and F" - ) - parser.add_argument( - "--hidden_dim", - type=int, - default=256, - help="Hidden dimension of the estimators' neural network modules.", - ) - parser.add_argument( - "--n_hidden", - type=int, - default=2, - help="Number of hidden layers (of size `hidden_dim`) in the estimators'" - + " neural network modules", - ) - - parser.add_argument( - "--lr", - type=float, - default=1e-3, - help="Learning rate for the estimators' modules", - ) - parser.add_argument( - "--lr_Z", - type=float, - default=0.1, - help="Specific learning rate for Z (only used for TB loss)", - ) - - parser.add_argument( - "--n_trajectories", - type=int, - default=int(1e6), - help="Total budget of trajectories to train on. " - + "Training iterations = n_trajectories // batch_size", - ) - - parser.add_argument( - "--validation_interval", - type=int, - default=100, - help="How often (in training steps) to validate the gflownet", - ) - parser.add_argument( - "--validation_samples", - type=int, - default=200000, - help="Number of validation samples to use to evaluate the probability mass function.", - ) - - parser.add_argument( - "--wandb_project", - type=str, - default="", - help="Name of the wandb project. If empty, don't use wandb", - ) - - args = parser.parse_args() +def main(args): # noqa: C901 seed = args.seed if args.seed != 0 else torch.randint(int(10e10), (1,))[0].item() torch.manual_seed(seed) @@ -181,7 +66,11 @@ hidden_dim=args.hidden_dim, n_hidden_layers=args.n_hidden, ) - estimator = DiscretePolicyEstimator(env=env, module=module, forward=True) + estimator = DiscretePolicyEstimator( + module=module, + n_actions=env.n_actions, + preprocessor=env.preprocessor, + ) gflownet = FMGFlowNet(estimator) else: pb_module = None @@ -215,8 +104,17 @@ pb_module is not None ), f"pb_module is None. Command-line arguments: {args}" - pf_estimator = DiscretePolicyEstimator(env=env, module=pf_module, forward=True) - pb_estimator = DiscretePolicyEstimator(env=env, module=pb_module, forward=False) + pf_estimator = DiscretePolicyEstimator( + module=pf_module, + n_actions=env.n_actions, + preprocessor=env.preprocessor, + ) + pb_estimator = DiscretePolicyEstimator( + module=pb_module, + n_actions=env.n_actions, + is_backward=True, + preprocessor=env.preprocessor, + ) if args.loss == "ModifiedDB": gflownet = ModifiedDBGFlowNet( @@ -245,7 +143,9 @@ torso=pf_module.torso if args.tied else None, ) - logF_estimator = ScalarEstimator(env=env, module=module) + logF_estimator = ScalarEstimator( + module=module, preprocessor=env.preprocessor + ) if args.loss == "DB": gflownet = DBGFlowNet( pf=pf_estimator, @@ -320,8 +220,9 @@ states_visited = 0 n_iterations = args.n_trajectories // args.batch_size + validation_info = {"l1_dist": float("inf")} for iteration in trange(n_iterations): - trajectories = gflownet.sample_trajectories(n_samples=args.batch_size) + trajectories = gflownet.sample_trajectories(env, n_samples=args.batch_size) training_samples = gflownet.to_training_samples(trajectories) if replay_buffer is not None: with torch.no_grad(): @@ -331,7 +232,7 @@ training_objects = training_samples optimizer.zero_grad() - loss = gflownet.loss(training_objects) + loss = gflownet.loss(env, training_objects) loss.backward() optimizer.step() @@ -353,3 +254,125 @@ wandb.log(validation_info, step=iteration) to_log.update(validation_info) tqdm.write(f"{iteration}: {to_log}") + + return validation_info["l1_dist"] + + +if __name__ == "__main__": + parser = ArgumentParser() + + parser.add_argument("--no_cuda", action="store_true", help="Prevent CUDA usage") + + parser.add_argument( + "--ndim", type=int, default=2, help="Number of dimensions in the environment" + ) + parser.add_argument( + "--height", type=int, default=8, help="Height of the environment" + ) + parser.add_argument("--R0", type=float, default=0.1, help="Environment's R0") + parser.add_argument("--R1", type=float, default=0.5, help="Environment's R1") + parser.add_argument("--R2", type=float, default=2.0, help="Environment's R2") + + parser.add_argument( + "--seed", + type=int, + default=0, + help="Random seed, if 0 then a random seed is used", + ) + parser.add_argument( + "--batch_size", + type=int, + default=16, + help="Batch size, i.e. number of trajectories to sample per training iteration", + ) + parser.add_argument( + "--replay_buffer_size", + type=int, + default=0, + help="If zero, no replay buffer is used. Otherwise, the replay buffer is used.", + ) + + parser.add_argument( + "--loss", + type=str, + choices=["FM", "TB", "DB", "SubTB", "ZVar", "ModifiedDB"], + default="TB", + help="Loss function to use", + ) + parser.add_argument( + "--subTB_weighting", + type=str, + default="geometric_within", + help="weighting scheme for SubTB", + ) + parser.add_argument( + "--subTB_lambda", type=float, default=0.9, help="Lambda parameter for SubTB" + ) + + parser.add_argument( + "--tabular", + action="store_true", + help="Use a lookup table for F, PF, PB instead of an estimator", + ) + parser.add_argument("--uniform_pb", action="store_true", help="Use a uniform PB") + parser.add_argument( + "--tied", action="store_true", help="Tie the parameters of PF, PB, and F" + ) + parser.add_argument( + "--hidden_dim", + type=int, + default=256, + help="Hidden dimension of the estimators' neural network modules.", + ) + parser.add_argument( + "--n_hidden", + type=int, + default=2, + help="Number of hidden layers (of size `hidden_dim`) in the estimators'" + + " neural network modules", + ) + + parser.add_argument( + "--lr", + type=float, + default=1e-3, + help="Learning rate for the estimators' modules", + ) + parser.add_argument( + "--lr_Z", + type=float, + default=0.1, + help="Specific learning rate for Z (only used for TB loss)", + ) + + parser.add_argument( + "--n_trajectories", + type=int, + default=int(1e6), + help="Total budget of trajectories to train on. " + + "Training iterations = n_trajectories // batch_size", + ) + + parser.add_argument( + "--validation_interval", + type=int, + default=100, + help="How often (in training steps) to validate the gflownet", + ) + parser.add_argument( + "--validation_samples", + type=int, + default=200000, + help="Number of validation samples to use to evaluate the probability mass function.", + ) + + parser.add_argument( + "--wandb_project", + type=str, + default="", + help="Name of the wandb project. If empty, don't use wandb", + ) + + args = parser.parse_args() + + print(main(args))