Skip to content

Commit 564fde1

Browse files
rangehowrangehowgithub-actions[bot]SunMarc
authored
FIX(trainer): ensure final checkpoint is saved when resuming training (#40347)
* fix(trainer): ensure final checkpoint is saved when resuming training * add test * make style && slight fix of test * make style again * move test code to test_trainer * remove outdated test file * Apply style fixes --------- Co-authored-by: rangehow <rangehow@foxmail.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
1 parent 5748352 commit 564fde1

File tree

2 files changed

+121
-24
lines changed

2 files changed

+121
-24
lines changed

src/transformers/trainer.py

Lines changed: 12 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2533,7 +2533,6 @@ def _inner_training_loop(
25332533
start_time = time.time()
25342534
epochs_trained = 0
25352535
steps_trained_in_current_epoch = 0
2536-
steps_trained_progress_bar = None
25372536

25382537
# Check if continuing training from a checkpoint
25392538
if resume_from_checkpoint is not None and os.path.isfile(
@@ -2594,18 +2593,18 @@ def _inner_training_loop(
25942593
)
25952594
self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)
25962595

2597-
if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0:
2598-
self._load_rng_state(resume_from_checkpoint)
2599-
2596+
step = -1
26002597
rng_to_sync = False
2601-
steps_skipped = 0
2602-
if steps_trained_in_current_epoch > 0:
2603-
epoch_dataloader = skip_first_batches(epoch_dataloader, steps_trained_in_current_epoch)
2604-
steps_skipped = steps_trained_in_current_epoch
2605-
steps_trained_in_current_epoch = 0
2606-
rng_to_sync = True
26072598

2608-
step = -1
2599+
# Handle resumption from checkpoint
2600+
if epoch == epochs_trained and resume_from_checkpoint is not None:
2601+
if steps_trained_in_current_epoch > 0 and not args.ignore_data_skip:
2602+
epoch_dataloader = skip_first_batches(epoch_dataloader, steps_trained_in_current_epoch)
2603+
step = steps_trained_in_current_epoch - 1
2604+
rng_to_sync = True
2605+
elif steps_trained_in_current_epoch == 0:
2606+
self._load_rng_state(resume_from_checkpoint)
2607+
26092608
epoch_iterator = iter(epoch_dataloader)
26102609
# We chunkify the epoch iterator into gradient accumulation steps `n` batches
26112610
remainder = steps_in_epoch % args.gradient_accumulation_steps
@@ -2658,22 +2657,11 @@ def _inner_training_loop(
26582657

26592658
input_tokens = torch.tensor(input_tokens, device=self.args.device, dtype=torch.int64)
26602659
self.state.num_input_tokens_seen += self.accelerator.gather(input_tokens).sum().item()
2660+
26612661
if rng_to_sync:
26622662
self._load_rng_state(resume_from_checkpoint)
26632663
rng_to_sync = False
26642664

2665-
# Skip past any already trained steps if resuming training
2666-
if steps_trained_in_current_epoch > 0:
2667-
steps_trained_in_current_epoch -= 1
2668-
if steps_trained_progress_bar is not None:
2669-
steps_trained_progress_bar.update(1)
2670-
if steps_trained_in_current_epoch == 0:
2671-
self._load_rng_state(resume_from_checkpoint)
2672-
continue
2673-
elif steps_trained_progress_bar is not None:
2674-
steps_trained_progress_bar.close()
2675-
steps_trained_progress_bar = None
2676-
26772665
if step % args.gradient_accumulation_steps == 0:
26782666
self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
26792667

@@ -2765,7 +2753,7 @@ def _inner_training_loop(
27652753

27662754
model.zero_grad()
27672755
self.state.global_step += 1
2768-
self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
2756+
self.state.epoch = epoch + (step + 1) / steps_in_epoch
27692757
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
27702758
self._maybe_log_save_evaluate(
27712759
tr_loss,

tests/trainer/test_trainer.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5158,6 +5158,115 @@ def test_trainer_works_without_model_config(self):
51585158
)
51595159
trainer.train()
51605160

5161+
@require_safetensors
5162+
def test_resume_from_interrupted_training(self):
5163+
"""
5164+
Tests resuming training from a checkpoint after a simulated interruption.
5165+
"""
5166+
5167+
# --- Helper classes and functions defined locally for this test ---
5168+
class DummyModel(nn.Module):
5169+
def __init__(self, input_dim=10, num_labels=2):
5170+
super().__init__()
5171+
self.linear = nn.Linear(input_dim, num_labels)
5172+
5173+
def forward(self, input_ids=None, attention_mask=None, labels=None):
5174+
logits = self.linear(input_ids.float())
5175+
loss = None
5176+
if labels is not None:
5177+
loss_fn = nn.CrossEntropyLoss()
5178+
loss = loss_fn(logits, labels)
5179+
return {"loss": loss, "logits": logits}
5180+
5181+
class DummyDictDataset(torch.utils.data.Dataset):
5182+
def __init__(self, input_ids, attention_mask, labels):
5183+
self.input_ids = input_ids
5184+
self.attention_mask = attention_mask
5185+
self.labels = labels
5186+
5187+
def __len__(self):
5188+
return len(self.input_ids)
5189+
5190+
def __getitem__(self, idx):
5191+
return {
5192+
"input_ids": self.input_ids[idx],
5193+
"attention_mask": self.attention_mask[idx],
5194+
"labels": self.labels[idx],
5195+
}
5196+
5197+
def create_dummy_dataset():
5198+
"""Creates a dummy dataset for this specific test."""
5199+
num_samples = 13
5200+
input_dim = 10
5201+
dummy_input_ids = torch.rand(num_samples, input_dim)
5202+
dummy_attention_mask = torch.ones(num_samples, input_dim)
5203+
dummy_labels = torch.randint(0, 2, (num_samples,))
5204+
return DummyDictDataset(dummy_input_ids, dummy_attention_mask, dummy_labels)
5205+
5206+
# 1. Set up a dummy model and dataset
5207+
model = DummyModel(input_dim=10, num_labels=2)
5208+
dummy_dataset = create_dummy_dataset()
5209+
5210+
# 2. First training phase (simulating an interruption)
5211+
output_dir_initial = self.get_auto_remove_tmp_dir()
5212+
training_args_initial = TrainingArguments(
5213+
output_dir=output_dir_initial,
5214+
num_train_epochs=1,
5215+
per_device_train_batch_size=2,
5216+
gradient_accumulation_steps=3,
5217+
save_strategy="steps",
5218+
save_steps=1, # Save at every step
5219+
report_to=[], # Disable wandb/tensorboard and other loggers
5220+
max_steps=2, # Stop after step 2 to simulate interruption
5221+
)
5222+
5223+
trainer_initial = Trainer(
5224+
model=model,
5225+
args=training_args_initial,
5226+
train_dataset=dummy_dataset,
5227+
)
5228+
trainer_initial.train()
5229+
5230+
# 3. Verify that a checkpoint was created before the "interruption"
5231+
checkpoint_path = os.path.join(output_dir_initial, "checkpoint-2")
5232+
self.assertTrue(os.path.exists(checkpoint_path), f"Checkpoint not found at {checkpoint_path}")
5233+
5234+
# 4. Second training phase (resuming from the checkpoint)
5235+
output_dir_resumed = self.get_auto_remove_tmp_dir()
5236+
# Note: total steps for one epoch is ceil(13 / (2*3)) = 3.
5237+
# We stopped at step 2, so the resumed training should run for 1 more step.
5238+
training_args_resumed = TrainingArguments(
5239+
output_dir=output_dir_resumed,
5240+
num_train_epochs=1,
5241+
per_device_train_batch_size=2,
5242+
gradient_accumulation_steps=3,
5243+
save_strategy="steps",
5244+
save_steps=1,
5245+
report_to=[],
5246+
)
5247+
5248+
trainer_resumed = Trainer(
5249+
model=model,
5250+
args=training_args_resumed,
5251+
train_dataset=dummy_dataset,
5252+
)
5253+
# Resume from the interrupted checkpoint and finish the remaining training
5254+
trainer_resumed.train(resume_from_checkpoint=checkpoint_path)
5255+
5256+
# 5. Assertions: Check if the training completed and the final model was saved
5257+
# The training should have completed step 3.
5258+
# Total steps per epoch = ceil(13 samples / (2 batch_size * 3 grad_accum)) = 3
5259+
self.assertEqual(trainer_resumed.state.global_step, 3)
5260+
5261+
# Check that a checkpoint for the final step exists.
5262+
final_checkpoint_path = os.path.join(output_dir_resumed, "checkpoint-3")
5263+
self.assertTrue(os.path.exists(final_checkpoint_path))
5264+
5265+
# Check if the model weights file exists in the final checkpoint directory.
5266+
# Trainer saves non-PreTrainedModel models as `model.safetensors` by default if safetensors is available.
5267+
final_model_path = os.path.join(final_checkpoint_path, SAFE_WEIGHTS_NAME)
5268+
self.assertTrue(os.path.exists(final_model_path), "Final model checkpoint was not saved!")
5269+
51615270

51625271
@require_torch
51635272
@is_staging_test

0 commit comments

Comments
 (0)