From e2d1e61316ced2066aeb8b95fd09809eba40f495 Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Wed, 1 Oct 2025 10:15:27 +0200 Subject: [PATCH 01/16] Set CI for debugging --- .github/workflows/tests.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 4231ef227ec..4c21a98f4f9 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -21,7 +21,7 @@ jobs: check_code_quality: name: Check code quality runs-on: ubuntu-latest - if: github.event.pull_request.draft == false + # if: github.event.pull_request.draft == false steps: - uses: actions/checkout@v4 - name: Set up Python 3.12 @@ -36,7 +36,7 @@ jobs: name: Tests strategy: matrix: - python-version: ['3.9', '3.10', '3.11', '3.12', '3.13'] + python-version: ['3.10'] fail-fast: false runs-on: group: aws-g4dn-2xlarge @@ -46,7 +46,7 @@ jobs: defaults: run: shell: bash - if: github.event.pull_request.draft == false + # if: github.event.pull_request.draft == false steps: - name: Git checkout uses: actions/checkout@v4 From 40666811f4b9f1ba13c28fc80787945016a35857 Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Wed, 1 Oct 2025 10:17:54 +0200 Subject: [PATCH 02/16] Convert syntax from unittest to pytest --- tests/slow/test_dpo_slow.py | 8 +- tests/slow/test_grpo_slow.py | 46 ++- tests/slow/test_sft_slow.py | 12 +- tests/test_activation_offloading.py | 12 +- tests/test_bco_trainer.py | 62 ++-- tests/test_best_of_n_sampler.py | 6 +- tests/test_callbacks.py | 28 +- tests/test_cli.py | 4 +- tests/test_cli_utils.py | 147 ++++----- tests/test_collators.py | 2 +- tests/test_core.py | 6 +- tests/test_cpo_trainer.py | 14 +- tests/test_data_utils.py | 204 ++++++------- tests/test_dataset_formatting.py | 74 +++-- tests/test_dpo_trainer.py | 142 ++++----- tests/test_gkd_trainer.py | 86 +++--- tests/test_grpo_trainer.py | 246 ++++++++------- tests/test_judges.py | 22 +- tests/test_kto_trainer.py | 78 +++-- ...test_modeling_geometric_mixture_wrapper.py | 14 +- tests/test_modeling_value_head.py | 104 +++---- tests/test_nash_md_trainer.py | 12 +- tests/test_online_dpo_trainer.py | 110 +++---- tests/test_orpo_trainer.py | 10 +- tests/test_peft_models.py | 54 ++-- tests/test_ppo_trainer.py | 8 +- tests/test_prm_trainer.py | 52 ++-- tests/test_reward_trainer.py | 112 ++++--- tests/test_rewards.py | 12 +- tests/test_rloo_trainer.py | 157 +++++----- tests/test_sft_trainer.py | 288 +++++++++--------- tests/test_trainers_args.py | 186 +++++------ tests/test_utils.py | 283 +++++++++-------- tests/test_vllm_client_server.py | 70 +++-- tests/test_xpo_trainer.py | 12 +- 35 files changed, 1276 insertions(+), 1407 deletions(-) diff --git a/tests/slow/test_dpo_slow.py b/tests/slow/test_dpo_slow.py index 3b76fd8ea07..e24362fbc88 100644 --- a/tests/slow/test_dpo_slow.py +++ b/tests/slow/test_dpo_slow.py @@ -151,8 +151,8 @@ def test_dpo_peft_model(self, model_id, loss_type, pre_compute_logits, gradient_ peft_config=self.peft_config, ) - self.assertIsInstance(trainer.model, PeftModel) - self.assertIsNone(trainer.ref_model) + assert isinstance(trainer.model, PeftModel) + assert trainer.ref_model is None # train the model trainer.train() @@ -215,8 +215,8 @@ def test_dpo_peft_model_qlora(self, model_id, loss_type, pre_compute_logits, gra peft_config=self.peft_config, ) - self.assertIsInstance(trainer.model, PeftModel) - self.assertIsNone(trainer.ref_model) + assert isinstance(trainer.model, PeftModel) + assert trainer.ref_model is None # train the model trainer.train() diff --git a/tests/slow/test_grpo_slow.py b/tests/slow/test_grpo_slow.py index 75dd1dc8d9e..5f4400b9115 100644 --- a/tests/slow/test_grpo_slow.py +++ b/tests/slow/test_grpo_slow.py @@ -103,7 +103,7 @@ def test_training_with_liger_grpo_loss(self, model_name): for n, param in previous_trainable_params.items(): new_param = model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." release_memory(model, trainer) @@ -153,20 +153,20 @@ def test_training_with_liger_grpo_loss_and_peft(self, model_name): # Verify PEFT adapter is properly initialized from peft import PeftModel - self.assertTrue(isinstance(trainer.model, PeftModel), "Model should be wrapped with PEFT") + assert isinstance(trainer.model, PeftModel), "Model should be wrapped with PEFT" # Store adapter weights before training previous_trainable_params = { n: param.clone() for n, param in trainer.model.named_parameters() if param.requires_grad } - self.assertTrue(len(previous_trainable_params) > 0, "No trainable parameters found in PEFT model") + assert len(previous_trainable_params) > 0, "No trainable parameters found in PEFT model" trainer.train() # Verify adapter weights have changed after training for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." release_memory(model, trainer) @@ -199,12 +199,12 @@ def test_training_with_transformers_paged(self, model_name): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." release_memory(model, trainer) @@ -310,13 +310,13 @@ def reward_func(prompts, completions, **kwargs): peft_config=lora_config, ) - self.assertIsInstance(trainer.model, PeftModel) + assert isinstance(trainer.model, PeftModel) previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that LoRA parameters have changed # For VLM models, we're more permissive about which parameters can change @@ -328,7 +328,7 @@ def reward_func(prompts, completions, **kwargs): lora_params_changed = True # At least some LoRA parameters should have changed during training - self.assertTrue(lora_params_changed, "No LoRA parameters were updated during training.") + assert lora_params_changed, "No LoRA parameters were updated during training." except torch.OutOfMemoryError as e: self.skipTest(f"Skipping VLM training test due to insufficient GPU memory: {e}") @@ -378,8 +378,8 @@ def test_vlm_processor_vllm_colocate_mode(self): processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-Instruct", use_fast=True, padding_side="left") # Verify processor has both required attributes for VLM detection - self.assertTrue(hasattr(processor, "tokenizer")) - self.assertTrue(hasattr(processor, "image_processor")) + assert hasattr(processor, "tokenizer") + assert hasattr(processor, "image_processor") def dummy_reward_func(completions, **kwargs): return [1.0] * len(completions) @@ -438,17 +438,15 @@ def dummy_reward_func(completions, **kwargs): ) # Should detect VLM processor correctly and allow vLLM - self.assertTrue(trainer.use_vllm, "vLLM should be enabled for VLM processors in colocate mode") - self.assertEqual(trainer.vllm_mode, "colocate", "Should use colocate mode") + assert trainer.use_vllm, "vLLM should be enabled for VLM processors in colocate mode" + assert trainer.vllm_mode == "colocate", "Should use colocate mode" # Check if signature columns were set properly if trainer._signature_columns is not None: # Should include 'image' in signature columns for VLM processors - self.assertIn( - "image", - trainer._signature_columns, - "Should include 'image' in signature columns for VLM", - ) + assert "image" in \ + trainer._signature_columns, \ + "Should include 'image' in signature columns for VLM" # Should not emit any warnings about VLM incompatibility incompatibility_warnings = [ @@ -457,11 +455,9 @@ def dummy_reward_func(completions, **kwargs): if "does not support VLMs" in str(w_item.message) or "not compatible" in str(w_item.message).lower() ] - self.assertEqual( - len(incompatibility_warnings), - 0, - f"Should not emit VLM incompatibility warnings, but got: {incompatibility_warnings}", - ) + assert len(incompatibility_warnings) == \ + 0, \ + f"Should not emit VLM incompatibility warnings, but got: {incompatibility_warnings}" # Test passes if we get this far without exceptions @@ -525,12 +521,12 @@ def test_training_vllm(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." except Exception as e: # If vLLM fails to initialize due to hardware constraints or other issues, that's expected diff --git a/tests/slow/test_sft_slow.py b/tests/slow/test_sft_slow.py index db762df107d..7e673a9457c 100755 --- a/tests/slow/test_sft_slow.py +++ b/tests/slow/test_sft_slow.py @@ -148,7 +148,7 @@ def test_sft_trainer_peft(self, model_name, packing): peft_config=self.peft_config, ) - self.assertIsInstance(trainer.model, PeftModel) + assert isinstance(trainer.model, PeftModel) trainer.train() @@ -252,7 +252,7 @@ def test_sft_trainer_transformers_mp_gc_peft(self, model_name, packing, gradient peft_config=self.peft_config, ) - self.assertIsInstance(trainer.model, PeftModel) + assert isinstance(trainer.model, PeftModel) trainer.train() @@ -332,7 +332,7 @@ def test_sft_trainer_transformers_mp_gc_peft_qlora(self, model_name, packing, gr peft_config=self.peft_config, ) - self.assertIsInstance(trainer.model, PeftModel) + assert isinstance(trainer.model, PeftModel) trainer.train() @@ -372,7 +372,7 @@ def test_sft_trainer_with_chat_format_qlora(self, model_name, packing): peft_config=self.peft_config, ) - self.assertIsInstance(trainer.model, PeftModel) + assert isinstance(trainer.model, PeftModel) trainer.train() @@ -447,11 +447,11 @@ def test_train_offloading(self, model_name, packing): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" release_memory(trainer.model, trainer) diff --git a/tests/test_activation_offloading.py b/tests/test_activation_offloading.py index a80563005f5..b1e8f59d61d 100644 --- a/tests/test_activation_offloading.py +++ b/tests/test_activation_offloading.py @@ -72,10 +72,8 @@ def test_offloading_with_peft_models(self) -> None: for name_orig, grad_orig in grads_original: for name_param, param in model.named_parameters(): if name_param == name_orig and param.requires_grad and param.grad is not None: - self.assertTrue( - torch.allclose(grad_orig, param.grad, rtol=1e-4, atol=1e-5), - f"Gradient mismatch for {name_orig}", - ) + assert torch.allclose(grad_orig, param.grad, rtol=1e-4, atol=1e-5), \ + f"Gradient mismatch for {name_orig}" @require_torch_accelerator def test_noop_manager_with_offloading(self): @@ -105,7 +103,7 @@ def test_noop_manager_with_offloading(self): # Gradients should match as NoOpManager should have prevented offloading for g1, g2 in zip(grads1, grads2): - self.assertTrue(torch.allclose(g1, g2, rtol=1e-4, atol=1e-5)) + assert torch.allclose(g1, g2, rtol=1e-4, atol=1e-5) @require_torch_accelerator def test_min_offload_size(self): @@ -152,6 +150,6 @@ def test_real_hf_model(self): grads2 = [p.grad.clone() for p in model.parameters()] # Check outputs and gradients match - self.assertTrue(torch.allclose(out1, out2, rtol=1e-5)) + assert torch.allclose(out1, out2, rtol=1e-5) for g1, g2 in zip(grads1, grads2): - self.assertTrue(torch.allclose(g1, g2, rtol=1e-5)) + assert torch.allclose(g1, g2, rtol=1e-5) diff --git a/tests/test_bco_trainer.py b/tests/test_bco_trainer.py index d609c3d5b90..b1fdaf8d8cf 100644 --- a/tests/test_bco_trainer.py +++ b/tests/test_bco_trainer.py @@ -26,6 +26,7 @@ from trl.trainer.bco_trainer import _process_tokens, _tokenize from .testing_utils import TrlTestCase, require_no_wandb, require_sklearn +import pytest if is_peft_available(): @@ -71,13 +72,13 @@ def test_train(self, config_name): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the parameters have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if param.sum() != 0: # ignore 0 biases - self.assertFalse(torch.equal(param.cpu(), new_param.cpu())) + assert not torch.equal(param.cpu(), new_param.cpu()) @require_sklearn def test_train_with_precompute(self): @@ -108,13 +109,13 @@ def test_train_with_precompute(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the parameters have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if param.sum() != 0: # ignore 0 biases - self.assertFalse(torch.equal(param.cpu(), new_param.cpu())) + assert not torch.equal(param.cpu(), new_param.cpu()) @require_sklearn def test_train_eval(self): @@ -158,7 +159,7 @@ def test_init_with_ref_model_is_model(self): report_to="none", ) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): BCOTrainer( model=model, ref_model=model, # ref_model can't be the same as model @@ -196,13 +197,13 @@ def test_tokenize_and_process_tokens(self): batched=True, batch_size=2, ) - self.assertListEqual(tokenized_dataset["prompt"][:], dataset["prompt"][:]) - self.assertListEqual(tokenized_dataset["completion"][:], dataset["completion"][:]) - self.assertListEqual(tokenized_dataset["label"][:], dataset["label"][:]) - self.assertListEqual(tokenized_dataset["prompt_input_ids"][0], [46518, 374, 2664, 1091]) - self.assertListEqual(tokenized_dataset["prompt_attention_mask"][0], [1, 1, 1, 1]) - self.assertListEqual(tokenized_dataset["answer_input_ids"][0], [27261, 13]) - self.assertListEqual(tokenized_dataset["answer_attention_mask"][0], [1, 1]) + assert tokenized_dataset["prompt"][:] == dataset["prompt"][:] + assert tokenized_dataset["completion"][:] == dataset["completion"][:] + assert tokenized_dataset["label"][:] == dataset["label"][:] + assert tokenized_dataset["prompt_input_ids"][0] == [46518, 374, 2664, 1091] + assert tokenized_dataset["prompt_attention_mask"][0] == [1, 1, 1, 1] + assert tokenized_dataset["answer_input_ids"][0] == [27261, 13] + assert tokenized_dataset["answer_attention_mask"][0] == [1, 1] fn_kwargs = { "prefix": "", @@ -214,14 +215,14 @@ def test_tokenize_and_process_tokens(self): "max_prompt_length": trainer.max_prompt_length, } processed_dataset = tokenized_dataset.map(_process_tokens, fn_kwargs=fn_kwargs) - self.assertListEqual(processed_dataset["prompt"][:], dataset["prompt"][:]) - self.assertListEqual(processed_dataset["completion"][:], dataset["completion"][:]) - self.assertListEqual(processed_dataset["label"][:], dataset["label"][:]) - self.assertListEqual(processed_dataset["prompt_input_ids"][0], [46518, 374, 2664, 1091]) - self.assertListEqual(processed_dataset["prompt_attention_mask"][0], [1, 1, 1, 1]) - self.assertListEqual(processed_dataset["completion_input_ids"][0], [46518, 374, 2664, 1091, 27261, 13, 151645]) - self.assertListEqual(processed_dataset["completion_attention_mask"][0], [1, 1, 1, 1, 1, 1, 1]) - self.assertListEqual(processed_dataset["completion_labels"][0], [-100, -100, -100, -100, 27261, 13, 151645]) + assert processed_dataset["prompt"][:] == dataset["prompt"][:] + assert processed_dataset["completion"][:] == dataset["completion"][:] + assert processed_dataset["label"][:] == dataset["label"][:] + assert processed_dataset["prompt_input_ids"][0] == [46518, 374, 2664, 1091] + assert processed_dataset["prompt_attention_mask"][0] == [1, 1, 1, 1] + assert processed_dataset["completion_input_ids"][0] == [46518, 374, 2664, 1091, 27261, 13, 151645] + assert processed_dataset["completion_attention_mask"][0] == [1, 1, 1, 1, 1, 1, 1] + assert processed_dataset["completion_labels"][0] == [-100, -100, -100, -100, 27261, 13, 151645] @require_sklearn def test_train_without_providing_ref_model(self): @@ -249,13 +250,13 @@ def test_train_without_providing_ref_model(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the parameters have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if param.sum() != 0: # ignore 0 biases - self.assertFalse(torch.equal(param.cpu(), new_param.cpu())) + assert not torch.equal(param.cpu(), new_param.cpu()) @require_sklearn def test_train_udm(self): @@ -298,13 +299,13 @@ def embed_prompt(input_ids, attention_mask, model): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the parameters have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if param.sum() != 0: # ignore 0 biases - self.assertFalse(torch.equal(param.cpu(), new_param.cpu())) + assert not torch.equal(param.cpu(), new_param.cpu()) @require_sklearn @require_peft @@ -335,14 +336,14 @@ def test_train_without_providing_ref_model_with_lora(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the parameters have changed for n, param in previous_trainable_params.items(): if "lora" in n: new_param = trainer.model.get_parameter(n) if param.sum() != 0: # ignore 0 biases - self.assertFalse(torch.equal(param.cpu(), new_param.cpu())) + assert not torch.equal(param.cpu(), new_param.cpu()) @require_sklearn @require_no_wandb @@ -362,11 +363,8 @@ def test_generate_during_eval_no_wandb(self): report_to="none", ) - with self.assertRaisesRegex( - ValueError, - expected_regex="`generate_during_eval=True` requires Weights and Biases or Comet to be installed." - " Please install `wandb` or `comet-ml` to resolve.", - ): + with pytest.raises(ValueError, match="`generate_during_eval=True` requires Weights and Biases or Comet to be installed." + " Please install `wandb` or `comet-ml` to resolve."): BCOTrainer( model=model, args=training_args, @@ -440,4 +438,4 @@ def dummy_compute_metrics(*args, **kwargs): trainer.train() - self.assertEqual(trainer.state.log_history[-2]["eval_test"], 0.0) + assert trainer.state.log_history[-2]["eval_test"] == 0.0 diff --git a/tests/test_best_of_n_sampler.py b/tests/test_best_of_n_sampler.py index 471f75c0c7c..cf6810976ec 100644 --- a/tests/test_best_of_n_sampler.py +++ b/tests/test_best_of_n_sampler.py @@ -74,8 +74,8 @@ def test_different_input_types(self): for q, expected_length in various_queries_formats: results = best_of_n.generate(q) - self.assertIsInstance(results, list) - self.assertEqual(len(results), expected_length) + assert isinstance(results, list) + assert len(results) == expected_length def test_different_sample_sizes_and_n_candidates_values(self): r""" @@ -110,4 +110,4 @@ def test_different_sample_sizes_and_n_candidates_values(self): tokenized_queries = [self.tokenizer.encode(query) for query in queries] results = best_of_n.generate(tokenized_queries) for result in results: - self.assertEqual(len(result), expected) + assert len(result) == expected diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index 7904a4ae374..501641b5df9 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -119,7 +119,7 @@ def test_basic(self): trainer.train() winrate_history = [h for h in trainer.state.log_history if "eval_win_rate" in h] for history_row, expected_row in zip(winrate_history, self.expected_winrates): - self.assertTrue(all(key in history_row and history_row[key] == expected_row[key] for key in expected_row)) + assert all(key in history_row and history_row[key] == expected_row[key] for key in expected_row) def test_without_ref_model(self): # Same as before, but without the ref_model attribute. It should use the model attribute instead @@ -145,7 +145,7 @@ def test_without_ref_model(self): trainer.train() winrate_history = [h for h in trainer.state.log_history if "eval_win_rate" in h] for history_row, expected_row in zip(winrate_history, self.expected_winrates): - self.assertTrue(all(key in history_row and history_row[key] == expected_row[key] for key in expected_row)) + assert all(key in history_row and history_row[key] == expected_row[key] for key in expected_row) def test_soft_judge(self): """Test that the soft judge functionality works correctly""" @@ -188,7 +188,7 @@ def test_soft_judge(self): if "eval_avg_win_prob" in h ] for history_row, expected_row in zip(winrate_history, expected_soft_winrates): - self.assertTrue(all(key in history_row and history_row[key] == expected_row[key] for key in expected_row)) + assert all(key in history_row and history_row[key] == expected_row[key] for key in expected_row) @require_peft def test_lora(self): @@ -222,7 +222,7 @@ def test_lora(self): trainer.train() winrate_history = [h for h in trainer.state.log_history if "eval_win_rate" in h] for history_row, expected_row in zip(winrate_history, self.expected_winrates): - self.assertTrue(all(key in history_row and history_row[key] == expected_row[key] for key in expected_row)) + assert all(key in history_row and history_row[key] == expected_row[key] for key in expected_row) class LogCompletionsCallbackTester(TrlTestCase): @@ -273,12 +273,12 @@ def test_basic_wandb(self): completions = json.load(f) # Check that the columns are correct - self.assertIn("step", completions["columns"]) - self.assertIn("prompt", completions["columns"]) - self.assertIn("completion", completions["columns"]) + assert "step" in completions["columns"] + assert "prompt" in completions["columns"] + assert "completion" in completions["columns"] # Check that the prompt is in the log - self.assertIn(self.dataset["test"][0]["prompt"], completions["data"][0]) + assert self.dataset["test"][0]["prompt"] in completions["data"][0] @require_comet def test_basic_comet(self): @@ -347,7 +347,7 @@ def test_callback(self): trainer.train() last_checkpoint = get_last_checkpoint(self.tmp_dir) merged_path = os.path.join(last_checkpoint, "merged") - self.assertTrue(os.path.isdir(merged_path), "Merged folder does not exist in the last checkpoint.") + assert os.path.isdir(merged_path), "Merged folder does not exist in the last checkpoint." def test_every_checkpoint(self): training_args = DPOConfig( @@ -374,7 +374,7 @@ def test_every_checkpoint(self): for checkpoint in checkpoints: merged_path = os.path.join(checkpoint, "merged") - self.assertTrue(os.path.isdir(merged_path), f"Merged folder does not exist in checkpoint {checkpoint}.") + assert os.path.isdir(merged_path), f"Merged folder does not exist in checkpoint {checkpoint}." class BEMACallbackTester(TrlTestCase): @@ -409,7 +409,7 @@ def test_model_saved(self): # Check that the BEMA model was saved and can be loaded bema_path = os.path.join(self.tmp_dir, "bema") - self.assertTrue(os.path.isdir(bema_path), "BEMA directory was not created") + assert os.path.isdir(bema_path), "BEMA directory was not created" AutoModelForCausalLM.from_pretrained(bema_path) def test_update_frequency_0(self): @@ -430,7 +430,7 @@ def test_update_frequency_0(self): # Total 9 steps (17 samples, batch size 8, 3 epochs). # BEMA starts after step 0 and updates every 2 steps → updates at 2, 4, 5, 8 - self.assertEqual(mock_update.call_args_list, [call(2), call(4), call(6), call(8)]) + assert mock_update.call_args_list == [call(2), call(4), call(6), call(8)] def test_update_frequency_1(self): """Test that BEMA callback respects the update frequency.""" @@ -450,7 +450,7 @@ def test_update_frequency_1(self): # Total 9 steps (17 samples, batch size 8, 3 epochs). # BEMA starts after step 0 and updates every 3 steps → updates at 3, 6, 9 - self.assertEqual(mock_update.call_args_list, [call(3), call(6), call(9)]) + assert mock_update.call_args_list == [call(3), call(6), call(9)] def test_update_frequency_2(self): """Test that BEMA callback respects the update frequency.""" @@ -470,7 +470,7 @@ def test_update_frequency_2(self): # Total 9 steps (17 samples, batch size 8, 3 epochs). # BEMA starts after step 3 and updates every 2 steps → updates at 5, 7, 9 - self.assertEqual(mock_update.call_args_list, [call(5), call(7), call(9)]) + assert mock_update.call_args_list == [call(5), call(7), call(9)] def test_no_bema(self): """Test that BEMACallback works without BEMA updates.""" diff --git a/tests/test_cli.py b/tests/test_cli.py index 23b5d6bcff7..638e1d38493 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -51,7 +51,7 @@ def test_env(self, mock_stdout): command = "trl env" with patch("sys.argv", command.split(" ")): main() - self.assertIn("TRL version: ", mock_stdout.getvalue().strip()) + assert "TRL version: " in mock_stdout.getvalue().strip() def test_grpo(self): from trl.cli import main @@ -112,7 +112,7 @@ def test_sft_config_file(self): main() # Verify that output directory was created - self.assertTrue(os.path.exists(output_dir)) + assert os.path.exists(output_dir) if __name__ == "__main__": diff --git a/tests/test_cli_utils.py b/tests/test_cli_utils.py index 271dd6f5e5b..01417c258bc 100644 --- a/tests/test_cli_utils.py +++ b/tests/test_cli_utils.py @@ -23,6 +23,7 @@ from trl.scripts.utils import DatasetConfig from .testing_utils import TrlTestCase +import pytest @dataclass @@ -40,13 +41,13 @@ class TestTrlParser(TrlTestCase): def test_init_without_config_field(self): """Test initialization without 'config' field in the dataclasses.""" parser = TrlParser(dataclass_types=[MyDataclass]) - self.assertIsInstance(parser, TrlParser) + assert isinstance(parser, TrlParser) def test_init_with_config_field(self): """Test initialization with a 'config' field in the dataclass (should raise ValueError).""" - with self.assertRaises(ValueError) as context: + with pytest.raises(ValueError) as context: TrlParser(dataclass_types=[InvalidDataclass]) - self.assertTrue("has a field named 'config'" in str(context.exception)) + assert "has a field named 'config'" in str(context.exception) @patch("builtins.open", mock_open(read_data="env:\n VAR1: value1\n VAR2: value2\narg1: 2")) @patch("yaml.safe_load") @@ -67,14 +68,14 @@ def test_parse_args_and_config_with_valid_config(self, mock_environ, mock_yaml_l mock_environ["VAR2"] = "value2" # Ensure that the environment variables were set correctly - self.assertEqual(mock_environ.get("VAR1"), "value1") - self.assertEqual(mock_environ.get("VAR2"), "value2") + assert mock_environ.get("VAR1") == "value1" + assert mock_environ.get("VAR2") == "value2" # Check the parsed arguments - self.assertEqual(len(result_args), 1) - self.assertIsInstance(result_args[0], MyDataclass) - self.assertEqual(result_args[0].arg1, 2) - self.assertEqual(result_args[0].arg2, "value") + assert len(result_args) == 1 + assert isinstance(result_args[0], MyDataclass) + assert result_args[0].arg1 == 2 + assert result_args[0].arg2 == "value" @patch("builtins.open", mock_open(read_data="arg1: 2")) @patch("yaml.safe_load") @@ -90,9 +91,9 @@ def test_parse_args_and_arg_override_config(self, mock_yaml_load): result_args = parser.parse_args_and_config(args) # Check the parsed arguments - self.assertEqual(len(result_args), 1) - self.assertIsInstance(result_args[0], MyDataclass) - self.assertEqual(result_args[0].arg1, 3) + assert len(result_args) == 1 + assert isinstance(result_args[0], MyDataclass) + assert result_args[0].arg1 == 3 @patch("builtins.open", mock_open(read_data="env: not_a_dict")) @patch("yaml.safe_load") @@ -104,10 +105,10 @@ def test_parse_args_and_config_with_invalid_env(self, mock_yaml_load): args = ["--arg1", "2", "--arg2", "value", "--config", "config.yaml"] - with self.assertRaises(ValueError) as context: + with pytest.raises(ValueError) as context: parser.parse_args_and_config(args) - self.assertEqual(str(context.exception), "`env` field should be a dict in the YAML file.") + assert str(context.exception) == "`env` field should be a dict in the YAML file." def test_parse_args_and_config_without_config(self): """Test parse_args_and_config without the `--config` argument.""" @@ -119,10 +120,10 @@ def test_parse_args_and_config_without_config(self): result_args = parser.parse_args_and_config(args) # Check that the arguments are parsed as is - self.assertEqual(len(result_args), 1) - self.assertIsInstance(result_args[0], MyDataclass) - self.assertEqual(result_args[0].arg1, 2) - self.assertEqual(result_args[0].arg2, "value") + assert len(result_args) == 1 + assert isinstance(result_args[0], MyDataclass) + assert result_args[0].arg1 == 2 + assert result_args[0].arg2 == "value" def test_set_defaults_with_config(self): """Test set_defaults_with_config updates the defaults.""" @@ -133,9 +134,9 @@ def test_set_defaults_with_config(self): # Ensure the default value is updated result_args = parser.parse_args_and_config([]) - self.assertEqual(len(result_args), 1) - self.assertIsInstance(result_args[0], MyDataclass) - self.assertEqual(result_args[0].arg1, 42) + assert len(result_args) == 1 + assert isinstance(result_args[0], MyDataclass) + assert result_args[0].arg1 == 42 def test_parse_args_and_config_with_remaining_strings(self): parser = TrlParser(dataclass_types=[MyDataclass]) @@ -146,11 +147,11 @@ def test_parse_args_and_config_with_remaining_strings(self): result_args = parser.parse_args_and_config(args, return_remaining_strings=True) # Check that the arguments are parsed as is - self.assertEqual(len(result_args), 2) - self.assertIsInstance(result_args[0], MyDataclass) - self.assertEqual(result_args[0].arg1, 2) - self.assertEqual(result_args[0].arg2, "value") - self.assertEqual(result_args[1], ["remaining"]) + assert len(result_args) == 2 + assert isinstance(result_args[0], MyDataclass) + assert result_args[0].arg1 == 2 + assert result_args[0].arg2 == "value" + assert result_args[1] == ["remaining"] @patch("builtins.open", mock_open(read_data="remaining_string_in_config: abc")) @patch("yaml.safe_load") @@ -165,10 +166,10 @@ def test_parse_args_and_config_with_remaining_strings_in_config_and_args(self, m result_args = parser.parse_args_and_config(args, return_remaining_strings=True) # Check that the arguments are parsed as is - self.assertEqual(len(result_args), 2) - self.assertIsInstance(result_args[0], MyDataclass) - self.assertEqual(result_args[0].arg1, 2) - self.assertEqual(result_args[1], ["--remaining_string_in_config", "abc", "--remaining_string_in_args", "def"]) + assert len(result_args) == 2 + assert isinstance(result_args[0], MyDataclass) + assert result_args[0].arg1 == 2 + assert result_args[1] == ["--remaining_string_in_config", "abc", "--remaining_string_in_args", "def"] @patch("builtins.open", mock_open(read_data="arg1: 2\narg2: config_value")) @patch("yaml.safe_load") @@ -190,11 +191,11 @@ def test_subparsers_with_config_defaults(self, mock_yaml_load): result_args = parser.parse_args_and_config(args) # Check main parser arguments - self.assertEqual(len(result_args), 1) + assert len(result_args) == 1 # Check that config values were applied to the subparser - self.assertEqual(result_args[0].arg1, 2) # Default from config - self.assertEqual(result_args[0].arg2, "config_value") # Default from config + assert result_args[0].arg1 == 2 # Default from config + assert result_args[0].arg2 == "config_value" # Default from config @patch("builtins.open", mock_open(read_data="arg1: 2\narg2: config_value")) @patch("yaml.safe_load") @@ -216,8 +217,8 @@ def test_subparsers_with_config_defaults_and_arg_override(self, mock_yaml_load): result_args = parser.parse_args_and_config(args) # Command line arguments should override config - self.assertEqual(result_args[0].arg1, 3) - self.assertEqual(result_args[0].arg2, "config_value") # Still from config + assert result_args[0].arg1 == 3 + assert result_args[0].arg2 == "config_value" # Still from config @patch("builtins.open", mock_open(read_data="arg1: 2\nthis_arg_does_not_exist: config_value")) @patch("yaml.safe_load") @@ -236,7 +237,7 @@ def test_subparsers_with_config_defaults_and_arg_override_wrong_name(self, mock_ # Test with command line arguments overriding config args = ["subcommand", "--arg1", "3", "--config", "config.yaml"] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): parser.parse_args_and_config(args) parser.parse_args_and_config(args, fail_with_unknown_args=False) @@ -263,11 +264,11 @@ def test_subparsers_multiple_with_config_defaults(self, mock_yaml_load): result_args = parser.parse_args_and_config(args) # Check main parser arguments - self.assertEqual(len(result_args), 1) + assert len(result_args) == 1 # Check that config values were applied to the subparser - self.assertEqual(result_args[0].arg1, 2) # Default from config - self.assertEqual(result_args[0].arg2, "config_value") # Default from config + assert result_args[0].arg1 == 2 # Default from config + assert result_args[0].arg2 == "config_value" # Default from config class TestGetDataset(unittest.TestCase): @@ -277,7 +278,7 @@ def test_single_dataset_with_config(self): ) result = get_dataset(mixture_config) expected = load_dataset("trl-internal-testing/zen", "standard_language_modeling") - self.assertEqual(expected["train"][:], result["train"][:]) + assert expected["train"][:] == result["train"][:] def test_single_dataset_preference_config(self): mixture_config = DatasetMixtureConfig( @@ -285,7 +286,7 @@ def test_single_dataset_preference_config(self): ) result = get_dataset(mixture_config) expected = load_dataset("trl-internal-testing/zen", "standard_preference") - self.assertEqual(expected["train"][:], result["train"][:]) + assert expected["train"][:] == result["train"][:] def test_single_dataset_streaming(self): mixture_config = DatasetMixtureConfig( @@ -294,7 +295,7 @@ def test_single_dataset_streaming(self): ) result = get_dataset(mixture_config) expected = load_dataset("trl-internal-testing/zen", "standard_language_modeling") - self.assertEqual(expected["train"].to_list(), list(result["train"])) + assert expected["train"].to_list() == list(result["train"]) def test_dataset_mixture_basic(self): dataset_config1 = DatasetConfig( @@ -305,15 +306,15 @@ def test_dataset_mixture_basic(self): ) mixture_config = DatasetMixtureConfig(datasets=[dataset_config1, dataset_config2]) result = get_dataset(mixture_config) - self.assertIsInstance(result, DatasetDict) - self.assertIn("train", result) + assert isinstance(result, DatasetDict) + assert "train" in result train_dataset = result["train"] - self.assertEqual(train_dataset.column_names, ["prompt"]) + assert train_dataset.column_names == ["prompt"] prompts = train_dataset["prompt"] expected_first_half = load_dataset("trl-internal-testing/zen", "standard_preference", split="train") - self.assertEqual(prompts[: len(prompts) // 2], expected_first_half["prompt"]) + assert prompts[: len(prompts) // 2] == expected_first_half["prompt"] expected_second_half = load_dataset("trl-internal-testing/zen", "standard_prompt_completion", split="train") - self.assertEqual(prompts[len(prompts) // 2 :], expected_second_half["prompt"]) + assert prompts[len(prompts) // 2 :] == expected_second_half["prompt"] def test_dataset_mixture_with_weights(self): dataset_config1 = DatasetConfig( @@ -324,17 +325,17 @@ def test_dataset_mixture_with_weights(self): ) mixture_config = DatasetMixtureConfig(datasets=[dataset_config1, dataset_config2]) result = get_dataset(mixture_config) - self.assertIsInstance(result, DatasetDict) - self.assertIn("train", result) + assert isinstance(result, DatasetDict) + assert "train" in result train_dataset = result["train"] - self.assertEqual(train_dataset.column_names, ["prompt"]) + assert train_dataset.column_names == ["prompt"] prompts = train_dataset["prompt"] expected_first_half = load_dataset("trl-internal-testing/zen", "standard_preference", split="train[:50%]") - self.assertEqual(prompts[: len(prompts) // 2], expected_first_half["prompt"]) + assert prompts[: len(prompts) // 2] == expected_first_half["prompt"] expected_second_half = load_dataset( "trl-internal-testing/zen", "standard_prompt_completion", split="train[:50%]" ) - self.assertEqual(prompts[len(prompts) // 2 :], expected_second_half["prompt"]) + assert prompts[len(prompts) // 2 :] == expected_second_half["prompt"] def test_dataset_mixture_with_test_split(self): mixture_config = DatasetMixtureConfig( @@ -342,19 +343,19 @@ def test_dataset_mixture_with_test_split(self): test_split_size=2, ) result = get_dataset(mixture_config) - self.assertIsInstance(result, DatasetDict) - self.assertIn("train", result) - self.assertIn("test", result) - self.assertEqual(len(result["train"]), 15) - self.assertEqual(len(result["test"]), 2) + assert isinstance(result, DatasetDict) + assert "train" in result + assert "test" in result + assert len(result["train"]) == 15 + assert len(result["test"]) == 2 def test_empty_dataset_mixture_raises_error(self): mixture_config = DatasetMixtureConfig(datasets=[]) - with self.assertRaises(ValueError) as context: + with pytest.raises(ValueError) as context: get_dataset(mixture_config) - self.assertIn("No datasets were loaded", str(context.exception)) + assert "No datasets were loaded" in str(context.exception) def test_mixture_multiple_different_configs(self): dataset_config1 = DatasetConfig( @@ -365,9 +366,9 @@ def test_mixture_multiple_different_configs(self): ) mixture_config = DatasetMixtureConfig(datasets=[dataset_config1, dataset_config2]) result = get_dataset(mixture_config) - self.assertIsInstance(result, DatasetDict) - self.assertIn("train", result) - self.assertGreater(len(result["train"]), 0) + assert isinstance(result, DatasetDict) + assert "train" in result + assert len(result["train"]) > 0 def test_trlparser_parses_yaml_config_correctly(self): # Prepare YAML content exactly like your example @@ -390,24 +391,24 @@ def test_trlparser_parses_yaml_config_correctly(self): args = parser.parse_args_and_config(args=["--config", tmpfile.name])[0] # Assert that we got DatasetMixtureConfig instance - self.assertIsInstance(args, DatasetMixtureConfig) + assert isinstance(args, DatasetMixtureConfig) # Assert datasets list length - self.assertEqual(len(args.datasets), 2) + assert len(args.datasets) == 2 # Check first dataset dataset_config1 = args.datasets[0] - self.assertIsInstance(dataset_config1, DatasetConfig) - self.assertEqual(dataset_config1.path, "trl-internal-testing/zen") - self.assertEqual(dataset_config1.name, "standard_prompt_only") - self.assertIsNone(dataset_config1.columns) # No columns specified + assert isinstance(dataset_config1, DatasetConfig) + assert dataset_config1.path == "trl-internal-testing/zen" + assert dataset_config1.name == "standard_prompt_only" + assert dataset_config1.columns is None # No columns specified # Check second dataset dataset_config2 = args.datasets[1] - self.assertIsInstance(dataset_config2, DatasetConfig) - self.assertEqual(dataset_config2.path, "trl-internal-testing/zen") - self.assertEqual(dataset_config2.name, "standard_preference") - self.assertEqual(dataset_config2.columns, ["prompt"]) # Columns specified + assert isinstance(dataset_config2, DatasetConfig) + assert dataset_config2.path == "trl-internal-testing/zen" + assert dataset_config2.name == "standard_preference" + assert dataset_config2.columns == ["prompt"] # Columns specified def test_trlparser_parses_yaml_and_loads_dataset(self): # Prepare YAML content exactly like your example @@ -428,4 +429,4 @@ def test_trlparser_parses_yaml_and_loads_dataset(self): # Load the dataset using get_dataset result = get_dataset(args) expected = load_dataset("trl-internal-testing/zen", "standard_language_modeling") - self.assertEqual(expected["train"][:], result["train"][:]) + assert expected["train"][:] == result["train"][:] diff --git a/tests/test_collators.py b/tests/test_collators.py index b578758f027..d798f29d54d 100644 --- a/tests/test_collators.py +++ b/tests/test_collators.py @@ -26,7 +26,7 @@ def setUp(self): self.collator = DataCollatorForPreference(pad_token_id=0) def assertTensorEqual(self, tensor1, tensor2): - self.assertTrue(torch.equal(tensor1, tensor2), f"Tensors are not equal:\n{tensor1}\n{tensor2}") + assert torch.equal(tensor1, tensor2), f"Tensors are not equal:\n{tensor1}\n{tensor2}" def test_padding_behavior(self): examples = [ diff --git a/tests/test_core.py b/tests/test_core.py index bab69ca9da2..78e57c64b30 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -32,10 +32,10 @@ def setUp(self): self.test_input_unmasked = self.test_input[1:3] def test_masked_mean(self): - self.assertEqual(torch.mean(self.test_input_unmasked), masked_mean(self.test_input, self.test_mask)) + assert torch.mean(self.test_input_unmasked) == masked_mean(self.test_input, self.test_mask) def test_masked_var(self): - self.assertEqual(torch.var(self.test_input_unmasked), masked_var(self.test_input, self.test_mask)) + assert torch.var(self.test_input_unmasked) == masked_var(self.test_input, self.test_mask) def test_masked_whiten(self): def whiten(values: torch.Tensor) -> torch.Tensor: @@ -45,4 +45,4 @@ def whiten(values: torch.Tensor) -> torch.Tensor: whiten_unmasked = whiten(self.test_input_unmasked) whiten_masked = masked_whiten(self.test_input, self.test_mask)[1:3] diffs = (whiten_unmasked - whiten_masked).sum() - self.assertLess(abs(diffs.item()), 0.00001) + assert abs(diffs.item()) < 0.00001 diff --git a/tests/test_cpo_trainer.py b/tests/test_cpo_trainer.py index cc3e394846d..c0edd771c5e 100644 --- a/tests/test_cpo_trainer.py +++ b/tests/test_cpo_trainer.py @@ -87,13 +87,13 @@ def test_cpo_trainer(self, name, loss_type, config_name): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the parameters have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if param.sum() != 0: # ignore 0 biases - self.assertFalse(torch.equal(param, new_param)) + assert not torch.equal(param, new_param) @parameterized.expand( [ @@ -143,14 +143,14 @@ def test_cpo_trainer_with_lora(self, config_name): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the parameters have changed for n, param in previous_trainable_params.items(): if "lora" in n: new_param = trainer.model.get_parameter(n) if param.sum() != 0: # ignore 0 biases - self.assertFalse(torch.equal(param, new_param)) + assert not torch.equal(param, new_param) def test_compute_metrics(self): dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference") @@ -180,7 +180,7 @@ def dummy_compute_metrics(*args, **kwargs): trainer.train() - self.assertEqual(trainer.state.log_history[-2]["eval_test"], 0.0) + assert trainer.state.log_history[-2]["eval_test"] == 0.0 def test_alphapo_trainer(self): training_args = CPOConfig( @@ -212,9 +212,9 @@ def test_alphapo_trainer(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if param.sum() != 0: - self.assertFalse(torch.equal(param, new_param)) + assert not torch.equal(param, new_param) diff --git a/tests/test_data_utils.py b/tests/test_data_utils.py index 0a9eba7f7bb..88c63868094 100644 --- a/tests/test_data_utils.py +++ b/tests/test_data_utils.py @@ -55,7 +55,7 @@ def test_basic_user_assistant_conversation(self): {"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]}, ] - self.assertEqual(messages, expected) + assert messages == expected def test_first_user_message_gets_image(self): """Test that only the first user message gets an image placeholder.""" @@ -73,7 +73,7 @@ def test_first_user_message_gets_image(self): {"role": "user", "content": [{"type": "text", "text": "How about the grass?"}]}, ] - self.assertEqual(messages, expected) + assert messages == expected def test_multiple_images(self): """Test that multiple images are added to the first user message.""" @@ -97,7 +97,7 @@ def test_multiple_images(self): {"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]}, ] - self.assertEqual(messages, expected) + assert messages == expected def test_system_message_transformation(self): """Test that system messages are properly transformed.""" @@ -113,7 +113,7 @@ def test_system_message_transformation(self): {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}, ] - self.assertEqual(messages, expected) + assert messages == expected def test_already_prepared_messages_unchanged(self): """Test that messages with list content are not modified.""" @@ -126,7 +126,7 @@ def test_already_prepared_messages_unchanged(self): original = copy.deepcopy(messages) prepare_multimodal_messages(messages, num_images=1) - self.assertEqual(messages, original) + assert messages == original def test_mixed_prepared_and_unprepared_messages(self): """Test handling of mixed prepared and unprepared messages.""" @@ -144,7 +144,7 @@ def test_mixed_prepared_and_unprepared_messages(self): {"role": "user", "content": [{"type": "text", "text": "What about the grass?"}]}, ] - self.assertEqual(messages, expected) + assert messages == expected class IsConversationalTester(TrlTestCase): @@ -250,11 +250,11 @@ class IsConversationalTester(TrlTestCase): @parameterized.expand(itertools.product(conversational_examples)) def test_conversational(self, example): - self.assertTrue(is_conversational(example)) + assert is_conversational(example) @parameterized.expand(itertools.product(non_conversational_examples)) def test_non_conversational(self, example): - self.assertFalse(is_conversational(example)) + assert not is_conversational(example) class IsConversationalFromValueTester(TrlTestCase): @@ -265,7 +265,7 @@ def test_positive_1(self): {"from": "assistant", "value": "It is blue."}, ], } - self.assertTrue(is_conversational_from_value(example)) + assert is_conversational_from_value(example) def test_negative_1(self): example = { @@ -274,11 +274,11 @@ def test_negative_1(self): {"role": "assistant", "content": "It is blue."}, ], } - self.assertFalse(is_conversational_from_value(example)) + assert not is_conversational_from_value(example) def test_negative_2(self): example = {"text": "The sky is blue."} - self.assertFalse(is_conversational_from_value(example)) + assert not is_conversational_from_value(example) class ApplyChatTemplateTester(TrlTestCase): @@ -352,24 +352,24 @@ def test_apply_chat_template(self, tokenizer_id, example): result = apply_chat_template(example, tokenizer) # Checking if the result is a dictionary - self.assertIsInstance(result, dict) + assert isinstance(result, dict) # The chat template should be applied to the following keys for key in ["prompt", "chosen", "rejected", "completion"]: if key in example: - self.assertIn(key, result) - self.assertIsInstance(result[key], str) + assert key in result + assert isinstance(result[key], str) # Exception for messages, the key is "text" once the chat template is applied if "messages" in example: - self.assertIn("text", result) - self.assertIsInstance(result["text"], str) + assert "text" in result + assert isinstance(result["text"], str) # The label should be kept if "label" in example: - self.assertIn("label", result) - self.assertIsInstance(result["label"], bool) - self.assertEqual(result["label"], example["label"]) + assert "label" in result + assert isinstance(result["label"], bool) + assert result["label"] == example["label"] # both conversational and non-conversational examples @parameterized.expand(itertools.product(tokenizers, conversational_examples + non_conversational_examples)) @@ -378,24 +378,24 @@ def test_maybe_apply_chat_template(self, tokenizer_id, example): result = maybe_apply_chat_template(example, tokenizer) # Checking if the result is a dictionary - self.assertIsInstance(result, dict) + assert isinstance(result, dict) # The chat template should be applied to the following keys for key in ["prompt", "chosen", "rejected", "completion"]: if key in example: - self.assertIn(key, result) - self.assertIsInstance(result[key], str) + assert key in result + assert isinstance(result[key], str) # Exception for messages, the key is "text" once the chat template is applied if "messages" in example: - self.assertIn("text", result) - self.assertIsInstance(result["text"], str) + assert "text" in result + assert isinstance(result["text"], str) # The label should be kept if "label" in example: - self.assertIn("label", result) - self.assertIsInstance(result["label"], bool) - self.assertEqual(result["label"], example["label"]) + assert "label" in result + assert isinstance(result["label"], bool) + assert result["label"] == example["label"] def test_apply_chat_template_with_tools(self): tokenizer = AutoProcessor.from_pretrained("trl-internal-testing/tiny-LlamaForCausalLM-3.2") @@ -420,13 +420,13 @@ def get_current_temperature(location: str): result_with_tools = apply_chat_template(test_case, tokenizer, tools=[get_current_temperature]) # Verify tools are included in the output - self.assertIn("get_current_temperature", result_with_tools["prompt"]) + assert "get_current_temperature" in result_with_tools["prompt"] # Test without tools result_without_tools = apply_chat_template(test_case, tokenizer, tools=None) # Verify tools are not included in the output - self.assertNotIn("get_current_temperature", result_without_tools["prompt"]) + assert "get_current_temperature" not in result_without_tools["prompt"] class ApplyChatTemplateHarmonyTester(TrlTestCase): @@ -459,7 +459,7 @@ def test_language_modeling(self): <|end|><|start|>user<|message|>What color is the sky?<|end|><|start|>assistant<|channel|>analysis<|message|>The user asks the color of the sky...<|end|><|start|>assistant<|channel|>final<|message|>It is blue.<|return|>""") - self.assertEqual(output["text"], expected) + assert output["text"] == expected def test_prompt_only(self): messages = { @@ -489,7 +489,7 @@ def test_prompt_only(self): <|end|><|start|>user<|message|>What color is the sky?<|end|><|start|>assistant""") - self.assertEqual(output["prompt"], expected) + assert output["prompt"] == expected def test_prompt_completion(self): messages = { @@ -523,8 +523,8 @@ def test_prompt_completion(self): <|end|><|start|>user<|message|>What color is the sky?<|end|><|start|>assistant""") expected_completion = "<|channel|>analysis<|message|>The user asks the color of the sky...<|end|><|start|>assistant<|channel|>final<|message|>It is blue.<|return|>" - self.assertEqual(output["prompt"], expected_prompt) - self.assertEqual(output["completion"], expected_completion) + assert output["prompt"] == expected_prompt + assert output["completion"] == expected_completion def test_preference(self): messages = { @@ -562,9 +562,9 @@ def test_preference(self): expected_chosen = "<|channel|>analysis<|message|>The user asks the color of the sky...<|end|><|start|>assistant<|channel|>final<|message|>It is blue.<|return|>" expected_rejected = "<|channel|>analysis<|message|>The user asks the color of the tree...<|end|><|start|>assistant<|channel|>final<|message|>It is green.<|return|>" - self.assertEqual(output["prompt"], expected_prompt) - self.assertEqual(output["chosen"], expected_chosen) - self.assertEqual(output["rejected"], expected_rejected) + assert output["prompt"] == expected_prompt + assert output["chosen"] == expected_chosen + assert output["rejected"] == expected_rejected def test_preference_with_implicit_prompt(self): messages = { @@ -614,8 +614,8 @@ def test_preference_with_implicit_prompt(self): <|end|><|start|>user<|message|>What color is the sky?<|end|><|start|>assistant<|channel|>analysis<|message|>The user asks the color of the tree...<|end|><|start|>assistant<|channel|>final<|message|>It is green.<|return|>""") - self.assertEqual(output["chosen"], expected_chosen) - self.assertEqual(output["rejected"], expected_rejected) + assert output["chosen"] == expected_chosen + assert output["rejected"] == expected_rejected def test_unpaired_preference(self): messages = { @@ -650,9 +650,9 @@ def test_unpaired_preference(self): <|end|><|start|>user<|message|>What color is the sky?<|end|><|start|>assistant""") expected_completion = "<|channel|>analysis<|message|>The user asks the color of the sky...<|end|><|start|>assistant<|channel|>final<|message|>It is blue.<|return|>" - self.assertEqual(output["prompt"], expected_prompt) - self.assertEqual(output["completion"], expected_completion) - self.assertTrue(output["label"]) + assert output["prompt"] == expected_prompt + assert output["completion"] == expected_completion + assert output["label"] class UnpairPreferenceDatasetTester(TrlTestCase): @@ -675,58 +675,46 @@ class UnpairPreferenceDatasetTester(TrlTestCase): def test_unpair_preference_dataset(self): # Test that a paired dataset is correctly converted to unpaired unpaired_dataset = unpair_preference_dataset(self.paired_dataset) - self.assertEqual( - unpaired_dataset.to_dict(), - self.unpaired_dataset.to_dict(), - "The paired dataset should be converted to unpaired.", - ) + assert unpaired_dataset.to_dict() == \ + self.unpaired_dataset.to_dict(), \ + "The paired dataset should be converted to unpaired." def test_unpair_preference_dataset_dict(self): # Test that a paired dataset dict is correctly converted to unpaired paired_dataset_dict = DatasetDict({"abc": self.paired_dataset}) unpaired_dataset_dict = unpair_preference_dataset(paired_dataset_dict) - self.assertEqual( - unpaired_dataset_dict["abc"].to_dict(), - self.unpaired_dataset.to_dict(), - "The paired dataset should be converted to unpaired.", - ) + assert unpaired_dataset_dict["abc"].to_dict() == \ + self.unpaired_dataset.to_dict(), \ + "The paired dataset should be converted to unpaired." def test_maybe_unpair_preference_dataset(self): # Test that a paired dataset is correctly converted to unpaired with maybe_unpair_preference_dataset unpaired_dataset = maybe_unpair_preference_dataset(self.paired_dataset) - self.assertEqual( - unpaired_dataset.to_dict(), - self.unpaired_dataset.to_dict(), - "The paired dataset should be converted to unpaired.", - ) + assert unpaired_dataset.to_dict() == \ + self.unpaired_dataset.to_dict(), \ + "The paired dataset should be converted to unpaired." def test_maybe_unpair_preference_dataset_dict(self): # Test that a paired dataset dict is correctly converted to unpaired with maybe_unpair_preference_dataset paired_dataset_dict = DatasetDict({"abc": self.paired_dataset}) unpaired_dataset_dict = maybe_unpair_preference_dataset(paired_dataset_dict) - self.assertEqual( - unpaired_dataset_dict["abc"].to_dict(), - self.unpaired_dataset.to_dict(), - "The paired dataset should be converted to unpaired.", - ) + assert unpaired_dataset_dict["abc"].to_dict() == \ + self.unpaired_dataset.to_dict(), \ + "The paired dataset should be converted to unpaired." def test_maybe_unpair_preference_dataset_already_paired(self): # Test that a paired dataset remains unchanged with maybe_unpair_preference_dataset unpaired_dataset = maybe_unpair_preference_dataset(self.unpaired_dataset) - self.assertEqual( - unpaired_dataset.to_dict(), - self.unpaired_dataset.to_dict(), - "The unpaired dataset should remain unchanged.", - ) + assert unpaired_dataset.to_dict() == \ + self.unpaired_dataset.to_dict(), \ + "The unpaired dataset should remain unchanged." def test_maybe_unpair_preference_dataset_dict_already_paired(self): # Test that a paired dataset dict remains unchanged with maybe_unpair_preference_dataset unpaired_dataset_dict = maybe_unpair_preference_dataset(DatasetDict({"abc": self.unpaired_dataset})) - self.assertEqual( - unpaired_dataset_dict["abc"].to_dict(), - self.unpaired_dataset.to_dict(), - "The unpaired dataset should remain unchanged.", - ) + assert unpaired_dataset_dict["abc"].to_dict() == \ + self.unpaired_dataset.to_dict(), \ + "The unpaired dataset should remain unchanged." class ExtractPromptTester(TrlTestCase): @@ -767,56 +755,44 @@ class ExtractPromptTester(TrlTestCase): def test_extract_prompt_conversational(self): # Test that the prompt is correctly extracted from the dataset example_extracted_prompt = extract_prompt(self.example_implicit_prompt_conversational) - self.assertEqual( - example_extracted_prompt, - self.example_explicit_prompt_conversational, - "The prompt is not correctly extracted from the dataset.", - ) + assert example_extracted_prompt == \ + self.example_explicit_prompt_conversational, \ + "The prompt is not correctly extracted from the dataset." def test_maybe_extract_prompt_conversational(self): # Test that the prompt is correctly extracted from the dataset with maybe_extract_prompt example_extracted_prompt = maybe_extract_prompt(self.example_implicit_prompt_conversational) - self.assertEqual( - example_extracted_prompt, - self.example_explicit_prompt_conversational, - "The prompt is not correctly extracted from the dataset.", - ) + assert example_extracted_prompt == \ + self.example_explicit_prompt_conversational, \ + "The prompt is not correctly extracted from the dataset." def test_maybe_extract_prompt_conversational_already_explicit(self): # Test that the prompt remains unchanged with maybe_extract_prompt example_extracted_prompt = maybe_extract_prompt(self.example_explicit_prompt_conversational) - self.assertEqual( - example_extracted_prompt, - self.example_explicit_prompt_conversational, - "The prompt should remain unchanged.", - ) + assert example_extracted_prompt == \ + self.example_explicit_prompt_conversational, \ + "The prompt should remain unchanged." def test_extract_prompt_standard(self): # Test that the prompt is correctly extracted from the dataset example_extracted_prompt = extract_prompt(self.example_implicit_prompt_standard) - self.assertEqual( - example_extracted_prompt, - self.example_explicit_prompt_standard, - "The prompt is not correctly extracted from the dataset.", - ) + assert example_extracted_prompt == \ + self.example_explicit_prompt_standard, \ + "The prompt is not correctly extracted from the dataset." def test_maybe_extract_prompt_standard(self): # Test that the prompt is correctly extracted from the dataset with maybe_extract_prompt example_extracted_prompt = maybe_extract_prompt(self.example_implicit_prompt_standard) - self.assertEqual( - example_extracted_prompt, - self.example_explicit_prompt_standard, - "The prompt is not correctly extracted from the dataset.", - ) + assert example_extracted_prompt == \ + self.example_explicit_prompt_standard, \ + "The prompt is not correctly extracted from the dataset." def test_maybe_extract_prompt_standard_already_explicit(self): # Test that the prompt remains unchanged with maybe_extract_prompt example_extracted_prompt = maybe_extract_prompt(self.example_explicit_prompt_standard) - self.assertEqual( - example_extracted_prompt, - self.example_explicit_prompt_standard, - "The prompt should remain unchanged.", - ) + assert example_extracted_prompt == \ + self.example_explicit_prompt_standard, \ + "The prompt should remain unchanged." class TestPackDatasetWrapped(TrlTestCase): @@ -832,7 +808,7 @@ def test_with_dataset(self): "attention_mask": [[0, 1, 1], [0, 0, 1], [1, 1]], } dataset = pack_dataset(dataset, seq_length, strategy="wrapped") - self.assertEqual(dataset.to_dict(), expected_output) + assert dataset.to_dict() == expected_output def test_with_iterable_dataset(self): examples = { @@ -847,7 +823,7 @@ def test_with_iterable_dataset(self): } dataset = pack_dataset(dataset, seq_length, strategy="wrapped") num_examples = len(examples[next(iter(examples))]) - self.assertEqual(next(iter(dataset.batch(batch_size=num_examples))), expected_output) + assert next(iter(dataset.batch(batch_size=num_examples))) == expected_output class TestPackDatasetBfd(TrlTestCase): @@ -864,7 +840,7 @@ def test_simple(self): "seq_lengths": [[4], [3, 1]], } dataset = pack_dataset(dataset, seq_length, strategy="bfd") - self.assertEqual(dataset.to_dict(), expected_output) + assert dataset.to_dict() == expected_output def test_with_iterable_dataset(self): examples = { @@ -880,7 +856,7 @@ def test_with_iterable_dataset(self): } dataset = pack_dataset(dataset, seq_length, strategy="bfd") num_examples = len(examples[next(iter(examples))]) - self.assertEqual(next(iter(dataset.batch(batch_size=num_examples))), expected_output) + assert next(iter(dataset.batch(batch_size=num_examples))) == expected_output def test_with_truncation(self): examples = { @@ -895,7 +871,7 @@ def test_with_truncation(self): "seq_lengths": [[4], [4], [2, 1]], } dataset = pack_dataset(dataset, seq_length, strategy="bfd") - self.assertEqual(dataset.to_dict(), expected_output) + assert dataset.to_dict() == expected_output def test_with_non_power_of_2(self): examples = { @@ -910,7 +886,7 @@ def test_with_non_power_of_2(self): "seq_lengths": [[5], [4, 1], [3]], } dataset = pack_dataset(dataset, seq_length, strategy="bfd") - self.assertEqual(dataset.to_dict(), expected_output) + assert dataset.to_dict() == expected_output class TestTruncateExamples(TrlTestCase): @@ -926,7 +902,7 @@ def test_with_dataset(self): "attention_mask": [[0, 1], [0, 0], [1]], } dataset = truncate_dataset(dataset, max_length) - self.assertEqual(dataset.to_dict(), expected_output) + assert dataset.to_dict() == expected_output def test_with_iterable_dataset(self): examples = { @@ -941,7 +917,7 @@ def test_with_iterable_dataset(self): } dataset = truncate_dataset(dataset, max_length) num_examples = len(examples[next(iter(examples))]) - self.assertEqual(next(iter(dataset.batch(batch_size=num_examples))), expected_output) + assert next(iter(dataset.batch(batch_size=num_examples))) == expected_output def test_with_extra_column(self): examples = { @@ -957,7 +933,7 @@ def test_with_extra_column(self): "my_column": ["a", "b", "c"], } dataset = truncate_dataset(dataset, max_length) - self.assertEqual(dataset.to_dict(), expected_output) + assert dataset.to_dict() == expected_output class TestMaybeConvertToChatML(TrlTestCase): @@ -975,7 +951,7 @@ def test_with_conversations_key(self): {"role": "assistant", "content": "It is blue."}, ] } - self.assertEqual(maybe_convert_to_chatml(example), expected_output) + assert maybe_convert_to_chatml(example) == expected_output def test_without_conversations_key(self): # Same as before, but we don't rename the keys @@ -987,12 +963,12 @@ def test_without_conversations_key(self): "prompt": [{"role": "user", "content": "What color is the sky?"}], "completion": [{"role": "assistant", "content": "It is blue."}], } - self.assertEqual(maybe_convert_to_chatml(example), expected_output) + assert maybe_convert_to_chatml(example) == expected_output def test_not_conversional(self): # When not needed, the example should remain unchanged example = {"text": "The sky is blue."} - self.assertEqual(maybe_convert_to_chatml(example), example) + assert maybe_convert_to_chatml(example) == example def test_already_chatml(self): # When the example is already in ChatML format, it should remain unchanged @@ -1002,7 +978,7 @@ def test_already_chatml(self): {"role": "assistant", "content": "It is blue."}, ] } - self.assertEqual(maybe_convert_to_chatml(example), example) + assert maybe_convert_to_chatml(example) == example # Run the tests diff --git a/tests/test_dataset_formatting.py b/tests/test_dataset_formatting.py index c85845e34c3..cc132aedefb 100644 --- a/tests/test_dataset_formatting.py +++ b/tests/test_dataset_formatting.py @@ -44,20 +44,20 @@ def test_get_formatting_func_from_dataset_with_chatml_messages(self): # Llama tokenizer formatting_func = get_formatting_func_from_dataset(dataset, self.llama_tokenizer) - self.assertIsInstance(formatting_func, Callable) + assert isinstance(formatting_func, Callable) formatted_text = formatting_func(dataset[0]) expected = " [INST] You are helpful\n\nHello [/INST] Hi, how can I help you?" - self.assertEqual(formatted_text, expected) + assert formatted_text == expected formatted_text = formatting_func(dataset[0:1]) - self.assertListEqual(formatted_text, [expected]) + assert formatted_text == [expected] # ChatML tokenizer formatting_func = get_formatting_func_from_dataset(dataset, self.chatml_tokenizer) formatted_text = formatting_func(dataset[0]) expected = "<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi, how can I help you?<|im_end|>\n" - self.assertEqual(formatted_text, expected) + assert formatted_text == expected formatted_text = formatting_func(dataset[0:1]) - self.assertListEqual(formatted_text, [expected]) + assert formatted_text == [expected] def test_get_formatting_func_from_dataset_with_chatml_conversations(self): dataset = Dataset.from_dict( @@ -73,48 +73,48 @@ def test_get_formatting_func_from_dataset_with_chatml_conversations(self): ) # Llama tokenizer formatting_func = get_formatting_func_from_dataset(dataset, self.llama_tokenizer) - self.assertIsInstance(formatting_func, Callable) + assert isinstance(formatting_func, Callable) formatted_text = formatting_func(dataset[0]) expected = " [INST] You are helpful\n\nHello [/INST] Hi, how can I help you?" - self.assertEqual(formatted_text, expected) + assert formatted_text == expected formatted_text = formatting_func(dataset[0:1]) - self.assertListEqual(formatted_text, [expected]) + assert formatted_text == [expected] # ChatML tokenizer formatting_func = get_formatting_func_from_dataset(dataset, self.chatml_tokenizer) formatted_text = formatting_func(dataset[0]) expected = "<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi, how can I help you?<|im_end|>\n" - self.assertEqual(formatted_text, expected) + assert formatted_text == expected formatted_text = formatting_func(dataset[0:1]) - self.assertListEqual(formatted_text, [expected]) + assert formatted_text == [expected] def test_get_formatting_func_from_dataset_with_instruction(self): dataset = Dataset.from_list( [{"prompt": "What is 2+2?", "completion": "4"}, {"prompt": "What is 3+3?", "completion": "6"}] ) formatting_func = get_formatting_func_from_dataset(dataset, self.llama_tokenizer) - self.assertIsNotNone(formatting_func) - self.assertIsInstance(formatting_func, Callable) + assert formatting_func is not None + assert isinstance(formatting_func, Callable) formatted_text = formatting_func(dataset[0]) - self.assertEqual(formatted_text, " [INST] What is 2+2? [/INST] 4") + assert formatted_text == " [INST] What is 2+2? [/INST] 4" formatted_text = formatting_func(dataset[0:1]) - self.assertListEqual(formatted_text, [" [INST] What is 2+2? [/INST] 4"]) + assert formatted_text == [" [INST] What is 2+2? [/INST] 4"] def test_get_formatting_func_from_dataset_from_hub(self): ds_1 = load_dataset("philschmid/trl-test-instruction", split="train") ds_2 = load_dataset("philschmid/dolly-15k-oai-style", split="train") for ds in [ds_1, ds_2]: formatting_func = get_formatting_func_from_dataset(ds, self.llama_tokenizer) - self.assertIsNotNone(formatting_func) - self.assertIsInstance(formatting_func, Callable) + assert formatting_func is not None + assert isinstance(formatting_func, Callable) ds_3 = load_dataset("philschmid/guanaco-sharegpt-style", split="train") formatting_func = get_formatting_func_from_dataset(ds_3, self.llama_tokenizer) - self.assertIsNone(formatting_func) + assert formatting_func is None def test_get_formatting_func_from_dataset_with_unknown_format(self): dataset = Dataset.from_dict({"text": "test"}) formatting_func = get_formatting_func_from_dataset(dataset, self.llama_tokenizer) - self.assertIsNone(formatting_func) + assert formatting_func is None class SetupChatFormatTestCase(TrlTestCase): @@ -132,13 +132,13 @@ def test_setup_chat_format(self): _chatml = ChatMlSpecialTokens() # Check if special tokens are correctly set - self.assertEqual(modified_tokenizer.eos_token, "<|im_end|>") - self.assertEqual(modified_tokenizer.pad_token, "<|im_end|>") - self.assertEqual(modified_tokenizer.bos_token, "<|im_start|>") - self.assertEqual(modified_tokenizer.eos_token, _chatml.eos_token) - self.assertEqual(modified_tokenizer.pad_token, _chatml.pad_token) - self.assertEqual(modified_tokenizer.bos_token, _chatml.bos_token) - self.assertEqual((modified_model.vocab_size % 123), 0) + assert modified_tokenizer.eos_token == "<|im_end|>" + assert modified_tokenizer.pad_token == "<|im_end|>" + assert modified_tokenizer.bos_token == "<|im_start|>" + assert modified_tokenizer.eos_token == _chatml.eos_token + assert modified_tokenizer.pad_token == _chatml.pad_token + assert modified_tokenizer.bos_token == _chatml.bos_token + assert (modified_model.vocab_size % 123) == 0 def test_example_with_setup_model(self): modified_model, modified_tokenizer = setup_chat_format( @@ -152,10 +152,8 @@ def test_example_with_setup_model(self): ] prompt = modified_tokenizer.apply_chat_template(messages, tokenize=False) - self.assertEqual( - prompt, - "<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi, how can I help you?<|im_end|>\n", - ) + assert prompt == \ + "<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi, how can I help you?<|im_end|>\n" class CloneChatTemplateTestCase(TrlTestCase): @@ -168,7 +166,7 @@ def test_clone(self): _, modified_tokenizer, _ = clone_chat_template(model, tokenizer, source) # Check if special tokens are correctly set - self.assertEqual(modified_tokenizer.eos_token, "<|im_end|>") + assert modified_tokenizer.eos_token == "<|im_end|>" def test_clone_with_resize(self): # This tokenizer doesn't have a chat_template by default @@ -181,9 +179,9 @@ def test_clone_with_resize(self): ) # Check that the input embeddings have been resized to a multiple of 123 - self.assertEqual((modified_model.vocab_size % 123), 0) + assert (modified_model.vocab_size % 123) == 0 # Check that the input embeddings size matches the tokenizer vocabulary size - self.assertEqual(model.vocab_size, len(modified_tokenizer.vocab)) + assert model.vocab_size == len(modified_tokenizer.vocab) def test_clone_with_resize_and_extra_tokens_already_in_vocab(self): # This tokenizer doesn't have a chat_template by default @@ -201,9 +199,9 @@ def test_clone_with_resize_and_extra_tokens_already_in_vocab(self): ) # Check that the input embeddings have been resized to a multiple of 123 - self.assertEqual((modified_model.vocab_size % 124), 0) + assert (modified_model.vocab_size % 124) == 0 # Check that the input embeddings size matches the tokenizer vocabulary size - self.assertEqual(model.vocab_size, len(modified_tokenizer.vocab)) + assert model.vocab_size == len(modified_tokenizer.vocab) def test_apply_new_chat_template(self): # This tokenizer doesn't have a chat_template by default @@ -219,10 +217,8 @@ def test_apply_new_chat_template(self): ] prompt = modified_tokenizer.apply_chat_template(messages, tokenize=False) - self.assertEqual( - prompt, - "<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n\n\n\n\nHi, how can I help you?<|im_end|>\n", - ) + assert prompt == \ + "<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n\n\n\n\nHi, how can I help you?<|im_end|>\n" def test_clone_with_sequence_classification_model(self): # This tokenizer doesn't have a chat_template by default @@ -235,4 +231,4 @@ def test_clone_with_sequence_classification_model(self): _, modified_tokenizer, _ = clone_chat_template(model, tokenizer, source) # Check if special tokens are correctly set - self.assertEqual(modified_tokenizer.eos_token, "<|im_end|>") + assert modified_tokenizer.eos_token == "<|im_end|>" diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index 6a4bfd22301..f429726c20b 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -40,6 +40,7 @@ from trl import DPOConfig, DPOTrainer, FDivergenceType from .testing_utils import TrlTestCase, require_bitsandbytes, require_no_wandb +import pytest if is_vision_available(): @@ -84,14 +85,12 @@ def test_tokenize_row_no_truncation_no_special_tokens(self): ) # Assert the correct output without truncation or special tokens - self.assertEqual( - result, + assert result == \ { "prompt_input_ids": [464, 6766, 318], "chosen_input_ids": [4171, 2], # eos_token added "rejected_input_ids": [4077, 2], # eos_token added - }, - ) + } def test_tokenize_row_with_truncation(self): # Define the input features @@ -107,14 +106,12 @@ def test_tokenize_row_with_truncation(self): ) # Assert the correct output with truncation applied - self.assertEqual( - result, + assert result == \ { "prompt_input_ids": [6766, 318], # truncated to the last 2 tokens "chosen_input_ids": [4171], # truncated to 1 token "rejected_input_ids": [4077], # truncated to 1 token - }, - ) + } def test_tokenize_row_with_special_tokens(self): # Define the input features @@ -130,14 +127,12 @@ def test_tokenize_row_with_special_tokens(self): ) # Assert the correct output with special tokens added - self.assertEqual( - result, + assert result == \ { "prompt_input_ids": [0, 464, 6766, 318, 2], # bos_token and eos_token added "chosen_input_ids": [4171, 2], # eos_token added "rejected_input_ids": [4077, 2], # eos_token added - }, - ) + } def test_tokenize_row_with_truncation_and_special_tokens(self): # Define the input features @@ -153,14 +148,12 @@ def test_tokenize_row_with_truncation_and_special_tokens(self): ) # Assert the correct output with both truncation and special tokens - self.assertEqual( - result, + assert result == \ { "prompt_input_ids": [464, 6766, 318, 2], # truncated to 4 tokens with bos_token and eos_token "chosen_input_ids": [4171], # truncated to 1 token "rejected_input_ids": [4077], # truncated to 1 token - }, - ) + } class DPOTrainerTester(TrlTestCase): @@ -193,13 +186,13 @@ def test_train(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the parameters have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if param.sum() != 0: # ignore 0 biases - self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)) + assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12) @parameterized.expand( [ @@ -241,13 +234,13 @@ def test_train_loss_types(self, loss_type): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the parameters have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if param.sum() != 0: # ignore 0 biases - self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)) + assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12) @require_liger_kernel def test_train_encoder_decoder_liger(self): @@ -274,13 +267,13 @@ def test_train_encoder_decoder_liger(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the parameters have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if param.sum() != 0: # ignore 0 biases - self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)) + assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12) def test_dpo_trainer_with_weighting(self): dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train") @@ -304,13 +297,13 @@ def test_dpo_trainer_with_weighting(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the parameters have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if param.sum() != 0: # ignore 0 biases - self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)) + assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12) def test_train_with_multiple_loss_types(self): """ @@ -338,22 +331,22 @@ def test_train_with_multiple_loss_types(self): # Test that training works trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Verify SFT loss is computed in the first test too with torch.no_grad(): batch = next(iter(trainer.get_train_dataloader())) loss, metrics = trainer.get_batch_loss_metrics(trainer.model, batch) - self.assertIn("nll_loss", metrics) # SFT loss should be computed + assert "nll_loss" in metrics # SFT loss should be computed def test_wrong_loss_weights_length(self): - with self.assertRaises(ValueError) as context: + with pytest.raises(ValueError) as context: DPOConfig( output_dir=self.tmp_dir, loss_type=["sigmoid", "bco_pair"], loss_weights=[1.0, 0.5, 0.1], # Wrong length ) - self.assertIn("Length of loss_weights list", str(context.exception)) + assert "Length of loss_weights list" in str(context.exception) @parameterized.expand([(None,), (0.5,)]) def test_dpo_trainer_without_providing_ref_model(self, rpo_alpha): @@ -386,13 +379,13 @@ def test_dpo_trainer_without_providing_ref_model(self, rpo_alpha): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the parameters have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if param.sum() != 0: # ignore 0 biases - self.assertFalse(torch.equal(param, new_param)) + assert not torch.equal(param, new_param) def test_dpo_trainer_with_ref_model_is_model(self): training_args = DPOConfig( @@ -404,7 +397,7 @@ def test_dpo_trainer_with_ref_model_is_model(self): dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference") - with self.assertRaises(ValueError): + with pytest.raises(ValueError): DPOTrainer( model=self.model, ref_model=self.model, # ref_model can't be the same as model @@ -437,13 +430,13 @@ def test_precompute_ref_batch_size(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the parameters have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if param.sum() != 0: # ignore 0 biases - self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)) + assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12) @require_peft def test_dpo_trainer_without_providing_ref_model_with_lora(self): @@ -486,14 +479,14 @@ def test_dpo_trainer_without_providing_ref_model_with_lora(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the parameters have changed for n, param in previous_trainable_params.items(): if "lora" in n: new_param = trainer.model.get_parameter(n) if param.sum() != 0: # ignore 0 biases - self.assertFalse(torch.equal(param, new_param)) + assert not torch.equal(param, new_param) def test_dpo_trainer_w_dataset_num_proc(self): training_args = DPOConfig( @@ -555,13 +548,13 @@ def test_tr_dpo_trainer(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the parameters have changed for n, param in previous_trainable_params.items(): new_param = trainer.ref_model.get_parameter(n) if param.sum() != 0: # ignore 0 biases - self.assertFalse(torch.equal(param, new_param)) + assert not torch.equal(param, new_param) @require_no_wandb def test_dpo_trainer_generate_during_eval_no_wandb(self): @@ -580,11 +573,8 @@ def test_dpo_trainer_generate_during_eval_no_wandb(self): dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference") - with self.assertRaisesRegex( - ValueError, - expected_regex="`generate_during_eval=True` requires Weights and Biases, MLFlow or Comet to be installed." - " Please install `wandb`, `mlflow` or `comet-ml` to resolve.", - ): + with pytest.raises(ValueError, match="`generate_during_eval=True` requires Weights and Biases, MLFlow or Comet to be installed." + " Please install `wandb`, `mlflow` or `comet-ml` to resolve."): DPOTrainer( model=self.model, ref_model=None, @@ -645,7 +635,7 @@ def test_dpo_lora_save(self): try: AutoModelForCausalLM.from_pretrained(self.tmp_dir) except OSError: - self.fail("Loading the saved peft adapter failed") + pytest.fail("Loading the saved peft adapter failed") @require_peft @require_torch_gpu_if_bnb_not_multi_backend_enabled @@ -826,7 +816,7 @@ def test_dpo_lora_tags(self): ) for tag in ["dpo", "trl"]: - self.assertIn(tag, trainer.model.model_tags) + assert tag in trainer.model.model_tags @require_peft def test_dpo_tags(self): @@ -861,7 +851,7 @@ def test_dpo_tags(self): ) for tag in ["dpo", "trl"]: - self.assertIn(tag, trainer.model.model_tags) + assert tag in trainer.model.model_tags @require_peft def test_dpo_lora_force_use_ref(self): @@ -895,7 +885,7 @@ def test_dpo_lora_force_use_ref(self): dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference") - with self.assertRaises(ValueError): + with pytest.raises(ValueError): # passing a peft_model as model and ref_model should error out, # unless you pass `force_use_ref_model` trainer = DPOTrainer( @@ -953,8 +943,8 @@ def test_dpo_trainer_dtype(self): args=training_args, train_dataset=dummy_dataset["train"], ) - self.assertEqual(trainer.model.config.dtype, torch.float16) - self.assertEqual(trainer.ref_model.config.dtype, torch.float16) + assert trainer.model.config.dtype == torch.float16 + assert trainer.ref_model.config.dtype == torch.float16 # Now test when `dtype` is provided but is wrong to either the model or the ref_model training_args = DPOConfig( @@ -965,7 +955,7 @@ def test_dpo_trainer_dtype(self): report_to="none", ) - with self.assertRaises(ValueError) as context: + with pytest.raises(ValueError) as context: _ = DPOTrainer( model=self.model_id, processing_class=self.tokenizer, @@ -973,11 +963,9 @@ def test_dpo_trainer_dtype(self): train_dataset=dummy_dataset["train"], ) - self.assertIn( - "Invalid `dtype` passed to the config. Expected either 'auto' or a string representing a valid " - "`torch.dtype` (e.g., 'float32'), but got -1.", - str(context.exception), - ) + assert "Invalid `dtype` passed to the config. Expected either 'auto' or a string representing a valid " \ + "`torch.dtype` (e.g., 'float32'), but got -1." in \ + str(context.exception) training_args = DPOConfig( output_dir=self.tmp_dir, @@ -987,7 +975,7 @@ def test_dpo_trainer_dtype(self): report_to="none", ) - with self.assertRaises(ValueError) as context: + with pytest.raises(ValueError) as context: _ = DPOTrainer( model=self.model_id, ref_model=self.model_id, @@ -996,11 +984,9 @@ def test_dpo_trainer_dtype(self): train_dataset=dummy_dataset["train"], ) - self.assertIn( - "Invalid `dtype` passed to the config. Expected either 'auto' or a string representing a valid " - "`torch.dtype` (e.g., 'float32'), but got -1.", - str(context.exception), - ) + assert "Invalid `dtype` passed to the config. Expected either 'auto' or a string representing a valid " \ + "`torch.dtype` (e.g., 'float32'), but got -1." in \ + str(context.exception) def test_dpo_loss_alpha_div_f(self): model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" @@ -1041,7 +1027,7 @@ def test_dpo_loss_alpha_div_f(self): losses, _, _ = trainer.dpo_loss( policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps ) - self.assertTrue(torch.isfinite(losses).cpu().numpy().all()) + assert torch.isfinite(losses).cpu().numpy().all() def test_dpo_loss_js_div_f(self): model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" @@ -1083,7 +1069,7 @@ def test_dpo_loss_js_div_f(self): losses, _, _ = trainer.dpo_loss( policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps ) - self.assertTrue(torch.isfinite(losses).cpu().numpy().all()) + assert torch.isfinite(losses).cpu().numpy().all() def test_dpo_trainer_use_logits_to_keep(self): model_id = "trl-internal-testing/tiny-LlamaForCausalLM-3.2" @@ -1199,7 +1185,7 @@ def get_current_temperature(location: str): # We don't run the training, but at this stage, the dataset is supposed to be pre-processed. When # pre-processing, we expect the available tools to be explicitly mentioned in the system prompt. That's # what we're checking here - self.assertIn("get_current_temperature", tokenizer.decode(trainer.train_dataset["prompt_input_ids"][0])) + assert "get_current_temperature" in tokenizer.decode(trainer.train_dataset["prompt_input_ids"][0]) def test_padding_free(self): model_id = "trl-internal-testing/tiny-LlamaForCausalLM-3.2" @@ -1235,7 +1221,7 @@ def test_padding_free(self): for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if param.sum() != 0: # ignore 0 biases - self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)) + assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12) def test_compute_metrics(self): model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") @@ -1270,7 +1256,7 @@ def dummy_compute_metrics(*args, **kwargs): trainer.train() - self.assertEqual(trainer.state.log_history[-2]["eval_test"], 0.0) + assert trainer.state.log_history[-2]["eval_test"] == 0.0 def test_train_with_length_desensitization(self): model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" @@ -1295,13 +1281,13 @@ def test_train_with_length_desensitization(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the parameters have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if param.sum() != 0: # ignore 0 biases - self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)) + assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12) @unittest.skipUnless(sys.version_info >= (3, 10), "Liger kernel is not supported on Python 3.9") @parameterized.expand( @@ -1359,20 +1345,20 @@ def test_dpo_trainer_with_liger(self, beta, loss_type): train_output = trainer.train() # Verify training completed successfully - self.assertIsNotNone(train_output) - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert train_output is not None + assert trainer.state.log_history[-1]["train_loss"] is not None # Verify loss is finite - self.assertTrue(np.isfinite(trainer.state.log_history[-1]["train_loss"])) + assert np.isfinite(trainer.state.log_history[-1]["train_loss"]) # Check parameters have been updated for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) # Only check non-zero parameters if param.sum() != 0: - self.assertFalse(torch.equal(param, new_param)) + assert not torch.equal(param, new_param) # Verify new parameters are finite - self.assertTrue(torch.isfinite(new_param).all()) + assert torch.isfinite(new_param).all() # Verify model can still do forward pass after training dummy_batch = next(iter(trainer.get_train_dataloader())) @@ -1382,8 +1368,8 @@ def test_dpo_trainer_with_liger(self, beta, loss_type): } with torch.no_grad(): output = trainer.model(**model_inputs) - self.assertIsNotNone(output) - self.assertFalse("loss" in output.keys()) + assert output is not None + assert not ("loss" in output.keys()) def test_train_with_iterable_dataset(self): model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" @@ -1411,13 +1397,13 @@ def test_train_with_iterable_dataset(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the parameters have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if param.sum() != 0: # ignore 0 biases - self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)) + assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12) @require_vision @@ -1494,7 +1480,7 @@ def test_vdpo_trainer(self, model_id): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the trainable params have changed for n, param in previous_trainable_params.items(): @@ -1510,7 +1496,7 @@ def test_vdpo_trainer(self, model_id): # For some reason, these params are not updated. This is probably not related to TRL, but to # the model itself. We should investigate this further, but for now we just skip these params. continue - self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated") + assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated" class TestDPOConfig(TrlTestCase): diff --git a/tests/test_gkd_trainer.py b/tests/test_gkd_trainer.py index 4a0d458440c..11ddcaedc81 100644 --- a/tests/test_gkd_trainer.py +++ b/tests/test_gkd_trainer.py @@ -70,10 +70,8 @@ def test_generate_on_policy_outputs_deterministic(self): # Check if the generated texts start with the original prompts for prompt, generated_text in zip(prompts, generated_texts): - self.assertTrue( - generated_text.startswith(prompt), - f"Generated text '{generated_text}' does not start with prompt '{prompt}'", - ) + assert generated_text.startswith(prompt), \ + f"Generated text '{generated_text}' does not start with prompt '{prompt}'" # Run the generation twice and check if the outputs are identical outputs2 = GKDTrainer.generate_on_policy_outputs( @@ -83,15 +81,11 @@ def test_generate_on_policy_outputs_deterministic(self): new_input_ids2, new_attention_mask2, new_labels2 = outputs2 # Check if the two generations are identical - self.assertTrue(torch.all(new_input_ids.eq(new_input_ids2)), "Deterministic generations are not identical") - self.assertTrue( - torch.all(new_attention_mask.eq(new_attention_mask2)), - "Attention masks for deterministic generations are not identical", - ) - self.assertTrue( - torch.all(new_labels.eq(new_labels2)), - "Labels for deterministic generations are not identical", - ) + assert torch.all(new_input_ids.eq(new_input_ids2)), "Deterministic generations are not identical" + assert torch.all(new_attention_mask.eq(new_attention_mask2)), \ + "Attention masks for deterministic generations are not identical" + assert torch.all(new_labels.eq(new_labels2)), \ + "Labels for deterministic generations are not identical" def test_generate_on_policy_outputs(self): prompts = ["Hello, how are you?", "What's the weather like today?"] @@ -107,25 +101,25 @@ def test_generate_on_policy_outputs(self): ) # Check that outputs is a tuple of three tensors - self.assertIsInstance(outputs, tuple) - self.assertEqual(len(outputs), 3) + assert isinstance(outputs, tuple) + assert len(outputs) == 3 new_input_ids, new_attention_mask, new_labels = outputs # Check shapes batch_size = len(prompts) - self.assertEqual(new_input_ids.shape[0], batch_size) - self.assertEqual(new_attention_mask.shape[0], batch_size) - self.assertEqual(new_labels.shape[0], batch_size) + assert new_input_ids.shape[0] == batch_size + assert new_attention_mask.shape[0] == batch_size + assert new_labels.shape[0] == batch_size # Check types - self.assertIsInstance(new_input_ids, torch.Tensor) - self.assertIsInstance(new_attention_mask, torch.Tensor) - self.assertIsInstance(new_labels, torch.Tensor) + assert isinstance(new_input_ids, torch.Tensor) + assert isinstance(new_attention_mask, torch.Tensor) + assert isinstance(new_labels, torch.Tensor) # Check that new_input_ids and new_attention_mask have the same shape - self.assertEqual(new_input_ids.shape, new_attention_mask.shape) - self.assertEqual(new_labels.shape, new_attention_mask.shape) + assert new_input_ids.shape == new_attention_mask.shape + assert new_labels.shape == new_attention_mask.shape class TestGeneralizedJSDLoss(TrlTestCase): @@ -140,7 +134,7 @@ def setUp(self): def test_uniform_distribution(self): logits = torch.ones(1, 1, self.vocab_size) loss = GKDTrainer.generalized_jsd_loss(logits, logits) - self.assertAlmostEqual(loss.item(), 0, places=5) + assert round(abs(loss.item()-0), 5) == 0 def test_generalized_jsd_loss_edge_cases(self): # Setup @@ -152,29 +146,29 @@ def test_generalized_jsd_loss_edge_cases(self): expected_loss_beta_1 = F.kl_div( F.log_softmax(teacher_logits, dim=-1), F.softmax(student_logits, dim=-1), reduction="batchmean" ) - self.assertAlmostEqual(loss_beta_1.item(), expected_loss_beta_1.item(), places=5) + assert round(abs(loss_beta_1.item()-expected_loss_beta_1.item()), 5) == 0 # Case 2: beta = 0 (should be equivalent to KL(teacher || student)) loss_beta_0 = GKDTrainer.generalized_jsd_loss(student_logits, teacher_logits, beta=0) expected_loss_beta_0 = F.kl_div( F.log_softmax(student_logits, dim=-1), F.softmax(teacher_logits, dim=-1), reduction="batchmean" ) - self.assertAlmostEqual(loss_beta_0.item(), expected_loss_beta_0.item(), places=5) + assert round(abs(loss_beta_0.item()-expected_loss_beta_0.item()), 5) == 0 def test_output_shape(self): loss = GKDTrainer.generalized_jsd_loss(self.student_logits, self.teacher_logits) - self.assertTrue(torch.is_tensor(loss)) - self.assertEqual(loss.shape, torch.Size([])) + assert torch.is_tensor(loss) + assert loss.shape == torch.Size([]) def test_beta_values(self): loss_beta_0 = GKDTrainer.generalized_jsd_loss(self.student_logits, self.teacher_logits, beta=0) loss_beta_1 = GKDTrainer.generalized_jsd_loss(self.student_logits, self.teacher_logits, beta=1) - self.assertNotEqual(loss_beta_0, loss_beta_1) + assert loss_beta_0 != loss_beta_1 def test_temperature_scaling(self): loss_temp_1 = GKDTrainer.generalized_jsd_loss(self.student_logits, self.teacher_logits, temperature=1) loss_temp_2 = GKDTrainer.generalized_jsd_loss(self.student_logits, self.teacher_logits, temperature=2) - self.assertNotEqual(loss_temp_1, loss_temp_2) + assert loss_temp_1 != loss_temp_2 def test_reduction_methods(self): loss_batchmean = GKDTrainer.generalized_jsd_loss( @@ -184,24 +178,24 @@ def test_reduction_methods(self): loss_mean = GKDTrainer.generalized_jsd_loss(self.student_logits, self.teacher_logits, reduction="mean") loss_none = GKDTrainer.generalized_jsd_loss(self.student_logits, self.teacher_logits, reduction="none") - self.assertEqual(loss_batchmean.shape, torch.Size([])) - self.assertEqual(loss_sum.shape, torch.Size([])) - self.assertEqual(loss_mean.shape, torch.Size([])) - self.assertEqual(loss_none.shape, self.student_logits.shape) + assert loss_batchmean.shape == torch.Size([]) + assert loss_sum.shape == torch.Size([]) + assert loss_mean.shape == torch.Size([]) + assert loss_none.shape == self.student_logits.shape def test_symmetry(self): student_teacher = GKDTrainer.generalized_jsd_loss(self.student_logits, self.teacher_logits, beta=0.1) teacher_student = GKDTrainer.generalized_jsd_loss(self.teacher_logits, self.student_logits, beta=0.1) - self.assertNotEqual(student_teacher, teacher_student) + assert student_teacher != teacher_student student_teacher = GKDTrainer.generalized_jsd_loss(self.student_logits, self.teacher_logits, beta=0.5) teacher_student = GKDTrainer.generalized_jsd_loss(self.teacher_logits, self.student_logits, beta=0.5) - self.assertEqual(student_teacher, teacher_student) + assert student_teacher == teacher_student def test_zero_loss_for_identical_inputs(self): identical_logits = torch.randn(self.batch_size, self.seq_length, self.vocab_size) loss = GKDTrainer.generalized_jsd_loss(identical_logits, identical_logits) - self.assertAlmostEqual(loss.item(), 0, places=6) + assert round(abs(loss.item()-0), 6) == 0 class GKDTrainerTester(TrlTestCase): @@ -242,9 +236,9 @@ def test_gkd_trainer(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[(-1)]["train_loss"]) - self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"]) - self.assertIn("model.safetensors", os.listdir(self.tmp_dir + "/checkpoint-2")) + assert trainer.state.log_history[(-1)]["train_loss"] is not None + assert trainer.state.log_history[0]["eval_loss"] is not None + assert "model.safetensors" in os.listdir(self.tmp_dir + "/checkpoint-2") @require_liger_kernel @pytest.mark.xfail(reason="Computing the Liger loss spikes GPU memory usage, causing the test to run OOM.") @@ -271,7 +265,7 @@ def test_gkd_trainer_with_liger(self): trainer.train() # Check we logged a train loss - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None def test_generation_config_init(self): training_args = GKDConfig(output_dir=self.tmp_dir) @@ -286,8 +280,8 @@ def test_generation_config_init(self): processing_class=self.tokenizer, ) - self.assertEqual(trainer.generation_config.pad_token_id, self.tokenizer.eos_token_id) - self.assertEqual(trainer.generation_config.eos_token_id, self.model.generation_config.eos_token_id) - self.assertEqual(trainer.generation_config.max_new_tokens, training_args.max_new_tokens) - self.assertEqual(trainer.generation_config.temperature, training_args.temperature) - self.assertEqual(trainer.generation_config.top_k, 0) + assert trainer.generation_config.pad_token_id == self.tokenizer.eos_token_id + assert trainer.generation_config.eos_token_id == self.model.generation_config.eos_token_id + assert trainer.generation_config.max_new_tokens == training_args.max_new_tokens + assert trainer.generation_config.temperature == training_args.temperature + assert trainer.generation_config.top_k == 0 diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index a839a654bca..9a442ee6102 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -148,12 +148,12 @@ def test_training(self, config_name): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." @parameterized.expand([("bnpo",), ("dr_grpo",), ("dapo",)]) def test_training_loss_types(self, loss_type): @@ -180,12 +180,12 @@ def test_training_loss_types(self, loss_type): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_with_eval(self): dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only") @@ -233,12 +233,12 @@ def test_training_multiple_iterations(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." @require_peft def test_training_peft(self): @@ -266,15 +266,15 @@ def test_training_peft(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the peft params have changed and the base model params have not changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if n in base_param_names: # We expect the base model params to be the same - self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed.") + assert torch.allclose(param, new_param), f"Parameter {n} has changed." elif "base_layer" not in n: # We expect the peft params to be different (except for the base layer) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed.") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed." @require_peft def test_training_peft_with_gradient_checkpointing(self): @@ -308,22 +308,22 @@ def test_training_peft_with_gradient_checkpointing(self): ) # Verify gradient checkpointing is enabled - self.assertIsInstance(trainer.model, PeftModel) + assert isinstance(trainer.model, PeftModel) # Store initial parameters to check which ones change previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that only LoRA parameters have changed, base model parameters remain unchanged for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if "lora" in n.lower(): # LoRA parameters should change - self.assertFalse(torch.equal(param, new_param), f"LoRA parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"LoRA parameter {n} has not changed." else: # Base model parameters should not change - self.assertTrue(torch.equal(param, new_param), f"Base parameter {n} has changed.") + assert torch.equal(param, new_param), f"Base parameter {n} has changed." def test_training_different_reward_model(self): # Use a reward model different from the model: different chat template, tokenization, etc. @@ -357,12 +357,12 @@ def test_training_different_reward_model(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_reward_func_standard(self): # Test if trainer can handle reward function with standard format @@ -391,12 +391,12 @@ def reward_func(completions, **kwargs): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_reward_func_conversational(self): # Test if trainer can handle reward function with conversational format @@ -426,12 +426,12 @@ def reward_func(completions, **kwargs): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_multiple_reward_funcs(self): # Test that GRPOTrainer can be instantiated with multiple reward functions @@ -464,12 +464,12 @@ def reward_func2(completions, **kwargs): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_multiple_reward_funcs_with_None_output(self): """Test that a valid math reward function is processed correctly while the code reward function returns None.""" @@ -508,12 +508,12 @@ def non_applicable_reward_func(completions, **kwargs): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_multiple_reward_funcs_with_weights(self): """Test that GRPOTrainer can handle multiple reward functions with weights.""" @@ -548,16 +548,16 @@ def reward_func2(completions, **kwargs): trainer.train() # Check that training logs contain both reward metrics - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) - self.assertIn("rewards/reward_func1/mean", trainer.state.log_history[-1]) - self.assertIn("rewards/reward_func1/std", trainer.state.log_history[-1]) - self.assertIn("rewards/reward_func2/mean", trainer.state.log_history[-1]) - self.assertIn("rewards/reward_func2/std", trainer.state.log_history[-1]) + assert trainer.state.log_history[-1]["train_loss"] is not None + assert "rewards/reward_func1/mean" in trainer.state.log_history[-1] + assert "rewards/reward_func1/std" in trainer.state.log_history[-1] + assert "rewards/reward_func2/mean" in trainer.state.log_history[-1] + assert "rewards/reward_func2/std" in trainer.state.log_history[-1] # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_multiple_mixed_reward_funcs(self): # Test if the trainer can handle a mix of reward functions and reward models @@ -586,12 +586,12 @@ def reward_func(completions, **kwargs): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_reward_func_additional_column(self): # Test if trainer can handle reward function that rely on additional columns in the dataset @@ -624,12 +624,12 @@ def reward_func(completions, some_values, **kwargs): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_with_sync_ref_model(self): dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") @@ -655,12 +655,12 @@ def test_training_with_sync_ref_model(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_beta_non_zero(self): dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") @@ -684,12 +684,12 @@ def test_training_beta_non_zero(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_with_entropy_filter(self): dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") @@ -713,12 +713,12 @@ def test_training_with_entropy_filter(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." @unittest.skip("We should add a mock for the vLLM server.") @require_peft @@ -755,16 +755,16 @@ def test_training_vllm_and_peft(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the peft params have changed and the base model params have not changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if n in base_param_names: # We expect the base model params to be the same - self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed.") + assert torch.allclose(param, new_param), f"Parameter {n} has changed." elif "base_layer" not in n and "original_module" not in n: # We expect the peft params to be different (except for the base layer) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed.") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed." @require_vllm @unittest.skip("We should add a mock for the vLLM server.") @@ -793,12 +793,12 @@ def test_training_vllm_guided_decoding(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." @require_vllm @unittest.skip("We should add a mock for the vLLM server.") @@ -828,12 +828,12 @@ def test_training_vllm_importance_sampling_correction(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_with_additional_generation_kwargs(self): """Test that training works with additional generation kwargs.""" @@ -863,12 +863,12 @@ def test_training_with_additional_generation_kwargs(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." @require_vllm @unittest.skip("We should add a mock for the vLLM server.") @@ -901,12 +901,12 @@ def test_training_vllm_with_additional_generation_kwargs(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." @parameterized.expand([(False,), ("group",), ("batch",), (True,), ("none",)]) def test_training_scale_rewards(self, scale_rewards): @@ -932,12 +932,12 @@ def test_training_scale_rewards(self, scale_rewards): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." @patch("transformers.generation.utils.GenerationMixin.generate") def test_training_with_mask_truncated_completions(self, mock_generate): @@ -982,12 +982,12 @@ def fake_generate(input_ids, **kwargs): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_with_mask_truncated_completions_all_masked(self): """ @@ -1020,12 +1020,12 @@ def test_training_with_mask_truncated_completions_all_masked(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertTrue(torch.equal(param, new_param), f"Parameter {n} has changed.") + assert torch.equal(param, new_param), f"Parameter {n} has changed." def test_warning_raised_all_rewards_none(self): """Test that a proper warning is raised when all rewards are None.""" @@ -1054,7 +1054,7 @@ def always_none_reward_func(completions, **kwargs): trainer.train() expected_warning = "All reward functions returned None for the following kwargs:" - self.assertIn(expected_warning, cm.output[0]) + assert expected_warning in cm.output[0] def test_training_num_generations_larger_than_batch_size(self): dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") @@ -1079,12 +1079,12 @@ def test_training_num_generations_larger_than_batch_size(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_delta_clipping(self): dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") @@ -1109,12 +1109,12 @@ def test_training_delta_clipping(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_multiple_dataloader_workers(self): dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") @@ -1139,12 +1139,12 @@ def test_training_multiple_dataloader_workers(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_with_generation_kwargs(self): dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") @@ -1169,12 +1169,12 @@ def test_training_with_generation_kwargs(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_with_reward_func_accessing_trainer_state(self): dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") @@ -1246,7 +1246,7 @@ def test_prepare_input_called_with_correct_data(self): with patch.object(GRPOTrainer, "training_step", wraps=trainer.training_step) as mock_prepare: trainer.train() # 3 epochs * 2 iterations * 2 generation batches to cover the dataset * 4 steps_per_generation - self.assertEqual(mock_prepare.call_count, 48) + assert mock_prepare.call_count == 48 for i in range(0, 8): # Generation batch repeated 8 times (steps_per_generation*num_iterations) assert mock_prepare.call_args_list[i].args[1] == expected_first_generation_batch for i in range(8, 16): @@ -1289,7 +1289,7 @@ def reward_func(completions, **kwargs): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed # Because of the way the tiny models are initialized, the gradient does not flow properly through the @@ -1305,7 +1305,7 @@ def reward_func(completions, **kwargs): if n.startswith(params_to_skip): continue new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." @require_vision def test_training_vlm_beta_non_zero(self): @@ -1335,7 +1335,7 @@ def reward_func(completions, **kwargs): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed # Because of the way the tiny models are initialized, the gradient does not flow properly through the @@ -1345,7 +1345,7 @@ def reward_func(completions, **kwargs): if n.startswith(params_to_skip): continue new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." @require_vision @require_peft @@ -1380,15 +1380,15 @@ def reward_func(completions, **kwargs): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the peft params have changed and the base model params have not changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if n in base_param_names: # We expect the base model params to be the same - self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed.") + assert torch.allclose(param, new_param), f"Parameter {n} has changed." elif "base_layer" not in n: # We expect the peft params to be different (except for the base layer) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed.") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed." @require_vision def test_training_vlm_and_importance_sampling(self): @@ -1418,7 +1418,7 @@ def reward_func(completions, **kwargs): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed # Because of the way the tiny models are initialized, the gradient does not flow properly through the @@ -1428,7 +1428,7 @@ def reward_func(completions, **kwargs): if n.startswith(params_to_skip): continue new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." @require_vision @require_liger_kernel @@ -1460,7 +1460,7 @@ def reward_func(completions, **kwargs): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed # Because of the way the tiny models are initialized, the gradient does not flow properly through the @@ -1470,7 +1470,7 @@ def reward_func(completions, **kwargs): if n.startswith(params_to_skip): continue new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." @require_vision def test_training_vlm_and_prompt_truncation(self): @@ -1501,7 +1501,7 @@ def reward_func(completions, **kwargs): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed # Because of the way the tiny models are initialized, the gradient does not flow properly through the @@ -1511,7 +1511,7 @@ def reward_func(completions, **kwargs): if n.startswith(params_to_skip): continue new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." @require_vision @require_vllm @@ -1551,11 +1551,11 @@ def reward_func(completions, **kwargs): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." @require_vision def test_training_vlm_multi_image(self): @@ -1588,14 +1588,14 @@ def reward_func(completions, **kwargs): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed # Because of the way the tiny models are initialized, the gradient does not flow properly through the # vision parts of the model, so we skip them. Ideally, we should fix the init of these models. for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_sequence_importance_sampling(self): dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") @@ -1621,12 +1621,12 @@ def test_training_sequence_importance_sampling(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_mismatched_reward_processing_classes_length(self): """Test that mismatched length between reward_funcs and reward_processing_classes raises error.""" @@ -1645,7 +1645,7 @@ def test_mismatched_reward_processing_classes_length(self): training_args = GRPOConfig(output_dir=self.tmp_dir, report_to="none") - with self.assertRaises(ValueError) as context: + with pytest.raises(ValueError) as context: GRPOTrainer( model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", reward_funcs=reward_models, @@ -1654,7 +1654,7 @@ def test_mismatched_reward_processing_classes_length(self): train_dataset=dataset, ) - self.assertIn("must match", str(context.exception)) + assert "must match" in str(context.exception) def test_correct_reward_processing_classes_list(self): """Test that correct list of reward_processing_classes works properly.""" @@ -1685,7 +1685,7 @@ def test_correct_reward_processing_classes_list(self): train_dataset=dataset, ) - self.assertEqual(len(trainer.reward_processing_classes), len(reward_models)) + assert len(trainer.reward_processing_classes) == len(reward_models) def test_single_reward_model_with_single_processing_class(self): """Test that single reward model with single processing class works.""" @@ -1709,8 +1709,8 @@ def test_single_reward_model_with_single_processing_class(self): train_dataset=dataset, ) - self.assertEqual(len(trainer.reward_processing_classes), 1) - self.assertEqual(trainer.reward_processing_classes[0], single_processing_class) + assert len(trainer.reward_processing_classes) == 1 + assert trainer.reward_processing_classes[0] == single_processing_class @pytest.mark.low_priority @@ -1731,12 +1731,12 @@ def test_add(self): self.replay_buffer.add(scores, data) # Check if the buffer contains the correct number of elements - self.assertEqual(len(self.replay_buffer.heap), 5) + assert len(self.replay_buffer.heap) == 5 # Check if the buffer maintains the min-heap property heap_scores = [item[0] for item in self.replay_buffer.heap] - self.assertEqual(heap_scores[0], min(heap_scores)) - self.assertEqual(heap_scores[0], 0.3) + assert heap_scores[0] == min(heap_scores) + assert heap_scores[0] == 0.3 def test_add_more_than_maxlen(self): # Add elements to the replay buffer @@ -1753,12 +1753,12 @@ def test_add_more_than_maxlen(self): self.replay_buffer.add(scores, data) # Check if the buffer contains the correct number of elements - self.assertEqual(len(self.replay_buffer.heap), 5) + assert len(self.replay_buffer.heap) == 5 # Check if the buffer maintains the min-heap property heap_scores = [item[0] for item in self.replay_buffer.heap] - self.assertEqual(heap_scores[0], min(heap_scores)) - self.assertEqual(heap_scores[0], 0.5) # 0.3 and 0.4 should be removed + assert heap_scores[0] == min(heap_scores) + assert heap_scores[0] == 0.5 # 0.3 and 0.4 should be removed def test_sample(self): # Add elements to the replay buffer @@ -1776,9 +1776,9 @@ def test_sample(self): sampled = self.replay_buffer.sample(num_samples=3) # Check if the sampled elements are from the buffer - self.assertEqual(len(sampled), 3) + assert len(sampled) == 3 for item in sampled: - self.assertIn(item, [entry[1] for entry in self.replay_buffer.heap]) + assert item in [entry[1] for entry in self.replay_buffer.heap] @pytest.mark.low_priority @@ -1841,12 +1841,12 @@ def test_update_with_replay_buffer_no_variance(self): outputs = self.trainer.update_with_replay_buffer(**inputs, num_items_in_batch=4) - self.assertIsNotNone(outputs) - self.assertIn("pixel_values", outputs) - self.assertIn("old_per_token_logps", outputs) - self.assertEqual(len(self.trainer.replay_buffer.heap), 2) + assert outputs is not None + assert "pixel_values" in outputs + assert "old_per_token_logps" in outputs + assert len(self.trainer.replay_buffer.heap) == 2 for pid in outputs["prompt_ids"]: - self.assertNotIn(pid.tolist(), original_prompt_ids.tolist()) + assert pid.tolist() not in original_prompt_ids.tolist() def test_update_with_replay_buffer_with_variance(self): self._prepopulate_buffer() @@ -1855,8 +1855,8 @@ def test_update_with_replay_buffer_with_variance(self): sampled = self.trainer.update_with_replay_buffer(**inputs, num_items_in_batch=4) - self.assertEqual(len(self.trainer.replay_buffer.heap), 4) # grew - self.assertIsNone(sampled) + assert len(self.trainer.replay_buffer.heap) == 4 # grew + assert sampled is None def test_update_with_mixed_variance(self): self._prepopulate_buffer() @@ -1866,16 +1866,16 @@ def test_update_with_mixed_variance(self): outputs = self.trainer.update_with_replay_buffer(**inputs, num_items_in_batch=4) - self.assertEqual(len(self.trainer.replay_buffer.heap), 3) # grew by 1 + assert len(self.trainer.replay_buffer.heap) == 3 # grew by 1 output_prompt_ids = outputs["prompt_ids"].view(-1, self.trainer.num_generations, 2).tolist() buffer_ids = [item[1]["prompt_ids"].tolist() for item in self.trainer.replay_buffer.heap] found_from_buffer = any(pid in buffer_ids for pid in output_prompt_ids) found_from_original = any(pid in original_prompt_ids for pid in output_prompt_ids) - self.assertTrue(found_from_buffer) - self.assertTrue(found_from_original) - self.assertNotIn([[1, 2], [3, 4]], output_prompt_ids) # excluded no-variance group + assert found_from_buffer + assert found_from_original + assert [[1, 2], [3, 4]] not in output_prompt_ids # excluded no-variance group def test_update_with_inputs_different_seq_len(self): """ @@ -1910,8 +1910,8 @@ def test_update_with_inputs_different_seq_len(self): outputs_after_sampling = self.trainer.update_with_replay_buffer(**inputs, num_items_in_batch=4) # Seq length of current batch should be preserved - self.assertEqual(outputs_after_sampling["prompt_ids"].shape[-1], 3) - self.assertEqual(len(self.trainer.replay_buffer.heap), 3) + assert outputs_after_sampling["prompt_ids"].shape[-1] == 3 + assert len(self.trainer.replay_buffer.heap) == 3 output_prompt_ids = outputs_after_sampling["prompt_ids"].view(-1, self.trainer.num_generations, 3).tolist() buffered_prompt_completion_ids = [ @@ -1921,24 +1921,20 @@ def test_update_with_inputs_different_seq_len(self): buffered_prompt_ids, buffered_completion_ids = zip(*buffered_prompt_completion_ids) # Check for new entry with seq len 3 in buffer - self.assertIn([[3, 4, 5], [3, 4, 5]], buffered_prompt_ids) # excluded no-variance group - self.assertIn( - [[1013, 1014, pad_token_id], [1015, 1016, 1017]], buffered_completion_ids - ) # excluded no-variance group + assert [[3, 4, 5], [3, 4, 5]] in buffered_prompt_ids # excluded no-variance group + assert [[1013, 1014, pad_token_id], [1015, 1016, 1017]] in buffered_completion_ids # excluded no-variance group # Check that sampled outputs contain one group with prompt_ids starting with a pad token - self.assertTrue( - [ + assert [ [pad_token_id, 101, 102], [pad_token_id, 102, 103], - ] - in output_prompt_ids + ] \ + in output_prompt_ids \ or [ [pad_token_id, 104, 105], [pad_token_id, 106, 107], - ] + ] \ in output_prompt_ids - ) @pytest.mark.low_priority @@ -1973,12 +1969,12 @@ def custom_reward_func(completions, **kwargs): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." class GSPOTokenTrainerTester(TrlTestCase): @@ -2006,12 +2002,12 @@ def test_training(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." if __name__ == "__main__": diff --git a/tests/test_judges.py b/tests/test_judges.py index cce8f961a5f..9238a7f5cb2 100644 --- a/tests/test_judges.py +++ b/tests/test_judges.py @@ -35,17 +35,17 @@ def test_all_true_judge(self): judge = AllTrueJudge(judges=[RandomBinaryJudge(), RandomBinaryJudge()]) prompts, completions = self._get_prompts_and_single_completions() judgements = judge.judge(prompts=prompts, completions=completions) - self.assertEqual(len(judgements), 2) - self.assertTrue(all(judgement in {0, 1, -1} for judgement in judgements)) + assert len(judgements) == 2 + assert all(judgement in {0, 1, -1} for judgement in judgements) @unittest.skip("This test needs to be run manually since it requires a valid Hugging Face API key.") def test_hugging_face_judge(self): judge = HfPairwiseJudge() prompts, completions = self._get_prompts_and_pairwise_completions() ranks = judge.judge(prompts=prompts, completions=completions) - self.assertEqual(len(ranks), 2) - self.assertTrue(all(isinstance(rank, int) for rank in ranks)) - self.assertEqual(ranks, [0, 1]) + assert len(ranks) == 2 + assert all(isinstance(rank, int) for rank in ranks) + assert ranks == [0, 1] def load_pair_rm_judge(self): # When using concurrent tests, PairRM may fail to load the model while another job is still downloading. @@ -62,15 +62,15 @@ def test_pair_rm_judge(self): judge = self.load_pair_rm_judge() prompts, completions = self._get_prompts_and_pairwise_completions() ranks = judge.judge(prompts=prompts, completions=completions) - self.assertEqual(len(ranks), 2) - self.assertTrue(all(isinstance(rank, int) for rank in ranks)) - self.assertEqual(ranks, [0, 1]) + assert len(ranks) == 2 + assert all(isinstance(rank, int) for rank in ranks) + assert ranks == [0, 1] @require_llm_blender def test_pair_rm_judge_return_scores(self): judge = self.load_pair_rm_judge() prompts, completions = self._get_prompts_and_pairwise_completions() probs = judge.judge(prompts=prompts, completions=completions, return_scores=True) - self.assertEqual(len(probs), 2) - self.assertTrue(all(isinstance(prob, float) for prob in probs)) - self.assertTrue(all(0 <= prob <= 1 for prob in probs)) + assert len(probs) == 2 + assert all(isinstance(prob, float) for prob in probs) + assert all(0 <= prob <= 1 for prob in probs) diff --git a/tests/test_kto_trainer.py b/tests/test_kto_trainer.py index 21b425fec05..ac68a00d9a7 100644 --- a/tests/test_kto_trainer.py +++ b/tests/test_kto_trainer.py @@ -23,6 +23,7 @@ from trl.trainer.kto_trainer import _get_kl_dataset, _process_tokens, _tokenize from .testing_utils import TrlTestCase, require_no_wandb +import pytest class KTOTrainerTester(TrlTestCase): @@ -91,13 +92,13 @@ def test_kto_trainer(self, name, config_name, loss_type, pre_compute, eval_datas trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the parameters have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if param.sum() != 0: # ignore 0 biases - self.assertFalse(torch.equal(param, new_param)) + assert not torch.equal(param, new_param) def test_kto_trainer_with_ref_model_is_model(self): training_args = KTOConfig( @@ -109,7 +110,7 @@ def test_kto_trainer_with_ref_model_is_model(self): dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference") - with self.assertRaises(ValueError): + with pytest.raises(ValueError): KTOTrainer( model=self.model, ref_model=self.model, # ref_model can't be the same as model @@ -149,13 +150,13 @@ def test_tokenize_and_process_tokens(self): batched=True, batch_size=2, ) - self.assertListEqual(tokenized_dataset["prompt"][:], train_dataset["prompt"][:]) - self.assertListEqual(tokenized_dataset["completion"][:], train_dataset["completion"][:]) - self.assertListEqual(tokenized_dataset["label"][:], train_dataset["label"][:]) - self.assertListEqual(tokenized_dataset["prompt_input_ids"][0], [46518, 374, 2664, 1091]) - self.assertListEqual(tokenized_dataset["prompt_attention_mask"][0], [1, 1, 1, 1]) - self.assertListEqual(tokenized_dataset["answer_input_ids"][0], [27261, 13]) - self.assertListEqual(tokenized_dataset["answer_attention_mask"][0], [1, 1]) + assert tokenized_dataset["prompt"][:] == train_dataset["prompt"][:] + assert tokenized_dataset["completion"][:] == train_dataset["completion"][:] + assert tokenized_dataset["label"][:] == train_dataset["label"][:] + assert tokenized_dataset["prompt_input_ids"][0] == [46518, 374, 2664, 1091] + assert tokenized_dataset["prompt_attention_mask"][0] == [1, 1, 1, 1] + assert tokenized_dataset["answer_input_ids"][0] == [27261, 13] + assert tokenized_dataset["answer_attention_mask"][0] == [1, 1] # Test corruption of (prompt, completion) pairs for KL dataset for batch_size in [2, 3]: @@ -166,18 +167,12 @@ def test_tokenize_and_process_tokens(self): # the last batch remains unaltered. This is a rare scenario that does not impact the training # process, so we exclude it from testing by iterating only up to len - 1. for i in range(len(tokenized_kl_dataset["answer_input_ids"]) - 1): - self.assertListEqual( - tokenized_dataset["prompt_input_ids"][i], - tokenized_kl_dataset["prompt_input_ids"][i], - ) - self.assertListEqual( - tokenized_dataset["prompt_attention_mask"][i], - tokenized_kl_dataset["prompt_attention_mask"][i], - ) - self.assertNotEqual( - tokenized_dataset["answer_input_ids"][i], - tokenized_kl_dataset["answer_input_ids"][i], - ) + assert tokenized_dataset["prompt_input_ids"][i] == \ + tokenized_kl_dataset["prompt_input_ids"][i] + assert tokenized_dataset["prompt_attention_mask"][i] == \ + tokenized_kl_dataset["prompt_attention_mask"][i] + assert tokenized_dataset["answer_input_ids"][i] != \ + tokenized_kl_dataset["answer_input_ids"][i] fn_kwargs = { "prefix": "", @@ -189,14 +184,14 @@ def test_tokenize_and_process_tokens(self): "max_prompt_length": trainer.max_prompt_length, } processed_dataset = tokenized_dataset.map(_process_tokens, fn_kwargs=fn_kwargs, num_proc=2) - self.assertListEqual(processed_dataset["prompt"][:], train_dataset["prompt"][:]) - self.assertListEqual(processed_dataset["completion"][:], train_dataset["completion"][:]) - self.assertListEqual(processed_dataset["label"][:], train_dataset["label"][:]) - self.assertListEqual(processed_dataset["prompt_input_ids"][0], [46518, 374, 2664, 1091]) - self.assertListEqual(processed_dataset["prompt_attention_mask"][0], [1, 1, 1, 1]) - self.assertListEqual(processed_dataset["completion_input_ids"][0], [46518, 374, 2664, 1091, 27261, 13, 151645]) - self.assertListEqual(processed_dataset["completion_attention_mask"][0], [1, 1, 1, 1, 1, 1, 1]) - self.assertListEqual(processed_dataset["completion_labels"][0], [-100, -100, -100, -100, 27261, 13, 151645]) + assert processed_dataset["prompt"][:] == train_dataset["prompt"][:] + assert processed_dataset["completion"][:] == train_dataset["completion"][:] + assert processed_dataset["label"][:] == train_dataset["label"][:] + assert processed_dataset["prompt_input_ids"][0] == [46518, 374, 2664, 1091] + assert processed_dataset["prompt_attention_mask"][0] == [1, 1, 1, 1] + assert processed_dataset["completion_input_ids"][0] == [46518, 374, 2664, 1091, 27261, 13, 151645] + assert processed_dataset["completion_attention_mask"][0] == [1, 1, 1, 1, 1, 1, 1] + assert processed_dataset["completion_labels"][0] == [-100, -100, -100, -100, 27261, 13, 151645] def test_kto_trainer_without_providing_ref_model(self): training_args = KTOConfig( @@ -226,13 +221,13 @@ def test_kto_trainer_without_providing_ref_model(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the parameters have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if param.sum() != 0: # ignore 0 biases - self.assertFalse(torch.equal(param, new_param)) + assert not torch.equal(param, new_param) @require_peft def test_kto_trainer_without_providing_ref_model_with_lora(self): @@ -274,14 +269,14 @@ def test_kto_trainer_without_providing_ref_model_with_lora(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the parameters have changed for n, param in previous_trainable_params.items(): if "lora" in n: new_param = trainer.model.get_parameter(n) if param.sum() != 0: # ignore 0 biases - self.assertFalse(torch.equal(param, new_param)) + assert not torch.equal(param, new_param) @require_no_wandb def test_kto_trainer_generate_during_eval_no_wandb(self): @@ -300,11 +295,8 @@ def test_kto_trainer_generate_during_eval_no_wandb(self): dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference") - with self.assertRaisesRegex( - ValueError, - expected_regex="`generate_during_eval=True` requires Weights and Biases or Comet to be installed." - " Please install `wandb` or `comet-ml` to resolve.", - ): + with pytest.raises(ValueError, match="`generate_during_eval=True` requires Weights and Biases or Comet to be installed." + " Please install `wandb` or `comet-ml` to resolve."): KTOTrainer( model=self.model, ref_model=None, @@ -365,7 +357,7 @@ def test_kto_lora_save(self): try: AutoModelForCausalLM.from_pretrained(self.tmp_dir) except OSError: - self.fail("Loading the saved peft adapter failed") + pytest.fail("Loading the saved peft adapter failed") @require_liger_kernel def test_kto_trainer_with_liger(self): @@ -389,14 +381,14 @@ def test_kto_trainer_with_liger(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) # check the params have changed - ignore 0 biases if param.sum() != 0: - self.assertFalse(torch.equal(param, new_param)) + assert not torch.equal(param, new_param) def test_compute_metrics(self): model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") @@ -432,4 +424,4 @@ def dummy_compute_metrics(*args, **kwargs): trainer.train() - self.assertEqual(trainer.state.log_history[-2]["eval_test"], 0.0) + assert trainer.state.log_history[-2]["eval_test"] == 0.0 diff --git a/tests/test_modeling_geometric_mixture_wrapper.py b/tests/test_modeling_geometric_mixture_wrapper.py index ae6f5010821..65553b79b77 100644 --- a/tests/test_modeling_geometric_mixture_wrapper.py +++ b/tests/test_modeling_geometric_mixture_wrapper.py @@ -40,9 +40,9 @@ def test_forward(self): output = self.wrapper(input_ids=input_ids, attention_mask=attention_mask) - self.assertIsNotNone(output) - self.assertTrue(hasattr(output, "logits")) - self.assertEqual(output.logits.shape, (1, 5, self.model.config.vocab_size)) + assert output is not None + assert hasattr(output, "logits") + assert output.logits.shape == (1, 5, self.model.config.vocab_size) def test_mixture_coefficient(self): input_ids = torch.tensor([[1, 2, 3, 4, 5]], device=self.device) @@ -57,7 +57,7 @@ def test_mixture_coefficient(self): self.mixture_coef * ref_model_output.logits + (1 - self.mixture_coef) * model_output.logits, dim=-1 ) - self.assertTrue(torch.allclose(wrapper_output.logits, expected_logits, atol=1e-5)) + assert torch.allclose(wrapper_output.logits, expected_logits, atol=1e-5) def test_prepare_inputs_for_generation(self): input_ids = torch.tensor([[1, 2, 3, 4, 5]], device=self.device) @@ -65,6 +65,6 @@ def test_prepare_inputs_for_generation(self): inputs = self.wrapper.prepare_inputs_for_generation(input_ids, attention_mask=attention_mask, use_cache=True) - self.assertIn("input_ids", inputs) - self.assertIn("attention_mask", inputs) - self.assertFalse(inputs.get("use_cache", False)) + assert "input_ids" in inputs + assert "attention_mask" in inputs + assert not inputs.get("use_cache", False) diff --git a/tests/test_modeling_value_head.py b/tests/test_modeling_value_head.py index b0a75211175..355dcb2a1c8 100644 --- a/tests/test_modeling_value_head.py +++ b/tests/test_modeling_value_head.py @@ -65,7 +65,7 @@ def test_value_head(self): """ for model_name in self.all_model_names: model = self.trl_model_class.from_pretrained(model_name) - self.assertTrue(hasattr(model, "v_head")) + assert hasattr(model, "v_head") def test_value_head_shape(self): r""" @@ -73,7 +73,7 @@ def test_value_head_shape(self): """ for model_name in self.all_model_names: model = self.trl_model_class.from_pretrained(model_name) - self.assertEqual(model.v_head.summary.weight.shape[0], 1) + assert model.v_head.summary.weight.shape[0] == 1 def test_value_head_init_random(self): r""" @@ -82,9 +82,7 @@ def test_value_head_init_random(self): """ for model_name in self.all_model_names: model = self.trl_model_class.from_pretrained(model_name) - self.assertFalse( - torch.allclose(model.v_head.summary.bias, torch.zeros_like(model.v_head.summary.bias)) - ) + assert not torch.allclose(model.v_head.summary.bias, torch.zeros_like(model.v_head.summary.bias)) def test_value_head_not_str(self): r""" @@ -94,7 +92,7 @@ def test_value_head_not_str(self): for model_name in self.all_model_names: pretrained_model = self.transformers_model_class.from_pretrained(model_name) model = self.trl_model_class.from_pretrained(pretrained_model) - self.assertTrue(hasattr(model, "v_head")) + assert hasattr(model, "v_head") def test_from_save_trl(self): """ @@ -110,7 +108,7 @@ def test_from_save_trl(self): # Check if the weights are the same for key in model_from_save.state_dict(): - self.assertTrue(torch.allclose(model_from_save.state_dict()[key], model.state_dict()[key])) + assert torch.allclose(model_from_save.state_dict()[key], model.state_dict()[key]) def test_from_save_trl_sharded(self): """ @@ -125,7 +123,7 @@ def test_from_save_trl_sharded(self): # Check if the weights are the same for key in model_from_save.state_dict(): - self.assertTrue(torch.allclose(model_from_save.state_dict()[key], model.state_dict()[key])) + assert torch.allclose(model_from_save.state_dict()[key], model.state_dict()[key]) def test_from_save_transformers_sharded(self): """ @@ -143,11 +141,9 @@ def test_from_save_transformers_sharded(self): # Check if the weights are the same for key in transformers_model.state_dict(): - self.assertTrue( - torch.allclose( + assert torch.allclose( transformers_model_from_save.state_dict()[key], transformers_model.state_dict()[key] ) - ) def test_from_save_transformers(self): """ @@ -166,27 +162,21 @@ def test_from_save_transformers(self): # Check if the weights are the same for key in transformers_model.state_dict(): - self.assertTrue( - torch.allclose( + assert torch.allclose( transformers_model_from_save.state_dict()[key], transformers_model.state_dict()[key] ) - ) # Check if the trl model has the same keys as the transformers model # except the v_head for key in trl_model.state_dict(): if "v_head" not in key: - self.assertIn(key, transformers_model.state_dict()) + assert key in transformers_model.state_dict() # check if the weights are the same - self.assertTrue( - torch.allclose(trl_model.state_dict()[key], transformers_model.state_dict()[key]) - ) + assert torch.allclose(trl_model.state_dict()[key], transformers_model.state_dict()[key]) # check if they have the same modules - self.assertEqual( - set(transformers_model_from_save.state_dict().keys()), - set(transformers_model.state_dict().keys()), - ) + assert set(transformers_model_from_save.state_dict().keys()) == \ + set(transformers_model.state_dict().keys()) class CausalLMValueHeadModelTester(BaseTester.VHeadModelTester, TrlTestCase): @@ -217,7 +207,7 @@ def test_inference(self): # Check if the outputs are of the right size - here # we always output 3 values - logits, loss, and value states - self.assertEqual(len(outputs), EXPECTED_OUTPUT_SIZE) + assert len(outputs) == EXPECTED_OUTPUT_SIZE def test_dropout_config(self): r""" @@ -229,7 +219,7 @@ def test_dropout_config(self): model = self.trl_model_class.from_pretrained(pretrained_model) # Check if v head of the model has the same dropout as the config - self.assertEqual(model.v_head.dropout.p, pretrained_model.config.summary_dropout_prob) + assert model.v_head.dropout.p == pretrained_model.config.summary_dropout_prob def test_dropout_kwargs(self): r""" @@ -241,12 +231,12 @@ def test_dropout_kwargs(self): model = self.trl_model_class.from_pretrained(model_name, **v_head_kwargs) # Check if v head of the model has the same dropout as the config - self.assertEqual(model.v_head.dropout.p, 0.5) + assert model.v_head.dropout.p == 0.5 model = self.trl_model_class.from_pretrained(model_name, summary_dropout_prob=0.5) # Check if v head of the model has the same dropout as the config - self.assertEqual(model.v_head.dropout.p, 0.5) + assert model.v_head.dropout.p == 0.5 @parameterized.expand(ALL_CAUSAL_LM_MODELS) def test_generate(self, model_name): @@ -271,14 +261,12 @@ def test_transformers_bf16_kwargs(self): lm_head_namings = ["lm_head", "embed_out", "output_layer"] - self.assertTrue( - any(hasattr(trl_model.pretrained_model, lm_head_naming) for lm_head_naming in lm_head_namings), - "Can't test the model because it doesn't have any of the expected lm_head namings", - ) + assert any(hasattr(trl_model.pretrained_model, lm_head_naming) for lm_head_naming in lm_head_namings), \ + "Can't test the model because it doesn't have any of the expected lm_head namings" for lm_head_naming in lm_head_namings: if hasattr(trl_model.pretrained_model, lm_head_naming): - self.assertEqual(getattr(trl_model.pretrained_model, lm_head_naming).weight.dtype, torch.bfloat16) + assert getattr(trl_model.pretrained_model, lm_head_naming).weight.dtype == torch.bfloat16 dummy_input = torch.LongTensor([[0, 1, 0, 1]]).to(self.device) @@ -296,13 +284,11 @@ def test_push_to_hub(self): model_from_pretrained = AutoModelForCausalLMWithValueHead.from_pretrained(model_name + "-ppo") # check all keys - self.assertEqual(model.state_dict().keys(), model_from_pretrained.state_dict().keys()) + assert model.state_dict().keys() == model_from_pretrained.state_dict().keys() for name, param in model.state_dict().items(): - self.assertTrue( - torch.allclose(param, model_from_pretrained.state_dict()[name]), - f"Parameter {name} is not the same after push_to_hub and from_pretrained", - ) + assert torch.allclose(param, model_from_pretrained.state_dict()[name]), \ + f"Parameter {name} is not the same after push_to_hub and from_pretrained" class Seq2SeqValueHeadModelTester(BaseTester.VHeadModelTester, TrlTestCase): @@ -334,7 +320,7 @@ def test_inference(self): # Check if the outputs are of the right size - here # we always output 3 values - logits, loss, and value states - self.assertEqual(len(outputs), EXPECTED_OUTPUT_SIZE) + assert len(outputs) == EXPECTED_OUTPUT_SIZE def test_dropout_config(self): r""" @@ -346,7 +332,7 @@ def test_dropout_config(self): model = self.trl_model_class.from_pretrained(pretrained_model) # Check if v head of the model has the same dropout as the config - self.assertEqual(model.v_head.dropout.p, pretrained_model.config.summary_dropout_prob) + assert model.v_head.dropout.p == pretrained_model.config.summary_dropout_prob def test_dropout_kwargs(self): r""" @@ -358,12 +344,12 @@ def test_dropout_kwargs(self): model = self.trl_model_class.from_pretrained(model_name, **v_head_kwargs) # Check if v head of the model has the same dropout as the config - self.assertEqual(model.v_head.dropout.p, 0.5) + assert model.v_head.dropout.p == 0.5 model = self.trl_model_class.from_pretrained(model_name, summary_dropout_prob=0.5) # Check if v head of the model has the same dropout as the config - self.assertEqual(model.v_head.dropout.p, 0.5) + assert model.v_head.dropout.p == 0.5 @parameterized.expand(ALL_SEQ2SEQ_MODELS) def test_generate(self, model_name): @@ -389,13 +375,11 @@ def test_push_to_hub(self): model_from_pretrained = self.trl_model_class.from_pretrained(model_name + "-ppo") # check all keys - self.assertEqual(model.state_dict().keys(), model_from_pretrained.state_dict().keys()) + assert model.state_dict().keys() == model_from_pretrained.state_dict().keys() for name, param in model.state_dict().items(): - self.assertTrue( - torch.allclose(param, model_from_pretrained.state_dict()[name]), - f"Parameter {name} is not the same after push_to_hub and from_pretrained", - ) + assert torch.allclose(param, model_from_pretrained.state_dict()[name]), \ + f"Parameter {name} is not the same after push_to_hub and from_pretrained" def test_transformers_bf16_kwargs(self): r""" @@ -408,13 +392,11 @@ def test_transformers_bf16_kwargs(self): lm_head_namings = self.trl_model_class.lm_head_namings - self.assertTrue( - any(hasattr(trl_model.pretrained_model, lm_head_naming) for lm_head_naming in lm_head_namings) - ) + assert any(hasattr(trl_model.pretrained_model, lm_head_naming) for lm_head_naming in lm_head_namings) for lm_head_naming in lm_head_namings: if hasattr(trl_model.pretrained_model, lm_head_naming): - self.assertTrue(getattr(trl_model.pretrained_model, lm_head_naming).weight.dtype == torch.bfloat16) + assert getattr(trl_model.pretrained_model, lm_head_naming).weight.dtype == torch.bfloat16 dummy_input = torch.LongTensor([[0, 1, 0, 1]]).to(self.device) @@ -453,16 +435,16 @@ def test_independent_reference(self): last_ref_layer_after = ref_model.get_parameter(layer_1).data.clone() # before optimization ref and model are identical - self.assertTrue((first_layer_before == first_ref_layer_before).all()) - self.assertTrue((last_layer_before == last_ref_layer_before).all()) + assert (first_layer_before == first_ref_layer_before).all() + assert (last_layer_before == last_ref_layer_before).all() # ref model stays identical after optimization - self.assertTrue((first_ref_layer_before == first_ref_layer_after).all()) - self.assertTrue((last_ref_layer_before == last_ref_layer_after).all()) + assert (first_ref_layer_before == first_ref_layer_after).all() + assert (last_ref_layer_before == last_ref_layer_after).all() # optimized model changes - self.assertFalse((first_layer_before == first_layer_after).all()) - self.assertFalse((last_layer_before == last_layer_after).all()) + assert not (first_layer_before == first_layer_after).all() + assert not (last_layer_before == last_layer_after).all() def test_shared_layers(self): layer_0 = self.layer_format.format(layer=0) @@ -487,15 +469,15 @@ def test_shared_layers(self): second_ref_layer_after = ref_model.get_parameter(layer_1).data.clone() # before optimization ref and model are identical - self.assertTrue((first_layer_before == first_ref_layer_before).all()) - self.assertTrue((second_layer_before == second_ref_layer_before).all()) + assert (first_layer_before == first_ref_layer_before).all() + assert (second_layer_before == second_ref_layer_before).all() # ref model stays identical after optimization - self.assertTrue((first_ref_layer_before == first_ref_layer_after).all()) - self.assertTrue((second_ref_layer_before == second_ref_layer_after).all()) + assert (first_ref_layer_before == first_ref_layer_after).all() + assert (second_ref_layer_before == second_ref_layer_after).all() # first layer of optimized model stays the same - self.assertTrue((first_layer_before == first_layer_after).all()) + assert (first_layer_before == first_layer_after).all() # other layers in optimized model change - self.assertFalse((second_layer_before == second_layer_after).all()) + assert not (second_layer_before == second_layer_after).all() diff --git a/tests/test_nash_md_trainer.py b/tests/test_nash_md_trainer.py index 4550c35e1d1..25caaff3d38 100644 --- a/tests/test_nash_md_trainer.py +++ b/tests/test_nash_md_trainer.py @@ -65,7 +65,7 @@ def test_nash_md_trainer_training(self, config_name): trainer.train() # Check if training loss is available - self.assertIn("train_loss", trainer.state.log_history[-1]) + assert "train_loss" in trainer.state.log_history[-1] @require_peft def test_training_with_peft(self): @@ -93,7 +93,7 @@ def test_training_with_peft(self): trainer.train() # Check if training loss is available - self.assertIn("train_loss", trainer.state.log_history[-1]) + assert "train_loss" in trainer.state.log_history[-1] @require_peft def test_training_with_peft_and_ref_model(self): @@ -122,7 +122,7 @@ def test_training_with_peft_and_ref_model(self): trainer.train() # Check if training loss is available - self.assertIn("train_loss", trainer.state.log_history[-1]) + assert "train_loss" in trainer.state.log_history[-1] @require_peft def test_training_with_peft_model_and_peft_config(self): @@ -153,7 +153,7 @@ def test_training_with_peft_model_and_peft_config(self): trainer.train() # Check if training loss is available - self.assertIn("train_loss", trainer.state.log_history[-1]) + assert "train_loss" in trainer.state.log_history[-1] @require_peft def test_training_pre_pefted_model_implicit_ref_with_reward_model(self): @@ -184,7 +184,7 @@ def test_training_pre_pefted_model_implicit_ref_with_reward_model(self): trainer.train() - self.assertIn("train_loss", trainer.state.log_history[-1]) + assert "train_loss" in trainer.state.log_history[-1] @parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)]) @require_llm_blender @@ -215,4 +215,4 @@ def test_nash_md_trainer_judge_training(self, config_name): trainer.train() # Check if training loss is available - self.assertIn("train_loss", trainer.state.log_history[-1]) + assert "train_loss" in trainer.state.log_history[-1] diff --git a/tests/test_online_dpo_trainer.py b/tests/test_online_dpo_trainer.py index 336db8a089f..b8eb83a0f04 100644 --- a/tests/test_online_dpo_trainer.py +++ b/tests/test_online_dpo_trainer.py @@ -73,7 +73,7 @@ def test_training(self, config_name): trainer.train() # Check if training loss is available - self.assertIn("train_loss", trainer.state.log_history[-1]) + assert "train_loss" in trainer.state.log_history[-1] def test_training_model_str(self): training_args = OnlineDPOConfig( @@ -98,7 +98,7 @@ def test_training_model_str(self): trainer.train() # Check if training loss is available - self.assertIn("train_loss", trainer.state.log_history[-1]) + assert "train_loss" in trainer.state.log_history[-1] def test_training_with_ref_model(self): training_args = OnlineDPOConfig( @@ -124,7 +124,7 @@ def test_training_with_ref_model(self): trainer.train() # Check if training loss is available - self.assertIn("train_loss", trainer.state.log_history[-1]) + assert "train_loss" in trainer.state.log_history[-1] def test_ref_model_is_model(self): training_args = OnlineDPOConfig( @@ -136,7 +136,7 @@ def test_ref_model_is_model(self): dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only") - with self.assertRaises(ValueError): + with pytest.raises(ValueError): OnlineDPOTrainer( model=self.model, ref_model=self.model, # ref_model can't be the same as model @@ -174,7 +174,7 @@ def test_training_with_peft(self): trainer.train() # Check if training loss is available - self.assertIn("train_loss", trainer.state.log_history[-1]) + assert "train_loss" in trainer.state.log_history[-1] @require_peft def test_training_with_peft_and_ref_model(self): @@ -204,7 +204,7 @@ def test_training_with_peft_and_ref_model(self): trainer.train() # Check if training loss is available - self.assertIn("train_loss", trainer.state.log_history[-1]) + assert "train_loss" in trainer.state.log_history[-1] @require_peft def test_training_with_peft_model_and_peft_config(self): @@ -236,7 +236,7 @@ def test_training_with_peft_model_and_peft_config(self): trainer.train() # Check if training loss is available - self.assertIn("train_loss", trainer.state.log_history[-1]) + assert "train_loss" in trainer.state.log_history[-1] @require_llm_blender @parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)]) @@ -262,7 +262,7 @@ def test_training_with_judge(self, config_name): trainer.train() # Check if training loss is available - self.assertIn("train_loss", trainer.state.log_history[-1]) + assert "train_loss" in trainer.state.log_history[-1] @parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)]) @require_torch_accelerator @@ -293,7 +293,7 @@ def test_training_with_vllm(self, config_name): trainer.train() # Check if training loss is available - self.assertIn("train_loss", trainer.state.log_history[-1]) + assert "train_loss" in trainer.state.log_history[-1] @require_vllm def test_training_with_vllm_colocate(self): @@ -330,57 +330,57 @@ def test_training_with_vllm_colocate(self): ) # Verify vLLM setup - self.assertTrue(trainer.use_vllm) - self.assertEqual(trainer.vllm_mode, "colocate") - self.assertIsNotNone(trainer.llm) + assert trainer.use_vllm + assert trainer.vllm_mode == "colocate" + assert trainer.llm is not None # self.assertIsNone(trainer.vllm_client) # self.assertEqual(trainer.vllm_gpu_memory_utilization, 0.2) # Verify generation parameters - self.assertEqual(trainer.temperature, 0.9) - self.assertEqual(trainer.top_p, 0.95) - self.assertEqual(trainer.top_k, 50) - self.assertEqual(trainer.repetition_penalty, 1.1) + assert trainer.temperature == 0.9 + assert trainer.top_p == 0.95 + assert trainer.top_k == 50 + assert trainer.repetition_penalty == 1.1 # Verify generation config - self.assertIsNotNone(trainer.generation_config) - self.assertEqual(trainer.generation_config.temperature, 0.9) - self.assertEqual(trainer.generation_config.top_p, 0.95) - self.assertEqual(trainer.generation_config.top_k, 50) - self.assertEqual(trainer.generation_config.repetition_penalty, 1.1) - self.assertEqual(trainer.generation_config.max_tokens, 32) + assert trainer.generation_config is not None + assert trainer.generation_config.temperature == 0.9 + assert trainer.generation_config.top_p == 0.95 + assert trainer.generation_config.top_k == 50 + assert trainer.generation_config.repetition_penalty == 1.1 + assert trainer.generation_config.max_tokens == 32 trainer.train() # Check if training loss is available - self.assertIn("train_loss", trainer.state.log_history[-1]) + assert "train_loss" in trainer.state.log_history[-1] def test_vllm_config_validation(self): """Test vLLM configuration validation""" # Test valid vllm_mode values config = OnlineDPOConfig(use_vllm=True, vllm_mode="server") - self.assertEqual(config.vllm_mode, "server") + assert config.vllm_mode == "server" config = OnlineDPOConfig(use_vllm=True, vllm_mode="colocate") - self.assertEqual(config.vllm_mode, "colocate") + assert config.vllm_mode == "colocate" # Test default values config = OnlineDPOConfig() - self.assertEqual(config.vllm_mode, "server") - self.assertIsNone(config.vllm_server_base_url) - self.assertEqual(config.vllm_server_host, "0.0.0.0") - self.assertEqual(config.vllm_server_port, 8000) - self.assertEqual(config.vllm_server_timeout, 240.0) - self.assertEqual(config.vllm_gpu_memory_utilization, 0.55) + assert config.vllm_mode == "server" + assert config.vllm_server_base_url is None + assert config.vllm_server_host == "0.0.0.0" + assert config.vllm_server_port == 8000 + assert config.vllm_server_timeout == 240.0 + assert config.vllm_gpu_memory_utilization == 0.55 # Test generation parameters - self.assertEqual(config.top_p, 1.0) - self.assertIsNone(config.top_k) - self.assertIsNone(config.min_p) - self.assertEqual(config.repetition_penalty, 1.0) - self.assertFalse(config.use_transformers_paged) - self.assertIsNone(config.cache_implementation) - self.assertIsNone(config.generation_kwargs) + assert config.top_p == 1.0 + assert config.top_k is None + assert config.min_p is None + assert config.repetition_penalty == 1.0 + assert not config.use_transformers_paged + assert config.cache_implementation is None + assert config.generation_kwargs is None def test_generation_config_setup(self): """Test that generation configuration is properly set up for both vLLM and transformers""" @@ -407,17 +407,17 @@ def test_generation_config_setup(self): ) # Verify transformers generation config - self.assertFalse(trainer.use_vllm) + assert not trainer.use_vllm # When not using vLLM, these attributes should not be set - self.assertFalse(hasattr(trainer, "llm") and trainer.llm is not None) - self.assertFalse(hasattr(trainer, "vllm_client") and trainer.vllm_client is not None) - self.assertIsNotNone(trainer.generation_config) - self.assertEqual(trainer.generation_config.temperature, 0.8) - self.assertEqual(trainer.generation_config.top_p, 0.9) - self.assertEqual(trainer.generation_config.top_k, 40) - self.assertEqual(trainer.generation_config.repetition_penalty, 1.2) - self.assertEqual(trainer.generation_config.max_new_tokens, 64) - self.assertFalse(trainer.generation_config.do_sample) # From generation_kwargs + assert not (hasattr(trainer, "llm") and trainer.llm is not None) + assert not (hasattr(trainer, "vllm_client") and trainer.vllm_client is not None) + assert trainer.generation_config is not None + assert trainer.generation_config.temperature == 0.8 + assert trainer.generation_config.top_p == 0.9 + assert trainer.generation_config.top_k == 40 + assert trainer.generation_config.repetition_penalty == 1.2 + assert trainer.generation_config.max_new_tokens == 64 + assert not trainer.generation_config.do_sample # From generation_kwargs @require_torch_accelerator @parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)]) @@ -447,7 +447,7 @@ def test_training_with_transformers_paged(self, config_name): trainer.train() # Check if training loss is available - self.assertIn("train_loss", trainer.state.log_history[-1]) + assert "train_loss" in trainer.state.log_history[-1] @parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)]) def test_training_with_reward_funcs(self, config_name): @@ -475,11 +475,11 @@ def simple_reward_func(prompts, completions, completion_ids, **kwargs): ) trainer.train() - self.assertIn("train_loss", trainer.state.log_history[-1]) - self.assertEqual(len(trainer.reward_funcs), 2) - self.assertIsNotNone(trainer.reward_weights) - self.assertAlmostEqual(trainer.reward_weights[0].item(), 0.7, places=5) - self.assertAlmostEqual(trainer.reward_weights[1].item(), 0.3, places=5) + assert "train_loss" in trainer.state.log_history[-1] + assert len(trainer.reward_funcs) == 2 + assert trainer.reward_weights is not None + assert round(abs(trainer.reward_weights[0].item()-0.7), 5) == 0 + assert round(abs(trainer.reward_weights[1].item()-0.3), 5) == 0 @require_vision @@ -531,4 +531,4 @@ def test_online_dpo_vlm_trainer(self, model_id): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None diff --git a/tests/test_orpo_trainer.py b/tests/test_orpo_trainer.py index 5898ac8d7dd..2f444eb4dfc 100644 --- a/tests/test_orpo_trainer.py +++ b/tests/test_orpo_trainer.py @@ -82,13 +82,13 @@ def test_orpo_trainer(self, name, config_name): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the parameters have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if param.sum() != 0: # ignore 0 biases - self.assertFalse(torch.equal(param, new_param)) + assert not torch.equal(param, new_param) @parameterized.expand( [ @@ -137,14 +137,14 @@ def test_orpo_trainer_with_lora(self, config_name): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the parameters have changed for n, param in previous_trainable_params.items(): if "lora" in n: new_param = trainer.model.get_parameter(n) if param.sum() != 0: # ignore 0 biases - self.assertFalse(torch.equal(param, new_param)) + assert not torch.equal(param, new_param) def test_compute_metrics(self): model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") @@ -178,4 +178,4 @@ def dummy_compute_metrics(*args, **kwargs): trainer.train() - self.assertEqual(trainer.state.log_history[-2]["eval_test"], 0.0) + assert trainer.state.log_history[-2]["eval_test"] == 0.0 diff --git a/tests/test_peft_models.py b/tests/test_peft_models.py index 0543ee31c3c..3b68be8b817 100644 --- a/tests/test_peft_models.py +++ b/tests/test_peft_models.py @@ -63,7 +63,7 @@ def test_peft_requires_grad(self): model = AutoModelForCausalLMWithValueHead.from_pretrained(pretrained_model) # Check that the value head has requires_grad=True - self.assertTrue(model.v_head.summary.weight.requires_grad) + assert model.v_head.summary.weight.requires_grad def test_check_peft_model_nb_trainable_params(self): r""" @@ -76,12 +76,12 @@ def test_check_peft_model_nb_trainable_params(self): # Check that the number of trainable parameters is correct nb_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) - self.assertEqual(nb_trainable_params, 905) + assert nb_trainable_params == 905 # Check that the number of trainable param for the non-peft model is correct non_peft_model = AutoModelForCausalLMWithValueHead.from_pretrained(self.causal_lm_model_id) nb_trainable_params = sum(p.numel() for p in non_peft_model.parameters() if p.requires_grad) - self.assertEqual(nb_trainable_params, 2428641) + assert nb_trainable_params == 2428641 def test_create_peft_model_from_config(self): r""" @@ -92,13 +92,13 @@ def test_create_peft_model_from_config(self): ) # Check that the number of trainable parameters is correct nb_trainable_params = sum(p.numel() for p in trl_model.parameters() if p.requires_grad) - self.assertEqual(nb_trainable_params, 905) + assert nb_trainable_params == 905 causal_lm_model = AutoModelForCausalLM.from_pretrained(self.causal_lm_model_id) trl_model = AutoModelForCausalLMWithValueHead.from_pretrained(causal_lm_model, peft_config=self.lora_config) # Check that the number of trainable parameters is correct nb_trainable_params = sum(p.numel() for p in trl_model.parameters() if p.requires_grad) - self.assertEqual(nb_trainable_params, 905) + assert nb_trainable_params == 905 @require_torch_gpu_if_bnb_not_multi_backend_enabled def test_create_bnb_peft_model_from_config(self): @@ -112,8 +112,8 @@ def test_create_bnb_peft_model_from_config(self): ) # Check that the number of trainable parameters is correct nb_trainable_params = sum(p.numel() for p in trl_model.parameters() if p.requires_grad) - self.assertEqual(nb_trainable_params, 905) - self.assertIsInstance(trl_model.pretrained_model.model.model.layers[0].mlp.gate_proj, Linear8bitLt) + assert nb_trainable_params == 905 + assert isinstance(trl_model.pretrained_model.model.model.layers[0].mlp.gate_proj, Linear8bitLt) causal_lm_model = AutoModelForCausalLM.from_pretrained( self.causal_lm_model_id, load_in_8bit=True, device_map="auto" @@ -121,8 +121,8 @@ def test_create_bnb_peft_model_from_config(self): trl_model = AutoModelForCausalLMWithValueHead.from_pretrained(causal_lm_model, peft_config=self.lora_config) # Check that the number of trainable parameters is correct nb_trainable_params = sum(p.numel() for p in trl_model.parameters() if p.requires_grad) - self.assertEqual(nb_trainable_params, 905) - self.assertIsInstance(trl_model.pretrained_model.model.model.layers[0].mlp.gate_proj, Linear8bitLt) + assert nb_trainable_params == 905 + assert isinstance(trl_model.pretrained_model.model.model.layers[0].mlp.gate_proj, Linear8bitLt) def test_save_pretrained_peft(self): r""" @@ -136,31 +136,23 @@ def test_save_pretrained_peft(self): model.save_pretrained(self.tmp_dir) # check that the files `adapter_model.safetensors` and `adapter_config.json` are in the directory - self.assertTrue( - os.path.isfile(f"{self.tmp_dir}/adapter_model.safetensors"), - f"{self.tmp_dir}/adapter_model.safetensors does not exist", - ) - self.assertTrue( - os.path.exists(f"{self.tmp_dir}/adapter_config.json"), f"{self.tmp_dir}/adapter_config.json does not exist" - ) + assert os.path.isfile(f"{self.tmp_dir}/adapter_model.safetensors"), \ + f"{self.tmp_dir}/adapter_model.safetensors does not exist" + assert os.path.exists(f"{self.tmp_dir}/adapter_config.json"), f"{self.tmp_dir}/adapter_config.json does not exist" # check also for `pytorch_model.bin` and make sure it only contains `v_head` weights - self.assertTrue( - os.path.exists(f"{self.tmp_dir}/pytorch_model.bin"), f"{self.tmp_dir}/pytorch_model.bin does not exist" - ) + assert os.path.exists(f"{self.tmp_dir}/pytorch_model.bin"), f"{self.tmp_dir}/pytorch_model.bin does not exist" # check that only keys that starts with `v_head` are in the dict maybe_v_head = torch.load(f"{self.tmp_dir}/pytorch_model.bin", weights_only=True) - self.assertTrue( - all(k.startswith("v_head") for k in maybe_v_head.keys()), - f"keys in {self.tmp_dir}/pytorch_model.bin do not start with `v_head`", - ) + assert all(k.startswith("v_head") for k in maybe_v_head.keys()), \ + f"keys in {self.tmp_dir}/pytorch_model.bin do not start with `v_head`" model_from_pretrained = AutoModelForCausalLMWithValueHead.from_pretrained(self.tmp_dir) # check all the weights are the same for p1, p2 in zip(model.named_parameters(), model_from_pretrained.named_parameters()): - self.assertTrue(torch.allclose(p1[1], p2[1]), f"{p1[0]} != {p2[0]}") + assert torch.allclose(p1[1], p2[1]), f"{p1[0]} != {p2[0]}" def test_load_pretrained_peft(self): r""" @@ -175,18 +167,14 @@ def test_load_pretrained_peft(self): model_from_pretrained = AutoModelForCausalLMWithValueHead.from_pretrained(self.tmp_dir) # check that the files `adapter_model.safetensors` and `adapter_config.json` are in the directory - self.assertTrue( - os.path.isfile(f"{self.tmp_dir}/adapter_model.safetensors"), - f"{self.tmp_dir}/adapter_model.safetensors does not exist", - ) - self.assertTrue( - os.path.exists(f"{self.tmp_dir}/adapter_config.json"), f"{self.tmp_dir}/adapter_config.json does not exist" - ) + assert os.path.isfile(f"{self.tmp_dir}/adapter_model.safetensors"), \ + f"{self.tmp_dir}/adapter_model.safetensors does not exist" + assert os.path.exists(f"{self.tmp_dir}/adapter_config.json"), f"{self.tmp_dir}/adapter_config.json does not exist" # check all the weights are the same for p1, p2 in zip(model.named_parameters(), model_from_pretrained.named_parameters()): if p1[0] not in ["v_head.summary.weight", "v_head.summary.bias"]: - self.assertTrue(torch.allclose(p1[1], p2[1]), f"{p1[0]} != {p2[0]}") + assert torch.allclose(p1[1], p2[1]), f"{p1[0]} != {p2[0]}" def test_continue_training_peft_model(self): r""" @@ -200,4 +188,4 @@ def test_continue_training_peft_model(self): model = AutoModelForCausalLMWithValueHead.from_pretrained(self.tmp_dir, is_trainable=True) # Check that the number of trainable parameters is correct nb_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) - self.assertEqual(nb_trainable_params, 905) + assert nb_trainable_params == 905 diff --git a/tests/test_ppo_trainer.py b/tests/test_ppo_trainer.py index f8b95a8e5ff..13354c36b29 100644 --- a/tests/test_ppo_trainer.py +++ b/tests/test_ppo_trainer.py @@ -107,8 +107,8 @@ def test_basic_training(self): policy_weights_updated = True break - self.assertTrue(critic_weights_updated, "Critic weights were not updated during training") - self.assertTrue(policy_weights_updated, "Policy weights were not updated during training") + assert critic_weights_updated, "Critic weights were not updated during training" + assert policy_weights_updated, "Policy weights were not updated during training" @require_peft def test_peft_training(self): @@ -171,5 +171,5 @@ def test_peft_training(self): policy_weights_updated = True break - self.assertTrue(critic_weights_updated, "Critic weights were not updated during training") - self.assertTrue(policy_weights_updated, "Policy LoRA weights were not updated during training") + assert critic_weights_updated, "Critic weights were not updated during training" + assert policy_weights_updated, "Policy LoRA weights were not updated during training" diff --git a/tests/test_prm_trainer.py b/tests/test_prm_trainer.py index e26428c5203..76398519836 100644 --- a/tests/test_prm_trainer.py +++ b/tests/test_prm_trainer.py @@ -75,13 +75,11 @@ def test_tokenize_row_no_truncation(self): is_eval=False, ) - self.assertEqual( - result, + assert result == \ { "input_ids": [0, 465, 6766, 318, 298, 4, 322, 12, 1030, 4995, 11, 22, 1030], "labels": [-100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100, -100, 0], - }, - ) + } def test_tokenize_row_train_on_last_step_only(self): # Define the input features @@ -102,13 +100,11 @@ def test_tokenize_row_train_on_last_step_only(self): is_eval=False, ) - self.assertEqual( - result, + assert result == \ { "input_ids": [0, 465, 6766, 318, 298, 4, 322, 12, 1030, 4995, 11, 22, 1030], "labels": [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 0], - }, - ) + } def test_tokenize_row_prompt_truncation(self): # Define the input features @@ -130,13 +126,11 @@ def test_tokenize_row_prompt_truncation(self): is_eval=False, ) - self.assertEqual( - result, + assert result == \ { "input_ids": [6766, 318, 298, 4, 322, 12, 1030, 4995, 11, 22, 1030], "labels": [-100, -100, -100, -100, -100, -100, 1, -100, -100, -100, 0], - }, - ) + } def test_tokenize_row_completion_truncation(self): # Define the input features @@ -158,13 +152,11 @@ def test_tokenize_row_completion_truncation(self): is_eval=False, ) - self.assertEqual( - result, + assert result == \ { "input_ids": [0, 465, 6766, 318, 298, 4, 322, 12, 1030, 4995, 11], "labels": [-100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100], - }, - ) + } def test_tokenize_row_prompt_completion_truncation(self): # Define the input features @@ -186,13 +178,11 @@ def test_tokenize_row_prompt_completion_truncation(self): is_eval=False, ) - self.assertEqual( - result, + assert result == \ { "input_ids": [0, 465, 6766, 318, 298, 4, 322, 12, 1030], "labels": [-100, -100, -100, -100, -100, -100, -100, -100, 1], - }, - ) + } def test_tokenize_row_multi_token_separator(self): # Define the input features @@ -214,13 +204,11 @@ def test_tokenize_row_multi_token_separator(self): is_eval=False, ) - self.assertEqual( - result, + assert result == \ { "input_ids": [0, 465, 6766, 318, 298, 4, 322, 12, 1030, 1030, 4995, 11, 22, 1030, 1030], "labels": [-100, -100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100, -100, -100, 0], - }, - ) + } class PRMTrainerTester(TrlTestCase): @@ -244,12 +232,12 @@ def test_train_full(self, train_on_last_step_only): previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the parameters have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if param.sum() != 0: # ignore 0 biases - self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)) + assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12) def test_train_full_pretokenized(self): dummy_dataset = Dataset.from_dict( @@ -297,12 +285,12 @@ def test_train_full_pretokenized(self): previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the parameters have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if param.sum() != 0: # ignore 0 biases - self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)) + assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12) @require_peft def test_train_lora(self): @@ -337,17 +325,17 @@ def test_train_lora(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[(-1)]["train_loss"]) + assert trainer.state.log_history[(-1)]["train_loss"] is not None # Check that the parameters have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param, atol=1e-12, rtol=1e-12)) + assert not torch.allclose(param, new_param, atol=1e-12, rtol=1e-12) # Check that the non trainable parameters have not changed for n, param in previous_non_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertTrue(torch.allclose(param, new_param, atol=1e-12, rtol=1e-12)) + assert torch.allclose(param, new_param, atol=1e-12, rtol=1e-12) def test_tags(self): dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_stepwise_supervision", split="train") @@ -355,4 +343,4 @@ def test_tags(self): trainer = PRMTrainer( model=self.model, args=training_args, processing_class=self.tokenizer, train_dataset=dummy_dataset ) - self.assertEqual(trainer.model.model_tags, trainer._tag_names) + assert trainer.model.model_tags == trainer._tag_names diff --git a/tests/test_reward_trainer.py b/tests/test_reward_trainer.py index b4d53e16941..ba8300c78ff 100644 --- a/tests/test_reward_trainer.py +++ b/tests/test_reward_trainer.py @@ -131,12 +131,12 @@ def test_train(self, model_id): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" @parameterized.expand( [ @@ -165,12 +165,12 @@ def test_train_dataset_types(self, config_name): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_model(self): # Instantiate the model @@ -192,12 +192,12 @@ def test_train_model(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_from_causal_lm(self): # Get the dataset @@ -216,12 +216,12 @@ def test_train_from_causal_lm(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_model_dtype(self): # Get the dataset @@ -247,7 +247,7 @@ def test_train_model_dtype(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): @@ -257,8 +257,8 @@ def test_train_model_dtype(self): continue new_param = trainer.model.get_parameter(n) # Check the torch dtype - self.assertEqual(new_param.dtype, torch.float16) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert new_param.dtype == torch.float16 + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" @require_peft def test_train_dense_with_peft_config(self): @@ -287,15 +287,15 @@ def test_train_dense_with_peft_config(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the peft params have changed and the base model params have not changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if n in base_param_names: # We expect the base model parameters to be the same - self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed") + assert torch.allclose(param, new_param), f"Parameter {n} has changed" elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" @require_peft def test_train_moe_with_peft_config(self): @@ -324,15 +324,15 @@ def test_train_moe_with_peft_config(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the peft params have changed and the base model params have not changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if n in base_param_names: # We expect the base model parameters to be the same - self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed") + assert torch.allclose(param, new_param), f"Parameter {n} has changed" elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" @require_peft def test_train_peft_model(self): @@ -361,15 +361,15 @@ def test_train_peft_model(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the peft params have changed and the base model params have not changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if n in base_param_names: # We expect the base model parameters to be the same - self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed") + assert torch.allclose(param, new_param), f"Parameter {n} has changed" elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" @require_peft def test_train_dense_with_peft_config_and_gradient_checkpointing(self): @@ -398,15 +398,15 @@ def test_train_dense_with_peft_config_and_gradient_checkpointing(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the peft params have changed and the base model params have not changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if n in base_param_names: # We expect the base model parameters to be the same - self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed") + assert torch.allclose(param, new_param), f"Parameter {n} has changed" elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" @require_peft def test_train_moe_with_peft_config_and_gradient_checkpointing(self): @@ -435,15 +435,15 @@ def test_train_moe_with_peft_config_and_gradient_checkpointing(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the peft params have changed and the base model params have not changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if n in base_param_names: # We expect the base model parameters to be the same - self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed") + assert torch.allclose(param, new_param), f"Parameter {n} has changed" elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" @require_peft def test_train_with_peft_model_and_gradient_checkpointing(self): @@ -462,7 +462,7 @@ def test_train_with_peft_model_and_gradient_checkpointing(self): trainer = RewardTrainer(model=model, args=training_args, train_dataset=dataset) # Verify model is a PeftModel - self.assertIsInstance(trainer.model, PeftModel) + assert isinstance(trainer.model, PeftModel) # Save the initial parameters to compare them later previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} @@ -471,15 +471,15 @@ def test_train_with_peft_model_and_gradient_checkpointing(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the peft params have changed and the base model params have not changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if n in base_param_names: # We expect the base model parameters to be the same - self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed") + assert torch.allclose(param, new_param), f"Parameter {n} has changed" elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_with_pretokenized_data(self): # Get the dataset @@ -507,12 +507,12 @@ def tokenize_example(example): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_with_iterable_dataset(self): # Get the dataset @@ -535,12 +535,12 @@ def test_train_with_iterable_dataset(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_with_chat_template_kwargs(self): # Get the dataset @@ -569,12 +569,12 @@ def test_train_with_chat_template_kwargs(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_with_set_chat_template_from_model(self): # Get the dataset @@ -596,7 +596,7 @@ def test_train_with_set_chat_template_from_model(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): @@ -606,7 +606,7 @@ def test_train_with_set_chat_template_from_model(self): # this parameter. if n == "gpt_neox.final_layer_norm.bias": continue - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_with_set_chat_template_from_path(self): # Get the dataset @@ -632,7 +632,7 @@ def test_train_with_set_chat_template_from_path(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): @@ -642,19 +642,17 @@ def test_train_with_set_chat_template_from_path(self): # this parameter. if n == "gpt_neox.final_layer_norm.bias": continue - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" # Check that the template saved in the output directory is the same as the one used for training template_path = pathlib.Path(self.tmp_dir) / "checkpoint-9" / "chat_template.jinja" - self.assertTrue(template_path.exists(), f"Chat template not found at {template_path}") + assert template_path.exists(), f"Chat template not found at {template_path}" with open(template_path) as f: template_content = f.read() with open(training_args.chat_template_path) as f: original_template_content = f.read() - self.assertEqual( - template_content, original_template_content, "Chat template content does not match the original" - ) + assert template_content == original_template_content, "Chat template content does not match the original" @unittest.skip("Skipping until we have a dataset with tool calls") def test_train_toolcall_data(self): @@ -676,12 +674,12 @@ def test_train_toolcall_data(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_with_eval(self): # Get the dataset @@ -700,7 +698,7 @@ def test_train_with_eval(self): trainer.train() # Check that the eval loss is not None - self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"]) + assert trainer.state.log_history[0]["eval_loss"] is not None def test_train_with_multiple_eval_dataset(self): # Get the dataset @@ -718,8 +716,8 @@ def test_train_with_multiple_eval_dataset(self): trainer.train() # Check that the eval losses are not None - self.assertIsNotNone(trainer.state.log_history[-3]["eval_data1_loss"]) - self.assertIsNotNone(trainer.state.log_history[-2]["eval_data2_loss"]) + assert trainer.state.log_history[-3]["eval_data1_loss"] is not None + assert trainer.state.log_history[-2]["eval_data2_loss"] is not None def test_train_with_gradient_checkpointing(self): # Get the dataset @@ -740,12 +738,12 @@ def test_train_with_gradient_checkpointing(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_tag_added(self): # Get the dataset @@ -758,7 +756,7 @@ def test_tag_added(self): ) for tag in ["reward-trainer", "trl"]: - self.assertIn(tag, trainer.model.model_tags) + assert tag in trainer.model.model_tags @require_peft def test_tag_added_peft(self): @@ -773,7 +771,7 @@ def test_tag_added_peft(self): ) for tag in ["reward-trainer", "trl"]: - self.assertIn(tag, trainer.model.model_tags) + assert tag in trainer.model.model_tags def test_train_with_margin(self): # Get the dataset @@ -800,12 +798,12 @@ def add_margin(example): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_with_center_rewards_coefficient(self): # Get the dataset @@ -826,9 +824,9 @@ def test_train_with_center_rewards_coefficient(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" diff --git a/tests/test_rewards.py b/tests/test_rewards.py index 8b20a0ff7e9..aac6aabca0c 100644 --- a/tests/test_rewards.py +++ b/tests/test_rewards.py @@ -31,7 +31,7 @@ def test_valid_format(self): completions = [[{"content": completion}] for completion in completions] expected_rewards = [1.0, 1.0, 1.0, 1.0, 1.0] # All should be valid rewards = think_format_reward(completions) - self.assertEqual(rewards, expected_rewards) + assert rewards == expected_rewards def test_invalid_format(self): completions = [ @@ -48,7 +48,7 @@ def test_invalid_format(self): completions = [[{"content": completion}] for completion in completions] expected_rewards = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] # All should be invalid rewards = think_format_reward(completions) - self.assertEqual(rewards, expected_rewards) + assert rewards == expected_rewards def test_mixed_format(self): completions = [ @@ -60,7 +60,7 @@ def test_mixed_format(self): completions = [[{"content": completion}] for completion in completions] expected_rewards = [1.0, 1.0, 0.0, 0.0] rewards = think_format_reward(completions) - self.assertEqual(rewards, expected_rewards) + assert rewards == expected_rewards class SoftOverlongPunishmentRewardTester(unittest.TestCase): @@ -70,7 +70,7 @@ def test_soft_overlong_punishment_short_completion(self): reward_fn = get_soft_overlong_punishment(max_completion_len=100, soft_punish_cache=20) completion_ids = [[1] * 50] # 50 <= 80 rewards = reward_fn(completion_ids=completion_ids) - self.assertEqual(rewards, [0]) + assert rewards == [0] def test_soft_overlong_punishment_long_completion(self): """Test soft overlong punishment reward function with a longer than max completion.""" @@ -78,14 +78,14 @@ def test_soft_overlong_punishment_long_completion(self): reward_fn = get_soft_overlong_punishment(max_completion_len=100, soft_punish_cache=20) completion_ids = [[1] * 110] rewards = reward_fn(completion_ids) - self.assertEqual(rewards, [-1]) + assert rewards == [-1] def test_soft_overlong_punishment_intermediate_completion(self): """Test soft overlong punishment reward function for intermediate length completion.""" reward_fn = get_soft_overlong_punishment(max_completion_len=100, soft_punish_cache=20) completion_ids = [[1] * 90] # 90 is between 80 and 100 rewards = reward_fn(completion_ids) - self.assertAlmostEqual(rewards[0], -0.5, places=4) + assert round(abs(rewards[0]--0.5), 4) == 0 if __name__ == "__main__": diff --git a/tests/test_rloo_trainer.py b/tests/test_rloo_trainer.py index cde52de6047..2ddb51248ce 100644 --- a/tests/test_rloo_trainer.py +++ b/tests/test_rloo_trainer.py @@ -30,6 +30,7 @@ from trl import RLOOConfig, RLOOTrainer from .testing_utils import TrlTestCase, require_vllm +import pytest if is_peft_available(): @@ -69,12 +70,12 @@ def test_training(self, config_name): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_with_eval(self): dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only") @@ -122,12 +123,12 @@ def test_training_multiple_iterations(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." @require_peft def test_training_peft(self): @@ -155,15 +156,15 @@ def test_training_peft(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the peft params have changed and the base model params have not changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if n in base_param_names: # We expect the base model params to be the same - self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed.") + assert torch.allclose(param, new_param), f"Parameter {n} has changed." elif "base_layer" not in n: # We expect the peft params to be different (except for the base layer) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed.") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed." @require_peft def test_training_peft_with_gradient_checkpointing(self): @@ -197,22 +198,22 @@ def test_training_peft_with_gradient_checkpointing(self): ) # Verify gradient checkpointing is enabled - self.assertIsInstance(trainer.model, PeftModel) + assert isinstance(trainer.model, PeftModel) # Store initial parameters to check which ones change previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that only LoRA parameters have changed, base model parameters remain unchanged for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if "lora" in n.lower(): # LoRA parameters should change - self.assertFalse(torch.equal(param, new_param), f"LoRA parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"LoRA parameter {n} has not changed." else: # Base model parameters should not change - self.assertTrue(torch.equal(param, new_param), f"Base parameter {n} has changed.") + assert torch.equal(param, new_param), f"Base parameter {n} has changed." def test_training_different_reward_model(self): # Use a reward model different from the model: different chat template, tokenization, etc. @@ -246,12 +247,12 @@ def test_training_different_reward_model(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_reward_func_standard(self): # Test if trainer can handle reward function with standard format @@ -280,12 +281,12 @@ def reward_func(completions, **kwargs): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_reward_func_conversational(self): # Test if trainer can handle reward function with conversational format @@ -315,12 +316,12 @@ def reward_func(completions, **kwargs): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_multiple_reward_funcs(self): # Test that RLOOTrainer can be instantiated with multiple reward functions @@ -353,12 +354,12 @@ def reward_func2(completions, **kwargs): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_multiple_reward_funcs_with_None_output(self): """Test that a valid math reward function is processed correctly while the code reward function returns None.""" @@ -397,12 +398,12 @@ def non_applicable_reward_func(completions, **kwargs): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_multiple_reward_funcs_with_weights(self): """Test that RLOOTrainer can handle multiple reward functions with weights.""" @@ -437,16 +438,16 @@ def reward_func2(completions, **kwargs): trainer.train() # Check that training logs contain both reward metrics - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) - self.assertIn("rewards/reward_func1/mean", trainer.state.log_history[-1]) - self.assertIn("rewards/reward_func1/std", trainer.state.log_history[-1]) - self.assertIn("rewards/reward_func2/mean", trainer.state.log_history[-1]) - self.assertIn("rewards/reward_func2/std", trainer.state.log_history[-1]) + assert trainer.state.log_history[-1]["train_loss"] is not None + assert "rewards/reward_func1/mean" in trainer.state.log_history[-1] + assert "rewards/reward_func1/std" in trainer.state.log_history[-1] + assert "rewards/reward_func2/mean" in trainer.state.log_history[-1] + assert "rewards/reward_func2/std" in trainer.state.log_history[-1] # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_multiple_mixed_reward_funcs(self): # Test if the trainer can handle a mix of reward functions and reward models @@ -475,12 +476,12 @@ def reward_func(completions, **kwargs): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_reward_func_additional_column(self): # Test if trainer can handle reward function that rely on additional columns in the dataset @@ -513,12 +514,12 @@ def reward_func(completions, some_values, **kwargs): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_with_sync_ref_model(self): dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") @@ -544,12 +545,12 @@ def test_training_with_sync_ref_model(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_beta_zero(self): dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") @@ -573,12 +574,12 @@ def test_training_beta_zero(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." @unittest.skip("We should add a mock for the vLLM server.") @require_peft @@ -615,16 +616,16 @@ def test_training_vllm_and_peft(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the peft params have changed and the base model params have not changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if n in base_param_names: # We expect the base model params to be the same - self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed.") + assert torch.allclose(param, new_param), f"Parameter {n} has changed." elif "base_layer" not in n and "original_module" not in n: # We expect the peft params to be different (except for the base layer) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed.") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed." @require_vllm @unittest.skip("We should add a mock for the vLLM server.") @@ -653,12 +654,12 @@ def test_training_vllm_guided_decoding(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_with_additional_generation_kwargs(self): """Test that training works with additional generation kwargs.""" @@ -688,12 +689,12 @@ def test_training_with_additional_generation_kwargs(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." @require_vllm @unittest.skip("We should add a mock for the vLLM server.") @@ -726,12 +727,12 @@ def test_training_vllm_with_additional_generation_kwargs(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_with_normalized_advantages(self): dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") @@ -756,12 +757,12 @@ def test_training_with_normalized_advantages(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_with_clipped_rewards(self): dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") @@ -786,12 +787,12 @@ def test_training_with_clipped_rewards(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." @patch("transformers.generation.utils.GenerationMixin.generate") def test_training_with_mask_truncated_completions(self, mock_generate): @@ -836,12 +837,12 @@ def fake_generate(input_ids, **kwargs): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_with_mask_truncated_completions_all_masked(self): """ @@ -874,12 +875,12 @@ def test_training_with_mask_truncated_completions_all_masked(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertTrue(torch.equal(param, new_param), f"Parameter {n} has changed.") + assert torch.equal(param, new_param), f"Parameter {n} has changed." def test_warning_raised_all_rewards_none(self): """Test that a proper warning is raised when all rewards are None.""" @@ -908,7 +909,7 @@ def always_none_reward_func(completions, **kwargs): trainer.train() expected_warning = "All reward functions returned None for the following kwargs:" - self.assertIn(expected_warning, cm.output[0]) + assert expected_warning in cm.output[0] def test_training_num_generations_larger_than_batch_size(self): dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") @@ -933,12 +934,12 @@ def test_training_num_generations_larger_than_batch_size(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_multiple_dataloader_workers(self): dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") @@ -963,12 +964,12 @@ def test_training_multiple_dataloader_workers(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_with_generation_kwargs(self): dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") @@ -993,12 +994,12 @@ def test_training_with_generation_kwargs(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_with_reward_func_accessing_trainer_state(self): dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") @@ -1070,7 +1071,7 @@ def test_prepare_input_called_with_correct_data(self): with patch.object(RLOOTrainer, "training_step", wraps=trainer.training_step) as mock_prepare: trainer.train() # 3 epochs * 2 iterations * 2 generation batches to cover the dataset * 4 steps_per_generation - self.assertEqual(mock_prepare.call_count, 48) + assert mock_prepare.call_count == 48 for i in range(0, 8): # Generation batch repeated 8 times (steps_per_generation*num_iterations) assert mock_prepare.call_args_list[i].args[1] == expected_first_generation_batch for i in range(8, 16): @@ -1113,7 +1114,7 @@ def reward_func(completions, **kwargs): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed # Because of the way the tiny models are initialized, the gradient does not flow properly through the @@ -1128,7 +1129,7 @@ def reward_func(completions, **kwargs): if n.startswith(params_to_skip): continue new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." @require_vision def test_training_vlm_beta_non_zero(self): @@ -1158,7 +1159,7 @@ def reward_func(completions, **kwargs): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed # Because of the way the tiny models are initialized, the gradient does not flow properly through the @@ -1168,7 +1169,7 @@ def reward_func(completions, **kwargs): if n.startswith(params_to_skip): continue new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." @require_vision @require_peft @@ -1203,15 +1204,15 @@ def reward_func(completions, **kwargs): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the peft params have changed and the base model params have not changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if n in base_param_names: # We expect the base model params to be the same - self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed.") + assert torch.allclose(param, new_param), f"Parameter {n} has changed." elif "base_layer" not in n: # We expect the peft params to be different (except for the base layer) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed.") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed." @require_vision def test_training_vlm_and_prompt_truncation(self): @@ -1242,7 +1243,7 @@ def reward_func(completions, **kwargs): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed # Because of the way the tiny models are initialized, the gradient does not flow properly through the @@ -1252,7 +1253,7 @@ def reward_func(completions, **kwargs): if n.startswith(params_to_skip): continue new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." @require_vision @require_vllm @@ -1292,11 +1293,11 @@ def reward_func(completions, **kwargs): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." @require_vision def test_training_vlm_multi_image(self): @@ -1329,11 +1330,11 @@ def reward_func(completions, **kwargs): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_mismatched_reward_processing_classes_length(self): """Test that mismatched length between reward_funcs and reward_processing_classes raises error.""" @@ -1352,7 +1353,7 @@ def test_mismatched_reward_processing_classes_length(self): training_args = RLOOConfig(output_dir=self.tmp_dir, report_to="none") - with self.assertRaises(ValueError) as context: + with pytest.raises(ValueError) as context: RLOOTrainer( model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", reward_funcs=reward_models, @@ -1361,7 +1362,7 @@ def test_mismatched_reward_processing_classes_length(self): train_dataset=dataset, ) - self.assertIn("must match", str(context.exception)) + assert "must match" in str(context.exception) def test_correct_reward_processing_classes_list(self): """Test that correct list of reward_processing_classes works properly.""" @@ -1392,7 +1393,7 @@ def test_correct_reward_processing_classes_list(self): train_dataset=dataset, ) - self.assertEqual(len(trainer.reward_processing_classes), len(reward_models)) + assert len(trainer.reward_processing_classes) == len(reward_models) def test_single_reward_model_with_single_processing_class(self): """Test that single reward model with single processing class works.""" @@ -1416,8 +1417,8 @@ def test_single_reward_model_with_single_processing_class(self): train_dataset=dataset, ) - self.assertEqual(len(trainer.reward_processing_classes), 1) - self.assertEqual(trainer.reward_processing_classes[0], single_processing_class) + assert len(trainer.reward_processing_classes) == 1 + assert trainer.reward_processing_classes[0] == single_processing_class if __name__ == "__main__": diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 19d5a1b5d70..2a558d6093f 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -64,7 +64,7 @@ def test_basic_padding(self): result = self.collator(examples) - self.assertEqual(set(result.keys()), {"input_ids", "attention_mask", "labels"}) + assert set(result.keys()) == {"input_ids", "attention_mask", "labels"} torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]])) torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]])) torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [4, 5, -100]])) @@ -79,7 +79,7 @@ def test_completion_mask(self): result = self.collator(examples) - self.assertEqual(set(result.keys()), {"input_ids", "attention_mask", "labels"}) + assert set(result.keys()) == {"input_ids", "attention_mask", "labels"} torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]])) torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]])) torch.testing.assert_close(result["labels"], torch.tensor([[-100, 2, 3], [-100, 5, -100]])) @@ -95,7 +95,7 @@ def test_completion_only_loss_disabled(self): result = collator(examples) # Labels should not be masked when completion_only_loss=False - self.assertEqual(set(result.keys()), {"input_ids", "attention_mask", "labels"}) + assert set(result.keys()) == {"input_ids", "attention_mask", "labels"} torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]])) torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]])) torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [4, 5, -100]])) @@ -107,7 +107,7 @@ def test_padding_free_mode(self): result = collator(examples) - self.assertEqual(set(result.keys()), {"input_ids", "position_ids", "labels"}) + assert set(result.keys()) == {"input_ids", "position_ids", "labels"} torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4, 5]])) torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0, 1]])) torch.testing.assert_close(result["labels"], torch.tensor([[-100, 2, 3, -100, 5]])) @@ -122,7 +122,7 @@ def test_padding_free_with_completion_mask(self): result = collator(examples) - self.assertEqual(set(result.keys()), {"input_ids", "position_ids", "labels"}) + assert set(result.keys()) == {"input_ids", "position_ids", "labels"} torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4, 5]])) torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0, 1]])) torch.testing.assert_close(result["labels"], torch.tensor([[-100, -100, 3, -100, 5]])) @@ -139,7 +139,7 @@ def test_packing(self): result = collator(examples) - self.assertEqual(set(result.keys()), {"input_ids", "position_ids", "labels"}) + assert set(result.keys()) == {"input_ids", "position_ids", "labels"} torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]])) torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0, 1, 2, 0, 1, 2, 3, 0]])) torch.testing.assert_close(result["labels"], torch.tensor([[-100, 2, 3, -100, 5, 6, -100, 8, 9, 10, -100]])) @@ -151,7 +151,7 @@ def test_pad_to_multiple_of(self): result = collator(examples) - self.assertEqual(set(result.keys()), {"input_ids", "attention_mask", "labels"}) + assert set(result.keys()) == {"input_ids", "attention_mask", "labels"} torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 0], [4, 5, 0, 0]])) torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1, 0], [1, 1, 0, 0]])) torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3, -100], [4, 5, -100, -100]])) @@ -163,7 +163,7 @@ def test_pad_to_multiple_of_and_padding_free(self): result = collator(examples) - self.assertEqual(set(result.keys()), {"input_ids", "position_ids", "labels"}) + assert set(result.keys()) == {"input_ids", "position_ids", "labels"} torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4, 5, 0, 0, 0]])) torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0, 1, 0, 0, 0]])) torch.testing.assert_close(result["labels"], torch.tensor([[-100, 2, 3, -100, 5, -100, -100, -100]])) @@ -175,7 +175,7 @@ def test_custom_position_ids_but_no_padding_free(self): result = self.collator(examples) - self.assertEqual(set(result.keys()), {"input_ids", "attention_mask", "labels"}) + assert set(result.keys()) == {"input_ids", "attention_mask", "labels"} torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]])) torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]])) torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [4, 5, -100]])) @@ -187,7 +187,7 @@ def test_single_example(self): result = self.collator(examples) - self.assertEqual(set(result.keys()), {"input_ids", "attention_mask", "labels"}) + assert set(result.keys()) == {"input_ids", "attention_mask", "labels"} torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4]])) torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1, 1]])) torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3, 4]])) @@ -199,7 +199,7 @@ def test_different_pad_token_id(self): result = collator(examples) - self.assertEqual(set(result.keys()), {"input_ids", "attention_mask", "labels"}) + assert set(result.keys()) == {"input_ids", "attention_mask", "labels"} torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 999]])) torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]])) torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [4, 5, -100]])) @@ -221,22 +221,22 @@ def test_assistant_masks(self): def test_single_example_single_doc(self): batch_seq_lengths = [[5]] result = DataCollatorForLanguageModeling.get_position_ids_from_packed_seq_lengths(batch_seq_lengths) - self.assertEqual(len(result), 1) - self.assertTrue(torch.equal(result[0], torch.arange(5))) + assert len(result) == 1 + assert torch.equal(result[0], torch.arange(5)) def test_single_example_multiple_docs(self): batch_seq_lengths = [[3, 2]] result = DataCollatorForLanguageModeling.get_position_ids_from_packed_seq_lengths(batch_seq_lengths) - self.assertEqual(len(result), 1) + assert len(result) == 1 # First sequence: 0, 1, 2; second sequence: 0, 1 - self.assertTrue(torch.equal(result[0], torch.tensor([0, 1, 2, 0, 1]))) + assert torch.equal(result[0], torch.tensor([0, 1, 2, 0, 1])) def test_multiple_examples(self): batch_seq_lengths = [[2, 2], [3]] result = DataCollatorForLanguageModeling.get_position_ids_from_packed_seq_lengths(batch_seq_lengths) - self.assertEqual(len(result), 2) - self.assertTrue(torch.equal(result[0], torch.tensor([0, 1, 0, 1]))) - self.assertTrue(torch.equal(result[1], torch.arange(3))) + assert len(result) == 2 + assert torch.equal(result[0], torch.tensor([0, 1, 0, 1])) + assert torch.equal(result[1], torch.arange(3)) class SFTTrainerTester(TrlTestCase): @@ -262,12 +262,12 @@ def test_train(self, model_id): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" # Special case for harmony def test_train_gpt_oss(self): @@ -287,12 +287,12 @@ def test_train_gpt_oss(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_model(self): # Instantiate the model @@ -312,12 +312,12 @@ def test_train_model(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_dft_loss(self): # Get the dataset @@ -348,12 +348,12 @@ def test_train_dft_loss(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_moe_model_with_aux_loss(self): # Get the dataset @@ -375,13 +375,13 @@ def test_train_moe_model_with_aux_loss(self): trainer.train() # Check that the training loss and aux loss are not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) - self.assertIsNotNone(trainer.state.log_history[-1]["aux_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None + assert trainer.state.log_history[-1]["aux_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_with_formatting_func(self): # Dummy formatting function @@ -408,12 +408,12 @@ def formatting_prompts_func(example): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_model_dtype(self): # Get the dataset @@ -437,7 +437,7 @@ def test_train_model_dtype(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): @@ -447,8 +447,8 @@ def test_train_model_dtype(self): continue new_param = trainer.model.get_parameter(n) # Check the torch dtype - self.assertEqual(new_param.dtype, torch.float16) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert new_param.dtype == torch.float16 + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" @require_peft def test_train_dense_with_peft_config(self): @@ -477,15 +477,15 @@ def test_train_dense_with_peft_config(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the peft params have changed and the base model params have not changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if n in base_param_names: # We expect the base model parameters to be the same - self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed") + assert torch.allclose(param, new_param), f"Parameter {n} has changed" elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" @require_peft def test_train_moe_with_peft_config(self): @@ -514,15 +514,15 @@ def test_train_moe_with_peft_config(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the peft params have changed and the base model params have not changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if n in base_param_names: # We expect the base model parameters to be the same - self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed") + assert torch.allclose(param, new_param), f"Parameter {n} has changed" elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" @require_peft def test_train_peft_model(self): @@ -551,15 +551,15 @@ def test_train_peft_model(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the peft params have changed and the base model params have not changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if n in base_param_names: # We expect the base model parameters to be the same - self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed") + assert torch.allclose(param, new_param), f"Parameter {n} has changed" elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" @require_peft def test_train_dense_with_peft_config_and_gradient_checkpointing(self): @@ -588,15 +588,15 @@ def test_train_dense_with_peft_config_and_gradient_checkpointing(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the peft params have changed and the base model params have not changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if n in base_param_names: # We expect the base model parameters to be the same - self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed") + assert torch.allclose(param, new_param), f"Parameter {n} has changed" elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" @require_peft def test_train_moe_with_peft_config_and_gradient_checkpointing(self): @@ -625,15 +625,15 @@ def test_train_moe_with_peft_config_and_gradient_checkpointing(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the peft params have changed and the base model params have not changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if n in base_param_names: # We expect the base model parameters to be the same - self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed") + assert torch.allclose(param, new_param), f"Parameter {n} has changed" elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" @require_peft def test_train_with_peft_model_and_gradient_checkpointing(self): @@ -652,7 +652,7 @@ def test_train_with_peft_model_and_gradient_checkpointing(self): trainer = SFTTrainer(model=model, args=training_args, train_dataset=dataset) # Verify model is a PeftModel - self.assertIsInstance(trainer.model, PeftModel) + assert isinstance(trainer.model, PeftModel) # Save the initial parameters to compare them later previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} @@ -661,15 +661,15 @@ def test_train_with_peft_model_and_gradient_checkpointing(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the peft params have changed and the base model params have not changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if n in base_param_names: # We expect the base model parameters to be the same - self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed") + assert torch.allclose(param, new_param), f"Parameter {n} has changed" elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" @require_liger_kernel def test_train_with_liger(self): @@ -689,12 +689,12 @@ def test_train_with_liger(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_with_non_chatml_conversational_data(self): # Get the dataset @@ -719,12 +719,12 @@ def rename_fields(example: list[dict]): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_with_pretokenized_data(self): # Get the dataset @@ -749,12 +749,12 @@ def tokenize_example(example): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_with_iterable_dataset(self): # Get the dataset @@ -773,12 +773,12 @@ def test_train_with_iterable_dataset(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" @require_flash_attn def test_train_padding_free(self): @@ -804,12 +804,12 @@ def test_train_padding_free(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" @parameterized.expand([("bfd",), ("wrapped",)]) @ignore_warnings(message="You are using packing, but the attention implementation is not.*", category=UserWarning) @@ -833,12 +833,12 @@ def test_train_packing(self, packing_strategy): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" @ignore_warnings(message="You are using packing, but the attention implementation is not.*", category=UserWarning) @ignore_warnings(message="Padding-free training is enabled, but the attention.*", category=UserWarning) @@ -863,16 +863,16 @@ def test_eval_packing(self): # Check the number of sequences in train and eval datasets num_train_seqs = sum(len(x) for x in trainer.train_dataset["seq_lengths"]) num_eval_seqs = sum(len(x) for x in trainer.eval_dataset["seq_lengths"]) - self.assertEqual(num_train_seqs, 17) # we should still have 17 seqs - self.assertEqual(num_eval_seqs, 2) # we should still have 2 seqs + assert num_train_seqs == 17 # we should still have 17 seqs + assert num_eval_seqs == 2 # we should still have 2 seqs # Check that all sequences are shorter than the max length - self.assertTrue(all(sum(x) <= 64 for x in trainer.train_dataset["seq_lengths"])) - self.assertTrue(all(sum(x) <= 64 for x in trainer.eval_dataset["seq_lengths"])) + assert all(sum(x) <= 64 for x in trainer.train_dataset["seq_lengths"]) + assert all(sum(x) <= 64 for x in trainer.eval_dataset["seq_lengths"]) # Check the number of sequences in train and eval datasets - self.assertEqual(len(trainer.train_dataset["input_ids"]), 3) # w/ this dataset, we end up with 46 seqs - self.assertEqual(len(trainer.eval_dataset["input_ids"]), 1) # w/ this dataset, we end up with 6 seqs + assert len(trainer.train_dataset["input_ids"]) == 3 # w/ this dataset, we end up with 46 seqs + assert len(trainer.eval_dataset["input_ids"]) == 1 # w/ this dataset, we end up with 6 seqs @ignore_warnings(message="You are using packing, but the attention implementation is not.*", category=UserWarning) @ignore_warnings(message="Padding-free training is enabled, but the attention.*", category=UserWarning) @@ -897,17 +897,17 @@ def test_only_train_packing(self): # Check the number of sequences in train dataset num_train_seqs = sum(len(x) for x in trainer.train_dataset["seq_lengths"]) - self.assertEqual(num_train_seqs, 17) # we should still have 17 seqs + assert num_train_seqs == 17 # we should still have 17 seqs # We expect eval dataset not having "seq_lengths" as eval_packing is False - self.assertNotIn("seq_lengths", trainer.eval_dataset) + assert "seq_lengths" not in trainer.eval_dataset # Check that all sequences are shorter than the max length - self.assertTrue(all(sum(x) <= 64 for x in trainer.train_dataset["seq_lengths"])) + assert all(sum(x) <= 64 for x in trainer.train_dataset["seq_lengths"]) # Check the number of sequences in train and eval datasets - self.assertEqual(len(trainer.train_dataset["input_ids"]), 3) # w/ this dataset, we end up with 46 seqs - self.assertEqual(len(trainer.eval_dataset["input_ids"]), 2) # w/ this dataset, we end up with 6 seqs + assert len(trainer.train_dataset["input_ids"]) == 3 # w/ this dataset, we end up with 46 seqs + assert len(trainer.eval_dataset["input_ids"]) == 2 # w/ this dataset, we end up with 6 seqs def test_train_with_chat_template_kwargs(self): # Get the dataset @@ -934,12 +934,12 @@ def test_train_with_chat_template_kwargs(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_assistant_only(self): # Get the dataset @@ -958,12 +958,12 @@ def test_train_assistant_only(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_completion_only(self): # Get the dataset @@ -982,12 +982,12 @@ def test_train_completion_only(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_completion_only_harmony(self): # Get the dataset @@ -1006,12 +1006,12 @@ def test_train_completion_only_harmony(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_assistant_only_and_completion_only(self): # Get the dataset @@ -1040,12 +1040,12 @@ def add_to_completion(example): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_assistant_only_iterable_dataset(self): # Get the dataset @@ -1066,12 +1066,12 @@ def test_train_assistant_only_iterable_dataset(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_with_set_chat_template_from_model(self): # Get the dataset @@ -1091,12 +1091,12 @@ def test_train_with_set_chat_template_from_model(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_with_set_chat_template_from_path(self): # Get the dataset @@ -1120,24 +1120,22 @@ def test_train_with_set_chat_template_from_path(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" # Check that the template saved in the output directory is the same as the one used for training template_path = pathlib.Path(self.tmp_dir) / "checkpoint-9" / "chat_template.jinja" - self.assertTrue(template_path.exists(), f"Chat template not found at {template_path}") + assert template_path.exists(), f"Chat template not found at {template_path}" with open(template_path) as f: template_content = f.read() with open(training_args.chat_template_path) as f: original_template_content = f.read() - self.assertEqual( - template_content, original_template_content, "Chat template content does not match the original" - ) + assert template_content == original_template_content, "Chat template content does not match the original" def test_train_toolcall_data(self): # Get the dataset @@ -1156,12 +1154,12 @@ def test_train_toolcall_data(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_with_eval(self): # Get the dataset @@ -1180,7 +1178,7 @@ def test_train_with_eval(self): trainer.train() # Check that the eval loss is not None - self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"]) + assert trainer.state.log_history[0]["eval_loss"] is not None def test_train_with_multiple_eval_dataset(self): # Get the dataset @@ -1198,8 +1196,8 @@ def test_train_with_multiple_eval_dataset(self): trainer.train() # Check that the eval losses are not None - self.assertIsNotNone(trainer.state.log_history[-3]["eval_data1_loss"]) - self.assertIsNotNone(trainer.state.log_history[-2]["eval_data2_loss"]) + assert trainer.state.log_history[-3]["eval_data1_loss"] is not None + assert trainer.state.log_history[-2]["eval_data2_loss"] is not None def test_train_with_gradient_checkpointing(self): # Get the dataset @@ -1218,12 +1216,12 @@ def test_train_with_gradient_checkpointing(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_tag_added(self): # Get the dataset @@ -1236,7 +1234,7 @@ def test_tag_added(self): ) for tag in ["sft", "trl"]: - self.assertIn(tag, trainer.model.model_tags) + assert tag in trainer.model.model_tags @require_peft def test_tag_added_peft(self): @@ -1251,7 +1249,7 @@ def test_tag_added_peft(self): ) for tag in ["sft", "trl"]: - self.assertIn(tag, trainer.model.model_tags) + assert tag in trainer.model.model_tags @parameterized.expand( [ @@ -1285,7 +1283,7 @@ def test_train_vlm(self, model_id): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): @@ -1302,9 +1300,7 @@ def test_train_vlm(self, model_id): ): # fmt: on continue - self.assertFalse( - torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated" - ) + assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated" @require_vision def test_train_vlm_prompt_completion(self): @@ -1330,12 +1326,12 @@ def test_train_vlm_prompt_completion(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated") + assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated" # Gemma 3n uses a timm encoder, making it difficult to create a smaller variant for testing. # To ensure coverage, we run tests on the full model but mark them as slow to exclude from default runs. @@ -1363,7 +1359,7 @@ def test_train_vlm_gemma_3n(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): @@ -1371,7 +1367,7 @@ def test_train_vlm_gemma_3n(self): if "model.vision_tower" in n: # The vision tower is not updated, not sure why at this point. continue - self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated") + assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated" @require_vision def test_train_vlm_text_only_data(self): @@ -1393,15 +1389,15 @@ def test_train_vlm_text_only_data(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if n.startswith("model.visual"): - self.assertTrue(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is updated") + assert torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is updated" else: - self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated") + assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated" @require_peft def test_prompt_tuning(self): @@ -1422,16 +1418,16 @@ def test_prompt_tuning(self): trainer.train() # Check that training completed successfully - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) - self.assertIsNotNone(trainer.state.log_history[-1]["mean_token_accuracy"]) + assert trainer.state.log_history[-1]["train_loss"] is not None + assert trainer.state.log_history[-1]["mean_token_accuracy"] is not None # Check the peft params have changed and the base model params have not changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if "base_model" in n: # We expect the base model parameters to be the same - self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed") + assert torch.allclose(param, new_param), f"Parameter {n} has changed" elif "prompt_encoder" in n: # We expect the peft parameters to be different - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" else: raise ValueError(f"Unexpected parameter {n} in model: {trainer.model}") @@ -1455,7 +1451,7 @@ def test_peft_model_with_quantization(self): # Verify that this triggers the is_qlora condition is_qlora = getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False) - self.assertTrue(is_qlora, "Model should be detected as QLoRA (quantized)") + assert is_qlora, "Model should be detected as QLoRA (quantized)" # Create LoRA configuration suitable for QLoRA lora_config = LoraConfig( @@ -1470,7 +1466,7 @@ def test_peft_model_with_quantization(self): peft_model = get_peft_model(model, lora_config) # Verify the quantization attributes are preserved on the PeftModel - self.assertTrue(getattr(peft_model, "is_loaded_in_4bit", False), "PeftModel should preserve quantization flag") + assert getattr(peft_model, "is_loaded_in_4bit", False), "PeftModel should preserve quantization flag" # Get the dataset dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") @@ -1489,9 +1485,9 @@ def test_peft_model_with_quantization(self): base_params_before.append(name) # Ensure we have the expected parameter distribution for QLoRA - self.assertTrue(len(trainable_params_before) > 0, "PeftModel should have trainable parameters initially") - self.assertTrue(len(lora_params_before) > 0, "PeftModel should have trainable LoRA parameters") - self.assertEqual(len(base_params_before), 0, "Base model parameters should already be frozen in PeftModel") + assert len(trainable_params_before) > 0, "PeftModel should have trainable parameters initially" + assert len(lora_params_before) > 0, "PeftModel should have trainable LoRA parameters" + assert len(base_params_before) == 0, "Base model parameters should already be frozen in PeftModel" # Initialize the trainer with the already configured PeftModel training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none", max_steps=1) @@ -1508,32 +1504,24 @@ def test_peft_model_with_quantization(self): lora_params_after.append(name) # LoRA parameters should remain trainable - self.assertTrue( - len(trainable_params_after) > 0, - f"PeftModel should still have trainable parameters after SFTTrainer initialization. " - f"Found {len(trainable_params_after)} trainable params. " - f"This test fails without the fix for issue #3926.", - ) + assert len(trainable_params_after) > 0, \ + f"PeftModel should still have trainable parameters after SFTTrainer initialization. " \ + f"Found {len(trainable_params_after)} trainable params. " \ + f"This test fails without the fix for issue #3926." - self.assertTrue( - len(lora_params_after) > 0, - f"LoRA adapter parameters should remain trainable. " - f"Found {len(lora_params_after)} trainable LoRA params out of {len(lora_params_before)} original.", - ) + assert len(lora_params_after) > 0, \ + f"LoRA adapter parameters should remain trainable. " \ + f"Found {len(lora_params_after)} trainable LoRA params out of {len(lora_params_before)} original." # Ensure the parameter counts are preserved (no additional freezing occurred) - self.assertEqual( - len(trainable_params_before), - len(trainable_params_after), - "Number of trainable parameters should not change after SFTTrainer initialization", - ) + assert len(trainable_params_before) == \ + len(trainable_params_after), \ + "Number of trainable parameters should not change after SFTTrainer initialization" # Verify that all original LoRA parameters are still trainable - self.assertEqual( - set(lora_params_before), - set(lora_params_after), - "All original LoRA parameters should remain trainable after SFTTrainer initialization", - ) + assert set(lora_params_before) == \ + set(lora_params_after), \ + "All original LoRA parameters should remain trainable after SFTTrainer initialization" @require_peft def test_prompt_tuning_peft_model(self): @@ -1552,15 +1540,15 @@ def test_prompt_tuning_peft_model(self): trainer.train() # Check that training completed successfully - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) - self.assertIsNotNone(trainer.state.log_history[-1]["mean_token_accuracy"]) + assert trainer.state.log_history[-1]["train_loss"] is not None + assert trainer.state.log_history[-1]["mean_token_accuracy"] is not None # Check the peft params have changed and the base model params have not changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if "base_model" in n: # We expect the base model parameters to be the same - self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed") + assert torch.allclose(param, new_param), f"Parameter {n} has changed" elif "prompt_encoder" in n: # We expect the peft parameters to be different - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" else: raise ValueError(f"Unexpected parameter {n} in model: {trainer.model}") diff --git a/tests/test_trainers_args.py b/tests/test_trainers_args.py index 61ab72130f2..35a8da57cd3 100644 --- a/tests/test_trainers_args.py +++ b/tests/test_trainers_args.py @@ -76,22 +76,22 @@ def test_bco(self): train_dataset=dataset, processing_class=tokenizer, ) - self.assertEqual(trainer.args.max_length, 256) - self.assertEqual(trainer.args.max_prompt_length, 64) - self.assertEqual(trainer.args.max_completion_length, 64) - self.assertEqual(trainer.args.beta, 0.5) - self.assertEqual(trainer.args.label_pad_token_id, -99) - self.assertEqual(trainer.args.padding_value, -99) - self.assertEqual(trainer.args.truncation_mode, "keep_start") + assert trainer.args.max_length == 256 + assert trainer.args.max_prompt_length == 64 + assert trainer.args.max_completion_length == 64 + assert trainer.args.beta == 0.5 + assert trainer.args.label_pad_token_id == -99 + assert trainer.args.padding_value == -99 + assert trainer.args.truncation_mode == "keep_start" # self.assertEqual(trainer.args.generate_during_eval, True) - self.assertEqual(trainer.args.is_encoder_decoder, True) - self.assertEqual(trainer.args.precompute_ref_log_probs, True) - self.assertEqual(trainer.args.model_init_kwargs, {"trust_remote_code": True}) - self.assertEqual(trainer.args.ref_model_init_kwargs, {"trust_remote_code": True}) - self.assertEqual(trainer.args.dataset_num_proc, 4) - self.assertEqual(trainer.args.prompt_sample_size, 512) - self.assertEqual(trainer.args.min_density_ratio, 0.2) - self.assertEqual(trainer.args.max_density_ratio, 20.0) + assert trainer.args.is_encoder_decoder == True + assert trainer.args.precompute_ref_log_probs == True + assert trainer.args.model_init_kwargs == {"trust_remote_code": True} + assert trainer.args.ref_model_init_kwargs == {"trust_remote_code": True} + assert trainer.args.dataset_num_proc == 4 + assert trainer.args.prompt_sample_size == 512 + assert trainer.args.min_density_ratio == 0.2 + assert trainer.args.max_density_ratio == 20.0 def test_cpo(self): model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" @@ -117,22 +117,22 @@ def test_cpo(self): dataset_num_proc=4, ) trainer = CPOTrainer(model=model_id, args=training_args, train_dataset=dataset, processing_class=tokenizer) - self.assertEqual(trainer.args.max_length, 256) - self.assertEqual(trainer.args.max_prompt_length, 64) - self.assertEqual(trainer.args.max_completion_length, 64) - self.assertEqual(trainer.args.beta, 0.5) - self.assertEqual(trainer.args.label_smoothing, 0.5) - self.assertEqual(trainer.args.loss_type, "hinge") - self.assertEqual(trainer.args.disable_dropout, False) - self.assertEqual(trainer.args.cpo_alpha, 0.5) - self.assertEqual(trainer.args.simpo_gamma, 0.2) - self.assertEqual(trainer.args.label_pad_token_id, -99) - self.assertEqual(trainer.args.padding_value, -99) - self.assertEqual(trainer.args.truncation_mode, "keep_start") + assert trainer.args.max_length == 256 + assert trainer.args.max_prompt_length == 64 + assert trainer.args.max_completion_length == 64 + assert trainer.args.beta == 0.5 + assert trainer.args.label_smoothing == 0.5 + assert trainer.args.loss_type == "hinge" + assert trainer.args.disable_dropout == False + assert trainer.args.cpo_alpha == 0.5 + assert trainer.args.simpo_gamma == 0.2 + assert trainer.args.label_pad_token_id == -99 + assert trainer.args.padding_value == -99 + assert trainer.args.truncation_mode == "keep_start" # self.assertEqual(trainer.args.generate_during_eval, True) - self.assertEqual(trainer.args.is_encoder_decoder, True) - self.assertEqual(trainer.args.model_init_kwargs, {"trust_remote_code": True}) - self.assertEqual(trainer.args.dataset_num_proc, 4) + assert trainer.args.is_encoder_decoder == True + assert trainer.args.model_init_kwargs == {"trust_remote_code": True} + assert trainer.args.dataset_num_proc == 4 def test_dpo(self): model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" @@ -174,32 +174,32 @@ def test_dpo(self): train_dataset=dataset, processing_class=tokenizer, ) - self.assertEqual(trainer.args.beta, 0.5) - self.assertEqual(trainer.args.label_smoothing, 0.5) - self.assertEqual(trainer.args.loss_type, "hinge") - self.assertEqual(trainer.args.label_pad_token_id, -99) - self.assertEqual(trainer.args.pad_token, ".") - self.assertEqual(trainer.args.truncation_mode, "keep_start") - self.assertEqual(trainer.args.max_length, 256) - self.assertEqual(trainer.args.max_prompt_length, 64) - self.assertEqual(trainer.args.max_completion_length, 64) - self.assertEqual(trainer.args.disable_dropout, False) + assert trainer.args.beta == 0.5 + assert trainer.args.label_smoothing == 0.5 + assert trainer.args.loss_type == "hinge" + assert trainer.args.label_pad_token_id == -99 + assert trainer.args.pad_token == "." + assert trainer.args.truncation_mode == "keep_start" + assert trainer.args.max_length == 256 + assert trainer.args.max_prompt_length == 64 + assert trainer.args.max_completion_length == 64 + assert trainer.args.disable_dropout == False # self.assertEqual(trainer.args.generate_during_eval, True) - self.assertEqual(trainer.args.precompute_ref_log_probs, True) - self.assertEqual(trainer.args.dataset_num_proc, 4) - self.assertEqual(trainer.args.model_init_kwargs, {"trust_remote_code": True}) - self.assertEqual(trainer.args.ref_model_init_kwargs, {"trust_remote_code": True}) - self.assertEqual(trainer.args.model_adapter_name, "dummy_adapter") - self.assertEqual(trainer.args.ref_adapter_name, "dummy_adapter") - self.assertEqual(trainer.args.reference_free, True) - self.assertEqual(trainer.args.force_use_ref_model, True) - self.assertEqual(trainer.args.f_divergence_type, FDivergenceType.JS_DIVERGENCE) - self.assertEqual(trainer.args.f_alpha_divergence_coef, 0.5) + assert trainer.args.precompute_ref_log_probs == True + assert trainer.args.dataset_num_proc == 4 + assert trainer.args.model_init_kwargs == {"trust_remote_code": True} + assert trainer.args.ref_model_init_kwargs == {"trust_remote_code": True} + assert trainer.args.model_adapter_name == "dummy_adapter" + assert trainer.args.ref_adapter_name == "dummy_adapter" + assert trainer.args.reference_free == True + assert trainer.args.force_use_ref_model == True + assert trainer.args.f_divergence_type == FDivergenceType.JS_DIVERGENCE + assert trainer.args.f_alpha_divergence_coef == 0.5 # self.assertEqual(trainer.args.sync_ref_model, True) - self.assertEqual(trainer.args.ref_model_mixup_alpha, 0.5) - self.assertEqual(trainer.args.ref_model_sync_steps, 32) - self.assertEqual(trainer.args.rpo_alpha, 0.5) - self.assertEqual(trainer.args.discopop_tau, 0.1) + assert trainer.args.ref_model_mixup_alpha == 0.5 + assert trainer.args.ref_model_sync_steps == 32 + assert trainer.args.rpo_alpha == 0.5 + assert trainer.args.discopop_tau == 0.1 def test_kto(self): model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" @@ -230,21 +230,21 @@ def test_kto(self): train_dataset=dataset, processing_class=tokenizer, ) - self.assertEqual(trainer.args.max_length, 256) - self.assertEqual(trainer.args.max_prompt_length, 64) - self.assertEqual(trainer.args.max_completion_length, 64) - self.assertEqual(trainer.args.beta, 0.5) - self.assertEqual(trainer.args.desirable_weight, 0.5) - self.assertEqual(trainer.args.undesirable_weight, 0.5) - self.assertEqual(trainer.args.label_pad_token_id, -99) - self.assertEqual(trainer.args.padding_value, -99) - self.assertEqual(trainer.args.truncation_mode, "keep_start") + assert trainer.args.max_length == 256 + assert trainer.args.max_prompt_length == 64 + assert trainer.args.max_completion_length == 64 + assert trainer.args.beta == 0.5 + assert trainer.args.desirable_weight == 0.5 + assert trainer.args.undesirable_weight == 0.5 + assert trainer.args.label_pad_token_id == -99 + assert trainer.args.padding_value == -99 + assert trainer.args.truncation_mode == "keep_start" # self.assertEqual(trainer.args.generate_during_eval, True) - self.assertEqual(trainer.args.is_encoder_decoder, True) - self.assertEqual(trainer.args.precompute_ref_log_probs, True) - self.assertEqual(trainer.args.model_init_kwargs, {"trust_remote_code": True}) - self.assertEqual(trainer.args.ref_model_init_kwargs, {"trust_remote_code": True}) - self.assertEqual(trainer.args.dataset_num_proc, 4) + assert trainer.args.is_encoder_decoder == True + assert trainer.args.precompute_ref_log_probs == True + assert trainer.args.model_init_kwargs == {"trust_remote_code": True} + assert trainer.args.ref_model_init_kwargs == {"trust_remote_code": True} + assert trainer.args.dataset_num_proc == 4 @parameterized.expand([(False,), (True,)]) def test_nash_md(self, mixtures_coef_list): @@ -266,7 +266,7 @@ def test_nash_md(self, mixtures_coef_list): reward_funcs=reward_model, train_dataset=dataset, ) - self.assertEqual(trainer.args.mixture_coef, 0.5 if not mixtures_coef_list else [0.5, 0.6]) + assert trainer.args.mixture_coef == (0.5 if not mixtures_coef_list else [0.5, 0.6]) @parameterized.expand([(False,), (True,)]) def test_online_dpo(self, beta_list): @@ -293,11 +293,11 @@ def test_online_dpo(self, beta_list): processing_class=tokenizer, reward_processing_classes=tokenizer, ) - self.assertEqual(trainer.args.max_new_tokens, 42) - self.assertEqual(trainer.args.temperature, 0.5) - self.assertEqual(trainer.args.missing_eos_penalty, 0.33) - self.assertEqual(trainer.args.beta, 0.6 if not beta_list else [0.6, 0.7]) - self.assertEqual(trainer.args.loss_type, "hinge") + assert trainer.args.max_new_tokens == 42 + assert trainer.args.temperature == 0.5 + assert trainer.args.missing_eos_penalty == 0.33 + assert trainer.args.beta == (0.6 if not beta_list else [0.6, 0.7]) + assert trainer.args.loss_type == "hinge" def test_orpo(self): model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" @@ -319,12 +319,12 @@ def test_orpo(self): dataset_num_proc=4, ) trainer = ORPOTrainer(model=model_id, args=training_args, train_dataset=dataset, processing_class=tokenizer) - self.assertEqual(trainer.args.max_length, 256) - self.assertEqual(trainer.args.max_prompt_length, 64) - self.assertEqual(trainer.args.max_completion_length, 64) - self.assertEqual(trainer.args.beta, 0.5) - self.assertEqual(trainer.args.disable_dropout, False) - self.assertEqual(trainer.args.label_pad_token_id, -99) + assert trainer.args.max_length == 256 + assert trainer.args.max_prompt_length == 64 + assert trainer.args.max_completion_length == 64 + assert trainer.args.beta == 0.5 + assert trainer.args.disable_dropout == False + assert trainer.args.label_pad_token_id == -99 def test_reward(self): model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" @@ -343,9 +343,9 @@ def test_reward(self): train_dataset=dataset, processing_class=tokenizer, ) - self.assertEqual(trainer.args.max_length, 256) - self.assertEqual(trainer.args.dataset_num_proc, 4) - self.assertEqual(trainer.args.center_rewards_coefficient, 0.1) + assert trainer.args.max_length == 256 + assert trainer.args.dataset_num_proc == 4 + assert trainer.args.center_rewards_coefficient == 0.1 def test_sft(self): model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" @@ -362,15 +362,15 @@ def test_sft(self): eval_packing=True, ) trainer = SFTTrainer(model_id, args=training_args, train_dataset=dataset) - self.assertEqual(trainer.args.dataset_text_field, "dummy_text_field") - self.assertEqual(trainer.args.packing, True) - self.assertEqual(trainer.args.max_length, 256) - self.assertEqual(trainer.args.dataset_num_proc, 4) - self.assertEqual(trainer.args.neftune_noise_alpha, 0.1) - self.assertEqual(trainer.args.model_init_kwargs, {"trust_remote_code": True}) - self.assertIn("append_concat_token", trainer.args.dataset_kwargs) - self.assertEqual(trainer.args.dataset_kwargs["append_concat_token"], True) - self.assertEqual(trainer.args.eval_packing, True) + assert trainer.args.dataset_text_field == "dummy_text_field" + assert trainer.args.packing == True + assert trainer.args.max_length == 256 + assert trainer.args.dataset_num_proc == 4 + assert trainer.args.neftune_noise_alpha == 0.1 + assert trainer.args.model_init_kwargs == {"trust_remote_code": True} + assert "append_concat_token" in trainer.args.dataset_kwargs + assert trainer.args.dataset_kwargs["append_concat_token"] == True + assert trainer.args.eval_packing == True @parameterized.expand([(False,), (True,)]) def test_xpo(self, alpha_list): @@ -392,4 +392,4 @@ def test_xpo(self, alpha_list): reward_funcs=reward_model, train_dataset=dataset, ) - self.assertEqual(trainer.args.alpha, 0.5 if not alpha_list else [0.5, 0.6]) + assert trainer.args.alpha == (0.5 if not alpha_list else [0.5, 0.6]) diff --git a/tests/test_utils.py b/tests/test_utils.py index 4758109e08b..2bfa9bff0c2 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -47,6 +47,7 @@ ) from .testing_utils import TrlTestCase, require_rich +import pytest if is_peft_available(): @@ -59,14 +60,14 @@ def test_pad_1_dim_left(self): y = torch.tensor([4, 5]) output = pad((x, y), padding_value=0, padding_side="left") expected = torch.tensor([[1, 2, 3], [0, 4, 5]]) - self.assertTrue(torch.equal(output, expected)) + assert torch.equal(output, expected) def test_pad_1_dim_right(self): x = torch.tensor([1, 2, 3]) y = torch.tensor([4, 5]) output = pad((x, y), padding_value=0, padding_side="right") expected = torch.tensor([[1, 2, 3], [4, 5, 0]]) - self.assertTrue(torch.equal(output, expected)) + assert torch.equal(output, expected) def test_pad_2_dim_left(self): x = torch.tensor([[1, 2], [3, 4]]) @@ -78,7 +79,7 @@ def test_pad_2_dim_left(self): [[0, 0], [5, 6]], ] ) - self.assertTrue(torch.equal(output, expected)) + assert torch.equal(output, expected) def test_pad_2_dim_right(self): x = torch.tensor([[1, 2], [3, 4]]) @@ -90,7 +91,7 @@ def test_pad_2_dim_right(self): [[5, 6], [0, 0]], ] ) - self.assertTrue(torch.equal(output, expected)) + assert torch.equal(output, expected) def test_pad_2_dim_right_multidim(self): x = torch.tensor([[1, 2], [3, 4]]) @@ -102,7 +103,7 @@ def test_pad_2_dim_right_multidim(self): [[5, 0], [0, 0]], ] ) - self.assertTrue(torch.equal(output, expected)) + assert torch.equal(output, expected) def test_pad_to_multiple_of_1(self): x = torch.tensor([1, 2, 3]) @@ -110,7 +111,7 @@ def test_pad_to_multiple_of_1(self): # Max length is 3, pad to multiple of 4 output = pad((x, y), padding_value=0, padding_side="right", pad_to_multiple_of=4) expected = torch.tensor([[1, 2, 3, 0], [4, 5, 0, 0]]) - self.assertTrue(torch.equal(output, expected)) + assert torch.equal(output, expected) def test_pad_to_multiple_of_2(self): x = torch.tensor([1, 2, 3, 4, 5]) @@ -118,7 +119,7 @@ def test_pad_to_multiple_of_2(self): # Max length is 3, pad to multiple of 4 output = pad((x, y), padding_value=0, padding_side="right", pad_to_multiple_of=4) expected = torch.tensor([[1, 2, 3, 4, 5, 0, 0, 0], [6, 7, 8, 0, 0, 0, 0, 0]]) - self.assertTrue(torch.equal(output, expected)) + assert torch.equal(output, expected) def test_pad_to_multiple_of_side_left(self): x = torch.tensor([1, 2, 3, 4, 5]) @@ -126,7 +127,7 @@ def test_pad_to_multiple_of_side_left(self): # Max length is 3, pad to multiple of 4 output = pad((x, y), padding_value=0, padding_side="left", pad_to_multiple_of=4) expected = torch.tensor([[0, 0, 0, 1, 2, 3, 4, 5], [0, 0, 0, 0, 0, 6, 7, 8]]) - self.assertTrue(torch.equal(output, expected)) + assert torch.equal(output, expected) def test_pad_to_multiple_of_no_extra_padding(self): x = torch.tensor([1, 2, 3, 4]) @@ -134,7 +135,7 @@ def test_pad_to_multiple_of_no_extra_padding(self): # Already multiple of 4 output = pad((x, y), padding_value=0, padding_side="left", pad_to_multiple_of=4) expected = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]) - self.assertTrue(torch.equal(output, expected)) + assert torch.equal(output, expected) @require_peft @@ -143,7 +144,7 @@ def test_create_peft_config_use_peft_false(self): """Test that when use_peft is False, the function returns None.""" model_args = ModelConfig(use_peft=False) peft_config = get_peft_config(model_args) - self.assertIsNone(peft_config) + assert peft_config is None def test_create_peft_config_use_peft_true(self): """Test that when use_peft is True, the function returns a LoraConfig object.""" @@ -159,7 +160,7 @@ def test_create_peft_config_use_peft_true(self): } model_args = ModelConfig(use_peft=True, **peft_kwargs) peft_config = get_peft_config(model_args) - self.assertTrue(isinstance(peft_config, LoraConfig)) + assert isinstance(peft_config, LoraConfig) for arg, value in peft_kwargs.items(): # Test that lists of modules are converted to sets if arg == "lora_target_modules": @@ -168,7 +169,7 @@ def test_create_peft_config_use_peft_true(self): if arg in ["lora_r", "lora_task_type", "lora_target_modules", "lora_modules_to_save"]: arg = arg[len("lora_") :] if arg.startswith("lora_") else arg - self.assertEqual(getattr(peft_config, arg), value) + assert getattr(peft_config, arg) == value class TestDecodeAndStripPadding(TrlTestCase): @@ -179,12 +180,12 @@ def setUp(self): def test_example_with_padding(self): inputs = self.tokenizer(["Hello world", "Hello"], padding=True, return_tensors="pt") decoded = decode_and_strip_padding(inputs["input_ids"], self.tokenizer) - self.assertEqual(decoded, ["Hello world", "Hello"]) + assert decoded == ["Hello world", "Hello"] def test_example_without_padding(self): inputs = self.tokenizer(["Hello", "Hello"], padding=False, return_tensors="pt") decoded = decode_and_strip_padding(inputs["input_ids"], self.tokenizer) - self.assertEqual(decoded, ["Hello", "Hello"]) + assert decoded == ["Hello", "Hello"] class TestGenerateModelCard(TrlTestCase): @@ -203,15 +204,15 @@ def test_full(self): paper_id="1234.56789", ) card_text = str(model_card) - self.assertIn("[username/my_base_model](https://huggingface.co/username/my_base_model)", card_text) - self.assertIn("my_model", card_text) - self.assertIn('pipeline("text-generation", model="username/my_hub_model", device="cuda")', card_text) - self.assertIn("datasets: username/my_dataset", card_text) - self.assertIn("](https://wandb.ai/username/project_id/runs/abcd1234)", card_text) - self.assertIn("](https://www.comet.com/username/project_id/experiment_id", card_text) - self.assertIn("My Trainer", card_text) - self.assertIn("```bibtex\n@article{my_trainer, ...}\n```", card_text) - self.assertIn("[My Paper](https://huggingface.co/papers/1234.56789)", card_text) + assert "[username/my_base_model](https://huggingface.co/username/my_base_model)" in card_text + assert "my_model" in card_text + assert 'pipeline("text-generation", model="username/my_hub_model", device="cuda")' in card_text + assert "datasets: username/my_dataset" in card_text + assert "](https://wandb.ai/username/project_id/runs/abcd1234)" in card_text + assert "](https://www.comet.com/username/project_id/experiment_id" in card_text + assert "My Trainer" in card_text + assert "```bibtex\n@article{my_trainer, ...}\n```" in card_text + assert "[My Paper](https://huggingface.co/papers/1234.56789)" in card_text def test_val_none(self): model_card = generate_model_card( @@ -228,9 +229,9 @@ def test_val_none(self): paper_id=None, ) card_text = str(model_card) - self.assertIn("my_model", card_text) - self.assertIn('pipeline("text-generation", model="username/my_hub_model", device="cuda")', card_text) - self.assertIn("My Trainer", card_text) + assert "my_model" in card_text + assert 'pipeline("text-generation", model="username/my_hub_model", device="cuda")' in card_text + assert "My Trainer" in card_text class TestDataCollatorForChatML(TrlTestCase): @@ -265,11 +266,11 @@ def test_data_collator_for_chatml(self): data = self.collator(self.examples) # Verify basic shapes and types - self.assertIn("input_ids", data) - self.assertIn("attention_mask", data) - self.assertIn("labels", data) - self.assertIn("prompts", data) - self.assertIn("prompt_attention_mask", data) + assert "input_ids" in data + assert "attention_mask" in data + assert "labels" in data + assert "prompts" in data + assert "prompt_attention_mask" in data # Decode input_ids and labels for verification input_ids = data["input_ids"][0].tolist() @@ -278,23 +279,21 @@ def test_data_collator_for_chatml(self): # Get the last assistant's response for comparison last_message = self.examples[0][self.messages_key][-1] - self.assertEqual(last_message["role"], "assistant", "Last message should be from assistant") + assert last_message["role"] == "assistant", "Last message should be from assistant" last_assistant_response = last_message["content"] # Verify that input_ids contain both prompt and response decoded_input = self.tokenizer.decode(input_ids) - self.assertIn(last_assistant_response, decoded_input, "Input should contain assistant's response") + assert last_assistant_response in decoded_input, "Input should contain assistant's response" # Verify that prompts only contain the conversation up to the last response decoded_prompt = self.tokenizer.decode(prompt_only) - self.assertNotIn(last_assistant_response, decoded_prompt, "Prompt should not contain assistant's response") + assert last_assistant_response not in decoded_prompt, "Prompt should not contain assistant's response" # Verify labels are -100 for non-assistant parts prompt_length = len(prompt_only) - self.assertTrue( - all(label == self.ignore_index for label in labels[:prompt_length]), - "Labels should be ignore_index for prompt tokens", - ) + assert all(label == self.ignore_index for label in labels[:prompt_length]), \ + "Labels should be ignore_index for prompt tokens" # Verify labels match assistant response after prompt # Add a filter to remove any trailing tokens after the first <|im_end|> @@ -310,24 +309,18 @@ def test_data_collator_for_chatml(self): response_labels.append(label) if label == self.tokenizer.convert_tokens_to_ids("<|im_end|>"): break - self.assertEqual( - response_labels, - last_assistant_response_tokens, - "Labels should match assistant response tokens", - ) + assert response_labels == \ + last_assistant_response_tokens, \ + "Labels should match assistant response tokens" # Verify there isn't a generation prompt at the end generation_prompt = "<|im_start|>assistant" - self.assertFalse( - decoded_input.strip().endswith(generation_prompt), - f"Input should not end with generation prompt '{generation_prompt}'", - ) + assert not decoded_input.strip().endswith(generation_prompt), \ + f"Input should not end with generation prompt '{generation_prompt}'" - self.assertEqual( - response_labels, - last_assistant_response_tokens, - "Labels should match assistant response tokens", - ) + assert response_labels == \ + last_assistant_response_tokens, \ + "Labels should match assistant response tokens" class TestBatchGeneration(TrlTestCase): @@ -367,9 +360,9 @@ def test_mini_batch_generation(self): max_length_query = query_responses.shape[1] max_length_logits = max_length_query - context_length - self.assertGreater(max_length_query, context_length) - self.assertEqual(query_responses.shape, (bs, max_length_query)) - self.assertEqual(logits.shape, (bs, max_length_logits, self.model.config.vocab_size)) + assert max_length_query > context_length + assert query_responses.shape == (bs, max_length_query) + assert logits.shape == (bs, max_length_logits, self.model.config.vocab_size) def test_single_batch_generation(self): batch = [ @@ -386,9 +379,9 @@ def test_single_batch_generation(self): max_length_query = query_responses.shape[1] max_length_logits = max_length_query - context_length - self.assertGreater(max_length_query, context_length) - self.assertEqual(query_responses.shape, (bs, max_length_query)) - self.assertEqual(logits.shape, (bs, max_length_logits, self.model.config.vocab_size)) + assert max_length_query > context_length + assert query_responses.shape == (bs, max_length_query) + assert logits.shape == (bs, max_length_logits, self.model.config.vocab_size) class TestComputeAccuracy(TrlTestCase): @@ -404,7 +397,7 @@ def test_token_classification_task(self): ) expected_accuracy = 0.5 # 2 matches, 2 mismatches result = compute_accuracy(eval_pred) - self.assertAlmostEqual(result["accuracy"], expected_accuracy) + assert round(abs(result["accuracy"]-expected_accuracy), 7) == 0 def test_token_classification_task_with_ignored_tokens_0(self): eval_pred = ( @@ -418,7 +411,7 @@ def test_token_classification_task_with_ignored_tokens_0(self): ) expected_accuracy = 1.0 # All non-ignored tokens match result = compute_accuracy(eval_pred) - self.assertAlmostEqual(result["accuracy"], expected_accuracy) + assert round(abs(result["accuracy"]-expected_accuracy), 7) == 0 def test_token_classification_task_with_ignored_tokens_1(self): eval_pred = ( @@ -432,7 +425,7 @@ def test_token_classification_task_with_ignored_tokens_1(self): ) expected_accuracy = 1 / 3 # 1 match, 2 mismatch, 1 ignored result = compute_accuracy(eval_pred) - self.assertAlmostEqual(result["accuracy"], expected_accuracy) + assert round(abs(result["accuracy"]-expected_accuracy), 7) == 0 def test_rewards_comparison_task(self): eval_pred = ( @@ -450,12 +443,12 @@ def test_rewards_comparison_task(self): with self.assertLogs("trl.trainer.utils", level="WARNING") as cm: result = compute_accuracy(eval_pred) - self.assertAlmostEqual(result["accuracy"], expected_accuracy) + assert round(abs(result["accuracy"]-expected_accuracy), 7) == 0 expected_warning = ( "There are 1 out of 3 instances where the predictions for both options are equal. " "These instances are ignored in the accuracy computation." ) - self.assertIn(expected_warning, cm.output[0]) + assert expected_warning in cm.output[0] class TestFlushLeft(TrlTestCase): @@ -469,9 +462,9 @@ def test_basic_case(self): expected_tensor1 = torch.tensor([[2, 3, 4], [5, 6, 0]]) expected_tensor2 = torch.tensor([[7, 8, 9], [10, 11, 0]]) - self.assertTrue(torch.equal(new_mask, expected_mask)) - self.assertTrue(torch.equal(new_tensor1, expected_tensor1)) - self.assertTrue(torch.equal(new_tensor2, expected_tensor2)) + assert torch.equal(new_mask, expected_mask) + assert torch.equal(new_tensor1, expected_tensor1) + assert torch.equal(new_tensor2, expected_tensor2) def test_single_row(self): mask = torch.tensor([[0, 0, 1, 1]]) @@ -481,8 +474,8 @@ def test_single_row(self): expected_mask = torch.tensor([[1, 1]]) expected_tensor1 = torch.tensor([[2, 3]]) - self.assertTrue(torch.equal(new_mask, expected_mask)) - self.assertTrue(torch.equal(new_tensor1, expected_tensor1)) + assert torch.equal(new_mask, expected_mask) + assert torch.equal(new_tensor1, expected_tensor1) def test_no_shift_needed(self): mask = torch.tensor([[1, 1, 0, 0], [1, 0, 0, 0]]) @@ -492,14 +485,14 @@ def test_no_shift_needed(self): expected_mask = torch.tensor([[1, 1], [1, 0]]) expected_tensor1 = torch.tensor([[5, 6], [7, 0]]) - self.assertTrue(torch.equal(new_mask, expected_mask)) - self.assertTrue(torch.equal(new_tensor1, expected_tensor1)) + assert torch.equal(new_mask, expected_mask) + assert torch.equal(new_tensor1, expected_tensor1) def test_no_tensors(self): mask = torch.tensor([[0, 0, 1, 1, 1], [0, 1, 1, 0, 0]]) new_mask = flush_left(mask) expected_mask = torch.tensor([[1, 1, 1], [1, 1, 0]]) - self.assertTrue(torch.equal(new_mask, expected_mask)) + assert torch.equal(new_mask, expected_mask) class TestFlushRight(TrlTestCase): @@ -513,9 +506,9 @@ def test_basic_case(self): expected_tensor1 = torch.tensor([[2, 3, 4], [0, 5, 6]]) expected_tensor2 = torch.tensor([[7, 8, 9], [0, 10, 11]]) - self.assertTrue(torch.equal(new_mask, expected_mask)) - self.assertTrue(torch.equal(new_tensor1, expected_tensor1)) - self.assertTrue(torch.equal(new_tensor2, expected_tensor2)) + assert torch.equal(new_mask, expected_mask) + assert torch.equal(new_tensor1, expected_tensor1) + assert torch.equal(new_tensor2, expected_tensor2) def test_single_row(self): mask = torch.tensor([[1, 1, 0, 0]]) @@ -525,8 +518,8 @@ def test_single_row(self): expected_mask = torch.tensor([[1, 1]]) expected_tensor1 = torch.tensor([[2, 3]]) - self.assertTrue(torch.equal(new_mask, expected_mask)) - self.assertTrue(torch.equal(new_tensor1, expected_tensor1)) + assert torch.equal(new_mask, expected_mask) + assert torch.equal(new_tensor1, expected_tensor1) def test_no_shift_needed(self): mask = torch.tensor([[0, 0, 1, 1], [0, 0, 0, 1]]) @@ -536,14 +529,14 @@ def test_no_shift_needed(self): expected_mask = torch.tensor([[1, 1], [0, 1]]) expected_tensor1 = torch.tensor([[5, 6], [0, 7]]) - self.assertTrue(torch.equal(new_mask, expected_mask)) - self.assertTrue(torch.equal(new_tensor1, expected_tensor1)) + assert torch.equal(new_mask, expected_mask) + assert torch.equal(new_tensor1, expected_tensor1) def test_no_tensors(self): mask = torch.tensor([[1, 1, 1, 0, 0], [0, 0, 1, 1, 0]]) new_mask = flush_right(mask) expected_mask = torch.tensor([[1, 1, 1], [0, 1, 1]]) - self.assertTrue(torch.equal(new_mask, expected_mask)) + assert torch.equal(new_mask, expected_mask) class RepeatRandomSamplerTester(TrlTestCase): @@ -564,7 +557,7 @@ def test_sampler_no_shuffle(self): sampler = RepeatSampler(dataset, mini_repeat_count=2, shuffle=False) sampled = list(sampler) expected = [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6] - self.assertEqual(sampled, expected) + assert sampled == expected def test_sampler_no_repeat(self): dataset = ["a", "b", "c", "d", "e", "f", "g"] @@ -706,7 +699,7 @@ def test_print_output(self, mock_stdout): ╰──────────────────────────────────────────────────────────────────╯ """) - self.assertEqual(output, expected_output) + assert output == expected_output @patch("sys.stdout", new_callable=StringIO) def test_num_samples(self, mock_stdout): @@ -741,7 +734,7 @@ def test_num_samples(self, mock_stdout): ╰─────────────────────────────────────────────╯ """), ] - self.assertIn(output, possible_outputs) + assert output in possible_outputs @patch("sys.stdout", new_callable=StringIO) def test_print_messages(self, mock_stdout): @@ -790,7 +783,7 @@ def test_print_messages(self, mock_stdout): ╰──────────────────────────────────────────────────────────────────────────────╯ """) - self.assertEqual(output, expected_output) + assert output == expected_output @patch("sys.stdout", new_callable=StringIO) def test_print_messages_with_tools(self, mock_stdout): @@ -829,7 +822,7 @@ def test_print_messages_with_tools(self, mock_stdout): ╰──────────────────────────────────────────────────────────────────────────────╯ """) - self.assertEqual(output, expected_output) + assert output == expected_output class TestSelectiveLogSoftmax(TrlTestCase): @@ -848,7 +841,7 @@ def test_selective_log_softmax(self, dtype): if dtype in [torch.float16, torch.bfloat16]: # half-precision dtypes fall back to an exact method - self.assertTrue(torch.equal(actual_output, expected_output)) + assert torch.equal(actual_output, expected_output) else: torch.testing.assert_close(actual_output, expected_output, rtol=1e-5, atol=1e-5) @@ -861,8 +854,8 @@ def test_shuffle_preserves_shape(self): shuffled = shuffle_sequence_dict(tensor_dict) - self.assertEqual(shuffled["x"].shape, x.shape) - self.assertEqual(shuffled["y"].shape, y.shape) + assert shuffled["x"].shape == x.shape + assert shuffled["y"].shape == y.shape def test_shuffle_consistent_across_tensors(self): # Use known patterns to check alignment @@ -878,13 +871,13 @@ def test_shuffle_consistent_across_tensors(self): y_val = shuffled["y"][i].item() if torch.equal(x_row, torch.tensor([10, 11])): - self.assertEqual(y_val, 1) + assert y_val == 1 elif torch.equal(x_row, torch.tensor([20, 21])): - self.assertEqual(y_val, 2) + assert y_val == 2 elif torch.equal(x_row, torch.tensor([30, 31])): - self.assertEqual(y_val, 3) + assert y_val == 3 else: - self.fail("Unexpected x row in shuffled output.") + pytest.fail("Unexpected x row in shuffled output.") def test_none_tensor_remains_none(self): x = torch.arange(6).reshape(3, 2) @@ -892,8 +885,8 @@ def test_none_tensor_remains_none(self): shuffled = shuffle_sequence_dict(tensor_dict) - self.assertIsNone(shuffled["y"]) - self.assertEqual(shuffled["x"].shape, x.shape) + assert shuffled["y"] is None + assert shuffled["x"].shape == x.shape def test_shuffle_with_list(self): x = torch.tensor([[10, 11], [20, 21], [30, 31]]) @@ -909,13 +902,13 @@ def test_shuffle_with_list(self): y_val = shuffled["y"][i] if torch.equal(x_row, torch.tensor([10, 11])): - self.assertEqual(y_val, "a") + assert y_val == "a" elif torch.equal(x_row, torch.tensor([20, 21])): - self.assertEqual(y_val, "b") + assert y_val == "b" elif torch.equal(x_row, torch.tensor([30, 31])): - self.assertEqual(y_val, "c") + assert y_val == "c" else: - self.fail("Unexpected x row in shuffled output.") + pytest.fail("Unexpected x row in shuffled output.") class SplitTensorDictTester(TrlTestCase): @@ -928,10 +921,10 @@ def test_split_equal_chunks(self): expected_x_chunks = torch.chunk(x, 3, dim=0) expected_y_chunks = torch.chunk(y, 3, dim=0) - self.assertEqual(len(result), 3) + assert len(result) == 3 for i in range(3): - self.assertTrue(torch.equal(result[i]["x"], expected_x_chunks[i])) - self.assertTrue(torch.equal(result[i]["y"], expected_y_chunks[i])) + assert torch.equal(result[i]["x"], expected_x_chunks[i]) + assert torch.equal(result[i]["y"], expected_y_chunks[i]) def test_with_none_tensor(self): x = torch.arange(12).reshape(6, 2) @@ -940,10 +933,10 @@ def test_with_none_tensor(self): result = split_tensor_dict(tensor_dict, 2) expected_x_chunks = torch.chunk(x, 2, dim=0) - self.assertEqual(len(result), 2) + assert len(result) == 2 for i in range(2): - self.assertTrue(torch.equal(result[i]["x"], expected_x_chunks[i])) - self.assertIsNone(result[i]["y"]) + assert torch.equal(result[i]["x"], expected_x_chunks[i]) + assert result[i]["y"] is None def test_with_scalar(self): x = torch.arange(12).reshape(6, 2) @@ -952,10 +945,10 @@ def test_with_scalar(self): result = split_tensor_dict(tensor_dict, 2) expected_x_chunks = torch.chunk(x, 2, dim=0) - self.assertEqual(len(result), 2) + assert len(result) == 2 for i in range(2): - self.assertTrue(torch.equal(result[i]["x"], expected_x_chunks[i])) - self.assertTrue(torch.equal(result[i]["y"], torch.tensor(1))) + assert torch.equal(result[i]["x"], expected_x_chunks[i]) + assert torch.equal(result[i]["y"], torch.tensor(1)) class SplitPixelValuesByGridTester(TrlTestCase): @@ -966,14 +959,14 @@ def test_split_correctly_0(self): "pixel_values": torch.arange(8 * 3).reshape(8, 3), # Shape: [8, 3] } result = split_pixel_values_by_grid(batch) - self.assertIsInstance(result["pixel_values"], list) - self.assertEqual(len(result["pixel_values"]), 2) - self.assertTrue(torch.equal(result["pixel_values"][0], batch["pixel_values"][:4])) - self.assertTrue(torch.equal(result["pixel_values"][1], batch["pixel_values"][4:])) - self.assertIsInstance(result["image_grid_thw"], list) - self.assertEqual(len(result["image_grid_thw"]), 2) - self.assertTrue(torch.equal(result["image_grid_thw"][0], torch.tensor([[1, 2, 2]]))) - self.assertTrue(torch.equal(result["image_grid_thw"][1], torch.tensor([[1, 2, 2]]))) + assert isinstance(result["pixel_values"], list) + assert len(result["pixel_values"]) == 2 + assert torch.equal(result["pixel_values"][0], batch["pixel_values"][:4]) + assert torch.equal(result["pixel_values"][1], batch["pixel_values"][4:]) + assert isinstance(result["image_grid_thw"], list) + assert len(result["image_grid_thw"]) == 2 + assert torch.equal(result["image_grid_thw"][0], torch.tensor([[1, 2, 2]])) + assert torch.equal(result["image_grid_thw"][1], torch.tensor([[1, 2, 2]])) def test_split_correctly_1(self): batch = { @@ -982,19 +975,19 @@ def test_split_correctly_1(self): "pixel_values": torch.arange(12 * 3).reshape(12, 3), # Shape: [12, 3] } result = split_pixel_values_by_grid(batch) - self.assertIsInstance(result["pixel_values"], list) - self.assertEqual(len(result["pixel_values"]), 2) - self.assertTrue(torch.equal(result["pixel_values"][0], batch["pixel_values"][:4])) - self.assertTrue(torch.equal(result["pixel_values"][1], batch["pixel_values"][4:12])) - self.assertIsInstance(result["image_grid_thw"], list) - self.assertEqual(len(result["image_grid_thw"]), 2) - self.assertTrue(torch.equal(result["image_grid_thw"][0], torch.tensor([[1, 2, 2]]))) - self.assertTrue(torch.equal(result["image_grid_thw"][1], torch.tensor([[1, 2, 4]]))) + assert isinstance(result["pixel_values"], list) + assert len(result["pixel_values"]) == 2 + assert torch.equal(result["pixel_values"][0], batch["pixel_values"][:4]) + assert torch.equal(result["pixel_values"][1], batch["pixel_values"][4:12]) + assert isinstance(result["image_grid_thw"], list) + assert len(result["image_grid_thw"]) == 2 + assert torch.equal(result["image_grid_thw"][0], torch.tensor([[1, 2, 2]])) + assert torch.equal(result["image_grid_thw"][1], torch.tensor([[1, 2, 4]])) def test_missing_keys(self): batch = {"pixel_values": torch.tensor([1.0])} result = split_pixel_values_by_grid(batch) - self.assertEqual(result, batch) + assert result == batch def test_mismatched_length(self): batch = { @@ -1002,7 +995,7 @@ def test_mismatched_length(self): "num_images": [1, 1], "pixel_values": torch.randn(3, 5), # Only 3 rows } - with self.assertRaises(ValueError): + with pytest.raises(ValueError): split_pixel_values_by_grid(batch) def test_multi_images(self): @@ -1012,14 +1005,14 @@ def test_multi_images(self): "pixel_values": torch.arange(8 * 3).reshape(8, 3), # Shape: [8, 3] } result = split_pixel_values_by_grid(batch) - self.assertIsInstance(result["pixel_values"], list) - self.assertEqual(len(result["pixel_values"]), 2) - self.assertTrue(torch.equal(result["pixel_values"][0], batch["pixel_values"][:2])) - self.assertTrue(torch.equal(result["pixel_values"][1], batch["pixel_values"][2:])) - self.assertIsInstance(result["image_grid_thw"], list) - self.assertEqual(len(result["image_grid_thw"]), 2) - self.assertTrue(torch.equal(result["image_grid_thw"][0], torch.tensor([[1, 1, 2]]))) - self.assertTrue(torch.equal(result["image_grid_thw"][1], torch.tensor([[1, 2, 2], [1, 2, 1]]))) + assert isinstance(result["pixel_values"], list) + assert len(result["pixel_values"]) == 2 + assert torch.equal(result["pixel_values"][0], batch["pixel_values"][:2]) + assert torch.equal(result["pixel_values"][1], batch["pixel_values"][2:]) + assert isinstance(result["image_grid_thw"], list) + assert len(result["image_grid_thw"]) == 2 + assert torch.equal(result["image_grid_thw"][0], torch.tensor([[1, 1, 2]])) + assert torch.equal(result["image_grid_thw"][1], torch.tensor([[1, 2, 2], [1, 2, 1]])) class TruncateWithProtectedTokensTester(TrlTestCase): @@ -1032,7 +1025,7 @@ def test_basic_example(self): new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens) expected_ids = [2, 3, 5] - self.assertEqual(new_ids, expected_ids) + assert new_ids == expected_ids def test_no_truncation_needed(self): """Test when target length equals current length.""" @@ -1042,7 +1035,7 @@ def test_no_truncation_needed(self): new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens) - self.assertEqual(new_ids, prompt_ids) + assert new_ids == prompt_ids def test_no_protected_tokens(self): """Test truncation with no protected tokens (normal right truncation).""" @@ -1053,7 +1046,7 @@ def test_no_protected_tokens(self): new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens) expected_ids = [3, 4, 5] # Last 3 tokens - self.assertEqual(new_ids, expected_ids) + assert new_ids == expected_ids def test_all_tokens_protected(self): """Test when all remaining tokens are protected.""" @@ -1064,7 +1057,7 @@ def test_all_tokens_protected(self): new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens) expected_ids = [3, 4, 5] - self.assertEqual(new_ids, expected_ids) + assert new_ids == expected_ids def test_too_many_protected_tokens(self): """Test error when too many protected tokens for target length.""" @@ -1072,7 +1065,7 @@ def test_too_many_protected_tokens(self): protected_tokens = [1, 2, 3, 4] target_length = 3 - with self.assertRaises(ValueError): + with pytest.raises(ValueError): truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens) def test_single_batch_single_token(self): @@ -1083,7 +1076,7 @@ def test_single_batch_single_token(self): new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens) - self.assertEqual(new_ids, prompt_ids) + assert new_ids == prompt_ids def test_order_preservation(self): """Test that relative order is preserved.""" @@ -1097,7 +1090,7 @@ def test_order_preservation(self): # Order should be: 2, 3, 30, 40 (maintaining original relative positions) expected_ids = [2, 3, 30, 40] - self.assertEqual(new_ids, expected_ids) + assert new_ids == expected_ids class UnsplitPixelValuesByGridTester(TrlTestCase): @@ -1108,14 +1101,14 @@ def test_unsplit_correctly(self): image_grid_thw_merged = torch.cat(image_grid_thw, dim=0) batch = {"pixel_values": pixel_values, "image_grid_thw": image_grid_thw, "other_key": torch.tensor([1])} result = unsplit_pixel_values_by_grid(batch) - self.assertIsInstance(result["pixel_values"], torch.Tensor) - self.assertTrue(torch.allclose(result["pixel_values"], pixel_values_merged)) - self.assertIsInstance(result["image_grid_thw"], torch.Tensor) - self.assertTrue(torch.equal(result["image_grid_thw"], image_grid_thw_merged)) - self.assertIn("other_key", result) + assert isinstance(result["pixel_values"], torch.Tensor) + assert torch.allclose(result["pixel_values"], pixel_values_merged) + assert isinstance(result["image_grid_thw"], torch.Tensor) + assert torch.equal(result["image_grid_thw"], image_grid_thw_merged) + assert "other_key" in result def test_no_op_if_not_list(self): original = torch.randn(5, 3) batch = {"pixel_values": original} result = unsplit_pixel_values_by_grid(batch) - self.assertTrue(torch.equal(result["pixel_values"], original)) + assert torch.equal(result["pixel_values"], original) diff --git a/tests/test_vllm_client_server.py b/tests/test_vllm_client_server.py index 08a302da41c..836cd124f16 100644 --- a/tests/test_vllm_client_server.py +++ b/tests/test_vllm_client_server.py @@ -27,28 +27,26 @@ class TestChunkList(TrlTestCase): def test_even_split(self): - self.assertEqual(chunk_list([1, 2, 3, 4, 5, 6], 2), [[1, 2, 3], [4, 5, 6]]) + assert chunk_list([1, 2, 3, 4, 5, 6], 2) == [[1, 2, 3], [4, 5, 6]] def test_uneven_split(self): - self.assertEqual(chunk_list([1, 2, 3, 4, 5, 6], 4), [[1, 2], [3, 4], [5], [6]]) + assert chunk_list([1, 2, 3, 4, 5, 6], 4) == [[1, 2], [3, 4], [5], [6]] def test_more_chunks_than_elements(self): - self.assertEqual(chunk_list([1, 2, 3, 4, 5, 6], 8), [[1], [2], [3], [4], [5], [6], [], []]) + assert chunk_list([1, 2, 3, 4, 5, 6], 8) == [[1], [2], [3], [4], [5], [6], [], []] def test_n_equals_len(self): - self.assertEqual(chunk_list([1, 2, 3], 3), [[1], [2], [3]]) + assert chunk_list([1, 2, 3], 3) == [[1], [2], [3]] def test_n_is_1(self): - self.assertEqual(chunk_list([1, 2, 3], 1), [[1, 2, 3]]) + assert chunk_list([1, 2, 3], 1) == [[1, 2, 3]] def test_single_element_list(self): - self.assertEqual(chunk_list([42], 2), [[42], []]) + assert chunk_list([42], 2) == [[42], []] def test_any_dtype(self): - self.assertEqual( - chunk_list([1, "two", 3.0, {"four": 4}, ["f", "i", "v", "e"]], 2), - [[1, "two", 3.0], [{"four": 4}, ["f", "i", "v", "e"]]], - ) + assert chunk_list([1, "two", 3.0, {"four": 4}, ["f", "i", "v", "e"]], 2) == \ + [[1, "two", 3.0], [{"four": 4}, ["f", "i", "v", "e"]]] @pytest.mark.slow @@ -77,14 +75,14 @@ def test_generate(self): outputs = self.client.generate(prompts)["completion_ids"] # Check that the output is a list - self.assertIsInstance(outputs, list) + assert isinstance(outputs, list) # Check that the number of generated sequences is equal to the number of prompts - self.assertEqual(len(outputs), len(prompts)) + assert len(outputs) == len(prompts) # Check that the generated sequences are lists of integers for seq in outputs: - self.assertTrue(all(isinstance(tok, int) for tok in seq)) + assert all(isinstance(tok, int) for tok in seq) def test_generate_with_params(self): prompts = ["Hello, AI!", "Tell me a joke"] @@ -93,18 +91,18 @@ def test_generate_with_params(self): ] # Check that the output is a list - self.assertIsInstance(outputs, list) + assert isinstance(outputs, list) # Check that the number of generated sequences is 2 times the number of prompts - self.assertEqual(len(outputs), 2 * len(prompts)) + assert len(outputs) == 2 * len(prompts) # Check that the generated sequences are lists of integers for seq in outputs: - self.assertTrue(all(isinstance(tok, int) for tok in seq)) + assert all(isinstance(tok, int) for tok in seq) # Check that the length of the generated sequences is less than or equal to 32 for seq in outputs: - self.assertLessEqual(len(seq), 32) + assert len(seq) <= 32 def test_update_model_params(self): model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map=torch_device) @@ -153,14 +151,14 @@ def test_generate(self): outputs = self.client.generate(prompts)["completion_ids"] # Check that the output is a list - self.assertIsInstance(outputs, list) + assert isinstance(outputs, list) # Check that the number of generated sequences is equal to the number of prompts - self.assertEqual(len(outputs), len(prompts)) + assert len(outputs) == len(prompts) # Check that the generated sequences are lists of integers for seq in outputs: - self.assertTrue(all(isinstance(tok, int) for tok in seq)) + assert all(isinstance(tok, int) for tok in seq) def test_generate_with_params(self): prompts = ["Hello, AI!", "Tell me a joke"] @@ -169,18 +167,18 @@ def test_generate_with_params(self): ] # Check that the output is a list - self.assertIsInstance(outputs, list) + assert isinstance(outputs, list) # Check that the number of generated sequences is 2 times the number of prompts - self.assertEqual(len(outputs), 2 * len(prompts)) + assert len(outputs) == 2 * len(prompts) # Check that the generated sequences are lists of integers for seq in outputs: - self.assertTrue(all(isinstance(tok, int) for tok in seq)) + assert all(isinstance(tok, int) for tok in seq) # Check that the length of the generated sequences is less than or equal to 32 for seq in outputs: - self.assertLessEqual(len(seq), 32) + assert len(seq) <= 32 def test_update_model_params(self): model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map=torch_device) @@ -231,14 +229,14 @@ def test_generate(self): outputs = self.client.generate(prompts)["completion_ids"] # Check that the output is a list - self.assertIsInstance(outputs, list) + assert isinstance(outputs, list) # Check that the number of generated sequences is equal to the number of prompts - self.assertEqual(len(outputs), len(prompts)) + assert len(outputs) == len(prompts) # Check that the generated sequences are lists of integers for seq in outputs: - self.assertTrue(all(isinstance(tok, int) for tok in seq)) + assert all(isinstance(tok, int) for tok in seq) def test_update_model_params(self): model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map=torch_device) @@ -289,14 +287,14 @@ def test_generate(self): outputs = self.client.generate(prompts)["completion_ids"] # Check that the output is a list - self.assertIsInstance(outputs, list) + assert isinstance(outputs, list) # Check that the number of generated sequences is equal to the number of prompts - self.assertEqual(len(outputs), len(prompts)) + assert len(outputs) == len(prompts) # Check that the generated sequences are lists of integers for seq in outputs: - self.assertTrue(all(isinstance(tok, int) for tok in seq)) + assert all(isinstance(tok, int) for tok in seq) def test_update_model_params(self): model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map=torch_device) @@ -345,8 +343,8 @@ def test_init_communicator_with_device_int(self): # Test basic functionality prompts = ["Hello, AI!"] outputs = client.generate(prompts)["completion_ids"] - self.assertIsInstance(outputs, list) - self.assertEqual(len(outputs), len(prompts)) + assert isinstance(outputs, list) + assert len(outputs) == len(prompts) client.close_communicator() @@ -358,8 +356,8 @@ def test_init_communicator_with_device_string(self): # Test basic functionality prompts = ["Hello, AI!"] outputs = client.generate(prompts)["completion_ids"] - self.assertIsInstance(outputs, list) - self.assertEqual(len(outputs), len(prompts)) + assert isinstance(outputs, list) + assert len(outputs) == len(prompts) client.close_communicator() @@ -374,8 +372,8 @@ def test_init_communicator_with_torch_device(self): # Test basic functionality prompts = ["Hello, AI!"] outputs = client.generate(prompts)["completion_ids"] - self.assertIsInstance(outputs, list) - self.assertEqual(len(outputs), len(prompts)) + assert isinstance(outputs, list) + assert len(outputs) == len(prompts) client.close_communicator() diff --git a/tests/test_xpo_trainer.py b/tests/test_xpo_trainer.py index 9d50b542a03..7a69455df33 100644 --- a/tests/test_xpo_trainer.py +++ b/tests/test_xpo_trainer.py @@ -65,7 +65,7 @@ def test_xpo_trainer_training(self, config_name): trainer.train() # Check if training loss is available - self.assertIn("train_loss", trainer.state.log_history[-1]) + assert "train_loss" in trainer.state.log_history[-1] @require_peft def test_training_with_peft(self): @@ -93,7 +93,7 @@ def test_training_with_peft(self): trainer.train() # Check if training loss is available - self.assertIn("train_loss", trainer.state.log_history[-1]) + assert "train_loss" in trainer.state.log_history[-1] @require_peft def test_training_with_peft_and_ref_model(self): @@ -122,7 +122,7 @@ def test_training_with_peft_and_ref_model(self): trainer.train() # Check if training loss is available - self.assertIn("train_loss", trainer.state.log_history[-1]) + assert "train_loss" in trainer.state.log_history[-1] @require_peft def test_training_with_peft_model_and_peft_config(self): @@ -153,7 +153,7 @@ def test_training_with_peft_model_and_peft_config(self): trainer.train() # Check if training loss is available - self.assertIn("train_loss", trainer.state.log_history[-1]) + assert "train_loss" in trainer.state.log_history[-1] @require_peft def test_training_pre_pefted_model_implicit_ref(self): @@ -182,7 +182,7 @@ def test_training_pre_pefted_model_implicit_ref(self): trainer.train() - self.assertIn("train_loss", trainer.state.log_history[-1]) + assert "train_loss" in trainer.state.log_history[-1] @require_llm_blender @parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)]) @@ -213,4 +213,4 @@ def test_xpo_trainer_judge_training(self, config_name): trainer.train() # Check if training loss is available - self.assertIn("train_loss", trainer.state.log_history[-1]) + assert "train_loss" in trainer.state.log_history[-1] From 283c05908c65f9ada9b42daa2a3153ab8c2be8ee Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Wed, 1 Oct 2025 10:24:26 +0200 Subject: [PATCH 03/16] Fix style --- tests/slow/test_grpo_slow.py | 8 ++-- tests/test_activation_offloading.py | 3 +- tests/test_bco_trainer.py | 9 ++-- tests/test_cli_utils.py | 2 +- tests/test_data_utils.py | 48 ++++++++++---------- tests/test_dataset_formatting.py | 12 +++-- tests/test_dpo_trainer.py | 69 +++++++++++++++-------------- tests/test_gkd_trainer.py | 17 +++---- tests/test_grpo_trainer.py | 20 ++++----- tests/test_kto_trainer.py | 20 +++++---- tests/test_modeling_value_head.py | 22 +++++---- tests/test_online_dpo_trainer.py | 4 +- tests/test_peft_models.py | 17 ++++--- tests/test_prm_trainer.py | 54 ++++++++++------------ tests/test_rewards.py | 2 +- tests/test_rloo_trainer.py | 2 +- tests/test_sft_trainer.py | 20 +++++---- tests/test_trainers_args.py | 28 ++++++------ tests/test_utils.py | 24 +++++----- tests/test_vllm_client_server.py | 6 ++- 20 files changed, 202 insertions(+), 185 deletions(-) diff --git a/tests/slow/test_grpo_slow.py b/tests/slow/test_grpo_slow.py index 5f4400b9115..3a453714daa 100644 --- a/tests/slow/test_grpo_slow.py +++ b/tests/slow/test_grpo_slow.py @@ -444,9 +444,9 @@ def dummy_reward_func(completions, **kwargs): # Check if signature columns were set properly if trainer._signature_columns is not None: # Should include 'image' in signature columns for VLM processors - assert "image" in \ - trainer._signature_columns, \ + assert "image" in trainer._signature_columns, ( "Should include 'image' in signature columns for VLM" + ) # Should not emit any warnings about VLM incompatibility incompatibility_warnings = [ @@ -455,9 +455,9 @@ def dummy_reward_func(completions, **kwargs): if "does not support VLMs" in str(w_item.message) or "not compatible" in str(w_item.message).lower() ] - assert len(incompatibility_warnings) == \ - 0, \ + assert len(incompatibility_warnings) == 0, ( f"Should not emit VLM incompatibility warnings, but got: {incompatibility_warnings}" + ) # Test passes if we get this far without exceptions diff --git a/tests/test_activation_offloading.py b/tests/test_activation_offloading.py index b1e8f59d61d..6c9ae24b8ae 100644 --- a/tests/test_activation_offloading.py +++ b/tests/test_activation_offloading.py @@ -72,8 +72,9 @@ def test_offloading_with_peft_models(self) -> None: for name_orig, grad_orig in grads_original: for name_param, param in model.named_parameters(): if name_param == name_orig and param.requires_grad and param.grad is not None: - assert torch.allclose(grad_orig, param.grad, rtol=1e-4, atol=1e-5), \ + assert torch.allclose(grad_orig, param.grad, rtol=1e-4, atol=1e-5), ( f"Gradient mismatch for {name_orig}" + ) @require_torch_accelerator def test_noop_manager_with_offloading(self): diff --git a/tests/test_bco_trainer.py b/tests/test_bco_trainer.py index b1fdaf8d8cf..91f905eb628 100644 --- a/tests/test_bco_trainer.py +++ b/tests/test_bco_trainer.py @@ -14,6 +14,7 @@ from functools import partial +import pytest import torch from accelerate import Accelerator from datasets import load_dataset @@ -26,7 +27,6 @@ from trl.trainer.bco_trainer import _process_tokens, _tokenize from .testing_utils import TrlTestCase, require_no_wandb, require_sklearn -import pytest if is_peft_available(): @@ -363,8 +363,11 @@ def test_generate_during_eval_no_wandb(self): report_to="none", ) - with pytest.raises(ValueError, match="`generate_during_eval=True` requires Weights and Biases or Comet to be installed." - " Please install `wandb` or `comet-ml` to resolve."): + with pytest.raises( + ValueError, + match="`generate_during_eval=True` requires Weights and Biases or Comet to be installed." + " Please install `wandb` or `comet-ml` to resolve.", + ): BCOTrainer( model=model, args=training_args, diff --git a/tests/test_cli_utils.py b/tests/test_cli_utils.py index 01417c258bc..f8196dd5377 100644 --- a/tests/test_cli_utils.py +++ b/tests/test_cli_utils.py @@ -17,13 +17,13 @@ from dataclasses import dataclass from unittest.mock import mock_open, patch +import pytest from datasets import DatasetDict, load_dataset from trl import DatasetMixtureConfig, TrlParser, get_dataset from trl.scripts.utils import DatasetConfig from .testing_utils import TrlTestCase -import pytest @dataclass diff --git a/tests/test_data_utils.py b/tests/test_data_utils.py index 88c63868094..abb27258e3f 100644 --- a/tests/test_data_utils.py +++ b/tests/test_data_utils.py @@ -675,46 +675,46 @@ class UnpairPreferenceDatasetTester(TrlTestCase): def test_unpair_preference_dataset(self): # Test that a paired dataset is correctly converted to unpaired unpaired_dataset = unpair_preference_dataset(self.paired_dataset) - assert unpaired_dataset.to_dict() == \ - self.unpaired_dataset.to_dict(), \ + assert unpaired_dataset.to_dict() == self.unpaired_dataset.to_dict(), ( "The paired dataset should be converted to unpaired." + ) def test_unpair_preference_dataset_dict(self): # Test that a paired dataset dict is correctly converted to unpaired paired_dataset_dict = DatasetDict({"abc": self.paired_dataset}) unpaired_dataset_dict = unpair_preference_dataset(paired_dataset_dict) - assert unpaired_dataset_dict["abc"].to_dict() == \ - self.unpaired_dataset.to_dict(), \ + assert unpaired_dataset_dict["abc"].to_dict() == self.unpaired_dataset.to_dict(), ( "The paired dataset should be converted to unpaired." + ) def test_maybe_unpair_preference_dataset(self): # Test that a paired dataset is correctly converted to unpaired with maybe_unpair_preference_dataset unpaired_dataset = maybe_unpair_preference_dataset(self.paired_dataset) - assert unpaired_dataset.to_dict() == \ - self.unpaired_dataset.to_dict(), \ + assert unpaired_dataset.to_dict() == self.unpaired_dataset.to_dict(), ( "The paired dataset should be converted to unpaired." + ) def test_maybe_unpair_preference_dataset_dict(self): # Test that a paired dataset dict is correctly converted to unpaired with maybe_unpair_preference_dataset paired_dataset_dict = DatasetDict({"abc": self.paired_dataset}) unpaired_dataset_dict = maybe_unpair_preference_dataset(paired_dataset_dict) - assert unpaired_dataset_dict["abc"].to_dict() == \ - self.unpaired_dataset.to_dict(), \ + assert unpaired_dataset_dict["abc"].to_dict() == self.unpaired_dataset.to_dict(), ( "The paired dataset should be converted to unpaired." + ) def test_maybe_unpair_preference_dataset_already_paired(self): # Test that a paired dataset remains unchanged with maybe_unpair_preference_dataset unpaired_dataset = maybe_unpair_preference_dataset(self.unpaired_dataset) - assert unpaired_dataset.to_dict() == \ - self.unpaired_dataset.to_dict(), \ + assert unpaired_dataset.to_dict() == self.unpaired_dataset.to_dict(), ( "The unpaired dataset should remain unchanged." + ) def test_maybe_unpair_preference_dataset_dict_already_paired(self): # Test that a paired dataset dict remains unchanged with maybe_unpair_preference_dataset unpaired_dataset_dict = maybe_unpair_preference_dataset(DatasetDict({"abc": self.unpaired_dataset})) - assert unpaired_dataset_dict["abc"].to_dict() == \ - self.unpaired_dataset.to_dict(), \ + assert unpaired_dataset_dict["abc"].to_dict() == self.unpaired_dataset.to_dict(), ( "The unpaired dataset should remain unchanged." + ) class ExtractPromptTester(TrlTestCase): @@ -755,44 +755,42 @@ class ExtractPromptTester(TrlTestCase): def test_extract_prompt_conversational(self): # Test that the prompt is correctly extracted from the dataset example_extracted_prompt = extract_prompt(self.example_implicit_prompt_conversational) - assert example_extracted_prompt == \ - self.example_explicit_prompt_conversational, \ + assert example_extracted_prompt == self.example_explicit_prompt_conversational, ( "The prompt is not correctly extracted from the dataset." + ) def test_maybe_extract_prompt_conversational(self): # Test that the prompt is correctly extracted from the dataset with maybe_extract_prompt example_extracted_prompt = maybe_extract_prompt(self.example_implicit_prompt_conversational) - assert example_extracted_prompt == \ - self.example_explicit_prompt_conversational, \ + assert example_extracted_prompt == self.example_explicit_prompt_conversational, ( "The prompt is not correctly extracted from the dataset." + ) def test_maybe_extract_prompt_conversational_already_explicit(self): # Test that the prompt remains unchanged with maybe_extract_prompt example_extracted_prompt = maybe_extract_prompt(self.example_explicit_prompt_conversational) - assert example_extracted_prompt == \ - self.example_explicit_prompt_conversational, \ + assert example_extracted_prompt == self.example_explicit_prompt_conversational, ( "The prompt should remain unchanged." + ) def test_extract_prompt_standard(self): # Test that the prompt is correctly extracted from the dataset example_extracted_prompt = extract_prompt(self.example_implicit_prompt_standard) - assert example_extracted_prompt == \ - self.example_explicit_prompt_standard, \ + assert example_extracted_prompt == self.example_explicit_prompt_standard, ( "The prompt is not correctly extracted from the dataset." + ) def test_maybe_extract_prompt_standard(self): # Test that the prompt is correctly extracted from the dataset with maybe_extract_prompt example_extracted_prompt = maybe_extract_prompt(self.example_implicit_prompt_standard) - assert example_extracted_prompt == \ - self.example_explicit_prompt_standard, \ + assert example_extracted_prompt == self.example_explicit_prompt_standard, ( "The prompt is not correctly extracted from the dataset." + ) def test_maybe_extract_prompt_standard_already_explicit(self): # Test that the prompt remains unchanged with maybe_extract_prompt example_extracted_prompt = maybe_extract_prompt(self.example_explicit_prompt_standard) - assert example_extracted_prompt == \ - self.example_explicit_prompt_standard, \ - "The prompt should remain unchanged." + assert example_extracted_prompt == self.example_explicit_prompt_standard, "The prompt should remain unchanged." class TestPackDatasetWrapped(TrlTestCase): diff --git a/tests/test_dataset_formatting.py b/tests/test_dataset_formatting.py index cc132aedefb..78f2e7dfb1d 100644 --- a/tests/test_dataset_formatting.py +++ b/tests/test_dataset_formatting.py @@ -152,8 +152,10 @@ def test_example_with_setup_model(self): ] prompt = modified_tokenizer.apply_chat_template(messages, tokenize=False) - assert prompt == \ - "<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi, how can I help you?<|im_end|>\n" + assert ( + prompt + == "<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi, how can I help you?<|im_end|>\n" + ) class CloneChatTemplateTestCase(TrlTestCase): @@ -217,8 +219,10 @@ def test_apply_new_chat_template(self): ] prompt = modified_tokenizer.apply_chat_template(messages, tokenize=False) - assert prompt == \ - "<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n\n\n\n\nHi, how can I help you?<|im_end|>\n" + assert ( + prompt + == "<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n\n\n\n\nHi, how can I help you?<|im_end|>\n" + ) def test_clone_with_sequence_classification_model(self): # This tokenizer doesn't have a chat_template by default diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index f429726c20b..8c0a4efb250 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -17,6 +17,7 @@ from unittest.mock import MagicMock import numpy as np +import pytest import torch from datasets import Dataset, features, load_dataset from parameterized import parameterized @@ -40,7 +41,6 @@ from trl import DPOConfig, DPOTrainer, FDivergenceType from .testing_utils import TrlTestCase, require_bitsandbytes, require_no_wandb -import pytest if is_vision_available(): @@ -85,12 +85,11 @@ def test_tokenize_row_no_truncation_no_special_tokens(self): ) # Assert the correct output without truncation or special tokens - assert result == \ - { - "prompt_input_ids": [464, 6766, 318], - "chosen_input_ids": [4171, 2], # eos_token added - "rejected_input_ids": [4077, 2], # eos_token added - } + assert result == { + "prompt_input_ids": [464, 6766, 318], + "chosen_input_ids": [4171, 2], # eos_token added + "rejected_input_ids": [4077, 2], # eos_token added + } def test_tokenize_row_with_truncation(self): # Define the input features @@ -106,12 +105,11 @@ def test_tokenize_row_with_truncation(self): ) # Assert the correct output with truncation applied - assert result == \ - { - "prompt_input_ids": [6766, 318], # truncated to the last 2 tokens - "chosen_input_ids": [4171], # truncated to 1 token - "rejected_input_ids": [4077], # truncated to 1 token - } + assert result == { + "prompt_input_ids": [6766, 318], # truncated to the last 2 tokens + "chosen_input_ids": [4171], # truncated to 1 token + "rejected_input_ids": [4077], # truncated to 1 token + } def test_tokenize_row_with_special_tokens(self): # Define the input features @@ -127,12 +125,11 @@ def test_tokenize_row_with_special_tokens(self): ) # Assert the correct output with special tokens added - assert result == \ - { - "prompt_input_ids": [0, 464, 6766, 318, 2], # bos_token and eos_token added - "chosen_input_ids": [4171, 2], # eos_token added - "rejected_input_ids": [4077, 2], # eos_token added - } + assert result == { + "prompt_input_ids": [0, 464, 6766, 318, 2], # bos_token and eos_token added + "chosen_input_ids": [4171, 2], # eos_token added + "rejected_input_ids": [4077, 2], # eos_token added + } def test_tokenize_row_with_truncation_and_special_tokens(self): # Define the input features @@ -148,12 +145,11 @@ def test_tokenize_row_with_truncation_and_special_tokens(self): ) # Assert the correct output with both truncation and special tokens - assert result == \ - { - "prompt_input_ids": [464, 6766, 318, 2], # truncated to 4 tokens with bos_token and eos_token - "chosen_input_ids": [4171], # truncated to 1 token - "rejected_input_ids": [4077], # truncated to 1 token - } + assert result == { + "prompt_input_ids": [464, 6766, 318, 2], # truncated to 4 tokens with bos_token and eos_token + "chosen_input_ids": [4171], # truncated to 1 token + "rejected_input_ids": [4077], # truncated to 1 token + } class DPOTrainerTester(TrlTestCase): @@ -573,8 +569,11 @@ def test_dpo_trainer_generate_during_eval_no_wandb(self): dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference") - with pytest.raises(ValueError, match="`generate_during_eval=True` requires Weights and Biases, MLFlow or Comet to be installed." - " Please install `wandb`, `mlflow` or `comet-ml` to resolve."): + with pytest.raises( + ValueError, + match="`generate_during_eval=True` requires Weights and Biases, MLFlow or Comet to be installed." + " Please install `wandb`, `mlflow` or `comet-ml` to resolve.", + ): DPOTrainer( model=self.model, ref_model=None, @@ -963,9 +962,10 @@ def test_dpo_trainer_dtype(self): train_dataset=dummy_dataset["train"], ) - assert "Invalid `dtype` passed to the config. Expected either 'auto' or a string representing a valid " \ - "`torch.dtype` (e.g., 'float32'), but got -1." in \ - str(context.exception) + assert ( + "Invalid `dtype` passed to the config. Expected either 'auto' or a string representing a valid " + "`torch.dtype` (e.g., 'float32'), but got -1." in str(context.exception) + ) training_args = DPOConfig( output_dir=self.tmp_dir, @@ -984,9 +984,10 @@ def test_dpo_trainer_dtype(self): train_dataset=dummy_dataset["train"], ) - assert "Invalid `dtype` passed to the config. Expected either 'auto' or a string representing a valid " \ - "`torch.dtype` (e.g., 'float32'), but got -1." in \ - str(context.exception) + assert ( + "Invalid `dtype` passed to the config. Expected either 'auto' or a string representing a valid " + "`torch.dtype` (e.g., 'float32'), but got -1." in str(context.exception) + ) def test_dpo_loss_alpha_div_f(self): model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" @@ -1369,7 +1370,7 @@ def test_dpo_trainer_with_liger(self, beta, loss_type): with torch.no_grad(): output = trainer.model(**model_inputs) assert output is not None - assert not ("loss" in output.keys()) + assert "loss" not in output.keys() def test_train_with_iterable_dataset(self): model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" diff --git a/tests/test_gkd_trainer.py b/tests/test_gkd_trainer.py index 11ddcaedc81..7bdec1d4553 100644 --- a/tests/test_gkd_trainer.py +++ b/tests/test_gkd_trainer.py @@ -70,8 +70,9 @@ def test_generate_on_policy_outputs_deterministic(self): # Check if the generated texts start with the original prompts for prompt, generated_text in zip(prompts, generated_texts): - assert generated_text.startswith(prompt), \ + assert generated_text.startswith(prompt), ( f"Generated text '{generated_text}' does not start with prompt '{prompt}'" + ) # Run the generation twice and check if the outputs are identical outputs2 = GKDTrainer.generate_on_policy_outputs( @@ -82,10 +83,10 @@ def test_generate_on_policy_outputs_deterministic(self): # Check if the two generations are identical assert torch.all(new_input_ids.eq(new_input_ids2)), "Deterministic generations are not identical" - assert torch.all(new_attention_mask.eq(new_attention_mask2)), \ + assert torch.all(new_attention_mask.eq(new_attention_mask2)), ( "Attention masks for deterministic generations are not identical" - assert torch.all(new_labels.eq(new_labels2)), \ - "Labels for deterministic generations are not identical" + ) + assert torch.all(new_labels.eq(new_labels2)), "Labels for deterministic generations are not identical" def test_generate_on_policy_outputs(self): prompts = ["Hello, how are you?", "What's the weather like today?"] @@ -134,7 +135,7 @@ def setUp(self): def test_uniform_distribution(self): logits = torch.ones(1, 1, self.vocab_size) loss = GKDTrainer.generalized_jsd_loss(logits, logits) - assert round(abs(loss.item()-0), 5) == 0 + assert round(abs(loss.item() - 0), 5) == 0 def test_generalized_jsd_loss_edge_cases(self): # Setup @@ -146,14 +147,14 @@ def test_generalized_jsd_loss_edge_cases(self): expected_loss_beta_1 = F.kl_div( F.log_softmax(teacher_logits, dim=-1), F.softmax(student_logits, dim=-1), reduction="batchmean" ) - assert round(abs(loss_beta_1.item()-expected_loss_beta_1.item()), 5) == 0 + assert round(abs(loss_beta_1.item() - expected_loss_beta_1.item()), 5) == 0 # Case 2: beta = 0 (should be equivalent to KL(teacher || student)) loss_beta_0 = GKDTrainer.generalized_jsd_loss(student_logits, teacher_logits, beta=0) expected_loss_beta_0 = F.kl_div( F.log_softmax(student_logits, dim=-1), F.softmax(teacher_logits, dim=-1), reduction="batchmean" ) - assert round(abs(loss_beta_0.item()-expected_loss_beta_0.item()), 5) == 0 + assert round(abs(loss_beta_0.item() - expected_loss_beta_0.item()), 5) == 0 def test_output_shape(self): loss = GKDTrainer.generalized_jsd_loss(self.student_logits, self.teacher_logits) @@ -195,7 +196,7 @@ def test_symmetry(self): def test_zero_loss_for_identical_inputs(self): identical_logits = torch.randn(self.batch_size, self.seq_length, self.vocab_size) loss = GKDTrainer.generalized_jsd_loss(identical_logits, identical_logits) - assert round(abs(loss.item()-0), 6) == 0 + assert round(abs(loss.item() - 0), 6) == 0 class GKDTrainerTester(TrlTestCase): diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index 9a442ee6102..f353df444cf 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -1922,19 +1922,19 @@ def test_update_with_inputs_different_seq_len(self): # Check for new entry with seq len 3 in buffer assert [[3, 4, 5], [3, 4, 5]] in buffered_prompt_ids # excluded no-variance group - assert [[1013, 1014, pad_token_id], [1015, 1016, 1017]] in buffered_completion_ids # excluded no-variance group + assert [ + [1013, 1014, pad_token_id], + [1015, 1016, 1017], + ] in buffered_completion_ids # excluded no-variance group # Check that sampled outputs contain one group with prompt_ids starting with a pad token assert [ - [pad_token_id, 101, 102], - [pad_token_id, 102, 103], - ] \ - in output_prompt_ids \ - or [ - [pad_token_id, 104, 105], - [pad_token_id, 106, 107], - ] \ - in output_prompt_ids + [pad_token_id, 101, 102], + [pad_token_id, 102, 103], + ] in output_prompt_ids or [ + [pad_token_id, 104, 105], + [pad_token_id, 106, 107], + ] in output_prompt_ids @pytest.mark.low_priority diff --git a/tests/test_kto_trainer.py b/tests/test_kto_trainer.py index ac68a00d9a7..da749abc62d 100644 --- a/tests/test_kto_trainer.py +++ b/tests/test_kto_trainer.py @@ -13,6 +13,7 @@ # limitations under the License. +import pytest import torch from datasets import load_dataset from parameterized import parameterized @@ -23,7 +24,6 @@ from trl.trainer.kto_trainer import _get_kl_dataset, _process_tokens, _tokenize from .testing_utils import TrlTestCase, require_no_wandb -import pytest class KTOTrainerTester(TrlTestCase): @@ -167,12 +167,11 @@ def test_tokenize_and_process_tokens(self): # the last batch remains unaltered. This is a rare scenario that does not impact the training # process, so we exclude it from testing by iterating only up to len - 1. for i in range(len(tokenized_kl_dataset["answer_input_ids"]) - 1): - assert tokenized_dataset["prompt_input_ids"][i] == \ - tokenized_kl_dataset["prompt_input_ids"][i] - assert tokenized_dataset["prompt_attention_mask"][i] == \ - tokenized_kl_dataset["prompt_attention_mask"][i] - assert tokenized_dataset["answer_input_ids"][i] != \ - tokenized_kl_dataset["answer_input_ids"][i] + assert tokenized_dataset["prompt_input_ids"][i] == tokenized_kl_dataset["prompt_input_ids"][i] + assert ( + tokenized_dataset["prompt_attention_mask"][i] == tokenized_kl_dataset["prompt_attention_mask"][i] + ) + assert tokenized_dataset["answer_input_ids"][i] != tokenized_kl_dataset["answer_input_ids"][i] fn_kwargs = { "prefix": "", @@ -295,8 +294,11 @@ def test_kto_trainer_generate_during_eval_no_wandb(self): dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference") - with pytest.raises(ValueError, match="`generate_during_eval=True` requires Weights and Biases or Comet to be installed." - " Please install `wandb` or `comet-ml` to resolve."): + with pytest.raises( + ValueError, + match="`generate_during_eval=True` requires Weights and Biases or Comet to be installed." + " Please install `wandb` or `comet-ml` to resolve.", + ): KTOTrainer( model=self.model, ref_model=None, diff --git a/tests/test_modeling_value_head.py b/tests/test_modeling_value_head.py index 355dcb2a1c8..0aa56901f82 100644 --- a/tests/test_modeling_value_head.py +++ b/tests/test_modeling_value_head.py @@ -142,8 +142,8 @@ def test_from_save_transformers_sharded(self): # Check if the weights are the same for key in transformers_model.state_dict(): assert torch.allclose( - transformers_model_from_save.state_dict()[key], transformers_model.state_dict()[key] - ) + transformers_model_from_save.state_dict()[key], transformers_model.state_dict()[key] + ) def test_from_save_transformers(self): """ @@ -163,8 +163,8 @@ def test_from_save_transformers(self): # Check if the weights are the same for key in transformers_model.state_dict(): assert torch.allclose( - transformers_model_from_save.state_dict()[key], transformers_model.state_dict()[key] - ) + transformers_model_from_save.state_dict()[key], transformers_model.state_dict()[key] + ) # Check if the trl model has the same keys as the transformers model # except the v_head @@ -175,8 +175,9 @@ def test_from_save_transformers(self): assert torch.allclose(trl_model.state_dict()[key], transformers_model.state_dict()[key]) # check if they have the same modules - assert set(transformers_model_from_save.state_dict().keys()) == \ - set(transformers_model.state_dict().keys()) + assert set(transformers_model_from_save.state_dict().keys()) == set( + transformers_model.state_dict().keys() + ) class CausalLMValueHeadModelTester(BaseTester.VHeadModelTester, TrlTestCase): @@ -261,8 +262,9 @@ def test_transformers_bf16_kwargs(self): lm_head_namings = ["lm_head", "embed_out", "output_layer"] - assert any(hasattr(trl_model.pretrained_model, lm_head_naming) for lm_head_naming in lm_head_namings), \ + assert any(hasattr(trl_model.pretrained_model, lm_head_naming) for lm_head_naming in lm_head_namings), ( "Can't test the model because it doesn't have any of the expected lm_head namings" + ) for lm_head_naming in lm_head_namings: if hasattr(trl_model.pretrained_model, lm_head_naming): @@ -287,8 +289,9 @@ def test_push_to_hub(self): assert model.state_dict().keys() == model_from_pretrained.state_dict().keys() for name, param in model.state_dict().items(): - assert torch.allclose(param, model_from_pretrained.state_dict()[name]), \ + assert torch.allclose(param, model_from_pretrained.state_dict()[name]), ( f"Parameter {name} is not the same after push_to_hub and from_pretrained" + ) class Seq2SeqValueHeadModelTester(BaseTester.VHeadModelTester, TrlTestCase): @@ -378,8 +381,9 @@ def test_push_to_hub(self): assert model.state_dict().keys() == model_from_pretrained.state_dict().keys() for name, param in model.state_dict().items(): - assert torch.allclose(param, model_from_pretrained.state_dict()[name]), \ + assert torch.allclose(param, model_from_pretrained.state_dict()[name]), ( f"Parameter {name} is not the same after push_to_hub and from_pretrained" + ) def test_transformers_bf16_kwargs(self): r""" diff --git a/tests/test_online_dpo_trainer.py b/tests/test_online_dpo_trainer.py index b8eb83a0f04..54e00447593 100644 --- a/tests/test_online_dpo_trainer.py +++ b/tests/test_online_dpo_trainer.py @@ -478,8 +478,8 @@ def simple_reward_func(prompts, completions, completion_ids, **kwargs): assert "train_loss" in trainer.state.log_history[-1] assert len(trainer.reward_funcs) == 2 assert trainer.reward_weights is not None - assert round(abs(trainer.reward_weights[0].item()-0.7), 5) == 0 - assert round(abs(trainer.reward_weights[1].item()-0.3), 5) == 0 + assert round(abs(trainer.reward_weights[0].item() - 0.7), 5) == 0 + assert round(abs(trainer.reward_weights[1].item() - 0.3), 5) == 0 @require_vision diff --git a/tests/test_peft_models.py b/tests/test_peft_models.py index 3b68be8b817..ac62174db00 100644 --- a/tests/test_peft_models.py +++ b/tests/test_peft_models.py @@ -136,17 +136,21 @@ def test_save_pretrained_peft(self): model.save_pretrained(self.tmp_dir) # check that the files `adapter_model.safetensors` and `adapter_config.json` are in the directory - assert os.path.isfile(f"{self.tmp_dir}/adapter_model.safetensors"), \ + assert os.path.isfile(f"{self.tmp_dir}/adapter_model.safetensors"), ( f"{self.tmp_dir}/adapter_model.safetensors does not exist" - assert os.path.exists(f"{self.tmp_dir}/adapter_config.json"), f"{self.tmp_dir}/adapter_config.json does not exist" + ) + assert os.path.exists(f"{self.tmp_dir}/adapter_config.json"), ( + f"{self.tmp_dir}/adapter_config.json does not exist" + ) # check also for `pytorch_model.bin` and make sure it only contains `v_head` weights assert os.path.exists(f"{self.tmp_dir}/pytorch_model.bin"), f"{self.tmp_dir}/pytorch_model.bin does not exist" # check that only keys that starts with `v_head` are in the dict maybe_v_head = torch.load(f"{self.tmp_dir}/pytorch_model.bin", weights_only=True) - assert all(k.startswith("v_head") for k in maybe_v_head.keys()), \ + assert all(k.startswith("v_head") for k in maybe_v_head.keys()), ( f"keys in {self.tmp_dir}/pytorch_model.bin do not start with `v_head`" + ) model_from_pretrained = AutoModelForCausalLMWithValueHead.from_pretrained(self.tmp_dir) @@ -167,9 +171,12 @@ def test_load_pretrained_peft(self): model_from_pretrained = AutoModelForCausalLMWithValueHead.from_pretrained(self.tmp_dir) # check that the files `adapter_model.safetensors` and `adapter_config.json` are in the directory - assert os.path.isfile(f"{self.tmp_dir}/adapter_model.safetensors"), \ + assert os.path.isfile(f"{self.tmp_dir}/adapter_model.safetensors"), ( f"{self.tmp_dir}/adapter_model.safetensors does not exist" - assert os.path.exists(f"{self.tmp_dir}/adapter_config.json"), f"{self.tmp_dir}/adapter_config.json does not exist" + ) + assert os.path.exists(f"{self.tmp_dir}/adapter_config.json"), ( + f"{self.tmp_dir}/adapter_config.json does not exist" + ) # check all the weights are the same for p1, p2 in zip(model.named_parameters(), model_from_pretrained.named_parameters()): diff --git a/tests/test_prm_trainer.py b/tests/test_prm_trainer.py index 76398519836..2a8083d7492 100644 --- a/tests/test_prm_trainer.py +++ b/tests/test_prm_trainer.py @@ -75,11 +75,10 @@ def test_tokenize_row_no_truncation(self): is_eval=False, ) - assert result == \ - { - "input_ids": [0, 465, 6766, 318, 298, 4, 322, 12, 1030, 4995, 11, 22, 1030], - "labels": [-100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100, -100, 0], - } + assert result == { + "input_ids": [0, 465, 6766, 318, 298, 4, 322, 12, 1030, 4995, 11, 22, 1030], + "labels": [-100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100, -100, 0], + } def test_tokenize_row_train_on_last_step_only(self): # Define the input features @@ -100,11 +99,10 @@ def test_tokenize_row_train_on_last_step_only(self): is_eval=False, ) - assert result == \ - { - "input_ids": [0, 465, 6766, 318, 298, 4, 322, 12, 1030, 4995, 11, 22, 1030], - "labels": [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 0], - } + assert result == { + "input_ids": [0, 465, 6766, 318, 298, 4, 322, 12, 1030, 4995, 11, 22, 1030], + "labels": [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 0], + } def test_tokenize_row_prompt_truncation(self): # Define the input features @@ -126,11 +124,10 @@ def test_tokenize_row_prompt_truncation(self): is_eval=False, ) - assert result == \ - { - "input_ids": [6766, 318, 298, 4, 322, 12, 1030, 4995, 11, 22, 1030], - "labels": [-100, -100, -100, -100, -100, -100, 1, -100, -100, -100, 0], - } + assert result == { + "input_ids": [6766, 318, 298, 4, 322, 12, 1030, 4995, 11, 22, 1030], + "labels": [-100, -100, -100, -100, -100, -100, 1, -100, -100, -100, 0], + } def test_tokenize_row_completion_truncation(self): # Define the input features @@ -152,11 +149,10 @@ def test_tokenize_row_completion_truncation(self): is_eval=False, ) - assert result == \ - { - "input_ids": [0, 465, 6766, 318, 298, 4, 322, 12, 1030, 4995, 11], - "labels": [-100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100], - } + assert result == { + "input_ids": [0, 465, 6766, 318, 298, 4, 322, 12, 1030, 4995, 11], + "labels": [-100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100], + } def test_tokenize_row_prompt_completion_truncation(self): # Define the input features @@ -178,11 +174,10 @@ def test_tokenize_row_prompt_completion_truncation(self): is_eval=False, ) - assert result == \ - { - "input_ids": [0, 465, 6766, 318, 298, 4, 322, 12, 1030], - "labels": [-100, -100, -100, -100, -100, -100, -100, -100, 1], - } + assert result == { + "input_ids": [0, 465, 6766, 318, 298, 4, 322, 12, 1030], + "labels": [-100, -100, -100, -100, -100, -100, -100, -100, 1], + } def test_tokenize_row_multi_token_separator(self): # Define the input features @@ -204,11 +199,10 @@ def test_tokenize_row_multi_token_separator(self): is_eval=False, ) - assert result == \ - { - "input_ids": [0, 465, 6766, 318, 298, 4, 322, 12, 1030, 1030, 4995, 11, 22, 1030, 1030], - "labels": [-100, -100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100, -100, -100, 0], - } + assert result == { + "input_ids": [0, 465, 6766, 318, 298, 4, 322, 12, 1030, 1030, 4995, 11, 22, 1030, 1030], + "labels": [-100, -100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100, -100, -100, 0], + } class PRMTrainerTester(TrlTestCase): diff --git a/tests/test_rewards.py b/tests/test_rewards.py index aac6aabca0c..21827b6b4ea 100644 --- a/tests/test_rewards.py +++ b/tests/test_rewards.py @@ -85,7 +85,7 @@ def test_soft_overlong_punishment_intermediate_completion(self): reward_fn = get_soft_overlong_punishment(max_completion_len=100, soft_punish_cache=20) completion_ids = [[1] * 90] # 90 is between 80 and 100 rewards = reward_fn(completion_ids) - assert round(abs(rewards[0]--0.5), 4) == 0 + assert round(abs(rewards[0] - -0.5), 4) == 0 if __name__ == "__main__": diff --git a/tests/test_rloo_trainer.py b/tests/test_rloo_trainer.py index 2ddb51248ce..f79c4ca3f08 100644 --- a/tests/test_rloo_trainer.py +++ b/tests/test_rloo_trainer.py @@ -15,6 +15,7 @@ import unittest from unittest.mock import patch +import pytest import torch from datasets import load_dataset from parameterized import parameterized @@ -30,7 +31,6 @@ from trl import RLOOConfig, RLOOTrainer from .testing_utils import TrlTestCase, require_vllm -import pytest if is_peft_available(): diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 2a558d6093f..b60ea48bfaa 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -1504,24 +1504,26 @@ def test_peft_model_with_quantization(self): lora_params_after.append(name) # LoRA parameters should remain trainable - assert len(trainable_params_after) > 0, \ - f"PeftModel should still have trainable parameters after SFTTrainer initialization. " \ - f"Found {len(trainable_params_after)} trainable params. " \ + assert len(trainable_params_after) > 0, ( + f"PeftModel should still have trainable parameters after SFTTrainer initialization. " + f"Found {len(trainable_params_after)} trainable params. " f"This test fails without the fix for issue #3926." + ) - assert len(lora_params_after) > 0, \ - f"LoRA adapter parameters should remain trainable. " \ + assert len(lora_params_after) > 0, ( + f"LoRA adapter parameters should remain trainable. " f"Found {len(lora_params_after)} trainable LoRA params out of {len(lora_params_before)} original." + ) # Ensure the parameter counts are preserved (no additional freezing occurred) - assert len(trainable_params_before) == \ - len(trainable_params_after), \ + assert len(trainable_params_before) == len(trainable_params_after), ( "Number of trainable parameters should not change after SFTTrainer initialization" + ) # Verify that all original LoRA parameters are still trainable - assert set(lora_params_before) == \ - set(lora_params_after), \ + assert set(lora_params_before) == set(lora_params_after), ( "All original LoRA parameters should remain trainable after SFTTrainer initialization" + ) @require_peft def test_prompt_tuning_peft_model(self): diff --git a/tests/test_trainers_args.py b/tests/test_trainers_args.py index 35a8da57cd3..b9a809ba071 100644 --- a/tests/test_trainers_args.py +++ b/tests/test_trainers_args.py @@ -84,8 +84,8 @@ def test_bco(self): assert trainer.args.padding_value == -99 assert trainer.args.truncation_mode == "keep_start" # self.assertEqual(trainer.args.generate_during_eval, True) - assert trainer.args.is_encoder_decoder == True - assert trainer.args.precompute_ref_log_probs == True + assert trainer.args.is_encoder_decoder + assert trainer.args.precompute_ref_log_probs assert trainer.args.model_init_kwargs == {"trust_remote_code": True} assert trainer.args.ref_model_init_kwargs == {"trust_remote_code": True} assert trainer.args.dataset_num_proc == 4 @@ -123,14 +123,14 @@ def test_cpo(self): assert trainer.args.beta == 0.5 assert trainer.args.label_smoothing == 0.5 assert trainer.args.loss_type == "hinge" - assert trainer.args.disable_dropout == False + assert not trainer.args.disable_dropout assert trainer.args.cpo_alpha == 0.5 assert trainer.args.simpo_gamma == 0.2 assert trainer.args.label_pad_token_id == -99 assert trainer.args.padding_value == -99 assert trainer.args.truncation_mode == "keep_start" # self.assertEqual(trainer.args.generate_during_eval, True) - assert trainer.args.is_encoder_decoder == True + assert trainer.args.is_encoder_decoder assert trainer.args.model_init_kwargs == {"trust_remote_code": True} assert trainer.args.dataset_num_proc == 4 @@ -183,16 +183,16 @@ def test_dpo(self): assert trainer.args.max_length == 256 assert trainer.args.max_prompt_length == 64 assert trainer.args.max_completion_length == 64 - assert trainer.args.disable_dropout == False + assert not trainer.args.disable_dropout # self.assertEqual(trainer.args.generate_during_eval, True) - assert trainer.args.precompute_ref_log_probs == True + assert trainer.args.precompute_ref_log_probs assert trainer.args.dataset_num_proc == 4 assert trainer.args.model_init_kwargs == {"trust_remote_code": True} assert trainer.args.ref_model_init_kwargs == {"trust_remote_code": True} assert trainer.args.model_adapter_name == "dummy_adapter" assert trainer.args.ref_adapter_name == "dummy_adapter" - assert trainer.args.reference_free == True - assert trainer.args.force_use_ref_model == True + assert trainer.args.reference_free + assert trainer.args.force_use_ref_model assert trainer.args.f_divergence_type == FDivergenceType.JS_DIVERGENCE assert trainer.args.f_alpha_divergence_coef == 0.5 # self.assertEqual(trainer.args.sync_ref_model, True) @@ -240,8 +240,8 @@ def test_kto(self): assert trainer.args.padding_value == -99 assert trainer.args.truncation_mode == "keep_start" # self.assertEqual(trainer.args.generate_during_eval, True) - assert trainer.args.is_encoder_decoder == True - assert trainer.args.precompute_ref_log_probs == True + assert trainer.args.is_encoder_decoder + assert trainer.args.precompute_ref_log_probs assert trainer.args.model_init_kwargs == {"trust_remote_code": True} assert trainer.args.ref_model_init_kwargs == {"trust_remote_code": True} assert trainer.args.dataset_num_proc == 4 @@ -323,7 +323,7 @@ def test_orpo(self): assert trainer.args.max_prompt_length == 64 assert trainer.args.max_completion_length == 64 assert trainer.args.beta == 0.5 - assert trainer.args.disable_dropout == False + assert not trainer.args.disable_dropout assert trainer.args.label_pad_token_id == -99 def test_reward(self): @@ -363,14 +363,14 @@ def test_sft(self): ) trainer = SFTTrainer(model_id, args=training_args, train_dataset=dataset) assert trainer.args.dataset_text_field == "dummy_text_field" - assert trainer.args.packing == True + assert trainer.args.packing assert trainer.args.max_length == 256 assert trainer.args.dataset_num_proc == 4 assert trainer.args.neftune_noise_alpha == 0.1 assert trainer.args.model_init_kwargs == {"trust_remote_code": True} assert "append_concat_token" in trainer.args.dataset_kwargs - assert trainer.args.dataset_kwargs["append_concat_token"] == True - assert trainer.args.eval_packing == True + assert trainer.args.dataset_kwargs["append_concat_token"] + assert trainer.args.eval_packing @parameterized.expand([(False,), (True,)]) def test_xpo(self, alpha_list): diff --git a/tests/test_utils.py b/tests/test_utils.py index 2bfa9bff0c2..0dfa00d06e9 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -17,6 +17,7 @@ from unittest.mock import patch import numpy as np +import pytest import torch from datasets import load_dataset from parameterized import parameterized @@ -47,7 +48,6 @@ ) from .testing_utils import TrlTestCase, require_rich -import pytest if is_peft_available(): @@ -292,8 +292,9 @@ def test_data_collator_for_chatml(self): # Verify labels are -100 for non-assistant parts prompt_length = len(prompt_only) - assert all(label == self.ignore_index for label in labels[:prompt_length]), \ + assert all(label == self.ignore_index for label in labels[:prompt_length]), ( "Labels should be ignore_index for prompt tokens" + ) # Verify labels match assistant response after prompt # Add a filter to remove any trailing tokens after the first <|im_end|> @@ -309,18 +310,15 @@ def test_data_collator_for_chatml(self): response_labels.append(label) if label == self.tokenizer.convert_tokens_to_ids("<|im_end|>"): break - assert response_labels == \ - last_assistant_response_tokens, \ - "Labels should match assistant response tokens" + assert response_labels == last_assistant_response_tokens, "Labels should match assistant response tokens" # Verify there isn't a generation prompt at the end generation_prompt = "<|im_start|>assistant" - assert not decoded_input.strip().endswith(generation_prompt), \ + assert not decoded_input.strip().endswith(generation_prompt), ( f"Input should not end with generation prompt '{generation_prompt}'" + ) - assert response_labels == \ - last_assistant_response_tokens, \ - "Labels should match assistant response tokens" + assert response_labels == last_assistant_response_tokens, "Labels should match assistant response tokens" class TestBatchGeneration(TrlTestCase): @@ -397,7 +395,7 @@ def test_token_classification_task(self): ) expected_accuracy = 0.5 # 2 matches, 2 mismatches result = compute_accuracy(eval_pred) - assert round(abs(result["accuracy"]-expected_accuracy), 7) == 0 + assert round(abs(result["accuracy"] - expected_accuracy), 7) == 0 def test_token_classification_task_with_ignored_tokens_0(self): eval_pred = ( @@ -411,7 +409,7 @@ def test_token_classification_task_with_ignored_tokens_0(self): ) expected_accuracy = 1.0 # All non-ignored tokens match result = compute_accuracy(eval_pred) - assert round(abs(result["accuracy"]-expected_accuracy), 7) == 0 + assert round(abs(result["accuracy"] - expected_accuracy), 7) == 0 def test_token_classification_task_with_ignored_tokens_1(self): eval_pred = ( @@ -425,7 +423,7 @@ def test_token_classification_task_with_ignored_tokens_1(self): ) expected_accuracy = 1 / 3 # 1 match, 2 mismatch, 1 ignored result = compute_accuracy(eval_pred) - assert round(abs(result["accuracy"]-expected_accuracy), 7) == 0 + assert round(abs(result["accuracy"] - expected_accuracy), 7) == 0 def test_rewards_comparison_task(self): eval_pred = ( @@ -443,7 +441,7 @@ def test_rewards_comparison_task(self): with self.assertLogs("trl.trainer.utils", level="WARNING") as cm: result = compute_accuracy(eval_pred) - assert round(abs(result["accuracy"]-expected_accuracy), 7) == 0 + assert round(abs(result["accuracy"] - expected_accuracy), 7) == 0 expected_warning = ( "There are 1 out of 3 instances where the predictions for both options are equal. " "These instances are ignored in the accuracy computation." diff --git a/tests/test_vllm_client_server.py b/tests/test_vllm_client_server.py index 836cd124f16..62d22583091 100644 --- a/tests/test_vllm_client_server.py +++ b/tests/test_vllm_client_server.py @@ -45,8 +45,10 @@ def test_single_element_list(self): assert chunk_list([42], 2) == [[42], []] def test_any_dtype(self): - assert chunk_list([1, "two", 3.0, {"four": 4}, ["f", "i", "v", "e"]], 2) == \ - [[1, "two", 3.0], [{"four": 4}, ["f", "i", "v", "e"]]] + assert chunk_list([1, "two", 3.0, {"four": 4}, ["f", "i", "v", "e"]], 2) == [ + [1, "two", 3.0], + [{"four": 4}, ["f", "i", "v", "e"]], + ] @pytest.mark.slow From 7c463f6f94da1b780fcfc4cfc2bf79cf80433586 Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Wed, 1 Oct 2025 11:25:07 +0200 Subject: [PATCH 04/16] Fix pytest.raises with match arg --- tests/test_cli_utils.py | 11 +++-------- tests/test_dpo_trainer.py | 23 +++++++++-------------- tests/test_grpo_trainer.py | 4 +--- tests/test_rloo_trainer.py | 4 +--- 4 files changed, 14 insertions(+), 28 deletions(-) diff --git a/tests/test_cli_utils.py b/tests/test_cli_utils.py index f8196dd5377..708640c55fd 100644 --- a/tests/test_cli_utils.py +++ b/tests/test_cli_utils.py @@ -45,9 +45,8 @@ def test_init_without_config_field(self): def test_init_with_config_field(self): """Test initialization with a 'config' field in the dataclass (should raise ValueError).""" - with pytest.raises(ValueError) as context: + with pytest.raises(ValueError, match="has a field named 'config'"): TrlParser(dataclass_types=[InvalidDataclass]) - assert "has a field named 'config'" in str(context.exception) @patch("builtins.open", mock_open(read_data="env:\n VAR1: value1\n VAR2: value2\narg1: 2")) @patch("yaml.safe_load") @@ -105,11 +104,9 @@ def test_parse_args_and_config_with_invalid_env(self, mock_yaml_load): args = ["--arg1", "2", "--arg2", "value", "--config", "config.yaml"] - with pytest.raises(ValueError) as context: + with pytest.raises(ValueError, match="`env` field should be a dict in the YAML file."): parser.parse_args_and_config(args) - assert str(context.exception) == "`env` field should be a dict in the YAML file." - def test_parse_args_and_config_without_config(self): """Test parse_args_and_config without the `--config` argument.""" parser = TrlParser(dataclass_types=[MyDataclass]) @@ -352,11 +349,9 @@ def test_dataset_mixture_with_test_split(self): def test_empty_dataset_mixture_raises_error(self): mixture_config = DatasetMixtureConfig(datasets=[]) - with pytest.raises(ValueError) as context: + with pytest.raises(ValueError, match="No datasets were loaded"): get_dataset(mixture_config) - assert "No datasets were loaded" in str(context.exception) - def test_mixture_multiple_different_configs(self): dataset_config1 = DatasetConfig( path="trl-internal-testing/zen", name="conversational_preference", split="train", columns=["prompt"] diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index 8c0a4efb250..0e4afd4a1bd 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -336,13 +336,12 @@ def test_train_with_multiple_loss_types(self): assert "nll_loss" in metrics # SFT loss should be computed def test_wrong_loss_weights_length(self): - with pytest.raises(ValueError) as context: + with pytest.raises(ValueError, match="Length of loss_weights list"): DPOConfig( output_dir=self.tmp_dir, loss_type=["sigmoid", "bco_pair"], loss_weights=[1.0, 0.5, 0.1], # Wrong length ) - assert "Length of loss_weights list" in str(context.exception) @parameterized.expand([(None,), (0.5,)]) def test_dpo_trainer_without_providing_ref_model(self, rpo_alpha): @@ -954,7 +953,10 @@ def test_dpo_trainer_dtype(self): report_to="none", ) - with pytest.raises(ValueError) as context: + with pytest.raises( + ValueError, + match="Invalid `dtype` passed to the config. Expected either 'auto' or a string representing a valid `torch.dtype` (e.g., 'float32'), but got -1.", + ): _ = DPOTrainer( model=self.model_id, processing_class=self.tokenizer, @@ -962,11 +964,6 @@ def test_dpo_trainer_dtype(self): train_dataset=dummy_dataset["train"], ) - assert ( - "Invalid `dtype` passed to the config. Expected either 'auto' or a string representing a valid " - "`torch.dtype` (e.g., 'float32'), but got -1." in str(context.exception) - ) - training_args = DPOConfig( output_dir=self.tmp_dir, per_device_train_batch_size=2, @@ -975,7 +972,10 @@ def test_dpo_trainer_dtype(self): report_to="none", ) - with pytest.raises(ValueError) as context: + with pytest.raises( + ValueError, + match="Invalid `dtype` passed to the config. Expected either 'auto' or a string representing a valid `torch.dtype` (e.g., 'float32'), but got -1.", + ): _ = DPOTrainer( model=self.model_id, ref_model=self.model_id, @@ -984,11 +984,6 @@ def test_dpo_trainer_dtype(self): train_dataset=dummy_dataset["train"], ) - assert ( - "Invalid `dtype` passed to the config. Expected either 'auto' or a string representing a valid " - "`torch.dtype` (e.g., 'float32'), but got -1." in str(context.exception) - ) - def test_dpo_loss_alpha_div_f(self): model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" tokenizer = AutoTokenizer.from_pretrained(model_id) diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index f353df444cf..785d8426f68 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -1645,7 +1645,7 @@ def test_mismatched_reward_processing_classes_length(self): training_args = GRPOConfig(output_dir=self.tmp_dir, report_to="none") - with pytest.raises(ValueError) as context: + with pytest.raises(ValueError, match="must match"): GRPOTrainer( model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", reward_funcs=reward_models, @@ -1654,8 +1654,6 @@ def test_mismatched_reward_processing_classes_length(self): train_dataset=dataset, ) - assert "must match" in str(context.exception) - def test_correct_reward_processing_classes_list(self): """Test that correct list of reward_processing_classes works properly.""" dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") diff --git a/tests/test_rloo_trainer.py b/tests/test_rloo_trainer.py index f79c4ca3f08..23f709d466a 100644 --- a/tests/test_rloo_trainer.py +++ b/tests/test_rloo_trainer.py @@ -1353,7 +1353,7 @@ def test_mismatched_reward_processing_classes_length(self): training_args = RLOOConfig(output_dir=self.tmp_dir, report_to="none") - with pytest.raises(ValueError) as context: + with pytest.raises(ValueError, match="must match"): RLOOTrainer( model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", reward_funcs=reward_models, @@ -1362,8 +1362,6 @@ def test_mismatched_reward_processing_classes_length(self): train_dataset=dataset, ) - assert "must match" in str(context.exception) - def test_correct_reward_processing_classes_list(self): """Test that correct list of reward_processing_classes works properly.""" dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") From 82107e4e17ad61cb664e75f4f3d5502ce5aa9fdf Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Wed, 1 Oct 2025 12:43:17 +0200 Subject: [PATCH 05/16] Use re.escape for regex special characters --- tests/test_dpo_trainer.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index 0e4afd4a1bd..58d04930b80 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import re import sys import unittest from unittest.mock import MagicMock @@ -955,7 +956,9 @@ def test_dpo_trainer_dtype(self): with pytest.raises( ValueError, - match="Invalid `dtype` passed to the config. Expected either 'auto' or a string representing a valid `torch.dtype` (e.g., 'float32'), but got -1.", + match=re.escape( + "Invalid `dtype` passed to the config. Expected either 'auto' or a string representing a valid `torch.dtype` (e.g., 'float32'), but got -1." + ), ): _ = DPOTrainer( model=self.model_id, @@ -974,7 +977,9 @@ def test_dpo_trainer_dtype(self): with pytest.raises( ValueError, - match="Invalid `dtype` passed to the config. Expected either 'auto' or a string representing a valid `torch.dtype` (e.g., 'float32'), but got -1.", + match=re.escape( + "Invalid `dtype` passed to the config. Expected either 'auto' or a string representing a valid `torch.dtype` (e.g., 'float32'), but got -1." + ), ): _ = DPOTrainer( model=self.model_id, From d93dd49826aae03c6210beea898d49da069c9407 Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Wed, 1 Oct 2025 13:03:52 +0200 Subject: [PATCH 06/16] Remove unittest from TrlTestCase --- tests/testing_utils.py | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/tests/testing_utils.py b/tests/testing_utils.py index 85026a53947..61a7935c544 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -14,13 +14,12 @@ import functools import random -import shutil import signal -import tempfile import unittest import warnings import psutil +import pytest import torch from transformers import is_bitsandbytes_available, is_comet_available, is_sklearn_available, is_wandb_available from transformers.testing_utils import torch_device @@ -119,18 +118,10 @@ def judge(self, prompts, completions, shuffle_order=True, return_scores=False): return [random.random() for _ in range(len(prompts))] -class TrlTestCase(unittest.TestCase): - """ - Base test case for TRL tests. Sets up a temporary directory for testing. - """ - - def setUp(self): - super().setUp() - self.tmp_dir = tempfile.mkdtemp() - - def tearDown(self): - shutil.rmtree(self.tmp_dir) - super().tearDown() +class TrlTestCase: + @pytest.fixture(autouse=True) + def set_tmp_dir(self, tmp_path): + self.tmp_dir = str(tmp_path) def ignore_warnings(message: str = None, category: type[Warning] = Warning) -> callable: From 56adebe96a9c33e8c1cf47d3e2c63a3c860101ce Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Wed, 1 Oct 2025 13:41:26 +0200 Subject: [PATCH 07/16] Replace setUp/tearDown --- tests/slow/test_dpo_slow.py | 6 ++-- tests/slow/test_grpo_slow.py | 6 ++-- tests/slow/test_sft_slow.py | 6 ++-- tests/test_callbacks.py | 12 +++----- tests/test_collators.py | 3 +- tests/test_core.py | 3 +- tests/test_cpo_trainer.py | 3 +- tests/test_dataset_formatting.py | 6 ++-- tests/test_dpo_trainer.py | 6 ++-- tests/test_gkd_trainer.py | 8 ++--- tests/test_grpo_trainer.py | 8 ++--- tests/test_kto_trainer.py | 3 +- ...test_modeling_geometric_mixture_wrapper.py | 3 +- tests/test_modeling_value_head.py | 12 +++----- tests/test_nash_md_trainer.py | 3 +- tests/test_online_dpo_trainer.py | 3 +- tests/test_orpo_trainer.py | 3 +- tests/test_peft_models.py | 3 +- tests/test_ppo_trainer.py | 3 +- tests/test_prm_trainer.py | 6 ++-- tests/test_rich_progress_callback.py | 3 +- tests/test_utils.py | 9 ++---- tests/test_vllm_client_server.py | 30 +++++++------------ tests/test_xpo_trainer.py | 3 +- 24 files changed, 52 insertions(+), 99 deletions(-) diff --git a/tests/slow/test_dpo_slow.py b/tests/slow/test_dpo_slow.py index e24362fbc88..30f7ffacb18 100644 --- a/tests/slow/test_dpo_slow.py +++ b/tests/slow/test_dpo_slow.py @@ -38,8 +38,7 @@ @require_torch_accelerator @require_peft class DPOTrainerSlowTester(TrlTestCase): - def setUp(self): - super().setUp() + def setup_method(self): self.dataset = load_dataset("trl-internal-testing/zen", "standard_preference") self.peft_config = LoraConfig( lora_alpha=16, @@ -50,11 +49,10 @@ def setUp(self): ) self.max_length = 128 - def tearDown(self): + def teardown_method(self): gc.collect() backend_empty_cache(torch_device) gc.collect() - super().tearDown() @parameterized.expand(list(itertools.product(MODELS_TO_TEST, DPO_LOSS_TYPES, DPO_PRECOMPUTE_LOGITS))) def test_dpo_bare_model(self, model_id, loss_type, pre_compute_logits): diff --git a/tests/slow/test_grpo_slow.py b/tests/slow/test_grpo_slow.py index 3a453714daa..5c1164fa420 100644 --- a/tests/slow/test_grpo_slow.py +++ b/tests/slow/test_grpo_slow.py @@ -55,17 +55,15 @@ @pytest.mark.slow @require_torch_accelerator class GRPOTrainerSlowTester(TrlTestCase): - def setUp(self): - super().setUp() + def setup_method(self): self.train_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") self.eval_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="test") self.max_length = 128 - def tearDown(self): + def teardown_method(self): gc.collect() backend_empty_cache(torch_device) gc.collect() - super().tearDown() @parameterized.expand(MODELS_TO_TEST) @require_liger_kernel diff --git a/tests/slow/test_sft_slow.py b/tests/slow/test_sft_slow.py index 7e673a9457c..d0811f5db40 100755 --- a/tests/slow/test_sft_slow.py +++ b/tests/slow/test_sft_slow.py @@ -45,8 +45,7 @@ @require_torch_accelerator @require_peft class SFTTrainerSlowTester(TrlTestCase): - def setUp(self): - super().setUp() + def setup_method(self): self.train_dataset = load_dataset("stanfordnlp/imdb", split="train[:10%]") self.eval_dataset = load_dataset("stanfordnlp/imdb", split="test[:10%]") self.max_length = 128 @@ -58,11 +57,10 @@ def setUp(self): task_type="CAUSAL_LM", ) - def tearDown(self): + def teardown_method(self): gc.collect() backend_empty_cache(torch_device) gc.collect() - super().tearDown() @parameterized.expand(list(itertools.product(MODELS_TO_TEST, PACKING_OPTIONS))) def test_sft_trainer_str(self, model_name, packing): diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index 501641b5df9..448e11bed8f 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -67,8 +67,7 @@ def __init__(self, model, ref_model, args, train_dataset, eval_dataset, processi class WinRateCallbackTester(TrlTestCase): - def setUp(self): - super().setUp() + def setup_method(self): self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") self.ref_model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") @@ -226,8 +225,7 @@ def test_lora(self): class LogCompletionsCallbackTester(TrlTestCase): - def setUp(self): - super().setUp() + def setup_method(self): self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") self.tokenizer.pad_token = self.tokenizer.eos_token @@ -321,8 +319,7 @@ def test_basic_comet(self): @require_mergekit class MergeModelCallbackTester(TrlTestCase): - def setUp(self): - super().setUp() + def setup_method(self): self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") self.dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train") @@ -378,8 +375,7 @@ def test_every_checkpoint(self): class BEMACallbackTester(TrlTestCase): - def setUp(self): - super().setUp() + def setup_method(self): self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") self.tokenizer.pad_token = self.tokenizer.eos_token diff --git a/tests/test_collators.py b/tests/test_collators.py index d798f29d54d..3159184f558 100644 --- a/tests/test_collators.py +++ b/tests/test_collators.py @@ -21,8 +21,7 @@ class TestDataCollatorForPreference(TrlTestCase): - def setUp(self): - super().setUp() + def setup_method(self): self.collator = DataCollatorForPreference(pad_token_id=0) def assertTensorEqual(self, tensor1, tensor2): diff --git a/tests/test_core.py b/tests/test_core.py index 78e57c64b30..45449e5815b 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -25,8 +25,7 @@ class CoreTester(TrlTestCase): A wrapper class for testing core utils functions """ - def setUp(self): - super().setUp() + def setup_method(self): self.test_input = torch.Tensor([1, 2, 3, 4]) self.test_mask = torch.Tensor([0, 1, 1, 0]) self.test_input_unmasked = self.test_input[1:3] diff --git a/tests/test_cpo_trainer.py b/tests/test_cpo_trainer.py index c0edd771c5e..4f36be3bc84 100644 --- a/tests/test_cpo_trainer.py +++ b/tests/test_cpo_trainer.py @@ -26,8 +26,7 @@ class CPOTrainerTester(TrlTestCase): - def setUp(self): - super().setUp() + def setup_method(self): self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" self.model = AutoModelForCausalLM.from_pretrained(self.model_id) self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) diff --git a/tests/test_dataset_formatting.py b/tests/test_dataset_formatting.py index 78f2e7dfb1d..49b9b460a86 100644 --- a/tests/test_dataset_formatting.py +++ b/tests/test_dataset_formatting.py @@ -24,8 +24,7 @@ class DatasetFormattingTestCase(TrlTestCase): - def setUp(self): - super().setUp() + def setup_method(self): self.llama_tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-MistralForCausalLM-0.1") self.chatml_tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") @@ -118,8 +117,7 @@ def test_get_formatting_func_from_dataset_with_unknown_format(self): class SetupChatFormatTestCase(TrlTestCase): - def setUp(self): - super().setUp() + def setup_method(self): self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") # remove built-in chat_template to simulate a model having no chat_template diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index 58d04930b80..db0f96e9fb2 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -49,8 +49,7 @@ class TestTokenizeRow(TrlTestCase): - def setUp(self): - super().setUp() + def setup_method(self): # Set up the mock tokenizer with specific behaviors self.tokenizer = MagicMock(spec=PreTrainedTokenizerBase) self.tokenizer.bos_token_id = 0 @@ -154,8 +153,7 @@ def test_tokenize_row_with_truncation_and_special_tokens(self): class DPOTrainerTester(TrlTestCase): - def setUp(self): - super().setUp() + def setup_method(self): self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" self.model = AutoModelForCausalLM.from_pretrained(self.model_id) self.ref_model = AutoModelForCausalLM.from_pretrained(self.model_id) diff --git a/tests/test_gkd_trainer.py b/tests/test_gkd_trainer.py index 7bdec1d4553..a1b62cd993e 100644 --- a/tests/test_gkd_trainer.py +++ b/tests/test_gkd_trainer.py @@ -29,7 +29,7 @@ class TestGKDTrainer(TrlTestCase): @classmethod - def setUpClass(cls): + def setup_class(cls): model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" cls.tokenizer = AutoTokenizer.from_pretrained(model_id) cls.tokenizer.pad_token = cls.tokenizer.eos_token @@ -124,8 +124,7 @@ def test_generate_on_policy_outputs(self): class TestGeneralizedJSDLoss(TrlTestCase): - def setUp(self): - super().setUp() + def setup_method(self): self.batch_size = 2 self.seq_length = 3 self.vocab_size = 5 @@ -200,8 +199,7 @@ def test_zero_loss_for_identical_inputs(self): class GKDTrainerTester(TrlTestCase): - def setUp(self): - super().setUp() + def setup_method(self): self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" self.model = AutoModelForCausalLM.from_pretrained(self.model_id) self.teacher_model = AutoModelForCausalLM.from_pretrained(self.model_id) diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index 785d8426f68..30a99199ee9 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -1712,8 +1712,8 @@ def test_single_reward_model_with_single_processing_class(self): @pytest.mark.low_priority -class TestReplayBuffer(unittest.TestCase): - def setUp(self): +class TestReplayBuffer: + def setup_method(self): self.replay_buffer = ReplayBuffer(max_size=5) def test_add(self): @@ -1780,8 +1780,8 @@ def test_sample(self): @pytest.mark.low_priority -class TestUpdateWithReplayBuffer(unittest.TestCase): - def setUp(self): +class TestUpdateWithReplayBuffer: + def setup_method(self): config = GRPOWithReplayBufferConfig( replay_buffer_size=5, ) diff --git a/tests/test_kto_trainer.py b/tests/test_kto_trainer.py index da749abc62d..6e303ebbe83 100644 --- a/tests/test_kto_trainer.py +++ b/tests/test_kto_trainer.py @@ -27,8 +27,7 @@ class KTOTrainerTester(TrlTestCase): - def setUp(self): - super().setUp() + def setup_method(self): self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" self.model = AutoModelForCausalLM.from_pretrained(self.model_id) self.ref_model = AutoModelForCausalLM.from_pretrained(self.model_id) diff --git a/tests/test_modeling_geometric_mixture_wrapper.py b/tests/test_modeling_geometric_mixture_wrapper.py index 65553b79b77..7dcd89f757e 100644 --- a/tests/test_modeling_geometric_mixture_wrapper.py +++ b/tests/test_modeling_geometric_mixture_wrapper.py @@ -22,8 +22,7 @@ class TestGeometricMixtureWrapper(TrlTestCase): - def setUp(self): - super().setUp() + def setup_method(self): model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model = AutoModelForCausalLM.from_pretrained(model_id).to(self.device) diff --git a/tests/test_modeling_value_head.py b/tests/test_modeling_value_head.py index 0aa56901f82..a7bfd08eeb9 100644 --- a/tests/test_modeling_value_head.py +++ b/tests/test_modeling_value_head.py @@ -55,8 +55,7 @@ class VHeadModelTester(TrlTestCase): trl_model_class = None transformers_model_class = None - def setUp(self): - super().setUp() + def setup_method(self): self.device = "cuda" if torch.cuda.is_available() else "cpu" def test_value_head(self): @@ -189,10 +188,9 @@ class CausalLMValueHeadModelTester(BaseTester.VHeadModelTester, TrlTestCase): trl_model_class = AutoModelForCausalLMWithValueHead transformers_model_class = AutoModelForCausalLM - def tearDown(self): + def teardown_method(self): # free memory gc.collect() - super().tearDown() def test_inference(self): r""" @@ -303,10 +301,9 @@ class Seq2SeqValueHeadModelTester(BaseTester.VHeadModelTester, TrlTestCase): trl_model_class = AutoModelForSeq2SeqLMWithValueHead transformers_model_class = AutoModelForSeq2SeqLM - def tearDown(self): + def teardown_method(self): # free memory gc.collect() - super().tearDown() def test_inference(self): r""" @@ -409,8 +406,7 @@ def test_transformers_bf16_kwargs(self): class ReferenceModelTest(TrlTestCase): - def setUp(self): - super().setUp() + def setup_method(self): self.model = AutoModelForCausalLMWithValueHead.from_pretrained("trl-internal-testing/tiny-GPT2LMHeadModel") self.test_input = torch.tensor([[0, 1, 2, 3]]) self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=1) diff --git a/tests/test_nash_md_trainer.py b/tests/test_nash_md_trainer.py index 25caaff3d38..90df42b48a7 100644 --- a/tests/test_nash_md_trainer.py +++ b/tests/test_nash_md_trainer.py @@ -29,8 +29,7 @@ class TestNashMDTrainer(TrlTestCase): - def setUp(self): - super().setUp() + def setup_method(self): self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" self.model = AutoModelForCausalLM.from_pretrained(self.model_id) self.ref_model = AutoModelForCausalLM.from_pretrained(self.model_id) diff --git a/tests/test_online_dpo_trainer.py b/tests/test_online_dpo_trainer.py index 54e00447593..4552edafaef 100644 --- a/tests/test_online_dpo_trainer.py +++ b/tests/test_online_dpo_trainer.py @@ -36,8 +36,7 @@ class TestOnlineDPOTrainer(TrlTestCase): - def setUp(self): - super().setUp() + def setup_method(self): self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" self.model = AutoModelForCausalLM.from_pretrained(self.model_id) self.ref_model = AutoModelForCausalLM.from_pretrained(self.model_id) diff --git a/tests/test_orpo_trainer.py b/tests/test_orpo_trainer.py index 2f444eb4dfc..8ab8819675a 100644 --- a/tests/test_orpo_trainer.py +++ b/tests/test_orpo_trainer.py @@ -26,8 +26,7 @@ class ORPOTrainerTester(TrlTestCase): - def setUp(self): - super().setUp() + def setup_method(self): self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" self.model = AutoModelForCausalLM.from_pretrained(self.model_id) self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) diff --git a/tests/test_peft_models.py b/tests/test_peft_models.py index ac62174db00..b8bacb7cce2 100644 --- a/tests/test_peft_models.py +++ b/tests/test_peft_models.py @@ -33,8 +33,7 @@ @require_peft class PeftModelTester(TrlTestCase): - def setUp(self): - super().setUp() + def setup_method(self): self.causal_lm_model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" self.lora_config = LoraConfig( r=16, diff --git a/tests/test_ppo_trainer.py b/tests/test_ppo_trainer.py index 13354c36b29..c4c69d24e10 100644 --- a/tests/test_ppo_trainer.py +++ b/tests/test_ppo_trainer.py @@ -30,8 +30,7 @@ class TestPPOTrainer(TrlTestCase): - def setUp(self): - super().setUp() + def setup_method(self): # Set up the models and tokenizer using the test model self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" self.model = AutoModelForCausalLM.from_pretrained(self.model_id) diff --git a/tests/test_prm_trainer.py b/tests/test_prm_trainer.py index 2a8083d7492..f435c1295ab 100644 --- a/tests/test_prm_trainer.py +++ b/tests/test_prm_trainer.py @@ -31,8 +31,7 @@ class TestTokenizeRow(TrlTestCase): - def setUp(self): - super().setUp() + def setup_method(self): # Set up the mock tokenizer with specific behaviors self.tokenizer = MagicMock(spec=PreTrainedTokenizerBase) self.tokenizer.bos_token_id = 0 @@ -206,8 +205,7 @@ def test_tokenize_row_multi_token_separator(self): class PRMTrainerTester(TrlTestCase): - def setUp(self): - super().setUp() + def setup_method(self): model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" self.model = AutoModelForTokenClassification.from_pretrained(model_id) self.tokenizer = AutoTokenizer.from_pretrained(model_id) diff --git a/tests/test_rich_progress_callback.py b/tests/test_rich_progress_callback.py index d9069481263..d246b694b72 100644 --- a/tests/test_rich_progress_callback.py +++ b/tests/test_rich_progress_callback.py @@ -34,8 +34,7 @@ def forward(self, x): @require_rich class TestRichProgressCallback(TrlTestCase): - def setUp(self): - super().setUp() + def setup_method(self): self.dummy_model = DummyModel() self.dummy_train_dataset = Dataset.from_list([{"x": 1.0, "y": 2.0}] * 5) self.dummy_val_dataset = Dataset.from_list([{"x": 1.0, "y": 2.0}] * 101) diff --git a/tests/test_utils.py b/tests/test_utils.py index 0dfa00d06e9..d663328e244 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -173,8 +173,7 @@ def test_create_peft_config_use_peft_true(self): class TestDecodeAndStripPadding(TrlTestCase): - def setUp(self): - super().setUp() + def setup_method(self): self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") def test_example_with_padding(self): @@ -235,8 +234,7 @@ def test_val_none(self): class TestDataCollatorForChatML(TrlTestCase): - def setUp(self): - super().setUp() + def setup_method(self): # Initialize the tokenizer self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") if self.tokenizer.pad_token is None: @@ -322,8 +320,7 @@ def test_data_collator_for_chatml(self): class TestBatchGeneration(TrlTestCase): - def setUp(self): - super().setUp() + def setup_method(self): # Initialize the tokenizer self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" self.device = "cuda" if torch.cuda.is_available() else "cpu" diff --git a/tests/test_vllm_client_server.py b/tests/test_vllm_client_server.py index 62d22583091..23c2080289c 100644 --- a/tests/test_vllm_client_server.py +++ b/tests/test_vllm_client_server.py @@ -57,7 +57,7 @@ class TestVLLMClientServer(TrlTestCase): model_id = "Qwen/Qwen2.5-1.5B" @classmethod - def setUpClass(cls): + def setup_class(cls): # We want the server to run on accelerator 1, so we set VISIBLE_DEVICES to "1" env = os.environ.copy() VISIBLE_DEVICES = "ZE_AFFINITY_MASK" if torch_device == "xpu" else "CUDA_VISIBLE_DEVICES" @@ -115,9 +115,7 @@ def test_reset_prefix_cache(self): self.client.reset_prefix_cache() @classmethod - def tearDownClass(cls): - super().tearDownClass() - + def teardown_class(cls): # Close the client cls.client.close_communicator() @@ -133,7 +131,7 @@ class TestVLLMClientServerBaseURL(TrlTestCase): model_id = "Qwen/Qwen2.5-1.5B" @classmethod - def setUpClass(cls): + def setup_class(cls): # We want the server to run on accelerator 1, so we set VISIBLE_DEVICES to "1" env = os.environ.copy() VISIBLE_DEVICES = "ZE_AFFINITY_MASK" if torch_device == "xpu" else "CUDA_VISIBLE_DEVICES" @@ -191,9 +189,7 @@ def test_reset_prefix_cache(self): self.client.reset_prefix_cache() @classmethod - def tearDownClass(cls): - super().tearDownClass() - + def teardown_class(cls): # Close the client cls.client.close_communicator() @@ -208,7 +204,7 @@ class TestVLLMClientServerTP(TrlTestCase): model_id = "Qwen/Qwen2.5-1.5B" @classmethod - def setUpClass(cls): + def setup_class(cls): # We want the server to run on accelerator 1 and 2, so we set VISIBLE_DEVICES to "1,2" env = os.environ.copy() VISIBLE_DEVICES = "ZE_AFFINITY_MASK" if torch_device == "xpu" else "CUDA_VISIBLE_DEVICES" @@ -249,9 +245,7 @@ def test_reset_prefix_cache(self): self.client.reset_prefix_cache() @classmethod - def tearDownClass(cls): - super().tearDownClass() - + def teardown_class(cls): # Close the client cls.client.close_communicator() @@ -266,7 +260,7 @@ class TestVLLMClientServerDP(TrlTestCase): model_id = "Qwen/Qwen2.5-1.5B" @classmethod - def setUpClass(cls): + def setup_class(cls): # We want the server to run on accelerator 1 and 2, so we set VISIBLE_DEVICES to "1,2" env = os.environ.copy() VISIBLE_DEVICES = "ZE_AFFINITY_MASK" if torch_device == "xpu" else "CUDA_VISIBLE_DEVICES" @@ -307,9 +301,7 @@ def test_reset_prefix_cache(self): self.client.reset_prefix_cache() @classmethod - def tearDownClass(cls): - super().tearDownClass() - + def teardown_class(cls): # Close the client cls.client.close_communicator() @@ -326,7 +318,7 @@ class TestVLLMClientServerDeviceParameter(TrlTestCase): model_id = "Qwen/Qwen2.5-1.5B" @classmethod - def setUpClass(cls): + def setup_class(cls): # We want the server to run on accelerator 1, so we set VISIBLE_DEVICES to "1" env = os.environ.copy() VISIBLE_DEVICES = "ZE_AFFINITY_MASK" if torch_device == "xpu" else "CUDA_VISIBLE_DEVICES" @@ -380,9 +372,7 @@ def test_init_communicator_with_torch_device(self): client.close_communicator() @classmethod - def tearDownClass(cls): - super().tearDownClass() - + def teardown_class(cls): # vLLM x pytest (or Popen) seems not to handle process termination well. To avoid zombie processes, we need to # kill the server process and its children explicitly. kill_process(cls.server_process) diff --git a/tests/test_xpo_trainer.py b/tests/test_xpo_trainer.py index 7a69455df33..bbf51bf83a8 100644 --- a/tests/test_xpo_trainer.py +++ b/tests/test_xpo_trainer.py @@ -29,8 +29,7 @@ class TestXPOTrainer(TrlTestCase): - def setUp(self): - super().setUp() + def setup_method(self): self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" self.model = AutoModelForCausalLM.from_pretrained(self.model_id) self.ref_model = AutoModelForCausalLM.from_pretrained(self.model_id) From 8c30eec2f7c6878b7a7eed26f8f20d0b31863722 Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Wed, 1 Oct 2025 13:53:44 +0200 Subject: [PATCH 08/16] Replace unittest.skipUnless --- tests/testing_utils.py | 81 ++++++++---------------------------------- 1 file changed, 14 insertions(+), 67 deletions(-) diff --git a/tests/testing_utils.py b/tests/testing_utils.py index 61a7935c544..fdc91ccab81 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -15,7 +15,6 @@ import functools import random import signal -import unittest import warnings import psutil @@ -29,72 +28,20 @@ from trl.import_utils import is_joblib_available, is_llm_blender_available, is_mergekit_available, is_vllm_available -# transformers.testing_utils contains a require_bitsandbytes function, but relies on pytest markers which we don't use -# in our test suite. We therefore need to implement our own version of this function. -def require_bitsandbytes(test_case): - """ - Decorator marking a test that requires bitsandbytes. Skips the test if bitsandbytes is not available. - """ - return unittest.skipUnless(is_bitsandbytes_available(), "test requires bitsandbytes")(test_case) - - -def require_comet(test_case): - """ - Decorator marking a test that requires Comet. Skips the test if Comet is not available. - """ - return unittest.skipUnless(is_comet_available(), "test requires comet_ml")(test_case) - - -def require_llm_blender(test_case): - """ - Decorator marking a test that requires llm-blender. Skips the test if llm-blender is not available. - """ - return unittest.skipUnless(is_llm_blender_available(), "test requires llm-blender")(test_case) - - -def require_mergekit(test_case): - """ - Decorator marking a test that requires mergekit. Skips the test if mergekit is not available. - """ - return unittest.skipUnless(is_mergekit_available(), "test requires mergekit")(test_case) - - -def require_rich(test_case): - """ - Decorator marking a test that requires rich. Skips the test if rich is not available. - """ - return unittest.skipUnless(is_rich_available(), "test requires rich")(test_case) - - -def require_sklearn(test_case): - """ - Decorator marking a test that requires sklearn. Skips the test if sklearn is not available. - """ - return unittest.skipUnless(is_sklearn_available() and is_joblib_available(), "test requires sklearn")(test_case) - - -def require_vllm(test_case): - """ - Decorator marking a test that requires vllm. Skips the test if vllm is not available. - """ - return unittest.skipUnless(is_vllm_available(), "test requires vllm")(test_case) - - -def require_no_wandb(test_case): - """ - Decorator marking a test that requires no wandb. Skips the test if wandb is available. - """ - return unittest.skipUnless(not is_wandb_available(), "test requires no wandb")(test_case) - - -def require_3_accelerators(test_case): - """ - Decorator marking a test that requires at least 3 accelerators. Skips the test if 3 accelerators are not available. - """ - torch_accelerator_module = getattr(torch, torch_device, torch.cuda) - return unittest.skipUnless( - torch_accelerator_module.device_count() >= 3, f"test requires at least 3 {torch_device}s" - )(test_case) +require_bitsandbytes = pytest.mark.skipif(not is_bitsandbytes_available(), reason="test requires bitsandbytes") +require_comet = pytest.mark.skipif(not is_comet_available(), reason="test requires comet_ml") +require_llm_blender = pytest.mark.skipif(not is_llm_blender_available(), reason="test requires llm-blender") +require_mergekit = pytest.mark.skipif(not is_mergekit_available(), reason="test requires mergekit") +require_rich = pytest.mark.skipif(not is_rich_available(), reason="test requires rich") +require_sklearn = pytest.mark.skipif( + not (is_sklearn_available() and is_joblib_available()), reason="test requires sklearn" +) +require_vllm = pytest.mark.skipif(not is_vllm_available(), reason="test requires vllm") +require_no_wandb = pytest.mark.skipif(is_wandb_available(), reason="test requires no wandb") +require_3_accelerators = pytest.mark.skipif( + not (getattr(torch, torch_device, torch.cuda).device_count() >= 3), + reason=f"test requires at least 3 {torch_device}s", +) class RandomBinaryJudge(BaseBinaryJudge): From 4630ed4ae57645742c023fa53cd102426dadc1db Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Wed, 1 Oct 2025 14:03:25 +0200 Subject: [PATCH 09/16] Remove remaining TestCase --- tests/test_cli_utils.py | 3 +-- tests/test_data_utils.py | 2 +- tests/test_rewards.py | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/test_cli_utils.py b/tests/test_cli_utils.py index 708640c55fd..a7cda2fddf8 100644 --- a/tests/test_cli_utils.py +++ b/tests/test_cli_utils.py @@ -13,7 +13,6 @@ # limitations under the License. import tempfile -import unittest from dataclasses import dataclass from unittest.mock import mock_open, patch @@ -268,7 +267,7 @@ def test_subparsers_multiple_with_config_defaults(self, mock_yaml_load): assert result_args[0].arg2 == "config_value" # Default from config -class TestGetDataset(unittest.TestCase): +class TestGetDataset: def test_single_dataset_with_config(self): mixture_config = DatasetMixtureConfig( datasets=[DatasetConfig(path="trl-internal-testing/zen", name="standard_language_modeling")] diff --git a/tests/test_data_utils.py b/tests/test_data_utils.py index abb27258e3f..dcb10a59537 100644 --- a/tests/test_data_utils.py +++ b/tests/test_data_utils.py @@ -40,7 +40,7 @@ from .testing_utils import TrlTestCase -class PrepareMultimodalMessagesTester(unittest.TestCase): +class TestPrepareMultimodalMessages: def test_basic_user_assistant_conversation(self): """Test basic conversation with user and assistant messages.""" messages = [ diff --git a/tests/test_rewards.py b/tests/test_rewards.py index 21827b6b4ea..46fddb5b19d 100644 --- a/tests/test_rewards.py +++ b/tests/test_rewards.py @@ -63,7 +63,7 @@ def test_mixed_format(self): assert rewards == expected_rewards -class SoftOverlongPunishmentRewardTester(unittest.TestCase): +class TestSoftOverlongPunishmentReward: def test_soft_overlong_punishment_short_completion(self): """Test soft overlong punishment reward function with a short completion.""" # length 50, with max=100 and soft cache=20, reward should be 0. From d9e8bf51e87365d01f862699d1b987529860b727 Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Wed, 1 Oct 2025 14:21:39 +0200 Subject: [PATCH 10/16] Fix naming of tests --- tests/slow/test_dpo_slow.py | 2 +- tests/slow/test_grpo_slow.py | 2 +- tests/slow/test_sft_slow.py | 2 +- tests/test_bco_trainer.py | 2 +- tests/test_best_of_n_sampler.py | 2 +- tests/test_callbacks.py | 8 ++++---- tests/test_core.py | 2 +- tests/test_cpo_trainer.py | 2 +- tests/test_data_utils.py | 12 ++++++------ tests/test_dataset_formatting.py | 6 +++--- tests/test_dpo_trainer.py | 4 ++-- tests/test_gkd_trainer.py | 4 ++-- tests/test_grpo_trainer.py | 6 +++--- tests/test_kto_trainer.py | 2 +- tests/test_modeling_value_head.py | 8 ++++---- tests/test_online_dpo_trainer.py | 2 +- tests/test_orpo_trainer.py | 2 +- tests/test_peft_models.py | 2 +- tests/test_prm_trainer.py | 2 +- tests/test_reward_trainer.py | 2 +- tests/test_rewards.py | 2 +- tests/test_rloo_trainer.py | 2 +- tests/test_sft_trainer.py | 4 ++-- tests/test_trainers_args.py | 2 +- tests/test_utils.py | 12 ++++++------ 25 files changed, 48 insertions(+), 48 deletions(-) diff --git a/tests/slow/test_dpo_slow.py b/tests/slow/test_dpo_slow.py index 30f7ffacb18..498c71179ac 100644 --- a/tests/slow/test_dpo_slow.py +++ b/tests/slow/test_dpo_slow.py @@ -37,7 +37,7 @@ @pytest.mark.slow @require_torch_accelerator @require_peft -class DPOTrainerSlowTester(TrlTestCase): +class TestDPOTrainerSlow(TrlTestCase): def setup_method(self): self.dataset = load_dataset("trl-internal-testing/zen", "standard_preference") self.peft_config = LoraConfig( diff --git a/tests/slow/test_grpo_slow.py b/tests/slow/test_grpo_slow.py index 5c1164fa420..1e594d0aedd 100644 --- a/tests/slow/test_grpo_slow.py +++ b/tests/slow/test_grpo_slow.py @@ -54,7 +54,7 @@ @pytest.mark.slow @require_torch_accelerator -class GRPOTrainerSlowTester(TrlTestCase): +class TestGRPOTrainerSlow(TrlTestCase): def setup_method(self): self.train_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") self.eval_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="test") diff --git a/tests/slow/test_sft_slow.py b/tests/slow/test_sft_slow.py index d0811f5db40..114a9069943 100755 --- a/tests/slow/test_sft_slow.py +++ b/tests/slow/test_sft_slow.py @@ -44,7 +44,7 @@ @pytest.mark.slow @require_torch_accelerator @require_peft -class SFTTrainerSlowTester(TrlTestCase): +class TestSFTTrainerSlow(TrlTestCase): def setup_method(self): self.train_dataset = load_dataset("stanfordnlp/imdb", split="train[:10%]") self.eval_dataset = load_dataset("stanfordnlp/imdb", split="test[:10%]") diff --git a/tests/test_bco_trainer.py b/tests/test_bco_trainer.py index 91f905eb628..07ce1178837 100644 --- a/tests/test_bco_trainer.py +++ b/tests/test_bco_trainer.py @@ -33,7 +33,7 @@ from peft import LoraConfig -class BCOTrainerTester(TrlTestCase): +class TestBCOTrainer(TrlTestCase): @parameterized.expand( [ ("standard_preference",), diff --git a/tests/test_best_of_n_sampler.py b/tests/test_best_of_n_sampler.py index cf6810976ec..d52538c71d0 100644 --- a/tests/test_best_of_n_sampler.py +++ b/tests/test_best_of_n_sampler.py @@ -27,7 +27,7 @@ def queries_to_scores(list_of_strings): return [torch.rand(1).item() for _ in list_of_strings] -class BestOfNSamplerTester(TrlTestCase): +class TestBestOfNSampler(TrlTestCase): """ Tests the BestOfNSampler class """ diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index 448e11bed8f..c41d68b7efa 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -66,7 +66,7 @@ def __init__(self, model, ref_model, args, train_dataset, eval_dataset, processi self.ref_model = ref_model -class WinRateCallbackTester(TrlTestCase): +class TestWinRateCallback(TrlTestCase): def setup_method(self): self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") self.ref_model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") @@ -224,7 +224,7 @@ def test_lora(self): assert all(key in history_row and history_row[key] == expected_row[key] for key in expected_row) -class LogCompletionsCallbackTester(TrlTestCase): +class TestLogCompletionsCallback(TrlTestCase): def setup_method(self): self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") @@ -318,7 +318,7 @@ def test_basic_comet(self): @require_mergekit -class MergeModelCallbackTester(TrlTestCase): +class TestMergeModelCallback(TrlTestCase): def setup_method(self): self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") @@ -374,7 +374,7 @@ def test_every_checkpoint(self): assert os.path.isdir(merged_path), f"Merged folder does not exist in checkpoint {checkpoint}." -class BEMACallbackTester(TrlTestCase): +class TestBEMACallback(TrlTestCase): def setup_method(self): self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") diff --git a/tests/test_core.py b/tests/test_core.py index 45449e5815b..85d99615be9 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -20,7 +20,7 @@ from .testing_utils import TrlTestCase -class CoreTester(TrlTestCase): +class TestCore(TrlTestCase): """ A wrapper class for testing core utils functions """ diff --git a/tests/test_cpo_trainer.py b/tests/test_cpo_trainer.py index 4f36be3bc84..d2c926a6735 100644 --- a/tests/test_cpo_trainer.py +++ b/tests/test_cpo_trainer.py @@ -25,7 +25,7 @@ from .testing_utils import TrlTestCase -class CPOTrainerTester(TrlTestCase): +class TestCPOTrainer(TrlTestCase): def setup_method(self): self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" self.model = AutoModelForCausalLM.from_pretrained(self.model_id) diff --git a/tests/test_data_utils.py b/tests/test_data_utils.py index dcb10a59537..324680e6331 100644 --- a/tests/test_data_utils.py +++ b/tests/test_data_utils.py @@ -147,7 +147,7 @@ def test_mixed_prepared_and_unprepared_messages(self): assert messages == expected -class IsConversationalTester(TrlTestCase): +class TestIsConversational(TrlTestCase): conversational_examples = [ { # Language modeling "messages": [ @@ -257,7 +257,7 @@ def test_non_conversational(self, example): assert not is_conversational(example) -class IsConversationalFromValueTester(TrlTestCase): +class TestIsConversationalFromValue(TrlTestCase): def test_positive_1(self): example = { "conversations": [ @@ -281,7 +281,7 @@ def test_negative_2(self): assert not is_conversational_from_value(example) -class ApplyChatTemplateTester(TrlTestCase): +class TestApplyChatTemplate(TrlTestCase): tokenizers = [ "trl-internal-testing/tiny-CohereForCausalLM", "trl-internal-testing/tiny-DbrxForCausalLM", @@ -429,7 +429,7 @@ def get_current_temperature(location: str): assert "get_current_temperature" not in result_without_tools["prompt"] -class ApplyChatTemplateHarmonyTester(TrlTestCase): +class TestApplyChatTemplateHarmony(TrlTestCase): def test_language_modeling(self): messages = { "messages": [ @@ -655,7 +655,7 @@ def test_unpaired_preference(self): assert output["label"] -class UnpairPreferenceDatasetTester(TrlTestCase): +class TestUnpairPreferenceDataset(TrlTestCase): paired_dataset = Dataset.from_dict( { "prompt": ["The sky is", "The sun is"], @@ -717,7 +717,7 @@ def test_maybe_unpair_preference_dataset_dict_already_paired(self): ) -class ExtractPromptTester(TrlTestCase): +class TestExtractPrompt(TrlTestCase): example_implicit_prompt_conversational = { "chosen": [ {"role": "user", "content": "What color is the sky?"}, diff --git a/tests/test_dataset_formatting.py b/tests/test_dataset_formatting.py index 49b9b460a86..80f65f964de 100644 --- a/tests/test_dataset_formatting.py +++ b/tests/test_dataset_formatting.py @@ -23,7 +23,7 @@ from .testing_utils import TrlTestCase -class DatasetFormattingTestCase(TrlTestCase): +class TestDatasetFormatting(TrlTestCase): def setup_method(self): self.llama_tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-MistralForCausalLM-0.1") self.chatml_tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") @@ -116,7 +116,7 @@ def test_get_formatting_func_from_dataset_with_unknown_format(self): assert formatting_func is None -class SetupChatFormatTestCase(TrlTestCase): +class TestSetupChatFormat(TrlTestCase): def setup_method(self): self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") @@ -156,7 +156,7 @@ def test_example_with_setup_model(self): ) -class CloneChatTemplateTestCase(TrlTestCase): +class TestCloneChatTemplate(TrlTestCase): def test_clone(self): # This tokenizer doesn't have a chat_template by default tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-BloomForCausalLM") diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index db0f96e9fb2..c7e5ebac50b 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -152,7 +152,7 @@ def test_tokenize_row_with_truncation_and_special_tokens(self): } -class DPOTrainerTester(TrlTestCase): +class TestDPOTrainer(TrlTestCase): def setup_method(self): self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" self.model = AutoModelForCausalLM.from_pretrained(self.model_id) @@ -1406,7 +1406,7 @@ def test_train_with_iterable_dataset(self): @require_vision -class DPOVisionTrainerTester(TrlTestCase): +class TestDPOVisionTrainer(TrlTestCase): @parameterized.expand( [ # ("trl-internal-testing/tiny-Idefics2ForConditionalGeneration",), device issue from transformers, see https://github.com/huggingface/transformers/pull/39975 diff --git a/tests/test_gkd_trainer.py b/tests/test_gkd_trainer.py index a1b62cd993e..b311ce2b0b6 100644 --- a/tests/test_gkd_trainer.py +++ b/tests/test_gkd_trainer.py @@ -27,7 +27,7 @@ from .testing_utils import TrlTestCase -class TestGKDTrainer(TrlTestCase): +class TestGKDTrainerGenerateOnPolicy(TrlTestCase): @classmethod def setup_class(cls): model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" @@ -198,7 +198,7 @@ def test_zero_loss_for_identical_inputs(self): assert round(abs(loss.item() - 0), 6) == 0 -class GKDTrainerTester(TrlTestCase): +class TestGKDTrainer(TrlTestCase): def setup_method(self): self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" self.model = AutoModelForCausalLM.from_pretrained(self.model_id) diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index 30a99199ee9..3b17b3554b7 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -43,7 +43,7 @@ from peft import LoraConfig, PeftModel -class GetHighEntropyMaskTester(TrlTestCase): +class TestGetHighEntropyMask(TrlTestCase): def get_high_entropy_mask(self, entropies, mask, threshold): """Helper method to test the get_high_entropy_mask functionality.""" # Create a mock trainer with minimal setup @@ -115,7 +115,7 @@ def test_compute_entropy_all_masked(self): torch.testing.assert_close(entropy_mask, expected_mask) -class GRPOTrainerTester(TrlTestCase): +class TestGRPOTrainer(TrlTestCase): def test_init_minimal(self): # Test that GRPOTrainer can be instantiated with only model, reward_model and train_dataset dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") @@ -1975,7 +1975,7 @@ def custom_reward_func(completions, **kwargs): assert not torch.equal(param, new_param), f"Parameter {n} has not changed." -class GSPOTokenTrainerTester(TrlTestCase): +class TestGSPOTokenTrainer(TrlTestCase): def test_training(self): dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") diff --git a/tests/test_kto_trainer.py b/tests/test_kto_trainer.py index 6e303ebbe83..7cf9466eafa 100644 --- a/tests/test_kto_trainer.py +++ b/tests/test_kto_trainer.py @@ -26,7 +26,7 @@ from .testing_utils import TrlTestCase, require_no_wandb -class KTOTrainerTester(TrlTestCase): +class TestKTOTrainer(TrlTestCase): def setup_method(self): self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" self.model = AutoModelForCausalLM.from_pretrained(self.model_id) diff --git a/tests/test_modeling_value_head.py b/tests/test_modeling_value_head.py index a7bfd08eeb9..65716cde99f 100644 --- a/tests/test_modeling_value_head.py +++ b/tests/test_modeling_value_head.py @@ -50,7 +50,7 @@ class BaseTester: - class VHeadModelTester(TrlTestCase): + class TestVHeadModel(TrlTestCase): all_model_names = None trl_model_class = None transformers_model_class = None @@ -179,7 +179,7 @@ def test_from_save_transformers(self): ) -class CausalLMValueHeadModelTester(BaseTester.VHeadModelTester, TrlTestCase): +class TestCausalLMValueHeadModel(BaseTester.VHeadModelTester, TrlTestCase): """ Testing suite for v-head models. """ @@ -292,7 +292,7 @@ def test_push_to_hub(self): ) -class Seq2SeqValueHeadModelTester(BaseTester.VHeadModelTester, TrlTestCase): +class TestSeq2SeqValueHeadModel(BaseTester.VHeadModelTester, TrlTestCase): """ Testing suite for v-head models. """ @@ -405,7 +405,7 @@ def test_transformers_bf16_kwargs(self): _ = trl_model(input_ids=dummy_input, decoder_input_ids=dummy_input) -class ReferenceModelTest(TrlTestCase): +class TestReferenceModel(TrlTestCase): def setup_method(self): self.model = AutoModelForCausalLMWithValueHead.from_pretrained("trl-internal-testing/tiny-GPT2LMHeadModel") self.test_input = torch.tensor([[0, 1, 2, 3]]) diff --git a/tests/test_online_dpo_trainer.py b/tests/test_online_dpo_trainer.py index 4552edafaef..981ed9dc47f 100644 --- a/tests/test_online_dpo_trainer.py +++ b/tests/test_online_dpo_trainer.py @@ -482,7 +482,7 @@ def simple_reward_func(prompts, completions, completion_ids, **kwargs): @require_vision -class OnlineDPOVisionTrainerTester(TrlTestCase): +class TestOnlineDPOVisionTrainer(TrlTestCase): @parameterized.expand( [ ("trl-internal-testing/tiny-Idefics2ForConditionalGeneration",), diff --git a/tests/test_orpo_trainer.py b/tests/test_orpo_trainer.py index 8ab8819675a..07adcd0d651 100644 --- a/tests/test_orpo_trainer.py +++ b/tests/test_orpo_trainer.py @@ -25,7 +25,7 @@ from .testing_utils import TrlTestCase -class ORPOTrainerTester(TrlTestCase): +class TestORPOTrainer(TrlTestCase): def setup_method(self): self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" self.model = AutoModelForCausalLM.from_pretrained(self.model_id) diff --git a/tests/test_peft_models.py b/tests/test_peft_models.py index b8bacb7cce2..9d946db923c 100644 --- a/tests/test_peft_models.py +++ b/tests/test_peft_models.py @@ -32,7 +32,7 @@ @require_peft -class PeftModelTester(TrlTestCase): +class TestPeftModel(TrlTestCase): def setup_method(self): self.causal_lm_model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" self.lora_config = LoraConfig( diff --git a/tests/test_prm_trainer.py b/tests/test_prm_trainer.py index f435c1295ab..52f4da21df2 100644 --- a/tests/test_prm_trainer.py +++ b/tests/test_prm_trainer.py @@ -204,7 +204,7 @@ def test_tokenize_row_multi_token_separator(self): } -class PRMTrainerTester(TrlTestCase): +class TestPRMTrainer(TrlTestCase): def setup_method(self): model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" self.model = AutoModelForTokenClassification.from_pretrained(model_id) diff --git a/tests/test_reward_trainer.py b/tests/test_reward_trainer.py index ba8300c78ff..9f86c8a9111 100644 --- a/tests/test_reward_trainer.py +++ b/tests/test_reward_trainer.py @@ -108,7 +108,7 @@ def test_collate_with_margin(self): torch.testing.assert_close(result["margin"], torch.tensor([0.1, 0.2])) -class RewardTrainerTester(TrlTestCase): +class TestRewardTrainer(TrlTestCase): @parameterized.expand( [ ("trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",), diff --git a/tests/test_rewards.py b/tests/test_rewards.py index 46fddb5b19d..0f584d1b58a 100644 --- a/tests/test_rewards.py +++ b/tests/test_rewards.py @@ -19,7 +19,7 @@ from .testing_utils import TrlTestCase -class ThinkFormatRewardTester(TrlTestCase): +class TestThinkFormatReward(TrlTestCase): def test_valid_format(self): completions = [ "This is my reasoning.This is my answer.", # Simple, one-line reasoning diff --git a/tests/test_rloo_trainer.py b/tests/test_rloo_trainer.py index 23f709d466a..18b402bd940 100644 --- a/tests/test_rloo_trainer.py +++ b/tests/test_rloo_trainer.py @@ -37,7 +37,7 @@ from peft import LoraConfig, PeftModel -class RLOOTrainerTester(TrlTestCase): +class TestRLOOTrainer(TrlTestCase): def test_init_minimal(self): # Test that RLOOTrainer can be instantiated with only model, reward_model and train_dataset dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index b60ea48bfaa..d5cbcc07b76 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -33,7 +33,7 @@ from peft import LoraConfig, PeftModel, PromptEncoderConfig, TaskType, get_peft_model -class DFTLossTester(TrlTestCase): +class TestDFTLoss(TrlTestCase): def test_dft_loss(self): batch_size = 2 seq_len = 3 @@ -239,7 +239,7 @@ def test_multiple_examples(self): assert torch.equal(result[1], torch.arange(3)) -class SFTTrainerTester(TrlTestCase): +class TestSFTTrainer(TrlTestCase): @parameterized.expand( [ ("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",), diff --git a/tests/test_trainers_args.py b/tests/test_trainers_args.py index b9a809ba071..b76110d5f17 100644 --- a/tests/test_trainers_args.py +++ b/tests/test_trainers_args.py @@ -44,7 +44,7 @@ from .testing_utils import TrlTestCase, require_sklearn -class TrainerArgTester(TrlTestCase): +class TestTrainerArg(TrlTestCase): @require_sklearn def test_bco(self): model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" diff --git a/tests/test_utils.py b/tests/test_utils.py index d663328e244..9bf457b590e 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -534,7 +534,7 @@ def test_no_tensors(self): assert torch.equal(new_mask, expected_mask) -class RepeatRandomSamplerTester(TrlTestCase): +class TestRepeatRandomSampler(TrlTestCase): def test_sampler(self): dataset = ["a", "b", "c", "d", "e", "f", "g"] sampler = RepeatSampler(dataset, mini_repeat_count=2) @@ -841,7 +841,7 @@ def test_selective_log_softmax(self, dtype): torch.testing.assert_close(actual_output, expected_output, rtol=1e-5, atol=1e-5) -class ShuffleSequenceDictTester(TrlTestCase): +class TestShuffleSequenceDict(TrlTestCase): def test_shuffle_preserves_shape(self): x = torch.arange(6).reshape(3, 2) y = torch.arange(3).reshape(3, 1) @@ -906,7 +906,7 @@ def test_shuffle_with_list(self): pytest.fail("Unexpected x row in shuffled output.") -class SplitTensorDictTester(TrlTestCase): +class TestSplitTensorDict(TrlTestCase): def test_split_equal_chunks(self): x = torch.arange(12).reshape(6, 2) y = torch.arange(6).reshape(6, 1) @@ -946,7 +946,7 @@ def test_with_scalar(self): assert torch.equal(result[i]["y"], torch.tensor(1)) -class SplitPixelValuesByGridTester(TrlTestCase): +class TestSplitPixelValuesByGrid(TrlTestCase): def test_split_correctly_0(self): batch = { "image_grid_thw": torch.tensor([[1, 2, 2], [1, 2, 2]]), @@ -1010,7 +1010,7 @@ def test_multi_images(self): assert torch.equal(result["image_grid_thw"][1], torch.tensor([[1, 2, 2], [1, 2, 1]])) -class TruncateWithProtectedTokensTester(TrlTestCase): +class TestTruncateWithProtectedTokens(TrlTestCase): def test_basic_example(self): """Test the basic example from the problem description.""" prompt_ids = [1, 2, 3, 4, 5] @@ -1088,7 +1088,7 @@ def test_order_preservation(self): assert new_ids == expected_ids -class UnsplitPixelValuesByGridTester(TrlTestCase): +class TestUnsplitPixelValuesByGrid(TrlTestCase): def test_unsplit_correctly(self): pixel_values = [torch.randn(4, 5), torch.randn(2, 5)] pixel_values_merged = torch.cat(pixel_values, dim=0) From 01d1262eec50d15525075df4c71125bebc2972e4 Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Wed, 1 Oct 2025 14:59:55 +0200 Subject: [PATCH 11/16] Fix naming of subclass --- tests/test_modeling_value_head.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_modeling_value_head.py b/tests/test_modeling_value_head.py index 65716cde99f..95445b31421 100644 --- a/tests/test_modeling_value_head.py +++ b/tests/test_modeling_value_head.py @@ -50,7 +50,7 @@ class BaseTester: - class TestVHeadModel(TrlTestCase): + class VHeadModelTester(TrlTestCase): all_model_names = None trl_model_class = None transformers_model_class = None From 309e76d40f1f05f0138ab32b5a809be29979c6dc Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Wed, 1 Oct 2025 15:12:03 +0200 Subject: [PATCH 12/16] Replace assertLogs --- tests/test_grpo_trainer.py | 6 +++--- tests/test_rloo_trainer.py | 6 +++--- tests/test_utils.py | 6 +++--- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index 3b17b3554b7..80d5eb97ec3 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -1027,7 +1027,7 @@ def test_training_with_mask_truncated_completions_all_masked(self): new_param = trainer.model.get_parameter(n) assert torch.equal(param, new_param), f"Parameter {n} has changed." - def test_warning_raised_all_rewards_none(self): + def test_warning_raised_all_rewards_none(self, caplog): """Test that a proper warning is raised when all rewards are None.""" dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") @@ -1050,11 +1050,11 @@ def always_none_reward_func(completions, **kwargs): train_dataset=dataset, ) - with self.assertLogs("trl.trainer.grpo_trainer", level="WARNING") as cm: + with caplog.at_level("WARNING", logger="trl.trainer.grpo_trainer"): trainer.train() expected_warning = "All reward functions returned None for the following kwargs:" - assert expected_warning in cm.output[0] + assert expected_warning in caplog.text def test_training_num_generations_larger_than_batch_size(self): dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") diff --git a/tests/test_rloo_trainer.py b/tests/test_rloo_trainer.py index 18b402bd940..a8554414dff 100644 --- a/tests/test_rloo_trainer.py +++ b/tests/test_rloo_trainer.py @@ -882,7 +882,7 @@ def test_training_with_mask_truncated_completions_all_masked(self): new_param = trainer.model.get_parameter(n) assert torch.equal(param, new_param), f"Parameter {n} has changed." - def test_warning_raised_all_rewards_none(self): + def test_warning_raised_all_rewards_none(self, caplog): """Test that a proper warning is raised when all rewards are None.""" dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") @@ -905,11 +905,11 @@ def always_none_reward_func(completions, **kwargs): train_dataset=dataset, ) - with self.assertLogs("trl.trainer.rloo_trainer", level="WARNING") as cm: + with caplog.at_level("WARNING", logger="trl.trainer.rloo_trainer"): trainer.train() expected_warning = "All reward functions returned None for the following kwargs:" - assert expected_warning in cm.output[0] + assert expected_warning in caplog.text def test_training_num_generations_larger_than_batch_size(self): dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") diff --git a/tests/test_utils.py b/tests/test_utils.py index 9bf457b590e..cdedcf95ad9 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -422,7 +422,7 @@ def test_token_classification_task_with_ignored_tokens_1(self): result = compute_accuracy(eval_pred) assert round(abs(result["accuracy"] - expected_accuracy), 7) == 0 - def test_rewards_comparison_task(self): + def test_rewards_comparison_task(self, caplog): eval_pred = ( np.array( [ @@ -435,7 +435,7 @@ def test_rewards_comparison_task(self): ) expected_accuracy = 0.5 # 1 match, 1 mismatch, 1 equal (ignored) - with self.assertLogs("trl.trainer.utils", level="WARNING") as cm: + with caplog.at_level("WARNING", logger="trl.trainer.utils"): result = compute_accuracy(eval_pred) assert round(abs(result["accuracy"] - expected_accuracy), 7) == 0 @@ -443,7 +443,7 @@ def test_rewards_comparison_task(self): "There are 1 out of 3 instances where the predictions for both options are equal. " "These instances are ignored in the accuracy computation." ) - assert expected_warning in cm.output[0] + assert expected_warning in caplog.text class TestFlushLeft(TrlTestCase): From 0d7815cd60b63a12f0ea1c1eb7069f68a77b5640 Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Wed, 1 Oct 2025 15:50:32 +0200 Subject: [PATCH 13/16] Revert "Set CI for debugging" This reverts commit e2d1e61316ced2066aeb8b95fd09809eba40f495. --- .github/workflows/tests.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 4c21a98f4f9..4231ef227ec 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -21,7 +21,7 @@ jobs: check_code_quality: name: Check code quality runs-on: ubuntu-latest - # if: github.event.pull_request.draft == false + if: github.event.pull_request.draft == false steps: - uses: actions/checkout@v4 - name: Set up Python 3.12 @@ -36,7 +36,7 @@ jobs: name: Tests strategy: matrix: - python-version: ['3.10'] + python-version: ['3.9', '3.10', '3.11', '3.12', '3.13'] fail-fast: false runs-on: group: aws-g4dn-2xlarge @@ -46,7 +46,7 @@ jobs: defaults: run: shell: bash - # if: github.event.pull_request.draft == false + if: github.event.pull_request.draft == false steps: - name: Git checkout uses: actions/checkout@v4 From 9428f4f564774de9ef03857ec6113e05fa1fa165 Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Wed, 1 Oct 2025 19:32:18 +0200 Subject: [PATCH 14/16] Replace require_peft --- tests/slow/test_dpo_slow.py | 4 ++-- tests/slow/test_grpo_slow.py | 3 +-- tests/slow/test_sft_slow.py | 3 +-- tests/test_activation_offloading.py | 4 ++-- tests/test_bco_trainer.py | 3 +-- tests/test_callbacks.py | 5 ++--- tests/test_cpo_trainer.py | 3 +-- tests/test_dpo_trainer.py | 3 +-- tests/test_grpo_trainer.py | 4 ++-- tests/test_kto_trainer.py | 4 ++-- tests/test_nash_md_trainer.py | 3 +-- tests/test_online_dpo_trainer.py | 4 ++-- tests/test_orpo_trainer.py | 3 +-- tests/test_peft_models.py | 7 ++----- tests/test_ppo_trainer.py | 3 +-- tests/test_prm_trainer.py | 3 +-- tests/test_reward_trainer.py | 3 +-- tests/test_rloo_trainer.py | 4 ++-- tests/test_sft_trainer.py | 4 ++-- tests/test_utils.py | 3 +-- tests/test_xpo_trainer.py | 3 +-- tests/testing_utils.py | 3 ++- 22 files changed, 32 insertions(+), 47 deletions(-) diff --git a/tests/slow/test_dpo_slow.py b/tests/slow/test_dpo_slow.py index 498c71179ac..26feb388c6b 100644 --- a/tests/slow/test_dpo_slow.py +++ b/tests/slow/test_dpo_slow.py @@ -21,12 +21,12 @@ from datasets import load_dataset from parameterized import parameterized from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig -from transformers.testing_utils import backend_empty_cache, require_peft, require_torch_accelerator, torch_device +from transformers.testing_utils import backend_empty_cache, require_torch_accelerator, torch_device from transformers.utils import is_peft_available from trl import DPOConfig, DPOTrainer -from ..testing_utils import TrlTestCase, require_bitsandbytes +from ..testing_utils import TrlTestCase, require_bitsandbytes, require_peft from .testing_constants import DPO_LOSS_TYPES, DPO_PRECOMPUTE_LOGITS, GRADIENT_CHECKPOINTING_KWARGS, MODELS_TO_TEST diff --git a/tests/slow/test_grpo_slow.py b/tests/slow/test_grpo_slow.py index 63eecdf1398..17798d10b11 100644 --- a/tests/slow/test_grpo_slow.py +++ b/tests/slow/test_grpo_slow.py @@ -35,7 +35,6 @@ backend_empty_cache, require_flash_attn, require_liger_kernel, - require_peft, require_torch_accelerator, torch_device, ) @@ -44,7 +43,7 @@ from trl import GRPOConfig, GRPOTrainer from trl.trainer.utils import get_kbit_device_map -from ..testing_utils import TrlTestCase, require_bitsandbytes, require_vllm +from ..testing_utils import TrlTestCase, require_bitsandbytes, require_peft, require_vllm from .testing_constants import MODELS_TO_TEST diff --git a/tests/slow/test_sft_slow.py b/tests/slow/test_sft_slow.py index 114a9069943..b6928b697c7 100755 --- a/tests/slow/test_sft_slow.py +++ b/tests/slow/test_sft_slow.py @@ -24,7 +24,6 @@ from transformers.testing_utils import ( backend_empty_cache, require_liger_kernel, - require_peft, require_torch_accelerator, require_torch_multi_accelerator, torch_device, @@ -33,7 +32,7 @@ from trl import SFTConfig, SFTTrainer -from ..testing_utils import TrlTestCase, require_bitsandbytes +from ..testing_utils import TrlTestCase, require_bitsandbytes, require_peft from .testing_constants import DEVICE_MAP_OPTIONS, GRADIENT_CHECKPOINTING_KWARGS, MODELS_TO_TEST, PACKING_OPTIONS diff --git a/tests/test_activation_offloading.py b/tests/test_activation_offloading.py index 6c9ae24b8ae..d1a9ea921f5 100644 --- a/tests/test_activation_offloading.py +++ b/tests/test_activation_offloading.py @@ -16,12 +16,12 @@ import torch from torch import nn from transformers import AutoModelForCausalLM -from transformers.testing_utils import require_peft, require_torch_accelerator, torch_device +from transformers.testing_utils import require_torch_accelerator, torch_device from transformers.utils import is_peft_available from trl.models.activation_offloading import NoOpManager, OffloadActivations -from .testing_utils import TrlTestCase +from .testing_utils import TrlTestCase, require_peft if is_peft_available(): diff --git a/tests/test_bco_trainer.py b/tests/test_bco_trainer.py index 30214a6f781..7b7f0414438 100644 --- a/tests/test_bco_trainer.py +++ b/tests/test_bco_trainer.py @@ -20,13 +20,12 @@ from datasets import load_dataset from parameterized import parameterized from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer -from transformers.testing_utils import require_peft from transformers.utils import is_peft_available from trl import BCOConfig, BCOTrainer from trl.trainer.bco_trainer import _process_tokens, _tokenize -from .testing_utils import TrlTestCase, require_no_wandb, require_sklearn +from .testing_utils import TrlTestCase, require_no_wandb, require_peft, require_sklearn if is_peft_available(): diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index c41d68b7efa..316c9b35ae1 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -18,11 +18,10 @@ from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, Trainer, TrainingArguments -from transformers.testing_utils import require_peft, require_wandb +from transformers.testing_utils import require_wandb from transformers.trainer_utils import get_last_checkpoint from transformers.utils import is_peft_available -from tests.testing_utils import require_comet, require_mergekit from trl import ( BasePairwiseJudge, BEMACallback, @@ -34,7 +33,7 @@ ) from trl.mergekit_utils import MergeConfig -from .testing_utils import TrlTestCase +from .testing_utils import TrlTestCase, require_comet, require_mergekit, require_peft if is_peft_available(): diff --git a/tests/test_cpo_trainer.py b/tests/test_cpo_trainer.py index d2c926a6735..56792f608dc 100644 --- a/tests/test_cpo_trainer.py +++ b/tests/test_cpo_trainer.py @@ -17,12 +17,11 @@ from datasets import load_dataset from parameterized import parameterized from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer -from transformers.testing_utils import require_peft from trl import CPOConfig, CPOTrainer from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE -from .testing_utils import TrlTestCase +from .testing_utils import TrlTestCase, require_peft class TestCPOTrainer(TrlTestCase): diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index 8455909c242..1dcecdb79f4 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -34,14 +34,13 @@ from transformers.testing_utils import ( get_device_properties, require_liger_kernel, - require_peft, require_torch_gpu_if_bnb_not_multi_backend_enabled, require_vision, ) from trl import DPOConfig, DPOTrainer, FDivergenceType -from .testing_utils import TrlTestCase, require_bitsandbytes, require_no_wandb +from .testing_utils import TrlTestCase, require_bitsandbytes, require_no_wandb, require_peft if is_vision_available(): diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index e3fa06eeb58..1e504ffe4a8 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -25,7 +25,7 @@ AutoModelForSequenceClassification, AutoTokenizer, ) -from transformers.testing_utils import require_liger_kernel, require_peft, require_vision +from transformers.testing_utils import require_liger_kernel, require_vision from transformers.utils import is_peft_available from trl import GRPOConfig, GRPOTrainer @@ -36,7 +36,7 @@ ) from trl.experimental.gspo_token import GRPOTrainer as GSPOTokenTrainer -from .testing_utils import TrlTestCase, require_vllm +from .testing_utils import TrlTestCase, require_peft, require_vllm if is_peft_available(): diff --git a/tests/test_kto_trainer.py b/tests/test_kto_trainer.py index c51cf5a0342..e2c325149f2 100644 --- a/tests/test_kto_trainer.py +++ b/tests/test_kto_trainer.py @@ -18,12 +18,12 @@ from datasets import load_dataset from parameterized import parameterized from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer -from transformers.testing_utils import require_liger_kernel, require_peft +from transformers.testing_utils import require_liger_kernel from trl import KTOConfig, KTOTrainer from trl.trainer.kto_trainer import _get_kl_dataset, _process_tokens, _tokenize -from .testing_utils import TrlTestCase, require_no_wandb +from .testing_utils import TrlTestCase, require_no_wandb, require_peft class TestKTOTrainer(TrlTestCase): diff --git a/tests/test_nash_md_trainer.py b/tests/test_nash_md_trainer.py index 90df42b48a7..d6026e73443 100644 --- a/tests/test_nash_md_trainer.py +++ b/tests/test_nash_md_trainer.py @@ -16,12 +16,11 @@ from datasets import load_dataset from parameterized import parameterized from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer -from transformers.testing_utils import require_peft from transformers.utils import is_peft_available from trl import NashMDConfig, NashMDTrainer -from .testing_utils import RandomPairwiseJudge, TrlTestCase, require_llm_blender +from .testing_utils import RandomPairwiseJudge, TrlTestCase, require_llm_blender, require_peft if is_peft_available(): diff --git a/tests/test_online_dpo_trainer.py b/tests/test_online_dpo_trainer.py index 33bc7bdd3b8..9d2e2ccf8da 100644 --- a/tests/test_online_dpo_trainer.py +++ b/tests/test_online_dpo_trainer.py @@ -18,12 +18,12 @@ from packaging.version import Version from parameterized import parameterized from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer -from transformers.testing_utils import require_peft, require_torch_accelerator, require_vision +from transformers.testing_utils import require_torch_accelerator, require_vision from transformers.utils import is_peft_available, is_vision_available from trl import OnlineDPOConfig, OnlineDPOTrainer -from .testing_utils import RandomPairwiseJudge, TrlTestCase, require_llm_blender, require_vllm +from .testing_utils import RandomPairwiseJudge, TrlTestCase, require_llm_blender, require_peft, require_vllm if is_peft_available(): diff --git a/tests/test_orpo_trainer.py b/tests/test_orpo_trainer.py index 07adcd0d651..dedfc4c36c9 100644 --- a/tests/test_orpo_trainer.py +++ b/tests/test_orpo_trainer.py @@ -17,12 +17,11 @@ from datasets import load_dataset from parameterized import parameterized from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer -from transformers.testing_utils import require_peft from trl import ORPOConfig, ORPOTrainer from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE -from .testing_utils import TrlTestCase +from .testing_utils import TrlTestCase, require_peft class TestORPOTrainer(TrlTestCase): diff --git a/tests/test_peft_models.py b/tests/test_peft_models.py index 9d946db923c..508ad175565 100644 --- a/tests/test_peft_models.py +++ b/tests/test_peft_models.py @@ -16,15 +16,12 @@ import torch from transformers import AutoModelForCausalLM -from transformers.testing_utils import ( - require_peft, - require_torch_gpu_if_bnb_not_multi_backend_enabled, -) +from transformers.testing_utils import require_torch_gpu_if_bnb_not_multi_backend_enabled from transformers.utils import is_peft_available from trl import AutoModelForCausalLMWithValueHead -from .testing_utils import TrlTestCase +from .testing_utils import TrlTestCase, require_peft if is_peft_available(): diff --git a/tests/test_ppo_trainer.py b/tests/test_ppo_trainer.py index c4c69d24e10..6e62e742115 100644 --- a/tests/test_ppo_trainer.py +++ b/tests/test_ppo_trainer.py @@ -16,13 +16,12 @@ import torch from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer -from transformers.testing_utils import require_peft from transformers.utils import is_peft_available from trl import PPOConfig, PPOTrainer from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE -from .testing_utils import TrlTestCase +from .testing_utils import TrlTestCase, require_peft if is_peft_available(): diff --git a/tests/test_prm_trainer.py b/tests/test_prm_trainer.py index 52f4da21df2..16876c6df62 100644 --- a/tests/test_prm_trainer.py +++ b/tests/test_prm_trainer.py @@ -18,12 +18,11 @@ from datasets import Dataset, load_dataset from parameterized import parameterized from transformers import AutoModelForTokenClassification, AutoTokenizer, PreTrainedTokenizerBase -from transformers.testing_utils import require_peft from transformers.utils import is_peft_available from trl import PRMConfig, PRMTrainer -from .testing_utils import TrlTestCase +from .testing_utils import TrlTestCase, require_peft if is_peft_available(): diff --git a/tests/test_reward_trainer.py b/tests/test_reward_trainer.py index 9f86c8a9111..f5645ae445f 100644 --- a/tests/test_reward_trainer.py +++ b/tests/test_reward_trainer.py @@ -19,13 +19,12 @@ from datasets import load_dataset from parameterized import parameterized from transformers import AutoModelForSequenceClassification, AutoTokenizer -from transformers.testing_utils import require_peft from transformers.utils import is_peft_available from trl import RewardConfig, RewardTrainer from trl.trainer.reward_trainer import DataCollatorForPreference -from .testing_utils import TrlTestCase +from .testing_utils import TrlTestCase, require_peft if is_peft_available(): diff --git a/tests/test_rloo_trainer.py b/tests/test_rloo_trainer.py index 9c267674f47..de3c7c2a159 100644 --- a/tests/test_rloo_trainer.py +++ b/tests/test_rloo_trainer.py @@ -25,12 +25,12 @@ AutoModelForSequenceClassification, AutoTokenizer, ) -from transformers.testing_utils import require_peft, require_vision +from transformers.testing_utils import require_vision from transformers.utils import is_peft_available from trl import RLOOConfig, RLOOTrainer -from .testing_utils import TrlTestCase, require_vllm +from .testing_utils import TrlTestCase, require_peft, require_vllm if is_peft_available(): diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index d5cbcc07b76..482c93df7e3 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -20,13 +20,13 @@ from datasets import load_dataset from parameterized import parameterized from transformers import AutoModelForCausalLM, AutoTokenizer -from transformers.testing_utils import require_flash_attn, require_liger_kernel, require_peft, require_vision +from transformers.testing_utils import require_flash_attn, require_liger_kernel, require_vision from transformers.utils import is_peft_available from trl import SFTConfig, SFTTrainer from trl.trainer.sft_trainer import DataCollatorForLanguageModeling, dft_loss -from .testing_utils import TrlTestCase, ignore_warnings, require_bitsandbytes +from .testing_utils import TrlTestCase, ignore_warnings, require_bitsandbytes, require_peft if is_peft_available(): diff --git a/tests/test_utils.py b/tests/test_utils.py index cdedcf95ad9..14b3fb570af 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -22,7 +22,6 @@ from datasets import load_dataset from parameterized import parameterized from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig -from transformers.testing_utils import require_peft from transformers.utils import is_peft_available from trl import ModelConfig @@ -47,7 +46,7 @@ unsplit_pixel_values_by_grid, ) -from .testing_utils import TrlTestCase, require_rich +from .testing_utils import TrlTestCase, require_peft, require_rich if is_peft_available(): diff --git a/tests/test_xpo_trainer.py b/tests/test_xpo_trainer.py index e4b60a810cf..4d41471187c 100644 --- a/tests/test_xpo_trainer.py +++ b/tests/test_xpo_trainer.py @@ -16,12 +16,11 @@ from datasets import load_dataset from parameterized import parameterized from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer -from transformers.testing_utils import require_peft from transformers.utils import is_peft_available from trl import XPOConfig, XPOTrainer -from .testing_utils import RandomPairwiseJudge, TrlTestCase, require_llm_blender +from .testing_utils import RandomPairwiseJudge, TrlTestCase, require_llm_blender, require_peft if is_peft_available(): diff --git a/tests/testing_utils.py b/tests/testing_utils.py index fdc91ccab81..64904678446 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -22,7 +22,7 @@ import torch from transformers import is_bitsandbytes_available, is_comet_available, is_sklearn_available, is_wandb_available from transformers.testing_utils import torch_device -from transformers.utils import is_rich_available +from transformers.utils import is_peft_available, is_rich_available from trl import BaseBinaryJudge, BasePairwiseJudge from trl.import_utils import is_joblib_available, is_llm_blender_available, is_mergekit_available, is_vllm_available @@ -32,6 +32,7 @@ require_comet = pytest.mark.skipif(not is_comet_available(), reason="test requires comet_ml") require_llm_blender = pytest.mark.skipif(not is_llm_blender_available(), reason="test requires llm-blender") require_mergekit = pytest.mark.skipif(not is_mergekit_available(), reason="test requires mergekit") +require_peft = pytest.mark.skipif(not is_peft_available(), reason="test requires peft") require_rich = pytest.mark.skipif(not is_rich_available(), reason="test requires rich") require_sklearn = pytest.mark.skipif( not (is_sklearn_available() and is_joblib_available()), reason="test requires sklearn" From bc60f7619e658b9d87007a5edac01e3b3b0fc7ee Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Wed, 1 Oct 2025 19:38:59 +0200 Subject: [PATCH 15/16] Replace require_vision --- tests/test_dpo_trainer.py | 3 +-- tests/test_grpo_trainer.py | 4 ++-- tests/test_online_dpo_trainer.py | 11 +++++++++-- tests/test_rloo_trainer.py | 3 +-- tests/test_sft_trainer.py | 4 ++-- tests/testing_utils.py | 3 ++- 6 files changed, 17 insertions(+), 11 deletions(-) diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index 1dcecdb79f4..aea42c8d8d0 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -35,12 +35,11 @@ get_device_properties, require_liger_kernel, require_torch_gpu_if_bnb_not_multi_backend_enabled, - require_vision, ) from trl import DPOConfig, DPOTrainer, FDivergenceType -from .testing_utils import TrlTestCase, require_bitsandbytes, require_no_wandb, require_peft +from .testing_utils import TrlTestCase, require_bitsandbytes, require_no_wandb, require_peft, require_vision if is_vision_available(): diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index 1e504ffe4a8..d63520bea43 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -25,7 +25,7 @@ AutoModelForSequenceClassification, AutoTokenizer, ) -from transformers.testing_utils import require_liger_kernel, require_vision +from transformers.testing_utils import require_liger_kernel from transformers.utils import is_peft_available from trl import GRPOConfig, GRPOTrainer @@ -36,7 +36,7 @@ ) from trl.experimental.gspo_token import GRPOTrainer as GSPOTokenTrainer -from .testing_utils import TrlTestCase, require_peft, require_vllm +from .testing_utils import TrlTestCase, require_peft, require_vision, require_vllm if is_peft_available(): diff --git a/tests/test_online_dpo_trainer.py b/tests/test_online_dpo_trainer.py index 9d2e2ccf8da..f8706770371 100644 --- a/tests/test_online_dpo_trainer.py +++ b/tests/test_online_dpo_trainer.py @@ -18,12 +18,19 @@ from packaging.version import Version from parameterized import parameterized from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer -from transformers.testing_utils import require_torch_accelerator, require_vision +from transformers.testing_utils import require_torch_accelerator from transformers.utils import is_peft_available, is_vision_available from trl import OnlineDPOConfig, OnlineDPOTrainer -from .testing_utils import RandomPairwiseJudge, TrlTestCase, require_llm_blender, require_peft, require_vllm +from .testing_utils import ( + RandomPairwiseJudge, + TrlTestCase, + require_llm_blender, + require_peft, + require_vision, + require_vllm, +) if is_peft_available(): diff --git a/tests/test_rloo_trainer.py b/tests/test_rloo_trainer.py index de3c7c2a159..005cd064a4d 100644 --- a/tests/test_rloo_trainer.py +++ b/tests/test_rloo_trainer.py @@ -25,12 +25,11 @@ AutoModelForSequenceClassification, AutoTokenizer, ) -from transformers.testing_utils import require_vision from transformers.utils import is_peft_available from trl import RLOOConfig, RLOOTrainer -from .testing_utils import TrlTestCase, require_peft, require_vllm +from .testing_utils import TrlTestCase, require_peft, require_vision, require_vllm if is_peft_available(): diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 482c93df7e3..5d1aacf2876 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -20,13 +20,13 @@ from datasets import load_dataset from parameterized import parameterized from transformers import AutoModelForCausalLM, AutoTokenizer -from transformers.testing_utils import require_flash_attn, require_liger_kernel, require_vision +from transformers.testing_utils import require_flash_attn, require_liger_kernel from transformers.utils import is_peft_available from trl import SFTConfig, SFTTrainer from trl.trainer.sft_trainer import DataCollatorForLanguageModeling, dft_loss -from .testing_utils import TrlTestCase, ignore_warnings, require_bitsandbytes, require_peft +from .testing_utils import TrlTestCase, ignore_warnings, require_bitsandbytes, require_peft, require_vision if is_peft_available(): diff --git a/tests/testing_utils.py b/tests/testing_utils.py index 64904678446..cbe677255b5 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -22,7 +22,7 @@ import torch from transformers import is_bitsandbytes_available, is_comet_available, is_sklearn_available, is_wandb_available from transformers.testing_utils import torch_device -from transformers.utils import is_peft_available, is_rich_available +from transformers.utils import is_peft_available, is_rich_available, is_vision_available from trl import BaseBinaryJudge, BasePairwiseJudge from trl.import_utils import is_joblib_available, is_llm_blender_available, is_mergekit_available, is_vllm_available @@ -37,6 +37,7 @@ require_sklearn = pytest.mark.skipif( not (is_sklearn_available() and is_joblib_available()), reason="test requires sklearn" ) +require_vision = pytest.mark.skipif(not is_vision_available(), reason="test requires vision") require_vllm = pytest.mark.skipif(not is_vllm_available(), reason="test requires vllm") require_no_wandb = pytest.mark.skipif(is_wandb_available(), reason="test requires no wandb") require_3_accelerators = pytest.mark.skipif( From bf326893823ff186250f3256e5e9454f3dd422b7 Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Thu, 2 Oct 2025 07:56:26 +0200 Subject: [PATCH 16/16] Replace unittest skip and remove main --- tests/test_cli.py | 10 +++------- tests/test_data_utils.py | 6 ------ tests/test_dpo_trainer.py | 11 +++-------- tests/test_grpo_trainer.py | 15 +++++---------- tests/test_judges.py | 5 +++-- tests/test_modeling_value_head.py | 6 +++--- tests/test_reward_trainer.py | 4 ++-- tests/test_rewards.py | 5 ----- tests/test_rloo_trainer.py | 13 ++++--------- 9 files changed, 23 insertions(+), 52 deletions(-) diff --git a/tests/test_cli.py b/tests/test_cli.py index 638e1d38493..48087f5054c 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -15,18 +15,18 @@ import os import sys -import unittest from io import StringIO from unittest.mock import patch +import pytest import yaml from .testing_utils import TrlTestCase -@unittest.skipIf( +@pytest.mark.skipif( sys.version_info < (3, 10), - "Transformers' generation codebase uses a Python >3.10 syntax (`str | None`), which seems to cause the CLI tests " + reason="Transformers' generation codebase uses a Python >3.10 syntax (`str | None`), which seems to cause the CLI tests " "to fail on Python <3.10.", # let's say it's a known issue, but not expected to be fixed, because too niche ) class TestCLI(TrlTestCase): @@ -113,7 +113,3 @@ def test_sft_config_file(self): # Verify that output directory was created assert os.path.exists(output_dir) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_data_utils.py b/tests/test_data_utils.py index 324680e6331..8fe8a24bd50 100644 --- a/tests/test_data_utils.py +++ b/tests/test_data_utils.py @@ -15,7 +15,6 @@ import copy import itertools import textwrap -import unittest from time import strftime from datasets import Dataset, DatasetDict @@ -977,8 +976,3 @@ def test_already_chatml(self): ] } assert maybe_convert_to_chatml(example) == example - - -# Run the tests -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index aea42c8d8d0..1d1e94fcf99 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -14,7 +14,6 @@ import re import sys -import unittest from unittest.mock import MagicMock import numpy as np @@ -714,9 +713,9 @@ def test_dpo_lora_bf16_autocast_llama(self): ) @require_bitsandbytes @require_peft - @unittest.skipIf( + @pytest.mark.skipif( get_device_properties()[0] == "cuda" and get_device_properties()[1] < 8, - "Skipping because bf16 not supported on CUDA GPU with capability < 8.0", + reason="Skipping because bf16 not supported on CUDA GPU with capability < 8.0", ) def test_dpo_lora_bf16_autocast(self, loss_type, pre_compute, gen_during_eval): from peft import LoraConfig @@ -1301,7 +1300,7 @@ def test_train_with_length_desensitization(self): ] ) @require_liger_kernel - @unittest.skipUnless(sys.version_info >= (3, 10), "Liger kernel is not supported on Python 3.9") + @pytest.mark.skipif(not (sys.version_info >= (3, 10)), reason="Liger kernel is not supported on Python 3.9") def test_dpo_trainer_with_liger(self, beta, loss_type): """Test DPO trainer with Liger loss enabled across supported loss types. @@ -1512,7 +1511,3 @@ def test_f_divergence_type(self, f_divergence_type, as_string: bool): # Serialization: TrainingArguments.to_dict should yield the enum's string value configparser_dict = training_args.to_dict() assert configparser_dict["f_divergence_type"] == f_divergence_type.value - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index d63520bea43..fb7e76f1594 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest from unittest.mock import patch import pytest @@ -720,9 +719,9 @@ def test_training_with_entropy_filter(self): new_param = trainer.model.get_parameter(n) assert not torch.equal(param, new_param), f"Parameter {n} has not changed." - @unittest.skip("We should add a mock for the vLLM server.") @require_peft @require_vllm + @pytest.mark.skip(reason="We should add a mock for the vLLM server.") def test_training_vllm_and_peft(self): """Test that training works with vLLM for generation.""" model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") # tiny model is too small for vLLM @@ -767,7 +766,7 @@ def test_training_vllm_and_peft(self): assert not torch.allclose(param, new_param), f"Parameter {n} has not changed." @require_vllm - @unittest.skip("We should add a mock for the vLLM server.") + @pytest.mark.skip(reason="We should add a mock for the vLLM server.") def test_training_vllm_guided_decoding(self): """Test that training works with vLLM for generation with guided decoding.""" dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") @@ -801,7 +800,7 @@ def test_training_vllm_guided_decoding(self): assert not torch.equal(param, new_param), f"Parameter {n} has not changed." @require_vllm - @unittest.skip("We should add a mock for the vLLM server.") + @pytest.mark.skip(reason="We should add a mock for the vLLM server.") def test_training_vllm_importance_sampling_correction(self): """Test that training works with vLLM for generation with guided decoding.""" dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") @@ -871,7 +870,7 @@ def test_training_with_additional_generation_kwargs(self): assert not torch.equal(param, new_param), f"Parameter {n} has not changed." @require_vllm - @unittest.skip("We should add a mock for the vLLM server.") + @pytest.mark.skip(reason="We should add a mock for the vLLM server.") def test_training_vllm_with_additional_generation_kwargs(self): """Test that training works with vLLM and additional generation kwargs.""" dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") @@ -1521,7 +1520,7 @@ def reward_func(completions, **kwargs): ) @require_vision @require_vllm - @unittest.skip("We should add a mock for the vLLM server.") + @pytest.mark.skip(reason="We should add a mock for the vLLM server.") def test_training_vlm_and_vllm(self, model_id) -> None: dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train") @@ -2006,7 +2005,3 @@ def test_training(self): for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) assert not torch.equal(param, new_param), f"Parameter {n} has not changed." - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_judges.py b/tests/test_judges.py index 9238a7f5cb2..bba1ffffbbb 100644 --- a/tests/test_judges.py +++ b/tests/test_judges.py @@ -13,7 +13,8 @@ # limitations under the License. import time -import unittest + +import pytest from trl import AllTrueJudge, HfPairwiseJudge, PairRMJudge @@ -38,7 +39,7 @@ def test_all_true_judge(self): assert len(judgements) == 2 assert all(judgement in {0, 1, -1} for judgement in judgements) - @unittest.skip("This test needs to be run manually since it requires a valid Hugging Face API key.") + @pytest.mark.skip(reason="This test needs to be run manually since it requires a valid Hugging Face API key.") def test_hugging_face_judge(self): judge = HfPairwiseJudge() prompts, completions = self._get_prompts_and_pairwise_completions() diff --git a/tests/test_modeling_value_head.py b/tests/test_modeling_value_head.py index 95445b31421..a2fde5a12c4 100644 --- a/tests/test_modeling_value_head.py +++ b/tests/test_modeling_value_head.py @@ -13,8 +13,8 @@ # limitations under the License. import gc -import unittest +import pytest import torch from parameterized import parameterized from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, GenerationConfig @@ -273,7 +273,7 @@ def test_transformers_bf16_kwargs(self): # check dummy forward pass works in half precision _ = trl_model(dummy_input) - @unittest.skip("This test needs to be run manually due to HF token issue.") + @pytest.mark.skip(reason="This test needs to be run manually due to HF token issue.") def test_push_to_hub(self): for model_name in self.all_model_names: model = AutoModelForCausalLMWithValueHead.from_pretrained(model_name) @@ -364,7 +364,7 @@ def test_generate(self, model_name): # Just check if the generation works _ = model.generate(input_ids, decoder_input_ids=decoder_input_ids, generation_config=generation_config) - @unittest.skip("This test needs to be run manually due to HF token issue.") + @pytest.mark.skip(reason="This test needs to be run manually due to HF token issue.") def test_push_to_hub(self): for model_name in self.all_model_names: model = self.trl_model_class.from_pretrained(model_name) diff --git a/tests/test_reward_trainer.py b/tests/test_reward_trainer.py index f5645ae445f..ab6d6656e99 100644 --- a/tests/test_reward_trainer.py +++ b/tests/test_reward_trainer.py @@ -13,8 +13,8 @@ # limitations under the License. import pathlib -import unittest +import pytest import torch from datasets import load_dataset from parameterized import parameterized @@ -653,7 +653,7 @@ def test_train_with_set_chat_template_from_path(self): original_template_content = f.read() assert template_content == original_template_content, "Chat template content does not match the original" - @unittest.skip("Skipping until we have a dataset with tool calls") + @pytest.mark.skip(reason="Skipping until we have a dataset with tool calls") def test_train_toolcall_data(self): # Get the dataset dataset = load_dataset("trl-internal-testing/toolcall", split="train") diff --git a/tests/test_rewards.py b/tests/test_rewards.py index 0f584d1b58a..0764ce5d9ea 100644 --- a/tests/test_rewards.py +++ b/tests/test_rewards.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest from trl.rewards import get_soft_overlong_punishment, think_format_reward @@ -86,7 +85,3 @@ def test_soft_overlong_punishment_intermediate_completion(self): completion_ids = [[1] * 90] # 90 is between 80 and 100 rewards = reward_fn(completion_ids) assert round(abs(rewards[0] - -0.5), 4) == 0 - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_rloo_trainer.py b/tests/test_rloo_trainer.py index 005cd064a4d..1de4eca479e 100644 --- a/tests/test_rloo_trainer.py +++ b/tests/test_rloo_trainer.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest from unittest.mock import patch import pytest @@ -580,9 +579,9 @@ def test_training_beta_zero(self): new_param = trainer.model.get_parameter(n) assert not torch.equal(param, new_param), f"Parameter {n} has not changed." - @unittest.skip("We should add a mock for the vLLM server.") @require_peft @require_vllm + @pytest.mark.skip(reason="We should add a mock for the vLLM server.") def test_training_vllm_and_peft(self): """Test that training works with vLLM for generation.""" model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") # tiny model is too small for vLLM @@ -627,7 +626,7 @@ def test_training_vllm_and_peft(self): assert not torch.allclose(param, new_param), f"Parameter {n} has not changed." @require_vllm - @unittest.skip("We should add a mock for the vLLM server.") + @pytest.mark.skip(reason="We should add a mock for the vLLM server.") def test_training_vllm_guided_decoding(self): """Test that training works with vLLM for generation with guided decoding.""" dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") @@ -696,7 +695,7 @@ def test_training_with_additional_generation_kwargs(self): assert not torch.equal(param, new_param), f"Parameter {n} has not changed." @require_vllm - @unittest.skip("We should add a mock for the vLLM server.") + @pytest.mark.skip(reason="We should add a mock for the vLLM server.") def test_training_vllm_with_additional_generation_kwargs(self): """Test that training works with vLLM and additional generation kwargs.""" dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") @@ -1262,7 +1261,7 @@ def reward_func(completions, **kwargs): ) @require_vision @require_vllm - @unittest.skip("We should add a mock for the vLLM server.") + @pytest.mark.skip(reason="We should add a mock for the vLLM server.") def test_training_vlm_and_vllm(self, model_id) -> None: dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train") @@ -1416,7 +1415,3 @@ def test_single_reward_model_with_single_processing_class(self): assert len(trainer.reward_processing_classes) == 1 assert trainer.reward_processing_classes[0] == single_processing_class - - -if __name__ == "__main__": - unittest.main()