Skip to content

Commit

Permalink
Fix two bugs with --logging_first_step (#8193)
Browse files Browse the repository at this point in the history
* make sure that logging_first_step evaluates

* fix bug with incorrect loss on logging_first_step

* fix style

* logging_first_step only logs, not evals
  • Loading branch information
abisee authored Oct 30, 2020
1 parent 689ff74 commit 8f1c960
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
6 changes: 5 additions & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,6 +729,7 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D

tr_loss = torch.tensor(0.0).to(self.args.device)
self._logging_loss_scalar = 0
self._globalstep_last_logged = 0
self._total_flos = self.state.total_flos
model.zero_grad()

Expand Down Expand Up @@ -849,14 +850,17 @@ def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch):
if self.control.should_log:
logs: Dict[str, float] = {}
tr_loss_scalar = tr_loss.item()
logs["loss"] = (tr_loss_scalar - self._logging_loss_scalar) / self.args.logging_steps
logs["loss"] = (tr_loss_scalar - self._logging_loss_scalar) / (
self.state.global_step - self._globalstep_last_logged
)
# backward compatibility for pytorch schedulers
logs["learning_rate"] = (
self.lr_scheduler.get_last_lr()[0]
if version.parse(torch.__version__) >= version.parse("1.4")
else self.lr_scheduler.get_lr()[0]
)
self._logging_loss_scalar = tr_loss_scalar
self._globalstep_last_logged = self.state.global_step

self.log(logs)

Expand Down
2 changes: 1 addition & 1 deletion src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ class TrainingArguments:
warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})

logging_dir: Optional[str] = field(default_factory=default_logdir, metadata={"help": "Tensorboard log dir."})
logging_first_step: bool = field(default=False, metadata={"help": "Log and eval the first global_step"})
logging_first_step: bool = field(default=False, metadata={"help": "Log the first global_step"})
logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."})
save_total_limit: Optional[int] = field(
Expand Down

0 comments on commit 8f1c960

Please sign in to comment.