Skip to content

Commit

Permalink
stack_states now ignores masks for non-discrete states, and fixed bug…
Browse files Browse the repository at this point in the history
… in mask updating behaviour to prevent accumulation of errors.
  • Loading branch information
josephdviviano committed Feb 24, 2024
1 parent 7b536a2 commit 77e7e1b
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions src/gfn/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,12 +409,15 @@ def set_nonexit_action_masks(self, cond, allow_exit: bool):
allow_exit: sets whether exiting can happen at any point in the
trajectory - if so, it should be set to True.
"""
# Resets masks in place to prevent side-effects across steps.
self.forward_masks[:] = True
if allow_exit:
exit_idx = torch.zeros(self.batch_shape + (1,)).to(cond.device)
else:
exit_idx = torch.ones(self.batch_shape + (1,)).to(cond.device)
self.forward_masks[torch.cat([cond, exit_idx], dim=-1).bool()] = False


def set_exit_masks(self, batch_idx):
"""Sets forward masks such that the only allowable next action is to exit.
Expand Down Expand Up @@ -456,8 +459,11 @@ def stack_states(states: List[States]):
stacked_states.tensor = torch.stack([s.tensor for s in states], dim=0)
if state_example._log_rewards:
stacked_states._log_rewards = torch.stack([s._log_rewards for s in states], dim=0)
stacked_states.forward_masks = torch.stack([s.forward_masks for s in states], dim=0)
stacked_states.backward_masks = torch.stack([s.backward_masks for s in states], dim=0)

# We are dealing with a list of DiscretrStates instances.
if hasattr(state_example, "forward_masks"):
stacked_states.forward_masks = torch.stack([s.forward_masks for s in states], dim=0)
stacked_states.backward_masks = torch.stack([s.backward_masks for s in states], dim=0)

# Adds the trajectory dimension.
stacked_states.batch_shape = (stacked_states.tensor.shape[0],) + state_example.batch_shape
Expand Down

0 comments on commit 77e7e1b

Please sign in to comment.