Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix off policy #174

Merged
merged 14 commits into from
Apr 2, 2024
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ torch = ">=1.9.0"
torchtyping = ">=0.1.4"

# dev dependencies.
black = { version = "24.2", optional = true }
black = { version = "24.3", optional = true }
flake8 = { version = "*", optional = true }
gitmopy = { version = "*", optional = true }
myst-parser = { version = "*", optional = true }
Expand Down
3 changes: 3 additions & 0 deletions src/gfn/containers/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import os
from typing import TYPE_CHECKING, Literal

import torch

from gfn.containers.trajectories import Trajectories
from gfn.containers.transitions import Transitions

Expand Down Expand Up @@ -48,6 +50,7 @@ def __init__(
elif objects_type == "states":
self.training_objects = env.states_from_batch_shape((0,))
self.terminating_states = env.states_from_batch_shape((0,))
self.terminating_states.log_rewards = torch.zeros((0,), device=env.device)
self.objects_type = "states"
else:
raise ValueError(f"Unknown objects_type: {objects_type}")
Expand Down
10 changes: 8 additions & 2 deletions src/gfn/containers/trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from gfn.containers.base import Container
from gfn.containers.transitions import Transitions
from gfn.utils.common import has_log_probs


def is_tensor(t) -> bool:
Expand Down Expand Up @@ -54,7 +55,7 @@ def __init__(
is_backward: bool = False,
log_rewards: TT["n_trajectories", torch.float] | None = None,
log_probs: TT["max_length", "n_trajectories", torch.float] | None = None,
estimator_outputs: torch.Tensor | None = None,
estimator_outputs: TT["batch_shape", "output_dim", torch.float] | None = None,
) -> None:
"""
Args:
Expand Down Expand Up @@ -325,7 +326,12 @@ def to_transitions(self) -> Transitions:
],
dim=0,
)
log_probs = self.log_probs[~self.actions.is_dummy]

# Only return logprobs if they exist.
log_probs = (
self.log_probs[~self.actions.is_dummy] if has_log_probs(self) else None
)

return Transitions(
env=self.env,
states=states,
Expand Down
6 changes: 5 additions & 1 deletion src/gfn/containers/transitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from gfn.states import States

from gfn.containers.base import Container
from gfn.utils.common import has_log_probs


class Transitions(Container):
Expand Down Expand Up @@ -186,7 +187,10 @@ def __getitem__(self, index: int | Sequence[int]) -> Transitions:
log_rewards = (
self._log_rewards[index] if self._log_rewards is not None else None
)
log_probs = self.log_probs[index]

# Only return logprobs if they exist.
log_probs = self.log_probs[index] if has_log_probs(self) else None

return Transitions(
env=self.env,
states=states,
Expand Down
1 change: 0 additions & 1 deletion src/gfn/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,6 @@ class DiscreteEnvStates(DiscreteStates):

def make_actions_class(self) -> type[Actions]:
env = self
n_actions = self.n_actions

class DiscreteEnvActions(Actions):
action_shape = env.action_shape
Expand Down
71 changes: 45 additions & 26 deletions src/gfn/gflownet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from gfn.modules import GFNModule
from gfn.samplers import Sampler
from gfn.states import States
from gfn.utils.common import has_log_probs

TrainingSampleType = TypeVar(
"TrainingSampleType", bound=Union[Container, tuple[States, ...]]
Expand All @@ -29,14 +30,20 @@ class GFlowNet(ABC, nn.Module, Generic[TrainingSampleType]):

@abstractmethod
def sample_trajectories(
self, env: Env, n_samples: int, sample_off_policy: bool
self,
env: Env,
n_samples: int,
save_logprobs: bool = True,
save_estimator_outputs: bool = False,
) -> Trajectories:
"""Sample a specific number of complete trajectories.

Args:
env: the environment to sample trajectories from.
n_samples: number of trajectories to be sampled.
sample_off_policy: whether to sample trajectories on / off policy.
save_logprobs: whether to save the logprobs of the actions - useful for on-policy learning.
save_estimator_outputs: whether to save the estimator outputs - useful for off-policy learning
with tempered policy
Returns:
Trajectories: sampled trajectories object.
"""
Expand All @@ -50,7 +57,9 @@ def sample_terminating_states(self, env: Env, n_samples: int) -> States:
Returns:
States: sampled terminating states object.
"""
trajectories = self.sample_trajectories(env, n_samples, sample_off_policy=False)
trajectories = self.sample_trajectories(
env, n_samples, save_estimator_outputs=False, save_logprobs=False
)
return trajectories.last_states

def logz_named_parameters(self):
Expand All @@ -76,21 +85,26 @@ class PFBasedGFlowNet(GFlowNet[TrainingSampleType]):
pb: GFNModule
"""

def __init__(self, pf: GFNModule, pb: GFNModule, off_policy: bool):
def __init__(self, pf: GFNModule, pb: GFNModule):
super().__init__()
self.pf = pf
self.pb = pb
self.off_policy = off_policy

def sample_trajectories(
self, env: Env, n_samples: int, sample_off_policy: bool, **policy_kwargs
self,
env: Env,
n_samples: int,
save_logprobs: bool = True,
save_estimator_outputs: bool = False,
**policy_kwargs,
) -> Trajectories:
"""Samples trajectories, optionally with specified policy kwargs."""
sampler = Sampler(estimator=self.pf)
trajectories = sampler.sample_trajectories(
env,
n_trajectories=n_samples,
off_policy=sample_off_policy,
save_estimator_outputs=save_estimator_outputs,
save_logprobs=save_logprobs,
**policy_kwargs,
)

Expand All @@ -108,6 +122,7 @@ def get_pfs_and_pbs(
self,
trajectories: Trajectories,
fill_value: float = 0.0,
recalculate_all_logprobs: bool = False,
) -> Tuple[
TT["max_length", "n_trajectories", torch.float],
TT["max_length", "n_trajectories", torch.float],
Expand All @@ -117,17 +132,16 @@ def get_pfs_and_pbs(
More specifically it evaluates $\log P_F (s' \mid s)$ and $\log P_B(s \mid s')$
for each transition in each trajectory in the batch.

Useful when the policy used to sample the trajectories is different from
the one used to evaluate the loss. Otherwise we can use the logprobs directly
from the trajectories.

Note - for off policy exploration, the trajectories submitted to this method
will be sampled off policy.
Unless recalculate_all_logprobs=True, in which case we re-evaluate the logprobs of the trajectories with
the current self.pf. The following applies:
- If trajectories have log_probs attribute, use them - this is usually for on-policy learning
- Else, if trajectories have estimator_outputs attribute, transform them
into log_probs - this is usually for off-policy learning with a tempered policy
- Else, if trajectories have none of them, re-evaluate the log_probs
using the current self.pf - this is usually for off-policy learning with replay buffer

Args:
trajectories: Trajectories to evaluate.
estimator_outputs: Optional stored estimator outputs from previous forward
sampling (encountered, for example, when sampling off policy).
fill_value: Value to use for invalid states (i.e. $s_f$ that is added to
shorter trajectories).

Expand All @@ -151,16 +165,18 @@ def get_pfs_and_pbs(
if valid_states.batch_shape != tuple(valid_actions.batch_shape):
raise AssertionError("Something wrong happening with log_pf evaluations")

if self.off_policy:
# We re-use the values calculated in .sample_trajectories().
if trajectories.estimator_outputs is not None:
if has_log_probs(trajectories) and not recalculate_all_logprobs:
log_pf_trajectories = trajectories.log_probs
else:
if (
trajectories.estimator_outputs is not None
and not recalculate_all_logprobs
):
estimator_outputs = trajectories.estimator_outputs[
~trajectories.actions.is_dummy
]
else:
raise Exception(
"GFlowNet is off policy, but no estimator_outputs found in Trajectories!"
)
estimator_outputs = self.pf(valid_states)

# Calculates the log PF of the actions sampled off policy.
valid_log_pf_actions = self.pf.to_probability_distribution(
Expand All @@ -175,9 +191,6 @@ def get_pfs_and_pbs(
)
log_pf_trajectories[~trajectories.actions.is_dummy] = valid_log_pf_actions

else:
log_pf_trajectories = trajectories.log_probs

non_initial_valid_states = valid_states[~valid_states.is_initial_state]
non_exit_valid_actions = valid_actions[~valid_actions.is_exit]

Expand All @@ -201,13 +214,19 @@ def get_pfs_and_pbs(

return log_pf_trajectories, log_pb_trajectories

def get_trajectories_scores(self, trajectories: Trajectories) -> Tuple[
def get_trajectories_scores(
self,
trajectories: Trajectories,
recalculate_all_logprobs: bool = False,
) -> Tuple[
TT["n_trajectories", torch.float],
TT["n_trajectories", torch.float],
TT["n_trajectories", torch.float],
]:
"""Given a batch of trajectories, calculate forward & backward policy scores."""
log_pf_trajectories, log_pb_trajectories = self.get_pfs_and_pbs(trajectories)
log_pf_trajectories, log_pb_trajectories = self.get_pfs_and_pbs(
trajectories, recalculate_all_logprobs=recalculate_all_logprobs
)

assert log_pf_trajectories is not None
total_log_pf_trajectories = log_pf_trajectories.sum(dim=0)
Expand Down
43 changes: 30 additions & 13 deletions src/gfn/gflownet/detailed_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from gfn.env import Env
from gfn.gflownet.base import PFBasedGFlowNet
from gfn.modules import GFNModule, ScalarEstimator
from gfn.utils.common import has_log_probs


class DBGFlowNet(PFBasedGFlowNet[Transitions]):
Expand All @@ -23,7 +24,6 @@ class DBGFlowNet(PFBasedGFlowNet[Transitions]):

Attributes:
logF: a ScalarEstimator instance.
off_policy: If true, we need to reevaluate the log probs.
forward_looking: whether to implement the forward looking GFN loss.
log_reward_clip_min: If finite, clips log rewards to this value.
"""
Expand All @@ -33,16 +33,17 @@ def __init__(
pf: GFNModule,
pb: GFNModule,
logF: ScalarEstimator,
off_policy: bool,
forward_looking: bool = False,
log_reward_clip_min: float = -float("inf"),
):
super().__init__(pf, pb, off_policy=off_policy)
super().__init__(pf, pb)
self.logF = logF
self.forward_looking = forward_looking
self.log_reward_clip_min = log_reward_clip_min

def get_scores(self, env: Env, transitions: Transitions) -> Tuple[
def get_scores(
self, env: Env, transitions: Transitions, recalculate_all_logprobs: bool = False
) -> Tuple[
TT["n_transitions", float],
TT["n_transitions", float],
TT["n_transitions", float],
Expand All @@ -52,6 +53,12 @@ def get_scores(self, env: Env, transitions: Transitions) -> Tuple[
Args:
transitions: a batch of transitions.

Unless recalculate_all_logprobs=True, in which case we re-evaluate the logprobs of the transitions with
the current self.pf. The following applies:
- If transitions have log_probs attribute, use them - this is usually for on-policy learning
- Else, re-evaluate the log_probs using the current self.pf - this is usually for
off-policy learning with replay buffer

Raises:
ValueError: when supplied with backward transitions.
AssertionError: when log rewards of transitions are None.
Expand All @@ -66,19 +73,20 @@ def get_scores(self, env: Env, transitions: Transitions) -> Tuple[

if states.batch_shape != tuple(actions.batch_shape):
raise ValueError("Something wrong happening with log_pf evaluations")
if not self.off_policy:

if has_log_probs(transitions) and not recalculate_all_logprobs:
valid_log_pf_actions = transitions.log_probs
else:
# Evaluate the log PF of the actions sampled off policy.
# I suppose the Transitions container should then have some
# Evaluate the log PF of the actions
module_output = self.pf(
states
) # TODO: Inefficient duplication in case of tempered policy
# The Transitions container should then have some
# estimator_outputs attribute as well, to avoid duplication here ?
# See (#156).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you remove this issue reference (#156) ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad! Added back

module_output = self.pf(states) # TODO: Inefficient duplication.
valid_log_pf_actions = self.pf.to_probability_distribution(
states, module_output
).log_prob(
actions.tensor
) # Actions sampled off policy.
).log_prob(actions.tensor)

valid_log_F_s = self.logF(states).squeeze(-1)
if self.forward_looking:
Expand Down Expand Up @@ -147,9 +155,17 @@ class ModifiedDBGFlowNet(PFBasedGFlowNet[Transitions]):
https://arxiv.org/abs/2202.13903 for more details.
"""

def get_scores(self, transitions: Transitions) -> TT["n_trajectories", torch.float]:
def get_scores(
self, transitions: Transitions, recalculate_all_logprobs: bool = False
) -> TT["n_trajectories", torch.float]:
"""DAG-GFN-style detailed balance, when all states are connected to the sink.

Unless recalculate_all_logprobs=True, in which case we re-evaluate the logprobs of the transitions with
the current self.pf. The following applies:
- If transitions have log_probs attribute, use them - this is usually for on-policy learning
- Else, re-evaluate the log_probs using the current self.pf - this is usually for
off-policy learning with replay buffer

Raises:
ValueError: when backward transitions are supplied (not supported).
ValueError: when the computed scores contain `inf`.
Expand All @@ -164,7 +180,8 @@ def get_scores(self, transitions: Transitions) -> TT["n_trajectories", torch.flo
all_log_rewards = transitions.all_log_rewards[mask]
module_output = self.pf(states)
pf_dist = self.pf.to_probability_distribution(states, module_output)
if not self.off_policy:

if has_log_probs(transitions) and not recalculate_all_logprobs:
valid_log_pf_actions = transitions[mask].log_probs
else:
# Evaluate the log PF of the actions sampled off policy.
Expand Down
6 changes: 4 additions & 2 deletions src/gfn/gflownet/flow_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def __init__(self, logF: DiscretePolicyEstimator, alpha: float = 1.0):
def sample_trajectories(
self,
env: Env,
off_policy: bool,
save_logprobs: bool,
save_estimator_outputs: bool = False,
n_samples: int = 1000,
**policy_kwargs: Optional[dict],
) -> Trajectories:
Expand All @@ -49,7 +50,8 @@ def sample_trajectories(
trajectories = sampler.sample_trajectories(
env,
n_trajectories=n_samples,
off_policy=off_policy,
save_estimator_outputs=save_estimator_outputs,
save_logprobs=save_logprobs,
**policy_kwargs,
)
return trajectories
Expand Down
3 changes: 1 addition & 2 deletions src/gfn/gflownet/sub_trajectory_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def __init__(
pf: GFNModule,
pb: GFNModule,
logF: ScalarEstimator,
off_policy: bool,
weighting: Literal[
"DB",
"ModifiedDB",
Expand All @@ -70,7 +69,7 @@ def __init__(
log_reward_clip_min: float = -float("inf"),
forward_looking: bool = False,
):
super().__init__(pf, pb, off_policy=off_policy)
super().__init__(pf, pb)
self.logF = logF
self.weighting = weighting
self.lamda = lamda
Expand Down
Loading
Loading