Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reset loss to zero on logging in Trainer to avoid bfloat16 issues #8561

Merged
merged 5 commits into from
Nov 18, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,8 +727,10 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D
self.state.is_local_process_zero = self.is_local_process_zero()
self.state.is_world_process_zero = self.is_world_process_zero()

# tr_loss is a tensor to avoid synchronization of TPUs through .item()
tr_loss = torch.tensor(0.0).to(self.args.device)
self._logging_loss_scalar = 0
# _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
self._total_loss_scalar = 0.0
self._globalstep_last_logged = 0
self._total_flos = self.state.total_flos
model.zero_grad()
Expand Down Expand Up @@ -843,23 +845,26 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D
self.log({"total_flos": self.state.total_flos})

self.control = self.callback_handler.on_train_end(self.args, self.state, self.control)
# add remaining tr_loss
self._total_loss_scalar += tr_loss.item()

return TrainOutput(self.state.global_step, tr_loss.item() / self.state.global_step)
return TrainOutput(self.state.global_step, self._total_loss_scalar / self.state.global_step)

def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch):
if self.control.should_log:
logs: Dict[str, float] = {}
tr_loss_scalar = tr_loss.item()
logs["loss"] = (tr_loss_scalar - self._logging_loss_scalar) / (
self.state.global_step - self._globalstep_last_logged
)
# reset tr_loss to zero
tr_loss -= tr_loss

logs["loss"] = tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged)
# backward compatibility for pytorch schedulers
logs["learning_rate"] = (
self.lr_scheduler.get_last_lr()[0]
if version.parse(torch.__version__) >= version.parse("1.4")
else self.lr_scheduler.get_lr()[0]
)
self._logging_loss_scalar = tr_loss_scalar
self._total_loss_scalar += tr_loss_scalar
self._globalstep_last_logged = self.state.global_step

self.log(logs)
Expand Down