3535 backend_empty_cache ,
3636 require_flash_attn ,
3737 require_liger_kernel ,
38- require_peft ,
3938 require_torch_accelerator ,
4039 torch_device ,
4140)
4443from trl import GRPOConfig , GRPOTrainer
4544from trl .trainer .utils import get_kbit_device_map
4645
47- from ..testing_utils import TrlTestCase , require_bitsandbytes , require_vllm
46+ from ..testing_utils import TrlTestCase , require_bitsandbytes , require_peft , require_vllm
4847from .testing_constants import MODELS_TO_TEST
4948
5049
5453
5554@pytest .mark .slow
5655@require_torch_accelerator
57- class GRPOTrainerSlowTester (TrlTestCase ):
58- def setUp (self ):
59- super ().setUp ()
56+ class TestGRPOTrainerSlow (TrlTestCase ):
57+ def setup_method (self ):
6058 self .train_dataset = load_dataset ("trl-internal-testing/zen" , "standard_prompt_only" , split = "train" )
6159 self .eval_dataset = load_dataset ("trl-internal-testing/zen" , "standard_prompt_only" , split = "test" )
6260 self .max_length = 128
6361
64- def tearDown (self ):
62+ def teardown_method (self ):
6563 gc .collect ()
6664 backend_empty_cache (torch_device )
6765 gc .collect ()
68- super ().tearDown ()
6966
7067 @parameterized .expand (MODELS_TO_TEST )
7168 @require_liger_kernel
@@ -103,7 +100,7 @@ def test_training_with_liger_grpo_loss(self, model_name):
103100
104101 for n , param in previous_trainable_params .items ():
105102 new_param = model .get_parameter (n )
106- self . assertFalse ( torch .equal (param , new_param ), f"Parameter { n } has not changed." )
103+ assert not torch .equal (param , new_param ), f"Parameter { n } has not changed."
107104
108105 release_memory (model , trainer )
109106
@@ -153,20 +150,20 @@ def test_training_with_liger_grpo_loss_and_peft(self, model_name):
153150 # Verify PEFT adapter is properly initialized
154151 from peft import PeftModel
155152
156- self . assertTrue ( isinstance (trainer .model , PeftModel ), "Model should be wrapped with PEFT" )
153+ assert isinstance (trainer .model , PeftModel ), "Model should be wrapped with PEFT"
157154
158155 # Store adapter weights before training
159156 previous_trainable_params = {
160157 n : param .clone () for n , param in trainer .model .named_parameters () if param .requires_grad
161158 }
162- self . assertTrue ( len (previous_trainable_params ) > 0 , "No trainable parameters found in PEFT model" )
159+ assert len (previous_trainable_params ) > 0 , "No trainable parameters found in PEFT model"
163160
164161 trainer .train ()
165162
166163 # Verify adapter weights have changed after training
167164 for n , param in previous_trainable_params .items ():
168165 new_param = trainer .model .get_parameter (n )
169- self . assertFalse ( torch .equal (param , new_param ), f"Parameter { n } has not changed." )
166+ assert not torch .equal (param , new_param ), f"Parameter { n } has not changed."
170167
171168 release_memory (model , trainer )
172169
@@ -199,12 +196,12 @@ def test_training_with_transformers_paged(self, model_name):
199196
200197 trainer .train ()
201198
202- self . assertIsNotNone ( trainer .state .log_history [- 1 ]["train_loss" ])
199+ assert trainer .state .log_history [- 1 ]["train_loss" ] is not None
203200
204201 # Check that the params have changed
205202 for n , param in previous_trainable_params .items ():
206203 new_param = model .get_parameter (n )
207- self . assertFalse ( torch .equal (param , new_param ), f"Parameter { n } has not changed." )
204+ assert not torch .equal (param , new_param ), f"Parameter { n } has not changed."
208205
209206 release_memory (model , trainer )
210207
@@ -310,13 +307,13 @@ def reward_func(prompts, completions, **kwargs):
310307 peft_config = lora_config ,
311308 )
312309
313- self . assertIsInstance (trainer .model , PeftModel )
310+ assert isinstance (trainer .model , PeftModel )
314311
315312 previous_trainable_params = {n : param .clone () for n , param in trainer .model .named_parameters ()}
316313
317314 trainer .train ()
318315
319- self . assertIsNotNone ( trainer .state .log_history [- 1 ]["train_loss" ])
316+ assert trainer .state .log_history [- 1 ]["train_loss" ] is not None
320317
321318 # Check that LoRA parameters have changed
322319 # For VLM models, we're more permissive about which parameters can change
@@ -328,7 +325,7 @@ def reward_func(prompts, completions, **kwargs):
328325 lora_params_changed = True
329326
330327 # At least some LoRA parameters should have changed during training
331- self . assertTrue ( lora_params_changed , "No LoRA parameters were updated during training." )
328+ assert lora_params_changed , "No LoRA parameters were updated during training."
332329
333330 except torch .OutOfMemoryError as e :
334331 self .skipTest (f"Skipping VLM training test due to insufficient GPU memory: { e } " )
@@ -378,8 +375,8 @@ def test_vlm_processor_vllm_colocate_mode(self):
378375 processor = AutoProcessor .from_pretrained ("HuggingFaceTB/SmolVLM-Instruct" , use_fast = True , padding_side = "left" )
379376
380377 # Verify processor has both required attributes for VLM detection
381- self . assertTrue ( hasattr (processor , "tokenizer" ) )
382- self . assertTrue ( hasattr (processor , "image_processor" ) )
378+ assert hasattr (processor , "tokenizer" )
379+ assert hasattr (processor , "image_processor" )
383380
384381 def dummy_reward_func (completions , ** kwargs ):
385382 return [1.0 ] * len (completions )
@@ -438,16 +435,14 @@ def dummy_reward_func(completions, **kwargs):
438435 )
439436
440437 # Should detect VLM processor correctly and allow vLLM
441- self . assertTrue ( trainer .use_vllm , "vLLM should be enabled for VLM processors in colocate mode" )
442- self . assertEqual ( trainer .vllm_mode , "colocate" , "Should use colocate mode" )
438+ assert trainer .use_vllm , "vLLM should be enabled for VLM processors in colocate mode"
439+ assert trainer .vllm_mode == "colocate" , "Should use colocate mode"
443440
444441 # Check if signature columns were set properly
445442 if trainer ._signature_columns is not None :
446443 # Should include 'image' in signature columns for VLM processors
447- self .assertIn (
448- "image" ,
449- trainer ._signature_columns ,
450- "Should include 'image' in signature columns for VLM" ,
444+ assert "image" in trainer ._signature_columns , (
445+ "Should include 'image' in signature columns for VLM"
451446 )
452447
453448 # Should not emit any warnings about VLM incompatibility
@@ -457,10 +452,8 @@ def dummy_reward_func(completions, **kwargs):
457452 if "does not support VLMs" in str (w_item .message )
458453 or "not compatible" in str (w_item .message ).lower ()
459454 ]
460- self .assertEqual (
461- len (incompatibility_warnings ),
462- 0 ,
463- f"Should not emit VLM incompatibility warnings, but got: { incompatibility_warnings } " ,
455+ assert len (incompatibility_warnings ) == 0 , (
456+ f"Should not emit VLM incompatibility warnings, but got: { incompatibility_warnings } "
464457 )
465458
466459 # Test passes if we get this far without exceptions
@@ -525,12 +518,12 @@ def test_training_vllm(self):
525518
526519 trainer .train ()
527520
528- self . assertIsNotNone ( trainer .state .log_history [- 1 ]["train_loss" ])
521+ assert trainer .state .log_history [- 1 ]["train_loss" ] is not None
529522
530523 # Check that the params have changed
531524 for n , param in previous_trainable_params .items ():
532525 new_param = trainer .model .get_parameter (n )
533- self . assertFalse ( torch .equal (param , new_param ), f"Parameter { n } has not changed." )
526+ assert not torch .equal (param , new_param ), f"Parameter { n } has not changed."
534527
535528 except Exception as e :
536529 # If vLLM fails to initialize due to hardware constraints or other issues, that's expected
0 commit comments