Skip to content

Commit

Permalink
Fixed trainer total_flos relaoding in distributed mode (huggingface#1…
Browse files Browse the repository at this point in the history
…1383)

* Fixed trainer total_flos relaoding in distributed mode

* logging flos at the end of training
  • Loading branch information
TevenLeScao authored and Iwontbecreative committed Jul 15, 2021
1 parent 8a7c51b commit 1f3971c
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,9 +426,9 @@ def __init__(

self.state = TrainerState()
self.control = TrainerControl()
# Internal variable for total_flos used to count as tensors (for distributed + TPU), will be sent in the
# state at each call to self.log.
self._total_flos = None
# Internal variable to count flos in each process, will be accumulated in `self.state.total_flos` then
# returned to 0 every time flos need to be logged
self.current_flos = 0
self.hp_search_backend = None
self.use_tune_checkpoints = False
default_label_names = (
Expand Down Expand Up @@ -1162,7 +1162,6 @@ def train(
# _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
self._total_loss_scalar = 0.0
self._globalstep_last_logged = self.state.global_step
self._total_flos = self.state.total_flos
model.zero_grad()

self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
Expand Down Expand Up @@ -1220,7 +1219,7 @@ def train(
tr_loss += self.training_step(model, inputs)
else:
tr_loss += self.training_step(model, inputs)
self._total_flos += float(self.floating_point_ops(inputs))
self.current_flos += float(self.floating_point_ops(inputs))

# Optimizer step for deepspeed must be called on every step regardless of the value of gradient_accumulation_steps
if self.deepspeed:
Expand Down Expand Up @@ -1321,9 +1320,8 @@ def train(
)

metrics = speed_metrics("train", start_time, self.state.max_steps)
if self._total_flos is not None:
self.store_flos()
metrics["total_flos"] = self.state.total_flos
self.store_flos()
metrics["total_flos"] = self.state.total_flos
self.log(metrics)

self.control = self.callback_handler.on_train_end(args, self.state, self.control)
Expand Down Expand Up @@ -1788,11 +1786,12 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None):

def store_flos(self):
# Storing the number of floating-point operations that went into the model
if self._total_flos is not None:
if self.args.local_rank != -1:
self.state.total_flos = distributed_broadcast_scalars([self._total_flos]).sum().item()
else:
self.state.total_flos = self._total_flos
if self.args.local_rank != -1:
self.state.total_flos += distributed_broadcast_scalars([self.current_flos]).sum().item()
self.current_flos = 0
else:
self.state.total_flos = self.current_flos
self.current_flos = 0

def _sorted_checkpoints(
self, output_dir=None, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False
Expand Down Expand Up @@ -1883,6 +1882,7 @@ def evaluate(
)

output.metrics.update(speed_metrics(metric_key_prefix, start_time, output.num_samples))

self.log(output.metrics)

if self.args.tpu_metrics_debug or self.args.debug:
Expand Down

0 comments on commit 1f3971c

Please sign in to comment.