3535    backend_empty_cache ,
3636    require_flash_attn ,
3737    require_liger_kernel ,
38-     require_peft ,
3938    require_torch_accelerator ,
4039    torch_device ,
4140)
4443from  trl  import  GRPOConfig , GRPOTrainer 
4544from  trl .trainer .utils  import  get_kbit_device_map 
4645
47- from  ..testing_utils  import  TrlTestCase , require_bitsandbytes , require_vllm 
46+ from  ..testing_utils  import  TrlTestCase , require_bitsandbytes , require_peft ,  require_vllm 
4847from  .testing_constants  import  MODELS_TO_TEST 
4948
5049
5453
5554@pytest .mark .slow  
5655@require_torch_accelerator  
57- class  GRPOTrainerSlowTester (TrlTestCase ):
58-     def  setUp (self ):
59-         super ().setUp ()
56+ class  TestGRPOTrainerSlow (TrlTestCase ):
57+     def  setup_method (self ):
6058        self .train_dataset  =  load_dataset ("trl-internal-testing/zen" , "standard_prompt_only" , split = "train" )
6159        self .eval_dataset  =  load_dataset ("trl-internal-testing/zen" , "standard_prompt_only" , split = "test" )
6260        self .max_length  =  128 
6361
64-     def  tearDown (self ):
62+     def  teardown_method (self ):
6563        gc .collect ()
6664        backend_empty_cache (torch_device )
6765        gc .collect ()
68-         super ().tearDown ()
6966
7067    @parameterized .expand (MODELS_TO_TEST ) 
7168    @require_liger_kernel  
@@ -103,7 +100,7 @@ def test_training_with_liger_grpo_loss(self, model_name):
103100
104101        for  n , param  in  previous_trainable_params .items ():
105102            new_param  =  model .get_parameter (n )
106-             self . assertFalse ( torch .equal (param , new_param ), f"Parameter { n }   has not changed." ) 
103+             assert   not   torch .equal (param , new_param ), f"Parameter { n }   has not changed." 
107104
108105        release_memory (model , trainer )
109106
@@ -153,20 +150,20 @@ def test_training_with_liger_grpo_loss_and_peft(self, model_name):
153150        # Verify PEFT adapter is properly initialized 
154151        from  peft  import  PeftModel 
155152
156-         self . assertTrue ( isinstance (trainer .model , PeftModel ), "Model should be wrapped with PEFT" ) 
153+         assert   isinstance (trainer .model , PeftModel ), "Model should be wrapped with PEFT" 
157154
158155        # Store adapter weights before training 
159156        previous_trainable_params  =  {
160157            n : param .clone () for  n , param  in  trainer .model .named_parameters () if  param .requires_grad 
161158        }
162-         self . assertTrue ( len (previous_trainable_params ) >  0 , "No trainable parameters found in PEFT model" ) 
159+         assert   len (previous_trainable_params ) >  0 , "No trainable parameters found in PEFT model" 
163160
164161        trainer .train ()
165162
166163        # Verify adapter weights have changed after training 
167164        for  n , param  in  previous_trainable_params .items ():
168165            new_param  =  trainer .model .get_parameter (n )
169-             self . assertFalse ( torch .equal (param , new_param ), f"Parameter { n }   has not changed." ) 
166+             assert   not   torch .equal (param , new_param ), f"Parameter { n }   has not changed." 
170167
171168        release_memory (model , trainer )
172169
@@ -199,12 +196,12 @@ def test_training_with_transformers_paged(self, model_name):
199196
200197        trainer .train ()
201198
202-         self . assertIsNotNone ( trainer .state .log_history [- 1 ]["train_loss" ]) 
199+         assert   trainer .state .log_history [- 1 ]["train_loss" ]  is   not   None 
203200
204201        # Check that the params have changed 
205202        for  n , param  in  previous_trainable_params .items ():
206203            new_param  =  model .get_parameter (n )
207-             self . assertFalse ( torch .equal (param , new_param ), f"Parameter { n }   has not changed." ) 
204+             assert   not   torch .equal (param , new_param ), f"Parameter { n }   has not changed." 
208205
209206        release_memory (model , trainer )
210207
@@ -310,13 +307,13 @@ def reward_func(prompts, completions, **kwargs):
310307                peft_config = lora_config ,
311308            )
312309
313-             self . assertIsInstance (trainer .model , PeftModel )
310+             assert   isinstance (trainer .model , PeftModel )
314311
315312            previous_trainable_params  =  {n : param .clone () for  n , param  in  trainer .model .named_parameters ()}
316313
317314            trainer .train ()
318315
319-             self . assertIsNotNone ( trainer .state .log_history [- 1 ]["train_loss" ]) 
316+             assert   trainer .state .log_history [- 1 ]["train_loss" ]  is   not   None 
320317
321318            # Check that LoRA parameters have changed 
322319            # For VLM models, we're more permissive about which parameters can change 
@@ -328,7 +325,7 @@ def reward_func(prompts, completions, **kwargs):
328325                        lora_params_changed  =  True 
329326
330327            # At least some LoRA parameters should have changed during training 
331-             self . assertTrue ( lora_params_changed , "No LoRA parameters were updated during training." ) 
328+             assert   lora_params_changed , "No LoRA parameters were updated during training." 
332329
333330        except  torch .OutOfMemoryError  as  e :
334331            self .skipTest (f"Skipping VLM training test due to insufficient GPU memory: { e }  " )
@@ -378,8 +375,8 @@ def test_vlm_processor_vllm_colocate_mode(self):
378375        processor  =  AutoProcessor .from_pretrained ("HuggingFaceTB/SmolVLM-Instruct" , use_fast = True , padding_side = "left" )
379376
380377        # Verify processor has both required attributes for VLM detection 
381-         self . assertTrue ( hasattr (processor , "tokenizer" ) )
382-         self . assertTrue ( hasattr (processor , "image_processor" ) )
378+         assert   hasattr (processor , "tokenizer" )
379+         assert   hasattr (processor , "image_processor" )
383380
384381        def  dummy_reward_func (completions , ** kwargs ):
385382            return  [1.0 ] *  len (completions )
@@ -438,16 +435,14 @@ def dummy_reward_func(completions, **kwargs):
438435                    )
439436
440437                    # Should detect VLM processor correctly and allow vLLM 
441-                     self . assertTrue ( trainer .use_vllm , "vLLM should be enabled for VLM processors in colocate mode" ) 
442-                     self . assertEqual ( trainer .vllm_mode ,  "colocate" , "Should use colocate mode" ) 
438+                     assert   trainer .use_vllm , "vLLM should be enabled for VLM processors in colocate mode" 
439+                     assert   trainer .vllm_mode   ==   "colocate" , "Should use colocate mode" 
443440
444441                    # Check if signature columns were set properly 
445442                    if  trainer ._signature_columns  is  not   None :
446443                        # Should include 'image' in signature columns for VLM processors 
447-                         self .assertIn (
448-                             "image" ,
449-                             trainer ._signature_columns ,
450-                             "Should include 'image' in signature columns for VLM" ,
444+                         assert  "image"  in  trainer ._signature_columns , (
445+                             "Should include 'image' in signature columns for VLM" 
451446                        )
452447
453448                    # Should not emit any warnings about VLM incompatibility 
@@ -457,10 +452,8 @@ def dummy_reward_func(completions, **kwargs):
457452                        if  "does not support VLMs"  in  str (w_item .message )
458453                        or  "not compatible"  in  str (w_item .message ).lower ()
459454                    ]
460-                     self .assertEqual (
461-                         len (incompatibility_warnings ),
462-                         0 ,
463-                         f"Should not emit VLM incompatibility warnings, but got: { incompatibility_warnings }  " ,
455+                     assert  len (incompatibility_warnings ) ==  0 , (
456+                         f"Should not emit VLM incompatibility warnings, but got: { incompatibility_warnings }  " 
464457                    )
465458
466459                    # Test passes if we get this far without exceptions 
@@ -525,12 +518,12 @@ def test_training_vllm(self):
525518
526519            trainer .train ()
527520
528-             self . assertIsNotNone ( trainer .state .log_history [- 1 ]["train_loss" ]) 
521+             assert   trainer .state .log_history [- 1 ]["train_loss" ]  is   not   None 
529522
530523            # Check that the params have changed 
531524            for  n , param  in  previous_trainable_params .items ():
532525                new_param  =  trainer .model .get_parameter (n )
533-                 self . assertFalse ( torch .equal (param , new_param ), f"Parameter { n }   has not changed." ) 
526+                 assert   not   torch .equal (param , new_param ), f"Parameter { n }   has not changed." 
534527
535528        except  Exception  as  e :
536529            # If vLLM fails to initialize due to hardware constraints or other issues, that's expected 
0 commit comments