Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix off policy #174

Merged
merged 14 commits into from
Apr 2, 2024
Merged

Fix off policy #174

merged 14 commits into from
Apr 2, 2024

Conversation

saleml
Copy link
Collaborator

@saleml saleml commented Mar 21, 2024

This fixes #168.
The idea is to remove the arguments we had before off_policy and sample_off_policy, and be explicit about what we're evaluating and storing when sampling.
When being on_policy, we should store the logprobs. This is the default.
When being off_policy, with a tempered/modified PF, we should only store estimator_outputs.
When we use a replay buffer, we don't need to store anything - we should recalculate the logprobs.

Additionally, this fixes FM + ReplayBuffer, that was broken before, because states extension didn't take into account the _log_probs attribute.

Copy link
Collaborator

@josephdviviano josephdviviano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really like these API changes -- I have a few small questions before we approve (but this might not require any further changes to the code -- I just want to understand).

src/gfn/env.py Outdated
@@ -393,7 +393,7 @@ class DiscreteEnvStates(DiscreteStates):

def make_actions_class(self) -> type[Actions]:
env = self
n_actions = self.n_actions
self.n_actions
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's going on here? I find this confusing.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm adding it back in. I'm sure this works and potentially correct but I find it weird, I suspect others will as well.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what happened. Actually, we don't need that line altogether (thanks Pylance) !
I'm removing the whole line

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok that works for me ;)

@@ -66,19 +72,20 @@ def get_scores(self, env: Env, transitions: Transitions) -> Tuple[

if states.batch_shape != tuple(actions.batch_shape):
raise ValueError("Something wrong happening with log_pf evaluations")
if not self.off_policy:
if (
transitions.log_probs is not None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm seeing this logic a few times in the code. Should we abstract it into a utility like

def has_log_probs(obj):
    return obj.log_probs is not None and obj.log_probs.nelement() > 0

?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added this utility function.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great !

# Evaluate the log PF of the actions sampled off policy.
# I suppose the Transitions container should then have some
# estimator_outputs attribute as well, to avoid duplication here ?
# See (#156).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you remove this issue reference (#156) ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad! Added back

@@ -53,7 +53,9 @@ def loss(self, env: Env, trajectories: Trajectories) -> TT[0, float]:
ValueError: if the loss is NaN.
"""
del env # unused
_, _, scores = self.get_trajectories_scores(trajectories)
_, _, scores = self.get_trajectories_scores(
trajectories, recalculate_all=recalculate_all
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if there's a more explicit name for recalculate_all -- like recalculate_all_logprobs?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, good idea, done

policy_kwargs: keyword arguments to be passed to the
`to_probability_distribution` method of the estimator. For example, for
DiscretePolicyEstimators, the kwargs can contain the `temperature`
parameter, `epsilon`, and `sf_bias`. In the continuous case these
kwargs will be user defined. This can be used to, for example, sample
off-policy.
debug_mode: if True, everything gets calculated.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is debug_mode removed? If I recall, this was important for tests.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this the same as recalculate_all?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right -- I'll change it back and add a note :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea! Can the function be a class method of Container?
self.has_log_prob() looks more natural than has_log_prob(self)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, the only issue is we actually use it in TrajectoryBasedGFlowNet

@saleml saleml mentioned this pull request Apr 2, 2024
Copy link
Collaborator

@josephdviviano josephdviviano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm!

@josephdviviano josephdviviano merged commit 74cd34e into master Apr 2, 2024
3 checks passed
@josephdviviano josephdviviano deleted the fix_off_policy branch April 2, 2024 14:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Replay buffer broken in train_hypergrid.py
2 participants