Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed trainer total_flos relaoding in distributed mode #11383

Merged
merged 2 commits into from
Apr 23, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,9 +426,9 @@ def __init__(

self.state = TrainerState()
self.control = TrainerControl()
# Internal variable for total_flos used to count as tensors (for distributed + TPU), will be sent in the
# state at each call to self.log.
self._total_flos = None
# Internal variable to count flos in each process, will be accumulated in `self.state.total_flos` then
# returned to 0 every time flos need to be logged
self.current_flos = 0
self.hp_search_backend = None
self.use_tune_checkpoints = False
default_label_names = (
Expand Down Expand Up @@ -1162,7 +1162,6 @@ def train(
# _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
self._total_loss_scalar = 0.0
self._globalstep_last_logged = self.state.global_step
self._total_flos = self.state.total_flos
model.zero_grad()

self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
Expand Down Expand Up @@ -1220,7 +1219,7 @@ def train(
tr_loss += self.training_step(model, inputs)
else:
tr_loss += self.training_step(model, inputs)
self._total_flos += float(self.floating_point_ops(inputs))
self.current_flos += float(self.floating_point_ops(inputs))

# Optimizer step for deepspeed must be called on every step regardless of the value of gradient_accumulation_steps
if self.deepspeed:
Expand Down Expand Up @@ -1321,9 +1320,8 @@ def train(
)

metrics = speed_metrics("train", start_time, self.state.max_steps)
if self._total_flos is not None:
self.store_flos()
metrics["total_flos"] = self.state.total_flos
self.store_flos()
metrics["total_flos"] = self.state.total_flos
self.log(metrics)

self.control = self.callback_handler.on_train_end(args, self.state, self.control)
Expand Down Expand Up @@ -1788,11 +1786,12 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None):

def store_flos(self):
# Storing the number of floating-point operations that went into the model
if self._total_flos is not None:
if self.args.local_rank != -1:
self.state.total_flos = distributed_broadcast_scalars([self._total_flos]).sum().item()
else:
self.state.total_flos = self._total_flos
if self.args.local_rank != -1:
self.state.total_flos += distributed_broadcast_scalars([self.current_flos]).sum().item()
self.current_flos = 0
else:
self.state.total_flos = self.current_flos
self.current_flos = 0

def _sorted_checkpoints(
self, output_dir=None, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False
Expand Down Expand Up @@ -1883,6 +1882,7 @@ def evaluate(
)

output.metrics.update(speed_metrics(metric_key_prefix, start_time, output.num_samples))

self.log(output.metrics)

if self.args.tpu_metrics_debug or self.args.debug:
Expand Down