Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dont recompute masks #163

Merged
merged 23 commits into from
Feb 27, 2024
Merged

Dont recompute masks #163

merged 23 commits into from
Feb 27, 2024

Conversation

josephdviviano
Copy link
Collaborator

This isn't working @saleml -- please see the issue in samplers.py.

You can reproduce the error with tutorials/examples/train_hypergrid_simple.py

…proposed new method for stacking a list of states into a trajectory, but as the assert statements show, the tensor is correct, but the forward_masks are not
@saleml
Copy link
Collaborator

saleml commented Feb 24, 2024

Some debugging:

First, I changed the batch_size to 3 in the script. Then, with a breakpoint at the assertion error, I see that the forward masks of all the steps within the batch of 3 trajectories are the same.

So here is the interesting part. If I add a breakpoint at lines 195 and 199 of samplers.py, I get the following:

Screenshot 2024-02-24 at 12 58 53 PM

The first mask is ok. Initially, we have 3 copies of s0, so the masks should all be True. Once we call _step, we haven't explicitly modified the list trajectories_states_b. Yet, the forward_masks of its only element changed, just because we called _step.

So the problem happens here in env.py
Screenshot 2024-02-24 at 1 04 10 PM

And looking at the update_masks function implemented in Hypergrid:

Screenshot 2024-02-24 at 1 04 53 PM

it seems to me that the masks are changed in place and that the problem is due to #149. I don't remember if tests were passing in that PR (I haven't reviewed that PR).

Let me know what you think

Comment on lines 459 to 460
stacked_states.forward_masks = torch.stack([s.forward_masks for s in states], dim=0)
stacked_states.backward_masks = torch.stack([s.backward_masks for s in states], dim=0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should only be implemented for DiscreteStates, not all States

@josephdviviano
Copy link
Collaborator Author

Good catch! These tests pass fine - I think the inplace update of masks is desirable behaviour except in this case where we want to accumulate a trajectory of states.

To reduce user error, the base States class could have a clone method, which return deepcopy(self).

When I was messing around, I tried copying the states. I should have instead used deepcopy, which prevents the forward masks from being updated inplace.

@josephdviviano
Copy link
Collaborator Author

OK @saleml figured it out - check line 413 here 77e7e1b

Before setting the False elements, since we are doing inplace operations, we must first set all values to True to prevent side effects over multiple steps (self.forward_masks[:] = True).

This, plus using deepcopy where appropriate, fixes the issue, and we no longer recompute masks.

@josephdviviano josephdviviano self-assigned this Feb 24, 2024
@josephdviviano josephdviviano marked this pull request as ready for review February 24, 2024 20:47
@josephdviviano josephdviviano added the bug Something isn't working label Feb 24, 2024
checks whether user-defined env.step method returns the expected type
@josephdviviano josephdviviano merged commit a565d4b into master Feb 27, 2024
3 checks passed
@josephdviviano josephdviviano deleted the dont_recompute_masks branch February 27, 2024 14:26
@josephdviviano josephdviviano restored the dont_recompute_masks branch February 27, 2024 14:26
@josephdviviano josephdviviano deleted the dont_recompute_masks branch April 2, 2024 14:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants