11import os
2- import shutil
2+
33import torch
4- from torch .utils .data import TensorDataset , Dataset
4+ import torch .nn as nn
5+ from torch .utils .data import Dataset
6+
57from transformers import (
6- AutoModelForSequenceClassification ,
7- AutoTokenizer ,
88 Trainer ,
99 TrainingArguments ,
1010)
11- import torch .nn as nn
11+
12+ from transformers .testing_utils import TestCasePlus
1213
1314
1415class DummyModel (nn .Module ):
@@ -24,6 +25,7 @@ def forward(self, input_ids=None, attention_mask=None, labels=None):
2425 loss = loss_fn (logits , labels )
2526 return {"loss" : loss , "logits" : logits }
2627
28+
2729class DummyDictDataset (Dataset ):
2830 def __init__ (self , input_ids , attention_mask , labels ):
2931 self .input_ids = input_ids
@@ -40,6 +42,7 @@ def __getitem__(self, idx):
4042 "labels" : self .labels [idx ],
4143 }
4244
45+
4346def create_dummy_dataset ():
4447 """Creates a dummy dataset for testing."""
4548 num_samples = 13
@@ -49,72 +52,71 @@ def create_dummy_dataset():
4952 dummy_labels = torch .randint (0 , 2 , (num_samples ,))
5053 return DummyDictDataset (dummy_input_ids , dummy_attention_mask , dummy_labels )
5154
52- def test_resume_with_original_trainer ():
53- """Tests the original transformers Trainer."""
54- print ("Testing the original transformers Trainer..." )
55-
56- # 1. Set up a dummy model
57- model = DummyModel (input_dim = 10 , num_labels = 2 )
58- dummy_dataset = create_dummy_dataset ()
59-
60- # 3. First training (simulate interruption)
61- output_dir_initial = "./test_original_trainer_initial"
62- training_args_initial = TrainingArguments (
63- output_dir = output_dir_initial ,
64- num_train_epochs = 1 ,
65- per_device_train_batch_size = 2 ,
66- gradient_accumulation_steps = 3 ,
67- save_strategy = "steps" ,
68- save_steps = 1 , # Save at every step
69- report_to = [], # Disable wandb/tensorboard and other loggers
70- max_steps = 2 , # Stop after step 2 to simulate interruption
71- )
72-
73- trainer_initial = Trainer (
74- model = model ,
75- args = training_args_initial ,
76- train_dataset = dummy_dataset ,
77- )
78- trainer_initial .train ()
79-
80- # Make sure we have a checkpoint before interruption
81- checkpoint_path = os .path .join (output_dir_initial , "checkpoint-2" )
82- assert os .path .exists (checkpoint_path )
83-
84- print ("Second phase" )
85- # 4. Resume training from checkpoint
86- output_dir_resumed = "./test_original_trainer_resumed"
87- training_args_resumed = TrainingArguments (
88- output_dir = output_dir_resumed ,
89- num_train_epochs = 1 ,
90- per_device_train_batch_size = 2 ,
91- gradient_accumulation_steps = 3 ,
92- save_strategy = "steps" ,
93- save_steps = 1 , # Keep the same save strategy
94- )
95-
96- trainer_resumed = Trainer (
97- model = model ,
98- args = training_args_resumed ,
99- train_dataset = dummy_dataset ,
100- )
101- # Resume from the interrupted checkpoint and finish the remaining training
102- trainer_resumed .train (resume_from_checkpoint = checkpoint_path )
103-
104- # 5. Assertion: Check if the final model has been saved
105- final_model_path = os .path .join (output_dir_resumed ,'checkpoint-3' , "model.safetensors" )
106- try :
107- assert os .path .exists (final_model_path ), "Original Trainer: Final model checkpoint was not saved!"
108- print ("✓ Original Trainer: Final model has been saved." )
109- except AssertionError as e :
110- print (f"✗ Original Trainer: { e } " )
111-
112-
113- # Clean up test directories
114- shutil .rmtree (output_dir_initial )
115- shutil .rmtree (output_dir_resumed )
55+
56+ class TestTrainerResume (TestCasePlus ):
57+ def test_resume_with_original_trainer (self ):
58+ """Tests the original transformers Trainer."""
59+ print ("Testing the original transformers Trainer..." )
60+
61+ # 1. Set up a dummy model
62+ model = DummyModel (input_dim = 10 , num_labels = 2 )
63+ dummy_dataset = create_dummy_dataset ()
64+
65+ # 3. First training (simulate interruption)
66+ output_dir_initial = self .get_auto_remove_tmp_dir ()
67+ training_args_initial = TrainingArguments (
68+ output_dir = output_dir_initial ,
69+ num_train_epochs = 1 ,
70+ per_device_train_batch_size = 2 ,
71+ gradient_accumulation_steps = 3 ,
72+ save_strategy = "steps" ,
73+ save_steps = 1 , # Save at every step
74+ report_to = [], # Disable wandb/tensorboard and other loggers
75+ max_steps = 2 , # Stop after step 2 to simulate interruption
76+ )
77+
78+ trainer_initial = Trainer (
79+ model = model ,
80+ args = training_args_initial ,
81+ train_dataset = dummy_dataset ,
82+ )
83+ trainer_initial .train ()
84+
85+ # Make sure we have a checkpoint before interruption
86+ checkpoint_path = os .path .join (output_dir_initial , "checkpoint-2" )
87+ assert os .path .exists (checkpoint_path )
88+
89+ print ("Second phase" )
90+ # 4. Resume training from checkpoint
91+ output_dir_resumed = self .get_auto_remove_tmp_dir ()
92+ training_args_resumed = TrainingArguments (
93+ output_dir = output_dir_resumed ,
94+ num_train_epochs = 1 ,
95+ per_device_train_batch_size = 2 ,
96+ gradient_accumulation_steps = 3 ,
97+ save_strategy = "steps" ,
98+ save_steps = 1 , # Keep the same save strategy
99+ )
100+
101+ trainer_resumed = Trainer (
102+ model = model ,
103+ args = training_args_resumed ,
104+ train_dataset = dummy_dataset ,
105+ )
106+ # Resume from the interrupted checkpoint and finish the remaining training
107+ trainer_resumed .train (resume_from_checkpoint = checkpoint_path )
108+
109+ # 5. Assertion: Check if the final model has been saved
110+ final_model_path = os .path .join (output_dir_resumed , "checkpoint-3" , "model.safetensors" )
111+ try :
112+ assert os .path .exists (final_model_path ), "Original Trainer: Final model checkpoint was not saved!"
113+ print ("✓ Original Trainer: Final model has been saved." )
114+ except AssertionError as e :
115+ print (f"✗ Original Trainer: { e } " )
116116
117117
118118# Run all tests
119119if __name__ == "__main__" :
120- test_resume_with_original_trainer ()
120+ import unittest
121+
122+ unittest .main ()
0 commit comments