Skip to content

Commit

Permalink
addressing PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
bengioe committed Jul 21, 2023
1 parent 78ff1be commit 40ee3b6
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 6 deletions.
3 changes: 3 additions & 0 deletions src/gflownet/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ class Config:
The number of training steps after which to validate the model
checkpoint_every : Optional[int]
The number of training steps after which to checkpoint the model
print_every : int
The number of training steps after which to print the training loss
start_at_step : int
The training step to start at (default: 0)
num_final_gen_steps : Optional[int]
Expand All @@ -80,6 +82,7 @@ class Config:
seed: int = 0
validate_every: int = 1000
checkpoint_every: Optional[int] = None
print_every: int = 100
start_at_step: int = 0
num_final_gen_steps: Optional[int] = None
num_training_steps: int = 10_000
Expand Down
9 changes: 5 additions & 4 deletions src/gflownet/data/sampling_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,10 +332,11 @@ def __iter__(self):
yield batch

def validate_batch(self, batch, trajs):
for actions, atypes in [
(batch.actions, self.ctx.action_type_order),
# (batch.bck_actions, self.ctx.bck_action_type_order),
]:
for actions, atypes in [(batch.actions, self.ctx.action_type_order)] + (
[(batch.bck_actions, self.ctx.bck_action_type_order)]
if hasattr(batch, "bck_actions") and hasattr(self.ctx, "bck_action_type_order")
else []
):
mask_cat = GraphActionCategorical(
batch,
[self.model._action_type_to_mask(t, batch) for t in atypes],
Expand Down
2 changes: 1 addition & 1 deletion src/gflownet/tasks/qm9/qm9.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def set_default_hps(self, cfg: Config):

def setup_env_context(self):
self.ctx = MolBuildingEnvContext(
["H", "C", "N", "F", "O"], expl_H_range=[0, 1, 2, 3], num_cond_dim=32, allow_5_valence_nitrogen=True
["C", "N", "F", "O"], expl_H_range=[0, 1, 2, 3], num_cond_dim=32, allow_5_valence_nitrogen=True
)
# Note: we only need the allow_5_valence_nitrogen flag because of how we generate trajectories
# from the dataset. For example, consider tue Nitrogen atom in this: C[NH+](C)C, when s=CN(C)C, if the action
Expand Down
2 changes: 1 addition & 1 deletion src/gflownet/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def __init__(self, hps: Dict[str, Any], device: torch.device):

self.device = device
# Print the loss every `self.print_every` iterations
self.print_every = 1
self.print_every = self.cfg.print_every
# These hooks allow us to compute extra quantities when sampling data
self.sampling_hooks: List[Callable] = []
self.valid_sampling_hooks: List[Callable] = []
Expand Down

0 comments on commit 40ee3b6

Please sign in to comment.