Skip to content

Commit 37ff988

Browse files
committed
make style && slight fix of test
1 parent ea0ad02 commit 37ff988

File tree

2 files changed

+73
-75
lines changed

2 files changed

+73
-75
lines changed

src/transformers/trainer.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2523,7 +2523,6 @@ def _inner_training_loop(
25232523
start_time = time.time()
25242524
epochs_trained = 0
25252525
steps_trained_in_current_epoch = 0
2526-
steps_trained_progress_bar = None
25272526

25282527
# Check if continuing training from a checkpoint
25292528
if resume_from_checkpoint is not None and os.path.isfile(
@@ -2596,7 +2595,6 @@ def _inner_training_loop(
25962595
elif steps_trained_in_current_epoch == 0:
25972596
self._load_rng_state(resume_from_checkpoint)
25982597

2599-
26002598
epoch_iterator = iter(epoch_dataloader)
26012599
# We chunkify the epoch iterator into gradient accumulation steps `n` batches
26022600
remainder = steps_in_epoch % args.gradient_accumulation_steps
@@ -2631,13 +2629,11 @@ def _inner_training_loop(
26312629
input_tokens = inputs[main_input_name].numel()
26322630
input_tokens = torch.tensor(input_tokens, device=self.args.device, dtype=torch.int64)
26332631
self.state.num_input_tokens_seen += self.accelerator.gather(input_tokens).sum().item()
2634-
2632+
26352633
if rng_to_sync:
26362634
self._load_rng_state(resume_from_checkpoint)
26372635
rng_to_sync = False
26382636

2639-
2640-
26412637
if step % args.gradient_accumulation_steps == 0:
26422638
self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
26432639

Lines changed: 72 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
import os
2-
import shutil
2+
33
import torch
4-
from torch.utils.data import TensorDataset, Dataset
4+
import torch.nn as nn
5+
from torch.utils.data import Dataset
6+
57
from transformers import (
6-
AutoModelForSequenceClassification,
7-
AutoTokenizer,
88
Trainer,
99
TrainingArguments,
1010
)
11-
import torch.nn as nn
11+
12+
from transformers.testing_utils import TestCasePlus
1213

1314

1415
class DummyModel(nn.Module):
@@ -24,6 +25,7 @@ def forward(self, input_ids=None, attention_mask=None, labels=None):
2425
loss = loss_fn(logits, labels)
2526
return {"loss": loss, "logits": logits}
2627

28+
2729
class DummyDictDataset(Dataset):
2830
def __init__(self, input_ids, attention_mask, labels):
2931
self.input_ids = input_ids
@@ -40,6 +42,7 @@ def __getitem__(self, idx):
4042
"labels": self.labels[idx],
4143
}
4244

45+
4346
def create_dummy_dataset():
4447
"""Creates a dummy dataset for testing."""
4548
num_samples = 13
@@ -49,72 +52,71 @@ def create_dummy_dataset():
4952
dummy_labels = torch.randint(0, 2, (num_samples,))
5053
return DummyDictDataset(dummy_input_ids, dummy_attention_mask, dummy_labels)
5154

52-
def test_resume_with_original_trainer():
53-
"""Tests the original transformers Trainer."""
54-
print("Testing the original transformers Trainer...")
55-
56-
# 1. Set up a dummy model
57-
model = DummyModel(input_dim=10, num_labels=2)
58-
dummy_dataset = create_dummy_dataset()
59-
60-
# 3. First training (simulate interruption)
61-
output_dir_initial = "./test_original_trainer_initial"
62-
training_args_initial = TrainingArguments(
63-
output_dir=output_dir_initial,
64-
num_train_epochs=1,
65-
per_device_train_batch_size=2,
66-
gradient_accumulation_steps=3,
67-
save_strategy="steps",
68-
save_steps=1, # Save at every step
69-
report_to=[], # Disable wandb/tensorboard and other loggers
70-
max_steps=2, # Stop after step 2 to simulate interruption
71-
)
72-
73-
trainer_initial = Trainer(
74-
model=model,
75-
args=training_args_initial,
76-
train_dataset=dummy_dataset,
77-
)
78-
trainer_initial.train()
79-
80-
# Make sure we have a checkpoint before interruption
81-
checkpoint_path = os.path.join(output_dir_initial, "checkpoint-2")
82-
assert os.path.exists(checkpoint_path)
83-
84-
print("Second phase")
85-
# 4. Resume training from checkpoint
86-
output_dir_resumed = "./test_original_trainer_resumed"
87-
training_args_resumed = TrainingArguments(
88-
output_dir=output_dir_resumed,
89-
num_train_epochs=1,
90-
per_device_train_batch_size=2,
91-
gradient_accumulation_steps=3,
92-
save_strategy="steps",
93-
save_steps=1, # Keep the same save strategy
94-
)
95-
96-
trainer_resumed = Trainer(
97-
model=model,
98-
args=training_args_resumed,
99-
train_dataset=dummy_dataset,
100-
)
101-
# Resume from the interrupted checkpoint and finish the remaining training
102-
trainer_resumed.train(resume_from_checkpoint=checkpoint_path)
103-
104-
# 5. Assertion: Check if the final model has been saved
105-
final_model_path = os.path.join(output_dir_resumed,'checkpoint-3', "model.safetensors")
106-
try:
107-
assert os.path.exists(final_model_path), "Original Trainer: Final model checkpoint was not saved!"
108-
print("✓ Original Trainer: Final model has been saved.")
109-
except AssertionError as e:
110-
print(f"✗ Original Trainer: {e}")
111-
112-
113-
# Clean up test directories
114-
shutil.rmtree(output_dir_initial)
115-
shutil.rmtree(output_dir_resumed)
55+
56+
class TestTrainerResume(TestCasePlus):
57+
def test_resume_with_original_trainer(self):
58+
"""Tests the original transformers Trainer."""
59+
print("Testing the original transformers Trainer...")
60+
61+
# 1. Set up a dummy model
62+
model = DummyModel(input_dim=10, num_labels=2)
63+
dummy_dataset = create_dummy_dataset()
64+
65+
# 3. First training (simulate interruption)
66+
output_dir_initial = self.get_auto_remove_tmp_dir()
67+
training_args_initial = TrainingArguments(
68+
output_dir=output_dir_initial,
69+
num_train_epochs=1,
70+
per_device_train_batch_size=2,
71+
gradient_accumulation_steps=3,
72+
save_strategy="steps",
73+
save_steps=1, # Save at every step
74+
report_to=[], # Disable wandb/tensorboard and other loggers
75+
max_steps=2, # Stop after step 2 to simulate interruption
76+
)
77+
78+
trainer_initial = Trainer(
79+
model=model,
80+
args=training_args_initial,
81+
train_dataset=dummy_dataset,
82+
)
83+
trainer_initial.train()
84+
85+
# Make sure we have a checkpoint before interruption
86+
checkpoint_path = os.path.join(output_dir_initial, "checkpoint-2")
87+
assert os.path.exists(checkpoint_path)
88+
89+
print("Second phase")
90+
# 4. Resume training from checkpoint
91+
output_dir_resumed = self.get_auto_remove_tmp_dir()
92+
training_args_resumed = TrainingArguments(
93+
output_dir=output_dir_resumed,
94+
num_train_epochs=1,
95+
per_device_train_batch_size=2,
96+
gradient_accumulation_steps=3,
97+
save_strategy="steps",
98+
save_steps=1, # Keep the same save strategy
99+
)
100+
101+
trainer_resumed = Trainer(
102+
model=model,
103+
args=training_args_resumed,
104+
train_dataset=dummy_dataset,
105+
)
106+
# Resume from the interrupted checkpoint and finish the remaining training
107+
trainer_resumed.train(resume_from_checkpoint=checkpoint_path)
108+
109+
# 5. Assertion: Check if the final model has been saved
110+
final_model_path = os.path.join(output_dir_resumed, "checkpoint-3", "model.safetensors")
111+
try:
112+
assert os.path.exists(final_model_path), "Original Trainer: Final model checkpoint was not saved!"
113+
print("✓ Original Trainer: Final model has been saved.")
114+
except AssertionError as e:
115+
print(f"✗ Original Trainer: {e}")
116116

117117

118118
# Run all tests
119119
if __name__ == "__main__":
120-
test_resume_with_original_trainer()
120+
import unittest
121+
122+
unittest.main()

0 commit comments

Comments
 (0)