Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Redundant calling of update_masks in dafault sampler #160

Closed
listar2000 opened this issue Feb 17, 2024 · 2 comments
Closed

Redundant calling of update_masks in dafault sampler #160

listar2000 opened this issue Feb 17, 2024 · 2 comments
Assignees

Comments

@listar2000
Copy link

Hi, I'm currently developing my own environment and following the training script given in the example (i.e. I do on_policy sampling using the forward policy), and I'm a little bit confused about the following redundancy:

  1. during the sequential sampling process, at each step, the current state calls the update_masks method in its initialization method, which sets up the forward_masks and backward_masks.
  2. once the sampling is done, the trajectories are batched into a new States object in the default samplers.py:
trajectories_states = env.States(tensor=trajectories_states)

(p.s. I think in the nightly version this is changed to)

trajectories_states = env.states_from_tensor(trajectories_states)

which essentially does the same thing. This new state is initialized and then call update_masks on all the states (in the trajectories of states) again, which I believe has been already calculated once in step 1. So why bother repeating this process and not reusing the already computed masks?

Lots of thanks for any explanation for this :).

@saleml saleml self-assigned this Feb 18, 2024
@josephdviviano josephdviviano self-assigned this Feb 19, 2024
@josephdviviano
Copy link
Collaborator

josephdviviano commented Feb 20, 2024

Thanks for this great point. We need a States.stack() method or stack_states() function which will accept a list of states to avoid this recomputation. Addressed in #161 .

Pseudocode

from gfn.containers.utils import stack_states

        while not all(dones):
            actions = env.actions_from_batch_shape((n_trajectories,))  # Dummy actions.
            valid_actions, actions_log_probs, estimator_outputs = self.sample_actions(
                env,
                states[~dones],
                save_estimator_outputs=True if save_estimator_outputs else False,
                calculate_logprobs=False if skip_logprob_calculaion else True,
                **policy_kwargs,
            )
            ...
            actions[~dones] = valid_actions
            ...
            if self.estimator.is_backward:
                new_states = env._backward_step(states, actions)
            else:
                new_states = env._step(states, actions)
            ...
            new_dones = (new_states.is_initial_state if self.estimator.is_backward else sink_states_mask ) & ~dones
            trajectories_dones[new_dones & ~dones] = step
            ...    
            states = new_states
            dones = dones | new_dones

            trajectories_states += [states]

        trajectories_states = stack_states(trajectories_states, dim=0)

And this stack_states method would extend all relevant attributes of the submitted states along the trajectory dim, and would return a Trajectories object.

@josephdviviano
Copy link
Collaborator

josephdviviano commented Feb 22, 2024

We're working on this here #163

Edit - this issue is resolved!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants