Skip to content

Commit

Permalink
Fix the behavior of collecting 'num_input_tokens_seen' (huggingface#2…
Browse files Browse the repository at this point in the history
…9099)

fix the behavior of collecting 'num_input_tokens_seen'

See huggingface#28791 for more details.
  • Loading branch information
youliangh authored and hovnatan committed Mar 27, 2024
1 parent 8cd2909 commit 4bd2d04
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2097,7 +2097,12 @@ def _inner_training_loop(
"a `main_input_name` attribute to the model class you are using."
)
else:
self.state.num_input_tokens_seen += self.accelerator.gather(inputs[main_input_name]).numel()
input_device = inputs[main_input_name].device
self.state.num_input_tokens_seen += torch.sum(
self.accelerator.gather(
torch.tensor(inputs[main_input_name].numel(), device=input_device, dtype=torch.int64)
)
).item()
if rng_to_sync:
self._load_rng_state(resume_from_checkpoint)
rng_to_sync = False
Expand Down

0 comments on commit 4bd2d04

Please sign in to comment.