Skip to content

Commit

Permalink
black
Browse files Browse the repository at this point in the history
  • Loading branch information
josephdviviano committed Apr 2, 2024
1 parent 89c72b5 commit 9ae95a5
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
5 changes: 4 additions & 1 deletion src/gfn/gflownet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,10 @@ def get_pfs_and_pbs(
if has_log_probs(trajectories) and not recalculate_all_logprobs:
log_pf_trajectories = trajectories.log_probs
else:
if trajectories.estimator_outputs is not None and not recalculate_all_logprobs:
if (
trajectories.estimator_outputs is not None
and not recalculate_all_logprobs
):
estimator_outputs = trajectories.estimator_outputs[
~trajectories.actions.is_dummy
]
Expand Down
10 changes: 8 additions & 2 deletions src/gfn/gflownet/trajectory_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@ def __init__(
self.log_reward_clip_min = log_reward_clip_min

def loss(
self, env: Env, trajectories: Trajectories, recalculate_all_logprobs: bool = False
self,
env: Env,
trajectories: Trajectories,
recalculate_all_logprobs: bool = False,
) -> TT[0, float]:
"""Trajectory balance loss.
Expand Down Expand Up @@ -83,7 +86,10 @@ def __init__(
self.log_reward_clip_min = log_reward_clip_min

def loss(
self, env: Env, trajectories: Trajectories, recalculate_all_logprobs: bool = False
self,
env: Env,
trajectories: Trajectories,
recalculate_all_logprobs: bool = False,
) -> TT[0, float]:
"""Log Partition Variance loss.
Expand Down

0 comments on commit 9ae95a5

Please sign in to comment.