Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rethinking sampling #147

Merged
merged 70 commits into from
Feb 16, 2024
Merged

Rethinking sampling #147

merged 70 commits into from
Feb 16, 2024

Conversation

josephdviviano
Copy link
Collaborator

@josephdviviano josephdviviano commented Nov 24, 2023

This PR is a hodgepodge of a few tweaks, bugfixes, and investigations related to the sampling logic, including a new simple continuous example.

  • estimator_outputs are now re-used when sampling off policy.
    • This is accomplished currently using padding. In a future PR this will be handled using nested tensors. I will likely wait until their API solidifies (it will apparently change in the near future).
  • policy_kwargs are passed around properly to do off policy sampling, hopefully in a way that remains generic.
  • TODO:s added around based on my observations.
  • log_reward_min_clamp now off by default and only defined in the gflownets, NOT the environment.
  • Other minor tweaks I made while chasing a ghost relating to magically changing numbers (I now believe this to be due to faulty RAM on my laptop).
  • API change: the user must now specify often whether a gflownet or sampler is running off_policy or not. This is for efficiency. If sampling happens off policy, we save estimator outputs (because we assume we will need them later, to evaluate log probabilities of actions under the policy). If it's done on_policy, we calculate the log-rewards during the forward pass.

@josephdviviano
Copy link
Collaborator Author

Sorry in advance for the large PR - feel free to be critical ...

Copy link
Collaborator

@marpaia marpaia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, I think all of the code in this PR makes a lot of sense. I don't deeply understand all of your intent as well as you do (of course) but I think everything here is well put together.

In general, I have some broad bits of feedback:

  • This PR is really large. While this is often the case necessarily with changes that aim to make large improvements across the whole codebase, it's always worth trying to make changes more focused when possible. I know you already know that though so no worries.
  • We should probably come up with some strategy for dealing with the TODOs. I've worked in large projects before where each TODO had to be associated with a specific GitHub Issue for example. A lot of them relate to copy semantics which seems like a good, focused thing that we could pursue in isolation.
  • I understand your intent with introducing a very generic policy_kwargs dictionary as its not possible to know what parameters might be needed by continuous off-policy exploration. I think we should keep an eye on how that winds up getting used in practice though. It may be possible to type those parameters more strongly in the future.

In general, I think this is worth merging! Not the least of which because I'm excited about #149 😄

@josephdviviano
Copy link
Collaborator Author

josephdviviano commented Nov 30, 2023

Thanks for the feedback!

I really like your idea of associating each TODO with an issue. That would also make it easier to go fix the thing (you just search for the issue number in the code).

I can do this in a follow up PR!

Sorry, I knew I was being naughty when I submitted this monster PR. It essentially was a grab bag of things I tried, while trying to get the library to play ball for the gflownet workshop, and it would have been really annoying to go split it out into various PRs post-hoc. I figured it wasn't too bad because I was the only one working on it but I agree this is horrible practice and not conducive to collaboration.

I understand your desire for strong typing on the policy_kwargs but I'm worried it will add a lot of developer overhead. We should keep this in mind. In general, I don't want researchers to have to think too hard about software engineering stuff when using this library - we should figure out the minimal set of good engineering practices that are researcher friendly. To be frank, the variance in engineering skills in research is enormous, because the pipeline does not select much for engineering ability, and I think this library will have the greatest impact if we can make it accessible to as many of those people as possible.

And just to clarify the intent of the PR: I was addressing multiple points of feedback:

  • It's inefficient to do two forward passes on a neural network when one would suffice.
  • The intention to train off or on policy was too often implicit. It wasn't obvious to new users how to sample off policy. So now everything is very explicit (and often required to be passed by the user).
  • The copy stuff is arising from discussions with a collaborator at Intel who did some profiling of the library. I think that piece is far from over - but the stuff I did was mostly trying to track down that bug with the slightly changing values.

@josephdviviano
Copy link
Collaborator Author

I'll wait for @saleml (who is defending next week so is likely distracted at the minute) to merge. No rush!

@saleml
Copy link
Collaborator

saleml commented Dec 18, 2023

On it !

Copy link
Collaborator

@saleml saleml left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for my very late review.

This is a great PR, that would make using the library much simpler. Thanks a lot @josephdviviano.

I left a few comments, questions and suggestions. They are minor. Hopefully the tests would pass after the fixes

@@ -65,7 +77,7 @@ def __init__(
self.env = env
self.is_backward = is_backward
self.states = (
states
states.clone() # TODO: Do we need this clone?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see why we would need that

@@ -155,6 +168,12 @@ def __getitem__(self, index: int | Sequence[int]) -> Trajectories:
self._log_rewards[index] if self._log_rewards is not None else None
)

if is_tensor(self.estimator_outputs):
estimator_outputs = self.estimator_outputs[:, index]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This implicitly assumes that self.estimator_outputs is of shape max_length x n_trajectories (as is the case for example for self.log_probs). Would this always be the case?

I feel like things would easily break here unless we force some structure on estimator_outputs. Rather than torch.Tensor, it has to be some TensorType with a specific shape IMO.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think of simply:

        if is_tensor(self.estimator_outputs):
            estimator_outputs = self.estimator_outputs[..., index]
            estimator_outputs = estimator_outputs[:new_max_length]

?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that should work !

# Either set, or append, estimator outputs if they exist in the submitted
# trajectory.
if self.estimator_outputs is None and is_tensor(other.estimator_outputs):
self.estimator_outputs = other.estimator_outputs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but how would we match the indices of the trajectories to the indices of the estimator_outputs ?

This feels dangerous. I suggest just throwing an error when one is None and the other is not (either one).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the idea is to be able to extend an empty Trajectories instance, say with a stored buffer.

I agree it is dangerous but I think we should support this behaviour.

Admittedly it has been some time since I looked at this so I might be forgetting something.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough!

other_shape = np.array(other.estimator_outputs.shape)
required_first_dim = max(self_shape[0], other_shape[0])

# TODO: This should be a single reused function.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right ! There is a function elsewhere that does something similar. Maybe for a next PR.

src/gfn/env.py Outdated
@@ -83,7 +79,7 @@ def reset(
assert not (random and sink)

if random and seed is not None:
torch.manual_seed(seed)
torch.manual_seed(seed) # TODO: Improve seeding here?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you made a set_seed function in common.py

@@ -119,6 +116,7 @@ def make_random_states_tensor(
device=env.device,
)

# TODO: Look into make masks - I don't think this is being called.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. This function can safely be deleted.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed!


def __setitem__(
self, index: int | Sequence[int] | Sequence[bool], states: States
) -> None:
"""Set particular states of the batch."""
self.tensor[index] = states.tensor

def clone(self) -> States:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about batch_shape and log_reward attributes ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right -- I think the easiest solution here is to use deepcopy - what do you think?

arch.append(nn.Linear(hidden_dim, hidden_dim))
arch.append(activation())
self.torso = nn.Sequential(*arch)
self.torso.hidden_dim = hidden_dim # TODO: what is this?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Storing the hidden_dim attribute in self.torso.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome !

return states.tensor.float()
return (
states.tensor.float()
) # TODO: should we typecast here? not a true identity...
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand the question

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It means that, the identity preprocessor is typecasting data, which seems like unexpected behaviour possibly. I would expect this to return whatever tensor is already inside states untouched.

@josephdviviano
Copy link
Collaborator Author

Hey Salem -- the indexing changes I implemented to fix this

#147 (comment)

have broken the tests -- I'm working on that now. I opened a can of worms here! There's likely an elegant solution.

@josephdviviano
Copy link
Collaborator Author

Finally I revered the change for that failing test. I'm not sure there's a better solution that isn't extremely complicated.

I'd like to get this PR merged but we can happily revisit this issue in a future much smaller PR.

@josephdviviano
Copy link
Collaborator Author

@saleml would love you to check this before I merge :)

@saleml
Copy link
Collaborator

saleml commented Feb 16, 2024

This is some great work! Thank you Joseph for this very important PR.

I have read your replies to my comments, and seen that the tests pass. I think this can be merged as is.

Small nitpick: For this comment: #147 (comment), do you think we should add the argument in the abstract function?

@josephdviviano josephdviviano merged commit eedc7e8 into master Feb 16, 2024
3 checks passed
@josephdviviano
Copy link
Collaborator Author

Small nitpick: For this comment: #147 (comment), do you think we should add the argument in the abstract function?

I added it!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants