From 40ee3b666de6098d02af0cbc1abe0ec88e65b98a Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Fri, 21 Jul 2023 15:54:22 -0400 Subject: [PATCH] addressing PR comments --- src/gflownet/config.py | 3 +++ src/gflownet/data/sampling_iterator.py | 9 +++++---- src/gflownet/tasks/qm9/qm9.py | 2 +- src/gflownet/trainer.py | 2 +- 4 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/gflownet/config.py b/src/gflownet/config.py index e7692f06..735eef79 100644 --- a/src/gflownet/config.py +++ b/src/gflownet/config.py @@ -58,6 +58,8 @@ class Config: The number of training steps after which to validate the model checkpoint_every : Optional[int] The number of training steps after which to checkpoint the model + print_every : int + The number of training steps after which to print the training loss start_at_step : int The training step to start at (default: 0) num_final_gen_steps : Optional[int] @@ -80,6 +82,7 @@ class Config: seed: int = 0 validate_every: int = 1000 checkpoint_every: Optional[int] = None + print_every: int = 100 start_at_step: int = 0 num_final_gen_steps: Optional[int] = None num_training_steps: int = 10_000 diff --git a/src/gflownet/data/sampling_iterator.py b/src/gflownet/data/sampling_iterator.py index 0d967a2f..90b8b4db 100644 --- a/src/gflownet/data/sampling_iterator.py +++ b/src/gflownet/data/sampling_iterator.py @@ -332,10 +332,11 @@ def __iter__(self): yield batch def validate_batch(self, batch, trajs): - for actions, atypes in [ - (batch.actions, self.ctx.action_type_order), - # (batch.bck_actions, self.ctx.bck_action_type_order), - ]: + for actions, atypes in [(batch.actions, self.ctx.action_type_order)] + ( + [(batch.bck_actions, self.ctx.bck_action_type_order)] + if hasattr(batch, "bck_actions") and hasattr(self.ctx, "bck_action_type_order") + else [] + ): mask_cat = GraphActionCategorical( batch, [self.model._action_type_to_mask(t, batch) for t in atypes], diff --git a/src/gflownet/tasks/qm9/qm9.py b/src/gflownet/tasks/qm9/qm9.py index 5c5a403b..7b1cf189 100644 --- a/src/gflownet/tasks/qm9/qm9.py +++ b/src/gflownet/tasks/qm9/qm9.py @@ -112,7 +112,7 @@ def set_default_hps(self, cfg: Config): def setup_env_context(self): self.ctx = MolBuildingEnvContext( - ["H", "C", "N", "F", "O"], expl_H_range=[0, 1, 2, 3], num_cond_dim=32, allow_5_valence_nitrogen=True + ["C", "N", "F", "O"], expl_H_range=[0, 1, 2, 3], num_cond_dim=32, allow_5_valence_nitrogen=True ) # Note: we only need the allow_5_valence_nitrogen flag because of how we generate trajectories # from the dataset. For example, consider tue Nitrogen atom in this: C[NH+](C)C, when s=CN(C)C, if the action diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index d835037c..747485b1 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -126,7 +126,7 @@ def __init__(self, hps: Dict[str, Any], device: torch.device): self.device = device # Print the loss every `self.print_every` iterations - self.print_every = 1 + self.print_every = self.cfg.print_every # These hooks allow us to compute extra quantities when sampling data self.sampling_hooks: List[Callable] = [] self.valid_sampling_hooks: List[Callable] = []