diff --git a/src/gfn/env.py b/src/gfn/env.py index 9b045ca3..510d3820 100644 --- a/src/gfn/env.py +++ b/src/gfn/env.py @@ -219,6 +219,10 @@ def _step( not_done_actions = actions[~new_sink_states_idx] new_not_done_states_tensor = self.step(not_done_states, not_done_actions) + if not isinstance(new_not_done_states_tensor, torch.Tensor): + raise Exception( + "User implemented env.step function *must* return a torch.Tensor!" + ) new_states.tensor[~new_sink_states_idx] = new_not_done_states_tensor diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index 68b052a6..a2f810b6 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -1,3 +1,4 @@ +from copy import deepcopy from typing import List, Optional, Tuple import torch @@ -7,7 +8,7 @@ from gfn.containers import Trajectories from gfn.env import Env from gfn.modules import GFNModule -from gfn.states import States +from gfn.states import States, stack_states class Sampler: @@ -140,9 +141,7 @@ def sample_trajectories( else states.is_sink_state ) - trajectories_states: List[TT["n_trajectories", "state_shape", torch.float]] = [ - states.tensor - ] + trajectories_states: List[States] = [deepcopy(states)] trajectories_actions: List[TT["n_trajectories", torch.long]] = [] trajectories_logprobs: List[TT["n_trajectories", torch.float]] = [] trajectories_dones = torch.zeros( @@ -219,10 +218,9 @@ def sample_trajectories( states = new_states dones = dones | new_dones - trajectories_states += [states.tensor] + trajectories_states += [deepcopy(states)] - trajectories_states = torch.stack(trajectories_states, dim=0) - trajectories_states = env.states_from_tensor(trajectories_states) + trajectories_states = stack_states(trajectories_states) trajectories_actions = env.Actions.stack(trajectories_actions) trajectories_logprobs = torch.stack(trajectories_logprobs, dim=0) diff --git a/src/gfn/states.py b/src/gfn/states.py index 53492861..cb48b130 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from copy import deepcopy from math import prod -from typing import Callable, ClassVar, Optional, Sequence, cast +from typing import Callable, ClassVar, List, Optional, Sequence, cast import torch from torchtyping import TensorType as TT @@ -409,6 +409,8 @@ def set_nonexit_action_masks(self, cond, allow_exit: bool): allow_exit: sets whether exiting can happen at any point in the trajectory - if so, it should be set to True. """ + # Resets masks in place to prevent side-effects across steps. + self.forward_masks[:] = True if allow_exit: exit_idx = torch.zeros(self.batch_shape + (1,)).to(cond.device) else: @@ -446,3 +448,31 @@ def init_forward_masks(self, set_ones: bool = True): self.forward_masks = torch.ones(shape).bool() else: self.forward_masks = torch.zeros(shape).bool() + + +def stack_states(states: List[States]): + """Given a list of states, stacks them along a new dimension (0).""" + state_example = states[0] # We assume all elems of `states` are the same. + + stacked_states = state_example.from_batch_shape((0, 0)) # Empty. + stacked_states.tensor = torch.stack([s.tensor for s in states], dim=0) + if state_example._log_rewards: + stacked_states._log_rewards = torch.stack( + [s._log_rewards for s in states], dim=0 + ) + + # We are dealing with a list of DiscretrStates instances. + if hasattr(state_example, "forward_masks"): + stacked_states.forward_masks = torch.stack( + [s.forward_masks for s in states], dim=0 + ) + stacked_states.backward_masks = torch.stack( + [s.backward_masks for s in states], dim=0 + ) + + # Adds the trajectory dimension. + stacked_states.batch_shape = ( + stacked_states.tensor.shape[0], + ) + state_example.batch_shape + + return stacked_states