Skip to content

Commit

Permalink
Merge pull request #163 from GFNOrg/dont_recompute_masks
Browse files Browse the repository at this point in the history
Dont recompute masks
  • Loading branch information
josephdviviano authored Feb 27, 2024
2 parents 9d00335 + 7996b37 commit a565d4b
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 8 deletions.
4 changes: 4 additions & 0 deletions src/gfn/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,10 @@ def _step(
not_done_actions = actions[~new_sink_states_idx]

new_not_done_states_tensor = self.step(not_done_states, not_done_actions)
if not isinstance(new_not_done_states_tensor, torch.Tensor):
raise Exception(
"User implemented env.step function *must* return a torch.Tensor!"
)

new_states.tensor[~new_sink_states_idx] = new_not_done_states_tensor

Expand Down
12 changes: 5 additions & 7 deletions src/gfn/samplers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from copy import deepcopy
from typing import List, Optional, Tuple

import torch
Expand All @@ -7,7 +8,7 @@
from gfn.containers import Trajectories
from gfn.env import Env
from gfn.modules import GFNModule
from gfn.states import States
from gfn.states import States, stack_states


class Sampler:
Expand Down Expand Up @@ -140,9 +141,7 @@ def sample_trajectories(
else states.is_sink_state
)

trajectories_states: List[TT["n_trajectories", "state_shape", torch.float]] = [
states.tensor
]
trajectories_states: List[States] = [deepcopy(states)]
trajectories_actions: List[TT["n_trajectories", torch.long]] = []
trajectories_logprobs: List[TT["n_trajectories", torch.float]] = []
trajectories_dones = torch.zeros(
Expand Down Expand Up @@ -219,10 +218,9 @@ def sample_trajectories(
states = new_states
dones = dones | new_dones

trajectories_states += [states.tensor]
trajectories_states += [deepcopy(states)]

trajectories_states = torch.stack(trajectories_states, dim=0)
trajectories_states = env.states_from_tensor(trajectories_states)
trajectories_states = stack_states(trajectories_states)
trajectories_actions = env.Actions.stack(trajectories_actions)
trajectories_logprobs = torch.stack(trajectories_logprobs, dim=0)

Expand Down
32 changes: 31 additions & 1 deletion src/gfn/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from abc import ABC, abstractmethod
from copy import deepcopy
from math import prod
from typing import Callable, ClassVar, Optional, Sequence, cast
from typing import Callable, ClassVar, List, Optional, Sequence, cast

import torch
from torchtyping import TensorType as TT
Expand Down Expand Up @@ -409,6 +409,8 @@ def set_nonexit_action_masks(self, cond, allow_exit: bool):
allow_exit: sets whether exiting can happen at any point in the
trajectory - if so, it should be set to True.
"""
# Resets masks in place to prevent side-effects across steps.
self.forward_masks[:] = True
if allow_exit:
exit_idx = torch.zeros(self.batch_shape + (1,)).to(cond.device)
else:
Expand Down Expand Up @@ -446,3 +448,31 @@ def init_forward_masks(self, set_ones: bool = True):
self.forward_masks = torch.ones(shape).bool()
else:
self.forward_masks = torch.zeros(shape).bool()


def stack_states(states: List[States]):
"""Given a list of states, stacks them along a new dimension (0)."""
state_example = states[0] # We assume all elems of `states` are the same.

stacked_states = state_example.from_batch_shape((0, 0)) # Empty.
stacked_states.tensor = torch.stack([s.tensor for s in states], dim=0)
if state_example._log_rewards:
stacked_states._log_rewards = torch.stack(
[s._log_rewards for s in states], dim=0
)

# We are dealing with a list of DiscretrStates instances.
if hasattr(state_example, "forward_masks"):
stacked_states.forward_masks = torch.stack(
[s.forward_masks for s in states], dim=0
)
stacked_states.backward_masks = torch.stack(
[s.backward_masks for s in states], dim=0
)

# Adds the trajectory dimension.
stacked_states.batch_shape = (
stacked_states.tensor.shape[0],
) + state_example.batch_shape

return stacked_states

0 comments on commit a565d4b

Please sign in to comment.