@@ -5158,6 +5158,115 @@ def test_trainer_works_without_model_config(self):
51585158            )
51595159            trainer .train ()
51605160
5161+     @require_safetensors  
5162+     def  test_resume_from_interrupted_training (self ):
5163+         """ 
5164+         Tests resuming training from a checkpoint after a simulated interruption. 
5165+         """ 
5166+ 
5167+         # --- Helper classes and functions defined locally for this test --- 
5168+         class  DummyModel (nn .Module ):
5169+             def  __init__ (self , input_dim = 10 , num_labels = 2 ):
5170+                 super ().__init__ ()
5171+                 self .linear  =  nn .Linear (input_dim , num_labels )
5172+ 
5173+             def  forward (self , input_ids = None , attention_mask = None , labels = None ):
5174+                 logits  =  self .linear (input_ids .float ())
5175+                 loss  =  None 
5176+                 if  labels  is  not None :
5177+                     loss_fn  =  nn .CrossEntropyLoss ()
5178+                     loss  =  loss_fn (logits , labels )
5179+                 return  {"loss" : loss , "logits" : logits }
5180+ 
5181+         class  DummyDictDataset (torch .utils .data .Dataset ):
5182+             def  __init__ (self , input_ids , attention_mask , labels ):
5183+                 self .input_ids  =  input_ids 
5184+                 self .attention_mask  =  attention_mask 
5185+                 self .labels  =  labels 
5186+ 
5187+             def  __len__ (self ):
5188+                 return  len (self .input_ids )
5189+ 
5190+             def  __getitem__ (self , idx ):
5191+                 return  {
5192+                     "input_ids" : self .input_ids [idx ],
5193+                     "attention_mask" : self .attention_mask [idx ],
5194+                     "labels" : self .labels [idx ],
5195+                 }
5196+ 
5197+         def  create_dummy_dataset ():
5198+             """Creates a dummy dataset for this specific test.""" 
5199+             num_samples  =  13 
5200+             input_dim  =  10 
5201+             dummy_input_ids  =  torch .rand (num_samples , input_dim )
5202+             dummy_attention_mask  =  torch .ones (num_samples , input_dim )
5203+             dummy_labels  =  torch .randint (0 , 2 , (num_samples ,))
5204+             return  DummyDictDataset (dummy_input_ids , dummy_attention_mask , dummy_labels )
5205+ 
5206+         # 1. Set up a dummy model and dataset 
5207+         model  =  DummyModel (input_dim = 10 , num_labels = 2 )
5208+         dummy_dataset  =  create_dummy_dataset ()
5209+ 
5210+         # 2. First training phase (simulating an interruption) 
5211+         output_dir_initial  =  self .get_auto_remove_tmp_dir ()
5212+         training_args_initial  =  TrainingArguments (
5213+             output_dir = output_dir_initial ,
5214+             num_train_epochs = 1 ,
5215+             per_device_train_batch_size = 2 ,
5216+             gradient_accumulation_steps = 3 ,
5217+             save_strategy = "steps" ,
5218+             save_steps = 1 ,  # Save at every step 
5219+             report_to = [],  # Disable wandb/tensorboard and other loggers 
5220+             max_steps = 2 ,  # Stop after step 2 to simulate interruption 
5221+         )
5222+ 
5223+         trainer_initial  =  Trainer (
5224+             model = model ,
5225+             args = training_args_initial ,
5226+             train_dataset = dummy_dataset ,
5227+         )
5228+         trainer_initial .train ()
5229+ 
5230+         # 3. Verify that a checkpoint was created before the "interruption" 
5231+         checkpoint_path  =  os .path .join (output_dir_initial , "checkpoint-2" )
5232+         self .assertTrue (os .path .exists (checkpoint_path ), f"Checkpoint not found at { checkpoint_path }  )
5233+ 
5234+         # 4. Second training phase (resuming from the checkpoint) 
5235+         output_dir_resumed  =  self .get_auto_remove_tmp_dir ()
5236+         # Note: total steps for one epoch is ceil(13 / (2*3)) = 3. 
5237+         # We stopped at step 2, so the resumed training should run for 1 more step. 
5238+         training_args_resumed  =  TrainingArguments (
5239+             output_dir = output_dir_resumed ,
5240+             num_train_epochs = 1 ,
5241+             per_device_train_batch_size = 2 ,
5242+             gradient_accumulation_steps = 3 ,
5243+             save_strategy = "steps" ,
5244+             save_steps = 1 ,
5245+             report_to = [],
5246+         )
5247+ 
5248+         trainer_resumed  =  Trainer (
5249+             model = model ,
5250+             args = training_args_resumed ,
5251+             train_dataset = dummy_dataset ,
5252+         )
5253+         # Resume from the interrupted checkpoint and finish the remaining training 
5254+         trainer_resumed .train (resume_from_checkpoint = checkpoint_path )
5255+ 
5256+         # 5. Assertions: Check if the training completed and the final model was saved 
5257+         # The training should have completed step 3. 
5258+         # Total steps per epoch = ceil(13 samples / (2 batch_size * 3 grad_accum)) = 3 
5259+         self .assertEqual (trainer_resumed .state .global_step , 3 )
5260+ 
5261+         # Check that a checkpoint for the final step exists. 
5262+         final_checkpoint_path  =  os .path .join (output_dir_resumed , "checkpoint-3" )
5263+         self .assertTrue (os .path .exists (final_checkpoint_path ))
5264+ 
5265+         # Check if the model weights file exists in the final checkpoint directory. 
5266+         # Trainer saves non-PreTrainedModel models as `model.safetensors` by default if safetensors is available. 
5267+         final_model_path  =  os .path .join (final_checkpoint_path , SAFE_WEIGHTS_NAME )
5268+         self .assertTrue (os .path .exists (final_model_path ), "Final model checkpoint was not saved!" )
5269+ 
51615270
51625271@require_torch  
51635272@is_staging_test  
0 commit comments